diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 6410257ec7908..6262f28fa2c78 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -355,15 +355,6 @@ def test_reshape_on_axis_uneven(self): np.testing.assert_allclose(t0.numpy().flatten(), t1.numpy().flatten()) assert t1.lazydata.axis == 2 - def test_mlb_assign_change_axis(self): - devices = (d0, d1) - - t_none = Tensor.zeros((16, 16)).shard(devices).contiguous().realize() - t_zero = Tensor.ones((16, 16)).shard(devices, axis=0) - with self.assertRaises(AssertionError): - # don't allow assigns that change axes - t_none.assign(t_zero) - @unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI") class TestShrinkMultiTensorShardedAxis(unittest.TestCase): # shrink a multitensor on sharded axis diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index f823727de1feb..e643b52917594 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -142,8 +142,6 @@ def assign(self, x) -> Tensor: if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype) # NOTE: we allow cross device assign assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}" - if isinstance(self.lazydata, MultiLazyBuffer): - assert self.lazydata.axis == x.lazydata.axis assert not x.requires_grad # self requires_grad is okay? if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}") if self.dtype == x.dtype and not getenv("DISALLOW_ASSIGN"):