Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions torch/_inductor/kernel/mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,10 +931,15 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
static_shape, is_nonzero = _is_static_problem(layout)
name = "addmm"
# Create MMKernelInputs for AddMM at the top
# Create MMKernelInputs for AddMM for triton fused kernels
kernel_inputs = MMKernelInputs(
[inp_expanded, mat1, mat2], scalars=dict(alpha=alpha, beta=beta)
)
# Create MMKernelInputs for AddMM for Aten kernels
kernel_inputs_aten = MMKernelInputs(
[inp, mat1, mat2], scalars=dict(alpha=alpha, beta=beta)
)

choices: list[ChoiceCaller] = []

# below is for getting an overview logging info of inductor mms
Expand All @@ -960,15 +965,9 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
aten_layout = FlexibleLayout(
device=layout.device, dtype=layout.dtype, size=layout.size
)
# TODO(coconutruben): combine this with the main flow of addmm through
# a subgraph or something as inp vs inp_expanded causes some slight numeric
# differences
kernel_inputs = MMKernelInputs(
[inp, mat1, mat2], scalars=dict(alpha=alpha, beta=beta)
)
choices.extend(
V.choices.get_mm_configs(
kernel_inputs,
kernel_inputs_aten,
aten_layout,
[aten_addmm],
name,
Expand All @@ -979,15 +978,15 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
if use_aten_gemm_kernels():
choices.extend(
V.choices.get_mm_configs(
kernel_inputs,
kernel_inputs_aten,
aten_layout,
[aten_bias_addmm],
name,
)
)
choices.extend(
V.choices.get_mm_configs(
kernel_inputs,
kernel_inputs_aten,
aten_layout,
[aten_addmm],
name,
Expand Down