-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloader_replay_buffer.py
32 lines (24 loc) · 1.06 KB
/
dataloader_replay_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from pathlib import Path
from replay_buffer import ReplayBufferStorage, make_replay_loader, AbstractReplayBuffer
class DataloaderReplayBuffer(AbstractReplayBuffer):
def __init__(self, buffer_size, batch_size, nstep, discount,
save_snapshot, num_workers, data_specs=None):
assert data_specs is not None
self.work_dir = Path.cwd()
self.replay_storage = ReplayBufferStorage(data_specs,
self.work_dir / 'buffer')
self.replay_loader = make_replay_loader(
self.work_dir / 'buffer', buffer_size, batch_size, num_workers,
save_snapshot, nstep, discount)
self._replay_iter = None
@property
def replay_iter(self):
if self._replay_iter is None:
self._replay_iter = iter(self.replay_loader)
return self._replay_iter
def add(self, time_step):
self.replay_storage.add(time_step)
def __next__(self,):
return next(self.replay_iter)
def __len__(self,):
return len(self.replay_storage)