From 74d6c388677498f1cff88ba904feed4643c87021 Mon Sep 17 00:00:00 2001 From: Kaiyuan Eric Chen <kych@berkeley.edu> Date: Mon, 2 Sep 2024 14:43:55 -0700 Subject: [PATCH] fix bugs that prevents rlds and lr to move forward after iterating through the dataset --- fog_x/loader/lerobot.py | 9 ++++----- fog_x/loader/rlds.py | 3 --- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/fog_x/loader/lerobot.py b/fog_x/loader/lerobot.py index 242d3a6..32c678a 100644 --- a/fog_x/loader/lerobot.py +++ b/fog_x/loader/lerobot.py @@ -11,7 +11,7 @@ def __init__(self, path, dataset_name, batch_size=1, delta_timestamps=None): self.episode_index = 0 def __len__(self): - return len(self.dataset) + return len(self.dataset.episode_data_index["from"]) def __iter__(self): return self @@ -24,12 +24,11 @@ def _frame_to_numpy(frame): return {k: np.array(v) for k, v in frame.items()} for _ in range(self.batch_size): episode = [] - # repeat - if self.episode_index >= len(self.dataset): - self.episode_index = 0 - for attempt in range(max_retries): try: + # repeat + if self.episode_index >= len(self.dataset): + self.episode_index = 0 from_idx = self.dataset.episode_data_index["from"][self.episode_index].item() to_idx = self.dataset.episode_data_index["to"][self.episode_index].item() frames = [_frame_to_numpy(self.dataset[idx]) for idx in range(from_idx, to_idx)] diff --git a/fog_x/loader/rlds.py b/fog_x/loader/rlds.py index 5c4d956..0003b6b 100644 --- a/fog_x/loader/rlds.py +++ b/fog_x/loader/rlds.py @@ -63,9 +63,6 @@ def to_numpy(step_data): return trajectory def __next__(self): - if self.index >= self.length: - self.index = 0 - raise StopIteration return self.get_batch() def __getitem__(self, idx):