diff --git a/torchpack/distributed/comm.py b/torchpack/distributed/comm.py index e5cdc3d..43c523f 100644 --- a/torchpack/distributed/comm.py +++ b/torchpack/distributed/comm.py @@ -13,9 +13,11 @@ def allreduce(data: Any, reduction: str = 'sum') -> Any: data = allgather(data) if reduction == 'sum': return sum(data) + else: + raise NotImplementedError(reduction) -def allgather(data: Any) -> List: +def allgather(data: Any) -> List[Any]: world_size = context.size() if world_size == 1: return [data] diff --git a/torchpack/launch/launchers/drunner.py b/torchpack/launch/launchers/drunner.py index 8d6225c..bb80430 100644 --- a/torchpack/launch/launchers/drunner.py +++ b/torchpack/launch/launchers/drunner.py @@ -9,7 +9,7 @@ __all__ = ['main'] -def is_exportable(v): +def is_exportable(v) -> bool: IGNORE_REGEXES = ['BASH_FUNC_.*', 'OLDPWD'] return not any(re.match(r, v) for r in IGNORE_REGEXES) diff --git a/torchpack/train/trainer.py b/torchpack/train/trainer.py index d5a77c6..5fe2360 100644 --- a/torchpack/train/trainer.py +++ b/torchpack/train/trainer.py @@ -6,11 +6,12 @@ from torchpack.callbacks import (Callback, Callbacks, ConsoleWriter, EstimatedTimeLeft, JSONLWriter, MetaInfoSaver, ProgressBar, TFEventWriter) -from torchpack.train.exception import StopTraining -from torchpack.train.summary import Summary from torchpack.utils import humanize from torchpack.utils.logging import logger +from .exception import StopTraining +from .summary import Summary + __all__ = ['Trainer'] @@ -35,10 +36,12 @@ def train_with_defaults( ProgressBar(), EstimatedTimeLeft() ] - self.train(dataflow=dataflow, - num_epochs=num_epochs, - steps_per_epoch=steps_per_epoch, - callbacks=callbacks) + self.train( + dataflow=dataflow, + num_epochs=num_epochs, + steps_per_epoch=steps_per_epoch, + callbacks=callbacks, + ) def train( self, diff --git a/torchpack/utils/config.py b/torchpack/utils/config.py index 799ea0d..f06d410 100644 --- a/torchpack/utils/config.py +++ b/torchpack/utils/config.py @@ -6,7 +6,7 @@ from multimethod import multimethod -from . import io +from torchpack.utils import io __all__ = ['Config', 'configs'] diff --git a/torchpack/utils/matching.py b/torchpack/utils/matching.py index f06fc41..4474611 100644 --- a/torchpack/utils/matching.py +++ b/torchpack/utils/matching.py @@ -6,7 +6,7 @@ class NameMatcher: - def __init__(self, patterns: Optional[Union[str, List[str]]]): + def __init__(self, patterns: Optional[Union[str, List[str]]]) -> None: if patterns is None: patterns = [] elif isinstance(patterns, str):