Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions provide.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,23 @@ func (o provideOption) String() string {
return fmt.Sprintf("fx.Provide(%s)", strings.Join(items, ", "))
}

// Transient marks a constructor so that fx will provide a factory function
// that creates a new instance each time it is called. For example,
//
// fx.Provide(fx.Transient(NewA))
//
// where `NewA` is `func(...) *A` will register a provider for
// `func() *A`. Each call to the provided factory will run `NewA` and
// return a fresh *A (its dependencies are captured by the factory
// closure when the app starts).
func Transient(constructor any) any {
return transientOption{Target: constructor}
}

type transientOption struct {
Target any
}

func runProvide(c container, p provide, opts ...dig.ProvideOption) error {
constructor := p.Target
if _, ok := constructor.(Option); ok {
Expand Down Expand Up @@ -162,6 +179,21 @@ func runProvide(c container, p provide, opts ...dig.ProvideOption) error {
return fmt.Errorf("fx.Provide(%v) from:\n%+vFailed: %w", ann, p.Stack, err)
}

case transientOption:
// Build a wrapper function of the form:
// func(deps...) func() (outs...)
// The inner func (factory) when called will execute the original
// constructor using the captured dependency values, producing a
// fresh set of outputs each time.
wrapper, err := buildTransientWrapper(constructor.Target)
if err != nil {
return fmt.Errorf("fx.Transient(%v) from:\n%+vFailed: %w", constructor.Target, p.Stack, err)
}

if err := c.Provide(wrapper, opts...); err != nil {
return fmt.Errorf("fx.Provide(fx.Transient(%v)) from:\n%+vFailed: %w", fxreflect.FuncName(constructor.Target), p.Stack, err)
}

default:
if reflect.TypeOf(constructor).Kind() == reflect.Func {
ft := reflect.ValueOf(constructor).Type()
Expand All @@ -185,3 +217,47 @@ func runProvide(c container, p provide, opts ...dig.ProvideOption) error {
}
return nil
}

// buildTransientWrapper builds a wrapper function that captures the
// dependencies of the original constructor and returns a factory function
// which calls the original constructor each time it is invoked.
func buildTransientWrapper(constructor any) (any, error) {
v := reflect.ValueOf(constructor)
if v.Kind() != reflect.Func {
return nil, fmt.Errorf("transient target must be a function, got %T", constructor)
}
t := v.Type()
if t.IsVariadic() {
return nil, fmt.Errorf("transient constructors may not be variadic: %s", fxreflect.FuncName(constructor))
}

// Collect output types for the inner factory function.
outs := make([]reflect.Type, 0, t.NumOut())
for i := 0; i < t.NumOut(); i++ {
outs = append(outs, t.Out(i))
}

// factoryType is func() (outs...)
factoryType := reflect.FuncOf([]reflect.Type{}, outs, false)

// wrapperType is func(in...) factoryType
ins := make([]reflect.Type, 0, t.NumIn())
for i := 0; i < t.NumIn(); i++ {
ins = append(ins, t.In(i))
}
wrapperType := reflect.FuncOf(ins, []reflect.Type{factoryType}, false)

// Make the wrapper function.
wrapper := reflect.MakeFunc(wrapperType, func(in []reflect.Value) []reflect.Value {
// Create the factory function which captures the provided deps (in)
factory := reflect.MakeFunc(factoryType, func(_ []reflect.Value) []reflect.Value {
// Call the original constructor with captured deps.
results := v.Call(in)
// Return results as-is to caller of factory.
return results
})
return []reflect.Value{factory}
})

return wrapper.Interface(), nil
}
60 changes: 60 additions & 0 deletions transient_provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package fx_test

import (
"math/rand"
"testing"

"github.com/stretchr/testify/assert"
"go.uber.org/fx"
"go.uber.org/fx/fxtest"
)

type TransientService struct {
ID int64
}

func NewTransientService() *TransientService {
return &TransientService{
ID: rand.Int63(),
}
}

type ConsumerA struct {
Service *TransientService
}

type ConsumerB struct {
Service *TransientService
}

func TestTransientProvider(t *testing.T) {
t.Run("should create new instance for each resolution", func(t *testing.T) {
app := fxtest.New(t,
fx.Provide(
fx.Transient(NewTransientService),
),
fx.Invoke(func(factory func() *TransientService) {
i1 := factory()
i2 := factory()
assert.NotEqual(t, i1, i2)
assert.NotEqual(t, i1.ID, i2.ID)
}),
)
defer app.RequireStart().RequireStop()
})

t.Run("should create new instance when injected into different consumers", func(t *testing.T) {
app := fxtest.New(t,
fx.Provide(
fx.Transient(NewTransientService),
func(f func() *TransientService) *ConsumerA { return &ConsumerA{Service: f()} },
func(f func() *TransientService) *ConsumerB { return &ConsumerB{Service: f()} },
),
fx.Invoke(func(a *ConsumerA, b *ConsumerB) {
assert.NotEqual(t, a.Service, b.Service)
assert.NotEqual(t, a.Service.ID, b.Service.ID)
}),
)
defer app.RequireStart().RequireStop()
})
}