diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ecabbc0abcfc..d2d21edd8e02 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,7 +1,7 @@ name: Unit Tests env: # increment this when downloads substantially change to avoid the internet - DOWNLOAD_CACHE_VERSION: '3' + DOWNLOAD_CACHE_VERSION: '4' on: push: diff --git a/test/test_multitensor.py b/test/test_multitensor.py new file mode 100644 index 000000000000..8da4077cd17c --- /dev/null +++ b/test/test_multitensor.py @@ -0,0 +1,137 @@ +import unittest +from tinygrad import Tensor, Device, nn, GlobalCounters +from tinygrad.helpers import CI +from tinygrad.nn.state import get_parameters +import numpy as np + +d0, d1 = f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2" +d2, d3 = f"{Device.DEFAULT}:3", f"{Device.DEFAULT}:4" +N = 128 + +# shard_x is "data parallel" +# shard_w is "model parallel" + +@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA"}, "no GPU CI") +class TestMultiTensor(unittest.TestCase): + def test_shard(self): + X = Tensor.ones(256).contiguous().realize() + X.shard_((d0, d1), 0) + for lb in X.lazydata.lbs: + assert lb.shape == (128,) + + def test_numpy(self): + X = Tensor.ones(256) + X.shard_((d0, d1), 0) + np.testing.assert_allclose(X.numpy(), 1) + + def _test_simple_add_axis(self, shard_x, shard_w): + X = Tensor.ones(256).contiguous().realize() + W = Tensor.ones(256).contiguous().realize() + X.shard_((d0, d1), shard_x) + W.shard_((d0, d1), shard_w) + O = X + W + np.testing.assert_allclose(O.numpy(), 2) + + def test_simple_add(self): return self._test_simple_add_axis(None, None) + def test_simple_add_X(self): return self._test_simple_add_axis(0, None) + def test_simple_add_W(self): return self._test_simple_add_axis(None, 0) + def test_simple_add_XW(self): return self._test_simple_add_axis(0, 0) + + def test_four_add(self): + X = Tensor.ones(256, 256).contiguous().realize() + W = Tensor.ones(256, 256).contiguous().realize() + X.shard_((d0, d1, d2, d3), 1) + W.shard_((d0, d1, d2, d3), None) + O = X + W + np.testing.assert_allclose(O.numpy(), 2) + + def _test_simple_reduce_axis(self, shard_x): + X = Tensor.ones(256, 256).contiguous().realize() + X.shard_((d0, d1), shard_x) + O = X.sum(axis=1) + np.testing.assert_allclose(O.numpy(), 256) + + def test_simple_reduce(self): return self._test_simple_reduce_axis(None) + def test_simple_reduce_0(self): return self._test_simple_reduce_axis(0) + def test_simple_reduce_1(self): return self._test_simple_reduce_axis(1) + + def _test_matmul_shard_axis(self, shard_x, shard_w): + X = Tensor.kaiming_uniform(N, N).realize() + W = Tensor.kaiming_uniform(N, N).realize() + Xs = X.shard((d0, d1), shard_x) + Ws = W.shard((d0, d1), shard_w) + O = (Xs@Ws) + np.testing.assert_allclose(X.numpy() @ W.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5) + + def _test_double_matmul_shard_axis(self, shard_x, shard_w): + X = Tensor.kaiming_uniform(N, N).realize() + W1 = Tensor.kaiming_uniform(N, N).realize() + W2 = Tensor.kaiming_uniform(N, N).realize() + Xs = X.shard((d0, d1), shard_x) + W1s = W1.shard((d0, d1), shard_w) + W2s = W2.shard((d0, d1), shard_w) + O = (Xs@W1s)@W2s + np.testing.assert_allclose((X.numpy() @ W1.numpy()) @ W2.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5) + + def test_matmul_shard_none(self): return self._test_matmul_shard_axis(None, None) + def test_matmul_shard_X_0(self): return self._test_matmul_shard_axis(0, None) + def test_matmul_shard_X_1(self): return self._test_matmul_shard_axis(1, None) + def test_matmul_shard_W_0(self): return self._test_matmul_shard_axis(None, 0) + def test_matmul_shard_W_1(self): return self._test_matmul_shard_axis(None, 1) + + def test_matmul_shard_0_0(self): return self._test_matmul_shard_axis(0, 0) + def test_matmul_shard_0_1(self): return self._test_matmul_shard_axis(0, 1) + def test_matmul_shard_1_0(self): return self._test_matmul_shard_axis(1, 0) + def test_matmul_shard_1_1(self): return self._test_matmul_shard_axis(1, 1) + + def test_double_matmul_shard_X_0(self): return self._test_double_matmul_shard_axis(0, None) + def test_double_matmul_shard_X_1(self): return self._test_double_matmul_shard_axis(1, None) + def test_double_matmul_shard_W_0(self): return self._test_double_matmul_shard_axis(None, 0) + def test_double_matmul_shard_W_1(self): return self._test_double_matmul_shard_axis(None, 1) + + def test_conv_data_shard(self): + conv = nn.Conv2d(3, 16, 3, bias=False) + for p in get_parameters(conv): p.shard_((d0, d1)) + fake_image = Tensor.rand((2, 3, 32, 32)).shard((d0, d1), axis=0) + out = conv(fake_image) + out.numpy() + + def test_conv_bias_data_shard(self): + conv = nn.Conv2d(3, 16, 3) + for p in get_parameters(conv): p.shard_((d0, d1)) + fake_image = Tensor.rand((2, 3, 32, 32)).shard((d0, d1), axis=0) + out = conv(fake_image) + out.numpy() + + def test_backprop_conv(self): + conv = nn.Conv2d(3, 16, 3) + for p in get_parameters(conv): p.shard_((d0, d1)) + optim = nn.optim.Adam(get_parameters(conv)) + fake_image = Tensor.rand((2, 3, 32, 32)).shard((d0, d1), axis=0) + out = conv(fake_image) + optim.zero_grad() + out.mean().backward() + #for p in get_parameters(conv): p.grad.realize() + optim.step() + + def test_data_parallel_resnet(self): + import sys, pathlib + sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix()) + from resnet import ResNet18 + + fake_image = Tensor.rand((2, 3, 224, 224)) + fake_image_sharded = fake_image.shard((d0, d1), axis=0) + print(fake_image_sharded.shape) + m = ResNet18() + m.load_from_pretrained() + real_output = m(fake_image).numpy() + for p in get_parameters(m): p.shard_((d0, d1)).realize() + GlobalCounters.reset() + shard_output = m(fake_image_sharded).realize() + assert shard_output.lazydata.lbs[0].shape == (1, 1000) + assert shard_output.lazydata.lbs[1].shape == (1, 1000) + shard_output_np = shard_output.numpy() + np.testing.assert_allclose(real_output, shard_output_np, atol=1e-6, rtol=1e-6) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tinygrad/features/multi.py b/tinygrad/features/multi.py new file mode 100644 index 000000000000..61b3492526ee --- /dev/null +++ b/tinygrad/features/multi.py @@ -0,0 +1,103 @@ +from __future__ import annotations +from typing import Optional, Union, Any, Tuple, List +import functools +from tinygrad.helpers import all_same, dedup +from tinygrad.dtype import DType +from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps +from tinygrad.lazy import LazyBuffer, create_schedule +from tinygrad.shape.shapetracker import ShapeTracker, sint + +def all_reduce(lbs): + # TODO: replace this with ring reduce + return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs] + +def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]: + assert lbs[0].shape[axis] % len(lbs) == 0, f"{lbs[0].shape=} {axis=} {len(lbs)=}" + sz = lbs[0].shape[axis] // len(lbs) + return [lb.shrink(tuple((0,s) if a != axis else (sz*i,sz*(i+1)) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)] + +class MultiLazyBuffer: + def __init__(self, lbs:List[LazyBuffer], axis:Optional[int]): + assert all(isinstance(x, LazyBuffer) for x in lbs) and len(lbs) >= 2, "all lbs must be LazyBuffers, and we need at least two of them" + assert all_same([(x.shape, x.dtype, x.st) for x in lbs]), "all multilazybuffer needs same shape, dtype, and st" + self.lbs, self.axis, self.dtype, self.device = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs) + self.shape = tuple(s*len(self.lbs) if a == self.axis else s for a,s in enumerate(lbs[0].shape)) + + def __repr__(self): + return f"" + + @staticmethod + def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None): + lbs = [lb.contiguous() if lb.base != lb else lb] * len(devices) + return MultiLazyBuffer([lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)], axis) + + def copy_to_device(self, device:str) -> LazyBuffer: + if self.axis is None: return self.lbs[0].copy_to_device(device) + sz = self.lbs[0].shape[self.axis] + llbs = [] + for i,lb in enumerate([lb.copy_to_device(device) for lb in self.lbs]): + pad_arg = tuple((0,0) if a != self.axis else (sz*i,(s*len(self.lbs))-sz*(i+1)) for a,s in enumerate(lb.shape)) + llbs.append(lb.pad(pad_arg)) + return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs) + + # TODO: fix this + def is_unrealized_contiguous_const(self): return False + + # passthroughs + def schedule(self, seen=None): return create_schedule(self.lbs, seen) + def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis) + def const(self, val:Union[float, int]) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis) + def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis) + + # elementwise is simple + def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer, arg:Optional[Any]=None) -> MultiLazyBuffer: + msrcs = (self,)+in_srcs + assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}" + assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}" + + # NOTE: they all have to share an axis, we always choose [-1] + axis = axes[-1] if len(axes := dedup([x.axis for x in msrcs if x.axis is not None])) else None + srcs = [] + for mlb in msrcs: + if mlb.axis == axis: srcs.append(mlb.lbs) + elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis)) + else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis)) + return MultiLazyBuffer([lsrcs[0].e(op, *lsrcs[1:], arg=arg) for lsrcs in zip(*srcs)], axis) + + def _shape_to_single_shard(self, shape): return tuple(s//len(self.lbs) if a == self.axis else s for a,s in enumerate(shape)) + + def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> MultiLazyBuffer: + if self.axis is not None and new_shape[self.axis] == 1: + # all-reduce on sharded axes + return MultiLazyBuffer(all_reduce([x.r(op, new_shape) for x in self.lbs]), None) + # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct + return MultiLazyBuffer([x.r(op, self._shape_to_single_shard(new_shape)) for x in self.lbs], self.axis) + + # *** movement ops *** + + def reshape(self, arg:Tuple[sint, ...]): + if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None) + # TODO: this can be wrong + st = ShapeTracker.from_shape(self.shape) + rs = st.real_strides()[self.axis] + new_axis = st.reshape(arg).real_strides().index(rs) + narg = tuple(s//len(self.lbs) if a == new_axis else s for a,s in enumerate(arg)) + return MultiLazyBuffer([x.reshape(narg) for x in self.lbs], new_axis) + + def pad(self, arg:Tuple[Tuple[sint, sint], ...]): + assert self.axis is None or arg[self.axis] == (0,0), "padding not supported on sharded axis" + return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis) + def expand(self, arg:Tuple[sint, ...]): + # NOTE: this assert isn't needed, sharded axis can have dim 1 + assert self.axis is None or arg[self.axis] == self.lbs[0].shape[self.axis] * len(self.lbs), "expand not supported on sharded axis" + return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg)) for x in self.lbs], self.axis) + def permute(self, arg:Tuple[int, ...]): + # all permutes supported! + return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None) + def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): + assert self.axis is None or arg[self.axis] == (0, self.lbs[0].shape[self.axis] * len(self.lbs)), "shrinking not supported on sharded axis" + narg = tuple((s1//len(self.lbs), s2//len(self.lbs)) if a == self.axis else (s1,s2) for a,(s1,s2) in enumerate(arg)) + return MultiLazyBuffer([x.shrink(narg) for x in self.lbs], self.axis) + def stride(self, arg:Tuple[int, ...]): + assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis" + return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis) diff --git a/tinygrad/jit.py b/tinygrad/jit.py index 449d7f74f5fc..a3a8d79e750a 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -5,6 +5,7 @@ from tinygrad.helpers import DEBUG, merge_dicts, getenv, all_int, Context, GRAPH from tinygrad.device import Device, JITRunner, CompiledASTRunner, Buffer from tinygrad.tensor import Tensor +from tinygrad.lazy import LazyBuffer from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import Variable, NumNode, Node from weakref import ref, WeakKeyDictionary @@ -50,16 +51,17 @@ def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) def __call__(self, *args, **kwargs) -> ReturnType: # all inputs (except const) are realized - input_tensors: Dict[Union[int, str], Tensor] = {cast(Union[int, str], k):v.realize() for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor} # noqa: E501 - expected_name_sts_dtype = tuple([(k, v.lazydata.st.unbind(), v.dtype) for k,v in input_tensors.items()]) + input_tensors: Dict[Union[int, str], LazyBuffer] = {cast(Union[int, str], k):v.realize().lazydata for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor} # noqa: E501 + assert all(isinstance(x, LazyBuffer) for x in input_tensors.values()), "multilazybuffer JIT isn't supported" + expected_name_sts_dtype = tuple([(k, v.st.unbind(), v.dtype) for k,v in input_tensors.items()]) # get rawbuffers # TODO: why can .realized have Any type? - input_rawbuffers: List[Buffer] = [v.lazydata.base.realized for v in input_tensors.values() if v.lazydata.base.realized is not None] + input_rawbuffers: List[Buffer] = [v.base.realized for v in input_tensors.values() if v.base.realized is not None] assert len(set(input_rawbuffers)) == len(input_rawbuffers), "duplicate inputs to JIT" # get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global - var_vals: Dict[Variable, int] = merge_dicts([arg.lazydata.st.var_vals for arg in input_tensors.values()] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))]) # noqa: E501 + var_vals: Dict[Variable, int] = merge_dicts([arg.st.var_vals for arg in input_tensors.values()] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))]) # noqa: E501 expected_vals = tuple(var_vals.keys()) if self.cnt >= 2: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e95d365e0263..c24f9da1c9c5 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -9,6 +9,7 @@ from tinygrad.dtype import DType, dtypes, ImageDType, least_upper_float, least_upper_dtype from tinygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, flatten, prod, all_int, round_up, merge_dicts, fully_flatten from tinygrad.lazy import LazyBuffer, create_schedule +from tinygrad.features.multi import MultiLazyBuffer from tinygrad.ops import LoadOps from tinygrad.device import Device, Buffer from tinygrad.shape.symbolic import sint @@ -17,7 +18,7 @@ # **** start with two base classes, Tensor and Function **** class Function: - def __init__(self, device:str, *tensors:Tensor): + def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor): self.device = device self.needs_input_grad = [t.requires_grad for t in tensors] self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False @@ -35,6 +36,10 @@ def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor: import tinygrad.mlops as mlops +def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Optional[LazyBuffer]=None): + if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src) + return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None) + class Tensor: __slots__ = "lazydata", "requires_grad", "grad", "_ctx" __deletable__ = ('_ctx',) @@ -45,10 +50,10 @@ def __enter__(self): self.prev, Tensor.training = Tensor.training, self.val def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev no_grad: ClassVar[bool] = False - def __init__(self, data:Union[None, bool, int, float, List, Tuple, LazyBuffer, np.ndarray, bytes], - device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): + def __init__(self, data:Union[None, bool, int, float, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer], + device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}" - device = Device.canonicalize(device) + device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device) # tensors have gradients, buffers do not self.grad: Optional[Tensor] = None @@ -59,9 +64,9 @@ def __init__(self, data:Union[None, bool, int, float, List, Tuple, LazyBuffer, n # internal variables used for autograd graph construction self._ctx: Optional[Function] = None if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported" - elif isinstance(data, (bool, int, float)): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data) + elif isinstance(data, (bool, int, float)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data) elif isinstance(data, bytes): data = LazyBuffer.fromCPU(np.frombuffer(data, np.uint8)) - elif data is None: data = LazyBuffer.loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device) + elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device) elif isinstance(data, list): if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool elif d and all_int(d): dtype = dtype or dtypes.default_int @@ -69,12 +74,16 @@ def __init__(self, data:Union[None, bool, int, float, List, Tuple, LazyBuffer, n # NOTE: cast at the end for the dtypes that do not have a numpy dtype data = LazyBuffer.fromCPU(np.array(data, dtype.np)).cast(dtype) elif isinstance(data, np.ndarray): - if data.shape == (): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item()) + if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item()) else: data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data) # data is a LazyBuffer, but it might be on the wrong device - if not isinstance(data, LazyBuffer): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}") - self.lazydata = data if data.device == device else data.copy_to_device(device) + if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}") + if isinstance(device, tuple): + # TODO: what if it's a MultiLazyBuffer on other devices? + self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = MultiLazyBuffer.from_sharded(data, device, None) if isinstance(data, LazyBuffer) else data + else: + self.lazydata = data if data.device == device else data.copy_to_device(device) def __repr__(self): return f"" @@ -83,7 +92,7 @@ def __repr__(self): def __hash__(self): return id(self) @property - def device(self) -> str: return self.lazydata.device + def device(self) -> Union[str, Tuple[str, ...]]: return self.lazydata.device @property def shape(self) -> Tuple[sint, ...]: return self.lazydata.shape @@ -94,7 +103,8 @@ def dtype(self) -> DType: return self.lazydata.dtype # ***** data handlers **** @staticmethod - def corealize(lst:Iterable[Tensor]): run_schedule(create_schedule([x.lazydata for x in lst])) + def corealize(lst:Iterable[Tensor]): + return run_schedule(create_schedule(flatten([x.lazydata.lbs if isinstance(x.lazydata, MultiLazyBuffer) else [x.lazydata] for x in lst]))) def realize(self) -> Tensor: run_schedule(self.lazydata.schedule()) @@ -102,7 +112,7 @@ def realize(self) -> Tensor: def assign(self, x) -> Tensor: # TODO: this is a hack for writing to DISK. remove with working assign - if self.device.startswith("DISK"): + if isinstance(self.device, str) and self.device.startswith("DISK"): if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype) self.contiguous().realize().lazydata.base.realized.copyin(x.numpy().data) return self @@ -111,7 +121,11 @@ def assign(self, x) -> Tensor: assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}" assert not x.requires_grad # self requires_grad is okay? if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}") - if self.dtype == x.dtype and self.lazydata.base.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.base.realized # noqa: E501 + if self.dtype == x.dtype and not getenv("DISALLOW_ASSIGN"): + if isinstance(self.lazydata, MultiLazyBuffer): + for d,s in zip(x.lazydata.lbs, self.lazydata.lbs): d.output_buffer = s.base.realized + else: + if self.lazydata.base.realized is not None: x.lazydata.output_buffer = self.lazydata.base.realized self.lazydata = x.lazydata return self def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False) @@ -127,7 +141,8 @@ def numpy(self) -> np.ndarray: assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}" assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}" if 0 in self.shape: return np.zeros(self.shape, dtype=self.dtype.np) - return self.cast(self.dtype.scalar()).contiguous().realize().lazydata.base.realized.toCPU().astype(self.dtype.np, copy=True).reshape(self.shape) + t = self if isinstance(self.device, str) else self.to("CPU") + return t.cast(self.dtype.scalar()).contiguous().realize().lazydata.base.realized.toCPU().astype(self.dtype.np, copy=True).reshape(self.shape) def to(self, device:Optional[str]) -> Tensor: if device is None or device == self.device: return self @@ -141,6 +156,14 @@ def to_(self, device:Optional[str]): _ret = Tensor(self.lazydata, device) self.lazydata = _ret.lazydata + def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None) -> Tensor: + assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer" + return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, devices, axis), device=devices) + + def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None): + self.lazydata = self.shard(devices, axis).lazydata + return self + # ***** creation llop entrypoint ***** @staticmethod @@ -751,7 +774,7 @@ def _broadcasted(self, y:Union[Tensor, float, int, bool], reverse:bool=False, ma return x.expand(broadcasted_shape), y.expand(broadcasted_shape) def _to_const_val(self, x:Union[Tensor, float, int, bool]) -> Union[Tensor, float, int, bool]: - return x.lazydata.base.arg if isinstance(x, Tensor) and x.lazydata.is_unrealized_contiguous_const() \ + return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_contiguous_const() \ and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x def add(self, x:Union[Tensor, float, int, bool], reverse=False) -> Tensor: