From 2893d6ed3d94d63ac5117c41aa61bacb318b2a83 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 9 Nov 2023 11:40:05 -0800 Subject: [PATCH] chore: add additional BN native converter (#2446) Signed-off-by: Dheeraj Peri --- .../dynamo/conversion/aten_ops_converters.py | 31 +++++++++++++++++++ .../dynamo/lowering/_decomposition_groups.py | 1 - .../dynamo/conversion/test_batch_norm_aten.py | 19 ++++++++++++ 3 files changed, 50 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 7ca0df5ebb..b05713c360 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -89,6 +89,37 @@ def aten_ops_batch_norm( ) +@dynamo_tensorrt_converter( + torch.ops.aten._native_batch_norm_legit_no_training.default, + capability_validator=one_user_validator, +) +def aten_ops_batch_norm_legit_no_training( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.batch_norm( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + weight=args[1], + bias=args[2], + running_mean=args[3], + running_var=args[4], + training=False, + momentum=args[5], + eps=args[6], + cudnn_enabled=False, + return_mean_rstd=( + target == torch.ops.aten._native_batch_norm_legit_no_training.default + ), + ) + + @dynamo_tensorrt_converter( torch.ops.aten.native_layer_norm.default, capability_validator=one_user_validator ) diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index b95715b5ae..af92a9dc50 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -100,7 +100,6 @@ aten.native_batch_norm_backward, aten._native_batch_norm_legit, aten._native_batch_norm_legit_functional, - aten._native_batch_norm_legit_no_training, aten.native_dropout_backward, aten.native_group_norm_backward, aten.native_layer_norm_backward, diff --git a/tests/py/dynamo/conversion/test_batch_norm_aten.py b/tests/py/dynamo/conversion/test_batch_norm_aten.py index 680e2264d1..bb1e0d8931 100644 --- a/tests/py/dynamo/conversion/test_batch_norm_aten.py +++ b/tests/py/dynamo/conversion/test_batch_norm_aten.py @@ -107,6 +107,25 @@ def forward(self, x): inputs, ) + def test_batchnorm_legit_no_training(self): + class BatchNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten._native_batch_norm_legit_no_training.default( + x, + torch.ones((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.ones((FEATURE_NUM,)), + 0.1, + 1e-05, + )[0] + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test( + BatchNorm(), + inputs, + ) + def test_batchnorm1d_with_dynamic_shape(self): class BatchNorm(torch.nn.Module): def forward(self, x):