Skip to content

Commit

Permalink
[Fix] int32/64 mismatch of buffer elem_offset at HandleBufferBindScope (
Browse files Browse the repository at this point in the history
#11755)

Yet another int64/32 mismatch at TIR level. `ArgBinder::Bind_` requires `elem_offset` of arg & view to have the same dtype while `int64-broadcast-concat` produce int64 `elem_offset`
  • Loading branch information
ganler authored Jun 22, 2022
1 parent caa0d59 commit c334790
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/tir/transforms/storage_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,12 @@ class BufferBindUnwrapper : public StmtExprMutator {
view = view.MakeStrideView();
}

// Match integer bits of source->elem_offset and view->elem_offset
// as is required by ArgBinder::Bind_
if (view->elem_offset.defined() && source->elem_offset.dtype() != view->elem_offset.dtype()) {
view.CopyOnWrite()->elem_offset = cast(source->elem_offset.dtype(), view->elem_offset);
}

// Bind any variables that reference the view (e.g. elem_offset,
// strides, shape). Pass fuzzy_match=false, because all shape
// transformations should have been handled in
Expand Down
17 changes: 17 additions & 0 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,23 @@ def test_broadcast_to_const_shape_int64(executor_kind):
tvm.testing.assert_allclose(op_res.numpy(), ref_res)


def test_broadcast_concat_shape_int64(executor_kind):
x_shape = (1, 2, 1, 1)
broadcast_shape = [1, 2, 2, 1]
x = relay.var("data", relay.TensorType(x_shape, "float32"))
broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype="int64"))
concate = relay.op.concatenate((broadcast_to,), axis=0)

f = relay.Function([x], concate)

x = np.zeros(x_shape).astype("float32")
ref_res = np.concatenate((np.broadcast_to(x, broadcast_shape),), axis=0)

for target, dev in tvm.testing.enabled_targets():
op_res = relay.create_executor(executor_kind, device=dev, target=target).evaluate(f)(x)
tvm.testing.assert_allclose(op_res.numpy(), ref_res)


@tvm.testing.uses_gpu
def test_broadcast_to_like(executor_kind):
shape = (4, 1, 6)
Expand Down

0 comments on commit c334790

Please sign in to comment.