diff --git a/thunder/core/prims.py b/thunder/core/prims.py index 90e1f3689b..6ef9ccee1a 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -270,6 +270,7 @@ class PrimIDs(Enum): # Linear algebra prims (Mostly experimental) MATMUL = auto() _GROUPED_MM = auto() # Used for grouped matmuls + SCALED_GROUPED_MM = auto() # Used for scaled grouped matmuls # NN prims (Experimental!) CONVOLUTION = auto() EMBEDDING = auto() @@ -3792,6 +3793,128 @@ def _grouped_mm_meta(a: TensorProxy, b: TensorProxy, offsets: TensorProxy) -> Te ) +def scaled_grouped_mm_meta( + a: TensorProxy, + b: TensorProxy, + scale_a: TensorProxy, + scale_b: TensorProxy, + offsets: None | TensorProxy = None, + bias: None | TensorProxy = None, + scale_result: None | TensorProxy = None, + out_dtype: None | dtypes.dtype = None, +) -> TensorProxy: + """Meta function for scaled_grouped_mm primitive. + + Similar to _grouped_mm but with scale tensors for quantization/dequantization. + Accepts the following shape combinations: + 1. (m, k) x (k, n) -> (groups, m, n) + 2. (groups, m, k) x (k, n) -> (m, n) + 3. (m, k) x (groups, k, n) -> (m, n) + + Args: + a: Input tensor of shape (groups, m, k) or (m, k) + b: Input tensor of shape (groups, k, n) or (k, n) + scale_a: Scale tensor for a + scale_b: Scale tensor for b + offsets: Optional offset tensor of shape (groups,) + bias: Optional bias tensor + scale_result: Optional scale tensor for result + out_dtype: Optional output dtype + + Returns: + TensorProxy with shape (groups, m, n) or (m, n) + """ + # Validate types + utils.check_type(a, TensorProxy) + utils.check_type(b, TensorProxy) + utils.check_type(scale_a, TensorProxy) + utils.check_type(scale_b, TensorProxy) + + # Accept 2D or 3D tensors + utils.check(a.ndim in (2, 3), lambda: f"Expected a to have 2 or 3 dimensions, got {a.ndim}") + utils.check(b.ndim in (2, 3), lambda: f"Expected b to have 2 or 3 dimensions, got {b.ndim}") + + # Compute output shape using same logic as _grouped_mm + if offsets is not None: + utils.check_type(offsets, TensorProxy) + utils.check(offsets.ndim == 1, lambda: f"`offsets` must be a vector, got shape {offsets.shape}") + + if a.ndim == 2 and b.ndim == 2: + utils.check(a.shape[1] == b.shape[0], lambda: f"Inner dimension mismatch: {a.shape} vs {b.shape}") + out_shape = (offsets.shape[0], a.shape[0], b.shape[1]) + elif a.ndim == 3 and b.ndim == 2: + utils.check(a.shape[2] == b.shape[1], lambda: f"Inner dimension mismatch: {a.shape} vs {b.shape}") + utils.check(a.shape[0] == offsets.shape[0], lambda: f"Group count mismatch: {a.shape} vs {offsets.shape}") + out_shape = (a.shape[1], b.shape[1]) + elif a.ndim == 2 and b.ndim == 3: + utils.check(a.shape[1] == b.shape[1], lambda: f"Inner dimension mismatch: {a.shape} vs {b.shape}") + utils.check(b.shape[0] == offsets.shape[0], lambda: f"Group count mismatch: {b.shape} vs {offsets.shape}") + out_shape = (a.shape[0], b.shape[2]) + else: + utils.check(False, lambda: f"Unexpected shape combination: {a.shape} and {b.shape}") + else: + # Without offsets, fall back to standard matmul shape logic + if a.ndim == 2 and b.ndim == 2: + utils.check(a.shape[1] == b.shape[0], lambda: f"Inner dimension mismatch: {a.shape} vs {b.shape}") + out_shape = (a.shape[0], b.shape[1]) + elif a.ndim == 3 and b.ndim == 2: + utils.check(a.shape[2] == b.shape[1], lambda: f"Inner dimension mismatch: {a.shape} vs {b.shape}") + out_shape = (a.shape[0], a.shape[1], b.shape[1]) + elif a.ndim == 2 and b.ndim == 3: + utils.check(a.shape[1] == b.shape[1], lambda: f"Inner dimension mismatch: {a.shape} vs {b.shape}") + out_shape = (b.shape[0], a.shape[0], b.shape[2]) + else: + utils.check(False, lambda: f"Unexpected shape combination: {a.shape} and {b.shape}") + + # Validate scale tensors + # Scale tensors are typically 1D with shape matching the number of groups + # or they can be scalars + utils.check( + scale_a.ndim <= 1, + lambda: f"Expected scale_a to be a scalar or 1D tensor, got shape {scale_a.shape}", + ) + utils.check( + scale_b.ndim <= 1, + lambda: f"Expected scale_b to be a scalar or 1D tensor, got shape {scale_b.shape}", + ) + + # Validate bias if provided + if bias is not None: + utils.check_type(bias, TensorProxy) + utils.check_same_device(a, bias) + utils.check_same_dtype(a, bias) + + # Validate scale_result if provided + if scale_result is not None: + utils.check_type(scale_result, TensorProxy) + utils.check( + scale_result.ndim <= 1, + lambda: f"Expected scale_result to be a scalar or 1D tensor, got shape {scale_result.shape}", + ) + + utils.check_same_dtype(a, b) + utils.check(a.dtype in dtypes.float_math_dtypes, lambda: f"`a` must be 16-bit float or higher, got {a.dtype}") + if offsets is not None: + utils.check(utils.is_integer_dtype(offsets.dtype), lambda: f"`offsets` must be integers, got {offsets.dtype}") + + utils.check_same_device(a, b, scale_a, scale_b) + if offsets is not None: + utils.check_same_device(a, offsets) + + # Determine output dtype + result_dtype = out_dtype if out_dtype is not None else a.dtype + + return TensorProxy(like=a, shape=out_shape, dtype=result_dtype) + + +scaled_grouped_mm = make_prim( + PrimIDs.SCALED_GROUPED_MM, + "scaled_grouped_mm", + meta=scaled_grouped_mm_meta, + tags=(OpTags.MATMUL_OP,), +) + + def transpose_meta(a: TensorProxy, /, permutation: tuple[int, ...]) -> TensorProxy: utils.check_type(a, TensorProxy) utils.check_type(permutation, tuple) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 84a8ea9f93..da006a7641 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1552,6 +1552,119 @@ def _copy_with_setitem_impl(a, key, value): "torch.ops.aten._adaptive_avg_pool2d_backward", like=ltorch.adaptive_avg_pool2d_backward ) multi_dot = _register_torch_operation("torch.linalg.multi_dot", like=ltorch.multi_dot) +if hasattr(torch.nn.functional, "scaled_grouped_mm"): + # PyTorch 2.10+ introduced scaled_grouped_mm with ScalingType/SwizzleType enums + if hasattr(torch.nn.functional, "ScalingType"): + # PyTorch 2.10+: scaled_grouped_mm is a new API with specific requirements + def _scaled_grouped_mm_impl( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + offsets: None | torch.Tensor = None, + bias: None | torch.Tensor = None, + scale_result: None | torch.Tensor = None, + out_dtype: None | torch.dtype = None, + ) -> torch.Tensor: + """Wrapper for PyTorch 2.10+ scaled_grouped_mm API. + + PyTorch 2.10 introduced scaled_grouped_mm with requirements for: + - mat_b to have transposed memory layout (create as [G, N, K] then .transpose(-2, -1)) + - Specific scale formats based on scale tensor shapes (infer RowWise vs TensorWise) + - ScalingType and SwizzleType enums for quantization control + """ + num_groups = offsets.numel() if offsets is not None else 1 + + # Transpose b to match PyTorch's expected memory layout + # PyTorch expects b in transposed form: shape [G, K, N] but with strides of [G, N, K].transpose(-2,-1) + b_transposed = b.transpose(-2, -1) if b.ndim >= 2 else b + + # Infer scaling type and format scales appropriately + # For 2D x 3D case: a is [M, K], b is [G, K, N] (or [G, N, K] transposed) + # PyTorch expects: + # - If scale is scalar (0D) or has 1 element: TensorWise + # - If scale is 1D with length matching rows: RowWise (needs reshaping) + # - If scale is 2D: RowWise + + # Handle scale_a + if scale_a.numel() == 1: + # Scalar scale - TensorWise + scale_a_list = [scale_a.view(1)] * num_groups + scale_recipe_a = [torch.nn.functional.ScalingType.TensorWise] * num_groups + elif scale_a.dim() == 1: + # 1D scale - could be TensorWise (if 1 elem) or RowWise + # For RowWise in 2D x 3D case: scale_a should be (num_groups * M,) which gets reshaped to (num_groups, M, 1) + if a.dim() == 2: + # a is [M, K], scale_a should be expandable to rowwise format + scale_a_2d = scale_a.view(num_groups, -1, 1) # [G, M, 1] + scale_a_list = [scale_a_2d[i] for i in range(num_groups)] + scale_recipe_a = [torch.nn.functional.ScalingType.RowWise] * num_groups + else: + # Fallback to splitting + if scale_a.size(0) == num_groups: + scale_a_list = [scale_a[i : i + 1] for i in range(num_groups)] + scale_recipe_a = [torch.nn.functional.ScalingType.TensorWise] * num_groups + else: + # Try RowWise + scale_a_2d = scale_a.view(num_groups, -1, 1) + scale_a_list = [scale_a_2d[i] for i in range(num_groups)] + scale_recipe_a = [torch.nn.functional.ScalingType.RowWise] * num_groups + else: + # 2D scale - RowWise + scale_a_list = [ + scale_a[i].unsqueeze(-1) if scale_a[i].dim() == 1 else scale_a[i] for i in range(scale_a.size(0)) + ] + scale_recipe_a = [torch.nn.functional.ScalingType.RowWise] * num_groups + + # Handle scale_b + if scale_b.numel() == 1: + # Scalar scale - TensorWise + scale_b_list = [scale_b.view(1)] * num_groups + scale_recipe_b = [torch.nn.functional.ScalingType.TensorWise] * num_groups + elif scale_b.dim() == 1: + # 1D scale + if scale_b.size(0) == num_groups: + scale_b_list = [scale_b[i : i + 1] for i in range(num_groups)] + scale_recipe_b = [torch.nn.functional.ScalingType.TensorWise] * num_groups + else: + # RowWise: reshape to (num_groups, N, 1) for transposed b + scale_b_2d = scale_b.view(num_groups, -1, 1) + scale_b_list = [scale_b_2d[i] for i in range(num_groups)] + scale_recipe_b = [torch.nn.functional.ScalingType.RowWise] * num_groups + else: + # 2D scale - RowWise + # b is [G, K, N], after transpose becomes [G, N, K] + # For RowWise on transposed b, scale should be [G, N, 1] + scale_b_list = [ + scale_b[i].unsqueeze(-1) if scale_b[i].dim() == 1 else scale_b[i] for i in range(scale_b.size(0)) + ] + scale_recipe_b = [torch.nn.functional.ScalingType.RowWise] * num_groups + + # Create swizzle parameters (no swizzle) + swizzle_a = [torch.nn.functional.SwizzleType.NO_SWIZZLE] * num_groups + swizzle_b = [torch.nn.functional.SwizzleType.NO_SWIZZLE] * num_groups + + return torch.nn.functional.scaled_grouped_mm( + a, + b_transposed, + scale_a_list, + scale_recipe_a, + scale_b_list, + scale_recipe_b, + swizzle_a=swizzle_a, + swizzle_b=swizzle_b, + bias=bias, + offs=offsets, + output_dtype=out_dtype, + ) + + scaled_grouped_mm = ex.register_operator( + "scaled_grouped_mm", like=ltorch.scaled_grouped_mm, fn=_scaled_grouped_mm_impl + ) + else: + # PyTorch < 2.10: scaled_grouped_mm doesn't exist yet + # This branch shouldn't be reached if the function exists + scaled_grouped_mm = _register_torch_operation("scaled_grouped_mm", module=torch.nn.functional) def _max_pool_with_indices_helper( @@ -1823,6 +1936,32 @@ def _grouped_mm_checker(a: TensorProxy, b: TensorProxy, offsets: TensorProxy) -> return a.dtype == dtypes.bfloat16 and b.dtype == dtypes.bfloat16 and offsets.dtype == dtypes.int32 +def _scaled_grouped_mm_checker( + a: TensorProxy, + b: TensorProxy, + scale_a: TensorProxy, + scale_b: TensorProxy, + offsets: None | TensorProxy = None, + bias: None | TensorProxy = None, + scale_result: None | TensorProxy = None, + out_dtype: None | dtypes.dtype = None, +) -> bool: + if not hasattr(torch.nn.functional, "scaled_grouped_mm"): + return False + + if not torch.cuda.is_available(): + return False + + capability = torch.cuda.get_device_capability() + if capability < (9, 0): + return False + + if torch.float4_e2m1fn_x2 in (a.dtype, b.dtype): + return False + + return True + + _register_implementation(ltorch.baddbmm, baddbmm, checker=_always_executable) _register_implementation(ltorch.bmm, bmm, checker=_always_executable) if LooseVersion(torch.__version__) >= "2.8": @@ -1846,6 +1985,9 @@ def _grouped_mm_checker(a: TensorProxy, b: TensorProxy, offsets: TensorProxy) -> ltorch.log_softmax_backward, checker=_always_executable, execution_transform=_log_softmax_backward_transform ) _register_implementation(ltorch.max_pool1d, max_pool1d, checker=_always_executable) +if hasattr(torch.nn.functional, "scaled_grouped_mm"): + _register_implementation(prims.scaled_grouped_mm, scaled_grouped_mm, checker=_scaled_grouped_mm_checker) + _register_implementation(ltorch.scaled_grouped_mm, scaled_grouped_mm, checker=_scaled_grouped_mm_checker) def max_pool2d_bwd_wrapper( diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 33df903a10..67b3a519a7 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -7271,6 +7271,273 @@ def _grouped_mm_reference(a, b, offsets): linear_algebra_ops.append(_grouped_mm_opinfo) +def _quantize_to_fp8_e4m3fn(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize bfloat16 tensor to fp8_e4m3fn format. + + Reference: https://github.com/pytorch/pytorch/blob/b4403bfc62ca97eec554cdf815baab1fe93057d9/test/test_scaled_matmul_cuda.py + This function follows the quantization pattern used in PyTorch's scaled matmul tests. + + Returns: + (quantized_tensor, scale): quantized tensor and scale factor + """ + amax = tensor.abs().max().to(torch.float32) + fp8_max = torch.finfo(torch.float8_e4m3fn).max + scale = fp8_max / torch.clamp(amax, min=torch.finfo(torch.float32).tiny) + quantized = (tensor * scale).clamp(min=-fp8_max, max=fp8_max).to(torch.float8_e4m3fn) + return quantized, scale.to(torch.bfloat16) + + +def _quantize_to_fp8_e5m2(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize bfloat16 tensor to fp8_e5m2 format. + + Reference: https://github.com/pytorch/pytorch/blob/b4403bfc62ca97eec554cdf815baab1fe93057d9/test/test_scaled_matmul_cuda.py + This function follows the quantization pattern used in PyTorch's scaled matmul tests. + + Returns: + (quantized_tensor, scale): quantized tensor and scale factor + """ + amax = tensor.abs().max().to(torch.float32) + fp8_max = torch.finfo(torch.float8_e5m2).max + scale = fp8_max / torch.clamp(amax, min=torch.finfo(torch.float32).tiny) + quantized = (tensor * scale).clamp(min=-fp8_max, max=fp8_max).to(torch.float8_e5m2) + return quantized, scale.to(torch.bfloat16) + + +def _quantize_to_nvfp4(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Quantize bfloat16 tensor to nvfp4 (float4_e2m1fn_x2) format. + + Reference: https://github.com/pytorch/pytorch/blob/b4403bfc62ca97eec554cdf815baab1fe93057d9/test/test_scaled_matmul_cuda.py + Uses block-wise quantization similar to pytorch_nvfp4_quantize. + Requires last dimension to be divisible by 16. + + Returns: + (quantized_tensor, block_scales, global_scale): quantized tensor, block scales (fp8_e4m3fn), and global scale + """ + BLOCK_SIZE = 16 + FLOAT4_E2M1_MAX = 6.0 + FLOAT8_E4M3_EPS = torch.finfo(torch.float8_e4m3fn).tiny + FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + + assert tensor.size(-1) % BLOCK_SIZE == 0, f"Last dimension must be divisible by {BLOCK_SIZE}" + assert tensor.is_contiguous(), "Tensor must be contiguous" + + original_shape = tensor.shape + tensor_fp32 = tensor.float().reshape(original_shape[0], -1, BLOCK_SIZE) + + # Find absolute maximum along blockwise dimension + max_abs = torch.amax(torch.abs(tensor_fp32), dim=-1) + block_scale_fp32 = (max_abs / FLOAT4_E2M1_MAX).float() + + # Compute global scale + global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / tensor_fp32.abs().amax() + + scaled_block_scale_fp32 = block_scale_fp32 * global_scale + scaled_block_scale_fp8 = torch.clamp( + scaled_block_scale_fp32, + min=FLOAT8_E4M3_EPS, + max=FLOAT8_E4M3_MAX, + ).to(torch.float8_e4m3fn) + scaled_block_scale_fp8_fp32 = scaled_block_scale_fp8.to(torch.float) + total_scale = scaled_block_scale_fp8_fp32 / global_scale + tensor_scaled = tensor_fp32 / total_scale.unsqueeze(-1) + tensor_scaled = torch.clamp(tensor_scaled, -FLOAT4_E2M1_MAX, FLOAT4_E2M1_MAX) + tensor_scaled = tensor_scaled.view(original_shape) + + # Convert to fp4 format + from torch.testing._internal.common_quantized import _f32_to_floatx_unpacked + + # Convert to uint4 and pack + tensor_uint4 = _f32_to_floatx_unpacked(tensor_scaled.flatten().float(), ebits=2, mbits=1) + # Pack uint4 values (2 uint4 values per uint8) + assert tensor_uint4.shape[-1] % 2 == 0 + tensor_uint4 = tensor_uint4.contiguous().view(-1) + packed = (tensor_uint4[::2] << 4) | tensor_uint4[1::2] + # Reshape packed tensor - last dimension is halved + if len(original_shape) == 2: + packed = packed.view(original_shape[0], original_shape[1] // 2) + else: + # For 3D, flatten first two dims, then reshape + packed = packed.view(-1, original_shape[-1] // 2) + quantized = packed.view(torch.float4_e2m1fn_x2) + + # Return block scales as-is (shape: [num_blocks]) + block_scales = scaled_block_scale_fp8 + + return quantized, block_scales, global_scale.to(torch.bfloat16) + + +def scaled_grouped_mm_sample_generator(op, device, dtype, requires_grad, **kwargs): + """Sample generator for scaled_grouped_mm based on PyTorch 2D x 2D test patterns. + + Reference: https://github.com/pytorch/pytorch/blob/b4403bfc62ca97eec554cdf815baab1fe93057d9/test/test_scaled_matmul_cuda.py + + Starting with simple 2D x 2D matmul cases (non-grouped) to match PyTorch's basic patterns. + PyTorch 2.10's scaled_grouped_mm is primarily designed for FP8 quantization. + """ + # Always create bfloat16 tensors first, then quantize if needed + make_bf16 = partial(make_tensor, device=device, dtype=torch.bfloat16, requires_grad=False) + + # Set appropriate value ranges for bfloat16 tensors + input_low, input_high = -1.0, 1.0 + scale_low, scale_high = 0.1, 1.0 + comp = TorchTensorComp(atol=1e-1, rtol=1e-1) + + # Test case: Simple 2D x 2D matmul [M, K] @ [N, K].T -> [M, N] + # PyTorch expects b to be transposed: create as [N, K] then transpose to [K, N] + M, K, N = 16, 32, 64 + + # Create bfloat16 tensors + a_bf16 = make_bf16((M, K), low=input_low, high=input_high) + # Create b as [N, K] then transpose to [K, N] with transposed strides + b_bf16 = make_bf16((N, K), low=input_low, high=input_high).transpose(-2, -1) + + # For bfloat16: use tensors directly with scalar scales (no quantization) + if dtype == datatypes.bfloat16: + a = a_bf16 + b = b_bf16 + # Use scalar scales for simplicity (TensorWise scaling) + scale_a = make_bf16((), low=scale_low, high=scale_high) + scale_b = make_bf16((), low=scale_low, high=scale_high) + + # Simple 2D x 2D case without offsets + si = SampleInput(a, b, scale_a, scale_b) + si.set_comparator(comp) + yield si + + # TODO: Add support for FP8 and other quantized dtypes + # These require proper quantization and more complex scale handling + + +def scaled_grouped_mm_reference(a, b, scale_a, scale_b, offsets=None, bias=None, scale_result=None, out_dtype=None): + """Reference implementation for scaled_grouped_mm. + + Reference: https://github.com/pytorch/pytorch/blob/b4403bfc62ca97eec554cdf815baab1fe93057d9/test/test_scaled_matmul_cuda.py + Uses manual implementation as reference since PyTorch 2.10's new scaled_grouped_mm API + has very specific requirements for tensor layouts and scale formats. + """ + # Use manual implementation as reference + # PyTorch 2.10 introduced scaled_grouped_mm but it has strict layout requirements + + # Fallback implementation: apply scales and use grouped matmul + # This is a simplified reference - actual implementation may differ + if offsets is not None: + num_groups = offsets.numel() + group_sizes = _group_sizes_from_offsets(offsets) + + if a.dim() == 2 and b.dim() == 3: + # [m, k] @ [g, k, n] => [m, n] + group_as = a.split(group_sizes, 0) + group_bs = b.unbind() + group_scale_as = scale_a.unbind() if scale_a.dim() > 0 else [scale_a] * num_groups + group_scale_bs = scale_b.unbind() if scale_b.dim() > 0 else [scale_b] * num_groups + + out = torch.empty(a.size(0), b.size(-1), dtype=a.dtype, device=a.device) + group_outs = out.split(group_sizes, 0) + + group_scale_results = None + if scale_result is not None: + group_scale_results = scale_result.unbind() if scale_result.dim() > 0 else [scale_result] * num_groups + + for i, (group_a, group_b, group_out) in enumerate(zip(group_as, group_bs, group_outs)): + sa = group_scale_as[i] + sb = group_scale_bs[i] + + # Apply scales + scaled_a = group_a * sa.item() if sa.numel() == 1 else group_a * sa.view(-1, 1) + scaled_b = group_b * sb.item() if sb.numel() == 1 else group_b * sb.view(-1, 1) + + result = torch.matmul(scaled_a, scaled_b) + + # Apply scale_result if provided + if group_scale_results is not None: + sr = group_scale_results[i] + result = result * sr.item() if sr.numel() == 1 else result * sr.view(-1, 1) + + # Apply bias if provided + if bias is not None: + result = result + bias + + group_out.copy_(result) + + if out_dtype is not None: + return out.to(out_dtype) + return out + else: + # Simplified fallback for other shape combinations + scaled_a = a * scale_a.item() if scale_a.numel() == 1 else a * scale_a.view(-1, 1, 1) + scaled_b = b * scale_b.item() if scale_b.numel() == 1 else b * scale_b.view(-1, 1, 1) + result = _grouped_mm_reference(scaled_a, scaled_b, offsets) + + if scale_result is not None: + result = ( + result * scale_result.item() if scale_result.numel() == 1 else result * scale_result.view(-1, 1, 1) + ) + + if bias is not None: + result = result + bias + + if out_dtype is not None: + return result.to(out_dtype) + return result + else: + # Without offsets, simpler case + scaled_a = a * scale_a.item() if scale_a.numel() == 1 else a * scale_a.view(-1, 1) + scaled_b = b * scale_b.item() if scale_b.numel() == 1 else b * scale_b.view(-1, 1, 1) + result = torch.matmul(scaled_a, scaled_b) + + if scale_result is not None: + result = result * scale_result.item() if scale_result.numel() == 1 else result * scale_result.view(-1, 1) + + if bias is not None: + result = result + bias + + if out_dtype is not None: + return result.to(out_dtype) + return result + + +if hasattr(ltorch, "scaled_grouped_mm"): + # scaled_grouped_mm supports quantization dtypes: bfloat16, fp8, and nvfp4 + # fp8 dtypes: float8_e4m3fn and float8_e5m2 are commonly used for quantization + # nvfp4: float4_e2m1fn_x2 is used for 4-bit quantization + # + # Reference: https://github.com/pytorch/pytorch/blob/b4403bfc62ca97eec554cdf815baab1fe93057d9/test/test_scaled_matmul_cuda.py + # The quantization functions and test patterns follow PyTorch's scaled_matmul_cuda test file, + # which quantizes bfloat16 tensors to fp8/nvfp4 formats for testing. + scaled_grouped_mm_dtypes = ( + datatypes.bfloat16, + datatypes.float8_e4m3fn, + datatypes.float8_e5m2, + datatypes.float4_e2m1fn_x2, + ) + + scaled_grouped_mm_opinfo = OpInfo( + ltorch.scaled_grouped_mm, + supports_grad=False, + sample_input_generator=scaled_grouped_mm_sample_generator, + torch_reference=scaled_grouped_mm_reference, + dtypes=scaled_grouped_mm_dtypes, + devicetypes=(devices.DeviceType.CUDA,), + test_directives=( + DecorateInfo( + pytest.mark.skip, + "test_core_vs_torch_consistency", + executors=("torch",), + # torch.nn.functional.scaled_grouped_mm may not support pre-Hopper. + active_if=(not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0)), + ), + # fp8 and nvfp4 require Hopper (compute capability 9.0+) + DecorateInfo( + pytest.mark.skip, + "test_core_vs_torch_consistency", + dtypes=(datatypes.float8_e4m3fn, datatypes.float8_e5m2, datatypes.float4_e2m1fn_x2), + active_if=(not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0)), + ), + ), + ) + linear_algebra_ops.append(scaled_grouped_mm_opinfo) + + def einsum_sample_generator(op, device, dtype, requires_grad, **kwargs): make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 81e0d199bd..88e929bcd5 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -5724,6 +5724,32 @@ def _grouped_mm( return prims._grouped_mm(a, b, offsets) +if hasattr(torch.nn.functional, "scaled_grouped_mm"): + + @torchsymbol(torch.nn.functional.scaled_grouped_mm, is_method=False, is_prim=True) + def scaled_grouped_mm( + a: TensorProxy, + b: TensorProxy, + scale_a: TensorProxy, + scale_b: TensorProxy, + offsets: None | TensorProxy = None, + bias: None | TensorProxy = None, + scale_result: None | TensorProxy = None, + out_dtype: None | dtypeLike = None, + ) -> TensorProxy: + utils.check(bias is None, lambda: "`bias` is not supported yet.") + utils.check(scale_result is None, lambda: "`scale_result` is not supported yet.") + utils.check( + a.ndim in (2, 3) and b.ndim in (2, 3), + lambda: f"a and b are expected to be 2D or 3D but {a.ndim=} dims and b has {b.ndim=}", + ) + # Uses the primitive implementation + out_dtype_thunder = to_dtype(out_dtype) if out_dtype is not None else None + return prims.scaled_grouped_mm( + a, b, scale_a, scale_b, offsets=offsets, bias=bias, scale_result=scale_result, out_dtype=out_dtype_thunder + ) + + @torchsymbol(torch.logsumexp, is_method=True) def logsumexp(a: TensorLike, /, dim: int | Sequence[int], keepdim: bool = False) -> TensorLike: input_max = amax(a, dim, keepdim=True)