Skip to content

Commit

Permalink
Merge pull request #67 from BDonnot/zrg-1.0.0
Browse files Browse the repository at this point in the history
MultiMixEnvironment Iterable
  • Loading branch information
BDonnot authored Jun 18, 2020
2 parents fb69ff5 + 576af52 commit 0394a64
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 17 deletions.
60 changes: 46 additions & 14 deletions grid2op/Environment/MultiMixEnv.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self,

self.current_env = None
self.env_index = None
self._envs = []
self.mix_envs = []

# Inline import to prevent cyclical import
from grid2op.MakeEnv.Make import make
Expand All @@ -41,17 +41,17 @@ def __init__(self,
if not os.path.isdir(env_path):
continue
env = make(env_path, **kwargs)
self._envs.append(env)
self.mix_envs.append(env)
except Exception as e:
err_msg = "MultiMix environment creation failed: {}".format(e)
raise EnvError(err_msg)

if len(self._envs) == 0:
if len(self.mix_envs) == 0:
err_msg = "MultiMix envs_dir did not contain any valid env"
raise EnvError(err_msg)

self.env_index = 0
self.current_env = self._envs[self.env_index]
self.current_env = self.mix_envs[self.env_index]
# Make sure GridObject class attributes are set from first env
# Should be fine since the grid is the same for all envs
self.__class__ = self.init_grid(self.current_env)
Expand All @@ -60,16 +60,48 @@ def __init__(self,
def current_index(self):
return self.env_index

def __len__(self):
return len(self.mix_envs)

def __iter__(self):
"""
Operator __iter__ overload to make a ``MultiMixEnvironment`` iterable
.. code-block:: python
import grid2op
from grid2op.Environment import MultiMixEnvironment
from grid2op.Runner import Runner
mm_env = MultiMixEnvironment("/path/to/multi/dataset/folder")
for env in mm_env:
run_p = env.get_params_for_runner()
runner = Runner(**run_p)
runner.run(nb_episode=1, max_iter=-1)
"""
self.env_index = 0
return self

def __next__(self):
if self.env_index < len(self.mix_envs):
r = self.mix_envs[self.env_index]
self.env_index = self.env_index + 1
return r
else:
self.env_index = 0
raise StopIteration

def __getattr__(self, name):
return getattr(self.current_env, name)

def reset(self, random=False):
if random:
self.env_index = self.space_prng.randint(len(self._envs))
self.env_index = self.space_prng.randint(len(self.mix_envs))
else:
self.env_index = (self.env_index + 1) % len(self._envs)
self.current_env = self._envs[self.env_index]
self.env_index = (self.env_index + 1) % len(self.mix_envs)

self.current_env = self.mix_envs[self.env_index]
self.current_env.reset()
return self.get_obs()

Expand Down Expand Up @@ -100,26 +132,26 @@ def seed(self, seed=None):
s = super().seed(seed)
seeds = [s]
max_dt_int = np.iinfo(dt_int).max
for env in self._envs:
for env in self.mix_envs:
env_seed = self.space_prng.randint(max_dt_int)
env_seeds = env.seed(env_seed)
seeds.append(env_seeds)
return seeds

def deactivate_forecast(self):
for e in self._envs:
for e in self.mix_envs:
e.deactivate_forecast()

def reactivate_forecast(self):
for e in self._envs:
for e in self.mix_envs:
e.reactivate_forecast()

def set_thermal_limit(self, thermal_limit):
"""
Set the thermal limit effectively.
Will propagate to all underlying environments
"""
for e in self._envs:
for e in self.mix_envs:
e.set_thermal_limit(thermal_limit)

def __enter__(self):
Expand All @@ -139,7 +171,7 @@ def __exit__(self, *args):
return False

def close(self):
for e in self._envs:
for e in self.mix_envs:
e.close()

def attach_layout(self, grid_layout):
Expand All @@ -155,5 +187,5 @@ def attach_layout(self, grid_layout):
-------
"""
for e in self._envs:
for e in self.mix_envs:
e.attach_layout(grid_layout)
6 changes: 3 additions & 3 deletions grid2op/tests/test_MultiMix.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def dummy(self):
backend=DummyBackend())
assert mme.current_obs is not None
assert mme.current_env is not None
for env in mme._envs:
for env in mme:
assert env.backend.dummy() == True

def test_creation_with_opponent(self):
Expand All @@ -77,7 +77,7 @@ def test_creation_with_opponent(self):
opponent_budget_per_ts=0.42)
assert mme.current_obs is not None
assert mme.current_env is not None
for env in mme._envs:
for env in mme:
assert env.opponent_class == BaseOpponent
assert env.opponent_init_budget == dt_float(42.0)
assert env.opponent_budget_per_ts == dt_float(0.42)
Expand Down Expand Up @@ -130,7 +130,7 @@ def dummy(self):
mme = MultiMixEnvironment(PATH_DATA_MULTIMIX,
backend=DummyBackend())
mme.reset()
for env in mme._envs:
for env in mme:
assert env.backend.dummy() == 1

def test_reset_with_opponent(self):
Expand Down

0 comments on commit 0394a64

Please sign in to comment.