Skip to content

Commit

Permalink
add bf16 support for jagged tensor op (pytorch#1472)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1472

As title

Reviewed By: brad-mengchi, jiawenliu64

Differential Revision: D41445777

fbshipit-source-id: 1a5b3cdd1736d342142e0b38e97f3879e7fd7d05
  • Loading branch information
jianyuh authored and facebook-github-bot committed Nov 21, 2022
1 parent afec5b4 commit 4a12c52
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 47 deletions.
89 changes: 51 additions & 38 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1288,8 +1288,9 @@ class DenseToJaggedGPUOp
}); // device lambda
} // lambda
) // CASE
AT_DISPATCH_CASE_FLOATING_TYPES_AND(
AT_DISPATCH_CASE_FLOATING_TYPES_AND2(
at::ScalarType::Long,
at::ScalarType::BFloat16,
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
values,
Expand Down Expand Up @@ -1364,17 +1365,19 @@ class JaggedDenseDenseAddJaggedOutputGPUOp
-> scalar_t { return x + y_0 + y_1; });
} // lambda
) // CASE
AT_DISPATCH_CASE_FLOATING_TYPES([&] {
jagged_dense_dense_elementwise_jagged_output_<scalar_t>(
x_values,
offsets,
dense_0,
dense_1,
output,
[] __device__(scalar_t x, scalar_t y_0, scalar_t y_1)
-> scalar_t { return x + y_0 + y_1; });
} // lambda
) // CASE_FLOATING_TYPES_AND
AT_DISPATCH_CASE_FLOATING_TYPES_AND(
at::ScalarType::BFloat16,
[&] {
jagged_dense_dense_elementwise_jagged_output_<scalar_t>(
x_values,
offsets,
dense_0,
dense_1,
output,
[] __device__(scalar_t x, scalar_t y_0, scalar_t y_1)
-> scalar_t { return x + y_0 + y_1; });
} // lambda
) // CASE_FLOATING_TYPES_AND
); // SWITCH
return {output};
Expand Down Expand Up @@ -1447,17 +1450,19 @@ class JaggedDenseAddJaggedOutputGPUOp
}); // device lambda
} // lambda
) // CASE
AT_DISPATCH_CASE_FLOATING_TYPES([&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
offsets,
dense,
output,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x + y;
}); // device lambda
} // lambda
) // CASE_FLOATING_TYPES_AND
AT_DISPATCH_CASE_FLOATING_TYPES_AND(
at::ScalarType::BFloat16,
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
offsets,
dense,
output,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x + y;
}); // device lambda
} // lambda
) // CASE_FLOATING_TYPES_AND
); // SWITCH
return {output};
Expand Down Expand Up @@ -1660,17 +1665,19 @@ class JaggedDenseMulGPUOp
});
} // lambda
) // CASE
AT_DISPATCH_CASE_FLOATING_TYPES([&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
x_offsets,
y,
output,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x * y;
});
} // lambda
) // CASE_FLOATING_TYPES_AND
AT_DISPATCH_CASE_FLOATING_TYPES_AND(
at::ScalarType::BFloat16,
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
x_offsets,
y,
output,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x * y;
});
} // lambda
) // CASE_FLOATING_TYPES_AND
); // SWITCH
return {output};
Expand All @@ -1693,8 +1700,12 @@ class JaggedDenseMulGPUOp
Tensor x_values_grad = at::empty_like(grad_outputs[0]);
Tensor y_grad = at::empty_like(y);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values.scalar_type(), "jagged_scalars", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
x_values.scalar_type(),
"jagged_scalars",
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
grad_outputs[0],
x_offsets,
Expand Down Expand Up @@ -2293,8 +2304,9 @@ Tensor jagged_index_select_2d_cuda(
at::empty({num_dense_output_rows, num_cols}, values.options());
if (num_blocks > 0) {
AT_DISPATCH_ALL_TYPES_AND(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"jagged_index_select_2d_kernel_wrapper_1",
[&] {
Expand Down Expand Up @@ -2835,8 +2847,9 @@ class KeyedJaggedIndexSelectDim1GPUOp
num_outputs); \
}
AT_DISPATCH_ALL_TYPES_AND(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
values.scalar_type(),
"keyed_jagged_index_select_dim1_warpper_1",
[&] {
Expand Down
18 changes: 9 additions & 9 deletions fbgemm_gpu/test/jagged_tensor_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ def _test_dense_to_jagged(
num_jagged_dim=st.integers(1, 5),
outer_dense_size=st.integers(0, 5),
inner_dense_size=st.integers(0, 5),
dtype=st.sampled_from([torch.float, torch.half]),
dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]),
use_cpu=st.booleans() if gpu_available else st.just(True),
precompute_total_L=st.booleans(),
)
Expand Down Expand Up @@ -874,7 +874,7 @@ def mul_func(*args) -> torch.Tensor:
outer_dense_size=st.integers(0, 4),
inner_dense_size=st.integers(0, 4),
operation=st.sampled_from(["add", "add_jagged_output", "mul"]),
dtype=st.sampled_from([torch.float, torch.half, torch.double]),
dtype=st.sampled_from([torch.float, torch.half, torch.double, torch.bfloat16]),
use_cpu=st.booleans() if gpu_available else st.just(True),
)
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
Expand Down Expand Up @@ -998,7 +998,7 @@ def add_jagged_output_func(*args) -> torch.Tensor:
num_jagged_dim=st.integers(1, 4),
outer_dense_size=st.integers(0, 4),
inner_dense_size=st.integers(0, 4),
dtype=st.sampled_from([torch.float, torch.half, torch.double]),
dtype=st.sampled_from([torch.float, torch.half, torch.double, torch.bfloat16]),
use_cpu=st.booleans() if gpu_available else st.just(True),
)
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
Expand Down Expand Up @@ -1139,7 +1139,7 @@ def jagged_index_select_2d_ref(
num_jagged_tensor_rows=st.integers(1, 128),
index_dtype=st.sampled_from([torch.int, torch.long]),
jagged_tensor_dtype=st.sampled_from(
[torch.float, torch.half, torch.int, torch.long]
[torch.float, torch.half, torch.bfloat16, torch.int, torch.long]
),
)
@settings(max_examples=20, deadline=None)
Expand All @@ -1152,7 +1152,7 @@ def test_jagged_index_select_2d(
index_dtype: torch.dtype,
jagged_tensor_dtype: torch.dtype,
) -> None:
is_float = jagged_tensor_dtype in [torch.float, torch.half]
is_float = jagged_tensor_dtype in [torch.float, torch.half, torch.bfloat16]
lengths = torch.randint(
low=0,
high=max_seq_length,
Expand Down Expand Up @@ -1218,7 +1218,7 @@ def test_jagged_index_select_2d(
max_truncated_length=st.integers(1, 32),
index_dtype=st.sampled_from([torch.int, torch.long]),
jagged_tensor_dtype=st.sampled_from(
[torch.float, torch.half, torch.int, torch.long]
[torch.float, torch.half, torch.bfloat16, torch.int, torch.long]
),
use_cpu=st.just(True),
)
Expand All @@ -1233,7 +1233,7 @@ def test_jagged_1d_to_truncated_values(
use_cpu: bool,
) -> None:
device = "cpu" if use_cpu else "cuda"
is_float = jagged_tensor_dtype in [torch.float, torch.half]
is_float = jagged_tensor_dtype in [torch.float, torch.half, torch.bfloat16]
lengths = torch.randint(
low=0,
high=max_length + 1,
Expand Down Expand Up @@ -1341,7 +1341,7 @@ def test_masked_select_jagged_1d(
num_batches=st.integers(1, 3),
index_dtype=st.sampled_from([torch.int, torch.long]),
jagged_tensor_dtype=st.sampled_from(
[torch.float, torch.half, torch.int, torch.long]
[torch.float, torch.half, torch.bfloat16, torch.int, torch.long]
),
has_weights=st.booleans(),
)
Expand All @@ -1356,7 +1356,7 @@ def test_keyed_jagged_index_select_dim1(
jagged_tensor_dtype: torch.dtype,
has_weights: bool,
) -> None:
is_float = jagged_tensor_dtype in [torch.float, torch.half]
is_float = jagged_tensor_dtype in [torch.float, torch.half, torch.bfloat16]
lengths = torch.randint(
low=0,
high=max_seq_length,
Expand Down

0 comments on commit 4a12c52

Please sign in to comment.