Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add bf16 support for jagged tensor op #1472

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 61 additions & 42 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1288,8 +1288,9 @@ class DenseToJaggedGPUOp
}); // device lambda
} // lambda
) // CASE
AT_DISPATCH_CASE_FLOATING_TYPES_AND(
AT_DISPATCH_CASE_FLOATING_TYPES_AND2(
at::ScalarType::Long,
at::ScalarType::BFloat16,
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
values,
Expand Down Expand Up @@ -1364,17 +1365,19 @@ class JaggedDenseDenseAddJaggedOutputGPUOp
-> scalar_t { return x + y_0 + y_1; });
} // lambda
) // CASE
AT_DISPATCH_CASE_FLOATING_TYPES([&] {
jagged_dense_dense_elementwise_jagged_output_<scalar_t>(
x_values,
offsets,
dense_0,
dense_1,
output,
[] __device__(scalar_t x, scalar_t y_0, scalar_t y_1)
-> scalar_t { return x + y_0 + y_1; });
} // lambda
) // CASE_FLOATING_TYPES_AND
AT_DISPATCH_CASE_FLOATING_TYPES_AND(
at::ScalarType::BFloat16,
[&] {
jagged_dense_dense_elementwise_jagged_output_<scalar_t>(
x_values,
offsets,
dense_0,
dense_1,
output,
[] __device__(scalar_t x, scalar_t y_0, scalar_t y_1)
-> scalar_t { return x + y_0 + y_1; });
} // lambda
) // CASE_FLOATING_TYPES_AND
); // SWITCH

return {output};
Expand Down Expand Up @@ -1447,17 +1450,19 @@ class JaggedDenseAddJaggedOutputGPUOp
}); // device lambda
} // lambda
) // CASE
AT_DISPATCH_CASE_FLOATING_TYPES([&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
offsets,
dense,
output,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x + y;
}); // device lambda
} // lambda
) // CASE_FLOATING_TYPES_AND
AT_DISPATCH_CASE_FLOATING_TYPES_AND(
at::ScalarType::BFloat16,
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
offsets,
dense,
output,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x + y;
}); // device lambda
} // lambda
) // CASE_FLOATING_TYPES_AND
); // SWITCH

return {output};
Expand Down Expand Up @@ -1660,17 +1665,19 @@ class JaggedDenseMulGPUOp
});
} // lambda
) // CASE
AT_DISPATCH_CASE_FLOATING_TYPES([&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
x_offsets,
y,
output,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x * y;
});
} // lambda
) // CASE_FLOATING_TYPES_AND
AT_DISPATCH_CASE_FLOATING_TYPES_AND(
at::ScalarType::BFloat16,
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
x_offsets,
y,
output,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x * y;
});
} // lambda
) // CASE_FLOATING_TYPES_AND
); // SWITCH

return {output};
Expand All @@ -1693,8 +1700,12 @@ class JaggedDenseMulGPUOp
Tensor x_values_grad = at::empty_like(grad_outputs[0]);
Tensor y_grad = at::empty_like(y);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values.scalar_type(), "jagged_scalars", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
x_values.scalar_type(),
"jagged_scalars",
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
grad_outputs[0],
x_offsets,
Expand Down Expand Up @@ -2115,8 +2126,12 @@ Tensor stacked_jagged_2d_to_dense_backward_cuda(
Tensor grad_values_slice =
grad_values.slice(0, offset_per_key[t], offset_per_key[t + 1]);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_values.scalar_type(), "jagged_2d_to_dense_backward_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad_values.scalar_type(),
"jagged_2d_to_dense_backward_kernel",
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
grad_values_slice, // dummy not used in the lambda function
{offsets_tensor_per_key[t]},
Expand Down Expand Up @@ -2293,8 +2308,9 @@ Tensor jagged_index_select_2d_cuda(
at::empty({num_dense_output_rows, num_cols}, values.options());

if (num_blocks > 0) {
AT_DISPATCH_ALL_TYPES_AND(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"jagged_index_select_2d_kernel_wrapper_1",
[&] {
Expand Down Expand Up @@ -2388,8 +2404,9 @@ Tensor jagged_index_add_2d_cuda(
Tensor output = at::zeros({num_output_rows, num_cols}, grad.options());

if (num_blocks > 0) {
AT_DISPATCH_ALL_TYPES_AND(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad.scalar_type(),
"jagged_index_add_2d_kernel_wrapper_1",
[&] {
Expand Down Expand Up @@ -2835,8 +2852,9 @@ class KeyedJaggedIndexSelectDim1GPUOp
num_outputs); \
}

AT_DISPATCH_ALL_TYPES_AND(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"keyed_jagged_index_select_dim1_warpper_1",
[&] {
Expand Down Expand Up @@ -2914,8 +2932,9 @@ class KeyedJaggedIndexSelectDim1GPUOp
Tensor grad_input = at::zeros({num_outputs}, grad.options());
auto grid_size = cuda_calc_xblock_count(grad.numel(), kMaxThreads);

AT_DISPATCH_ALL_TYPES_AND(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad.scalar_type(),
"keyed_jagged_index_add_dim1_wrapper_1",
[&] {
Expand Down
34 changes: 24 additions & 10 deletions fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,9 @@ class JaggedToPaddedDenseCPUOp
Tensor padded_values_view =
values.dim() == 1 ? padded_values.unsqueeze(-1) : padded_values;

AT_DISPATCH_ALL_TYPES_AND(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"jagged_to_padded_dense",
[&] {
Expand Down Expand Up @@ -437,8 +438,9 @@ class JaggedToPaddedDenseCPUOp
// in forward.
auto grad_values = at::zeros({total_L, D}, grad_padded_values.options());

AT_DISPATCH_ALL_TYPES_AND(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad_padded_values.scalar_type(),
"jagged_2d_to_dense_backward_kernel",
[&] {
Expand Down Expand Up @@ -519,9 +521,9 @@ class DenseToJaggedCPUOp
auto values = at::empty({total_L_computed, D}, dense.options());
auto output = at::zeros({total_L_computed, D}, dense.options());

AT_DISPATCH_FLOATING_TYPES_AND2(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::Long,
at::ScalarType::BFloat16,
values.scalar_type(),
"jagged_scalars",
[&] {
Expand Down Expand Up @@ -886,7 +888,9 @@ class BatchedDenseVecJagged2DMulCPUOp
if (B > 0 && D > 0) {
AT_DISPATCH_INDEX_TYPES(
a_offsets.scalar_type(), "dense_vec_jagged_2d_bmm_kernel_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
a_values.scalar_type(),
"dense_vec_jagged_2d_bmm_kernel_2",
[&] {
Expand Down Expand Up @@ -925,7 +929,9 @@ class BatchedDenseVecJagged2DMulCPUOp
a_offsets.scalar_type(),
"dense_vec_jagged_2d_bmm_baackward_kernel_1",
[&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad_outputs[0].scalar_type(),
"dense_vec_jagged_2d_bmm_baackward_kernel_2",
[&] {
Expand Down Expand Up @@ -974,8 +980,12 @@ Tensor jagged_1d_to_truncated_values_cpu(
Tensor truncated_values;
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "jagged_1d_to_truncated_values_cpu_kernel", [&] {
AT_DISPATCH_ALL_TYPES_AND_HALF(
values.scalar_type(), "copy_values_and_truncate_cpu_kernel", [&] {
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"copy_values_and_truncate_cpu_kernel",
[&] {
const index_t max_length_int =
static_cast<index_t>(max_truncated_length);
const auto lengths_accessor = lengths.accessor<index_t, 1>();
Expand Down Expand Up @@ -1021,8 +1031,12 @@ std::tuple<Tensor, Tensor> masked_select_jagged_1d(

AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "mask_select_jagged_1d_kernel1", [&] {
AT_DISPATCH_ALL_TYPES_AND_HALF(
values.scalar_type(), "mask_select_jagged_1d_kernel2", [&] {
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"mask_select_jagged_1d_kernel2",
[&] {
const int32_t num_outputs = mask.sum().item<int32_t>();
masked_values = at::empty({num_outputs}, values.options());

Expand Down
Loading