Skip to content

Commit

Permalink
fix bugs that prevents rlds and lr to move forward after iterating th…
Browse files Browse the repository at this point in the history
…rough the dataset
  • Loading branch information
KeplerC committed Sep 2, 2024
1 parent 066abe9 commit 74d6c38
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
9 changes: 4 additions & 5 deletions fog_x/loader/lerobot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, path, dataset_name, batch_size=1, delta_timestamps=None):
self.episode_index = 0

def __len__(self):
return len(self.dataset)
return len(self.dataset.episode_data_index["from"])

def __iter__(self):
return self
Expand All @@ -24,12 +24,11 @@ def _frame_to_numpy(frame):
return {k: np.array(v) for k, v in frame.items()}
for _ in range(self.batch_size):
episode = []
# repeat
if self.episode_index >= len(self.dataset):
self.episode_index = 0

for attempt in range(max_retries):
try:
# repeat
if self.episode_index >= len(self.dataset):
self.episode_index = 0
from_idx = self.dataset.episode_data_index["from"][self.episode_index].item()
to_idx = self.dataset.episode_data_index["to"][self.episode_index].item()
frames = [_frame_to_numpy(self.dataset[idx]) for idx in range(from_idx, to_idx)]
Expand Down
3 changes: 0 additions & 3 deletions fog_x/loader/rlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ def to_numpy(step_data):
return trajectory

def __next__(self):
if self.index >= self.length:
self.index = 0
raise StopIteration
return self.get_batch()

def __getitem__(self, idx):
Expand Down

0 comments on commit 74d6c38

Please sign in to comment.