The goal of this section is to look at how we can create a custom feature extractor for the RecurrentActorCriticCnnPolicy
class in stable-baselines3-contrib
. We can see that one of the inputs RecurrentActorCriticCnnPolicy
takes is features_extractor_class
which is defaulted to NatureCNN
. We can see that the CNN network in the NatureCNN
class has the architecture
self.cnn = nn.Sequential(
nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten(),
)
This does not work with Minigrid
observations, which rarely has shape larger than 3x9x9
. A better option for the CNN architecture is what is shown in rl-starter-files
self.image_conv = nn.Sequential(
nn.Conv2d(3, 16, (2, 2)),
nn.ReLU(),
nn.MaxPool2d((2, 2)),
nn.Conv2d(16, 32, (2, 2)),
nn.ReLU(),
nn.Conv2d(32, 64, (2, 2)),
nn.ReLU()
)
To train the Minigrid agent, use the command
python3 minigrid_gotoobj_train.py --train
To train the Miniworld agent use the command
python3 miniworld_gotoobj_train.py --train
To run all of the transfer learning experiments, use the commands
python3 load_param_gotoobj.py --num 1
python3 load_param_gotoobj.py --num 2
python3 load_param_gotoobj.py --num 3