Skip to content

Commit

Permalink
Add meta func for scaled mm (pytorch#112609)
Browse files Browse the repository at this point in the history
# Summary
Adds a meta implementation for _scaled_mm which is required for dynamic shapes

Pull Request resolved: pytorch#112609
Approved by: https://github.com/eellison, https://github.com/malfet
  • Loading branch information
drisspg authored and pytorchmergebot committed Nov 2, 2023
1 parent dd95713 commit 75174c3
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 2 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
TORCH_CHECK(amax.scalar_type() == kFloat, "amax must be a float scalar");
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat1.scalar_type());
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
// Type restrictions imposed by CuBLASLt as of CUDA-12.1
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
"Multiplication of two Float8_e5m2 matrices is not supported");
Expand Down
50 changes: 50 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5179,6 +5179,56 @@ def meta__scaled_dot_product_efficient_backward(
return grad_q, grad_k, grad_v, grad_bias


@register_meta([aten._scaled_mm.default])
def meta_scaled_mm(
self: torch.Tensor,
mat2: torch.Tensor,
bias: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
scale_a: Optional[torch.Tensor] = None,
scale_b: Optional[torch.Tensor] = None,
scale_result: Optional[torch.Tensor] = None,
use_fast_accum: bool = False,
):
def is_row_major(stride):
return stride[0] > stride[1] and stride[1] == 1

def is_col_major(shape, stride):
return stride[0] == 1 and stride[1] == shape[0]

def is_fp8_type(dtype):
return dtype in (torch.float8_e4m3fn, torch.float8_e5m2)

torch._check(
self.dim() == 2 and mat2.dim() == 2,
lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}",
)
torch._check(
is_row_major(self.stride()),
lambda: "self must be row_major",
)
torch._check(
is_col_major(mat2.shape, mat2.stride()),
lambda: "mat2 must be col_major",
)
torch._check(
self.size(1) % 16 == 0,
lambda: f"Expected self.size(0) to be divisible by 16, but got self.size(1)={self.size(1)}",
)
torch._check(
mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0,
lambda: f"Expected both dimensions of mat2 to be divisble by 16 but got {mat2.shape}",
)
torch._check(
is_fp8_type(self.dtype) and is_fp8_type(mat2.dtype),
lambda: f"Expected both inputs to be fp8 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}",
)
_out_dtype = out_dtype if out_dtype is not None else self.dtype
return torch.empty(
self.size(0), mat2.size(1), dtype=_out_dtype, device=self.device
), torch.empty((), dtype=torch.float32, device=self.device)


@register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out])
@out_wrapper()
def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True):
Expand Down
13 changes: 13 additions & 0 deletions torch/testing/_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

_INTEGRAL_TYPES = [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
_FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
_FLOATING_8BIT_TYPES = [torch.float8_e4m3fn, torch.float8_e5m2]
_COMPLEX_TYPES = [torch.complex32, torch.complex64, torch.complex128]
_BOOLEAN_OR_INTEGRAL_TYPES = [torch.bool, *_INTEGRAL_TYPES]
_FLOATING_OR_COMPLEX_TYPES = [*_FLOATING_TYPES, *_COMPLEX_TYPES]
Expand Down Expand Up @@ -217,6 +218,18 @@ def clamp(a: float, l: float, h: float) -> float:
_uniform_random_(
torch.view_as_real(result) if dtype in _COMPLEX_TYPES else result, low, high
)
elif dtype in _FLOATING_8BIT_TYPES:
low, high = modify_low_high(
low,
high,
lowest_inclusive=torch.finfo(dtype).min,
highest_exclusive=torch.finfo(dtype).max,
default_low=-9,
default_high=9,
)
result = torch.empty(shape, device=device, dtype=torch.float32)
_uniform_random_(result, low, high)
result = result.to(dtype)
else:
raise TypeError(
f"The requested dtype '{dtype}' is not supported by torch.testing.make_tensor()."
Expand Down
32 changes: 31 additions & 1 deletion torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
skipCPUIfNoMklSparse,
toleranceOverride, tol)
from torch.testing._internal.common_cuda import (
SM53OrLater, SM60OrLater, SM80OrLater, with_tf32_off, TEST_CUDNN,
SM53OrLater, SM60OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN,
_get_torch_cuda_version, _get_torch_rocm_version,
)
from torch.testing._internal.common_utils import (
Expand Down Expand Up @@ -8176,6 +8176,25 @@ def error_inputs_triplet_margin_loss(op_info, device, **kwargs):
yield ErrorInput(SampleInput(input, args=args, kwargs=kwargs),
error_type=error_type, error_regex=error_regex)

def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs):
make_mat_e4m3 = partial(make_tensor, device=device, dtype=torch.float8_e4m3fn, requires_grad=requires_grad)
make_mat_e5m2 = partial(make_tensor, device=device, dtype=torch.float8_e5m2, requires_grad=requires_grad)
M, N, K = 15, 32, 16
samples = []
# two e4m3
mat1 = make_mat_e4m3((M, K))
mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
samples.append(SampleInput(mat1, mat2))
# mat1 e4m3 mat2 e5m2
mat1 = make_mat_e4m3((M, K))
mat2 = make_mat_e5m2((K, N)).t().contiguous().t()
samples.append(SampleInput(mat1, mat2))
# mat1 e5m2 mat2 e4m3
mat1 = make_mat_e5m2((M, K))
mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
samples.append(SampleInput(mat1, mat2))

yield from samples

def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
Expand Down Expand Up @@ -13690,6 +13709,17 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
'TestUnaryUfuncs', device_type='cuda',
), ],
),
OpInfo(
'torch._scaled_mm',
sample_inputs_func=sample_inputs_scaled_mm,
dtypes=empty_types(),
dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,),
supports_out=True,
supports_forward_ad=False,
supports_autograd=False,
decorators=[skipCUDAIf(not SM90OrLater, 'Requires CUDA SM >= 9.0')],
skips=()
),
OpInfo(
'nn.functional.scaled_dot_product_attention',
op=lambda *args, **kwargs:
Expand Down

0 comments on commit 75174c3

Please sign in to comment.