Training Minigrid Environments#
The environments in the Minigrid library can be trained easily using StableBaselines3. In this tutorial we show how a PPO agent can be trained on the
Create Custom Feature Extractor#
StableBaselines3 is fully compatible with
Gymnasium-based environments, including Minigrid, the default CNN architecture does not directly support the Minigrid observation space. Thus, to train an agent on Minigrid environments, we therefore need to create a custom feature extractor. This can be done by creating a feature extractor class that inherits from
class MinigridFeaturesExtractor(BaseFeaturesExtractor): def __init__(self, observation_space: gym.Space, features_dim: int = 512, normalized_image: bool = False) -> None: super().__init__(observation_space, features_dim) n_input_channels = observation_space.shape self.cnn = nn.Sequential( nn.Conv2d(n_input_channels, 16, (2, 2)), nn.ReLU(), nn.Conv2d(16, 32, (2, 2)), nn.ReLU(), nn.Conv2d(32, 64, (2, 2)), nn.ReLU(), nn.Flatten(), ) # Compute shape by doing one forward pass with torch.no_grad(): n_flatten = self.cnn(torch.as_tensor(observation_space.sample()[None]).float()).shape self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU()) def forward(self, observations: torch.Tensor) -> torch.Tensor: return self.linear(self.cnn(observations))
Train a PPO Agent#
The using the custom feature extractor, we can train a PPO agent on the
MiniGrid-Empty-16x16-v0 environment. The following code snippet shows how this can be done.
import minigrid from minigrid.wrappers import ImgObsWrapper from stable_baselines3 import PPO policy_kwargs = dict( features_extractor_class=MinigridFeaturesExtractor, features_extractor_kwargs=dict(features_dim=128), ) env = gym.make("MiniGrid-Empty-16x16-v0", render_mode="rgb_array") env = ImgObsWrapper(env) model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs, verbose=1) model.learn(2e5)
By default the observation of Minigrid environments are dictionaries. Since the
CnnPolicy from StableBaseline3 by default takes in image observations, we need to wrap the environment using the
ImgObsWrapper from the Minigrid library. This wrapper converts the dictionary observation to an image observation.