Skip to content

Commit

Permalink
[Relax][PyTorch] Cleanup Statistical, Search and DataType op converte…
Browse files Browse the repository at this point in the history
…rs (#17372)

* cleanup `_mean()`

* cleanup `_sum()`

* cleanup `_argmax_argmin()`

* cleanup datatype ops
  • Loading branch information
mshr-h authored Sep 15, 2024
1 parent 4bc61a1 commit 48d661c
Showing 1 changed file with 55 additions and 68 deletions.
123 changes: 55 additions & 68 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,61 @@ def _unbind(self, node: fx.Node) -> relax.Var:
ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim)))
return self.block_builder.emit(relax.Tuple(ret))

########## Statistical ##########

def _mean(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False)
return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim))

def _sum(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
if len(args) == 1:
return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim))
return self.block_builder.emit(relax.op.sum(args[0], args[1]))

########## Search ##########

def _argmax_argmin(self, op: Callable) -> Callable:
from torch import fx

def convert(node: fx.Node):
x = self.env[node.args[0]]
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
keepdim = node.args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False)
return self.block_builder.emit(op(x, dim, keepdim))

return convert

########## DataType ##########

def _float(self, node: fx.Node) -> relax.Var:
return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32"))

def _half(self, node: fx.Node) -> relax.Var:
return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16"))

def _to(self, node: fx.Node) -> relax.Var:
import torch

x = self.env[node.args[0]]
if len(node.args) == 2:
if isinstance(node.args[1], torch.dtype):
dtype = TorchFXImporter._convert_data_type(node.args[1], self.env)
return self.block_builder.emit(relax.op.astype(x, dtype))
elif "dtype" in node.kwargs:
dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], self.env)
return self.block_builder.emit(relax.op.astype(x, dtype))
return x

def _type(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dtype = TorchFXImporter._convert_data_type(node.args[1], self.env)
return self.block_builder.emit(relax.op.astype(x, dtype))

########## Creation ##########

def _arange(self, node: fx.Node) -> relax.Var:
Expand Down Expand Up @@ -1022,48 +1077,6 @@ def _full(self, node: fx.Node) -> relax.Var:
)
)

########## Statistical ##########

def _sum(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
if len(args) == 1:
return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim))
return self.block_builder.emit(relax.op.sum(args[0], args[1]))

def _mean(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
if len(args) == 1:
return self.block_builder.emit(relax.op.mean(args[0], keepdims=keepdim))
return self.block_builder.emit(relax.op.mean(args[0], args[1], keepdims=keepdim))

########## DataType ##########

def _float(self, node: fx.Node) -> relax.Var:
return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32"))

def _half(self, node: fx.Node) -> relax.Var:
return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16"))

def _type(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dtype = TorchFXImporter._convert_data_type(node.args[1], self.env)
return self.block_builder.emit(relax.op.astype(x, dtype))

def _to(self, node: fx.Node) -> relax.Var:
import torch

x = self.env[node.args[0]]
if len(node.args) == 2:
if isinstance(node.args[1], torch.dtype):
dtype = TorchFXImporter._convert_data_type(node.args[1], self.env)
return self.block_builder.emit(relax.op.astype(x, dtype))
elif "dtype" in node.kwargs:
dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], self.env)
return self.block_builder.emit(relax.op.astype(x, dtype))
return x

########## Manipulation ##########

def _cat(self, node: fx.Node) -> relax.Var:
Expand Down Expand Up @@ -1220,32 +1233,6 @@ def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
self.env[node.args[0]] = output
return output

########## Search ##########

def _argmax_argmin(self, op: Callable) -> Callable:
from torch import fx

def convert(node: fx.Node):
x = self.env[node.args[0]]
dim = None
keepdims = False

if len(node.args) > 1:
dim = node.args[1]
if len(node.args) > 2:
keepdims = node.args[2]

if "dim" in node.kwargs:
dim = node.kwargs["dim"]
if "keepdim" in node.kwargs:
keepdims = node.kwargs["keepdim"]
if "keepdims" in node.kwargs:
keepdims = node.kwargs["keepdims"]

return self.block_builder.emit(op(x, dim, keepdims))

return convert

########## Neural Network ##########

def _softmax(self, node: fx.Node) -> relax.Var:
Expand Down

0 comments on commit 48d661c

Please sign in to comment.