diff --git a/builtin_test.go b/builtin_test.go index aaea8d5..a05393a 100644 --- a/builtin_test.go +++ b/builtin_test.go @@ -350,6 +350,110 @@ func TestIsTypeEx(t *testing.T) { } } +func TestNewOverloadFunc1(t *testing.T) { + pkg := types.NewPackage("", "foo") + pkg.Scope().Insert(types.NewConst(0, pkg, "XGoPackage", types.Typ[types.Bool], constant.MakeBool(true))) + f1 := types.NewFunc(0, pkg, "bar__0", types.NewSignature(nil, nil, nil, false)) + f2 := types.NewFunc(0, pkg, "bar__1", types.NewSignature(nil, types.NewTuple(types.NewVar(0, pkg, "n", types.Typ[types.Int])), nil, false)) + pkg.Scope().Insert(f1) + pkg.Scope().Insert(f2) + InitXGoPackage(pkg) + of := pkg.Scope().Lookup("bar") + _, objs := CheckSigFuncExObjects(of.Type().(*types.Signature)) + if len(objs) != 2 { + t.Fatal("error") + } +} + +func TestNewOverloadFunc2(t *testing.T) { + pkg := types.NewPackage("", "foo") + pkg.Scope().Insert(types.NewConst(0, pkg, "XGoPackage", types.Typ[types.Bool], constant.MakeBool(true))) + of := NewOverloadFunc(0, pkg, "bar") + f1 := types.NewFunc(0, pkg, "bar__0", types.NewSignature(nil, nil, nil, false)) + f2 := types.NewFunc(0, pkg, "bar__1", types.NewSignature(nil, types.NewTuple(types.NewVar(0, pkg, "n", types.Typ[types.Int])), nil, false)) + pkg.Scope().Insert(of) + pkg.Scope().Insert(f1) + pkg.Scope().Insert(f2) + InitXGoPackage(pkg) + _, objs := CheckSigFuncExObjects(of.Type().(*types.Signature)) + if len(objs) != 2 { + t.Fatal("error") + } +} + +func TestNewOverloadFuncError(t *testing.T) { + defer func() { + err := recover() + if err == nil { + t.Fatal("no error?") + } + }() + pkg := types.NewPackage("", "foo") + pkg.Scope().Insert(types.NewConst(0, pkg, "XGoPackage", types.Typ[types.Bool], constant.MakeBool(true))) + of := types.NewVar(0, pkg, "bar", types.Typ[types.Int]) + f1 := types.NewFunc(0, pkg, "bar__0", types.NewSignature(nil, nil, nil, false)) + f2 := types.NewFunc(0, pkg, "bar__1", types.NewSignature(nil, types.NewTuple(types.NewVar(0, pkg, "n", types.Typ[types.Int])), nil, false)) + pkg.Scope().Insert(of) + pkg.Scope().Insert(f1) + pkg.Scope().Insert(f2) + InitXGoPackage(pkg) +} + +func TestNewOverloadMethod1(t *testing.T) { + pkg := types.NewPackage("", "foo") + pkg.Scope().Insert(types.NewConst(0, pkg, "XGoPackage", types.Typ[types.Bool], constant.MakeBool(true))) + typ := types.NewNamed(types.NewTypeName(0, pkg, "T", nil), types.Typ[types.Int], nil) + pkg.Scope().Insert(typ.Obj()) + f1 := types.NewFunc(0, pkg, "bar__0", types.NewSignature(types.NewVar(0, pkg, "", typ), nil, nil, false)) + f2 := types.NewFunc(0, pkg, "bar__1", types.NewSignature(types.NewVar(0, pkg, "", typ), types.NewTuple(types.NewVar(0, pkg, "n", types.Typ[types.Int])), nil, false)) + typ.AddMethod(f1) + typ.AddMethod(f2) + InitXGoPackage(pkg) + of := findMethod(typ, "bar") + _, objs := CheckSigFuncExObjects(of.Type().(*types.Signature)) + if typ.NumMethods() != 3 || len(objs) != 2 { + t.Fatal("error") + } +} + +func TestNewOverloadMethod2(t *testing.T) { + pkg := types.NewPackage("", "foo") + pkg.Scope().Insert(types.NewConst(0, pkg, "XGoPackage", types.Typ[types.Bool], constant.MakeBool(true))) + typ := types.NewNamed(types.NewTypeName(0, pkg, "T", nil), types.Typ[types.Int], nil) + pkg.Scope().Insert(typ.Obj()) + of := NewOverloadMethod(typ, 0, pkg, "bar") + types.NewFunc(0, pkg, "bar__0", types.NewSignature(types.NewVar(0, pkg, "", typ), nil, nil, false)) + f1 := types.NewFunc(0, pkg, "bar__0", types.NewSignature(types.NewVar(0, pkg, "", typ), nil, nil, false)) + f2 := types.NewFunc(0, pkg, "bar__1", types.NewSignature(types.NewVar(0, pkg, "", typ), types.NewTuple(types.NewVar(0, pkg, "n", types.Typ[types.Int])), nil, false)) + typ.AddMethod(f1) + typ.AddMethod(f2) + InitXGoPackage(pkg) + _, objs := CheckSigFuncExObjects(of.Type().(*types.Signature)) + if typ.NumMethods() != 3 || len(objs) != 2 { + t.Fatal("error") + } +} + +func TestNewOverloadMethodError(t *testing.T) { + defer func() { + err := recover() + if err == nil { + t.Fatal("no error?") + } + }() + pkg := types.NewPackage("", "foo") + pkg.Scope().Insert(types.NewConst(0, pkg, "XGoPackage", types.Typ[types.Bool], constant.MakeBool(true))) + typ := types.NewNamed(types.NewTypeName(0, pkg, "T", nil), types.Typ[types.Int], nil) + pkg.Scope().Insert(typ.Obj()) + of := types.NewFunc(0, pkg, "bar", types.NewSignature(types.NewVar(0, pkg, "", typ), nil, nil, false)) + f1 := types.NewFunc(0, pkg, "bar__0", types.NewSignature(types.NewVar(0, pkg, "", typ), nil, nil, false)) + f2 := types.NewFunc(0, pkg, "bar__1", types.NewSignature(types.NewVar(0, pkg, "", typ), types.NewTuple(types.NewVar(0, pkg, "n", types.Typ[types.Int])), nil, false)) + typ.AddMethod(of) + typ.AddMethod(f1) + typ.AddMethod(f2) + InitXGoPackage(pkg) +} + func TestGetBuiltinTI(t *testing.T) { pkg := NewPackage("", "foo", nil) cb := &pkg.cb diff --git a/import.go b/import.go index d4407d6..79eeccd 100644 --- a/import.go +++ b/import.go @@ -98,11 +98,6 @@ func isXGoCommon(name string) bool { // InitXGoPackage initializes a XGo package. func InitXGoPackage(pkg *types.Package) { - InitXGoPackageEx(pkg, nil) -} - -// InitXGoPackageEx initializes a XGo package. pos map overload name to position. -func InitXGoPackageEx(pkg *types.Package, pos map[string]token.Pos) { scope := pkg.Scope() gopos := make([]string, 0, 4) overloads := make(map[omthd][]types.Object) @@ -162,7 +157,7 @@ func InitXGoPackageEx(pkg *types.Package, pos map[string]token.Pos) { } } if len(fns) > 0 { - newOverload(pkg, scope, m, fns, pos) + newOverload(token.NoPos, pkg, scope, m, fns) } delete(overloads, m) } @@ -170,7 +165,7 @@ func InitXGoPackageEx(pkg *types.Package, pos map[string]token.Pos) { for key, items := range overloads { off := len(key.name) + 2 fns := overloadFuncs(off, items) - newOverload(pkg, scope, key, fns, pos) + newOverload(token.NoPos, pkg, scope, key, fns) } for name, items := range onameds { off := len(name) + 2 @@ -319,21 +314,61 @@ func checkOverloads(scope *types.Scope, gopoName string) (ret []string, exists b return } -func newOverload(pkg *types.Package, scope *types.Scope, m omthd, fns []types.Object, pos map[string]token.Pos) { +func checkOverload[T TyFuncEx](obj types.Object) (t T, ok bool) { + sig, ok := obj.Type().(*types.Signature) + if !ok { + return + } + ext, ok := CheckFuncEx(sig) + if !ok { + return + } + t, ok = ext.(T) + return +} + +func newOverload(pos token.Pos, pkg *types.Package, scope *types.Scope, m omthd, fns []types.Object) { if m.typ == nil { if debugImport { log.Println("==> NewOverloadFunc", m.name) } - o := NewOverloadFunc(pos[m.name], pkg, m.name, fns...) - scope.Insert(o) - checkXGotsx(pkg, scope, m.name, o) + var obj types.Object + if obj = scope.Lookup(m.name); obj != nil { + t, ok := checkOverload[*TyOverloadFunc](obj) + if !ok { + log.Panicf("Object not OverloadFunc: %v", obj) + } + t.Funcs = fns + } else { + obj = NewOverloadFunc(pos, pkg, m.name, fns...) + scope.Insert(obj) + } + checkXGotsx(pkg, scope, m.name, obj) } else { typName := m.typ.Obj().Name() if debugImport { log.Println("==> NewOverloadMethod", typName, m.name) } - NewOverloadMethod(m.typ, pos[typName+"."+m.name], pkg, m.name, fns...) + if obj := findMethod(m.typ, m.name); obj != nil { + t, ok := checkOverload[*TyOverloadMethod](obj) + if !ok { + log.Panicf("Object not OverloadMethod: %v", obj) + } + t.Methods = fns + } else { + NewOverloadMethod(m.typ, pos, pkg, m.name, fns...) + } + } +} + +func findMethod(typ *types.Named, name string) *types.Func { + n := typ.NumMethods() + for i := 0; i < n; i++ { + if m := typ.Method(i); m.Name() == name { + return m + } } + return nil } func overloadFuncs(off int, items []types.Object) []types.Object {