Skip to content

Commit

Permalink
polish(gry): polish reward model and td error (#624)
Browse files Browse the repository at this point in the history
* polish gcl

* polish gail irl

* polish icm api doc

* add api comment for mdqn td error

* add config table for pdeil reward model

* add config table for pwil reward model

* add config table for red reward model

* add config table for rnd reward model

* add config table for trex reward model

* add config table for drex reward model

* add config table for drex reward model

* add comment for td error

fix style for reward model

* fix typo for reward model and td

* fix typo for clear buffer
  • Loading branch information
ruoyuGao authored Apr 3, 2023
1 parent cd10e58 commit a580019
Show file tree
Hide file tree
Showing 11 changed files with 455 additions and 136 deletions.
41 changes: 34 additions & 7 deletions ding/reward_model/drex_reward_model.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,54 @@
import copy
from easydict import EasyDict
import numpy as np
import pickle

import torch
import torch.nn as nn

from ding.utils import REWARD_MODEL_REGISTRY

from .trex_reward_model import TrexRewardModel


@REWARD_MODEL_REGISTRY.register('drex')
class DrexRewardModel(TrexRewardModel):
"""
Overview:
The Drex reward model class (https://arxiv.org/pdf/1907.03976.pdf)
Interface:
``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \
``__init__``, ``_train``,
Config:
== ==================== ====== ============= ======================================= ===============
ID Symbol Type Default Value Description Other(Shape)
== ==================== ====== ============= ======================================= ===============
1 ``type`` str drex | Reward model register name, refer |
| to registry ``REWARD_MODEL_REGISTRY`` |
3 | ``learning_rate`` float 0.00001 | learning rate for optimizer |
4 | ``update_per_`` int 100 | Number of updates per collect |
| ``collect`` | |
5 | ``batch_size`` int 64 | How many samples in a training batch |
6 | ``hidden_size`` int 128 | Linear model hidden size |
7 | ``num_trajs`` int 0 | Number of downsampled full |
| trajectories |
8 | ``num_snippets`` int 6000 | Number of short subtrajectories |
| to sample |
== ==================== ====== ============= ======================================= ================
"""
config = dict(
# (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``.
type='drex',
# (float) The step size of gradient descent.
learning_rate=1e-5,
# (int) How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=100,
# (int) How many samples in a training batch.
batch_size=64,
target_new_data_count=64,
# (int) Linear model hidden size
hidden_size=128,
num_trajs=0, # number of downsampled full trajectories
num_snippets=6000, # number of short subtrajectories to sample
# (int) Number of downsampled full trajectories.
num_trajs=0,
# (int) Number of short subtrajectories to sample.
num_snippets=6000,
)

bc_cfg = None
Expand Down
56 changes: 35 additions & 21 deletions ding/reward_model/gail_irl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import random
from collections.abc import Iterable
from easydict import EasyDict
import numpy as np

import torch
import torch.nn as nn
Expand Down Expand Up @@ -113,35 +112,50 @@ class GailRewardModel(BaseRewardModel):
``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \
``__init__``, ``state_dict``, ``load_state_dict``, ``learn``
Config:
== ==================== ======== ============= ================================= =======================
ID Symbol Type Default Value Description Other(Shape)
== ==================== ======== ============= ================================= =======================
1 ``type`` str gail | RL policy register name, refer | this arg is optional,
| to registry ``POLICY_REGISTRY`` | a placeholder
2 | ``expert_data_`` str expert_data. | Path to the expert dataset | Should be a '.pkl'
| ``path`` .pkl | | file
3 | ``update_per_`` int 100 | Number of updates per collect |
| ``collect`` | |
4 | ``batch_size`` int 64 | Training batch size |
5 | ``input_size`` int | Size of the input: |
| | obs_dim + act_dim |
6 | ``target_new_`` int 64 | Collect steps per iteration |
| ``data_count`` | |
7 | ``hidden_size`` int 128 | Linear model hidden size |
8 | ``collect_count`` int 100000 | Expert dataset size | One entry is a (s,a)
| | | tuple
== ==================== ======== ============= ================================= =======================
"""
== ==================== ======== ============= =================================== =======================
ID Symbol Type Default Value Description Other(Shape)
== ==================== ======== ============= =================================== =======================
1 ``type`` str gail | RL policy register name, refer | this arg is optional,
| to registry ``POLICY_REGISTRY`` | a placeholder
2 | ``expert_data_`` str expert_data. | Path to the expert dataset | Should be a '.pkl'
| ``path`` .pkl | | file
3 | ``learning_rate`` float 0.001 | The step size of gradient descent |
4 | ``update_per_`` int 100 | Number of updates per collect |
| ``collect`` | |
5 | ``batch_size`` int 64 | Training batch size |
6 | ``input_size`` int | Size of the input: |
| | obs_dim + act_dim |
7 | ``target_new_`` int 64 | Collect steps per iteration |
| ``data_count`` | |
8 | ``hidden_size`` int 128 | Linear model hidden size |
9 | ``collect_count`` int 100000 | Expert dataset size | One entry is a (s,a)
| | | tuple
10 | ``clear_buffer_`` int 1 | clear buffer per fixed iters | make sure replay
| ``per_iters`` | buffer's data count
| | isn't too few.
| | (code work in entry)
== ==================== ======== ============= =================================== =======================
"""
config = dict(
# (str) RL policy register name, refer to registry ``POLICY_REGISTRY``.
type='gail',
# (float) The step size of gradient descent.
learning_rate=1e-3,
# (int) How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=100,
# (int) How many samples in a training batch.
batch_size=64,
# (int) Size of the input: obs_dim + act_dim.
input_size=4,
# (int) Collect steps per iteration.
target_new_data_count=64,
# (int) Linear model hidden size.
hidden_size=128,
# (int) Expert dataset size.
collect_count=100000,
# (int) Clear buffer per fixed iters.
clear_buffer_per_iters=1,
)

Expand Down
48 changes: 38 additions & 10 deletions ding/reward_model/guided_cost_reward_model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from typing import List, Dict, Any, Tuple, Union, Optional
from typing import List, Dict, Any
from easydict import EasyDict

import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Independent, Normal
import copy

from ding.utils import SequenceType, REWARD_MODEL_REGISTRY
from ding.utils.data import default_collate, default_decollate
from ding.model import FCEncoder, ConvEncoder
from ding.utils import REWARD_MODEL_REGISTRY
from ding.utils.data import default_collate
from .base_reward_model import BaseRewardModel


Expand All @@ -38,23 +35,54 @@ def forward(self, x):

@REWARD_MODEL_REGISTRY.register('guided_cost')
class GuidedCostRewardModel(BaseRewardModel):
r"""
"""
Overview:
Policy class of Guided cost algorithm.
https://arxiv.org/pdf/1603.00448.pdf
Policy class of Guided cost algorithm. (https://arxiv.org/pdf/1603.00448.pdf)
Interface:
``estimate``, ``train``, ``collect_data``, ``clear_date``, \
``__init__``, ``state_dict``, ``load_state_dict``, ``learn``\
``state_dict_reward_model``, ``load_state_dict_reward_model``
Config:
== ==================== ======== ============= ======================================== ================
ID Symbol Type Default Value Description Other(Shape)
== ==================== ======== ============= ======================================== ================
1 ``type`` str guided_cost | Reward model register name, refer |
| to registry ``REWARD_MODEL_REGISTRY`` |
2 | ``continuous`` bool True | Whether action is continuous |
3 | ``learning_rate`` float 0.001 | learning rate for optimizer |
4 | ``update_per_`` int 100 | Number of updates per collect |
| ``collect`` | |
5 | ``batch_size`` int 64 | Training batch size |
6 | ``hidden_size`` int 128 | Linear model hidden size |
7 | ``action_shape`` int 1 | Action space shape |
8 | ``log_every_n`` int 50 | add loss to log every n iteration |
| ``_train`` | |
9 | ``store_model_`` int 100 | save model every n iteration |
| ``every_n_train`` |
== ==================== ======== ============= ======================================== ================
"""

config = dict(
# (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``.
type='guided_cost',
# (float) The step size of gradient descent.
learning_rate=1e-3,
# (int) Action space shape, such as 1.
action_shape=1,
# (bool) Whether action is continuous.
continuous=True,
# (int) How many samples in a training batch.
batch_size=64,
# (int) Linear model hidden size.
hidden_size=128,
# (int) How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=100,
# (int) Add loss to log every n iteration.
log_every_n_train=50,
# (int) Save model every n iteration.
store_model_every_n_train=100,
)

Expand Down
63 changes: 51 additions & 12 deletions ding/reward_model/icm_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from ding.utils import SequenceType, REWARD_MODEL_REGISTRY
from ding.model import FCEncoder, ConvEncoder
Expand All @@ -28,7 +27,7 @@ def collect_states(iterator: list) -> Tuple[list, list, list]:


class ICMNetwork(nn.Module):
r"""
"""
Intrinsic Curiosity Model (ICM Module)
Implementation of:
[1] Curiosity-driven Exploration by Self-supervised Prediction
Expand Down Expand Up @@ -130,30 +129,70 @@ class ICMRewardModel(BaseRewardModel):
The ICM reward model class (https://arxiv.org/pdf/1705.05363.pdf)
Interface:
``estimate``, ``train``, ``collect_data``, ``clear_data``, \
``__init__``, ``_train``,
``__init__``, ``_train``, ``load_state_dict``, ``state_dict``
Config:
== ==================== ======== ============= ==================================== =======================
ID Symbol Type Default Value Description Other(Shape)
== ==================== ======== ============= ==================================== =======================
1 ``type`` str icm | Reward model register name, |
| refer to registry |
| ``REWARD_MODEL_REGISTRY`` |
2 | ``intrinsic_`` str add | the intrinsic reward type | including add, new
| ``reward_type`` | | , or assign
3 | ``learning_rate`` float 0.001 | The step size of gradient descent |
4 | ``obs_shape`` Tuple( 6 | the observation shape |
[int,
list])
5 | ``action_shape`` int 7 | the action space shape |
6 | ``batch_size`` int 64 | Training batch size |
7 | ``hidden`` list [64, 64, | the MLP layer shape |
| ``_size_list`` (int) 128] | |
8 | ``update_per_`` int 100 | Number of updates per collect |
| ``collect`` | |
9 | ``reverse_scale`` float 1 | the importance weight of the |
| forward and reverse loss |
10 | ``intrinsic_`` float 0.003 | the weight of intrinsic reward | r = w*r_i + r_e
``reward_weight``
11 | ``extrinsic_`` bool True | Whether to normlize
``reward_norm`` | extrinsic reward
12 | ``extrinsic_`` int 1 | the upper bound of the reward
``reward_norm_max`` | normalization
13 | ``clear_buffer`` int 1 | clear buffer per fixed iters | make sure replay
``_per_iters`` | buffer's data count
| isn't too few.
| (code work in entry)
== ==================== ======== ============= ==================================== =======================
"""
config = dict(
# (str) the type of the exploration method
# (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``.
type='icm',
# (str) the intrinsic reward type, including add, new, or assign
# (str) The intrinsic reward type, including add, new, or assign.
intrinsic_reward_type='add',
# (float) learning rate of the optimizer
# (float) The step size of gradient descent.
learning_rate=1e-3,
# (Tuple[int, list]), the observation shape,
# (Tuple[int, list]), The observation shape.
obs_shape=6,
# (int) the action shape, support discrete action only in this version
# (int) The action shape, support discrete action only in this version.
action_shape=7,
# (float) batch size
# (float) Batch size.
batch_size=64,
# (list) the MLP layer shape
# (list) The MLP layer shape.
hidden_size_list=[64, 64, 128],
# (int) update how many times after each collect
# (int) How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=100,
# (float) the importance weight of the forward and reverse loss
# (float) The importance weight of the forward and reverse loss.
reverse_scale=1,
# (float) The weight of intrinsic reward.
# r = intrinsic_reward_weight * r_i + r_e.
intrinsic_reward_weight=0.003, # 1/300
# (bool) Whether to normlize extrinsic reward.
# Normalize the reward to [0, extrinsic_reward_norm_max].
extrinsic_reward_norm=True,
# (int) The upper bound of the reward normalization.
extrinsic_reward_norm_max=1,
# (int) Clear buffer per fixed iters.
clear_buffer_per_iters=100,
)

Expand Down
2 changes: 1 addition & 1 deletion ding/reward_model/ngu_reward_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
import random
from typing import Union, Tuple, Any, Dict, List
from typing import Union, Tuple, Dict, List

import numpy as np
import torch
Expand Down
Loading

0 comments on commit a580019

Please sign in to comment.