forked from tinygrad/tinygrad
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
simple multitensor API (tinygrad#2903)
* simple multitensor API * test multitensor * mt work * new api * copies * all but data parallel * allreduce there * works, but axis sharded * fix all mt tests * features/multi * work * backprop * fix tests * tests passing * mt progress * cleanups * less lines * tensor cleanup * save more lines * mypy passes * fix tests * skip for cuda too * bump download cache
- Loading branch information
Showing
5 changed files
with
285 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import unittest | ||
from tinygrad import Tensor, Device, nn, GlobalCounters | ||
from tinygrad.helpers import CI | ||
from tinygrad.nn.state import get_parameters | ||
import numpy as np | ||
|
||
d0, d1 = f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2" | ||
d2, d3 = f"{Device.DEFAULT}:3", f"{Device.DEFAULT}:4" | ||
N = 128 | ||
|
||
# shard_x is "data parallel" | ||
# shard_w is "model parallel" | ||
|
||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA"}, "no GPU CI") | ||
class TestMultiTensor(unittest.TestCase): | ||
def test_shard(self): | ||
X = Tensor.ones(256).contiguous().realize() | ||
X.shard_((d0, d1), 0) | ||
for lb in X.lazydata.lbs: | ||
assert lb.shape == (128,) | ||
|
||
def test_numpy(self): | ||
X = Tensor.ones(256) | ||
X.shard_((d0, d1), 0) | ||
np.testing.assert_allclose(X.numpy(), 1) | ||
|
||
def _test_simple_add_axis(self, shard_x, shard_w): | ||
X = Tensor.ones(256).contiguous().realize() | ||
W = Tensor.ones(256).contiguous().realize() | ||
X.shard_((d0, d1), shard_x) | ||
W.shard_((d0, d1), shard_w) | ||
O = X + W | ||
np.testing.assert_allclose(O.numpy(), 2) | ||
|
||
def test_simple_add(self): return self._test_simple_add_axis(None, None) | ||
def test_simple_add_X(self): return self._test_simple_add_axis(0, None) | ||
def test_simple_add_W(self): return self._test_simple_add_axis(None, 0) | ||
def test_simple_add_XW(self): return self._test_simple_add_axis(0, 0) | ||
|
||
def test_four_add(self): | ||
X = Tensor.ones(256, 256).contiguous().realize() | ||
W = Tensor.ones(256, 256).contiguous().realize() | ||
X.shard_((d0, d1, d2, d3), 1) | ||
W.shard_((d0, d1, d2, d3), None) | ||
O = X + W | ||
np.testing.assert_allclose(O.numpy(), 2) | ||
|
||
def _test_simple_reduce_axis(self, shard_x): | ||
X = Tensor.ones(256, 256).contiguous().realize() | ||
X.shard_((d0, d1), shard_x) | ||
O = X.sum(axis=1) | ||
np.testing.assert_allclose(O.numpy(), 256) | ||
|
||
def test_simple_reduce(self): return self._test_simple_reduce_axis(None) | ||
def test_simple_reduce_0(self): return self._test_simple_reduce_axis(0) | ||
def test_simple_reduce_1(self): return self._test_simple_reduce_axis(1) | ||
|
||
def _test_matmul_shard_axis(self, shard_x, shard_w): | ||
X = Tensor.kaiming_uniform(N, N).realize() | ||
W = Tensor.kaiming_uniform(N, N).realize() | ||
Xs = X.shard((d0, d1), shard_x) | ||
Ws = W.shard((d0, d1), shard_w) | ||
O = (Xs@Ws) | ||
np.testing.assert_allclose(X.numpy() @ W.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5) | ||
|
||
def _test_double_matmul_shard_axis(self, shard_x, shard_w): | ||
X = Tensor.kaiming_uniform(N, N).realize() | ||
W1 = Tensor.kaiming_uniform(N, N).realize() | ||
W2 = Tensor.kaiming_uniform(N, N).realize() | ||
Xs = X.shard((d0, d1), shard_x) | ||
W1s = W1.shard((d0, d1), shard_w) | ||
W2s = W2.shard((d0, d1), shard_w) | ||
O = (Xs@W1s)@W2s | ||
np.testing.assert_allclose((X.numpy() @ W1.numpy()) @ W2.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5) | ||
|
||
def test_matmul_shard_none(self): return self._test_matmul_shard_axis(None, None) | ||
def test_matmul_shard_X_0(self): return self._test_matmul_shard_axis(0, None) | ||
def test_matmul_shard_X_1(self): return self._test_matmul_shard_axis(1, None) | ||
def test_matmul_shard_W_0(self): return self._test_matmul_shard_axis(None, 0) | ||
def test_matmul_shard_W_1(self): return self._test_matmul_shard_axis(None, 1) | ||
|
||
def test_matmul_shard_0_0(self): return self._test_matmul_shard_axis(0, 0) | ||
def test_matmul_shard_0_1(self): return self._test_matmul_shard_axis(0, 1) | ||
def test_matmul_shard_1_0(self): return self._test_matmul_shard_axis(1, 0) | ||
def test_matmul_shard_1_1(self): return self._test_matmul_shard_axis(1, 1) | ||
|
||
def test_double_matmul_shard_X_0(self): return self._test_double_matmul_shard_axis(0, None) | ||
def test_double_matmul_shard_X_1(self): return self._test_double_matmul_shard_axis(1, None) | ||
def test_double_matmul_shard_W_0(self): return self._test_double_matmul_shard_axis(None, 0) | ||
def test_double_matmul_shard_W_1(self): return self._test_double_matmul_shard_axis(None, 1) | ||
|
||
def test_conv_data_shard(self): | ||
conv = nn.Conv2d(3, 16, 3, bias=False) | ||
for p in get_parameters(conv): p.shard_((d0, d1)) | ||
fake_image = Tensor.rand((2, 3, 32, 32)).shard((d0, d1), axis=0) | ||
out = conv(fake_image) | ||
out.numpy() | ||
|
||
def test_conv_bias_data_shard(self): | ||
conv = nn.Conv2d(3, 16, 3) | ||
for p in get_parameters(conv): p.shard_((d0, d1)) | ||
fake_image = Tensor.rand((2, 3, 32, 32)).shard((d0, d1), axis=0) | ||
out = conv(fake_image) | ||
out.numpy() | ||
|
||
def test_backprop_conv(self): | ||
conv = nn.Conv2d(3, 16, 3) | ||
for p in get_parameters(conv): p.shard_((d0, d1)) | ||
optim = nn.optim.Adam(get_parameters(conv)) | ||
fake_image = Tensor.rand((2, 3, 32, 32)).shard((d0, d1), axis=0) | ||
out = conv(fake_image) | ||
optim.zero_grad() | ||
out.mean().backward() | ||
#for p in get_parameters(conv): p.grad.realize() | ||
optim.step() | ||
|
||
def test_data_parallel_resnet(self): | ||
import sys, pathlib | ||
sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix()) | ||
from resnet import ResNet18 | ||
|
||
fake_image = Tensor.rand((2, 3, 224, 224)) | ||
fake_image_sharded = fake_image.shard((d0, d1), axis=0) | ||
print(fake_image_sharded.shape) | ||
m = ResNet18() | ||
m.load_from_pretrained() | ||
real_output = m(fake_image).numpy() | ||
for p in get_parameters(m): p.shard_((d0, d1)).realize() | ||
GlobalCounters.reset() | ||
shard_output = m(fake_image_sharded).realize() | ||
assert shard_output.lazydata.lbs[0].shape == (1, 1000) | ||
assert shard_output.lazydata.lbs[1].shape == (1, 1000) | ||
shard_output_np = shard_output.numpy() | ||
np.testing.assert_allclose(real_output, shard_output_np, atol=1e-6, rtol=1e-6) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
from __future__ import annotations | ||
from typing import Optional, Union, Any, Tuple, List | ||
import functools | ||
from tinygrad.helpers import all_same, dedup | ||
from tinygrad.dtype import DType | ||
from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps | ||
from tinygrad.lazy import LazyBuffer, create_schedule | ||
from tinygrad.shape.shapetracker import ShapeTracker, sint | ||
|
||
def all_reduce(lbs): | ||
# TODO: replace this with ring reduce | ||
return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs] | ||
|
||
def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]: | ||
assert lbs[0].shape[axis] % len(lbs) == 0, f"{lbs[0].shape=} {axis=} {len(lbs)=}" | ||
sz = lbs[0].shape[axis] // len(lbs) | ||
return [lb.shrink(tuple((0,s) if a != axis else (sz*i,sz*(i+1)) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)] | ||
|
||
class MultiLazyBuffer: | ||
def __init__(self, lbs:List[LazyBuffer], axis:Optional[int]): | ||
assert all(isinstance(x, LazyBuffer) for x in lbs) and len(lbs) >= 2, "all lbs must be LazyBuffers, and we need at least two of them" | ||
assert all_same([(x.shape, x.dtype, x.st) for x in lbs]), "all multilazybuffer needs same shape, dtype, and st" | ||
self.lbs, self.axis, self.dtype, self.device = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs) | ||
self.shape = tuple(s*len(self.lbs) if a == self.axis else s for a,s in enumerate(lbs[0].shape)) | ||
|
||
def __repr__(self): | ||
return f"<MLB{chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>" | ||
|
||
@staticmethod | ||
def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None): | ||
lbs = [lb.contiguous() if lb.base != lb else lb] * len(devices) | ||
return MultiLazyBuffer([lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)], axis) | ||
|
||
def copy_to_device(self, device:str) -> LazyBuffer: | ||
if self.axis is None: return self.lbs[0].copy_to_device(device) | ||
sz = self.lbs[0].shape[self.axis] | ||
llbs = [] | ||
for i,lb in enumerate([lb.copy_to_device(device) for lb in self.lbs]): | ||
pad_arg = tuple((0,0) if a != self.axis else (sz*i,(s*len(self.lbs))-sz*(i+1)) for a,s in enumerate(lb.shape)) | ||
llbs.append(lb.pad(pad_arg)) | ||
return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs) | ||
|
||
# TODO: fix this | ||
def is_unrealized_contiguous_const(self): return False | ||
|
||
# passthroughs | ||
def schedule(self, seen=None): return create_schedule(self.lbs, seen) | ||
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis) | ||
def const(self, val:Union[float, int]) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis) | ||
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis) | ||
|
||
# elementwise is simple | ||
def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer, arg:Optional[Any]=None) -> MultiLazyBuffer: | ||
msrcs = (self,)+in_srcs | ||
assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}" | ||
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}" | ||
|
||
# NOTE: they all have to share an axis, we always choose [-1] | ||
axis = axes[-1] if len(axes := dedup([x.axis for x in msrcs if x.axis is not None])) else None | ||
srcs = [] | ||
for mlb in msrcs: | ||
if mlb.axis == axis: srcs.append(mlb.lbs) | ||
elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis)) | ||
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis)) | ||
return MultiLazyBuffer([lsrcs[0].e(op, *lsrcs[1:], arg=arg) for lsrcs in zip(*srcs)], axis) | ||
|
||
def _shape_to_single_shard(self, shape): return tuple(s//len(self.lbs) if a == self.axis else s for a,s in enumerate(shape)) | ||
|
||
def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> MultiLazyBuffer: | ||
if self.axis is not None and new_shape[self.axis] == 1: | ||
# all-reduce on sharded axes | ||
return MultiLazyBuffer(all_reduce([x.r(op, new_shape) for x in self.lbs]), None) | ||
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct | ||
return MultiLazyBuffer([x.r(op, self._shape_to_single_shard(new_shape)) for x in self.lbs], self.axis) | ||
|
||
# *** movement ops *** | ||
|
||
def reshape(self, arg:Tuple[sint, ...]): | ||
if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None) | ||
# TODO: this can be wrong | ||
st = ShapeTracker.from_shape(self.shape) | ||
rs = st.real_strides()[self.axis] | ||
new_axis = st.reshape(arg).real_strides().index(rs) | ||
narg = tuple(s//len(self.lbs) if a == new_axis else s for a,s in enumerate(arg)) | ||
return MultiLazyBuffer([x.reshape(narg) for x in self.lbs], new_axis) | ||
|
||
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): | ||
assert self.axis is None or arg[self.axis] == (0,0), "padding not supported on sharded axis" | ||
return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis) | ||
def expand(self, arg:Tuple[sint, ...]): | ||
# NOTE: this assert isn't needed, sharded axis can have dim 1 | ||
assert self.axis is None or arg[self.axis] == self.lbs[0].shape[self.axis] * len(self.lbs), "expand not supported on sharded axis" | ||
return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg)) for x in self.lbs], self.axis) | ||
def permute(self, arg:Tuple[int, ...]): | ||
# all permutes supported! | ||
return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None) | ||
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): | ||
assert self.axis is None or arg[self.axis] == (0, self.lbs[0].shape[self.axis] * len(self.lbs)), "shrinking not supported on sharded axis" | ||
narg = tuple((s1//len(self.lbs), s2//len(self.lbs)) if a == self.axis else (s1,s2) for a,(s1,s2) in enumerate(arg)) | ||
return MultiLazyBuffer([x.shrink(narg) for x in self.lbs], self.axis) | ||
def stride(self, arg:Tuple[int, ...]): | ||
assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis" | ||
return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.