Skip to content

Commit

Permalink
Add jagged slice op for cpu (pytorch#1690)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1690

The context why this is needed is as follows
1) For really long sparse features we want to split them into multiple chunks that can be fed into the model
2) Slicing requires users to require per row start point & a maximum L.

Based on these requirements, a custom op mimicing the slice semantics of a normal tensor works best.

An example usage using pseudo code

```
input_jagged_tensor = [[1, 2, 3, 4], [1, 2, 3], [1, 2, 3, 4, 5, 6], [1], [1, 2]]
start = [0, 0, 0, 0, 0]
slice_length = 3

>> jagged_slice(input_jagged_tensor, start, slice_length)

output_jagged_tensor = [[1, 2, 3], [1, 2, 3], [1, 2, 3], [1], [1, 2]]

```

A corresponding operation for dense tensor would look like
```
dense_tensor = torch.randn((8, 10))
slice_dense_tensor = dense_tensor[:, 1:3]
```

Differential Revision: D44299744

fbshipit-source-id: c6988370060bff6dda3d162d7d78a9d2e5b7bb65
  • Loading branch information
Devashish Tyagi authored and facebook-github-bot committed Apr 7, 2023
1 parent 3eaae90 commit debba83
Show file tree
Hide file tree
Showing 5 changed files with 434 additions and 26 deletions.
114 changes: 114 additions & 0 deletions fbgemm_gpu/bench/jagged_tensor_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

import functools
import logging
import random
from typing import List, Tuple

import click
import fbgemm_gpu
import torch
from torch.profiler import profile

logging.basicConfig(level=logging.DEBUG)

Expand Down Expand Up @@ -419,5 +421,117 @@ def keyed_jagged_index_select_dim1_ref(
)


@cli.command()
@click.option("--max-seq-length", type=int, default=400)
@click.option("--input-batch-size", type=int, default=1024)
@click.option("--slice-length", type=int, default=10)
@click.option("--jagged-tensor-type", type=str, default="float")
def jagged_slice_cpu(
max_seq_length: int,
input_batch_size: int,
slice_length: int,
jagged_tensor_type: str,
) -> None:
jagged_tensor_types = {
"float": torch.float,
"half": torch.half,
"int": torch.int,
"long": torch.long,
}

if jagged_tensor_type not in jagged_tensor_types.keys():
raise AssertionError(
f"--jagged-tensor-type ({jagged_tensor_type}) is not supported"
)

jagged_tensor_dtype = jagged_tensor_types[jagged_tensor_type]
is_float = jagged_tensor_dtype in [torch.float, torch.half]

lengths = torch.randint(
low=0,
high=max_seq_length,
size=(input_batch_size,),
dtype=torch.long,
)
start_list = [random.randint(0, max(len_ - 1, 0)) for len_ in lengths.tolist()]
start = torch.tensor(start_list)

offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
if is_float:
values = torch.rand(
int(offsets[-1].item()),
dtype=jagged_tensor_dtype,
)
else:
values = torch.randint(
2**16,
(int(offsets[-1].item()),),
dtype=jagged_tensor_dtype,
)

print(f"Starting to benchmark")
time, output = benchmark_torch_function(
torch.ops.fbgemm.jagged_slice,
(values, lengths, start, slice_length),
iters=1000,
)
print(f"Done benchmarking")

def jagged_slice_ref(
x_values: torch.Tensor,
offsets: torch.Tensor,
start: torch.Tensor,
max_L: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
end_offsets_ = max_L + start + offsets[:-1]
end_offsets = torch.where(end_offsets_ > offsets[1:], offsets[1:], end_offsets_)
start_offsets = start + offsets[:-1]
indices_to_select: List[torch.Tensor] = []
for i in range(end_offsets.size(0)):
indices_to_select.append(
torch.arange(start_offsets[i].item(), end_offsets[i].item())
)
output_ref = torch.index_select(x_values, 0, torch.cat(indices_to_select))
new_lengths = end_offsets - start_offsets
new_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(new_lengths)
return output_ref, new_offsets

time_ref, output = benchmark_torch_function(
jagged_slice_ref, (values, offsets, start, slice_length)
)

logging.info(f"jagged_slice forward time: {time * 1e3} ms, ref {time_ref * 1e3} ms")

profiler = profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=200,
warmup=100,
active=100,
),
record_shapes=True,
profile_memory=True,
with_stack=True,
with_flops=True,
)

profiler.start()
for _ in range(500):
torch.ops.fbgemm.jagged_slice(values, lengths, start, slice_length)
profiler.step()
profiler.stop()

logging.info(
"\n"
+ profiler.key_averages().table(sort_by="self_cuda_time_total", row_limit=10)
)

flops = sum(e.flops for e in profiler.events())
logging.info(f"Total Compute: {flops / 1e9} gflops")


if __name__ == "__main__":
cli()
15 changes: 15 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -752,5 +752,20 @@ at::Tensor jagged_index_add_2d_forward_cpu(
const int64_t num_dense_grad_rows,
const int64_t num_output_rows);

std::tuple<at::Tensor, at::Tensor> jagged_slice(
const at::Tensor& x_values,
const at::Tensor& x_lengths,
const at::Tensor& start,
const int64_t max_L);

at::Tensor jagged_slice_forward_cpu(
const at::Tensor& x_values,
const at::Tensor& x_lengths,
const at::Tensor& src_start,
const at::Tensor& output_lengths,
const at::Tensor& tgt_start,
const int64_t num_output_rows,
const int64_t max_L,
const bool fill_zeros);
#endif
} // namespace fbgemm_gpu
109 changes: 109 additions & 0 deletions fbgemm_gpu/src/jagged_tensor_ops_autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <torch/csrc/autograd/custom_function.h>
#include <torch/library.h>

#include "ATen/TensorUtils.h"
#include "fbgemm_gpu/sparse_ops.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

Expand Down Expand Up @@ -559,6 +560,103 @@ class JaggedIndexSelect2dOp
}
};

class JaggedSliceOp : public torch::autograd::Function<JaggedSliceOp> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const Tensor& values,
const Tensor& lengths,
const Tensor& start,
const int64_t slice_length) {
TENSOR_NDIM_EQUALS(values, 1);
TENSORS_ON_SAME_DEVICE(values, lengths);
TORCH_CHECK_TENSOR_ALL(start <= lengths, "start should be <= len");
TORCH_CHECK_TENSOR_ALL(start >= 0, "start should be always be positive");

Tensor output_lengths = (lengths - start).clamp_max(std::abs(slice_length));
// D2H sync here
const int64_t num_output_rows = output_lengths.sum().item<int64_t>();
const int64_t num_input_rows = lengths.sum().item<int64_t>();

Tensor tgt_start = at::zeros_like(lengths);

ctx->save_for_backward({lengths, output_lengths, start, tgt_start});
ctx->saved_data["num_output_rows"] = num_input_rows;
ctx->saved_data["slice_length"] = slice_length;

static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::jagged_slice_forward", "")
.typed<at::Tensor(
const Tensor& values,
const Tensor& lengths,
const Tensor& src_start,
const Tensor& output_lengths,
const Tensor& tgt_start,
const int64_t num_output_rows,
const int64_t max_L,
const bool fill_zeros)>();

return {
op.call(
values,
lengths,
start,
output_lengths,
tgt_start,
num_output_rows,
slice_length,
false),
output_lengths,
};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
TORCH_CHECK(grad_outputs.size() == 2);

const auto saved = ctx->get_saved_variables();
auto savedItr = std::begin(saved);
Tensor output_lengths = *savedItr++;
Tensor grad_lengths = *savedItr++;
Tensor tgt_start = *savedItr++;
Tensor src_start = *savedItr++;
Tensor grad = grad_outputs[0];

TENSORS_ON_SAME_DEVICE(grad, output_lengths);

const int64_t num_output_rows = ctx->saved_data["num_output_rows"].toInt();
const int64_t slice_length = ctx->saved_data["slice_length"].toInt();

static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::jagged_slice_forward", "")
.typed<at::Tensor(
const Tensor& values,
const Tensor& lengths,
const Tensor& src_start,
const Tensor& output_lengths,
const Tensor& tgt_start,
const int64_t num_output_rows,
const int64_t slice_length,
const bool fill_zeros)>();

return {
op.call(
grad,
grad_lengths,
src_start,
output_lengths,
tgt_start,
num_output_rows,
slice_length,
true),
torch::autograd::Variable(), // lengths
torch::autograd::Variable(), // start
torch::autograd::Variable() // max_L
};
}
};

} // namespace

///@ingroup jagged-tensor-ops-cpu
Expand Down Expand Up @@ -719,6 +817,16 @@ std::vector<Tensor> jagged_index_select_2d(
return JaggedIndexSelect2dOp::apply(values, lengths, indices);
}

std::tuple<Tensor, Tensor> jagged_slice(
const Tensor& values,
const Tensor& lengths,
const Tensor& start,
const int64_t slice_length) {
const auto output =
JaggedSliceOp::apply(values, lengths, start, slice_length);
return {output[0], output[1]};
}

} // namespace fbgemm_gpu

TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
Expand All @@ -740,4 +848,5 @@ TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
m.impl("jagged_jagged_bmm", TORCH_FN(fbgemm_gpu::jagged_jagged_bmm));
m.impl("jagged_dense_bmm", TORCH_FN(fbgemm_gpu::jagged_dense_bmm));
m.impl("jagged_index_select", TORCH_FN(fbgemm_gpu::jagged_index_select_2d));
m.impl("jagged_slice", TORCH_FN(fbgemm_gpu::jagged_slice));
}
Loading

0 comments on commit debba83

Please sign in to comment.