Skip to content

Commit

Permalink
feat: support converter for torch.log10
Browse files Browse the repository at this point in the history
  • Loading branch information
bowang007 committed Mar 4, 2024
1 parent 9a100b6 commit ab08c63
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 0 deletions.
17 changes: 17 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2708,6 +2708,23 @@ def aten_ops_scalar_tensor(
)


@dynamo_tensorrt_converter(torch.ops.aten.log10.default)
def log10(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.unary.log10(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
)


@dynamo_tensorrt_converter(torch.ops.aten.roll.default)
@enforce_tensor_types(
{
Expand Down
16 changes: 16 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,22 @@ def log(
)


def log10(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
) -> TRTTensor:
log_layer_output = log(ctx, target, source_ir, f"{name}_log", input_val)

ln10 = 2.302585092994046

return impl.elementwise.div(
ctx, target, source_ir, f"{name}_div", log_layer_output, ln10
)


def sqrt(
ctx: ConversionContext,
target: Target,
Expand Down
49 changes: 49 additions & 0 deletions tests/py/dynamo/conversion/test_log10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestLogConverter(DispatchTestCase):
@parameterized.expand(
[
((10,), torch.float),
((1, 20), torch.float),
((2, 3, 4), torch.float),
((2, 3, 4, 5), torch.float),
]
)
def test_log10_float(self, input_shape, dtype):
class log10(nn.Module):
def forward(self, input):
return torch.ops.aten.log10.default(input)

inputs = [torch.randn(input_shape, dtype=dtype)]
self.run_test(
log10(),
inputs,
)

@parameterized.expand(
[
((10,), torch.int, 0, 5),
((1, 20), torch.int32, -10, 10),
((2, 3, 4), torch.int, -5, 5),
]
)
def test_log10_int(self, input_shape, dtype, low, high):
class log10(nn.Module):
def forward(self, input):
return torch.ops.aten.log10.default(input)

inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
self.run_test(
log10(),
inputs,
)


if __name__ == "__main__":
run_tests()

0 comments on commit ab08c63

Please sign in to comment.