diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index dd6ca0a61..b1a8fbea2 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -177,7 +177,7 @@ def load_state_dict(self, state_dict, move_to_device=True): raise ValueError("loaded state dict has a different number of parameter groups") param_lens = (len(g["params"]) for g in groups) saved_lens = (len(g["params"]) for g in saved_groups) - if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens, strict=True)): + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): raise ValueError( "loaded state dict contains a parameter group that doesn't match the size of optimizer's group", ) @@ -188,7 +188,6 @@ def load_state_dict(self, state_dict, move_to_device=True): for old_id, p in zip( chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups), - strict=True, ) } @@ -230,7 +229,7 @@ def update_group(group, new_group): new_group["params"] = group["params"] return new_group - param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups, strict=True)] + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] self.__setstate__({"state": state, "param_groups": param_groups}) def to_gpu(self):