Skip to content

Commit

Permalink
Run pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
noahfarr committed Aug 28, 2024
1 parent 180e302 commit 3f0cd76
Showing 1 changed file with 23 additions and 64 deletions.
87 changes: 23 additions & 64 deletions cleanrl/td3_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ class QNetwork(nn.Module):
def __init__(self, env):
super().__init__()
self.fc1 = nn.Linear(
np.array(env.single_observation_space.shape).prod()
+ np.prod(env.single_action_space.shape),
np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape),
256,
)
self.fc2 = nn.Linear(256, 256)
Expand Down Expand Up @@ -159,8 +158,7 @@ def forward(self, x):
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s"
% ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)

# TRY NOT TO MODIFY: seeding
Expand All @@ -173,14 +171,9 @@ def forward(self, x):

# env setup
envs = gym.vector.SyncVectorEnv(
[
make_env(args.env_id, args.seed + i, i, args.capture_video, run_name)
for i in range(args.num_envs)
]
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
)
assert isinstance(
envs.single_action_space, gym.spaces.Box
), "only continuous action space is supported"
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"

actor = Actor(envs).to(device)
qf1 = QNetwork(envs).to(device)
Expand All @@ -191,9 +184,7 @@ def forward(self, x):
target_actor.load_state_dict(actor.state_dict())
qf1_target.load_state_dict(qf1.state_dict())
qf2_target.load_state_dict(qf2.state_dict())
q_optimizer = optim.Adam(
list(qf1.parameters()) + list(qf2.parameters()), lr=args.learning_rate
)
q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.learning_rate)
actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.learning_rate)

envs.single_observation_space.dtype = np.float32
Expand All @@ -212,18 +203,12 @@ def forward(self, x):
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
if global_step < args.learning_starts:
actions = np.array(
[envs.single_action_space.sample() for _ in range(envs.num_envs)]
)
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
else:
with torch.no_grad():
actions = actor(torch.Tensor(obs).to(device))
actions += torch.normal(0, actor.action_scale * args.exploration_noise)
actions = (
actions.cpu()
.numpy()
.clip(envs.single_action_space.low, envs.single_action_space.high)
)
actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminations, truncations, infos = envs.step(actions)
Expand All @@ -232,15 +217,9 @@ def forward(self, x):
if "final_info" in infos:
for info in infos["final_info"]:
if info is not None:
print(
f"global_step={global_step}, episodic_return={info['episode']['r']}"
)
writer.add_scalar(
"charts/episodic_return", info["episode"]["r"], global_step
)
writer.add_scalar(
"charts/episodic_length", info["episode"]["l"], global_step
)
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
Expand All @@ -257,21 +236,17 @@ def forward(self, x):
if global_step > args.learning_starts:
data = rb.sample(args.batch_size)
with torch.no_grad():
clipped_noise = (
torch.randn_like(data.actions, device=device) * args.policy_noise
).clamp(-args.noise_clip, args.noise_clip) * target_actor.action_scale
clipped_noise = (torch.randn_like(data.actions, device=device) * args.policy_noise).clamp(
-args.noise_clip, args.noise_clip
) * target_actor.action_scale

next_state_actions = (
target_actor(data.next_observations) + clipped_noise
).clamp(
next_state_actions = (target_actor(data.next_observations) + clipped_noise).clamp(
envs.single_action_space.low[0], envs.single_action_space.high[0]
)
qf1_next_target = qf1_target(data.next_observations, next_state_actions)
qf2_next_target = qf2_target(data.next_observations, next_state_actions)
min_qf_next_target = torch.min(qf1_next_target, qf2_next_target)
next_q_value = data.rewards.flatten() + (
1 - data.dones.flatten()
) * args.gamma * (min_qf_next_target).view(-1)
next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1)

qf1_a_values = qf1(data.observations, data.actions).view(-1)
qf2_a_values = qf2(data.observations, data.actions).view(-1)
Expand All @@ -291,32 +266,16 @@ def forward(self, x):
actor_optimizer.step()

# update the target network
for param, target_param in zip(
actor.parameters(), target_actor.parameters()
):
target_param.data.copy_(
args.tau * param.data + (1 - args.tau) * target_param.data
)
for param, target_param in zip(
qf1.parameters(), qf1_target.parameters()
):
target_param.data.copy_(
args.tau * param.data + (1 - args.tau) * target_param.data
)
for param, target_param in zip(
qf2.parameters(), qf2_target.parameters()
):
target_param.data.copy_(
args.tau * param.data + (1 - args.tau) * target_param.data
)
for param, target_param in zip(actor.parameters(), target_actor.parameters()):
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)

if global_step % 100 == 0:
writer.add_scalar(
"losses/qf1_values", qf1_a_values.mean().item(), global_step
)
writer.add_scalar(
"losses/qf2_values", qf2_a_values.mean().item(), global_step
)
writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step)
writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step)
writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step)
writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step)
writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)
Expand Down

0 comments on commit 3f0cd76

Please sign in to comment.