Skip to content

aliang8/varibad_jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

56 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

JAX Implementations of Meta-RL / Offline-RL algorithms

This repository provides clean reimplementations of existing Meta-RL and Offline-RL algorithms.

Getting started:

conda create --name jax_metarl python==3.11.8
pip3 install -e . # should install this repo and dependencies

Example command

Run basic goal-conditioned RL

python3 main.py \
    --config=configs/rl_config.py:lstm-gridworld \
    --config.smoke_test=True \
    --config.use_wb=False \
    --config.policy.pass_task_to_policy=True \
    --config.env.env_kwargs.grid_size=7 \
    --config.env.env_kwargs.max_episode_steps=20 \
    --config.env.num_episodes_per_rollout=1 \
    --config.env.steps_per_rollout=20 \
    --config.env.env_kwargs.random_init=True \
    --config.exp_name="7x7_rand_init_"

python3 main.py \
    --config=configs/rl_config.py:lstm-xland-5x5 \
    --config.smoke_test=True \
    --config.use_wb=True \
    --config.visualize_rollouts=True

Run VariBAD on XLand using LSTM encoder

# gridworld 
python3 main.py \
    --config=configs/varibad_config.py:lstm-gridworld \
    --config.smoke_test=True \
    --config.use_wb=False \
    --config.overwrite=False

# XLand
CUDA_VISIBLE_DEVICES=2 python3 main.py \
    --config=configs/varibad_config.py:lstm-xland-5x5 \
    --config.smoke_test=True \
    --config.use_wb=False

Collect offline dataset with trained model

python3 scripts/generate_data_from_policy.py \
    --config=configs/offline_config.py:gridworld \
    --config.model_ckpt_dir=results/en-gridworld_alg-ppo_pltp-True_t-vae_nvu-3_ed-8 \

CUDA_VISIBLE_DEVICES=0 python3 scripts/generate_data_from_policy.py \
    --config=configs/offline_config.py:xland-5x5 \
    --config.model_ckpt_dir=results/en-xland_alg-ppo_pltp-True_t-vae_nvu-3_ed-8

python3 scripts/generate_data_from_policy.py \
    --config=configs/offline_config.py:dt-xland-5x5 \
    --config.model_ckpt_dir=results/ \

Run offline RL experiments with Decision Transformer

CUDA_VISIBLE_DEVICES=4 python3 main.py \
    --config=configs/offline_config.py:dt-gridworld \
    --config.smoke_test=True \
    --config.use_wb=False

Also supports using Ray for hyperparameter search and WandB for logging experiment metrics. Use smoke_test to toggle Ray tune.

File organization:

Meta-RL algorithms supported:

  • VariBAD
  • RL^2
  • HyperX, not working yet

Offline RL:

  • Decision Transformers
  • LAPO

Misc Models:

  • Genie

Environments supported:

  • Gridworld
  • Xland-Minigrid
  • DM-Alchemy
  • ProcGen

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published