Skip to content

Commit

Permalink
feat: support embedding_bag converter (1D input) (#2395)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored Oct 25, 2023
1 parent d649d12 commit cb20f90
Show file tree
Hide file tree
Showing 3 changed files with 321 additions and 3 deletions.
45 changes: 45 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,51 @@ def aten_ops_embedding(
)


def embedding_bag_validator(node: Node) -> bool:
mode = args_bounds_check(node.args, 4, 0)
indices = node.args[1].meta.get("tensor_meta")
if indices is None:
return False
return (
bool(node.args[2].op == "get_attr")
and (mode == 0 or mode == 1 or mode == 2)
and len(indices.shape) == 1
)


@dynamo_tensorrt_converter(torch.ops.aten.embedding_bag.default, capability_validator=embedding_bag_validator) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten._embedding_bag.default, capability_validator=embedding_bag_validator) # type: ignore[misc]
@enforce_tensor_types(
{
0: (TRTTensor,),
1: (TRTTensor,),
2: (np.ndarray, torch.Tensor),
}
) # type: ignore[misc]
def aten_ops_embedding_bag(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.embedding.embedding_bag(
ctx,
target,
SourceIR.ATEN,
name,
weight=args[0],
indices=args[1],
offsets=args[2],
scale_grad_by_freq=args_bounds_check(args, 3, False),
mode=args_bounds_check(args, 4, 0),
sparse=args_bounds_check(args, 5, False),
per_sample_weights=args_bounds_check(args, 6, None),
include_last_offset=args_bounds_check(args, 7, False),
# padding index is useful for training only
)


@dynamo_tensorrt_converter(torch.ops.aten.fmod.Scalar) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.fmod.Tensor) # type: ignore[misc]
def aten_ops_fmod(
Expand Down
138 changes: 135 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/embedding.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Optional
import functools
from typing import Optional, Sequence, Tuple, Union

import numpy as np
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor, to_numpy
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor

Expand Down Expand Up @@ -40,5 +43,134 @@ def embedding(

# Implement embedding lookup with gather layer
gather_layer = ctx.net.add_gather(embedding_tensor, indices_tensor, axis=0)
set_layer_name(gather_layer, target, name + "_gather", source_ir)
set_layer_name(gather_layer, target, f"{name}_gather", source_ir)
return gather_layer.get_output(0)


def embedding_bag(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
weight: TRTTensor,
indices: TRTTensor,
offsets: Union[torch.Tensor, np.ndarray, Sequence[int]],
scale_grad_by_freq: bool,
mode: int,
sparse: bool,
per_sample_weights: Optional[TRTTensor],
include_last_offset: bool,
) -> Tuple[TRTTensor, TRTTensor, TRTTensor, TRTTensor]:
"""
This function is for calculating embedding bags.
In PyTorch, `offsets` is only used when input is 1D. If input is 2D of shape (B, N),
it will be treated as B bags (sequences) each of fixed length N, and this will return
B values aggregated in a way depending on the mode. `offsets` is ignored and required
to be None in this case.
However, according to the schema, `offsets` is required for input with any dimensions.
Accordingly, this function flattens N-D input to 1D and then to calculate embedding bags.
"""

# TODO: support 2D inputs
# indices = impl.shuffle.reshape(ctx, target, source_ir, f"{name}_reshape_indices", indices, (-1,))

if mode == 0: # sum
reduce_op = functools.partial(
impl.reduce.sum, ctx=ctx, target=target, source_ir=source_ir
)
reduce_name = "sum"
elif mode == 1: # mean
reduce_op = functools.partial(
impl.reduce.mean, ctx=ctx, target=target, source_ir=source_ir
)
reduce_name = "mean"
elif mode == 2: # max
reduce_op = functools.partial(
impl.reduce.max,
ctx=ctx,
target=target,
source_ir=source_ir,
return_indices=False,
)
reduce_name = "max"

# calculate embedding
embed = embedding(
ctx,
target,
source_ir,
f"{name}_embedding",
indices,
weight,
scale_grad_by_freq,
sparse,
)

# give weights to embedding
if per_sample_weights is not None:
assert (
per_sample_weights.shape == indices.shape
), f"`per_sample_weights` (shape: {per_sample_weights.shape}) must have exactly the same shape as indices/input (shape: {indices.shape})!"
per_sample_weights = get_trt_tensor(
ctx, per_sample_weights, f"{name}_per_sample_weights", np.float32
)
per_sample_weights = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_per_sample_weights",
per_sample_weights,
(-1, 1),
)
embed = impl.elementwise.mul(
ctx,
target,
source_ir,
f"{name}_mul_per_sample_weights",
embed,
per_sample_weights,
)

offsets = to_numpy(offsets)

if include_last_offset is False:
# add the end index to offsets
offsets = np.append(offsets, indices.shape[0])
else:
# modify the last index of offsets to the end index
# however, pytorch doc says if `include_last_offset` is True, the size of offsets
# is equal to the number of bags + 1. The last element is the size of the input,
# or the ending index position of the last bag (sequence).

offsets[-1] = indices.shape[0]

# separately reduce embeddings for different bags
reduced_embed = []
len_offsets = len(offsets)
for i in range(len_offsets - 1):
if offsets[i] < offsets[i + 1]:
sliced_embed = impl.slice.slice_op(
ctx,
target,
source_ir,
f"{name}_slice_embed_{i}",
embed,
0,
offsets[i],
offsets[i + 1],
1,
)
reduced_sliced_embed = reduce_op(
name=f"{name}_{reduce_name}_{i}",
input_val=sliced_embed,
dim=0,
keepdim=True,
)
reduced_embed.append(reduced_sliced_embed)

out = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", reduced_embed, 0)
# out = reduce_op(input_val=embed, dim=1, keepdim=False) # Note: This implementation doesn't work for N-dim

return out, None, None, None
141 changes: 141 additions & 0 deletions tests/py/dynamo/conversion/test_embedding_bag_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import torch
from parameterized import param, parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestEmbeddingBagConverter(DispatchTestCase):
@parameterized.expand(
[
# 1D input
param(
test_name="1d_indices_1",
weight=torch.randn((10, 3), dtype=torch.float32),
indices=torch.tensor([1, 2, 4, 5, 4, 3], dtype=torch.int32),
offsets=torch.tensor([0, 3], dtype=torch.int32),
scale_grad_by_freq=False,
mode=1,
sparse=False,
per_sample_weights=None,
include_last_offset=True,
padding_idx=-1,
),
param(
test_name="1d_indices_2",
weight=torch.randn((10, 3), dtype=torch.float32),
indices=torch.tensor([1, 2, 4, 5, 4, 3], dtype=torch.int32),
offsets=torch.tensor([0, 5], dtype=torch.int32),
scale_grad_by_freq=False,
mode=0,
sparse=False,
per_sample_weights=torch.randn((6,)),
include_last_offset=False,
padding_idx=-1,
),
param(
test_name="1d_indices_3",
weight=torch.randn((10, 3), dtype=torch.float32),
indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
offsets=torch.tensor([0, 2, 4], dtype=torch.int32),
scale_grad_by_freq=False,
mode=2,
sparse=False,
per_sample_weights=None,
include_last_offset=False,
padding_idx=-1,
),
# 2D input
# param(
# test_name="2d_indices_1",
# weight=torch.randn((5, 10), dtype=torch.float32),
# indices=torch.tensor([[3, 1], [4, 3]], dtype=torch.int32),
# offsets=torch.tensor([0, 1], dtype=torch.int32),
# scale_grad_by_freq=False,
# mode=0,
# sparse=False,
# per_sample_weights=torch.randn((4,)),
# include_last_offset=False,
# padding_idx=-1,
# ),
# param(
# test_name="2d_indices_3",
# weight=torch.tensor([
# [0.0, 0.0, 0.0],
# [1.0, 1.0, 1.0],
# [2.0, 2.0, 2.0],
# [3.0, 3.0, 3.0],
# [4.0, 4.0, 4.0],
# [5.0, 5.0, 5.0],
# ], dtype=torch.float32),
# indices=torch.tensor([[0, 2, 1], [3, 5, 4]], dtype=torch.int32),
# offsets=torch.tensor([0, 1], dtype=torch.int32),
# scale_grad_by_freq=False,
# mode=2,
# sparse=False,
# per_sample_weights=None,
# include_last_offset=False,
# padding_idx=-1,
# ),
# param(
# test_name="2d_indices_2",
# weight=torch.randn((5, 5), dtype=torch.float32),
# indices=torch.tensor([[3, 1, 2], [4, 2, 3]], dtype=torch.int32),
# offsets=torch.tensor([0, 2], dtype=torch.int32),
# scale_grad_by_freq=False,
# mode=1,
# sparse=False,
# per_sample_weights=None,
# include_last_offset=False,
# padding_idx=-1,
# ),
# param(
# test_name="2d_indices_2",
# weight=torch.randn((5, 10), dtype=torch.float32),
# indices=torch.tensor([[3, 1, 2, 4], [4, 1, 3, 1]], dtype=torch.int32),
# offsets=torch.tensor([0, 2], dtype=torch.int32),
# scale_grad_by_freq=False,
# mode=0,
# sparse=False,
# per_sample_weights=torch.randn((8,)),
# include_last_offset=True,
# padding_idx=-1,
# ),
]
)
def test_embedding_bag(
self,
test_name,
weight,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
padding_idx,
):
class TestEmbeddingBag(torch.nn.Module):
def forward(self, weight, indices):
return torch.ops.aten._embedding_bag.default(
weight,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
padding_idx,
)[0]

self.run_test(
TestEmbeddingBag(),
inputs=[weight, indices],
enable_passes=True,
)


if __name__ == "__main__":
run_tests()

0 comments on commit cb20f90

Please sign in to comment.