diff --git a/ding/model/common/encoder.py b/ding/model/common/encoder.py index 7662bd17dc..2af4762234 100644 --- a/ding/model/common/encoder.py +++ b/ding/model/common/encoder.py @@ -141,7 +141,8 @@ def __init__( hidden_size_list: SequenceType, res_block: bool = False, activation: Optional[nn.Module] = nn.ReLU(), - norm_type: Optional[str] = None + norm_type: Optional[str] = None, + dropout: Optional[float] = None ) -> None: """ Overview: @@ -153,6 +154,7 @@ def __init__( - activation (:obj:`nn.Module`): Type of activation to use in ``ResFCBlock``. Default is ``nn.ReLU()``. - norm_type (:obj:`str`): Type of normalization to use. See ``ding.torch_utils.network.ResFCBlock`` \ for more details. Default is ``None``. + - dropout (:obj:`float`): Dropout rate of the dropout layer. If ``None`` then default no dropout layer. """ super(FCEncoder, self).__init__() self.obs_shape = obs_shape @@ -162,17 +164,21 @@ def __init__( if res_block: assert len(set(hidden_size_list)) == 1, "Please indicate the same hidden size for res block parts" if len(hidden_size_list) == 1: - self.main = ResFCBlock(hidden_size_list[0], activation=self.act, norm_type=norm_type) + self.main = ResFCBlock(hidden_size_list[0], activation=self.act, norm_type=norm_type, dropout=dropout) else: layers = [] for i in range(len(hidden_size_list)): - layers.append(ResFCBlock(hidden_size_list[0], activation=self.act, norm_type=norm_type)) + layers.append( + ResFCBlock(hidden_size_list[0], activation=self.act, norm_type=norm_type, dropout=dropout) + ) self.main = nn.Sequential(*layers) else: layers = [] for i in range(len(hidden_size_list) - 1): layers.append(nn.Linear(hidden_size_list[i], hidden_size_list[i + 1])) layers.append(self.act) + if dropout is not None: + layers.append(nn.Dropout(dropout)) self.main = nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/ding/model/common/head.py b/ding/model/common/head.py index def699564b..817726493a 100755 --- a/ding/model/common/head.py +++ b/ding/model/common/head.py @@ -28,6 +28,7 @@ def __init__( layer_num: int = 1, activation: Optional[nn.Module] = nn.ReLU(), norm_type: Optional[str] = None, + dropout: Optional[float] = None, noise: Optional[bool] = False, ) -> None: """ @@ -41,6 +42,7 @@ def __init__( If ``None``, then default set activation to ``nn.ReLU()``. Default ``None``. - norm_type (:obj:`str`): The type of normalization to use. See ``ding.torch_utils.network.fc_block`` \ for more details. Default ``None``. + - dropout (:obj:`float`): The dropout rate, default set to None. - noise (:obj:`bool`): Whether use ``NoiseLinearLayer`` as ``layer_fn`` in Q networks' MLP. \ Default ``False``. """ @@ -55,6 +57,8 @@ def __init__( layer_num, layer_fn=layer, activation=activation, + use_dropout=dropout is not None, + dropout_probability=dropout, norm_type=norm_type ), block(hidden_size, output_size) ) @@ -800,6 +804,7 @@ def __init__( v_layer_num: Optional[int] = None, activation: Optional[nn.Module] = nn.ReLU(), norm_type: Optional[str] = None, + dropout: Optional[float] = None, noise: Optional[bool] = False, ) -> None: """ @@ -814,6 +819,7 @@ def __init__( If ``None``, then default set activation to ``nn.ReLU()``. Default ``None``. - norm_type (:obj:`str`): The type of normalization to use. See ``ding.torch_utils.network.fc_block`` \ for more details. Default ``None``. + - dropout (:obj:`float`): The dropout rate of dropout layer. Default ``None``. - noise (:obj:`bool`): Whether use ``NoiseLinearLayer`` as ``layer_fn`` in Q networks' MLP. \ Default ``False``. """ @@ -832,6 +838,8 @@ def __init__( a_layer_num, layer_fn=layer, activation=activation, + use_dropout=dropout is not None, + dropout_probability=dropout, norm_type=norm_type ), block(hidden_size, output_size) ) @@ -843,6 +851,8 @@ def __init__( v_layer_num, layer_fn=layer, activation=activation, + use_dropout=dropout is not None, + dropout_probability=dropout, norm_type=norm_type ), block(hidden_size, 1) ) diff --git a/ding/model/template/q_learning.py b/ding/model/template/q_learning.py index 544b9d10f4..013790cd65 100644 --- a/ding/model/template/q_learning.py +++ b/ding/model/template/q_learning.py @@ -21,7 +21,8 @@ def __init__( head_hidden_size: Optional[int] = None, head_layer_num: int = 1, activation: Optional[nn.Module] = nn.ReLU(), - norm_type: Optional[str] = None + norm_type: Optional[str] = None, + dropout: Optional[float] = None ) -> None: """ Overview: @@ -35,9 +36,11 @@ def __init__( - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network. - head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \ - if ``None`` then default set it to ``nn.ReLU()`` + if ``None`` then default set it to ``nn.ReLU()``. - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \ ``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN'] + - dropout (:obj:`Optional[float]`): The dropout rate of the dropout layer. \ + if ``None`` then default no dropout layer. """ super(DQN, self).__init__() # Squeeze data from tuple, list or dict to single object. For example, from (4, ) to 4 @@ -46,9 +49,12 @@ def __init__( head_hidden_size = encoder_hidden_size_list[-1] # FC Encoder if isinstance(obs_shape, int) or len(obs_shape) == 1: - self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) + self.encoder = FCEncoder( + obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type, dropout=dropout + ) # Conv Encoder elif len(obs_shape) == 3: + assert dropout is None, "dropout is not supported in ConvEncoder" self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) else: raise RuntimeError( @@ -67,11 +73,17 @@ def __init__( action_shape, layer_num=head_layer_num, activation=activation, - norm_type=norm_type + norm_type=norm_type, + dropout=dropout ) else: self.head = head_cls( - head_hidden_size, action_shape, head_layer_num, activation=activation, norm_type=norm_type + head_hidden_size, + action_shape, + head_layer_num, + activation=activation, + norm_type=norm_type, + dropout=dropout ) def forward(self, x: torch.Tensor) -> Dict: diff --git a/ding/policy/dqn.py b/ding/policy/dqn.py index a51bdf64b0..37ca5051ce 100644 --- a/ding/policy/dqn.py +++ b/ding/policy/dqn.py @@ -43,34 +43,37 @@ class DQNPolicy(Policy): | ``_hidden`` (int) 64, 128] | subsequent conv layers and the | is [8, 4, 3] | ``_size_list`` | final dense layer. | default stride is | [4, 2 ,1] - 10 | ``learn.update`` int 3 | How many updates(iterations) to train | This args can be vary + 10 | ``model.dropout`` float None | Dropout rate for dropout layers. | [0,1] + | If set to ``None`` + | means no dropout + 11 | ``learn.update`` int 3 | How many updates(iterations) to train | This args can be vary | ``per_collect`` | after collector's one collection. | from envs. Bigger val | Only valid in serial training | means more off-policy - 11 | ``learn.batch_`` int 64 | The number of samples of an iteration + 12 | ``learn.batch_`` int 64 | The number of samples of an iteration | ``size`` - 12 | ``learn.learning`` float 0.001 | Gradient step length of an iteration. + 13 | ``learn.learning`` float 0.001 | Gradient step length of an iteration. | ``_rate`` - 13 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update + 14 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update | ``update_freq`` - 14 | ``learn.target_`` float 0.005 | Frequence of target network update. | Soft(assign) update + 15 | ``learn.target_`` float 0.005 | Frequence of target network update. | Soft(assign) update | ``theta`` | Only one of [target_update_freq, | | target_theta] should be set - 15 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some + 16 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some | ``done`` | calculation. | fake termination env - 16 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from + 17 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from | call of collector. | different envs - 17 ``collect.n_episode`` int 8 | The number of training episodes of a | only one of [n_sample + 18 ``collect.n_episode`` int 8 | The number of training episodes of a | only one of [n_sample | call of collector | ,n_episode] should | | be set - 18 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1 + 19 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1 | ``_len`` - 19 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp', + 20 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp', | 'linear']. - 20 | ``other.eps.`` float 0.95 | start value of exploration rate | [0,1] + 21 | ``other.eps.`` float 0.95 | start value of exploration rate | [0,1] | ``start`` - 21 | ``other.eps.`` float 0.1 | end value of exploration rate | [0,1] + 22 | ``other.eps.`` float 0.1 | end value of exploration rate | [0,1] | ``end`` - 22 | ``other.eps.`` int 10000 | decay length of exploration | greater than 0. set + 23 | ``other.eps.`` int 10000 | decay length of exploration | greater than 0. set | ``decay`` | decay=10000 means | the exploration rate | decay from start diff --git a/ding/torch_utils/network/nn_module.py b/ding/torch_utils/network/nn_module.py index 3b4a63fc9f..3d387da5bc 100644 --- a/ding/torch_utils/network/nn_module.py +++ b/ding/torch_utils/network/nn_module.py @@ -376,7 +376,7 @@ def MLP( block.append(build_normalization(norm_type, dim=1)(out_channels)) if activation is not None: block.append(activation) - if use_dropout: + if use_dropout is not None: block.append(nn.Dropout(dropout_probability)) # The last layer @@ -396,6 +396,8 @@ def MLP( # The last layer uses the same activation as front layers. if activation is not None: block.append(activation) + if use_dropout is not None: + block.append(nn.Dropout(dropout_probability)) if last_linear_layer_init_zero: # Locate the last linear layer and initialize its weights and biases to 0. diff --git a/ding/torch_utils/network/res_block.py b/ding/torch_utils/network/res_block.py index 21dd2b1c2d..136c633da0 100644 --- a/ding/torch_utils/network/res_block.py +++ b/ding/torch_utils/network/res_block.py @@ -111,7 +111,9 @@ class ResFCBlock(nn.Module): forward """ - def __init__(self, in_channels: int, activation: nn.Module = nn.ReLU(), norm_type: str = 'BN'): + def __init__( + self, in_channels: int, activation: nn.Module = nn.ReLU(), norm_type: str = 'BN', dropout: float = None + ): r""" Overview: Init the fully connected layer residual block. @@ -119,9 +121,14 @@ def __init__(self, in_channels: int, activation: nn.Module = nn.ReLU(), norm_typ - in_channels (:obj:`int`): The number of channels in the input tensor. - activation (:obj:`nn.Module`): The optional activation function. - norm_type (:obj:`str`): The type of the normalization, default set to 'BN'. + - dropout (:obj:`float`): The dropout rate, default set to None. """ super(ResFCBlock, self).__init__() self.act = activation + if dropout is not None: + self.dropout = nn.Dropout(dropout) + else: + self.dropout = None self.fc1 = fc_block(in_channels, in_channels, activation=self.act, norm_type=norm_type) self.fc2 = fc_block(in_channels, in_channels, activation=None, norm_type=norm_type) @@ -138,4 +145,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc1(x) x = self.fc2(x) x = self.act(x + identity) + if self.dropout is not None: + x = self.dropout(x) return x diff --git a/dizoo/classic_control/cartpole/config/cartpole_dqn_config.py b/dizoo/classic_control/cartpole/config/cartpole_dqn_config.py index 3e5ca613d0..beff91944b 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_dqn_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_dqn_config.py @@ -17,6 +17,7 @@ action_shape=2, encoder_hidden_size_list=[128, 128, 64], dueling=True, + dropout=0.5, ), nstep=1, discount_factor=0.97,