diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 7024e898938fa4..38cce45ab6e77a 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -753,7 +753,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type"); TORCH_CHECK(amax.scalar_type() == kFloat, "amax must be a float scalar"); TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type()); - TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat1.scalar_type()); + TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type()); // Type restrictions imposed by CuBLASLt as of CUDA-12.1 TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2, "Multiplication of two Float8_e5m2 matrices is not supported"); diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 75d5e0607109a0..68a435392ef979 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5179,6 +5179,56 @@ def meta__scaled_dot_product_efficient_backward( return grad_q, grad_k, grad_v, grad_bias +@register_meta([aten._scaled_mm.default]) +def meta_scaled_mm( + self: torch.Tensor, + mat2: torch.Tensor, + bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + scale_a: Optional[torch.Tensor] = None, + scale_b: Optional[torch.Tensor] = None, + scale_result: Optional[torch.Tensor] = None, + use_fast_accum: bool = False, +): + def is_row_major(stride): + return stride[0] > stride[1] and stride[1] == 1 + + def is_col_major(shape, stride): + return stride[0] == 1 and stride[1] == shape[0] + + def is_fp8_type(dtype): + return dtype in (torch.float8_e4m3fn, torch.float8_e5m2) + + torch._check( + self.dim() == 2 and mat2.dim() == 2, + lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}", + ) + torch._check( + is_row_major(self.stride()), + lambda: "self must be row_major", + ) + torch._check( + is_col_major(mat2.shape, mat2.stride()), + lambda: "mat2 must be col_major", + ) + torch._check( + self.size(1) % 16 == 0, + lambda: f"Expected self.size(0) to be divisible by 16, but got self.size(1)={self.size(1)}", + ) + torch._check( + mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0, + lambda: f"Expected both dimensions of mat2 to be divisble by 16 but got {mat2.shape}", + ) + torch._check( + is_fp8_type(self.dtype) and is_fp8_type(mat2.dtype), + lambda: f"Expected both inputs to be fp8 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}", + ) + _out_dtype = out_dtype if out_dtype is not None else self.dtype + return torch.empty( + self.size(0), mat2.size(1), dtype=_out_dtype, device=self.device + ), torch.empty((), dtype=torch.float32, device=self.device) + + @register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out]) @out_wrapper() def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True): diff --git a/torch/testing/_creation.py b/torch/testing/_creation.py index 6ded7ff5783860..a46d8cf590e407 100644 --- a/torch/testing/_creation.py +++ b/torch/testing/_creation.py @@ -11,6 +11,7 @@ _INTEGRAL_TYPES = [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64] _FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64] +_FLOATING_8BIT_TYPES = [torch.float8_e4m3fn, torch.float8_e5m2] _COMPLEX_TYPES = [torch.complex32, torch.complex64, torch.complex128] _BOOLEAN_OR_INTEGRAL_TYPES = [torch.bool, *_INTEGRAL_TYPES] _FLOATING_OR_COMPLEX_TYPES = [*_FLOATING_TYPES, *_COMPLEX_TYPES] @@ -217,6 +218,18 @@ def clamp(a: float, l: float, h: float) -> float: _uniform_random_( torch.view_as_real(result) if dtype in _COMPLEX_TYPES else result, low, high ) + elif dtype in _FLOATING_8BIT_TYPES: + low, high = modify_low_high( + low, + high, + lowest_inclusive=torch.finfo(dtype).min, + highest_exclusive=torch.finfo(dtype).max, + default_low=-9, + default_high=9, + ) + result = torch.empty(shape, device=device, dtype=torch.float32) + _uniform_random_(result, low, high) + result = result.to(dtype) else: raise TypeError( f"The requested dtype '{dtype}' is not supported by torch.testing.make_tensor()." diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 1817b61b9e9bb3..3a66f5773fc105 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -26,7 +26,7 @@ skipCPUIfNoMklSparse, toleranceOverride, tol) from torch.testing._internal.common_cuda import ( - SM53OrLater, SM60OrLater, SM80OrLater, with_tf32_off, TEST_CUDNN, + SM53OrLater, SM60OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version, _get_torch_rocm_version, ) from torch.testing._internal.common_utils import ( @@ -8176,6 +8176,25 @@ def error_inputs_triplet_margin_loss(op_info, device, **kwargs): yield ErrorInput(SampleInput(input, args=args, kwargs=kwargs), error_type=error_type, error_regex=error_regex) +def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs): + make_mat_e4m3 = partial(make_tensor, device=device, dtype=torch.float8_e4m3fn, requires_grad=requires_grad) + make_mat_e5m2 = partial(make_tensor, device=device, dtype=torch.float8_e5m2, requires_grad=requires_grad) + M, N, K = 15, 32, 16 + samples = [] + # two e4m3 + mat1 = make_mat_e4m3((M, K)) + mat2 = make_mat_e4m3((K, N)).t().contiguous().t() + samples.append(SampleInput(mat1, mat2)) + # mat1 e4m3 mat2 e5m2 + mat1 = make_mat_e4m3((M, K)) + mat2 = make_mat_e5m2((K, N)).t().contiguous().t() + samples.append(SampleInput(mat1, mat2)) + # mat1 e5m2 mat2 e4m3 + mat1 = make_mat_e5m2((M, K)) + mat2 = make_mat_e4m3((K, N)).t().contiguous().t() + samples.append(SampleInput(mat1, mat2)) + + yield from samples def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs): make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -13690,6 +13709,17 @@ def reference_flatten(input, start_dim=0, end_dim=-1): 'TestUnaryUfuncs', device_type='cuda', ), ], ), + OpInfo( + 'torch._scaled_mm', + sample_inputs_func=sample_inputs_scaled_mm, + dtypes=empty_types(), + dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,), + supports_out=True, + supports_forward_ad=False, + supports_autograd=False, + decorators=[skipCUDAIf(not SM90OrLater, 'Requires CUDA SM >= 9.0')], + skips=() + ), OpInfo( 'nn.functional.scaled_dot_product_attention', op=lambda *args, **kwargs: