Skip to content

Commit

Permalink
add bf16 support for jagged tensor op (#1472)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1472

As title

Reviewed By: brad-mengchi, jiawenliu64

Differential Revision: D41445777

fbshipit-source-id: ae083bb7f534285bf54e8803327ea64e84b0ba23
  • Loading branch information
jianyuh authored and facebook-github-bot committed Nov 21, 2022
1 parent afec5b4 commit 169d0bf
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 61 deletions.
103 changes: 61 additions & 42 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1288,8 +1288,9 @@ class DenseToJaggedGPUOp
}); // device lambda
} // lambda
) // CASE
AT_DISPATCH_CASE_FLOATING_TYPES_AND(
AT_DISPATCH_CASE_FLOATING_TYPES_AND2(
at::ScalarType::Long,
at::ScalarType::BFloat16,
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
values,
Expand Down Expand Up @@ -1364,17 +1365,19 @@ class JaggedDenseDenseAddJaggedOutputGPUOp
-> scalar_t { return x + y_0 + y_1; });
} // lambda
) // CASE
AT_DISPATCH_CASE_FLOATING_TYPES([&] {
jagged_dense_dense_elementwise_jagged_output_<scalar_t>(
x_values,
offsets,
dense_0,
dense_1,
output,
[] __device__(scalar_t x, scalar_t y_0, scalar_t y_1)
-> scalar_t { return x + y_0 + y_1; });
} // lambda
) // CASE_FLOATING_TYPES_AND
AT_DISPATCH_CASE_FLOATING_TYPES_AND(
at::ScalarType::BFloat16,
[&] {
jagged_dense_dense_elementwise_jagged_output_<scalar_t>(
x_values,
offsets,
dense_0,
dense_1,
output,
[] __device__(scalar_t x, scalar_t y_0, scalar_t y_1)
-> scalar_t { return x + y_0 + y_1; });
} // lambda
) // CASE_FLOATING_TYPES_AND
); // SWITCH
return {output};
Expand Down Expand Up @@ -1447,17 +1450,19 @@ class JaggedDenseAddJaggedOutputGPUOp
}); // device lambda
} // lambda
) // CASE
AT_DISPATCH_CASE_FLOATING_TYPES([&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
offsets,
dense,
output,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x + y;
}); // device lambda
} // lambda
) // CASE_FLOATING_TYPES_AND
AT_DISPATCH_CASE_FLOATING_TYPES_AND(
at::ScalarType::BFloat16,
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
offsets,
dense,
output,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x + y;
}); // device lambda
} // lambda
) // CASE_FLOATING_TYPES_AND
); // SWITCH
return {output};
Expand Down Expand Up @@ -1660,17 +1665,19 @@ class JaggedDenseMulGPUOp
});
} // lambda
) // CASE
AT_DISPATCH_CASE_FLOATING_TYPES([&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
x_offsets,
y,
output,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x * y;
});
} // lambda
) // CASE_FLOATING_TYPES_AND
AT_DISPATCH_CASE_FLOATING_TYPES_AND(
at::ScalarType::BFloat16,
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
x_offsets,
y,
output,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x * y;
});
} // lambda
) // CASE_FLOATING_TYPES_AND
); // SWITCH
return {output};
Expand All @@ -1693,8 +1700,12 @@ class JaggedDenseMulGPUOp
Tensor x_values_grad = at::empty_like(grad_outputs[0]);
Tensor y_grad = at::empty_like(y);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values.scalar_type(), "jagged_scalars", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
x_values.scalar_type(),
"jagged_scalars",
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
grad_outputs[0],
x_offsets,
Expand Down Expand Up @@ -2115,8 +2126,12 @@ Tensor stacked_jagged_2d_to_dense_backward_cuda(
Tensor grad_values_slice =
grad_values.slice(0, offset_per_key[t], offset_per_key[t + 1]);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_values.scalar_type(), "jagged_2d_to_dense_backward_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad_values.scalar_type(),
"jagged_2d_to_dense_backward_kernel",
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
grad_values_slice, // dummy not used in the lambda function
{offsets_tensor_per_key[t]},
Expand Down Expand Up @@ -2293,8 +2308,9 @@ Tensor jagged_index_select_2d_cuda(
at::empty({num_dense_output_rows, num_cols}, values.options());
if (num_blocks > 0) {
AT_DISPATCH_ALL_TYPES_AND(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"jagged_index_select_2d_kernel_wrapper_1",
[&] {
Expand Down Expand Up @@ -2388,8 +2404,9 @@ Tensor jagged_index_add_2d_cuda(
Tensor output = at::zeros({num_output_rows, num_cols}, grad.options());
if (num_blocks > 0) {
AT_DISPATCH_ALL_TYPES_AND(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad.scalar_type(),
"jagged_index_add_2d_kernel_wrapper_1",
[&] {
Expand Down Expand Up @@ -2835,8 +2852,9 @@ class KeyedJaggedIndexSelectDim1GPUOp
num_outputs); \
}
AT_DISPATCH_ALL_TYPES_AND(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"keyed_jagged_index_select_dim1_warpper_1",
[&] {
Expand Down Expand Up @@ -2914,8 +2932,9 @@ class KeyedJaggedIndexSelectDim1GPUOp
Tensor grad_input = at::zeros({num_outputs}, grad.options());
auto grid_size = cuda_calc_xblock_count(grad.numel(), kMaxThreads);
AT_DISPATCH_ALL_TYPES_AND(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad.scalar_type(),
"keyed_jagged_index_add_dim1_wrapper_1",
[&] {
Expand Down
34 changes: 24 additions & 10 deletions fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,9 @@ class JaggedToPaddedDenseCPUOp
Tensor padded_values_view =
values.dim() == 1 ? padded_values.unsqueeze(-1) : padded_values;

AT_DISPATCH_ALL_TYPES_AND(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"jagged_to_padded_dense",
[&] {
Expand Down Expand Up @@ -437,8 +438,9 @@ class JaggedToPaddedDenseCPUOp
// in forward.
auto grad_values = at::zeros({total_L, D}, grad_padded_values.options());

AT_DISPATCH_ALL_TYPES_AND(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad_padded_values.scalar_type(),
"jagged_2d_to_dense_backward_kernel",
[&] {
Expand Down Expand Up @@ -519,9 +521,9 @@ class DenseToJaggedCPUOp
auto values = at::empty({total_L_computed, D}, dense.options());
auto output = at::zeros({total_L_computed, D}, dense.options());

AT_DISPATCH_FLOATING_TYPES_AND2(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::Long,
at::ScalarType::BFloat16,
values.scalar_type(),
"jagged_scalars",
[&] {
Expand Down Expand Up @@ -886,7 +888,9 @@ class BatchedDenseVecJagged2DMulCPUOp
if (B > 0 && D > 0) {
AT_DISPATCH_INDEX_TYPES(
a_offsets.scalar_type(), "dense_vec_jagged_2d_bmm_kernel_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
a_values.scalar_type(),
"dense_vec_jagged_2d_bmm_kernel_2",
[&] {
Expand Down Expand Up @@ -925,7 +929,9 @@ class BatchedDenseVecJagged2DMulCPUOp
a_offsets.scalar_type(),
"dense_vec_jagged_2d_bmm_baackward_kernel_1",
[&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad_outputs[0].scalar_type(),
"dense_vec_jagged_2d_bmm_baackward_kernel_2",
[&] {
Expand Down Expand Up @@ -974,8 +980,12 @@ Tensor jagged_1d_to_truncated_values_cpu(
Tensor truncated_values;
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "jagged_1d_to_truncated_values_cpu_kernel", [&] {
AT_DISPATCH_ALL_TYPES_AND_HALF(
values.scalar_type(), "copy_values_and_truncate_cpu_kernel", [&] {
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"copy_values_and_truncate_cpu_kernel",
[&] {
const index_t max_length_int =
static_cast<index_t>(max_truncated_length);
const auto lengths_accessor = lengths.accessor<index_t, 1>();
Expand Down Expand Up @@ -1021,8 +1031,12 @@ std::tuple<Tensor, Tensor> masked_select_jagged_1d(

AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "mask_select_jagged_1d_kernel1", [&] {
AT_DISPATCH_ALL_TYPES_AND_HALF(
values.scalar_type(), "mask_select_jagged_1d_kernel2", [&] {
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"mask_select_jagged_1d_kernel2",
[&] {
const int32_t num_outputs = mask.sum().item<int32_t>();
masked_values = at::empty({num_outputs}, values.options());

Expand Down
18 changes: 9 additions & 9 deletions fbgemm_gpu/test/jagged_tensor_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ def _test_dense_to_jagged(
num_jagged_dim=st.integers(1, 5),
outer_dense_size=st.integers(0, 5),
inner_dense_size=st.integers(0, 5),
dtype=st.sampled_from([torch.float, torch.half]),
dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]),
use_cpu=st.booleans() if gpu_available else st.just(True),
precompute_total_L=st.booleans(),
)
Expand Down Expand Up @@ -874,7 +874,7 @@ def mul_func(*args) -> torch.Tensor:
outer_dense_size=st.integers(0, 4),
inner_dense_size=st.integers(0, 4),
operation=st.sampled_from(["add", "add_jagged_output", "mul"]),
dtype=st.sampled_from([torch.float, torch.half, torch.double]),
dtype=st.sampled_from([torch.float, torch.half, torch.double, torch.bfloat16]),
use_cpu=st.booleans() if gpu_available else st.just(True),
)
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
Expand Down Expand Up @@ -998,7 +998,7 @@ def add_jagged_output_func(*args) -> torch.Tensor:
num_jagged_dim=st.integers(1, 4),
outer_dense_size=st.integers(0, 4),
inner_dense_size=st.integers(0, 4),
dtype=st.sampled_from([torch.float, torch.half, torch.double]),
dtype=st.sampled_from([torch.float, torch.half, torch.double, torch.bfloat16]),
use_cpu=st.booleans() if gpu_available else st.just(True),
)
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
Expand Down Expand Up @@ -1139,7 +1139,7 @@ def jagged_index_select_2d_ref(
num_jagged_tensor_rows=st.integers(1, 128),
index_dtype=st.sampled_from([torch.int, torch.long]),
jagged_tensor_dtype=st.sampled_from(
[torch.float, torch.half, torch.int, torch.long]
[torch.float, torch.half, torch.bfloat16, torch.int, torch.long]
),
)
@settings(max_examples=20, deadline=None)
Expand All @@ -1152,7 +1152,7 @@ def test_jagged_index_select_2d(
index_dtype: torch.dtype,
jagged_tensor_dtype: torch.dtype,
) -> None:
is_float = jagged_tensor_dtype in [torch.float, torch.half]
is_float = jagged_tensor_dtype in [torch.float, torch.half, torch.bfloat16]
lengths = torch.randint(
low=0,
high=max_seq_length,
Expand Down Expand Up @@ -1218,7 +1218,7 @@ def test_jagged_index_select_2d(
max_truncated_length=st.integers(1, 32),
index_dtype=st.sampled_from([torch.int, torch.long]),
jagged_tensor_dtype=st.sampled_from(
[torch.float, torch.half, torch.int, torch.long]
[torch.float, torch.half, torch.bfloat16, torch.int, torch.long]
),
use_cpu=st.just(True),
)
Expand All @@ -1233,7 +1233,7 @@ def test_jagged_1d_to_truncated_values(
use_cpu: bool,
) -> None:
device = "cpu" if use_cpu else "cuda"
is_float = jagged_tensor_dtype in [torch.float, torch.half]
is_float = jagged_tensor_dtype in [torch.float, torch.half, torch.bfloat16]
lengths = torch.randint(
low=0,
high=max_length + 1,
Expand Down Expand Up @@ -1341,7 +1341,7 @@ def test_masked_select_jagged_1d(
num_batches=st.integers(1, 3),
index_dtype=st.sampled_from([torch.int, torch.long]),
jagged_tensor_dtype=st.sampled_from(
[torch.float, torch.half, torch.int, torch.long]
[torch.float, torch.half, torch.bfloat16, torch.int, torch.long]
),
has_weights=st.booleans(),
)
Expand All @@ -1356,7 +1356,7 @@ def test_keyed_jagged_index_select_dim1(
jagged_tensor_dtype: torch.dtype,
has_weights: bool,
) -> None:
is_float = jagged_tensor_dtype in [torch.float, torch.half]
is_float = jagged_tensor_dtype in [torch.float, torch.half, torch.bfloat16]
lengths = torch.randint(
low=0,
high=max_seq_length,
Expand Down

0 comments on commit 169d0bf

Please sign in to comment.