diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6710763..2594cd4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,6 +10,7 @@ repos: hooks: - id: ruff args: [ --fix ] + exclude: tzrec/acc/_decompositions.py|tzrec/acc/_aten_lowering_pass.py - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 diff --git a/.pyre_configuration b/.pyre_configuration index 3461691..6c2f954 100644 --- a/.pyre_configuration +++ b/.pyre_configuration @@ -3,7 +3,8 @@ "tzrec/protos/*_pb2.pyi", "tzrec/*/*_test.py", "tzrec/tests/*.py", - "tzrec/utils/load_class.py" + "tzrec/utils/load_class.py", + "tzrec/acc/_*.py" ], "site_package_search_strategy": "all", "source_directories": [ diff --git a/scripts/ci_test.sh b/scripts/ci_test.sh index f212b87..dd4380e 100644 --- a/scripts/ci_test.sh +++ b/scripts/ci_test.sh @@ -3,4 +3,8 @@ pip install -r requirements.txt bash scripts/gen_proto.sh +# just workaround for torch-tensorrt (dynamic shape) https://github.com/pytorch/TensorRT/pull/3289/files +cp tzrec/acc/_aten_lowering_pass.py /opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +cp tzrec/acc/_decompositions.py /opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/lowering/_decompositions.py + MKL_THREADING_LAYER=GNU TORCH_DEVICE_BACKEND_AUTOLOAD=0 PYTHONPATH=. python tzrec/tests/run.py diff --git a/tzrec/acc/_aten_lowering_pass.py b/tzrec/acc/_aten_lowering_pass.py new file mode 100644 index 0000000..1b9d65f --- /dev/null +++ b/tzrec/acc/_aten_lowering_pass.py @@ -0,0 +1,125 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Callable, Optional, Sequence, Union + +import torch + +from .constant_folding import constant_fold +from .fuse_prims_broadcast import fuse_prims_broadcast +from .lower_linear import lower_linear +from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention +from .pass_manager import DynamoPassManager +from .remove_assert_scalar import remove_assert_scalar +from .remove_detach import remove_detach +from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones +from .repair_input_as_output import repair_input_as_output + +# from .replace_full_like_with_full import replace_full_like_with_full +from .replace_max_pool_with_indices import replace_max_pool_with_indices +from .view_to_reshape import view_to_reshape + +ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist( + [ + remove_input_alias_fixing_clones, + constant_fold, + repair_input_as_output, + lower_scaled_dot_product_attention, + lower_linear, + fuse_prims_broadcast, + replace_max_pool_with_indices, + # replace_full_like_with_full, + view_to_reshape, + remove_assert_scalar, + ] +) + +ATEN_PRE_LOWERING_PASSES = DynamoPassManager.build_from_passlist( + [ + remove_detach, + ] +) + +logger = logging.getLogger(__name__) + + +LoweringPassSignature = Callable[ + [torch.fx.GraphModule, Sequence[torch.Tensor]], torch.fx.GraphModule +] + + +def _aten_lowering_pass( + *args: LoweringPassSignature, + index: Optional[int] = None, +) -> Union[ + LoweringPassSignature, Callable[[LoweringPassSignature], LoweringPassSignature] +]: + """Adds a lowering pass to the registry, at a specified index if desired + + If no index is specified, the lowering pass is inserted at the end of the list + """ + + def add_lowering_pass( + lowering_pass: LoweringPassSignature, + ) -> LoweringPassSignature: + ATEN_POST_LOWERING_PASSES.add_pass_with_index(lowering_pass, index) + logger.debug( + f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}" + ) + return lowering_pass + + # If there are arguments specified, the decorator may have been called as-is + if args: + # The decorator may only be called with the lowering pass + # The index must be specified as a keyword argument + if len(args) == 1 and callable(args[0]): + return add_lowering_pass(args[0]) + else: + raise AssertionError( + f"aten_lowering_pass decorator called with invalid arguments {args} " + "To specify an index to insert the pass, use the keyword 'index='" + ) + # If no arguments are specified, the decorator was called with an index keyword + else: + return add_lowering_pass + + +def _remove_lowering_pass(*, index: int) -> None: + """Removes a lowering pass at a specific index from the registry""" + ATEN_POST_LOWERING_PASSES.remove_pass_with_index(index) + logger.debug( + f"Removed lowering pass at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}" + ) + return + + +def post_lowering(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Applies the lowering passes to a graph module after torch.export/ torch.compile and their decompositions, returns the modified GraphModule""" + logging.debug( + f"Invoking DynamoPassManager and applying lowering passes: {ATEN_POST_LOWERING_PASSES}" + ) + return ATEN_POST_LOWERING_PASSES(gm) + + +def pre_export_lowering(ep: torch.export.ExportedProgram) -> torch.fx.GraphModule: + """Applies the lowering passes to a graph module after torch.export/ torch.compile and their decompositions, returns the modified GraphModule""" + logging.debug( + f"Invoking DynamoPassManager and applying lowering passes: {ATEN_PRE_LOWERING_PASSES}" + ) + gm = ep.graph_module + gm = ATEN_PRE_LOWERING_PASSES(gm) + return ep + + +def dump_lowering_passes() -> str: + """Returns a string containing the lowering passes""" + return str(ATEN_POST_LOWERING_PASSES) diff --git a/tzrec/acc/_decompositions.py b/tzrec/acc/_decompositions.py new file mode 100644 index 0000000..3703650 --- /dev/null +++ b/tzrec/acc/_decompositions.py @@ -0,0 +1,420 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from enum import Enum, auto +from typing import Any, Callable, Dict, List, Optional + +import torch +from torch._decomp import register_decomposition +from torch._ops import OpOverload +from torch_tensorrt.dynamo._defaults import default_device +from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim +from torch_tensorrt.dynamo.utils import to_torch_device + +from ._decomposition_groups import ( + ENABLED_TORCH_DECOMPOSITIONS, + TORCH_TRT_DECOMPOSITIONS, + _core_aten_decompositions, + aten, + torch_disabled_decompositions, + torch_enabled_decompositions, +) + +logger = logging.getLogger(__name__) + + +def register_torch_trt_decomposition( + aten_op: OpOverload, registry: Optional[Any] = None +) -> Callable[[Any], Any]: + """Checks if the decomposition already exists in one of the sets + Registers the decomposition via the Torch utility + + Alerts the user if the decomposition already exists, before registering + Throws an AssertionError if the user attempts to register a decomposition + which is present in the set of explicitly disabled decompositions + """ + if aten_op in torch_enabled_decompositions: + logger.warning( + f"Detected custom decomposition for {aten_op}, which conflicts " + "with an existing Torch decomposition in torch_enabled_decompositions. " + "The custom implementation will take precedence." + ) + elif aten_op in torch_disabled_decompositions: + logger.info( + f"Detected custom decomposition for {aten_op}, which is present " + "in torch_disabled_decompositions." + ) + + # Conflicts with _core_aten_decompositions will only occur if + # enable_experimental_decompositions is True in get_decompositions + if aten_op in _core_aten_decompositions: + logger.debug( + f"Detected custom decomposition for {aten_op}, which conflicts " + "with an existing Torch decomposition in core_aten_decompositions. " + "The custom implementation will take precedence." + ) + + def register(fn: Callable[[Any], Any]) -> Any: + return register_decomposition(aten_op=aten_op, registry=registry)(fn) + + return register + + +def replace_inplace_op(aten_op: OpOverload, outplace_op: OpOverload) -> Any: + """Replace inplace operation with functional equivalent + Adapted from: + https://github.com/pytorch/pytorch/blob/3344d79e3f732dadd5c85b99a7aa1a022f187929/torch/_decomp/decompositions.py#L3355-L3361 + """ + + @register_torch_trt_decomposition(aten_op, registry=TORCH_TRT_DECOMPOSITIONS) + def inplace_op(*args, **kwargs): # type: ignore + out = outplace_op(*args, **kwargs) + return args[0].copy_(out) + + return inplace_op + + +replace_inplace_op(aten.add_, aten.add) +replace_inplace_op(aten.addbmm_, aten.addbmm) +replace_inplace_op(aten.addmm_, aten.addmm) +replace_inplace_op(aten.addmv_, aten.addmv) +replace_inplace_op(aten.baddbmm_, aten.baddbmm) +replace_inplace_op(aten.cumprod_, aten.cumprod) +replace_inplace_op(aten.index_put_, aten.index_put) +replace_inplace_op(aten.index_reduce_, aten.index_reduce) +replace_inplace_op(aten.relu_, aten.relu) +replace_inplace_op(aten.round_, aten.round) +replace_inplace_op(aten.scatter_, aten.scatter) +replace_inplace_op(aten.scatter_add_, aten.scatter_add) +replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce) + + +@register_torch_trt_decomposition(aten.rsqrt, registry=TORCH_TRT_DECOMPOSITIONS) +def rsqrt_replacement(*args, **kwargs) -> torch.Tensor: # type: ignore + return torch.reciprocal(torch.sqrt(*args, **kwargs)) + + +@register_torch_trt_decomposition(aten._unsafe_view, registry=TORCH_TRT_DECOMPOSITIONS) +def unsafe_view_replacement(x: torch.Tensor, *args, **kwargs) -> torch.Tensor: # type: ignore + return torch.reshape(x, *args, **kwargs) + + +@register_torch_trt_decomposition( + torch.ops.aten.lift_fresh_copy, registry=TORCH_TRT_DECOMPOSITIONS +) +def lift_fresh_copy_replacement(x: torch.Tensor) -> torch.Tensor: + return x + + +@register_torch_trt_decomposition(aten.alias, registry=TORCH_TRT_DECOMPOSITIONS) +def alias_replacement(x: torch.Tensor) -> torch.Tensor: + return x + + +@register_torch_trt_decomposition( + torch.ops.aten.reciprocal.default, registry=TORCH_TRT_DECOMPOSITIONS +) +def reciprocal_replacement( + input_: torch.Tensor, +) -> torch.Tensor: + return torch.div(1, input_) + + +@register_torch_trt_decomposition( + torch.ops.prims.var.default, registry=TORCH_TRT_DECOMPOSITIONS +) +def var_decomposition( + input_tensor: torch.Tensor, + dims: Optional[List[int]], + correction: int, + output_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + if dims is None: + dims = [] + + # If the dimensions are empty, variance is taken over all dimensions + if isinstance(dims, (tuple, list)) and len(dims) == 0: + N = input_tensor.numel() + # Otherwise, the number of samples is the product of the dimensions reduced over + else: + N = 1 + for dim_i in dims: + N *= input_tensor.shape[dim_i] + + # Compute the mean, difference, and correction term as per the formula: + # https://pytorch.org/docs/stable/generated/torch.var.html + + # Additionally, prims does not support keepdim, and so we only keep dimensions + # on the first reduction, then remove it for the second + sample_mean = torch.mean(input_tensor, dims, keepdim=True) + diff = input_tensor - sample_mean + squared_diff = diff * diff + variance_unnormalized = torch.sum(squared_diff, dims, keepdim=False) + + if correction is None: + correction_term = float(N - 1) + elif isinstance(correction, int): + correction_term = float(N - correction) + elif isinstance(correction, float): + correction_term = float(N) - correction + else: + raise RuntimeError("correction must be int or float") + + if correction_term <= 0: + raise RuntimeError(f"correction term was non-positive, got: {correction_term}") + + variance = variance_unnormalized / correction_term + + return variance + + +@register_torch_trt_decomposition( + torch.ops.aten.empty_permuted.default, registry=TORCH_TRT_DECOMPOSITIONS +) +def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor: # type: ignore + empty_size = args[0] + empty_permute = args[1] + perm = [0] * len(empty_size) + for permute_index, permute_element in enumerate(empty_permute): + perm[permute_element] = permute_index + kwargs["device"] = to_torch_device(default_device()) + return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm) + + +@register_torch_trt_decomposition( + torch.ops.aten.slice_scatter.default, registry=TORCH_TRT_DECOMPOSITIONS +) +def slice_scatter_decomposition( + input_tensor: torch.Tensor, + src_tensor: torch.Tensor, + dim: int, + start: Optional[int] = None, + end: Optional[int] = None, + step: Optional[int] = None, +) -> torch.Tensor: + dim_size = input_tensor.shape[dim] + device_input_tensor = input_tensor.device + start = get_positive_dim(start, input_tensor.shape[dim]) + if end is None: + end = dim_size + end = get_positive_dim(end, input_tensor.shape[dim]) + if step is None: + step = 1 + + # Ensure start, end, and step are all integers + assert isinstance(start, int), "start must be an integer" + assert isinstance(end, int), "end must be an integer" + assert isinstance(step, int), "step must be an integer" + + src_dim = src_tensor.shape + # step == 0 is not a valid torch case + # also src_dim should be equal to slice dimension + + if start == 0 and end == dim_size and step == 1: + return src_tensor + + cat_tensors = [] + index_tensor_shape = [] + for i, src_each_dim in enumerate(list(src_dim)): + if i != dim: + index_tensor_shape.append(src_each_dim) + for index in range(start, end, step): + cat_tensors.append(index * torch.ones(index_tensor_shape, dtype=torch.int64)) + index_tensor = torch.stack(cat_tensors, dim) + index_tensor = index_tensor.to(device_input_tensor) + index_tensor_64 = index_tensor.to(torch.int64) + output_tensor = torch.scatter(input_tensor, dim, index_tensor_64, src_tensor) + return output_tensor + + +@register_torch_trt_decomposition( + torch.ops.aten.select_scatter.default, registry=TORCH_TRT_DECOMPOSITIONS +) +def select_scatter_decomposition( + input_tensor: torch.Tensor, + src_tensor: torch.Tensor, + dim: int, + index: int, +) -> torch.Tensor: + src_tensor = torch.unsqueeze(src_tensor, dim) + return torch.slice_scatter(input_tensor, src_tensor, dim, index, index + 1, 1) + + +@register_torch_trt_decomposition( + torch.ops.aten.empty_strided.default, registry=TORCH_TRT_DECOMPOSITIONS +) +def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor: # type: ignore + empty_size = args[0] + empty_stride = args[1] + return torch.as_strided( + torch.empty(empty_size, device=to_torch_device(default_device())), + empty_size, + empty_stride, + ) + + +@register_torch_trt_decomposition( + torch.ops.aten.scatter_add.default, registry=TORCH_TRT_DECOMPOSITIONS +) +def scatter_add_decomposition( + input_tensor: torch.Tensor, + dim: int, + index: torch.Tensor, + src_tensor: torch.Tensor, +) -> torch.Tensor: + scatter_add_tensor = input_tensor + src_shape = list(src_tensor.shape) + src_dim = src_shape[dim] + for i in range(0, src_dim): + to_scatter_tensor = torch.zeros(input_tensor.shape, dtype=input_tensor.dtype) + # index and src slice + src_slice = torch.select(src_tensor, dim, i) + index_slice = torch.select(index, dim, i) + + # unsqueeze src and index in dim + src_slice = torch.unsqueeze(src_slice, dim) + index_slice = torch.unsqueeze(index_slice, dim) + + # moving tensor to default device + device = input_tensor.device + scatter_add_tensor = scatter_add_tensor.to(device) + to_scatter_tensor = to_scatter_tensor.to(device) + index_slice = index_slice.to(device) + src_slice = src_slice.to(device) + + scatter_add_tensor = torch.add( + scatter_add_tensor, + torch.scatter(to_scatter_tensor, dim, index_slice, src_slice), + ) + + return scatter_add_tensor + + +# enum class for reduce operation of scatter_reduce +class ReduceOperation(Enum): + SUM = ("Sum reduce operation", lambda x, y: torch.add(x, y)) + PROD = ("Product reduce operation", lambda x, y: torch.mul(x, y)) + MEAN = ("Mean reduce operation", lambda x, y: torch.add(x, y)) + AMAX = ("Amax reduce operation", lambda x, y: torch.max(x, y)) + AMIN = ("Amin reduce operation", lambda x, y: torch.min(x, y)) + + def __new__(cls, description, func): + obj = object.__new__(cls) + obj._value_ = auto() + obj.description = description + obj.func = func + return obj + + def reduce_operation_with_scatter( + self, operation_lhs, initial_tensor, dim, index_tensor, src_tensor + ): + scatter_tensor = None + if self == ReduceOperation.SUM or self == ReduceOperation.MEAN: + scatter_tensor = torch.zeros_like(initial_tensor) + elif self == ReduceOperation.PROD: + scatter_tensor = torch.ones_like(initial_tensor) + elif self == ReduceOperation.AMIN or self == ReduceOperation.AMAX: + scatter_tensor = initial_tensor + else: + # This case would not be encountered from torch itself + print("Invalid Operation for Reduce op!!") + + operation_rhs = torch.scatter(scatter_tensor, dim, index_tensor, src_tensor) + device = to_torch_device(scatter_tensor.device) + operation_lhs = operation_lhs.to(device) + operation_rhs = operation_rhs.to(device) + return self.func(operation_lhs, operation_rhs) + + +@register_torch_trt_decomposition( + torch.ops.aten.scatter_reduce.two, registry=TORCH_TRT_DECOMPOSITIONS +) +def scatter_reduce_decomposition( + input_tensor: torch.Tensor, + dim: int, + index: torch.Tensor, + src_tensor: torch.Tensor, + reduce: str, + include_self: bool = True, +) -> torch.Tensor: + scatter_loop_tensor = input_tensor + device_input_tensor = input_tensor.device + # required for mean reduce operation + scatter_count_tensor = torch.zeros_like(input_tensor) + src_shape = list(src_tensor.shape) + src_dim = src_shape[dim] + if include_self == False: + raise AssertionError("include_self False for scatter reduce not yet supported") + for i in range(0, src_dim): + src_slice = torch.select(src_tensor, dim, i) + index_slice = torch.select(index, dim, i) + # unsqueeze src and index in dim + src_slice = torch.unsqueeze(src_slice, dim) + index_slice = torch.unsqueeze(index_slice, dim) + + # moving tensor to default device + scatter_loop_tensor = scatter_loop_tensor.to(device_input_tensor) + index_slice = index_slice.to(device_input_tensor) + src_slice = src_slice.to(device_input_tensor) + if reduce == "sum": + reduceOp = ReduceOperation.SUM + elif reduce == "prod": + reduceOp = ReduceOperation.PROD + elif reduce == "mean": + reduceOp = ReduceOperation.MEAN + scatter_count_tensor = reduceOp.reduce_operation_with_scatter( + scatter_count_tensor, + input_tensor, + dim, + index_slice, + torch.ones_like(src_slice), + ) + elif reduce == "amax": + reduceOp = ReduceOperation.AMAX + elif reduce == "amin": + reduceOp = ReduceOperation.AMIN + scatter_loop_tensor = reduceOp.reduce_operation_with_scatter( + scatter_loop_tensor, input_tensor, dim, index_slice, src_slice + ) + if reduce == "mean": + scatter_loop_tensor = torch.div( + scatter_loop_tensor, + torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor)), + rounding_mode="trunc", + ) + return scatter_loop_tensor + + +@register_torch_trt_decomposition( + torch.ops.aten.full_like, registry=TORCH_TRT_DECOMPOSITIONS +) +def full_like_decomposition(*args, **kwargs) -> torch.Tensor: + input = args[0] + shape = args[0].shape + fill_value = args[1] + kwargs["dtype"] = input.dtype + kwargs["device"] = to_torch_device(default_device()) + return torch.full(shape, fill_value, dtype=kwargs["dtype"], device=kwargs["device"]) + + +def get_decompositions( + enable_experimental_decompositions: bool = False, +) -> Dict[OpOverload, Callable[[Any], Any]]: + if enable_experimental_decompositions: + CORE_ATEN_DECOMPOSITIONS_FILTERED: Dict[OpOverload, Callable[[Any], Any]] = { + decomp: _core_aten_decompositions[decomp] + for decomp in _core_aten_decompositions + if decomp not in torch_disabled_decompositions + } + return {**CORE_ATEN_DECOMPOSITIONS_FILTERED, **TORCH_TRT_DECOMPOSITIONS} + else: + return {**ENABLED_TORCH_DECOMPOSITIONS, **TORCH_TRT_DECOMPOSITIONS} diff --git a/tzrec/acc/trt_utils.py b/tzrec/acc/trt_utils.py index 1ae479a..9ea1279 100644 --- a/tzrec/acc/trt_utils.py +++ b/tzrec/acc/trt_utils.py @@ -46,7 +46,7 @@ def trt_convert( torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT """ logger.info("trt convert start...") - # torch_tensorrt.runtime.set_multi_device_safe_mode(True) + torch_tensorrt.runtime.set_multi_device_safe_mode(True) enabled_precisions = {torch.float32} # Workspace size for TensorRT @@ -149,6 +149,24 @@ def forward( return y +def get_trt_max_batch_size() -> int: + """Get trt max batch size. + + Returns: + int: max_batch_size + """ + return int(os.environ.get("TRT_MAX_BATCH_SIZE", 2048)) + + +def get_trt_max_seq_len() -> int: + """Get trt max seq len. + + Returns: + int: max_seq_len + """ + return int(os.environ.get("TRT_MAX_SEQ_LEN", 100)) + + def export_model_trt( model: nn.Module, data: Dict[str, torch.Tensor], save_dir: str ) -> None: @@ -166,7 +184,9 @@ def export_model_trt( emb_trace_gpu = torch.jit.script(emb_trace_gpu) # dynamic shapes - batch = torch.export.Dim("batch", min=1, max=10000) + max_batch_size = get_trt_max_batch_size() + max_seq_len = get_trt_max_seq_len() + batch = torch.export.Dim("batch", min=1, max=max_batch_size) dynamic_shapes_list = [] values_list_cuda = [] for i, value in enumerate(emb_res): @@ -174,7 +194,7 @@ def export_model_trt( values_list_cuda.append(v) dict_dy = {0: batch} if v.dim() == 3: - dict_dy[1] = torch.export.Dim("seq_len" + str(i), min=1, max=10000) + dict_dy[1] = torch.export.Dim("seq_len" + str(i), min=1, max=max_seq_len) dynamic_shapes_list.append(dict_dy) # convert dense diff --git a/tzrec/acc/utils.py b/tzrec/acc/utils.py index aee142d..1f353b7 100644 --- a/tzrec/acc/utils.py +++ b/tzrec/acc/utils.py @@ -12,7 +12,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import json import os -from typing import Dict, List +from typing import Dict import torch @@ -150,28 +150,3 @@ def export_acc_config() -> Dict[str, str]: if "ENABLE_TRT" in os.environ: acc_config["ENABLE_TRT"] = os.environ["ENABLE_TRT"] return acc_config - - -def dicts_are_equal( - dict1: Dict[str, torch.Tensor], dict2: Dict[str, torch.Tensor] -) -> bool: - """Compare dict[str,torch.Tensor].""" - if dict1.keys() != dict2.keys(): - return False - - for key in dict1: - if not torch.equal(dict1[key], dict2[key]): - return False - - return True - - -def lists_are_equal(list1: List[torch.Tensor], list2: List[torch.Tensor]) -> bool: - """Compare List[torch.Tensor].""" - if len(list1) != len(list2): - return False - - for i in range(len(list1)): - if not torch.equal(list1[i], list2[i]): - return False - return True diff --git a/tzrec/main.py b/tzrec/main.py index e078ad5..0478faa 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -8,7 +8,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +# cpu image has no torch_tensorrt +try: + import torch_tensorrt +except Exception: + pass import copy import itertools import json @@ -43,7 +47,7 @@ from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper from torchrec.optim.optimizers import in_backward_optimizer_filter -from tzrec.acc.trt_utils import export_model_trt +from tzrec.acc.trt_utils import export_model_trt, get_trt_max_batch_size from tzrec.acc.utils import ( export_acc_config, is_input_tile_emb, @@ -805,6 +809,13 @@ def export( assets = asset_files.split(",") data_config = pipeline_config.data_config + is_trt_convert = is_trt() + if is_trt_convert: + # export batch_size too large may OOM in trt convert phase + max_batch_size = get_trt_max_batch_size() + data_config.batch_size = min(data_config.batch_size, max_batch_size) + logger.info("using new batch_size: %s in trt export", data_config.batch_size) + # Build feature features = _create_features(list(pipeline_config.feature_configs), data_config) @@ -849,7 +860,6 @@ def export( else: raise ValueError("checkpoint path should be specified.") - is_trt_convert = is_trt() if is_trt_convert: checkpoint_pg = dist.new_group(backend="nccl") if is_rank_zero: @@ -970,6 +980,19 @@ def predict( ) if batch_size: pipeline_config.data_config.batch_size = batch_size + + is_trt_convert: bool = is_trt_predict(scripted_model_path) + if is_trt_convert: + # predict batch_size too large may out of range + max_batch_size = get_trt_max_batch_size() + pipeline_config.data_config.batch_size = min( + pipeline_config.data_config.batch_size, max_batch_size + ) + logger.info( + "using new batch_size: %s in trt predict", + pipeline_config.data_config.batch_size, + ) + if dataset_type: pipeline_config.data_config.dataset_type = getattr(DatasetType, dataset_type) if edit_config_json: @@ -1012,6 +1035,10 @@ def predict( # disable jit compileļ¼Œ as it compile too slow now. if "PYTORCH_TENSOREXPR_FALLBACK" not in os.environ: os.environ["PYTORCH_TENSOREXPR_FALLBACK"] = "2" + + if is_trt_convert: + torch_tensorrt.runtime.set_multi_device_safe_mode(True) + model: torch.jit.ScriptModule = torch.jit.load( os.path.join(scripted_model_path, "scripted_model.pt"), map_location=device ) @@ -1047,8 +1074,7 @@ def _forward(batch: Batch) -> Tuple[Dict[str, torch.Tensor], RecordBatchTensor]: # when predicting with a model exported using INPUT_TILE, # we set the batch size tensor to 1 to disable tiling. parsed_inputs["batch_size"] = torch.tensor(1, dtype=torch.int64) - is_trt = is_trt_predict(scripted_model_path) - if is_trt: + if is_trt_convert: predictions = model(parsed_inputs) else: predictions = model(parsed_inputs, device) diff --git a/tzrec/tests/configs/multi_tower_din_trt_fg_mock.config b/tzrec/tests/configs/multi_tower_din_trt_fg_mock.config index 5a59bfd..d8d0c03 100644 --- a/tzrec/tests/configs/multi_tower_din_trt_fg_mock.config +++ b/tzrec/tests/configs/multi_tower_din_trt_fg_mock.config @@ -299,6 +299,18 @@ model_config { hidden_units: [512, 256, 128] } } + din_towers { + input: 'seq' + attn_mlp { + hidden_units: [256, 64] + } + } + din_towers { + input: 'seq_item' + attn_mlp { + hidden_units: [256, 64] + } + } final { hidden_units: [64] diff --git a/tzrec/tests/train_eval_export_test.py b/tzrec/tests/train_eval_export_test.py index 7e2bd4b..56bb3ce 100644 --- a/tzrec/tests/train_eval_export_test.py +++ b/tzrec/tests/train_eval_export_test.py @@ -16,6 +16,12 @@ import unittest import torch + +# cpu image has no torch_tensorrt +try: + import torch_tensorrt +except Exception: + pass from pyarrow import dataset as ds from tzrec.constant import Mode @@ -674,6 +680,7 @@ def test_multi_tower_with_fg_train_eval_export_trt(self): self.test_dir, "predict_result_tile_emb_trt" ) + predict_columns = ["user_id", "item_id", "clk", "probs"] # quant and no-input-tile if self.success: self.success = utils.test_export( @@ -710,14 +717,13 @@ def test_multi_tower_with_fg_train_eval_export_trt(self): output_columns="probs", test_dir=trt_dir, ) - # compare INPUT_TILE and no INPUT_TILE result consistency + # compare TRT and origin result consistency df = ds.dataset(pred_output, format="parquet").to_table().to_pandas() df_t = ds.dataset(trt_pred_output, format="parquet").to_table().to_pandas() - df = df.sort_values(by=list(df.columns)).reset_index(drop=True) - df_t = df_t.sort_values(by=list(df_t.columns)).reset_index(drop=True) - # self.assertTrue(df.equals(df_t)) - print(df) - print(df_t) + df = df.sort_values(by=predict_columns).reset_index(drop=True) + df_t = df_t.sort_values(by=predict_columns).reset_index(drop=True) + # differences = df.compare(df_t) + # self.assertTrue(dfs_are_close(df, df_t, 1e-6)) # quant and input-tile and trt if self.success: @@ -739,18 +745,17 @@ def test_multi_tower_with_fg_train_eval_export_trt(self): output_columns="probs", test_dir=input_tile_trt_dir, ) - # compare INPUT_TILE and no INPUT_TILE result consistency + # compare INPUT_TILE+TRT and origin result consistency df = ds.dataset(pred_output, format="parquet").to_table().to_pandas() df_t = ( ds.dataset(tile_trt_pred_output, format="parquet") .to_table() .to_pandas() ) - df = df.sort_values(by=list(df.columns)).reset_index(drop=True) - df_t = df_t.sort_values(by=list(df_t.columns)).reset_index(drop=True) - # self.assertTrue(df.equals(df_t)) - print(df) - print(df_t) + df = df.sort_values(by=predict_columns).reset_index(drop=True) + df_t = df_t.sort_values(by=predict_columns).reset_index(drop=True) + # differences = df.compare(df_t) + # self.assertTrue(dfs_are_close(df, df_t, 1e-6)) # quant and input-tile emb and trt if self.success: @@ -772,18 +777,17 @@ def test_multi_tower_with_fg_train_eval_export_trt(self): output_columns="probs", test_dir=input_tile_emb_trt_dir, ) - # compare INPUT_TILE and no INPUT_TILE result consistency + # compare INPUT_TILE_EMB+TRT and origin result consistency df = ds.dataset(pred_output, format="parquet").to_table().to_pandas() df_t = ( ds.dataset(tile_trt_pred_output_emb, format="parquet") .to_table() .to_pandas() ) - df = df.sort_values(by=list(df.columns)).reset_index(drop=True) - df_t = df_t.sort_values(by=list(df_t.columns)).reset_index(drop=True) - # self.assertTrue(df.equals(df_t)) - print(df) - print(df_t) + df = df.sort_values(by=predict_columns).reset_index(drop=True) + df_t = df_t.sort_values(by=predict_columns).reset_index(drop=True) + # differences = df.compare(df_t) + # self.assertTrue(dfs_are_close(df, df_t, 1e-6)) self.assertTrue(self.success) @@ -842,6 +846,7 @@ def test_multi_tower_with_fg_train_eval_export_trt(self): utils.save_predict_result_json(result_gpu, result_dict_json_path) # quant and trt + torch_tensorrt.runtime.set_multi_device_safe_mode(True) model_gpu_trt = torch.jit.load( os.path.join(self.test_dir, "trt/export/scripted_model.pt"), map_location=device, @@ -884,19 +889,19 @@ def test_multi_tower_with_fg_train_eval_export_trt(self): # trt is all same sa no-trt for k, v in result_gpu.items(): torch.testing.assert_close( - result_gpu_trt[k].to(device), v, rtol=1e-4, atol=1e-4 + result_gpu_trt[k].to(device), v, rtol=1e-6, atol=1e-6 ) # tile & trt is all same sa no-tile-trt for k, v in result_gpu.items(): torch.testing.assert_close( - result_gpu_input_tile[k].to(device), v, rtol=1e-4, atol=1e-4 + result_gpu_input_tile[k].to(device), v, rtol=1e-6, atol=1e-6 ) # tile emb & trt is all same sa no-tile-trt for k, v in result_gpu.items(): torch.testing.assert_close( - result_gpu_input_tile_emb[k].to(device), v, rtol=1e-4, atol=1e-4 + result_gpu_input_tile_emb[k].to(device), v, rtol=1e-6, atol=1e-6 ) diff --git a/tzrec/utils/test_util.py b/tzrec/utils/test_util.py index 9f36e2f..ba0b986 100644 --- a/tzrec/utils/test_util.py +++ b/tzrec/utils/test_util.py @@ -10,8 +10,10 @@ # limitations under the License. from enum import Enum -from typing import Union +from typing import Dict, List, Union +import numpy as np +import pandas as pd import torch from torch import nn from torch.fx import GraphModule @@ -80,3 +82,38 @@ def parameterized_name_func(func, num, p) -> str: base_name = func.__name__ name_suffix = "_%s" % (num,) return base_name + name_suffix + + +def dicts_are_equal( + dict1: Dict[str, torch.Tensor], dict2: Dict[str, torch.Tensor] +) -> bool: + """Compare dict[str,torch.Tensor].""" + if dict1.keys() != dict2.keys(): + return False + + for key in dict1: + if not torch.equal(dict1[key], dict2[key]): + return False + + return True + + +def lists_are_equal(list1: List[torch.Tensor], list2: List[torch.Tensor]) -> bool: + """Compare List[torch.Tensor].""" + if len(list1) != len(list2): + return False + + for i in range(len(list1)): + if not torch.equal(list1[i], list2[i]): + return False + return True + + +def dfs_are_close(df1: pd.DataFrame, df2: pd.DataFrame, abs_tol: float) -> bool: + """Compare DataFrame.""" + if df1.shape != df2.shape: + return False + abs_diff = np.abs(df1.values - df2.values) + result = np.all(abs_diff <= abs_tol) + # pyre-ignore [7] + return result