From d0ef6bb44b935879864915c15496986afafc9375 Mon Sep 17 00:00:00 2001 From: qthequartermasterman Date: Sat, 11 May 2024 21:15:13 -0500 Subject: [PATCH] fix: :bug: fix incompatible parameters when generating Adam optimizers on Torch 1.13. Fix incompatible parameters: - fused is only available if cuda is available - fused is incompatible with differentiable --- hypothesis_torch/optim.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/hypothesis_torch/optim.py b/hypothesis_torch/optim.py index 6b83cd1..daf4d59 100644 --- a/hypothesis_torch/optim.py +++ b/hypothesis_torch/optim.py @@ -57,6 +57,7 @@ def betas(draw: st.DrawFn) -> tuple[float, float]: "dampening": _ZERO_TO_ONE_FLOATS, "nesterov": st.booleans(), "initial_accumulator_value": _ZERO_TO_ONE_FLOATS, + "fused": st.booleans() if torch.cuda.is_available() else st.just(False), } @@ -112,6 +113,10 @@ def optimizer_strategy( kwargs.pop("self", None) # Remove self if a type was inferred kwargs.pop("params", None) # Remove params if a type was inferred + # Adam cannot be both fused and differentiable simultaneously + if "differentiable" in kwargs and kwargs["differentiable"] and "fused" in kwargs and kwargs["fused"]: + kwargs.pop("differentiable") + hypothesis.note(f"Chosen optimizer type: {optimizer_type}") hypothesis.note(f"Chosen optimizer hyperparameters: {kwargs}")