diff --git a/fbgemm_gpu/src/jagged_tensor_ops.cu b/fbgemm_gpu/src/jagged_tensor_ops.cu index 04c0db5457..5e2bc93faa 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops.cu @@ -1288,8 +1288,9 @@ class DenseToJaggedGPUOp }); // device lambda } // lambda ) // CASE - AT_DISPATCH_CASE_FLOATING_TYPES_AND( + AT_DISPATCH_CASE_FLOATING_TYPES_AND2( at::ScalarType::Long, + at::ScalarType::BFloat16, [&] { jagged_dense_elementwise_jagged_output_( values, @@ -1364,17 +1365,19 @@ class JaggedDenseDenseAddJaggedOutputGPUOp -> scalar_t { return x + y_0 + y_1; }); } // lambda ) // CASE - AT_DISPATCH_CASE_FLOATING_TYPES([&] { - jagged_dense_dense_elementwise_jagged_output_( - x_values, - offsets, - dense_0, - dense_1, - output, - [] __device__(scalar_t x, scalar_t y_0, scalar_t y_1) - -> scalar_t { return x + y_0 + y_1; }); - } // lambda - ) // CASE_FLOATING_TYPES_AND + AT_DISPATCH_CASE_FLOATING_TYPES_AND( + at::ScalarType::BFloat16, + [&] { + jagged_dense_dense_elementwise_jagged_output_( + x_values, + offsets, + dense_0, + dense_1, + output, + [] __device__(scalar_t x, scalar_t y_0, scalar_t y_1) + -> scalar_t { return x + y_0 + y_1; }); + } // lambda + ) // CASE_FLOATING_TYPES_AND ); // SWITCH return {output}; @@ -1447,17 +1450,19 @@ class JaggedDenseAddJaggedOutputGPUOp }); // device lambda } // lambda ) // CASE - AT_DISPATCH_CASE_FLOATING_TYPES([&] { - jagged_dense_elementwise_jagged_output_( - x_values, - offsets, - dense, - output, - [] __device__(scalar_t x, scalar_t y) -> scalar_t { - return x + y; - }); // device lambda - } // lambda - ) // CASE_FLOATING_TYPES_AND + AT_DISPATCH_CASE_FLOATING_TYPES_AND( + at::ScalarType::BFloat16, + [&] { + jagged_dense_elementwise_jagged_output_( + x_values, + offsets, + dense, + output, + [] __device__(scalar_t x, scalar_t y) -> scalar_t { + return x + y; + }); // device lambda + } // lambda + ) // CASE_FLOATING_TYPES_AND ); // SWITCH return {output}; @@ -1660,17 +1665,19 @@ class JaggedDenseMulGPUOp }); } // lambda ) // CASE - AT_DISPATCH_CASE_FLOATING_TYPES([&] { - jagged_dense_elementwise_jagged_output_( - x_values, - x_offsets, - y, - output, - [] __device__(scalar_t x, scalar_t y) -> scalar_t { - return x * y; - }); - } // lambda - ) // CASE_FLOATING_TYPES_AND + AT_DISPATCH_CASE_FLOATING_TYPES_AND( + at::ScalarType::BFloat16, + [&] { + jagged_dense_elementwise_jagged_output_( + x_values, + x_offsets, + y, + output, + [] __device__(scalar_t x, scalar_t y) -> scalar_t { + return x * y; + }); + } // lambda + ) // CASE_FLOATING_TYPES_AND ); // SWITCH return {output}; @@ -1693,8 +1700,12 @@ class JaggedDenseMulGPUOp Tensor x_values_grad = at::empty_like(grad_outputs[0]); Tensor y_grad = at::empty_like(y); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - x_values.scalar_type(), "jagged_scalars", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + x_values.scalar_type(), + "jagged_scalars", + [&] { jagged_dense_elementwise_jagged_output_( grad_outputs[0], x_offsets, @@ -2293,8 +2304,9 @@ Tensor jagged_index_select_2d_cuda( at::empty({num_dense_output_rows, num_cols}, values.options()); if (num_blocks > 0) { - AT_DISPATCH_ALL_TYPES_AND( + AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, + at::ScalarType::BFloat16, values.scalar_type(), "jagged_index_select_2d_kernel_wrapper_1", [&] { @@ -2835,8 +2847,9 @@ class KeyedJaggedIndexSelectDim1GPUOp num_outputs); \ } - AT_DISPATCH_ALL_TYPES_AND( + AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, + at::ScalarType::BFloat16, values.scalar_type(), "keyed_jagged_index_select_dim1_warpper_1", [&] { diff --git a/fbgemm_gpu/test/jagged_tensor_ops_test.py b/fbgemm_gpu/test/jagged_tensor_ops_test.py index 2b1261548a..ef120fc5d2 100644 --- a/fbgemm_gpu/test/jagged_tensor_ops_test.py +++ b/fbgemm_gpu/test/jagged_tensor_ops_test.py @@ -627,7 +627,7 @@ def _test_dense_to_jagged( num_jagged_dim=st.integers(1, 5), outer_dense_size=st.integers(0, 5), inner_dense_size=st.integers(0, 5), - dtype=st.sampled_from([torch.float, torch.half]), + dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]), use_cpu=st.booleans() if gpu_available else st.just(True), precompute_total_L=st.booleans(), ) @@ -874,7 +874,7 @@ def mul_func(*args) -> torch.Tensor: outer_dense_size=st.integers(0, 4), inner_dense_size=st.integers(0, 4), operation=st.sampled_from(["add", "add_jagged_output", "mul"]), - dtype=st.sampled_from([torch.float, torch.half, torch.double]), + dtype=st.sampled_from([torch.float, torch.half, torch.double, torch.bfloat16]), use_cpu=st.booleans() if gpu_available else st.just(True), ) @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) @@ -998,7 +998,7 @@ def add_jagged_output_func(*args) -> torch.Tensor: num_jagged_dim=st.integers(1, 4), outer_dense_size=st.integers(0, 4), inner_dense_size=st.integers(0, 4), - dtype=st.sampled_from([torch.float, torch.half, torch.double]), + dtype=st.sampled_from([torch.float, torch.half, torch.double, torch.bfloat16]), use_cpu=st.booleans() if gpu_available else st.just(True), ) @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) @@ -1139,7 +1139,7 @@ def jagged_index_select_2d_ref( num_jagged_tensor_rows=st.integers(1, 128), index_dtype=st.sampled_from([torch.int, torch.long]), jagged_tensor_dtype=st.sampled_from( - [torch.float, torch.half, torch.int, torch.long] + [torch.float, torch.half, torch.bfloat16, torch.int, torch.long] ), ) @settings(max_examples=20, deadline=None) @@ -1152,7 +1152,7 @@ def test_jagged_index_select_2d( index_dtype: torch.dtype, jagged_tensor_dtype: torch.dtype, ) -> None: - is_float = jagged_tensor_dtype in [torch.float, torch.half] + is_float = jagged_tensor_dtype in [torch.float, torch.half, torch.bfloat16] lengths = torch.randint( low=0, high=max_seq_length, @@ -1218,7 +1218,7 @@ def test_jagged_index_select_2d( max_truncated_length=st.integers(1, 32), index_dtype=st.sampled_from([torch.int, torch.long]), jagged_tensor_dtype=st.sampled_from( - [torch.float, torch.half, torch.int, torch.long] + [torch.float, torch.half, torch.bfloat16, torch.int, torch.long] ), use_cpu=st.just(True), ) @@ -1233,7 +1233,7 @@ def test_jagged_1d_to_truncated_values( use_cpu: bool, ) -> None: device = "cpu" if use_cpu else "cuda" - is_float = jagged_tensor_dtype in [torch.float, torch.half] + is_float = jagged_tensor_dtype in [torch.float, torch.half, torch.bfloat16] lengths = torch.randint( low=0, high=max_length + 1, @@ -1341,7 +1341,7 @@ def test_masked_select_jagged_1d( num_batches=st.integers(1, 3), index_dtype=st.sampled_from([torch.int, torch.long]), jagged_tensor_dtype=st.sampled_from( - [torch.float, torch.half, torch.int, torch.long] + [torch.float, torch.half, torch.bfloat16, torch.int, torch.long] ), has_weights=st.booleans(), ) @@ -1356,7 +1356,7 @@ def test_keyed_jagged_index_select_dim1( jagged_tensor_dtype: torch.dtype, has_weights: bool, ) -> None: - is_float = jagged_tensor_dtype in [torch.float, torch.half] + is_float = jagged_tensor_dtype in [torch.float, torch.half, torch.bfloat16] lengths = torch.randint( low=0, high=max_seq_length,