Skip to content

Commit

Permalink
create pack_segments_v2 with additional pad_minf and presence_mask fu…
Browse files Browse the repository at this point in the history
…nctionality

Summary: As titled. Since we need to change function signature, it is better to have separate API

Differential Revision: D66619431
  • Loading branch information
brad-mengchi authored and facebook-github-bot committed Dec 1, 2024
1 parent 6eb379a commit 5e14515
Show file tree
Hide file tree
Showing 6 changed files with 480 additions and 5 deletions.
23 changes: 23 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,29 @@ at::Tensor pack_segments_forward_cuda(
const at::Tensor& lengths,
int64_t max_length);

///@ingroup sparse-data-cpu
std::tuple<at::Tensor, std::optional<at::Tensor>> pack_segments_cpu_v2(
const at::Tensor& t_in,
const at::Tensor& lengths,
int64_t max_length,
const bool pad_minf,
const bool return_presence_mask);

///@ingroup sparse-data-cuda
std::tuple<at::Tensor, std::optional<at::Tensor>> pack_segments_cuda_v2(
const at::Tensor& t_in,
const at::Tensor& lengths,
int64_t max_length,
const bool pad_minf,
const bool return_presence_mask);

std::tuple<at::Tensor, std::optional<at::Tensor>> pack_segments_forward_cuda_v2(
const at::Tensor& t_in,
const at::Tensor& lengths,
int64_t max_length,
const bool pad_minf,
const bool return_presence_mask);

at::Tensor pack_segments_backward_cuda(
const at::Tensor& data,
const at::Tensor& lengths,
Expand Down
186 changes: 186 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,75 @@ class PackSegments : public torch::autograd::Function<PackSegments> {
}
};

class PackSegmentsV2 : public torch::autograd::Function<PackSegmentsV2> {
public:
static constexpr bool is_traceable = true;
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const Tensor& t_in,
const Tensor& lengths,
at::SymInt max_length,
const bool pad_minf,
const bool return_presence_mask) {
const at::SymInt total_length = t_in.sym_size(0);

at::AutoDispatchBelowADInplaceOrView guard;

static auto custom_pack_segments_op =
at::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::pack_segments_v2", "")
.typed<std::tuple<Tensor, std::optional<Tensor>>(
const at::Tensor&,
const at::Tensor&,
const at::SymInt,
const bool,
const bool)>();

const auto& res = custom_pack_segments_op.call(
t_in, lengths, max_length, pad_minf, return_presence_mask);

ctx->saved_data["max_length"] = max_length;
ctx->saved_data["total_length"] = total_length;
ctx->save_for_backward({lengths});

int num_ouputs = return_presence_mask ? 2 : 1;
torch::autograd::variable_list outputs(num_ouputs);
if (return_presence_mask) {
outputs[1] = std::get<1>(res).value();
}
outputs[0] = std::get<0>(res);
return outputs;
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
TORCH_CHECK(grad_output.size() == 2 or grad_output.size() == 1);
const Tensor& grad = grad_output[0];
const auto& max_length = ctx->saved_data["max_length"].toSymInt();
const auto& total_length = ctx->saved_data["total_length"].toSymInt();

// Retrieve saved variables for backward.
const auto& saved_variables = ctx->get_saved_variables();
const auto& lengths = saved_variables[0];

torch::autograd::variable_list grad_inputs(5);

static auto custom_pack_segments_backward_op =
at::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::pack_segments_backward", "")
.typed<at::Tensor(
const at::Tensor&,
const at::Tensor&,
const at::SymInt,
const at::SymInt)>();

grad_inputs[0] = custom_pack_segments_backward_op.call(
grad, lengths, total_length, max_length);
return grad_inputs;
}
};

Tensor pack_segments_autograd(
const Tensor& t_in,
const Tensor& lengths,
Expand All @@ -130,6 +199,21 @@ Tensor pack_segments_autograd(
return PackSegments::apply(t_in, lengths, max_length)[0];
}

std::tuple<Tensor, std::optional<Tensor>> pack_segments_autograd_v2(
const Tensor& t_in,
const Tensor& lengths,
const at::SymInt max_length,
const bool pad_minf,
const bool return_presence_mask) {
const auto& res = PackSegmentsV2::apply(
t_in, lengths, max_length, pad_minf, return_presence_mask);
std::optional<Tensor> presence_mask;
if (return_presence_mask) {
presence_mask = res[1];
}
return {res[0], presence_mask};
}

template <typename T>
void prefix_sum(const int length, const T* const array, T* const presum) {
presum[0] = 0;
Expand Down Expand Up @@ -2804,6 +2888,87 @@ Tensor pack_segments_forward_cpu(
return packed_tensor;
}

std::tuple<Tensor, std::optional<Tensor>> pack_segments_forward_cpu_v2(
const Tensor& t_in,
const Tensor& lengths,
const int64_t max_length,
const bool pad_minf,
const bool return_presence_mask) {
TENSOR_NDIM_IS_GE(t_in, 1);
TENSOR_NDIM_EQUALS(lengths, 1);
TORCH_CHECK(
t_in.dtype() == at::ScalarType::Float ||
t_in.dtype() == at::ScalarType::Half ||
t_in.dtype() == at::ScalarType::BFloat16 ||
t_in.dtype() == at::ScalarType::Int ||
t_in.dtype() == at::ScalarType::Long,
"t_in must be of type float, half, bfloat16, int or long");
TORCH_CHECK_GT(max_length, 0);

const auto t_in_cont = t_in.expect_contiguous();
Tensor packed_tensor;
std::optional<Tensor> presence_mask;

AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "pack_segments_cpu", ([&]() {
const auto* const lengths_data = lengths.data_ptr<index_t>();

// Shape of output is batch_size x max_len x ...
auto shape = t_in_cont->sizes().vec(); // Get copy of current shape
shape[0] = max_length; // Set first element to max_len
shape.insert(
shape.begin(), lengths.numel()); // Insert batch size at beginning
if (pad_minf) {
// Downcasting double infinity to float should still give infinity
packed_tensor = at::full(
shape,
-std::numeric_limits<double>::infinity(),
t_in_cont->options());
} else {
packed_tensor = at::zeros(shape, t_in_cont->options());
}

bool* presence_mask_data = nullptr;
if (return_presence_mask) {
// Shape of presence is batch_size x max_len
presence_mask = at::zeros({lengths.numel(), max_length}, at::kBool);
presence_mask_data = presence_mask->data_ptr<bool>();
}

if (t_in_cont->sizes()[0] == 0) {
return; // Return empty output (with the proper shape)
}

FBGEMM_DISPATCH_ALL_TYPES(
t_in_cont->scalar_type(), "pack_segments_cpu-packing", ([&]() {
const auto sizes =
t_in_cont->sizes().slice(1, t_in_cont->sizes().size() - 1);
const auto block_size = c10::multiply_integers(sizes);
const auto block_bytesize = t_in_cont->itemsize() * block_size;
const auto* const data_ptr = t_in_cont->data_ptr<scalar_t>();
auto* const out_data = packed_tensor.data_ptr<scalar_t>();
int64_t start = 0;
for (const auto i : c10::irange(lengths.sizes()[0])) {
const auto len =
std::min(static_cast<int64_t>(lengths_data[i]), max_length);
std::memcpy(
out_data + block_size * max_length * i, // dst
data_ptr + block_size * start, // src
len * block_bytesize);
if (return_presence_mask) {
std::fill(
presence_mask_data + max_length * i,
presence_mask_data + max_length * i + len,
true);
}
start += lengths_data[i];
}
}));
}));

return {packed_tensor, presence_mask};
}

/// Map N+1 dim tensor to N dim based on lengths tensor
/// Sequences that are shorter than the longest sequence are padded with zeros.
/// @param data N+1 dim Tensor.
Expand Down Expand Up @@ -2877,13 +3042,29 @@ Tensor pack_segments_backward_cpu(

return unpacked_tensor;
}

Tensor pack_segments_cpu(
const Tensor& t_in,
const Tensor& lengths,
const int64_t max_length) {
return pack_segments_forward_cpu(t_in, lengths, max_length);
}

std::tuple<Tensor, std::optional<Tensor>> pack_segments_cpu_v2(
const Tensor& t_in,
const Tensor& lengths,
const int64_t max_length,
const bool pad_minf,
const bool return_presence_mask // https://fburl.com/code/ol14vkbn
) {
const auto& res = pack_segments_forward_cpu_v2(
t_in, lengths, max_length, pad_minf, return_presence_mask);
if (return_presence_mask) {
return std::make_tuple(std::get<0>(res), std::get<1>(res));
}
return std::make_tuple(std::get<0>(res), std::nullopt);
}

torch::autograd::variable_list group_index_select_dim0_autograd_impl(
at::TensorList all_indices_input,
const int64_t group_size) {
Expand Down Expand Up @@ -3263,6 +3444,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"pack_segments(Tensor t_in, Tensor lengths, SymInt max_length) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"pack_segments_v2(Tensor t_in, Tensor lengths, SymInt max_length, bool pad_minf=False, bool return_presence_mask=False) -> (Tensor packed_segments, Tensor? presence_mask)",
{PT2_COMPLIANT_TAG});
m.def(
"pack_segments_backward(Tensor data, Tensor lengths, SymInt total_length, SymInt max_length) -> Tensor");
// A specialization of at::index_select for selecting dim 0
Expand Down Expand Up @@ -3371,6 +3555,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
"permute_sequence_embeddings",
fbgemm_gpu::permute_sequence_embeddings_cpu);
DISPATCH_TO_CPU("pack_segments", fbgemm_gpu::pack_segments_cpu);
DISPATCH_TO_CPU("pack_segments_v2", fbgemm_gpu::pack_segments_cpu_v2);
DISPATCH_TO_CPU(
"pack_segments_backward", fbgemm_gpu::pack_segments_backward_cpu);
DISPATCH_TO_CPU("index_select_dim0", fbgemm_gpu::index_select_dim0);
Expand All @@ -3387,6 +3572,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {

TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
m.impl("pack_segments", &fbgemm_gpu::pack_segments_autograd);
m.impl("pack_segments_v2", &fbgemm_gpu::pack_segments_autograd_v2);
m.impl(
"group_index_select_dim0_gpu_impl",
&fbgemm_gpu::group_index_select_dim0_autograd_impl);
Expand Down
12 changes: 12 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,16 @@ Tensor pack_segments_cuda(
return fbgemm_gpu::pack_segments_forward_cuda(t_in, lengths, max_length)[0];
}

std::tuple<Tensor, std::optional<Tensor>> pack_segments_cuda_v2(
const Tensor& t_in,
const Tensor& lengths,
const int64_t max_length,
const bool pad_minf,
const bool return_presence_mask) {
return fbgemm_gpu::pack_segments_forward_cuda_v2(
t_in, lengths, max_length, pad_minf, return_presence_mask);
}

Tensor index_select_dim0_gpu(
const Tensor& input,
const Tensor& indices,
Expand Down Expand Up @@ -583,6 +593,8 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
"generic_histogram_binning_calibration_by_feature",
fbgemm_gpu::generic_histogram_binning_calibration_by_feature_cuda);
DISPATCH_TO_CUDA("pack_segments", fbgemm_gpu::pack_segments_forward_cuda);
DISPATCH_TO_CUDA(
"pack_segments_v2", fbgemm_gpu::pack_segments_forward_cuda_v2);
DISPATCH_TO_CUDA(
"pack_segments_backward", fbgemm_gpu::pack_segments_backward_cuda);
DISPATCH_TO_CUDA("index_select_dim0", fbgemm_gpu::index_select_dim0_gpu);
Expand Down
17 changes: 17 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,21 @@ Tensor pack_segments_forward_meta(
return at::empty_symint(padded_values_shape, t_in.options());
}

std::tuple<Tensor, std::optional<Tensor>> pack_segments_forward_meta_v2(
const Tensor& t_in,
const Tensor& lengths,
const at::SymInt max_length,
const bool pad_minf,
const bool return_presence_mask) {
at::SymDimVector padded_values_shape({lengths.sym_numel(), max_length});

for (const auto i : c10::irange(1, t_in.dim())) {
padded_values_shape.push_back(t_in.sym_size(i));
}
return std::make_tuple(
at::empty_symint(padded_values_shape, t_in.options()), std::nullopt);
}

Tensor pack_segments_backward_meta(
const at::Tensor& data,
const at::Tensor& lengths,
Expand Down Expand Up @@ -86,6 +101,8 @@ Tensor asynchronous_inclusive_cumsum_meta(const Tensor& t_in) {

TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl("pack_segments", TORCH_FN(fbgemm_gpu::pack_segments_forward_meta));
m.impl(
"pack_segments_v2", TORCH_FN(fbgemm_gpu::pack_segments_forward_meta_v2));
m.impl(
"pack_segments_backward",
TORCH_FN(fbgemm_gpu::pack_segments_backward_meta));
Expand Down
Loading

0 comments on commit 5e14515

Please sign in to comment.