Skip to content

Commit

Permalink
use imagenet spawn (tinygrad#4096)
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot authored Apr 6, 2024
1 parent fffd9b0 commit 97c402d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
2 changes: 2 additions & 0 deletions examples/mlperf/model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import time
from tqdm import tqdm
import multiprocessing

from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
from tinygrad.helpers import getenv, BEAM, WINO
Expand Down Expand Up @@ -251,6 +252,7 @@ def train_maskrcnn():
pass

if __name__ == "__main__":
multiprocessing.set_start_method('spawn')
with Tensor.train():
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
nm = f"train_{m}"
Expand Down
3 changes: 2 additions & 1 deletion tinygrad/device.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
from collections import defaultdict
from typing import TYPE_CHECKING, Any, List, Optional, Dict, Tuple, ClassVar, NamedTuple
import importlib, inspect, functools, pathlib, time, ctypes
import importlib, inspect, functools, pathlib, time, ctypes, os
from tinygrad.helpers import ansilen, prod, getenv, colored, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put
from tinygrad.helpers import DEBUG, CACHECOLLECTING, BEAM, NOOPT, GlobalCounters
from tinygrad.shape.symbolic import Variable, sym_infer, sint
Expand All @@ -23,6 +23,7 @@ def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(d
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def __get_canonicalized_item(self, ix:str) -> Compiled:
if DEBUG >= 1: print(f"opening device {ix} from pid:{os.getpid()}")
x = ix.split(":")[0].upper()
return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
@functools.cached_property
Expand Down

0 comments on commit 97c402d

Please sign in to comment.