Skip to content

Commit

Permalink
add bf16 support for jagged tensor op (pytorch#1472)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1472

As title

Reviewed By: brad-mengchi, jiawenliu64

Differential Revision: D41445777

fbshipit-source-id: 2a8814fce8d941594b03ec0cc95122920b8369d3
  • Loading branch information
jianyuh authored and facebook-github-bot committed Nov 22, 2022
1 parent afec5b4 commit b07d833
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 84 deletions.
103 changes: 61 additions & 42 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1288,8 +1288,9 @@ class DenseToJaggedGPUOp
}); // device lambda
} // lambda
) // CASE
AT_DISPATCH_CASE_FLOATING_TYPES_AND(
AT_DISPATCH_CASE_FLOATING_TYPES_AND2(
at::ScalarType::Long,
at::ScalarType::BFloat16,
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
values,
Expand Down Expand Up @@ -1364,17 +1365,19 @@ class JaggedDenseDenseAddJaggedOutputGPUOp
-> scalar_t { return x + y_0 + y_1; });
} // lambda
) // CASE
AT_DISPATCH_CASE_FLOATING_TYPES([&] {
jagged_dense_dense_elementwise_jagged_output_<scalar_t>(
x_values,
offsets,
dense_0,
dense_1,
output,
[] __device__(scalar_t x, scalar_t y_0, scalar_t y_1)
-> scalar_t { return x + y_0 + y_1; });
} // lambda
) // CASE_FLOATING_TYPES_AND
AT_DISPATCH_CASE_FLOATING_TYPES_AND(
at::ScalarType::BFloat16,
[&] {
jagged_dense_dense_elementwise_jagged_output_<scalar_t>(
x_values,
offsets,
dense_0,
dense_1,
output,
[] __device__(scalar_t x, scalar_t y_0, scalar_t y_1)
-> scalar_t { return x + y_0 + y_1; });
} // lambda
) // CASE_FLOATING_TYPES_AND
); // SWITCH
return {output};
Expand Down Expand Up @@ -1447,17 +1450,19 @@ class JaggedDenseAddJaggedOutputGPUOp
}); // device lambda
} // lambda
) // CASE
AT_DISPATCH_CASE_FLOATING_TYPES([&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
offsets,
dense,
output,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x + y;
}); // device lambda
} // lambda
) // CASE_FLOATING_TYPES_AND
AT_DISPATCH_CASE_FLOATING_TYPES_AND(
at::ScalarType::BFloat16,
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
offsets,
dense,
output,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x + y;
}); // device lambda
} // lambda
) // CASE_FLOATING_TYPES_AND
); // SWITCH
return {output};
Expand Down Expand Up @@ -1660,17 +1665,19 @@ class JaggedDenseMulGPUOp
});
} // lambda
) // CASE
AT_DISPATCH_CASE_FLOATING_TYPES([&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
x_offsets,
y,
output,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x * y;
});
} // lambda
) // CASE_FLOATING_TYPES_AND
AT_DISPATCH_CASE_FLOATING_TYPES_AND(
at::ScalarType::BFloat16,
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
x_offsets,
y,
output,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x * y;
});
} // lambda
) // CASE_FLOATING_TYPES_AND
); // SWITCH
return {output};
Expand All @@ -1693,8 +1700,12 @@ class JaggedDenseMulGPUOp
Tensor x_values_grad = at::empty_like(grad_outputs[0]);
Tensor y_grad = at::empty_like(y);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values.scalar_type(), "jagged_scalars", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
x_values.scalar_type(),
"jagged_scalars",
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
grad_outputs[0],
x_offsets,
Expand Down Expand Up @@ -2115,8 +2126,12 @@ Tensor stacked_jagged_2d_to_dense_backward_cuda(
Tensor grad_values_slice =
grad_values.slice(0, offset_per_key[t], offset_per_key[t + 1]);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_values.scalar_type(), "jagged_2d_to_dense_backward_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad_values.scalar_type(),
"jagged_2d_to_dense_backward_kernel",
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
grad_values_slice, // dummy not used in the lambda function
{offsets_tensor_per_key[t]},
Expand Down Expand Up @@ -2293,8 +2308,9 @@ Tensor jagged_index_select_2d_cuda(
at::empty({num_dense_output_rows, num_cols}, values.options());
if (num_blocks > 0) {
AT_DISPATCH_ALL_TYPES_AND(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"jagged_index_select_2d_kernel_wrapper_1",
[&] {
Expand Down Expand Up @@ -2388,8 +2404,9 @@ Tensor jagged_index_add_2d_cuda(
Tensor output = at::zeros({num_output_rows, num_cols}, grad.options());
if (num_blocks > 0) {
AT_DISPATCH_ALL_TYPES_AND(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad.scalar_type(),
"jagged_index_add_2d_kernel_wrapper_1",
[&] {
Expand Down Expand Up @@ -2835,8 +2852,9 @@ class KeyedJaggedIndexSelectDim1GPUOp
num_outputs); \
}
AT_DISPATCH_ALL_TYPES_AND(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"keyed_jagged_index_select_dim1_warpper_1",
[&] {
Expand Down Expand Up @@ -2914,8 +2932,9 @@ class KeyedJaggedIndexSelectDim1GPUOp
Tensor grad_input = at::zeros({num_outputs}, grad.options());
auto grid_size = cuda_calc_xblock_count(grad.numel(), kMaxThreads);
AT_DISPATCH_ALL_TYPES_AND(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad.scalar_type(),
"keyed_jagged_index_add_dim1_wrapper_1",
[&] {
Expand Down
34 changes: 24 additions & 10 deletions fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,9 @@ class JaggedToPaddedDenseCPUOp
Tensor padded_values_view =
values.dim() == 1 ? padded_values.unsqueeze(-1) : padded_values;

AT_DISPATCH_ALL_TYPES_AND(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"jagged_to_padded_dense",
[&] {
Expand Down Expand Up @@ -437,8 +438,9 @@ class JaggedToPaddedDenseCPUOp
// in forward.
auto grad_values = at::zeros({total_L, D}, grad_padded_values.options());

AT_DISPATCH_ALL_TYPES_AND(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad_padded_values.scalar_type(),
"jagged_2d_to_dense_backward_kernel",
[&] {
Expand Down Expand Up @@ -519,9 +521,9 @@ class DenseToJaggedCPUOp
auto values = at::empty({total_L_computed, D}, dense.options());
auto output = at::zeros({total_L_computed, D}, dense.options());

AT_DISPATCH_FLOATING_TYPES_AND2(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::Long,
at::ScalarType::BFloat16,
values.scalar_type(),
"jagged_scalars",
[&] {
Expand Down Expand Up @@ -886,7 +888,9 @@ class BatchedDenseVecJagged2DMulCPUOp
if (B > 0 && D > 0) {
AT_DISPATCH_INDEX_TYPES(
a_offsets.scalar_type(), "dense_vec_jagged_2d_bmm_kernel_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
a_values.scalar_type(),
"dense_vec_jagged_2d_bmm_kernel_2",
[&] {
Expand Down Expand Up @@ -925,7 +929,9 @@ class BatchedDenseVecJagged2DMulCPUOp
a_offsets.scalar_type(),
"dense_vec_jagged_2d_bmm_baackward_kernel_1",
[&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad_outputs[0].scalar_type(),
"dense_vec_jagged_2d_bmm_baackward_kernel_2",
[&] {
Expand Down Expand Up @@ -974,8 +980,12 @@ Tensor jagged_1d_to_truncated_values_cpu(
Tensor truncated_values;
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "jagged_1d_to_truncated_values_cpu_kernel", [&] {
AT_DISPATCH_ALL_TYPES_AND_HALF(
values.scalar_type(), "copy_values_and_truncate_cpu_kernel", [&] {
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"copy_values_and_truncate_cpu_kernel",
[&] {
const index_t max_length_int =
static_cast<index_t>(max_truncated_length);
const auto lengths_accessor = lengths.accessor<index_t, 1>();
Expand Down Expand Up @@ -1021,8 +1031,12 @@ std::tuple<Tensor, Tensor> masked_select_jagged_1d(

AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "mask_select_jagged_1d_kernel1", [&] {
AT_DISPATCH_ALL_TYPES_AND_HALF(
values.scalar_type(), "mask_select_jagged_1d_kernel2", [&] {
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"mask_select_jagged_1d_kernel2",
[&] {
const int32_t num_outputs = mask.sum().item<int32_t>();
masked_values = at::empty({num_outputs}, values.options());

Expand Down
Loading

0 comments on commit b07d833

Please sign in to comment.