Skip to content

Commit

Permalink
Moved utils into separate file
Browse files Browse the repository at this point in the history
  • Loading branch information
ancestor-mithril committed Dec 5, 2024
1 parent b711ee6 commit d48461a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
18 changes: 1 addition & 17 deletions bs_scheduler/batch_size_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,7 @@
'CosineAnnealingBSWithWarmRestarts', 'OneCycleBS', 'BSScheduler']

from .batch_size_manager import BatchSizeManager, DefaultBatchSizeManager, CustomBatchSizeManager
from .utils import check_isinstance


def rint(x: float) -> int:
""" Rounds to the nearest int and returns the value as int.
"""
return int(round(x))


def clip(x: int, min_x: int, max_x: int) -> int:
""" Clips x to [min, max] interval.
"""
if x < min_x:
return min_x
if x > max_x:
return max_x
return x
from .utils import check_isinstance, clip, rint


class BSScheduler:
Expand Down
16 changes: 16 additions & 0 deletions bs_scheduler/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
def check_isinstance(x, instance: type):
if not isinstance(x, instance):
raise TypeError(f"{type(x).__name__} is not a {instance.__name__}.")


def rint(x: float) -> int:
""" Rounds to the nearest int and returns the value as int.
"""
return int(round(x))


def clip(x: int, min_x: int, max_x: int) -> int:
""" Clips x to [min, max] interval.
"""
if x < min_x:
return min_x
if x > max_x:
return max_x
return x

0 comments on commit d48461a

Please sign in to comment.