Skip to content

Commit

Permalink
Merge branch 'main' into malfet-patch-2
Browse files Browse the repository at this point in the history
  • Loading branch information
malfet authored Oct 16, 2023
2 parents ce122a3 + 924f310 commit 83f0f43
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
3 changes: 3 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ at::Tensor asynchronous_inclusive_cumsum_cpu(const at::Tensor& t_in);
///@ingroup sparse-data-cuda
at::Tensor asynchronous_complete_cumsum_meta(const at::Tensor& t_in);

///@ingroup sparse-data-cuda
at::Tensor asynchronous_exclusive_cumsum_meta(const at::Tensor& t_in);

///@ingroup sparse-data-cuda
at::Tensor offsets_range_cuda(const at::Tensor& offsets, int64_t range_size);

Expand Down
8 changes: 4 additions & 4 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ Tensor asynchronous_complete_cumsum_meta(const Tensor& t_in) {
return output;
}

Tensor asynchronous_exclusive_cumsum_meta(const Tensor& t_in) {
return at::zeros_symint(t_in.sym_sizes(), t_in.options());
}

namespace {

Tensor pack_segments_forward_meta(
Expand Down Expand Up @@ -77,10 +81,6 @@ Tensor asynchronous_inclusive_cumsum_meta(const Tensor& t_in) {
return at::empty_symint(t_in.sym_sizes(), t_in.options());
}

Tensor asynchronous_exclusive_cumsum_meta(const Tensor& t_in) {
return at::empty_symint(t_in.sym_sizes(), t_in.options());
}

} // namespace

} // namespace fbgemm_gpu
Expand Down
6 changes: 3 additions & 3 deletions fbgemm_gpu/test/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,11 +610,11 @@ def test_cumsum(self, n: int, long_index: bool) -> None:

# meta tests
mx = torch.randint(low=0, high=100, size=(n,)).type(index_dtype).to("meta")
# mze = torch.ops.fbgemm.asynchronous_exclusive_cumsum(mx)
mze = torch.ops.fbgemm.asynchronous_exclusive_cumsum(mx)
self.assertEqual(ze.size(), mze.size())
# mzi = torch.ops.fbgemm.asynchronous_inclusive_cumsum(mx)
mzc = torch.ops.fbgemm.asynchronous_complete_cumsum(mx)
# self.assertEqual(ze.size(), mze.size())
# self.assertEqual(zi.size(), mzi.size())
mzc = torch.ops.fbgemm.asynchronous_complete_cumsum(mx)
self.assertEqual(zc.size(), mzc.size())

if gpu_available:
Expand Down

0 comments on commit 83f0f43

Please sign in to comment.