Skip to content

Commit

Permalink
revert command queue (tinygrad#4097)
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot authored Apr 6, 2024
1 parent 97c402d commit e4a1858
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 120 deletions.
3 changes: 1 addition & 2 deletions openpilot/compile2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from tinygrad.dtype import ImageDType
from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.commandqueue import CommandQueue
from tinygrad.engine.schedule import create_schedule
from tinygrad.ops import LoadOps, ScheduleItem
Device.DEFAULT = "GPU"
Expand Down Expand Up @@ -89,7 +88,7 @@ def test_vs_onnx(onnx_data, schedule:Optional[List[ScheduleItem]], inputs:Dict[s
# run code (all buffers have been allocated)
GlobalCounters.reset()
output = schedule[-1].outputs[0]
CommandQueue(schedule)()
run_schedule(schedule)

new_tinygrad_out = np.frombuffer(output.as_buffer(), dtype=output.dtype.np)
np.testing.assert_allclose(new_torch_out.reshape(new_tinygrad_out.shape), new_tinygrad_out, atol=1e-4, rtol=1e-2)
Expand Down
113 changes: 0 additions & 113 deletions tinygrad/engine/commandqueue.py

This file was deleted.

44 changes: 40 additions & 4 deletions tinygrad/engine/realize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,41 @@
from typing import List
from tinygrad.ops import ScheduleItem
from tinygrad.engine.commandqueue import CommandQueue
from typing import List, Dict, Optional
from tinygrad.helpers import getenv, colored
from tinygrad.ops import ScheduleItem, BufferOps, LoadOps
from tinygrad.device import JITRunner, Device, BufferCopy, BufferXfer, update_stats
from tinygrad.buffer import Buffer
from tinygrad.shape.symbolic import Variable

def run_schedule(schedule:List[ScheduleItem]): CommandQueue(schedule)()
class CustomOp(JITRunner):
def __init__(self, fxn):
self.fxn = fxn
super().__init__()
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): self.fxn(*rawbufs)

def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
assert len(set(x.device for x in si.outputs+si.inputs)) == 1 or si.ast[0].op is LoadOps.COPY
if si.ast[0].op is BufferOps.STORE: return Device[si.outputs[0].device].get_runner(*si.ast)
assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput"
out, ast = si.outputs[0], si.ast[0]
if ast.op is LoadOps.COPY:
if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]: return BufferXfer()
return BufferCopy()
if ast.op is LoadOps.CUSTOM: return CustomOp(ast.arg)
return None

logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
def run_schedule(schedule:List[ScheduleItem]):
while len(schedule):
si = schedule.pop(0)
if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")

# get the program
prg = lower_schedule_item(si)

for out in si.outputs:
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
if out.size > 0 and not (out.device.startswith("DISK") and si.ast[0].op is BufferOps.STORE) and not hasattr(out, "_buf"): out.allocate()

# run the function (put it in JIT)
real_buffers = [x for x in si.outputs+si.inputs if x.size != 0]
if prg: prg.exec(real_buffers, si.var_vals)
elif (out:=si.outputs[0]).size > 0: update_stats(colored(f"empty {out.size:10d} {out.dtype}", "yellow"), 0, 0, {}, None, 1, device=out.device)
2 changes: 1 addition & 1 deletion tinygrad/renderer/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.ADD, "dtype": set([dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]),
"vin": [{"__name__": "non_muls"}, {"__name__": "muls", "uop": UOps.ALU, "arg": BinaryOps.MUL}]},
lambda root, muls, non_muls : UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)),
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)),
*[({"__name__": "x", "uop": UOps.ALU, "dtype": dtypes.half, "arg": op},
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
for op in lang.asm_for_op.keys() if op not in lang.supports_half],
Expand Down

0 comments on commit e4a1858

Please sign in to comment.