Skip to content

Commit

Permalink
✅Enhanced coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Oct 16, 2024
1 parent c8c4f75 commit 49b4436
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
16 changes: 9 additions & 7 deletions core/learn/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,12 +230,12 @@ class IAsyncDataset(IDataset):
programming languages that can do the 'real' async I/O stuffs.
"""

def __getitems__(self, indices: List[int]) -> Any:
def __getitems__(self, indices: List[int]) -> Any: # pragma: no cover
raise NotImplementedError("should not call `__getitems__` of an async dataset")

@abstractmethod
def async_reset(self) -> None:
pass
"""reset the dataset at the beginning of each epoch"""

@abstractmethod
def async_submit(self, cursor: int, index: Any) -> bool:
Expand All @@ -250,7 +250,7 @@ def poll(self, cursor: int) -> Any:
fetched = self.async_fetch(cursor)
if fetched is not None:
return fetched
time.sleep(0.01)
time.sleep(0.01) # pragma: no cover


class AsyncDataLoaderIter(_SingleProcessDataLoaderIter):
Expand All @@ -264,7 +264,9 @@ def __init__(self, loader: "DataLoader"):
self.enabled = loader.async_prefetch
self.async_prefetch_factor = loader.async_prefetch_factor
if self.enabled and not isinstance(loader.dataset, IAsyncDataset):
raise RuntimeError("async prefetch is only available for `IAsyncDataset`")
raise RuntimeError(
"async prefetch is only available for `IAsyncDataset`"
) # pragma: no cover
self._initialized = False

def _initialize(self) -> None:
Expand All @@ -277,7 +279,7 @@ def _initialize(self) -> None:
def _sumbit_next(self) -> None:
cursor = self._queue_cursor
index = self._next_index()
if not self._dataset.async_submit(cursor, index):
if not self._dataset.async_submit(cursor, index): # pragma: no cover
msg = f"failed to submit async task with cursor={cursor} and index={index}"
console.error(msg)
raise RuntimeError("failed to sumbit async task")
Expand Down Expand Up @@ -306,7 +308,7 @@ def _next_data(self) -> Any:
self._drained = True
cursor = self._queue.pop(0)
data = self._dataset.poll(cursor)
if self._pin_memory:
if self._pin_memory: # pragma: no cover
data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
return data

Expand Down Expand Up @@ -335,7 +337,7 @@ class DataLoader(TorchDataLoader):
def _get_iterator(self) -> _BaseDataLoaderIter:
if self.num_workers == 0:
return AsyncDataLoaderIter(self)
return super()._get_iterator()
return super()._get_iterator() # pragma: no cover

def __iter__(self) -> Iterator[tensor_dict_type]: # type: ignore
self.dataset.reset(for_inference=self.for_inference)
Expand Down
13 changes: 13 additions & 0 deletions tests/test_learn/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ def test_array_data(self) -> None:
np.testing.assert_allclose(x0, x1)
np.testing.assert_allclose(y0, y1)

with self.assertRaises(ValueError):
data.fit(None).build_loaders()
with self.assertRaises(ValueError):
data.fit(x, y[:-1]).build_loaders()

def test_array_dict_data(self) -> None:
input_dim = 11
output_dim = 7
Expand Down Expand Up @@ -189,6 +194,14 @@ def test_testing_data(self) -> None:
)
config.to_debug()
cflearn.TrainingPipeline.init(config).fit(data)
data, in_dim, out_dim, _ = cflearn.testing.linear_data(1000, use_async=True)
config = cflearn.Config(
module_name="linear",
module_config=dict(input_dim=in_dim, output_dim=out_dim, bias=False),
loss_name="mse",
num_epoch=1,
)
cflearn.TrainingPipeline.init(config).fit(data)

def test_seeding(self) -> None:
data = cflearn.testing.arange_data()[0]
Expand Down

0 comments on commit 49b4436

Please sign in to comment.