Skip to content

BolunDai0216/MinigridMiniworldTransfer

Repository files navigation

MinigridMiniworldTransfer

Create Custom Feature Extractor in Stable-Baselines3

The goal of this section is to look at how we can create a custom feature extractor for the RecurrentActorCriticCnnPolicy class in stable-baselines3-contrib. We can see that one of the inputs RecurrentActorCriticCnnPolicy takes is features_extractor_class which is defaulted to NatureCNN. We can see that the CNN network in the NatureCNN class has the architecture

self.cnn = nn.Sequential(
    nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
    nn.ReLU(),
    nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
    nn.ReLU(),
    nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
    nn.ReLU(),
    nn.Flatten(),
)

This does not work with Minigrid observations, which rarely has shape larger than 3x9x9. A better option for the CNN architecture is what is shown in rl-starter-files

self.image_conv = nn.Sequential(
    nn.Conv2d(3, 16, (2, 2)),
    nn.ReLU(),
    nn.MaxPool2d((2, 2)),
    nn.Conv2d(16, 32, (2, 2)),
    nn.ReLU(),
    nn.Conv2d(32, 64, (2, 2)),
    nn.ReLU()
)

Train Minigrid & Miniworld Env

To train the Minigrid agent, use the command

python3 minigrid_gotoobj_train.py --train

To train the Miniworld agent use the command

python3 miniworld_gotoobj_train.py --train

To run all of the transfer learning experiments, use the commands

python3 load_param_gotoobj.py --num 1
python3 load_param_gotoobj.py --num 2
python3 load_param_gotoobj.py --num 3

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published