diff --git a/mock/mock.go b/mock/mock.go index a13c37f3b..9b798faba 100644 --- a/mock/mock.go +++ b/mock/mock.go @@ -208,6 +208,22 @@ func (c *Call) On(methodName string, arguments ...interface{}) *Call { return c.Parent.On(methodName, arguments...) } +// OnF chains a new expectation description onto the mcked interface using +// a function reference instead of a string method name. +// for example: +// +// mock. +// OnF(mocked.MyMethod, 1).Return(nil). +// OnF(mocked.MyOtherMethod, 'a', 'b', 'c').Return(errors.New("Some Error")) +// +// The `method` argument must be a function; otherwise, this call will panic. +// The function name is resolved using reflection and runtime information. +// +//go:noinline +func (c *Call) OnF(method interface{}, args ...interface{}) *Call { + return c.Parent.On(runtimeMethodName(method), args...) +} + // Unset removes all mock handlers that satisfy the call instance arguments from being // called. Only supported on call instances with static input arguments. // @@ -381,6 +397,18 @@ func (m *Mock) On(methodName string, arguments ...interface{}) *Call { return c } +// OnF starts a description of an expectation of the specified method +// being called using a function reference instead of a string method name. +// +// Mock.OnF(mocked.MyMethod, arg1, arg2) +// +// The `method` argument must be a function; otherwise, OnF will panic. +// The function name is determined using reflection and runtime information, +// and then passed to On(methodName, args...). +func (m *Mock) OnF(method interface{}, args ...interface{}) *Call { + return m.On(runtimeMethodName(method), args...) +} + // /* // Recording and responding to activity // */ @@ -1304,6 +1332,20 @@ func funcName(f *runtime.Func) string { return splitted[len(splitted)-1] } +func runtimeMethodName(f interface{}) string { + t := reflect.TypeOf(f) + + if t.Kind() != reflect.Func { + panic("not a function") + } + + fname := runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name() + + parts := strings.Split(fname, ".") + + return strings.Split(parts[len(parts)-1], "-")[0] +} + func isFuncSame(f1, f2 *runtime.Func) bool { f1File, f1Loc := f1.FileLine(f1.Entry()) f2File, f2Loc := f2.FileLine(f2.Entry()) diff --git a/mock/mock_test.go b/mock/mock_test.go index 3dc9e0b1e..fb60d3d2e 100644 --- a/mock/mock_test.go +++ b/mock/mock_test.go @@ -558,6 +558,17 @@ func Test_Mock_On_WithFuncTypeArg(t *testing.T) { }) } +func Test_Mock_OnF(t *testing.T) { + t.Parallel() + + // make a test impl object + var mockedService = new(TestExampleImplementation) + + c := mockedService.OnF(mockedService.TheExampleMethod) + assert.Equal(t, []*Call{c}, mockedService.ExpectedCalls) + assert.Equal(t, "TheExampleMethod", c.Method) +} + func Test_Mock_Unset(t *testing.T) { t.Parallel()