diff --git a/mutator/mutator.go b/mutator/mutator.go index 34f05d7..675033b 100644 --- a/mutator/mutator.go +++ b/mutator/mutator.go @@ -46,13 +46,15 @@ func combineErrorsWithDefault(p *pkg.Package) *pkg.Package { return p } -// inlinePrimitiveTypes takes any non-struct type declarations and inlines their -// use as the original type in function parameters, return types, and structs. +// inlinePrimitiveTypes takes any non-struct, non-interface type declarations +// and inlines their use as the original type in function parameters, return +// types, and structs. // // Removal of the type declaration is handled in subsequent mutations. func inlinePrimitiveTypes(p *pkg.Package) *pkg.Package { for _, d := range p.TypeDecls { - if _, ok := d.Type.(*pkg.StructType); ok { + switch d.Type.(type) { + case *pkg.StructType, *pkg.InterfaceType: continue } @@ -354,6 +356,14 @@ func recurseType(typ pkg.Type, parentCtx typeContext, fn func(pkg.Type, typeCont t.Type = recurseType(t.Type, parentCtx|iter, fn) case *pkg.PointerType: t.Type = recurseType(t.Type, parentCtx, fn) + case *pkg.InterfaceType: + for i := range t.Methods { + t.Methods[i].Return = recurseType(t.Methods[i].Return, parentCtx, fn) + } + + for i := range t.Implementors { + t.Implementors[i].Type = recurseType(t.Implementors[i].Type, parentCtx, fn) + } } return fn(typ, parentCtx) diff --git a/mutator/mutator_test.go b/mutator/mutator_test.go index ff1ed1d..6fda200 100644 --- a/mutator/mutator_test.go +++ b/mutator/mutator_test.go @@ -155,6 +155,36 @@ func TestInlinePrimitiveTypes(t *testing.T) { }}}}, }, }, + + {"struct not inlined", + pkg.Package{ + TypeDecls: []pkg.TypeDecl{{Name: "ID", Type: &pkg.StructType{}}}, + Clients: []pkg.Client{{Methods: []pkg.Method{{ + Errors: map[int]pkg.Type{-1: &pkg.IdentType{Name: "ID"}}, + }}}}, + }, + pkg.Package{ + TypeDecls: []pkg.TypeDecl{{Name: "ID", Type: &pkg.StructType{}}}, + Clients: []pkg.Client{{Methods: []pkg.Method{{ + Errors: map[int]pkg.Type{-1: &pkg.IdentType{Name: "ID"}}, + }}}}, + }, + }, + + {"interface not inlined", + pkg.Package{ + TypeDecls: []pkg.TypeDecl{{Name: "ID", Type: &pkg.InterfaceType{}}}, + Clients: []pkg.Client{{Methods: []pkg.Method{{ + Errors: map[int]pkg.Type{-1: &pkg.IdentType{Name: "ID"}}, + }}}}, + }, + pkg.Package{ + TypeDecls: []pkg.TypeDecl{{Name: "ID", Type: &pkg.InterfaceType{}}}, + Clients: []pkg.Client{{Methods: []pkg.Method{{ + Errors: map[int]pkg.Type{-1: &pkg.IdentType{Name: "ID"}}, + }}}}, + }, + }, } for _, tc := range tcs { @@ -229,6 +259,23 @@ func TestInlineResponseStruct(t *testing.T) { } } + interfacePkg := func(c ...pkg.Client) pkg.Package { + return pkg.Package{ + TypeDecls: []pkg.TypeDecl{ + {Name: "FieldThing", Type: &pkg.StructType{ + Fields: []pkg.Field{{ID: "Field", Type: &pkg.IdentType{Name: "string"}}}, + }}, + {Name: "StructThing", Type: &pkg.StructType{ + Fields: []pkg.Field{{ID: "Field", Type: &pkg.IdentType{Name: "FieldThing"}}}, + }}, + {Name: "InterfaceThing", Type: &pkg.InterfaceType{ + Methods: []pkg.InterfaceMethod{{Name: "GetField", Return: &pkg.IdentType{Name: "FieldThing"}}}, + }}, + }, + Clients: c, + } + } + inlineDoublePkg := func(c ...pkg.Client) pkg.Package { return pkg.Package{ TypeDecls: []pkg.TypeDecl{ @@ -329,6 +376,17 @@ func TestInlineResponseStruct(t *testing.T) { Return: []pkg.Type{&pkg.IdentType{Name: "StructThing"}}, }}}), }, + + {"not inlined if used in one struct and interface", + interfacePkg(pkg.Client{Methods: []pkg.Method{{Return: []pkg.Type{ + &pkg.IdentType{Name: "StructThing"}, + &pkg.IdentType{Name: "InterfaceThing"}, + }}}}), + interfacePkg(pkg.Client{Methods: []pkg.Method{{Return: []pkg.Type{ + &pkg.IdentType{Name: "StructThing"}, + &pkg.IdentType{Name: "InterfaceThing"}, + }}}}), + }, } for _, tc := range tcs { diff --git a/openapi/v2/spec.go b/openapi/v2/spec.go index 4b59302..497adbf 100644 --- a/openapi/v2/spec.go +++ b/openapi/v2/spec.go @@ -759,7 +759,7 @@ func (a *AllOfSchema) UnmarshalYAML(um func(interface{}) error) error { // ObjectSchema is a schema definition for an object. type ObjectSchema struct { SchemaFields `yaml:",inline"` - Descriminator *string + Discriminator *string Properties *SchemaMap Required *[]string @@ -775,7 +775,7 @@ type ObjectSchema struct { func (o *ObjectSchema) UnmarshalYAML(um func(interface{}) error) error { var oy struct { SchemaFields `yaml:",inline"` - Descriminator *string + Discriminator *string Properties *SchemaMap Required *[]string @@ -789,7 +789,7 @@ func (o *ObjectSchema) UnmarshalYAML(um func(interface{}) error) error { } o.SchemaFields = oy.SchemaFields - o.Descriminator = oy.Descriminator + o.Discriminator = oy.Discriminator o.Properties = oy.Properties o.Required = oy.Required o.MinProperties = oy.MinProperties diff --git a/pkg/pkg.go b/pkg/pkg.go index 516d8e6..c2c09dd 100644 --- a/pkg/pkg.go +++ b/pkg/pkg.go @@ -71,21 +71,22 @@ type StructType struct { // Equal implements equality for Types func (t *StructType) Equal(o Type) bool { - if ot, ok := o.(*StructType); ok { - if len(t.Fields) != len(ot.Fields) { - return false - } + ot, ok := o.(*StructType) + if !ok { + return false + } - for i, f := range t.Fields { - if !f.equal(ot.Fields[i]) { - return false - } - } + if len(t.Fields) != len(ot.Fields) { + return false + } - return true + for i, f := range t.Fields { + if !f.equal(ot.Fields[i]) { + return false + } } - return false + return true } // IterType is used for return types, indicating they're iterators @@ -117,18 +118,64 @@ func (t *MapType) Equal(o Type) bool { return false } -// InterfaceType is an empty interface -type InterfaceType struct{} +// EmptyInterfaceType is an empty interface +type EmptyInterfaceType struct{} // Equal implements equality for Types -func (t *InterfaceType) Equal(o Type) bool { - _, ok := o.(*InterfaceType) +func (t *EmptyInterfaceType) Equal(o Type) bool { + _, ok := o.(*EmptyInterfaceType) return ok } +// InterfaceType is an interface that represents the common values in an +// OpenAPI discriminator type. +type InterfaceType struct { + Methods []InterfaceMethod + Implementors []InterfaceImplementor +} + +// Equal implements equality for Types +func (t *InterfaceType) Equal(o Type) bool { + ot, ok := o.(*InterfaceType) + if !ok { + return false + } + + if len(t.Methods) != len(ot.Methods) { + return false + } + + for i, m := range t.Methods { + if !m.equal(ot.Methods[i]) { + return false + } + } + + return true +} + +// InterfaceMethod is a method on an interface +type InterfaceMethod struct { + Name string + Return Type + Comment string +} + +func (m InterfaceMethod) equal(o InterfaceMethod) bool { + return m.Name == o.Name && m.Return.Equal(o.Return) +} + +// InterfaceImplementor is an Implementor of an interface created from +// discriminators +type InterfaceImplementor struct { + Discriminator string + Type Type +} + // TypeDecl is a type declaration. type TypeDecl struct { Name string + Orig string Comment string Type Type } diff --git a/pkg/pkg_test.go b/pkg/pkg_test.go index 444e5ac..d1559ae 100644 --- a/pkg/pkg_test.go +++ b/pkg/pkg_test.go @@ -14,7 +14,18 @@ func TestTypeEqual(t *testing.T) { &StructType{Fields: []Field{{ID: "Name", Type: &IdentType{Name: "Thing"}}}}, &MapType{Key: &IdentType{Name: "string"}, Value: &IdentType{Name: "int"}}, &MapType{Key: &IdentType{Name: "int"}, Value: &SliceType{&IdentType{Name: "int"}}}, - &InterfaceType{}, + &EmptyInterfaceType{}, + &InterfaceType{Methods: []InterfaceMethod{}}, + &InterfaceType{Methods: []InterfaceMethod{ + {Name: "GetFoo", Return: &IdentType{Name: "string"}}, + }}, + &InterfaceType{Methods: []InterfaceMethod{ + {Name: "GetFoo", Return: &IdentType{Name: "int"}}, + }}, + &InterfaceType{Methods: []InterfaceMethod{ + {Name: "GetBar", Return: &IdentType{Name: "int"}}, + {Name: "GetFoo", Return: &IdentType{Name: "string"}}, + }}, } for i := range cases { diff --git a/translator/registry.go b/translator/registry.go index 4bda0d7..6d914f3 100644 --- a/translator/registry.go +++ b/translator/registry.go @@ -9,8 +9,9 @@ import ( ) type typeRegistry struct { - strFmt stringFormat - types []pkg.TypeDecl + strFmt stringFormat + types []pkg.TypeDecl + discriminators map[string]*pkg.TypeDecl } func (tr *typeRegistry) add(td pkg.TypeDecl) { @@ -31,17 +32,14 @@ func (tr *typeRegistry) convertSchema(schema v2.Schema, td *pkg.TypeDecl, declAl mt := &pkg.MapType{Key: &pkg.IdentType{Name: "string"}} if s.AnyAdditionalProperties { - mt.Value = &pkg.InterfaceType{} + mt.Value = &pkg.EmptyInterfaceType{} } else { mt.Value = tr.convertSchema(s.AdditionalProperties, &pkg.TypeDecl{ Name: td.Name + "Value"}, false) } - if td != nil { - td.Type = mt - tr.types = append(tr.types, *td) - } - + td.Type = mt + tr.types = append(tr.types, *td) return &pkg.IdentType{Name: td.Name} } @@ -52,6 +50,52 @@ func (tr *typeRegistry) convertSchema(schema v2.Schema, td *pkg.TypeDecl, declAl } } + if s.Discriminator != nil { + t := &pkg.InterfaceType{} + + for _, prop := range *s.Properties { + method := pkg.InterfaceMethod{ + Name: formatID("Get", prop.Name), + } + + method.Comment = commentForPropSchema(prop.Schema) + + sn := td.Name + formatID(prop.Name) + method.Return = tr.convertSchema(prop.Schema, &pkg.TypeDecl{ + Name: sn, + Comment: fmt.Sprintf("%s is a data type for API communication.", sn), + }, false) + + if _, ok := required[prop.Name]; !ok { + method.Return = &pkg.PointerType{Type: method.Return} + if method.Comment == "" { + method.Comment = "Optional" + } + } + + t.Methods = append(t.Methods, method) + } + + td.Type = t + tr.types = append(tr.types, *td) + + if tr.discriminators == nil { + tr.discriminators = make(map[string]*pkg.TypeDecl) + } + tr.discriminators[td.Name] = &tr.types[len(tr.types)-1] + + oldDiscriminator := s.Discriminator + s.Discriminator = nil + mn := td.Name + "Meta" + tr.convertSchema(s, &pkg.TypeDecl{ + Name: mn, + Comment: fmt.Sprintf("%s is an abstract data type for API communication.", mn), + }, false) + + s.Discriminator = oldDiscriminator + return &pkg.IdentType{Name: td.Name} + } + for _, prop := range *s.Properties { field := pkg.Field{ ID: formatID(prop.Name), @@ -61,14 +105,7 @@ func (tr *typeRegistry) convertSchema(schema v2.Schema, td *pkg.TypeDecl, declAl field.Orig = prop.Name } - fieldComment := "" - if prop.Schema.GetTitle() != nil { - fieldComment += *prop.Schema.GetTitle() - } - if prop.Schema.GetDescription() != nil { - fieldComment += *prop.Schema.GetDescription() - } - field.Comment = fieldComment + field.Comment = commentForPropSchema(prop.Schema) sn := td.Name + formatID(prop.Name) field.Type = tr.convertSchema(prop.Schema, &pkg.TypeDecl{ @@ -85,11 +122,9 @@ func (tr *typeRegistry) convertSchema(schema v2.Schema, td *pkg.TypeDecl, declAl t.Fields = append(t.Fields, field) } - if td != nil { - td.Type = t - tr.types = append(tr.types, *td) - } + td.Type = t + tr.types = append(tr.types, *td) return &pkg.IdentType{Name: td.Name} case *v2.StringSchema: ret = tr.strFmt.typeFor(s.Format) @@ -110,6 +145,30 @@ func (tr *typeRegistry) convertSchema(schema v2.Schema, td *pkg.TypeDecl, declAl fields := make([]pkg.Field, len(s.AllOf)) for i := range s.AllOf { + if dr, ok := s.AllOf[i].(*v2.ReferenceSchema); ok { + parts := strings.Split(dr.Reference, "/") + refName := formatID(parts[len(parts)-1]) + + if dt, ok := tr.discriminators[refName]; ok { + fields[i] = pkg.Field{ + Type: &pkg.IdentType{ + Name: refName + "Meta", + }, + } + + td.Comment = strings.Replace(td.Comment, td.Name, td.Name+refName, 1) + td.Name += refName + + disc := dt.Type.(*pkg.InterfaceType) + disc.Implementors = append(disc.Implementors, pkg.InterfaceImplementor{ + Discriminator: td.Orig, + Type: &pkg.IdentType{Name: td.Name}, + }) + + continue + } + } + sn := fmt.Sprintf("%sAllOf%d", td.Name, i) field := pkg.Field{ Type: tr.convertSchema(s.AllOf[i], &pkg.TypeDecl{ @@ -125,11 +184,8 @@ func (tr *typeRegistry) convertSchema(schema v2.Schema, td *pkg.TypeDecl, declAl Fields: fields, } - if td != nil { - td.Type = t - tr.types = append(tr.types, *td) - } - + td.Type = t + tr.types = append(tr.types, *td) return &pkg.IdentType{Name: td.Name} default: // XXX handle this @@ -223,3 +279,15 @@ func (sf stringFormat) typeFor(fmt *string) pkg.Type { return &pkg.IdentType{Name: "string"} } + +func commentForPropSchema(prop v2.Schema) string { + comment := "" + if prop.GetTitle() != nil { + comment += *prop.GetTitle() + } + if prop.GetDescription() != nil { + comment += *prop.GetDescription() + } + + return comment +} diff --git a/translator/registry_test.go b/translator/registry_test.go index bdfe9ea..18d686f 100644 --- a/translator/registry_test.go +++ b/translator/registry_test.go @@ -86,7 +86,7 @@ func TestConvertSchema(t *testing.T) { &pkg.IdentType{Name: "Foo"}, []pkg.TypeDecl{{Name: "Foo", Type: &pkg.MapType{ Key: &pkg.IdentType{Name: "string"}, - Value: &pkg.InterfaceType{}, + Value: &pkg.EmptyInterfaceType{}, }}}, }, { @@ -129,6 +129,71 @@ func TestConvertSchema(t *testing.T) { }, }, }, + { + "object discriminator", + &v2.ObjectSchema{ + Discriminator: ptr("type_field"), + Properties: &v2.SchemaMap{ + {Name: "type_field", Schema: &v2.StringSchema{}}, + }, + Required: &[]string{"type_field"}, + }, + &pkg.TypeDecl{Name: "Foo"}, + &pkg.IdentType{Name: "Foo"}, + []pkg.TypeDecl{ + { + Name: "Foo", + Type: &pkg.InterfaceType{ + Methods: []pkg.InterfaceMethod{ + {Name: "GetTypeField", Return: &pkg.IdentType{Name: "string"}}, + }, + }, + }, + { + Name: "FooMeta", + Comment: "FooMeta is an abstract data type for API communication.", + Type: &pkg.StructType{ + Fields: []pkg.Field{ + {ID: "TypeField", Type: &pkg.IdentType{Name: "string"}, Orig: "type_field"}, + }, + }, + }, + }, + }, + { + "object discriminator optional field", + &v2.ObjectSchema{ + Discriminator: ptr("type_field"), + Properties: &v2.SchemaMap{ + {Name: "field", Schema: &v2.StringSchema{}}, + {Name: "type_field", Schema: &v2.StringSchema{}}, + }, + Required: &[]string{"type_field"}, + }, + &pkg.TypeDecl{Name: "Foo"}, + &pkg.IdentType{Name: "Foo"}, + []pkg.TypeDecl{ + { + Name: "Foo", + Type: &pkg.InterfaceType{ + Methods: []pkg.InterfaceMethod{ + {Name: "GetField", Return: &pkg.PointerType{Type: &pkg.IdentType{Name: "string"}}, Comment: "Optional"}, + {Name: "GetTypeField", Return: &pkg.IdentType{Name: "string"}}, + }, + }, + }, + { + Name: "FooMeta", + Comment: "FooMeta is an abstract data type for API communication.", + Type: &pkg.StructType{ + Fields: []pkg.Field{ + {ID: "Field", Type: &pkg.PointerType{Type: &pkg.IdentType{Name: "string"}}, Orig: "field", Comment: "Optional"}, + {ID: "TypeField", Type: &pkg.IdentType{Name: "string"}, Orig: "type_field"}, + }, + }, + }, + }, + }, } for _, tc := range tcs { @@ -217,3 +282,5 @@ func TestStringFormatTypeFor(t *testing.T) { }) } } + +func ptr(in string) *string { return &in } diff --git a/translator/translator.go b/translator/translator.go index 6342b8f..f823cce 100644 --- a/translator/translator.go +++ b/translator/translator.go @@ -18,10 +18,21 @@ func Translate(doc *v2.Document, qual, name string, types map[string]string, str BaseURL: "https://" + *doc.Host + *doc.BasePath, } + // convert defined types. do two passes: one to convert the discriminators, + // and another for the remaining types. This ensures the discriminators are + // defined for any types that may use them. tr := &typeRegistry{strFmt: stringFormats} if doc.Definitions != nil { for _, def := range *doc.Definitions { - convertDefinition(tr, def.Name, def.Schema, types) + if os, ok := def.Schema.(*v2.ObjectSchema); ok && os.Discriminator != nil { + convertDefinition(tr, def.Name, def.Schema, types) + } + } + + for _, def := range *doc.Definitions { + if os, ok := def.Schema.(*v2.ObjectSchema); !ok || os.Discriminator == nil { + convertDefinition(tr, def.Name, def.Schema, types) + } } } @@ -100,6 +111,7 @@ func convertDefinition(tr *typeRegistry, name string, def v2.Schema, types map[s tr.convertSchema(def, &pkg.TypeDecl{ Name: dataName, + Orig: name, Comment: comment, }, true) } diff --git a/writer/type.go b/writer/type.go index a53b703..2cd0f91 100644 --- a/writer/type.go +++ b/writer/type.go @@ -27,8 +27,10 @@ func writeType(typ pkg.Type) func(s *jen.Statement) { s.Struct(convertFields(t.Fields)...) case *pkg.MapType: s.Map(jen.Do(writeType(t.Key))).Do(writeType(t.Value)) - case *pkg.InterfaceType: + case *pkg.EmptyInterfaceType: s.Interface() + case *pkg.InterfaceType: + s.Interface(convertInterfaceMethods(t.Methods)...) default: panic("unhandled type") } @@ -72,6 +74,32 @@ func convertFields(fields []pkg.Field) []jen.Code { return o } +func convertInterfaceMethods(methods []pkg.InterfaceMethod) []jen.Code { + var o []jen.Code + first := true + for _, m := range methods { + im := jen.Id(m.Name).Params().Do(writeType(m.Return)) + + if m.Comment != "" { + if len(m.Comment) < 80 { + im.Comment(strings.Replace(m.Comment, "\n", " ", -1)) + } else { + if !first { + o = append(o, jen.Empty()) + } + o = append(o, jen.Comment(formatComment(m.Comment))) + } + } + + o = append(o, im) + if first { + first = false + } + } + + return o +} + func hasStruct(typ pkg.Type) bool { switch t := typ.(type) { case *pkg.IterType: diff --git a/writer/type_test.go b/writer/type_test.go index 083ffae..01557cc 100644 --- a/writer/type_test.go +++ b/writer/type_test.go @@ -21,7 +21,17 @@ func TestWriteType(t *testing.T) { {"slice", &pkg.SliceType{Type: &pkg.IdentType{Name: "string"}}, "[]string"}, {"pointer", &pkg.PointerType{Type: &pkg.IdentType{Name: "string"}}, "*string"}, {"empty struct", &pkg.StructType{}, "struct{}"}, - {"empty interface", &pkg.InterfaceType{}, "interface{}"}, + {"empty interface", &pkg.EmptyInterfaceType{}, "interface{}"}, + + {"interface (no methods)", &pkg.InterfaceType{}, "interface{}"}, + {"interface", &pkg.InterfaceType{Methods: []pkg.InterfaceMethod{ + {Name: "GetFoo", Return: &pkg.IdentType{Name: "string"}}, + {Name: "GetBar", Return: &pkg.IdentType{Name: "string"}, Comment: "bar is neat"}, + }}, `interface{ + GetFoo() string + GetBar() string // bar is neat + }`}, + {"struct", &pkg.StructType{Fields: []pkg.Field{ {ID: "Foo", Type: &pkg.IdentType{Name: "string"}}, @@ -51,6 +61,7 @@ func TestWriteType(t *testing.T) { Bar struct{} }`, }, + {"map", &pkg.MapType{ Key: &pkg.IdentType{Name: "string"},