diff --git a/annotated.go b/annotated.go index 03eebc821..88698886b 100644 --- a/annotated.go +++ b/annotated.go @@ -28,6 +28,7 @@ import ( "strings" "go.uber.org/dig" + "go.uber.org/fx/internal/fxreflect" ) @@ -1213,7 +1214,7 @@ func isIn(t reflect.Type) bool { var _ Annotation = (*asAnnotation)(nil) // As is an Annotation that annotates the result of a function (i.e. a -// constructor) to be provided as another interface. +// constructor) to be provided as another interface or convertible type. // // For example, the following code specifies that the return type of // bytes.NewBuffer (bytes.Buffer) should be provided as io.Writer type: @@ -1233,6 +1234,28 @@ var _ Annotation = (*asAnnotation)(nil) // constructor does NOT provide both bytes.Buffer and io.Writer type; it just // provides io.Writer type. // +// Example for function value: +// +// type domainHandler func(ctx context.Context) error +// +// func anyHandlerProvider() func(ctx context.Context) error { +// ... +// } +// +// fx.Provide( +// anyHandlerProvider(), +// fx.As(new(domainHandler)), +// ) +// +// Example for convertible types: +// +// type customStringType string +// +// fx.Provide( +// func() string { return "some string" }, +// fx.As(new(customStringType)), +// ) +// // When multiple values are returned by the annotated function, each type // gets mapped to corresponding positional result of the annotated function. // @@ -1299,8 +1322,8 @@ func (at *asAnnotation) apply(ann *annotated) error { continue } t := reflect.TypeOf(typ) - if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Interface { - return fmt.Errorf("fx.As: argument must be a pointer to an interface: got %v", t) + if t.Kind() != reflect.Ptr { + return fmt.Errorf("fx.As: argument must be a pointer to an interface or convertible type: got %v", t) } t = t.Elem() at.types[i] = asType{typ: t} @@ -1353,8 +1376,12 @@ func (at *asAnnotation) results(ann *annotated) ( continue } - if !t.Implements(at.types[i].typ) { - return nil, nil, fmt.Errorf("invalid fx.As: %v does not implement %v", t, at.types[i]) + if at.types[i].typ.Kind() == reflect.Interface { + if !t.Implements(at.types[i].typ) { + return nil, nil, fmt.Errorf("invalid fx.As: %v does not implement %v", t, at.types[i]) + } + } else if !t.ConvertibleTo(at.types[i].typ) { + return nil, nil, fmt.Errorf("invalid fx.As: %v cannot be converted to %v", t, at.types[i]) } field.Type = at.types[i].typ fields = append(fields, field) @@ -1388,7 +1415,12 @@ func (at *asAnnotation) results(ann *annotated) ( newOutResult := reflect.New(resType).Elem() for i := 1; i < resType.NumField(); i++ { - newOutResult.Field(i).Set(getResult(i, results)) + resultInstance := getResult(i, results) + if newOutResult.Field(i).Kind() == reflect.Interface { + newOutResult.Field(i).Set(resultInstance) + } else if toType := newOutResult.Field(i).Type(); resultInstance.CanConvert(toType) { + newOutResult.Field(i).Set(resultInstance.Convert(toType)) + } } outResults = append(outResults, newOutResult) diff --git a/annotated_test.go b/annotated_test.go index 2baf39da5..07e8a5b95 100644 --- a/annotated_test.go +++ b/annotated_test.go @@ -33,6 +33,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/fx" "go.uber.org/fx/fxevent" "go.uber.org/fx/fxtest" @@ -442,6 +443,10 @@ func TestAnnotatedAs(t *testing.T) { type myStringer interface { String() string } + type myProvideFunc func() string + type myInvokeFunc func() string + + type myStringType string newAsStringer := func() *asStringer { return &asStringer{ @@ -477,6 +482,43 @@ func TestAnnotatedAs(t *testing.T) { assert.Equal(t, s.String(), "another stringer") }, }, + { + desc: "function value convertible to target type", + provide: fx.Provide( + fx.Annotate(func() myProvideFunc { + return func() string { + return "provide func example" + } + }, fx.As(new(myInvokeFunc))), + ), + invoke: func(h myInvokeFunc) { + assert.Equal(t, "provide func example", h()) + }, + }, + { + desc: "anonymous function value convertible to target type", + provide: fx.Provide( + fx.Annotate(func() func() string { + return func() string { + return "anonymous func example" + } + }, fx.As(new(myInvokeFunc))), + ), + invoke: func(h myInvokeFunc) { + assert.Equal(t, "anonymous func example", h()) + }, + }, + { + desc: "value type convertible to target type", + provide: fx.Provide( + fx.Annotate(func() string { + return "provide convertible type" + }, fx.As(new(myStringType))), + ), + invoke: func(h myStringType) { + assert.Equal(t, myStringType("provide convertible type"), h) + }, + }, { desc: "provide with multiple types As", provide: fx.Provide(fx.Annotate(func() (*asStringer, *bytes.Buffer) { @@ -834,6 +876,18 @@ func TestAnnotatedAsFailures(t *testing.T) { return nil, errors.New("great sadness") } + type myProvideFunc func() string + + exampleProvideFunc := func() myProvideFunc { + return func() string { + return "i'm string" + } + } + + type myIntType int + + type myInvokeFunc func() int + tests := []struct { desc string provide fx.Option @@ -846,6 +900,18 @@ func TestAnnotatedAsFailures(t *testing.T) { invoke: func() {}, errorContains: "asStringer does not implement io.Writer", }, + { + desc: "provide when an inconvertible function value As", + provide: fx.Provide(fx.Annotate(exampleProvideFunc, fx.As(new(myInvokeFunc)))), + invoke: func() {}, + errorContains: "fx_test.myProvideFunc cannot be converted to fx_test.myInvokeFunc", + }, + { + desc: "provide when an inconvertible type As", + provide: fx.Provide(fx.Annotate(exampleProvideFunc(), fx.As(new(myIntType)))), + invoke: func() {}, + errorContains: "string cannot be converted to fx_test.myIntType", + }, { desc: "provide when an illegal type As with result tag", provide: fx.Provide(fx.Annotate(newAsStringer, fx.ResultTags(`name:"stringer"`), fx.As(new(io.Writer)))), @@ -890,7 +956,7 @@ func TestAnnotatedAsFailures(t *testing.T) { fx.As("foo"), ), ), - errorContains: "argument must be a pointer to an interface: got string", + errorContains: "argument must be a pointer to an interface or convertible type: got string", }, } @@ -1806,9 +1872,9 @@ func TestAnnotateApplySuccess(t *testing.T) { func assertApp( t *testing.T, app interface { - Start(context.Context) error - Stop(context.Context) error - }, + Start(context.Context) error + Stop(context.Context) error +}, started *bool, stopped *bool, invoked *bool,