diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index a597107510e78..7362a44905d4f 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -931,10 +931,15 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout) static_shape, is_nonzero = _is_static_problem(layout) name = "addmm" - # Create MMKernelInputs for AddMM at the top + # Create MMKernelInputs for AddMM for triton fused kernels kernel_inputs = MMKernelInputs( [inp_expanded, mat1, mat2], scalars=dict(alpha=alpha, beta=beta) ) + # Create MMKernelInputs for AddMM for Aten kernels + kernel_inputs_aten = MMKernelInputs( + [inp, mat1, mat2], scalars=dict(alpha=alpha, beta=beta) + ) + choices: list[ChoiceCaller] = [] # below is for getting an overview logging info of inductor mms @@ -960,15 +965,9 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): aten_layout = FlexibleLayout( device=layout.device, dtype=layout.dtype, size=layout.size ) - # TODO(coconutruben): combine this with the main flow of addmm through - # a subgraph or something as inp vs inp_expanded causes some slight numeric - # differences - kernel_inputs = MMKernelInputs( - [inp, mat1, mat2], scalars=dict(alpha=alpha, beta=beta) - ) choices.extend( V.choices.get_mm_configs( - kernel_inputs, + kernel_inputs_aten, aten_layout, [aten_addmm], name, @@ -979,7 +978,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): if use_aten_gemm_kernels(): choices.extend( V.choices.get_mm_configs( - kernel_inputs, + kernel_inputs_aten, aten_layout, [aten_bias_addmm], name, @@ -987,7 +986,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): ) choices.extend( V.choices.get_mm_configs( - kernel_inputs, + kernel_inputs_aten, aten_layout, [aten_addmm], name,