gobuf/inject_tag.go
2023-06-25 23:02:26 +08:00

351 lines
8.0 KiB
Go

package gobuf
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"io"
"log"
"os"
"path/filepath"
"regexp"
"strings"
"unicode"
)
var (
rComment = regexp.MustCompile(`^//.*?@(?i:gotags?):\s*(.*)$`)
bsonComment = regexp.MustCompile(`^//.*?@(?i:bson?):\s*(.*)$`)
jsonComment = regexp.MustCompile(`^//.*?@(?i:json?):\s*(.*)$`)
rInject = regexp.MustCompile("`.+`$")
rTags = regexp.MustCompile(`[\w_]+:"[^"]+"`)
)
type textArea struct {
Start int
End int
CurrentTag string
InjectTag string
CommentStart int
CommentEnd int
}
type TagValueStyle int
const (
Underline TagValueStyle = iota
LowerCase
UpperCase
)
type InjectTagProps struct {
TagName string // tag name
Style TagValueStyle // tag value style, default underline
comment *ast.Comment
fieldName string
fieldValue string
}
type InjectTag struct {
inputFiles string
defaultTags map[string]TagValueStyle
}
func NewInjectTag(pbFiles string) *InjectTag {
return &InjectTag{inputFiles: pbFiles}
}
func (it *InjectTag) Inject() error {
globResults, err := filepath.Glob(it.inputFiles)
if err != nil {
return fmt.Errorf("parser input file failed: %v", err)
}
var matched int
for _, path := range globResults {
fileInfo, err := os.Stat(path)
if err != nil {
return fmt.Errorf("read file stat failed: %v", err)
}
if fileInfo.IsDir() {
continue
}
// It should end with ".go" at a minimum.
if !strings.HasSuffix(strings.ToLower(fileInfo.Name()), ".go") {
continue
}
matched++
areas, err := it.parserFile(path)
if err != nil {
log.Fatal(err)
}
if err = it.writeFile(path, areas); err != nil {
log.Fatal(err)
}
}
if matched == 0 {
return fmt.Errorf("input %q matched no files, see: -help", it.inputFiles)
}
return nil
}
func (it *InjectTag) WithTags(tags ...InjectTagProps) *InjectTag {
if len(it.defaultTags) == 0 {
it.defaultTags = make(map[string]TagValueStyle)
}
for _, tag := range tags {
it.defaultTags[tag.TagName] = tag.Style
}
return it
}
func (it *InjectTag) parserFile(inputPath string) (areas []textArea, err error) {
it.logf("parsing file %q for inject tag comments", inputPath)
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, inputPath, nil, parser.ParseComments)
if err != nil {
return
}
for _, decl := range f.Decls {
// check if is generic declaration
genDecl, ok := decl.(*ast.GenDecl)
if !ok {
continue
}
var typeSpec *ast.TypeSpec
for _, spec := range genDecl.Specs {
if ts, tsOK := spec.(*ast.TypeSpec); tsOK {
typeSpec = ts
break
}
}
// skip if can't get type spec
if typeSpec == nil {
continue
}
// not a struct, skip
structDecl, ok := typeSpec.Type.(*ast.StructType)
if !ok {
continue
}
for _, field := range structDecl.Fields.List {
// skip if field name abnormal
if len(field.Names) != 1 {
continue
}
fieldName := field.Names[0].Name
if unicode.IsLower(rune(fieldName[0])) {
continue
}
// skip if field has no doc
var comments []*ast.Comment
if field.Doc != nil {
comments = append(comments, field.Doc.List...)
}
// The "doc" field (above comment) is more commonly "free-form"
// due to the ability to have a much larger comment without it
// being unwieldy. As such, the "comment" field (trailing comment),
// should take precedence if there happen to be multiple tags
// specified, both in the field doc, and the field line. Whichever
// comes last, will take precedence.
if field.Comment != nil {
comments = append(comments, field.Comment.List...)
}
tags := it.tagFromComment(field.Names[0].Name, comments)
if len(tags) == 0 {
continue
}
currentTag := field.Tag.Value
for _, tag := range tags {
area := textArea{
Start: int(field.Pos()),
End: int(field.End()),
CurrentTag: currentTag[1 : len(currentTag)-1],
InjectTag: tag.fieldValue,
}
if tag.comment != nil {
area.CommentStart = int(tag.comment.Pos())
area.CommentEnd = int(tag.comment.End())
}
areas = append(areas, area)
}
}
}
//it.logf("parsed file %q, number of fields to inject custom tags: %d", inputPath, len(areas))
return
}
func (it *InjectTag) writeFile(inputPath string, areas []textArea) error {
f, err := os.Open(inputPath)
if err != nil {
return err
}
contents, err := io.ReadAll(f)
if err != nil {
return err
}
if err = f.Close(); err != nil {
return err
}
// inject custom tags from tail of file first to preserve order
for i := range areas {
area := areas[len(areas)-i-1]
it.logf("inject custom tag %q to expression %q", area.InjectTag, string(contents[area.Start-1:area.End-1]))
contents = it.injectTag(contents, area)
}
if err = os.WriteFile(inputPath, contents, 0o644); err != nil {
return err
}
if len(areas) > 0 {
it.logf("file %q is injected with custom tags", inputPath)
}
return nil
}
func (it *InjectTag) logf(format string, v ...interface{}) {
log.Printf(format, v...)
}
func (it *InjectTag) tagFromComment(fieldName string, comments []*ast.Comment) (tags []InjectTagProps) {
var commentInject = make(map[string]struct{})
for i := range comments {
bsonMatch := bsonComment.FindStringSubmatch(comments[i].Text)
if len(bsonMatch) == 2 {
tags = append(tags, InjectTagProps{
TagName: "bson",
comment: comments[i],
fieldName: fieldName,
fieldValue: fmt.Sprintf(`bson:%v`, bsonMatch[1]),
})
commentInject["bson"] = struct{}{}
continue
}
jsonMatch := jsonComment.FindStringSubmatch(comments[i].Text)
if len(jsonMatch) == 2 {
tags = append(tags, InjectTagProps{
TagName: "json",
comment: comments[i],
fieldName: fieldName,
fieldValue: fmt.Sprintf(`json:%v`, jsonMatch[1]),
})
commentInject["json"] = struct{}{}
continue
}
match := rComment.FindStringSubmatch(comments[i].Text)
if len(match) == 2 {
tagVal := match[1]
tagName := strings.Split(match[1], ":")[0]
tags = append(tags, InjectTagProps{
TagName: tagName,
comment: comments[i],
fieldName: fieldName,
fieldValue: tagVal,
})
commentInject[tagName] = struct{}{}
continue
}
}
for tag, style := range it.defaultTags {
_, exist := commentInject[tag]
if exist {
continue
}
tags = append(tags, InjectTagProps{
TagName: tag,
fieldName: fieldName,
fieldValue: fmt.Sprintf("%s:\"%s\"", tag, getFieldValue(fieldName, style)),
})
}
return
}
type tagItem struct {
key string
value string
}
type tagItems []tagItem
func (ti *tagItems) format() string {
var tags []string
for _, item := range *ti {
tags = append(tags, fmt.Sprintf(`%s:%s`, item.key, item.value))
}
return strings.Join(tags, " ")
}
func (ti *tagItems) override(nti tagItems) tagItems {
var overrides []tagItem
for i := range *ti {
dup := -1
for j := range nti {
if (*ti)[i].key == nti[j].key {
dup = j
break
}
}
if dup == -1 {
overrides = append(overrides, (*ti)[i])
} else {
overrides = append(overrides, nti[dup])
nti = append(nti[:dup], nti[dup+1:]...)
}
}
return append(overrides, nti...)
}
func (it *InjectTag) newTagItems(tag string) tagItems {
var items []tagItem
it.logf("new tag: %v\n", tag)
split := rTags.FindAllString(tag, -1)
for _, t := range split {
sepPos := strings.Index(t, ":")
items = append(items, tagItem{
key: t[:sepPos],
value: t[sepPos+1:],
})
}
return items
}
func (it *InjectTag) injectTag(contents []byte, area textArea) (injected []byte) {
expr := make([]byte, area.End-area.Start)
copy(expr, contents[area.Start-1:area.End-1])
log.Println("expr: ", string(expr))
cti := it.newTagItems(area.CurrentTag)
log.Printf("cti: %v\n", cti)
iti := it.newTagItems(area.InjectTag)
log.Printf("iti: %v\n", iti)
ti := cti.override(iti)
log.Printf("ti: %v\n", ti)
expr = rInject.ReplaceAll(expr, []byte(fmt.Sprintf("`%s`", ti.format())))
injected = append(injected, contents[:area.Start-1]...)
injected = append(injected, expr...)
injected = append(injected, contents[area.End-1:]...)
return
}