Skip to content

Commit

Permalink
don't allow MLB assigns with different axes (tinygrad#3557)
Browse files Browse the repository at this point in the history
* allow LB <- MLB assign, but don't reuse buffer

* update test

* update test

* assign assert axes are the same

* update tests to manually shard running stats

* unused import
  • Loading branch information
chaosagent authored Mar 1, 2024
1 parent cfd23f3 commit d16aa89
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 7 deletions.
2 changes: 1 addition & 1 deletion examples/hlb_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def update(self, net, decay):

if len(GPUS) > 1:
for k, x in get_state_dict(model).items():
if not getenv('SYNCBN') and ('running_mean' in k or 'running_bias' in k):
if not getenv('SYNCBN') and ('running_mean' in k or 'running_var' in k):
x.shard_(GPUS, axis=0)
else:
x.to_(GPUS)
Expand Down
27 changes: 21 additions & 6 deletions test/test_multitensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tinygrad.device import BufferCopy
from tinygrad.ops import LoadOps, ReduceOps
from tinygrad.helpers import CI
from tinygrad.nn.state import get_parameters
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.realize import create_schedule
import numpy as np
from hypothesis import given, strategies as strat, settings
Expand Down Expand Up @@ -355,6 +355,15 @@ def test_reshape_on_axis_uneven(self):
np.testing.assert_allclose(t0.numpy().flatten(), t1.numpy().flatten())
assert t1.lazydata.axis == 2

def test_mlb_assign_change_axis(self):
devices = (d0, d1)

t_none = Tensor.zeros((16, 16)).shard(devices).contiguous().realize()
t_zero = Tensor.ones((16, 16)).shard(devices, axis=0)
with self.assertRaises(AssertionError):
# don't allow assigns that change axes
t_none.assign(t_zero)

@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
# shrink a multitensor on sharded axis
Expand Down Expand Up @@ -544,16 +553,17 @@ def __call__(self, x:Tensor):
def test_unsynced_backprop_sync_weights(self):
from extra.lr_scheduler import OneCycleLR
from examples.hlb_cifar10 import UnsyncedBatchNorm
from tinygrad.features.multi import MultiLazyBuffer
GPUS = (d1, d2)

with Tensor.train():
conv = nn.Conv2d(3, 16, 3)
bn = UnsyncedBatchNorm(16, num_devices=len(GPUS))

for p in get_parameters([conv, bn]):
if not isinstance(p.lazydata, MultiLazyBuffer):
p.shard_(GPUS)
for k, p in get_state_dict([conv, bn]).items():
if 'running_mean' in k or 'running_var' in k:
p.shard_(GPUS, axis=0)
else:
p.to_(GPUS)
optim = nn.optim.Adam(get_parameters([conv, bn]))
lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
lr_sched.step()
Expand Down Expand Up @@ -597,8 +607,13 @@ def test_synced_vs_unsynced_bn(self):
synced_bn = BatchNorm2d(8)
unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices))

for p in get_parameters([synced_bn, unsynced_bn]):
for p in get_parameters(synced_bn):
p.shard_(devices)
for k, p in get_state_dict(unsynced_bn).items():
if 'running_mean' in k or 'running_var' in k:
p.shard_(devices, axis=0)
else:
p.to_(devices)

synced_out = synced_bn(x)
synced_si = [si for si in create_schedule(synced_out.lazydata.lbs)]
Expand Down
2 changes: 2 additions & 0 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def assign(self, x) -> Tensor:
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
# NOTE: we allow cross device assign
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
if isinstance(self.lazydata, MultiLazyBuffer):
assert self.lazydata.axis == x.lazydata.axis
assert not x.requires_grad # self requires_grad is okay?
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
if self.dtype == x.dtype and not getenv("DISALLOW_ASSIGN"):
Expand Down

0 comments on commit d16aa89

Please sign in to comment.