-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feat] torch-trt: set multi_device_safe_mode & support dynamic shape …
…in sequence (#32)
- Loading branch information
Showing
11 changed files
with
683 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Copyright (c) 2024, Alibaba Group; | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
from typing import Callable, Optional, Sequence, Union | ||
|
||
import torch | ||
|
||
from .constant_folding import constant_fold | ||
from .fuse_prims_broadcast import fuse_prims_broadcast | ||
from .lower_linear import lower_linear | ||
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention | ||
from .pass_manager import DynamoPassManager | ||
from .remove_assert_scalar import remove_assert_scalar | ||
from .remove_detach import remove_detach | ||
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones | ||
from .repair_input_as_output import repair_input_as_output | ||
|
||
# from .replace_full_like_with_full import replace_full_like_with_full | ||
from .replace_max_pool_with_indices import replace_max_pool_with_indices | ||
from .view_to_reshape import view_to_reshape | ||
|
||
ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist( | ||
[ | ||
remove_input_alias_fixing_clones, | ||
constant_fold, | ||
repair_input_as_output, | ||
lower_scaled_dot_product_attention, | ||
lower_linear, | ||
fuse_prims_broadcast, | ||
replace_max_pool_with_indices, | ||
# replace_full_like_with_full, | ||
view_to_reshape, | ||
remove_assert_scalar, | ||
] | ||
) | ||
|
||
ATEN_PRE_LOWERING_PASSES = DynamoPassManager.build_from_passlist( | ||
[ | ||
remove_detach, | ||
] | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
LoweringPassSignature = Callable[ | ||
[torch.fx.GraphModule, Sequence[torch.Tensor]], torch.fx.GraphModule | ||
] | ||
|
||
|
||
def _aten_lowering_pass( | ||
*args: LoweringPassSignature, | ||
index: Optional[int] = None, | ||
) -> Union[ | ||
LoweringPassSignature, Callable[[LoweringPassSignature], LoweringPassSignature] | ||
]: | ||
"""Adds a lowering pass to the registry, at a specified index if desired | ||
If no index is specified, the lowering pass is inserted at the end of the list | ||
""" | ||
|
||
def add_lowering_pass( | ||
lowering_pass: LoweringPassSignature, | ||
) -> LoweringPassSignature: | ||
ATEN_POST_LOWERING_PASSES.add_pass_with_index(lowering_pass, index) | ||
logger.debug( | ||
f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}" | ||
) | ||
return lowering_pass | ||
|
||
# If there are arguments specified, the decorator may have been called as-is | ||
if args: | ||
# The decorator may only be called with the lowering pass | ||
# The index must be specified as a keyword argument | ||
if len(args) == 1 and callable(args[0]): | ||
return add_lowering_pass(args[0]) | ||
else: | ||
raise AssertionError( | ||
f"aten_lowering_pass decorator called with invalid arguments {args} " | ||
"To specify an index to insert the pass, use the keyword 'index='" | ||
) | ||
# If no arguments are specified, the decorator was called with an index keyword | ||
else: | ||
return add_lowering_pass | ||
|
||
|
||
def _remove_lowering_pass(*, index: int) -> None: | ||
"""Removes a lowering pass at a specific index from the registry""" | ||
ATEN_POST_LOWERING_PASSES.remove_pass_with_index(index) | ||
logger.debug( | ||
f"Removed lowering pass at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}" | ||
) | ||
return | ||
|
||
|
||
def post_lowering(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
"""Applies the lowering passes to a graph module after torch.export/ torch.compile and their decompositions, returns the modified GraphModule""" | ||
logging.debug( | ||
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_POST_LOWERING_PASSES}" | ||
) | ||
return ATEN_POST_LOWERING_PASSES(gm) | ||
|
||
|
||
def pre_export_lowering(ep: torch.export.ExportedProgram) -> torch.fx.GraphModule: | ||
"""Applies the lowering passes to a graph module after torch.export/ torch.compile and their decompositions, returns the modified GraphModule""" | ||
logging.debug( | ||
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_PRE_LOWERING_PASSES}" | ||
) | ||
gm = ep.graph_module | ||
gm = ATEN_PRE_LOWERING_PASSES(gm) | ||
return ep | ||
|
||
|
||
def dump_lowering_passes() -> str: | ||
"""Returns a string containing the lowering passes""" | ||
return str(ATEN_POST_LOWERING_PASSES) |
Oops, something went wrong.