Skip to content
This repository has been archived by the owner on Jul 31, 2024. It is now read-only.

Commit

Permalink
Update type annotations and imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhijian Liu committed Sep 26, 2021
1 parent 7b06a1e commit 2873f68
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 10 deletions.
4 changes: 3 additions & 1 deletion torchpack/distributed/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ def allreduce(data: Any, reduction: str = 'sum') -> Any:
data = allgather(data)
if reduction == 'sum':
return sum(data)
else:
raise NotImplementedError(reduction)


def allgather(data: Any) -> List:
def allgather(data: Any) -> List[Any]:
world_size = context.size()
if world_size == 1:
return [data]
Expand Down
2 changes: 1 addition & 1 deletion torchpack/launch/launchers/drunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
__all__ = ['main']


def is_exportable(v):
def is_exportable(v) -> bool:
IGNORE_REGEXES = ['BASH_FUNC_.*', 'OLDPWD']
return not any(re.match(r, v) for r in IGNORE_REGEXES)

Expand Down
15 changes: 9 additions & 6 deletions torchpack/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from torchpack.callbacks import (Callback, Callbacks, ConsoleWriter,
EstimatedTimeLeft, JSONLWriter, MetaInfoSaver,
ProgressBar, TFEventWriter)
from torchpack.train.exception import StopTraining
from torchpack.train.summary import Summary
from torchpack.utils import humanize
from torchpack.utils.logging import logger

from .exception import StopTraining
from .summary import Summary

__all__ = ['Trainer']


Expand All @@ -35,10 +36,12 @@ def train_with_defaults(
ProgressBar(),
EstimatedTimeLeft()
]
self.train(dataflow=dataflow,
num_epochs=num_epochs,
steps_per_epoch=steps_per_epoch,
callbacks=callbacks)
self.train(
dataflow=dataflow,
num_epochs=num_epochs,
steps_per_epoch=steps_per_epoch,
callbacks=callbacks,
)

def train(
self,
Expand Down
2 changes: 1 addition & 1 deletion torchpack/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from multimethod import multimethod

from . import io
from torchpack.utils import io

__all__ = ['Config', 'configs']

Expand Down
2 changes: 1 addition & 1 deletion torchpack/utils/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class NameMatcher:

def __init__(self, patterns: Optional[Union[str, List[str]]]):
def __init__(self, patterns: Optional[Union[str, List[str]]]) -> None:
if patterns is None:
patterns = []
elif isinstance(patterns, str):
Expand Down

0 comments on commit 2873f68

Please sign in to comment.