-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cea48af
commit c754f49
Showing
8 changed files
with
427 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import Any | ||
|
||
import numpy as np | ||
|
||
|
||
class ImmortalAction: | ||
def __init__(self): | ||
super().__init__() | ||
self._lookup_table = self._make_lookup_table() | ||
|
||
@staticmethod | ||
def _make_lookup_table(): | ||
actions = [] | ||
# Ground | ||
for throttle in (-1, 0, 1): | ||
for steer in (-1, 0, 1): | ||
for boost in (0, 1): | ||
for handbrake in (0, 1): | ||
if boost == 1 and throttle != 1: | ||
continue | ||
actions.append([throttle or boost, steer, 0, steer, 0, 0, boost, handbrake]) | ||
# Aerial | ||
for pitch in (-1, 0, 1): | ||
for yaw in (-1, 0, 1): | ||
for roll in (-1, 0, 1): | ||
for jump in (0, 1): | ||
for boost in (0, 1): | ||
if pitch == roll == jump == 0: | ||
continue | ||
actions.append([boost, yaw, pitch, yaw, roll, jump, boost, 1]) | ||
actions = np.array(actions) | ||
return actions | ||
|
||
def parse_actions(self, actions: Any) -> np.ndarray: | ||
return self._lookup_table[np.array(actions, dtype=np.float32).squeeze().astype(int)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import math | ||
import os | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn.functional as F | ||
from torch.distributions import Categorical | ||
|
||
from action.actionparser import ImmortalAction | ||
|
||
|
||
def get_action_distribution(obs, actor): | ||
if isinstance(obs, np.ndarray): | ||
obs = torch.from_numpy(obs).float() | ||
elif isinstance(obs, tuple): | ||
obs = tuple(o if isinstance(o, torch.Tensor) else torch.from_numpy(o).float() for o in obs) | ||
|
||
out = actor(obs) | ||
|
||
if isinstance(out, torch.Tensor): | ||
out = (out,) | ||
|
||
max_shape = max(o.shape[-1] for o in out) | ||
logits = torch.stack( | ||
[ | ||
l | ||
if l.shape[-1] == max_shape | ||
else F.pad(l, pad=(0, max_shape - l.shape[-1]), value=float("-inf")) | ||
for l in out | ||
], | ||
dim=1 | ||
) | ||
|
||
return Categorical(logits=logits) | ||
|
||
|
||
def sample_action( | ||
distribution: Categorical, | ||
deterministic=None | ||
): | ||
if deterministic: | ||
action_indices = torch.argmax(distribution.logits, dim=-1) | ||
else: | ||
action_indices = distribution.sample() | ||
|
||
return action_indices | ||
|
||
|
||
|
||
def env_compatible(action): | ||
if isinstance(action, torch.Tensor): | ||
action = action.numpy() | ||
return action | ||
|
||
|
||
class Agent: | ||
def __init__(self): | ||
cur_dir = os.path.dirname(os.path.realpath(__file__)) | ||
self.actor = torch.jit.load(os.path.join(cur_dir, "jit.pt")) | ||
torch.set_num_threads(1) | ||
self.action_parser = ImmortalAction() | ||
|
||
def act(self, state): | ||
with torch.no_grad(): | ||
all_actions = [] | ||
dist = get_action_distribution(state, self.actor) | ||
action_indices = sample_action(dist, deterministic=True)[0] | ||
actions = env_compatible(action_indices) | ||
|
||
all_actions.append(actions) | ||
all_actions = np.array(all_actions) | ||
|
||
return all_actions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
[Bot Loadout] | ||
# Primary Color selection | ||
team_color_id = 9 | ||
# Secondary Color selection | ||
custom_color_id = 104 | ||
# Car type (Octane, Merc, etc) | ||
car_id = 5713 | ||
# Type of decal | ||
decal_id = 5127 | ||
# Wheel selection | ||
wheels_id = 5183 | ||
# Boost selection | ||
boost_id = 5860 | ||
# Antenna Selection | ||
antenna_id = 0 | ||
# Hat Selection | ||
hat_id = 0 | ||
# Paint Type (for first color) | ||
paint_finish_id = 272 | ||
# Paint Type (for secondary color) | ||
custom_finish_id = 2973 | ||
# Engine Audio Selection | ||
engine_audio_id = 0 | ||
# Car trail Selection | ||
trails_id = 3575 | ||
# Goal Explosion Selection | ||
goal_explosion_id = 3131 | ||
# Finds the closest primary color swatch based on the provided RGB value like [34, 255, 60] | ||
primary_color_lookup = None | ||
# Finds the closest secondary color swatch based on the provided RGB value like [34, 255, 60] | ||
secondary_color_lookup = None | ||
|
||
[Bot Loadout Orange] | ||
# Primary Color selection | ||
team_color_id = 67 | ||
# Secondary Color selection | ||
custom_color_id = 90 | ||
# Car type (Octane, Merc, etc) | ||
car_id = 5713 | ||
# Type of decal | ||
decal_id = 5127 | ||
# Wheel selection | ||
wheels_id = 5183 | ||
# Boost selection | ||
boost_id = 5860 | ||
# Antenna Selection | ||
antenna_id = 0 | ||
# Hat Selection | ||
hat_id = 0 | ||
# Paint Type (for first color) | ||
paint_finish_id = 272 | ||
# Paint Type (for secondary color) | ||
custom_finish_id = 2973 | ||
# Engine Audio Selection | ||
engine_audio_id = 0 | ||
# Car trail Selection | ||
trails_id = 3575 | ||
# Goal Explosion Selection | ||
goal_explosion_id = 3131 | ||
# Finds the closest primary color swatch based on the provided RGB value like [34, 255, 60] | ||
primary_color_lookup = None | ||
# Finds the closest secondary color swatch based on the provided RGB value like [34, 255, 60] | ||
secondary_color_lookup = None | ||
|
||
[Bot Paint Blue] | ||
# car_paint_id | ||
car_paint_id = 3 | ||
# decal_paint_id | ||
decal_paint_id = 12 | ||
# wheels_paint_id | ||
wheels_paint_id = 12 | ||
# boost_paint_id | ||
boost_paint_id = 3 | ||
# antenna_paint_id | ||
antenna_paint_id = 0 | ||
# hat_paint_id | ||
hat_paint_id = 0 | ||
# trails_paint_id | ||
trails_paint_id = 0 | ||
# goal_explosion_paint_id | ||
goal_explosion_paint_id = 12 | ||
|
||
[Bot Paint Orange] | ||
# car_paint_id | ||
car_paint_id = 3 | ||
# decal_paint_id | ||
decal_paint_id = 12 | ||
# wheels_paint_id | ||
wheels_paint_id = 12 | ||
# boost_paint_id | ||
boost_paint_id = 3 | ||
# antenna_paint_id | ||
antenna_paint_id = 0 | ||
# hat_paint_id | ||
hat_paint_id = 0 | ||
# trails_paint_id | ||
trails_paint_id = 0 | ||
# goal_explosion_paint_id | ||
goal_explosion_paint_id = 12 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
[Locations] | ||
# Path to loadout config. Can use relative path from here. | ||
looks_config = ./appearance.cfg | ||
|
||
# Path to python file. Can use relative path from here. | ||
python_file = ./bot.py | ||
requirements_file = ./requirements.txt | ||
|
||
# Name of the bot in-game | ||
name = ImmortalETSET | ||
|
||
# The maximum number of ticks per second that your bot wishes to receive. | ||
maximum_tick_rate_preference = 120 | ||
|
||
[Details] | ||
# These values are optional but useful metadata for helper programs | ||
# Name of the bot's creator/developer | ||
developer = CosmicVivacity | ||
|
||
# Short description of the bot | ||
description = Prepare to get decimated from the air | ||
|
||
# Fun fact about the bot | ||
fun_fact = This is a bot | ||
|
||
# Programming language | ||
language = rlgym |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import numpy as np | ||
import torch | ||
from rlbot.agents.base_agent import BaseAgent, SimpleControllerState | ||
from rlbot.utils.structures.game_data_struct import GameTickPacket | ||
from rlgym_compat import GameState | ||
from obs.advanced_obs import ExpandAdvancedObs | ||
from action.actionparser import ImmortalAction | ||
|
||
from agent import Agent | ||
|
||
|
||
class RLGymExampleBot(BaseAgent): | ||
def __init__(self, name, team, index): | ||
super().__init__(name, team, index) | ||
|
||
# Swap the obs builder if you are using a different one, RLGym's AdvancedObs is also available | ||
self.obs_builder = ExpandAdvancedObs() | ||
# Swap the action parser if you are using a different one, RLGym's Discrete and Continuous are also available | ||
self.act_parser = ImmortalAction() | ||
# Your neural network logic goes inside the Agent class, go take a look inside src/agent.py | ||
self.agent = Agent() | ||
# Adjust the tickskip if your agent was trained with a different value | ||
self.tick_skip = 6 | ||
|
||
self.game_state: GameState = None | ||
self.controls = None | ||
self.action = None | ||
self.update_action = True | ||
self.ticks = 0 | ||
self.prev_time = 0 | ||
print('Immortal - Index:', index) | ||
|
||
def initialize_agent(self): | ||
# Initialize the rlgym GameState object now that the game is active and the info is available | ||
self.game_state = GameState(self.get_field_info()) | ||
self.ticks = self.tick_skip # So we take an action the first tick | ||
self.prev_time = 0 | ||
self.controls = SimpleControllerState() | ||
self.action = np.zeros(8) | ||
self.update_action = True | ||
|
||
def get_output(self, packet: GameTickPacket) -> SimpleControllerState: | ||
cur_time = packet.game_info.seconds_elapsed | ||
delta = cur_time - self.prev_time | ||
self.prev_time = cur_time | ||
|
||
ticks_elapsed = round(delta * 120) | ||
self.ticks += ticks_elapsed | ||
self.game_state.decode(packet, ticks_elapsed) | ||
|
||
if self.update_action: | ||
self.update_action = False | ||
|
||
# By default we treat every match as a 1v1 against a fixed opponent, | ||
# by doing this your bot can participate in 2v2 or 3v3 matches. Feel free to change this | ||
player = self.game_state.players[self.index] | ||
teammates = [p for p in self.game_state.players if p.team_num == self.team] | ||
opponents = [p for p in self.game_state.players if p.team_num != self.team] | ||
|
||
if len(opponents) == 0: | ||
# There's no opponent, we assume this model is 1v0 | ||
self.game_state.players = [player] | ||
else: | ||
# Sort by distance to ball | ||
teammates.sort(key=lambda p: np.linalg.norm(self.game_state.ball.position - p.car_data.position)) | ||
opponents.sort(key=lambda p: np.linalg.norm(self.game_state.ball.position - p.car_data.position)) | ||
|
||
# Grab opponent in same "position" relative to it's teammates | ||
opponent = opponents[min(teammates.index(player), len(opponents) - 1)] | ||
|
||
self.game_state.players = [player, opponent] | ||
|
||
obs = self.obs_builder.build_obs(player, self.game_state, self.action) | ||
self.action = self.act_parser.parse_actions(self.agent.act(obs)) # Dim is (N, 8) | ||
|
||
if self.ticks >= self.tick_skip - 1: | ||
self.update_controls(self.action) | ||
|
||
if self.ticks >= self.tick_skip: | ||
self.ticks = 0 | ||
self.update_action = True | ||
|
||
return self.controls | ||
|
||
def update_controls(self, action): | ||
self.controls.throttle = action[0] | ||
self.controls.steer = action[1] | ||
self.controls.pitch = action[2] | ||
self.controls.yaw = 0 if action[5] > 0 else action[3] | ||
self.controls.roll = action[4] | ||
self.controls.jump = action[5] > 0 | ||
self.controls.boost = action[6] > 0 | ||
self.controls.handbrake = action[7] > 0 |
Binary file not shown.
Oops, something went wrong.