protoc-gen-coco/main.go

305 lines
8.2 KiB
Go

package main
import (
"bytes"
"flag"
"fmt"
"github.com/emicklei/proto"
"gitter.top/common/protofmt"
"google.golang.org/protobuf/compiler/protogen"
"os"
"path/filepath"
"runtime/debug"
"strings"
)
var (
version string
)
func init() {
info, ok := debug.ReadBuildInfo()
if !ok {
version = "unknown"
return
}
version = info.Main.Version
}
type Coco struct {
DisableGenerateRouter bool // 禁止路由信息生成
DisableGenerateMongoModel bool // 禁止mongo信息生成
DisableGenerateErrorCode bool // 禁止错误码信息生成
DisableGenerateRouterWire bool // 禁止wire路由信息生成
Prefix string // Prefix 当遇到proto是相对路径生成的时候 指定prefix
ProjectName string // 项目名称
}
// Generate generate coco router map
func (c *Coco) Generate(plugin *protogen.Plugin) error {
if len(plugin.Files) == 0 {
return nil
}
if c.Prefix != "" {
if !strings.HasSuffix(c.Prefix, "/") {
c.Prefix += "/"
}
}
c.format(plugin)
if !c.DisableGenerateRouter {
c.generateRouterMap(plugin)
c.generateRouterImpl(plugin)
c.generateGrpcImpl(plugin)
c.generateRouterWire(plugin)
}
if !c.DisableGenerateMongoModel {
c.generateMongoModel(plugin)
}
if !c.DisableGenerateErrorCode {
c.generateErrorCode(plugin)
}
return nil
}
func (c *Coco) format(plugin *protogen.Plugin) {
for _, pbFile := range plugin.Files {
filename := pbFile.Desc.Path()
if c.Prefix != "" {
filename = c.Prefix + filename
}
file, err := os.Open(filename)
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "format pbFile %s failed: %v\n", filename, err)
continue
}
buf := new(bytes.Buffer)
parser := proto.NewParser(file)
parser.Filename(filename)
def, err := parser.Parse()
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "parse pbFile %s failed: %v\n", filename, err)
continue
}
protofmt.NewFormatter(buf, " ").Format(def)
if err := os.WriteFile(filename, buf.Bytes(), os.ModePerm); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "rewrite pbFile %s failed: %v\n", filename, err)
continue
}
}
}
func (c *Coco) generateRouterMap(plugin *protogen.Plugin) {
for _, pbFile := range plugin.Files {
// service empty
if len(pbFile.Services) == 0 {
continue
}
for _, service := range pbFile.Services {
// is service router group
if !IsCommentRouterGroup(pbFile.Services[0].Comments) {
continue
}
if len(service.Methods) == 0 {
continue
}
if GetCommentGenerateTo(service.Comments) == "" {
continue
}
// write file header
filename := fmt.Sprintf("%s/autogen_router_%s.go",
separator(filepath.Dir(pbFile.GeneratedFilenamePrefix)), CamelCaseToUnderscore(service.GoName))
g := plugin.NewGeneratedFile(filename, pbFile.GoImportPath)
g.P("// Code generated by protoc-gen-coco. DO NOT EDIT.")
g.P("// source: ", pbFile.GeneratedFilenamePrefix, ".proto")
g.P("// protoc-gen-coco: ", version)
g.P()
g.P("package ", pbFile.GoPackageName)
g.P()
g.P(`import (
"gitter.top/coco/coco"
)`)
g.P()
values, err := GenerateRouterMap(service)
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "generate router map failed: %v", err)
continue
}
g.P(values)
g.P()
}
}
}
func (c *Coco) generateRouterImpl(plugin *protogen.Plugin) {
for _, pbFile := range plugin.Files {
// service empty
if len(pbFile.Services) == 0 {
continue
}
for _, service := range pbFile.Services {
// is service router group
if !IsCommentRouterGroup(service.Comments) {
continue
}
genTo := GetCommentGenerateTo(service.Comments)
if genTo == "" {
continue
}
// mkdir when Dir(genTo) not found
if err := mkdir(filepath.Dir(genTo)); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "mkdir genTo path failed: %s\n", err)
return
}
// new version, package name only support `controller`+`api_version`
// so old value `string(pbFile.GoPackageName)` Deprecated.
version := filepath.Base(filepath.Dir(pbFile.GeneratedFilenamePrefix))
generator := newRouterImpl(genTo, "controller"+version, string(pbFile.GoPackageName), service.GoName, service)
generator.ProjectName = c.ProjectName
if !generator.IsFileExist() {
if err := generator.generateNewFile(); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "generate all router failed: %v\n", err)
}
return
}
if err := generator.generateNewRouters(); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "generate router impl failed: %v\n", err)
}
}
}
}
func (c *Coco) generateGrpcImpl(plugin *protogen.Plugin) {
for _, pbFile := range plugin.Files {
// service empty
if len(pbFile.Services) == 0 {
continue
}
for _, service := range pbFile.Services {
genTo := GetCommentRpcGenerateTo(service.Comments)
if genTo == "" {
continue
}
// mkdir when Dir(genTo) not found
if err := mkdir(filepath.Dir(genTo)); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "mkdir genTo path failed: %s\n", err)
return
}
version := filepath.Base(filepath.Dir(pbFile.GeneratedFilenamePrefix))
generator := newRouterImpl(genTo, "rpc"+version, string(pbFile.GoPackageName), service.GoName, service)
generator.ProjectName = c.ProjectName
if !generator.IsFileExist() {
if err := generator.generateNewGrpcFile(); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "generate all grpc failed: %v\n", err)
}
return
}
if err := generator.generateNewGrpcServers(); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "generate rpc impl failed: %v\n", err)
}
}
}
}
func (c *Coco) generateMongoModel(plugin *protogen.Plugin) {
for _, pbFile := range plugin.Files {
if len(pbFile.Messages) == 0 {
return
}
var needMessages []*protogen.Message
for _, message := range pbFile.Messages {
// 非model
if !IsCommentModel(message) {
continue
}
needMessages = append(needMessages, message)
}
if len(needMessages) == 0 {
continue
}
filename := fmt.Sprintf("%s/autogen_model_%s.go",
separator(filepath.Dir(pbFile.GeneratedFilenamePrefix)), filepath.Base(pbFile.GeneratedFilenamePrefix))
g := plugin.NewGeneratedFile(filename, pbFile.GoImportPath)
g.P("// Code generated by protoc-gen-coco. DO NOT EDIT.")
g.P("// source: ", pbFile.GeneratedFilenamePrefix, ".proto")
g.P("// protoc-gen-coco: ", version)
g.P()
g.P("package ", pbFile.GoPackageName)
g.P()
for _, message := range needMessages {
value, err := GenerateModel(message, string(pbFile.GoPackageName))
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "generate model failed: %v", err)
continue
}
g.P(value)
g.P()
}
}
}
func (c *Coco) generateRouterWire(plugin *protogen.Plugin) {
if c.DisableGenerateRouterWire {
return
}
for _, pbFile := range plugin.Files {
if len(pbFile.Services) == 0 {
continue
}
for _, service := range pbFile.Services {
// is service router group
if !IsCommentRouterGroup(service.Comments) {
continue
}
if err := mkdir("gen/wire"); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "mkdir gen/wire path failed: %s\n", err)
return
}
wireDir := "gen/wire/wire.go"
generator := newRouterImplWire(wireDir, "wire", string(pbFile.GoPackageName), service.GoName)
generator.Version = filepath.Base(filepath.Dir(pbFile.GeneratedFilenamePrefix))
generator.ProjectName = c.ProjectName
if !generator.IsFileExist() {
if err := generator.generateNewWireFile(); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "generate router wire failed: %v\n", err)
return
}
}
if err := generator.generateServicesWire(); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "generate router wire failed: %v\n", err)
}
}
}
}
func main() {
var coco = new(Coco)
var flags flag.FlagSet
flags.BoolVar(&coco.DisableGenerateErrorCode, "disable_error_code", false, "disable generate error code")
flags.BoolVar(&coco.DisableGenerateMongoModel, "disable_mongodb_model", false, "disable generate mongodb model")
flags.BoolVar(&coco.DisableGenerateRouter, "disable_router", false, "disable generate router code")
flags.BoolVar(&coco.DisableGenerateRouterWire, "disable_router_wire", false, "disable generate router wire")
flags.StringVar(&coco.Prefix, "prefix", "", "proto prefix")
flags.StringVar(&coco.ProjectName, "project_name", "", "project name")
protogen.Options{
ParamFunc: flags.Set,
}.Run(coco.Generate)
}