439 lines
11 KiB
Go
439 lines
11 KiB
Go
|
package gen
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"fmt"
|
||
|
"hash/fnv"
|
||
|
"io"
|
||
|
"path"
|
||
|
"reflect"
|
||
|
"sort"
|
||
|
"strconv"
|
||
|
"strings"
|
||
|
"unicode"
|
||
|
)
|
||
|
|
||
|
const pkgWriter = "github.com/mailru/easyjson/jwriter"
|
||
|
const pkgLexer = "github.com/mailru/easyjson/jlexer"
|
||
|
const pkgEasyJSON = "github.com/mailru/easyjson"
|
||
|
|
||
|
// FieldNamer defines a policy for generating names for struct fields.
|
||
|
type FieldNamer interface {
|
||
|
GetJSONFieldName(t reflect.Type, f reflect.StructField) string
|
||
|
}
|
||
|
|
||
|
// Generator generates the requested marshallers/unmarshallers.
|
||
|
type Generator struct {
|
||
|
out *bytes.Buffer
|
||
|
|
||
|
pkgName string
|
||
|
pkgPath string
|
||
|
buildTags string
|
||
|
hashString string
|
||
|
|
||
|
varCounter int
|
||
|
|
||
|
noStdMarshalers bool
|
||
|
omitEmpty bool
|
||
|
fieldNamer FieldNamer
|
||
|
|
||
|
// package path to local alias map for tracking imports
|
||
|
imports map[string]string
|
||
|
|
||
|
// types that marshallers were requested for by user
|
||
|
marshallers map[reflect.Type]bool
|
||
|
|
||
|
// types that encoders were already generated for
|
||
|
typesSeen map[reflect.Type]bool
|
||
|
|
||
|
// types that encoders were requested for (e.g. by encoders of other types)
|
||
|
typesUnseen []reflect.Type
|
||
|
|
||
|
// function name to relevant type maps to track names of de-/encoders in
|
||
|
// case of a name clash or unnamed structs
|
||
|
functionNames map[string]reflect.Type
|
||
|
}
|
||
|
|
||
|
// NewGenerator initializes and returns a Generator.
|
||
|
func NewGenerator(filename string) *Generator {
|
||
|
ret := &Generator{
|
||
|
imports: map[string]string{
|
||
|
pkgWriter: "jwriter",
|
||
|
pkgLexer: "jlexer",
|
||
|
pkgEasyJSON: "easyjson",
|
||
|
"encoding/json": "json",
|
||
|
},
|
||
|
fieldNamer: DefaultFieldNamer{},
|
||
|
marshallers: make(map[reflect.Type]bool),
|
||
|
typesSeen: make(map[reflect.Type]bool),
|
||
|
functionNames: make(map[string]reflect.Type),
|
||
|
}
|
||
|
|
||
|
// Use a file-unique prefix on all auxiliary functions to avoid
|
||
|
// name clashes.
|
||
|
hash := fnv.New32()
|
||
|
hash.Write([]byte(filename))
|
||
|
ret.hashString = fmt.Sprintf("%x", hash.Sum32())
|
||
|
|
||
|
return ret
|
||
|
}
|
||
|
|
||
|
// SetPkg sets the name and path of output package.
|
||
|
func (g *Generator) SetPkg(name, path string) {
|
||
|
g.pkgName = name
|
||
|
g.pkgPath = path
|
||
|
}
|
||
|
|
||
|
// SetBuildTags sets build tags for the output file.
|
||
|
func (g *Generator) SetBuildTags(tags string) {
|
||
|
g.buildTags = tags
|
||
|
}
|
||
|
|
||
|
// SetFieldNamer sets field naming strategy.
|
||
|
func (g *Generator) SetFieldNamer(n FieldNamer) {
|
||
|
g.fieldNamer = n
|
||
|
}
|
||
|
|
||
|
// UseSnakeCase sets snake_case field naming strategy.
|
||
|
func (g *Generator) UseSnakeCase() {
|
||
|
g.fieldNamer = SnakeCaseFieldNamer{}
|
||
|
}
|
||
|
|
||
|
// NoStdMarshalers instructs not to generate standard MarshalJSON/UnmarshalJSON
|
||
|
// methods (only the custom interface).
|
||
|
func (g *Generator) NoStdMarshalers() {
|
||
|
g.noStdMarshalers = true
|
||
|
}
|
||
|
|
||
|
// OmitEmpty triggers `json=",omitempty"` behaviour by default.
|
||
|
func (g *Generator) OmitEmpty() {
|
||
|
g.omitEmpty = true
|
||
|
}
|
||
|
|
||
|
// addTypes requests to generate en-/decoding functions for the given type.
|
||
|
func (g *Generator) addType(t reflect.Type) {
|
||
|
if g.typesSeen[t] {
|
||
|
return
|
||
|
}
|
||
|
for _, t1 := range g.typesUnseen {
|
||
|
if t1 == t {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
g.typesUnseen = append(g.typesUnseen, t)
|
||
|
}
|
||
|
|
||
|
// Add requests to generate (un-)marshallers and en-/decoding functions for the type of given object.
|
||
|
func (g *Generator) Add(obj interface{}) {
|
||
|
t := reflect.TypeOf(obj)
|
||
|
if t.Kind() == reflect.Ptr {
|
||
|
t = t.Elem()
|
||
|
}
|
||
|
g.addType(t)
|
||
|
g.marshallers[t] = true
|
||
|
}
|
||
|
|
||
|
// printHeader prints package declaration and imports.
|
||
|
func (g *Generator) printHeader() {
|
||
|
if g.buildTags != "" {
|
||
|
fmt.Println("// +build ", g.buildTags)
|
||
|
fmt.Println()
|
||
|
}
|
||
|
fmt.Println("// AUTOGENERATED FILE: easyjson marshaller/unmarshallers.")
|
||
|
fmt.Println()
|
||
|
fmt.Println("package ", g.pkgName)
|
||
|
fmt.Println()
|
||
|
|
||
|
byAlias := map[string]string{}
|
||
|
var aliases []string
|
||
|
for path, alias := range g.imports {
|
||
|
aliases = append(aliases, alias)
|
||
|
byAlias[alias] = path
|
||
|
}
|
||
|
|
||
|
sort.Strings(aliases)
|
||
|
fmt.Println("import (")
|
||
|
for _, alias := range g.imports {
|
||
|
fmt.Printf(" %s %q\n", alias, byAlias[alias])
|
||
|
}
|
||
|
|
||
|
fmt.Println(")")
|
||
|
fmt.Println("")
|
||
|
fmt.Println("// suppress unused package warning")
|
||
|
fmt.Println("var (")
|
||
|
fmt.Println(" _ *json.RawMessage")
|
||
|
fmt.Println(" _ *jlexer.Lexer")
|
||
|
fmt.Println(" _ *jwriter.Writer")
|
||
|
fmt.Println(" _ easyjson.Marshaler")
|
||
|
fmt.Println(")")
|
||
|
|
||
|
fmt.Println()
|
||
|
}
|
||
|
|
||
|
// Run runs the generator and outputs generated code to out.
|
||
|
func (g *Generator) Run(out io.Writer) error {
|
||
|
g.out = &bytes.Buffer{}
|
||
|
|
||
|
for len(g.typesUnseen) > 0 {
|
||
|
t := g.typesUnseen[len(g.typesUnseen)-1]
|
||
|
g.typesUnseen = g.typesUnseen[:len(g.typesUnseen)-1]
|
||
|
g.typesSeen[t] = true
|
||
|
|
||
|
if err := g.genDecoder(t); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if err := g.genEncoder(t); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if !g.marshallers[t] {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
if err := g.genStructMarshaller(t); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if err := g.genStructUnmarshaller(t); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
g.printHeader()
|
||
|
_, err := out.Write(g.out.Bytes())
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// fixes vendored paths
|
||
|
func fixPkgPathVendoring(pkgPath string) string {
|
||
|
const vendor = "/vendor/"
|
||
|
if i := strings.LastIndex(pkgPath, vendor); i != -1 {
|
||
|
return pkgPath[i+len(vendor):]
|
||
|
}
|
||
|
return pkgPath
|
||
|
}
|
||
|
|
||
|
// pkgAlias creates and returns and import alias for a given package.
|
||
|
func (g *Generator) pkgAlias(pkgPath string) string {
|
||
|
pkgPath = fixPkgPathVendoring(pkgPath)
|
||
|
if alias := g.imports[pkgPath]; alias != "" {
|
||
|
return alias
|
||
|
}
|
||
|
|
||
|
for i := 0; ; i++ {
|
||
|
alias := path.Base(pkgPath)
|
||
|
if i > 0 {
|
||
|
alias += fmt.Sprint(i)
|
||
|
}
|
||
|
|
||
|
exists := false
|
||
|
for _, v := range g.imports {
|
||
|
if v == alias {
|
||
|
exists = true
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if !exists {
|
||
|
g.imports[pkgPath] = alias
|
||
|
return alias
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// getType return the textual type name of given type that can be used in generated code.
|
||
|
func (g *Generator) getType(t reflect.Type) string {
|
||
|
if t.Name() == "" {
|
||
|
switch t.Kind() {
|
||
|
case reflect.Ptr:
|
||
|
return "*" + g.getType(t.Elem())
|
||
|
case reflect.Slice:
|
||
|
return "[]" + g.getType(t.Elem())
|
||
|
case reflect.Array:
|
||
|
return "[" + strconv.Itoa(t.Len()) + "]" + g.getType(t.Elem())
|
||
|
case reflect.Map:
|
||
|
return "map[" + g.getType(t.Key()) + "]" + g.getType(t.Elem())
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if t.Name() == "" || t.PkgPath() == "" {
|
||
|
if t.Kind() == reflect.Struct {
|
||
|
// the fields of an anonymous struct can have named types,
|
||
|
// and t.String() will not be sufficient because it does not
|
||
|
// remove the package name when it matches g.pkgPath.
|
||
|
// so we convert by hand
|
||
|
nf := t.NumField()
|
||
|
lines := make([]string, 0, nf)
|
||
|
for i := 0; i < nf; i++ {
|
||
|
f := t.Field(i)
|
||
|
line := f.Name + " " + g.getType(f.Type)
|
||
|
t := f.Tag
|
||
|
if t != "" {
|
||
|
line += " " + escapeTag(t)
|
||
|
}
|
||
|
lines = append(lines, line)
|
||
|
}
|
||
|
return strings.Join([]string{"struct { ", strings.Join(lines, "; "), " }"}, "")
|
||
|
}
|
||
|
return t.String()
|
||
|
} else if t.PkgPath() == g.pkgPath {
|
||
|
return t.Name()
|
||
|
}
|
||
|
return g.pkgAlias(t.PkgPath()) + "." + t.Name()
|
||
|
}
|
||
|
|
||
|
// escape a struct field tag string back to source code
|
||
|
func escapeTag(tag reflect.StructTag) string {
|
||
|
t := string(tag)
|
||
|
if strings.ContainsRune(t, '`') {
|
||
|
// there are ` in the string; we can't use ` to enclose the string
|
||
|
return strconv.Quote(t)
|
||
|
}
|
||
|
return "`" + t + "`"
|
||
|
}
|
||
|
|
||
|
// uniqueVarName returns a file-unique name that can be used for generated variables.
|
||
|
func (g *Generator) uniqueVarName() string {
|
||
|
g.varCounter++
|
||
|
return fmt.Sprint("v", g.varCounter)
|
||
|
}
|
||
|
|
||
|
// safeName escapes unsafe characters in pkg/type name and returns a string that can be used
|
||
|
// in encoder/decoder names for the type.
|
||
|
func (g *Generator) safeName(t reflect.Type) string {
|
||
|
name := t.PkgPath()
|
||
|
if t.Name() == "" {
|
||
|
name += "anonymous"
|
||
|
} else {
|
||
|
name += "." + t.Name()
|
||
|
}
|
||
|
|
||
|
parts := []string{}
|
||
|
part := []rune{}
|
||
|
for _, c := range name {
|
||
|
if unicode.IsLetter(c) || unicode.IsDigit(c) {
|
||
|
part = append(part, c)
|
||
|
} else if len(part) > 0 {
|
||
|
parts = append(parts, string(part))
|
||
|
part = []rune{}
|
||
|
}
|
||
|
}
|
||
|
return joinFunctionNameParts(false, parts...)
|
||
|
}
|
||
|
|
||
|
// functionName returns a function name for a given type with a given prefix. If a function
|
||
|
// with this prefix already exists for a type, it is returned.
|
||
|
//
|
||
|
// Method is used to track encoder/decoder names for the type.
|
||
|
func (g *Generator) functionName(prefix string, t reflect.Type) string {
|
||
|
prefix = joinFunctionNameParts(true, "easyjson", g.hashString, prefix)
|
||
|
name := joinFunctionNameParts(true, prefix, g.safeName(t))
|
||
|
|
||
|
// Most of the names will be unique, try a shortcut first.
|
||
|
if e, ok := g.functionNames[name]; !ok || e == t {
|
||
|
g.functionNames[name] = t
|
||
|
return name
|
||
|
}
|
||
|
|
||
|
// Search if the function already exists.
|
||
|
for name1, t1 := range g.functionNames {
|
||
|
if t1 == t && strings.HasPrefix(name1, prefix) {
|
||
|
return name1
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Create a new name in the case of a clash.
|
||
|
for i := 1; ; i++ {
|
||
|
nm := fmt.Sprint(name, i)
|
||
|
if _, ok := g.functionNames[nm]; ok {
|
||
|
continue
|
||
|
}
|
||
|
g.functionNames[nm] = t
|
||
|
return nm
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// DefaultFieldsNamer implements trivial naming policy equivalent to encoding/json.
|
||
|
type DefaultFieldNamer struct{}
|
||
|
|
||
|
func (DefaultFieldNamer) GetJSONFieldName(t reflect.Type, f reflect.StructField) string {
|
||
|
jsonName := strings.Split(f.Tag.Get("json"), ",")[0]
|
||
|
if jsonName != "" {
|
||
|
return jsonName
|
||
|
} else {
|
||
|
return f.Name
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// SnakeCaseFieldNamer implements CamelCase to snake_case conversion for fields names.
|
||
|
type SnakeCaseFieldNamer struct{}
|
||
|
|
||
|
func camelToSnake(name string) string {
|
||
|
var ret bytes.Buffer
|
||
|
|
||
|
multipleUpper := false
|
||
|
var lastUpper rune
|
||
|
var beforeUpper rune
|
||
|
|
||
|
for _, c := range name {
|
||
|
// Non-lowercase character after uppercase is considered to be uppercase too.
|
||
|
isUpper := (unicode.IsUpper(c) || (lastUpper != 0 && !unicode.IsLower(c)))
|
||
|
|
||
|
if lastUpper != 0 {
|
||
|
// Output a delimiter if last character was either the first uppercase character
|
||
|
// in a row, or the last one in a row (e.g. 'S' in "HTTPServer").
|
||
|
// Do not output a delimiter at the beginning of the name.
|
||
|
|
||
|
firstInRow := !multipleUpper
|
||
|
lastInRow := !isUpper
|
||
|
|
||
|
if ret.Len() > 0 && (firstInRow || lastInRow) && beforeUpper != '_' {
|
||
|
ret.WriteByte('_')
|
||
|
}
|
||
|
ret.WriteRune(unicode.ToLower(lastUpper))
|
||
|
}
|
||
|
|
||
|
// Buffer uppercase char, do not output it yet as a delimiter may be required if the
|
||
|
// next character is lowercase.
|
||
|
if isUpper {
|
||
|
multipleUpper = (lastUpper != 0)
|
||
|
lastUpper = c
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
ret.WriteRune(c)
|
||
|
lastUpper = 0
|
||
|
beforeUpper = c
|
||
|
multipleUpper = false
|
||
|
}
|
||
|
|
||
|
if lastUpper != 0 {
|
||
|
ret.WriteRune(unicode.ToLower(lastUpper))
|
||
|
}
|
||
|
return string(ret.Bytes())
|
||
|
}
|
||
|
|
||
|
func (SnakeCaseFieldNamer) GetJSONFieldName(t reflect.Type, f reflect.StructField) string {
|
||
|
jsonName := strings.Split(f.Tag.Get("json"), ",")[0]
|
||
|
if jsonName != "" {
|
||
|
return jsonName
|
||
|
}
|
||
|
|
||
|
return camelToSnake(f.Name)
|
||
|
}
|
||
|
|
||
|
func joinFunctionNameParts(keepFirst bool, parts ...string) string {
|
||
|
buf := bytes.NewBufferString("")
|
||
|
for i, part := range parts {
|
||
|
if i == 0 && keepFirst {
|
||
|
buf.WriteString(part)
|
||
|
} else {
|
||
|
if len(part) > 0 {
|
||
|
buf.WriteString(strings.ToUpper(string(part[0])))
|
||
|
}
|
||
|
if len(part) > 1 {
|
||
|
buf.WriteString(part[1:])
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return buf.String()
|
||
|
}
|