diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 8be6efe334..6239a2115e 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -32,7 +32,8 @@ repos:
rev: v1.7.3
hooks:
- id: docformatter
- args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120]
+ additional_dependencies: [tomli]
+ args: ["--in-place"]
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.16
@@ -44,11 +45,6 @@ repos:
- mdformat_frontmatter
exclude: CHANGELOG.md
- #- repo: https://github.com/PyCQA/isort
- # rev: 5.12.0
- # hooks:
- # - id: isort
-
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
diff --git a/pyproject.toml b/pyproject.toml
index adca3a93b1..e4e0cabbdc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -50,11 +50,11 @@ relative_files = true
line-length = 120
exclude = "(.eggs|.git|.hg|.mypy_cache|.venv|_build|buck-out|build|dist)"
-[tool.isort]
-known_first_party = ["pl_bolts", "tests", "notebooks"]
-skip_glob = []
-profile = "black"
-line_length = 120
+[tool.docformatter]
+recursive = true
+wrap-summaries = 120
+wrap-descriptions = 120
+blank = true
[tool.ruff]
diff --git a/setup.py b/setup.py
index 7e88ff3063..e4b6b74349 100755
--- a/setup.py
+++ b/setup.py
@@ -31,6 +31,7 @@ def _augment_requirement(ln: str, comment_char: str = "#", unfreeze: bool = True
'arrow>=1.2.0, <=1.2.2 # strict'
>>> _augment_requirement("arrow", unfreeze=True)
'arrow'
+
"""
# filer all comments
if comment_char in ln:
@@ -61,6 +62,7 @@ def _load_requirements(path_dir: str, file_name: str, unfreeze: bool = not _FREE
>>> path_req = os.path.join(_PATH_ROOT, "requirements")
>>> _load_requirements(path_req, "docs.txt") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
['sphinx>=4.0', ...]
+
"""
with open(os.path.join(path_dir, file_name)) as file:
lines = [ln.strip() for ln in file.readlines()]
@@ -77,6 +79,7 @@ def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str:
>>> _load_readme_description(_PATH_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
'
...'
+
"""
path_readme = os.path.join(path_dir, "README.md")
with open(path_readme, encoding="utf-8") as fo:
diff --git a/src/pl_bolts/callbacks/byol_updates.py b/src/pl_bolts/callbacks/byol_updates.py
index b7fdea1c1c..c3e2bd2378 100644
--- a/src/pl_bolts/callbacks/byol_updates.py
+++ b/src/pl_bolts/callbacks/byol_updates.py
@@ -30,6 +30,7 @@ class BYOLMAWeightUpdate(Callback):
model.target_network = ...
trainer = Trainer(callbacks=[BYOLMAWeightUpdate()])
+
"""
def __init__(self, initial_tau: float = 0.996) -> None:
diff --git a/src/pl_bolts/callbacks/data_monitor.py b/src/pl_bolts/callbacks/data_monitor.py
index 62ec3b7f7c..7a39a1e709 100644
--- a/src/pl_bolts/callbacks/data_monitor.py
+++ b/src/pl_bolts/callbacks/data_monitor.py
@@ -35,12 +35,13 @@ class DataMonitorBase(Callback):
def __init__(self, log_every_n_steps: int = None) -> None:
"""Base class for monitoring data histograms in a LightningModule. This requires a logger configured in the
- Trainer, otherwise no data is logged. The specific class that inherits from this base defines what data
- gets collected.
+ Trainer, otherwise no data is logged. The specific class that inherits from this base defines what data gets
+ collected.
Args:
log_every_n_steps: The interval at which histograms should be logged. This defaults to the
interval defined in the Trainer. Use this to override the Trainer default.
+
"""
super().__init__()
self._log_every_n_steps: Optional[int] = log_every_n_steps
@@ -84,12 +85,13 @@ def log_histograms(self, batch: Any, group: str = "") -> None:
self.log_histogram(tensor, name)
def log_histogram(self, tensor: Tensor, name: str) -> None:
- """Override this method to customize the logging of histograms. Detaches the tensor from the graph and
- moves it to the CPU for logging.
+ """Override this method to customize the logging of histograms. Detaches the tensor from the graph and moves it
+ to the CPU for logging.
Args:
tensor: The tensor for which to log a histogram
name: The name of the tensor as determined by the callback. Example: ``ìnput/0/[64, 1, 28, 28]``
+
"""
logger = self._trainer.logger
tensor = tensor.detach().cpu()
@@ -234,9 +236,9 @@ def on_train_batch_start(
def collect_and_name_tensors(data: Any, output: Dict[str, Tensor], parent_name: str = "input") -> None:
- """Recursively fetches all tensors in a (nested) collection of data (depth-first search) and names them. Data
- in dictionaries get named by their corresponding keys and otherwise they get indexed by an increasing integer.
- The shape of the tensor gets appended to the name as well.
+ """Recursively fetches all tensors in a (nested) collection of data (depth-first search) and names them. Data in
+ dictionaries get named by their corresponding keys and otherwise they get indexed by an increasing integer. The
+ shape of the tensor gets appended to the name as well.
Args:
data: A collection of data (potentially nested).
@@ -249,6 +251,7 @@ def collect_and_name_tensors(data: Any, output: Dict[str, Tensor], parent_name:
>>> collect_and_name_tensors(data, output)
>>> output # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
{'input/x/[2, 3]': ..., 'input/y/z/[5]': ...}
+
"""
assert isinstance(output, dict)
if isinstance(data, Tensor):
@@ -273,5 +276,6 @@ def shape2str(tensor: Tensor) -> str:
'[1, 2, 3]'
>>> shape2str(torch.rand(4))
'[4]'
+
"""
return "[" + ", ".join(map(str, tensor.shape)) + "]"
diff --git a/src/pl_bolts/callbacks/printing.py b/src/pl_bolts/callbacks/printing.py
index ada4cee3cc..1ba298aa12 100644
--- a/src/pl_bolts/callbacks/printing.py
+++ b/src/pl_bolts/callbacks/printing.py
@@ -33,6 +33,7 @@ class PrintTableMetricsCallback(Callback):
# loss│train_loss│val_loss│epoch
# ──────────────────────────────
# 2.2541470527648926│2.2541470527648926│2.2158432006835938│0
+
"""
def __init__(self) -> None:
diff --git a/src/pl_bolts/callbacks/sparseml.py b/src/pl_bolts/callbacks/sparseml.py
index 182c2791f0..3a5d48b471 100644
--- a/src/pl_bolts/callbacks/sparseml.py
+++ b/src/pl_bolts/callbacks/sparseml.py
@@ -33,6 +33,7 @@ class SparseMLCallback(Callback):
Args:
recipe_path: Path to a SparseML compatible yaml recipe.
More information at https://docs.neuralmagic.com/sparseml/source/recipes.html
+
"""
def __init__(self, recipe_path: str) -> None:
diff --git a/src/pl_bolts/callbacks/ssl_online.py b/src/pl_bolts/callbacks/ssl_online.py
index 5f5d01b765..943a2a28f7 100644
--- a/src/pl_bolts/callbacks/ssl_online.py
+++ b/src/pl_bolts/callbacks/ssl_online.py
@@ -31,6 +31,7 @@ class SSLOnlineEvaluator(Callback): # pragma: no cover
online_eval = SSLOnlineEvaluator(
z_dim=model.z_dim
)
+
"""
def __init__(
@@ -182,6 +183,7 @@ def set_training(module: nn.Module, mode: bool):
Args:
module: module to set training mode
mode: whether to set training mode (True) or evaluation mode (False).
+
"""
original_mode = module.training
diff --git a/src/pl_bolts/callbacks/variational.py b/src/pl_bolts/callbacks/variational.py
index 5b709f0742..66f74871a2 100644
--- a/src/pl_bolts/callbacks/variational.py
+++ b/src/pl_bolts/callbacks/variational.py
@@ -18,8 +18,8 @@
@under_review()
class LatentDimInterpolator(Callback):
- """Interpolates the latent space for a model by setting all dims to zero and stepping through the first two
- dims increasing one unit at a time.
+ """Interpolates the latent space for a model by setting all dims to zero and stepping through the first two dims
+ increasing one unit at a time.
Default interpolates between [-5, 5] (-5, -4, -3, ..., 3, 4, 5)
@@ -28,6 +28,7 @@ class LatentDimInterpolator(Callback):
from pl_bolts.callbacks import LatentDimInterpolator
Trainer(callbacks=[LatentDimInterpolator()])
+
"""
def __init__(
diff --git a/src/pl_bolts/callbacks/verification/base.py b/src/pl_bolts/callbacks/verification/base.py
index 9f81922067..49e2e3593a 100644
--- a/src/pl_bolts/callbacks/verification/base.py
+++ b/src/pl_bolts/callbacks/verification/base.py
@@ -17,6 +17,7 @@ class VerificationBase:
All verifications should run with any
:class: `torch.nn.Module` unless otherwise stated.
+
"""
def __init__(self, model: nn.Module) -> None:
@@ -39,14 +40,16 @@ def check(self, *args: Any, **kwargs: Any) -> bool:
`True` if the test passes, and `False` otherwise. Some verifications can only be performed
with a heuristic accuracy, thus the return value may not always reflect the true state of
the system in these cases.
+
"""
def _get_input_array_copy(self, input_array: Optional[Any] = None) -> Any:
- """Returns a deep copy of the example input array in cases where it is expected that the input changes
- during the verification process.
+ """Returns a deep copy of the example input array in cases where it is expected that the input changes during
+ the verification process.
Arguments:
input_array: The input to clone.
+
"""
if input_array is None and isinstance(self.model, LightningModule):
input_array = self.model.example_input_array
@@ -89,6 +92,7 @@ class VerificationCallbackBase(Callback):
This type of verification is expected to only work with
:class:`~pytorch_lightning.core.lightning.LightningModule` and will take the input array
from :attr:`~pytorch_lightning.core.lightning.LightningModule.example_input_array` if needed.
+
"""
def __init__(self, warn: bool = True, error: bool = False) -> None:
diff --git a/src/pl_bolts/callbacks/verification/batch_gradient.py b/src/pl_bolts/callbacks/verification/batch_gradient.py
index a7fef82548..834184c152 100644
--- a/src/pl_bolts/callbacks/verification/batch_gradient.py
+++ b/src/pl_bolts/callbacks/verification/batch_gradient.py
@@ -19,6 +19,7 @@ class BatchGradientVerification(VerificationBase):
This can happen if reshape- and/or permutation operations are carried out in the wrong order or on the wrong tensor
dimensions.
+
"""
NORM_LAYER_CLASSES = (
@@ -57,6 +58,7 @@ def check(
Returns:
``True`` if the data in the batch does not mix during the forward pass, and ``False`` otherwise.
+
"""
input_mapping = input_mapping or default_input_mapping
output_mapping = output_mapping or default_output_mapping
@@ -151,6 +153,7 @@ def default_input_mapping(data: Any) -> List[Tensor]:
torch.Size([3, 1])
>>> result[1].shape
torch.Size([3, 2])
+
"""
tensors = collect_tensors(data)
batches: List[Tensor] = []
@@ -181,6 +184,7 @@ def default_output_mapping(data: Any) -> Tensor:
>>> result = default_output_mapping(data)
>>> result.shape
torch.Size([3, 7])
+
"""
if isinstance(data, Tensor):
return data
diff --git a/src/pl_bolts/callbacks/vision/confused_logit.py b/src/pl_bolts/callbacks/vision/confused_logit.py
index 1098cb9227..e66a088a35 100644
--- a/src/pl_bolts/callbacks/vision/confused_logit.py
+++ b/src/pl_bolts/callbacks/vision/confused_logit.py
@@ -46,6 +46,7 @@ class ConfusedLogitCallback(Callback): # pragma: no cover
Authored by:
- Alfredo Canziani
+
"""
def __init__(
diff --git a/src/pl_bolts/datamodules/async_dataloader.py b/src/pl_bolts/datamodules/async_dataloader.py
index 0767ad4a1a..98f373f85c 100644
--- a/src/pl_bolts/datamodules/async_dataloader.py
+++ b/src/pl_bolts/datamodules/async_dataloader.py
@@ -28,6 +28,7 @@ class AsynchronousLoader:
if set and DataLoader has a __len__. Otherwise it can be left as None
**kwargs: Any additional arguments to pass to the dataloader if we're
constructing one here
+
"""
def __init__(
diff --git a/src/pl_bolts/datamodules/cifar10_datamodule.py b/src/pl_bolts/datamodules/cifar10_datamodule.py
index 0fa39f44c3..47779c1e4b 100644
--- a/src/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/src/pl_bolts/datamodules/cifar10_datamodule.py
@@ -152,6 +152,7 @@ class TinyCIFAR10DataModule(CIFAR10DataModule):
dm = CIFAR10DataModule(PATH)
model = LitModel(datamodule=dm)
+
"""
dataset_cls = TrialCIFAR10
diff --git a/src/pl_bolts/datamodules/emnist_datamodule.py b/src/pl_bolts/datamodules/emnist_datamodule.py
index 1c76cd2050..75d4af155b 100644
--- a/src/pl_bolts/datamodules/emnist_datamodule.py
+++ b/src/pl_bolts/datamodules/emnist_datamodule.py
@@ -176,6 +176,7 @@ def num_classes(self) -> int:
"""Returns the number of classes.
See the table above.
+
"""
return len(self.dataset_cls.classes_split_dict[self.split])
diff --git a/src/pl_bolts/datamodules/experience_source.py b/src/pl_bolts/datamodules/experience_source.py
index b0ceee1522..2a0d4467e4 100644
--- a/src/pl_bolts/datamodules/experience_source.py
+++ b/src/pl_bolts/datamodules/experience_source.py
@@ -30,6 +30,7 @@ class ExperienceSourceDataset(IterableDataset):
Takes a generate_batch function that returns an iterator. The logic for the experience source and how the batch is
generated is defined the Lightning model itself
+
"""
def __init__(self, generate_batch: Callable) -> None:
@@ -95,6 +96,7 @@ def runner(self, device: torch.device) -> Tuple[Experience]:
Returns:
Tuple of Experiences
+
"""
while True:
# get actions for all envs
@@ -116,14 +118,15 @@ def runner(self, device: torch.device) -> Tuple[Experience]:
self.iter_idx += 1
def update_history_queue(self, env_idx, exp, history) -> None:
- """Updates the experience history queue with the lastest experiences. In the event of an experience step is
- in the done state, the history will be incrementally appended to the queue, removing the tail of the
- history each time.
+ """Updates the experience history queue with the lastest experiences. In the event of an experience step is in
+ the done state, the history will be incrementally appended to the queue, removing the tail of the history each
+ time.
Args:
env_idx: index of the environment
exp: the current experience
history: history of experience steps for this environment
+
"""
# If there is a full history of step, append history to queue
if len(history) == self.n_steps:
@@ -184,6 +187,7 @@ def env_step(self, env_idx: int, env: Env, action: List[int]) -> Experience:
Returns:
Experience tuple
+
"""
next_state, r, is_done, _ = env.step(action[0])
@@ -198,6 +202,7 @@ def update_env_stats(self, env_idx: int) -> None:
Args:
env_idx: index of the environment used to update stats
+
"""
self._total_rewards.append(self.cur_rewards[env_idx])
self.total_steps.append(self.cur_steps[env_idx])
@@ -248,6 +253,7 @@ def runner(self, device: torch.device) -> Experience:
Yields:
Discounted Experience
+
"""
for experiences in super().runner(device):
last_exp_state, tail_experiences = self.split_head_tail_exp(experiences)
@@ -263,14 +269,15 @@ def runner(self, device: torch.device) -> Experience:
)
def split_head_tail_exp(self, experiences: Tuple[Experience]) -> Tuple[List, Tuple[Experience]]:
- """Takes in a tuple of experiences and returns the last state and tail experiences based on if the last
- state is the end of an episode.
+ """Takes in a tuple of experiences and returns the last state and tail experiences based on if the last state is
+ the end of an episode.
Args:
experiences: Tuple of N Experience
Returns:
last state (Array or None) and remaining Experience
+
"""
if experiences[-1].done and len(experiences) <= self.steps:
last_exp_state = experiences[-1].new_state
@@ -288,6 +295,7 @@ def discount_rewards(self, experiences: Tuple[Experience]) -> float:
Returns:
total discounted reward
+
"""
total_reward = 0.0
for exp in reversed(experiences):
diff --git a/src/pl_bolts/datamodules/imagenet_datamodule.py b/src/pl_bolts/datamodules/imagenet_datamodule.py
index 5d1d5a563f..90f2bb641d 100644
--- a/src/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/src/pl_bolts/datamodules/imagenet_datamodule.py
@@ -119,6 +119,7 @@ def prepare_data(self) -> None:
"""This method already assumes you have imagenet2012 downloaded. It validates the data using the meta.bin.
.. warning:: Please download imagenet on your own first.
+
"""
self._verify_splits(self.data_dir, "train")
self._verify_splits(self.data_dir, "val")
@@ -223,6 +224,7 @@ def train_transform(self) -> Callable:
std=[0.229, 0.224, 0.225]
),
])
+
"""
return transform_lib.Compose(
[
@@ -247,6 +249,7 @@ def val_transform(self) -> Callable:
std=[0.229, 0.224, 0.225]
),
])
+
"""
return transform_lib.Compose(
diff --git a/src/pl_bolts/datamodules/kitti_datamodule.py b/src/pl_bolts/datamodules/kitti_datamodule.py
index 30fb18c22e..f30f99689b 100644
--- a/src/pl_bolts/datamodules/kitti_datamodule.py
+++ b/src/pl_bolts/datamodules/kitti_datamodule.py
@@ -69,6 +69,7 @@ def __init__(
pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
returning them
drop_last: If true drops the last incomplete batch
+
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.")
diff --git a/src/pl_bolts/datamodules/sklearn_datamodule.py b/src/pl_bolts/datamodules/sklearn_datamodule.py
index fedd464dc8..1c93613690 100644
--- a/src/pl_bolts/datamodules/sklearn_datamodule.py
+++ b/src/pl_bolts/datamodules/sklearn_datamodule.py
@@ -33,6 +33,7 @@ class SklearnDataset(Dataset):
>>> dataset = SklearnDataset(X, y)
>>> len(dataset)
442
+
"""
def __init__(
@@ -70,8 +71,8 @@ def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]:
@under_review()
class SklearnDataModule(LightningDataModule):
- """Automatically generates the train, validation and test splits for a Numpy dataset. They are set up as
- dataloaders for convenience. Optionally, you can pass in your own validation and test splits.
+ """Automatically generates the train, validation and test splits for a Numpy dataset. They are set up as dataloaders
+ for convenience. Optionally, you can pass in your own validation and test splits.
Example:
@@ -99,6 +100,7 @@ class SklearnDataModule(LightningDataModule):
44
>>> len(test_loader)
2
+
"""
name = "sklearn"
diff --git a/src/pl_bolts/datamodules/sr_datamodule.py b/src/pl_bolts/datamodules/sr_datamodule.py
index 3f6af04ea4..b6d30b3c80 100644
--- a/src/pl_bolts/datamodules/sr_datamodule.py
+++ b/src/pl_bolts/datamodules/sr_datamodule.py
@@ -18,6 +18,7 @@ class TVTDataModule(LightningDataModule):
dataset_train, dataset_val = random_split(dataset_dev, lengths=[55_000, 5_000])
dataset_test = SRMNIST(scale_factor=4, root=".", train=True)
dm = TVTDataModule(dataset_train, dataset_val, dataset_test)
+
"""
def __init__(
diff --git a/src/pl_bolts/datamodules/stl10_datamodule.py b/src/pl_bolts/datamodules/stl10_datamodule.py
index 522649645b..158baee0bc 100644
--- a/src/pl_bolts/datamodules/stl10_datamodule.py
+++ b/src/pl_bolts/datamodules/stl10_datamodule.py
@@ -177,6 +177,7 @@ def val_dataloader(self) -> DataLoader:
batch_size: the batch size
transforms: a sequence of transforms
+
"""
transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
@@ -210,6 +211,7 @@ def val_dataloader_mixed(self) -> DataLoader:
batch_size: the batch size
transforms: a sequence of transforms
+
"""
transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
unlabeled_dataset = STL10(self.data_dir, split="unlabeled", download=False, transform=transforms)
@@ -244,6 +246,7 @@ def test_dataloader(self) -> DataLoader:
Args:
batch_size: the batch size
transforms: the transforms
+
"""
transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms
diff --git a/src/pl_bolts/datamodules/vocdetection_datamodule.py b/src/pl_bolts/datamodules/vocdetection_datamodule.py
index df07fc19d9..de8a84b0f3 100644
--- a/src/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/src/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -67,6 +67,7 @@ def _prepare_voc_instance(image: Any, target: Dict[str, Any]):
"""Prepares VOC dataset into appropriate target for fasterrcnn.
https://github.com/pytorch/vision/issues/1097#issuecomment-508917489
+
"""
anno = target["annotation"]
boxes = []
diff --git a/src/pl_bolts/datasets/array_dataset.py b/src/pl_bolts/datasets/array_dataset.py
index 8251d3d9a5..ab17d3b7e2 100644
--- a/src/pl_bolts/datasets/array_dataset.py
+++ b/src/pl_bolts/datasets/array_dataset.py
@@ -28,6 +28,7 @@ class ArrayDataset(Dataset):
>>> ds = ArrayDataset(features, target)
>>> len(ds)
3
+
"""
def __init__(self, *data_models: DataModel) -> None:
@@ -48,5 +49,6 @@ def _equal_size(self) -> bool:
Returns:
bool: True if size of data_models are equal in the first dimension. False, if not.
+
"""
return len({len(data_model.data) for data_model in self.data_models}) == 1
diff --git a/src/pl_bolts/datasets/base_dataset.py b/src/pl_bolts/datasets/base_dataset.py
index 4fcc8e93df..57767fcbe1 100644
--- a/src/pl_bolts/datasets/base_dataset.py
+++ b/src/pl_bolts/datasets/base_dataset.py
@@ -71,6 +71,7 @@ class DataModel:
Attributes:
data: Sequence of indexables.
transform: Callable to transform data. The transform is called on a subset of data.
+
"""
data: TArrays
@@ -84,6 +85,7 @@ def process(self, subset: Union[TArrays, float]) -> Union[TArrays, float]:
Returns:
data: Transformed data if transform is not None.
+
"""
if self.transform is not None:
subset = self.transform(subset)
diff --git a/src/pl_bolts/datasets/dummy_dataset.py b/src/pl_bolts/datasets/dummy_dataset.py
index 4c65b2c19a..71a850b196 100644
--- a/src/pl_bolts/datasets/dummy_dataset.py
+++ b/src/pl_bolts/datasets/dummy_dataset.py
@@ -18,6 +18,7 @@ class DummyDataset(Dataset):
torch.Size([7, 1, 28, 28])
>>> y.size()
torch.Size([7, 1])
+
"""
def __init__(self, *shapes, num_samples: int = 10000) -> None:
@@ -62,6 +63,7 @@ class DummyDetectionDataset(Dataset):
torch.Size([7, 1, 4])
>>> y['labels'].size()
torch.Size([7, 1])
+
"""
def __init__(
@@ -110,6 +112,7 @@ class RandomDictDataset(Dataset):
(7, 10)
>>> len(batch['b']),len(batch['b'][0])
(7, 10)
+
"""
def __init__(self, size: int, num_samples: int = 250) -> None:
@@ -150,6 +153,7 @@ class RandomDictStringDataset(Dataset):
['0', '1', '2', '3', '4', '5', '6']
>>> len(batch['x'])
7
+
"""
def __init__(self, size: int, num_samples: int = 250) -> None:
@@ -182,6 +186,7 @@ class RandomDataset(Dataset):
>>> batch = next(iter(dl))
>>> len(batch),len(batch[0])
(7, 10)
+
"""
def __init__(self, size: int, num_samples: int = 250) -> None:
diff --git a/src/pl_bolts/datasets/emnist_dataset.py b/src/pl_bolts/datasets/emnist_dataset.py
index 7256d21ab8..a3403faaa7 100644
--- a/src/pl_bolts/datasets/emnist_dataset.py
+++ b/src/pl_bolts/datasets/emnist_dataset.py
@@ -41,6 +41,7 @@ class BinaryEMNIST(EMNIST):
Note:
Documentation is based on https://pytorch.org/vision/main/generated/torchvision.datasets.EMNIST.html
+
"""
def __init__(self, root: str, split: str, threshold: Union[int, float] = 127.0, **kwargs: Any) -> None:
diff --git a/src/pl_bolts/datasets/imagenet_dataset.py b/src/pl_bolts/datasets/imagenet_dataset.py
index 5ab373cdef..b5d0c275d2 100644
--- a/src/pl_bolts/datasets/imagenet_dataset.py
+++ b/src/pl_bolts/datasets/imagenet_dataset.py
@@ -22,10 +22,11 @@
@under_review()
class UnlabeledImagenet(ImageNet):
- """Official train set gets split into train, val. (using num_imgs_per_val_class for each class). Official
- validation becomes test set.
+ """Official train set gets split into train, val. (using num_imgs_per_val_class for each class). Official validation
+ becomes test set.
Within each class, we further allow limiting the number of samples per class (for semi-sup lng)
+
"""
def __init__(
@@ -188,13 +189,14 @@ def _calculate_md5(fpath, chunk_size=1024 * 1024):
@under_review()
def parse_devkit_archive(root, file=None):
- """Parse the devkit archive of the ImageNet2012 classification dataset and save the meta information in a
- binary file.
+ """Parse the devkit archive of the ImageNet2012 classification dataset and save the meta information in a binary
+ file.
Args:
root (str): Root directory containing the devkit archive
file (str, optional): Name of devkit archive. Defaults to
'ILSVRC2012_devkit_t12.tar.gz'
+
"""
from scipy import io as sio
diff --git a/src/pl_bolts/datasets/kitti_dataset.py b/src/pl_bolts/datasets/kitti_dataset.py
index de0b6c7a27..0e6674224d 100644
--- a/src/pl_bolts/datasets/kitti_dataset.py
+++ b/src/pl_bolts/datasets/kitti_dataset.py
@@ -34,6 +34,7 @@ class KittiDataset(Dataset):
img_size (tuple): image dimensions (width, height)
valid_labels (tuple): useful classes to include
transform (callable, optional): A function/transform that takes in the numpy array and transforms it.
+
"""
IMAGE_PATH = os.path.join("training", "image_2")
diff --git a/src/pl_bolts/datasets/mnist_dataset.py b/src/pl_bolts/datasets/mnist_dataset.py
index 851dfd0bd8..18d716481a 100644
--- a/src/pl_bolts/datasets/mnist_dataset.py
+++ b/src/pl_bolts/datasets/mnist_dataset.py
@@ -38,6 +38,7 @@ class BinaryMNIST(MNIST):
Note:
Documentation is based on https://pytorch.org/vision/main/generated/torchvision.datasets.EMNIST.html
+
"""
def __init__(self, root: str, threshold: Union[int, float] = 127.0, **kwargs: Any) -> None:
diff --git a/src/pl_bolts/datasets/sr_celeba_dataset.py b/src/pl_bolts/datasets/sr_celeba_dataset.py
index e776eef64d..ff4127131b 100644
--- a/src/pl_bolts/datasets/sr_celeba_dataset.py
+++ b/src/pl_bolts/datasets/sr_celeba_dataset.py
@@ -23,6 +23,7 @@ class SRCelebA(SRDatasetMixin, CelebA):
"""CelebA dataset that can be used to train Super Resolution models.
Function __getitem__ (implemented in SRDatasetMixin) returns tuple of high and low resolution image.
+
"""
def __init__(self, scale_factor: int, *args: Any, **kwargs: Any) -> None:
diff --git a/src/pl_bolts/datasets/sr_dataset_mixin.py b/src/pl_bolts/datasets/sr_dataset_mixin.py
index 4a22717134..cdeddce054 100644
--- a/src/pl_bolts/datasets/sr_dataset_mixin.py
+++ b/src/pl_bolts/datasets/sr_dataset_mixin.py
@@ -23,6 +23,7 @@ class SRDatasetMixin:
"""Mixin for Super Resolution datasets.
Scales range of high resolution images to [-1, 1] and range or low resolution images to [0, 1].
+
"""
def __init__(self, hr_image_size: int, lr_image_size: int, image_channels: int, *args: Any, **kwargs: Any) -> None:
diff --git a/src/pl_bolts/datasets/sr_mnist_dataset.py b/src/pl_bolts/datasets/sr_mnist_dataset.py
index e950a80054..62521fb4a4 100644
--- a/src/pl_bolts/datasets/sr_mnist_dataset.py
+++ b/src/pl_bolts/datasets/sr_mnist_dataset.py
@@ -17,6 +17,7 @@ class SRMNIST(SRDatasetMixin, MNIST):
"""MNIST dataset that can be used to train Super Resolution models.
Function __getitem__ (implemented in SRDatasetMixin) returns tuple of high and low resolution image.
+
"""
def __init__(self, scale_factor: int, *args: Any, **kwargs: Any) -> None:
diff --git a/src/pl_bolts/datasets/sr_stl10_dataset.py b/src/pl_bolts/datasets/sr_stl10_dataset.py
index a331da0aef..b13c9e7be2 100644
--- a/src/pl_bolts/datasets/sr_stl10_dataset.py
+++ b/src/pl_bolts/datasets/sr_stl10_dataset.py
@@ -24,6 +24,7 @@ class SRSTL10(SRDatasetMixin, STL10):
"""STL10 dataset that can be used to train Super Resolution models.
Function __getitem__ (implemented in SRDatasetMixin) returns tuple of high and low resolution image.
+
"""
def __init__(self, scale_factor: int, *args: Any, **kwargs: Any) -> None:
diff --git a/src/pl_bolts/datasets/ssl_amdim_datasets.py b/src/pl_bolts/datasets/ssl_amdim_datasets.py
index 4d013d0560..241c2c1178 100644
--- a/src/pl_bolts/datasets/ssl_amdim_datasets.py
+++ b/src/pl_bolts/datasets/ssl_amdim_datasets.py
@@ -48,6 +48,7 @@ def select_num_imgs_per_class(cls, examples, labels, num_imgs_in_val):
"""Splits a dataset into two parts.
The labeled split has num_imgs_in_val per class
+
"""
num_classes = len(set(labels))
diff --git a/src/pl_bolts/datasets/utils.py b/src/pl_bolts/datasets/utils.py
index f6fcf1a7f2..e3b085fbc1 100644
--- a/src/pl_bolts/datasets/utils.py
+++ b/src/pl_bolts/datasets/utils.py
@@ -25,6 +25,7 @@ def prepare_sr_datasets(dataset: str, scale_factor: int, data_dir: str):
Returns:
sr_datasets: tuple containing train, val, and test dataset.
+
"""
assert dataset in ["celeba", "mnist", "stl10"]
diff --git a/src/pl_bolts/losses/rl.py b/src/pl_bolts/losses/rl.py
index f5da25393d..52702975a7 100644
--- a/src/pl_bolts/losses/rl.py
+++ b/src/pl_bolts/losses/rl.py
@@ -21,6 +21,7 @@ def dqn_loss(batch: Tuple[Tensor, Tensor], net: nn.Module, target_net: nn.Module
Returns:
loss
+
"""
states, actions, rewards, dones, next_states = batch
@@ -45,9 +46,9 @@ def double_dqn_loss(
target_net: nn.Module,
gamma: float = 0.99,
) -> Tensor:
- """Calculates the mse loss using a mini batch from the replay buffer. This uses an improvement to the original
- DQN loss by using the double dqn. This is shown by using the actions of the train network to pick the value
- from the target network. This code is heavily commented in order to explain the process clearly.
+ """Calculates the mse loss using a mini batch from the replay buffer. This uses an improvement to the original DQN
+ loss by using the double dqn. This is shown by using the actions of the train network to pick the value from the
+ target network. This code is heavily commented in order to explain the process clearly.
Args:
batch: current mini batch of replay data
@@ -57,6 +58,7 @@ def double_dqn_loss(
Returns:
loss
+
"""
states, actions, rewards, dones, next_states = batch # batch of experiences, batch_size = 16
@@ -103,6 +105,7 @@ def per_dqn_loss(
Returns:
loss and batch_weights
+
"""
states, actions, rewards, dones, next_states = batch
diff --git a/src/pl_bolts/losses/self_supervised_learning.py b/src/pl_bolts/losses/self_supervised_learning.py
index 04571121f4..2588fc6073 100644
--- a/src/pl_bolts/losses/self_supervised_learning.py
+++ b/src/pl_bolts/losses/self_supervised_learning.py
@@ -209,6 +209,7 @@ class FeatureMapContrastiveTask(nn.Module):
# will compare the following:
# 01: (pos_0, anc_1), (anc_0, pos_1)
# 02: (pos_0, anc_2), (anc_0, pos_2)
+
"""
def __init__(self, comparisons: str = "00, 11", tclip: float = 10.0, bidirectional: bool = True) -> None:
@@ -321,6 +322,7 @@ def forward(self, anchor_maps, positive_maps):
tensor([2.2351, 2.1902])
>>> regularizer
tensor(0.0324)
+
"""
assert len(anchor_maps) == len(self.map_indexes), f"expected each input to have {len(self.map_indexes)} tensors"
diff --git a/src/pl_bolts/models/autoencoders/basic_ae/__init__.py b/src/pl_bolts/models/autoencoders/basic_ae/__init__.py
index fae8e6d631..04fa50ef69 100644
--- a/src/pl_bolts/models/autoencoders/basic_ae/__init__.py
+++ b/src/pl_bolts/models/autoencoders/basic_ae/__init__.py
@@ -18,4 +18,5 @@
model = AE()
trainer = pl.Trainer()
trainer.fit(model)
+
"""
diff --git a/src/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py b/src/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py
index 3c1ba25599..300d459492 100644
--- a/src/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py
+++ b/src/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py
@@ -29,6 +29,7 @@ class AE(LightningModule):
# pretrained on cifar10
ae = AE(input_height=32).from_pretrained('cifar10-resnet18')
+
"""
pretrained_urls = {
diff --git a/src/pl_bolts/models/autoencoders/basic_vae/__init__.py b/src/pl_bolts/models/autoencoders/basic_vae/__init__.py
index bb58b6b05b..678b93c6ad 100644
--- a/src/pl_bolts/models/autoencoders/basic_vae/__init__.py
+++ b/src/pl_bolts/models/autoencoders/basic_vae/__init__.py
@@ -11,4 +11,5 @@
The default encoder is a resnet18 backbone followed by linear layers which map representations to mu and var. The
default decoder mirrors the encoder architecture and is similar to an inverted resnet18. The model also assumes a
Gaussian prior and a Gaussian approximate posterior distribution.
+
"""
diff --git a/src/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py b/src/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py
index 125cc2727d..0cf6729df2 100644
--- a/src/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py
+++ b/src/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py
@@ -32,6 +32,7 @@ class VAE(LightningModule):
# pretrained on stl10
vae = VAE(input_height=32).from_pretrained('stl10-resnet18')
+
"""
pretrained_urls = {
diff --git a/src/pl_bolts/models/detection/components/torchvision_backbones.py b/src/pl_bolts/models/detection/components/torchvision_backbones.py
index 835e1301d2..ce8ecc0066 100644
--- a/src/pl_bolts/models/detection/components/torchvision_backbones.py
+++ b/src/pl_bolts/models/detection/components/torchvision_backbones.py
@@ -15,6 +15,7 @@ def _create_backbone_generic(model: nn.Module, out_channels: int) -> nn.Module:
Args:
model: torch.nn model
out_channels: Number of out_channels in last layer.
+
"""
modules_total = list(model.children())
modules = modules_total[:-1]
@@ -32,6 +33,7 @@ def _create_backbone_adaptive(model: nn.Module, out_channels: Optional[int] = No
Args:
model: torch.nn model with adaptive pooling layer
out_channels: Number of out_channels in last layer
+
"""
if out_channels is None:
modules_total = list(model.children())
@@ -46,6 +48,7 @@ def _create_backbone_features(model: nn.Module, out_channels: int) -> nn.Module:
Args:
model: torch.nn model with features as sequential block.
out_channels: Number of out_channels in last layer.
+
"""
ft_backbone = model.features
ft_backbone.out_channels = out_channels
@@ -59,6 +62,7 @@ def create_torchvision_backbone(model_name: str, pretrained: bool = True) -> Tup
Args:
model_name: Name of the model. E.g. resnet18
pretrained: Pretrained weights dataset "imagenet", etc
+
"""
model_selected = TORCHVISION_MODEL_ZOO[model_name]
diff --git a/src/pl_bolts/models/detection/retinanet/retinanet_module.py b/src/pl_bolts/models/detection/retinanet/retinanet_module.py
index 67583fb818..78c140340e 100644
--- a/src/pl_bolts/models/detection/retinanet/retinanet_module.py
+++ b/src/pl_bolts/models/detection/retinanet/retinanet_module.py
@@ -35,6 +35,7 @@ class RetinaNet(LightningModule):
# PascalVOC using LightningCLI
python retinanet_module.py --trainer.gpus 1 --model.pretrained True
+
"""
def __init__(
diff --git a/src/pl_bolts/models/detection/yolo/darknet_network.py b/src/pl_bolts/models/detection/yolo/darknet_network.py
index 7fdfcde532..7a38a0f5d2 100644
--- a/src/pl_bolts/models/detection/yolo/darknet_network.py
+++ b/src/pl_bolts/models/detection/yolo/darknet_network.py
@@ -57,6 +57,7 @@ class DarknetNetwork(nn.Module):
overlap_loss_multiplier: Overlap loss will be scaled by this value.
confidence_loss_multiplier: Confidence loss will be scaled by this value.
class_loss_multiplier: Classification loss will be scaled by this value.
+
"""
def __init__(
@@ -128,6 +129,7 @@ def load_weights(self, weight_file: io.IOBase) -> None:
Args:
weight_file: A file-like object containing model weights in the Darknet binary format.
+
"""
if not isinstance(weight_file, io.IOBase):
raise ValueError("weight_file must be a file-like object.")
@@ -183,6 +185,7 @@ def _read_config(self, config_file: Iterable[str]) -> List[Dict[str, Any]]:
Returns:
A list of configuration sections.
+
"""
section_re = re.compile(r"\[([^]]+)\]")
list_variables = ("layers", "anchors", "mask", "scales")
@@ -304,6 +307,7 @@ def _create_convolutional(config: CONFIG, num_inputs: List[int], **kwargs: Any)
Returns:
module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in
its output.
+
"""
batch_normalize = config.get("batch_normalize", False)
padding = (config["size"] - 1) // 2 if config["pad"] else 0
@@ -333,6 +337,7 @@ def _create_maxpool(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CRE
Returns:
module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in
its output.
+
"""
layer = MaxPool(config["size"], config["stride"])
return layer, num_inputs[-1]
@@ -351,6 +356,7 @@ def _create_route(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CREAT
Returns:
module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in
its output.
+
"""
num_chunks = config.get("groups", 1)
chunk_idx = config.get("group_id", 0)
@@ -379,6 +385,7 @@ def _create_shortcut(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CR
Returns:
module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in
its output.
+
"""
layer = ShortcutLayer(config["from"])
return layer, num_inputs[-1]
@@ -394,6 +401,7 @@ def _create_upsample(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CR
Returns:
module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in
its output.
+
"""
layer = nn.Upsample(scale_factor=config["stride"], mode="nearest")
return layer, num_inputs[-1]
@@ -453,6 +461,7 @@ def _create_yolo(
Returns:
module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in
its output (always 0 for a detection layer).
+
"""
if prior_shapes is None:
# The "anchors" list alternates width and height.
diff --git a/src/pl_bolts/models/detection/yolo/layers.py b/src/pl_bolts/models/detection/yolo/layers.py
index 1be1fe10e1..041c640f08 100644
--- a/src/pl_bolts/models/detection/yolo/layers.py
+++ b/src/pl_bolts/models/detection/yolo/layers.py
@@ -33,6 +33,7 @@ def _get_padding(kernel_size: int, stride: int) -> Tuple[int, nn.Module]:
Returns:
padding, pad_op: The amount of padding to be added to all sides of the input and an ``nn.Identity`` or
``nn.ZeroPad2d`` operation to add one more column and row of padding if necessary.
+
"""
# The output size is generally (input_size + padding - max(kernel_size, stride)) / stride + 1 and we want to
# make it equal to input_size / stride.
@@ -62,6 +63,7 @@ class DetectionLayer(nn.Module):
detection layer will not take the sigmoid of the coordinate and probability predictions, and the width and
height are scaled up so that the maximum value is four times the anchor dimension. This is used by the
Darknet configurations of Scaled-YOLOv4.
+
"""
def __init__(
@@ -106,6 +108,7 @@ def forward(self, x: Tensor, image_size: Tensor) -> Tuple[Tensor, PREDS]:
The layer output, with normalized probabilities, in a tensor sized
``[batch_size, anchors_per_cell * height * width, num_classes + 5]`` and a list of dictionaries, containing
the same predictions, but with unnormalized probabilities (for loss calculation).
+
"""
batch_size, num_features, height, width = x.shape
num_attrs = self.num_classes + 5
@@ -169,6 +172,7 @@ def match_targets(
Returns:
Two dictionaries, the matched predictions and targets.
+
"""
batch_size = len(preds)
if (len(targets) != batch_size) or (len(return_preds) != batch_size):
@@ -236,6 +240,7 @@ def calculate_losses(
Returns:
A vector of the overlap, confidence, and classification loss, normalized by batch size, and the number of
targets that were matched to this layer.
+
"""
if loss_preds is None:
loss_preds = preds
@@ -266,6 +271,7 @@ class Conv(nn.Module):
activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic",
"linear", or "none".
norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none".
+
"""
def __init__(
@@ -302,6 +308,7 @@ class MaxPool(nn.Module):
The module tries to add padding so much that the output size will be the input size divided by the stride. If the
input size is not divisible by the stride, the output size will be rounded upwards.
+
"""
def __init__(self, kernel_size: int, stride: int):
@@ -321,6 +328,7 @@ class RouteLayer(nn.Module):
source_layers: Indices of the layers whose output will be concatenated.
num_chunks: Layer outputs will be split into this number of chunks.
chunk_idx: Only the chunks with this index will be concatenated.
+
"""
def __init__(self, source_layers: List[int], num_chunks: int, chunk_idx: int) -> None:
@@ -339,6 +347,7 @@ class ShortcutLayer(nn.Module):
Args:
source_layer: Index of the layer whose output will be added to the output of the previous layer.
+
"""
def __init__(self, source_layer: int) -> None:
@@ -360,6 +369,7 @@ class ReOrg(nn.Module):
"""Re-organizes the tensor so that every square region of four cells is placed into four different channels.
The result is a tensor with half the width and height, and four times as many channels.
+
"""
def forward(self, x: Tensor) -> Tensor:
@@ -376,6 +386,7 @@ def create_activation_module(name: Optional[str]) -> nn.Module:
Args:
name: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", "linear",
or "none".
+
"""
if name == "relu":
return nn.ReLU(inplace=True)
@@ -400,6 +411,7 @@ def create_normalization_module(name: Optional[str], num_channels: int) -> nn.Mo
Args:
name: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none".
num_channels: The number of input channels that the module expects.
+
"""
if name == "batchnorm":
return nn.BatchNorm2d(num_channels, eps=0.001)
@@ -464,6 +476,7 @@ def create_detection_layer(
detection layer will not take the sigmoid of the coordinate and probability predictions, and the width and
height are scaled up so that the maximum value is four times the anchor dimension. This is used by the
Darknet configurations of Scaled-YOLOv4.
+
"""
matching_func: Union[ShapeMatching, SimOTAMatching]
if matching_algorithm == "simota":
diff --git a/src/pl_bolts/models/detection/yolo/loss.py b/src/pl_bolts/models/detection/yolo/loss.py
index cb119436db..44ac5b0f11 100644
--- a/src/pl_bolts/models/detection/yolo/loss.py
+++ b/src/pl_bolts/models/detection/yolo/loss.py
@@ -78,6 +78,7 @@ def _get_iou_and_loss_functions(name: str) -> Tuple[Callable, Callable]:
Returns:
A tuple of two functions. The first function calculates the pairwise IoU and the second function calculates the
elementwise loss.
+
"""
if name not in _iou_and_loss_functions:
raise ValueError(f"Unknown IoU function '{name}'.")
@@ -101,6 +102,7 @@ def _size_compensation(targets: Tensor, image_size: Tensor) -> Tuple[Tensor, Ten
Returns:
The size compensation factor.
+
"""
unit_wh = targets[:, 2:] / image_size
return 2 - (unit_wh[:, 0] * unit_wh[:, 1])
@@ -124,6 +126,7 @@ def _pairwise_confidence_loss(
Returns:
An ``[N, M]`` matrix of confidence losses between all predictions and targets.
+
"""
if predict_overlap is not None:
# When predicting overlap, target confidence is different for each pair of a prediction and a target. The
@@ -158,6 +161,7 @@ def _foreground_confidence_loss(
Returns:
The sum of the confidence losses for foreground anchors.
+
"""
targets = torch.ones_like(preds)
if predict_overlap is not None:
@@ -176,6 +180,7 @@ def _background_confidence_loss(preds: Tensor, bce_func: Callable) -> Tensor:
Returns:
The sum of the background confidence losses.
+
"""
targets = torch.zeros_like(preds)
return bce_func(preds, targets, reduction="sum")
@@ -244,6 +249,7 @@ class YOLOLoss:
overlap_loss_multiplier: Overlap loss will be scaled by this value.
confidence_loss_multiplier: Confidence loss will be scaled by this value.
class_loss_multiplier: Classification loss will be scaled by this value.
+
"""
def __init__(
@@ -285,6 +291,7 @@ def pairwise(
Returns:
Loss matrices and an overlap matrix. Each matrix is shaped ``[N, M]``.
+
"""
loss_shape = torch.Size([len(preds["boxes"]), len(targets["boxes"])])
@@ -325,8 +332,8 @@ def elementwise_sums(
input_is_normalized: bool,
image_size: Tensor,
) -> YOLOLosses:
- """Calculates the sums of the losses for optimization, over prediction/target pairs, assuming the
- predictions and targets have been matched (there are as many predictions and targets).
+ """Calculates the sums of the losses for optimization, over prediction/target pairs, assuming the predictions
+ and targets have been matched (there are as many predictions and targets).
Args:
preds: A dictionary of predictions, containing "boxes", "confidences", and "classprobs".
@@ -336,6 +343,7 @@ def elementwise_sums(
Returns:
The final losses.
+
"""
bce_func: Callable[..., Tensor] = (
binary_cross_entropy if input_is_normalized else binary_cross_entropy_with_logits # type: ignore
diff --git a/src/pl_bolts/models/detection/yolo/target_matching.py b/src/pl_bolts/models/detection/yolo/target_matching.py
index 20951dc8f1..baca8020e3 100644
--- a/src/pl_bolts/models/detection/yolo/target_matching.py
+++ b/src/pl_bolts/models/detection/yolo/target_matching.py
@@ -17,8 +17,8 @@
class ShapeMatching(ABC):
- """Selects which anchors are used to predict each target, by comparing the shape of the target box to a set of
- prior shapes.
+ """Selects which anchors are used to predict each target, by comparing the shape of the target box to a set of prior
+ shapes.
Most YOLO variants match targets to anchors based on prior shapes that are assigned to the anchors in the model
configuration. The subclasses of ``ShapeMatching`` implement matching rules that compare the width and height of
@@ -30,6 +30,7 @@ class ShapeMatching(ABC):
ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU
with some target greater than this threshold, the predictor will not be taken into account when calculating
the confidence loss.
+
"""
def __init__(self, ignore_bg_threshold: float = 0.7) -> None:
@@ -53,6 +54,7 @@ def __call__(
Returns:
The indices of the matched predictions, background mask, and a mask for selecting the matched targets.
+
"""
height, width = preds["boxes"].shape[:2]
device = preds["boxes"].device
@@ -92,6 +94,7 @@ def match(self, wh: Tensor) -> Union[Tuple[Tensor, Tensor], Tensor]:
Returns:
matched_targets, matched_anchors: Two vectors or a `2xN` matrix. The first vector is used to select the
targets that this layer matched and the second one lists the matching anchors within the grid cell.
+
"""
pass
@@ -109,6 +112,7 @@ class HighestIoUMatching(ShapeMatching):
ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU
with some target greater than this threshold, the predictor will not be taken into account when calculating
the confidence loss.
+
"""
def __init__(
@@ -146,6 +150,7 @@ class IoUThresholdMatching(ShapeMatching):
ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the corresponding anchor
has IoU with some target greater than this threshold, the predictor will not be taken into account when
calculating the confidence loss.
+
"""
def __init__(
@@ -168,8 +173,7 @@ def match(self, wh: Tensor) -> Union[Tuple[Tensor, Tensor], Tensor]:
class SizeRatioMatching(ShapeMatching):
- """For each target, select those prior shapes, whose width and height relative to the target is below given
- ratio.
+ """For each target, select those prior shapes, whose width and height relative to the target is below given ratio.
This is the matching rule used by Ultralytics YOLOv5 implementation.
@@ -182,6 +186,7 @@ class SizeRatioMatching(ShapeMatching):
ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the corresponding anchor
has IoU with some target greater than this threshold, the predictor will not be taken into account when
calculating the confidence loss.
+
"""
def __init__(
@@ -214,6 +219,7 @@ def _sim_ota_match(costs: Tensor, ious: Tensor) -> Tuple[Tensor, Tensor]:
Returns:
A mask of predictions that were matched, and the indices of the matched targets. The latter contains as many
elements as there are ``True`` values in the mask.
+
"""
num_preds, num_targets = ious.shape
@@ -258,6 +264,7 @@ class SimOTAMatching:
target, where `N` is the value of this parameter.
size_range: For each target, restrict to the anchors whose prior dimensions are not larger than the target
dimensions multiplied by this value and not smaller than the target dimensions divided by this value.
+
"""
def __init__(
@@ -289,6 +296,7 @@ def __call__(
Returns:
A mask of predictions that were matched, background mask (inverse of the first mask), and the indices of the
matched targets. The last tensor contains as many elements as there are ``True`` values in the first mask.
+
"""
height, width, boxes_per_cell, _ = preds["boxes"].shape
prior_mask, anchor_inside_target = self._get_prior_mask(targets, image_size, width, height, boxes_per_cell)
@@ -334,6 +342,7 @@ def _get_prior_mask(
Two masks, a ``[grid_height, grid_width, boxes_per_cell]`` mask for selecting anchors that are close and
similar in shape to a target, and an ``[anchors, targets]`` matrix that indicates which targets are inside
those anchors.
+
"""
# A multiplier for scaling feature map coordinates to image coordinates
grid_size = torch.tensor([grid_width, grid_height], device=targets["boxes"].device)
diff --git a/src/pl_bolts/models/detection/yolo/torch_networks.py b/src/pl_bolts/models/detection/yolo/torch_networks.py
index 480df94974..ee5358ac7f 100644
--- a/src/pl_bolts/models/detection/yolo/torch_networks.py
+++ b/src/pl_bolts/models/detection/yolo/torch_networks.py
@@ -100,6 +100,7 @@ class BottleneckBlock(nn.Module):
activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic",
"linear", or "none".
norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none".
+
"""
def __init__(
@@ -136,6 +137,7 @@ class TinyStage(nn.Module):
activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic",
"linear", or "none".
norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none".
+
"""
def __init__(
@@ -175,6 +177,7 @@ class CSPStage(nn.Module):
activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic",
"linear", or "none".
norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none".
+
"""
def __init__(
@@ -228,6 +231,7 @@ class ELANStage(nn.Module):
activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic",
"linear", or "none".
norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none".
+
"""
def __init__(
@@ -290,6 +294,7 @@ class CSPSPP(nn.Module):
activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic",
"linear", or "none".
norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none".
+
"""
def __init__(
@@ -340,6 +345,7 @@ class FastSPP(nn.Module):
activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic",
"linear", or "none".
norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none".
+
"""
def __init__(
@@ -373,6 +379,7 @@ class YOLOV4TinyBackbone(nn.Module):
activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic",
"linear", or "none".
normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none".
+
"""
def __init__(
@@ -439,6 +446,7 @@ class YOLOV4Backbone(nn.Module):
activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic",
"linear", or "none".
normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none".
+
"""
def __init__(
@@ -507,6 +515,7 @@ class YOLOV5Backbone(nn.Module):
activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic",
"linear", or "none".
normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none".
+
"""
def __init__(
@@ -572,6 +581,7 @@ class YOLOV7Backbone(nn.Module):
activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic",
"linear", or "none".
normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none".
+
"""
def __init__(
@@ -664,6 +674,7 @@ class YOLOV4TinyNetwork(nn.Module):
class_loss_multiplier: Classification loss will be scaled by this value.
xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps
to produce coordinate values close to one.
+
"""
def __init__(
@@ -800,6 +811,7 @@ class YOLOV4Network(nn.Module):
class_loss_multiplier: Classification loss will be scaled by this value.
xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps
to produce coordinate values close to one.
+
"""
def __init__(
@@ -967,6 +979,7 @@ class YOLOV4P6Network(nn.Module):
class_loss_multiplier: Classification loss will be scaled by this value.
xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps
to produce coordinate values close to one.
+
"""
def __init__(
@@ -1326,6 +1339,7 @@ class YOLOV7Network(nn.Module):
class_loss_multiplier: Classification loss will be scaled by this value.
xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps
to produce coordinate values close to one.
+
"""
def __init__(
@@ -1552,6 +1566,7 @@ class YOLOXHead(nn.Module):
activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic",
"linear", or "none".
norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none".
+
"""
def __init__(
diff --git a/src/pl_bolts/models/detection/yolo/utils.py b/src/pl_bolts/models/detection/yolo/utils.py
index 0aab13d94e..d981fadceb 100644
--- a/src/pl_bolts/models/detection/yolo/utils.py
+++ b/src/pl_bolts/models/detection/yolo/utils.py
@@ -30,6 +30,7 @@ def grid_offsets(grid_size: Tensor) -> Tensor:
Returns:
A ``[height, width, 2]`` tensor containing the grid cell `(x, y)` offsets.
+
"""
x_range = torch.arange(grid_size[0].item(), device=grid_size.device)
y_range = torch.arange(grid_size[1].item(), device=grid_size.device)
@@ -42,6 +43,7 @@ def grid_centers(grid_size: Tensor) -> Tensor:
Returns:
A ``[height, width, 2]`` tensor containing coordinates to the centers of the grid cells.
+
"""
return grid_offsets(grid_size) + 0.5
@@ -66,6 +68,7 @@ def global_xy(xy: Tensor, image_size: Tensor) -> Tensor:
Returns:
Global coordinates scaled to the size of the network input image, in a tensor with the same shape as the input
tensor.
+
"""
height = xy.shape[1]
width = xy.shape[2]
@@ -77,8 +80,8 @@ def global_xy(xy: Tensor, image_size: Tensor) -> Tensor:
def aligned_iou(wh1: Tensor, wh2: Tensor) -> Tensor:
- """Calculates a matrix of intersections over union from box dimensions, assuming that the boxes are located at
- the same coordinates.
+ """Calculates a matrix of intersections over union from box dimensions, assuming that the boxes are located at the
+ same coordinates.
Args:
wh1: An ``[N, 2]`` matrix of box shapes (width and height).
@@ -86,6 +89,7 @@ def aligned_iou(wh1: Tensor, wh2: Tensor) -> Tensor:
Returns:
An ``[N, M]`` matrix of pairwise IoU values for every element in ``wh1`` and ``wh2``
+
"""
area1 = wh1[:, 0] * wh1[:, 1] # [N]
area2 = wh2[:, 0] * wh2[:, 1] # [M]
@@ -126,6 +130,7 @@ def is_inside_box(points: Tensor, boxes: Tensor) -> Tensor:
Returns:
A tensor shaped ``[points, boxes]`` containing pairwise truth values of whether the points are inside the boxes.
+
"""
lt = points[:, None, :] - boxes[None, :, :2] # [boxes, points, 2]
rb = boxes[None, :, 2:] - points[:, None, :] # [boxes, points, 2]
@@ -145,6 +150,7 @@ def box_size_ratio(wh1: Tensor, wh2: Tensor) -> Tensor:
Returns:
An ``[N, M]`` matrix of ratios of width or height dimensions, whichever is larger.
+
"""
wh_ratio = wh1[:, None, :] / wh2[None, :, :] # [M, N, 2]
wh_ratio = torch.max(wh_ratio, 1.0 / wh_ratio)
@@ -164,6 +170,7 @@ def get_image_size(images: Tensor) -> Tensor:
Returns:
A tensor that contains the image width and height.
+
"""
height = images.shape[2]
width = images.shape[3]
diff --git a/src/pl_bolts/models/detection/yolo/yolo_module.py b/src/pl_bolts/models/detection/yolo/yolo_module.py
index cd29409163..429ed087b7 100644
--- a/src/pl_bolts/models/detection/yolo/yolo_module.py
+++ b/src/pl_bolts/models/detection/yolo/yolo_module.py
@@ -42,8 +42,8 @@
class YOLO(LightningModule):
- """PyTorch Lightning implementation of YOLO that supports the most important features of YOLOv3, YOLOv4,
- YOLOv5, YOLOv7, Scaled-YOLOv4, and YOLOX.
+ """PyTorch Lightning implementation of YOLO that supports the most important features of YOLOv3, YOLOv4, YOLOv5,
+ YOLOv7, Scaled-YOLOv4, and YOLOX.
*YOLOv3 paper*: `Joseph Redmon and Ali Farhadi `__
@@ -103,6 +103,7 @@ class labels. *Each target is a dictionary containing the following tensors*:
nms_threshold: Non-maximum suppression will remove bounding boxes whose IoU with a higher confidence box is
higher than this threshold, if the predicted categories are equal.
detections_per_image: Keep at most this number of highest-confidence detections per image.
+
"""
def __init__(
@@ -219,6 +220,7 @@ def training_step(self, batch: BATCH, batch_idx: int) -> STEP_OUTPUT:
Returns:
A dictionary that includes the training loss in 'loss'.
+
"""
images, targets = batch
_, losses = self(images, targets)
@@ -237,6 +239,7 @@ def validation_step(self, batch: BATCH, batch_idx: int) -> Optional[STEP_OUTPUT]
batch: A tuple of images and targets. Images is a list of 3-dimensional tensors. Targets is a list of target
dictionaries.
batch_idx: Index of the current batch.
+
"""
images, targets = batch
detections, losses = self(images, targets)
@@ -269,6 +272,7 @@ def test_step(self, batch: BATCH, batch_idx: int) -> Optional[STEP_OUTPUT]:
batch: A tuple of images and targets. Images is a list of 3-dimensional tensors. Targets is a list of target
dictionaries.
batch_idx: Index of the current batch.
+
"""
images, targets = batch
detections, losses = self(images, targets)
@@ -295,8 +299,8 @@ def on_test_epoch_end(self) -> None:
self._test_map.reset()
def predict_step(self, batch: BATCH, batch_idx: int, dataloader_idx: int = 0) -> List[PRED]:
- """Feeds a batch of images to the network and returns the detected bounding boxes, confidence scores, and
- class labels.
+ """Feeds a batch of images to the network and returns the detected bounding boxes, confidence scores, and class
+ labels.
If a prediction has a high score for more than one class, it will be duplicated.
@@ -310,14 +314,14 @@ class labels.
A list of dictionaries containing tensors "boxes", "scores", and "labels". "boxes" is a matrix of detected
bounding box `(x1, y1, x2, y2)` coordinates. "scores" is a vector of confidence scores for the bounding box
detections. "labels" is a vector of predicted class labels.
+
"""
images, _ = batch
detections = self(images)
return self.process_detections(detections)
def infer(self, image: Tensor) -> PRED:
- """Feeds an image to the network and returns the detected bounding boxes, confidence scores, and class
- labels.
+ """Feeds an image to the network and returns the detected bounding boxes, confidence scores, and class labels.
If a prediction has a high score for more than one class, it will be duplicated.
@@ -328,6 +332,7 @@ def infer(self, image: Tensor) -> PRED:
A dictionary containing tensors "boxes", "scores", and "labels". "boxes" is a matrix of detected bounding
box `(x1, y1, x2, y2)` coordinates. "scores" is a vector of confidence scores for the bounding box
detections. "labels" is a vector of predicted class labels.
+
"""
if not isinstance(image, Tensor):
image = T.to_tensor(image)
@@ -344,9 +349,8 @@ def infer(self, image: Tensor) -> PRED:
return detections
def process_detections(self, preds: Tensor) -> List[PRED]:
- """Splits the detection tensor returned by a forward pass into a list of prediction dictionaries, and
- filters them based on confidence threshold, non-maximum suppression (NMS), and maximum number of
- predictions.
+ """Splits the detection tensor returned by a forward pass into a list of prediction dictionaries, and filters
+ them based on confidence threshold, non-maximum suppression (NMS), and maximum number of predictions.
If for any single detection there are multiple categories whose score is above the confidence threshold, the
detection will be duplicated to create one detection for each category. NMS processes one category at a time,
@@ -363,6 +367,7 @@ def process_detections(self, preds: Tensor) -> List[PRED]:
Returns:
Filtered detections. A list of prediction dictionaries, one for each image.
+
"""
def process(boxes: Tensor, confidences: Tensor, classprobs: Tensor) -> Dict[str, Any]:
@@ -389,6 +394,7 @@ def process_targets(self, targets: TARGETS) -> List[TARGET]:
Returns:
Single-label targets. A list of target dictionaries, one for each image.
+
"""
def process(boxes: Tensor, labels: Tensor, **other: Any) -> Dict[str, Any]:
@@ -406,6 +412,7 @@ def validate_batch(self, images: Union[Tensor, IMAGES], targets: Optional[TARGET
images: A tensor containing a batch of images or a list of image tensors.
targets: A list of target dictionaries or ``None``. If a list is provided, there should be as many target
dictionaries as there are images.
+
"""
if not isinstance(images, Tensor):
if not isinstance(images, (tuple, list)):
@@ -503,6 +510,7 @@ class CLIYOLO(YOLO):
overlap_loss_multiplier: Overlap loss will be scaled by this value.
confidence_loss_multiplier: Confidence loss will be scaled by this value.
class_loss_multiplier: Classification loss will be scaled by this value.
+
"""
def __init__(
diff --git a/src/pl_bolts/models/gans/basic/basic_gan_module.py b/src/pl_bolts/models/gans/basic/basic_gan_module.py
index 6080b10d71..965577b6bd 100644
--- a/src/pl_bolts/models/gans/basic/basic_gan_module.py
+++ b/src/pl_bolts/models/gans/basic/basic_gan_module.py
@@ -27,6 +27,7 @@ class GAN(LightningModule):
python basic_gan_module.py --gpus 1 --dataset 'imagenet2012'
--data_dir /path/to/imagenet/folder/ --meta_dir ~/path/to/meta/bin/folder
--batch_size 256 --learning_rate 0.0001
+
"""
def __init__(
@@ -70,6 +71,7 @@ def forward(self, z):
z = torch.rand(batch_size, latent_dim)
gan = GAN.load_from_checkpoint(PATH)
img = gan(z)
+
"""
return self.generator(z)
diff --git a/src/pl_bolts/models/gans/dcgan/dcgan_module.py b/src/pl_bolts/models/gans/dcgan/dcgan_module.py
index b5a0dc2e11..01202a25b0 100644
--- a/src/pl_bolts/models/gans/dcgan/dcgan_module.py
+++ b/src/pl_bolts/models/gans/dcgan/dcgan_module.py
@@ -36,6 +36,7 @@ class DCGAN(LightningModule):
# cifar10
python dcgan_module.py --gpus 1 --dataset cifar10 --image_channels 3
+
"""
def __init__(
@@ -99,6 +100,7 @@ def forward(self, noise: Tensor) -> Tensor:
noise = torch.rand(batch_size, latent_dim)
gan = GAN.load_from_checkpoint(PATH)
img = gan(noise)
+
"""
noise = noise.view(*noise.shape, 1, 1)
return self.generator(noise)
diff --git a/src/pl_bolts/models/gans/srgan/srgan_module.py b/src/pl_bolts/models/gans/srgan/srgan_module.py
index 3f9d9e6b43..ef11f10dc2 100644
--- a/src/pl_bolts/models/gans/srgan/srgan_module.py
+++ b/src/pl_bolts/models/gans/srgan/srgan_module.py
@@ -98,6 +98,7 @@ def forward(self, lr_image: torch.Tensor) -> torch.Tensor:
srgan = SRGAN.load_from_checkpoint(PATH)
hr_image = srgan(lr_image)
+
"""
return self.generator(lr_image)
diff --git a/src/pl_bolts/models/gans/srgan/srresnet_module.py b/src/pl_bolts/models/gans/srgan/srresnet_module.py
index b21eafc6a9..fc6ba2498b 100644
--- a/src/pl_bolts/models/gans/srgan/srresnet_module.py
+++ b/src/pl_bolts/models/gans/srgan/srresnet_module.py
@@ -77,6 +77,7 @@ def forward(self, lr_image: torch.Tensor) -> torch.Tensor:
srresnet = SRResNet.load_from_checkpoint(PATH)
hr_image = srresnet(lr_image)
+
"""
return self.srresnet(lr_image)
diff --git a/src/pl_bolts/models/mnist_module.py b/src/pl_bolts/models/mnist_module.py
index 502617740f..83af348c13 100644
--- a/src/pl_bolts/models/mnist_module.py
+++ b/src/pl_bolts/models/mnist_module.py
@@ -24,6 +24,7 @@ class LitMNIST(LightningModule):
trainer = Trainer()
trainer.fit(model, datamodule=datamodule)
+
"""
def __init__(self, hidden_dim: int = 128, learning_rate: float = 1e-3, **kwargs: Any) -> None:
diff --git a/src/pl_bolts/models/regression/logistic_regression.py b/src/pl_bolts/models/regression/logistic_regression.py
index 44c7a5f1aa..26ec01a198 100644
--- a/src/pl_bolts/models/regression/logistic_regression.py
+++ b/src/pl_bolts/models/regression/logistic_regression.py
@@ -45,6 +45,7 @@ def __init__(
linear: Linear layer.
criterion: Cross-Entropy loss function.
optimizer: Model optimizer to use.
+
"""
super().__init__()
self.save_hyperparameters()
@@ -62,6 +63,7 @@ def forward(self, x: Tensor) -> Tensor:
Returns:
Output tensor.
+
"""
return self.linear(x)
@@ -74,6 +76,7 @@ def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[st
Returns:
Loss tensor.
+
"""
return self._shared_step(batch, "train")
@@ -86,6 +89,7 @@ def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[
Returns:
Loss tensor.
+
"""
return self._shared_step(batch, "val")
@@ -98,6 +102,7 @@ def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, T
Returns:
Loss tensor.
+
"""
return self._shared_step(batch, "test")
@@ -109,6 +114,7 @@ def validation_epoch_end(self, outputs: List[Dict[str, Tensor]]) -> Dict[str, Te
Returns:
Loss tensor.
+
"""
return self._shared_epoch_end(outputs, "val")
@@ -120,6 +126,7 @@ def test_epoch_end(self, outputs: List[Dict[str, Tensor]]) -> Dict[str, Tensor]:
Returns:
Loss tensor.
+
"""
return self._shared_epoch_end(outputs, "test")
@@ -128,6 +135,7 @@ def configure_optimizers(self) -> Optimizer:
Returns:
Optimizer.
+
"""
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate)
@@ -185,6 +193,7 @@ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
Returns:
ArgumentParser with the added arguments.
+
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--learning_rate", type=float, default=0.0001)
diff --git a/src/pl_bolts/models/rl/advantage_actor_critic_model.py b/src/pl_bolts/models/rl/advantage_actor_critic_model.py
index 43418378b1..e4863e32fe 100644
--- a/src/pl_bolts/models/rl/advantage_actor_critic_model.py
+++ b/src/pl_bolts/models/rl/advantage_actor_critic_model.py
@@ -93,14 +93,15 @@ def __init__(
self.state = self.env.reset()
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Passes in a state x through the network and gets the log prob of each action and the value for the state
- as an output.
+ """Passes in a state x through the network and gets the log prob of each action and the value for the state as
+ an output.
Args:
x: environment state
Returns:
action log probabilities, values
+
"""
if not isinstance(x, list):
x = [x]
@@ -123,6 +124,7 @@ def train_batch(self) -> Iterator[Tuple[np.ndarray, int, Tensor]]:
states: a list of numpy array
actions: a list of list of int
returns: a torch tensor
+
"""
while True:
for _ in range(self.hparams.batch_size):
@@ -170,6 +172,7 @@ def compute_returns(
Returns:
tensor of discounted rewards
+
"""
g = last_value
returns = []
@@ -187,13 +190,14 @@ def loss(
actions: Tensor,
returns: Tensor,
) -> Tensor:
- """Calculates the loss for A2C which is a weighted sum of actor loss (MSE), critic loss (PG), and entropy
- (for exploration)
+ """Calculates the loss for A2C which is a weighted sum of actor loss (MSE), critic loss (PG), and entropy (for
+ exploration)
Args:
states: tensor of shape (batch_size, state dimension)
actions: tensor of shape (batch_size, )
returns: tensor of shape (batch_size, )
+
"""
logprobs, values = self.net(states)
@@ -226,6 +230,7 @@ def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Ordered
Args:
batch: a batch of (states, actions, returns)
+
"""
states, actions, returns = batch
loss = self.loss(states, actions, returns)
@@ -271,6 +276,7 @@ def add_model_specific_args(arg_parser: ArgumentParser) -> ArgumentParser:
Returns:
arg_parser with model specific cargs added
+
"""
arg_parser.add_argument("--entropy_beta", type=float, default=0.01, help="entropy coefficient")
diff --git a/src/pl_bolts/models/rl/common/agents.py b/src/pl_bolts/models/rl/common/agents.py
index a762cbfc99..116b0b89dd 100644
--- a/src/pl_bolts/models/rl/common/agents.py
+++ b/src/pl_bolts/models/rl/common/agents.py
@@ -29,6 +29,7 @@ def __call__(self, state: Tensor, device: str, *args, **kwargs) -> List[int]:
Returns:
action
+
"""
return [0]
@@ -62,6 +63,7 @@ def __call__(self, state: Tensor, device: str) -> List[int]:
Returns:
action defined by policy
+
"""
if not isinstance(state, list):
state = [state]
@@ -81,6 +83,7 @@ def get_action(self, state: Tensor, device: torch.device):
Returns:
action defined by Q values
+
"""
if not isinstance(state, Tensor):
state = torch.tensor(state, device=device)
@@ -94,6 +97,7 @@ def update_epsilon(self, step: int) -> None:
Args:
step: current global step
+
"""
self.epsilon = max(self.eps_end, self.eps_start - (step + 1) / self.eps_frames)
@@ -112,6 +116,7 @@ def __call__(self, states: Tensor, device: str) -> List[int]:
Returns:
action defined by policy
+
"""
if not isinstance(states, list):
states = [states]
@@ -139,6 +144,7 @@ def __call__(self, states: Tensor, device: str) -> List[int]:
Returns:
action defined by policy
+
"""
if not isinstance(states, list):
states = [states]
@@ -167,6 +173,7 @@ def __call__(self, states: Tensor, device: str) -> List[float]:
Returns:
action defined by policy
+
"""
if not isinstance(states, list):
states = [states]
@@ -186,6 +193,7 @@ def get_action(self, states: Tensor, device: str) -> List[float]:
Returns:
action defined by policy
+
"""
if not isinstance(states, list):
states = [states]
diff --git a/src/pl_bolts/models/rl/common/cli.py b/src/pl_bolts/models/rl/common/cli.py
index 3ce60fc3f6..951a299a0c 100644
--- a/src/pl_bolts/models/rl/common/cli.py
+++ b/src/pl_bolts/models/rl/common/cli.py
@@ -14,6 +14,7 @@ def add_base_args(parent) -> argparse.ArgumentParser:
Args:
parent
+
"""
arg_parser = argparse.ArgumentParser(parents=[parent])
diff --git a/src/pl_bolts/models/rl/common/distributions.py b/src/pl_bolts/models/rl/common/distributions.py
index 9a5d812eab..c589c2db3a 100644
--- a/src/pl_bolts/models/rl/common/distributions.py
+++ b/src/pl_bolts/models/rl/common/distributions.py
@@ -24,6 +24,7 @@ def rsample_with_z(self, sample_shape=torch.Size()):
Returns:
Sampled X and Z
+
"""
z = super().rsample()
return self.action_scale * torch.tanh(z) + self.action_bias, z
@@ -38,6 +39,7 @@ def log_prob_with_z(self, value, z):
z: the value of Z
Returns:
Log probability of the sample
+
"""
value = (value - self.action_bias) / self.action_scale
z_logprob = super().log_prob(z)
@@ -49,6 +51,7 @@ def rsample_and_log_prob(self, sample_shape=torch.Size()):
Returns:
Sampled X and log probability
+
"""
z = super().rsample()
z_logprob = super().log_prob(z)
diff --git a/src/pl_bolts/models/rl/common/gym_wrappers.py b/src/pl_bolts/models/rl/common/gym_wrappers.py
index b9f4e0bea7..605b498a7a 100644
--- a/src/pl_bolts/models/rl/common/gym_wrappers.py
+++ b/src/pl_bolts/models/rl/common/gym_wrappers.py
@@ -102,6 +102,7 @@ def reset(self):
"""Clear past frame buffer and init.
to first obs. from inner env.
+
"""
self._obs_buffer.clear()
obs = self.env.reset()
@@ -202,6 +203,7 @@ class DataAugmentation(ObservationWrapper):
- ToTensor
- GrayScale
- RandomCrop
+
"""
def __init__(self, env=None) -> None:
diff --git a/src/pl_bolts/models/rl/common/memory.py b/src/pl_bolts/models/rl/common/memory.py
index 47f5f6b927..3e63a1b437 100644
--- a/src/pl_bolts/models/rl/common/memory.py
+++ b/src/pl_bolts/models/rl/common/memory.py
@@ -31,6 +31,7 @@ def append(self, experience: Experience) -> None:
Args:
experience: tuple (state, action, reward, done, new_state)
+
"""
self.buffer.append(experience)
@@ -66,6 +67,7 @@ def sample(self, batch_size: int) -> Tuple:
Returns:
a batch of tuple np arrays of state, action, reward, done, next_state
+
"""
indices = np.random.choice(len(self.buffer), batch_size, replace=False)
@@ -103,6 +105,7 @@ def append(self, exp: Experience) -> None:
Args:
exp: tuple (state, action, reward, done, new_state)
+
"""
self.update_history_queue(exp) # add single step experience to history
while self.exp_history_queue: # go through all the n_steps that have been queued
@@ -123,14 +126,15 @@ def append(self, exp: Experience) -> None:
self.buffer.append(n_step_exp) # add n_step experience to buffer
def update_history_queue(self, exp) -> None:
- """Updates the experience history queue with the lastest experiences. In the event of an experience step is
- in the done state, the history will be incrementally appended to the queue, removing the tail of the
- history each time.
+ """Updates the experience history queue with the lastest experiences. In the event of an experience step is in
+ the done state, the history will be incrementally appended to the queue, removing the tail of the history each
+ time.
Args:
env_idx: index of the environment
exp: the current experience
history: history of experience steps for this environment
+
"""
self.history.append(exp)
@@ -157,14 +161,15 @@ def update_history_queue(self, exp) -> None:
self.history.clear()
def split_head_tail_exp(self, experiences: Tuple[Experience]) -> Tuple[List, Tuple[Experience]]:
- """Takes in a tuple of experiences and returns the last state and tail experiences based on if the last
- state is the end of an episode.
+ """Takes in a tuple of experiences and returns the last state and tail experiences based on if the last state is
+ the end of an episode.
Args:
experiences: Tuple of N Experience
Returns:
last state (Array or None) and remaining Experience
+
"""
last_exp_state = experiences[-1].new_state
tail_experiences = experiences
@@ -182,6 +187,7 @@ def discount_rewards(self, experiences: Tuple[Experience]) -> float:
Returns:
total discounted reward
+
"""
total_reward = 0.0
for exp in reversed(experiences):
@@ -238,6 +244,7 @@ def update_beta(self, step) -> float:
Returns:
beta value for this indexed experience
+
"""
beta_val = self.beta_start + step * (1.0 - self.beta_start) / self.beta_frames
self.beta = min(1.0, beta_val)
@@ -249,6 +256,7 @@ def append(self, exp) -> None:
Args:
exp: experience tuple being added to the buffer
+
"""
# what is the max priority for new sample
max_prio = self.priorities.max() if self.buffer else 1.0
@@ -272,6 +280,7 @@ def sample(self, batch_size=32) -> Tuple:
Returns:
sample of experiences chosen with ranked probability
+
"""
# get list of priority rankings
prios = self.priorities if len(self.buffer) == self.capacity else self.priorities[: self.pos]
@@ -308,6 +317,7 @@ def update_priorities(self, batch_indices: List, batch_priorities: List) -> None
Args:
batch_indices: index of each datum in the batch
batch_priorities: priority of each datum in the batch
+
"""
for idx, prio in zip(batch_indices, batch_priorities):
self.priorities[idx] = prio
diff --git a/src/pl_bolts/models/rl/common/networks.py b/src/pl_bolts/models/rl/common/networks.py
index 2e8ba80799..63aad43a11 100644
--- a/src/pl_bolts/models/rl/common/networks.py
+++ b/src/pl_bolts/models/rl/common/networks.py
@@ -43,6 +43,7 @@ def _get_conv_out(self, shape) -> int:
shape: input dimensions
Returns:
size of the conv output
+
"""
conv_out = self.conv(torch.zeros(1, *shape))
return int(np.prod(conv_out.size()))
@@ -54,6 +55,7 @@ def forward(self, input_x) -> Tensor:
x: input to network
Returns:
output of network
+
"""
conv_out = self.conv(input_x).view(input_x.size()[0], -1)
return self.head(conv_out)
@@ -85,6 +87,7 @@ def forward(self, input_x):
Returns:
output of network
+
"""
return self.net(input_x.float())
@@ -126,6 +129,7 @@ def forward(self, x: FloatTensor) -> TanhMultivariateNormal:
x: input to network
Returns:
action distribution
+
"""
x = self.shared_net(x.float())
batch_mean = self.mean_layer(x)
@@ -142,6 +146,7 @@ def get_action(self, x: FloatTensor) -> Tensor:
x: input to network
Returns:
mean action
+
"""
x = self.shared_net(x.float())
batch_mean = self.mean_layer(x)
@@ -173,6 +178,7 @@ def forward(self, x) -> Tuple[Tensor, Tensor]:
Returns:
action log probs (logits), value
+
"""
x = F.relu(self.fc1(x.float()))
a = F.log_softmax(self.actor_head(x), dim=-1)
@@ -214,6 +220,7 @@ def forward(self, input_x):
Returns:
Q value
+
"""
adv, val = self.adv_val(input_x)
return val + (adv - adv.mean(dim=1, keepdim=True))
@@ -226,6 +233,7 @@ def adv_val(self, input_x) -> Tuple[Tensor, Tensor]:
Returns:
advantage, value
+
"""
float_x = input_x.float()
base_out = self.net(float_x)
@@ -270,6 +278,7 @@ def _get_conv_out(self, shape) -> int:
Returns:
size of the conv output
+
"""
conv_out = self.conv(torch.zeros(1, *shape))
return int(np.prod(conv_out.size()))
@@ -282,6 +291,7 @@ def forward(self, input_x):
Returns:
Q value
+
"""
adv, val = self.adv_val(input_x)
return val + (adv - adv.mean(dim=1, keepdim=True))
@@ -294,6 +304,7 @@ def adv_val(self, input_x):
Returns:
advantage, value
+
"""
float_x = input_x.float()
base_out = self.conv(input_x).view(float_x.size()[0], -1)
@@ -332,6 +343,7 @@ def _get_conv_out(self, shape) -> int:
Returns:
size of the conv output
+
"""
conv_out = self.conv(torch.zeros(1, *shape))
return int(np.prod(conv_out.size()))
@@ -344,6 +356,7 @@ def forward(self, input_x) -> Tensor:
Returns:
output of network
+
"""
conv_out = self.conv(input_x).view(input_x.size()[0], -1)
return self.head(conv_out)
@@ -360,6 +373,7 @@ class NoisyLinear(nn.Linear):
based on https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/
Chapter08/lib/dqn_extra.py#L19
+
"""
def __init__(self, in_features: int, out_features: int, sigma_init: float = 0.017, bias: bool = True) -> None:
@@ -399,6 +413,7 @@ def forward(self, input_x: Tensor) -> Tensor:
Returns:
output of the layer
+
"""
self.epsilon_weight.normal_()
bias = self.bias
@@ -413,8 +428,7 @@ def forward(self, input_x: Tensor) -> Tensor:
@under_review()
class ActorCategorical(nn.Module):
- """Policy network, for discrete action spaces, which returns a distribution and an action given an
- observation."""
+ """Policy network, for discrete action spaces, which returns a distribution and an action given an observation."""
def __init__(self, actor_net: nn.Module) -> None:
"""
@@ -441,14 +455,14 @@ def get_log_prob(self, pi: Categorical, actions: Tensor):
Returns:
log probability of the acition under pi
+
"""
return pi.log_prob(actions)
@under_review()
class ActorContinous(nn.Module):
- """Policy network, for continous action spaces, which returns a distribution and an action given an
- observation."""
+ """Policy network, for continous action spaces, which returns a distribution and an action given an observation."""
def __init__(self, actor_net: nn.Module, act_dim: int) -> None:
"""
@@ -478,5 +492,6 @@ def get_log_prob(self, pi: Normal, actions: Tensor):
Returns:
log probability of the acition under pi
+
"""
return pi.log_prob(actions).sum(axis=-1)
diff --git a/src/pl_bolts/models/rl/double_dqn_model.py b/src/pl_bolts/models/rl/double_dqn_model.py
index 7f0937f915..2d76279c87 100644
--- a/src/pl_bolts/models/rl/double_dqn_model.py
+++ b/src/pl_bolts/models/rl/double_dqn_model.py
@@ -40,11 +40,12 @@ class DoubleDQN(DQN):
Currently only supports CPU and single GPU training with `accelerator=dp`
.. _`Double DQN`: https://arxiv.org/pdf/1509.06461.pdf
+
"""
def training_step(self, batch: Tuple[Tensor, Tensor], _) -> OrderedDict:
- """Carries out a single step through the environment to update the replay buffer. Then calculates loss
- based on the minibatch recieved.
+ """Carries out a single step through the environment to update the replay buffer. Then calculates loss based on
+ the minibatch recieved.
Args:
batch: current mini batch of replay data
@@ -52,6 +53,7 @@ def training_step(self, batch: Tuple[Tensor, Tensor], _) -> OrderedDict:
Returns:
Training loss and log metrics
+
"""
# calculates training loss
diff --git a/src/pl_bolts/models/rl/dqn_model.py b/src/pl_bolts/models/rl/dqn_model.py
index 1aea71a11c..567aa8d185 100644
--- a/src/pl_bolts/models/rl/dqn_model.py
+++ b/src/pl_bolts/models/rl/dqn_model.py
@@ -56,6 +56,7 @@ class DQN(LightningModule):
Note:
Currently only supports CPU and single GPU training with `accelerator=dp`
+
"""
def __init__(
@@ -158,6 +159,7 @@ def run_n_episodes(self, env, n_epsiodes: int = 1, epsilon: float = 1.0) -> List
env: environment to use, either train environment or test environment
n_epsiodes: number of episodes to run
epsilon: epsilon value for DQN agent
+
"""
total_rewards = []
@@ -206,6 +208,7 @@ def forward(self, x: Tensor) -> Tensor:
Returns:
q values
+
"""
return self.net(x)
@@ -216,6 +219,7 @@ def train_batch(
Returns:
yields a Experience tuple containing the state, action, reward, done and next_state.
+
"""
episode_reward = 0
episode_steps = 0
@@ -254,8 +258,8 @@ def train_batch(
break
def training_step(self, batch: Tuple[Tensor, Tensor], _) -> OrderedDict:
- """Carries out a single step through the environment to update the replay buffer. Then calculates loss
- based on the minibatch recieved.
+ """Carries out a single step through the environment to update the replay buffer. Then calculates loss based on
+ the minibatch recieved.
Args:
batch: current mini batch of replay data
@@ -263,6 +267,7 @@ def training_step(self, batch: Tuple[Tensor, Tensor], _) -> OrderedDict:
Returns:
Training loss and log metrics
+
"""
# calculates training loss
@@ -336,6 +341,7 @@ def make_environment(env_name: str, seed: Optional[int] = None) -> Env:
Returns:
gym environment
+
"""
env = make_environment(env_name)
@@ -355,6 +361,7 @@ def add_model_specific_args(
Args:
arg_parser: parent parser
+
"""
arg_parser.add_argument(
"--sync_rate",
diff --git a/src/pl_bolts/models/rl/noisy_dqn_model.py b/src/pl_bolts/models/rl/noisy_dqn_model.py
index 904abff2ec..76b4531c5b 100644
--- a/src/pl_bolts/models/rl/noisy_dqn_model.py
+++ b/src/pl_bolts/models/rl/noisy_dqn_model.py
@@ -49,11 +49,12 @@ def train_batch(
self,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Contains the logic for generating a new batch of data to be passed to the DataLoader. This is the same
- function as the standard DQN except that we dont update epsilon as it is always 0. The exploration comes
- from the noisy network.
+ function as the standard DQN except that we dont update epsilon as it is always 0. The exploration comes from
+ the noisy network.
Returns:
yields a Experience tuple containing the state, action, reward, done and next_state.
+
"""
episode_reward = 0
episode_steps = 0
diff --git a/src/pl_bolts/models/rl/per_dqn_model.py b/src/pl_bolts/models/rl/per_dqn_model.py
index dbf30a04cc..a864afb51b 100644
--- a/src/pl_bolts/models/rl/per_dqn_model.py
+++ b/src/pl_bolts/models/rl/per_dqn_model.py
@@ -43,6 +43,7 @@ class PERDQN(DQN):
.. note:: Currently only supports CPU and single GPU training with `accelerator=dp`
.. _`DQN With Prioritized Experience Replay`: https://arxiv.org/abs/1511.05952
+
"""
def train_batch(
@@ -52,6 +53,7 @@ def train_batch(
Returns:
yields a Experience tuple containing the state, action, reward, done and next_state.
+
"""
episode_reward = 0
episode_steps = 0
@@ -102,8 +104,8 @@ def train_batch(
], weights[idx]
def training_step(self, batch, _) -> OrderedDict:
- """Carries out a single step through the environment to update the replay buffer. Then calculates loss
- based on the minibatch recieved.
+ """Carries out a single step through the environment to update the replay buffer. Then calculates loss based on
+ the minibatch recieved.
Args:
batch: current mini batch of replay data
@@ -111,6 +113,7 @@ def training_step(self, batch, _) -> OrderedDict:
Returns:
Training loss and log metrics
+
"""
samples, indices, weights = batch
indices = indices.cpu().numpy()
diff --git a/src/pl_bolts/models/rl/ppo_model.py b/src/pl_bolts/models/rl/ppo_model.py
index b0935aa9b6..21bc0873c0 100644
--- a/src/pl_bolts/models/rl/ppo_model.py
+++ b/src/pl_bolts/models/rl/ppo_model.py
@@ -41,6 +41,7 @@ class PPO(LightningModule):
Note:
Currently only supports CPU and single GPU training with ``accelerator=dp``
+
"""
def __init__(
@@ -129,6 +130,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
Returns:
Tuple of policy and action
+
"""
pi, action = self.actor(x)
value = self.critic(x)
@@ -144,6 +146,7 @@ def discount_rewards(self, rewards: List[float], discount: float) -> List[float]
Returns:
list of discounted rewards/advantages
+
"""
assert isinstance(rewards[0], float)
@@ -166,6 +169,7 @@ def calc_advantage(self, rewards: List[float], values: List[float], last_value:
Returns:
list of advantages
+
"""
rews = rewards + [last_value]
vals = values + [last_value]
@@ -178,6 +182,7 @@ def generate_trajectory_samples(self) -> Tuple[List[Tensor], List[Tensor], List[
Yield:
Tuple of Lists containing tensors for states, actions, log probs, qvals and advantage
+
"""
for step in range(self.steps_per_epoch):
@@ -278,6 +283,7 @@ def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx, optimizer_idx):
Returns:
loss
+
"""
state, action, old_logp, qval, adv = batch
diff --git a/src/pl_bolts/models/rl/reinforce_model.py b/src/pl_bolts/models/rl/reinforce_model.py
index c554e68383..876686fab2 100644
--- a/src/pl_bolts/models/rl/reinforce_model.py
+++ b/src/pl_bolts/models/rl/reinforce_model.py
@@ -53,6 +53,7 @@ class Reinforce(LightningModule):
.. _REINFORCE:
https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf
+
"""
def __init__(
@@ -125,6 +126,7 @@ def forward(self, x: Tensor) -> Tensor:
Returns:
q values
+
"""
return self.net(x)
@@ -136,6 +138,7 @@ def calc_qvals(self, rewards: List[float]) -> List[float]:
Returns:
list of discounted rewards
+
"""
assert isinstance(rewards[0], float)
@@ -156,6 +159,7 @@ def discount_rewards(self, experiences: Tuple[Experience]) -> float:
Returns:
total discounted reward
+
"""
total_reward = 0.0
for exp in reversed(experiences):
@@ -169,6 +173,7 @@ def train_batch(
Yield:
yields a tuple of Lists containing tensors for states, actions and rewards of the batch.
+
"""
while True:
@@ -215,8 +220,8 @@ def loss(self, states, actions, scaled_rewards) -> Tensor:
return -log_prob_actions.mean()
def training_step(self, batch: Tuple[Tensor, Tensor], _) -> OrderedDict:
- """Carries out a single step through the environment to update the replay buffer. Then calculates loss
- based on the minibatch recieved.
+ """Carries out a single step through the environment to update the replay buffer. Then calculates loss based on
+ the minibatch recieved.
Args:
batch: current mini batch of replay data
@@ -224,6 +229,7 @@ def training_step(self, batch: Tuple[Tensor, Tensor], _) -> OrderedDict:
Returns:
Training loss and log metrics
+
"""
states, actions, scaled_rewards = batch
@@ -274,6 +280,7 @@ def add_model_specific_args(arg_parser) -> argparse.ArgumentParser:
Returns:
arg_parser with model specific cargs added
+
"""
arg_parser.add_argument("--batches_per_epoch", type=int, default=10000, help="number of batches in an epoch")
arg_parser.add_argument("--batch_size", type=int, default=32, help="size of the batches")
diff --git a/src/pl_bolts/models/rl/sac_model.py b/src/pl_bolts/models/rl/sac_model.py
index d0142ce8ed..8c0bb2b712 100644
--- a/src/pl_bolts/models/rl/sac_model.py
+++ b/src/pl_bolts/models/rl/sac_model.py
@@ -98,6 +98,7 @@ def run_n_episodes(self, env, n_epsiodes: int = 1) -> List[int]:
Args:
env: environment to use, either train environment or test environment
n_epsiodes: number of episodes to run
+
"""
total_rewards = []
@@ -153,6 +154,7 @@ def soft_update_target(self, q_net, target_net):
Args:
q_net: the critic (q) network
target_net: the target (q) network
+
"""
for q_param, target_param in zip(q_net.parameters(), target_net.parameters()):
target_param.data.copy_(
@@ -167,6 +169,7 @@ def forward(self, x: Tensor) -> Tensor:
Returns:
q values
+
"""
return self.policy(x).sample()
@@ -177,6 +180,7 @@ def train_batch(
Returns:
yields a Experience tuple containing the state, action, reward, done and next_state.
+
"""
episode_reward = 0
episode_steps = 0
@@ -218,6 +222,7 @@ def loss(self, batch: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]) -> Tuple[Te
Args:
batch: a batch of states, actions, rewards, dones, and next states
+
"""
states, actions, rewards, dones, next_states = batch
rewards = rewards.unsqueeze(-1)
@@ -257,12 +262,13 @@ def loss(self, batch: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]) -> Tuple[Te
return policy_loss, q1_loss, q2_loss
def training_step(self, batch: Tuple[Tensor, Tensor], _):
- """Carries out a single step through the environment to update the replay buffer. Then calculates loss
- based on the minibatch recieved.
+ """Carries out a single step through the environment to update the replay buffer. Then calculates loss based on
+ the minibatch recieved.
Args:
batch: current mini batch of replay data
_: batch number, not used
+
"""
policy_optim, q1_optim, q2_optim = self.optimizers()
policy_loss, q1_loss, q2_loss = self.loss(batch)
@@ -343,6 +349,7 @@ def add_model_specific_args(
Args:
arg_parser: parent parser
+
"""
arg_parser.add_argument(
"--sync_rate",
diff --git a/src/pl_bolts/models/rl/vanilla_policy_gradient_model.py b/src/pl_bolts/models/rl/vanilla_policy_gradient_model.py
index de463d3515..bebfbc7637 100644
--- a/src/pl_bolts/models/rl/vanilla_policy_gradient_model.py
+++ b/src/pl_bolts/models/rl/vanilla_policy_gradient_model.py
@@ -53,6 +53,7 @@ class VanillaPolicyGradient(LightningModule):
.. _`Vanilla Policy Gradient`:
https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf
+
"""
def __init__(
@@ -118,6 +119,7 @@ def forward(self, x: Tensor) -> Tensor:
Returns:
q values
+
"""
return self.net(x)
@@ -128,6 +130,7 @@ def train_batch(
Returns:
yields a tuple of Lists containing tensors for states, actions and rewards of the batch.
+
"""
while True:
@@ -163,6 +166,7 @@ def compute_returns(self, rewards):
Returns:
list of discounted rewards
+
"""
reward = 0
returns = []
@@ -184,6 +188,7 @@ def loss(self, states, actions, scaled_rewards) -> Tensor:
Returns:
loss for the current batch
+
"""
logits = self.net(states)
@@ -202,8 +207,8 @@ def loss(self, states, actions, scaled_rewards) -> Tensor:
return policy_loss + entropy_loss
def training_step(self, batch: Tuple[Tensor, Tensor], _) -> OrderedDict:
- """Carries out a single step through the environment to update the replay buffer. Then calculates loss
- based on the minibatch recieved.
+ """Carries out a single step through the environment to update the replay buffer. Then calculates loss based on
+ the minibatch recieved.
Args:
batch: current mini batch of replay data
@@ -211,6 +216,7 @@ def training_step(self, batch: Tuple[Tensor, Tensor], _) -> OrderedDict:
Returns:
Training loss and log metrics
+
"""
states, actions, scaled_rewards = batch
@@ -260,6 +266,7 @@ def add_model_specific_args(arg_parser) -> argparse.ArgumentParser:
Returns:
arg_parser with model specific cargs added
+
"""
arg_parser.add_argument("--entropy_beta", type=float, default=0.01, help="entropy value")
diff --git a/src/pl_bolts/models/self_supervised/__init__.py b/src/pl_bolts/models/self_supervised/__init__.py
index 7286fe9dc8..f501ee73a0 100644
--- a/src/pl_bolts/models/self_supervised/__init__.py
+++ b/src/pl_bolts/models/self_supervised/__init__.py
@@ -1,5 +1,5 @@
-"""These models have been pre-trained using self-supervised learning. The models can also be used without pre-
-training and overwritten for your own research.
+"""These models have been pre-trained using self-supervised learning. The models can also be used without pre- training
+and overwritten for your own research.
Here's an example for using these as pretrained models.
@@ -15,6 +15,7 @@
# use these in classification or any downstream task
classifications = classifier(representations)
+
"""
from pl_bolts.models.self_supervised.amdim.amdim_module import AMDIM
from pl_bolts.models.self_supervised.byol.byol_module import BYOL
diff --git a/src/pl_bolts/models/self_supervised/amdim/amdim_module.py b/src/pl_bolts/models/self_supervised/amdim/amdim_module.py
index ae781b583a..295b2fe093 100644
--- a/src/pl_bolts/models/self_supervised/amdim/amdim_module.py
+++ b/src/pl_bolts/models/self_supervised/amdim/amdim_module.py
@@ -83,6 +83,7 @@ class AMDIM(LightningModule):
trainer.fit(model)
.. _AMDIM: https://arxiv.org/abs/1906.00910
+
"""
def __init__(
diff --git a/src/pl_bolts/models/self_supervised/byol/byol_module.py b/src/pl_bolts/models/self_supervised/byol/byol_module.py
index f438de2239..93bcd2742a 100644
--- a/src/pl_bolts/models/self_supervised/byol/byol_module.py
+++ b/src/pl_bolts/models/self_supervised/byol/byol_module.py
@@ -59,6 +59,7 @@ class BYOL(LightningModule):
--batch_size 32
.. _BYOL: https://arxiv.org/pdf/2006.07733.pdf
+
"""
def __init__(
@@ -92,6 +93,7 @@ def forward(self, x: Tensor) -> Tensor:
Args:
x (Tensor): sample to be encoded
+
"""
return self.online_network.encode(x)
@@ -131,6 +133,7 @@ def calculate_loss(self, v_online: Tensor, v_target: Tensor) -> Tensor:
Args:
v_online (Tensor): Online network view
v_target (Tensor): Target network view
+
"""
_, z1 = self.online_network(v_online)
h1 = self.predictor(z1)
diff --git a/src/pl_bolts/models/self_supervised/byol/models.py b/src/pl_bolts/models/self_supervised/byol/models.py
index 497fa112bb..4c9e901c2c 100644
--- a/src/pl_bolts/models/self_supervised/byol/models.py
+++ b/src/pl_bolts/models/self_supervised/byol/models.py
@@ -15,6 +15,7 @@ class MLP(nn.Module):
Note:
Default values for input, hidden, and output dimensions are based on values used in BYOL.
+
"""
def __init__(self, input_dim: int = 2048, hidden_dim: int = 4096, output_dim: int = 256) -> None:
@@ -32,8 +33,7 @@ def forward(self, x: Tensor) -> Tensor:
class SiameseArm(nn.Module):
- """SiameseArm consolidates the encoder and projector networks of BYOL's symmetric architecture into a single
- class.
+ """SiameseArm consolidates the encoder and projector networks of BYOL's symmetric architecture into a single class.
Args:
encoder (Union[str, nn.Module], optional): Online and target network encoder architecture.
@@ -43,6 +43,7 @@ class SiameseArm(nn.Module):
Defaults to 4096.
projector_out_dim (int, optional): Online and target network projector network output dimension.
Defaults to 256.
+
"""
def __init__(
@@ -67,10 +68,11 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
return y, z
def encode(self, x: Tensor) -> Tensor:
- """Returns the encoded representation of a view. This method does not calculate the projection as in the
- forward method.
+ """Returns the encoded representation of a view. This method does not calculate the projection as in the forward
+ method.
Args:
x (Tensor): sample to be encoded
+
"""
return self.encoder(x)[0]
diff --git a/src/pl_bolts/models/self_supervised/moco/moco_module.py b/src/pl_bolts/models/self_supervised/moco/moco_module.py
index cb9ddbdde1..86d01147c3 100644
--- a/src/pl_bolts/models/self_supervised/moco/moco_module.py
+++ b/src/pl_bolts/models/self_supervised/moco/moco_module.py
@@ -63,6 +63,7 @@ def dequeue_and_enqueue(self, x: Tensor) -> None:
Args:
x: A mini-batch of representations. The queue size has to be a multiple of the total number of
representations across all devices.
+
"""
# Gather representations from all GPUs into a [batch_size * world_size, num_features] tensor, in case of
# distributed training.
@@ -133,6 +134,7 @@ def __init__(
optimizer_params: Parameters to pass to the optimizer constructor.
lr_scheduler: Which learning rate scheduler class to use for training.
lr_scheduler_params: Parameters to pass to the learning rate scheduler constructor.
+
"""
super().__init__()
@@ -182,6 +184,7 @@ def forward(self, query_images: Tensor, key_images: Tensor) -> Tuple[Tensor, Ten
Returns:
A tuple of query and key representations.
+
"""
q = self.encoder_q(query_images)
if self.head_q is not None:
@@ -268,6 +271,7 @@ def _calculate_loss(self, images: Tensor, queue: RepresentationQueue) -> Tuple[T
images: A mini-batch of image pairs in a ``[batch_size, 2, num_channels, height, width]`` tensor.
queue: The queue that the query representations will be compared against. The key representations will be
added to the queue.
+
"""
if images.size(1) != 2:
raise ValueError(
diff --git a/src/pl_bolts/models/self_supervised/moco/utils.py b/src/pl_bolts/models/self_supervised/moco/utils.py
index ef12e633e1..116b52f979 100644
--- a/src/pl_bolts/models/self_supervised/moco/utils.py
+++ b/src/pl_bolts/models/self_supervised/moco/utils.py
@@ -17,6 +17,7 @@ def validate_batch(batch: Tuple[List[List[Tensor]], List[Any]]) -> Tensor:
Returns:
The input batch with images stacked into a single ``[N, 2, channels, height, width]`` tensor.
+
"""
images, targets = batch
@@ -119,6 +120,7 @@ def shuffle_batch(x: Tensor) -> Tuple[Tensor, Tensor]:
Returns:
The output tensor and a list of indices that gives the original order of the combined mini-batch. The output
tensor is the same size as the input tensor, but contains a random subset of the combined mini-batch.
+
"""
all_x = concatenate_all(x)
@@ -151,6 +153,7 @@ def sort_batch(x: Tensor, order: Tensor) -> Tensor:
Returns:
The subset of the combined mini-batch that corresponds to this device.
+
"""
all_x = concatenate_all(x)
diff --git a/src/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/src/pl_bolts/models/self_supervised/simsiam/simsiam_module.py
index ec90fb528a..9bd79906f6 100644
--- a/src/pl_bolts/models/self_supervised/simsiam/simsiam_module.py
+++ b/src/pl_bolts/models/self_supervised/simsiam/simsiam_module.py
@@ -58,6 +58,7 @@ class SimSiam(LightningModule):
--batch_size 32
.. _SimSiam: https://arxiv.org/pdf/2011.10566v1.pdf
+
"""
def __init__(
@@ -122,6 +123,7 @@ def calculate_loss(self, v_online: Tensor, v_target: Tensor) -> Tensor:
Args:
v_online (Tensor): Online network view
v_target (Tensor): Target network view
+
"""
_, z1 = self.online_network(v_online)
h1 = self.predictor(z1)
diff --git a/src/pl_bolts/models/self_supervised/ssl_finetuner.py b/src/pl_bolts/models/self_supervised/ssl_finetuner.py
index b745796d51..b5f74e7ab5 100644
--- a/src/pl_bolts/models/self_supervised/ssl_finetuner.py
+++ b/src/pl_bolts/models/self_supervised/ssl_finetuner.py
@@ -9,8 +9,8 @@
class SSLFineTuner(LightningModule):
- """Finetunes a self-supervised learning backbone using the standard evaluation protocol of a singler layer MLP
- with 1024 units.
+ """Finetunes a self-supervised learning backbone using the standard evaluation protocol of a singler layer MLP with
+ 1024 units.
Example::
@@ -37,6 +37,7 @@ class SSLFineTuner(LightningModule):
# test
trainer.test(datamodule=dm)
+
"""
def __init__(
diff --git a/src/pl_bolts/models/self_supervised/swav/loss.py b/src/pl_bolts/models/self_supervised/swav/loss.py
index cd46997a11..f0e70a49b0 100644
--- a/src/pl_bolts/models/self_supervised/swav/loss.py
+++ b/src/pl_bolts/models/self_supervised/swav/loss.py
@@ -28,6 +28,7 @@ def __init__(
gpus: number of gpus per node used in training, passed to SwAV module
to manage the queue and select distributed sinkhorn
num_nodes: num_nodes: number of nodes to train on
+
"""
super().__init__()
self.temperature = temperature
diff --git a/src/pl_bolts/models/vision/pixel_cnn.py b/src/pl_bolts/models/vision/pixel_cnn.py
index 89237195f4..70be9dcd19 100644
--- a/src/pl_bolts/models/vision/pixel_cnn.py
+++ b/src/pl_bolts/models/vision/pixel_cnn.py
@@ -2,6 +2,7 @@
Implemented by: William Falcon Reference
: https: //arxiv.org/pdf/1905.09272.pdf (page 15 Accessed: May 14, 2020.
+
"""
from torch import nn
from torch.nn import functional as F # noqa: N812
diff --git a/src/pl_bolts/models/vision/segmentation.py b/src/pl_bolts/models/vision/segmentation.py
index 06f079da3a..550b6af8b7 100644
--- a/src/pl_bolts/models/vision/segmentation.py
+++ b/src/pl_bolts/models/vision/segmentation.py
@@ -33,6 +33,7 @@ class SemSegment(LightningModule):
# KITTI
python segmentation.py --data_dir /path/to/kitti/ --accelerator=gpu
+
"""
def __init__(
diff --git a/src/pl_bolts/models/vision/unet.py b/src/pl_bolts/models/vision/unet.py
index 0bc7c09d67..0ce0f73408 100644
--- a/src/pl_bolts/models/vision/unet.py
+++ b/src/pl_bolts/models/vision/unet.py
@@ -22,6 +22,7 @@ class UNet(nn.Module):
num_layers: Number of layers in each side of U-net (default 5)
features_start: Number of features in first layer (default 64)
bilinear: Whether to use bilinear interpolation (True) or transposed convolutions (default) for upsampling.
+
"""
def __init__(
@@ -94,8 +95,8 @@ def forward(self, x: Tensor) -> Tensor:
class Up(nn.Module):
- """Upsampling (by either bilinear interpolation or transpose convolutions) followed by concatenation of feature
- map from contracting path, followed by DoubleConv."""
+ """Upsampling (by either bilinear interpolation or transpose convolutions) followed by concatenation of feature map
+ from contracting path, followed by DoubleConv."""
def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False) -> None:
super().__init__()
diff --git a/src/pl_bolts/optimizers/lars.py b/src/pl_bolts/optimizers/lars.py
index 06a2e1d8aa..58ed202f24 100644
--- a/src/pl_bolts/optimizers/lars.py
+++ b/src/pl_bolts/optimizers/lars.py
@@ -103,6 +103,7 @@ def step(self, closure=None):
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
+
"""
loss = None
if closure is not None:
diff --git a/src/pl_bolts/optimizers/lr_scheduler.py b/src/pl_bolts/optimizers/lr_scheduler.py
index e094d0312f..158e4a618c 100644
--- a/src/pl_bolts/optimizers/lr_scheduler.py
+++ b/src/pl_bolts/optimizers/lr_scheduler.py
@@ -10,8 +10,8 @@
@under_review()
class LinearWarmupCosineAnnealingLR(_LRScheduler):
- """Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr
- and base_lr followed by a cosine annealing schedule between base_lr and eta_min.
+ """Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr and
+ base_lr followed by a cosine annealing schedule between base_lr and eta_min.
.. warning::
It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR`
@@ -42,6 +42,7 @@ class LinearWarmupCosineAnnealingLR(_LRScheduler):
... scheduler.step(epoch)
... # train(...)
... # validate(...)
+
"""
def __init__(
diff --git a/src/pl_bolts/transforms/self_supervised/amdim_transforms.py b/src/pl_bolts/transforms/self_supervised/amdim_transforms.py
index b10f647b90..b3b72d641f 100644
--- a/src/pl_bolts/transforms/self_supervised/amdim_transforms.py
+++ b/src/pl_bolts/transforms/self_supervised/amdim_transforms.py
@@ -32,6 +32,7 @@ class AMDIMTrainTransformsCIFAR10:
transform = AMDIMTrainTransformsCIFAR10()
(view1, view2) = transform(x)
+
"""
def __init__(self) -> None:
@@ -74,6 +75,7 @@ class AMDIMEvalTransformsCIFAR10:
transform = AMDIMEvalTransformsCIFAR10()
(view1, view2) = transform(x)
+
"""
def __init__(self) -> None:
@@ -113,6 +115,7 @@ class AMDIMTrainTransformsSTL10:
transform = AMDIMTrainTransformsSTL10()
(view1, view2) = transform(x)
+
"""
def __init__(self, height: int = 64) -> None:
@@ -155,6 +158,7 @@ class AMDIMEvalTransformsSTL10:
transform = AMDIMTrainTransformsSTL10()
view1 = transform(x)
+
"""
def __init__(self, height: int = 64) -> None:
@@ -200,6 +204,7 @@ class AMDIMTrainTransformsImageNet128:
transform = AMDIMTrainTransformsSTL10()
(view1, view2) = transform(x)
+
"""
def __init__(self, height: int = 128) -> None:
@@ -245,6 +250,7 @@ class AMDIMEvalTransformsImageNet128:
transform = AMDIMEvalTransformsImageNet128()
view1 = transform(x)
+
"""
def __init__(self, height: int = 128) -> None:
diff --git a/src/pl_bolts/transforms/self_supervised/moco_transforms.py b/src/pl_bolts/transforms/self_supervised/moco_transforms.py
index f6117c39d8..c1e3654342 100644
--- a/src/pl_bolts/transforms/self_supervised/moco_transforms.py
+++ b/src/pl_bolts/transforms/self_supervised/moco_transforms.py
@@ -31,6 +31,7 @@ class MoCoTrainTransforms:
Args:
Example::
+
"""
def __init__(self, size: int, normalize: Union[str, Callable]) -> None:
@@ -79,6 +80,7 @@ class MoCo2TrainCIFAR10Transforms:
MoCo 2 augmentation:
https://arxiv.org/pdf/2003.04297.pdf
+
"""
def __init__(self, size: int = 32) -> None:
@@ -165,6 +167,7 @@ class MoCo2EvalImagenetTransforms:
"""Transforms for MoCo during training step.
https://arxiv.org/pdf/2003.04297.pdf
+
"""
def __init__(self, size: int = 128) -> None:
diff --git a/src/pl_bolts/transforms/self_supervised/simclr_transforms.py b/src/pl_bolts/transforms/self_supervised/simclr_transforms.py
index 672e2a94a1..7d670f422b 100644
--- a/src/pl_bolts/transforms/self_supervised/simclr_transforms.py
+++ b/src/pl_bolts/transforms/self_supervised/simclr_transforms.py
@@ -36,6 +36,7 @@ class SimCLRTrainDataTransform:
transform = SimCLRTrainDataTransform(input_height=32)
x = sample()
(xi, xj, xk) = transform(x) # xk is only for the online evaluator if used
+
"""
def __init__(
@@ -116,6 +117,7 @@ class SimCLREvalDataTransform(SimCLRTrainDataTransform):
transform = SimCLREvalDataTransform(input_height=32)
x = sample()
(xi, xj, xk) = transform(x) # xk is only for the online evaluator if used
+
"""
def __init__(
@@ -162,6 +164,7 @@ class SimCLRFinetuneTransform(SimCLRTrainDataTransform):
transform = SimCLREvalDataTransform(input_height=32)
x = sample()
xk = transform(x)
+
"""
def __init__(
diff --git a/src/pl_bolts/transforms/self_supervised/ssl_transforms.py b/src/pl_bolts/transforms/self_supervised/ssl_transforms.py
index 8db1069c32..5b3322eba7 100644
--- a/src/pl_bolts/transforms/self_supervised/ssl_transforms.py
+++ b/src/pl_bolts/transforms/self_supervised/ssl_transforms.py
@@ -13,10 +13,11 @@
@under_review()
class RandomTranslateWithReflect:
- """Translate image randomly Translate vertically and horizontally by n pixels where n is integer drawn
- uniformly independently for each axis from [-max_translation, max_translation].
+ """Translate image randomly Translate vertically and horizontally by n pixels where n is integer drawn uniformly
+ independently for each axis from [-max_translation, max_translation].
Fill the uncovered blank area with reflect padding.
+
"""
def __init__(self, max_translation) -> None:
diff --git a/src/pl_bolts/utils/arguments.py b/src/pl_bolts/utils/arguments.py
index 818fb6b410..0325d6575b 100644
--- a/src/pl_bolts/utils/arguments.py
+++ b/src/pl_bolts/utils/arguments.py
@@ -34,6 +34,7 @@ class LightningArgumentParser(ArgumentParser):
# args.data -> data args
# args.model -> model args
+
"""
def __init__(self, *args: Any, ignore_required_init_args: bool = True, **kwargs: Any) -> None:
diff --git a/src/pl_bolts/utils/semi_supervised.py b/src/pl_bolts/utils/semi_supervised.py
index 6448eca7e6..3fbcfae0e4 100644
--- a/src/pl_bolts/utils/semi_supervised.py
+++ b/src/pl_bolts/utils/semi_supervised.py
@@ -25,6 +25,7 @@ class Identity(torch.nn.Module):
model = resnet18()
model.fc = Identity()
+
"""
def __init__(self) -> None:
@@ -44,6 +45,7 @@ def balance_classes(
X: input features
y: mixed labels (ints)
batch_size: the ultimate batch size
+
"""
if not _SKLEARN_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError("You want to use `shuffle` function from `scikit-learn` which is not installed yet.")
@@ -100,8 +102,8 @@ def generate_half_labeled_batches(
larger_set_y: np.ndarray,
batch_size: int,
) -> Tuple[np.ndarray, np.ndarray]:
- """Given a labeled dataset and an unlabeled dataset, this function generates a joint pair where half the
- batches are labeled and the other half is not."""
+ """Given a labeled dataset and an unlabeled dataset, this function generates a joint pair where half the batches are
+ labeled and the other half is not."""
x = []
y = []
half_batch = batch_size // 2
diff --git a/src/pl_bolts/utils/stability.py b/src/pl_bolts/utils/stability.py
index 486542cbdf..3f0e4e6b08 100644
--- a/src/pl_bolts/utils/stability.py
+++ b/src/pl_bolts/utils/stability.py
@@ -48,8 +48,7 @@ def _raise_review_warning(message: str, stacklevel: int = 6) -> None:
def under_review():
- """The under_review decorator is used to indicate that a particular feature is not properly reviewed and tested
- yet.
+ """The under_review decorator is used to indicate that a particular feature is not properly reviewed and tested yet.
A callable or type that has been marked as under_review will give a ``UnderReviewWarning`` when it is called or
instantiated. This designation should be used following the description given in :ref:`stability`.
@@ -67,6 +66,7 @@ def under_review():
... MyExperimentalFeature() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
...
<...>
+
"""
def decorator(cls_or_callable: Union[Callable, Type], feature_name: Optional[str] = None, was_class: bool = False):
diff --git a/src/pl_bolts/utils/warnings.py b/src/pl_bolts/utils/warnings.py
index 116ef2173a..ef642ccc87 100644
--- a/src/pl_bolts/utils/warnings.py
+++ b/src/pl_bolts/utils/warnings.py
@@ -23,6 +23,7 @@ def warn_missing_pkg(
Returns:
Number of warning calls
+
"""
if not WARN_MISSING_PACKAGE:
return -1
diff --git a/tests/callbacks/verification/test_batch_gradient.py b/tests/callbacks/verification/test_batch_gradient.py
index e1a99bf037..bc25bf3b49 100644
--- a/tests/callbacks/verification/test_batch_gradient.py
+++ b/tests/callbacks/verification/test_batch_gradient.py
@@ -17,6 +17,7 @@ def __init__(self, mix_data=False) -> None:
"""Base model for testing.
The setting ``mix_data=True`` simulates a wrong implementation.
+
"""
super().__init__()
self.mix_data = mix_data
@@ -142,8 +143,7 @@ def test_batch_verification_raises_on_batch_size_1():
def test_batch_verification_calls_custom_input_output_mappings():
- """Test that batch gradient verification can support different input and outputs with user-provided
- mappings."""
+ """Test that batch gradient verification can support different input and outputs with user-provided mappings."""
model = MultipleInputModel()
def input_mapping(inputs):
diff --git a/tests/conftest.py b/tests/conftest.py
index 7637c059b2..b1340ac335 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -79,6 +79,7 @@ def restore_signal_handlers(): # noqa: PT004
"""Ensures that signal handlers get restored before the next test runs.
This is a safety net for tests that don't run Trainer's teardown.
+
"""
valid_signals = SignalConnector._valid_signals()
if not _IS_WINDOWS:
diff --git a/tests/datamodules/test_datamodules.py b/tests/datamodules/test_datamodules.py
index a981cac7e5..7dce4e3f01 100644
--- a/tests/datamodules/test_datamodules.py
+++ b/tests/datamodules/test_datamodules.py
@@ -26,8 +26,7 @@ def test_dev_datasets(datadir):
def _create_synth_cityscapes_dataset(path_dir):
- """Create synthetic dataset with random images, just to simulate that the dataset have been already
- downloaded."""
+ """Create synthetic dataset with random images, just to simulate that the dataset have been already downloaded."""
non_existing_citites = ["dummy_city_1", "dummy_city_2"]
fine_labels_dir = Path(path_dir) / "gtFine"
images_dir = Path(path_dir) / "leftImg8bit"
@@ -156,6 +155,7 @@ def test_emnist_datamodules_with_strict_val_split(datadir, catch_warnings, dm_cl
"""Test EMNIST datamodules when strict_val_split is specified to use the validation set defined in the paper.
Refer to https://arxiv.org/abs/1702.05373 for `expected_val_split` values.
+
"""
if expected_val_split is None:
diff --git a/tests/datamodules/test_experience_sources.py b/tests/datamodules/test_experience_sources.py
index 22f71c1c69..78aeed9045 100644
--- a/tests/datamodules/test_experience_sources.py
+++ b/tests/datamodules/test_experience_sources.py
@@ -173,8 +173,8 @@ def test_source_is_done_short_episode(self):
break
def test_source_is_done_2step_episode(self):
- """Test that when done and the history is full, return the full history, then start to return the tail of
- the history."""
+ """Test that when done and the history is full, return the full history, then start to return the tail of the
+ history."""
self.env = [self.mock_env]
self.source = ExperienceSource(self.env, self.agent, n_steps=2)
@@ -289,6 +289,7 @@ def test_source_discounted_return(self):
"""Tests that the source returns a single experience with discounted rewards.
discounted returns: G(t) = R(t+1) + γ*R(t+2) + γ^2*R(t+3) ... + γ^N-1*R(t+N)
+
"""
self.source = DiscountedExperienceSource(self.env1, self.agent, n_steps=self.n_steps)
diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py
index ca730899bb..cb98a01906 100644
--- a/tests/helpers/boring_model.py
+++ b/tests/helpers/boring_model.py
@@ -83,6 +83,7 @@ def training_step(...):
model = BaseTestModel()
model.training_epoch_end = None
+
"""
super().__init__()
self.layer = torch.nn.Linear(32, 2)
diff --git a/tests/models/rl/unit/test_memory.py b/tests/models/rl/unit/test_memory.py
index 970495363c..c199acde39 100644
--- a/tests/models/rl/unit/test_memory.py
+++ b/tests/models/rl/unit/test_memory.py
@@ -174,8 +174,8 @@ def setUp(self) -> None:
self.experience03 = Experience(self.state_02, self.action_02, self.reward_02, self.done_02, self.next_state_02)
def test_append_single_experience_less_than_n(self):
- """If a single experience is added and n > 1 nothing should be added to the buffer as it is waiting
- experiences to equal n."""
+ """If a single experience is added and n > 1 nothing should be added to the buffer as it is waiting experiences
+ to equal n."""
assert len(self.buffer) == 0
self.buffer.append(self.experience01)
@@ -183,8 +183,8 @@ def test_append_single_experience_less_than_n(self):
assert len(self.buffer) == 0
def test_append_single_experience(self):
- """If a single experience is added and n > 1 nothing should be added to the buffer as it is waiting
- experiences to equal n."""
+ """If a single experience is added and n > 1 nothing should be added to the buffer as it is waiting experiences
+ to equal n."""
assert len(self.buffer) == 0
self.buffer.append(self.experience01)