diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index 3350e9d716e0..fa698d22c36b 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -306,7 +306,7 @@ def update(self, net, decay): if len(GPUS) > 1: for k, x in get_state_dict(model).items(): - if not getenv('SYNCBN') and ('running_mean' in k or 'running_bias' in k): + if not getenv('SYNCBN') and ('running_mean' in k or 'running_var' in k): x.shard_(GPUS, axis=0) else: x.to_(GPUS) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 6262f28fa2c7..1c5db1375086 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -4,7 +4,7 @@ from tinygrad.device import BufferCopy from tinygrad.ops import LoadOps, ReduceOps from tinygrad.helpers import CI -from tinygrad.nn.state import get_parameters +from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.realize import create_schedule import numpy as np from hypothesis import given, strategies as strat, settings @@ -355,6 +355,15 @@ def test_reshape_on_axis_uneven(self): np.testing.assert_allclose(t0.numpy().flatten(), t1.numpy().flatten()) assert t1.lazydata.axis == 2 + def test_mlb_assign_change_axis(self): + devices = (d0, d1) + + t_none = Tensor.zeros((16, 16)).shard(devices).contiguous().realize() + t_zero = Tensor.ones((16, 16)).shard(devices, axis=0) + with self.assertRaises(AssertionError): + # don't allow assigns that change axes + t_none.assign(t_zero) + @unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI") class TestShrinkMultiTensorShardedAxis(unittest.TestCase): # shrink a multitensor on sharded axis @@ -544,16 +553,17 @@ def __call__(self, x:Tensor): def test_unsynced_backprop_sync_weights(self): from extra.lr_scheduler import OneCycleLR from examples.hlb_cifar10 import UnsyncedBatchNorm - from tinygrad.features.multi import MultiLazyBuffer GPUS = (d1, d2) with Tensor.train(): conv = nn.Conv2d(3, 16, 3) bn = UnsyncedBatchNorm(16, num_devices=len(GPUS)) - for p in get_parameters([conv, bn]): - if not isinstance(p.lazydata, MultiLazyBuffer): - p.shard_(GPUS) + for k, p in get_state_dict([conv, bn]).items(): + if 'running_mean' in k or 'running_var' in k: + p.shard_(GPUS, axis=0) + else: + p.to_(GPUS) optim = nn.optim.Adam(get_parameters([conv, bn])) lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10) lr_sched.step() @@ -597,8 +607,13 @@ def test_synced_vs_unsynced_bn(self): synced_bn = BatchNorm2d(8) unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices)) - for p in get_parameters([synced_bn, unsynced_bn]): + for p in get_parameters(synced_bn): p.shard_(devices) + for k, p in get_state_dict(unsynced_bn).items(): + if 'running_mean' in k or 'running_var' in k: + p.shard_(devices, axis=0) + else: + p.to_(devices) synced_out = synced_bn(x) synced_si = [si for si in create_schedule(synced_out.lazydata.lbs)] diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e643b5291759..f823727de1fe 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -142,6 +142,8 @@ def assign(self, x) -> Tensor: if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype) # NOTE: we allow cross device assign assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}" + if isinstance(self.lazydata, MultiLazyBuffer): + assert self.lazydata.axis == x.lazydata.axis assert not x.requires_grad # self requires_grad is okay? if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}") if self.dtype == x.dtype and not getenv("DISALLOW_ASSIGN"):