Skip to content

Commit

Permalink
simple multitensor API (tinygrad#2903)
Browse files Browse the repository at this point in the history
* simple multitensor API

* test multitensor

* mt work

* new api

* copies

* all but data parallel

* allreduce there

* works, but axis sharded

* fix all mt tests

* features/multi

* work

* backprop

* fix tests

* tests passing

* mt progress

* cleanups

* less lines

* tensor cleanup

* save more lines

* mypy passes

* fix tests

* skip for cuda too

* bump download cache
  • Loading branch information
geohot authored Jan 3, 2024
1 parent 5522ba2 commit f494b9d
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: Unit Tests
env:
# increment this when downloads substantially change to avoid the internet
DOWNLOAD_CACHE_VERSION: '3'
DOWNLOAD_CACHE_VERSION: '4'

on:
push:
Expand Down
137 changes: 137 additions & 0 deletions test/test_multitensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import unittest
from tinygrad import Tensor, Device, nn, GlobalCounters
from tinygrad.helpers import CI
from tinygrad.nn.state import get_parameters
import numpy as np

d0, d1 = f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2"
d2, d3 = f"{Device.DEFAULT}:3", f"{Device.DEFAULT}:4"
N = 128

# shard_x is "data parallel"
# shard_w is "model parallel"

@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA"}, "no GPU CI")
class TestMultiTensor(unittest.TestCase):
def test_shard(self):
X = Tensor.ones(256).contiguous().realize()
X.shard_((d0, d1), 0)
for lb in X.lazydata.lbs:
assert lb.shape == (128,)

def test_numpy(self):
X = Tensor.ones(256)
X.shard_((d0, d1), 0)
np.testing.assert_allclose(X.numpy(), 1)

def _test_simple_add_axis(self, shard_x, shard_w):
X = Tensor.ones(256).contiguous().realize()
W = Tensor.ones(256).contiguous().realize()
X.shard_((d0, d1), shard_x)
W.shard_((d0, d1), shard_w)
O = X + W
np.testing.assert_allclose(O.numpy(), 2)

def test_simple_add(self): return self._test_simple_add_axis(None, None)
def test_simple_add_X(self): return self._test_simple_add_axis(0, None)
def test_simple_add_W(self): return self._test_simple_add_axis(None, 0)
def test_simple_add_XW(self): return self._test_simple_add_axis(0, 0)

def test_four_add(self):
X = Tensor.ones(256, 256).contiguous().realize()
W = Tensor.ones(256, 256).contiguous().realize()
X.shard_((d0, d1, d2, d3), 1)
W.shard_((d0, d1, d2, d3), None)
O = X + W
np.testing.assert_allclose(O.numpy(), 2)

def _test_simple_reduce_axis(self, shard_x):
X = Tensor.ones(256, 256).contiguous().realize()
X.shard_((d0, d1), shard_x)
O = X.sum(axis=1)
np.testing.assert_allclose(O.numpy(), 256)

def test_simple_reduce(self): return self._test_simple_reduce_axis(None)
def test_simple_reduce_0(self): return self._test_simple_reduce_axis(0)
def test_simple_reduce_1(self): return self._test_simple_reduce_axis(1)

def _test_matmul_shard_axis(self, shard_x, shard_w):
X = Tensor.kaiming_uniform(N, N).realize()
W = Tensor.kaiming_uniform(N, N).realize()
Xs = X.shard((d0, d1), shard_x)
Ws = W.shard((d0, d1), shard_w)
O = (Xs@Ws)
np.testing.assert_allclose(X.numpy() @ W.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5)

def _test_double_matmul_shard_axis(self, shard_x, shard_w):
X = Tensor.kaiming_uniform(N, N).realize()
W1 = Tensor.kaiming_uniform(N, N).realize()
W2 = Tensor.kaiming_uniform(N, N).realize()
Xs = X.shard((d0, d1), shard_x)
W1s = W1.shard((d0, d1), shard_w)
W2s = W2.shard((d0, d1), shard_w)
O = (Xs@W1s)@W2s
np.testing.assert_allclose((X.numpy() @ W1.numpy()) @ W2.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5)

def test_matmul_shard_none(self): return self._test_matmul_shard_axis(None, None)
def test_matmul_shard_X_0(self): return self._test_matmul_shard_axis(0, None)
def test_matmul_shard_X_1(self): return self._test_matmul_shard_axis(1, None)
def test_matmul_shard_W_0(self): return self._test_matmul_shard_axis(None, 0)
def test_matmul_shard_W_1(self): return self._test_matmul_shard_axis(None, 1)

def test_matmul_shard_0_0(self): return self._test_matmul_shard_axis(0, 0)
def test_matmul_shard_0_1(self): return self._test_matmul_shard_axis(0, 1)
def test_matmul_shard_1_0(self): return self._test_matmul_shard_axis(1, 0)
def test_matmul_shard_1_1(self): return self._test_matmul_shard_axis(1, 1)

def test_double_matmul_shard_X_0(self): return self._test_double_matmul_shard_axis(0, None)
def test_double_matmul_shard_X_1(self): return self._test_double_matmul_shard_axis(1, None)
def test_double_matmul_shard_W_0(self): return self._test_double_matmul_shard_axis(None, 0)
def test_double_matmul_shard_W_1(self): return self._test_double_matmul_shard_axis(None, 1)

def test_conv_data_shard(self):
conv = nn.Conv2d(3, 16, 3, bias=False)
for p in get_parameters(conv): p.shard_((d0, d1))
fake_image = Tensor.rand((2, 3, 32, 32)).shard((d0, d1), axis=0)
out = conv(fake_image)
out.numpy()

def test_conv_bias_data_shard(self):
conv = nn.Conv2d(3, 16, 3)
for p in get_parameters(conv): p.shard_((d0, d1))
fake_image = Tensor.rand((2, 3, 32, 32)).shard((d0, d1), axis=0)
out = conv(fake_image)
out.numpy()

def test_backprop_conv(self):
conv = nn.Conv2d(3, 16, 3)
for p in get_parameters(conv): p.shard_((d0, d1))
optim = nn.optim.Adam(get_parameters(conv))
fake_image = Tensor.rand((2, 3, 32, 32)).shard((d0, d1), axis=0)
out = conv(fake_image)
optim.zero_grad()
out.mean().backward()
#for p in get_parameters(conv): p.grad.realize()
optim.step()

def test_data_parallel_resnet(self):
import sys, pathlib
sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix())
from resnet import ResNet18

fake_image = Tensor.rand((2, 3, 224, 224))
fake_image_sharded = fake_image.shard((d0, d1), axis=0)
print(fake_image_sharded.shape)
m = ResNet18()
m.load_from_pretrained()
real_output = m(fake_image).numpy()
for p in get_parameters(m): p.shard_((d0, d1)).realize()
GlobalCounters.reset()
shard_output = m(fake_image_sharded).realize()
assert shard_output.lazydata.lbs[0].shape == (1, 1000)
assert shard_output.lazydata.lbs[1].shape == (1, 1000)
shard_output_np = shard_output.numpy()
np.testing.assert_allclose(real_output, shard_output_np, atol=1e-6, rtol=1e-6)

if __name__ == '__main__':
unittest.main()
103 changes: 103 additions & 0 deletions tinygrad/features/multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from __future__ import annotations
from typing import Optional, Union, Any, Tuple, List
import functools
from tinygrad.helpers import all_same, dedup
from tinygrad.dtype import DType
from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
from tinygrad.lazy import LazyBuffer, create_schedule
from tinygrad.shape.shapetracker import ShapeTracker, sint

def all_reduce(lbs):
# TODO: replace this with ring reduce
return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]

def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
assert lbs[0].shape[axis] % len(lbs) == 0, f"{lbs[0].shape=} {axis=} {len(lbs)=}"
sz = lbs[0].shape[axis] // len(lbs)
return [lb.shrink(tuple((0,s) if a != axis else (sz*i,sz*(i+1)) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]

class MultiLazyBuffer:
def __init__(self, lbs:List[LazyBuffer], axis:Optional[int]):
assert all(isinstance(x, LazyBuffer) for x in lbs) and len(lbs) >= 2, "all lbs must be LazyBuffers, and we need at least two of them"
assert all_same([(x.shape, x.dtype, x.st) for x in lbs]), "all multilazybuffer needs same shape, dtype, and st"
self.lbs, self.axis, self.dtype, self.device = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs)
self.shape = tuple(s*len(self.lbs) if a == self.axis else s for a,s in enumerate(lbs[0].shape))

def __repr__(self):
return f"<MLB{chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"

@staticmethod
def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None):
lbs = [lb.contiguous() if lb.base != lb else lb] * len(devices)
return MultiLazyBuffer([lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)], axis)

def copy_to_device(self, device:str) -> LazyBuffer:
if self.axis is None: return self.lbs[0].copy_to_device(device)
sz = self.lbs[0].shape[self.axis]
llbs = []
for i,lb in enumerate([lb.copy_to_device(device) for lb in self.lbs]):
pad_arg = tuple((0,0) if a != self.axis else (sz*i,(s*len(self.lbs))-sz*(i+1)) for a,s in enumerate(lb.shape))
llbs.append(lb.pad(pad_arg))
return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)

# TODO: fix this
def is_unrealized_contiguous_const(self): return False

# passthroughs
def schedule(self, seen=None): return create_schedule(self.lbs, seen)
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis)
def const(self, val:Union[float, int]) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis)
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis)

# elementwise is simple
def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer, arg:Optional[Any]=None) -> MultiLazyBuffer:
msrcs = (self,)+in_srcs
assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}"
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"

# NOTE: they all have to share an axis, we always choose [-1]
axis = axes[-1] if len(axes := dedup([x.axis for x in msrcs if x.axis is not None])) else None
srcs = []
for mlb in msrcs:
if mlb.axis == axis: srcs.append(mlb.lbs)
elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis))
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis))
return MultiLazyBuffer([lsrcs[0].e(op, *lsrcs[1:], arg=arg) for lsrcs in zip(*srcs)], axis)

def _shape_to_single_shard(self, shape): return tuple(s//len(self.lbs) if a == self.axis else s for a,s in enumerate(shape))

def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> MultiLazyBuffer:
if self.axis is not None and new_shape[self.axis] == 1:
# all-reduce on sharded axes
return MultiLazyBuffer(all_reduce([x.r(op, new_shape) for x in self.lbs]), None)
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
return MultiLazyBuffer([x.r(op, self._shape_to_single_shard(new_shape)) for x in self.lbs], self.axis)

# *** movement ops ***

def reshape(self, arg:Tuple[sint, ...]):
if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None)
# TODO: this can be wrong
st = ShapeTracker.from_shape(self.shape)
rs = st.real_strides()[self.axis]
new_axis = st.reshape(arg).real_strides().index(rs)
narg = tuple(s//len(self.lbs) if a == new_axis else s for a,s in enumerate(arg))
return MultiLazyBuffer([x.reshape(narg) for x in self.lbs], new_axis)

def pad(self, arg:Tuple[Tuple[sint, sint], ...]):
assert self.axis is None or arg[self.axis] == (0,0), "padding not supported on sharded axis"
return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis)
def expand(self, arg:Tuple[sint, ...]):
# NOTE: this assert isn't needed, sharded axis can have dim 1
assert self.axis is None or arg[self.axis] == self.lbs[0].shape[self.axis] * len(self.lbs), "expand not supported on sharded axis"
return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg)) for x in self.lbs], self.axis)
def permute(self, arg:Tuple[int, ...]):
# all permutes supported!
return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None)
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]):
assert self.axis is None or arg[self.axis] == (0, self.lbs[0].shape[self.axis] * len(self.lbs)), "shrinking not supported on sharded axis"
narg = tuple((s1//len(self.lbs), s2//len(self.lbs)) if a == self.axis else (s1,s2) for a,(s1,s2) in enumerate(arg))
return MultiLazyBuffer([x.shrink(narg) for x in self.lbs], self.axis)
def stride(self, arg:Tuple[int, ...]):
assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis"
return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis)
10 changes: 6 additions & 4 deletions tinygrad/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tinygrad.helpers import DEBUG, merge_dicts, getenv, all_int, Context, GRAPH
from tinygrad.device import Device, JITRunner, CompiledASTRunner, Buffer
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, Node
from weakref import ref, WeakKeyDictionary
Expand Down Expand Up @@ -50,16 +51,17 @@ def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)

def __call__(self, *args, **kwargs) -> ReturnType:
# all inputs (except const) are realized
input_tensors: Dict[Union[int, str], Tensor] = {cast(Union[int, str], k):v.realize() for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor} # noqa: E501
expected_name_sts_dtype = tuple([(k, v.lazydata.st.unbind(), v.dtype) for k,v in input_tensors.items()])
input_tensors: Dict[Union[int, str], LazyBuffer] = {cast(Union[int, str], k):v.realize().lazydata for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor} # noqa: E501
assert all(isinstance(x, LazyBuffer) for x in input_tensors.values()), "multilazybuffer JIT isn't supported"
expected_name_sts_dtype = tuple([(k, v.st.unbind(), v.dtype) for k,v in input_tensors.items()])

# get rawbuffers
# TODO: why can .realized have Any type?
input_rawbuffers: List[Buffer] = [v.lazydata.base.realized for v in input_tensors.values() if v.lazydata.base.realized is not None]
input_rawbuffers: List[Buffer] = [v.base.realized for v in input_tensors.values() if v.base.realized is not None]
assert len(set(input_rawbuffers)) == len(input_rawbuffers), "duplicate inputs to JIT"

# get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global
var_vals: Dict[Variable, int] = merge_dicts([arg.lazydata.st.var_vals for arg in input_tensors.values()] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))]) # noqa: E501
var_vals: Dict[Variable, int] = merge_dicts([arg.st.var_vals for arg in input_tensors.values()] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))]) # noqa: E501
expected_vals = tuple(var_vals.keys())

if self.cnt >= 2:
Expand Down
Loading

0 comments on commit f494b9d

Please sign in to comment.