Skip to content

Commit

Permalink
feat: Exempt default softmax from decomposition (#2268)
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive authored Aug 26, 2023
1 parent a65c95c commit 31d30e2
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from typing import Any, Callable, Dict, Set
from typing import Any, Callable, Dict, Set, Union

import torch
from torch._decomp import core_aten_decompositions
from torch._decomp import get_decompositions as get_torch_decompositions
from torch._ops import OpOverload
from torch._ops import OpOverload, OpOverloadPacket

aten = torch.ops.aten

_core_aten_decompositions: Dict[
OpOverload, Callable[[Any], Any]
] = core_aten_decompositions()
torch_enabled_decompositions: Set[OpOverload] = {
torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
aten._adaptive_avg_pool2d_backward,
aten.addcdiv,
aten.addcdiv_,
Expand Down Expand Up @@ -140,7 +140,7 @@
aten.smooth_l1_loss_backward,
aten.soft_margin_loss,
aten.soft_margin_loss_backward,
aten._softmax,
aten._softmax.out,
aten._softmax_backward_data,
aten.softplus,
aten.softplus_backward,
Expand Down Expand Up @@ -176,7 +176,9 @@
aten.full,
aten.repeat,
}
torch_disabled_decompositions: Set[OpOverload] = set()
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
aten._softmax.default,
}


ENABLED_TORCH_DECOMPOSITIONS: Dict[
Expand Down

0 comments on commit 31d30e2

Please sign in to comment.