Skip to content

Commit

Permalink
Add class mapping of categorical model (#216)
Browse files Browse the repository at this point in the history
* Add class mapping for Categorical model

* Apply format

---------

Co-authored-by: Toni-SM <[email protected]>
  • Loading branch information
Telios and Toni-SM authored Nov 3, 2024
1 parent 9252ec9 commit eff7295
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion skrl/utils/runner/jax/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from skrl.resources.schedulers.jax import KLAdaptiveLR # noqa
from skrl.trainers.jax import SequentialTrainer, Trainer
from skrl.utils import set_seed
from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model
from skrl.utils.model_instantiators.jax import categorical_model, deterministic_model, gaussian_model


class Runner:
Expand All @@ -35,6 +35,7 @@ def __init__(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str,
self._class_mapping = {
# model
"gaussianmixin": gaussian_model,
"categoricalmixin": categorical_model,
"deterministicmixin": deterministic_model,
"shared": None,
# memory
Expand Down
3 changes: 2 additions & 1 deletion skrl/utils/runner/torch/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from skrl.resources.schedulers.torch import KLAdaptiveLR # noqa
from skrl.trainers.torch import SequentialTrainer, Trainer
from skrl.utils import set_seed
from skrl.utils.model_instantiators.torch import deterministic_model, gaussian_model, shared_model
from skrl.utils.model_instantiators.torch import categorical_model, deterministic_model, gaussian_model, shared_model


class Runner:
Expand All @@ -35,6 +35,7 @@ def __init__(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str,
self._class_mapping = {
# model
"gaussianmixin": gaussian_model,
"categoricalmixin": categorical_model,
"deterministicmixin": deterministic_model,
"shared": shared_model,
# memory
Expand Down

0 comments on commit eff7295

Please sign in to comment.