Skip to content

Commit

Permalink
RLDS slicing & fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Lenoplus42 committed Sep 27, 2024
1 parent 219c7e4 commit f129d37
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 433 deletions.
439 changes: 53 additions & 386 deletions benchmarks/Visualization.ipynb

Large diffs are not rendered by default.

46 changes: 23 additions & 23 deletions benchmarks/openx_by_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse
import time
import numpy as np
from fog_x.loader import RLDSLoader, VLALoader, HDF5Loader
from fog_x.loader import RLDSLoader, VLALoader
import tensorflow as tf
import pandas as pd
import fog_x
Expand Down Expand Up @@ -340,34 +340,34 @@ def evaluation(args):
logger.debug(f"Evaluating dataset: {dataset_name}")

handlers = [
# VLAHandler(
# args.exp_dir,
# dataset_name,
# args.num_batches,
# args.batch_size,
# args.log_frequency,
# ),
VLAHandler(
args.exp_dir,
dataset_name,
args.num_batches,
args.batch_size,
args.log_frequency,
),
HDF5Handler(
args.exp_dir,
dataset_name,
args.num_batches,
args.batch_size,
args.log_frequency,
),
# LeRobotHandler(
# args.exp_dir,
# dataset_name,
# args.num_batches,
# args.batch_size,
# args.log_frequency,
# ),
# RLDSHandler(
# args.exp_dir,
# dataset_name,
# args.num_batches,
# args.batch_size,
# args.log_frequency,
# ),
LeRobotHandler(
args.exp_dir,
dataset_name,
args.num_batches,
args.batch_size,
args.log_frequency,
),
RLDSHandler(
args.exp_dir,
dataset_name,
args.num_batches,
args.batch_size,
args.log_frequency,
),
# FFV1Handler(
# args.exp_dir,
# dataset_name,
Expand Down Expand Up @@ -438,4 +438,4 @@ def evaluation(args):
)
args = parser.parse_args()

evaluation(args)
evaluation(args)
50 changes: 27 additions & 23 deletions benchmarks/openx_by_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,27 +360,27 @@ def evaluation(args):
logger.debug(f"Evaluating dataset: {dataset_name}")

handlers = [
# VLAHandler(
# args.exp_dir,
# dataset_name,
# args.num_batches,
# args.batch_size,
# args.log_frequency,
# ),
# HDF5Handler(
# args.exp_dir,
# dataset_name,
# args.num_batches,
# args.batch_size,
# args.log_frequency,
# ),
# LeRobotHandler(
# args.exp_dir,
# dataset_name,
# args.num_batches,
# args.batch_size,
# args.log_frequency,
# ),
VLAHandler(
args.exp_dir,
dataset_name,
args.num_batches,
args.batch_size,
args.log_frequency,
),
HDF5Handler(
args.exp_dir,
dataset_name,
args.num_batches,
args.batch_size,
args.log_frequency,
),
LeRobotHandler(
args.exp_dir,
dataset_name,
args.num_batches,
args.batch_size,
args.log_frequency,
),
RLDSHandler(
args.exp_dir,
dataset_name,
Expand Down Expand Up @@ -423,9 +423,13 @@ def evaluation(args):

# Write all results to CSV
results_df = pd.DataFrame(all_results)
results_df.to_csv(csv_file, index=False)
results_df.to_csv(csv_file, index = False)
logger.debug(f"Results appended to {csv_file}")

# if os.path.exists(csv_file):
# print("exist in", os.path.abspath(csv_file))
# print(pd.read_csv(csv_file))


if __name__ == "__main__":
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -458,4 +462,4 @@ def evaluation(args):
)
args = parser.parse_args()

evaluation(args)
evaluation(args)
3 changes: 2 additions & 1 deletion evaluation.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ sudo echo "Use sudo access for clearning cache"

# Define a list of batch sizes to iterate through

batch_sizes=(64)
batch_sizes=(1 2 4 6 8 10 12 14 16 64)
num_batches=200
# batch_sizes=(1 2)

Expand All @@ -17,6 +17,7 @@ do

# python3 benchmarks/openx.py --dataset_names nyu_door_opening_surprising_effectiveness --num_batches $num_batches --batch_size $batch_size
python3 benchmarks/openx_by_frame.py --dataset_names berkeley_cable_routing --num_batches $num_batches --batch_size $batch_size
# python3 benchmarks/openx_by_episode.py --dataset_names berkeley_cable_routing --num_batches $num_batches --batch_size $batch_size
# python3 benchmarks/openx_by_frame.py --dataset_names bridge --num_batches $num_batches --batch_size $batch_size
# python3 benchmarks/openx.py --dataset_names berkeley_autolab_ur5 --num_batches $num_batches --batch_size $batch_size
done
3 changes: 3 additions & 0 deletions fog_x/loader/rlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,11 @@ def to_numpy(step_data):
num_frames = len(traj["steps"])
if num_frames >= self.slice_length:
random_from = np.random.randint(0, num_frames - self.slice_length + 1)
# random_to = random_from + self.slice_length
trajs = traj["steps"].skip(random_from).take(self.slice_length)
else:
# random_from = 0
# random_to = num_frames
trajs = traj["steps"]
for step in trajs:
trajectory.append(to_numpy(step))
Expand Down

0 comments on commit f129d37

Please sign in to comment.