diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 8d66343254c1..7efc2412eaf7 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -19,7 +19,7 @@ # pylint: disable=import-outside-toplevel """PyTorch FX frontend of Relax.""" from typing import Callable, Dict, List, Optional, Tuple, Union -from functools import reduce +from functools import partial, reduce import tvm from tvm import relax @@ -119,23 +119,6 @@ def _retrieve_args(self, node): else: return node - @staticmethod - def _promote_binary_op_args(lhs, rhs): - if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): - return lhs, rhs - elif isinstance(lhs, relax.Expr): - assert isinstance(lhs.struct_info, relax.TensorStructInfo) - return lhs, relax.const(rhs, lhs.struct_info.dtype) - elif isinstance(rhs, relax.Expr): - assert isinstance(rhs.struct_info, relax.TensorStructInfo) - return relax.const(lhs, rhs.struct_info.dtype), rhs - else: - assert False - - def _call_binary_op(self, op, lhs, rhs): - lhs, rhs = TorchFXImporter._promote_binary_op_args(lhs, rhs) - return self.block_builder.emit(op(lhs, rhs)) - ########## Unary Ops ########## def _unary_op(self, op: Callable) -> Callable: @@ -240,66 +223,38 @@ def convert(node: fx.Node) -> relax.Var: return convert - ########## Arithmetic ########## + ########## Binary Ops ########## - def _add(self, node: fx.Node) -> relax.Expr: - lhs, rhs = self.retrieve_args(node) - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): - return self._call_binary_op(relax.op.add, lhs, rhs) - elif isinstance(lhs, relax.expr.Constant): - return self._call_binary_op( - relax.op.add, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype) - ) - elif isinstance(rhs, relax.expr.Constant): - return self._call_binary_op( - relax.op.add, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs - ) - return lhs + rhs - - def _max(self, node: fx.Node) -> relax.Expr: - lhs, rhs = self.retrieve_args(node) - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): - return self._call_binary_op(relax.op.maximum, lhs, rhs) - - def _floordiv(self, node: fx.Node) -> relax.Expr: - lhs, rhs = self.retrieve_args(node) - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): - return self._call_binary_op(relax.op.floor_divide, lhs, rhs) - return lhs // rhs - - def _mul(self, node: fx.Node) -> relax.Expr: - lhs, rhs = self.retrieve_args(node) - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): - return self._call_binary_op(relax.op.multiply, lhs, rhs) - return lhs * rhs - - def _pow(self, node: fx.Node) -> relax.Expr: - lhs, rhs = self.retrieve_args(node) - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): - return self._call_binary_op(relax.op.power, lhs, rhs) - return lhs**rhs - - def _sub(self, node: fx.Node) -> relax.Expr: - lhs, rhs = self.retrieve_args(node) - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): - return self._call_binary_op(relax.op.subtract, lhs, rhs) - return lhs - rhs - - def _truediv(self, node: fx.Node) -> relax.Expr: - lhs, rhs = self.retrieve_args(node) - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): - return self._call_binary_op(relax.op.divide, lhs, rhs) - return lhs / rhs - - ########## Compare ########## - - def _lt(self, node: fx.Node) -> relax.Expr: - lhs, rhs = self.retrieve_args(node) - return self._call_binary_op(relax.op.less, lhs, rhs) - - def _eq(self, node: fx.Node) -> relax.Expr: - lhs, rhs = self.retrieve_args(node) - return self._call_binary_op(relax.op.equal, lhs, rhs) + def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node) -> relax.Var: + def promote_binary_op_args(lhs, rhs): + if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): + return lhs, rhs + elif isinstance(lhs, relax.Expr): + assert isinstance(lhs.struct_info, relax.TensorStructInfo) + return lhs, relax.const(rhs, lhs.struct_info.dtype) + elif isinstance(rhs, relax.Expr): + assert isinstance(rhs.struct_info, relax.TensorStructInfo) + return relax.const(lhs, rhs.struct_info.dtype), rhs + else: + assert False + + def call_binary_op(op, lhs, rhs): + lhs, rhs = promote_binary_op_args(lhs, rhs) + return self.block_builder.emit(op(lhs, rhs)) + + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return call_binary_op(relax_op, lhs, rhs) + elif isinstance(lhs, relax.expr.Constant): + return call_binary_op(relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype)) + elif isinstance(rhs, relax.expr.Constant): + return call_binary_op(relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs) + return intrinsic_op(lhs, rhs) + + return convert ########## Creation ########## @@ -486,14 +441,6 @@ def _to(self, node: fx.Node) -> relax.Var: def _matmul_impl(self, a: relax.Expr, b: relax.Expr): return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b, out_dtype="float32")) - def _matmul(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - res = self._matmul_impl( - args[0], - args[1], - ) - return res - def _addmm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] y = self.env[node.args[1]] @@ -1568,6 +1515,7 @@ def _getitem(self, node: fx.Node) -> relax.Var: assert False def create_convert_map(self): + import operator from torch import nn from torch import fx @@ -1641,23 +1589,27 @@ def create_convert_map(self): "triu_": self._inplace_tril_triu(relax.op.triu), "triu": self._tril_triu(relax.op.triu), # binary - "add": self._add, - "eq": self._eq, - "floordiv": self._floordiv, - "iadd": self._add, - "lt": self._lt, - "matmul": self._matmul, - "max": self._max, - "mul": self._mul, - "pow": self._pow, - "sub": self._sub, - "truediv": self._truediv, + "add": self._binary_op(relax.op.add, operator.add), + "eq": self._binary_op(relax.op.equal, operator.eq), + "floordiv": self._binary_op(relax.op.floor_divide, operator.floordiv), + "iadd": self._binary_op(relax.op.add, operator.add), + "lt": self._binary_op(relax.op.less, operator.lt), + "matmul": self._binary_op( + partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul + ), + "max": self._binary_op(relax.op.maximum, max), + "mul": self._binary_op(relax.op.multiply, operator.mul), + "pow": self._binary_op(relax.op.power, operator.pow), + "sub": self._binary_op(relax.op.subtract, operator.sub), + "truediv": self._binary_op(relax.op.divide, operator.truediv), # neural network "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False), "addmm": self._addmm, "avg_pool2d": self._avg_pool2d, "baddbmm": self._baddbmm, - "bmm": self._matmul, + "bmm": self._binary_op( + partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul + ), "conv_transpose1d": self._conv1d_transpose_functional, "conv_transpose2d": self._conv2d_transpose_functional, "conv1d": self._conv1d_functional,