Skip to content

Commit

Permalink
Revert "don't allow MLB assigns with different axes (tinygrad#3483)" (t…
Browse files Browse the repository at this point in the history
…inygrad#3554)

This reverts commit f19d8bb.
  • Loading branch information
chenyuxyz authored Mar 1, 2024
1 parent f19d8bb commit cfd23f3
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 11 deletions.
9 changes: 0 additions & 9 deletions test/test_multitensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,15 +355,6 @@ def test_reshape_on_axis_uneven(self):
np.testing.assert_allclose(t0.numpy().flatten(), t1.numpy().flatten())
assert t1.lazydata.axis == 2

def test_mlb_assign_change_axis(self):
devices = (d0, d1)

t_none = Tensor.zeros((16, 16)).shard(devices).contiguous().realize()
t_zero = Tensor.ones((16, 16)).shard(devices, axis=0)
with self.assertRaises(AssertionError):
# don't allow assigns that change axes
t_none.assign(t_zero)

@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
# shrink a multitensor on sharded axis
Expand Down
2 changes: 0 additions & 2 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,6 @@ def assign(self, x) -> Tensor:
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
# NOTE: we allow cross device assign
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
if isinstance(self.lazydata, MultiLazyBuffer):
assert self.lazydata.axis == x.lazydata.axis
assert not x.requires_grad # self requires_grad is okay?
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
if self.dtype == x.dtype and not getenv("DISALLOW_ASSIGN"):
Expand Down

0 comments on commit cfd23f3

Please sign in to comment.