diff --git a/fbgemm_gpu/src/jagged_tensor_ops.cu b/fbgemm_gpu/src/jagged_tensor_ops.cu index 7cd3c0148d..d3fa03941f 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops.cu @@ -226,58 +226,6 @@ void jagged_dense_elementwise_dense_output_( #undef INVOKE_KERNEL_WITH_DIM } -// Almost identical copy of jagged_to_padded_dense in jagged_tensor_ops_cpu.cpp -Tensor jagged_to_padded_dense( - const Tensor& values, - const std::vector& offsets, - const std::vector& max_lengths, - const int64_t padding_value) { - const size_t num_jagged_dim = offsets.size(); - TORCH_CHECK( - max_lengths.size() == num_jagged_dim, - "max_lengths.size(), ", - max_lengths.size(), - " != num_jagged_dim, ", - num_jagged_dim); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(values.get_device()); - - const Tensor values_canonicalized = values.view( - {values.size(0), - std::accumulate( - values.sizes().begin() + 1, - values.sizes().end(), - 1, - std::multiplies())}); - at::DimVector padded_values_shape({offsets[0].size(0) - 1}); - padded_values_shape.insert( - padded_values_shape.end(), max_lengths.begin(), max_lengths.end()); - if (values.dim() > 1) { - padded_values_shape.push_back(values.size(-1)); - } - Tensor padded_values = at::empty(padded_values_shape, values.options()); - Tensor padded_values_view = - values.dim() == 1 ? padded_values.unsqueeze(-1) : padded_values; - - AT_DISPATCH_ALL_TYPES_AND( - at::ScalarType::Half, - values.scalar_type(), - "jagged_to_padded_dense", - [&] { - jagged_dense_elementwise_dense_output_( - values_canonicalized, - offsets, - padded_values_view, // dummy not used in the lambda function - padded_values_view, - [] __device__(scalar_t x, scalar_t /*unused*/) -> scalar_t { - return x; - }, - static_cast(padding_value)); - }); - - return padded_values; -} - template Tensor jagged_dense_elementwise_dense_output_( const Tensor& x_values, @@ -396,6 +344,117 @@ Tensor jagged_dense_elementwise_jagged_output_( return output; } +class JaggedToPaddedDenseGPUOp + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const Tensor& values, + const std::vector& offsets, + const std::vector& max_lengths, + const int64_t padding_value) { + ctx->save_for_backward(offsets); + ctx->saved_data["total_L"] = values.size(0); + + const size_t num_jagged_dim = offsets.size(); + TORCH_CHECK( + max_lengths.size() == num_jagged_dim, + "max_lengths.size(), ", + max_lengths.size(), + " != num_jagged_dim, ", + num_jagged_dim); + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(values.get_device()); + + const Tensor values_canonicalized = values.view( + {values.size(0), + std::accumulate( + values.sizes().begin() + 1, + values.sizes().end(), + 1, + std::multiplies())}); + at::DimVector padded_values_shape({offsets[0].size(0) - 1}); + padded_values_shape.insert( + padded_values_shape.end(), max_lengths.begin(), max_lengths.end()); + if (values.dim() > 1) { + padded_values_shape.push_back(values.size(-1)); + } + Tensor padded_values = at::empty(padded_values_shape, values.options()); + Tensor padded_values_view = + values.dim() == 1 ? padded_values.unsqueeze(-1) : padded_values; + + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, + values.scalar_type(), + "jagged_to_padded_dense", + [&] { + jagged_dense_elementwise_dense_output_( + values_canonicalized, + offsets, + padded_values_view, // dummy not used in the lambda function + padded_values_view, + [] __device__(scalar_t x, scalar_t /*unused*/) -> scalar_t { + return x; + }, + static_cast(padding_value)); + }); + + return {padded_values}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_outputs) { + auto offsets = ctx->get_saved_variables(); + int32_t total_L = ctx->saved_data["total_L"].toInt(); + TORCH_CHECK(grad_outputs.size() == 1); + + TORCH_CHECK(total_L >= 0); + auto grad_padded_values = grad_outputs[0]; + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(grad_padded_values.get_device()); + + int32_t D = grad_padded_values.size(-1); + auto grad_values = at::zeros({total_L, D}, grad_padded_values.options()); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_padded_values.scalar_type(), + "jagged_2d_to_dense_backward_kernel", + [&] { + jagged_dense_elementwise_jagged_output_( + grad_values, // dummy not used in the lambda function + {offsets}, + grad_padded_values, + grad_values, + [] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t { + return y; + }); + }); + + return { + grad_values, + torch::autograd::Variable(), // offsets + torch::autograd::Variable(), // max_lengths + torch::autograd::Variable(), // padding_value + }; + } +}; + +Tensor jagged_to_padded_dense( + const Tensor& values, + const std::vector& offsets, + const std::vector& max_lengths, + const int64_t padding_value) { + return JaggedToPaddedDenseGPUOp::apply( + values, offsets, max_lengths, padding_value)[0]; +} + +Tensor +jagged_2d_to_dense(Tensor values, Tensor offsets, int64_t max_sequence_length) { + return jagged_to_padded_dense( + values, {offsets}, {max_sequence_length}, /*padding_value=*/0L); +} + class JaggedDenseAddGPUOp : public torch::autograd::Function { public: @@ -406,7 +465,6 @@ class JaggedDenseAddGPUOp const Tensor& y) { ctx->save_for_backward(x_offsets); ctx->saved_data["x_values_shape"] = x_values.sizes(); - ctx->saved_data["y_shape"] = y.sizes(); at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(x_values.get_device()); @@ -431,7 +489,6 @@ class JaggedDenseAddGPUOp torch::autograd::variable_list grad_outputs) { auto offsets = ctx->get_saved_variables(); auto x_values_shape = ctx->saved_data["x_values_shape"].toIntVector(); - auto y_shape = ctx->saved_data["y_shape"].toIntVector(); TORCH_CHECK(grad_outputs.size() == 1); at::cuda::OptionalCUDAGuard device_guard; @@ -466,6 +523,73 @@ Tensor jagged_dense_elementwise_add( return JaggedDenseAddGPUOp::apply(x_values, x_offsets, y)[0]; } +// Unlike JaggedDenseAddGPUOp that treats "zeros" as zeros so adding with +// a dense tensor results in a dense tensor, this operator treats "zeros" as +// undefined so resulting a jagged tensor. +class JaggedDenseAddJaggedOutputGPUOp + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y) { + ctx->save_for_backward(x_offsets); + ctx->saved_data["y_shape"] = y.sizes(); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(x_values.get_device()); + + Tensor output; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + x_values.scalar_type(), "jagged_dense_add_forward", [&] { + output = jagged_dense_elementwise_jagged_output_( + x_values, + x_offsets, + y, + [] __device__(scalar_t x, scalar_t y) -> scalar_t { + return x + y; + }); + }); + + return {output}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_outputs) { + auto offsets = ctx->get_saved_variables(); + auto y_shape = ctx->saved_data["y_shape"].toIntVector(); + TORCH_CHECK(grad_outputs.size() == 1); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(grad_outputs[0].get_device()); + + Tensor y_values_grad = jagged_to_padded_dense( + grad_outputs[0], + offsets, + std::vector(y_shape.begin() + 1, y_shape.end() - 1), + /*padding_value=*/0); + TORCH_CHECK(y_values_grad.sizes() == y_shape); + + return { + grad_outputs[0], + torch::autograd::Variable(), // x_offsets + y_values_grad}; + } +}; + +// output = x + y where x is jagged, y is dense, and output is jagged +std::tuple> +jagged_dense_elementwise_add_jagged_output( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y) { + return { + JaggedDenseAddJaggedOutputGPUOp::apply(x_values, x_offsets, y)[0], + x_offsets}; +} + /** * output = f(x, y) where x and y are jagged (and share x_offsets), and output * is dense. @@ -924,45 +1048,6 @@ Tensor batched_dense_vec_jagged_2d_mul( } // namespace -Tensor -jagged_2d_to_dense_forward_cuda(Tensor values, Tensor offsets, int32_t max_L) { - TORCH_CHECK(values.dim() == 2); - TORCH_CHECK(offsets.dim() == 1); - TORCH_CHECK(max_L > 0); - - return jagged_to_padded_dense(values, {offsets}, {max_L}, 0); -} - -Tensor jagged_2d_to_dense_backward_cuda( - Tensor grad_padded_values, - Tensor offsets, - int32_t total_L) { - TORCH_CHECK(grad_padded_values.dim() == 3); - TORCH_CHECK(offsets.dim() == 1); - TORCH_CHECK(total_L >= 0); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(grad_padded_values.get_device()); - - int32_t D = grad_padded_values.size(2); - auto grad_values = at::zeros({total_L, D}, grad_padded_values.options()); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad_padded_values.scalar_type(), - "jagged_2d_to_dense_backward_kernel", - [&] { - jagged_dense_elementwise_jagged_output_( - grad_values, // dummy not used in the lambda function - {offsets}, - grad_padded_values, - grad_values, - [] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t { - return y; - }); - }); - - return grad_values; -} - Tensor jagged_1d_to_dense_gpu( Tensor values, Tensor offsets, @@ -1023,10 +1108,11 @@ stacked_jagged_2d_to_dense_forward_cuda( }); offsets_tensor_per_key.push_back(offsets); - padded_values_per_key.push_back(jagged_2d_to_dense_forward_cuda( + padded_values_per_key.push_back(jagged_to_padded_dense( values.slice(0, offset_per_key[t], offset_per_key[t + 1]), - offsets, - max_L)); + {offsets}, + {max_L}, + /*padding_value=*/0L)); } return std::make_tuple(padded_values_per_key, offsets_tensor_per_key); @@ -1128,8 +1214,12 @@ std::vector stacked_jagged_1d_to_dense_gpu( TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { DISPATCH_TO_CUDA( "jagged_to_padded_dense", fbgemm_gpu::jagged_to_padded_dense); + DISPATCH_TO_CUDA("jagged_2d_to_dense", fbgemm_gpu::jagged_2d_to_dense); DISPATCH_TO_CUDA( "jagged_dense_elementwise_add", fbgemm_gpu::jagged_dense_elementwise_add); + DISPATCH_TO_CUDA( + "jagged_dense_elementwise_add_jagged_output", + fbgemm_gpu::jagged_dense_elementwise_add_jagged_output); DISPATCH_TO_CUDA( "jagged_dense_elementwise_mul", fbgemm_gpu::jagged_dense_elementwise_mul); DISPATCH_TO_CUDA( diff --git a/fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp b/fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp index 93484f5b30..f2811a9a81 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp @@ -219,53 +219,6 @@ void jagged_dense_elementwise_dense_output_( #undef INVOKE_KERNEL_WITH_DIM } -Tensor jagged_to_padded_dense( - const Tensor& values, - const std::vector& offsets, - const std::vector& max_lengths, - const int64_t padding_value = 0) { - const size_t num_jagged_dim = offsets.size(); - TORCH_CHECK( - max_lengths.size() == num_jagged_dim, - "max_lengths.size(), ", - max_lengths.size(), - " != num_jagged_dim, ", - num_jagged_dim); - - const Tensor values_canonicalized = values.view( - {values.size(0), - std::accumulate( - values.sizes().begin() + 1, - values.sizes().end(), - 1, - std::multiplies())}); - at::DimVector padded_values_shape({offsets[0].size(0) - 1}); - padded_values_shape.insert( - padded_values_shape.end(), max_lengths.begin(), max_lengths.end()); - if (values.dim() > 1) { - padded_values_shape.push_back(values.size(-1)); - } - Tensor padded_values = at::empty(padded_values_shape, values.options()); - Tensor padded_values_view = - values.dim() == 1 ? padded_values.unsqueeze(-1) : padded_values; - - AT_DISPATCH_ALL_TYPES_AND( - at::ScalarType::Half, - values.scalar_type(), - "jagged_to_padded_dense", - [&] { - jagged_dense_elementwise_dense_output_( - values_canonicalized, - offsets, - padded_values_view, // dummy not used in the lambda function - padded_values_view, - [](scalar_t x, scalar_t /*unused*/) -> scalar_t { return x; }, - static_cast(padding_value)); - }); - - return padded_values; -} - template Tensor jagged_dense_elementwise_dense_output_( const Tensor& x_values, @@ -413,6 +366,105 @@ Tensor jagged_dense_elementwise_jagged_output_( return output; } +class JaggedToPaddedDenseCPUOp + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const Tensor& values, + const std::vector& offsets, + const std::vector& max_lengths, + const int64_t padding_value) { + ctx->save_for_backward(offsets); + ctx->saved_data["total_L"] = values.size(0); + + const size_t num_jagged_dim = offsets.size(); + TORCH_CHECK( + max_lengths.size() == num_jagged_dim, + "max_lengths.size(), ", + max_lengths.size(), + " != num_jagged_dim, ", + num_jagged_dim); + + const Tensor values_canonicalized = values.view( + {values.size(0), + std::accumulate( + values.sizes().begin() + 1, + values.sizes().end(), + 1, + std::multiplies())}); + at::DimVector padded_values_shape({offsets[0].size(0) - 1}); + padded_values_shape.insert( + padded_values_shape.end(), max_lengths.begin(), max_lengths.end()); + if (values.dim() > 1) { + padded_values_shape.push_back(values.size(-1)); + } + Tensor padded_values = at::empty(padded_values_shape, values.options()); + Tensor padded_values_view = + values.dim() == 1 ? padded_values.unsqueeze(-1) : padded_values; + + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, + values.scalar_type(), + "jagged_to_padded_dense", + [&] { + jagged_dense_elementwise_dense_output_( + values_canonicalized, + offsets, + padded_values_view, // dummy not used in the lambda function + padded_values_view, + [](scalar_t x, scalar_t /*unused*/) -> scalar_t { return x; }, + static_cast(padding_value)); + }); + + return {padded_values}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_outputs) { + auto offsets = ctx->get_saved_variables(); + int32_t total_L = ctx->saved_data["total_L"].toInt(); + TORCH_CHECK(grad_outputs.size() == 1); + + TORCH_CHECK(total_L >= 0); + auto grad_padded_values = grad_outputs[0]; + + int32_t D = grad_padded_values.size(-1); + auto grad_values = at::empty({total_L, D}, grad_padded_values.options()); + + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, + grad_padded_values.scalar_type(), + "jagged_2d_to_dense_backward_kernel", + [&] { + jagged_dense_elementwise_jagged_output_( + grad_values, // dummy not used in the lambda function + {offsets}, + grad_padded_values, + grad_values, + [](scalar_t /*unused*/, scalar_t y) -> scalar_t { return y; }); + }); + + return { + grad_values, + torch::autograd::Variable(), // offsets + torch::autograd::Variable(), // max_lengths + torch::autograd::Variable(), // padding_value + }; + } +}; + +// Almost identical copy of jagged_to_padded_dense in jagged_tensor_ops_cpu.cpp +Tensor jagged_to_padded_dense( + const Tensor& values, + const std::vector& offsets, + const std::vector& max_lengths, + const int64_t padding_value = 0) { + return JaggedToPaddedDenseCPUOp::apply( + values, offsets, max_lengths, padding_value)[0]; +} + class JaggedDenseAddCPUOp : public torch::autograd::Function { public: @@ -423,7 +475,6 @@ class JaggedDenseAddCPUOp const Tensor& y) { ctx->save_for_backward(x_offsets); ctx->saved_data["x_values_shape"] = x_values.sizes(); - ctx->saved_data["y_shape"] = y.sizes(); Tensor output; AT_DISPATCH_FLOATING_TYPES_AND_HALF( @@ -442,7 +493,6 @@ class JaggedDenseAddCPUOp torch::autograd::variable_list grad_outputs) { auto offsets = ctx->get_saved_variables(); auto x_values_shape = ctx->saved_data["x_values_shape"].toIntVector(); - auto y_shape = ctx->saved_data["y_shape"].toIntVector(); TORCH_CHECK(grad_outputs.size() == 1); Tensor x_values_grad = at::empty(x_values_shape, grad_outputs[0].options()); @@ -472,6 +522,65 @@ Tensor jagged_dense_elementwise_add( return JaggedDenseAddCPUOp::apply(x_values, x_offsets, y)[0]; } +// Unlike JaggedDenseAddGPUOp that treats "zeros" as zeros so adding with +// a dense tensor results in a dense tensor, this operator treats "zeros" as +// undefined so resulting a jagged tensor. +class JaggedDenseJaggedOutputAddCPUOp + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y) { + ctx->save_for_backward(x_offsets); + ctx->saved_data["y_shape"] = y.sizes(); + + Tensor output; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + x_values.scalar_type(), "jagged_scalars", [&] { + output = jagged_dense_elementwise_jagged_output_( + x_values, x_offsets, y, [](scalar_t x, scalar_t y) -> scalar_t { + return x + y; + }); + }); + + return {output}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_outputs) { + auto offsets = ctx->get_saved_variables(); + auto y_shape = ctx->saved_data["y_shape"].toIntVector(); + TORCH_CHECK(grad_outputs.size() == 1); + + Tensor y_values_grad = jagged_to_padded_dense( + grad_outputs[0], + offsets, + std::vector(y_shape.begin() + 1, y_shape.end() - 1), + /*padding_value=*/0); + TORCH_CHECK(y_values_grad.sizes() == y_shape); + TORCH_CHECK(y_values_grad.sizes() == y_shape); + + return { + grad_outputs[0], + torch::autograd::Variable(), // x_offsets + y_values_grad}; + } +}; + +// output = x + y where x is jagged, y is dense, and output is jagged +std::tuple> +jagged_dense_elementwise_add_jagged_output( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y) { + return { + JaggedDenseJaggedOutputAddCPUOp::apply(x_values, x_offsets, y)[0], + x_offsets}; +} + template < int NUM_JAGGED_DIM, bool NO_INNER_DENSE, @@ -868,6 +977,8 @@ class BatchedDenseVecJagged2DMulCPUOp a_values_grad.accessor()); }); }); + } else { + v_grad.zero_(); } return { @@ -893,7 +1004,8 @@ jagged_2d_to_dense_forward_cpu(Tensor values, Tensor offsets, int64_t max_L) { TORCH_CHECK(offsets.dim() == 1); TORCH_CHECK(max_L > 0); - return jagged_to_padded_dense(values, {offsets}, {max_L}); + return jagged_to_padded_dense( + values, {offsets}, {max_L}, /*padding_value=*/0); } Tensor jagged_1d_to_dense_cpu( @@ -927,6 +1039,10 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { // jagged + dense -> dense m.def( "jagged_dense_elementwise_add(Tensor x_values, Tensor[] x_offsets, Tensor y) -> Tensor"); + // jagged + dense -> jagged (treat "zeros" in the jagged tensor as unknowns. + // output offsets is same as x_offsets) + m.def( + "jagged_dense_elementwise_add_jagged_output(Tensor x_values, Tensor[] x_offsets, Tensor y) -> (Tensor, Tensor[])"); // jagged * dense -> jagged (its offsets is same as x_offsets) m.def( "jagged_dense_elementwise_mul(Tensor x_values, Tensor[] x_offsets, Tensor y) -> (Tensor, Tensor[])"); @@ -941,6 +1057,9 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { DISPATCH_TO_CPU("jagged_to_padded_dense", fbgemm_gpu::jagged_to_padded_dense); DISPATCH_TO_CPU( "jagged_dense_elementwise_add", fbgemm_gpu::jagged_dense_elementwise_add); + DISPATCH_TO_CPU( + "jagged_dense_elementwise_add_jagged_output", + fbgemm_gpu::jagged_dense_elementwise_add_jagged_output); DISPATCH_TO_CPU( "jagged_dense_elementwise_mul", fbgemm_gpu::jagged_dense_elementwise_mul); DISPATCH_TO_CPU( diff --git a/fbgemm_gpu/src/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops_gpu.cpp index b3fddade6b..d3f9520311 100644 --- a/fbgemm_gpu/src/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops_gpu.cpp @@ -58,50 +58,6 @@ Tensor lookup_batched_unary_embedding_function( weight, table_offsets, offsets, indices)[0]; } -class Jagged2DToDenseGPUOp - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - Tensor values, - Tensor offsets, - int32_t max_sequence_length) { - int32_t total_L = values.size(0); - ctx->save_for_backward({offsets}); - ctx->saved_data["total_L"] = total_L; - - return { - jagged_2d_to_dense_forward_cuda(values, offsets, max_sequence_length)}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_outputs) { - const auto saved = ctx->get_saved_variables(); - auto savedItr = std::begin(saved); - auto offsets = *savedItr++; - int32_t total_L = ctx->saved_data["total_L"].toInt(); - - using torch::autograd::Variable; - auto grad_padded_values = grad_outputs[0]; - auto grad_values = - jagged_2d_to_dense_backward_cuda(grad_padded_values, offsets, total_L); - return { - grad_values, - Variable(), // offsets - Variable() // max_sequence_length - }; - } -}; - -Tensor jagged_2d_to_dense_gpu( - Tensor values, - Tensor offsets, - int64_t max_sequence_length) { - return Jagged2DToDenseGPUOp::apply( - values, offsets, static_cast(max_sequence_length))[0]; -} - class StackedJagged2DToDenseGPUOp : public torch::autograd::Function { public: @@ -192,7 +148,6 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { DISPATCH_TO_CUDA( "batched_unary_embeddings", fbgemm_gpu::lookup_batched_unary_embedding_function); - DISPATCH_TO_CUDA("jagged_2d_to_dense", fbgemm_gpu::jagged_2d_to_dense_gpu); DISPATCH_TO_CUDA("jagged_1d_to_dense", fbgemm_gpu::jagged_1d_to_dense_gpu); DISPATCH_TO_CUDA( "stacked_jagged_1d_to_dense", fbgemm_gpu::stacked_jagged_1d_to_dense_gpu); diff --git a/fbgemm_gpu/test/sparse_ops_test.py b/fbgemm_gpu/test/sparse_ops_test.py index 38b4502058..3e33c4298d 100644 --- a/fbgemm_gpu/test/sparse_ops_test.py +++ b/fbgemm_gpu/test/sparse_ops_test.py @@ -1724,12 +1724,21 @@ def test_jagged_to_padded_dense( torch.testing.assert_close(output, output_ref) + torch.autograd.gradcheck( + torch.ops.fbgemm.jagged_to_padded_dense, + ( + x_values.double().requires_grad_(True), + x_offsets, + max_lengths, + ), + ) + # pyre-ignore [56] @given( num_jagged_dim=st.integers(1, 4), outer_dense_size=st.integers(0, 4), inner_dense_size=st.integers(0, 4), - operation=st.sampled_from(["add", "mul"]), + operation=st.sampled_from(["add", "add_jagged_output", "mul"]), use_cpu=st.booleans() if gpu_available else st.just(True), ) @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) @@ -1758,6 +1767,21 @@ def test_jagged_elementwise_binary( output = torch.ops.fbgemm.jagged_dense_elementwise_add( x_values, x_offsets, y ) + elif operation == "add_jagged_output": + # from y values, create a jagged tensor and then densify + y_padded = self._to_padded_dense( + y.view(outer_dense_size * np.prod(max_lengths), inner_dense_size), + x_offsets, + max_lengths, + ) + output_ref = x_padded + y_padded + ( + output, + output_offsets, + ) = torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output( + x_values, x_offsets, y_padded + ) + output = self._to_padded_dense(output, output_offsets, max_lengths) elif operation == "mul": output_ref = x_padded * y output, output_offsets = torch.ops.fbgemm.jagged_dense_elementwise_mul(