From 74d6c388677498f1cff88ba904feed4643c87021 Mon Sep 17 00:00:00 2001
From: Kaiyuan Eric Chen <kych@berkeley.edu>
Date: Mon, 2 Sep 2024 14:43:55 -0700
Subject: [PATCH] fix bugs that prevents rlds and lr to move forward after
 iterating through the dataset

---
 fog_x/loader/lerobot.py | 9 ++++-----
 fog_x/loader/rlds.py    | 3 ---
 2 files changed, 4 insertions(+), 8 deletions(-)

diff --git a/fog_x/loader/lerobot.py b/fog_x/loader/lerobot.py
index 242d3a6..32c678a 100644
--- a/fog_x/loader/lerobot.py
+++ b/fog_x/loader/lerobot.py
@@ -11,7 +11,7 @@ def __init__(self, path, dataset_name, batch_size=1, delta_timestamps=None):
         self.episode_index = 0
 
     def __len__(self):
-        return len(self.dataset)
+        return len(self.dataset.episode_data_index["from"])
 
     def __iter__(self):
         return self
@@ -24,12 +24,11 @@ def _frame_to_numpy(frame):
             return {k: np.array(v) for k, v in frame.items()}
         for _ in range(self.batch_size):
             episode = []
-            # repeat
-            if self.episode_index >= len(self.dataset):
-                self.episode_index = 0
-                
             for attempt in range(max_retries):
                 try:
+                    # repeat
+                    if self.episode_index >= len(self.dataset):
+                        self.episode_index = 0
                     from_idx = self.dataset.episode_data_index["from"][self.episode_index].item()
                     to_idx = self.dataset.episode_data_index["to"][self.episode_index].item()
                     frames = [_frame_to_numpy(self.dataset[idx]) for idx in range(from_idx, to_idx)]
diff --git a/fog_x/loader/rlds.py b/fog_x/loader/rlds.py
index 5c4d956..0003b6b 100644
--- a/fog_x/loader/rlds.py
+++ b/fog_x/loader/rlds.py
@@ -63,9 +63,6 @@ def to_numpy(step_data):
         return trajectory
 
     def __next__(self):
-        if self.index >= self.length:
-            self.index = 0
-            raise StopIteration
         return self.get_batch()
 
     def __getitem__(self, idx):