Skip to content

Commit

Permalink
🐛✨Enhanced AsyncIterManager
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Nov 13, 2024
1 parent 57e4c8f commit 06a7f61
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 12 deletions.
28 changes: 16 additions & 12 deletions core/learn/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,24 +274,28 @@ def poll(self, cursor: int, index: Any) -> Any:


class AsyncIterManager:
_cur: Optional["AsyncDataLoaderIter"] = None
_cur: Dict[int, "AsyncDataLoaderIter"] = {}

@classmethod
def new(cls, fn: Callable[[], "AsyncDataLoaderIter"]) -> "AsyncDataLoaderIter":
cls.cleanup()
cls._cur = fn()
return cls._cur
def new(
cls,
id: int,
fn: Callable[[], "AsyncDataLoaderIter"],
) -> "AsyncDataLoaderIter":
cls.cleanup(id)
cls._cur[id] = fn()
return cls._cur[id]

@classmethod
def remove(cls, iter: "AsyncDataLoaderIter") -> None:
if cls._cur is iter:
cls.cleanup()
for id, v in cls._cur.items():
if v is iter:
cls.cleanup(id)

@classmethod
def cleanup(cls) -> None:
if cls._cur is not None:
cur = cls._cur
cls._cur = None
def cleanup(cls, id: int) -> None:
cur = cls._cur.pop(id, None)
if cur is not None:
if not cur._finalized:
cur._cleanup()

Expand Down Expand Up @@ -431,7 +435,7 @@ class DataLoader(TorchDataLoader):

def _get_iterator(self) -> _BaseDataLoaderIter:
if self.num_workers == 0:
return AsyncIterManager.new(lambda: AsyncDataLoaderIter(self))
return AsyncIterManager.new(id(self), lambda: AsyncDataLoaderIter(self))
return super()._get_iterator() # pragma: no cover

def __iter__(self) -> Iterator[tensor_dict_type]: # type: ignore
Expand Down
23 changes: 23 additions & 0 deletions tests/test_learn/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,29 @@ def test_testing_data(self) -> None:
)
cflearn.TrainingPipeline.init(config).fit(data)

def test_async_data(self) -> None:
data, *_ = cflearn.testing.linear_data(
100,
batch_size=4,
use_validation=True,
use_async=True,
)
train_loader, valid_loader = data.build_loaders()
for i, _ in enumerate(train_loader):
if i == 2:
for _ in valid_loader:
break
for _ in valid_loader:
break
for _ in valid_loader:
break
if i == 4:
break
for _ in train_loader:
break
for _ in train_loader:
break

def test_seeding(self) -> None:
data = cflearn.testing.arange_data()[0]
loader = data.build_loaders()[0]
Expand Down

0 comments on commit 06a7f61

Please sign in to comment.