From 0c29db783c5f384ac3edec11c43831e59dac0a6a Mon Sep 17 00:00:00 2001 From: Reza Mokaram Date: Fri, 7 Nov 2025 03:01:13 +0330 Subject: [PATCH] feat(fx): add Transient provider option for non-cached constructors Introduces a new fx.Transient() helper that allows constructors to be invoked each time their dependency is requested, rather than once at application start. This enables transient or per-request dependency lifetimes similar to scoped services in other DI systems. The Transient wrapper registers a factory function that produces a new instance on each call. Includes unit tests verifying: - Constructor runs multiple times for transient providers. - Singleton behavior remains unchanged for regular fx.Provide. --- provide.go | 76 ++++++++++++++++++++++++++++++++++++++ transient_provider_test.go | 60 ++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+) create mode 100644 transient_provider_test.go diff --git a/provide.go b/provide.go index f55348966..8d03b0488 100644 --- a/provide.go +++ b/provide.go @@ -119,6 +119,23 @@ func (o provideOption) String() string { return fmt.Sprintf("fx.Provide(%s)", strings.Join(items, ", ")) } +// Transient marks a constructor so that fx will provide a factory function +// that creates a new instance each time it is called. For example, +// +// fx.Provide(fx.Transient(NewA)) +// +// where `NewA` is `func(...) *A` will register a provider for +// `func() *A`. Each call to the provided factory will run `NewA` and +// return a fresh *A (its dependencies are captured by the factory +// closure when the app starts). +func Transient(constructor any) any { + return transientOption{Target: constructor} +} + +type transientOption struct { + Target any +} + func runProvide(c container, p provide, opts ...dig.ProvideOption) error { constructor := p.Target if _, ok := constructor.(Option); ok { @@ -162,6 +179,21 @@ func runProvide(c container, p provide, opts ...dig.ProvideOption) error { return fmt.Errorf("fx.Provide(%v) from:\n%+vFailed: %w", ann, p.Stack, err) } + case transientOption: + // Build a wrapper function of the form: + // func(deps...) func() (outs...) + // The inner func (factory) when called will execute the original + // constructor using the captured dependency values, producing a + // fresh set of outputs each time. + wrapper, err := buildTransientWrapper(constructor.Target) + if err != nil { + return fmt.Errorf("fx.Transient(%v) from:\n%+vFailed: %w", constructor.Target, p.Stack, err) + } + + if err := c.Provide(wrapper, opts...); err != nil { + return fmt.Errorf("fx.Provide(fx.Transient(%v)) from:\n%+vFailed: %w", fxreflect.FuncName(constructor.Target), p.Stack, err) + } + default: if reflect.TypeOf(constructor).Kind() == reflect.Func { ft := reflect.ValueOf(constructor).Type() @@ -185,3 +217,47 @@ func runProvide(c container, p provide, opts ...dig.ProvideOption) error { } return nil } + +// buildTransientWrapper builds a wrapper function that captures the +// dependencies of the original constructor and returns a factory function +// which calls the original constructor each time it is invoked. +func buildTransientWrapper(constructor any) (any, error) { + v := reflect.ValueOf(constructor) + if v.Kind() != reflect.Func { + return nil, fmt.Errorf("transient target must be a function, got %T", constructor) + } + t := v.Type() + if t.IsVariadic() { + return nil, fmt.Errorf("transient constructors may not be variadic: %s", fxreflect.FuncName(constructor)) + } + + // Collect output types for the inner factory function. + outs := make([]reflect.Type, 0, t.NumOut()) + for i := 0; i < t.NumOut(); i++ { + outs = append(outs, t.Out(i)) + } + + // factoryType is func() (outs...) + factoryType := reflect.FuncOf([]reflect.Type{}, outs, false) + + // wrapperType is func(in...) factoryType + ins := make([]reflect.Type, 0, t.NumIn()) + for i := 0; i < t.NumIn(); i++ { + ins = append(ins, t.In(i)) + } + wrapperType := reflect.FuncOf(ins, []reflect.Type{factoryType}, false) + + // Make the wrapper function. + wrapper := reflect.MakeFunc(wrapperType, func(in []reflect.Value) []reflect.Value { + // Create the factory function which captures the provided deps (in) + factory := reflect.MakeFunc(factoryType, func(_ []reflect.Value) []reflect.Value { + // Call the original constructor with captured deps. + results := v.Call(in) + // Return results as-is to caller of factory. + return results + }) + return []reflect.Value{factory} + }) + + return wrapper.Interface(), nil +} diff --git a/transient_provider_test.go b/transient_provider_test.go new file mode 100644 index 000000000..84043d4cb --- /dev/null +++ b/transient_provider_test.go @@ -0,0 +1,60 @@ +package fx_test + +import ( + "math/rand" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/fx" + "go.uber.org/fx/fxtest" +) + +type TransientService struct { + ID int64 +} + +func NewTransientService() *TransientService { + return &TransientService{ + ID: rand.Int63(), + } +} + +type ConsumerA struct { + Service *TransientService +} + +type ConsumerB struct { + Service *TransientService +} + +func TestTransientProvider(t *testing.T) { + t.Run("should create new instance for each resolution", func(t *testing.T) { + app := fxtest.New(t, + fx.Provide( + fx.Transient(NewTransientService), + ), + fx.Invoke(func(factory func() *TransientService) { + i1 := factory() + i2 := factory() + assert.NotEqual(t, i1, i2) + assert.NotEqual(t, i1.ID, i2.ID) + }), + ) + defer app.RequireStart().RequireStop() + }) + + t.Run("should create new instance when injected into different consumers", func(t *testing.T) { + app := fxtest.New(t, + fx.Provide( + fx.Transient(NewTransientService), + func(f func() *TransientService) *ConsumerA { return &ConsumerA{Service: f()} }, + func(f func() *TransientService) *ConsumerB { return &ConsumerB{Service: f()} }, + ), + fx.Invoke(func(a *ConsumerA, b *ConsumerB) { + assert.NotEqual(t, a.Service, b.Service) + assert.NotEqual(t, a.Service.ID, b.Service.ID) + }), + ) + defer app.RequireStart().RequireStop() + }) +}