Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(lxy): add dropout layers to dqn #712

Merged
merged 5 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions ding/model/common/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def __init__(
hidden_size_list: SequenceType,
res_block: bool = False,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None
norm_type: Optional[str] = None,
dropout: Optional[float] = None
) -> None:
"""
Overview:
Expand All @@ -153,6 +154,7 @@ def __init__(
- activation (:obj:`nn.Module`): Type of activation to use in ``ResFCBlock``. Default is ``nn.ReLU()``.
- norm_type (:obj:`str`): Type of normalization to use. See ``ding.torch_utils.network.ResFCBlock`` \
for more details. Default is ``None``.
- dropout (:obj:`float`): Dropout rate of the dropout layer. If ``None`` then default no dropout layer.
"""
super(FCEncoder, self).__init__()
self.obs_shape = obs_shape
Expand All @@ -162,17 +164,21 @@ def __init__(
if res_block:
assert len(set(hidden_size_list)) == 1, "Please indicate the same hidden size for res block parts"
if len(hidden_size_list) == 1:
self.main = ResFCBlock(hidden_size_list[0], activation=self.act, norm_type=norm_type)
self.main = ResFCBlock(hidden_size_list[0], activation=self.act, norm_type=norm_type, dropout=dropout)
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
else:
layers = []
for i in range(len(hidden_size_list)):
layers.append(ResFCBlock(hidden_size_list[0], activation=self.act, norm_type=norm_type))
layers.append(
ResFCBlock(hidden_size_list[0], activation=self.act, norm_type=norm_type, dropout=dropout)
)
self.main = nn.Sequential(*layers)
else:
layers = []
for i in range(len(hidden_size_list) - 1):
layers.append(nn.Linear(hidden_size_list[i], hidden_size_list[i + 1]))
layers.append(self.act)
if dropout is not None:
layers.append(nn.Dropout(dropout))
self.main = nn.Sequential(*layers)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
10 changes: 10 additions & 0 deletions ding/model/common/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
dropout: Optional[float] = None,
noise: Optional[bool] = False,
) -> None:
"""
Expand All @@ -41,6 +42,7 @@ def __init__(
If ``None``, then default set activation to ``nn.ReLU()``. Default ``None``.
- norm_type (:obj:`str`): The type of normalization to use. See ``ding.torch_utils.network.fc_block`` \
for more details. Default ``None``.
- dropout (:obj:`float`): The dropout rate, default set to None.
- noise (:obj:`bool`): Whether use ``NoiseLinearLayer`` as ``layer_fn`` in Q networks' MLP. \
Default ``False``.
"""
Expand All @@ -55,6 +57,8 @@ def __init__(
layer_num,
layer_fn=layer,
activation=activation,
use_dropout=dropout is not None,
dropout_probability=dropout,
norm_type=norm_type
), block(hidden_size, output_size)
)
Expand Down Expand Up @@ -800,6 +804,7 @@ def __init__(
v_layer_num: Optional[int] = None,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
dropout: Optional[float] = None,
noise: Optional[bool] = False,
) -> None:
"""
Expand All @@ -814,6 +819,7 @@ def __init__(
If ``None``, then default set activation to ``nn.ReLU()``. Default ``None``.
- norm_type (:obj:`str`): The type of normalization to use. See ``ding.torch_utils.network.fc_block`` \
for more details. Default ``None``.
- dropout (:obj:`float`): The dropout rate of dropout layer. Default ``None``.
- noise (:obj:`bool`): Whether use ``NoiseLinearLayer`` as ``layer_fn`` in Q networks' MLP. \
Default ``False``.
"""
Expand All @@ -832,6 +838,8 @@ def __init__(
a_layer_num,
layer_fn=layer,
activation=activation,
use_dropout=dropout is not None,
dropout_probability=dropout,
norm_type=norm_type
), block(hidden_size, output_size)
)
Expand All @@ -843,6 +851,8 @@ def __init__(
v_layer_num,
layer_fn=layer,
activation=activation,
use_dropout=dropout is not None,
dropout_probability=dropout,
norm_type=norm_type
), block(hidden_size, 1)
)
Expand Down
22 changes: 17 additions & 5 deletions ding/model/template/q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def __init__(
head_hidden_size: Optional[int] = None,
head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None
norm_type: Optional[str] = None,
dropout: Optional[float] = None
) -> None:
"""
Overview:
Expand All @@ -35,9 +36,11 @@ def __init__(
- head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network.
- head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output
- activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
if ``None`` then default set it to ``nn.ReLU()``
if ``None`` then default set it to ``nn.ReLU()``.
- norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']
- dropout (:obj:`Optional[float]`): The dropout rate of the dropout layer. \
if ``None`` then default no dropout layer.
"""
super(DQN, self).__init__()
# Squeeze data from tuple, list or dict to single object. For example, from (4, ) to 4
Expand All @@ -46,9 +49,12 @@ def __init__(
head_hidden_size = encoder_hidden_size_list[-1]
# FC Encoder
if isinstance(obs_shape, int) or len(obs_shape) == 1:
self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
self.encoder = FCEncoder(
obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type, dropout=dropout
)
# Conv Encoder
elif len(obs_shape) == 3:
assert dropout is None, "dropout is not supported in ConvEncoder"
self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
else:
raise RuntimeError(
Expand All @@ -67,11 +73,17 @@ def __init__(
action_shape,
layer_num=head_layer_num,
activation=activation,
norm_type=norm_type
norm_type=norm_type,
dropout=dropout
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
)
else:
self.head = head_cls(
head_hidden_size, action_shape, head_layer_num, activation=activation, norm_type=norm_type
head_hidden_size,
action_shape,
head_layer_num,
activation=activation,
norm_type=norm_type,
dropout=dropout
)

def forward(self, x: torch.Tensor) -> Dict:
Expand Down
29 changes: 16 additions & 13 deletions ding/policy/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,34 +43,37 @@ class DQNPolicy(Policy):
| ``_hidden`` (int) 64, 128] | subsequent conv layers and the | is [8, 4, 3]
| ``_size_list`` | final dense layer. | default stride is
| [4, 2 ,1]
10 | ``learn.update`` int 3 | How many updates(iterations) to train | This args can be vary
10 | ``model.dropout`` float None | Dropout rate for dropout layers. | [0,1]
| If set to ``None``
| means no dropout
11 | ``learn.update`` int 3 | How many updates(iterations) to train | This args can be vary
| ``per_collect`` | after collector's one collection. | from envs. Bigger val
| Only valid in serial training | means more off-policy
11 | ``learn.batch_`` int 64 | The number of samples of an iteration
12 | ``learn.batch_`` int 64 | The number of samples of an iteration
| ``size``
12 | ``learn.learning`` float 0.001 | Gradient step length of an iteration.
13 | ``learn.learning`` float 0.001 | Gradient step length of an iteration.
| ``_rate``
13 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update
14 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update
| ``update_freq``
14 | ``learn.target_`` float 0.005 | Frequence of target network update. | Soft(assign) update
15 | ``learn.target_`` float 0.005 | Frequence of target network update. | Soft(assign) update
| ``theta`` | Only one of [target_update_freq,
| | target_theta] should be set
15 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some
16 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some
| ``done`` | calculation. | fake termination env
16 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from
17 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from
| call of collector. | different envs
17 ``collect.n_episode`` int 8 | The number of training episodes of a | only one of [n_sample
18 ``collect.n_episode`` int 8 | The number of training episodes of a | only one of [n_sample
| call of collector | ,n_episode] should
| | be set
18 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1
19 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1
| ``_len``
19 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp',
20 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp',
| 'linear'].
20 | ``other.eps.`` float 0.95 | start value of exploration rate | [0,1]
21 | ``other.eps.`` float 0.95 | start value of exploration rate | [0,1]
| ``start``
21 | ``other.eps.`` float 0.1 | end value of exploration rate | [0,1]
22 | ``other.eps.`` float 0.1 | end value of exploration rate | [0,1]
| ``end``
22 | ``other.eps.`` int 10000 | decay length of exploration | greater than 0. set
23 | ``other.eps.`` int 10000 | decay length of exploration | greater than 0. set
| ``decay`` | decay=10000 means
| the exploration rate
| decay from start
Expand Down
4 changes: 3 additions & 1 deletion ding/torch_utils/network/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def MLP(
block.append(build_normalization(norm_type, dim=1)(out_channels))
if activation is not None:
block.append(activation)
if use_dropout:
if use_dropout is not None:
block.append(nn.Dropout(dropout_probability))

# The last layer
Expand All @@ -396,6 +396,8 @@ def MLP(
# The last layer uses the same activation as front layers.
if activation is not None:
block.append(activation)
if use_dropout is not None:
block.append(nn.Dropout(dropout_probability))

if last_linear_layer_init_zero:
# Locate the last linear layer and initialize its weights and biases to 0.
Expand Down
11 changes: 10 additions & 1 deletion ding/torch_utils/network/res_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,24 @@ class ResFCBlock(nn.Module):
forward
"""

def __init__(self, in_channels: int, activation: nn.Module = nn.ReLU(), norm_type: str = 'BN'):
def __init__(
self, in_channels: int, activation: nn.Module = nn.ReLU(), norm_type: str = 'BN', dropout: float = None
):
r"""
Overview:
Init the fully connected layer residual block.
Arguments:
- in_channels (:obj:`int`): The number of channels in the input tensor.
- activation (:obj:`nn.Module`): The optional activation function.
- norm_type (:obj:`str`): The type of the normalization, default set to 'BN'.
- dropout (:obj:`float`): The dropout rate, default set to None.
"""
super(ResFCBlock, self).__init__()
self.act = activation
if dropout is not None:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = None
self.fc1 = fc_block(in_channels, in_channels, activation=self.act, norm_type=norm_type)
self.fc2 = fc_block(in_channels, in_channels, activation=None, norm_type=norm_type)

Expand All @@ -138,4 +145,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.fc2(x)
x = self.act(x + identity)
if self.dropout is not None:
x = self.dropout(x)
return x
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
action_shape=2,
encoder_hidden_size_list=[128, 128, 64],
dueling=True,
dropout=0.5,
),
nstep=1,
discount_factor=0.97,
Expand Down
Loading