Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add python types via pyre infer to miscellaneous files #11956

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
import shutil
import sys
import warnings
from importlib.machinery import ModuleSpec
from importlib.util import module_from_spec, spec_from_file_location
from types import ModuleType
from typing import Dict, List, Optional, Tuple

import pt_lightning_sphinx_theme

Expand All @@ -38,10 +41,10 @@
FOLDER_GENERATED = "generated"
SPHINX_MOCK_REQUIREMENTS = int(os.environ.get("SPHINX_MOCK_REQUIREMENTS", True))

spec = spec_from_file_location(
spec: Optional[ModuleSpec] = spec_from_file_location(
"pytorch_lightning/__about__.py", os.path.join(PATH_ROOT, "pytorch_lightning", "__about__.py")
)
about = module_from_spec(spec)
about: ModuleType = module_from_spec(spec)
spec.loader.exec_module(about)

# -- Project documents -------------------------------------------------------
Expand Down Expand Up @@ -205,7 +208,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None:
# -- Options for HTMLHelp output ---------------------------------------------

# Output file base name for HTML help builder.
htmlhelp_basename = project + "-doc"
htmlhelp_basename: str = project + "-doc"

# -- Options for LaTeX output ------------------------------------------------

Expand Down Expand Up @@ -251,7 +254,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None:
# -- Options for Epub output -------------------------------------------------

# Bibliographic Dublin Core info.
epub_title = project
epub_title: str = project

# The unique identifier of the text. This can be a ISBN number
# or the project homepage.
Expand All @@ -269,7 +272,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None:

# -- Options for intersphinx extension ---------------------------------------

intersphinx_mapping = {
intersphinx_mapping: Dict[str, Tuple[str, None]] = {
"python": ("https://docs.python.org/3", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
Expand All @@ -285,7 +288,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None:
todo_include_todos = True


def setup(app):
def setup(app) -> None:
# this is for hiding doctest decoration,
# see: http://z4r.github.io/python/2011/12/02/hides-the-prompts-and-output/
app.add_js_file("copybutton.js")
Expand All @@ -303,7 +306,7 @@ def setup(app):

# Ignoring Third-party packages
# https://stackoverflow.com/questions/15889621/sphinx-how-to-exclude-imports-in-automodule
def package_list_from_file(file):
def package_list_from_file(file) -> List[str]:
"""List up package name (not containing version and extras) from a package list file."""
mocked_packages = []
with open(file) as fp:
Expand Down
21 changes: 11 additions & 10 deletions legacy/simple_classif_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict

import torch
import torch.nn.functional as F
Expand All @@ -29,7 +30,7 @@


class SklearnDataset(Dataset):
def __init__(self, x, y, x_type, y_type):
def __init__(self, x, y, x_type, y_type) -> None:
self.x = x
self.y = y
self._x_type = x_type
Expand All @@ -38,20 +39,20 @@ def __init__(self, x, y, x_type, y_type):
def __getitem__(self, idx):
return torch.tensor(self.x[idx], dtype=self._x_type), torch.tensor(self.y[idx], dtype=self._y_type)

def __len__(self):
def __len__(self) -> int:
return len(self.y)


class SklearnDataModule(LightningDataModule):
def __init__(self, sklearn_dataset, x_type, y_type, batch_size: int = 128):
def __init__(self, sklearn_dataset, x_type, y_type, batch_size: int = 128) -> None:
super().__init__()
self.batch_size = batch_size
self._x, self._y = sklearn_dataset
self._split_data()
self._x_type = x_type
self._y_type = y_type

def _split_data(self):
def _split_data(self) -> None:
self.x_train, self.x_test, self.y_train, self.y_test = train_test_split(
self._x, self._y, test_size=0.20, random_state=42
)
Expand Down Expand Up @@ -86,7 +87,7 @@ def predict_dataloader(self):


class ClassifDataModule(SklearnDataModule):
def __init__(self, num_features=24, length=6000, num_classes=3, batch_size=128):
def __init__(self, num_features: int = 24, length: int = 6000, num_classes: int = 3, batch_size: int = 128) -> None:
data = make_classification(
n_samples=length,
n_features=num_features,
Expand All @@ -99,7 +100,7 @@ def __init__(self, num_features=24, length=6000, num_classes=3, batch_size=128):


class ClassificationModel(LightningModule):
def __init__(self, num_features=24, num_classes=3, lr=0.01):
def __init__(self, num_features: int = 24, num_classes: int = 3, lr: float = 0.01) -> None:
super().__init__()
self.save_hyperparameters()

Expand Down Expand Up @@ -128,28 +129,28 @@ def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return [optimizer], []

def training_step(self, batch, batch_idx):
def training_step(self, batch, batch_idx) -> Dict[str, Any]:
x, y = batch
logits = self.forward(x)
loss = F.cross_entropy(logits, y)
self.log("train_loss", loss, prog_bar=True)
self.log("train_acc", self.train_acc(logits, y), prog_bar=True)
return {"loss": loss}

def validation_step(self, batch, batch_idx):
def validation_step(self, batch, batch_idx) -> None:
x, y = batch
logits = self.forward(x)
self.log("val_loss", F.cross_entropy(logits, y), prog_bar=False)
self.log("val_acc", self.valid_acc(logits, y), prog_bar=True)

def test_step(self, batch, batch_idx):
def test_step(self, batch, batch_idx) -> None:
x, y = batch
logits = self.forward(x)
self.log("test_loss", F.cross_entropy(logits, y), prog_bar=False)
self.log("test_acc", self.test_acc(logits, y), prog_bar=True)


def main_train(dir_path, max_epochs: int = 20):
def main_train(dir_path, max_epochs: int = 20) -> None:
seed_everything(42)
stopping = EarlyStopping(monitor="val_acc", mode="max", min_delta=0.005)
trainer = pl.Trainer(
Expand Down
13 changes: 8 additions & 5 deletions requirements/adjust_versions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import os
import re
import sys
from typing import Dict, Optional
from typing import Dict, List, Optional

requirements_path: str
torch_version: Optional[str]

# IMPORTANT: this list needs to be sorted in reverse
VERSIONS = [
VERSIONS: List[Dict[str, str]] = [
dict(torch="1.12.0", torchvision="0.12.*", torchtext=""), # nightly
dict(torch="1.11.0", torchvision="0.12.0", torchtext="0.12.0"), # pre-release
dict(torch="1.10.2", torchvision="0.11.3", torchtext="0.11.2"), # stable
Expand Down Expand Up @@ -53,7 +56,7 @@ def main(req: str, torch_version: Optional[str] = None) -> str:
return req


def test():
def test() -> None:
requirements = """
torch>=1.2.*
torch==1.2.3
Expand Down Expand Up @@ -87,8 +90,8 @@ def test():
requirements_path, torch_version = sys.argv[1], None

with open(requirements_path) as fp:
requirements = fp.read()
requirements = main(requirements, torch_version)
requirements: str = fp.read()
requirements: str = main(requirements, torch_version)
print(requirements) # on purpose - to debug
with open(requirements_path, "w") as fp:
fp.write(requirements)
8 changes: 5 additions & 3 deletions requirements/collect_env_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,21 @@
import os
import platform
import sys
from typing import Dict, Tuple, Union

import numpy
import torch
import tqdm

sys.path += [os.path.abspath(".."), os.path.abspath(".")]

import pytorch_lightning # noqa: E402

LEVEL_OFFSET = "\t"
KEY_PADDING = 20


def info_system():
def info_system() -> Dict[str, Union[str, Tuple[str, str]]]:
return {
"OS": platform.system(),
"architecture": platform.architecture(),
Expand Down Expand Up @@ -60,7 +62,7 @@ def info_packages():
}


def nice_print(details, level=0):
def nice_print(details, level: int = 0):
lines = []
for k in sorted(details):
key = f"* {k}:" if level == 0 else f"- {k}:"
Expand All @@ -77,7 +79,7 @@ def nice_print(details, level=0):
return lines


def main():
def main() -> None:
details = {"System": info_system(), "CUDA": info_cuda(), "Packages": info_packages()}
lines = nice_print(details)
text = os.linesep.join(lines)
Expand Down
6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import os
from importlib.util import module_from_spec, spec_from_file_location
from types import ModuleType
from typing import Any, Dict

from setuptools import find_packages, setup

Expand All @@ -24,7 +26,7 @@
_PATH_REQUIRE = os.path.join(_PATH_ROOT, "requirements")


def _load_py_module(fname, pkg="pytorch_lightning"):
def _load_py_module(fname, pkg: str = "pytorch_lightning") -> ModuleType:
spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_ROOT, pkg, fname))
py = module_from_spec(spec)
spec.loader.exec_module(py)
Expand All @@ -38,7 +40,7 @@ def _load_py_module(fname, pkg="pytorch_lightning"):
# Define package extras. These are only installed if you specify them.
# From remote, use like `pip install pytorch-lightning[dev, docs]`
# From local copy of repo, use like `pip install ".[dev, docs]"`
extras = {
extras: Dict[str, Any] = {
# 'docs': load_requirements(file_name='docs.txt'),
"examples": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="examples.txt"),
"loggers": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="loggers.txt"),
Expand Down