Skip to content

Commit

Permalink
Fix autocast in processing step.
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Pfreundschuh committed Oct 8, 2024
1 parent 3b8a6f2 commit 531697c
Showing 1 changed file with 2 additions and 13 deletions.
15 changes: 2 additions & 13 deletions chimp/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,24 +153,13 @@ def predict_fun(x_t):
x_t["lead_time"] = lead_time

with torch.no_grad():
if device != "cpu":
with torch.autocast(device_type="cuda", dtype=float_type):
y_pred = model(x_t)
for key, y_pred_k in y_pred.items():
for step, y_pred_k_s in enumerate(iter_tensors(y_pred_k)):
results_step = results.setdefault(step, {})
y_mean_k_s = y_pred_k_s.expected_value()[0, 0]
results_step[key + "_mean"] = y_mean_k_s.cpu().numpy()
if key in quantile_outputs:
results_step[key + "_cdf"] = y_pred_k_s.cpu().float().numpy()[0, :, 0]
else:
with torch.autocast(device_type=device.type, dtype=float_type):
y_pred = model(x_t)
for key, y_pred_k in y_pred.items():
for step, y_pred_k_s in enumerate(iter_tensors(y_pred_k)):
results_step = results.setdefault(step, {})
y_mean_k_s = y_pred_k_s.expected_value()[0, 0]
results_step[key + "_mean"] = y_mean_k.cpu().float().numpy()

results_step[key + "_mean"] = y_mean_k_s.cpu().numpy()
if key in quantile_outputs:
results_step[key + "_cdf"] = y_pred_k_s.cpu().float().numpy()[0, :, 0]
return results
Expand Down

0 comments on commit 531697c

Please sign in to comment.