Skip to content

Commit

Permalink
fix: 🐛 fix incompatible parameters when generating Adam optimizers on…
Browse files Browse the repository at this point in the history
… Torch 1.13.

Fix incompatible parameters:
- fused is only available if cuda is available
- fused is incompatible with differentiable
  • Loading branch information
qthequartermasterman committed May 12, 2024
1 parent f166a9e commit d0ef6bb
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions hypothesis_torch/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def betas(draw: st.DrawFn) -> tuple[float, float]:
"dampening": _ZERO_TO_ONE_FLOATS,
"nesterov": st.booleans(),
"initial_accumulator_value": _ZERO_TO_ONE_FLOATS,
"fused": st.booleans() if torch.cuda.is_available() else st.just(False),
}


Expand Down Expand Up @@ -112,6 +113,10 @@ def optimizer_strategy(
kwargs.pop("self", None) # Remove self if a type was inferred
kwargs.pop("params", None) # Remove params if a type was inferred

# Adam cannot be both fused and differentiable simultaneously
if "differentiable" in kwargs and kwargs["differentiable"] and "fused" in kwargs and kwargs["fused"]:
kwargs.pop("differentiable")

hypothesis.note(f"Chosen optimizer type: {optimizer_type}")
hypothesis.note(f"Chosen optimizer hyperparameters: {kwargs}")

Expand Down

5 comments on commit d0ef6bb

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests Skipped Failures Errors Time
1256 1209 💤 0 ❌ 0 🔥 59.452s ⏱️

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests Skipped Failures Errors Time
1256 1205 💤 0 ❌ 0 🔥 1m 8s ⏱️

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests Skipped Failures Errors Time
1256 1205 💤 0 ❌ 0 🔥 1m 14s ⏱️

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests Skipped Failures Errors Time
1256 1205 💤 0 ❌ 0 🔥 1m 11s ⏱️

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests Skipped Failures Errors Time
1249 1198 💤 0 ❌ 0 🔥 1m 14s ⏱️

Please sign in to comment.