Skip to content

Commit

Permalink
[feat] torch-trt: set multi_device_safe_mode & support dynamic shape …
Browse files Browse the repository at this point in the history
…in sequence (#32)
  • Loading branch information
yjjinjie authored Nov 20, 2024
1 parent 215c49a commit 64b990b
Show file tree
Hide file tree
Showing 11 changed files with 683 additions and 57 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ repos:
hooks:
- id: ruff
args: [ --fix ]
exclude: tzrec/acc/_decompositions.py|tzrec/acc/_aten_lowering_pass.py
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
Expand Down
3 changes: 2 additions & 1 deletion .pyre_configuration
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"tzrec/protos/*_pb2.pyi",
"tzrec/*/*_test.py",
"tzrec/tests/*.py",
"tzrec/utils/load_class.py"
"tzrec/utils/load_class.py",
"tzrec/acc/_*.py"
],
"site_package_search_strategy": "all",
"source_directories": [
Expand Down
4 changes: 4 additions & 0 deletions scripts/ci_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,8 @@
pip install -r requirements.txt
bash scripts/gen_proto.sh

# just workaround for torch-tensorrt (dynamic shape) https://github.com/pytorch/TensorRT/pull/3289/files
cp tzrec/acc/_aten_lowering_pass.py /opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
cp tzrec/acc/_decompositions.py /opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/lowering/_decompositions.py

MKL_THREADING_LAYER=GNU TORCH_DEVICE_BACKEND_AUTOLOAD=0 PYTHONPATH=. python tzrec/tests/run.py
125 changes: 125 additions & 0 deletions tzrec/acc/_aten_lowering_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Callable, Optional, Sequence, Union

import torch

from .constant_folding import constant_fold
from .fuse_prims_broadcast import fuse_prims_broadcast
from .lower_linear import lower_linear
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
from .pass_manager import DynamoPassManager
from .remove_assert_scalar import remove_assert_scalar
from .remove_detach import remove_detach
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output

# from .replace_full_like_with_full import replace_full_like_with_full
from .replace_max_pool_with_indices import replace_max_pool_with_indices
from .view_to_reshape import view_to_reshape

ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
[
remove_input_alias_fixing_clones,
constant_fold,
repair_input_as_output,
lower_scaled_dot_product_attention,
lower_linear,
fuse_prims_broadcast,
replace_max_pool_with_indices,
# replace_full_like_with_full,
view_to_reshape,
remove_assert_scalar,
]
)

ATEN_PRE_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
[
remove_detach,
]
)

logger = logging.getLogger(__name__)


LoweringPassSignature = Callable[
[torch.fx.GraphModule, Sequence[torch.Tensor]], torch.fx.GraphModule
]


def _aten_lowering_pass(
*args: LoweringPassSignature,
index: Optional[int] = None,
) -> Union[
LoweringPassSignature, Callable[[LoweringPassSignature], LoweringPassSignature]
]:
"""Adds a lowering pass to the registry, at a specified index if desired
If no index is specified, the lowering pass is inserted at the end of the list
"""

def add_lowering_pass(
lowering_pass: LoweringPassSignature,
) -> LoweringPassSignature:
ATEN_POST_LOWERING_PASSES.add_pass_with_index(lowering_pass, index)
logger.debug(
f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}"
)
return lowering_pass

# If there are arguments specified, the decorator may have been called as-is
if args:
# The decorator may only be called with the lowering pass
# The index must be specified as a keyword argument
if len(args) == 1 and callable(args[0]):
return add_lowering_pass(args[0])
else:
raise AssertionError(
f"aten_lowering_pass decorator called with invalid arguments {args} "
"To specify an index to insert the pass, use the keyword 'index='"
)
# If no arguments are specified, the decorator was called with an index keyword
else:
return add_lowering_pass


def _remove_lowering_pass(*, index: int) -> None:
"""Removes a lowering pass at a specific index from the registry"""
ATEN_POST_LOWERING_PASSES.remove_pass_with_index(index)
logger.debug(
f"Removed lowering pass at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}"
)
return


def post_lowering(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""Applies the lowering passes to a graph module after torch.export/ torch.compile and their decompositions, returns the modified GraphModule"""
logging.debug(
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_POST_LOWERING_PASSES}"
)
return ATEN_POST_LOWERING_PASSES(gm)


def pre_export_lowering(ep: torch.export.ExportedProgram) -> torch.fx.GraphModule:
"""Applies the lowering passes to a graph module after torch.export/ torch.compile and their decompositions, returns the modified GraphModule"""
logging.debug(
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_PRE_LOWERING_PASSES}"
)
gm = ep.graph_module
gm = ATEN_PRE_LOWERING_PASSES(gm)
return ep


def dump_lowering_passes() -> str:
"""Returns a string containing the lowering passes"""
return str(ATEN_POST_LOWERING_PASSES)
Loading

0 comments on commit 64b990b

Please sign in to comment.