Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add annotations from __future__ #239

Merged
merged 3 commits into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions generative/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from .version import __version__
2 changes: 2 additions & 0 deletions generative/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from .trainer import AdversarialTrainer
2 changes: 2 additions & 0 deletions generative/inferers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from .inferer import DiffusionInferer, LatentDiffusionInferer
67 changes: 34 additions & 33 deletions generative/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import math
from typing import Callable, List, Optional, Tuple, Union
from collections.abc import Callable

import torch
import torch.nn as nn
Expand Down Expand Up @@ -41,7 +42,7 @@ def __call__(
diffusion_model: Callable[..., torch.Tensor],
noise: torch.Tensor,
timesteps: torch.Tensor,
condition: Optional[torch.Tensor] = None,
condition: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Implements the forward pass for a supervised training iteration.
Expand All @@ -62,12 +63,12 @@ def sample(
self,
input_noise: torch.Tensor,
diffusion_model: Callable[..., torch.Tensor],
scheduler: Optional[Callable[..., torch.Tensor]] = None,
save_intermediates: Optional[bool] = False,
intermediate_steps: Optional[int] = 100,
conditioning: Optional[torch.Tensor] = None,
verbose: Optional[bool] = True,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
scheduler: Callable[..., torch.Tensor] | None = None,
save_intermediates: bool | None = False,
intermediate_steps: int | None = 100,
conditioning: torch.Tensor | None = None,
verbose: bool | None = True,
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
"""
Args:
input_noise: random noise, of the same shape as the desired sample.
Expand Down Expand Up @@ -107,13 +108,13 @@ def get_likelihood(
self,
inputs: torch.Tensor,
diffusion_model: Callable[..., torch.Tensor],
scheduler: Optional[Callable[..., torch.Tensor]] = None,
save_intermediates: Optional[bool] = False,
conditioning: Optional[torch.Tensor] = None,
original_input_range: Optional[Tuple] = (0, 255),
scaled_input_range: Optional[Tuple] = (0, 1),
verbose: Optional[bool] = True,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
scheduler: Callable[..., torch.Tensor] | None = None,
save_intermediates: bool | None = False,
conditioning: torch.Tensor | None = None,
original_input_range: tuple | None = (0, 255),
scaled_input_range: tuple | None = (0, 1),
verbose: bool | None = True,
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
"""
Computes the likelihoods for an input.

Expand Down Expand Up @@ -228,8 +229,8 @@ def _get_decoder_log_likelihood(
inputs: torch.Tensor,
means: torch.Tensor,
log_scales: torch.Tensor,
original_input_range: Optional[Tuple] = (0, 255),
scaled_input_range: Optional[Tuple] = (0, 1),
original_input_range: tuple | None = (0, 255),
scaled_input_range: tuple | None = (0, 1),
) -> torch.Tensor:
"""
Compute the log-likelihood of a Gaussian distribution discretizing to a
Expand Down Expand Up @@ -287,7 +288,7 @@ def __call__(
diffusion_model: Callable[..., torch.Tensor],
noise: torch.Tensor,
timesteps: torch.Tensor,
condition: Optional[torch.Tensor] = None,
condition: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Implements the forward pass for a supervised training iteration.
Expand All @@ -314,12 +315,12 @@ def sample(
input_noise: torch.Tensor,
autoencoder_model: Callable[..., torch.Tensor],
diffusion_model: Callable[..., torch.Tensor],
scheduler: Optional[Callable[..., torch.Tensor]] = None,
save_intermediates: Optional[bool] = False,
intermediate_steps: Optional[int] = 100,
conditioning: Optional[torch.Tensor] = None,
verbose: Optional[bool] = True,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
scheduler: Callable[..., torch.Tensor] | None = None,
save_intermediates: bool | None = False,
intermediate_steps: int | None = 100,
conditioning: torch.Tensor | None = None,
verbose: bool | None = True,
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
"""
Args:
input_noise: random noise, of the same shape as the desired latent representation.
Expand Down Expand Up @@ -367,15 +368,15 @@ def get_likelihood(
inputs: torch.Tensor,
autoencoder_model: Callable[..., torch.Tensor],
diffusion_model: Callable[..., torch.Tensor],
scheduler: Optional[Callable[..., torch.Tensor]] = None,
save_intermediates: Optional[bool] = False,
conditioning: Optional[torch.Tensor] = None,
original_input_range: Optional[Tuple] = (0, 255),
scaled_input_range: Optional[Tuple] = (0, 1),
verbose: Optional[bool] = True,
resample_latent_likelihoods: Optional[bool] = False,
resample_interpolation_mode: Optional[str] = "bilinear",
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
scheduler: Callable[..., torch.Tensor] | None = None,
save_intermediates: bool | None = False,
conditioning: torch.Tensor | None = None,
original_input_range: tuple | None = (0, 255),
scaled_input_range: tuple | None = (0, 1),
verbose: bool | None = True,
resample_latent_likelihoods: bool | None = False,
resample_interpolation_mode: str | None = "bilinear",
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
"""
Computes the likelihoods of the latent representations of the input.

Expand Down
2 changes: 2 additions & 0 deletions generative/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from .spectral_loss import JukeboxLoss
11 changes: 6 additions & 5 deletions generative/losses/adversarial_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings
from typing import List, Optional, Union

import torch
from monai.networks.layers.utils import get_act_layer
Expand Down Expand Up @@ -45,7 +46,7 @@ class PatchAdversarialLoss(_Loss):

def __init__(
self,
reduction: Union[LossReduction, str] = LossReduction.MEAN,
reduction: LossReduction | str = LossReduction.MEAN,
criterion: str = AdversarialCriterions.LEAST_SQUARE.value,
) -> None:
super().__init__(reduction=LossReduction(reduction).value)
Expand Down Expand Up @@ -101,8 +102,8 @@ def get_zero_tensor(self, input: torch.FloatTensor) -> torch.Tensor:
return zero_label_tensor.expand_as(input)

def forward(
self, input: Union[torch.FloatTensor, list], target_is_real: bool, for_discriminator: bool
) -> Union[torch.Tensor, List[torch.Tensor]]:
self, input: torch.FloatTensor | list, target_is_real: bool, for_discriminator: bool
) -> torch.Tensor | list[torch.Tensor]:

"""

Expand Down Expand Up @@ -152,7 +153,7 @@ def forward(

return loss

def forward_single(self, input: torch.FloatTensor, target: torch.FloatTensor) -> Optional[torch.Tensor]:
def forward_single(self, input: torch.FloatTensor, target: torch.FloatTensor) -> torch.Tensor | None:
if (
self.criterion == AdversarialCriterions.BCE.value
or self.criterion == AdversarialCriterions.LEAST_SQUARE.value
Expand Down
4 changes: 2 additions & 2 deletions generative/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple
from __future__ import annotations

import torch
import torch.nn as nn
Expand Down Expand Up @@ -72,7 +72,7 @@ def _calculate_axis_loss(self, input: torch.Tensor, target: torch.Tensor, spatia
spatial_axis: spatial axis to obtain the 2D slices.
"""

def batchify_axis(x: torch.Tensor, fake_3d_perm: Tuple) -> torch.Tensor:
def batchify_axis(x: torch.Tensor, fake_3d_perm: tuple) -> torch.Tensor:
"""
Transform slices from one spatial axis into different instances in the batch.
"""
Expand Down
6 changes: 3 additions & 3 deletions generative/losses/spectral_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union
from __future__ import annotations

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -42,9 +42,9 @@ class JukeboxLoss(_Loss):
def __init__(
self,
spatial_dims: int,
fft_signal_size: Optional[Tuple[int]] = None,
fft_signal_size: tuple[int] | None = None,
fft_norm: str = "ortho",
reduction: Union[LossReduction, str] = LossReduction.MEAN,
reduction: LossReduction | str = LossReduction.MEAN,
) -> None:
super().__init__(reduction=LossReduction(reduction).value)

Expand Down
2 changes: 2 additions & 0 deletions generative/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from .fid import FID
from .mmd import MMD
from .ms_ssim import MSSSIM
4 changes: 2 additions & 2 deletions generative/metrics/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#
# =========================================================================
# Adapted from https://github.com/photosynthesis-team/piq
# which has the following license:
# https://github.com/photosynthesis-team/piq/blob/master/LICENSE

#
# Copyright 2023 photosynthesis-team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
11 changes: 7 additions & 4 deletions generative/metrics/mmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Union

from __future__ import annotations

from collections.abc import Callable

import torch
from monai.metrics.regression import RegressionMetric
Expand Down Expand Up @@ -40,9 +43,9 @@ class MMD(RegressionMetric):

def __init__(
self,
y_transform: Optional[Callable] = None,
y_pred_transform: Optional[Callable] = None,
reduction: Union[MetricReduction, str] = MetricReduction.MEAN,
y_transform: Callable | None = None,
y_pred_transform: Callable | None = None,
reduction: MetricReduction | str = MetricReduction.MEAN,
get_not_nans: bool = False,
) -> None:
super().__init__(reduction=reduction, get_not_nans=get_not_nans)
Expand Down
11 changes: 6 additions & 5 deletions generative/metrics/ms_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union

from __future__ import annotations

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -45,13 +46,13 @@ class MSSSIM(RegressionMetric):

def __init__(
self,
data_range: Union[torch.Tensor, float],
data_range: torch.Tensor | float,
win_size: int = 7,
k1: float = 0.01,
k2: float = 0.03,
spatial_dims: int = 2,
weights: Optional[List] = None,
reduction: Union[MetricReduction, str] = MetricReduction.MEAN,
weights: list | None = None,
reduction: MetricReduction | str = MetricReduction.MEAN,
) -> None:
super().__init__()

Expand Down Expand Up @@ -114,7 +115,7 @@ def _compute_metric(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
)

levels = self.weights.shape[0]
mcs_list: List[torch.Tensor] = []
mcs_list: list[torch.Tensor] = []
for i in range(levels):
ssim, cs = self.SSIM._compute_metric_and_contrast(x, y)

Expand Down
8 changes: 4 additions & 4 deletions generative/networks/layers/vector_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Sequence, Tuple
from typing import Sequence

import torch
from torch import nn
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(
)

@torch.cuda.amp.autocast(enabled=False)
def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def quantize(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss.

Expand Down Expand Up @@ -158,7 +158,7 @@ def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Ten
else:
pass

def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def forward(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
flat_input, encodings, encoding_indices = self.quantize(inputs)
quantized = self.embed(encoding_indices)

Expand Down Expand Up @@ -205,7 +205,7 @@ def __init__(self, quantizer: torch.nn.Module = None):

self.perplexity: torch.Tensor = torch.rand(1)

def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
quantized, loss, encoding_indices = self.quantizer(inputs)

# Perplexity calculations
Expand Down
2 changes: 2 additions & 0 deletions generative/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from .autoencoderkl import AutoencoderKL
from .diffusion_model_unet import DiffusionModelUNet
from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator
Expand Down
Loading