diff --git a/fbgemm_gpu/src/jagged_tensor_ops.cu b/fbgemm_gpu/src/jagged_tensor_ops.cu index 04c0db5457..1971c61477 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops.cu @@ -1288,8 +1288,9 @@ class DenseToJaggedGPUOp }); // device lambda } // lambda ) // CASE - AT_DISPATCH_CASE_FLOATING_TYPES_AND( + AT_DISPATCH_CASE_FLOATING_TYPES_AND2( at::ScalarType::Long, + at::ScalarType::BFloat16, [&] { jagged_dense_elementwise_jagged_output_( values, @@ -1364,17 +1365,19 @@ class JaggedDenseDenseAddJaggedOutputGPUOp -> scalar_t { return x + y_0 + y_1; }); } // lambda ) // CASE - AT_DISPATCH_CASE_FLOATING_TYPES([&] { - jagged_dense_dense_elementwise_jagged_output_( - x_values, - offsets, - dense_0, - dense_1, - output, - [] __device__(scalar_t x, scalar_t y_0, scalar_t y_1) - -> scalar_t { return x + y_0 + y_1; }); - } // lambda - ) // CASE_FLOATING_TYPES_AND + AT_DISPATCH_CASE_FLOATING_TYPES_AND( + at::ScalarType::BFloat16, + [&] { + jagged_dense_dense_elementwise_jagged_output_( + x_values, + offsets, + dense_0, + dense_1, + output, + [] __device__(scalar_t x, scalar_t y_0, scalar_t y_1) + -> scalar_t { return x + y_0 + y_1; }); + } // lambda + ) // CASE_FLOATING_TYPES_AND ); // SWITCH return {output}; @@ -1447,17 +1450,19 @@ class JaggedDenseAddJaggedOutputGPUOp }); // device lambda } // lambda ) // CASE - AT_DISPATCH_CASE_FLOATING_TYPES([&] { - jagged_dense_elementwise_jagged_output_( - x_values, - offsets, - dense, - output, - [] __device__(scalar_t x, scalar_t y) -> scalar_t { - return x + y; - }); // device lambda - } // lambda - ) // CASE_FLOATING_TYPES_AND + AT_DISPATCH_CASE_FLOATING_TYPES_AND( + at::ScalarType::BFloat16, + [&] { + jagged_dense_elementwise_jagged_output_( + x_values, + offsets, + dense, + output, + [] __device__(scalar_t x, scalar_t y) -> scalar_t { + return x + y; + }); // device lambda + } // lambda + ) // CASE_FLOATING_TYPES_AND ); // SWITCH return {output}; @@ -1660,17 +1665,19 @@ class JaggedDenseMulGPUOp }); } // lambda ) // CASE - AT_DISPATCH_CASE_FLOATING_TYPES([&] { - jagged_dense_elementwise_jagged_output_( - x_values, - x_offsets, - y, - output, - [] __device__(scalar_t x, scalar_t y) -> scalar_t { - return x * y; - }); - } // lambda - ) // CASE_FLOATING_TYPES_AND + AT_DISPATCH_CASE_FLOATING_TYPES_AND( + at::ScalarType::BFloat16, + [&] { + jagged_dense_elementwise_jagged_output_( + x_values, + x_offsets, + y, + output, + [] __device__(scalar_t x, scalar_t y) -> scalar_t { + return x * y; + }); + } // lambda + ) // CASE_FLOATING_TYPES_AND ); // SWITCH return {output}; @@ -1693,8 +1700,12 @@ class JaggedDenseMulGPUOp Tensor x_values_grad = at::empty_like(grad_outputs[0]); Tensor y_grad = at::empty_like(y); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - x_values.scalar_type(), "jagged_scalars", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + x_values.scalar_type(), + "jagged_scalars", + [&] { jagged_dense_elementwise_jagged_output_( grad_outputs[0], x_offsets, @@ -2115,8 +2126,12 @@ Tensor stacked_jagged_2d_to_dense_backward_cuda( Tensor grad_values_slice = grad_values.slice(0, offset_per_key[t], offset_per_key[t + 1]); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad_values.scalar_type(), "jagged_2d_to_dense_backward_kernel", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + grad_values.scalar_type(), + "jagged_2d_to_dense_backward_kernel", + [&] { jagged_dense_elementwise_jagged_output_( grad_values_slice, // dummy not used in the lambda function {offsets_tensor_per_key[t]}, @@ -2293,8 +2308,9 @@ Tensor jagged_index_select_2d_cuda( at::empty({num_dense_output_rows, num_cols}, values.options()); if (num_blocks > 0) { - AT_DISPATCH_ALL_TYPES_AND( + AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, + at::ScalarType::BFloat16, values.scalar_type(), "jagged_index_select_2d_kernel_wrapper_1", [&] { @@ -2388,8 +2404,9 @@ Tensor jagged_index_add_2d_cuda( Tensor output = at::zeros({num_output_rows, num_cols}, grad.options()); if (num_blocks > 0) { - AT_DISPATCH_ALL_TYPES_AND( + AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, + at::ScalarType::BFloat16, grad.scalar_type(), "jagged_index_add_2d_kernel_wrapper_1", [&] { @@ -2835,8 +2852,9 @@ class KeyedJaggedIndexSelectDim1GPUOp num_outputs); \ } - AT_DISPATCH_ALL_TYPES_AND( + AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, + at::ScalarType::BFloat16, values.scalar_type(), "keyed_jagged_index_select_dim1_warpper_1", [&] { @@ -2914,8 +2932,9 @@ class KeyedJaggedIndexSelectDim1GPUOp Tensor grad_input = at::zeros({num_outputs}, grad.options()); auto grid_size = cuda_calc_xblock_count(grad.numel(), kMaxThreads); - AT_DISPATCH_ALL_TYPES_AND( + AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, + at::ScalarType::BFloat16, grad.scalar_type(), "keyed_jagged_index_add_dim1_wrapper_1", [&] { diff --git a/fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp b/fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp index 65b8df5ffc..2d6236ec99 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp @@ -405,8 +405,9 @@ class JaggedToPaddedDenseCPUOp Tensor padded_values_view = values.dim() == 1 ? padded_values.unsqueeze(-1) : padded_values; - AT_DISPATCH_ALL_TYPES_AND( + AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, + at::ScalarType::BFloat16, values.scalar_type(), "jagged_to_padded_dense", [&] { @@ -437,8 +438,9 @@ class JaggedToPaddedDenseCPUOp // in forward. auto grad_values = at::zeros({total_L, D}, grad_padded_values.options()); - AT_DISPATCH_ALL_TYPES_AND( + AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, + at::ScalarType::BFloat16, grad_padded_values.scalar_type(), "jagged_2d_to_dense_backward_kernel", [&] { @@ -519,9 +521,9 @@ class DenseToJaggedCPUOp auto values = at::empty({total_L_computed, D}, dense.options()); auto output = at::zeros({total_L_computed, D}, dense.options()); - AT_DISPATCH_FLOATING_TYPES_AND2( + AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, - at::ScalarType::Long, + at::ScalarType::BFloat16, values.scalar_type(), "jagged_scalars", [&] { @@ -886,7 +888,9 @@ class BatchedDenseVecJagged2DMulCPUOp if (B > 0 && D > 0) { AT_DISPATCH_INDEX_TYPES( a_offsets.scalar_type(), "dense_vec_jagged_2d_bmm_kernel_1", [&] { - AT_DISPATCH_FLOATING_TYPES_AND_HALF( + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, a_values.scalar_type(), "dense_vec_jagged_2d_bmm_kernel_2", [&] { @@ -925,7 +929,9 @@ class BatchedDenseVecJagged2DMulCPUOp a_offsets.scalar_type(), "dense_vec_jagged_2d_bmm_baackward_kernel_1", [&] { - AT_DISPATCH_FLOATING_TYPES_AND_HALF( + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, grad_outputs[0].scalar_type(), "dense_vec_jagged_2d_bmm_baackward_kernel_2", [&] { @@ -974,8 +980,12 @@ Tensor jagged_1d_to_truncated_values_cpu( Tensor truncated_values; AT_DISPATCH_INDEX_TYPES( lengths.scalar_type(), "jagged_1d_to_truncated_values_cpu_kernel", [&] { - AT_DISPATCH_ALL_TYPES_AND_HALF( - values.scalar_type(), "copy_values_and_truncate_cpu_kernel", [&] { + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + values.scalar_type(), + "copy_values_and_truncate_cpu_kernel", + [&] { const index_t max_length_int = static_cast(max_truncated_length); const auto lengths_accessor = lengths.accessor(); @@ -1021,8 +1031,12 @@ std::tuple masked_select_jagged_1d( AT_DISPATCH_INDEX_TYPES( lengths.scalar_type(), "mask_select_jagged_1d_kernel1", [&] { - AT_DISPATCH_ALL_TYPES_AND_HALF( - values.scalar_type(), "mask_select_jagged_1d_kernel2", [&] { + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + values.scalar_type(), + "mask_select_jagged_1d_kernel2", + [&] { const int32_t num_outputs = mask.sum().item(); masked_values = at::empty({num_outputs}, values.options()); diff --git a/fbgemm_gpu/test/jagged_tensor_ops_test.py b/fbgemm_gpu/test/jagged_tensor_ops_test.py index 2b1261548a..ef31ad3704 100644 --- a/fbgemm_gpu/test/jagged_tensor_ops_test.py +++ b/fbgemm_gpu/test/jagged_tensor_ops_test.py @@ -136,14 +136,14 @@ def test_expand_into_jagged_permute( B=st.integers(min_value=1, max_value=128), D=st.integers(min_value=1, max_value=128), max_sequence_length=st.integers(min_value=1, max_value=200), - is_half=st.booleans(), + dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]), ) def test_jagged_2d_to_dense( self, B: int, D: int, max_sequence_length: int, - is_half: bool, + dtype: torch.dtype, ) -> None: D = D * 4 lengths_ = np.random.randint(low=0, high=max_sequence_length, size=B) @@ -158,14 +158,10 @@ def test_jagged_2d_to_dense( max_sequence_length, D, ).to_dense() - if is_half: - ref_output_values = ref_output_values.half() + ref_output_values = ref_output_values.to(dtype) # test cpu forward - if is_half: - values = ref_values.clone().half().detach().requires_grad_(True) - else: - values = ref_values.clone().detach().requires_grad_(True) + values = ref_values.clone().to(dtype).detach().requires_grad_(True) output_values = torch.ops.fbgemm.jagged_2d_to_dense( values=values, offsets=offsets, @@ -176,10 +172,7 @@ def test_jagged_2d_to_dense( if torch.cuda.is_available(): # test gpu forward ref_values = ref_values.cuda() - if is_half: - values = ref_values.clone().half().detach().requires_grad_(True) - else: - values = ref_values.clone().detach().requires_grad_(True) + values = ref_values.clone().to(dtype).detach().requires_grad_(True) offsets = offsets.cuda() ref_output_values = ref_output_values.cuda() output_values = torch.ops.fbgemm.jagged_2d_to_dense( @@ -191,8 +184,7 @@ def test_jagged_2d_to_dense( # test gpu backward output_values.backward(ref_output_values) - if is_half: - ref_values = ref_values.half() + ref_values = ref_values.to(dtype) torch.testing.assert_close(ref_values, values.grad) def test_jagged_2d_to_dense_truncation(self) -> None: @@ -627,7 +619,7 @@ def _test_dense_to_jagged( num_jagged_dim=st.integers(1, 5), outer_dense_size=st.integers(0, 5), inner_dense_size=st.integers(0, 5), - dtype=st.sampled_from([torch.float, torch.half]), + dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]), use_cpu=st.booleans() if gpu_available else st.just(True), precompute_total_L=st.booleans(), ) @@ -874,7 +866,7 @@ def mul_func(*args) -> torch.Tensor: outer_dense_size=st.integers(0, 4), inner_dense_size=st.integers(0, 4), operation=st.sampled_from(["add", "add_jagged_output", "mul"]), - dtype=st.sampled_from([torch.float, torch.half, torch.double]), + dtype=st.sampled_from([torch.float, torch.half, torch.double, torch.bfloat16]), use_cpu=st.booleans() if gpu_available else st.just(True), ) @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) @@ -998,7 +990,7 @@ def add_jagged_output_func(*args) -> torch.Tensor: num_jagged_dim=st.integers(1, 4), outer_dense_size=st.integers(0, 4), inner_dense_size=st.integers(0, 4), - dtype=st.sampled_from([torch.float, torch.half, torch.double]), + dtype=st.sampled_from([torch.float, torch.half, torch.double, torch.bfloat16]), use_cpu=st.booleans() if gpu_available else st.just(True), ) @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) @@ -1084,22 +1076,22 @@ def test_batched_dense_vec_jagged_2d_mul( .reshape(B * H, max_L, D) ) # torch.bmm not implemented for Half on CPU - if dtype == torch.half and use_cpu: + if dtype in [torch.half, torch.bfloat16] and use_cpu: bmm_arg1 = bmm_arg1.float() bmm_arg2 = bmm_arg2.float() output_ref = torch.bmm(bmm_arg1, bmm_arg2).squeeze( 1 ) # [B H, 1, N] x [B H, N, D] = [B H, 1, D] - if dtype == torch.half and use_cpu: - output_ref = output_ref.half() + if dtype in [torch.half, torch.bfloat16] and use_cpu: + output_ref = output_ref.to(dtype) output = torch.ops.fbgemm.batched_dense_vec_jagged_2d_mul( dense, values, offsets ) torch.testing.assert_close( output, output_ref, - rtol=1e-2 if dtype == torch.half else None, - atol=1e-2 if dtype == torch.half else None, + rtol=1e-2 if dtype in [torch.half, torch.bfloat16] else None, + atol=1e-2 if dtype in [torch.half, torch.bfloat16] else None, ) torch.autograd.gradcheck( @@ -1139,7 +1131,12 @@ def jagged_index_select_2d_ref( num_jagged_tensor_rows=st.integers(1, 128), index_dtype=st.sampled_from([torch.int, torch.long]), jagged_tensor_dtype=st.sampled_from( - [torch.float, torch.half, torch.int, torch.long] + [ + torch.float, + torch.half, + torch.int, + torch.long, + ] # Disable torch.bfloat16 due to large error bound ), ) @settings(max_examples=20, deadline=None) @@ -1152,7 +1149,7 @@ def test_jagged_index_select_2d( index_dtype: torch.dtype, jagged_tensor_dtype: torch.dtype, ) -> None: - is_float = jagged_tensor_dtype in [torch.float, torch.half] + is_float = jagged_tensor_dtype in [torch.float, torch.half, torch.bfloat16] lengths = torch.randint( low=0, high=max_seq_length, @@ -1207,8 +1204,8 @@ def test_jagged_index_select_2d( torch.testing.assert_close( values.grad, values_ref.grad, - rtol=1e-2 if jagged_tensor_dtype == torch.half else None, - atol=1e-2 if jagged_tensor_dtype == torch.half else None, + rtol=1e-2 if jagged_tensor_dtype in [torch.half, torch.bfloat16] else None, + atol=1e-2 if jagged_tensor_dtype in [torch.half, torch.bfloat16] else None, ) # pyre-ignore [56] @@ -1218,7 +1215,7 @@ def test_jagged_index_select_2d( max_truncated_length=st.integers(1, 32), index_dtype=st.sampled_from([torch.int, torch.long]), jagged_tensor_dtype=st.sampled_from( - [torch.float, torch.half, torch.int, torch.long] + [torch.float, torch.half, torch.bfloat16, torch.int, torch.long] ), use_cpu=st.just(True), ) @@ -1233,7 +1230,7 @@ def test_jagged_1d_to_truncated_values( use_cpu: bool, ) -> None: device = "cpu" if use_cpu else "cuda" - is_float = jagged_tensor_dtype in [torch.float, torch.half] + is_float = jagged_tensor_dtype in [torch.float, torch.half, torch.bfloat16] lengths = torch.randint( low=0, high=max_length + 1, @@ -1341,7 +1338,12 @@ def test_masked_select_jagged_1d( num_batches=st.integers(1, 3), index_dtype=st.sampled_from([torch.int, torch.long]), jagged_tensor_dtype=st.sampled_from( - [torch.float, torch.half, torch.int, torch.long] + [ + torch.float, + torch.half, + torch.int, + torch.long, + ] # Disable torch.bfloat16 due to large error bound ), has_weights=st.booleans(), ) @@ -1356,7 +1358,7 @@ def test_keyed_jagged_index_select_dim1( jagged_tensor_dtype: torch.dtype, has_weights: bool, ) -> None: - is_float = jagged_tensor_dtype in [torch.float, torch.half] + is_float = jagged_tensor_dtype in [torch.float, torch.half, torch.bfloat16] lengths = torch.randint( low=0, high=max_seq_length, @@ -1450,8 +1452,8 @@ def test_keyed_jagged_index_select_dim1( torch.testing.assert_close( values.grad, values_ref.grad, - rtol=1e-2 if jagged_tensor_dtype == torch.half else None, - atol=1e-2 if jagged_tensor_dtype == torch.half else None, + rtol=1e-2 if jagged_tensor_dtype in [torch.half, torch.bfloat16] else None, + atol=1e-2 if jagged_tensor_dtype in [torch.half, torch.bfloat16] else None, )