Skip to content

Commit

Permalink
yapf -> black (#406)
Browse files Browse the repository at this point in the history
* yapf >> black
* fmt randn

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] authored Aug 2, 2021
1 parent 4fd18b3 commit d8b89e0
Show file tree
Hide file tree
Showing 147 changed files with 1,329 additions and 1,344 deletions.
2 changes: 1 addition & 1 deletion .github/prune-packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def main(req_file: str, *pkgs):
lines = [ln for ln in lines if not ln.startswith(pkg)]
pprint(lines)

with open(req_file, 'w') as fp:
with open(req_file, "w") as fp:
fp.writelines(lines)


Expand Down
20 changes: 10 additions & 10 deletions .github/set-minimal-versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,21 @@
import sys

LUT_PYTHON_TORCH = {
'3.8': '1.4',
'3.9': '1.7.1',
"3.8": "1.4",
"3.9": "1.7.1",
}
REQUIREMENTS_FILES = ('requirements.txt', ) + tuple(glob.glob(os.path.join('requirements', '*.txt')))
REQUIREMENTS_FILES = ("requirements.txt",) + tuple(glob.glob(os.path.join("requirements", "*.txt")))


def set_min_torch_by_python(fpath: str = 'requirements.txt') -> None:
def set_min_torch_by_python(fpath: str = "requirements.txt") -> None:
"""set minimal torch version"""
py_ver = f'{sys.version_info.major}.{sys.version_info.minor}'
py_ver = f"{sys.version_info.major}.{sys.version_info.minor}"
if py_ver not in LUT_PYTHON_TORCH:
return
with open(fpath) as fp:
req = fp.read()
req = re.sub(r'torch>=[\d\.]+', f'torch>={LUT_PYTHON_TORCH[py_ver]}', req)
with open(fpath, 'w') as fp:
req = re.sub(r"torch>=[\d\.]+", f"torch>={LUT_PYTHON_TORCH[py_ver]}", req)
with open(fpath, "w") as fp:
fp.write(req)


Expand All @@ -28,12 +28,12 @@ def replace_min_requirements(fpath: str) -> None:
logging.info(f"processing: {fpath}")
with open(fpath) as fp:
req = fp.read()
req = req.replace('>=', '==')
with open(fpath, 'w') as fp:
req = req.replace(">=", "==")
with open(fpath, "w") as fp:
fp.write(req)


if __name__ == '__main__':
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
set_min_torch_by_python()
for fpath in REQUIREMENTS_FILES:
Expand Down
30 changes: 0 additions & 30 deletions .github/workflows/code-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,33 +38,3 @@ jobs:
- name: mypy
run: |
mypy --show-error-codes --warn-unused-configs
# imports-check-isort:
# runs-on: ubuntu-20.04
# steps:
# - uses: actions/checkout@master
# - uses: actions/setup-python@v2
# with:
# python-version: 3.8
# - name: Install isort
# run: |
# pip install "isort==5.6.4"
# pip list
# - name: isort
# run: |
# isort --settings-path=./pyproject.toml . --check --diff

# format-check-yapf:
# runs-on: ubuntu-20.04
# steps:
# - uses: actions/checkout@master
# - uses: actions/setup-python@v2
# with:
# python-version: 3.8
# - name: Install dependencies
# run: |
# pip install "yapf==0.30"
# pip list
# shell: bash
# - name: yapf
# run: yapf --diff --parallel --recursive .
24 changes: 11 additions & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,25 @@ repos:
args: ['--maxkb=250', '--enforce-all']
- id: detect-private-key

- repo: https://github.com/asottile/pyupgrade
rev: v2.23.0
hooks:
- id: pyupgrade
args: [--py36-plus]
name: Upgrade code

- repo: https://github.com/PyCQA/isort
rev: 5.9.2
hooks:
- id: isort
name: imports
require_serial: false

- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.31.0
- repo: https://github.com/psf/black
rev: 21.7b0
hooks:
- id: yapf
name: formatting
language: python
require_serial: false

- repo: https://github.com/asottile/pyupgrade
rev: v2.23.0
hooks:
- id: pyupgrade
args: [--py36-plus]
name: Upgrade code
- id: black
name: Format code

- repo: https://github.com/executablebooks/mdformat
rev: 0.7.7
Expand Down
41 changes: 20 additions & 21 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@
_PATH_ROOT = os.path.realpath(os.path.join(_PATH_HERE, "..", ".."))
sys.path.insert(0, os.path.abspath(_PATH_ROOT))

FOLDER_GENERATED = 'generated'
FOLDER_GENERATED = "generated"
SPHINX_MOCK_REQUIREMENTS = int(os.environ.get("SPHINX_MOCK_REQUIREMENTS", True))

# alternative https://stackoverflow.com/a/67692/4521646
spec = spec_from_file_location("torchmetrics/__about__.py", os.path.join(_PATH_ROOT, "torchmetrics", "__about__.py"))
about = module_from_spec(spec)
spec.loader.exec_module(about)

html_favicon = '_static/images/icon.svg'
html_favicon = "_static/images/icon.svg"

# -- Project information -----------------------------------------------------

Expand All @@ -58,25 +58,25 @@ def _transform_changelog(path_in: str, path_out: str) -> None:
with open(path_in) as fp:
chlog_lines = fp.readlines()
# enrich short subsub-titles to be unique
chlog_ver = ''
chlog_ver = ""
for i, ln in enumerate(chlog_lines):
if ln.startswith('## '):
chlog_ver = ln[2:].split('-')[0].strip()
elif ln.startswith('### '):
ln = ln.replace('###', f'### {chlog_ver} -')
if ln.startswith("## "):
chlog_ver = ln[2:].split("-")[0].strip()
elif ln.startswith("### "):
ln = ln.replace("###", f"### {chlog_ver} -")
chlog_lines[i] = ln
with open(path_out, 'w') as fp:
with open(path_out, "w") as fp:
fp.writelines(chlog_lines)


os.makedirs(os.path.join(_PATH_HERE, FOLDER_GENERATED), exist_ok=True)
# copy all documents from GH templates like contribution guide
for md in glob.glob(os.path.join(_PATH_ROOT, '.github', '*.md')):
for md in glob.glob(os.path.join(_PATH_ROOT, ".github", "*.md")):
shutil.copy(md, os.path.join(_PATH_HERE, FOLDER_GENERATED, os.path.basename(md)))
# copy also the changelog
_transform_changelog(
os.path.join(_PATH_ROOT, 'CHANGELOG.md'),
os.path.join(_PATH_HERE, FOLDER_GENERATED, 'CHANGELOG.md'),
os.path.join(_PATH_ROOT, "CHANGELOG.md"),
os.path.join(_PATH_HERE, FOLDER_GENERATED, "CHANGELOG.md"),
)

# -- General configuration ---------------------------------------------------
Expand Down Expand Up @@ -161,14 +161,14 @@ def _transform_changelog(path_in: str, path_out: str) -> None:
# documentation.

html_theme_options = {
'pytorch_project': 'https://pytorchlightning.ai',
'canonical_url': about.__docs_url__,
"pytorch_project": "https://pytorchlightning.ai",
"canonical_url": about.__docs_url__,
"collapse_navigation": False,
"display_version": True,
"logo_only": False,
}

html_logo = '_static/images/logo.svg'
html_logo = "_static/images/logo.svg"

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
Expand Down Expand Up @@ -292,7 +292,7 @@ def package_list_from_file(file):
with open(file) as fp:
for ln in fp.readlines():
found = [ln.index(ch) for ch in list(",=<>#") if ch in ln]
pkg = ln[:min(found)] if found else ln
pkg = ln[: min(found)] if found else ln
if pkg.rstrip():
mocked_packages.append(pkg.rstrip())
return mocked_packages
Expand All @@ -315,7 +315,6 @@ def package_list_from_file(file):
# Resolve function
# This function is used to populate the (source) links in the API
def linkcode_resolve(domain, info):

def find_source():
# try to find the file and line number, based on code from numpy:
# https://github.com/numpy/numpy/blob/master/doc/source/conf.py#L286
Expand Down Expand Up @@ -353,15 +352,15 @@ def find_source():

autosummary_generate = True

autodoc_member_order = 'groupwise'
autodoc_member_order = "groupwise"

autoclass_content = 'both'
autoclass_content = "both"

autodoc_default_options = {
'members': True,
"members": True,
# 'methods': True,
'special-members': '__call__',
'exclude-members': '_abc_impl',
"special-members": "__call__",
"exclude-members": "_abc_impl",
# 'show-inheritance': True,
}

Expand Down
2 changes: 1 addition & 1 deletion integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@

_INTEGRATION_ROOT = os.path.realpath(os.path.dirname(__file__))
_PACKAGE_ROOT = os.path.dirname(_INTEGRATION_ROOT)
_PATH_DATASETS = os.path.join(_PACKAGE_ROOT, 'datasets')
_PATH_DATASETS = os.path.join(_PACKAGE_ROOT, "datasets")

_LIGHTNING_GREATER_EQUAL_1_3 = _compare_version("pytorch_lightning", operator.ge, "1.3.0")
5 changes: 1 addition & 4 deletions integrations/lightning/boring_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@


class RandomDictStringDataset(Dataset):

def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
Expand All @@ -31,7 +30,6 @@ def __len__(self):


class RandomDataset(Dataset):

def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
Expand All @@ -44,7 +42,6 @@ def __len__(self):


class BoringModel(LightningModule):

def __init__(self):
"""
Testing PL Module
Expand Down Expand Up @@ -99,7 +96,7 @@ def validation_step(self, batch, batch_idx):

@staticmethod
def validation_epoch_end(outputs) -> None:
torch.stack([x['x'] for x in outputs]).mean()
torch.stack([x["x"] for x in outputs]).mean()

def test_step(self, batch, batch_idx):
output = self.layer(batch)
Expand Down
31 changes: 13 additions & 18 deletions integrations/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


class SumMetric(Metric):

def __init__(self):
super().__init__()
self.add_state("x", tensor(0.0), dist_reduce_fx="sum")
Expand All @@ -38,7 +37,6 @@ def compute(self):


class DiffMetric(Metric):

def __init__(self):
super().__init__()
self.add_state("x", tensor(0.0), dist_reduce_fx="sum")
Expand All @@ -51,9 +49,7 @@ def compute(self):


def test_metric_lightning(tmpdir):

class TestModel(BoringModel):

def __init__(self):
super().__init__()
self.metric = SumMetric()
Expand All @@ -68,7 +64,7 @@ def training_step(self, batch, batch_idx):

def training_epoch_end(self, outs):
if not torch.allclose(self.sum, self.metric.compute()):
raise ValueError('Sum and computed value must be equal')
raise ValueError("Sum and computed value must be equal")
self.sum = 0.0
self.metric.reset()

Expand All @@ -86,20 +82,19 @@ def training_epoch_end(self, outs):
trainer.fit(model)


@pytest.mark.skipif(not _LIGHTNING_GREATER_EQUAL_1_3, reason='test requires lightning v1.3 or higher')
@pytest.mark.skipif(not _LIGHTNING_GREATER_EQUAL_1_3, reason="test requires lightning v1.3 or higher")
def test_metrics_reset(tmpdir):
"""Tests that metrics are reset correctly after the end of the train/val/test epoch.
Taken from:
https://github.com/PyTorchLightning/pytorch-lightning/pull/7055
"""

class TestModel(LightningModule):

def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 1)

for stage in ['train', 'val', 'test']:
for stage in ["train", "val", "test"]:
acc = Accuracy()
acc.reset = mock.Mock(side_effect=acc.reset)
ap = AveragePrecision(num_classes=1, pos_label=1)
Expand Down Expand Up @@ -134,13 +129,13 @@ def _step(self, stage, batch):
return loss

def training_step(self, batch, batch_idx, *args, **kwargs):
return self._step('train', batch)
return self._step("train", batch)

def validation_step(self, batch, batch_idx, *args, **kwargs):
return self._step('val', batch)
return self._step("val", batch)

def test_step(self, batch, batch_idx, *args, **kwargs):
return self._step('test', batch)
return self._step("test", batch)

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
Expand All @@ -167,13 +162,13 @@ def _assert_epoch_end(self, stage):
ap.reset.assert_not_called()

def train_epoch_end(self, outputs):
self._assert_epoch_end('train')
self._assert_epoch_end("train")

def validation_epoch_end(self, outputs):
self._assert_epoch_end('val')
self._assert_epoch_end("val")

def test_epoch_end(self, outputs):
self._assert_epoch_end('test')
self._assert_epoch_end("test")

def _assert_called(model, stage):
acc = model._modules[f"acc_{stage}"]
Expand All @@ -196,14 +191,14 @@ def _assert_called(model, stage):
)

trainer.fit(model)
_assert_called(model, 'train')
_assert_called(model, 'val')
_assert_called(model, "train")
_assert_called(model, "val")

trainer.validate(model)
_assert_called(model, 'val')
_assert_called(model, "val")

trainer.test(model)
_assert_called(model, 'test')
_assert_called(model, "test")


# todo: reconsider if it make sense to keep here
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ requires = [
[tool.black]
# https://github.com/psf/black
line-length = 120
target-version = ["py38"]
exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|.venv|.svn|_build|buck-out|build|dist)"
exclude = "(.eggs|.git|.hg|.mypy_cache|.venv|_build|buck-out|build|dist)"

[tool.isort]
known_first_party = [
Expand Down
Loading

0 comments on commit d8b89e0

Please sign in to comment.