Skip to content

Commit

Permalink
[Relay][Frontend][Torch] add aten:broadcast_to (#16319)
Browse files Browse the repository at this point in the history
Recently, I worked with the Stable Video Diffusion model, which contains the `aten::broadcast_to` op, but TVM does not support it. 

Add support for it here.
  • Loading branch information
huanmei9 authored Dec 31, 2023
1 parent 506eff2 commit 2da3798
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
16 changes: 16 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2328,6 +2328,21 @@ def broadcast_tensors(self, inputs, input_types):
res_shape = list(torch.broadcast_tensors(*map(torch.empty, infer_shape_value))[0].shape)
return [_op.broadcast_to(tensor, res_shape) for tensor in tensor_list]

def broadcast_to(self, inputs, input_types):
tensor = inputs[0]
new_shape = inputs[1]
import torch

if not isinstance(new_shape, (list, tuple, torch.Size)):
msg = f"Data type {type(new_shape)} could not be parsed in broadcast_to op"
raise AssertionError(msg)

for i, dim in enumerate(new_shape):
if not isinstance(dim, int):
new_shape[i] = int(_infer_value(dim, {}).numpy())

return _op.broadcast_to(tensor, new_shape)

def Bool(self, inputs, input_types):
assert len(inputs) == 1
return inputs[0]
Expand Down Expand Up @@ -4190,6 +4205,7 @@ def create_convert_map(self):
"aten::upsample_nearest3d": self.make_upsample3d("nearest_neighbor"),
"aten::expand_as": self.expand_as,
"aten::broadcast_tensors": self.broadcast_tensors,
"aten::broadcast_to": self.broadcast_to,
"aten::lt": self.make_elemwise("less"),
"aten::gt": self.make_elemwise("greater"),
"aten::le": self.make_elemwise("less_equal"),
Expand Down
25 changes: 25 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2162,6 +2162,31 @@ def forward(self, x, y, z):
verify_model(BroadCastTensors2().float().eval(), input_data=[x, y, z])


@tvm.testing.uses_gpu
def test_forward_broadcast_to():
"""test_forward_broadcast_to"""
torch.set_grad_enabled(False)

class BroadCastTo1(Module):
def forward(self, x):
return torch.broadcast_to(x, (3, 3))

x = torch.tensor([1, 2, 3])
verify_model(BroadCastTo1().float().eval(), input_data=[x])

class BroadCastTo2(Module):
def __init__(self):
super().__init__()
self.y = torch.tensor(1)
self.z = torch.tensor(2)

def forward(self, x):
return torch.broadcast_to(x, (self.y + self.z, 3))

x = torch.tensor([1, 2, 3])
verify_model(BroadCastTo2().float().eval(), input_data=[x])


@tvm.testing.uses_gpu
def test_forward_pow():
"""test_forward_pow"""
Expand Down

0 comments on commit 2da3798

Please sign in to comment.