diff --git a/.gitignore b/.gitignore
index a25976e..0ad959e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -138,6 +138,4 @@ cython_debug/
.vscode
.idea
-/cmrl.egg-info/
/exp/
-/stable-baselines3/
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 1d44313..c4370e3 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -43,7 +43,7 @@ Emoji | Description
:art: `:art:` | When you improved / added assets like themes.
:rocket: `:rocket:` | When you improved performance.
:memo: `:memo:` | When you wrote documentation.
-:beetle: `:beetle:` | When you fixed a bug.
+:bug: `:bug:` | When you fixed a bug.
:twisted_rightwards_arrows: `:twisted_rightwards_arrows:` | When you merged a branch.
:fire: `:fire:` | When you removed something.
:truck: `:truck:` | When you moved / renamed something.
diff --git a/README.md b/README.md
index 3219478..4091dd8 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,4 @@
-![](/img/cmrl_logo.png)
+![](/docs/cmrl_logo.png)
# Causal-MBRL
@@ -10,7 +10,7 @@
`cmrl`(short for `Causal-MBRL`) is a toolbox for facilitating the development of Causal Model-based Reinforcement
-learning algorithms. It use [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) as model-free engine and
+learning algorithms. It uses [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) as model-free engine and
allows flexible use of causal models.
`cmrl` is inspired by [MBRL-Lib](https://github.com/facebookresearch/mbrl-lib). Unlike MBRL-Lib, `cmrl` focuses on the
@@ -111,18 +111,44 @@ cd causal-mbrl
# create conda env
conda create -n cmrl python=3.8
conda activate cmrl
+# install torch
+conda install pytorch -c pytorch
# install cmrl and its dependent packages
pip install -e .
```
-If there is no `cuda` in your device, it's convenient to install `cuda` and `pytorch` from conda directly (refer
-to [pytorch](https://pytorch.org/get-started/locally/)):
+for pytorch
-````shell
-# for example, in the case of cuda=11.3
-conda install pytorch cudatoolkit=11.3 -c pytorch
-````
+```shell
+# for MacOS
+conda install pytorch -c pytorch
+# for Linux
+conda install pytorch pytorch-cuda=11.6 -c pytorch -c nvidia
+```
+
+for KCIT and RCIT
+
+```shell
+conda install -c conda-forge r-base
+conda install -c conda-forge r-devtools
+R
+```
+```shell
+# Install the RCIT from Github.
+install.packages("devtools")
+library(devtools)
+install_github("ericstrobl/RCIT")
+library(RCIT)
+
+# Install R libraries for RCIT
+install.packages("MASS")
+install.packages("momentchi2")
+install.packages("devtools")
+
+# test RCIT
+RCIT(rnorm(1000),rnorm(1000),rnorm(1000))
+```
## install using pip
coming soon.
diff --git a/cmrl/agent/__init__.py b/cmrl/agent/__init__.py
deleted file mode 100644
index 012afb4..0000000
--- a/cmrl/agent/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from cmrl.agent.core import Agent, RandomAgent, complete_agent_cfg, load_agent
diff --git a/cmrl/agent/core.py b/cmrl/agent/core.py
deleted file mode 100644
index b6a074e..0000000
--- a/cmrl/agent/core.py
+++ /dev/null
@@ -1,137 +0,0 @@
-import abc
-import pathlib
-from typing import Any, Optional, Union
-
-import gym
-import hydra
-import numpy as np
-import omegaconf
-
-
-class Agent:
- """Abstract class for all agents."""
-
- @abc.abstractmethod
- def act(self, obs: np.ndarray, **kwargs) -> np.ndarray:
- pass
-
- def reset(self):
- pass
-
-
-class RandomAgent(Agent):
- """An agent that samples action from the environments action space.
-
- Args:
- env (gym.Env): the environment on which the agent will act.
- """
-
- def __init__(self, env: gym.Env):
- self.env = env
-
- def act(self, obs: np.ndarray, **kwargs) -> np.ndarray:
- return self.env.action_space.sample()
-
-
-def complete_agent_cfg(env: gym.Env, agent_cfg: omegaconf.DictConfig):
- obs_shape = env.observation_space.shape
- act_shape = env.action_space.shape
-
- def _check_and_replace(key: str, value: Any, cfg: omegaconf.DictConfig):
- if key in cfg.keys() and key not in cfg:
- setattr(cfg, key, value)
-
- # create numpy object by existed object
- def _create_numpy_config(array):
- return {
- "_target_": "numpy.array",
- "object": array.tolist(),
- "dtype": str(array.dtype),
- }
-
- _check_and_replace("num_inputs", obs_shape[0], agent_cfg)
- if "action_space" in agent_cfg.keys() and isinstance(agent_cfg.action_space, omegaconf.DictConfig):
- _check_and_replace("low", _create_numpy_config(env.action_space.low), agent_cfg.action_space)
- _check_and_replace("high", _create_numpy_config(env.action_space.high), agent_cfg.action_space)
- _check_and_replace("shape", env.action_space.shape, agent_cfg.action_space)
-
- if "obs_dim" in agent_cfg.keys() and "obs_dim" not in agent_cfg:
- agent_cfg.obs_dim = obs_shape[0]
- if "action_dim" in agent_cfg.keys() and "action_dim" not in agent_cfg:
- agent_cfg.action_dim = act_shape[0]
- if "action_range" in agent_cfg.keys() and "action_range" not in agent_cfg:
- agent_cfg.action_range = [
- float(env.action_space.low.min()),
- float(env.action_space.high.max()),
- ]
- if "action_lb" in agent_cfg.keys() and "action_lb" not in agent_cfg:
- agent_cfg.action_lb = _create_numpy_config(env.action_space.low)
- if "action_ub" in agent_cfg.keys() and "action_ub" not in agent_cfg:
- agent_cfg.action_ub = _create_numpy_config(env.action_space.high)
-
- if "env" in agent_cfg.keys():
- _check_and_replace(
- "low",
- _create_numpy_config(env.action_space.low),
- agent_cfg.env.action_space,
- )
- _check_and_replace(
- "high",
- _create_numpy_config(env.action_space.high),
- agent_cfg.env.action_space,
- )
- _check_and_replace("shape", env.action_space.shape, agent_cfg.env.action_space)
-
- _check_and_replace(
- "low",
- _create_numpy_config(env.observation_space.low),
- agent_cfg.env.observation_space,
- )
- _check_and_replace(
- "high",
- _create_numpy_config(env.observation_space.high),
- agent_cfg.env.observation_space,
- )
- _check_and_replace("shape", env.observation_space.shape, agent_cfg.env.observation_space)
-
- return agent_cfg
-
-
-def load_agent(
- agent_path: Union[str, pathlib.Path],
- env: gym.Env,
- type: Optional[str] = "best",
- device: Optional[str] = None,
-) -> Agent:
- """Loads an agent from a Hydra config file at the given path.
-
- For agent of type "pytorch_sac.agent.sac.SACAgent", the directory
- must contain the following files:
-
- - ".hydra/config.yaml": the Hydra configuration for the agent.
- - "critic.pth": the saved checkpoint for the critic.
- - "actor.pth": the saved checkpoint for the actor.
-
- Args:
- agent_path (str or pathlib.Path): a path to the directory where the agent is saved.
- env (gym.Env): the environment on which the agent will operate (only used to complete
- the agent's configuration).
-
- Returns:
- (Agent): the new agent.
- """
- agent_path = pathlib.Path(agent_path)
- cfg = omegaconf.OmegaConf.load(agent_path / ".hydra" / "config.yaml")
- cfg.device = device
-
- if cfg.algorithm.agent._target_ == "cmrl.third_party.pytorch_sac.sac.SAC":
- import cmrl.third_party.pytorch_sac as pytorch_sac
-
- from .sac_wrapper import SACAgent
-
- complete_agent_cfg(env, cfg.algorithm.agent)
- agent: pytorch_sac.SAC = hydra.utils.instantiate(cfg.algorithm.agent)
- agent.load_checkpoint(ckpt_path=agent_path / "sac_{}.pth".format(type), device=device)
- return SACAgent(agent)
- else:
- raise ValueError("Invalid agent configuration.")
diff --git a/cmrl/agent/sac_wrapper.py b/cmrl/agent/sac_wrapper.py
deleted file mode 100644
index b83c5a5..0000000
--- a/cmrl/agent/sac_wrapper.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import numpy as np
-import torch
-
-import cmrl.third_party.pytorch_sac as pytorch_sac
-
-from .core import Agent
-
-
-class SACAgent(Agent):
- def __init__(self, sac_agent: pytorch_sac.SAC):
- self.sac_agent = sac_agent
-
- def act(self, obs: np.ndarray, sample: bool = False, batched: bool = False, **kwargs) -> np.ndarray:
- """Issues an action given an observation.
-
- Args:
- obs (np.ndarray): the observation (or batch of observations) for which the action
- is needed.
- sample (bool): if ``True`` the agent samples actions from its policy, otherwise it
- returns the mean policy value. Defaults to ``False``.
- batched (bool): if ``True`` signals to the agent that the obs should be interpreted
- as a batch.
-
- Returns:
- (np.ndarray): the action.
- """
- with torch.no_grad():
- return self.sac_agent.select_action(obs, batched=batched, evaluate=not sample)
diff --git a/cmrl/algorithms/__init__.py b/cmrl/algorithms/__init__.py
index 0f9ea59..5b1eb8c 100644
--- a/cmrl/algorithms/__init__.py
+++ b/cmrl/algorithms/__init__.py
@@ -1,4 +1,4 @@
-from cmrl.algorithms.offline import mopo
-from cmrl.algorithms.offline import off_dyna
-from cmrl.algorithms.online import mbpo
-from cmrl.algorithms.online import on_dyna
+from cmrl.algorithms.off_dyna import OfflineDyna
+from cmrl.algorithms.mopo import MOPO
+from cmrl.algorithms.on_dyna import OnlineDyna
+from cmrl.algorithms.mbpo import MBPO
diff --git a/cmrl/algorithms/base_algorithm.py b/cmrl/algorithms/base_algorithm.py
new file mode 100644
index 0000000..532ea13
--- /dev/null
+++ b/cmrl/algorithms/base_algorithm.py
@@ -0,0 +1,116 @@
+import os
+from typing import Optional
+from functools import partial
+
+import numpy as np
+import torch
+from omegaconf import DictConfig, OmegaConf
+from stable_baselines3.common.buffers import ReplayBuffer
+from stable_baselines3.common.callbacks import BaseCallback
+import wandb
+
+from cmrl.models.fake_env import VecFakeEnv
+from cmrl.sb3_extension.logger import configure as logger_configure
+from cmrl.sb3_extension.eval_callback import EvalCallback
+from cmrl.utils.creator import create_dynamics, create_agent
+from cmrl.utils.env import make_env
+
+
+class BaseAlgorithm:
+ def __init__(
+ self,
+ cfg: DictConfig,
+ work_dir: Optional[str] = None,
+ ):
+ self.cfg = cfg
+ self.work_dir = work_dir or os.getcwd()
+
+ self.env, fns = make_env(self.cfg)
+ self.reward_fn, self.termination_fn, self.get_init_obs_fn, self.obs2state_fn, self.state2obs_fn = fns
+
+ self.eval_env, *_ = make_env(self.cfg)
+ np.random.seed(self.cfg.seed)
+ torch.manual_seed(self.cfg.seed)
+
+ format_strings = ["tensorboard", "multi_csv"]
+ if self.cfg.verbose:
+ format_strings += ["stdout"]
+ self.logger = logger_configure("log", format_strings)
+
+ if cfg.wandb:
+ wandb.init(
+ project="causal-mbrl",
+ group=cfg.exp_name,
+ config=OmegaConf.to_container(cfg, resolve=True),
+ sync_tensorboard=True,
+ )
+
+ # create ``cmrl`` dynamics
+ self.dynamics = create_dynamics(
+ self.cfg, self.env.state_space, self.env.action_space, self.obs2state_fn, self.state2obs_fn, logger=self.logger
+ )
+
+ if self.cfg.transition.name == "oracle_transition":
+ graph = self.env.get_transition_graph() if self.cfg.transition.oracle == "truth" else None
+ self.dynamics.transition.set_oracle_graph(graph)
+ if self.cfg.reward_mech.learn and not self.cfg.reward_mech.name == "oracle_reward_mech":
+ graph = self.env.get_reward_mech_graph() if self.cfg.transition.oracle == "truth" else None
+ self.dynamics.reward_mech.set_oracle_graph(graph)
+ if self.cfg.termination_mech.learn and not self.cfg.termination_mech.name == "oracle_termination_mech":
+ graph = self.env.get_termination_mech_graph() if self.cfg.transition.oracle == "truth" else None
+ self.dynamics.termination_mech.set_oracle_graph(graph)
+
+ # create sb3's replay buffer for real offline data
+ self.real_replay_buffer = ReplayBuffer(
+ cfg.task.num_steps,
+ self.env.observation_space,
+ self.env.action_space,
+ self.cfg.device,
+ handle_timeout_termination=False,
+ )
+
+ self.partial_fake_env = partial(
+ VecFakeEnv,
+ self.cfg.algorithm.num_envs,
+ self.env.state_space,
+ self.env.action_space,
+ self.dynamics,
+ self.reward_fn,
+ self.termination_fn,
+ self.get_init_obs_fn,
+ self.real_replay_buffer,
+ penalty_coeff=self.cfg.task.penalty_coeff,
+ logger=self.logger,
+ )
+ self.agent = create_agent(self.cfg, self.fake_env, self.logger)
+
+ @property
+ def fake_env(self) -> VecFakeEnv:
+ return self.partial_fake_env(
+ deterministic=self.cfg.algorithm.deterministic,
+ max_episode_steps=self.env.spec.max_episode_steps,
+ branch_rollout=False,
+ )
+
+ @property
+ def callback(self) -> BaseCallback:
+ fake_eval_env = self.partial_fake_env(
+ deterministic=True, max_episode_steps=self.env.spec.max_episode_steps, branch_rollout=False
+ )
+ return EvalCallback(
+ self.eval_env,
+ fake_eval_env,
+ n_eval_episodes=self.cfg.task.n_eval_episodes,
+ best_model_save_path="./",
+ eval_freq=self.cfg.task.eval_freq,
+ deterministic=True,
+ render=False,
+ )
+
+ def learn(self):
+ self._setup_learn()
+
+ self.agent.learn(total_timesteps=self.cfg.task.num_steps, callback=self.callback)
+
+ def _setup_learn(self):
+ pass
diff --git a/cmrl/algorithms/mbpo.py b/cmrl/algorithms/mbpo.py
new file mode 100644
index 0000000..e34f2f3
--- /dev/null
+++ b/cmrl/algorithms/mbpo.py
@@ -0,0 +1,40 @@
+from typing import Optional
+
+from omegaconf import DictConfig
+from stable_baselines3.common.callbacks import BaseCallback, CallbackList
+
+from cmrl.models.fake_env import VecFakeEnv
+from cmrl.algorithms.base_algorithm import BaseAlgorithm
+from cmrl.sb3_extension.online_mb_callback import OnlineModelBasedCallback
+
+
+class MBPO(BaseAlgorithm):
+ def __init__(
+ self,
+ cfg: DictConfig,
+ work_dir: Optional[str] = None,
+ ):
+ super(MBPO, self).__init__(cfg, work_dir)
+
+ @property
+ def fake_env(self) -> VecFakeEnv:
+ return self.partial_fake_env(
+ deterministic=self.cfg.algorithm.deterministic,
+ max_episode_steps=self.cfg.algorithm.branch_rollout_length,
+ branch_rollout=True,
+ )
+
+ @property
+ def callback(self) -> BaseCallback:
+ eval_callback = super(MBPO, self).callback
+ omb_callback = OnlineModelBasedCallback(
+ self.env,
+ self.dynamics,
+ self.real_replay_buffer,
+ total_online_timesteps=self.cfg.task.online_num_steps,
+ initial_exploration_steps=self.cfg.algorithm.initial_exploration_steps,
+ freq_train_model=self.cfg.task.freq_train_model,
+ device=self.cfg.device,
+ )
+
+ return CallbackList([eval_callback, omb_callback])
diff --git a/cmrl/algorithms/mopo.py b/cmrl/algorithms/mopo.py
new file mode 100644
index 0000000..4a98517
--- /dev/null
+++ b/cmrl/algorithms/mopo.py
@@ -0,0 +1,38 @@
+from typing import Optional
+
+from omegaconf import DictConfig
+
+from cmrl.models.fake_env import VecFakeEnv
+from cmrl.algorithms.base_algorithm import BaseAlgorithm
+from cmrl.utils.env import load_offline_data
+from cmrl.algorithms.util import maybe_load_offline_model
+
+
+class MOPO(BaseAlgorithm):
+ def __init__(
+ self,
+ cfg: DictConfig,
+ work_dir: Optional[str] = None,
+ ):
+ super(MOPO, self).__init__(cfg, work_dir)
+
+ @property
+ def fake_env(self) -> VecFakeEnv:
+ return self.partial_fake_env(
+ deterministic=self.cfg.algorithm.deterministic,
+ max_episode_steps=self.cfg.algorithm.branch_rollout_length,
+ branch_rollout=True,
+ )
+
+ def _setup_learn(self):
+ load_offline_data(self.env, self.real_replay_buffer, self.cfg.task.dataset, self.cfg.task.use_ratio)
+
+ if self.cfg.task.get("auto_load_offline_model", False):
+ existed_trained_model = maybe_load_offline_model(self.dynamics, self.cfg, work_dir=self.work_dir)
+ else:
+ existed_trained_model = None
+ if not existed_trained_model:
+ self.dynamics.learn(
+ real_replay_buffer=self.real_replay_buffer,
+ work_dir=self.work_dir,
+ )
diff --git a/cmrl/algorithms/off_dyna.py b/cmrl/algorithms/off_dyna.py
new file mode 100644
index 0000000..3911c05
--- /dev/null
+++ b/cmrl/algorithms/off_dyna.py
@@ -0,0 +1,29 @@
+from typing import Optional
+
+from omegaconf import DictConfig
+
+from cmrl.algorithms.base_algorithm import BaseAlgorithm
+from cmrl.utils.env import load_offline_data
+from cmrl.algorithms.util import maybe_load_offline_model
+
+
+class OfflineDyna(BaseAlgorithm):
+ def __init__(
+ self,
+ cfg: DictConfig,
+ work_dir: Optional[str] = None,
+ ):
+ super(OfflineDyna, self).__init__(cfg, work_dir)
+
+ def _setup_learn(self):
+ load_offline_data(self.env, self.real_replay_buffer, self.cfg.task.dataset, self.cfg.task.use_ratio)
+
+ if self.cfg.task.get("auto_load_offline_model", False):
+ existed_trained_model = maybe_load_offline_model(self.dynamics, self.cfg, work_dir=self.work_dir)
+ else:
+ existed_trained_model = None
+ if not existed_trained_model:
+ self.dynamics.learn(
+ real_replay_buffer=self.real_replay_buffer,
+ work_dir=self.work_dir,
+ )
diff --git a/cmrl/algorithms/offline/mopo.py b/cmrl/algorithms/offline/mopo.py
deleted file mode 100644
index 860518e..0000000
--- a/cmrl/algorithms/offline/mopo.py
+++ /dev/null
@@ -1,84 +0,0 @@
-import os
-from typing import Optional, cast
-
-import emei
-import hydra.utils
-import numpy as np
-from omegaconf import DictConfig
-from stable_baselines3.common.base_class import BaseAlgorithm
-from stable_baselines3.common.buffers import ReplayBuffer
-
-from cmrl.agent import complete_agent_cfg
-from cmrl.algorithms.util import maybe_load_trained_offline_model, setup_fake_env, load_offline_data
-from cmrl.models.dynamics import ConstraintBasedDynamics
-from cmrl.sb3_extension.eval_callback import EvalCallback
-from cmrl.sb3_extension.logger import configure as logger_configure
-from cmrl.types import InitObsFnType, RewardFnType, TermFnType
-from cmrl.util.creator import create_dynamics
-
-
-def train(
- env: emei.EmeiEnv,
- eval_env: emei.EmeiEnv,
- termination_fn: Optional[TermFnType],
- reward_fn: Optional[RewardFnType],
- get_init_obs_fn: Optional[InitObsFnType],
- cfg: DictConfig,
- work_dir: Optional[str] = None,
-):
- obs_shape = env.observation_space.shape
- act_shape = env.action_space.shape
-
- # build model-free agent, which is a stable-baselines3's agent
- complete_agent_cfg(env, cfg.algorithm.agent)
- agent = cast(BaseAlgorithm, hydra.utils.instantiate(cfg.algorithm.agent))
-
- work_dir = work_dir or os.getcwd()
- logger = logger_configure("log", ["tensorboard", "multi_csv", "stdout"])
-
- numpy_generator = np.random.default_rng(seed=cfg.seed)
-
- # create initial dataset and add it to replay buffer
- dynamics = create_dynamics(cfg.dynamics, obs_shape, act_shape, logger=logger)
- real_replay_buffer = ReplayBuffer(
- cfg.task.num_steps, env.observation_space, env.action_space, cfg.device, handle_timeout_termination=False
- )
- load_offline_data(cfg, env, real_replay_buffer)
-
- fake_eval_env = setup_fake_env(
- cfg=cfg,
- agent=agent,
- dynamics=dynamics,
- reward_fn=reward_fn,
- termination_fn=termination_fn,
- get_init_obs_fn=get_init_obs_fn,
- real_replay_buffer=real_replay_buffer,
- logger=logger,
- max_episode_steps=env.spec.max_episode_steps,
- penalty_coeff=cfg.algorithm.penalty_coeff,
- )
-
- if hasattr(env, "get_causal_graph"):
- oracle_causal_graph = env.get_causal_graph()
- else:
- oracle_causal_graph = None
-
- if isinstance(dynamics, ConstraintBasedDynamics):
- dynamics.set_oracle_mask("transition", oracle_causal_graph.T)
-
- existed_trained_model = maybe_load_trained_offline_model(dynamics, cfg, obs_shape, act_shape, work_dir=work_dir)
- if not existed_trained_model:
- dynamics.learn(real_replay_buffer, **cfg.dynamics, work_dir=work_dir)
-
- eval_callback = EvalCallback(
- eval_env,
- fake_eval_env,
- n_eval_episodes=cfg.task.n_eval_episodes,
- best_model_save_path="./",
- eval_freq=1000,
- deterministic=True,
- render=False,
- )
-
- agent.set_logger(logger)
- agent.learn(total_timesteps=cfg.task.num_steps, callback=eval_callback)
diff --git a/cmrl/algorithms/offline/off_dyna.py b/cmrl/algorithms/offline/off_dyna.py
deleted file mode 100644
index 9993c55..0000000
--- a/cmrl/algorithms/offline/off_dyna.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import os
-from typing import Optional, cast
-
-import emei
-import hydra.utils
-import numpy as np
-from omegaconf import DictConfig
-from stable_baselines3.common.base_class import BaseAlgorithm
-from stable_baselines3.common.buffers import ReplayBuffer
-
-from cmrl.agent import complete_agent_cfg
-from cmrl.algorithms.util import maybe_load_trained_offline_model, setup_fake_env, load_offline_data
-from cmrl.models.dynamics import ConstraintBasedDynamics
-from cmrl.sb3_extension.eval_callback import EvalCallback
-from cmrl.sb3_extension.logger import configure as logger_configure
-from cmrl.types import InitObsFnType, RewardFnType, TermFnType
-from cmrl.util.creator import create_dynamics
-
-
-def train(
- env: emei.EmeiEnv,
- eval_env: emei.EmeiEnv,
- termination_fn: Optional[TermFnType],
- reward_fn: Optional[RewardFnType],
- get_init_obs_fn: Optional[InitObsFnType],
- cfg: DictConfig,
- work_dir: Optional[str] = None,
-):
- obs_shape = env.observation_space.shape
- act_shape = env.action_space.shape
-
- # build model-free agent, which is a stable-baselines3's agent
- complete_agent_cfg(env, cfg.algorithm.agent)
- agent = cast(BaseAlgorithm, hydra.utils.instantiate(cfg.algorithm.agent))
-
- work_dir = work_dir or os.getcwd()
- logger = logger_configure("log", ["tensorboard", "multi_csv", "stdout"])
-
- # create initial dataset and add it to replay buffer
- dynamics = create_dynamics(cfg.dynamics, obs_shape, act_shape, logger=logger)
- real_replay_buffer = ReplayBuffer(
- cfg.task.num_steps, env.observation_space, env.action_space, cfg.device, handle_timeout_termination=False
- )
- load_offline_data(cfg, env, real_replay_buffer)
-
- fake_eval_env = setup_fake_env(
- cfg=cfg,
- agent=agent,
- dynamics=dynamics,
- reward_fn=reward_fn,
- termination_fn=termination_fn,
- get_init_obs_fn=get_init_obs_fn,
- logger=logger,
- max_episode_steps=env.spec.max_episode_steps,
- penalty_coeff=cfg.algorithm.penalty_coeff,
- )
-
- if hasattr(env, "get_causal_graph"):
- oracle_causal_graph = env.get_causal_graph()
- else:
- oracle_causal_graph = None
-
- if isinstance(dynamics, ConstraintBasedDynamics):
- dynamics.set_oracle_mask("transition", oracle_causal_graph.T)
-
- existed_trained_model = maybe_load_trained_offline_model(dynamics, cfg, obs_shape, act_shape, work_dir=work_dir)
- if not existed_trained_model:
- dynamics.learn(real_replay_buffer, **cfg.dynamics, work_dir=work_dir)
-
- eval_callback = EvalCallback(
- eval_env,
- fake_eval_env,
- n_eval_episodes=cfg.task.n_eval_episodes,
- best_model_save_path="./",
- eval_freq=1000,
- deterministic=True,
- render=False,
- )
-
- agent.set_logger(logger)
- agent.learn(total_timesteps=cfg.task.num_steps, callback=eval_callback)
diff --git a/cmrl/algorithms/on_dyna.py b/cmrl/algorithms/on_dyna.py
new file mode 100644
index 0000000..32bbdb4
--- /dev/null
+++ b/cmrl/algorithms/on_dyna.py
@@ -0,0 +1,32 @@
+from typing import Optional
+
+from omegaconf import DictConfig
+from stable_baselines3.common.callbacks import BaseCallback, CallbackList
+
+from cmrl.models.fake_env import VecFakeEnv
+from cmrl.algorithms.base_algorithm import BaseAlgorithm
+from cmrl.sb3_extension.online_mb_callback import OnlineModelBasedCallback
+
+
+class OnlineDyna(BaseAlgorithm):
+ def __init__(
+ self,
+ cfg: DictConfig,
+ work_dir: Optional[str] = None,
+ ):
+ super(OnlineDyna, self).__init__(cfg, work_dir)
+
+ @property
+ def callback(self) -> BaseCallback:
+ eval_callback = super(OnlineDyna, self).callback
+ omb_callback = OnlineModelBasedCallback(
+ self.env,
+ self.dynamics,
+ self.real_replay_buffer,
+ total_online_timesteps=self.cfg.task.online_num_steps,
+ initial_exploration_steps=self.cfg.algorithm.initial_exploration_steps,
+ freq_train_model=self.cfg.task.freq_train_model,
+ device=self.cfg.device,
+ )
+
+ return CallbackList([eval_callback, omb_callback])
diff --git a/cmrl/algorithms/online/mbpo.py b/cmrl/algorithms/online/mbpo.py
deleted file mode 100644
index d858141..0000000
--- a/cmrl/algorithms/online/mbpo.py
+++ /dev/null
@@ -1,94 +0,0 @@
-import os
-from typing import Optional, cast
-
-import emei
-import hydra.utils
-import numpy as np
-from omegaconf import DictConfig
-from stable_baselines3.common.base_class import BaseAlgorithm
-from stable_baselines3.common.callbacks import CallbackList
-from stable_baselines3.common.buffers import ReplayBuffer
-
-from cmrl.agent import complete_agent_cfg
-from cmrl.algorithms.util import setup_fake_env
-from cmrl.models.dynamics import ConstraintBasedDynamics
-from cmrl.sb3_extension.eval_callback import EvalCallback
-from cmrl.sb3_extension.online_mb_callback import OnlineModelBasedCallback
-from cmrl.sb3_extension.logger import configure as logger_configure
-from cmrl.types import InitObsFnType, RewardFnType, TermFnType
-from cmrl.util.creator import create_dynamics
-
-
-def train(
- env: emei.EmeiEnv,
- eval_env: emei.EmeiEnv,
- termination_fn: Optional[TermFnType],
- reward_fn: Optional[RewardFnType],
- get_init_obs_fn: Optional[InitObsFnType],
- cfg: DictConfig,
- work_dir: Optional[str] = None,
-):
- obs_shape = env.observation_space.shape
- act_shape = env.action_space.shape
-
- # build model-free agent, which is a stable-baselines3's agent
- complete_agent_cfg(env, cfg.algorithm.agent)
- agent = cast(BaseAlgorithm, hydra.utils.instantiate(cfg.algorithm.agent))
-
- work_dir = work_dir or os.getcwd()
- logger = logger_configure("log", ["tensorboard", "multi_csv", "stdout"])
-
- numpy_generator = np.random.default_rng(seed=cfg.seed)
-
- dynamics = create_dynamics(cfg.dynamics, obs_shape, act_shape, logger=logger)
- real_replay_buffer = ReplayBuffer(
- cfg.task.online_num_steps,
- env.observation_space,
- env.action_space,
- device=cfg.device,
- n_envs=1,
- optimize_memory_usage=False,
- )
-
- fake_eval_env = setup_fake_env(
- cfg=cfg,
- agent=agent,
- dynamics=dynamics,
- reward_fn=reward_fn,
- termination_fn=termination_fn,
- get_init_obs_fn=get_init_obs_fn,
- real_replay_buffer=real_replay_buffer,
- logger=logger,
- max_episode_steps=env.spec.max_episode_steps,
- )
-
- if hasattr(env, "causal_graph"):
- oracle_causal_graph = env.causal_graph
- else:
- oracle_causal_graph = None
-
- if isinstance(dynamics, ConstraintBasedDynamics):
- dynamics.set_oracle_mask("transition", oracle_causal_graph.T)
-
- eval_callback = EvalCallback(
- eval_env,
- fake_eval_env,
- n_eval_episodes=cfg.task.n_eval_episodes,
- best_model_save_path="./",
- eval_freq=cfg.task.eval_freq,
- deterministic=True,
- render=False,
- )
-
- omb_callback = OnlineModelBasedCallback(
- env,
- dynamics,
- real_replay_buffer,
- total_num_steps=cfg.task.online_num_steps,
- initial_exploration_steps=cfg.algorithm.initial_exploration_steps,
- freq_train_model=cfg.task.freq_train_model,
- device=cfg.device,
- )
-
- agent.set_logger(logger)
- agent.learn(total_timesteps=int(1e10), callback=CallbackList([eval_callback, omb_callback]))
diff --git a/cmrl/algorithms/online/on_dyna.py b/cmrl/algorithms/online/on_dyna.py
deleted file mode 100644
index a1056ca..0000000
--- a/cmrl/algorithms/online/on_dyna.py
+++ /dev/null
@@ -1,93 +0,0 @@
-import os
-from typing import Optional, cast
-
-import emei
-import hydra.utils
-import numpy as np
-from omegaconf import DictConfig
-from stable_baselines3.common.base_class import BaseAlgorithm
-from stable_baselines3.common.callbacks import CallbackList
-from stable_baselines3.common.buffers import ReplayBuffer
-
-from cmrl.agent import complete_agent_cfg
-from cmrl.algorithms.util import setup_fake_env
-from cmrl.models.dynamics import ConstraintBasedDynamics
-from cmrl.sb3_extension.eval_callback import EvalCallback
-from cmrl.sb3_extension.online_mb_callback import OnlineModelBasedCallback
-from cmrl.sb3_extension.logger import configure as logger_configure
-from cmrl.types import InitObsFnType, RewardFnType, TermFnType
-from cmrl.util.creator import create_dynamics
-
-
-def train(
- env: emei.EmeiEnv,
- eval_env: emei.EmeiEnv,
- termination_fn: Optional[TermFnType],
- reward_fn: Optional[RewardFnType],
- get_init_obs_fn: Optional[InitObsFnType],
- cfg: DictConfig,
- work_dir: Optional[str] = None,
-):
- obs_shape = env.observation_space.shape
- act_shape = env.action_space.shape
-
- # build model-free agent, which is a stable-baselines3's agent
- complete_agent_cfg(env, cfg.algorithm.agent)
- agent = cast(BaseAlgorithm, hydra.utils.instantiate(cfg.algorithm.agent))
-
- work_dir = work_dir or os.getcwd()
- logger = logger_configure("log", ["tensorboard", "multi_csv", "stdout"])
-
- numpy_generator = np.random.default_rng(seed=cfg.seed)
-
- dynamics = create_dynamics(cfg.dynamics, obs_shape, act_shape, logger=logger)
- real_replay_buffer = ReplayBuffer(
- cfg.task.online_num_steps,
- env.observation_space,
- env.action_space,
- device=cfg.device,
- n_envs=1,
- optimize_memory_usage=False,
- )
-
- fake_eval_env = setup_fake_env(
- cfg=cfg,
- agent=agent,
- dynamics=dynamics,
- reward_fn=reward_fn,
- termination_fn=termination_fn,
- get_init_obs_fn=get_init_obs_fn,
- logger=logger,
- max_episode_steps=env.spec.max_episode_steps,
- )
-
- if hasattr(env, "causal_graph"):
- oracle_causal_graph = env.causal_graph
- else:
- oracle_causal_graph = None
-
- if isinstance(dynamics, ConstraintBasedDynamics):
- dynamics.set_oracle_mask("transition", oracle_causal_graph.T)
-
- eval_callback = EvalCallback(
- eval_env,
- fake_eval_env,
- n_eval_episodes=cfg.task.n_eval_episodes,
- best_model_save_path="./",
- eval_freq=cfg.task.eval_freq,
- deterministic=True,
- render=False,
- )
-
- omb_callback = OnlineModelBasedCallback(
- env,
- dynamics,
- real_replay_buffer,
- total_num_steps=cfg.task.online_num_steps,
- initial_exploration_steps=cfg.algorithm.initial_exploration_steps,
- freq_train_model=cfg.task.freq_train_model,
- device=cfg.device,
- )
-
- agent.set_logger(logger)
- agent.learn(total_timesteps=int(1e10), callback=CallbackList([eval_callback, omb_callback]))
diff --git a/cmrl/algorithms/util.py b/cmrl/algorithms/util.py
index 6e901c4..fd594f1 100644
--- a/cmrl/algorithms/util.py
+++ b/cmrl/algorithms/util.py
@@ -1,29 +1,29 @@
-import pathlib
from typing import Optional, cast
from copy import deepcopy
+import pathlib
-import emei
import hydra
-import numpy as np
-from omegaconf import DictConfig
+from omegaconf import DictConfig, OmegaConf
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import ReplayBuffer
from cmrl.types import InitObsFnType, RewardFnType, TermFnType
-from cmrl.models.dynamics import BaseDynamics
-from cmrl.util.config import get_complete_dynamics_cfg, load_hydra_cfg
-from cmrl.models.fake_env import VecFakeEnv
+
+from cmrl.models.dynamics import Dynamics
+from cmrl.utils.config import load_hydra_cfg
-def is_same_dict(dict1, dict2):
+def compare_dict(dict1, dict2):
+ if len(list(dict1)) != len(list(dict2)):
+ return False
for key in dict1:
if key not in dict2:
return False
else:
- if isinstance(dict1[key], DictConfig) and isinstance(dict2[key], DictConfig):
- if not is_same_dict(dict1[key], dict2[key]):
+ if isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
+ if not compare_dict(dict1[key], dict2[key]):
return False
else:
if dict1[key] != dict2[key]:
@@ -31,94 +31,38 @@ def is_same_dict(dict1, dict2):
return True
-def maybe_load_trained_offline_model(dynamics: BaseDynamics, cfg, obs_shape, act_shape, work_dir):
+def maybe_load_offline_model(
+ dynamics: Dynamics,
+ cfg: DictConfig,
+ work_dir,
+):
work_dir = pathlib.Path(work_dir)
if "." not in work_dir.name: # exp by hydra's MULTIRUN mode
- task_exp_dir = work_dir.parent.parent.parent
- else:
task_exp_dir = work_dir.parent.parent
- dynamics_cfg = cfg.dynamics
-
- for date_dir in task_exp_dir.glob(r"*"):
- for time_dir in date_dir.glob(r"*"):
- if (time_dir / "multirun.yaml").exists(): # exp by hydra's MULTIRUN mode, multi exp in this time
- this_time_exp_dir_list = list(time_dir.glob(r"*"))
- else: # only one exp in this time
- this_time_exp_dir_list = [time_dir]
-
- for exp_dir in this_time_exp_dir_list:
- if not (exp_dir / ".hydra").exists():
- continue
- exp_cfg = load_hydra_cfg(exp_dir)
- exp_dynamics_cfg = get_complete_dynamics_cfg(exp_cfg.dynamics, obs_shape, act_shape)
-
- if exp_cfg.seed == cfg.seed and is_same_dict(dynamics_cfg, exp_dynamics_cfg):
- exist_model_file = True
- for mech in dynamics.learn_mech:
- mech_file_name = getattr(dynamics, mech).model_file_name
- if not (exp_dir / mech_file_name).exists():
- exist_model_file = False
- if exist_model_file:
- dynamics.load(exp_dir)
- print("loaded dynamics from {}".format(exp_dir))
- return True
+ else:
+ task_exp_dir = work_dir.parent
+
+ transition_cfg = OmegaConf.to_container(cfg.transition, resolve=True)
+
+ for time_dir in task_exp_dir.glob(r"*"):
+ if (time_dir / "multirun.yaml").exists(): # exp by hydra's MULTIRUN mode, multi exp in this time
+ this_time_exp_dir_list = list(time_dir.glob(r"*"))
+ else: # only one exp in this time
+ this_time_exp_dir_list = [time_dir]
+
+ for exp_dir in this_time_exp_dir_list:
+ if not (exp_dir / ".hydra").exists():
+ continue
+ exp_cfg = load_hydra_cfg(exp_dir)
+
+ exp_transition_dir = OmegaConf.to_container(exp_cfg.transition, resolve=True)
+ if (
+ cfg.seed == exp_cfg.seed
+ and cfg.task.use_ratio == exp_cfg.task.use_ratio
+ and compare_dict(exp_transition_dir, transition_cfg)
+ and (exp_dir / "transition").exists()
+ ):
+ dynamics.transition.load(exp_dir / "transition")
+ print("loaded dynamics from {}".format(exp_dir))
+ return True
return False
-
-
-def setup_fake_env(
- cfg: DictConfig,
- agent: BaseAlgorithm,
- dynamics,
- reward_fn: Optional[RewardFnType],
- termination_fn: Optional[TermFnType],
- get_init_obs_fn: Optional[InitObsFnType],
- real_replay_buffer: Optional[ReplayBuffer] = None,
- logger=None,
- max_episode_steps: int = 1000,
- penalty_coeff: Optional[float] = 0,
-):
- fake_env = cast(VecFakeEnv, agent.env)
- fake_env.set_up(
- dynamics,
- reward_fn,
- termination_fn,
- get_init_obs_fn=get_init_obs_fn,
- real_replay_buffer=real_replay_buffer,
- logger=logger,
- max_episode_steps=max_episode_steps,
- penalty_coeff=penalty_coeff,
- )
- agent.env = VecMonitor(fake_env)
-
- fake_eval_env_cfg = deepcopy(cfg.algorithm.agent.env)
- fake_eval_env_cfg.num_envs = cfg.task.n_eval_episodes
- fake_eval_env = cast(VecFakeEnv, hydra.utils.instantiate(fake_eval_env_cfg))
- fake_eval_env.set_up(
- dynamics,
- reward_fn,
- termination_fn,
- get_init_obs_fn=get_init_obs_fn,
- max_episode_steps=max_episode_steps,
- penalty_coeff=penalty_coeff,
- )
- fake_eval_env.seed(seed=cfg.seed)
- return fake_eval_env
-
-
-def load_offline_data(cfg: DictConfig, env, replay_buffer: ReplayBuffer):
- assert hasattr(env, "get_dataset"), "env must have `get_dataset` method"
-
- params, dataset_type = cfg.task.env.split("___")[-2:]
- data_dict = env.get_dataset("{}-{}".format(params, dataset_type))
- all_data_num = len(data_dict["observations"])
- sample_data_num = int(cfg.task.use_ratio * all_data_num)
- sample_idx = np.random.permutation(all_data_num)[:sample_data_num]
-
- replay_buffer.extend(
- data_dict["observations"][sample_idx],
- data_dict["next_observations"][sample_idx],
- data_dict["actions"][sample_idx],
- data_dict["rewards"][sample_idx],
- data_dict["terminals"][sample_idx].astype(bool) | data_dict["timeouts"][sample_idx].astype(bool),
- [{}] * sample_data_num,
- )
diff --git a/cmrl/diagnostics/base_diagnostic.py b/cmrl/diagnostics/base_diagnostic.py
new file mode 100644
index 0000000..3039cd4
--- /dev/null
+++ b/cmrl/diagnostics/base_diagnostic.py
@@ -0,0 +1,11 @@
+from typing import Union
+import pathlib
+
+
+class BaseDiagnostic:
+ def __init__(self, exp_dir: Union[str, pathlib.Path]):
+ if isinstance(exp_dir, str):
+ self.exp_dir = pathlib.Path(exp_dir)
+ else:
+ self.exp_dir = exp_dir
+ pass
diff --git a/cmrl/diagnostics/eval_model_on_dataset.py b/cmrl/diagnostics/eval_model_on_dataset.py
index a0166c2..faf379a 100644
--- a/cmrl/diagnostics/eval_model_on_dataset.py
+++ b/cmrl/diagnostics/eval_model_on_dataset.py
@@ -8,9 +8,10 @@
import matplotlib.pylab as plt
import numpy as np
-import cmrl.util.creator
-import cmrl.util.env
-from cmrl.util.config import load_hydra_cfg
+import cmrl.utils.creator
+import cmrl.utils.env
+from cmrl.utils.config import load_hydra_cfg
+from cmrl.utils.transition_iterator import TransitionIterator
class DatasetEvaluator:
@@ -24,7 +25,7 @@ def __init__(self, model_dir: str, dataset: str, batch_size: int = 4096, device=
self.dynamics = cmrl.util.creator.create_dynamics(
self.cfg.dynamics,
- self.env.observation_space.shape,
+ self.env.state_space.shape,
self.env.action_space.shape,
load_dir=self.model_path,
load_device=device,
@@ -32,7 +33,7 @@ def __init__(self, model_dir: str, dataset: str, batch_size: int = 4096, device=
self.replay_buffer = cmrl.util.creator.create_replay_buffer(
self.cfg,
- self.env.observation_space.shape,
+ self.env.state_space.shape,
self.env.action_space.shape,
)
@@ -62,7 +63,7 @@ def __init__(self, model_dir: str, dataset: str, batch_size: int = 4096, device=
def plot_dataset_results(
self,
- dataset: cmrl.util.TransitionIterator,
+ dataset: TransitionIterator,
hist_bins: int = 20,
hist_log: bool = True,
):
diff --git a/cmrl/diagnostics/eval_model_on_space.py b/cmrl/diagnostics/eval_model_on_space.py
index 645b10e..4e17e53 100644
--- a/cmrl/diagnostics/eval_model_on_space.py
+++ b/cmrl/diagnostics/eval_model_on_space.py
@@ -11,22 +11,16 @@
import numpy as np
from matplotlib.widgets import Button, RadioButtons, Slider
-import cmrl.util.creator
-import cmrl.util.env
-from cmrl.util.config import load_hydra_cfg
+import cmrl.utils.creator
+import cmrl.utils.env
+from cmrl.utils.config import load_hydra_cfg
+from cmrl.utils.creator import create_dynamics, create_agent
+from cmrl.models.fake_env import get_penalty
-mpl.use("Qt5Agg")
+mpl.use("TKAgg")
SIN_COS_BINDINGS = {"BoundaryInvertedPendulumSwingUp-v0": [1]}
-def calculate_penalty(ensemble_mean):
- avg_ensemble_mean = np.mean(ensemble_mean, axis=0) # average predictions over models
- diffs = ensemble_mean - avg_ensemble_mean
- dists = np.linalg.norm(diffs, axis=2) # distance in obs space
- penalty = np.max(dists, axis=0) # max distances over models
- return penalty
-
-
def set_ylim(y_min, y_max, ax):
if y_max - y_min > 0.1:
obs_y_lim = [y_min - 0.05, y_max + 0.05]
@@ -73,25 +67,30 @@ def __init__(
self.cfg = load_hydra_cfg(self.model_path)
self.cfg.device = device
- self.env, *_ = cmrl.util.env.make_env(self.cfg)
+ self.env, *_ = cmrl.utils.env.make_env(self.cfg)
if penalty_coeff is None:
self.penalty_coeff = self.cfg.task.penalty_coeff
else:
self.penalty_coeff = penalty_coeff
- self.dynamics = cmrl.util.creator.create_dynamics(
- self.cfg.dynamics,
- self.env.observation_space.shape,
- self.env.action_space.shape,
- load_dir=self.model_path,
- load_device=device,
+ self.dynamics = create_dynamics(
+ self.cfg,
+ self.env.state_space,
+ self.env.action_space,
)
+ if not self.cfg.transition.discovery:
+ self.dynamics.transition.set_oracle_graph(self.env.get_transition_graph())
+ if self.cfg.reward_mech.learn and not self.cfg.reward_mech.discovery:
+ self.dynamics.reward_mech.set_oracle_graph(self.env.get_reward_mech_graph())
+ if self.cfg.termination_mech.learn and not self.cfg.termination_mech.discovery:
+ self.dynamics.termination_mech.set_oracle_graph(self.env.get_termination_mech_graph())
+ self.dynamics.transition.load(self.model_path / "transition")
self.bindings = []
self.obs_range, self.action_range = self.get_range()
self.range = np.concatenate([self.obs_range, self.action_range], axis=0)
- self.real_obs_dim_num = self.env.observation_space.shape[0]
+ self.real_obs_dim_num = self.env.state_space.shape[0]
self.compact_obs_dim_num, self.action_dim_num = (
self.obs_range.shape[0],
self.action_range.shape[0],
@@ -213,20 +212,19 @@ def slider_changed(value, dim=dim):
self.draw_button.on_clicked(self.draw)
def get_range(self, dataset_type="SAC-expert-replay"):
- universe, basic_env_name, params, origin_dataset_type = self.cfg.task.env.split("___")
- data_dict = self.env.get_dataset("{}-{}".format(params, dataset_type))
+ data_dict = self.env.get_dataset(dataset_type)
obs_min = np.percentile(data_dict["observations"], self.range_quantile, axis=0)
obs_max = np.percentile(data_dict["observations"], 100 - self.range_quantile, axis=0)
action_min = np.percentile(data_dict["actions"], self.range_quantile, axis=0)
action_max = np.percentile(data_dict["actions"], 100 - self.range_quantile, axis=0)
obs_range, action_range = np.array(list(zip(obs_min, obs_max))), np.array(list(zip(action_min, action_max)))
- if basic_env_name in SIN_COS_BINDINGS:
- self.bindings = SIN_COS_BINDINGS[basic_env_name]
- for idx, binding_idx in enumerate(self.bindings):
- theta_idx = binding_idx - idx
- obs_range = np.delete(obs_range, [binding_idx, binding_idx + 1], axis=0)
- obs_range = np.insert(obs_range, theta_idx, np.array([0, 2 * np.pi]), axis=0)
+ # if basic_env_name in SIN_COS_BINDINGS:
+ # self.bindings = SIN_COS_BINDINGS[basic_env_name]
+ # for idx, binding_idx in enumerate(self.bindings):
+ # theta_idx = binding_idx - idx
+ # obs_range = np.delete(obs_range, [binding_idx, binding_idx + 1], axis=0)
+ # obs_range = np.insert(obs_range, theta_idx, np.array([0, 2 * np.pi]), axis=0)
return obs_range, action_range
def build_model_in(self):
@@ -250,7 +248,7 @@ def build_model_in(self):
real_model_in[:, dim] = np.cos(compact_model_in[:, compact_dim].copy())
else: # is an action
compact_dim = dim - (self.real_obs_dim_num - self.compact_obs_dim_num)
- real_model_in[:, dim] = np.cos(compact_model_in[:, compact_dim].copy())
+ real_model_in[:, dim] = compact_model_in[:, compact_dim].copy()
return x, real_model_in
def draw(self, event):
@@ -285,32 +283,29 @@ def get_model_out(self, model_in):
penalized_reward = np.empty(self.plot_dot_num)
for batch_idx in range(batch_num):
+ f, t = self.batch_size * batch_idx, self.batch_size * (batch_idx + 1)
+
batch_input = model_in[self.batch_size * batch_idx : self.batch_size * (batch_idx + 1)]
batch_obs, batch_action = (
batch_input[:, : self.real_obs_dim_num],
batch_input[:, self.real_obs_dim_num :],
)
- dynamics_result = self.dynamics.query(batch_obs, batch_action, return_as_np=True)
- gt_next_obs, gt_reward, gt_terminal, gt_truncated, _ = self.env.query(batch_obs, batch_action)
- # predict and ground truth
- batch_predict_obs = dynamics_result["batch_next_obs"]["mean"].mean(0)
- batch_gt_obs = gt_next_obs
+
+ predict_next_obs, predict_reward, terminal, info = self.dynamics.step(batch_obs, batch_action)
+ gt_next_obs = self.env.get_batch_next_obs(batch_obs, batch_action)
+ gt_reward = self.env.get_batch_reward(gt_next_obs)
+
if self.draw_diff:
- batch_predict_obs -= batch_obs
- batch_gt_obs -= batch_obs
- batch_predict = batch_predict_obs[:, self.current_out_dim]
- batch_ground_truth = batch_gt_obs[:, self.current_out_dim]
- # reward
- # batch_reward = dynamics_result["batch_reward"]["mean"].mean(0)[:, 0]
- batch_reward = gt_reward
- # penalized_reward
- batch_penalty = calculate_penalty(dynamics_result["batch_next_obs"]["mean"])
- batch_penalized_reward = batch_reward - batch_penalty * self.penalty_coeff
-
- predict[self.batch_size * batch_idx : self.batch_size * (batch_idx + 1)] = batch_predict
- ground_truth[self.batch_size * batch_idx : self.batch_size * (batch_idx + 1)] = batch_ground_truth
- reward[self.batch_size * batch_idx : self.batch_size * (batch_idx + 1)] = batch_reward
- penalized_reward[self.batch_size * batch_idx : self.batch_size * (batch_idx + 1)] = batch_penalized_reward
+ predict_next_obs -= batch_obs
+ gt_next_obs -= batch_obs
+
+ batch_penalty = get_penalty(info["origin-next_obs"])
+
+ predict[f:t] = predict_next_obs[:, self.current_out_dim]
+ ground_truth[f:t] = gt_next_obs[:, self.current_out_dim]
+ reward[f:t] = gt_reward[:, 0]
+ penalized_reward[f:t] = gt_reward[:, 0] - batch_penalty * self.penalty_coeff
+
return predict, ground_truth, reward, penalized_reward
def run(self):
@@ -326,7 +321,7 @@ def run(self):
np.linspace(0, 1, 100),
color="black",
lw=2,
- label="gt",
+ label="ground truth",
)
(self.reward_line,) = self.reward_ax.plot(
np.linspace(0, 1, 100),
@@ -358,10 +353,10 @@ def run(self):
parser = argparse.ArgumentParser()
parser.add_argument("model_dir", type=str, default=None)
parser.add_argument("--penalty_coeff", type=float, default=None)
- parser.add_argument("--draw_diff", action="store_true")
+ parser.add_argument("--not_draw_diff", action="store_true", default=False)
args = parser.parse_args()
- evaluator = DatasetEvaluator(args.model_dir, penalty_coeff=args.penalty_coeff, draw_diff=args.draw_diff)
+ evaluator = DatasetEvaluator(args.model_dir, penalty_coeff=args.penalty_coeff, draw_diff=not args.not_draw_diff)
mpl.rcParams["figure.facecolor"] = "white"
mpl.rcParams["font.size"] = 14
diff --git a/cmrl/diagnostics/run_trained_model.py b/cmrl/diagnostics/run_trained_model.py
index 379c842..0871d43 100644
--- a/cmrl/diagnostics/run_trained_model.py
+++ b/cmrl/diagnostics/run_trained_model.py
@@ -17,9 +17,9 @@
import cmrl
import cmrl.agent
import cmrl.models
-import cmrl.util.creator
-from cmrl.util.config import load_hydra_cfg
-from cmrl.util.env import make_env
+import cmrl.utils.creator
+from cmrl.utils.config import load_hydra_cfg
+from cmrl.utils.env import make_env
class Runner:
@@ -34,7 +34,7 @@ def __init__(self, model_dir: str, device: str = "cuda:0", render: bool = False)
self.dynamics = cmrl.util.creator.create_dynamics(
self.cfg.dynamics,
- self.env.observation_space.shape,
+ self.env.state_space.shape,
self.env.action_space.shape,
load_dir=self.model_path,
load_device=device,
@@ -61,7 +61,7 @@ def __init__(self, model_dir: str, device: str = "cuda:0", render: bool = False)
self.agent = agent_class.load(self.model_path / "best_model")
def run(self):
- # from emei.util import random_policy_test
+ # from emei.utils import random_policy_test
obs = self.fake_eval_env.reset()
if self.render:
self.fake_eval_env.render()
diff --git a/cmrl/diagnostics/run_trained_policy.py b/cmrl/diagnostics/run_trained_policy.py
index 2e8f8a8..94f2008 100644
--- a/cmrl/diagnostics/run_trained_policy.py
+++ b/cmrl/diagnostics/run_trained_policy.py
@@ -10,8 +10,8 @@
import cmrl
import cmrl.agent
import cmrl.models
-from cmrl.util.config import load_hydra_cfg
-from cmrl.util.env import make_env
+from cmrl.utils.config import load_hydra_cfg
+from cmrl.utils.env import make_env
class Runner:
@@ -26,7 +26,7 @@ def __init__(self, agent_dir: str, type: str = "best", device="cuda:0"):
self.agent = agent_class.load(self.agent_dir / "best_model")
def run(self):
- # from emei.util import random_policy_test
+ # from emei.utils import random_policy_test
obs = self.env.reset()
self.env.render()
total_reward = 0
diff --git a/cmrl/examples/conf/algorithm/mbpo.yaml b/cmrl/examples/conf/algorithm/mbpo.yaml
index 55fab6d..c53e292 100644
--- a/cmrl/examples/conf/algorithm/mbpo.yaml
+++ b/cmrl/examples/conf/algorithm/mbpo.yaml
@@ -1,32 +1,22 @@
name: "mbpo"
+algo:
+ _partial_: true
+ _target_: cmrl.algorithms.MBPO
+
freq_train_model: ${task.freq_train_model}
-real_data_ratio: 0.0
-sac_samples_action: true
-initial_exploration_steps: 5000
-random_initial_explore: false
num_eval_episodes: 5
-# --------------------------------------------
-# SAC Agent configuration
-# --------------------------------------------
+initial_exploration_steps: 1000
+
+num_envs: 1000
+deterministic: false
agent:
+ _partial_: true
_target_: stable_baselines3.sac.SAC
policy: "MlpPolicy"
- env:
- _target_: cmrl.models.fake_env.VecFakeEnv
- num_envs: 1000
- action_space:
- _target_: gym.spaces.Box
- low: ???
- high: ???
- shape: ???
- observation_space:
- _target_: gym.spaces.Box
- low: ???
- high: ???
- shape: ???
+ env: ???
learning_starts: 0
batch_size: 256
tau: 0.005
diff --git a/cmrl/examples/conf/algorithm/mopo.yaml b/cmrl/examples/conf/algorithm/mopo.yaml
index 9daa227..9ff83d8 100644
--- a/cmrl/examples/conf/algorithm/mopo.yaml
+++ b/cmrl/examples/conf/algorithm/mopo.yaml
@@ -1,33 +1,21 @@
name: "mopo"
-freq_train_model: ${task.freq_train_model}
-real_data_ratio: 0.0
-
-sac_samples_action: true
-num_eval_episodes: 5
+algo:
+ _partial_: true
+ _target_: cmrl.algorithms.MOPO
dataset_size: 1000000
penalty_coeff: ${task.penalty_coeff}
-# --------------------------------------------
-# SAC Agent configuration
-# --------------------------------------------
+branch_rollout_length: 5
+
+num_envs: 100
+deterministic: false
agent:
+ _partial_: true
_target_: stable_baselines3.sac.SAC
policy: "MlpPolicy"
- env:
- _target_: cmrl.models.fake_env.VecFakeEnv
- num_envs: 1000
- action_space:
- _target_: gym.spaces.Box
- low: ???
- high: ???
- shape: ???
- observation_space:
- _target_: gym.spaces.Box
- low: ???
- high: ???
- shape: ???
+ env: ???
learning_starts: 0
batch_size: 256
tau: 0.005
diff --git a/cmrl/examples/conf/algorithm/off_dyna.yaml b/cmrl/examples/conf/algorithm/off_dyna.yaml
index d22c034..c4d0bec 100644
--- a/cmrl/examples/conf/algorithm/off_dyna.yaml
+++ b/cmrl/examples/conf/algorithm/off_dyna.yaml
@@ -1,33 +1,19 @@
name: "off_dyna"
-freq_train_model: ${task.freq_train_model}
-real_data_ratio: 0.0
-
-sac_samples_action: true
-num_eval_episodes: 5
+algo:
+ _partial_: true
+ _target_: cmrl.algorithms.OfflineDyna
dataset_size: 1000000
penalty_coeff: ${task.penalty_coeff}
-# --------------------------------------------
-# SAC Agent configuration
-# --------------------------------------------
+num_envs: 8
+deterministic: false
agent:
+ _partial_: true
_target_: stable_baselines3.sac.SAC
policy: "MlpPolicy"
- env:
- _target_: cmrl.models.fake_env.VecFakeEnv
- num_envs: 16
- action_space:
- _target_: gym.spaces.Box
- low: ???
- high: ???
- shape: ???
- observation_space:
- _target_: gym.spaces.Box
- low: ???
- high: ???
- shape: ???
+ env: ???
learning_starts: 0
batch_size: 256
tau: 0.005
diff --git a/cmrl/examples/conf/algorithm/on_dyna.yaml b/cmrl/examples/conf/algorithm/on_dyna.yaml
index d3ae6e6..e66c5b6 100644
--- a/cmrl/examples/conf/algorithm/on_dyna.yaml
+++ b/cmrl/examples/conf/algorithm/on_dyna.yaml
@@ -1,32 +1,22 @@
name: "on_dyna"
+algo:
+ _partial_: true
+ _target_: cmrl.algorithms.OnlineDyna
+
freq_train_model: ${task.freq_train_model}
-real_data_ratio: 0.0
-sac_samples_action: true
num_eval_episodes: 5
initial_exploration_steps: 1000
-# --------------------------------------------
-# SAC Agent configuration
-# --------------------------------------------
+num_envs: 16
+deterministic: false
agent:
+ _partial_: true
_target_: stable_baselines3.sac.SAC
policy: "MlpPolicy"
- env:
- _target_: cmrl.models.fake_env.VecFakeEnv
- num_envs: 16
- action_space:
- _target_: gym.spaces.Box
- low: ???
- high: ???
- shape: ???
- observation_space:
- _target_: gym.spaces.Box
- low: ???
- high: ???
- shape: ???
+ env: ???
learning_starts: 0
batch_size: 256
tau: 0.005
diff --git a/cmrl/examples/conf/dynamics/constraint_based_dynamics.yaml b/cmrl/examples/conf/dynamics/constraint_based_dynamics.yaml
deleted file mode 100644
index 2fefed1..0000000
--- a/cmrl/examples/conf/dynamics/constraint_based_dynamics.yaml
+++ /dev/null
@@ -1,63 +0,0 @@
-name: constraint_based_dynamics
-
-multi_step: ${task.multi_step}
-
-transition:
- _target_: cmrl.models.transition.ExternalMaskEnsembleGaussianTransition
- # transition info
- obs_size: ???
- action_size: ???
- deterministic: false
- # algorithm parameters
- ensemble_num: ${task.ensemble_num}
- elite_num: ${task.elite_num}
- residual: true
- learn_logvar_bounds: false # so far this works better
- # network parameters
- num_layers: 4
- hid_size: 200
- activation_fn_cfg:
- _target_: torch.nn.SiLU
- # others
- device: ${device}
-
-learned_reward: ${task.learning_reward}
-reward_mech:
- _target_: cmrl.models.BaseRewardMech
- # transition info
- obs_size: ???
- action_size: ???
- deterministic: false
- # algorithm parameters
- learn_logvar_bounds: false # so far this works better
- ensemble_num: ${task.ensemble_num}
- elite_num: ${task.elite_num}
- # network parameters
- num_layers: 4
- hid_size: 200
- activation_fn_cfg:
- _target_: torch.nn.SiLU
- # others
- device: ${device}
-
-learned_termination: ${task.learning_terminal}
-termination_mech:
- _target_: cmrl.models.BaseTerminationMech
- # transition info
- obs_size: ???
- action_size: ???
- deterministic: false
-
-optim_lr: ${task.optim_lr}
-weight_decay: ${task.weight_decay}
-patience: ${task.patience}
-batch_size: ${task.batch_size}
-use_ratio: ${task.use_ratio}
-validation_ratio: ${task.validation_ratio}
-shuffle_each_epoch: ${task.shuffle_each_epoch}
-bootstrap_permutes: ${task.bootstrap_permutes}
-longest_epoch: ${task.longest_epoch}
-improvement_threshold: ${task.improvement_threshold}
-
-normalize: true
-normalize_double_precision: true
diff --git a/cmrl/examples/conf/dynamics/plain_dynamics.yaml b/cmrl/examples/conf/dynamics/plain_dynamics.yaml
deleted file mode 100644
index 74914c2..0000000
--- a/cmrl/examples/conf/dynamics/plain_dynamics.yaml
+++ /dev/null
@@ -1,63 +0,0 @@
-name: plain_dynamics
-
-multi_step: ${task.multi_step}
-
-transition:
- _target_: cmrl.models.transition.PlainEnsembleGaussianTransition
- # transition info
- obs_size: ???
- action_size: ???
- deterministic: false
- # algorithm parameters
- ensemble_num: ${task.ensemble_num}
- elite_num: ${task.elite_num}
- residual: true
- learn_logvar_bounds: false # so far this works better
- # network parameters
- num_layers: 4
- hid_size: 200
- activation_fn_cfg:
- _target_: torch.nn.SiLU
- # others
- device: ${device}
-
-learned_reward: ${task.learning_reward}
-reward_mech:
- _target_: cmrl.models.BaseRewardMech
- # transition info
- obs_size: ???
- action_size: ???
- deterministic: fase
- # algorithm parameters
- learn_logvar_bounds: false # so far this works better
- ensemble_num: ${task.ensemble_num}
- elite_num: ${task.elite_num}
- # network parameters
- num_layers: 4
- hid_size: 200
- activation_fn_cfg:
- _target_: torch.nn.SiLU
- # others
- device: ${device}
-
-learned_termination: ${task.learning_terminal}
-termination_mech:
- _target_: cmrl.models.BaseTerminationMech
- # transition info
- obs_size: ???
- action_size: ???
- deterministic: false
-
-optim_lr: ${task.optim_lr}
-weight_decay: ${task.weight_decay}
-patience: ${task.patience}
-batch_size: ${task.batch_size}
-use_ratio: ${task.use_ratio}
-validation_ratio: ${task.validation_ratio}
-shuffle_each_epoch: ${task.shuffle_each_epoch}
-bootstrap_permutes: ${task.bootstrap_permutes}
-longest_epoch: ${task.longest_epoch}
-improvement_threshold: ${task.improvement_threshold}
-
-normalize: true
-normalize_double_precision: true
diff --git a/cmrl/examples/conf/main.yaml b/cmrl/examples/conf/main.yaml
index 65eda8e..39abbe5 100644
--- a/cmrl/examples/conf/main.yaml
+++ b/cmrl/examples/conf/main.yaml
@@ -1,20 +1,23 @@
defaults:
- algorithm: off_dyna
- - dynamics: constraint_based_dynamics
- - task: BIPS
+ - task: continuous_cart_pole_swingup
+ - transition: oracle
+ - reward_mech: oracle
+ - termination_mech: oracle
- _self_
seed: 0
-device: "cuda:0"
+device: "cpu"
exp_name: default
wandb: false
+verbose: false
root_dir: "./exp"
hydra:
run:
- dir: ${root_dir}/${exp_name}/${task.env}/${dynamics.name}/${now:%Y.%m.%d}/${now:%H.%M.%S}
+ dir: ${root_dir}/${exp_name}/${task.env_id}/${to_str:${task.params}}/${task.dataset}/${now:%Y.%m.%d.%H%M%S}
sweep:
- dir: ${root_dir}/${exp_name}/${task.env}/${dynamics.name}/${now:%Y.%m.%d}/${now:%H.%M.%S}
+ dir: ${root_dir}/${exp_name}/${task.env_id}/${to_str:${task.params}}/${task.dataset}/${now:%Y.%m.%d.%H%M%S}
job:
chdir: true
diff --git a/cmrl/examples/conf/reward_mech/oracle.yaml b/cmrl/examples/conf/reward_mech/oracle.yaml
new file mode 100644
index 0000000..ffe3f12
--- /dev/null
+++ b/cmrl/examples/conf/reward_mech/oracle.yaml
@@ -0,0 +1,62 @@
+name: "oracle_reward_mech"
+learn: false
+discovery: false
+
+encoder_cfg:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.networks.VariableEncoder
+ output_dim: 100
+ hidden_dims: [ 100 ]
+ bias: true
+ activation_fn_cfg:
+ _target_: torch.nn.SiLU
+
+decoder_cfg:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.networks.VariableDecoder
+ input_dim: 100
+ hidden_dims: [ 100 ]
+ bias: true
+ activation_fn_cfg:
+ _target_: torch.nn.SiLU
+
+network_cfg:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.networks.ParallelMLP
+ hidden_dims: [ 200, 200 ]
+ bias: true
+ activation_fn_cfg:
+ _target_: torch.nn.SiLU
+
+optimizer_cfg:
+ _partial_: true
+ _target_: torch.optim.Adam
+ lr: 1e-4
+ weight_decay: 1e-5
+ eps: 1e-8
+
+mech:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.causal_mech.OracleMech
+ # base causal-mech params
+ name: reward_mech
+ input_variables: ???
+ output_variables: ???
+ ensemble_num: 7
+ elite_num: 5
+ # cfgs
+ network_cfg: ${transition.network_cfg}
+ encoder_cfg: ${transition.encoder_cfg}
+ decoder_cfg: ${transition.decoder_cfg}
+ optimizer_cfg: ${transition.optimizer_cfg}
+ # forward method
+ residual: true
+ multi_step: "none"
+ # logger
+ logger: ???
+ # others
+ device: ${device}
diff --git a/cmrl/examples/conf/task/BI2PB.yaml b/cmrl/examples/conf/task/BI2PB.yaml
index 8d13511..8684ca1 100644
--- a/cmrl/examples/conf/task/BI2PB.yaml
+++ b/cmrl/examples/conf/task/BI2PB.yaml
@@ -1,7 +1,12 @@
-env: "emei___BoundaryInvertedDoublePendulumBalancing-v0___freq_rate=${task.freq_rate}&time_step=${task.time_step}___${task.dataset}"
+# env parameters
+env_id: "BoundaryInvertedDoublePendulumBalancing-v0"
+
+params:
+ freq_rate: 1
+ real_time_scale: 0.02
+ integrator: "euler"
+
dataset: "SAC-expert-replay"
-freq_rate: 1
-time_step: 0.02
# basic RL params
num_steps: 3000000
diff --git a/cmrl/examples/conf/task/BI2PS.yaml b/cmrl/examples/conf/task/BI2PS.yaml
index 06918ab..c5feda6 100644
--- a/cmrl/examples/conf/task/BI2PS.yaml
+++ b/cmrl/examples/conf/task/BI2PS.yaml
@@ -1,5 +1,12 @@
-env: "emei___BoundaryInvertedDoublePendulumSwingUp-v0___freq_rate=1&time_step=0.02___${task.dataset}"
-dataset: "expert-replay"
+# env parameters
+env_id: "BoundaryInvertedDoublePendulumSwingUp-v0"
+
+params:
+ freq_rate: 1
+ real_time_scale: 0.02
+ integrator: "euler"
+
+dataset: "SAC-expert-replay"
# basic RL params
num_steps: 1000000
diff --git a/cmrl/examples/conf/task/BIPB.yaml b/cmrl/examples/conf/task/BIPB.yaml
index a0536e4..321c5d2 100644
--- a/cmrl/examples/conf/task/BIPB.yaml
+++ b/cmrl/examples/conf/task/BIPB.yaml
@@ -1,7 +1,12 @@
-env: "emei___BoundaryInvertedPendulumBalancing-v0___freq_rate=${task.freq_rate}&time_step=${task.time_step}___${task.dataset}"
+# env parameters
+env_id: "BoundaryInvertedPendulumBalancing-v0"
+
+params:
+ freq_rate: 1
+ real_time_scale: 0.02
+ integrator: "euler"
+
dataset: "SAC-expert-replay"
-freq_rate: 1
-time_step: 0.02
# basic RL params
num_steps: 2000000
diff --git a/cmrl/examples/conf/task/BIPS.yaml b/cmrl/examples/conf/task/BIPS.yaml
index c9c0ea2..13de2af 100644
--- a/cmrl/examples/conf/task/BIPS.yaml
+++ b/cmrl/examples/conf/task/BIPS.yaml
@@ -1,10 +1,15 @@
-env: "emei___BoundaryInvertedPendulumSwingUp-v0___freq_rate=${task.freq_rate}&time_step=${task.time_step}___${task.dataset}"
+# env parameters
+env_id: "BoundaryInvertedPendulumSwingUp-v0"
+
+params:
+ freq_rate: 1
+ real_time_scale: 0.02
+ integrator: "euler"
+
dataset: "SAC-expert-replay"
-freq_rate: 1
-time_step: 0.02
# basic RL params
-num_steps: 300000
+num_steps: 10000000
online_num_steps: 10000
epoch_length: 10000
n_eval_episodes: 8
@@ -15,7 +20,7 @@ learning_reward: false
learning_terminal: false
ensemble_num: 7
elite_num: 5
-multi_step: "none"
+multi_step: "forward_euler_5"
# conditional mutual information test(causal discovery)
oracle: true
@@ -26,36 +31,13 @@ update_causal_mask_ratio: 0.25
discovery_schedule: [ 1, 30, 250, 250 ]
# offline
-penalty_coeff: 0.5
+penalty_coeff: 0.2
use_ratio: 1
# dyna
freq_train_model: 100
+
# model learning
patience: 20
-optim_lr: 0.0001
-weight_decay: 0.00001
-batch_size: 256
-validation_ratio: 0.2
-shuffle_each_epoch: true
-bootstrap_permutes: false
longest_epoch: -1
improvement_threshold: 0.01
-# model using
-effective_model_rollouts_per_step: 50
-rollout_schedule: [ 1, 15, 1, 1 ]
-num_sac_updates_per_step: 1
-sac_updates_every_steps: 1
-num_epochs_to_retain_sac_buffer: 1
-
-# SAC
-sac_gamma: 0.99
-sac_tau: 0.005
-sac_alpha: 0.2
-sac_policy: "Gaussian"
-sac_target_update_interval: 1
-sac_automatic_entropy_tuning: true
-sac_hidden_size: 256
-sac_lr: 0.0003
-sac_batch_size: 256
-sac_target_entropy: -1
diff --git a/cmrl/examples/conf/task/continuous_cart_pole_swingup.yaml b/cmrl/examples/conf/task/continuous_cart_pole_swingup.yaml
new file mode 100644
index 0000000..bdf1483
--- /dev/null
+++ b/cmrl/examples/conf/task/continuous_cart_pole_swingup.yaml
@@ -0,0 +1,29 @@
+# env parameters
+env_id: "ContinuousCartPoleSwingUp-v0"
+
+params:
+ freq_rate: 1
+ real_time_scale: 0.02
+ integrator: "euler"
+ gravity: 9.8
+ length: 0.5
+ force_mag: 10.0
+
+dataset: "SAC-expert-replay"
+
+extra_variable_info:
+ Radian:
+ - "obs_1"
+
+# basic RL params
+num_steps: 3000000
+online_num_steps: 10000
+n_eval_episodes: 5
+eval_freq: 10000
+
+# offline
+penalty_coeff: 1
+use_ratio: 1
+
+# dyna
+freq_train_model: 100
diff --git a/cmrl/examples/conf/task/hopper.yaml b/cmrl/examples/conf/task/hopper.yaml
new file mode 100644
index 0000000..9a851ce
--- /dev/null
+++ b/cmrl/examples/conf/task/hopper.yaml
@@ -0,0 +1,43 @@
+# env parameters
+env_id: "HopperRunning-v0"
+
+params:
+ freq_rate: 1
+ real_time_scale: 0.01
+ integrator: "euler"
+
+dataset: "SAC-medium"
+
+# basic RL params
+num_steps: 10000000
+online_num_steps: 10000
+epoch_length: 10000
+n_eval_episodes: 8
+eval_freq: 100
+
+# dynamics
+learning_reward: false
+learning_terminal: false
+ensemble_num: 7
+elite_num: 5
+multi_step: "none"
+
+# conditional mutual information test(causal discovery)
+oracle: true
+cit_threshold: 0.02
+test_freq: 100
+# causal
+update_causal_mask_ratio: 0.25
+discovery_schedule: [ 1, 30, 250, 250 ]
+
+# offline
+penalty_coeff: 1.0
+use_ratio: 1
+
+# dyna
+freq_train_model: 100
+
+# model learning
+patience: 10
+longest_epoch: -1
+improvement_threshold: 0.01
\ No newline at end of file
diff --git a/cmrl/examples/conf/task/mbpo_ant.yaml b/cmrl/examples/conf/task/mbpo_ant.yaml
deleted file mode 100644
index 813cd89..0000000
--- a/cmrl/examples/conf/task/mbpo_ant.yaml
+++ /dev/null
@@ -1,28 +0,0 @@
-env: "ant_truncated_obs"
-# term_fn is set automatically by cmrl.util.env.EnvHandler.make_env
-
-num_steps: 300000
-epoch_length: 1000
-num_elites: 5
-patience: 10
-model_lr: 0.0003
-model_wd: 5e-5
-model_batch_size: 256
-validation_ratio: 0.2
-freq_train_model: 250
-effective_model_rollouts_per_step: 400
-rollout_schedule: [20, 100, 1, 25]
-num_sac_updates_per_step: 20
-sac_updates_every_steps: 1
-num_epochs_to_retain_sac_buffer: 1
-
-sac_gamma: 0.99
-sac_tau: 0.005
-sac_alpha: 0.2
-sac_policy: "Gaussian"
-sac_target_update_interval: 4
-sac_automatic_entropy_tuning: false
-sac_target_entropy: -1 # ignored, since entropy tuning is false
-sac_hidden_size: 1024
-sac_lr: 0.0001
-sac_batch_size: 256
diff --git a/cmrl/examples/conf/task/mbpo_boundary_inverted_double_pendulum_swing_up.yaml b/cmrl/examples/conf/task/mbpo_boundary_inverted_double_pendulum_swing_up.yaml
deleted file mode 100644
index e31a548..0000000
--- a/cmrl/examples/conf/task/mbpo_boundary_inverted_double_pendulum_swing_up.yaml
+++ /dev/null
@@ -1,33 +0,0 @@
-env: "emei___BoundaryInvertedDoublePendulumSwingUp-v0___freq_rate=1&time_step=0.02"
-
-oracle: true
-cit_threshold: 0.02
-test_freq: 500
-
-num_steps: 800000
-epoch_length: 1000
-num_elites: 5
-patience: 5
-model_lr: 0.001
-model_wd: 0.00001
-model_batch_size: 256
-validation_ratio: 0.2
-freq_train_model: 250
-update_causal_mask_ratio: 0.25
-discovery_schedule: [1, 30, 250, 250]
-effective_model_rollouts_per_step: 400
-rollout_schedule: [1, 15, 100, 100]
-num_sac_updates_per_step: 10
-sac_updates_every_steps: 1
-num_epochs_to_retain_sac_buffer: 1
-
-sac_gamma: 0.99
-sac_tau: 0.005
-sac_alpha: 0.2
-sac_policy: "Gaussian"
-sac_target_update_interval: 1
-sac_automatic_entropy_tuning: true
-sac_hidden_size: 256
-sac_lr: 0.0003
-sac_batch_size: 256
-sac_target_entropy: -1
diff --git a/cmrl/examples/conf/task/mbpo_boundary_inverted_pendulum_holding.yaml b/cmrl/examples/conf/task/mbpo_boundary_inverted_pendulum_holding.yaml
deleted file mode 100644
index a8f7da6..0000000
--- a/cmrl/examples/conf/task/mbpo_boundary_inverted_pendulum_holding.yaml
+++ /dev/null
@@ -1,33 +0,0 @@
-env: "emei___BoundaryInvertedPendulumHolding-v0___freq_rate=1&time_step=0.02"
-
-oracle: true
-cit_threshold: 0.02
-test_freq: 1000
-
-num_steps: 20000
-epoch_length: 1000
-num_elites: 5
-patience: 5
-model_lr: 0.001
-model_wd: 0.00001
-model_batch_size: 256
-validation_ratio: 0.2
-freq_train_model: 250
-update_causal_mask_ratio: 0.25
-discovery_schedule: [1, 30, 250, 250]
-effective_model_rollouts_per_step: 400
-rollout_schedule: [1, 15, 100, 100]
-num_sac_updates_per_step: 10
-sac_updates_every_steps: 1
-num_epochs_to_retain_sac_buffer: 1
-
-sac_gamma: 0.99
-sac_tau: 0.005
-sac_alpha: 0.2
-sac_policy: "Gaussian"
-sac_target_update_interval: 1
-sac_automatic_entropy_tuning: true
-sac_hidden_size: 256
-sac_lr: 0.0003
-sac_batch_size: 256
-sac_target_entropy: -1
diff --git a/cmrl/examples/conf/task/mbpo_boundary_inverted_pendulum_swing_up.yaml b/cmrl/examples/conf/task/mbpo_boundary_inverted_pendulum_swing_up.yaml
deleted file mode 100644
index 6dbbc1a..0000000
--- a/cmrl/examples/conf/task/mbpo_boundary_inverted_pendulum_swing_up.yaml
+++ /dev/null
@@ -1,33 +0,0 @@
-env: "emei___BoundaryInvertedPendulumSwingUp-v0___freq_rate=1&time_step=0.02"
-
-oracle: true
-cit_threshold: 0.02
-test_freq: 500
-
-num_steps: 8000
-epoch_length: 1000
-num_elites: 5
-patience: 5
-model_lr: 0.001
-model_wd: 0.00001
-model_batch_size: 256
-validation_ratio: 0.2
-freq_train_model: 250
-update_causal_mask_ratio: 0.25
-discovery_schedule: [1, 30, 250, 250]
-effective_model_rollouts_per_step: 400
-rollout_schedule: [1, 15, 1, 1]
-num_sac_updates_per_step: 10
-sac_updates_every_steps: 1
-num_epochs_to_retain_sac_buffer: 1
-
-sac_gamma: 0.99
-sac_tau: 0.005
-sac_alpha: 0.2
-sac_policy: "Gaussian"
-sac_target_update_interval: 1
-sac_automatic_entropy_tuning: true
-sac_hidden_size: 256
-sac_lr: 0.0003
-sac_batch_size: 256
-sac_target_entropy: -1
diff --git a/cmrl/examples/conf/task/mbpo_cartpole.yaml b/cmrl/examples/conf/task/mbpo_cartpole.yaml
deleted file mode 100644
index a6016cf..0000000
--- a/cmrl/examples/conf/task/mbpo_cartpole.yaml
+++ /dev/null
@@ -1,28 +0,0 @@
-env: "cartpole_continuous"
-trial_length: 200
-
-num_steps: 5000
-epoch_length: 200
-num_elites: 5
-patience: 5
-model_lr: 0.001
-model_wd: 0.00005
-model_batch_size: 256
-validation_ratio: 0.2
-freq_train_model: 200
-effective_model_rollouts_per_step: 400
-rollout_schedule: [1, 15, 1, 1]
-num_sac_updates_per_step: 20
-sac_updates_every_steps: 1
-num_epochs_to_retain_sac_buffer: 1
-
-sac_gamma: 0.99
-sac_tau: 0.005
-sac_alpha: 0.2
-sac_policy: "Gaussian"
-sac_target_update_interval: 4
-sac_automatic_entropy_tuning: true
-sac_target_entropy: -0.05
-sac_hidden_size: 256
-sac_lr: 0.0003
-sac_batch_size: 256
diff --git a/cmrl/examples/conf/task/mbpo_halfcheetah.yaml b/cmrl/examples/conf/task/mbpo_halfcheetah.yaml
deleted file mode 100644
index 3b5e3f9..0000000
--- a/cmrl/examples/conf/task/mbpo_halfcheetah.yaml
+++ /dev/null
@@ -1,28 +0,0 @@
-env: "gym___HalfCheetah-v2"
-term_fn: "no_termination"
-
-num_steps: 400000
-epoch_length: 1000
-num_elites: 5
-patience: 5
-model_lr: 0.001
-model_wd: 0.00001
-model_batch_size: 256
-validation_ratio: 0.2
-freq_train_model: 250
-effective_model_rollouts_per_step: 400
-rollout_schedule: [20, 150, 1, 1]
-num_sac_updates_per_step: 10
-sac_updates_every_steps: 1
-num_epochs_to_retain_sac_buffer: 1
-
-sac_gamma: 0.99
-sac_tau: 0.005
-sac_alpha: 0.2
-sac_policy: "Gaussian"
-sac_target_update_interval: 1
-sac_automatic_entropy_tuning: true
-sac_target_entropy: -1
-sac_hidden_size: 512
-sac_lr: 0.0003
-sac_batch_size: 256
diff --git a/cmrl/examples/conf/task/mbpo_hopper.yaml b/cmrl/examples/conf/task/mbpo_hopper.yaml
deleted file mode 100644
index 5ee267c..0000000
--- a/cmrl/examples/conf/task/mbpo_hopper.yaml
+++ /dev/null
@@ -1,28 +0,0 @@
-env: "gym___Hopper-v2"
-term_fn: "hopper"
-
-num_steps: 500000
-epoch_length: 1000
-num_elites: 5
-patience: 5
-model_lr: 0.001
-model_wd: 0.00001
-model_batch_size: 256
-validation_ratio: 0.2
-freq_train_model: 250
-effective_model_rollouts_per_step: 400
-rollout_schedule: [20, 150, 200, 200]
-num_sac_updates_per_step: 100
-sac_updates_every_steps: 1
-num_epochs_to_retain_sac_buffer: 1
-
-sac_gamma: 0.99
-sac_tau: 0.005
-sac_alpha: 0.2
-sac_policy: "Gaussian"
-sac_target_update_interval: 4
-sac_automatic_entropy_tuning: false
-sac_target_entropy: 1 # ignored, since entropy tuning is false
-sac_hidden_size: 512
-sac_lr: 0.0003
-sac_batch_size: 256
diff --git a/cmrl/examples/conf/task/mbpo_humanoid.yaml b/cmrl/examples/conf/task/mbpo_humanoid.yaml
deleted file mode 100644
index 4cf0d8a..0000000
--- a/cmrl/examples/conf/task/mbpo_humanoid.yaml
+++ /dev/null
@@ -1,28 +0,0 @@
-env: "humanoid_truncated_obs"
-# term_fn is set automatically by cmrl.util.env.EnvHandler.make_env
-
-num_steps: 300000
-epoch_length: 1000
-num_elites: 5
-patience: 10
-model_lr: 0.0003
-model_wd: 5e-5
-model_batch_size: 256
-validation_ratio: 0.2
-freq_train_model: 250
-effective_model_rollouts_per_step: 400
-rollout_schedule: [20, 300, 1, 25]
-num_sac_updates_per_step: 20
-sac_updates_every_steps: 1
-num_epochs_to_retain_sac_buffer: 5
-
-sac_gamma: 0.99
-sac_tau: 0.005
-sac_alpha: 0.2
-sac_policy: "Gaussian"
-sac_target_update_interval: 4
-sac_automatic_entropy_tuning: false
-sac_target_entropy: -1 # ignored, since entropy tuning is false
-sac_hidden_size: 1024
-sac_lr: 0.0001
-sac_batch_size: 256
diff --git a/cmrl/examples/conf/task/mbpo_inv_pendulum.yaml b/cmrl/examples/conf/task/mbpo_inv_pendulum.yaml
deleted file mode 100644
index 912d4f5..0000000
--- a/cmrl/examples/conf/task/mbpo_inv_pendulum.yaml
+++ /dev/null
@@ -1,29 +0,0 @@
-env: "inv_pendulum___0.25___1"
-
-test_freq: 500
-
-num_steps: 500000
-epoch_length: 1000
-num_elites: 5
-patience: 5
-model_lr: 0.001
-model_wd: 0.00001
-model_batch_size: 256
-validation_ratio: 0.2
-freq_train_model: 250
-effective_model_rollouts_per_step: 400
-rollout_schedule: [1, 15, 10, 10]
-num_sac_updates_per_step: 10
-sac_updates_every_steps: 1
-num_epochs_to_retain_sac_buffer: 1
-
-sac_gamma: 0.99
-sac_tau: 0.005
-sac_alpha: 0.2
-sac_policy: "Gaussian"
-sac_target_update_interval: 1
-sac_automatic_entropy_tuning: true
-sac_hidden_size: 256
-sac_lr: 0.0003
-sac_batch_size: 256
-sac_target_entropy: -1
diff --git a/cmrl/examples/conf/task/mbpo_pusher.yaml b/cmrl/examples/conf/task/mbpo_pusher.yaml
deleted file mode 100644
index c90c0b1..0000000
--- a/cmrl/examples/conf/task/mbpo_pusher.yaml
+++ /dev/null
@@ -1,29 +0,0 @@
-env: "pets_pusher"
-term_fn: "no_termination"
-trial_length: 150
-
-num_steps: 20000
-epoch_length: 150
-num_elites: 5
-patience: 5
-model_lr: 0.001
-model_wd: 0.00005
-model_batch_size: 256
-validation_ratio: 0.2
-freq_train_model: 250
-effective_model_rollouts_per_step: 400
-rollout_schedule: [1, 15, 1, 1]
-num_sac_updates_per_step: 20
-sac_updates_every_steps: 1
-num_epochs_to_retain_sac_buffer: 1
-
-sac_gamma: 0.99
-sac_tau: 0.005
-sac_alpha: 0.2
-sac_policy: "Gaussian"
-sac_target_update_interval: 4
-sac_automatic_entropy_tuning: true
-sac_target_entropy: -0.05
-sac_hidden_size: 256
-sac_lr: 0.0003
-sac_batch_size: 256
diff --git a/cmrl/examples/conf/task/mbpo_walker.yaml b/cmrl/examples/conf/task/mbpo_walker.yaml
deleted file mode 100644
index 5ef2095..0000000
--- a/cmrl/examples/conf/task/mbpo_walker.yaml
+++ /dev/null
@@ -1,28 +0,0 @@
-env: "gym___Walker2d-v2"
-term_fn: "walker2d"
-
-num_steps: 300000
-epoch_length: 1000
-num_elites: 5
-patience: 10
-model_lr: 0.001
-model_wd: 0.00001
-model_batch_size: 256
-validation_ratio: 0.2
-freq_train_model: 250
-effective_model_rollouts_per_step: 400
-rollout_schedule: [20, 150, 1, 1]
-num_sac_updates_per_step: 20
-sac_updates_every_steps: 1
-num_epochs_to_retain_sac_buffer: 1
-
-sac_gamma: 0.99
-sac_tau: 0.005
-sac_alpha: 0.2
-sac_policy: "Gaussian"
-sac_target_update_interval: 4
-sac_automatic_entropy_tuning: false
-sac_target_entropy: -1 # ignored, since entropy tuning is false
-sac_hidden_size: 1024
-sac_lr: 0.0001
-sac_batch_size: 256
diff --git a/cmrl/examples/conf/task/parallel_cart_pole.yaml b/cmrl/examples/conf/task/parallel_cart_pole.yaml
new file mode 100644
index 0000000..35145cd
--- /dev/null
+++ b/cmrl/examples/conf/task/parallel_cart_pole.yaml
@@ -0,0 +1,29 @@
+# env parameters
+env_id: "ParallelContinuousCartPoleSwingUp-v0"
+
+params:
+ freq_rate: 1
+ real_time_scale: 0.02
+ integrator: "euler"
+ parallel_num: 3
+
+dataset: "SAC-expert-replay"
+
+extra_variable_info:
+ Radian:
+ - "obs_1"
+ - "obs_5"
+ - "obs_9"
+
+# basic RL params
+num_steps: 10000000
+online_num_steps: 10000
+n_eval_episodes: 5
+eval_freq: 10000
+
+# offline
+penalty_coeff: 1
+use_ratio: 1
+
+# dyna
+freq_train_model: 100
diff --git a/cmrl/examples/conf/termination_mech/oracle.yaml b/cmrl/examples/conf/termination_mech/oracle.yaml
new file mode 100644
index 0000000..c6077ce
--- /dev/null
+++ b/cmrl/examples/conf/termination_mech/oracle.yaml
@@ -0,0 +1,62 @@
+name: "oracle_termination_mech"
+learn: false
+discovery: false
+
+encoder_cfg:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.networks.VariableEncoder
+ output_dim: 100
+ hidden_dims: [ 100 ]
+ bias: true
+ activation_fn_cfg:
+ _target_: torch.nn.SiLU
+
+decoder_cfg:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.networks.VariableDecoder
+ input_dim: 100
+ hidden_dims: [ 100 ]
+ bias: true
+ activation_fn_cfg:
+ _target_: torch.nn.SiLU
+
+network_cfg:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.networks.ParallelMLP
+ hidden_dims: [ 200, 200 ]
+ bias: true
+ activation_fn_cfg:
+ _target_: torch.nn.SiLU
+
+optimizer_cfg:
+ _partial_: true
+ _target_: torch.optim.Adam
+ lr: 1e-4
+ weight_decay: 1e-5
+ eps: 1e-8
+
+mech:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.causal_mech.OracleMech
+ # base causal-mech params
+ name: termination_mech
+ input_variables: ???
+ output_variables: ???
+ ensemble_num: 7
+ elite_num: 5
+ # cfgs
+ network_cfg: ${transition.network_cfg}
+ encoder_cfg: ${transition.encoder_cfg}
+ decoder_cfg: ${transition.decoder_cfg}
+ optimizer_cfg: ${transition.optimizer_cfg}
+ # forward method
+ residual: true
+ multi_step: "none"
+ # logger
+ logger: ???
+ # others
+ device: ${device}
diff --git a/cmrl/examples/conf/transition/CMI_test.yaml b/cmrl/examples/conf/transition/CMI_test.yaml
new file mode 100644
index 0000000..4461b49
--- /dev/null
+++ b/cmrl/examples/conf/transition/CMI_test.yaml
@@ -0,0 +1,74 @@
+name: "CMI_test_transition"
+learn: true
+
+encoder_cfg:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.networks.VariableEncoder
+ output_dim: 200
+ hidden_dims: [ 100 ]
+ bias: true
+ activation_fn_cfg:
+ _target_: torch.nn.SiLU
+
+decoder_cfg:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.networks.VariableDecoder
+ input_dim: 100
+ hidden_dims: [ 100 ]
+ bias: true
+ activation_fn_cfg:
+ _target_: torch.nn.SiLU
+
+network_cfg:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.networks.ParallelMLP
+ hidden_dims: [ 100, 100 ]
+ bias: true
+ activation_fn_cfg:
+ _target_: torch.nn.SiLU
+
+optimizer_cfg:
+ _partial_: true
+ _target_: torch.optim.Adam
+ lr: 1e-4
+ weight_decay: 1e-5
+ eps: 1e-8
+
+scheduler_cfg:
+ _partial_: true
+ _target_: torch.optim.lr_scheduler.StepLR
+ step_size: 1
+ gamma: 1
+
+mech:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.causal_mech.CMITestMEch
+ # base causal-mech params
+ name: transition
+ input_variables: ???
+ output_variables: ???
+ # model learning
+ patience: 5
+ longest_epoch: -1
+ improvement_threshold: 0.01
+ batch_size: 256
+ # ensemble
+ ensemble_num: 7
+ elite_num: 5
+ # cfgs
+ network_cfg: ${transition.network_cfg}
+ encoder_cfg: ${transition.encoder_cfg}
+ decoder_cfg: ${transition.decoder_cfg}
+ optimizer_cfg: ${transition.optimizer_cfg}
+ scheduler_cfg: ${transition.scheduler_cfg}
+ # forward method
+ residual: true
+ encoder_reduction: "sum"
+ # logger
+ logger: ???
+ # others
+ device: ${device}
diff --git a/cmrl/examples/conf/transition/kernel_test.yaml b/cmrl/examples/conf/transition/kernel_test.yaml
new file mode 100644
index 0000000..f7d750b
--- /dev/null
+++ b/cmrl/examples/conf/transition/kernel_test.yaml
@@ -0,0 +1,77 @@
+name: "kernal_test_transition"
+learn: true
+
+encoder_cfg:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.networks.VariableEncoder
+ output_dim: 200
+ hidden_dims: [ 100 ]
+ bias: true
+ activation_fn_cfg:
+ _target_: torch.nn.SiLU
+
+decoder_cfg:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.networks.VariableDecoder
+ input_dim: 100
+ hidden_dims: [ 100 ]
+ bias: true
+ activation_fn_cfg:
+ _target_: torch.nn.SiLU
+
+network_cfg:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.networks.ParallelMLP
+ hidden_dims: [ 100, 100 ]
+ bias: true
+ activation_fn_cfg:
+ _target_: torch.nn.SiLU
+
+optimizer_cfg:
+ _partial_: true
+ _target_: torch.optim.Adam
+ lr: 1e-4
+ weight_decay: 1e-5
+ eps: 1e-8
+
+scheduler_cfg:
+ _partial_: true
+ _target_: torch.optim.lr_scheduler.StepLR
+ step_size: 1
+ gamma: 1
+
+mech:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.causal_mech.KernelTestMech
+ # base causal-mech params
+ name: transition
+ input_variables: ???
+ output_variables: ???
+ logger: ???
+ # model learning
+ patience: 5
+ longest_epoch: -1
+ improvement_threshold: 0.01
+ batch_size: 256
+ # ensemble
+ ensemble_num: 7
+ elite_num: 5
+ # cfgs
+ network_cfg: ${transition.network_cfg}
+ encoder_cfg: ${transition.encoder_cfg}
+ decoder_cfg: ${transition.decoder_cfg}
+ optimizer_cfg: ${transition.optimizer_cfg}
+ scheduler_cfg: ${transition.scheduler_cfg}
+ # forward method
+ residual: true
+ encoder_reduction: "sum"
+ # others
+ device: ${device}
+ # KCI
+ sample_num: 256
+ kci_times: 16
+ not_confident_bound: 0.2
diff --git a/cmrl/examples/conf/transition/oracle.yaml b/cmrl/examples/conf/transition/oracle.yaml
new file mode 100644
index 0000000..45fae3b
--- /dev/null
+++ b/cmrl/examples/conf/transition/oracle.yaml
@@ -0,0 +1,74 @@
+name: "oracle_transition"
+learn: true
+oracle: "truth"
+
+encoder_cfg:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.networks.VariableEncoder
+ output_dim: 200
+ hidden_dims: [ 100 ]
+ bias: true
+ activation_fn_cfg:
+ _target_: torch.nn.SiLU
+
+decoder_cfg:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.networks.VariableDecoder
+ input_dim: 100
+ hidden_dims: [ 100 ]
+ bias: true
+ activation_fn_cfg:
+ _target_: torch.nn.SiLU
+
+network_cfg:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.networks.ParallelMLP
+ hidden_dims: [ 100, 100]
+ bias: true
+ activation_fn_cfg:
+ _target_: torch.nn.SiLU
+
+optimizer_cfg:
+ _partial_: true
+ _target_: torch.optim.Adam
+ lr: 1e-4
+ weight_decay: 1e-5
+ eps: 1e-8
+
+scheduler_cfg:
+ _partial_: true
+ _target_: torch.optim.lr_scheduler.StepLR
+ step_size: 1
+ gamma: 0.8
+
+mech:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.causal_mech.OracleMech
+ # base causal-mech params
+ name: transition
+ input_variables: ???
+ output_variables: ???
+ # model learning
+ patience: 5
+ longest_epoch: -1
+ improvement_threshold: 0.01
+ batch_size: 1024
+ # ensemble
+ ensemble_num: 7
+ elite_num: 5
+ # cfgs
+ network_cfg: ${transition.network_cfg}
+ encoder_cfg: ${transition.encoder_cfg}
+ decoder_cfg: ${transition.decoder_cfg}
+ optimizer_cfg: ${transition.optimizer_cfg}
+ scheduler_cfg: ${transition.scheduler_cfg}
+ # forward method
+ residual: true
+ # logger
+ logger: ???
+ # others
+ device: ${device}
diff --git a/cmrl/examples/conf/transition/reinforce.yaml b/cmrl/examples/conf/transition/reinforce.yaml
new file mode 100644
index 0000000..44f9102
--- /dev/null
+++ b/cmrl/examples/conf/transition/reinforce.yaml
@@ -0,0 +1,81 @@
+name: "reinforce_transition"
+learn: true
+discovery: true
+
+encoder_cfg:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.networks.VariableEncoder
+ output_dim: 100
+ hidden_dims: [ 100 ]
+ bias: true
+ activation_fn_cfg:
+ _target_: torch.nn.SiLU
+
+decoder_cfg:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.networks.VariableDecoder
+ input_dim: 100
+ hidden_dims: [ 100 ]
+ bias: true
+ activation_fn_cfg:
+ _target_: torch.nn.SiLU
+
+network_cfg:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.networks.ParallelMLP
+ hidden_dims: [ 200, 200 ]
+ bias: true
+ activation_fn_cfg:
+ _target_: torch.nn.SiLU
+
+optimizer_cfg:
+ _partial_: true
+ _target_: torch.optim.Adam
+ lr: 1e-4
+ weight_decay: 1e-5
+ eps: 1e-8
+
+graph_optimizer_cfg:
+ _partial_: true
+ _target_: torch.optim.Adam
+ lr: 1e-3
+ weight_decay: 0.0
+ eps: 1e-8
+
+mech:
+ _partial_: true
+ _recursive_: false
+ _target_: cmrl.models.causal_mech.ReinforceCausalMech
+ # base causal-mech params
+ name: transition
+ input_variables: ???
+ output_variables: ???
+ # model learning
+ patience: 5
+ longest_epoch: -1
+ improvement_threshold: 0.01
+ # ensemble
+ ensemble_num: 7
+ elite_num: 5
+ # cfgs
+ network_cfg: ${transition.network_cfg}
+ encoder_cfg: ${transition.encoder_cfg}
+ decoder_cfg: ${transition.decoder_cfg}
+ optimizer_cfg: ${transition.optimizer_cfg}
+ graph_optimizer_cfg: ${transition.graph_optimizer_cfg}
+ # graph params
+ concat_mask: true
+ graph_MC_samples: 20
+ graph_max_stack: 20
+ lambda_sparse: 5e-2
+ # forward method
+ residual: true
+ encoder_reduction: "sum"
+ multi_step: "forward-euler 1"
+ # logger
+ logger: ???
+ # others
+ device: ${device}
diff --git a/cmrl/examples/main.py b/cmrl/examples/main.py
index 59282b0..e64eb57 100644
--- a/cmrl/examples/main.py
+++ b/cmrl/examples/main.py
@@ -1,43 +1,15 @@
import hydra
-import numpy as np
-import torch
-import wandb
+from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
-
-from cmrl.algorithms import mopo, mbpo, off_dyna, on_dyna
-from cmrl.util.env import make_env
+from emei.core import get_params_str
@hydra.main(version_base=None, config_path="conf", config_name="main")
def run(cfg: DictConfig):
- if cfg.wandb:
- wandb.init(
- project="causal-mbrl",
- group=cfg.exp_name,
- config=OmegaConf.to_container(cfg, resolve=True),
- sync_tensorboard=True,
- )
-
- env, term_fn, reward_fn, init_obs_fn = make_env(cfg)
- test_env, *_ = make_env(cfg)
- np.random.seed(cfg.seed)
- torch.manual_seed(cfg.seed)
-
- if cfg.algorithm.name == "on_dyna":
- test_env, *_ = make_env(cfg)
- return on_dyna.train(env, test_env, term_fn, reward_fn, init_obs_fn, cfg)
- elif cfg.algorithm.name == "mopo":
- test_env, *_ = make_env(cfg)
- return mopo.train(env, test_env, term_fn, reward_fn, init_obs_fn, cfg)
- elif cfg.algorithm.name == "off_dyna":
- test_env, *_ = make_env(cfg)
- return off_dyna.train(env, test_env, term_fn, reward_fn, init_obs_fn, cfg)
- elif cfg.algorithm.name == "mbpo":
- test_env, *_ = make_env(cfg)
- return mbpo.train(env, test_env, term_fn, reward_fn, init_obs_fn, cfg)
- else:
- raise NotImplementedError
+ algo = instantiate(cfg.algorithm.algo)(cfg=cfg)
+ algo.learn()
if __name__ == "__main__":
+ OmegaConf.register_new_resolver("to_str", get_params_str)
run()
diff --git a/cmrl/models/causal_discovery/CMI_test.py b/cmrl/models/causal_discovery/CMI_test.py
deleted file mode 100644
index 34a6668..0000000
--- a/cmrl/models/causal_discovery/CMI_test.py
+++ /dev/null
@@ -1,130 +0,0 @@
-from typing import Dict, Optional, Sequence, Tuple, Union
-
-import hydra
-import omegaconf
-import torch
-from torch import nn as nn
-from torch.nn import functional as F
-
-from cmrl.models.util import gaussian_nll
-from cmrl.models.layers import ParallelEnsembleLinearLayer, truncated_normal_init
-from cmrl.models.networks.mlp import EnsembleMLP
-from cmrl.models.util import to_tensor
-
-
-class TransitionConditionalMutualInformationTest(EnsembleMLP):
- _MODEL_FILENAME = "conditional_mutual_information_test.pth"
-
- def __init__(
- self,
- # transition info
- obs_size: int,
- action_size: int,
- # algorithm parameters
- ensemble_num: int = 7,
- elite_num: int = 5,
- residual: bool = True,
- learn_logvar_bounds: bool = False,
- # network parameters
- num_layers: int = 4,
- hid_size: int = 200,
- activation_fn_cfg: Optional[Union[Dict, omegaconf.DictConfig]] = None,
- # others
- device: Union[str, torch.device] = "cpu",
- ):
- super().__init__(
- ensemble_num=ensemble_num,
- elite_num=elite_num,
- device=device,
- )
- self.obs_size = obs_size
- self.action_size = action_size
- self.residual = residual
- self.learn_logvar_bounds = learn_logvar_bounds
-
- self.num_layers = num_layers
- self.hid_size = hid_size
-
- self.parallel_num = self.obs_size + self.action_size + 1
-
- self._input_mask = 1 - torch.eye(self.parallel_num, self.obs_size + self.action_size).to(self.device)
-
- def create_activation():
- if activation_fn_cfg is None:
- return nn.ReLU()
- else:
- return hydra.utils.instantiate(activation_fn_cfg)
-
- hidden_layers = [
- nn.Sequential(
- self.create_linear_layer(obs_size + action_size, hid_size),
- create_activation(),
- )
- ]
- for i in range(num_layers - 1):
- hidden_layers.append(
- nn.Sequential(
- self.create_linear_layer(hid_size, hid_size),
- create_activation(),
- )
- )
- self.hidden_layers = nn.Sequential(*hidden_layers)
-
- self.mean_and_logvar = self.create_linear_layer(hid_size, 2 * self.obs_size)
- self.min_logvar = nn.Parameter(
- -10 * torch.ones(self.parallel_num, 1, 1, self.obs_size), requires_grad=learn_logvar_bounds
- )
- self.max_logvar = nn.Parameter(
- 0.5 * torch.ones(self.parallel_num, 1, 1, self.obs_size), requires_grad=learn_logvar_bounds
- )
-
- self.apply(truncated_normal_init)
- self.to(self.device)
-
- def create_linear_layer(self, l_in, l_out):
- return ParallelEnsembleLinearLayer(l_in, l_out, parallel_num=self.parallel_num, ensemble_num=self.ensemble_num)
-
- @property
- def input_mask(self):
- return self._input_mask
-
- def mask_input(self, x: torch.Tensor) -> torch.Tensor:
- assert x.ndim == 4
- assert self._input_mask.ndim == 2
- input_mask = self._input_mask[:, None, None, :]
- return x * input_mask
-
- def forward(
- self,
- batch_obs: torch.Tensor, # shape: (parallel_num, )ensemble_num, batch_size, obs_size
- batch_action: torch.Tensor, # shape: ensemble_num, batch_size, action_size
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
- assert len(batch_action.shape) == 3 and batch_action.shape[-1] == self.action_size
-
- batch_action = batch_action.repeat((self.parallel_num, 1, 1, 1))
- if len(batch_obs.shape) == 3: # non-repeat or first repeat
- batch_obs = batch_obs.repeat((self.parallel_num, 1, 1, 1))
-
- batch_input = torch.concat([batch_obs, batch_action], dim=-1)
-
- masked_input = self.mask_input(batch_input)
- hidden = self.hidden_layers(masked_input)
- mean_and_logvar = self.mean_and_logvar(hidden)
-
- mean = mean_and_logvar[..., : self.obs_size]
- logvar = mean_and_logvar[..., self.obs_size :]
- logvar = self.max_logvar - F.softplus(self.max_logvar - logvar)
- logvar = self.min_logvar + F.softplus(logvar - self.min_logvar)
-
- if self.residual:
- mean += batch_obs
-
- return mean, logvar
-
- def get_nll_loss(self, model_in: Dict[(str, torch.Tensor)], target: torch.Tensor) -> torch.Tensor:
- pred_mean, pred_logvar = self.forward(**model_in)
- target = target.repeat((self.parallel_num, 1, 1, 1))
-
- nll_loss = gaussian_nll(pred_mean, pred_logvar, target, reduce=False)
- nll_loss += 0.01 * (self.max_logvar.sum() - self.min_logvar.sum())
- return nll_loss
diff --git a/cmrl/models/causal_mech/CMI_test.py b/cmrl/models/causal_mech/CMI_test.py
new file mode 100644
index 0000000..4cadde4
--- /dev/null
+++ b/cmrl/models/causal_mech/CMI_test.py
@@ -0,0 +1,274 @@
+from typing import Optional, List, Dict, Union, MutableMapping
+import pathlib
+from functools import partial
+from itertools import count
+
+import torch
+import numpy as np
+from torch.utils.data import DataLoader
+from omegaconf import DictConfig
+from hydra.utils import instantiate
+from stable_baselines3.common.logger import Logger
+
+from cmrl.utils.variables import Variable
+from cmrl.models.causal_mech.base import EnsembleNeuralMech
+from cmrl.models.graphs.binary_graph import BinaryGraph
+from cmrl.models.causal_mech.util import variable_loss_func, train_func, eval_func
+
+
+class CMITestMech(EnsembleNeuralMech):
+ def __init__(
+ self,
+ name: str,
+ input_variables: List[Variable],
+ output_variables: List[Variable],
+ logger: Optional[Logger] = None,
+ # model learning
+ longest_epoch: int = -1,
+ improvement_threshold: float = 0.01,
+ patience: int = 5,
+ batch_size: int = 256,
+ # ensemble
+ ensemble_num: int = 7,
+ elite_num: int = 5,
+ # cfgs
+ network_cfg: Optional[DictConfig] = None,
+ encoder_cfg: Optional[DictConfig] = None,
+ decoder_cfg: Optional[DictConfig] = None,
+ optimizer_cfg: Optional[DictConfig] = None,
+ # forward method
+ residual: bool = True,
+ encoder_reduction: str = "sum",
+ # others
+ device: Union[str, torch.device] = "cpu",
+ ):
+ EnsembleNeuralMech.__init__(
+ self,
+ name=name,
+ input_variables=input_variables,
+ output_variables=output_variables,
+ logger=logger,
+ longest_epoch=longest_epoch,
+ improvement_threshold=improvement_threshold,
+ patience=patience,
+ batch_size=batch_size,
+ ensemble_num=ensemble_num,
+ elite_num=elite_num,
+ network_cfg=network_cfg,
+ encoder_cfg=encoder_cfg,
+ decoder_cfg=decoder_cfg,
+ optimizer_cfg=optimizer_cfg,
+ residual=residual,
+ encoder_reduction=encoder_reduction,
+ device=device,
+ )
+
+ self.total_CMI_epoch = 0
+
+ def build_network(self):
+ self.network = instantiate(self.network_cfg)(
+ input_dim=self.encoder_output_dim,
+ output_dim=self.decoder_input_dim,
+ extra_dims=[self.output_var_num, self.ensemble_num],
+ ).to(self.device)
+
+ def build_graph(self):
+ self.graph = BinaryGraph(self.input_var_num, self.output_var_num, device=self.device)
+
+ @property
+ def CMI_mask(self) -> torch.Tensor:
+ mask = torch.zeros(self.input_var_num + 1, self.output_var_num, self.input_var_num, dtype=torch.long)
+ for i in range(self.input_var_num + 1):
+ m = torch.ones(self.output_var_num, self.input_var_num)
+ if i != self.input_var_num:
+ m[:, i] = 0
+ mask[i] = m
+ return mask.to(self.device)
+
+ def multi_graph_forward(self, inputs: MutableMapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """when first step, inputs should be dict of str and Tensor with (ensemble-num, batch-size, specific-dim) shape,
+ since twice step, the shape of Tensor becomes (input-var-num + 1, ensemble-num, batch-size, specific-dim)
+
+ Args:
+ inputs:
+
+ Returns:
+
+ """
+ batch_size, extra_dim = self.get_inputs_info(inputs)
+
+ inputs_tensor = torch.empty(*extra_dim, self.ensemble_num, batch_size, self.input_var_num, self.encoder_output_dim).to(
+ self.device
+ )
+ for i, var in enumerate(self.input_variables):
+ out = self.variable_encoders[var.name](inputs[var.name].to(self.device))
+ inputs_tensor[..., i, :] = out
+
+ # if len(extra_dim) == 0:
+ # # [..., output-var-num, input-var-num]
+ # mask = self.CMI_mask
+ # # [..., output-var-num, ensemble-num, batch-size, input-var-num]
+ # mask = mask.unsqueeze(-2).unsqueeze(-2)
+ # mask = mask.repeat((1,) * len(mask.shape[:-3]) + (self.ensemble_num, batch_size, 1))
+ # reduced_inputs_tensor = self.reduce_encoder_output(inputs_tensor, mask)
+ # assert (
+ # not torch.isinf(reduced_inputs_tensor).any() and not torch.isnan(reduced_inputs_tensor).any()
+ # ), "tensor must not be inf or nan"
+ # output_tensor = self.network(reduced_inputs_tensor)
+ # else:
+ # output_tensor = torch.empty(
+ # *extra_dim, self.output_var_num, self.ensemble_num, batch_size, self.decoder_input_dim
+ # ).to(self.device)
+ #
+ # CMI_mask = self.CMI_mask
+ # for i in range(self.input_var_num + 1):
+ # # [..., output-var-num, input-var-num]
+ # mask = CMI_mask[i]
+ # # [..., output-var-num, ensemble-num, batch-size, input-var-num]
+ # mask = mask.unsqueeze(-2).unsqueeze(-2)
+ # mask = mask.repeat((1,) * len(mask.shape[:-3]) + (self.ensemble_num, batch_size, 1))
+ # if i == len(inputs_tensor) - 1:
+ # reduced_inputs_tensor = self.reduce_encoder_output(inputs_tensor[i], mask)
+ # outs = self.network(reduced_inputs_tensor)
+ # output_tensor[i] = outs
+ # else:
+ # for j in range(self.output_var_num):
+ # ins = inputs_tensor[-1]
+ # ins[:, :, j] = inputs_tensor[i, :, :, j, :]
+ # reduced_inputs_tensor = self.reduce_encoder_output(inputs_tensor[i], mask)
+ # outs = self.network(reduced_inputs_tensor)
+ # output_tensor[i, j] = outs[j]
+
+ mask = self.CMI_mask
+ # [..., output-var-num, ensemble-num, batch-size, input-var-num]
+ mask = mask.unsqueeze(-2).unsqueeze(-2)
+ mask = mask.repeat((1,) * len(mask.shape[:-3]) + (self.ensemble_num, batch_size, 1))
+ reduced_inputs_tensor = self.reduce_encoder_output(inputs_tensor, mask)
+ assert (
+ not torch.isinf(reduced_inputs_tensor).any() and not torch.isnan(reduced_inputs_tensor).any()
+ ), "tensor must not be inf or nan"
+ output_tensor = self.network(reduced_inputs_tensor)
+
+ outputs = {}
+ for i, var in enumerate(self.output_variables):
+ hid = output_tensor[:, i]
+ outputs[var.name] = self.variable_decoders[var.name](hid)
+
+ if self.residual:
+ outputs = self.residual_outputs(inputs, outputs)
+ return outputs
+
+ def calculate_CMI(self, nll_loss: torch.Tensor, threshold=1):
+ nll_loss_diff = nll_loss[:-1] - nll_loss[-1]
+ graph_data = (nll_loss_diff.mean(dim=(1, 2)) > threshold).to(torch.long)
+ return graph_data, nll_loss_diff.mean(dim=(1, 2))
+
+ def learn(
+ self,
+ inputs: MutableMapping[str, np.ndarray],
+ outputs: MutableMapping[str, np.ndarray],
+ work_dir: Optional[pathlib.Path] = None,
+ **kwargs
+ ):
+ work_dir = pathlib.Path(".") if work_dir is None else work_dir
+
+ open(work_dir / "history_mask.txt", "w")
+ open(work_dir / "history_cmi.txt", "w")
+ train_loader, valid_loader = self.get_data_loaders(inputs, outputs)
+
+ final_graph_data = None
+
+ epoch_iter = range(self.longest_epoch) if self.longest_epoch >= 0 else count()
+ epochs_since_update = 0
+
+ loss_func = partial(variable_loss_func, output_variables=self.output_variables, device=self.device)
+ train = partial(train_func, forward=self.multi_graph_forward, optimizer=self.optimizer, loss_func=loss_func)
+ eval = partial(eval_func, forward=self.multi_graph_forward, loss_func=loss_func)
+
+ best_eval_loss = eval(valid_loader).mean(dim=(0, 2, 3))
+
+ for epoch in epoch_iter:
+ train_loss = train(train_loader)
+ eval_loss = eval(valid_loader)
+
+ improvement = (best_eval_loss - eval_loss.mean(dim=(0, 2, 3))) / torch.abs(best_eval_loss)
+ if (improvement > self.improvement_threshold).any().item():
+ best_eval_loss = torch.minimum(best_eval_loss, eval_loss.mean(dim=(0, 2, 3)))
+ epochs_since_update = 0
+
+ final_graph_data, mean_nll_loss_diff = self.calculate_CMI(eval_loss)
+ with open(work_dir / "history_mask.txt", "a") as f:
+ f.write(str(final_graph_data) + "\n")
+ with open(work_dir / "history_cmi.txt", "a") as f:
+ f.write(str(mean_nll_loss_diff) + "\n")
+ print(
+ "new best valid, CMI test result:\n{}\nwith mean nll loss diff:\n{}".format(
+ final_graph_data, mean_nll_loss_diff
+ )
+ )
+ else:
+ epochs_since_update += 1
+
+ # log
+ self.total_CMI_epoch += 1
+ if self.logger is not None:
+ self.logger.record("{}-CMI-test/epoch".format(self.name), epoch)
+ self.logger.record("{}-CMI-test/epochs_since_update".format(self.name), epochs_since_update)
+ self.logger.record("{}-CMI-test/train_dataset_size".format(self.name), len(train_loader.dataset))
+ self.logger.record("{}-CMI-test/valid_dataset_size".format(self.name), len(valid_loader.dataset))
+ self.logger.record("{}-CMI-test/train_loss".format(self.name), train_loss.mean().item())
+ self.logger.record("{}-CMI-test/val_loss".format(self.name), eval_loss.mean().item())
+ self.logger.record("{}-CMI-test/best_val_loss".format(self.name), best_eval_loss.mean().item())
+ self.logger.record("{}-CMI-test/lr".format(self.name), self.optimizer.param_groups[0]["lr"])
+
+ self.logger.dump(self.total_CMI_epoch)
+
+ if self.patience and epochs_since_update >= self.patience:
+ break
+
+ self.scheduler.step()
+ print(self.optimizer)
+
+ assert final_graph_data is not None
+ self.graph.set_data(final_graph_data)
+ self.build_optimizer()
+
+ super(CMITestMech, self).learn(inputs, outputs, work_dir=work_dir, **kwargs)
+
+
+if __name__ == "__main__":
+ import gym
+ from stable_baselines3.common.buffers import ReplayBuffer
+ from torch.utils.data import DataLoader
+
+ from cmrl.models.data_loader import EnsembleBufferDataset, collate_fn, buffer_to_dict
+ from cmrl.utils.creator import parse_space
+ from cmrl.sb3_extension.logger import configure as logger_configure
+
+ from cmrl.utils.env import load_offline_data
+ from cmrl.models.causal_mech.util import variable_loss_func
+
+ def unwrap_env(env):
+ while isinstance(env, gym.Wrapper):
+ env = env.env
+ return env
+
+ env = unwrap_env(gym.make("ParallelContinuousCartPoleSwingUp-v0"))
+ real_replay_buffer = ReplayBuffer(
+ int(1e6), env.observation_space, env.action_space, "cpu", handle_timeout_termination=False
+ )
+ load_offline_data(env, real_replay_buffer, "SAC-expert", use_ratio=0.01)
+
+ extra_info = {"Radian": ["obs_1", "obs_5", "obs_9"]}
+ # extra_info = {"Radian": ["obs_1"]}
+
+ input_variables = parse_space(env.state_space, "obs", extra_info=extra_info) + parse_space(env.action_space, "act")
+ output_variables = parse_space(env.state_space, "next_obs", extra_info=extra_info)
+
+ logger = logger_configure("cmi-log", ["tensorboard", "stdout"])
+
+ mech = CMITestMech("kernel_test_mech", input_variables, output_variables)
+
+ inputs, outputs = buffer_to_dict(env.state_space, env.action_space, env.obs2state, real_replay_buffer, "transition")
+
+ mech.learn(inputs, outputs)
diff --git a/cmrl/models/causal_mech/__init__.py b/cmrl/models/causal_mech/__init__.py
new file mode 100644
index 0000000..dd309df
--- /dev/null
+++ b/cmrl/models/causal_mech/__init__.py
@@ -0,0 +1,4 @@
+from cmrl.models.causal_mech.oracle_mech import OracleMech
+from cmrl.models.causal_mech.CMI_test import CMITestMech
+# from cmrl.models.causal_mech.reinforce import ReinforceCausalMech
+from cmrl.models.causal_mech.kernel_test import KernelTestMech
\ No newline at end of file
diff --git a/cmrl/models/causal_mech/base.py b/cmrl/models/causal_mech/base.py
new file mode 100644
index 0000000..4a46b5a
--- /dev/null
+++ b/cmrl/models/causal_mech/base.py
@@ -0,0 +1,419 @@
+from typing import Optional, List, Dict, Union, MutableMapping
+from abc import abstractmethod, ABC
+from itertools import chain, count
+import pathlib
+from functools import partial
+import copy
+from multiprocessing import cpu_count
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+from torch.optim import Optimizer
+from omegaconf import DictConfig
+from stable_baselines3.common.logger import Logger
+from hydra.utils import instantiate
+
+from cmrl.models.graphs.base_graph import BaseGraph
+from cmrl.models.graphs.binary_graph import BinaryGraph
+from cmrl.utils.variables import Variable
+from cmrl.models.constant import NETWORK_CFG, ENCODER_CFG, DECODER_CFG, OPTIMIZER_CFG, SCHEDULER_CFG
+from cmrl.models.networks.base_network import BaseNetwork
+from cmrl.models.graphs.base_graph import BaseGraph
+from cmrl.models.networks.coder import VariableEncoder, VariableDecoder
+from cmrl.utils.variables import Variable, ContinuousVariable, DiscreteVariable, BinaryVariable
+from cmrl.models.causal_mech.util import variable_loss_func, train_func, eval_func
+from cmrl.models.data_loader import EnsembleBufferDataset, collate_fn
+
+
+class BaseCausalMech(ABC):
+ """The base class of causal-mech learned by neural networks.
+ Pay attention that the causal discovery maybe not realized through a neural way.
+ """
+
+ def __init__(
+ self,
+ name: str,
+ input_variables: List[Variable],
+ output_variables: List[Variable],
+ logger: Optional[Logger] = None,
+ ):
+ self.name = name
+ self.input_variables = input_variables
+ self.output_variables = output_variables
+ self.logger = logger
+
+ self.input_variables_dict = dict([(v.name, v) for v in self.input_variables])
+ self.output_variables_dict = dict([(v.name, v) for v in self.output_variables])
+
+ self.input_var_num = len(self.input_variables)
+ self.output_var_num = len(self.output_variables)
+ self.graph: Optional[BaseGraph] = None
+
+ @abstractmethod
+ def learn(
+ self,
+ inputs: MutableMapping[str, np.ndarray],
+ outputs: MutableMapping[str, np.ndarray],
+ work_dir: Optional[Union[str, pathlib.Path]] = None,
+ **kwargs
+ ):
+ raise NotImplementedError
+
+ @abstractmethod
+ def forward(self, inputs: MutableMapping[str, np.ndarray]) -> Dict[str, torch.Tensor]:
+ raise NotImplementedError
+
+ @property
+ def causal_graph(self) -> torch.Tensor:
+ """property causal graph"""
+ if self.graph is None:
+ return torch.ones(len(self.input_variables), len(self.output_variables), dtype=torch.int, device=self.device)
+ else:
+ return self.graph.get_binary_adj_matrix()
+
+ def save(self, save_dir: Union[str, pathlib.Path]):
+ pass
+
+ def load(self, load_dir: Union[str, pathlib.Path]):
+ pass
+
+
+class EnsembleNeuralMech(BaseCausalMech):
+ def __init__(
+ self,
+ # base
+ name: str,
+ input_variables: List[Variable],
+ output_variables: List[Variable],
+ logger: Optional[Logger] = None,
+ # model learning
+ longest_epoch: int = -1,
+ improvement_threshold: float = 0.01,
+ patience: int = 5,
+ batch_size: int = 256,
+ # ensemble
+ ensemble_num: int = 7,
+ elite_num: int = 5,
+ # cfgs
+ network_cfg: Optional[DictConfig] = None,
+ encoder_cfg: Optional[DictConfig] = None,
+ decoder_cfg: Optional[DictConfig] = None,
+ optimizer_cfg: Optional[DictConfig] = None,
+ scheduler_cfg: Optional[DictConfig] = None,
+ # forward method
+ residual: bool = True,
+ encoder_reduction: str = "sum",
+ # others
+ device: Union[str, torch.device] = "cpu",
+ ):
+ BaseCausalMech.__init__(
+ self, name=name, input_variables=input_variables, output_variables=output_variables, logger=logger
+ )
+ # model learning
+ self.longest_epoch = longest_epoch
+ self.improvement_threshold = improvement_threshold
+ self.patience = patience
+ self.batch_size = batch_size
+ # ensemble
+ self.ensemble_num = ensemble_num
+ self.elite_num = elite_num
+ # cfgs
+ self.network_cfg = NETWORK_CFG if network_cfg is None else network_cfg
+ self.encoder_cfg = ENCODER_CFG if encoder_cfg is None else encoder_cfg
+ self.decoder_cfg = DECODER_CFG if decoder_cfg is None else decoder_cfg
+ self.optimizer_cfg = OPTIMIZER_CFG if optimizer_cfg is None else optimizer_cfg
+ self.scheduler_cfg = SCHEDULER_CFG if scheduler_cfg is None else scheduler_cfg
+ # forward method
+ self.residual = residual
+ self.encoder_reduction = encoder_reduction
+ # others
+ self.device = device
+
+ # build member object
+ self.variable_encoders: Optional[Dict[str, VariableEncoder]] = None
+ self.variable_decoders: Optional[Dict[str, VariableEncoder]] = None
+ self.network: Optional[BaseNetwork] = None
+ self.graph: Optional[BaseGraph] = None
+ self.optimizer: Optional[Optimizer] = None
+ self.scheduler: Optional[object] = None
+ self.build_coders()
+ self.build_network()
+ self.build_graph()
+ self.build_optimizer()
+
+ self.total_epoch = 0
+ self.elite_indices: List[int] = []
+
+ @property
+ def encoder_output_dim(self):
+ return self.encoder_cfg.output_dim
+
+ @property
+ def decoder_input_dim(self):
+ return self.decoder_cfg.input_dim
+
+ def build_network(self):
+ self.network = instantiate(self.network_cfg)(
+ input_dim=self.encoder_output_dim,
+ output_dim=self.decoder_input_dim,
+ extra_dims=[self.output_var_num, self.ensemble_num],
+ ).to(self.device)
+
+ def build_optimizer(self):
+ assert self.network, "you must build network first"
+ assert self.variable_encoders and self.variable_decoders, "you must build coders first"
+ params = (
+ [self.network.parameters()]
+ + [encoder.parameters() for encoder in self.variable_encoders.values()]
+ + [decoder.parameters() for decoder in self.variable_decoders.values()]
+ )
+
+ self.optimizer = instantiate(self.optimizer_cfg)(params=chain(*params))
+ self.scheduler = instantiate(self.scheduler_cfg)(optimizer=self.optimizer)
+
+ def forward(self, inputs: MutableMapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ batch_size, _ = self.get_inputs_batch_size(inputs)
+
+ inputs_tensor = torch.zeros(self.ensemble_num, batch_size, self.input_var_num, self.encoder_output_dim).to(self.device)
+ for i, var in enumerate(self.input_variables):
+ out = self.variable_encoders[var.name](inputs[var.name].to(self.device))
+ inputs_tensor[:, :, i] = out
+
+ output_tensor = self.network(self.reduce_encoder_output(inputs_tensor))
+
+ outputs = {}
+ for i, var in enumerate(self.output_variables):
+ hid = output_tensor[i]
+ outputs[var.name] = self.variable_decoders[var.name](hid)
+
+ if self.residual:
+ outputs = self.residual_outputs(inputs, outputs)
+ return outputs
+
+ def build_graph(self):
+ pass
+
+ def build_coders(self):
+ self.variable_encoders = {}
+ for var in self.input_variables:
+ assert var.name not in self.variable_encoders, "duplicate name in encoders: {}".format(var.name)
+ self.variable_encoders[var.name] = instantiate(self.encoder_cfg)(variable=var).to(self.device)
+
+ self.variable_decoders = {}
+ for var in self.output_variables:
+ assert var.name not in self.variable_decoders, "duplicate name in decoders: {}".format(var.name)
+ self.variable_decoders[var.name] = instantiate(self.decoder_cfg)(variable=var).to(self.device)
+
+ def save(self, save_dir: Union[str, pathlib.Path]):
+ if isinstance(save_dir, str):
+ save_dir = pathlib.Path(save_dir)
+ save_dir = save_dir / pathlib.Path(self.name)
+ save_dir.mkdir(exist_ok=True)
+
+ self.network.save(save_dir)
+ if self.graph is not None:
+ self.graph.save(save_dir)
+ for coder in self.variable_encoders.values():
+ coder.save(save_dir)
+ for coder in self.variable_decoders.values():
+ coder.save(save_dir)
+
+ def load(self, load_dir: Union[str, pathlib.Path]):
+ if isinstance(load_dir, str):
+ load_dir = pathlib.Path(load_dir)
+ assert load_dir.exists()
+
+ self.network.load(load_dir)
+ if self.graph is not None:
+ self.graph.load(load_dir)
+ for coder in self.variable_encoders.values():
+ coder.load(load_dir)
+ for coder in self.variable_decoders.values():
+ coder.load(load_dir)
+
+ def get_inputs_info(self, inputs: MutableMapping[str, torch.Tensor]):
+ assert len(set(inputs.keys()) & set(self.input_variables_dict.keys())) == len(inputs)
+ data_shape = next(iter(inputs.values())).shape
+ # assert len(data_shape) == 3, "{}".format(data_shape) # ensemble-num, batch-size, specific-dim
+ ensemble, batch_size, specific_dim = data_shape[-3:]
+ assert ensemble == self.ensemble_num
+
+ return batch_size, data_shape[:-3]
+
+ def residual_outputs(
+ self,
+ inputs: MutableMapping[str, torch.Tensor],
+ outputs: MutableMapping[str, torch.Tensor],
+ ) -> MutableMapping[str, torch.Tensor]:
+ for name in filter(lambda s: s.startswith("obs"), inputs.keys()):
+ # assert inputs[name].shape[:2] == outputs["next_{}".format(name)].shape[:2]
+ # assert inputs[name].shape[2] * 2 == outputs["next_{}".format(name)].shape[2]
+ var_dim = inputs[name].shape[-1]
+ outputs["next_{}".format(name)][..., :var_dim] += inputs[name].to(self.device)
+ return outputs
+
+ def reduce_encoder_output(
+ self,
+ encoder_output: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ assert len(encoder_output.shape) == 4, (
+ "shape of `encoder_output` should be (ensemble-num, batch-size, input-var-num, encoder-output-dim), "
+ "rather than {}".format(encoder_output.shape)
+ )
+
+ if mask is None:
+ # [..., input-var-num]
+ mask = self.forward_mask
+ # [..., ensemble-num, batch-size, input-var-num]
+ mask = mask.unsqueeze(-2).unsqueeze(-2)
+ mask = mask.repeat((1,) * len(mask.shape[:-3]) + (*encoder_output.shape[:2], 1))
+
+ # mask shape [..., ensemble-num, batch-size, input-var-num]
+ assert (
+ mask.shape[-3:] == encoder_output.shape[:-1]
+ ), "mask shape should be (..., ensemble-num, batch-size, input-var-num)"
+
+ # [*mask-extra-dims, ensemble-num, batch-size, input-var-num, encoder-output-dim]
+ mask = mask[..., None].repeat([1] * len(mask.shape) + [encoder_output.shape[-1]])
+ masked_encoder_output = encoder_output.repeat(tuple(mask.shape[:-4]) + (1,) * 4)
+
+ # choose mask value
+ mask_value = 0
+ if self.encoder_reduction == "max":
+ mask_value = -float("inf")
+ masked_encoder_output[mask == 0] = mask_value
+
+ if self.encoder_reduction == "sum":
+ return masked_encoder_output.sum(-2)
+ elif self.encoder_reduction == "mean":
+ return masked_encoder_output.mean(-2)
+ elif self.encoder_reduction == "max":
+ values, indices = masked_encoder_output.max(-2)
+ return values
+ else:
+ raise NotImplementedError("not implemented encoder reduction method: {}".format(self.encoder_reduction))
+
+ @property
+ def forward_mask(self) -> torch.Tensor:
+ """property input masks"""
+ return self.causal_graph.T
+
+ def get_data_loaders(
+ self,
+ inputs: MutableMapping[str, np.ndarray],
+ outputs: MutableMapping[str, np.ndarray],
+ ):
+ train_set = EnsembleBufferDataset(
+ inputs=inputs, outputs=outputs, training=True, train_ratio=0.8, ensemble_num=self.ensemble_num, seed=1
+ )
+ valid_set = EnsembleBufferDataset(
+ inputs=inputs, outputs=outputs, training=False, train_ratio=0.8, ensemble_num=self.ensemble_num, seed=1
+ )
+
+ train_loader = DataLoader(train_set, batch_size=self.batch_size, collate_fn=collate_fn, num_workers=cpu_count())
+ valid_loader = DataLoader(valid_set, batch_size=self.batch_size, collate_fn=collate_fn, num_workers=cpu_count())
+
+ return train_loader, valid_loader
+
+ def learn(
+ self,
+ inputs: MutableMapping[str, np.ndarray],
+ outputs: MutableMapping[str, np.ndarray],
+ work_dir: Optional[Union[str, pathlib.Path]] = None,
+ **kwargs
+ ):
+ train_loader, valid_loader = self.get_data_loaders(inputs, outputs)
+
+ best_weights: Optional[Dict] = None
+ epoch_iter = range(self.longest_epoch) if self.longest_epoch >= 0 else count()
+ epochs_since_update = 0
+
+ loss_func = partial(variable_loss_func, output_variables=self.output_variables, device=self.device)
+ train = partial(train_func, forward=self.forward, optimizer=self.optimizer, loss_func=loss_func)
+ eval = partial(eval_func, forward=self.forward, loss_func=loss_func)
+
+ best_eval_loss = eval(valid_loader).mean(dim=(-2, -1))
+
+ for epoch in epoch_iter:
+ train_loss = train(train_loader)
+ eval_loss = eval(valid_loader)
+
+ maybe_best_weights = self._maybe_get_best_weights(
+ best_eval_loss, eval_loss.mean(dim=(-2, -1)), self.improvement_threshold
+ )
+ if maybe_best_weights:
+ # best loss
+ best_eval_loss = torch.minimum(best_eval_loss, eval_loss.mean(dim=(-2, -1)))
+ best_weights = maybe_best_weights
+ epochs_since_update = 0
+ else:
+ epochs_since_update += 1
+
+ # log
+ self.total_epoch += 1
+ if self.logger is not None:
+ self.logger.record("{}/epoch".format(self.name), epoch)
+ self.logger.record("{}/epochs_since_update".format(self.name), epochs_since_update)
+ self.logger.record("{}/train_dataset_size".format(self.name), len(train_loader.dataset))
+ self.logger.record("{}/valid_dataset_size".format(self.name), len(valid_loader.dataset))
+ self.logger.record("{}/train_loss".format(self.name), train_loss.mean().item())
+ self.logger.record("{}/val_loss".format(self.name), eval_loss.mean().item())
+ self.logger.record("{}/best_val_loss".format(self.name), best_eval_loss.mean().item())
+ self.logger.record("{}/lr".format(self.name), self.optimizer.param_groups[0]["lr"])
+
+ self.logger.dump(self.total_epoch)
+
+ if self.patience and epochs_since_update >= self.patience:
+ break
+
+ self.scheduler.step()
+
+ # saving the best models:
+ self._maybe_set_best_weights_and_elite(best_weights, best_eval_loss)
+
+ self.save(save_dir=work_dir)
+
+ def _maybe_get_best_weights(
+ self,
+ best_val_loss: torch.Tensor,
+ val_loss: torch.Tensor,
+ threshold: float = 0.01,
+ ) -> Optional[Dict]:
+ """Return the current model state dict if the validation score improves.
+ For ensembles, this checks the validation for each ensemble member separately.
+ Copy from https://github.com/facebookresearch/mbrl-lib/blob/main/mbrl/models/model_trainer.py
+
+ Args:
+ best_val_score (tensor): the current best validation losses per model.
+ val_score (tensor): the new validation loss per model.
+ threshold (float): the threshold for relative improvement.
+ Returns:
+ (dict, optional): if the validation score's relative improvement over the
+ best validation score is higher than the threshold, returns the state dictionary
+ of the stored model, otherwise returns ``None``.
+ """
+ improvement = (best_val_loss - val_loss) / torch.abs(best_val_loss)
+ if (improvement > threshold).any().item():
+ best_weights = copy.deepcopy(self.network.state_dict())
+ else:
+ best_weights = None
+
+ return best_weights
+
+ def _maybe_set_best_weights_and_elite(self, best_weights: Optional[Dict], best_val_score: torch.Tensor):
+ if best_weights is not None:
+ self.network.load_state_dict(best_weights)
+
+ sorted_indices = np.argsort(best_val_score.tolist())
+ self.elite_indices = sorted_indices[: self.elite_num]
+
+ def get_inputs_batch_size(self, inputs: MutableMapping[str, torch.Tensor]) -> int:
+ assert len(set(inputs.keys()) & set(self.variable_encoders.keys())) == len(inputs)
+ data_shape = list(inputs.values())[0].shape
+ # assert len(data_shape) == 3, "{}".format(data_shape) # ensemble-num, batch-size, specific-dim
+ ensemble, batch_size, specific_dim = data_shape[-3:]
+ assert ensemble == self.ensemble_num
+
+ return batch_size, data_shape[:-3]
diff --git a/cmrl/models/causal_mech/kernel_test.py b/cmrl/models/causal_mech/kernel_test.py
new file mode 100644
index 0000000..f4389c3
--- /dev/null
+++ b/cmrl/models/causal_mech/kernel_test.py
@@ -0,0 +1,256 @@
+from typing import Optional, List, Dict, Union, MutableMapping
+from functools import partial
+from collections import defaultdict
+
+import pathlib
+import numpy
+import numpy as np
+import torch
+from omegaconf import DictConfig
+from stable_baselines3.common.logger import Logger
+from hydra.utils import instantiate
+
+# from cmrl.utils.RCIT import KCI_CInd
+from causallearn.utils.KCI.KCI import KCI_CInd
+from tqdm import tqdm
+
+from cmrl.models.causal_mech.base import EnsembleNeuralMech
+from cmrl.utils.variables import Variable, ContinuousVariable, DiscreteVariable, BinaryVariable, RadianVariable
+from cmrl.models.graphs.binary_graph import BinaryGraph
+
+
+class KernelTestMech(EnsembleNeuralMech):
+ def __init__(
+ self,
+ # base
+ name: str,
+ input_variables: List[Variable],
+ output_variables: List[Variable],
+ logger: Optional[Logger] = None,
+ # model learning
+ longest_epoch: int = -1,
+ improvement_threshold: float = 0.01,
+ patience: int = 5,
+ batch_size: int = 256,
+ # ensemble
+ ensemble_num: int = 7,
+ elite_num: int = 5,
+ # cfgs
+ network_cfg: Optional[DictConfig] = None,
+ encoder_cfg: Optional[DictConfig] = None,
+ decoder_cfg: Optional[DictConfig] = None,
+ optimizer_cfg: Optional[DictConfig] = None,
+ scheduler_cfg: Optional[DictConfig] = None,
+ # forward method
+ residual: bool = True,
+ encoder_reduction: str = "sum",
+ # others
+ device: Union[str, torch.device] = "cpu",
+ # KCI
+ sample_num: int = 2000,
+ kci_times: int = 10,
+ not_confident_bound: float = 0.25,
+ longest_sample: int = 5000,
+ ):
+ EnsembleNeuralMech.__init__(
+ self,
+ name=name,
+ input_variables=input_variables,
+ output_variables=output_variables,
+ logger=logger,
+ longest_epoch=longest_epoch,
+ improvement_threshold=improvement_threshold,
+ patience=patience,
+ batch_size=batch_size,
+ ensemble_num=ensemble_num,
+ elite_num=elite_num,
+ network_cfg=network_cfg,
+ encoder_cfg=encoder_cfg,
+ decoder_cfg=decoder_cfg,
+ optimizer_cfg=optimizer_cfg,
+ scheduler_cfg=scheduler_cfg,
+ residual=residual,
+ encoder_reduction=encoder_reduction,
+ device=device,
+ )
+ self.sample_num = sample_num
+ self.kci_times = kci_times
+ self.not_confident_bound = not_confident_bound
+ self.longest_sample = longest_sample
+
+ def kci(
+ self,
+ input_idx: int,
+ output_idx: int,
+ inputs: MutableMapping[str, numpy.ndarray],
+ outputs: MutableMapping[str, numpy.ndarray],
+ sample_indices: np.ndarray,
+ ):
+ in_name, out_name = list(inputs.keys())[input_idx], list(outputs.keys())[output_idx]
+
+ if self.residual:
+ data_x = outputs[out_name][sample_indices] - inputs[out_name.replace("next_", "")][sample_indices]
+ else:
+ data_x = outputs[out_name][sample_indices]
+
+ def deal_with_radian_input(name, data):
+ if isinstance(self.input_variables_dict[name], RadianVariable):
+ return (data + np.pi) % (2 * np.pi) - np.pi
+ else:
+ return data
+
+ data_y = deal_with_radian_input(in_name, inputs[in_name])[sample_indices]
+ data_z = [
+ deal_with_radian_input(other_in_name, in_data)[sample_indices]
+ for other_in_name, in_data in inputs.items()
+ if other_in_name != in_name
+ ]
+ data_z = np.concatenate(data_z, axis=1)
+
+ kci = KCI_CInd()
+ p_value, test_stat = kci.compute_pvalue(data_x, data_y, data_z)
+ return p_value
+
+ def kci_compute_graph(
+ self,
+ inputs: MutableMapping[str, numpy.ndarray],
+ outputs: MutableMapping[str, numpy.ndarray],
+ work_dir: Optional[pathlib.Path] = None,
+ **kwargs
+ ):
+
+ open(work_dir / "history_vote.txt", "w")
+
+ length = next(iter(inputs.values())).shape[0]
+ sample_length = min(length, self.sample_num) if self.sample_num > 0 else length
+
+ init_pvalues_array = np.empty((self.kci_times, self.input_var_num, self.output_var_num))
+ with tqdm(
+ total=self.kci_times * self.input_var_num * self.output_var_num,
+ desc="init kci of {} samples".format(sample_length),
+ ) as pbar:
+ for time in range(self.kci_times):
+ sample_indices = np.random.permutation(length)[:sample_length]
+ kci = partial(self.kci, inputs=inputs, outputs=outputs, sample_indices=sample_indices)
+ for out_idx in range(len(outputs)):
+ for in_idx in range(len(inputs)):
+ init_pvalues_array[time][in_idx][out_idx] = kci(in_idx, out_idx)
+ pbar.update(1)
+
+ votes = (init_pvalues_array < 0.05).mean(axis=0)
+ is_not_confident = np.logical_and(votes > self.not_confident_bound, votes < 1 - self.not_confident_bound)
+ not_confident_list = np.array(np.where(is_not_confident)).T
+
+ recompute_times = 1
+ while len(not_confident_list) != 0:
+ with open(work_dir / "history_vote.txt", "a") as f:
+ f.write(str(votes) + "\n")
+ print(votes)
+
+ new_sample_length = int(sample_length * 1.5**recompute_times)
+ if new_sample_length > min(self.longest_sample, length):
+ break
+
+ pvalues_dict = defaultdict(list)
+ with tqdm(
+ total=self.kci_times * len(not_confident_list),
+ desc="{}th re-compute kci of {} samples".format(recompute_times, new_sample_length),
+ ) as pbar:
+ for time in range(self.kci_times):
+ sample_indices = np.random.permutation(length)[:new_sample_length]
+ kci = partial(self.kci, inputs=inputs, outputs=outputs, sample_indices=sample_indices)
+ for in_idx, out_idx in not_confident_list:
+ pvalues_dict[(in_idx, out_idx)].append(kci(in_idx, out_idx))
+ pbar.update(1)
+
+ not_confident_list = []
+ for key, value in pvalues_dict.items():
+ vote = (np.array(value) < 0.05).mean()
+ if self.not_confident_bound < vote < 1 - self.not_confident_bound:
+ not_confident_list.append(key)
+ else:
+ votes[key] = vote
+ recompute_times += 1
+
+ return votes > 0.5
+
+ def build_network(self):
+ self.network = instantiate(self.network_cfg)(
+ input_dim=self.encoder_output_dim,
+ output_dim=self.decoder_input_dim,
+ extra_dims=[self.ensemble_num],
+ ).to(self.device)
+
+ def forward(self, inputs: MutableMapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ batch_size, _ = self.get_inputs_batch_size(inputs)
+
+ inputs_tensor = torch.zeros(self.ensemble_num, batch_size, self.input_var_num, self.encoder_output_dim).to(self.device)
+ for i, var in enumerate(self.input_variables):
+ out = self.variable_encoders[var.name](inputs[var.name].to(self.device))
+ inputs_tensor[:, :, i] = out
+
+ output_tensor = self.network(self.reduce_encoder_output(inputs_tensor))
+
+ outputs = {}
+ for i, var in enumerate(self.output_variables):
+ hid = output_tensor[i]
+ outputs[var.name] = self.variable_decoders[var.name](hid)
+
+ if self.residual:
+ outputs = self.residual_outputs(inputs, outputs)
+ return outputs
+
+ def build_graph(self):
+ self.graph = BinaryGraph(self.input_var_num, self.output_var_num, device=self.device)
+
+ def learn(
+ self,
+ inputs: MutableMapping[str, np.ndarray],
+ outputs: MutableMapping[str, np.ndarray],
+ work_dir: Optional[pathlib.Path] = None,
+ **kwargs
+ ):
+ work_dir = pathlib.Path(".") if work_dir is None else work_dir
+ graph = self.kci_compute_graph(inputs, outputs, work_dir)
+ self.graph.set_data(graph)
+
+ super(KernelTestMech, self).learn(inputs, outputs, work_dir=work_dir, **kwargs)
+
+
+if __name__ == "__main__":
+ import gym
+ from emei import EmeiEnv
+ from stable_baselines3.common.buffers import ReplayBuffer
+ from torch.utils.data import DataLoader
+ from typing import cast
+
+ from cmrl.models.data_loader import EnsembleBufferDataset, collate_fn, buffer_to_dict
+ from cmrl.utils.creator import parse_space
+ from cmrl.utils.env import load_offline_data
+ from cmrl.sb3_extension.logger import configure as logger_configure
+ from cmrl.models.causal_mech.util import variable_loss_func
+
+ def unwrap_env(env):
+ while isinstance(env, gym.Wrapper):
+ env = env.env
+ return env
+
+ env = unwrap_env(gym.make("ParallelContinuousCartPoleSwingUp-v0"))
+ real_replay_buffer = ReplayBuffer(
+ int(1e6), env.observation_space, env.action_space, "cpu", handle_timeout_termination=False
+ )
+ load_offline_data(env, real_replay_buffer, "SAC-expert", use_ratio=1)
+
+ extra_info = {"Radian": ["obs_1", "obs_5", "obs_9"]}
+ # extra_info = {"Radian": ["obs_1"]}
+
+ input_variables = parse_space(env.state_space, "obs", extra_info=extra_info) + parse_space(env.action_space, "act")
+ output_variables = parse_space(env.state_space, "next_obs", extra_info=extra_info)
+
+ logger = logger_configure("kci-log", ["tensorboard", "stdout"])
+
+ mech = KernelTestMech("kernel_test_mech", input_variables, output_variables, sample_num=100, kci_times=20, logger=logger)
+
+ inputs, outputs = buffer_to_dict(env.state_space, env.action_space, env.obs2state, real_replay_buffer, "transition")
+
+ mech.learn(inputs, outputs)
diff --git a/cmrl/models/causal_mech/oracle_mech.py b/cmrl/models/causal_mech/oracle_mech.py
new file mode 100644
index 0000000..dda20c1
--- /dev/null
+++ b/cmrl/models/causal_mech/oracle_mech.py
@@ -0,0 +1,103 @@
+from typing import Optional, List, Dict, Union, MutableMapping
+
+import numpy
+import torch
+from torch.utils.data import DataLoader
+import numpy as np
+from omegaconf import DictConfig
+from hydra.utils import instantiate
+from stable_baselines3.common.logger import Logger
+
+from cmrl.utils.variables import Variable
+from cmrl.models.causal_mech.base import EnsembleNeuralMech
+from cmrl.models.graphs.binary_graph import BinaryGraph
+from cmrl.models.data_loader import EnsembleBufferDataset, collate_fn
+
+
+class OracleMech(EnsembleNeuralMech):
+ def __init__(
+ self,
+ name: str,
+ input_variables: List[Variable],
+ output_variables: List[Variable],
+ logger: Optional[Logger] = None,
+ # model learning
+ longest_epoch: int = -1,
+ improvement_threshold: float = 0.01,
+ patience: int = 5,
+ batch_size: int = 256,
+ # ensemble
+ ensemble_num: int = 7,
+ elite_num: int = 5,
+ # cfgs
+ network_cfg: Optional[DictConfig] = None,
+ encoder_cfg: Optional[DictConfig] = None,
+ decoder_cfg: Optional[DictConfig] = None,
+ optimizer_cfg: Optional[DictConfig] = None,
+ scheduler_cfg: Optional[DictConfig] = None,
+ # forward method
+ residual: bool = True,
+ encoder_reduction: str = "sum",
+ # others
+ device: Union[str, torch.device] = "cpu",
+ ):
+ EnsembleNeuralMech.__init__(
+ self,
+ name=name,
+ input_variables=input_variables,
+ output_variables=output_variables,
+ logger=logger,
+ longest_epoch=longest_epoch,
+ improvement_threshold=improvement_threshold,
+ patience=patience,
+ batch_size=batch_size,
+ ensemble_num=ensemble_num,
+ elite_num=elite_num,
+ network_cfg=network_cfg,
+ encoder_cfg=encoder_cfg,
+ decoder_cfg=decoder_cfg,
+ optimizer_cfg=optimizer_cfg,
+ scheduler_cfg=scheduler_cfg,
+ residual=residual,
+ encoder_reduction=encoder_reduction,
+ device=device,
+ )
+
+ def set_oracle_graph(self, graph_data: Optional[numpy.ndarray]):
+ self.graph = BinaryGraph(self.input_var_num, self.output_var_num, device=self.device)
+ if graph_data is None:
+ graph_data = np.ones([self.input_var_num, self.output_var_num])
+ self.graph.set_data(graph_data=graph_data)
+ print("set oracle causal graph successfully: \n{}".format(graph_data))
+
+
+if __name__ == "__main__":
+ from typing import cast
+
+ import gym
+ from stable_baselines3.common.buffers import ReplayBuffer
+ from torch.utils.data import DataLoader
+ from emei import EmeiEnv
+
+ from cmrl.models.data_loader import EnsembleBufferDataset, collate_fn, buffer_to_dict
+ from cmrl.utils.creator import parse_space
+ from cmrl.utils.env import load_offline_data
+ from cmrl.models.causal_mech.util import variable_loss_func
+ from cmrl.sb3_extension.logger import configure as logger_configure
+
+ env = cast(EmeiEnv, gym.make("ParallelContinuousCartPoleSwingUp-v0"))
+ real_replay_buffer = ReplayBuffer(
+ int(1e6), env.observation_space, env.action_space, "cpu", handle_timeout_termination=False
+ )
+ load_offline_data(env, real_replay_buffer, "SAC-expert", use_ratio=1)
+
+ input_variables = parse_space(env.state_space, "obs") + parse_space(env.action_space, "act")
+ output_variables = parse_space(env.state_space, "next_obs")
+
+ logger = logger_configure("kci-log", ["tensorboard", "stdout"])
+
+ mech = OracleMech("plain_mech", input_variables, output_variables, logger=logger, device="cuda:1")
+
+ inputs, outputs = buffer_to_dict(env.observation_space, env.action_space, env.obs2state, real_replay_buffer, "transition")
+
+ mech.learn(inputs, outputs)
diff --git a/cmrl/models/causal_mech/reinforce.py b/cmrl/models/causal_mech/reinforce.py
new file mode 100644
index 0000000..cc089a8
--- /dev/null
+++ b/cmrl/models/causal_mech/reinforce.py
@@ -0,0 +1,389 @@
+from typing import List, Optional, Dict, Union, MutableMapping, Tuple
+import math
+import pathlib
+from itertools import count
+from functools import partial
+import copy
+
+import torch
+import numpy as np
+from torch.utils.data import DataLoader
+from stable_baselines3.common.logger import Logger
+from omegaconf import DictConfig
+from hydra.utils import instantiate
+
+from cmrl.utils.variables import Variable
+from cmrl.models.causal_mech.neural_causal_mech import NeuralCausalMech
+from cmrl.models.graphs.prob_graph import BernoulliGraph
+from cmrl.models.causal_mech.util import variable_loss_func, train_func, eval_func
+
+default_graph_optimizer_cfg = DictConfig(
+ dict(
+ _target_="torch.optim.Adam",
+ _partial_=True,
+ lr=1e-3,
+ weight_decay=0.0,
+ eps=1e-8,
+ )
+)
+
+
+class ReinforceCausalMech(NeuralCausalMech):
+ def __init__(
+ self,
+ name: str,
+ input_variables: List[Variable],
+ output_variables: List[Variable],
+ # model learning
+ longest_epoch: int = -1,
+ improvement_threshold: float = 0.01,
+ patience: int = 5,
+ # ensemble
+ ensemble_num: int = 7,
+ elite_num: int = 5,
+ # cfgs
+ network_cfg: Optional[DictConfig] = None,
+ encoder_cfg: Optional[DictConfig] = None,
+ decoder_cfg: Optional[DictConfig] = None,
+ optimizer_cfg: Optional[DictConfig] = None,
+ graph_optimizer_cfg: Optional[DictConfig] = default_graph_optimizer_cfg,
+ # graph params
+ concat_mask: bool = True,
+ graph_MC_samples: int = 100,
+ graph_max_stack: int = 200,
+ lambda_sparse: float = 1e-3,
+ # forward method
+ residual: bool = True,
+ encoder_reduction: str = "sum",
+ multi_step: str = "none",
+ # logger
+ logger: Optional[Logger] = None,
+ # others
+ device: Union[str, torch.device] = "cpu",
+ **kwargs
+ ):
+ if multi_step == "none":
+ multi_step = "forward-euler 1"
+
+ # cfgs
+ self.graph_optimizer_cfg = graph_optimizer_cfg
+
+ # graph params
+ self._concat_mask = concat_mask
+ self._graph_MC_samples = graph_MC_samples
+ self._graph_max_stack = graph_max_stack
+ self._lambda_sparse = lambda_sparse
+
+ self.graph_optimizer = None
+
+ super(ReinforceCausalMech, self).__init__(
+ name=name,
+ input_variables=input_variables,
+ output_variables=output_variables,
+ longest_epoch=longest_epoch,
+ improvement_threshold=improvement_threshold,
+ patience=patience,
+ ensemble_num=ensemble_num,
+ elite_num=elite_num,
+ network_cfg=network_cfg,
+ encoder_cfg=encoder_cfg,
+ decoder_cfg=decoder_cfg,
+ optimizer_cfg=optimizer_cfg,
+ residual=residual,
+ encoder_reduction=encoder_reduction,
+ multi_step=multi_step,
+ logger=logger,
+ device=device,
+ **kwargs
+ )
+
+ def build_network(self):
+ input_dim = self.encoder_output_dim
+ if self._concat_mask:
+ input_dim += self.input_var_num
+
+ self.network = instantiate(self.network_cfg)(
+ input_dim=input_dim,
+ output_dim=self.decoder_input_dim,
+ extra_dims=[self.output_var_num, self.ensemble_num],
+ ).to(self.device)
+
+ def build_graph(self):
+ self.graph = BernoulliGraph(
+ in_dim=self.input_var_num,
+ out_dim=self.output_var_num,
+ include_input=False,
+ init_param=1e-6,
+ requires_grad=True,
+ device=self.device,
+ )
+
+ def build_optimizer(self):
+ assert (
+ self.network is not None and self.graph is not None
+ ), "network and graph are both required when building optimizer"
+ super().build_optimizer()
+
+ # graph optimizer
+ self.graph_optimizer = instantiate(self.graph_optimizer_cfg)(self.graph.parameters)
+
+ @property
+ def causal_graph(self) -> torch.Tensor:
+ """property causal graph"""
+ assert self.graph is not None, "graph incorrectly initialized"
+
+ return self.graph.get_binary_adj_matrix(threshold=0.5)
+
+ def single_step_forward(
+ self,
+ inputs: MutableMapping[str, torch.Tensor],
+ train: bool = False,
+ mask: Optional[torch.Tensor] = None,
+ ) -> Dict[str, torch.Tensor]:
+ batch_size, extra_dim = self.get_inputs_batch_size(inputs)
+ assert len(extra_dim) == 0, "unexpected dimension in the inputs"
+
+ inputs_tensor = torch.zeros(self.ensemble_num, batch_size, self.input_var_num, self.encoder_output_dim).to(self.device)
+ for i, var in enumerate(self.input_variables):
+ out = self.variable_encoders[var.name](inputs[var.name].to(self.device))
+ inputs_tensor[..., i, :] = out
+
+ if train and self.discovery:
+ # [ensemble-num, batch-size, input-var-num, output-var-num]
+ adj_matrix = self.graph.sample(None, sample_size=(self.ensemble_num, batch_size))
+ # [ensemble-num, batch-size, output-var-num, input-var-num]
+ mask = adj_matrix.transpose(-1, -2)
+ # [output-var-num, ensemble-num, batch-size, input-var-num]
+ mask = mask.permute(2, 0, 1, 3)
+ else:
+ if mask is None:
+ mask = self.forward_mask
+ mask = mask.unsqueeze(-2).unsqueeze(-2)
+ mask = mask.repeat(1, self.ensemble_num, batch_size, 1)
+
+ # [output-var-num, ensemble-num, batch-size, encoder-output-dim]
+ reduced_inputs_tensor = self.reduce_encoder_output(inputs_tensor, mask=mask)
+ if self._concat_mask:
+ # [output-var-num, ensemble-num, batch-size, encoder-output-dim + input-var-num]
+ reduced_inputs_tensor = torch.cat([reduced_inputs_tensor, mask], dim=-1)
+ output_tensor = self.network(reduced_inputs_tensor)
+
+ outputs = {}
+ for i, var in enumerate(self.output_variables):
+ hid = output_tensor[i]
+ outputs[var.name] = self.variable_decoders[var.name](hid)
+
+ if self.residual:
+ outputs = self.residual_outputs(inputs, outputs)
+ return outputs
+
+ def forward(
+ self,
+ inputs: MutableMapping[str, torch.Tensor],
+ train: bool = False,
+ mask: Optional[torch.Tensor] = None,
+ ) -> Dict[str, torch.Tensor]:
+ if self.multi_step.startswith("forward-euler"):
+ step_num = int(self.multi_step.split()[-1])
+
+ outputs = {}
+ for step in range(step_num):
+ outputs = self.single_step_forward(inputs, train=train, mask=mask)
+ if step < step_num - 1:
+ for name in filter(lambda s: s.startswith("obs"), inputs.keys()):
+ inputs[name] = outputs["next_{}".format(name)][..., : inputs[name].shape[-1]]
+ else:
+ raise NotImplementedError("multi-step method {} is not supported".format(self.multi_step))
+
+ return outputs
+
+ def train_graph(self, loader: DataLoader, data_ratio: float):
+ num_batches = len(loader)
+ train_num = int(num_batches * data_ratio)
+
+ grads = torch.tensor([0], dtype=torch.float32)
+ for i, (inputs, targets) in enumerate(loader):
+ if train_num <= i:
+ break
+
+ grads = grads + self._update_graph(inputs, targets)
+
+ return grads
+
+ def _update_graph(
+ self,
+ inputs: MutableMapping[str, torch.Tensor],
+ targets: MutableMapping[str, torch.Tensor],
+ ) -> torch.Tensor:
+ # do Monte-Carlo sampling to obtain adjacent matrices and corresponding model losses
+ adj_matrices, losses = self._MC_sample(inputs, targets)
+
+ # calculate graph gradients
+ graph_grads = self._estimate_graph_grads(adj_matrices, losses)
+
+ # update graph
+ graph_params = self.graph.parameters[0] # only one tensor parameter
+ self.graph_optimizer.zero_grad()
+ graph_params.grad = graph_grads
+ self.graph_optimizer.step()
+
+ return graph_grads.detach().cpu()
+
+ def _MC_sample(
+ self,
+ inputs: MutableMapping[str, torch.Tensor],
+ targets: MutableMapping[str, torch.Tensor],
+ ) -> Tuple[torch.Tensor]:
+ num_graph_list = [
+ min(self._graph_max_stack, self._graph_MC_samples - i * self._graph_max_stack)
+ for i in range(math.ceil(self._graph_MC_samples / self._graph_max_stack))
+ ]
+ num_graph_list = [(num_graph_list[i], sum(num_graph_list[:i])) for i in range(len(num_graph_list))]
+
+ # sample graphs
+ adj_mats = self.graph.sample(None, sample_size=self._graph_MC_samples)
+
+ # evaluate scores using the sampled adjacency matrices and data
+ batch_size, extra_dim = self.get_inputs_batch_size(inputs)
+ assert len(extra_dim) == 0, "unexpected dimension in the inputs"
+
+ losses = []
+ for graph_count, start_idx in num_graph_list:
+ # [ensemble-num, samples*batch_size, input-var-num, output-var-num]
+ expanded_adj_mats = (
+ adj_mats[None, start_idx : start_idx + graph_count, None]
+ .expand(self.ensemble_num, -1, batch_size, -1, -1)
+ .flatten(1, 2)
+ )
+ expanded_masks = expanded_adj_mats.transpose(-1, -2).permute(2, 0, 1, 3)
+
+ expanded_inputs = {}
+ expanded_targets = {}
+ # expand inputs and targets
+ for in_key in inputs:
+ expanded_inputs[in_key] = inputs[in_key].repeat(1, graph_count, 1)
+ for tar_key in targets:
+ expanded_targets[tar_key] = targets[tar_key].repeat(1, graph_count, 1)
+
+ with torch.no_grad():
+ outputs = self.forward(expanded_inputs, train=False, mask=expanded_masks)
+ loss = variable_loss_func(outputs, expanded_targets, self.output_variables, device=self.device)
+ loss = loss.reshape(loss.shape[0], graph_count, batch_size, -1)
+ losses.append(loss.mean(dim=(0, 2)))
+ losses = sum(losses)
+
+ return adj_mats, losses
+
+ def _estimate_graph_grads(
+ self,
+ adj_matrices: torch.Tensor,
+ losses: torch.Tensor,
+ ) -> torch.Tensor:
+ """Use MC samples and corresponding losses to estimate gradients via REINFORCE.
+
+ Args:
+ adj_matrices (tensor): MC sampled adjacent matrices from current graph,
+ shaped [num-samples, input-var-num, output-var-num].
+ losses (tensor): the model losses corresponding to the adjacent matrices,
+ shaped [num-samples, output-var-num]
+
+ """
+ num_graphs = adj_matrices.shape[0]
+ losses = losses.unsqueeze(dim=1)
+
+ # calculate graph gradients
+ edge_prob = self.graph.get_adj_matrix()
+ num_pos = adj_matrices.sum(dim=0)
+ num_neg = num_graphs - num_pos
+ mask = ((num_pos > 0) * (num_neg > 0)).float()
+ pos_grads = (losses * adj_matrices).sum(dim=0) / num_pos.clamp_(min=1e-5)
+ neg_grads = (losses * (1 - adj_matrices)).sum(dim=0) / num_neg.clamp_(min=1e-5)
+ graph_grads = mask * edge_prob * (1 - edge_prob) * (pos_grads - neg_grads + self._lambda_sparse)
+
+ return graph_grads
+
+ def learn(
+ self,
+ train_loader: DataLoader,
+ valid_loader: DataLoader,
+ graph_data_ratio: float = 0.5,
+ train_graph_freq: int = 2,
+ work_dir: Optional[Union[str, pathlib.Path]] = None,
+ **kwargs
+ ):
+ assert 0 <= graph_data_ratio <= 1, "graph data ratio should be in [0, 1]"
+
+ best_weights: Optional[Dict] = None
+ epoch_iter = range(self.longest_epoch) if self.longest_epoch >= 0 else count()
+ epochs_since_update = 0
+
+ loss_fn = partial(variable_loss_func, output_variables=self.output_variables, device=self.device)
+ train_fn = partial(train_func, forward=partial(self.forward, train=True), optimizer=self.optimizer, loss_func=loss_fn)
+ eval_fn = partial(eval_func, forward=partial(self.forward, train=False), loss_func=loss_fn)
+
+ best_eval_loss = eval_fn(valid_loader).mean(dim=(-2, -1))
+ for epoch in epoch_iter:
+ if self.discovery and epoch % train_graph_freq == 0:
+ grads = self.train_graph(train_loader, data_ratio=graph_data_ratio)
+ print(self.graph.parameters[0])
+ print(self.graph.get_binary_adj_matrix())
+
+ train_loss = train_fn(train_loader)
+ eval_loss = eval_fn(valid_loader)
+
+ maybe_best_weights = self._maybe_get_best_weights(
+ best_eval_loss, eval_loss.mean(dim=(-2, -1)), self.improvement_threshold
+ )
+ if maybe_best_weights:
+ # best loss
+ best_eval_loss = torch.minimum(best_eval_loss, eval_loss.mean(dim=(-2, -1)))
+ best_weights = maybe_best_weights
+ epochs_since_update = 0
+ else:
+ epochs_since_update += 1
+
+ # log
+ self.total_epoch += 1
+ if self.logger is not None:
+ self.logger.record("{}/epoch".format(self.name), epoch)
+ self.logger.record("{}/epoch_since_update".format(self.name), epochs_since_update)
+ self.logger.record("{}/train_dataset_size".format(self.name), len(train_loader.dataset))
+ self.logger.record("{}/valid_dataset_size".format(self.name), len(valid_loader.dataset))
+ self.logger.record("{}/train_loss".format(self.name), train_loss.mean().item())
+ self.logger.record("{}/val_loss".format(self.name), eval_loss.mean().item())
+ self.logger.record("{}/best_val_loss".format(self.name), best_eval_loss.mean().item())
+
+ if self.discovery and epoch % train_graph_freq == 0:
+ self.logger.record("{}/graph_update_grads".format(self.name), grads.abs().mean().item())
+
+ self.logger.dump(self.total_epoch)
+
+ if self.patience and epochs_since_update >= self.patience:
+ break
+
+ # saving the best models
+ self._maybe_set_best_weights_and_elite(best_weights, best_eval_loss)
+
+ self.save(save_dir=work_dir)
+
+ def _maybe_get_best_weights(
+ self, best_val_loss: torch.Tensor, val_loss: torch.Tensor, threshold: float = 0.01
+ ) -> Optional[Dict]:
+ improvement = (best_val_loss - val_loss) / torch.abs(best_val_loss)
+ if (improvement > threshold).any().item():
+ best_weights = {
+ "graph": copy.deepcopy(self.graph.parameters[0].detach().clone()),
+ "model": copy.deepcopy(self.network.state_dict()),
+ }
+ else:
+ best_weights = None
+
+ return best_weights
+
+ def _maybe_set_best_weights_and_elite(self, best_weights: Optional[Dict], best_val_score: torch.Tensor):
+ if best_weights is not None:
+ self.network.load_state_dict(best_weights["model"])
+ self.graph.set_data(best_weights["graph"])
+
+ sorted_indices = np.argsort(best_val_score.tolist())
+ self.elite_indices = sorted_indices[: self.elite_num]
diff --git a/cmrl/models/causal_mech/util.py b/cmrl/models/causal_mech/util.py
new file mode 100644
index 0000000..3ec9d6c
--- /dev/null
+++ b/cmrl/models/causal_mech/util.py
@@ -0,0 +1,193 @@
+from typing import Callable, Dict, List, Union, MutableMapping
+from collections import defaultdict
+import math
+import time
+
+import torch
+from torch import Tensor
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from torch.optim import Optimizer
+from torch.distributions.von_mises import _log_modified_bessel_fn
+from tqdm import tqdm
+
+from cmrl.utils.variables import Variable, ContinuousVariable, DiscreteVariable, BinaryVariable, RadianVariable
+
+
+def von_mises_nll_loss(
+ input: Tensor,
+ target: Tensor,
+ var: Tensor,
+ full: bool = False,
+ eps: float = 1e-6,
+ reduction: str = "mean",
+) -> Tensor:
+ r"""Von Mises negative log likelihood loss.
+
+ Args:
+ input: loc of the Von Mises distribution.
+ target: sample from the Von Mises distribution.
+ var: tensor of positive var(s), one for each of the expectations
+ in the input (heteroscedastic), or a single one (homoscedastic).
+ full (bool, optional): include the constant term in the loss calculation. Default: ``False``.
+ eps (float, optional): value added to var, for stability. Default: 1e-6.
+ reduction (string, optional): specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the output is the average of all batch member losses,
+ ``'sum'``: the output is the sum of all batch member losses.
+ Default: ``'mean'``.
+ """
+ # Entries of var must be non-negative
+ if torch.any(var < 0):
+ raise ValueError("var has negative entry/entries")
+
+ # Clamp for stability
+ var = var.clone()
+ with torch.no_grad():
+ var.clamp_(min=eps)
+
+ concentration = 1 / var
+ loss = -concentration * torch.cos(input - target) + _log_modified_bessel_fn(concentration, order=0)
+ if full:
+ loss += math.log(2 * math.pi)
+
+ if reduction == "mean":
+ return loss.mean()
+ elif reduction == "sum":
+ return loss.sum()
+ else:
+ return loss
+
+
+def circular_gaussian_nll_loss(
+ input: Tensor,
+ target: Tensor,
+ var: Tensor,
+ full: bool = False,
+ eps: float = 1e-6,
+ reduction: str = "mean",
+) -> Tensor:
+ # Entries of var must be non-negative
+ if torch.any(var < 0):
+ raise ValueError("var has negative entry/entries")
+
+ # Clamp for stability
+ var = var.clone()
+ with torch.no_grad():
+ var.clamp_(min=eps)
+
+ diff = torch.remainder(input - target, 2 * torch.pi)
+ diff[diff > torch.pi] = 2 * torch.pi - diff[diff > torch.pi]
+ loss = 0.5 * (torch.log(var) + diff**2 / var)
+ if full:
+ loss += 0.5 * math.log(2 * math.pi)
+
+ if reduction == "mean":
+ return loss.mean()
+ elif reduction == "sum":
+ return loss.sum()
+ else:
+ return loss
+
+
+def variable_loss_func(
+ outputs: Dict[str, torch.Tensor],
+ targets: Dict[str, torch.Tensor],
+ output_variables: List[Variable],
+ device: Union[str, torch.device] = "cpu",
+):
+ dims = list(outputs.values())[0].shape[:-1]
+ total_loss = torch.zeros(*dims, len(outputs)).to(device)
+
+ for i, var in enumerate(output_variables):
+ output = outputs[var.name]
+ target = targets[var.name].to(device)
+ if isinstance(var, ContinuousVariable):
+ dim = target.shape[-1] # (xxx, ensemble-num, batch-size, dim)
+ assert output.shape[-1] == 2 * dim
+ mean, log_var = output[..., :dim], output[..., dim:]
+ # clip log_var to avoid nan loss
+ log_var = torch.clamp(log_var, min=-10, max=10)
+ loss = F.gaussian_nll_loss(mean, target, log_var.exp(), reduction="none", full=True, eps=1e-4).mean(dim=-1)
+ total_loss[..., i] = loss
+ elif isinstance(var, RadianVariable):
+ dim = target.shape[-1] # (xxx, ensemble-num, batch-size, dim)
+ assert output.shape[-1] == 2 * dim
+ mean, log_var = output[..., :dim], output[..., dim:]
+ loss = circular_gaussian_nll_loss(mean, target, log_var.exp(), reduction="none").mean(dim=-1)
+ total_loss[..., i] = loss
+ elif isinstance(var, DiscreteVariable):
+ # TODO: onehot to int?
+ raise NotImplementedError
+ elif isinstance(var, BinaryVariable):
+ total_loss[..., i] = F.binary_cross_entropy(output, target, reduction="none")
+ else:
+ raise NotImplementedError
+
+ if torch.isnan(total_loss[..., i]).any():
+ raise ValueError(f"nan loss for {var.name} ({type(var)})")
+ elif torch.isinf(total_loss[..., i]).any():
+ raise ValueError(f"inf loss for {var.name} ({type(var)})")
+ return total_loss
+
+
+def train_func(
+ loader: DataLoader,
+ forward: Callable[[MutableMapping[str, torch.Tensor]], Dict[str, torch.Tensor]],
+ optimizer: Optimizer,
+ loss_func: Callable[[MutableMapping[str, torch.Tensor], MutableMapping[str, torch.Tensor]], torch.Tensor],
+):
+ """train for data
+
+ Args:
+ forward: forward function.
+ loader: train data-loader.
+ optimizer: Optimizer
+ loss_func: loss function
+
+ Returns: tensor of train loss, with shape (xxx, ensemble-num, batch-size).
+
+ """
+ batch_loss_list = []
+ with tqdm(loader) as pbar:
+ for inputs, targets in loader:
+ outputs = forward(inputs)
+ loss = loss_func(outputs, targets) # ensemble-num, batch-size, output-var-num
+
+ optimizer.zero_grad()
+ loss.mean().backward()
+ optimizer.step()
+ batch_loss_list.append(loss)
+
+ pbar.set_description(f"train loss: {loss.mean().item():.4f}")
+ pbar.update()
+
+ return torch.cat(batch_loss_list, dim=-2).detach().cpu()
+
+
+def eval_func(
+ loader: DataLoader,
+ forward: Callable[[MutableMapping[str, torch.Tensor]], Dict[str, torch.Tensor]],
+ loss_func: Callable[[MutableMapping[str, torch.Tensor], MutableMapping[str, torch.Tensor]], torch.Tensor],
+):
+ """evaluate for data
+
+ Args:
+ forward: forward function.
+ loader: train data-loader.
+ loss_func: loss function
+
+ Returns: tensor of train loss, with shape (xxx, ensemble-num, batch-size).
+
+ """
+ batch_loss_list = []
+ with torch.no_grad():
+ with tqdm(loader) as pbar:
+ for inputs, targets in loader:
+ outputs = forward(inputs)
+ loss = loss_func(outputs, targets) # ensemble-num, batch-size, output-var-num
+ batch_loss_list.append(loss)
+
+ pbar.set_description(f"eval loss: {loss.mean().item():.4f}")
+ pbar.update()
+ return torch.cat(batch_loss_list, dim=-2).detach().cpu()
diff --git a/cmrl/models/constant.py b/cmrl/models/constant.py
new file mode 100644
index 0000000..7003ccd
--- /dev/null
+++ b/cmrl/models/constant.py
@@ -0,0 +1,55 @@
+from omegaconf import DictConfig
+
+NETWORK_CFG = DictConfig(
+ dict(
+ _target_="cmrl.models.networks.ParallelMLP",
+ _partial_=True,
+ _recursive_=False,
+ hidden_dims=[200, 200],
+ bias=True,
+ activation_fn_cfg=dict(_target_="torch.nn.SiLU"),
+ )
+)
+
+ENCODER_CFG = DictConfig(
+ dict(
+ _target_="cmrl.models.networks.VariableEncoder",
+ _partial_=True,
+ _recursive_=False,
+ output_dim=100,
+ hidden_dims=[100],
+ bias=True,
+ activation_fn_cfg=dict(_target_="torch.nn.SiLU"),
+ )
+)
+
+DECODER_CFG = DictConfig(
+ dict(
+ _target_="cmrl.models.networks.VariableDecoder",
+ _partial_=True,
+ _recursive_=False,
+ input_dim=100,
+ hidden_dims=[100],
+ bias=True,
+ activation_fn_cfg=dict(_target_="torch.nn.SiLU"),
+ )
+)
+
+OPTIMIZER_CFG = DictConfig(
+ dict(
+ _target_="torch.optim.Adam",
+ _partial_=True,
+ lr=1e-4,
+ weight_decay=1e-5,
+ eps=1e-8,
+ )
+)
+
+SCHEDULER_CFG = DictConfig(
+ dict(
+ _target_="torch.optim.lr_scheduler.StepLR",
+ _partial_=True,
+ step_size=1,
+ gamma=1,
+ )
+)
diff --git a/cmrl/models/data_loader.py b/cmrl/models/data_loader.py
new file mode 100644
index 0000000..e9f5843
--- /dev/null
+++ b/cmrl/models/data_loader.py
@@ -0,0 +1,103 @@
+from typing import Optional, MutableMapping
+
+from gym import spaces, Env
+import torch
+from torch.utils.data import Dataset, default_collate
+import numpy as np
+from stable_baselines3.common.buffers import ReplayBuffer, DictReplayBuffer
+
+from cmrl.utils.variables import to_dict_by_space
+
+
+def buffer_to_dict(state_space, action_space, obs2state_fn, replay_buffer: ReplayBuffer, mech: str, device: str = "cpu"):
+ assert mech in ["transition", "reward_mech", "termination_mech"]
+ # dict action is not supported by SB3(so not done by cmrl)
+ assert not isinstance(action_space, spaces.Dict)
+ assert hasattr(replay_buffer, "extra_obs")
+ assert hasattr(replay_buffer, "next_extra_obs")
+
+ real_buffer_size = replay_buffer.buffer_size if replay_buffer.full else replay_buffer.pos
+
+ if hasattr(replay_buffer, "extra_obs"):
+ states = obs2state_fn(replay_buffer.observations[:real_buffer_size, 0], replay_buffer.extra_obs[:real_buffer_size, 0])
+ else:
+ states = replay_buffer.observations[:real_buffer_size, 0]
+ state_dict = to_dict_by_space(states, state_space, prefix="obs", to_tensor=True)
+ act_dict = to_dict_by_space(replay_buffer.actions[:real_buffer_size, 0], action_space, prefix="act", to_tensor=True)
+
+ if hasattr(replay_buffer, "next_extra_obs"):
+ next_states = obs2state_fn(
+ replay_buffer.next_observations[:real_buffer_size, 0], replay_buffer.next_extra_obs[:real_buffer_size, 0]
+ )
+ else:
+ next_states = replay_buffer.next_observations[:real_buffer_size, 0]
+ next_state_dict = to_dict_by_space(next_states, state_space, prefix="next_obs", to_tensor=True)
+
+ inputs = {}
+ inputs.update(state_dict)
+ inputs.update(act_dict)
+
+ if mech == "transition":
+ outputs = next_state_dict
+ elif mech == "reward_mech":
+ rewards = replay_buffer.rewards[:real_buffer_size, 0]
+ rewards_dict = {"reward": torch.from_numpy(rewards[:, None])}
+ inputs.update(next_state_dict)
+ outputs = rewards_dict
+ elif mech == "termination_mech":
+ terminals = replay_buffer.dones[:real_buffer_size, 0] * (1 - replay_buffer.timeouts[:real_buffer_size, 0])
+ terminals_dict = {"terminal": torch.from_numpy(terminals[:, None])}
+ inputs.update(next_state_dict)
+ outputs = terminals_dict
+ else:
+ raise NotImplementedError("support mechs in [transition, reward_mech, termination_mech] only")
+
+ return inputs, outputs
+
+
+class EnsembleBufferDataset(Dataset):
+ def __init__(
+ self,
+ inputs: MutableMapping,
+ outputs: MutableMapping,
+ training: bool = False,
+ train_ratio: float = 0.8,
+ ensemble_num: int = 7,
+ seed: int = 10086,
+ ):
+ self.inputs = inputs
+ self.outputs = outputs
+ self.training = training
+ self.train_ratio = train_ratio
+ self.ensemble_num = ensemble_num
+ self.seed = seed
+ self.indexes = None
+
+ size = next(iter(inputs.values())).shape[0]
+
+ np.random.seed(self.seed)
+ permutation = np.random.permutation(size)
+ if self.training:
+ train_indexes = permutation[: int(size * self.train_ratio)]
+ indexes = [np.random.permutation(train_indexes) for _ in range(self.ensemble_num)]
+ else:
+ valid_indexes = permutation[int(size * self.train_ratio) :]
+ indexes = [valid_indexes for _ in range(self.ensemble_num)]
+ self.indexes = np.array(indexes).T
+
+ def __getitem__(self, item):
+ index = self.indexes[item]
+
+ inputs = dict([(key, self.inputs[key][index]) for key in self.inputs])
+ outputs = dict([(key, self.outputs[key][index]) for key in self.outputs])
+ return inputs, outputs
+
+ def __len__(self):
+ return len(self.indexes)
+
+
+def collate_fn(data):
+ inputs, outputs = default_collate(data)
+ inputs = dict([(key, value.transpose(0, 1)) for key, value in inputs.items()])
+ outputs = dict([(key, value.transpose(0, 1)) for key, value in outputs.items()])
+ return [inputs, outputs]
diff --git a/cmrl/models/dynamics.py b/cmrl/models/dynamics.py
new file mode 100644
index 0000000..10367dd
--- /dev/null
+++ b/cmrl/models/dynamics.py
@@ -0,0 +1,84 @@
+import abc
+from collections import ChainMap
+import pathlib
+from typing import Dict, List, Optional, Tuple, Union
+from functools import partial
+
+import numpy as np
+import torch
+from gym import spaces
+from torch.utils.data import DataLoader
+from stable_baselines3.common.logger import Logger
+from stable_baselines3.common.buffers import ReplayBuffer
+
+from cmrl.utils.variables import to_dict_by_space
+from cmrl.models.causal_mech.base import BaseCausalMech
+from cmrl.models.data_loader import buffer_to_dict
+from cmrl.types import Obs2StateFnType, State2ObsFnType
+
+
+class Dynamics:
+ def __init__(
+ self,
+ transition: BaseCausalMech,
+ state_space: spaces.Space,
+ action_space: spaces.Space,
+ obs2state_fn: Obs2StateFnType,
+ state2obs_fn: State2ObsFnType,
+ reward_mech: Optional[BaseCausalMech] = None,
+ termination_mech: Optional[BaseCausalMech] = None,
+ seed: int = 7,
+ logger: Optional[Logger] = None,
+ ):
+ self.transition = transition
+ self.state_space = state_space
+ self.action_space = action_space
+ self.obs2state_fn = obs2state_fn
+ self.state2obs_fn = state2obs_fn
+ self.reward_mech = reward_mech
+ self.termination_mech = termination_mech
+ self.seed = seed
+ self.logger = logger
+
+ self.learn_reward = reward_mech is not None
+ self.learn_termination = termination_mech is not None
+
+ self.device = self.transition.device
+ pass
+
+ def learn(self, real_replay_buffer: ReplayBuffer, work_dir: Optional[Union[str, pathlib.Path]] = None, **kwargs):
+ get_dataset = partial(
+ buffer_to_dict,
+ state_space=self.state_space,
+ action_space=self.action_space,
+ obs2state_fn=self.obs2state_fn,
+ replay_buffer=real_replay_buffer,
+ device=self.device
+ )
+
+ # transition
+ self.transition.learn(*get_dataset(mech="transition"), work_dir=work_dir)
+ # reward-mech
+ if self.learn_reward:
+ self.reward_mech.learn(*get_dataset(mech="reward_mech"), work_dir=work_dir)
+ # termination-mech
+ if self.learn_termination:
+ self.termination_mech.learn(*get_dataset(mech="termination_mech"), work_dir=work_dir)
+
+ def step(self, batch_obs, batch_action):
+ with torch.no_grad():
+ obs_dict = to_dict_by_space(batch_obs, self.state_space, "obs",
+ repeat=7, to_tensor=True, device=self.device)
+ act_dict = to_dict_by_space(batch_action, self.action_space, "act",
+ repeat=7, to_tensor=True, device=self.device)
+
+ inputs = ChainMap(obs_dict, act_dict)
+ outputs = self.transition.forward(inputs)
+
+ batch_next_state = torch.concat([tensor.mean(dim=0)[:, :1] for tensor in outputs.values()],
+ dim=-1).cpu().numpy()
+ batch_next_obs = self.state2obs_fn(batch_next_state)
+ info = {
+ "origin-next_obs": torch.concat([tensor[:, :, :1] for tensor in outputs.values()], dim=-1).cpu().numpy()}
+
+ return batch_next_obs, None, None, info
diff --git a/cmrl/models/dynamics/__init__.py b/cmrl/models/dynamics/__init__.py
deleted file mode 100644
index 3b6c6f1..0000000
--- a/cmrl/models/dynamics/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .base_dynamics import BaseDynamics
-from .constraint_based_dynamics import ConstraintBasedDynamics
-from .plain_dynamics import PlainEnsembleDynamics
diff --git a/cmrl/models/dynamics/base_dynamics.py b/cmrl/models/dynamics/base_dynamics.py
deleted file mode 100644
index 8bc05cd..0000000
--- a/cmrl/models/dynamics/base_dynamics.py
+++ /dev/null
@@ -1,221 +0,0 @@
-import abc
-import collections
-import pathlib
-from typing import Dict, List, Optional, Tuple, Union
-
-import numpy as np
-import torch
-from stable_baselines3.common.logger import Logger
-from stable_baselines3.common.buffers import ReplayBuffer
-
-from cmrl.models.reward_mech.base_reward_mech import BaseRewardMech
-from cmrl.models.termination_mech.base_termination_mech import BaseTerminationMech
-from cmrl.models.transition.base_transition import BaseTransition
-from cmrl.types import InteractionBatch
-from cmrl.util.transition_iterator import BootstrapIterator, TransitionIterator
-
-
-def split_dict(old_dict: Dict, need_keys: List[str]):
- return dict([(key, old_dict[key]) for key in need_keys])
-
-
-class BaseDynamics:
- _MECH_TO_VARIABLE = {
- "transition": "batch_next_obs",
- "reward_mech": "batch_reward",
- "termination_mech": "batch_terminal",
- }
- _VARIABLE_TO_MECH = dict([(value, key) for key, value in _MECH_TO_VARIABLE.items()])
-
- def __init__(
- self,
- transition: BaseTransition,
- learned_reward: bool = True,
- reward_mech: Optional[BaseRewardMech] = None,
- learned_termination: bool = False,
- termination_mech: Optional[BaseTerminationMech] = None,
- optim_lr: float = 1e-4,
- weight_decay: float = 1e-5,
- optim_eps: float = 1e-8,
- logger: Optional[Logger] = None,
- ):
- super(BaseDynamics, self).__init__()
- self.transition = transition
- self.learned_reward = learned_reward
- self.reward_mech = reward_mech
- self.learned_termination = learned_termination
- self.termination_mech = termination_mech
-
- self.optim_lr = optim_lr
- self.weight_decay = weight_decay
- self.optim_eps = optim_eps
- self.logger = logger
-
- self.device = self.transition.device
- self.ensemble_num = self.transition.ensemble_num
-
- self.learn_mech = ["transition"]
- self.transition_optimizer = torch.optim.Adam(
- self.transition.parameters(),
- lr=optim_lr,
- weight_decay=weight_decay,
- eps=optim_eps,
- )
- if self.learned_reward:
- self.reward_mech_optimizer = torch.optim.Adam(
- self.reward_mech.parameters(),
- lr=optim_lr,
- weight_decay=weight_decay,
- eps=optim_eps,
- )
- self.learn_mech.append("reward_mech")
- if self.learned_termination:
- self.termination_mech_optimizer = torch.optim.Adam(
- self.termination_mech.parameters(),
- lr=optim_lr,
- weight_decay=weight_decay,
- eps=optim_eps,
- )
- self.learn_mech.append("termination_mech")
-
- self.total_epoch = {}
- for mech in self.learn_mech:
- self.total_epoch[mech] = 0
-
- @abc.abstractmethod
- def learn(self, replay_buffer: ReplayBuffer, **kwargs):
- pass
-
- # auxiliary method for "single batch data"
- def get_3d_tensor(self, data: Union[np.ndarray, torch.Tensor], is_ensemble: bool):
- if isinstance(data, np.ndarray):
- data = torch.from_numpy(data)
- if is_ensemble:
- if data.ndim == 2: # reward or terminal
- data = data.unsqueeze(data.ndim)
- return data.to(self.device)
- else:
- if data.ndim == 1: # reward or terminal
- data = data.unsqueeze(data.ndim)
- return data.repeat([self.ensemble_num, 1, 1]).to(self.device)
-
- # auxiliary method for "interaction batch data"
- def get_mech_loss(
- self,
- batch: InteractionBatch,
- mech: str = "transition",
- loss_type: str = "default",
- is_ensemble: bool = False,
- ):
- data = {}
- for attr in batch.attrs:
- data[attr] = self.get_3d_tensor(getattr(batch, attr).copy(), is_ensemble=is_ensemble)
- model_in = split_dict(data, ["batch_obs", "batch_action"])
-
- if loss_type == "default":
- loss_type = "mse" if getattr(self, mech).deterministic else "nll"
-
- variable = self._MECH_TO_VARIABLE[mech]
- get_loss = getattr(getattr(self, mech), "get_{}_loss".format(loss_type))
- return get_loss(model_in, data[variable])
-
- # auxiliary method for "replay buffer"
- def dataset_split(
- self,
- replay_buffer: ReplayBuffer,
- validation_ratio: float = 0.2,
- batch_size: int = 256,
- shuffle_each_epoch: bool = True,
- bootstrap_permutes: bool = False,
- ) -> Tuple[TransitionIterator, Optional[TransitionIterator]]:
- size = replay_buffer.buffer_size if replay_buffer.full else replay_buffer.pos
- data = InteractionBatch(
- replay_buffer.observations[:size, 0].astype(np.float32),
- replay_buffer.actions[:size, 0],
- replay_buffer.next_observations[:size, 0].astype(np.float32),
- replay_buffer.rewards[:size, 0],
- replay_buffer.dones[:size, 0],
- )
-
- val_size = int(len(data) * validation_ratio)
- train_size = len(data) - val_size
- train_data = data[:train_size]
- train_iter = BootstrapIterator(
- train_data,
- batch_size,
- self.ensemble_num,
- shuffle_each_epoch=shuffle_each_epoch,
- permute_indices=bootstrap_permutes,
- )
-
- val_iter = None
- if val_size > 0:
- val_data = data[train_size:]
- val_iter = TransitionIterator(val_data, batch_size, shuffle_each_epoch=False)
-
- return train_iter, val_iter
-
- # auxiliary method for "dataset"
- def evaluate(
- self,
- dataset: TransitionIterator,
- mech: str = "transition",
- ):
- assert not isinstance(dataset, BootstrapIterator)
-
- batch_loss_list = []
- with torch.no_grad():
- for batch in dataset:
- val_loss = self.get_mech_loss(batch, mech=mech, loss_type="mse", is_ensemble=False)
- batch_loss_list.append(val_loss)
- return torch.cat(batch_loss_list, dim=batch_loss_list[0].ndim - 2).cpu()
-
- def train(
- self,
- dataset: TransitionIterator,
- mech: str = "transition",
- ):
- assert isinstance(dataset, BootstrapIterator)
-
- batch_loss_list = []
- for batch in dataset:
- train_loss = self.get_mech_loss(batch, mech=mech, is_ensemble=True)
- optim = getattr(self, "{}_optimizer".format(mech))
- optim.zero_grad()
- train_loss.mean().backward()
- optim.step()
- batch_loss_list.append(train_loss)
- return torch.cat(batch_loss_list, dim=batch_loss_list[0].ndim - 2).detach().cpu()
-
- def query(self, obs, action, return_as_np=True):
- result = collections.defaultdict(dict)
- obs = self.get_3d_tensor(obs, is_ensemble=False)
- action = self.get_3d_tensor(action, is_ensemble=False)
- for mech in self.learn_mech:
- with torch.no_grad():
- mean, logvar = getattr(self, "{}".format(mech)).forward(obs, action)
- variable = self.get_variable_by_mech(mech)
- if return_as_np:
- result[variable]["mean"] = mean.cpu().numpy()
- result[variable]["logvar"] = logvar.cpu().numpy()
- else:
- result[variable]["mean"] = mean.cpu()
- result[variable]["logvar"] = logvar.cpu()
- return result
-
- # other auxiliary method
- def save(self, save_dir: Union[str, pathlib.Path]):
- for mech in self.learn_mech:
- getattr(self, mech).save(save_dir=save_dir)
-
- def load(self, load_dir: Union[str, pathlib.Path], load_device: Optional[str] = None):
- for mech in self.learn_mech:
- getattr(self, mech).load(load_dir=load_dir, load_device=load_device)
-
- def get_variable_by_mech(self, mech: str) -> str:
- assert mech in self._MECH_TO_VARIABLE
- return self._MECH_TO_VARIABLE[mech]
-
- def get_mach_by_variable(self, variable: str) -> str:
- assert variable in self._VARIABLE_TO_MECH
- return self._VARIABLE_TO_MECH[variable]
diff --git a/cmrl/models/dynamics/constraint_based_dynamics.py b/cmrl/models/dynamics/constraint_based_dynamics.py
deleted file mode 100644
index 5463566..0000000
--- a/cmrl/models/dynamics/constraint_based_dynamics.py
+++ /dev/null
@@ -1,183 +0,0 @@
-import copy
-import itertools
-import pathlib
-from typing import Callable, Dict, List, Optional, Tuple, Union, cast
-
-import numpy as np
-import torch
-from stable_baselines3.common.logger import Logger
-from stable_baselines3.common.buffers import ReplayBuffer
-
-from cmrl.models.dynamics.base_dynamics import BaseDynamics
-from cmrl.models.networks.mlp import EnsembleMLP
-from cmrl.models.reward_mech.base_reward_mech import BaseRewardMech
-from cmrl.models.termination_mech.base_termination_mech import BaseTerminationMech
-from cmrl.models.transition.base_transition import BaseTransition
-from cmrl.models.causal_discovery.CMI_test import TransitionConditionalMutualInformationTest
-from cmrl.util.transition_iterator import BootstrapIterator, TransitionIterator
-from cmrl.models.util import to_tensor
-from cmrl.types import TensorType
-
-
-class ConstraintBasedDynamics(BaseDynamics):
- def __init__(
- self,
- transition: BaseTransition,
- learned_reward: bool = True,
- reward_mech: Optional[BaseRewardMech] = None,
- learned_termination: bool = False,
- termination_mech: Optional[BaseTerminationMech] = None,
- # trainer
- optim_lr: float = 1e-4,
- weight_decay: float = 1e-5,
- optim_eps: float = 1e-8,
- logger: Optional[Logger] = None,
- ):
- super(ConstraintBasedDynamics, self).__init__(
- transition=transition,
- learned_reward=learned_reward,
- reward_mech=reward_mech,
- learned_termination=learned_termination,
- termination_mech=termination_mech,
- optim_lr=optim_lr,
- weight_decay=weight_decay,
- optim_eps=optim_eps,
- logger=logger,
- )
- # self.cmi_test: Optional[EnsembleMLP] = None
- # self.build_cmi_test()
- #
- # self.cmi_test_optimizer = torch.optim.Adam(
- # self.cmi_test.parameters(),
- # lr=optim_lr,
- # weight_decay=weight_decay,
- # eps=optim_eps,
- # )
- # self.learn_mech.append("cmi_test")
- # self.total_epoch["cmi_test"] = 0
- # self._MECH_TO_VARIABLE["cmi_test"] = self._MECH_TO_VARIABLE["transition"]
-
- for mech in self.learn_mech:
- if hasattr(getattr(self, mech), "input_mask"):
- setattr(self, "{}_oracle_mask".format(mech), None)
- setattr(self, "{}_history_mask".format(mech), torch.ones(getattr(self, mech).input_mask.shape).to(self.device))
-
- def build_cmi_test(self):
- self.cmi_test = TransitionConditionalMutualInformationTest(
- obs_size=self.transition.obs_size,
- action_size=self.transition.action_size,
- ensemble_num=1,
- elite_num=1,
- residual=self.transition.residual,
- learn_logvar_bounds=self.transition.learn_logvar_bounds,
- num_layers=4,
- hid_size=200,
- activation_fn_cfg=self.transition.activation_fn_cfg,
- device=self.transition.device,
- )
-
- def set_oracle_mask(self, mech: str, mask: TensorType):
- assert hasattr(self, "{}_oracle_mask".format(mech))
- setattr(self, "{}_oracle_mask".format(mech), to_tensor(mask))
-
- def learn(
- self,
- # data
- replay_buffer: ReplayBuffer,
- # dataset split
- validation_ratio: float = 0.2,
- batch_size: int = 256,
- shuffle_each_epoch: bool = True,
- bootstrap_permutes: bool = False,
- # model learning
- longest_epoch: Optional[int] = None,
- improvement_threshold: float = 0.1,
- patience: int = 5,
- work_dir: Optional[Union[str, pathlib.Path]] = None,
- # other
- **kwargs
- ):
- train_dataset, val_dataset = self.dataset_split(
- replay_buffer,
- validation_ratio,
- batch_size,
- shuffle_each_epoch,
- bootstrap_permutes,
- )
-
- for mech in self.learn_mech:
-
- if hasattr(self, "{}_oracle_mask".format(mech)):
- getattr(self, mech).set_input_mask(getattr(self, "{}_oracle_mask".format(mech)))
-
- best_weights: Optional[Dict] = None
- epoch_iter = range(longest_epoch) if longest_epoch > 0 else itertools.count()
- epochs_since_update = 0
-
- best_val_loss = self.evaluate(val_dataset, mech=mech).mean(dim=(1, 2))
-
- for epoch in epoch_iter:
- train_loss = self.train(train_dataset, mech=mech)
- val_loss = self.evaluate(val_dataset, mech=mech).mean(dim=(1, 2))
-
- maybe_best_weights = self.maybe_get_best_weights(
- best_val_loss,
- val_loss,
- mech,
- improvement_threshold,
- )
- if maybe_best_weights:
- # best loss
- best_val_loss = torch.minimum(best_val_loss, val_loss)
- best_weights = maybe_best_weights
- epochs_since_update = 0
- else:
- epochs_since_update += 1
-
- # log
- self.total_epoch[mech] += 1
- if self.logger is not None:
- self.logger.record("{}/epoch".format(mech), epoch)
- self.logger.record("{}/train_dataset_size".format(mech), train_dataset.num_stored)
- self.logger.record("{}/val_dataset_size".format(mech), val_dataset.num_stored)
- self.logger.record("{}/train_loss".format(mech), train_loss.mean().item())
- self.logger.record("{}/val_loss".format(mech), val_loss.mean().item())
- self.logger.record("{}/best_val_loss".format(mech), best_val_loss.mean().item())
- self.logger.dump(self.total_epoch[mech])
- if patience and epochs_since_update >= patience:
- break
-
- # saving the best models:
- self.maybe_set_best_weights_and_elite(best_weights, best_val_loss, mech=mech)
- self.save(work_dir)
-
- def maybe_get_best_weights(
- self,
- best_val_loss: torch.Tensor,
- val_loss: torch.Tensor,
- mech: str = "transition",
- threshold: float = 0.01,
- ):
- improvement = (best_val_loss - val_loss) / torch.abs(best_val_loss)
- if (improvement > threshold).any().item():
- model = getattr(self, mech)
- best_weights = copy.deepcopy(model.state_dict())
- else:
- best_weights = None
-
- return best_weights
-
- def maybe_set_best_weights_and_elite(
- self,
- best_weights: Optional[Dict],
- best_val_loss: torch.Tensor,
- mech: str = "transition",
- ):
- model = getattr(self, mech)
- assert isinstance(model, EnsembleMLP)
-
- if best_weights is not None:
- model.load_state_dict(best_weights)
- sorted_indices = np.argsort(best_val_loss.tolist())
- elite_models = sorted_indices[: model.elite_num]
- model.set_elite_members(elite_models)
diff --git a/cmrl/models/dynamics/ncd_dynamics.py b/cmrl/models/dynamics/ncd_dynamics.py
deleted file mode 100644
index fa1e071..0000000
--- a/cmrl/models/dynamics/ncd_dynamics.py
+++ /dev/null
@@ -1,182 +0,0 @@
-import copy
-import itertools
-import pathlib
-from typing import Callable, Dict, List, Optional, Tuple, Union, cast
-
-import numpy as np
-import torch
-from stable_baselines3.common.logger import Logger
-from stable_baselines3.common.buffers import ReplayBuffer
-
-from cmrl.models.dynamics.base_dynamics import BaseDynamics
-from cmrl.models.networks.mlp import EnsembleMLP
-from cmrl.models.reward_mech.base_reward_mech import BaseRewardMech
-from cmrl.models.termination_mech.base_termination_mech import BaseTerminationMech
-from cmrl.models.transition.base_transition import BaseTransition
-from cmrl.models.causal_discovery.CMI_test import TransitionConditionalMutualInformationTest
-from cmrl.util.transition_iterator import BootstrapIterator, TransitionIterator
-from cmrl.models.util import to_tensor
-from cmrl.types import TensorType
-
-
-class ConstraintBasedDynamics(BaseDynamics):
- def __init__(
- self,
- transition: BaseTransition,
- learned_reward: bool = True,
- reward_mech: Optional[BaseRewardMech] = None,
- learned_termination: bool = False,
- termination_mech: Optional[BaseTerminationMech] = None,
- # trainer
- optim_lr: float = 1e-4,
- weight_decay: float = 1e-5,
- optim_eps: float = 1e-8,
- logger: Optional[Logger] = None,
- ):
- super(ConstraintBasedDynamics, self).__init__(
- transition=transition,
- learned_reward=learned_reward,
- reward_mech=reward_mech,
- learned_termination=learned_termination,
- termination_mech=termination_mech,
- optim_lr=optim_lr,
- weight_decay=weight_decay,
- optim_eps=optim_eps,
- logger=logger,
- )
- # self.cmi_test: Optional[EnsembleMLP] = None
- # self.build_cmi_test()
- #
- # self.cmi_test_optimizer = torch.optim.Adam(
- # self.cmi_test.parameters(),
- # lr=optim_lr,
- # weight_decay=weight_decay,
- # eps=optim_eps,
- # )
- # self.learn_mech.append("cmi_test")
- # self.total_epoch["cmi_test"] = 0
- # self._MECH_TO_VARIABLE["cmi_test"] = self._MECH_TO_VARIABLE["transition"]
-
- for mech in self.learn_mech:
- if hasattr(getattr(self, mech), "input_mask"):
- setattr(self, "{}_oracle_mask".format(mech), None)
- setattr(self, "{}_history_mask".format(mech), torch.ones(getattr(self, mech).input_mask.shape).to(self.device))
-
- def build_cmi_test(self):
- self.cmi_test = TransitionConditionalMutualInformationTest(
- obs_size=self.transition.obs_size,
- action_size=self.transition.action_size,
- ensemble_num=1,
- elite_num=1,
- residual=self.transition.residual,
- learn_logvar_bounds=self.transition.learn_logvar_bounds,
- num_layers=4,
- hid_size=200,
- activation_fn_cfg=self.transition.activation_fn_cfg,
- device=self.transition.device,
- )
-
- def set_oracle_mask(self, mech: str, mask: TensorType):
- assert hasattr(self, "{}_oracle_mask".format(mech))
- setattr(self, "{}_oracle_mask".format(mech), to_tensor(mask))
-
- def learn(
- self,
- # data
- replay_buffer: ReplayBuffer,
- # dataset split
- validation_ratio: float = 0.2,
- batch_size: int = 256,
- shuffle_each_epoch: bool = True,
- bootstrap_permutes: bool = False,
- # model learning
- longest_epoch: Optional[int] = None,
- improvement_threshold: float = 0.1,
- patience: int = 5,
- work_dir: Optional[Union[str, pathlib.Path]] = None,
- # other
- **kwargs
- ):
- train_dataset, val_dataset = self.dataset_split(
- replay_buffer,
- validation_ratio,
- batch_size,
- shuffle_each_epoch,
- bootstrap_permutes,
- )
-
- for mech in self.learn_mech:
- if hasattr(self, "{}_oracle_mask".format(mech)):
- getattr(self, mech).set_input_mask(getattr(self, "{}_oracle_mask".format(mech)))
-
- best_weights: Optional[Dict] = None
- epoch_iter = range(longest_epoch) if longest_epoch > 0 else itertools.count()
- epochs_since_update = 0
-
- best_val_loss = self.evaluate(val_dataset, mech=mech).mean(dim=(1, 2))
-
- for epoch in epoch_iter:
- train_loss = self.train(train_dataset, mech=mech)
- val_loss = self.evaluate(val_dataset, mech=mech).mean(dim=(1, 2))
-
- maybe_best_weights = self.maybe_get_best_weights(
- best_val_loss,
- val_loss,
- mech,
- improvement_threshold,
- )
- if maybe_best_weights:
- # best loss
- best_val_loss = torch.minimum(best_val_loss, val_loss)
- best_weights = maybe_best_weights
- epochs_since_update = 0
- else:
- epochs_since_update += 1
-
- # log
- self.total_epoch[mech] += 1
- if self.logger is not None:
- self.logger.record("{}/epoch".format(mech), epoch)
- self.logger.record("{}/train_dataset_size".format(mech), train_dataset.num_stored)
- self.logger.record("{}/val_dataset_size".format(mech), val_dataset.num_stored)
- self.logger.record("{}/train_loss".format(mech), train_loss.mean().item())
- self.logger.record("{}/val_loss".format(mech), val_loss.mean().item())
- self.logger.record("{}/best_val_loss".format(mech), best_val_loss.mean().item())
- self.logger.dump(self.total_epoch[mech])
- if patience and epochs_since_update >= patience:
- break
-
- # saving the best models:
- self.maybe_set_best_weights_and_elite(best_weights, best_val_loss, mech=mech)
- self.save(work_dir)
-
- def maybe_get_best_weights(
- self,
- best_val_loss: torch.Tensor,
- val_loss: torch.Tensor,
- mech: str = "transition",
- threshold: float = 0.01,
- ):
- improvement = (best_val_loss - val_loss) / torch.abs(best_val_loss)
- if (improvement > threshold).any().item():
- model = getattr(self, mech)
- best_weights = copy.deepcopy(model.state_dict())
- else:
- best_weights = None
-
- return best_weights
-
- def maybe_set_best_weights_and_elite(
- self,
- best_weights: Optional[Dict],
- best_val_loss: torch.Tensor,
- mech: str = "transition",
- ):
- model = getattr(self, mech)
- assert isinstance(model, EnsembleMLP)
-
- if best_weights is not None:
- model.load_state_dict(best_weights)
- sorted_indices = np.argsort(best_val_loss.tolist())
- elite_models = sorted_indices[: model.elite_num]
- model.set_elite_members(elite_models)
diff --git a/cmrl/models/dynamics/plain_dynamics.py b/cmrl/models/dynamics/plain_dynamics.py
deleted file mode 100644
index 317ac46..0000000
--- a/cmrl/models/dynamics/plain_dynamics.py
+++ /dev/null
@@ -1,142 +0,0 @@
-import copy
-import itertools
-import pathlib
-from typing import Callable, Dict, List, Optional, Tuple, Union, cast
-
-import numpy as np
-import torch
-from stable_baselines3.common.logger import Logger
-from stable_baselines3.common.buffers import ReplayBuffer
-
-from cmrl.models.dynamics import BaseDynamics
-from cmrl.models.networks.mlp import EnsembleMLP
-from cmrl.models.reward_mech.base_reward_mech import BaseRewardMech
-from cmrl.models.termination_mech.base_termination_mech import BaseTerminationMech
-from cmrl.models.transition.base_transition import BaseTransition
-
-
-class PlainEnsembleDynamics(BaseDynamics):
- def __init__(
- self,
- transition: BaseTransition,
- learned_reward: bool = True,
- reward_mech: Optional[BaseRewardMech] = None,
- learned_termination: bool = False,
- termination_mech: Optional[BaseTerminationMech] = None,
- # trainer
- optim_lr: float = 1e-4,
- weight_decay: float = 1e-5,
- optim_eps: float = 1e-8,
- logger: Optional[Logger] = None,
- ):
- super(PlainEnsembleDynamics, self).__init__(
- transition=transition,
- learned_reward=learned_reward,
- reward_mech=reward_mech,
- learned_termination=learned_termination,
- termination_mech=termination_mech,
- optim_lr=optim_lr,
- weight_decay=weight_decay,
- optim_eps=optim_eps,
- logger=logger,
- )
-
- def learn(
- self,
- # data
- replay_buffer: ReplayBuffer,
- # dataset split
- validation_ratio: float = 0.2,
- batch_size: int = 256,
- shuffle_each_epoch: bool = True,
- bootstrap_permutes: bool = False,
- # model learning
- longest_epoch: int = -1,
- improvement_threshold: float = 0.1,
- patience: int = 5,
- work_dir: Optional[Union[str, pathlib.Path]] = None,
- # other
- **kwargs
- ):
- train_dataset, val_dataset = self.dataset_split(
- replay_buffer,
- validation_ratio,
- batch_size,
- shuffle_each_epoch,
- bootstrap_permutes,
- )
-
- for mech in self.learn_mech:
- best_weights: Optional[Dict] = None
- epoch_iter = range(longest_epoch) if longest_epoch > 0 else itertools.count()
- epochs_since_update = 0
-
- best_val_loss = self.evaluate(val_dataset, mech=mech).mean(dim=(1, 2))
-
- for epoch in epoch_iter:
- train_loss = self.train(train_dataset, mech=mech)
- val_loss = self.evaluate(val_dataset, mech=mech).mean(dim=(1, 2))
-
- maybe_best_weights = self.maybe_get_best_weights(
- best_val_loss,
- val_loss,
- mech,
- improvement_threshold,
- )
- if maybe_best_weights:
- # best loss
- best_val_loss = torch.minimum(best_val_loss, val_loss)
- best_weights = maybe_best_weights
- epochs_since_update = 0
- else:
- epochs_since_update += 1
-
- # log
- self.total_epoch[mech] += 1
- if self.logger is not None:
- self.logger.record("{}/epoch".format(mech), epoch)
- self.logger.record("{}/train_dataset_size".format(mech), train_dataset.num_stored)
- self.logger.record("{}/val_dataset_size".format(mech), val_dataset.num_stored)
- self.logger.record("{}/train_loss".format(mech), train_loss.mean().item())
- self.logger.record("{}/val_loss".format(mech), val_loss.mean().item())
- self.logger.record("{}/best_val_loss".format(mech), best_val_loss.mean().item())
- self.logger.dump(self.total_epoch[mech])
-
- if patience and epochs_since_update >= patience:
- break
-
- # saving the best models:
- self.maybe_set_best_weights_and_elite(best_weights, best_val_loss, mech=mech)
- if work_dir is not None:
- self.save(work_dir)
-
- def maybe_get_best_weights(
- self,
- best_val_loss: torch.Tensor,
- val_loss: torch.Tensor,
- mech: str = "transition",
- threshold: float = 0.01,
- ):
- improvement = (best_val_loss - val_loss) / torch.abs(best_val_loss)
- if (improvement > threshold).any().item():
- model = getattr(self, mech)
- best_weights = copy.deepcopy(model.state_dict())
- else:
- best_weights = None
-
- return best_weights
-
- def maybe_set_best_weights_and_elite(
- self,
- best_weights: Optional[Dict],
- best_val_loss: torch.Tensor,
- mech: str = "transition",
- ):
- model = getattr(self, mech)
- assert isinstance(model, EnsembleMLP)
-
- if best_weights is not None:
- model.load_state_dict(best_weights)
- sorted_indices = np.argsort(best_val_loss.tolist())
- elite_models = sorted_indices[: model.elite_num]
- model.set_elite_members(elite_models)
diff --git a/cmrl/models/fake_env.py b/cmrl/models/fake_env.py
index 68d195a..0e6aefe 100644
--- a/cmrl/models/fake_env.py
+++ b/cmrl/models/fake_env.py
@@ -1,124 +1,103 @@
-# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
+from typing import Any, Dict, List, Optional, Type
import gym
import numpy as np
import torch
-from gym.core import ActType, ObsType
-from stable_baselines3.common.vec_env.base_vec_env import (
- VecEnv,
- VecEnvIndices,
- VecEnvObs,
- VecEnvStepReturn,
-)
+from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices
+from stable_baselines3.common.logger import Logger
from stable_baselines3.common.buffers import ReplayBuffer
-import cmrl.types
-from cmrl.models.dynamics import BaseDynamics
+from cmrl.types import RewardFnType, TermFnType, InitObsFnType
+from cmrl.models.dynamics import Dynamics
+
+
+def get_penalty(ensemble_batch_next_obs):
+ avg = np.mean(ensemble_batch_next_obs, axis=0) # average predictions over models
+ diffs = ensemble_batch_next_obs - avg
+ dists = np.linalg.norm(diffs, axis=2) # distance in obs space
+ penalty = np.max(dists, axis=0) # max distances over models
+ return penalty
class VecFakeEnv(VecEnv):
def __init__(
- self,
- num_envs: int,
- observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
+ self,
+ # for need of sb3's agent
+ num_envs: int,
+ observation_space: gym.spaces.Space,
+ action_space: gym.spaces.Space,
+ # for dynamics
+ dynamics: Dynamics,
+ reward_fn: Optional[RewardFnType] = None,
+ termination_fn: Optional[TermFnType] = None,
+ get_init_obs_fn: Optional[InitObsFnType] = None,
+ real_replay_buffer: Optional[ReplayBuffer] = None,
+ # for offline
+ penalty_coeff: float = 0.0,
+ # for behaviour
+ deterministic: bool = False,
+ max_episode_steps: int = 1000,
+ branch_rollout: bool = False,
+ # others
+ logger: Optional[Logger] = None,
+ **kwargs,
):
super(VecFakeEnv, self).__init__(
num_envs=num_envs,
observation_space=observation_space,
action_space=action_space,
)
-
- self.has_set_up = False
-
- self.penalty_coeff = None
- self.deterministic = None
- self.max_episode_steps = None
-
- self.dynamics = None
- self.reward_fn = None
- self.termination_fn = None
- self.learned_reward = None
- self.learned_termination = None
- self.get_init_obs_fn = None
- self.replay_buffer = None
- self.generator = np.random.default_rng()
- self.device = None
- self.logger = None
-
- self._current_batch_obs = None
- self._current_batch_action = None
-
- self._reset_by_buffer = True
-
- self._envs_length = np.zeros(self.num_envs, dtype=int)
-
- def set_up(
- self,
- dynamics: BaseDynamics,
- reward_fn: Optional[cmrl.types.RewardFnType] = None,
- termination_fn: Optional[cmrl.types.TermFnType] = None,
- get_init_obs_fn: Optional[cmrl.types.InitObsFnType] = None,
- real_replay_buffer: Optional[ReplayBuffer] = None,
- penalty_coeff: float = 0.0,
- deterministic=False,
- max_episode_steps=1000,
- logger=None,
- ):
self.dynamics = dynamics
+ self.reward_fn = reward_fn
+ self.termination_fn = termination_fn
+ assert self.dynamics.learn_reward or reward_fn, "you must learn a reward-mech or give one"
+ assert self.dynamics.learn_termination or termination_fn, "you must learn a termination-mech or give one"
+ self.learn_reward = self.dynamics.learn_reward
+ self.learn_termination = self.dynamics.learn_termination
+ self.get_init_obs_fn = get_init_obs_fn
+ self.replay_buffer = real_replay_buffer
self.penalty_coeff = penalty_coeff
self.deterministic = deterministic
self.max_episode_steps = max_episode_steps
+ self.branch_rollout = branch_rollout
+ if self.branch_rollout:
+ assert self.replay_buffer, "you must provide a replay buffer if using branch-rollout"
+ else:
+ assert self.get_init_obs_fn, "you must provide a get-init-obs function if using fully-virtual"
- self.reward_fn = reward_fn
- self.termination_fn = termination_fn
- assert self.dynamics.learned_reward or reward_fn
- assert self.dynamics.learned_termination or termination_fn
- self.learned_reward = self.dynamics.learned_reward
- self.learned_termination = self.dynamics.learned_termination
- self.get_init_obs_fn = get_init_obs_fn
- self.replay_buffer = real_replay_buffer
self.logger = logger
- assert self.get_init_obs_fn or self.replay_buffer
- self._reset_by_buffer = self.replay_buffer is not None
-
self.device = dynamics.device
- self.has_set_up = True
+
+ self._current_batch_obs = None
+ self._current_batch_action = None
+ self._envs_length = np.zeros(self.num_envs, dtype=int)
def step_async(self, actions: np.ndarray) -> None:
+ assert len(actions.shape) == 2 # batch, action_dim
self._current_batch_action = actions
def step_wait(self):
- assert self.has_set_up, "fake-env has not set up"
- assert len(self._current_batch_action.shape) == 2 # batch, action_dim
- with torch.no_grad():
- batch_obs_tensor = torch.from_numpy(self._current_batch_obs).to(torch.float32).to(self.device)
- batch_action_tensor = torch.from_numpy(self._current_batch_action).to(torch.float32).to(self.device)
- dynamics_pred = self.dynamics.query(batch_obs_tensor, batch_action_tensor, return_as_np=True)
-
- # transition
- batch_next_obs = self.get_dynamics_predict(dynamics_pred, "transition", deterministic=self.deterministic)
- if self.learned_reward:
- batch_reward = self.get_dynamics_predict(dynamics_pred, "reward_mech", deterministic=self.deterministic)
- else:
- batch_reward = self.reward_fn(batch_next_obs, self._current_batch_obs, self._current_batch_action)
- if self.learned_termination:
- batch_terminal = self.get_dynamics_predict(dynamics_pred, "termination_mech", deterministic=self.deterministic)
- else:
- batch_terminal = self.termination_fn(batch_next_obs, self._current_batch_obs, self._current_batch_action)
-
- if self.penalty_coeff != 0:
- penalty = self.get_penalty(dynamics_pred["batch_next_obs"]["mean"]).reshape(batch_reward.shape)
- batch_reward -= penalty * self.penalty_coeff
-
- if self.logger is not None:
- self.logger.record_mean("rollout/penalty", penalty.mean().item())
+ batch_next_obs, batch_reward, batch_terminal, info = self.dynamics.step(
+ self._current_batch_obs, self._current_batch_action
+ )
+
+ if not self.learn_reward:
+ batch_reward = self.reward_fn(batch_next_obs, self._current_batch_obs, self._current_batch_action)
+ if not self.learn_termination:
+ batch_terminal = self.termination_fn(batch_next_obs, self._current_batch_obs, self._current_batch_action)
+
+ if self.penalty_coeff != 0:
+ penalty = get_penalty(info["origin-next_obs"]).reshape(batch_reward.shape) * self.penalty_coeff
+ batch_reward -= penalty
+
+ if self.logger is not None:
+ self.logger.record_mean("rollout/penalty", penalty.mean().item())
+
+ assert not np.isnan(batch_next_obs).any(), "next obs of fake env should not be nan."
+ assert not np.isnan(batch_reward).any(), "reward of fake env should not be nan."
+ assert not np.isnan(batch_terminal).any(), "terminal of fake env should not be nan."
self._current_batch_obs = batch_next_obs.copy()
batch_reward = batch_reward.reshape(self.num_envs)
@@ -142,21 +121,23 @@ def step_wait(self):
)
def reset(
- self,
- *,
- seed: Optional[int] = None,
- return_info: bool = False,
- options: Optional[dict] = None,
+ self,
+ *,
+ seed: Optional[int] = None,
+ return_info: bool = False,
+ options: Optional[dict] = None,
):
- if self.has_set_up:
- if self._reset_by_buffer:
- upper_bound = self.replay_buffer.buffer_size if self.replay_buffer.full else self.replay_buffer.pos
- batch_inds = np.random.randint(0, upper_bound, size=self.num_envs)
- self._current_batch_obs = self.replay_buffer.observations[batch_inds, 0]
- else:
- self._current_batch_obs = self.get_init_obs_fn(self.num_envs)
- self._envs_length = np.zeros(self.num_envs, dtype=int)
+ if self.branch_rollout:
+ upper_bound = self.replay_buffer.buffer_size if self.replay_buffer.full else self.replay_buffer.pos
+ batch_inds = np.random.randint(0, upper_bound, size=self.num_envs)
+ self._current_batch_obs = self.replay_buffer.observations[batch_inds, 0]
+ else:
+ self._current_batch_obs = self.get_init_obs_fn(self.num_envs)
+ self._envs_length = np.zeros(self.num_envs, dtype=int)
+ if return_info:
+ return self._current_batch_obs.copy(), {}
+ else:
return self._current_batch_obs.copy()
def seed(self, seed: Optional[int] = None):
@@ -169,13 +150,11 @@ def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndice
return [False for _ in range(self.num_envs)]
def single_reset(self, idx):
- assert self.has_set_up, "fake-env has not set up"
-
self._envs_length[idx] = 0
- if self._reset_by_buffer:
+ if self.branch_rollout:
upper_bound = self.replay_buffer.buffer_size if self.replay_buffer.full else self.replay_buffer.pos
- batch_inds = np.random.randint(0, upper_bound)
- self._current_batch_obs[idx] = self.replay_buffer.observations[batch_inds, 0]
+ batch_idxs = np.random.randint(0, upper_bound)
+ self._current_batch_obs[idx] = self.replay_buffer.observations[batch_idxs, 0]
else:
assert self.get_init_obs_fn is not None
self._current_batch_obs[idx] = self.get_init_obs_fn(1)
@@ -183,43 +162,12 @@ def single_reset(self, idx):
def render(self, mode="human"):
raise NotImplementedError
- @staticmethod
- def get_penalty(ensemble_batch_next_obs):
- avg = np.mean(ensemble_batch_next_obs, axis=0) # average predictions over models
- diffs = ensemble_batch_next_obs - avg
- dists = np.linalg.norm(diffs, axis=2) # distance in obs space
- penalty = np.max(dists, axis=0) # max distances over models
-
- return penalty
-
- def get_dynamics_predict(
- self,
- origin_predict: Dict,
- mech: str,
- deterministic: bool = False,
- ):
- variable = self.dynamics.get_variable_by_mech(mech)
- ensemble_mean, ensemble_logvar = (
- origin_predict[variable]["mean"],
- origin_predict[variable]["logvar"],
- )
- batch_size = ensemble_mean.shape[1]
- random_index = getattr(self.dynamics, mech).get_random_index(batch_size, self.generator)
- if deterministic:
- pred = ensemble_mean[random_index, np.arange(batch_size)]
- else:
- ensemble_std = np.sqrt(np.exp(ensemble_logvar))
- pred = ensemble_mean[random_index, np.arange(batch_size)] + ensemble_std[
- random_index, np.arange(batch_size)
- ] * self.generator.normal(size=ensemble_mean.shape[1:]).astype(np.float32)
- return pred
-
def env_method(
- self,
- method_name: str,
- *method_args,
- indices: VecEnvIndices = None,
- **method_kwargs,
+ self,
+ method_name: str,
+ *method_args,
+ indices: VecEnvIndices = None,
+ **method_kwargs,
) -> List[Any]:
pass
diff --git a/cmrl/algorithms/offline/__init__.py b/cmrl/models/graphs/__init__.py
similarity index 100%
rename from cmrl/algorithms/offline/__init__.py
rename to cmrl/models/graphs/__init__.py
diff --git a/cmrl/models/graphs/base_graph.py b/cmrl/models/graphs/base_graph.py
index df531c8..a75eec2 100644
--- a/cmrl/models/graphs/base_graph.py
+++ b/cmrl/models/graphs/base_graph.py
@@ -1,123 +1,73 @@
import abc
import pathlib
-from typing import Any, Dict, Optional, Sequence, Tuple, Union
+from typing import Optional, Tuple, Union
import torch
-import torch.nn as nn
-class BaseGraph(nn.Module, abc.ABC):
+class BaseGraph(abc.ABC):
"""Base abstract class for all graph models.
All classes derived from `BaseGraph` must implement the following methods:
- - ``forward``: computes the graph (parameters).
- - ``update``: updates the structural parameters.
- - ``get_binary_graph``: gets the binary graph.
+ - ``parameters``: the graph parameters property.
+ - ``get_adj_matrix``: get the (raw) adjacency matrix.
+ - ``get_binary_adj_matrix``: get the binary format of the adjacency matrix.
+ - ``save``: save the graph data
+ - ``load``: load the graph data
Args:
in_dim (int): input dimension.
out_dim (int): output dimension.
- device (str or torch.device): device to use for the structural parameters.
+ extra_dim (int | tuple(int) | None): extra dimensions (multi-graph).
+ include_input (bool): whether include input variables in the output variables.
"""
- _GRAPH_FNAME = "graph.pth"
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ extra_dim: Optional[Union[int, Tuple[int]]] = None,
+ include_input: bool = False,
+ *args,
+ **kwargs
+ ) -> None:
+ self._in_dim = in_dim
+ self._out_dim = out_dim
+ self._extra_dim = extra_dim
+ self._include_input = include_input
- def __init__(self, in_dim: int, out_dim: int, device: Union[str, torch.device] = "cpu", *args, **kwargs):
- super().__init__()
- self.in_dim = in_dim
- self.out_dim = out_dim
- self.device = device
+ assert not (include_input and out_dim < in_dim), "Once include input, the out dimension must >= in dimension"
+ @property
@abc.abstractmethod
- def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, ...]:
- """Computes the graph parameters.
+ def parameters(self) -> Tuple[torch.Tensor]:
+ """Get the graph parameters (raw graph).
- Returns:
- (tuple of tensors): all tensors representing the output
- graph (e.g. existence and orientation)
+ Returns: (tuple of tensor) the true graph parameters
"""
@abc.abstractmethod
- def get_binary_graph(self, *args, **kwargs) -> torch.Tensor:
- """Gets the binary graph.
+ def get_adj_matrix(self, *args, **kwargs) -> torch.Tensor:
+ """Get the raw adjacency matrix.
Returns:
- (tensor): the binary graph tensor, shape [in_dim, out_dim];
- graph[i, j] == 1 represents i causes j
+ (tensor): the raw adjacency matrix tensor, shape [in_dim, out_dim];
"""
- def get_mask(self, *args, **kwargs) -> torch.Tensor:
- # [..., in_dim, out_dim]
- binary_mat = self.get_binary_graph(*args, **kwargs)
- # [..., out_dim, in_dim], mask apply on the input for each output variable
- return binary_mat.transpose(-1, -2)
-
- def save(self, save_dir: Union[str, pathlib.Path]):
- """Saves the model to the given directory."""
- torch.save(self.state_dict(), pathlib.Path(save_dir) / self._GRAPH_FNAME)
-
- def load(self, load_dir: Union[str, pathlib.Path]):
- """Loads the model from the given path."""
- self.load_state_dict(torch.load(pathlib.Path(load_dir) / self._GRAPH_FNAME, map_location=self.device))
-
-
-class BaseEnsembleGraph(BaseGraph, abc.ABC):
- """Base abstract class for all ensemble of bootstrapped 1-D graph models.
-
- Valid propagation options are:
-
- - "random_model": for each output in the batch a model will be chosen at random.
- - "fixed_model": for output j-th in the batch, the model will be chosen according to
- the model index in `propagation_indices[j]`.
- - "expectation": the output for each element in the batch will be the mean across
- models.
- - "majority": the output for each element in the batch will be determined by the
- majority voting with the models (only for binary edge).
-
- The default value of ``None`` indicates that no uncertainty propagation, and the forward
- method returns all outpus of all models.
-
- Args:
- num_members (int): number of models in the ensemble.
- in_dim (int): input dimension.
- out_dim (int): output dimension.
- device (str or torch.device): device to use for the model.
- propagation_method (str, optional): the uncertainty method to use. Defaults to ``None``.
- """
-
- def __init__(
- self,
- num_members: int,
- in_dim: int,
- out_dim: int,
- device: Union[str, torch.device],
- propagation_method: str,
- *args,
- **kwargs
- ):
- super().__init__(in_dim, out_dim, device, *args, **kwargs)
- self.num_members = num_members
- self.propagation_method = propagation_method
- self.device = torch.device(device)
-
- def __len__(self):
- return self.num_members
-
- def set_elite(self, elite_grpahs: Sequence[int]):
- """For ensemble graphs, indicates if some graphs should be considered elite."""
- pass
-
@abc.abstractmethod
- def sample_propagation_indices(self, batch_size: int, rng: torch.Generator) -> torch.Tensor:
- """Samples uncertainty propagation indices.
+ def get_binary_adj_matrix(self, *args, **kwargs) -> torch.Tensor:
+ """Get the binary adjacency matrix.
- Args:
- batch_size (int): the desired batch size.
- rng (`torch.Generator`): a random number generator to use for sampling.
Returns:
- (tensor) with ``batch_size`` integers from [0, ``self.num_members``).
+ (tensor): the binary adjacency matrix tensor, shape [in_dim, out_dim];
+ graph[i, j] == 1 represents i causes j
"""
- def set_propagation_method(self, propagation_method: Optional[str] = None):
- self.propagation_method = propagation_method
+ @abc.abstractmethod
+ def save(self, save_dir: Union[str, pathlib.Path]):
+ """Save the model to the given directory."""
+
+ @abc.abstractmethod
+ def load(self, load_dir: Union[str, pathlib.Path]):
+ """Load the model from the given path."""
diff --git a/cmrl/models/graphs/binary_graph.py b/cmrl/models/graphs/binary_graph.py
new file mode 100644
index 0000000..bcdce56
--- /dev/null
+++ b/cmrl/models/graphs/binary_graph.py
@@ -0,0 +1,81 @@
+import copy
+import pathlib
+from typing import Optional, Union, Tuple
+
+import torch
+import numpy as np
+
+from cmrl.models.graphs.base_graph import BaseGraph
+
+
+class BinaryGraph(BaseGraph):
+ """Binary graph models (binary graph data)
+
+ Args:
+ in_dim (int): input dimension.
+ out_dim (int): output dimension.
+ extra_dim (int | tuple(int) | None): extra dimensions (multi-graph).
+ include_input (bool): whether inlcude input variables in the output variables.
+ init_param (int | Tensor | ndarray): initial parameter of the binary graph
+ device (str or torch.device): device to use for the graph parameters.
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ extra_dim: Optional[Union[int, Tuple[int]]] = None,
+ include_input: bool = False,
+ init_param: Union[int, torch.Tensor, np.ndarray] = 1,
+ device: Union[str, torch.device] = "cpu",
+ *args,
+ **kwargs,
+ ) -> None:
+ super().__init__(in_dim, out_dim, extra_dim, include_input, *args, **kwargs)
+
+ graph_size = (in_dim, out_dim)
+ if extra_dim is not None:
+ if isinstance(extra_dim, int):
+ extra_dim = (extra_dim,)
+ graph_size = extra_dim + graph_size
+
+ if isinstance(init_param, int):
+ self.graph = torch.ones(graph_size, dtype=torch.int, device=device) * int(bool(init_param))
+ else:
+ assert (
+ init_param.shape == graph_size
+ ), f"initial parameters shape mismatch (given {init_param.shape}, while {graph_size} required)"
+ self.graph = torch.as_tensor(init_param, dtype=torch.bool, device=device).int()
+
+ # remove self loop
+ if self._include_input:
+ self.graph[..., torch.arange(self._in_dim), torch.arange(self._in_dim)] = 0
+
+ self.device = device
+
+ @property
+ def parameters(self) -> Tuple[torch.Tensor]:
+ return (self.graph,)
+
+ def get_adj_matrix(self, *args, **kwargs) -> torch.Tensor:
+ return self.graph
+
+ def get_binary_adj_matrix(self, *args, **kwargs) -> torch.Tensor:
+ return self.get_adj_matrix()
+
+ def set_data(self, graph_data: Union[torch.Tensor, np.ndarray]):
+ assert (
+ self.graph.shape == graph_data.shape
+ ), f"graph data shape mismatch (given {graph_data.shape}, while {self.graph.shape} required)"
+ self.graph.data = torch.as_tensor(graph_data, dtype=torch.bool, device=self.device).int()
+
+ # remove self loop
+ if self._include_input:
+ self.graph[..., torch.arange(self._in_dim), torch.arange(self._in_dim)] = 0
+
+ def save(self, save_dir: Union[str, pathlib.Path]):
+ torch.save({"graph_data": self.graph}, pathlib.Path(save_dir) / "graph.pth")
+
+ def load(self, load_dir: Union[str, pathlib.Path]):
+ data_dict = torch.load(pathlib.Path(load_dir) / "graph.pth", map_location=self.device)
+ self.graph = data_dict["graph_data"]
diff --git a/cmrl/models/graphs/neural_graph.py b/cmrl/models/graphs/neural_graph.py
new file mode 100644
index 0000000..84935cc
--- /dev/null
+++ b/cmrl/models/graphs/neural_graph.py
@@ -0,0 +1,186 @@
+import pathlib
+from typing import Optional, Union, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from omegaconf import DictConfig
+from hydra.utils import instantiate
+
+from cmrl.models.graphs.base_graph import BaseGraph
+from cmrl.models.graphs.prob_graph import BaseProbGraph
+
+default_network_cfg = DictConfig(
+ dict(
+ _target_="cmrl.models.networks.ParallelMLP",
+ _partial_=True,
+ _recursive_=False,
+ hidden_dims=[200, 200],
+ bias=True,
+ activation_fn_cfg=dict(_target_="torch.nn.ReLU"),
+ )
+)
+
+
+class NeuralGraph(BaseGraph):
+
+ _MASK_VALUE = 0
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ extra_dim: Optional[Union[int, Tuple[int]]] = None,
+ include_input: bool = False,
+ network_cfg: Optional[DictConfig] = default_network_cfg,
+ device: Union[str, torch.device] = "cpu",
+ *args,
+ **kwargs
+ ) -> None:
+ super().__init__(in_dim=in_dim, out_dim=out_dim, extra_dim=extra_dim, include_input=include_input, *args, **kwargs)
+
+ self._network_cfg = network_cfg
+ self.device = device
+
+ self._build_graph_network()
+
+ def _build_graph_network(self):
+ """called at the last of ``NeuralGraph.__init__``"""
+ network_extra_dims = self._extra_dim
+ if isinstance(network_extra_dims, int):
+ network_extra_dims = [network_extra_dims]
+
+ self.graph = instantiate(self._network_cfg)(
+ input_dim=self._in_dim,
+ output_dim=self._in_dim * self._out_dim,
+ extra_dims=network_extra_dims,
+ ).to(self.device)
+
+ @property
+ def parameters(self) -> Tuple[torch.Tensor]:
+ return tuple(self.graph.parameters())
+
+ def get_adj_matrix(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ adj_mat = self.graph(inputs)
+ adj_mat = adj_mat.reshape(*adj_mat.shape[:-1], self._in_dim, self._out_dim)
+
+ if self._include_input:
+ adj_mat[..., torch.arange(self._in_dim), torch.arange(self._in_dim)] = self._MASK_VALUE
+
+ return adj_mat
+
+ def get_binary_adj_matrix(self, inputs: torch.Tensor, threshold: float, *args, **kwargs) -> torch.Tensor:
+ return (self.get_adj_matrix(inputs) > threshold).int()
+
+ def save(self, save_dir: Union[str, pathlib.Path]):
+ torch.save({"graph_network": self.graph.state_dict()}, pathlib.Path(save_dir) / "graph.pth")
+
+ def load(self, load_dir: Union[str, pathlib.Path]):
+ data_dict = torch.load(pathlib.Path(load_dir) / "graph.pth", map_location=self.device)
+ self.graph.load_state_dict(data_dict["graph_network"])
+
+
+class NeuralBernoulliGraph(NeuralGraph, BaseProbGraph):
+
+ _MASK_VALUE = -9e15
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ extra_dim: Optional[Union[int, Tuple[int]]] = None,
+ include_input: bool = False,
+ network_cfg: Optional[DictConfig] = default_network_cfg,
+ device: Union[str, torch.device] = "cpu",
+ *args,
+ **kwargs
+ ) -> None:
+ super().__init__(
+ in_dim=in_dim,
+ out_dim=out_dim,
+ extra_dim=extra_dim,
+ include_input=include_input,
+ network_cfg=network_cfg,
+ device=device,
+ *args,
+ **kwargs
+ )
+
+ def _build_graph_network(self):
+ super()._build_graph_network()
+
+ def init_weights_zero(layer):
+ for pname, params in layer.named_parameters():
+ if "weight" in pname:
+ nn.init.zeros_(params)
+
+ self.graph.apply(init_weights_zero)
+
+ def get_adj_matrix(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ return torch.sigmoid(super().get_adj_matrix(inputs, *args, **kwargs))
+
+ def get_binary_adj_matrix(self, inputs: torch.Tensor, threshold: float = 0.5, *args, **kwargs) -> torch.Tensor:
+ """return the binary adjacency matrices corresponding to the inputs (w/o grad.)"""
+ return super().get_binary_adj_matrix(inputs, threshold, *args, **kwargs)
+
+ def sample(
+ self,
+ prob_matrix: Optional[torch.Tensor],
+ sample_size: Union[Tuple[int], int],
+ reparameterization: Optional[str] = None,
+ *args,
+ **kwargs
+ ) -> torch.Tensor:
+ """sample from given or current graph probability (Bernoulli distribution).
+
+ Args:
+ prob_matrix (tensor), graph probability, can not be empty here.
+ sample_size (tuple(int) or int), extra size of sampled graphs.
+
+ Return:
+ (tensor): [*sample_size, *extra_dim, in_dim, out_dim] shaped multiple graphs.
+ """
+ if prob_matrix is None:
+ raise ValueError("Porb. matrix can not be empty")
+
+ if isinstance(sample_size, int):
+ sample_size = (sample_size,)
+
+ sample_prob = prob_matrix[None].expand(*sample_size, -1, -1)
+
+ if reparameterization is None:
+ return torch.bernoulli(sample_prob)
+ elif reparameterization == "gumbel-softmax":
+ return F.gumbel_softmax(torch.stack((sample_prob, 1 - sample_prob)), hard=True, dim=0)[0]
+ else:
+ raise NotImplementedError
+
+ def sample_from_inputs(
+ self,
+ inputs: torch.Tensor,
+ sample_size: Union[Tuple[int], int],
+ reparameterization: Optional[str] = "gumbel-softmax",
+ *args,
+ **kwargs
+ ) -> torch.Tensor:
+ """sample adjacency matrix from inputs (genereated Bernoulli distribution given the inputs).
+
+ Args:
+ inputs (tensor), input samples.
+ sample_size (tuple(int) or int), extra size of sampled graphs.
+
+ Return:
+ (tensor): [*sample_size, *extra_dim, in_dim, out_dim] shaped multiple graphs.
+ """
+ if isinstance(sample_size, int):
+ sample_size = (sample_size,)
+
+ inputs = inputs[None].expand(*sample_size, *((-1,) * len(inputs.shape)))
+ sample_prob = self.get_adj_matrix(inputs)
+
+ if reparameterization is None:
+ return torch.bernoulli(sample_prob)
+ elif reparameterization == "gumbel-softmax":
+ return F.gumbel_softmax(torch.stack((sample_prob, 1 - sample_prob)), hard=True, dim=0)[0]
+ else:
+ raise NotImplementedError
diff --git a/cmrl/models/graphs/prob_graph.py b/cmrl/models/graphs/prob_graph.py
index ebf7517..6bfc65e 100644
--- a/cmrl/models/graphs/prob_graph.py
+++ b/cmrl/models/graphs/prob_graph.py
@@ -1,11 +1,12 @@
from abc import abstractmethod
-import math
from typing import Union, Tuple, Optional
import torch
-import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
from cmrl.models.graphs.base_graph import BaseGraph
+from cmrl.models.graphs.weight_graph import WeightGraph
class BaseProbGraph(BaseGraph):
@@ -13,12 +14,14 @@ class BaseProbGraph(BaseGraph):
All classes derived from `BaseProbGraph` must implement the following additional methods:
- - ``sample``: sample graphs from given (or current) graph probability.
+ - ``sample``: sample graphs from current (or given) graph probability.
"""
@abstractmethod
- def sample(self, graph: Optional[torch.Tensor], sample_size: Union[Tuple[int], int], *args, **kwargs) -> torch.Tensor:
- """sample from given or current graph probability.
+ def sample(
+ self, prob_matrix: Optional[torch.Tensor], sample_size: Union[Tuple[int], int], *args, **kwargs
+ ) -> torch.Tensor:
+ """sample from given or current probability adjacency matrix.
Args:
graph (tensor), graph probability, use current graph parameter when given `None`.
@@ -30,57 +33,62 @@ def sample(self, graph: Optional[torch.Tensor], sample_size: Union[Tuple[int], i
pass
-class BernoulliGraph(BaseProbGraph):
- """Probability (Bernoulli dist.) modeled graphs, store the graph with the
- probability parameter of the existence/orientation of edges.
+class BernoulliGraph(WeightGraph, BaseProbGraph):
+ """Probability (Bernoulli dist.) graph models, store the graph with the
+ probability parameter of the existence of edges.
Args:
in_dim (int): input dimension.
out_dim (int): output dimension.
- init_param (float or torch.Tensor): initial parameter of the graph
- (sigmoid(init_param) representing the initial edge probabilities).
- device (str or torch.device): device to use for the structural parameters.
+ extra_dim (int | tuple(int) | None): extra dimensions (multi-graph).
+ include_input (bool): whether inlcude input variables in the output variables.
+ init_param (int | Tensor | ndarray): initial parameter of the bernoulli graph。
+ requires_grad (bool): whether the graph parameters require gradient computation.
+ device (str or torch.device): device to use for the graph parameters.
"""
+ _MASK_VALUE = -9e15
+
def __init__(
self,
in_dim: int,
out_dim: int,
- init_param: Union[float, torch.Tensor] = 1e-6,
+ extra_dim: Optional[Union[int, Tuple[int]]] = None,
+ include_input: bool = False,
+ init_param: Union[float, torch.Tensor, np.ndarray] = 1e-6,
+ requires_grad: bool = False,
device: Union[str, torch.device] = "cpu",
*args,
**kwargs
- ):
- super().__init__(in_dim, out_dim, device, *args, **kwargs)
-
- if isinstance(init_param, float):
- init_param = torch.ones(in_dim, out_dim) * init_param
- self.graph = nn.Parameter(init_param, requires_grad=True)
-
- self.to(device)
-
- def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, ...]:
- """Computes the graph parameters.
-
- Returns:
- (tuple of tensors): all tensors representing the output
- graph (e.g. existence and orientation)
- """
+ ) -> None:
+ super().__init__(
+ in_dim=in_dim,
+ out_dim=out_dim,
+ extra_dim=extra_dim,
+ include_input=include_input,
+ init_param=init_param,
+ requires_grad=requires_grad,
+ device=device,
+ *args,
+ **kwargs
+ )
+
+ def get_adj_matrix(self, *args, **kwargs) -> torch.Tensor:
return torch.sigmoid(self.graph)
- def get_binary_graph(self, thresh: float = 0.5) -> torch.Tensor:
- """Gets the binary graph.
+ def get_binary_adj_matrix(self, threshold: float = 0.5, *args, **kwargs) -> torch.Tensor:
+ assert 0 <= threshold <= 1, "threshold of bernoulli graph should be in [0, 1]"
- Returns:
- (tensor): the binary graph tensor, shape [in_dim, out_dim];
- graph[i, j] == 1 represents i causes j
- """
- assert 0 <= thresh <= 1
-
- prob_graph = self()
- return prob_graph > thresh
+ return super().get_binary_adj_matrix(threshold, *args, **kwargs)
- def sample(self, graph: Optional[torch.Tensor], sample_size: Union[Tuple[int], int], *args, **kwargs):
+ def sample(
+ self,
+ prob_matrix: Optional[torch.Tensor],
+ sample_size: Union[Tuple[int], int],
+ reparameterization: Optional[str] = None,
+ *args,
+ **kwargs
+ ):
"""sample from given or current graph probability (Bernoulli distribution).
Args:
@@ -88,14 +96,19 @@ def sample(self, graph: Optional[torch.Tensor], sample_size: Union[Tuple[int], i
sample_size (tuple(int) or int), extra size of sampled graphs.
Return:
- (tensor): [*sample_size, in_dim, out_dim] shaped multiple graphs.
+ (tensor): [*sample_size, *extra_dim, in_dim, out_dim] shaped multiple graphs.
"""
- if graph is None:
- graph = self()
+ if prob_matrix is None:
+ prob_matrix = self.get_adj_matrix()
if isinstance(sample_size, int):
sample_size = (sample_size,)
- sample_prob = graph[None].expand(*sample_size, -1, -1)
+ sample_prob = prob_matrix[None].expand(*sample_size, *((-1,) * len(prob_matrix.shape)))
- return torch.bernoulli(sample_prob)
+ if reparameterization is None:
+ return torch.bernoulli(sample_prob)
+ elif reparameterization == "gumbel-softmax":
+ return F.gumbel_softmax(torch.stack((sample_prob, 1 - sample_prob)), hard=True, dim=0)[0]
+ else:
+ raise NotImplementedError
diff --git a/cmrl/models/graphs/weight_graph.py b/cmrl/models/graphs/weight_graph.py
new file mode 100644
index 0000000..a752551
--- /dev/null
+++ b/cmrl/models/graphs/weight_graph.py
@@ -0,0 +1,94 @@
+import pathlib
+from typing import Optional, Union, Tuple
+
+import torch
+import numpy as np
+
+from cmrl.models.graphs.base_graph import BaseGraph
+
+
+class WeightGraph(BaseGraph):
+ """Weight graph models (real graph data)
+
+ Args:
+ in_dim (int): input dimension.
+ out_dim (int): output dimension.
+ extra_dim (int | tuple(int) | None): extra dimensions (multi-graph).
+ include_input (bool): whether inlcude input variables in the output variables.
+ init_param (int | Tensor | ndarray): initial parameter of the weight graph。
+ requires_grad (bool): whether the graph parameters require gradient computation.
+ device (str or torch.device): device to use for the graph parameters.
+ """
+
+ _MASK_VALUE = 0
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ extra_dim: Optional[Union[int, Tuple[int]]] = None,
+ include_input: bool = False,
+ init_param: Union[float, torch.Tensor, np.ndarray] = 1.0,
+ requires_grad: bool = False,
+ device: Union[str, torch.device] = "cpu",
+ *args,
+ **kwargs,
+ ) -> None:
+ super().__init__(in_dim, out_dim, extra_dim, include_input, *args, **kwargs)
+ self._requires_grad = requires_grad
+
+ graph_size = (in_dim, out_dim)
+ if extra_dim is not None:
+ if isinstance(extra_dim, int):
+ extra_dim = (extra_dim,)
+ graph_size = extra_dim + graph_size
+
+ if isinstance(init_param, float):
+ self.graph = torch.ones(graph_size, dtype=torch.float32, device=device) * init_param
+ else:
+ assert (
+ init_param.shape == graph_size
+ ), f"initial parameters shape mismatch (given {init_param.shape}, while {graph_size} required)"
+ self.graph = torch.as_tensor(init_param, dtype=torch.float32, device=device)
+
+ if requires_grad:
+ self.graph.requires_grad_()
+
+ # remove self loop
+ if self._include_input:
+ with torch.no_grad():
+ self.graph[..., torch.arange(self._in_dim), torch.arange(self._in_dim)] = self._MASK_VALUE
+
+ self.device = device
+
+ @property
+ def parameters(self) -> Tuple[torch.Tensor]:
+ return (self.graph,)
+
+ @property
+ def requries_grad(self) -> bool:
+ return self._requires_grad
+
+ def get_adj_matrix(self, *args, **kwargs) -> torch.Tensor:
+ return self.graph
+
+ def get_binary_adj_matrix(self, threshold: float, *args, **kwargs) -> torch.Tensor:
+ return (self.get_adj_matrix() > threshold).int()
+
+ @torch.no_grad()
+ def set_data(self, graph_data: Union[torch.Tensor, np.ndarray]):
+ assert (
+ self.graph.shape == graph_data.shape
+ ), f"graph data shape mismatch (given {graph_data.shape}, while {self.graph.shape} required)"
+ self.graph.data = torch.as_tensor(graph_data, dtype=torch.float32, device=self.device)
+
+ # remove self loop
+ if self._include_input:
+ self.graph[..., torch.arange(self._in_dim), torch.arange(self._in_dim)] = self._MASK_VALUE
+
+ def save(self, save_dir: Union[str, pathlib.Path]):
+ torch.save({"graph_data": self.graph}, pathlib.Path(save_dir) / "graph.pth")
+
+ def load(self, load_dir: Union[str, pathlib.Path]):
+ data_dict = torch.load(pathlib.Path(load_dir) / "graph.pth", map_location=self.device)
+ self.graph = data_dict["graph_data"]
diff --git a/cmrl/models/layers.py b/cmrl/models/layers.py
index 5222dc6..b92dca7 100644
--- a/cmrl/models/layers.py
+++ b/cmrl/models/layers.py
@@ -1,117 +1,92 @@
+from typing import Optional, List
+
import numpy as np
import torch
from torch import nn as nn
+from torch import Tensor
+from itertools import product
+
+from cmrl.models.util import truncated_normal_
-import cmrl.models.util as model_util
-
-
-def truncated_normal_init(m: nn.Module):
- """Initializes the weights of the given module using a truncated normal distribution."""
-
- if isinstance(m, nn.Linear):
- input_dim = m.weight.data.shape[0]
- stddev = 1 / (2 * np.sqrt(input_dim))
- model_util.truncated_normal_(m.weight.data, std=stddev)
- m.bias.data.fill_(0.0)
- elif isinstance(m, EnsembleLinearLayer):
- num_members, input_dim, _ = m.weight.data.shape
- stddev = 1 / (2 * np.sqrt(input_dim))
- for i in range(num_members):
- model_util.truncated_normal_(m.weight.data[i], std=stddev)
- m.bias.data.fill_(0.0)
- elif isinstance(m, ParallelEnsembleLinearLayer):
- num_parallel, num_members, input_dim, _ = m.weight.data.shape
- stddev = 1 / (2 * np.sqrt(input_dim))
- for i in range(num_parallel):
- for j in range(num_members):
- model_util.truncated_normal_(m.weight.data[i, j], std=stddev)
- m.bias.data.fill_(0.0)
-
-
-class EnsembleLinearLayer(nn.Module):
- """Implements an ensemble of layers.
-
- Args:
- in_size (int): the input size of this layer.
- out_size (int): the output size of this layer.
- use_bias (bool): use bias in this layer or not.
- ensemble_num (int): the ensemble dimension of this layer,
- the corresponding part of each dimension is called a "member".
- """
+# partial from https://github.com/phlippe/ENCO/blob/main/causal_discovery/multivariable_mlp.py
+class ParallelLinear(nn.Module):
def __init__(
self,
- in_size: int,
- out_size: int,
- use_bias: bool = True,
- ensemble_num: int = 1,
+ input_dim: int,
+ output_dim: int,
+ extra_dims: Optional[List[int]] = None,
+ bias: bool = True,
+ init_type: str = "truncated_normal",
):
+ """Linear layer with the same properties as Parallel MLP. It effectively applies N independent linear layers
+ in parallel.
+
+ Args:
+ input_dim: Number of input dimensions per layer.
+ output_dim: Number of output dimensions per layer.
+ extra_dims: Number of neural networks to have in parallel (e.g. number of variables). Can have multiple
+ dimensions if needed.
+ bias: Weather using bias in this layer.
+ init_type: How to initialize weights and biases.
+ """
super().__init__()
- self.ensemble_num = ensemble_num
- self.in_size = in_size
- self.out_size = out_size
- self.weight = nn.Parameter(torch.rand(self.ensemble_num, self.in_size, self.out_size))
- if use_bias:
- self.bias = nn.Parameter(torch.rand(self.ensemble_num, 1, self.out_size))
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.extra_dims = [] if extra_dims is None else extra_dims
+ self.init_type = init_type
+
+ self.weight = nn.Parameter(torch.zeros(*self.extra_dims, self.input_dim, self.output_dim))
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(*self.extra_dims, 1, self.output_dim))
self.use_bias = True
else:
self.use_bias = False
- def forward(self, x):
- xw = x.matmul(self.weight)
- if self.use_bias:
- return xw + self.bias
- else:
- return xw
-
- def __repr__(self) -> str:
- return (
- f"in_size={self.in_size}, out_size={self.out_size}, use_bias={self.use_bias}, " f"ensemble_num={self.ensemble_num}"
- )
+ self.init_params()
+ def init_params(self):
+ """Initialize weights and biases. Currently, only `kaiming_uniform` and `truncated_normal` are supported.
-class ParallelEnsembleLinearLayer(nn.Module):
- """Implements an ensemble of parallel layers.
+ Returns: None
- Args:
- in_size (int): the input size of this layer.
- out_size (int): the output size of this layer.
- use_bias (bool): use bias in this layer or not.
- parallel_num (int): the parallel dimension of this layer,
- the corresponding part of each dimension is called a "sub-network".
- ensemble_num (int): the ensemble dimension of this layer,
- the corresponding part of each dimension is called a "member".
- """
-
- def __init__(
- self,
- in_size: int,
- out_size: int,
- use_bias: bool = True,
- parallel_num: int = 1,
- ensemble_num: int = 1,
- ):
- super().__init__()
- self.parallel_num = parallel_num
- self.ensemble_num = ensemble_num
- self.in_size = in_size
- self.out_size = out_size
- self.weight = nn.Parameter(torch.rand(self.parallel_num, self.ensemble_num, self.in_size, self.out_size))
- if use_bias:
- self.bias = nn.Parameter(torch.rand(self.parallel_num, self.ensemble_num, 1, self.out_size))
- self.use_bias = True
+ """
+ if self.init_type == "kaiming_uniform":
+ nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
+ elif self.init_type == "truncated_normal":
+ stddev = 1 / (2 * np.sqrt(self.input_dim))
+ for dims in product(*map(range, self.extra_dims)):
+ truncated_normal_(self.weight.data[dims], std=stddev)
else:
- self.use_bias = False
+ raise NotImplementedError
- def forward(self, x):
+ def forward(self, x: Tensor) -> Tensor:
xw = x.matmul(self.weight)
if self.use_bias:
return xw + self.bias
else:
return xw
- def __repr__(self) -> str:
- return (
- f"in_size={self.in_size}, out_size={self.out_size}, use_bias={self.use_bias}, "
- f"parallel_num={self.parallel_num}, ensemble_num={self.ensemble_num}"
+ @property
+ def device(self) -> torch.device:
+ """Infer which device this policy lives on by inspecting its parameters.
+ If it has no parameters, the 'cpu' device is used as a fallback.
+
+ Returns: device
+ """
+ for param in self.parameters():
+ return param.device
+ return torch.device("cpu")
+
+ def extra_repr(self):
+ return 'input_dims={}, output_dims={}, extra_dims={}, bias={}, init_type="{}"'.format(
+ self.input_dim, self.output_dim, str(self.extra_dims), self.use_bias, self.init_type
)
+
+
+class RadianLayer(nn.Module):
+ def __init__(self) -> None:
+ super(RadianLayer, self).__init__()
+
+ def forward(self, input: Tensor) -> Tensor:
+ return torch.remainder(input + torch.pi, 2 * torch.pi) - torch.pi
diff --git a/cmrl/models/networks/__init__.py b/cmrl/models/networks/__init__.py
new file mode 100644
index 0000000..8563cbc
--- /dev/null
+++ b/cmrl/models/networks/__init__.py
@@ -0,0 +1,2 @@
+from cmrl.models.networks.coder import VariableEncoder, VariableDecoder
+from cmrl.models.networks.parallel_mlp import ParallelMLP
diff --git a/cmrl/models/networks/base_network.py b/cmrl/models/networks/base_network.py
new file mode 100644
index 0000000..ab32a5e
--- /dev/null
+++ b/cmrl/models/networks/base_network.py
@@ -0,0 +1,71 @@
+import pathlib
+from typing import List, Optional, Sequence, Union
+from abc import abstractmethod
+
+import torch
+import torch.nn as nn
+import hydra
+from omegaconf import DictConfig
+
+
+class BaseNetwork(nn.Module):
+ def __init__(self, **kwargs):
+ """Base class of all neural network.
+
+ Args:
+ network_cfg:
+ """
+ super(BaseNetwork, self).__init__()
+
+ self._model_filename = "base_network.pth"
+ self._save_attrs: List[str] = ["state_dict"]
+ self._layers: Optional[nn.ModuleList] = None
+
+ self.build()
+
+ def save(self, save_dir: Union[str, pathlib.Path]):
+ """Saves the model to the given directory."""
+ model_dict = {}
+ for attr in self._save_attrs:
+ if attr == "state_dict":
+ model_dict["state_dict"] = self.state_dict()
+ else:
+ model_dict[attr] = getattr(self, attr)
+ torch.save(model_dict, pathlib.Path(save_dir) / self._model_filename)
+
+ def load(self, load_dir: Union[str, pathlib.Path]):
+ """Loads the model from the given path."""
+ model_dict = torch.load(pathlib.Path(load_dir) / self._model_filename, map_location=self.device)
+ for attr in model_dict:
+ if attr == "state_dict":
+ self.load_state_dict(model_dict["state_dict"])
+ else:
+ getattr(self, attr)(model_dict[attr])
+
+ def forward(self, x) -> torch.Tensor:
+ for layer in self._layers:
+ x = layer(x)
+ return x
+
+ @abstractmethod
+ def build(self):
+ raise NotImplementedError
+
+ @property
+ def save_attrs(self):
+ return self._save_attrs
+
+ @property
+ def model_filename(self):
+ return self._model_filename
+
+ @property
+ def device(self):
+ return next(iter(self.parameters())).device
+
+
+def create_activation(activation_fn_cfg: DictConfig):
+ if activation_fn_cfg is None:
+ return nn.ReLU()
+ else:
+ return hydra.utils.instantiate(activation_fn_cfg)
diff --git a/cmrl/models/networks/coder.py b/cmrl/models/networks/coder.py
new file mode 100644
index 0000000..ca0a133
--- /dev/null
+++ b/cmrl/models/networks/coder.py
@@ -0,0 +1,104 @@
+from typing import List, Optional
+
+import torch.nn as nn
+from omegaconf import DictConfig
+
+from cmrl.utils.variables import Variable, DiscreteVariable, ContinuousVariable, BinaryVariable, RadianVariable
+from cmrl.models.networks.base_network import BaseNetwork, create_activation
+from cmrl.models.layers import RadianLayer
+
+
+class VariableEncoder(BaseNetwork):
+ def __init__(
+ self,
+ variable: Variable,
+ output_dim: int = 100,
+ hidden_dims: Optional[List[int]] = None,
+ bias: bool = True,
+ activation_fn_cfg: Optional[DictConfig] = None,
+ ):
+ self.variable = variable
+ self.output_dim = output_dim
+ self.hidden_dims = hidden_dims if hidden_dims is not None else []
+ self.bias = bias
+ self.activation_fn_cfg = activation_fn_cfg
+
+ self.name = "{}_encoder".format(variable.name)
+
+ super(VariableEncoder, self).__init__()
+ self._model_filename = "{}.pth".format(self.name)
+
+ def build(self):
+ layers = []
+ if len(self.hidden_dims) == 0:
+ hidden_dim = self.output_dim
+ else:
+ hidden_dim = self.hidden_dims[0]
+
+ if isinstance(self.variable, ContinuousVariable):
+ layers.append(nn.Linear(self.variable.dim, hidden_dim))
+ elif isinstance(self.variable, RadianVariable):
+ layers.append(RadianLayer())
+ layers.append(nn.Linear(self.variable.dim, hidden_dim))
+ elif isinstance(self.variable, DiscreteVariable):
+ layers.append(nn.Linear(self.variable.n, hidden_dim))
+ elif isinstance(self.variable, BinaryVariable):
+ layers.append(nn.Linear(1, hidden_dim))
+ else:
+ raise NotImplementedError("Type {} is not supported by VariableEncoder".format(type(self.variable)))
+
+ hidden_dims = self.hidden_dims + [self.output_dim]
+ for i in range(len(hidden_dims) - 1):
+ layers += [nn.Linear(hidden_dims[i], hidden_dims[i + 1], bias=self.bias)]
+ layers += [create_activation(self.activation_fn_cfg)]
+
+ self._layers = nn.ModuleList(layers)
+
+
+class VariableDecoder(BaseNetwork):
+ def __init__(
+ self,
+ variable: Variable,
+ input_dim: int = 100,
+ hidden_dims: Optional[List[int]] = None,
+ bias: bool = True,
+ activation_fn_cfg: Optional[DictConfig] = None,
+ ):
+ self.variable = variable
+ self.input_dim = input_dim
+ self.hidden_dims = hidden_dims if hidden_dims is not None else []
+ self.bias = bias
+ self.activation_fn_cfg = activation_fn_cfg
+
+ self.name = "{}_decoder".format(variable.name)
+
+ super(VariableDecoder, self).__init__()
+ self._model_filename = "{}.pth".format(self.name)
+
+ def build(self):
+ layers = [create_activation(self.activation_fn_cfg)]
+
+ hidden_dims = [self.input_dim] + self.hidden_dims
+ for i in range(len(hidden_dims) - 1):
+ layers += [nn.Linear(hidden_dims[i], hidden_dims[i + 1], bias=self.bias)]
+ layers += [create_activation(self.activation_fn_cfg)]
+
+ if len(self.hidden_dims) == 0:
+ hidden_dim = self.input_dim
+ else:
+ hidden_dim = self.hidden_dims[-1]
+
+ if isinstance(self.variable, ContinuousVariable):
+ layers.append(nn.Linear(hidden_dim, self.variable.dim * 2))
+ elif isinstance(self.variable, RadianVariable):
+ layers.append(nn.Linear(hidden_dim, self.variable.dim * 2))
+ elif isinstance(self.variable, DiscreteVariable):
+ layers.append(nn.Linear(hidden_dim, self.variable.n))
+ layers.append(nn.Softmax())
+ elif isinstance(self.variable, BinaryVariable):
+ layers.append(nn.Linear(hidden_dim, 1))
+ layers.append(nn.Sigmoid())
+ else:
+ raise NotImplementedError("Type {} is not supported by VariableDecoder".format(type(self.variable)))
+
+ self._layers = nn.ModuleList(layers)
diff --git a/cmrl/models/networks/mlp.py b/cmrl/models/networks/mlp.py
deleted file mode 100644
index 9b9e8c3..0000000
--- a/cmrl/models/networks/mlp.py
+++ /dev/null
@@ -1,100 +0,0 @@
-import pathlib
-from typing import Dict, Optional, Sequence, Union
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from cmrl.models.util import gaussian_nll
-from cmrl.models.layers import EnsembleLinearLayer
-
-
-class EnsembleMLP(nn.Module):
- _MODEL_FILENAME = "ensemble_mlp.pth"
-
- def __init__(
- self,
- ensemble_num: int = 7,
- elite_num: int = 5,
- device: Union[str, torch.device] = "cpu",
- ):
- super(EnsembleMLP, self).__init__()
- self.ensemble_num = ensemble_num
- self.elite_num = elite_num
- self.device = device
-
- self._elite_members: Optional[Sequence[int]] = np.random.permutation(ensemble_num)[:elite_num]
-
- self._model_save_attrs = ["elite_members", "state_dict"]
-
- def set_elite_members(self, elite_indices: Sequence[int]):
- if len(elite_indices) != self.ensemble_num:
- assert len(elite_indices) == self.elite_num
- self._elite_members = list(elite_indices)
-
- @property
- def elite_members(self):
- return self._elite_members
-
- def get_random_index(self, batch_size: int, numpy_generator: Optional[np.random.Generator] = None):
- if numpy_generator:
- return numpy_generator.choice(self._elite_members, size=batch_size)
- else:
- return np.random.choice(self._elite_members, size=batch_size)
-
- def save(self, save_dir: Union[str, pathlib.Path]):
- """Saves the model to the given directory."""
- model_dict = {}
- for attr in self._model_save_attrs:
- if attr == "state_dict":
- model_dict["state_dict"] = self.state_dict()
- else:
- model_dict[attr] = getattr(self, attr)
- torch.save(model_dict, pathlib.Path(save_dir) / self._MODEL_FILENAME)
-
- def load(self, load_dir: Union[str, pathlib.Path], load_device: Optional[str] = None):
- """Loads the model from the given path."""
- model_dict = torch.load(pathlib.Path(load_dir) / self._MODEL_FILENAME, map_location=load_device)
- for attr in model_dict:
- if attr == "state_dict":
- self.load_state_dict(model_dict["state_dict"])
- else:
- getattr(self, "set_" + attr)(model_dict[attr])
-
- def create_linear_layer(self, l_in, l_out):
- return EnsembleLinearLayer(l_in, l_out, ensemble_num=self.ensemble_num)
-
- def get_mse_loss(self, model_in: Dict[(str, torch.Tensor)], target: torch.Tensor) -> torch.Tensor:
- pred_mean, pred_logvar = self.forward(**model_in)
- return F.mse_loss(pred_mean, target, reduction="none")
-
- def get_nll_loss(self, model_in: Dict[(str, torch.Tensor)], target: torch.Tensor) -> torch.Tensor:
- pred_mean, pred_logvar = self.forward(**model_in)
- nll_loss = gaussian_nll(pred_mean, pred_logvar, target, reduce=False)
- nll_loss += 0.01 * (self.max_logvar.sum() - self.min_logvar.sum())
- return nll_loss
-
- def add_save_attr(self, attr: str):
- assert hasattr(self, attr), "Class must has attribute {}".format(attr)
- assert attr not in self._model_save_attrs, "Attribute {} has been in model-save-list".format(attr)
- self._model_save_attrs.append(attr)
-
- @property
- def save_attr(self):
- return self._model_save_attrs
-
- @property
- def model_file_name(self):
- return self._MODEL_FILENAME
-
-
-class ExternalMaskEnsembleMLP(EnsembleMLP):
- """Ensemble of multi-layer perceptrons with input mask inside
-
- Args:
- TODO
- """
-
- def __init__(self, ensemble_num: int = 7, elite_num: int = 5, device: Union[str, torch.device] = "cpu"):
- super().__init__(ensemble_num, elite_num, device)
diff --git a/cmrl/models/networks/parallel_mlp.py b/cmrl/models/networks/parallel_mlp.py
new file mode 100644
index 0000000..e322f19
--- /dev/null
+++ b/cmrl/models/networks/parallel_mlp.py
@@ -0,0 +1,52 @@
+import pathlib
+from typing import List, Optional, Sequence, Union
+from abc import abstractmethod
+
+import torch
+import torch.nn as nn
+import hydra
+from omegaconf import DictConfig
+
+from cmrl.models.layers import ParallelLinear
+from cmrl.models.networks.base_network import BaseNetwork, create_activation
+
+
+# partial from https://github.com/phlippe/ENCO/blob/main/causal_discovery/multivariable_mlp.py
+class ParallelMLP(BaseNetwork):
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ extra_dims: Optional[List[int]] = None,
+ hidden_dims: Optional[List[int]] = None,
+ bias: bool = True,
+ init_type: str = "truncated_normal",
+ activation_fn_cfg: Optional[DictConfig] = None,
+ **kwargs
+ ):
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.extra_dims = extra_dims if extra_dims is not None else []
+ self.hidden_dims = hidden_dims if hidden_dims is not None else [200, 200, 200, 200]
+ self.bias = bias
+ self.init_type = init_type
+ self.activation_fn_cfg = activation_fn_cfg
+
+ super().__init__(**kwargs)
+ self._model_filename = "parallel_mlp.pth"
+
+ def build(self):
+ layers = []
+ hidden_dims = [self.input_dim] + self.hidden_dims
+ for i in range(len(hidden_dims) - 1):
+ layers += [
+ ParallelLinear(
+ input_dim=hidden_dims[i], output_dim=hidden_dims[i + 1], extra_dims=self.extra_dims, bias=self.bias
+ )
+ ]
+ layers += [create_activation(self.activation_fn_cfg)]
+ layers += [
+ ParallelLinear(input_dim=hidden_dims[-1], output_dim=self.output_dim, extra_dims=self.extra_dims, bias=self.bias)
+ ]
+
+ self._layers = nn.ModuleList(layers)
diff --git a/cmrl/algorithms/online/__init__.py b/cmrl/models/networks/util.py
similarity index 100%
rename from cmrl/algorithms/online/__init__.py
rename to cmrl/models/networks/util.py
diff --git a/cmrl/models/reward_mech/__init__.py b/cmrl/models/reward_mech/__init__.py
deleted file mode 100644
index 77ea159..0000000
--- a/cmrl/models/reward_mech/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from cmrl.models.reward_mech.base_reward_mech import BaseRewardMech
-
-from cmrl.models.reward_mech.plain_reward_mech import PlainRewardMech
diff --git a/cmrl/models/reward_mech/base_reward_mech.py b/cmrl/models/reward_mech/base_reward_mech.py
deleted file mode 100644
index f2addef..0000000
--- a/cmrl/models/reward_mech/base_reward_mech.py
+++ /dev/null
@@ -1,26 +0,0 @@
-from typing import Union
-
-import torch
-
-from cmrl.models.networks.mlp import EnsembleMLP
-
-
-class BaseRewardMech(EnsembleMLP):
- _MODEL_FILENAME = "base_reward_mech.pth"
-
- def __init__(
- self,
- obs_size: int,
- action_size: int,
- deterministic: bool = False,
- ensemble_num: int = 7,
- elite_num: int = 5,
- device: Union[str, torch.device] = "cpu",
- ):
- super(BaseRewardMech, self).__init__(ensemble_num=ensemble_num, elite_num=elite_num, device=device)
- self.obs_size = obs_size
- self.action_size = action_size
- self.deterministic = deterministic
-
- def forward(self, state: torch.Tensor, action: torch.Tensor):
- pass
diff --git a/cmrl/models/reward_mech/plain_reward_mech.py b/cmrl/models/reward_mech/plain_reward_mech.py
deleted file mode 100644
index a111471..0000000
--- a/cmrl/models/reward_mech/plain_reward_mech.py
+++ /dev/null
@@ -1,94 +0,0 @@
-from typing import Dict, Optional, Tuple, Union
-
-import hydra
-import omegaconf
-import torch
-from torch import nn as nn
-from torch.nn import functional as F
-
-from cmrl.models.layers import truncated_normal_init
-from cmrl.models.reward_mech.base_reward_mech import BaseRewardMech
-
-
-class PlainRewardMech(BaseRewardMech):
- _MODEL_FILENAME = "plain_reward_mech.pth"
-
- def __init__(
- self,
- # transition info
- obs_size: int,
- action_size: int,
- deterministic: bool = False,
- # algorithm parameters
- ensemble_num: int = 7,
- elite_num: int = 5,
- learn_logvar_bounds: bool = False,
- # network parameters
- num_layers: int = 4,
- hid_size: int = 200,
- activation_fn_cfg: Optional[Union[Dict, omegaconf.DictConfig]] = None,
- # others
- device: Union[str, torch.device] = "cpu",
- ):
- super(PlainRewardMech, self).__init__(
- obs_size=obs_size,
- action_size=action_size,
- deterministic=deterministic,
- ensemble_num=ensemble_num,
- elite_num=elite_num,
- device=device,
- )
- self.num_layers = num_layers
- self.hid_size = hid_size
-
- def create_activation():
- if activation_fn_cfg is None:
- return nn.ReLU()
- else:
- return hydra.utils.instantiate(activation_fn_cfg)
-
- hidden_layers = [
- nn.Sequential(
- self.create_linear_layer(obs_size + action_size, hid_size),
- create_activation(),
- )
- ]
- for i in range(num_layers - 1):
- hidden_layers.append(
- nn.Sequential(
- self.create_linear_layer(hid_size, hid_size),
- create_activation(),
- )
- )
- self.hidden_layers = nn.Sequential(*hidden_layers)
-
- if deterministic:
- self.mean_and_logvar = self.create_linear_layer(hid_size, 1)
- else:
- self.mean_and_logvar = self.create_linear_layer(hid_size, 2)
- self.min_logvar = nn.Parameter(-10 * torch.ones(1), requires_grad=learn_logvar_bounds)
- self.max_logvar = nn.Parameter(0.5 * torch.ones(1), requires_grad=learn_logvar_bounds)
-
- self.apply(truncated_normal_init)
- self.to(self.device)
-
- def forward(
- self,
- batch_obs: torch.Tensor, # shape: ensemble_num, batch_size, state_size
- batch_action: torch.Tensor, # shape: ensemble_num, batch_size, action_size
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
- assert len(batch_obs.shape) == 3 and batch_obs.shape[-1] == self.obs_size
- assert len(batch_action.shape) == 3 and batch_action.shape[-1] == self.action_size
-
- hidden = self.hidden_layers(torch.concat([batch_obs, batch_action], dim=-1))
- mean_and_logvar = self.mean_and_logvar(hidden)
-
- if self.deterministic:
- mean, logvar = mean_and_logvar, None
- else:
- mean = mean_and_logvar[..., :1]
- logvar = mean_and_logvar[..., 1:]
- logvar = self.max_logvar - F.softplus(self.max_logvar - logvar)
- logvar = self.min_logvar + F.softplus(logvar - self.min_logvar)
-
- return mean, logvar
diff --git a/cmrl/models/termination_mech/__init__.py b/cmrl/models/termination_mech/__init__.py
deleted file mode 100644
index 43a4ca3..0000000
--- a/cmrl/models/termination_mech/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from cmrl.models.termination_mech.base_termination_mech import BaseTerminationMech
-
-from cmrl.models.termination_mech.plain_termination_mech import PlainTerminationMech
diff --git a/cmrl/models/termination_mech/base_termination_mech.py b/cmrl/models/termination_mech/base_termination_mech.py
deleted file mode 100644
index cd80284..0000000
--- a/cmrl/models/termination_mech/base_termination_mech.py
+++ /dev/null
@@ -1,26 +0,0 @@
-from typing import Union
-
-import torch
-
-from cmrl.models.networks.mlp import EnsembleMLP
-
-
-class BaseTerminationMech(EnsembleMLP):
- _MODEL_FILENAME = "base_reward_mech.pth"
-
- def __init__(
- self,
- obs_size: int,
- action_size: int,
- deterministic: bool = False,
- ensemble_num: int = 7,
- elite_num: int = 5,
- device: Union[str, torch.device] = "cpu",
- ):
- super(BaseTerminationMech, self).__init__(ensemble_num=ensemble_num, elite_num=elite_num, device=device)
- self.obs_size = obs_size
- self.action_size = action_size
- self.deterministic = deterministic
-
- def forward(self, state: torch.Tensor, action: torch.Tensor):
- pass
diff --git a/cmrl/models/termination_mech/plain_termination_mech.py b/cmrl/models/termination_mech/plain_termination_mech.py
deleted file mode 100644
index 1143690..0000000
--- a/cmrl/models/termination_mech/plain_termination_mech.py
+++ /dev/null
@@ -1,94 +0,0 @@
-from typing import Dict, Optional, Tuple, Union
-
-import hydra
-import omegaconf
-import torch
-from torch import nn as nn
-from torch.nn import functional as F
-
-from cmrl.models.layers import truncated_normal_init
-from cmrl.models.termination_mech.base_termination_mech import BaseTerminationMech
-
-
-class PlainTerminationMech(BaseTerminationMech):
- _MODEL_FILENAME = "plain_termination_mech.pth"
-
- def __init__(
- self,
- # transition info
- obs_size: int,
- action_size: int,
- deterministic: bool = False,
- # algorithm parameters
- ensemble_num: int = 7,
- elite_num: int = 5,
- learn_logvar_bounds: bool = False,
- # network parameters
- num_layers: int = 4,
- hid_size: int = 200,
- activation_fn_cfg: Optional[Union[Dict, omegaconf.DictConfig]] = None,
- # others
- device: Union[str, torch.device] = "cpu",
- ):
- super(PlainTerminationMech, self).__init__(
- obs_size=obs_size,
- action_size=action_size,
- deterministic=deterministic,
- ensemble_num=ensemble_num,
- elite_num=elite_num,
- device=device,
- )
- self.num_layers = num_layers
- self.hid_size = hid_size
-
- def create_activation():
- if activation_fn_cfg is None:
- return nn.ReLU()
- else:
- return hydra.utils.instantiate(activation_fn_cfg)
-
- hidden_layers = [
- nn.Sequential(
- self.create_linear_layer(obs_size + action_size, hid_size),
- create_activation(),
- )
- ]
- for i in range(num_layers - 1):
- hidden_layers.append(
- nn.Sequential(
- self.create_linear_layer(hid_size, hid_size),
- create_activation(),
- )
- )
- self.hidden_layers = nn.Sequential(*hidden_layers)
-
- if deterministic:
- self.mean_and_logvar = self.create_linear_layer(hid_size, 1)
- else:
- self.mean_and_logvar = self.create_linear_layer(hid_size, 2)
- self.min_logvar = nn.Parameter(-10 * torch.ones(1), requires_grad=learn_logvar_bounds)
- self.max_logvar = nn.Parameter(0.5 * torch.ones(1), requires_grad=learn_logvar_bounds)
-
- self.apply(truncated_normal_init)
- self.to(self.device)
-
- def forward(
- self,
- batch_obs: torch.Tensor, # shape: ensemble_num, batch_size, state_size
- batch_action: torch.Tensor, # shape: ensemble_num, batch_size, action_size
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
- assert len(batch_obs.shape) == 3 and batch_obs.shape[-1] == self.obs_size
- assert len(batch_action.shape) == 3 and batch_action.shape[-1] == self.action_size
-
- hidden = self.hidden_layers(torch.concat([batch_obs, batch_action], dim=-1))
- mean_and_logvar = self.mean_and_logvar(hidden)
-
- if self.deterministic:
- mean, logvar = mean_and_logvar, None
- else:
- mean = mean_and_logvar[..., :1]
- logvar = mean_and_logvar[..., 1:]
- logvar = self.max_logvar - F.softplus(self.max_logvar - logvar)
- logvar = self.min_logvar + F.softplus(logvar - self.min_logvar)
-
- return mean, logvar
diff --git a/cmrl/models/transition/__init__.py b/cmrl/models/transition/__init__.py
deleted file mode 100644
index 23295bd..0000000
--- a/cmrl/models/transition/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from cmrl.models.transition.base_transition import BaseTransition
-
-from cmrl.models.transition.one_step.external_mask_transition import ExternalMaskTransition
-from cmrl.models.transition.one_step.plain_transition import PlainTransition
-
-from cmrl.models.transition.multi_step.forward_euler import ForwardEulerTransition
diff --git a/cmrl/models/transition/base_transition.py b/cmrl/models/transition/base_transition.py
deleted file mode 100644
index 7f8f5f7..0000000
--- a/cmrl/models/transition/base_transition.py
+++ /dev/null
@@ -1,26 +0,0 @@
-from typing import Union
-
-import torch
-
-from cmrl.models.networks.mlp import EnsembleMLP
-
-
-class BaseTransition(EnsembleMLP):
- _MODEL_FILENAME = "base_ensemble_transition.pth"
-
- def __init__(
- self,
- obs_size: int,
- action_size: int,
- deterministic: bool,
- ensemble_num: int = 7,
- elite_num: int = 5,
- device: Union[str, torch.device] = "cpu",
- ):
- super(BaseTransition, self).__init__(ensemble_num=ensemble_num, elite_num=elite_num, device=device)
- self.obs_size = obs_size
- self.action_size = action_size
- self.deterministic = deterministic
-
- def forward(self, state: torch.Tensor, action: torch.Tensor):
- pass
diff --git a/cmrl/models/transition/multi_step/forward_euler.py b/cmrl/models/transition/multi_step/forward_euler.py
deleted file mode 100644
index 7b59388..0000000
--- a/cmrl/models/transition/multi_step/forward_euler.py
+++ /dev/null
@@ -1,41 +0,0 @@
-from typing import Dict, Optional, Sequence, Tuple, Union
-
-import torch
-
-from cmrl.models.transition.base_transition import BaseTransition
-
-
-class ForwardEulerTransition(BaseTransition):
- def __init__(self, one_step_transition: BaseTransition, repeat_times: int = 2):
- super().__init__(
- obs_size=one_step_transition.obs_size,
- action_size=one_step_transition.action_size,
- ensemble_num=one_step_transition.ensemble_num,
- deterministic=one_step_transition.deterministic,
- device=one_step_transition.device,
- )
-
- self.one_step_transition = one_step_transition
- self.repeat_times = repeat_times
-
- if hasattr(self.one_step_transition, "max_logvar"):
- self.max_logvar = one_step_transition.max_logvar
- self.min_logvar = one_step_transition.min_logvar
-
- if hasattr(self.one_step_transition, "input_mask"):
- self.input_mask = self.one_step_transition.input_mask
- self.set_input_mask = self.one_step_transition.set_input_mask
-
- def set_elite_members(self, elite_indices: Sequence[int]):
- self.one_step_transition.set_elite_members(elite_indices)
-
- def forward(
- self,
- batch_obs: torch.Tensor,
- batch_action: torch.Tensor,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
- logvar = torch.zeros(batch_obs.shape, device=self.device)
- mean = batch_obs
- for t in range(self.repeat_times):
- mean, logvar = self.one_step_transition.forward(mean, batch_action.clone())
- return mean, logvar
diff --git a/cmrl/models/transition/one_step/external_mask_transition.py b/cmrl/models/transition/one_step/external_mask_transition.py
deleted file mode 100644
index 783236c..0000000
--- a/cmrl/models/transition/one_step/external_mask_transition.py
+++ /dev/null
@@ -1,169 +0,0 @@
-from typing import Dict, Optional, Sequence, Tuple, Union
-
-import hydra
-import omegaconf
-import torch
-from torch import nn as nn
-from torch.nn import functional as F
-
-import cmrl.types
-from cmrl.models.layers import ParallelEnsembleLinearLayer, truncated_normal_init
-from cmrl.models.transition.base_transition import BaseTransition
-from cmrl.models.util import to_tensor
-
-
-class ExternalMaskTransition(BaseTransition):
- """Implements an ensemble of multi-layer perceptrons each modeling a Gaussian distribution
- corresponding to each independent dimension.
-
- Args:
- obs_size (int): size of state.
- action_size (int): size of action.
- device (str or torch.device): the device to use for the model.
- num_layers (int): the number of layers in the model
- (e.g., if ``num_layers == 3``, then model graph looks like
- input -h1-> -h2-> -l3-> output).
- ensemble_num (int): the number of members in the ensemble. Defaults to 1.
- hid_size (int): the size of the hidden layers (e.g., size of h1 and h2 in the graph above).
- deterministic (bool): if ``True``, the model predicts the mean and logvar of the conditional
- gaussian distribution, otherwise only predicts the mean. Defaults to ``False``.
- residual (bool): if ``True``, the model predicts the residual of output and input. Defaults to ``True``.
- learn_logvar_bounds (bool): if ``True``, the log-var bounds will be learned, otherwise
- they will be constant. Defaults to ``False``.
- activation_fn_cfg (dict or omegaconf.DictConfig, optional): configuration of the
- desired activation function. Defaults to torch.nn.ReLU when ``None``.
- """
-
- _MODEL_FILENAME = "external_mask_transition.pth"
-
- def __init__(
- self,
- # transition info
- obs_size: int,
- action_size: int,
- deterministic: bool = False,
- # algorithm parameters
- ensemble_num: int = 7,
- elite_num: int = 5,
- residual: bool = True,
- learn_logvar_bounds: bool = False,
- # network parameters
- num_layers: int = 4,
- hid_size: int = 200,
- activation_fn_cfg: Optional[Union[Dict, omegaconf.DictConfig]] = None,
- # others
- device: Union[str, torch.device] = "cpu",
- ):
- super().__init__(
- obs_size=obs_size,
- action_size=action_size,
- deterministic=deterministic,
- ensemble_num=ensemble_num,
- elite_num=elite_num,
- device=device,
- )
- self.residual = residual
- self.learn_logvar_bounds = learn_logvar_bounds
-
- self.num_layers = num_layers
- self.hid_size = hid_size
- self.activation_fn_cfg = activation_fn_cfg
-
- self._input_mask: Optional[torch.Tensor] = torch.ones((obs_size, obs_size + action_size)).to(device)
- self.add_save_attr("input_mask")
-
- def create_activation():
- if activation_fn_cfg is None:
- return nn.ReLU()
- else:
- return hydra.utils.instantiate(activation_fn_cfg)
-
- hidden_layers = [
- nn.Sequential(
- self.create_linear_layer(obs_size + action_size, hid_size),
- create_activation(),
- )
- ]
- for i in range(num_layers - 1):
- hidden_layers.append(
- nn.Sequential(
- self.create_linear_layer(hid_size, hid_size),
- create_activation(),
- )
- )
- self.hidden_layers = nn.Sequential(*hidden_layers)
-
- if deterministic:
- self.mean_and_logvar = self.create_linear_layer(hid_size, 1)
- else:
- self.mean_and_logvar = self.create_linear_layer(hid_size, 2)
- self.min_logvar = nn.Parameter(-10 * torch.ones(obs_size, 1, 1, 1), requires_grad=learn_logvar_bounds)
- self.max_logvar = nn.Parameter(0.5 * torch.ones(obs_size, 1, 1, 1), requires_grad=learn_logvar_bounds)
-
- self.apply(truncated_normal_init)
- self.to(self.device)
-
- def create_linear_layer(self, l_in, l_out):
- return ParallelEnsembleLinearLayer(l_in, l_out, parallel_num=self.obs_size, ensemble_num=self.ensemble_num)
-
- def set_input_mask(self, mask: cmrl.types.TensorType):
- self._input_mask = to_tensor(mask).to(self.device)
-
- @property
- def input_mask(self):
- return self._input_mask
-
- def mask_input(self, x: torch.Tensor) -> torch.Tensor:
- assert x.ndim == 4
- assert self._input_mask is not None
- assert 2 <= self._input_mask.ndim <= 4
- assert x.shape[0] == self._input_mask.shape[0] and x.shape[-1] == self._input_mask.shape[-1]
-
- if self._input_mask.ndim == 2:
- # [parallel_size x in_dim]
- input_mask = self._input_mask[:, None, None, :]
- elif self._input_mask.ndim == 3:
- if self._input_mask.shape[1] == x.shape[1]:
- # [parallel_size x ensemble_size x in_dim]
- input_mask = self._input_mask[:, :, None, :]
- elif self._input_mask.shape[1] == x.shape[2]:
- # [parallel_size x batch_size x in_dim]
- input_mask = self._input_mask[:, None, :, :]
- else:
- raise RuntimeError("input mask shape %a does not match x shape %a" % (self._input_mask.shape, x.shape))
- else:
- assert self._input_mask.shape == x.shape
- input_mask = self._input_mask
-
- x *= input_mask
- return x
-
- def forward(
- self,
- batch_obs: torch.Tensor, # shape: ensemble_num, batch_size, obs_size
- batch_action: torch.Tensor, # shape: ensemble_num, batch_size, action_size
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
- assert len(batch_obs.shape) == 3 and batch_obs.shape[-1] == self.obs_size
- assert len(batch_action.shape) == 3 and batch_action.shape[-1] == self.action_size
-
- repeated_input = torch.concat([batch_obs, batch_action], dim=-1).repeat((self.obs_size, 1, 1, 1))
- masked_input = self.mask_input(repeated_input)
- hidden = self.hidden_layers(masked_input)
- mean_and_logvar = self.mean_and_logvar(hidden)
-
- if self.deterministic:
- mean, logvar = mean_and_logvar, None
- else:
- mean = mean_and_logvar[..., :1]
- logvar = mean_and_logvar[..., 1:]
- logvar = self.max_logvar - F.softplus(self.max_logvar - logvar)
- logvar = self.min_logvar + F.softplus(logvar - self.min_logvar)
-
- mean = torch.transpose(mean, 0, -1)[0]
- if logvar is not None:
- logvar = torch.transpose(logvar, 0, -1)[0]
-
- if self.residual:
- mean += batch_obs
-
- return mean, logvar
diff --git a/cmrl/models/transition/one_step/internal_mask_transition.py b/cmrl/models/transition/one_step/internal_mask_transition.py
deleted file mode 100644
index e69de29..0000000
diff --git a/cmrl/models/transition/one_step/plain_transition.py b/cmrl/models/transition/one_step/plain_transition.py
deleted file mode 100644
index 3efa810..0000000
--- a/cmrl/models/transition/one_step/plain_transition.py
+++ /dev/null
@@ -1,122 +0,0 @@
-from typing import Dict, Optional, Sequence, Tuple, Union
-
-import hydra
-import omegaconf
-import torch
-from torch import nn as nn
-from torch.nn import functional as F
-
-from cmrl.models.layers import EnsembleLinearLayer, truncated_normal_init
-from cmrl.models.transition.base_transition import BaseTransition
-
-
-class PlainTransition(BaseTransition):
- """Implements an ensemble of multi-layer perceptrons each modeling a Gaussian distribution.
-
- Args:
- obs_size (int): size of state.
- action_size (int): size of action.
- device (str or torch.device): the device to use for the model.
- num_layers (int): the number of layers in the model
- (e.g., if ``num_layers == 3``, then model graph looks like
- input -h1-> -h2-> -l3-> output).
- ensemble_num (int): the number of members in the ensemble. Defaults to 1.
- hid_size (int): the size of the hidden layers (e.g., size of h1 and h2 in the graph above).
- deterministic (bool): if ``True``, the model predicts the mean and logvar of the conditional
- gaussian distribution, otherwise only predicts the mean. Defaults to ``False``.
- residual (bool): if ``True``, the model predicts the residual of output and input. Defaults to ``True``.
- learn_logvar_bounds (bool): if ``True``, the log-var bounds will be learned, otherwise
- they will be constant. Defaults to ``False``.
- activation_fn_cfg (dict or omegaconf.DictConfig, optional): configuration of the
- desired activation function. Defaults to torch.nn.ReLU when ``None``.
- """
-
- _MODEL_FILENAME = "plain_transition.pth"
-
- def __init__(
- self,
- # transition info
- obs_size: int,
- action_size: int,
- deterministic: bool = False,
- # algorithm parameters
- ensemble_num: int = 7,
- elite_num: int = 5,
- residual: bool = True,
- learn_logvar_bounds: bool = False,
- # network parameters
- num_layers: int = 4,
- hid_size: int = 200,
- activation_fn_cfg: Optional[Union[Dict, omegaconf.DictConfig]] = None,
- # others
- device: Union[str, torch.device] = "cpu",
- ):
- super().__init__(
- obs_size=obs_size,
- action_size=action_size,
- deterministic=deterministic,
- ensemble_num=ensemble_num,
- elite_num=elite_num,
- device=device,
- )
- self.residual = residual
- self.learn_logvar_bounds = learn_logvar_bounds
-
- self.num_layers = num_layers
- self.hid_size = hid_size
- self.activation_fn_cfg = activation_fn_cfg
-
- def create_activation():
- if activation_fn_cfg is None:
- return nn.ReLU()
- else:
- return hydra.utils.instantiate(activation_fn_cfg)
-
- hidden_layers = [
- nn.Sequential(
- self.create_linear_layer(obs_size + action_size, hid_size),
- create_activation(),
- )
- ]
- for i in range(num_layers - 1):
- hidden_layers.append(
- nn.Sequential(
- self.create_linear_layer(hid_size, hid_size),
- create_activation(),
- )
- )
- self.hidden_layers = nn.Sequential(*hidden_layers)
-
- if deterministic:
- self.mean_and_logvar = self.create_linear_layer(hid_size, obs_size)
- else:
- self.mean_and_logvar = self.create_linear_layer(hid_size, 2 * obs_size)
- self.min_logvar = nn.Parameter(-10 * torch.ones(1, obs_size), requires_grad=learn_logvar_bounds)
- self.max_logvar = nn.Parameter(0.5 * torch.ones(1, obs_size), requires_grad=learn_logvar_bounds)
-
- self.apply(truncated_normal_init)
- self.to(self.device)
-
- def forward(
- self,
- batch_obs: torch.Tensor, # shape: ensemble_num, batch_size, obs_size
- batch_action: torch.Tensor, # shape: ensemble_num, batch_size, action_size
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
- assert len(batch_obs.shape) == 3 and batch_obs.shape[-1] == self.obs_size
- assert len(batch_action.shape) == 3 and batch_action.shape[-1] == self.action_size
-
- hidden = self.hidden_layers(torch.concat([batch_obs, batch_action], dim=-1))
- mean_and_logvar = self.mean_and_logvar(hidden)
-
- if self.deterministic:
- mean, logvar = mean_and_logvar, None
- else:
- mean = mean_and_logvar[..., : self.obs_size]
- logvar = mean_and_logvar[..., self.obs_size :]
- logvar = self.max_logvar - F.softplus(self.max_logvar - logvar)
- logvar = self.min_logvar + F.softplus(logvar - self.min_logvar)
-
- if self.residual:
- mean += batch_obs
-
- return mean, logvar
diff --git a/cmrl/models/util.py b/cmrl/models/util.py
index efc6849..a2c63ef 100644
--- a/cmrl/models/util.py
+++ b/cmrl/models/util.py
@@ -1,40 +1,10 @@
-# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-from typing import List, Sequence, Tuple
+from typing import List, Optional, Union, Dict
import numpy as np
import torch
-import torch.nn.functional as F
+from gym import spaces
-import cmrl.types
-
-
-def gaussian_nll(
- pred_mean: torch.Tensor,
- pred_logvar: torch.Tensor,
- target: torch.Tensor,
- reduce: bool = True,
-) -> torch.Tensor:
- """Negative log-likelihood for Gaussian distribution
-
- Args:
- pred_mean (tensor): the predicted mean.
- pred_logvar (tensor): the predicted log variance.
- target (tensor): the target value.
- reduce (bool): if ``False`` the loss is returned w/o reducing.
- Defaults to ``True``.
-
- Returns:
- (tensor): the negative log-likelihood.
- """
- l2 = F.mse_loss(pred_mean, target, reduction="none")
- inv_var = (-pred_logvar).exp()
- losses = l2 * inv_var + pred_logvar
- if reduce:
- return losses.sum(dim=1).mean()
- return losses
+from cmrl.utils.variables import Variable, ContinuousVariable, DiscreteVariable, BinaryVariable
# inplace truncated normal function for pytorch.
@@ -59,11 +29,3 @@ def truncated_normal_(tensor: torch.Tensor, mean: float = 0, std: float = 1) ->
break
tensor[cond] = torch.normal(mean, std, size=(bound_violations,), device=tensor.device)
return tensor
-
-
-def to_tensor(x: cmrl.types.TensorType):
- if isinstance(x, torch.Tensor):
- return x
- if isinstance(x, np.ndarray):
- return torch.from_numpy(x)
- raise ValueError("Input must be torch.Tensor or np.ndarray.")
diff --git a/cmrl/sb3_extension/online_mb_callback.py b/cmrl/sb3_extension/online_mb_callback.py
index 49cb918..06f2655 100644
--- a/cmrl/sb3_extension/online_mb_callback.py
+++ b/cmrl/sb3_extension/online_mb_callback.py
@@ -1,5 +1,5 @@
import os
-import warnings
+import pathlib
from typing import Any, Callable, Dict, List, Optional, Union
from copy import deepcopy
@@ -8,53 +8,67 @@
from stable_baselines3.common.callbacks import BaseCallback, EventCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.buffers import ReplayBuffer
-from stable_baselines3.common.vec_env import (
- DummyVecEnv,
- VecEnv,
- sync_envs_normalization,
-)
+from stable_baselines3.common.vec_env import DummyVecEnv
from cmrl.models.fake_env import VecFakeEnv
-from cmrl.models.dynamics.base_dynamics import BaseDynamics
+from cmrl.models.dynamics import Dynamics
class OnlineModelBasedCallback(BaseCallback):
def __init__(
self,
env: gym.Env,
- dynamics: BaseDynamics,
+ dynamics: Dynamics,
real_replay_buffer: ReplayBuffer,
- total_num_steps: int = int(1e5),
+ # online RL
+ total_online_timesteps: int = int(1e5),
initial_exploration_steps: int = 1000,
freq_train_model: int = 250,
+ # dynamics learning
+ longest_epoch: int = -1,
+ improvement_threshold: float = 0.01,
+ patience: int = 5,
+ work_dir: Optional[Union[str, pathlib.Path]] = None,
device: str = "cpu",
):
super(OnlineModelBasedCallback, self).__init__(verbose=2)
self.env = DummyVecEnv([lambda: env])
self.dynamics = dynamics
- self.total_num_steps = total_num_steps
+ self.real_replay_buffer = real_replay_buffer
+ # online RL
+ self.total_online_timesteps = total_online_timesteps
self.initial_exploration_steps = initial_exploration_steps
self.freq_train_model = freq_train_model
+ # dynamics learning
+ self.longest_epoch = longest_epoch
+ self.improvement_threshold = improvement_threshold
+ self.patience = patience
+ self.work_dir = work_dir
self.device = device
self.action_space = env.action_space
self.observation_space = env.observation_space
- self.real_replay_buffer = real_replay_buffer
-
- self.now_num_steps = 0
- self.step_times = 0
+ self.now_online_timesteps = 0
self._last_obs = None
def _on_step(self) -> bool:
- if self.step_times % self.freq_train_model == 0:
- self.dynamics.learn(self.real_replay_buffer)
+ if self.n_calls % self.freq_train_model == 0:
+ # dump some residual log before dynamics learn
+ self.model.logger.dump(step=self.num_timesteps)
+
+ self.dynamics.learn(
+ self.real_replay_buffer,
+ longest_epoch=self.longest_epoch,
+ improvement_threshold=self.improvement_threshold,
+ patience=self.patience,
+ work_dir=self.work_dir,
+ )
- self.step_and_add(explore=False)
- self.step_times += 1
+ self.step_and_add(explore=False)
- if self.now_num_steps >= self.total_num_steps:
+ if self.now_online_timesteps >= self.total_online_timesteps:
return False
return True
@@ -62,7 +76,7 @@ def _on_training_start(self):
assert self.env.num_envs == 1
self._last_obs = self.env.reset()
- while self.now_num_steps < self.initial_exploration_steps:
+ while self.now_online_timesteps < self.initial_exploration_steps:
self.step_and_add(explore=True)
def step_and_add(self, explore=True):
@@ -73,7 +87,7 @@ def step_and_add(self, explore=True):
buffer_actions = self.model.policy.scale_action(actions)
new_obs, rewards, dones, infos = self.env.step(actions)
- self.now_num_steps += 1
+ self.now_online_timesteps += 1
next_obs = deepcopy(new_obs)
if dones[0] and infos[0].get("terminal_observation") is not None:
diff --git a/cmrl/types.py b/cmrl/types.py
index 1bb28b4..07a1c4b 100644
--- a/cmrl/types.py
+++ b/cmrl/types.py
@@ -1,7 +1,5 @@
-from dataclasses import dataclass
from typing import Callable, Optional, Tuple, Union
-import numpy as np
import torch
# (next_obs, pre_obs, action) -> reward
@@ -9,72 +7,5 @@
# (next_obs, pre_obs, action) -> terminal
TermFnType = Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
InitObsFnType = Callable[[int], torch.Tensor]
-ObsProcessFnType = Callable[[np.ndarray], np.ndarray]
-
-TensorType = Union[torch.Tensor, np.ndarray]
-TrajectoryEvalFnType = Callable[[TensorType, torch.Tensor], torch.Tensor]
-# obs, action, next_obs, reward, done
-InteractionData = Tuple[TensorType, TensorType, TensorType, TensorType, TensorType]
-
-
-@dataclass
-class InteractionBatch:
- """Represents a batch of transitions"""
-
- batch_obs: Optional[TensorType]
- batch_action: Optional[TensorType]
- batch_next_obs: Optional[TensorType]
- batch_reward: Optional[TensorType]
- batch_done: Optional[TensorType]
-
- @property
- def attrs(self):
- return [
- "batch_obs",
- "batch_action",
- "batch_next_obs",
- "batch_reward",
- "batch_done",
- ]
-
- def __len__(self):
- return self.batch_obs.shape[0]
-
- def as_tuple(self) -> InteractionData:
- return (
- self.batch_obs,
- self.batch_action,
- self.batch_next_obs,
- self.batch_reward,
- self.batch_done,
- )
-
- def __getitem__(self, item):
- return InteractionBatch(
- self.batch_obs[item],
- self.batch_action[item],
- self.batch_next_obs[item],
- self.batch_reward[item],
- self.batch_done[item],
- )
-
- @staticmethod
- def _get_new_shape(old_shape: Tuple[int, ...], batch_size: int):
- new_shape = list((1,) + old_shape)
- new_shape[0] = batch_size
- new_shape[1] = old_shape[0] // batch_size
- return tuple(new_shape)
-
- def add_new_batch_dim(self, batch_size: int):
- if not len(self) % batch_size == 0:
- raise ValueError("Current batch of transitions size is not a " "multiple of the new batch size. ")
- return InteractionBatch(
- self.batch_obs.reshape(self._get_new_shape(self.batch_obs.shape, batch_size)),
- self.batch_action.reshape(self._get_new_shape(self.batch_action.shape, batch_size)),
- self.batch_next_obs.reshape(self._get_new_shape(self.batch_obs.shape, batch_size)),
- self.batch_reward.reshape(self._get_new_shape(self.batch_reward.shape, batch_size)),
- self.batch_done.reshape(self._get_new_shape(self.batch_done.shape, batch_size)),
- )
-
-
-ModelInput = Union[torch.Tensor, InteractionBatch]
+Obs2StateFnType = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
+State2ObsFnType = Callable[[torch.Tensor], torch.Tensor]
diff --git a/cmrl/util/config.py b/cmrl/util/config.py
deleted file mode 100644
index c07dba2..0000000
--- a/cmrl/util/config.py
+++ /dev/null
@@ -1,43 +0,0 @@
-import pathlib
-from typing import Tuple, Union
-
-import omegaconf
-
-
-def load_hydra_cfg(results_dir: Union[str, pathlib.Path]) -> omegaconf.DictConfig:
- """Loads a Hydra configuration from the given directory path.
-
- Tries to load the configuration from "results_dir/.hydra/config.yaml".
-
- Args:
- results_dir (str or pathlib.Path): the path to the directory containing the config.
-
- Returns:
- (omegaconf.DictConfig): the loaded configuration.
-
- """
- results_dir = pathlib.Path(results_dir)
- cfg_file = results_dir / ".hydra" / "config.yaml"
- cfg = omegaconf.OmegaConf.load(cfg_file)
- if not isinstance(cfg, omegaconf.DictConfig):
- raise RuntimeError("Configuration format not a omegaconf.DictConf")
- return cfg
-
-
-def get_complete_dynamics_cfg(
- dynamics_cfg: omegaconf.DictConfig,
- obs_shape: Tuple[int, ...],
- act_shape: Tuple[int, ...],
-):
- transition_cfg = dynamics_cfg.transition
- transition_cfg.obs_size = obs_shape[0]
- transition_cfg.action_size = act_shape[0]
-
- reward_cfg = dynamics_cfg.reward_mech
- reward_cfg.obs_size = obs_shape[0]
- reward_cfg.action_size = act_shape[0]
-
- termination_cfg = dynamics_cfg.termination_mech
- termination_cfg.obs_size = obs_shape[0]
- termination_cfg.action_size = act_shape[0]
- return dynamics_cfg
diff --git a/cmrl/util/creator.py b/cmrl/util/creator.py
deleted file mode 100644
index f9fc723..0000000
--- a/cmrl/util/creator.py
+++ /dev/null
@@ -1,64 +0,0 @@
-import pathlib
-from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
-
-import gym.wrappers
-import hydra
-import numpy as np
-import omegaconf
-from stable_baselines3.common.logger import Logger
-
-from cmrl.models.dynamics import ConstraintBasedDynamics, PlainEnsembleDynamics
-from cmrl.models.transition import ForwardEulerTransition
-from cmrl.util.config import get_complete_dynamics_cfg
-
-
-def create_dynamics(
- dynamics_cfg: omegaconf.DictConfig,
- obs_shape: Tuple[int, ...],
- act_shape: Tuple[int, ...],
- logger: Optional[Logger] = None,
- load_dir: Optional[Union[str, pathlib.Path]] = None,
- load_device: Optional[str] = None,
-):
- if dynamics_cfg.name == "plain_dynamics":
- dynamics_class = PlainEnsembleDynamics
- elif dynamics_cfg.name == "constraint_based_dynamics":
- dynamics_class = ConstraintBasedDynamics
- else:
- raise NotImplementedError
-
- dynamics_cfg = get_complete_dynamics_cfg(dynamics_cfg, obs_shape, act_shape)
- transition = hydra.utils.instantiate(dynamics_cfg.transition, _recursive_=False)
- if dynamics_cfg.multi_step == "none":
- pass
- elif dynamics_cfg.multi_step.startswith("forward_euler"):
- repeat_times = int(dynamics_cfg.multi_step[len("forward_euler") + 1 :])
- transition = ForwardEulerTransition(transition, repeat_times)
- else:
- raise NotImplementedError
-
- if dynamics_cfg.learned_reward:
- reward_mech = hydra.utils.instantiate(dynamics_cfg.reward_mech, _recursive_=False)
- else:
- reward_mech = None
-
- if dynamics_cfg.learned_termination:
- termination_mech = hydra.utils.instantiate(dynamics_cfg.termination_mech, _recursive_=False)
- raise NotImplementedError
- else:
- termination_mech = None
-
- dynamics_model = dynamics_class(
- transition=transition,
- learned_reward=dynamics_cfg.learned_reward,
- reward_mech=reward_mech,
- learned_termination=dynamics_cfg.learned_termination,
- termination_mech=termination_mech,
- optim_lr=dynamics_cfg.optim_lr,
- weight_decay=dynamics_cfg.weight_decay,
- logger=logger,
- )
- if load_dir:
- dynamics_model.load(load_dir, load_device)
-
- return dynamics_model
diff --git a/cmrl/util/env.py b/cmrl/util/env.py
deleted file mode 100644
index a7628f6..0000000
--- a/cmrl/util/env.py
+++ /dev/null
@@ -1,47 +0,0 @@
-from typing import Dict, Optional, Tuple, Union, cast
-
-import emei
-import gym
-import omegaconf
-import torch
-
-import cmrl.types
-
-
-def to_num(s):
- try:
- return int(s)
- except ValueError:
- return float(s)
-
-
-def get_term_and_reward_fn(
- cfg: omegaconf.DictConfig,
-) -> Tuple[cmrl.types.TermFnType, Optional[cmrl.types.RewardFnType]]:
- return None, None
-
-
-def make_env(
- cfg: omegaconf.DictConfig,
-) -> Tuple[emei.EmeiEnv, cmrl.types.TermFnType, Optional[cmrl.types.RewardFnType], Optional[cmrl.types.InitObsFnType],]:
- if "gym___" in cfg.task.env:
- env = gym.make(cfg.task.env.split("___")[1])
- term_fn, reward_fn = get_term_and_reward_fn(cfg)
- init_obs_fn = None
- elif "emei___" in cfg.task.env:
- env_name, params, = cfg.task.env.split(
- "___"
- )[1:3]
- kwargs = dict([(item.split("=")[0], to_num(item.split("=")[1])) for item in params.split("&")])
- env = cast(emei.EmeiEnv, gym.make(env_name, **kwargs))
- term_fn = env.get_terminal
- reward_fn = env.get_reward
- init_obs_fn = env.get_batch_init_obs
- else:
- raise NotImplementedError
-
- # set seed
- env.reset(seed=cfg.seed)
- env.observation_space.seed(cfg.seed + 1)
- env.action_space.seed(cfg.seed + 2)
- return env, term_fn, reward_fn, init_obs_fn
diff --git a/cmrl/util/transition_iterator.py b/cmrl/util/transition_iterator.py
deleted file mode 100644
index b5e1928..0000000
--- a/cmrl/util/transition_iterator.py
+++ /dev/null
@@ -1,161 +0,0 @@
-# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-import pathlib
-import warnings
-from typing import Any, List, Optional, Sequence, Sized, Tuple, Type, Union
-
-import numpy as np
-
-from cmrl.types import InteractionBatch
-
-
-def _consolidate_batches(batches: Sequence[InteractionBatch]) -> InteractionBatch:
- len_batches = len(batches)
- b0 = batches[0]
- obs = np.empty((len_batches,) + b0.batch_obs.shape, dtype=b0.batch_obs.dtype)
- act = np.empty((len_batches,) + b0.batch_action.shape, dtype=b0.batch_action.dtype)
- next_obs = np.empty((len_batches,) + b0.batch_obs.shape, dtype=b0.batch_obs.dtype)
- rewards = np.empty((len_batches,) + b0.batch_reward.shape, dtype=np.float32)
- dones = np.empty((len_batches,) + b0.batch_done.shape, dtype=bool)
- for i, b in enumerate(batches):
- obs[i] = b.batch_obs
- act[i] = b.batch_action
- next_obs[i] = b.batch_next_obs
- rewards[i] = b.batch_reward
- dones[i] = b.batch_done
- return InteractionBatch(obs, act, next_obs, rewards, dones)
-
-
-class TransitionIterator:
- """An iterator for batches of transitions.
-
- The iterator can be used doing:
-
- .. code-block:: python
-
- for batch in batch_iterator:
- do_something_with_batch()
-
- Rather than be constructed directly, the preferred way to use objects of this class
- is for the user to obtain them from :class:`ReplayBuffer`.
-
- Args:
- transitions (:class:`InteractionBatch`): the transition data used to built
- the iterator.
- batch_size (int): the batch size to use when iterating over the stored data.
- shuffle_each_epoch (bool): if ``True`` the iteration order is shuffled everytime a
- loop over the data is completed. Defaults to ``False``.
- rng (np.random.Generator, optional): a random number generator when sampling
- batches. If None (default value), a new default generator will be used.
- """
-
- def __init__(
- self,
- transitions: InteractionBatch,
- batch_size: int,
- shuffle_each_epoch: bool = False,
- rng: Optional[np.random.Generator] = None,
- ):
- self.transitions = transitions
- self.num_stored = len(transitions)
- self._order: np.ndarray = np.arange(self.num_stored)
- self.batch_size = batch_size
- self._current_batch = 0
- self._shuffle_each_epoch = shuffle_each_epoch
- self._rng = rng if rng is not None else np.random.default_rng()
-
- def _get_indices_next_batch(self) -> Sized:
- start_idx = self._current_batch * self.batch_size
- if start_idx >= self.num_stored:
- raise StopIteration
- end_idx = min((self._current_batch + 1) * self.batch_size, self.num_stored)
- order_indices = range(start_idx, end_idx)
- indices = self._order[order_indices]
- self._current_batch += 1
- return indices
-
- def __iter__(self):
- self._current_batch = 0
- if self._shuffle_each_epoch:
- self._order = self._rng.permutation(self.num_stored)
- return self
-
- def __next__(self):
- return self[self._get_indices_next_batch()]
-
- def ensemble_size(self):
- return 0
-
- def __len__(self):
- return (self.num_stored - 1) // self.batch_size + 1
-
- def __getitem__(self, item):
- return self.transitions[item]
-
-
-class BootstrapIterator(TransitionIterator):
- def __init__(
- self,
- transitions: InteractionBatch,
- batch_size: int,
- ensemble_size: int,
- shuffle_each_epoch: bool = False,
- permute_indices: bool = True,
- rng: Optional[np.random.Generator] = None,
- ):
- super().__init__(transitions, batch_size, shuffle_each_epoch=shuffle_each_epoch, rng=rng)
- self._ensemble_size = ensemble_size
- self._permute_indices = permute_indices
- self._bootstrap_iter = ensemble_size > 1
- self.member_indices = self._sample_member_indices()
-
- def _sample_member_indices(self) -> np.ndarray:
- member_indices = np.empty((self.ensemble_size, self.num_stored), dtype=int)
- if self._permute_indices:
- for i in range(self.ensemble_size):
- member_indices[i] = self._rng.permutation(self.num_stored)
- else:
- member_indices = self._rng.choice(
- self.num_stored,
- size=(self.ensemble_size, self.num_stored),
- replace=True,
- )
- return member_indices
-
- def __iter__(self):
- super().__iter__()
- return self
-
- def __next__(self):
- if not self._bootstrap_iter:
- return super().__next__()
- indices = self._get_indices_next_batch()
- batches = []
- for member_idx in self.member_indices:
- content_indices = member_idx[indices]
- batches.append(self[content_indices])
- return _consolidate_batches(batches)
-
- def toggle_bootstrap(self):
- """Toggles whether the iterator returns a batch per model or a single batch."""
- if self.ensemble_size > 1:
- self._bootstrap_iter = not self._bootstrap_iter
-
- @property
- def ensemble_size(self):
- return self._ensemble_size
-
-
-def _sequence_getitem_impl(
- transitions: InteractionBatch,
- batch_size: int,
- sequence_length: int,
- valid_starts: np.ndarray,
- item: Any,
-):
- start_indices = valid_starts[item].repeat(sequence_length)
- increment_array = np.tile(np.arange(sequence_length), len(item))
- full_trajectory_indices = start_indices + increment_array
- return transitions[full_trajectory_indices].add_new_batch_dim(min(batch_size, len(item)))
diff --git a/cmrl/utils/RCIT.py b/cmrl/utils/RCIT.py
new file mode 100644
index 0000000..76c6c23
--- /dev/null
+++ b/cmrl/utils/RCIT.py
@@ -0,0 +1,643 @@
+import numpy as np
+from numpy import sqrt
+from numpy.linalg import eigh, eigvalsh
+from scipy import stats
+from sklearn.gaussian_process import GaussianProcessRegressor
+from sklearn.gaussian_process.kernels import RBF
+from sklearn.gaussian_process.kernels import ConstantKernel as C
+from sklearn.gaussian_process.kernels import WhiteKernel
+
+from causallearn.utils.KCI.GaussianKernel import GaussianKernel
+from causallearn.utils.KCI.Kernel import Kernel
+from causallearn.utils.KCI.LinearKernel import LinearKernel
+from causallearn.utils.KCI.PolynomialKernel import PolynomialKernel
+import random
+import math
+import time
+from numpy.linalg import inv
+
+##################### For Random Feature #####################
+try:
+ import rpy2
+ import rpy2.robjects
+
+ rpy2.robjects.r['options'](warn=-1)
+ from rpy2.robjects.packages import importr
+ import rpy2.robjects.numpy2ri
+
+ rpy2.robjects.numpy2ri.activate()
+except:
+ print("Could not import rpy package")
+
+try:
+ importr('RCIT')
+except:
+ print("Could not import r-package RCIT")
+import random
+
+
+def set_random_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+
+
+###############################################################
+
+
+# Cannot find reference 'xxx' in '__init__.pyi | __init__.pyi | __init__.pxd' is a bug in pycharm, please ignore
+class KCI_UInd(object):
+ """
+ Python implementation of Kernel-based Conditional Independence (KCI) test. Unconditional version.
+ The original Matlab implementation can be found in http://people.tuebingen.mpg.de/kzhang/KCI-test.zip
+
+ References
+ ----------
+ [1] K. Zhang, J. Peters, D. Janzing, and B. Schölkopf,
+ "A kernel-based conditional independence test and application in causal discovery," In UAI 2011.
+ [2] A. Gretton, K. Fukumizu, C.-H. Teo, L. Song, B. Schölkopf, and A. Smola, "A kernel
+ Statistical test of independence." In NIPS 21, 2007.
+ """
+
+ def __init__(self, kernelX='Gaussian', kernelY='Gaussian', null_ss=1000, approx=True, est_width='empirical',
+ polyd=2, kwidthx=None, kwidthy=None):
+ """
+ Construct the KCI_UInd model.
+
+ Parameters
+ ----------
+ kernelX: kernel function for input data x
+ 'Gaussian': Gaussian kernel
+ 'Polynomial': Polynomial kernel
+ 'Linear': Linear kernel
+ kernelY: kernel function for input data y
+ est_width: set kernel width for Gaussian kernels
+ 'empirical': set kernel width using empirical rules
+ 'median': set kernel width using the median trick
+ 'manual': set by users
+ null_ss: sample size in simulating the null distribution
+ approx: whether to use gamma approximation (default=True)
+ polyd: polynomial kernel degrees (default=1)
+ kwidthx: kernel width for data x (standard deviation sigma)
+ kwidthy: kernel width for data y (standard deviation sigma)
+ """
+
+ self.kernelX = kernelX
+ self.kernelY = kernelY
+ self.est_width = est_width
+ self.polyd = polyd
+ self.kwidthx = kwidthx
+ self.kwidthy = kwidthy
+ self.nullss = null_ss
+ self.thresh = 1e-6
+ self.approx = approx
+
+ def compute_pvalue(self, data_x=None, data_y=None):
+ """
+ Main function: compute the p value and return it together with the test statistic
+
+ Parameters
+ ----------
+ data_x: input data for x (nxd1 array)
+ data_y: input data for y (nxd2 array)
+
+ Returns
+ _________
+ pvalue: p value (scalar)
+ test_stat: test statistic (scalar)
+
+ [Notes for speedup optimization]
+ Kx, Ky are both symmetric with diagonals equal to 1 (no matter what the kernel is)
+ Kxc, Kyc are both symmetric
+ """
+
+ Kx, Ky = self.kernel_matrix(data_x, data_y)
+ test_stat, Kxc, Kyc = self.HSIC_V_statistic(Kx, Ky)
+
+ if self.approx:
+ k_appr, theta_appr = self.get_kappa(Kxc, Kyc)
+ pvalue = 1 - stats.gamma.cdf(test_stat, k_appr, 0, theta_appr)
+ else:
+ null_dstr = self.null_sample_spectral(Kxc, Kyc)
+ pvalue = sum(null_dstr.squeeze() > test_stat) / float(self.nullss)
+ return pvalue, test_stat
+
+ def compute_pvalue_rf(self, data_x=None, data_y=None):
+ rit = rpy2.robjects.r['RIT'](data_x, data_y, approx="lpd4", seed=42)
+ sta = float(rit.rx2('Sta')[0])
+ pval = float(rit.rx2('p')[0])
+ return pval, sta
+
+ def kernel_matrix(self, data_x, data_y):
+ """
+ Compute kernel matrix for data x and data y
+
+ Parameters
+ ----------
+ data_x: input data for x (nxd1 array)
+ data_y: input data for y (nxd2 array)
+
+ Returns
+ _________
+ Kx: kernel matrix for data_x (nxn)
+ Ky: kernel matrix for data_y (nxn)
+ """
+ if self.kernelX == 'Gaussian':
+ if self.est_width == 'manual':
+ if self.kwidthx is not None:
+ kernelX = GaussianKernel(self.kwidthx)
+ else:
+ raise Exception('specify kwidthx')
+ else:
+ kernelX = GaussianKernel()
+ if self.est_width == 'median':
+ kernelX.set_width_median(data_x)
+ elif self.est_width == 'empirical':
+ kernelX.set_width_empirical_hsic(data_x)
+ else:
+ raise Exception('Undefined kernel width estimation method')
+ elif self.kernelX == 'Polynomial':
+ kernelX = PolynomialKernel(self.polyd)
+ elif self.kernelX == 'Linear':
+ kernelX = LinearKernel()
+ else:
+ raise Exception('Undefined kernel function')
+
+ if self.kernelY == 'Gaussian':
+ if self.est_width == 'manual':
+ if self.kwidthy is not None:
+ kernelY = GaussianKernel(self.kwidthy)
+ else:
+ raise Exception('specify kwidthy')
+ else:
+ kernelY = GaussianKernel()
+ if self.est_width == 'median':
+ kernelY.set_width_median(data_y)
+ elif self.est_width == 'empirical':
+ kernelY.set_width_empirical_hsic(data_y)
+ else:
+ raise Exception('Undefined kernel width estimation method')
+ elif self.kernelY == 'Polynomial':
+ kernelY = PolynomialKernel(self.polyd)
+ elif self.kernelY == 'Linear':
+ kernelY = LinearKernel()
+ else:
+ raise Exception('Undefined kernel function')
+
+ data_x = stats.zscore(data_x, ddof=1, axis=0)
+ data_x[np.isnan(data_x)] = 0. # in case some dim of data_x is constant
+ data_y = stats.zscore(data_y, ddof=1, axis=0)
+ data_y[np.isnan(data_y)] = 0.
+ # We set 'ddof=1' to conform to the normalization way in the original Matlab implementation in
+ # http://people.tuebingen.mpg.de/kzhang/KCI-test.zip
+
+ Kx = kernelX.kernel(data_x)
+ Ky = kernelY.kernel(data_y)
+ return Kx, Ky
+
+ def HSIC_V_statistic(self, Kx, Ky):
+ """
+ Compute V test statistic from kernel matrices Kx and Ky
+ Parameters
+ ----------
+ Kx: kernel matrix for data_x (nxn)
+ Ky: kernel matrix for data_y (nxn)
+
+ Returns
+ _________
+ Vstat: HSIC v statistics
+ Kxc: centralized kernel matrix for data_x (nxn)
+ Kyc: centralized kernel matrix for data_y (nxn)
+ """
+ Kxc = Kernel.center_kernel_matrix(Kx)
+ Kyc = Kernel.center_kernel_matrix(Ky)
+ V_stat = np.sum(Kxc * Kyc)
+ return V_stat, Kxc, Kyc
+
+ def null_sample_spectral(self, Kxc, Kyc):
+ """
+ Simulate data from null distribution
+
+ Parameters
+ ----------
+ Kxc: centralized kernel matrix for data_x (nxn)
+ Kyc: centralized kernel matrix for data_y (nxn)
+
+ Returns
+ _________
+ null_dstr: samples from the null distribution
+
+ """
+ T = Kxc.shape[0]
+ if T > 1000:
+ num_eig = np.int(np.floor(T / 2))
+ else:
+ num_eig = T
+ lambdax = eigvalsh(Kxc)
+ lambday = eigvalsh(Kyc)
+ lambdax = -np.sort(-lambdax)
+ lambday = -np.sort(-lambday)
+ lambdax = lambdax[0:num_eig]
+ lambday = lambday[0:num_eig]
+ lambda_prod = np.dot(lambdax.reshape(num_eig, 1), lambday.reshape(1, num_eig)).reshape(
+ (num_eig ** 2, 1))
+ lambda_prod = lambda_prod[lambda_prod > lambda_prod.max() * self.thresh]
+ f_rand = np.random.chisquare(1, (lambda_prod.shape[0], self.nullss))
+ null_dstr = lambda_prod.T.dot(f_rand) / T
+ return null_dstr
+
+ def get_kappa(self, Kx, Ky):
+ """
+ Get parameters for the approximated gamma distribution
+ Parameters
+ ----------
+ Kx: kernel matrix for data_x (nxn)
+ Ky: kernel matrix for data_y (nxn)
+
+ Returns
+ _________
+ k_appr, theta_appr: approximated parameters of the gamma distribution
+
+ [Updated @Haoyue 06/24/2022]
+ equivalent to:
+ var_appr = 2 * np.trace(Kx.dot(Kx)) * np.trace(Ky.dot(Ky)) / T / T
+ based on the fact that:
+ np.trace(K.dot(K)) == np.sum(K * K.T), where here K is symmetric
+ we can save time on the dot product by only considering the diagonal entries of K.dot(K)
+ time complexity is reduced from O(n^3) (matrix dot) to O(n^2) (traverse each element),
+ where n is usually big (sample size).
+ """
+ T = Kx.shape[0]
+ mean_appr = np.trace(Kx) * np.trace(Ky) / T
+ var_appr = 2 * np.sum(Kx ** 2) * np.sum(Ky ** 2) / T / T # same as np.sum(Kx * Kx.T) ..., here Kx is symmetric
+ k_appr = mean_appr ** 2 / var_appr
+ theta_appr = var_appr / mean_appr
+ return k_appr, theta_appr
+
+
+class KCI_CInd(object):
+ """
+ Python implementation of Kernel-based Conditional Independence (KCI) test. Conditional version.
+ The original Matlab implementation can be found in http://people.tuebingen.mpg.de/kzhang/KCI-test.zip
+
+ References
+ ----------
+ [1] K. Zhang, J. Peters, D. Janzing, and B. Schölkopf, "A kernel-based conditional independence test and application in causal discovery," In UAI 2011.
+ """
+
+ def __init__(self, kernelX='Gaussian', kernelY='Gaussian', kernelZ='Gaussian', nullss=5000, est_width='empirical',
+ use_gp=False, approx=True, polyd=2, kwidthx=None, kwidthy=None, kwidthz=None):
+ """
+ Construct the KCI_CInd model.
+ Parameters
+ ----------
+ kernelX: kernel function for input data x
+ 'Gaussian': Gaussian kernel
+ 'Polynomial': Polynomial kernel
+ 'Linear': Linear kernel
+ kernelY: kernel function for input data y
+ kernelZ: kernel function for input data z (conditional variable)
+ est_width: set kernel width for Gaussian kernels
+ 'empirical': set kernel width using empirical rules
+ 'median': set kernel width using the median trick
+ 'manual': set by users
+ null_ss: sample size in simulating the null distribution
+ use_gp: whether use gaussian process to determine kernel width for z
+ approx: whether to use gamma approximation (default=True)
+ polyd: polynomial kernel degrees (default=1)
+ kwidthx: kernel width for data x (standard deviation sigma, default None)
+ kwidthy: kernel width for data y (standard deviation sigma)
+ kwidthz: kernel width for data z (standard deviation sigma)
+ """
+ self.kernelX = kernelX
+ self.kernelY = kernelY
+ self.kernelZ = kernelZ
+ self.est_width = est_width
+ self.polyd = polyd
+ self.kwidthx = kwidthx
+ self.kwidthy = kwidthy
+ self.kwidthz = kwidthz
+ self.nullss = nullss
+ self.epsilon_x = 1e-3 # To conform to the original Matlab implementation.
+ self.epsilon_y = 1e-3
+ self.use_gp = use_gp
+ self.thresh = 1e-5
+ self.approx = approx
+
+ def compute_pvalue_rf(self, data_x=None, data_y=None, data_z=None):
+ rit = rpy2.robjects.r['RCIT'](data_x, data_y, data_z, num_f=10000, num_f2=200, approx="lpd4", seed=42)
+ sta = float(rit.rx2('Sta')[0])
+ pval = float(rit.rx2('p')[0])
+ print(pval)
+ return pval, sta
+
+ def compute_pvalue(self, data_x=None, data_y=None, data_z=None):
+ """
+ Main function: compute the p value and return it together with the test statistic
+ Parameters
+ ----------
+ data_x: input data for x (nxd1 array)
+ data_y: input data for y (nxd2 array)
+ data_z: input data for z (nxd3 array)
+
+ Returns
+ _________
+ pvalue: p value
+ test_stat: test statistic
+ """
+ Kx, Ky, Kzx, Kzy = self.kernel_matrix(data_x, data_y, data_z)
+ test_stat, KxR, KyR = self.KCI_V_statistic(Kx, Ky, Kzx, Kzy)
+ uu_prod, size_u = self.get_uuprod(KxR, KyR)
+ if self.approx:
+ k_appr, theta_appr = self.get_kappa(uu_prod)
+ pvalue = 1 - stats.gamma.cdf(test_stat, k_appr, 0, theta_appr)
+ else:
+ null_samples = self.null_sample_spectral(uu_prod, size_u, Kx.shape[0])
+ pvalue = sum(null_samples > test_stat) / float(self.nullss)
+ return pvalue, test_stat
+
+ def kernel_matrix(self, data_x, data_y, data_z):
+ """
+ Compute kernel matrix for data x, data y, and data_z
+ Parameters
+ ----------
+ data_x: input data for x (nxd1 array)
+ data_y: input data for y (nxd2 array)
+ data_z: input data for z (nxd3 array)
+
+ Returns
+ _________
+ Kx: kernel matrix for data_x (nxn)
+ Ky: kernel matrix for data_y (nxn)
+ Kzx: centering kernel matrix for data_x (nxn)
+ kzy: centering kernel matrix for data_y (nxn)
+ """
+ # normalize the data
+ data_x = stats.zscore(data_x, ddof=1, axis=0)
+ data_x[np.isnan(data_x)] = 0.
+
+ data_y = stats.zscore(data_y, ddof=1, axis=0)
+ data_y[np.isnan(data_y)] = 0.
+
+ data_z = stats.zscore(data_z, ddof=1, axis=0)
+ data_z[np.isnan(data_z)] = 0.
+ # We set 'ddof=1' to conform to the normalization way in the original Matlab implementation in
+ # http://people.tuebingen.mpg.de/kzhang/KCI-test.zip
+
+ # concatenate x and z
+ data_x = np.concatenate((data_x, 0.5 * data_z), axis=1)
+ if self.kernelX == 'Gaussian':
+ if self.est_width == 'manual':
+ if self.kwidthx is not None:
+ kernelX = GaussianKernel(self.kwidthx)
+ else:
+ raise Exception('specify kwidthx')
+ else:
+ kernelX = GaussianKernel()
+ if self.est_width == 'median':
+ kernelX.set_width_median(data_x)
+ elif self.est_width == 'empirical':
+ # kernelX's empirical width is determined by data_z's shape, please refer to the original code
+ # (http://people.tuebingen.mpg.de/kzhang/KCI-test.zip) in the file
+ # 'algorithms/CInd_test_new_withGP.m', Line 37 to 52.
+ kernelX.set_width_empirical_kci(data_z)
+ else:
+ raise Exception('Undefined kernel width estimation method')
+ elif self.kernelX == 'Polynomial':
+ kernelX = PolynomialKernel(self.polyd)
+ elif self.kernelX == 'Linear':
+ kernelX = LinearKernel()
+ else:
+ raise Exception('Undefined kernel function')
+
+ if self.kernelY == 'Gaussian':
+ if self.est_width == 'manual':
+ if self.kwidthy is not None:
+ kernelY = GaussianKernel(self.kwidthy)
+ else:
+ raise Exception('specify kwidthy')
+ else:
+ kernelY = GaussianKernel()
+ if self.est_width == 'median':
+ kernelY.set_width_median(data_y)
+ elif self.est_width == 'empirical':
+ # kernelY's empirical width is determined by data_z's shape, please refer to the original code
+ # (http://people.tuebingen.mpg.de/kzhang/KCI-test.zip) in the file
+ # 'algorithms/CInd_test_new_withGP.m', Line 37 to 52.
+ kernelY.set_width_empirical_kci(data_z)
+ else:
+ raise Exception('Undefined kernel width estimation method')
+ elif self.kernelY == 'Polynomial':
+ kernelY = PolynomialKernel(self.polyd)
+ elif self.kernelY == 'Linear':
+ kernelY = LinearKernel()
+ else:
+ raise Exception('Undefined kernel function')
+
+ Kx = kernelX.kernel(data_x)
+ Ky = kernelY.kernel(data_y)
+
+ # centering kernel matrix
+ Kx = Kernel.center_kernel_matrix(Kx)
+ Ky = Kernel.center_kernel_matrix(Ky)
+
+ if self.kernelZ == 'Gaussian':
+ if not self.use_gp:
+ if self.est_width == 'manual':
+ if self.kwidthz is not None:
+ kernelZ = GaussianKernel(self.kwidthz)
+ else:
+ raise Exception('specify kwidthz')
+ else:
+ kernelZ = GaussianKernel()
+ if self.est_width == 'median':
+ kernelZ.set_width_median(data_z)
+ elif self.est_width == 'empirical':
+ kernelZ.set_width_empirical_kci(data_z)
+ Kzx = kernelZ.kernel(data_z)
+ Kzx = Kernel.center_kernel_matrix(Kzx)
+ # centering kernel matrix to conform with the original Matlab implementation,
+ # specifically, Line 100 in the file 'algorithms/CInd_test_new_withGP.m'
+ Kzy = Kzx
+ else:
+ # learning the kernel width of Kz using Gaussian process
+ n, Dz = data_z.shape
+ if self.kernelX == 'Gaussian':
+ widthz = sqrt(1.0 / (kernelX.width * data_x.shape[1]))
+ else:
+ widthz = 1.0
+ # Instantiate a Gaussian Process model for x
+ wx, vx = eigh(0.5 * (Kx + Kx.T))
+ topkx = int(np.min((400, np.floor(n / 4))))
+ idx = np.argsort(-wx)
+ wx = wx[idx]
+ vx = vx[:, idx]
+ wx = wx[0:topkx]
+ vx = vx[:, 0:topkx]
+ vx = vx[:, wx > wx.max() * self.thresh]
+ wx = wx[wx > wx.max() * self.thresh]
+ vx = 2 * sqrt(n) * vx.dot(np.diag(np.sqrt(wx))) / sqrt(wx[0])
+ kernelx = C(1.0, (1e-3, 1e3)) * RBF(widthz * np.ones(Dz), (1e-2, 1e2)) + WhiteKernel(0.1, (1e-10, 1e+1))
+ gpx = GaussianProcessRegressor(kernel=kernelx)
+ # fit Gaussian process, including hyperparameter optimization
+ gpx.fit(data_z, vx)
+
+ # construct Gaussian kernels according to learned hyperparameters
+ Kzx = gpx.kernel_.k1(data_z, data_z)
+ self.epsilon_x = np.exp(gpx.kernel_.theta[-1])
+
+ # Instantiate a Gaussian Process model for y
+ wy, vy = eigh(0.5 * (Ky + Ky.T))
+ topky = int(np.min((400, np.floor(n / 4))))
+ idy = np.argsort(-wy)
+ wy = wy[idy]
+ vy = vy[:, idy]
+ wy = wy[0:topky]
+ vy = vy[:, 0:topky]
+ vy = vy[:, wy > wy.max() * self.thresh]
+ wy = wy[wy > wy.max() * self.thresh]
+ vy = 2 * sqrt(n) * vy.dot(np.diag(np.sqrt(wy))) / sqrt(wy[0])
+ kernely = C(1.0, (1e-3, 1e3)) * RBF(widthz * np.ones(Dz), (1e-2, 1e2)) + WhiteKernel(0.1, (1e-10, 1e+1))
+ gpy = GaussianProcessRegressor(kernel=kernely)
+ # fit Gaussian process, including hyperparameter optimization
+ gpy.fit(data_z, vy)
+
+ # construct Gaussian kernels according to learned hyperparameters
+ Kzy = gpy.kernel_.k1(data_z, data_z)
+ self.epsilon_y = np.exp(gpy.kernel_.theta[-1])
+ elif self.kernelZ == 'Polynomial':
+ kernelZ = PolynomialKernel(self.polyd)
+ Kzx = kernelZ.kernel(data_z)
+ Kzx = Kernel.center_kernel_matrix(Kzx)
+ Kzy = Kzx
+ elif self.kernelZ == 'Linear':
+ kernelZ = LinearKernel()
+ Kzx = kernelZ.kernel(data_z)
+ Kzx = Kernel.center_kernel_matrix(Kzx)
+ Kzy = Kzx
+ else:
+ raise Exception('Undefined kernel function')
+ return Kx, Ky, Kzx, Kzy
+
+ def KCI_V_statistic(self, Kx, Ky, Kzx, Kzy):
+ """
+ Compute V test statistic from kernel matrices Kx and Ky
+ Parameters
+ ----------
+ Kx: kernel matrix for data_x (nxn)
+ Ky: kernel matrix for data_y (nxn)
+ Kzx: centering kernel matrix for data_x (nxn)
+ kzy: centering kernel matrix for data_y (nxn)
+
+ Returns
+ _________
+ Vstat: KCI v statistics
+ KxR: centralized kernel matrix for data_x (nxn)
+ KyR: centralized kernel matrix for data_y (nxn)
+
+ [Updated @Haoyue 06/24/2022]
+ 1. Kx, Ky, Kzx, Kzy are all symmetric matrices.
+ - * Kx's diagonal elements are not the same, because the kernel Kx is centered.
+ * Before centering, Kx's all diagonal elements are 1 (because of exp(-0.5 * sq_dists * self.width)).
+ * The same applies to Ky.
+ - * If (self.kernelZ == 'Gaussian' and self.use_gp), then Kzx has all the same diagonal elements (not necessarily 1).
+ * The same applies to Kzy.
+ 2. If not (self.kernelZ == 'Gaussian' and self.use_gp): assert (Kzx == Kzy).all()
+ With this we could save one repeated calculation of pinv(Kzy+\epsilonI), which consumes most time.
+ """
+ KxR, Rzx = Kernel.center_kernel_matrix_regression(Kx, Kzx, self.epsilon_x)
+ if self.epsilon_x != self.epsilon_y or (self.kernelZ == 'Gaussian' and self.use_gp):
+ KyR, _ = Kernel.center_kernel_matrix_regression(Ky, Kzy, self.epsilon_y)
+ else:
+ # assert np.all(Kzx == Kzy), 'Kzx and Kzy are the same'
+ KyR = Rzx.dot(Ky.dot(Rzx))
+ Vstat = np.sum(KxR * KyR)
+ return Vstat, KxR, KyR
+
+ def get_uuprod(self, Kx, Ky):
+ """
+ Compute eigenvalues for null distribution estimation
+
+ Parameters
+ ----------
+ Kx: centralized kernel matrix for data_x (nxn)
+ Ky: centralized kernel matrix for data_y (nxn)
+
+ Returns
+ _________
+ uu_prod: product of the eigenvectors of Kx and Ky
+ size_u: number of producted eigenvectors
+
+ """
+ wx, vx = eigh(0.5 * (Kx + Kx.T))
+ wy, vy = eigh(0.5 * (Ky + Ky.T))
+ idx = np.argsort(-wx)
+ idy = np.argsort(-wy)
+ wx = wx[idx]
+ vx = vx[:, idx]
+ wy = wy[idy]
+ vy = vy[:, idy]
+ vx = vx[:, wx > np.max(wx) * self.thresh]
+ wx = wx[wx > np.max(wx) * self.thresh]
+ vy = vy[:, wy > np.max(wy) * self.thresh]
+ wy = wy[wy > np.max(wy) * self.thresh]
+ vx = vx.dot(np.diag(np.sqrt(wx)))
+ vy = vy.dot(np.diag(np.sqrt(wy)))
+
+ # calculate their product
+ T = Kx.shape[0]
+ num_eigx = vx.shape[1]
+ num_eigy = vy.shape[1]
+ size_u = num_eigx * num_eigy
+ uu = np.zeros((T, size_u))
+ for i in range(0, num_eigx):
+ for j in range(0, num_eigy):
+ uu[:, i * num_eigy + j] = vx[:, i] * vy[:, j]
+
+ if size_u > T:
+ uu_prod = uu.dot(uu.T)
+ else:
+ uu_prod = uu.T.dot(uu)
+
+ return uu_prod, size_u
+
+ def null_sample_spectral(self, uu_prod, size_u, T):
+ """
+ Simulate data from null distribution
+
+ Parameters
+ ----------
+ uu_prod: product of the eigenvectors of Kx and Ky
+ size_u: number of producted eigenvectors
+ T: sample size
+
+ Returns
+ _________
+ null_dstr: samples from the null distribution
+
+ """
+ eig_uu = eigvalsh(uu_prod)
+ eig_uu = -np.sort(-eig_uu)
+ eig_uu = eig_uu[0:np.min((T, size_u))]
+ eig_uu = eig_uu[eig_uu > np.max(eig_uu) * self.thresh]
+
+ f_rand = np.random.chisquare(1, (eig_uu.shape[0], self.nullss))
+ null_dstr = eig_uu.T.dot(f_rand)
+ return null_dstr
+
+ def get_kappa(self, uu_prod):
+ """
+ Get parameters for the approximated gamma distribution
+ Parameters
+ ----------
+ uu_prod: product of the eigenvectors of Kx and Ky
+
+ Returns
+ ----------
+ k_appr, theta_appr: approximated parameters of the gamma distribution
+
+ """
+ mean_appr = np.trace(uu_prod)
+ var_appr = 2 * np.trace(uu_prod.dot(uu_prod))
+ k_appr = mean_appr ** 2 / var_appr
+ theta_appr = var_appr / mean_appr
+ return k_appr, theta_appr
diff --git a/cmrl/util/__init__.py b/cmrl/utils/__init__.py
similarity index 92%
rename from cmrl/util/__init__.py
rename to cmrl/utils/__init__.py
index d4085c1..c59acbe 100644
--- a/cmrl/util/__init__.py
+++ b/cmrl/utils/__init__.py
@@ -40,11 +40,11 @@ def create_handler(cfg: Union[Dict, omegaconf.ListConfig, omegaconf.DictConfig])
target = cfg.overrides.env_cfg.get_dynamics_predict("_target_")
if "pybulletgym" in target:
- from cmrl.util.pybullet import PybulletEnvHandler
+ from cmrl.utils.pybullet import PybulletEnvHandler
return PybulletEnvHandler()
elif "mujoco" in target:
- from cmrl.util.mujoco import MujocoEnvHandler
+ from cmrl.utils.mujoco import MujocoEnvHandler
return MujocoEnvHandler()
else:
@@ -74,15 +74,15 @@ def create_handler_from_str(env_name: str):
(EnvHandler): A handler for the associated gym environment
"""
if "dmcontrol___" in env_name:
- from cmrl.util.dmcontrol import DmcontrolEnvHandler
+ from cmrl.utils.dmcontrol import DmcontrolEnvHandler
return DmcontrolEnvHandler()
elif "pybulletgym___" in env_name:
- from cmrl.util.pybullet import PybulletEnvHandler
+ from cmrl.utils.pybullet import PybulletEnvHandler
return PybulletEnvHandler()
elif "gym___" in env_name or env_name == "ideal_inv_pendulum":
- from cmrl.util.mujoco import MujocoEnvHandler
+ from cmrl.utils.mujoco import MujocoEnvHandler
return MujocoEnvHandler()
else:
diff --git a/cmrl/utils/config.py b/cmrl/utils/config.py
new file mode 100644
index 0000000..befe862
--- /dev/null
+++ b/cmrl/utils/config.py
@@ -0,0 +1,64 @@
+import pathlib
+from typing import Dict, Union, Optional
+from collections import defaultdict
+
+import omegaconf
+from omegaconf import DictConfig
+import pandas as pd
+import numpy as np
+
+PACKAGE_PATH = pathlib.Path(__file__).parent.parent.parent
+
+
+def load_hydra_cfg(results_dir: Union[str, pathlib.Path]) -> omegaconf.DictConfig:
+ """Loads a Hydra configuration from the given directory path.
+
+ Tries to load the configuration from "results_dir/.hydra/config.yaml".
+
+ Args:
+ results_dir (str or pathlib.Path): the path to the directory containing the config.
+
+ Returns:
+ (omegaconf.DictConfig): the loaded configuration.
+
+ """
+ results_dir = pathlib.Path(results_dir)
+ cfg_file = results_dir / ".hydra" / "config.yaml"
+ cfg = omegaconf.OmegaConf.load(cfg_file)
+ if not isinstance(cfg, omegaconf.DictConfig):
+ raise RuntimeError("Configuration format not a omegaconf.DictConf")
+ return cfg
+
+
+def exp_collect(cfg_extractor,
+ csv_extractor,
+ env_name="ContinuousCartPoleSwingUp-v0",
+ exp_name="default",
+ exp_path=None):
+ data = defaultdict(list)
+
+ if exp_path is None:
+ exp_path = PACKAGE_PATH / "exp"
+ exp_dir = exp_path / exp_name
+ env_dir = exp_dir / env_name
+
+ for params_dir in env_dir.glob("*"):
+ for dataset_dir in params_dir.glob("*"):
+ for time_dir in dataset_dir.glob("*"):
+ if not (time_dir / ".hydra").exists(): # exp by hydra's MULTIRUN mode, multi exp in this time
+ time_dir_list = list(time_dir.glob("*"))
+ else: # only one exp in this time
+ time_dir_list = [time_dir]
+
+ for single_dir in time_dir_list:
+ if single_dir.name == "multirun.yaml":
+ continue
+
+ cfg = load_hydra_cfg(single_dir)
+
+ key = cfg_extractor(cfg, params_dir.name, dataset_dir.name, time_dir.name)
+ if not key:
+ continue
+
+ data[key] = csv_extractor(single_dir / "log")
+ return data
diff --git a/cmrl/utils/creator.py b/cmrl/utils/creator.py
new file mode 100644
index 0000000..8b71dd8
--- /dev/null
+++ b/cmrl/utils/creator.py
@@ -0,0 +1,83 @@
+from typing import Optional, cast, List
+
+from gym import spaces
+from hydra.utils import instantiate
+from omegaconf import DictConfig
+import numpy as np
+from stable_baselines3.common.vec_env import VecMonitor
+from stable_baselines3.common.logger import Logger
+from stable_baselines3.common.base_class import BaseAlgorithm
+
+from cmrl.types import Obs2StateFnType, State2ObsFnType
+from cmrl.models.dynamics import Dynamics
+from cmrl.models.fake_env import VecFakeEnv
+from cmrl.models.causal_mech.base import BaseCausalMech
+from cmrl.utils.variables import ContinuousVariable, BinaryVariable, DiscreteVariable, Variable, parse_space
+
+
+def create_agent(cfg: DictConfig, fake_env: VecFakeEnv, logger: Optional[Logger] = None):
+ agent = instantiate(cfg.algorithm.agent)(env=VecMonitor(fake_env))
+ agent = cast(BaseAlgorithm, agent)
+ agent.set_logger(logger)
+
+ return agent
+
+
+def create_dynamics(
+ cfg: DictConfig,
+ state_space: spaces.Space,
+ action_space: spaces.Space,
+ obs2state_fn: Obs2StateFnType,
+ state2obs_fn: State2ObsFnType,
+ logger: Optional[Logger] = None,
+):
+ extra_info = cfg.task.get("extra_variable_info", {})
+ obs_variables = parse_space(state_space, "obs", extra_info=extra_info)
+ act_variables = parse_space(action_space, "act", extra_info=extra_info)
+ next_obs_variables = parse_space(state_space, "next_obs", extra_info=extra_info)
+
+ # transition
+ assert cfg.transition.learn, "transition must be learned, or you should try model-free RL:)"
+ transition = instantiate(cfg.transition.mech)(
+ input_variables=obs_variables + act_variables,
+ output_variables=next_obs_variables,
+ logger=logger,
+ )
+ transition = cast(BaseCausalMech, transition)
+
+ # reward mech
+ assert cfg.reward_mech.mech.multi_step == "none", "reward-mech must be one-step"
+ if cfg.reward_mech.learn:
+ reward_mech = instantiate(cfg.reward_mech.mech)(
+ input_variables=obs_variables + act_variables + next_obs_variables,
+ output_variables=[ContinuousVariable("reward", dim=1, low=-np.inf, high=np.inf)],
+ logger=logger,
+ )
+ reward_mech = cast(BaseCausalMech, reward_mech)
+ else:
+ reward_mech = None
+
+ # termination mech
+ assert cfg.termination_mech.mech.multi_step == "none", "termination-mech must be one-step"
+ if cfg.termination_mech.learn:
+ termination_mech = instantiate(cfg.termination_mech.mech)(
+ input_variables=obs_variables + act_variables + next_obs_variables,
+ output_variables=[BinaryVariable("terminal")],
+ logger=logger,
+ )
+ termination_mech = cast(BaseCausalMech, termination_mech)
+ else:
+ termination_mech = None
+
+ dynamics = Dynamics(
+ transition=transition,
+ reward_mech=reward_mech,
+ termination_mech=termination_mech,
+ state_space=state_space,
+ action_space=action_space,
+ obs2state_fn=obs2state_fn,
+ state2obs_fn=state2obs_fn,
+ logger=logger,
+ )
+
+ return dynamics
diff --git a/cmrl/utils/env.py b/cmrl/utils/env.py
new file mode 100644
index 0000000..f5d68a0
--- /dev/null
+++ b/cmrl/utils/env.py
@@ -0,0 +1,62 @@
+from typing import Dict, Optional, Tuple, cast
+
+import numpy as np
+import emei
+import gym
+import omegaconf
+from stable_baselines3.common.buffers import ReplayBuffer
+
+import cmrl.utils.variables
+from cmrl.types import TermFnType, RewardFnType, InitObsFnType, Obs2StateFnType
+
+
+def make_env(
+ cfg: omegaconf.DictConfig,
+) -> Tuple[emei.EmeiEnv, tuple]:
+ env = cast(emei.EmeiEnv, gym.make(cfg.task.env_id, **cfg.task.params))
+ fns = (
+ env.get_batch_reward,
+ env.get_batch_terminal,
+ env.get_batch_init_obs,
+ env.obs2state,
+ env.state2obs
+ )
+
+ # set seed
+ env.reset(seed=cfg.seed)
+ env.state_space.seed(cfg.seed + 1)
+ env.action_space.seed(cfg.seed + 2)
+ return env, fns
+
+
+def load_offline_data(env, replay_buffer: ReplayBuffer, dataset_name: str, use_ratio: float = 1):
+ assert hasattr(env, "get_dataset"), "env must have `get_dataset` method"
+
+ data_dict = env.get_dataset(dataset_name)
+ all_data_num = len(data_dict["observations"])
+ sample_data_num = int(use_ratio * all_data_num)
+ sample_idx = np.random.permutation(all_data_num)[:sample_data_num]
+
+ assert replay_buffer.n_envs == 1
+ assert replay_buffer.buffer_size >= sample_data_num
+
+ if sample_data_num == replay_buffer.buffer_size:
+ replay_buffer.full = True
+ replay_buffer.pos = 0
+ else:
+ replay_buffer.pos = sample_data_num
+
+ # set all data
+ for attr in ["observations", "next_observations", "actions", "rewards", "dones", "timeouts"]:
+ # if attr == "dones" and attr not in data_dict and "terminals" in data_dict:
+ # replay_buffer.dones[:sample_data_num, 0] = data_dict["terminals"][sample_idx]
+ # continue
+ getattr(replay_buffer, attr)[:sample_data_num, 0] = data_dict[attr][sample_idx]
+
+ for attr in ["extra_obs", "next_extra_obs"]:
+ setattr(
+ replay_buffer,
+ attr,
+ np.zeros((replay_buffer.buffer_size, replay_buffer.n_envs) + data_dict[attr].shape[1:], dtype=np.float32)
+ )
+ getattr(replay_buffer, attr)[:sample_data_num, 0] = data_dict[attr][sample_idx]
diff --git a/cmrl/utils/variables.py b/cmrl/utils/variables.py
new file mode 100644
index 0000000..534c316
--- /dev/null
+++ b/cmrl/utils/variables.py
@@ -0,0 +1,116 @@
+from dataclasses import dataclass
+from typing import Optional, Dict, Union, List
+
+from gym import spaces
+import numpy as np
+import torch
+
+
+@dataclass
+class Variable:
+ name: str
+ pass
+
+
+@dataclass
+class ContinuousVariable(Variable):
+ dim: int
+ low: np.ndarray = None
+ high: np.ndarray = None
+
+
+@dataclass
+class RadianVariable(Variable):
+ dim: int
+
+
+@dataclass
+class BinaryVariable(Variable):
+ pass
+
+
+@dataclass
+class DiscreteVariable(Variable):
+ n: int
+
+
+def parse_space(
+ space: spaces.Space,
+ prefix="obs",
+ extra_info=None
+) -> List[Variable]:
+ extra_info = extra_info if extra_info is not None else {}
+
+ variables = []
+ if isinstance(space, spaces.Box):
+ for i, (low, high) in enumerate(zip(space.low, space.high)):
+ name = "{}_{}".format(prefix, i)
+ if "Radian" in extra_info and name in extra_info["Radian"]:
+ variables.append(RadianVariable(dim=1, name=name))
+ else:
+ variables.append(ContinuousVariable(dim=1, low=low, high=high, name=name))
+ elif isinstance(space, spaces.Discrete):
+ variables.append(DiscreteVariable(n=space.n, name="{}_0".format(prefix)))
+ elif isinstance(space, spaces.MultiDiscrete):
+ for i, n in enumerate(space.nvec):
+ variables.append(DiscreteVariable(n=n, name="{}_{}".format(prefix, i)))
+ elif isinstance(space, spaces.MultiBinary):
+ for i in range(space.n):
+ variables.append(BinaryVariable(name="{}_{}".format(prefix, i)))
+ elif isinstance(space, spaces.Dict):
+ # TODO
+ raise NotImplementedError
+
+ return variables
+
+
+def to_dict_by_space(
+ data: np.ndarray,
+ space: spaces.Space,
+ prefix="obs",
+ repeat: Optional[int] = None,
+ to_tensor: bool = False,
+ device: str = "cpu"
+) -> Dict[str, Union[np.ndarray, torch.Tensor]]:
+ """Transform the interaction data from its own type to python's dict, by the signature of space.
+
+ Args:
+ data: interaction data from replay buffer
+ space: space of gym
+ prefix: prefix of the key in dict
+ repeat: copy data in a new dimension
+ to_tensor: transform the data from numpy's ndarray to torch's tensor
+ device: device
+
+
+ Returns: interaction data organized in dictionary form
+
+ """
+ if repeat:
+ assert repeat > 1, "repeat must be a int greater than 1"
+
+ dict_data = {}
+ if isinstance(space, spaces.Box):
+ # shape of data: (batch-size, node-num), every node has exactly one dim
+ for i, (low, high) in enumerate(zip(space.low, space.high)):
+ # shape of dict_data['xxx']: (batch-size, 1)
+ dict_data["{}_{}".format(prefix, i)] = data[:, i, None].astype(np.float32)
+ else:
+ # TODO
+ raise NotImplementedError
+
+ for name in dict_data:
+ if repeat:
+ # shape of dict_data['xxx']: (repeat-dim, batch-size, specific-dim)
+ # specific-dim is 1 for the case of spaces.Box
+ dict_data[name] = np.tile(dict_data[name][None, :, :], [repeat, 1, 1])
+ if to_tensor:
+ dict_data[name] = torch.from_numpy(dict_data[name]).to(device)
+
+ return dict_data
+
+
+def dict2space(
+ data: Dict[str, Union[np.ndarray, torch.Tensor]], space: spaces.Space
+) -> Dict[str, Union[np.ndarray, torch.Tensor]]:
+ pass
diff --git a/cmrl/util/video.py b/cmrl/utils/video.py
similarity index 100%
rename from cmrl/util/video.py
rename to cmrl/utils/video.py
diff --git a/docs/about.md b/docs/about.md
new file mode 100644
index 0000000..8e22f96
--- /dev/null
+++ b/docs/about.md
@@ -0,0 +1,43 @@
+# Iphitiden Phoebo caede retiaque solvit genis abdiderat
+
+## Humo utinam
+
+Lorem markdownum illos non, somni et evocet Messeniaque diva *agitatis*
+nocentius. Templum Erymanthidas prius, duris mihi, iuvenum, nec quod acceptus
+una, secuit.
+
+1. Foret sanguine puniceo
+2. Erubuit mittit ipso lenta adspexit arbiter nondum
+3. Insanis sum est oves domus nam pars
+4. Distinxit verba
+
+## Iacet venit
+
+Languore manus est ad prima et caelum sit aristas, ante Styphelumque moris ad
+pulsant: vertitur novat. Latrare **minimam coniunx imbribus**: acceptior, ipso
+verum demit laudibus non peperi operiri, iussae arva. Ferit fibris gradieris
+Dianae, et **dabant dependent** adfixa versa flectit: signumque.
+
+Indigenae talia, ora rari est in inter coetus *protinus summaque mittitur*
+fuerant gravisque agitur et sedibus attulit. Multa capessamus Bacchi: de ut
+saeva funera, certamine Chimaera auxiliumque teloque!
+
+## Imponit requirit armigerae quoque sitimque
+
+Inplevit nimium. Sub loqui innectens in cincta ripis plangens est. Annua et
+stabat Panopeusque naidas audentia quantum videoque quam ipsum.
+
+## Ingreditur totidem illi
+
+Est **corpore referam** est rates leti vertitur ab dictu rex quoque sceptra
+flamma. Ut quod? Tota nil, horruit hoc. Ubi colla sopore vides dixit qua non
+sanguineaque dixerat Iove. Cernis est viribus, **indue** genae verbis solio et
+tantum bibulas *et surgere Caesaris* damno Iolaus umquam; scelerate animam?
+
+1. Testudine tener herosmaxime loco tonitruque
+2. Graium pectora pavet sit cum ianua
+3. Quod inque
+4. De patens ictus spectantia ereptaque constitit falsa
+5. Minis modo minor gravet numerusque duorum mediam
+
+Celsior quid fores, tremore, est quo culpatque terret pati. Anno pietas poples!
diff --git a/img/cmrl_logo.png b/docs/cmrl_logo.png
similarity index 100%
rename from img/cmrl_logo.png
rename to docs/cmrl_logo.png
diff --git a/docs/index.md b/docs/index.md
new file mode 100644
index 0000000..d7703e7
--- /dev/null
+++ b/docs/index.md
@@ -0,0 +1,19 @@
+``# Welcome to MkDocs
+
+For full documentation visit [mkdocs.org](https://www.mkdocs.org).
+
+## Commands
+
+* `mkdocs new [dir-name]` - Create a new project.
+* `mkdocs serve` - Start the live-reloading docs server.
+* `mkdocs build` - Build the documentation site.
+* `mkdocs -h` - Print help message and exit.
+
+## Project layout
+
+ mkdocs.yml # The configuration file.
+ docs/
+ index.md # The documentation homepage.
+ ... # Other markdown pages, images and other files.
+
+::: cmrl.models.layers.ParallelLinear
diff --git a/exp_reader.ipynb b/exp_reader.ipynb
new file mode 100644
index 0000000..cb2b37d
--- /dev/null
+++ b/exp_reader.ipynb
@@ -0,0 +1,263 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "import cmrl\n",
+ "from emei.core import get_params_str\n",
+ "from pathlib import Path\n",
+ "import matplotlib.pyplot as plt\n",
+ "import yaml\n",
+ "from collections import defaultdict"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# 递归判断a字典中存在的值是否与b字典相等\n",
+ "def dict_equal(a, b):\n",
+ " for k, v in a.items():\n",
+ " if isinstance(v, dict):\n",
+ " if not dict_equal(v, b[k]):\n",
+ " return False\n",
+ " elif v != b[k]:\n",
+ " return False\n",
+ " return True\n",
+ "\n",
+ "\n",
+ "def get_value(d, key):\n",
+ " if isinstance(key, str):\n",
+ " if key not in d:\n",
+ " raise ValueError(f\"{key} not in dict\")\n",
+ " return d[key]\n",
+ " elif isinstance(key, tuple):\n",
+ " if key[0] not in d:\n",
+ " raise ValueError(f\"{key[0]} not in dict\")\n",
+ " if len(key) <= 1:\n",
+ " raise ValueError(\"length of tuple-key must be 2\")\n",
+ " return get_value(d[key[0]], key[1])\n",
+ " else:\n",
+ " raise ValueError(\"key must be str or tuple\")\n",
+ "\n",
+ "\n",
+ "# 返回多个字典中不同的value对应的key, 通过递归的方法\n",
+ "def get_diff_key(dicts):\n",
+ " if len(dicts) <= 1:\n",
+ " return []\n",
+ " keys = set(dicts[0].keys())\n",
+ " for d in dicts[1:]:\n",
+ " keys = keys & set(d.keys())\n",
+ "\n",
+ " diff_keys = []\n",
+ " for k in keys:\n",
+ " if isinstance(dicts[0][k], dict):\n",
+ " diff_keys += [(k, dk) for dk in get_diff_key([d[k] for d in dicts])]\n",
+ " elif not all([dicts[0][k] == d[k] for d in dicts[1:]]):\n",
+ " diff_keys.append(k)\n",
+ " return diff_keys\n",
+ "\n",
+ "\n",
+ "def argmax(l):\n",
+ " return max(l), l.index(max(l))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "default_params = dict(freq_rate=1,\n",
+ " real_time_scale=0.02,\n",
+ " integrator=\"euler\",\n",
+ " parallel_num=3)\n",
+ "default_custom_cfg = {}\n",
+ "default_result_key = [\"seed\"]\n",
+ "\n",
+ "\n",
+ "def load_log(exp_name=\"default\",\n",
+ " task_name=\"ParallelContinuousCartPoleSwingUp-v0\",\n",
+ " params=default_params,\n",
+ " dataset=\"SAC-expert-replay\",\n",
+ " custom_cfg=default_custom_cfg,\n",
+ " log_file=\"rollout.csv\",\n",
+ " log_key=\"ep_rew_mean\"):\n",
+ " path = Path(\"./exp\") / exp_name / task_name / get_params_str(params) / dataset\n",
+ "\n",
+ " result_list = []\n",
+ " cfg_list = []\n",
+ " for time_dir in path.glob(r\"*\"):\n",
+ " if not time_dir.is_dir() or not (time_dir / \".hydra\").exists():\n",
+ " continue\n",
+ "\n",
+ " config_path = time_dir / \".hydra\" / \"config.yaml\"\n",
+ " with open(config_path, \"r\") as f:\n",
+ " cfg = yaml.load(f, Loader=yaml.FullLoader)\n",
+ "\n",
+ " if not dict_equal(custom_cfg, cfg):\n",
+ " print(\"{} is passed cause its inconsistent cfg\".format(time_dir))\n",
+ " continue\n",
+ "\n",
+ " log_path = time_dir / \"log\" / log_file\n",
+ " if not log_path.exists():\n",
+ " continue\n",
+ "\n",
+ " df = pd.read_csv(log_path)\n",
+ " result_list.append(df[log_key].to_numpy())\n",
+ " cfg_list.append(cfg)\n",
+ "\n",
+ " diff_key = get_diff_key(cfg_list)\n",
+ " result_dict = {}\n",
+ " for i, cfg in enumerate(cfg_list):\n",
+ " result_dict[tuple([get_value(cfg, key) for key in diff_key])] = result_list[i]\n",
+ " return diff_key, result_dict"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "outputs": [],
+ "source": [
+ "def draw_result(exp_name=\"default\",\n",
+ " task_name=\"ParallelContinuousCartPoleSwingUp-v0\",\n",
+ " params=default_params,\n",
+ " dataset=\"SAC-expert-replay\",\n",
+ " custom_cfg=default_custom_cfg,\n",
+ " log_file=\"rollout.csv\",\n",
+ " log_key=\"ep_rew_mean\",\n",
+ " group_key=('transition', 'oracle')):\n",
+ " diff_key, result_dict = load_log(exp_name=exp_name,\n",
+ " task_name=task_name,\n",
+ " params=params,\n",
+ " dataset=dataset,\n",
+ " custom_cfg=custom_cfg,\n",
+ " log_file=log_file,\n",
+ " log_key=log_key)\n",
+ "\n",
+ " idx = diff_key.index(group_key)\n",
+ "\n",
+ " for name in set([key[idx] for key in result_dict.keys()]):\n",
+ " values = [value for key, value in result_dict.items() if key[idx] == name]\n",
+ " longest, longest_idx = argmax([value.shape[0] for value in values])\n",
+ " values_array = np.empty((len(values), longest))\n",
+ " for i, value in enumerate(values):\n",
+ " values_array[i, :len(value)] = value\n",
+ " values_array[i, len(value):] = values[longest_idx][len(value):]\n",
+ " plt.plot(values_array.mean(axis=0), label=name)\n",
+ " plt.fill_between(np.arange(len(values_array.mean(axis=0))),\n",
+ " values_array.mean(axis=0) - values_array.std(axis=0),\n",
+ " values_array.mean(axis=0) + values_array.std(axis=0), alpha=0.5)\n",
+ " plt.legend()"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": "