Skip to content

Commit

Permalink
reshaping obs tensors for map sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
DennisSoemers committed Dec 8, 2023
1 parent da86e48 commit b02fe6f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
4 changes: 3 additions & 1 deletion experiments/ppo_gridnet_variable_mapsizes.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def parse_args():
help='the list of maps used during training')
parser.add_argument('--eval-maps', nargs='+', default=["maps/16x16/basesWorkers16x16A.xml"],
help='the list of maps used during evaluation')
parser.add_argument('--cycle-maps', nargs='+', default=["maps/EightBasesWorkers16x12.xml", "maps/16x16/basesWorkers16x16A.xml"],
help='list of maps to cycle through after environments complete during training')

args = parser.parse_args()
if not args.seed:
Expand Down Expand Up @@ -359,7 +361,7 @@ def on_evaluation_done(self, future):
+ [microrts_ai.workerRushAI for _ in range(min(args.num_bot_envs, 2))],
map_paths=[args.train_maps[0]],
reward_weight=np.array([10.0, 1.0, 1.0, 0.2, 1.0, 4.0]),
cycle_maps=args.train_maps,
cycle_maps=args.cycle_maps,
)
envs = MicroRTSStatsRecorder(envs, args.gamma)
envs = VecMonitor(envs)
Expand Down
15 changes: 15 additions & 0 deletions gym_microrts/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,28 @@ def reset(self):
return np.array(obs)

def _encode_obs(self, obs):
num_channels, height, width = obs.shape

# Add padding to obs such that it is as big as we need for our biggest map
pad_width = self.width - width
pad_height = self.height - height
if pad_width > 0 or pad_height > 0:
obs_padded = np.ndarray((num_channels, self.height, self.width), np.int32)
for channel_idx, plane in enumerate(obs):
if channel_idx == 5: # Index for the walls/terrain channel
obs_padded[channel_idx, :, :] = np.pad(plane, ((0, pad_height), (0, pad_width)), constant_values=1)
else:
obs_padded[channel_idx, :, :] = np.pad(plane, ((0, pad_height), (0, pad_width)), constant_values=0)
obs = obs_padded

obs = obs.reshape(len(obs), -1).clip(0, np.array([self.num_planes]).T - 1)
obs_planes = np.zeros((self.height * self.width, self.num_planes_prefix_sum[-1]), dtype=np.int32)
obs_planes_idx = np.arange(len(obs_planes))
obs_planes[obs_planes_idx, obs[0]] = 1

for i in range(1, self.num_planes_len):
obs_planes[obs_planes_idx, obs[i] + self.num_planes_prefix_sum[i]] = 1

return obs_planes.reshape(self.height, self.width, -1)

def step_async(self, actions):
Expand Down
2 changes: 1 addition & 1 deletion gym_microrts/microrts

0 comments on commit b02fe6f

Please sign in to comment.