Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixed double precision for PPO_RNN algorithm #172

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 54 additions & 39 deletions skrl/agents/torch/ppo/ppo_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
"rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward
"time_limit_bootstrap": False, # bootstrap at timeout termination (episode truncation)

"mixed_precision": False, # mixed torch.float32 and torch.float16 precision for higher performance

"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
Expand Down Expand Up @@ -151,6 +153,12 @@ def __init__(self,
self._rewards_shaper = self.cfg["rewards_shaper"]
self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"]

self._mixed_precision = self.cfg["mixed_precision"]

# set up automatic mixed precision
self._device_type = torch.device(device).type
self._scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision)

# set up optimizer and learning rate scheduler
if self.policy is not None and self.value is not None:
if self.policy is self.value:
Expand Down Expand Up @@ -301,9 +309,10 @@ def record_transition(self,
rewards = self._rewards_shaper(rewards, timestep, timesteps)

# compute values
rnn = {"rnn": self._rnn_initial_states["value"]} if self._rnn else {}
values, _, outputs = self.value.act({"states": self._state_preprocessor(states), **rnn}, role="value")
values = self._value_preprocessor(values, inverse=True)
with torch.autocast(device_type=self._device_type, enabled=(self._mixed_precision)):
rnn = {"rnn": self._rnn_initial_states["value"]} if self._rnn else {}
values, _, outputs = self.value.act({"states": self._state_preprocessor(states), **rnn}, role="value")
values = self._value_preprocessor(values, inverse=True)

# time-limit (truncation) boostrapping
if self._time_limit_bootstrap:
Expand Down Expand Up @@ -452,63 +461,69 @@ def compute_gae(rewards: torch.Tensor,
# mini-batches loop
for i, (sampled_states, sampled_actions, sampled_dones, sampled_log_prob, sampled_values, sampled_returns, sampled_advantages) in enumerate(sampled_batches):

if self._rnn:
if self.policy is self.value:
rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn_batches[i]], "terminated": sampled_dones}
rnn_value = rnn_policy
else:
rnn_policy = {"rnn": [s.transpose(0, 1) for s, n in zip(sampled_rnn_batches[i], self._rnn_tensors_names) if "policy" in n], "terminated": sampled_dones}
rnn_value = {"rnn": [s.transpose(0, 1) for s, n in zip(sampled_rnn_batches[i], self._rnn_tensors_names) if "value" in n], "terminated": sampled_dones}
with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision):

sampled_states = self._state_preprocessor(sampled_states, train=not epoch)
if self._rnn:
if self.policy is self.value:
rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn_batches[i]], "terminated": sampled_dones}
rnn_value = rnn_policy
else:
rnn_policy = {"rnn": [s.transpose(0, 1) for s, n in zip(sampled_rnn_batches[i], self._rnn_tensors_names) if "policy" in n], "terminated": sampled_dones}
rnn_value = {"rnn": [s.transpose(0, 1) for s, n in zip(sampled_rnn_batches[i], self._rnn_tensors_names) if "value" in n], "terminated": sampled_dones}

_, next_log_prob, _ = self.policy.act({"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="policy")
sampled_states = self._state_preprocessor(sampled_states, train=not epoch)

# compute approximate KL divergence
with torch.no_grad():
ratio = next_log_prob - sampled_log_prob
kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean()
kl_divergences.append(kl_divergence)
_, next_log_prob, _ = self.policy.act({"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="policy")

# early stopping with KL divergence
if self._kl_threshold and kl_divergence > self._kl_threshold:
break
# compute approximate KL divergence
with torch.no_grad():
ratio = next_log_prob - sampled_log_prob
kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean()
kl_divergences.append(kl_divergence)

# compute entropy loss
if self._entropy_loss_scale:
entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean()
else:
entropy_loss = 0
# early stopping with KL divergence
if self._kl_threshold and kl_divergence > self._kl_threshold:
break

# compute entropy loss
if self._entropy_loss_scale:
entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean()
else:
entropy_loss = 0

# compute policy loss
ratio = torch.exp(next_log_prob - sampled_log_prob)
surrogate = sampled_advantages * ratio
surrogate_clipped = sampled_advantages * torch.clip(ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip)
# compute policy loss
ratio = torch.exp(next_log_prob - sampled_log_prob)
surrogate = sampled_advantages * ratio
surrogate_clipped = sampled_advantages * torch.clip(ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip)

policy_loss = -torch.min(surrogate, surrogate_clipped).mean()
policy_loss = -torch.min(surrogate, surrogate_clipped).mean()

# compute value loss
predicted_values, _, _ = self.value.act({"states": sampled_states, **rnn_value}, role="value")
# compute value loss
predicted_values, _, _ = self.value.act({"states": sampled_states, **rnn_value}, role="value")

if self._clip_predicted_values:
predicted_values = sampled_values + torch.clip(predicted_values - sampled_values,
min=-self._value_clip,
max=self._value_clip)
value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values)
if self._clip_predicted_values:
predicted_values = sampled_values + torch.clip(predicted_values - sampled_values,
min=-self._value_clip,
max=self._value_clip)
value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values)

# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
self._scaler.scale(policy_loss + entropy_loss + value_loss).backward()

if config.torch.is_distributed:
self.policy.reduce_parameters()
if self.policy is not self.value:
self.value.reduce_parameters()

if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
else:
nn.utils.clip_grad_norm_(itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip)
self.optimizer.step()

self._scaler.step(self.optimizer)
self._scaler.update()

# update cumulative losses
cumulative_policy_loss += policy_loss.item()
Expand Down
Loading