Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fit new emei api #7

Open
wants to merge 74 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
022251d
:hammer: try to refactor the models module
FrankTianTT Oct 24, 2022
3a5a808
:memo: add docs
FrankTianTT Oct 24, 2022
52afe3a
:beetle: update requirements/dev.txt
FrankTianTT Oct 25, 2022
158ae90
:tada: try to import BaseCausalMechanism
FrankTianTT Oct 25, 2022
d0eaa35
:hammer: introduce causal-mech class
FrankTianTT Nov 4, 2022
b65c71e
:beetle: fix parallel-linear bug
FrankTianTT Nov 4, 2022
c161362
:tada: create OfflineDataset
FrankTianTT Nov 4, 2022
b0ffd7a
:tada: add OfflineDataset and EnsembleOfflineDataset
FrankTianTT Nov 4, 2022
311838d
Merge remote-tracking branch 'origin/refactor_model' into refactor_model
FrankTianTT Nov 4, 2022
51d3101
:tada: add binary-variable and update encoder and decoder
FrankTianTT Nov 4, 2022
d450eae
:tada: add forward and loss for PlainMech
FrankTianTT Nov 4, 2022
d0a909a
:tada: mixed use of BufferDataset and EnsembleBufferDataset is supported
FrankTianTT Nov 5, 2022
db4e4fb
:tada: finish plain_mech.py
FrankTianTT Nov 5, 2022
9cba197
:tada: add dynamics
FrankTianTT Nov 5, 2022
c79a0b0
:hammer: refactor dynamics and fake-env
FrankTianTT Nov 6, 2022
1391399
:hammer: add space2dict
FrankTianTT Nov 6, 2022
38aa162
:hammer: refactor algorithm(from function to class)
FrankTianTT Nov 6, 2022
4a9f989
:beetle: fix type check in causal mechs
wz139704646 Nov 7, 2022
6a1e90f
:hammer: refactor CausalMech
FrankTianTT Nov 7, 2022
1d69337
Merge remote-tracking branch 'origin/refactor_model' into refactor_model
FrankTianTT Nov 7, 2022
0492f30
:hammer: refactor BaseCausalMech
FrankTianTT Nov 7, 2022
119bfb0
:wrench: make offline-dyna independent
FrankTianTT Nov 7, 2022
190f4d6
:bug: fix online-RL wrong log
FrankTianTT Nov 7, 2022
09c651c
:bug: fix tests
FrankTianTT Nov 7, 2022
2e34bc9
:tada: add mask in reduce_encoder_output
FrankTianTT Nov 8, 2022
e7c6850
:tada: finish CMI test!
FrankTianTT Nov 9, 2022
c244d8c
:wrench: use index to mask the encoder output
Nov 9, 2022
04e5a03
:bug: fix mask bug
FrankTianTT Nov 9, 2022
0f16f2f
:tada: add binary, weight, neural and prob graphs
wz139704646 Nov 11, 2022
2eac931
:fire: remove a test temp file
Nov 13, 2022
eb444fc
:tada: add reinforce causal mech and corresponding tests
wz139704646 Nov 16, 2022
e168b2f
:wrench: add reinforce config file, fix some bugs in reinforce and lo…
wz139704646 Nov 21, 2022
be1ccf1
:tada: update configuration
FrankTianTT Nov 23, 2022
2927e04
:bug: fix cmi test bug in single step forward
wz139704646 Nov 25, 2022
c5bdfb1
:wrench: update .gitignore
FrankTianTT Nov 28, 2022
3fe3612
:wrench: update hydra's config
FrankTianTT Nov 28, 2022
1c929ba
:tada: fit emei's oracle causal graph
FrankTianTT Nov 28, 2022
a3f1f9a
:tada: add RadianVariable and von_mises_nll_loss
FrankTianTT Nov 29, 2022
dd1ee0c
:wrench: add discovery param in reinforce
wz139704646 Nov 30, 2022
a4d5310
:tada: add save and auto-load
FrankTianTT Dec 1, 2022
2b77130
Merge remote-tracking branch 'origin/refactor_model' into refactor_model
FrankTianTT Dec 1, 2022
5db6803
:bug: save after learn
FrankTianTT Dec 2, 2022
55f6698
:bug: fix eval_model_on_space.py action bug
FrankTianTT Dec 3, 2022
9f5750a
:bug: fix eval_model_on_space.py action bug
FrankTianTT Dec 3, 2022
05c0d0b
:tada: add scheduler for optimizer
FrankTianTT Dec 4, 2022
4ca25e0
:tada: save to work dir
FrankTianTT Dec 6, 2022
289c5b2
:wrench: add graph saving and loading
wz139704646 Dec 7, 2022
bb5203d
:tada: add neural bernoulli graph
Dec 17, 2022
ea72762
:wrench: fix dev requirements
wz139704646 Dec 17, 2022
711d0f2
:tada: update config
FrankTianTT Jan 12, 2023
6914045
Merge remote-tracking branch 'origin/refactor_model' into refactor_model
FrankTianTT Jan 12, 2023
4855152
:tada: add exp_collect
FrankTianTT Jan 15, 2023
f996955
:tada: update config
FrankTianTT Feb 5, 2023
ec49fd9
:tada: update config
FrankTianTT Feb 6, 2023
87fa146
:hammer: update causal mech framework
FrankTianTT Feb 28, 2023
e68768d
:tada: add self-adaption
FrankTianTT Feb 28, 2023
5f9ce3e
:hammer: update kci sample num
FrankTianTT Mar 2, 2023
909600c
:hammer: add parallel
FrankTianTT Mar 12, 2023
8546dc6
:tada: update plain mech
FrankTianTT Mar 12, 2023
92a5f60
:hammer: add state2obs_fn
FrankTianTT Mar 20, 2023
54b250b
Merge remote-tracking branch 'origin/refactor_model' into refactor_model
FrankTianTT Mar 20, 2023
db217ad
:hammer: add state2obs_fn
FrankTianTT Mar 20, 2023
3060b61
:bug: fix maybe_load_offline_model bug
FrankTianTT Mar 21, 2023
b519930
:hammer: plain -> oracle
FrankTianTT Mar 21, 2023
69d047a
:tada: save causal discovery history
FrankTianTT Mar 21, 2023
7995e86
:bug: fix device bug
FrankTianTT Mar 27, 2023
dd0be11
:tada: add exp_reader
FrankTianTT Mar 29, 2023
d1151dd
:tada: add exp_reader
FrankTianTT Mar 30, 2023
52403f3
:tada: update notebook
FrankTianTT Apr 6, 2023
5758629
:tada: multi-process loader
FrankTianTT Apr 6, 2023
d894c0e
:bug: clip log_var to avoid inf
FrankTianTT Apr 6, 2023
c738171
:tada: update new cfg
FrankTianTT Apr 7, 2023
edfafa6
:tada: update new cfg
FrankTianTT Apr 7, 2023
7279e56
Merge pull request #2 from FrankTianTT/refactor_model
FrankTianTT Jun 5, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,4 @@ cython_debug/
.vscode
.idea

/cmrl.egg-info/
/exp/
/stable-baselines3/
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Emoji | Description
:art: `:art:` | When you improved / added assets like themes.
:rocket: `:rocket:` | When you improved performance.
:memo: `:memo:` | When you wrote documentation.
:beetle: `:beetle:` | When you fixed a bug.
:bug: `:bug:` | When you fixed a bug.
:twisted_rightwards_arrows: `:twisted_rightwards_arrows:` | When you merged a branch.
:fire: `:fire:` | When you removed something.
:truck: `:truck:` | When you moved / renamed something.
Expand Down
42 changes: 34 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
![](/img/cmrl_logo.png)
![](/docs/cmrl_logo.png)

# Causal-MBRL

Expand All @@ -10,7 +10,7 @@
<a href="https://www.python.org/downloads/release/python-380/"><img src="https://img.shields.io/badge/python-3.8-brightgreen"></a>

`cmrl`(short for `Causal-MBRL`) is a toolbox for facilitating the development of Causal Model-based Reinforcement
learning algorithms. It use [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) as model-free engine and
learning algorithms. It uses [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) as model-free engine and
allows flexible use of causal models.

`cmrl` is inspired by [MBRL-Lib](https://github.com/facebookresearch/mbrl-lib). Unlike MBRL-Lib, `cmrl` focuses on the
Expand Down Expand Up @@ -111,18 +111,44 @@ cd causal-mbrl
# create conda env
conda create -n cmrl python=3.8
conda activate cmrl
# install torch
conda install pytorch -c pytorch
# install cmrl and its dependent packages
pip install -e .
```

If there is no `cuda` in your device, it's convenient to install `cuda` and `pytorch` from conda directly (refer
to [pytorch](https://pytorch.org/get-started/locally/)):
for pytorch

````shell
# for example, in the case of cuda=11.3
conda install pytorch cudatoolkit=11.3 -c pytorch
````
```shell
# for MacOS
conda install pytorch -c pytorch
# for Linux
conda install pytorch pytorch-cuda=11.6 -c pytorch -c nvidia
```

for KCIT and RCIT

```shell
conda install -c conda-forge r-base
conda install -c conda-forge r-devtools
R
```

```shell
# Install the RCIT from Github.
install.packages("devtools")
library(devtools)
install_github("ericstrobl/RCIT")
library(RCIT)

# Install R libraries for RCIT
install.packages("MASS")
install.packages("momentchi2")
install.packages("devtools")

# test RCIT
RCIT(rnorm(1000),rnorm(1000),rnorm(1000))
```
## install using pip

coming soon.
Expand Down
1 change: 0 additions & 1 deletion cmrl/agent/__init__.py

This file was deleted.

137 changes: 0 additions & 137 deletions cmrl/agent/core.py

This file was deleted.

28 changes: 0 additions & 28 deletions cmrl/agent/sac_wrapper.py

This file was deleted.

8 changes: 4 additions & 4 deletions cmrl/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from cmrl.algorithms.offline import mopo
from cmrl.algorithms.offline import off_dyna
from cmrl.algorithms.online import mbpo
from cmrl.algorithms.online import on_dyna
from cmrl.algorithms.off_dyna import OfflineDyna
from cmrl.algorithms.mopo import MOPO
from cmrl.algorithms.on_dyna import OnlineDyna
from cmrl.algorithms.mbpo import MBPO
116 changes: 116 additions & 0 deletions cmrl/algorithms/base_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os
from typing import Optional
from functools import partial

import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.callbacks import BaseCallback
import wandb

from cmrl.models.fake_env import VecFakeEnv
from cmrl.sb3_extension.logger import configure as logger_configure
from cmrl.sb3_extension.eval_callback import EvalCallback
from cmrl.utils.creator import create_dynamics, create_agent
from cmrl.utils.env import make_env


class BaseAlgorithm:
def __init__(
self,
cfg: DictConfig,
work_dir: Optional[str] = None,
):
self.cfg = cfg
self.work_dir = work_dir or os.getcwd()

self.env, fns = make_env(self.cfg)
self.reward_fn, self.termination_fn, self.get_init_obs_fn, self.obs2state_fn, self.state2obs_fn = fns

self.eval_env, *_ = make_env(self.cfg)
np.random.seed(self.cfg.seed)
torch.manual_seed(self.cfg.seed)

format_strings = ["tensorboard", "multi_csv"]
if self.cfg.verbose:
format_strings += ["stdout"]
self.logger = logger_configure("log", format_strings)

if cfg.wandb:
wandb.init(
project="causal-mbrl",
group=cfg.exp_name,
config=OmegaConf.to_container(cfg, resolve=True),
sync_tensorboard=True,
)

# create ``cmrl`` dynamics
self.dynamics = create_dynamics(
self.cfg, self.env.state_space, self.env.action_space, self.obs2state_fn, self.state2obs_fn, logger=self.logger
)

if self.cfg.transition.name == "oracle_transition":
graph = self.env.get_transition_graph() if self.cfg.transition.oracle == "truth" else None
self.dynamics.transition.set_oracle_graph(graph)
if self.cfg.reward_mech.learn and not self.cfg.reward_mech.name == "oracle_reward_mech":
graph = self.env.get_reward_mech_graph() if self.cfg.transition.oracle == "truth" else None
self.dynamics.reward_mech.set_oracle_graph(graph)
if self.cfg.termination_mech.learn and not self.cfg.termination_mech.name == "oracle_termination_mech":
graph = self.env.get_termination_mech_graph() if self.cfg.transition.oracle == "truth" else None
self.dynamics.termination_mech.set_oracle_graph(graph)

# create sb3's replay buffer for real offline data
self.real_replay_buffer = ReplayBuffer(
cfg.task.num_steps,
self.env.observation_space,
self.env.action_space,
self.cfg.device,
handle_timeout_termination=False,
)

self.partial_fake_env = partial(
VecFakeEnv,
self.cfg.algorithm.num_envs,
self.env.state_space,
self.env.action_space,
self.dynamics,
self.reward_fn,
self.termination_fn,
self.get_init_obs_fn,
self.real_replay_buffer,
penalty_coeff=self.cfg.task.penalty_coeff,
logger=self.logger,
)
self.agent = create_agent(self.cfg, self.fake_env, self.logger)

@property
def fake_env(self) -> VecFakeEnv:
return self.partial_fake_env(
deterministic=self.cfg.algorithm.deterministic,
max_episode_steps=self.env.spec.max_episode_steps,
branch_rollout=False,
)

@property
def callback(self) -> BaseCallback:
fake_eval_env = self.partial_fake_env(
deterministic=True, max_episode_steps=self.env.spec.max_episode_steps, branch_rollout=False
)
return EvalCallback(
self.eval_env,
fake_eval_env,
n_eval_episodes=self.cfg.task.n_eval_episodes,
best_model_save_path="./",
eval_freq=self.cfg.task.eval_freq,
deterministic=True,
render=False,
)

def learn(self):
self._setup_learn()

self.agent.learn(total_timesteps=self.cfg.task.num_steps, callback=self.callback)

def _setup_learn(self):
pass
Loading