Skip to content

Commit

Permalink
feat: added typing in Ratio class
Browse files Browse the repository at this point in the history
  • Loading branch information
michele-milesi committed Mar 29, 2024
1 parent beff471 commit 9be5304
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions sheeprl/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,12 @@ def __init__(self, ratio: float, pretrain_steps: int = 0):
self._ratio = ratio
self._prev = None

def __call__(self, step) -> int:
step = int(step)
def __call__(self, step: int) -> int:
if self._ratio == 0:
return 0
if self._prev is None:
self._prev = step
repeats = 1
if self._pretrain_steps > 0:
if step < self._pretrain_steps:
warnings.warn(
Expand All @@ -226,17 +226,17 @@ def __call__(self, step) -> int:
"the number of current steps."
)
self._pretrain_steps = step
return round(self._pretrain_steps * self._ratio)
else:
return 1
repeats = round(self._pretrain_steps * self._ratio)
return repeats
repeats = round((step - self._prev) * self._ratio)
self._prev += repeats / self._ratio
return repeats

def state_dict(self) -> Dict[str, Any]:
return {"_ratio": self._ratio, "_prev": self._prev}
return {"_ratio": self._ratio, "_prev": self._prev, "_pretrain_steps": self._pretrain_steps}

def load_state_dict(self, state_dict: Mapping[str, Any]):
self._ratio = state_dict["_ratio"]
self._prev = state_dict["_prev"]
self._pretrain_steps = state_dict["_pretrain_steps"]
return self

0 comments on commit 9be5304

Please sign in to comment.