diff --git a/plugins/pluginrpc-gen/README.md b/plugins/pluginrpc-gen/README.md index 98720b2..0418a3e 100644 --- a/plugins/pluginrpc-gen/README.md +++ b/plugins/pluginrpc-gen/README.md @@ -43,16 +43,6 @@ supplying `--tag`. This flag can be specified multiple times. ## Known issues -The parser can currently only handle types which are not specifically a map or -a slice. -You can, however, create a type that uses a map or a slice internally, for instance: - -```go -type opts map[string]string -``` - -This `opts` type will work, whreas using a `map[string]string` directly will not. - ## go-generate You can also use this with go-generate, which is pretty awesome. diff --git a/plugins/pluginrpc-gen/fixtures/foo.go b/plugins/pluginrpc-gen/fixtures/foo.go index fcb2b62..5695dcc 100644 --- a/plugins/pluginrpc-gen/fixtures/foo.go +++ b/plugins/pluginrpc-gen/fixtures/foo.go @@ -1,5 +1,17 @@ package foo +import ( + "fmt" + + aliasedio "io" + + "github.com/docker/docker/pkg/plugins/pluginrpc-gen/fixtures/otherfixture" +) + +var ( + errFakeImport = fmt.Errorf("just to import fmt for imports tests") +) + type wobble struct { Some string Val string @@ -22,6 +34,7 @@ type Fooer3 interface { Qux(a, b string) (val string, err error) Wobble() (w *wobble) Wiggle() (w wobble) + WiggleWobble(a []*wobble, b []wobble, c map[string]*wobble, d map[*wobble]wobble, e map[string][]wobble, f []*otherfixture.Spaceship) (g map[*wobble]wobble, h [][]*wobble, i otherfixture.Spaceship, j *otherfixture.Spaceship, k map[*otherfixture.Spaceship]otherfixture.Spaceship, l []otherfixture.Spaceship) } // Fooer4 is an interface used for tests. @@ -39,3 +52,38 @@ type Fooer5 interface { Foo() Bar } + +// Fooer6 is an interface used for tests. +type Fooer6 interface { + Foo(a otherfixture.Spaceship) +} + +// Fooer7 is an interface used for tests. +type Fooer7 interface { + Foo(a *otherfixture.Spaceship) +} + +// Fooer8 is an interface used for tests. +type Fooer8 interface { + Foo(a map[string]otherfixture.Spaceship) +} + +// Fooer9 is an interface used for tests. +type Fooer9 interface { + Foo(a map[string]*otherfixture.Spaceship) +} + +// Fooer10 is an interface used for tests. +type Fooer10 interface { + Foo(a []otherfixture.Spaceship) +} + +// Fooer11 is an interface used for tests. +type Fooer11 interface { + Foo(a []*otherfixture.Spaceship) +} + +// Fooer12 is an interface used for tests. +type Fooer12 interface { + Foo(a aliasedio.Reader) +} diff --git a/plugins/pluginrpc-gen/fixtures/otherfixture/spaceship.go b/plugins/pluginrpc-gen/fixtures/otherfixture/spaceship.go new file mode 100644 index 0000000..1937d17 --- /dev/null +++ b/plugins/pluginrpc-gen/fixtures/otherfixture/spaceship.go @@ -0,0 +1,4 @@ +package otherfixture + +// Spaceship is a fixture for tests +type Spaceship struct{} diff --git a/plugins/pluginrpc-gen/main.go b/plugins/pluginrpc-gen/main.go index 772984c..402044c 100644 --- a/plugins/pluginrpc-gen/main.go +++ b/plugins/pluginrpc-gen/main.go @@ -78,7 +78,7 @@ func main() { errorOut("parser error", generatedTempl.Execute(&buf, analysis)) src, err := format.Source(buf.Bytes()) - errorOut("error formating generated source", err) + errorOut("error formating generated source:\n"+buf.String(), err) errorOut("error writing file", ioutil.WriteFile(*outputFile, src, 0644)) } diff --git a/plugins/pluginrpc-gen/parser.go b/plugins/pluginrpc-gen/parser.go index 3adeb49..6c547e1 100644 --- a/plugins/pluginrpc-gen/parser.go +++ b/plugins/pluginrpc-gen/parser.go @@ -6,7 +6,9 @@ import ( "go/ast" "go/parser" "go/token" + "path" "reflect" + "strings" ) var errBadReturn = errors.New("found return arg with no name: all args must be named") @@ -25,6 +27,7 @@ func (e errUnexpectedType) Error() string { type ParsedPkg struct { Name string Functions []function + Imports []importSpec } type function struct { @@ -35,14 +38,29 @@ type function struct { } type arg struct { - Name string - ArgType string + Name string + ArgType string + PackageSelector string } func (a *arg) String() string { return a.Name + " " + a.ArgType } +type importSpec struct { + Name string + Path string +} + +func (s *importSpec) String() string { + var ss string + if len(s.Name) != 0 { + ss += s.Name + } + ss += s.Path + return ss +} + // Parse parses the given file for an interface definition with the given name. func Parse(filePath string, objName string) (*ParsedPkg, error) { fs := token.NewFileSet() @@ -73,6 +91,44 @@ func Parse(filePath string, objName string) (*ParsedPkg, error) { return nil, err } + // figure out what imports will be needed + imports := make(map[string]importSpec) + for _, f := range p.Functions { + args := append(f.Args, f.Returns...) + for _, arg := range args { + if len(arg.PackageSelector) == 0 { + continue + } + + for _, i := range pkg.Imports { + if i.Name != nil { + if i.Name.Name != arg.PackageSelector { + continue + } + imports[i.Path.Value] = importSpec{Name: arg.PackageSelector, Path: i.Path.Value} + break + } + + _, name := path.Split(i.Path.Value) + splitName := strings.Split(name, "-") + if len(splitName) > 1 { + name = splitName[len(splitName)-1] + } + // import paths have quotes already added in, so need to remove them for name comparison + name = strings.TrimPrefix(name, `"`) + name = strings.TrimSuffix(name, `"`) + if name == arg.PackageSelector { + imports[i.Path.Value] = importSpec{Path: i.Path.Value} + break + } + } + } + } + + for _, spec := range imports { + p.Imports = append(p.Imports, spec) + } + return p, nil } @@ -142,22 +198,66 @@ func parseArgs(fields []*ast.Field) ([]arg, error) { return nil, errBadReturn } for _, name := range f.Names { - var typeName string - switch argType := f.Type.(type) { - case *ast.Ident: - typeName = argType.Name - case *ast.StarExpr: - i, ok := argType.X.(*ast.Ident) - if !ok { - return nil, errUnexpectedType{"*ast.Ident", f.Type} - } - typeName = "*" + i.Name - default: - return nil, errUnexpectedType{"*ast.Ident or *ast.StarExpr", f.Type} + p, err := parseExpr(f.Type) + if err != nil { + return nil, err } - - args = append(args, arg{name.Name, typeName}) + args = append(args, arg{name.Name, p.value, p.pkg}) } } return args, nil } + +type parsedExpr struct { + value string + pkg string +} + +func parseExpr(e ast.Expr) (parsedExpr, error) { + var parsed parsedExpr + switch i := e.(type) { + case *ast.Ident: + parsed.value += i.Name + case *ast.StarExpr: + p, err := parseExpr(i.X) + if err != nil { + return parsed, err + } + parsed.value += "*" + parsed.value += p.value + parsed.pkg = p.pkg + case *ast.SelectorExpr: + p, err := parseExpr(i.X) + if err != nil { + return parsed, err + } + parsed.pkg = p.value + parsed.value += p.value + "." + parsed.value += i.Sel.Name + case *ast.MapType: + parsed.value += "map[" + p, err := parseExpr(i.Key) + if err != nil { + return parsed, err + } + parsed.value += p.value + parsed.value += "]" + p, err = parseExpr(i.Value) + if err != nil { + return parsed, err + } + parsed.value += p.value + parsed.pkg = p.pkg + case *ast.ArrayType: + parsed.value += "[]" + p, err := parseExpr(i.Elt) + if err != nil { + return parsed, err + } + parsed.value += p.value + parsed.pkg = p.pkg + default: + return parsed, errUnexpectedType{"*ast.Ident or *ast.StarExpr", i} + } + return parsed, nil +} diff --git a/plugins/pluginrpc-gen/parser_test.go b/plugins/pluginrpc-gen/parser_test.go index 5a7579c..a1b1ac9 100644 --- a/plugins/pluginrpc-gen/parser_test.go +++ b/plugins/pluginrpc-gen/parser_test.go @@ -47,7 +47,7 @@ func TestParseWithMultipleFuncs(t *testing.T) { } assertName(t, "foo", pkg.Name) - assertNum(t, 6, len(pkg.Functions)) + assertNum(t, 7, len(pkg.Functions)) f := pkg.Functions[0] assertName(t, "Foo", f.Name) @@ -105,6 +105,35 @@ func TestParseWithMultipleFuncs(t *testing.T) { arg = f.Returns[0] assertName(t, "w", arg.Name) assertName(t, "wobble", arg.ArgType) + + f = pkg.Functions[6] + assertName(t, "WiggleWobble", f.Name) + assertNum(t, 6, len(f.Args)) + assertNum(t, 6, len(f.Returns)) + expectedArgs := [][]string{ + {"a", "[]*wobble"}, + {"b", "[]wobble"}, + {"c", "map[string]*wobble"}, + {"d", "map[*wobble]wobble"}, + {"e", "map[string][]wobble"}, + {"f", "[]*otherfixture.Spaceship"}, + } + for i, arg := range f.Args { + assertName(t, expectedArgs[i][0], arg.Name) + assertName(t, expectedArgs[i][1], arg.ArgType) + } + expectedReturns := [][]string{ + {"g", "map[*wobble]wobble"}, + {"h", "[][]*wobble"}, + {"i", "otherfixture.Spaceship"}, + {"j", "*otherfixture.Spaceship"}, + {"k", "map[*otherfixture.Spaceship]otherfixture.Spaceship"}, + {"l", "[]otherfixture.Spaceship"}, + } + for i, ret := range f.Returns { + assertName(t, expectedReturns[i][0], ret.Name) + assertName(t, expectedReturns[i][1], ret.ArgType) + } } func TestParseWithUnamedReturn(t *testing.T) { @@ -150,6 +179,31 @@ func TestEmbeddedInterface(t *testing.T) { assertName(t, "error", arg.ArgType) } +func TestParsedImports(t *testing.T) { + cases := []string{"Fooer6", "Fooer7", "Fooer8", "Fooer9", "Fooer10", "Fooer11"} + for _, testCase := range cases { + pkg, err := Parse(testFixture, testCase) + if err != nil { + t.Fatal(err) + } + + assertNum(t, 1, len(pkg.Imports)) + importPath := strings.Split(pkg.Imports[0].Path, "/") + assertName(t, "otherfixture\"", importPath[len(importPath)-1]) + assertName(t, "", pkg.Imports[0].Name) + } +} + +func TestAliasedImports(t *testing.T) { + pkg, err := Parse(testFixture, "Fooer12") + if err != nil { + t.Fatal(err) + } + + assertNum(t, 1, len(pkg.Imports)) + assertName(t, "aliasedio", pkg.Imports[0].Name) +} + func assertName(t *testing.T, expected, actual string) { if expected != actual { fatalOut(t, fmt.Sprintf("expected name to be `%s`, got: %s", expected, actual)) diff --git a/plugins/pluginrpc-gen/template.go b/plugins/pluginrpc-gen/template.go index d3dc494..50ed929 100644 --- a/plugins/pluginrpc-gen/template.go +++ b/plugins/pluginrpc-gen/template.go @@ -13,6 +13,19 @@ func printArgs(args []arg) string { return strings.Join(argStr, ", ") } +func buildImports(specs []importSpec) string { + if len(specs) == 0 { + return `import "errors"` + } + imports := "import(\n" + imports += "\t\"errors\"\n" + for _, i := range specs { + imports += "\t" + i.String() + "\n" + } + imports += ")" + return imports +} + func marshalType(t string) string { switch t { case "error": @@ -44,6 +57,7 @@ var templFuncs = template.FuncMap{ "lower": strings.ToLower, "title": title, "tag": buildTag, + "imports": buildImports, } func title(s string) string { @@ -60,7 +74,7 @@ var generatedTempl = template.Must(template.New("rpc_cient").Funcs(templFuncs).P package {{ .Name }} -import "errors" +{{ imports .Imports }} type client interface{ Call(string, interface{}, interface{}) error