Skip to content

Commit

Permalink
Add jagged slice op for cpu (#1690)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1690

The context why this is needed is as follows
1) For really long sparse features we want to split them into multiple chunks that can be fed into the model
2) Slicing requires users to require per row start point & a maximum L.

Based on these requirements, a custom op mimicing the slice semantics of a normal tensor works best.

An example usage using pseudo code

```
input_jagged_tensor = [[1, 2, 3, 4], [1, 2, 3], [1, 2, 3, 4, 5, 6], [1], [1, 2]]
start = [0, 0, 0, 0, 0]
slice_length = 3

>> jagged_slice(input_jagged_tensor, start, slice_length)

output_jagged_tensor = [[1, 2, 3], [1, 2, 3], [1, 2, 3], [1], [1, 2]]

```

A corresponding operation for dense tensor would look like
```
dense_tensor = torch.randn((8, 10))
slice_dense_tensor = dense_tensor[:, 1:3]
```

Differential Revision: D44299744

fbshipit-source-id: 2e4f0ff6901b5a919ff90f3b2e811b7ac071ecce
  • Loading branch information
Devashish Tyagi authored and facebook-github-bot committed Apr 17, 2023
1 parent e07dda2 commit 5cb7395
Show file tree
Hide file tree
Showing 5 changed files with 432 additions and 26 deletions.
112 changes: 112 additions & 0 deletions fbgemm_gpu/bench/jagged_tensor_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

import functools
import logging
import random
from typing import List, Tuple

import click
import fbgemm_gpu
import torch
from torch.profiler import profile

logging.basicConfig(level=logging.DEBUG)

Expand Down Expand Up @@ -419,5 +421,115 @@ def keyed_jagged_index_select_dim1_ref(
)


@cli.command()
@click.option("--max-seq-length", type=int, default=400)
@click.option("--input-batch-size", type=int, default=1024)
@click.option("--slice-length", type=int, default=10)
@click.option("--jagged-tensor-type", type=str, default="float")
def jagged_slice_cpu(
max_seq_length: int,
input_batch_size: int,
slice_length: int,
jagged_tensor_type: str,
) -> None:
jagged_tensor_types = {
"float": torch.float,
"half": torch.half,
"int": torch.int,
"long": torch.long,
}

if jagged_tensor_type not in jagged_tensor_types.keys():
raise AssertionError(
f"--jagged-tensor-type ({jagged_tensor_type}) is not supported"
)

jagged_tensor_dtype = jagged_tensor_types[jagged_tensor_type]
is_float = jagged_tensor_dtype in [torch.float, torch.half]

lengths = torch.randint(
low=0,
high=max_seq_length,
size=(input_batch_size,),
dtype=torch.long,
)
start_list = [random.randint(0, max(len_ - 1, 0)) for len_ in lengths.tolist()]
start = torch.tensor(start_list)

offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
if is_float:
values = torch.rand(
int(offsets[-1].item()),
dtype=jagged_tensor_dtype,
)
else:
values = torch.randint(
2**16,
(int(offsets[-1].item()),),
dtype=jagged_tensor_dtype,
)

time, output = benchmark_torch_function(
torch.ops.fbgemm.jagged_slice,
(values, lengths, start, slice_length),
iters=1000,
)

def jagged_slice_ref(
x_values: torch.Tensor,
offsets: torch.Tensor,
start: torch.Tensor,
max_L: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
end_offsets_ = max_L + start + offsets[:-1]
end_offsets = torch.where(end_offsets_ > offsets[1:], offsets[1:], end_offsets_)
start_offsets = start + offsets[:-1]
indices_to_select: List[torch.Tensor] = []
for i in range(end_offsets.size(0)):
indices_to_select.append(
torch.arange(start_offsets[i].item(), end_offsets[i].item())
)
output_ref = torch.index_select(x_values, 0, torch.cat(indices_to_select))
new_lengths = end_offsets - start_offsets
new_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(new_lengths)
return output_ref, new_offsets

time_ref, output = benchmark_torch_function(
jagged_slice_ref, (values, offsets, start, slice_length)
)

logging.info(f"jagged_slice forward time: {time * 1e3} ms, ref {time_ref * 1e3} ms")

profiler = profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=200,
warmup=100,
active=100,
),
record_shapes=True,
profile_memory=True,
with_stack=True,
with_flops=True,
)

profiler.start()
for _ in range(500):
torch.ops.fbgemm.jagged_slice(values, lengths, start, slice_length)
profiler.step()
profiler.stop()

logging.info(
"\n"
+ profiler.key_averages().table(sort_by="self_cuda_time_total", row_limit=10)
)

flops = sum(e.flops for e in profiler.events())
logging.info(f"Total Compute: {flops / 1e9} gflops")


if __name__ == "__main__":
cli()
15 changes: 15 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -752,5 +752,20 @@ at::Tensor jagged_index_add_2d_forward_cpu(
const int64_t num_dense_grad_rows,
const int64_t num_output_rows);

std::tuple<at::Tensor, at::Tensor> jagged_slice(
const at::Tensor& x_values,
const at::Tensor& x_lengths,
const at::Tensor& start,
const int64_t max_L);

at::Tensor jagged_slice_forward_cpu(
const at::Tensor& x_values,
const at::Tensor& x_lengths,
const at::Tensor& src_start,
const at::Tensor& output_lengths,
const at::Tensor& tgt_start,
const int64_t num_output_rows,
const int64_t max_L,
const bool fill_zeros);
#endif
} // namespace fbgemm_gpu
109 changes: 109 additions & 0 deletions fbgemm_gpu/src/jagged_tensor_ops_autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <torch/csrc/autograd/custom_function.h>
#include <torch/library.h>

#include "ATen/TensorUtils.h"
#include "fbgemm_gpu/sparse_ops.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

Expand Down Expand Up @@ -559,6 +560,103 @@ class JaggedIndexSelect2dOp
}
};

class JaggedSliceOp : public torch::autograd::Function<JaggedSliceOp> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const Tensor& values,
const Tensor& lengths,
const Tensor& start,
const int64_t slice_length) {
TENSOR_NDIM_EQUALS(values, 1);
TENSORS_ON_SAME_DEVICE(values, lengths);
TORCH_CHECK_TENSOR_ALL(start <= lengths, "start should be <= len");
TORCH_CHECK_TENSOR_ALL(start >= 0, "start should be always be positive");

Tensor output_lengths = (lengths - start).clamp_max(std::abs(slice_length));
// D2H sync here
const int64_t num_output_rows = output_lengths.sum().item<int64_t>();
const int64_t num_input_rows = lengths.sum().item<int64_t>();

Tensor tgt_start = at::zeros_like(lengths);

ctx->save_for_backward({lengths, output_lengths, start, tgt_start});
ctx->saved_data["num_output_rows"] = num_input_rows;
ctx->saved_data["slice_length"] = slice_length;

static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::jagged_slice_forward", "")
.typed<at::Tensor(
const Tensor& values,
const Tensor& lengths,
const Tensor& src_start,
const Tensor& output_lengths,
const Tensor& tgt_start,
const int64_t num_output_rows,
const int64_t max_L,
const bool fill_zeros)>();

return {
op.call(
values,
lengths,
start,
output_lengths,
tgt_start,
num_output_rows,
slice_length,
false),
output_lengths,
};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
TORCH_CHECK(grad_outputs.size() == 2);

const auto saved = ctx->get_saved_variables();
auto savedItr = std::begin(saved);
Tensor output_lengths = *savedItr++;
Tensor grad_lengths = *savedItr++;
Tensor tgt_start = *savedItr++;
Tensor src_start = *savedItr++;
Tensor grad = grad_outputs[0];

TENSORS_ON_SAME_DEVICE(grad, output_lengths);

const int64_t num_output_rows = ctx->saved_data["num_output_rows"].toInt();
const int64_t slice_length = ctx->saved_data["slice_length"].toInt();

static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::jagged_slice_forward", "")
.typed<at::Tensor(
const Tensor& values,
const Tensor& lengths,
const Tensor& src_start,
const Tensor& output_lengths,
const Tensor& tgt_start,
const int64_t num_output_rows,
const int64_t slice_length,
const bool fill_zeros)>();

return {
op.call(
grad,
grad_lengths,
src_start,
output_lengths,
tgt_start,
num_output_rows,
slice_length,
true),
torch::autograd::Variable(), // lengths
torch::autograd::Variable(), // start
torch::autograd::Variable() // max_L
};
}
};

} // namespace

///@ingroup jagged-tensor-ops-cpu
Expand Down Expand Up @@ -719,6 +817,16 @@ std::vector<Tensor> jagged_index_select_2d(
return JaggedIndexSelect2dOp::apply(values, lengths, indices);
}

std::tuple<Tensor, Tensor> jagged_slice(
const Tensor& values,
const Tensor& lengths,
const Tensor& start,
const int64_t slice_length) {
const auto output =
JaggedSliceOp::apply(values, lengths, start, slice_length);
return {output[0], output[1]};
}

} // namespace fbgemm_gpu

TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
Expand All @@ -740,4 +848,5 @@ TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
m.impl("jagged_jagged_bmm", TORCH_FN(fbgemm_gpu::jagged_jagged_bmm));
m.impl("jagged_dense_bmm", TORCH_FN(fbgemm_gpu::jagged_dense_bmm));
m.impl("jagged_index_select", TORCH_FN(fbgemm_gpu::jagged_index_select_2d));
m.impl("jagged_slice", TORCH_FN(fbgemm_gpu::jagged_slice));
}
95 changes: 95 additions & 0 deletions fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1481,6 +1481,95 @@ Tensor jagged_dense_bmm_forward(
return output;
}

template <typename scalar_t, typename offset_t>
void jagged_slice_forward_cpu_kernel(
at::TensorAccessor<scalar_t, 1> output,
const at::TensorAccessor<offset_t, 1>& output_lengths,
const at::TensorAccessor<offset_t, 1>& output_offsets,
const at::TensorAccessor<offset_t, 1>& tgt_start,
const at::TensorAccessor<scalar_t, 1>& input,
const at::TensorAccessor<offset_t, 1>& input_lengths,
const at::TensorAccessor<offset_t, 1>& input_offsets,
const at::TensorAccessor<offset_t, 1>& src_start,
const int64_t slice_length) {
const auto B = output_offsets.size(0);

// TODO (devashisht) parallelize this loop
for (const auto row_i : c10::irange(B)) {
const int64_t output_offset_start = output_offsets[row_i];
const int64_t input_offset_start = input_offsets[row_i];
for (const auto col_i : c10::irange(slice_length)) {
if (tgt_start[row_i] + col_i >= output_lengths[row_i] &&
src_start[row_i] + col_i >= input_lengths[row_i]) {
break;
}
const int64_t output_offset =
output_offset_start + tgt_start[row_i] + col_i;
const int64_t input_offset =
input_offset_start + src_start[row_i] + col_i;
output[output_offset] = input[input_offset];
}
}
}

/// Slice the jagged dim to max length from slice_length,
/// from start point `start`. This is a jagged -> jagged op
/// @param x_values - X values of shape B * J_DIM where J_DIM is
/// jagged dim
/// @param x_lengths - length along jagged dim
/// @param src_start - start of slice operation from the src tensor
/// @param output_lengths - length of jagged dim for output tensor
/// @param tgt_start - position to start filling in sliced values from source
/// @param num_output_rows - output dense dim
/// @param slice_length - length of jagged dim to slice
/// @param fill_zeros - option exists as an optimization, we can reuse
/// the same code path for forward & backward. For backward
/// we need to fill zeros in output tensor but fwd we don't.
Tensor jagged_slice_forward_cpu(
const Tensor& x_values,
const Tensor& x_lengths,
const Tensor& src_start,
const Tensor& output_lengths,
const Tensor& tgt_start,
const int64_t num_output_rows,
const int64_t slice_length,
const bool fill_zeros) {
TENSOR_ON_CPU(x_values);
TENSOR_ON_CPU(x_lengths);
TENSOR_NDIM_EQUALS(x_values, 1);
TENSOR_NDIM_EQUALS(x_lengths, 1);

auto output_values = [fill_zeros, num_output_rows, &x_values]() -> Tensor {
if (fill_zeros) {
return at::zeros({num_output_rows}, x_values.options());
} else {
return at::empty({num_output_rows}, x_values.options());
}
}();
auto output_offsets = asynchronous_exclusive_cumsum_cpu(output_lengths);
auto input_offsets = asynchronous_exclusive_cumsum_cpu(x_lengths);

AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
x_values.scalar_type(),
"jagged_slice_wrapper_1",
[&] {
jagged_slice_forward_cpu_kernel<scalar_t>(
output_values.accessor<scalar_t, 1>(),
output_lengths.accessor<int64_t, 1>(),
output_offsets.accessor<int64_t, 1>(),
tgt_start.accessor<int64_t, 1>(),
x_values.accessor<scalar_t, 1>(),
x_lengths.accessor<int64_t, 1>(),
input_offsets.accessor<int64_t, 1>(),
src_start.accessor<int64_t, 1>(),
slice_length);
});

return output_values;
}

} // namespace fbgemm_gpu

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
Expand Down Expand Up @@ -1555,6 +1644,11 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"jagged_dense_bmm(Tensor x_values, Tensor x_offsets, Tensor y, int max_L) -> (Tensor, Tensor)");
m.def(
"jagged_dense_bmm_forward(Tensor x_values, Tensor x_offsets, Tensor y, int max_L) -> Tensor");
// jagged -> jagged
m.def(
"jagged_slice(Tensor x_values, Tensor x_lengths, Tensor start, int slice_length) -> (Tensor, Tensor)");
m.def(
"jagged_slice_forward(Tensor x_values, Tensor x_lengths, Tensor src_start, Tensor output_lengths, Tensor tgt_start, int num_output_rows, int slice_length, bool fill_zeros) -> Tensor");
}

TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
Expand Down Expand Up @@ -1623,4 +1717,5 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
DISPATCH_TO_CPU("jagged_dense_bmm", fbgemm_gpu::jagged_dense_bmm);
DISPATCH_TO_CPU(
"jagged_dense_bmm_forward", fbgemm_gpu::jagged_dense_bmm_forward);
DISPATCH_TO_CPU("jagged_slice_forward", fbgemm_gpu::jagged_slice_forward_cpu);
}
Loading

0 comments on commit 5cb7395

Please sign in to comment.