From c614e029968c98f784bb88499011d1dc4a41b08d Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Thu, 9 Feb 2023 23:20:04 +0000 Subject: [PATCH 1/3] Add from __future__ import annotations Signed-off-by: Walter Hugo Lopez Pinaya --- generative/__init__.py | 2 + generative/engines/__init__.py | 2 + generative/inferers/__init__.py | 2 + generative/inferers/inferer.py | 68 ++++++++-------- generative/losses/__init__.py | 2 + generative/losses/adversarial_loss.py | 11 +-- generative/losses/perceptual.py | 5 +- generative/losses/spectral_loss.py | 7 +- generative/metrics/__init__.py | 2 + generative/metrics/mmd.py | 10 ++- generative/metrics/ms_ssim.py | 11 +-- generative/networks/layers/__init__.py | 2 + .../networks/layers/vector_quantizer.py | 10 ++- generative/networks/nets/__init__.py | 2 + generative/networks/nets/autoencoderkl.py | 10 ++- .../networks/nets/diffusion_model_unet.py | 80 ++++++++++--------- .../networks/nets/patchgan_discriminator.py | 30 +++---- generative/networks/nets/transformer.py | 5 +- generative/networks/nets/vqvae.py | 24 +++--- generative/networks/schedulers/__init__.py | 2 + generative/networks/schedulers/ddim.py | 9 ++- generative/networks/schedulers/ddpm.py | 11 +-- generative/networks/schedulers/pndm.py | 8 +- generative/utils/__init__.py | 2 + generative/utils/enums.py | 2 + generative/utils/ordering.py | 13 +-- setup.cfg | 5 +- setup.py | 2 + tests/test_adversarial.py | 2 + tests/test_autoencoderkl.py | 2 + tests/test_compute_fid_metric.py | 2 + tests/test_compute_mmd_metric.py | 2 + tests/test_compute_ms_ssim_metric.py | 2 + tests/test_diffusion_inferer.py | 2 + tests/test_diffusion_model_unet.py | 2 + .../test_integration_workflows_adversarial.py | 2 + tests/test_latent_diffusion_inferer.py | 2 + tests/test_ordering.py | 6 +- tests/test_patch_gan.py | 2 + tests/test_perceptual_loss.py | 2 + tests/test_scheduler_ddim.py | 2 + tests/test_scheduler_ddpm.py | 2 + tests/test_scheduler_pndm.py | 2 + tests/test_spectral_loss.py | 2 + tests/test_transformer.py | 2 + tests/test_vector_quantizer.py | 2 + tests/test_vqvae.py | 2 + 47 files changed, 233 insertions(+), 148 deletions(-) diff --git a/generative/__init__.py b/generative/__init__.py index 43d8445e..b0822e5e 100644 --- a/generative/__init__.py +++ b/generative/__init__.py @@ -9,4 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .version import __version__ diff --git a/generative/engines/__init__.py b/generative/engines/__init__.py index 4c19abbb..f76c669d 100644 --- a/generative/engines/__init__.py +++ b/generative/engines/__init__.py @@ -9,4 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .trainer import AdversarialTrainer diff --git a/generative/inferers/__init__.py b/generative/inferers/__init__.py index c77668f1..94775e76 100644 --- a/generative/inferers/__init__.py +++ b/generative/inferers/__init__.py @@ -9,4 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .inferer import DiffusionInferer, LatentDiffusionInferer diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index f65bdb20..f758594f 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -10,8 +10,10 @@ # limitations under the License. +from __future__ import annotations + import math -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable import torch import torch.nn as nn @@ -41,7 +43,7 @@ def __call__( diffusion_model: Callable[..., torch.Tensor], noise: torch.Tensor, timesteps: torch.Tensor, - condition: Optional[torch.Tensor] = None, + condition: torch.Tensor | None = None, ) -> torch.Tensor: """ Implements the forward pass for a supervised training iteration. @@ -62,12 +64,12 @@ def sample( self, input_noise: torch.Tensor, diffusion_model: Callable[..., torch.Tensor], - scheduler: Optional[Callable[..., torch.Tensor]] = None, - save_intermediates: Optional[bool] = False, - intermediate_steps: Optional[int] = 100, - conditioning: Optional[torch.Tensor] = None, - verbose: Optional[bool] = True, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + verbose: bool | None = True, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: input_noise: random noise, of the same shape as the desired sample. @@ -107,13 +109,13 @@ def get_likelihood( self, inputs: torch.Tensor, diffusion_model: Callable[..., torch.Tensor], - scheduler: Optional[Callable[..., torch.Tensor]] = None, - save_intermediates: Optional[bool] = False, - conditioning: Optional[torch.Tensor] = None, - original_input_range: Optional[Tuple] = (0, 255), - scaled_input_range: Optional[Tuple] = (0, 1), - verbose: Optional[bool] = True, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool | None = True, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Computes the likelihoods for an input. @@ -228,8 +230,8 @@ def _get_decoder_log_likelihood( inputs: torch.Tensor, means: torch.Tensor, log_scales: torch.Tensor, - original_input_range: Optional[Tuple] = (0, 255), - scaled_input_range: Optional[Tuple] = (0, 1), + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), ) -> torch.Tensor: """ Compute the log-likelihood of a Gaussian distribution discretizing to a @@ -287,7 +289,7 @@ def __call__( diffusion_model: Callable[..., torch.Tensor], noise: torch.Tensor, timesteps: torch.Tensor, - condition: Optional[torch.Tensor] = None, + condition: torch.Tensor | None = None, ) -> torch.Tensor: """ Implements the forward pass for a supervised training iteration. @@ -314,12 +316,12 @@ def sample( input_noise: torch.Tensor, autoencoder_model: Callable[..., torch.Tensor], diffusion_model: Callable[..., torch.Tensor], - scheduler: Optional[Callable[..., torch.Tensor]] = None, - save_intermediates: Optional[bool] = False, - intermediate_steps: Optional[int] = 100, - conditioning: Optional[torch.Tensor] = None, - verbose: Optional[bool] = True, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + verbose: bool | None = True, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: input_noise: random noise, of the same shape as the desired latent representation. @@ -367,15 +369,15 @@ def get_likelihood( inputs: torch.Tensor, autoencoder_model: Callable[..., torch.Tensor], diffusion_model: Callable[..., torch.Tensor], - scheduler: Optional[Callable[..., torch.Tensor]] = None, - save_intermediates: Optional[bool] = False, - conditioning: Optional[torch.Tensor] = None, - original_input_range: Optional[Tuple] = (0, 255), - scaled_input_range: Optional[Tuple] = (0, 1), - verbose: Optional[bool] = True, - resample_latent_likelihoods: Optional[bool] = False, - resample_interpolation_mode: Optional[str] = "bilinear", - ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: + scheduler: Callable[..., torch.Tensor] | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool | None = True, + resample_latent_likelihoods: bool | None = False, + resample_interpolation_mode: str | None = "bilinear", + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Computes the likelihoods of the latent representations of the input. diff --git a/generative/losses/__init__.py b/generative/losses/__init__.py index 211676dc..49fe1192 100644 --- a/generative/losses/__init__.py +++ b/generative/losses/__init__.py @@ -9,4 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .spectral_loss import JukeboxLoss diff --git a/generative/losses/adversarial_loss.py b/generative/losses/adversarial_loss.py index 6c8bb07c..9189d89d 100644 --- a/generative/losses/adversarial_loss.py +++ b/generative/losses/adversarial_loss.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings -from typing import List, Optional, Union import torch from monai.networks.layers.utils import get_act_layer @@ -45,7 +46,7 @@ class PatchAdversarialLoss(_Loss): def __init__( self, - reduction: Union[LossReduction, str] = LossReduction.MEAN, + reduction: LossReduction | str = LossReduction.MEAN, criterion: str = AdversarialCriterions.LEAST_SQUARE.value, ) -> None: super().__init__(reduction=LossReduction(reduction).value) @@ -101,8 +102,8 @@ def get_zero_tensor(self, input: torch.FloatTensor) -> torch.Tensor: return zero_label_tensor.expand_as(input) def forward( - self, input: Union[torch.FloatTensor, list], target_is_real: bool, for_discriminator: bool - ) -> Union[torch.Tensor, List[torch.Tensor]]: + self, input: torch.FloatTensor | list, target_is_real: bool, for_discriminator: bool + ) -> torch.Tensor | list[torch.Tensor]: """ @@ -152,7 +153,7 @@ def forward( return loss - def forward_single(self, input: torch.FloatTensor, target: torch.FloatTensor) -> Optional[torch.Tensor]: + def forward_single(self, input: torch.FloatTensor, target: torch.FloatTensor) -> torch.Tensor | None: if ( self.criterion == AdversarialCriterions.BCE.value or self.criterion == AdversarialCriterions.LEAST_SQUARE.value diff --git a/generative/losses/perceptual.py b/generative/losses/perceptual.py index 5a3640b8..9160265b 100644 --- a/generative/losses/perceptual.py +++ b/generative/losses/perceptual.py @@ -9,7 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from __future__ import annotations + import torch import torch.nn as nn @@ -72,7 +73,7 @@ def _calculate_axis_loss(self, input: torch.Tensor, target: torch.Tensor, spatia spatial_axis: spatial axis to obtain the 2D slices. """ - def batchify_axis(x: torch.Tensor, fake_3d_perm: Tuple) -> torch.Tensor: + def batchify_axis(x: torch.Tensor, fake_3d_perm: tuple) -> torch.Tensor: """ Transform slices from one spatial axis into different instances in the batch. """ diff --git a/generative/losses/spectral_loss.py b/generative/losses/spectral_loss.py index 649a8c86..3224d43d 100644 --- a/generative/losses/spectral_loss.py +++ b/generative/losses/spectral_loss.py @@ -9,7 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from __future__ import annotations + import torch import torch.nn.functional as F @@ -42,9 +43,9 @@ class JukeboxLoss(_Loss): def __init__( self, spatial_dims: int, - fft_signal_size: Optional[Tuple[int]] = None, + fft_signal_size: tuple[int] | None = None, fft_norm: str = "ortho", - reduction: Union[LossReduction, str] = LossReduction.MEAN, + reduction: LossReduction | str = LossReduction.MEAN, ) -> None: super().__init__(reduction=LossReduction(reduction).value) diff --git a/generative/metrics/__init__.py b/generative/metrics/__init__.py index bd4b9acc..e05137d8 100644 --- a/generative/metrics/__init__.py +++ b/generative/metrics/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .fid import FID from .mmd import MMD from .ms_ssim import MSSSIM diff --git a/generative/metrics/mmd.py b/generative/metrics/mmd.py index 6713ea12..b781b698 100644 --- a/generative/metrics/mmd.py +++ b/generative/metrics/mmd.py @@ -8,7 +8,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Union +from __future__ import annotations + +from typing import Callable import torch from monai.metrics.regression import RegressionMetric @@ -40,9 +42,9 @@ class MMD(RegressionMetric): def __init__( self, - y_transform: Optional[Callable] = None, - y_pred_transform: Optional[Callable] = None, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + y_transform: Callable | None = None, + y_pred_transform: Callable | None = None, + reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, ) -> None: super().__init__(reduction=reduction, get_not_nans=get_not_nans) diff --git a/generative/metrics/ms_ssim.py b/generative/metrics/ms_ssim.py index 769a7ff3..456b05b4 100644 --- a/generative/metrics/ms_ssim.py +++ b/generative/metrics/ms_ssim.py @@ -8,7 +8,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Union +from __future__ import annotations + import torch import torch.nn.functional as F @@ -45,13 +46,13 @@ class MSSSIM(RegressionMetric): def __init__( self, - data_range: Union[torch.Tensor, float], + data_range: torch.Tensor | float, win_size: int = 7, k1: float = 0.01, k2: float = 0.03, spatial_dims: int = 2, - weights: Optional[List] = None, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + weights: list | None = None, + reduction: MetricReduction | str = MetricReduction.MEAN, ) -> None: super().__init__() @@ -114,7 +115,7 @@ def _compute_metric(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ) levels = self.weights.shape[0] - mcs_list: List[torch.Tensor] = [] + mcs_list: list[torch.Tensor] = [] for i in range(levels): ssim, cs = self.SSIM._compute_metric_and_contrast(x, y) diff --git a/generative/networks/layers/__init__.py b/generative/networks/layers/__init__.py index bd6e831f..51b907c1 100644 --- a/generative/networks/layers/__init__.py +++ b/generative/networks/layers/__init__.py @@ -9,4 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .vector_quantizer import EMAQuantizer, VectorQuantizer diff --git a/generative/networks/layers/vector_quantizer.py b/generative/networks/layers/vector_quantizer.py index 358b95f3..bcbffa4b 100644 --- a/generative/networks/layers/vector_quantizer.py +++ b/generative/networks/layers/vector_quantizer.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Tuple +from __future__ import annotations + +from typing import Sequence import torch from torch import nn @@ -84,7 +86,7 @@ def __init__( ) @torch.cuda.amp.autocast(enabled=False) - def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def quantize(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss. @@ -158,7 +160,7 @@ def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Ten else: pass - def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def forward(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: flat_input, encodings, encoding_indices = self.quantize(inputs) quantized = self.embed(encoding_indices) @@ -205,7 +207,7 @@ def __init__(self, quantizer: torch.nn.Module = None): self.perplexity: torch.Tensor = torch.rand(1) - def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: quantized, loss, encoding_indices = self.quantizer(inputs) # Perplexity calculations diff --git a/generative/networks/nets/__init__.py b/generative/networks/nets/__init__.py index ed15c2f8..593ea54c 100644 --- a/generative/networks/nets/__init__.py +++ b/generative/networks/nets/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .autoencoderkl import AutoencoderKL from .diffusion_model_unet import DiffusionModelUNet from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index d13b26ad..e9ce57f6 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -9,9 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import importlib.util import math -from typing import Optional, Sequence, Tuple +from typing import Sequence import torch import torch.nn as nn @@ -190,7 +192,7 @@ def __init__( self, spatial_dims: int, num_channels: int, - num_head_channels: Optional[int] = None, + num_head_channels: int | None = None, norm_num_groups: int = 32, norm_eps: float = 1e-6, ) -> None: @@ -655,7 +657,7 @@ def __init__( ) self.latent_channels = latent_channels - def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations. @@ -717,7 +719,7 @@ def decode(self, z: torch.Tensor) -> torch.Tensor: dec = self.decoder(z) return dec - def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: z_mu, z_sigma = self.encode(x) z = self.sampling(z_mu, z_sigma) reconstruction = self.decode(z) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 76caa08b..46aaab58 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -29,9 +29,11 @@ # limitations under the License. # ========================================================================= +from __future__ import annotations + import importlib.util import math -from typing import List, Optional, Sequence, Tuple, Union +from typing import Sequence import torch import torch.nn.functional as F @@ -82,7 +84,7 @@ class CrossAttention(nn.Module): def __init__( self, query_dim: int, - cross_attention_dim: Optional[int] = None, + cross_attention_dim: int | None = None, num_attention_heads: int = 8, num_head_channels: int = 64, dropout: float = 0.0, @@ -147,7 +149,7 @@ def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor x = torch.bmm(attention_probs, value) return x - def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: query = self.to_q(x) context = context if context is not None else x key = self.to_k(context) @@ -188,7 +190,7 @@ def __init__( num_attention_heads: int, num_head_channels: int, dropout: float = 0.0, - cross_attention_dim: Optional[int] = None, + cross_attention_dim: int | None = None, upcast_attention: bool = False, ) -> None: super().__init__() @@ -212,7 +214,7 @@ def __init__( self.norm2 = nn.LayerNorm(num_channels) self.norm3 = nn.LayerNorm(num_channels) - def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: # 1. Self-Attention x = self.attn1(self.norm1(x)) + x @@ -252,7 +254,7 @@ def __init__( dropout: float = 0.0, norm_num_groups: int = 32, norm_eps: float = 1e-6, - cross_attention_dim: Optional[int] = None, + cross_attention_dim: int | None = None, upcast_attention: bool = False, ) -> None: super().__init__() @@ -298,7 +300,7 @@ def __init__( ) ) - def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: # note: if no context is given, cross-attention defaults to self-attention batch = channel = height = width = depth = -1 if self.spatial_dims == 2: @@ -347,7 +349,7 @@ def __init__( self, spatial_dims: int, num_channels: int, - num_head_channels: Optional[int] = None, + num_head_channels: int | None = None, norm_num_groups: int = 32, norm_eps: float = 1e-6, ) -> None: @@ -482,7 +484,7 @@ class Downsample(nn.Module): """ def __init__( - self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: Optional[int] = None, padding: int = 1 + self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 ) -> None: super().__init__() self.num_channels = num_channels @@ -502,7 +504,7 @@ def __init__( assert self.num_channels == self.out_channels self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2) - def forward(self, x: torch.Tensor, emb: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: del emb assert x.shape[1] == self.num_channels return self.op(x) @@ -522,7 +524,7 @@ class Upsample(nn.Module): """ def __init__( - self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: Optional[int] = None, padding: int = 1 + self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 ) -> None: super().__init__() self.num_channels = num_channels @@ -541,7 +543,7 @@ def __init__( else: self.conv = None - def forward(self, x: torch.Tensor, emb: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: del emb assert x.shape[1] == self.num_channels @@ -582,7 +584,7 @@ def __init__( spatial_dims: int, in_channels: int, temb_channels: int, - out_channels: Optional[int] = None, + out_channels: int | None = None, up: bool = False, down: bool = False, norm_num_groups: int = 32, @@ -745,8 +747,8 @@ def __init__( self.downsampler = None def forward( - self, hidden_states: torch.Tensor, temb: torch.Tensor, context: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: del context output_states = [] @@ -847,8 +849,8 @@ def __init__( self.downsampler = None def forward( - self, hidden_states: torch.Tensor, temb: torch.Tensor, context: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: del context output_states = [] @@ -899,7 +901,7 @@ def __init__( downsample_padding: int = 1, num_head_channels: int = 1, transformer_num_layers: int = 1, - cross_attention_dim: Optional[int] = None, + cross_attention_dim: int | None = None, upcast_attention: bool = False, ) -> None: super().__init__() @@ -961,8 +963,8 @@ def __init__( self.downsampler = None def forward( - self, hidden_states: torch.Tensor, temb: torch.Tensor, context: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: output_states = [] for resnet, attn in zip(self.resnets, self.attentions): @@ -1028,7 +1030,7 @@ def __init__( ) def forward( - self, hidden_states: torch.Tensor, temb: torch.Tensor, context: Optional[torch.Tensor] = None + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None ) -> torch.Tensor: del context hidden_states = self.resnet_1(hidden_states, temb) @@ -1063,7 +1065,7 @@ def __init__( norm_eps: float = 1e-6, num_head_channels: int = 1, transformer_num_layers: int = 1, - cross_attention_dim: Optional[int] = None, + cross_attention_dim: int | None = None, upcast_attention: bool = False, ) -> None: super().__init__() @@ -1098,7 +1100,7 @@ def __init__( ) def forward( - self, hidden_states: torch.Tensor, temb: torch.Tensor, context: Optional[torch.Tensor] = None + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None ) -> torch.Tensor: hidden_states = self.resnet_1(hidden_states, temb) hidden_states = self.attention(hidden_states, context=context) @@ -1179,9 +1181,9 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - res_hidden_states_list: List[torch.Tensor], + res_hidden_states_list: list[torch.Tensor], temb: torch.Tensor, - context: Optional[torch.Tensor] = None, + context: torch.Tensor | None = None, ) -> torch.Tensor: del context for resnet in self.resnets: @@ -1284,9 +1286,9 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - res_hidden_states_list: List[torch.Tensor], + res_hidden_states_list: list[torch.Tensor], temb: torch.Tensor, - context: Optional[torch.Tensor] = None, + context: torch.Tensor | None = None, ) -> torch.Tensor: del context for resnet, attn in zip(self.resnets, self.attentions): @@ -1339,7 +1341,7 @@ def __init__( resblock_updown: bool = False, num_head_channels: int = 1, transformer_num_layers: int = 1, - cross_attention_dim: Optional[int] = None, + cross_attention_dim: int | None = None, upcast_attention: bool = False, ) -> None: super().__init__() @@ -1400,9 +1402,9 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - res_hidden_states_list: List[torch.Tensor], + res_hidden_states_list: list[torch.Tensor], temb: torch.Tensor, - context: Optional[torch.Tensor] = None, + context: torch.Tensor | None = None, ) -> torch.Tensor: for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states @@ -1433,7 +1435,7 @@ def get_down_block( with_cross_attn: bool, num_head_channels: int, transformer_num_layers: int, - cross_attention_dim: Optional[int], + cross_attention_dim: int | None, upcast_attention: bool = False, ) -> nn.Module: if with_attn: @@ -1488,7 +1490,7 @@ def get_mid_block( with_conditioning: bool, num_head_channels: int, transformer_num_layers: int, - cross_attention_dim: Optional[int], + cross_attention_dim: int | None, upcast_attention: bool = False, ) -> nn.Module: if with_conditioning: @@ -1529,7 +1531,7 @@ def get_up_block( with_cross_attn: bool, num_head_channels: int, transformer_num_layers: int, - cross_attention_dim: Optional[int], + cross_attention_dim: int | None, upcast_attention: bool = False, ) -> nn.Module: if with_attn: @@ -1614,11 +1616,11 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, resblock_updown: bool = False, - num_head_channels: Union[int, Sequence[int]] = 8, + num_head_channels: int | Sequence[int] = 8, with_conditioning: bool = False, transformer_num_layers: int = 1, - cross_attention_dim: Optional[int] = None, - num_class_embeds: Optional[int] = None, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, upcast_attention: bool = False, ) -> None: super().__init__() @@ -1775,8 +1777,8 @@ def forward( self, x: torch.Tensor, timesteps: torch.Tensor, - context: Optional[torch.Tensor] = None, - class_labels: Optional[torch.Tensor] = None, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, ) -> torch.Tensor: """ Args: @@ -1808,7 +1810,7 @@ def forward( # 4. down if context is not None and self.with_conditioning is False: raise ValueError("model should have with_conditioning = True if context is provided") - down_block_res_samples: List[torch.Tensor] = [h] + down_block_res_samples: list[torch.Tensor] = [h] for downsample_block in self.down_blocks: h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) for residual in res_samples: diff --git a/generative/networks/nets/patchgan_discriminator.py b/generative/networks/nets/patchgan_discriminator.py index c77f9fdf..a2eb2ae1 100644 --- a/generative/networks/nets/patchgan_discriminator.py +++ b/generative/networks/nets/patchgan_discriminator.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from typing import Sequence import torch import torch.nn as nn @@ -54,10 +56,10 @@ def __init__( in_channels: int, out_channels: int, kernel_size: int, - activation: Union[str, tuple] = "PRELU", - norm: Union[str, tuple] = "INSTANCE", + activation: str | tuple = "PRELU", + norm: str | tuple = "INSTANCE", bias: bool = False, - dropout: Union[float, tuple] = 0.0, + dropout: float | tuple = 0.0, minimum_size_im: int = 256, last_conv_kernel_size: int = 1, ) -> None: @@ -91,7 +93,7 @@ def __init__( self.add_module("discriminator_%d" % i_, subnet_d) - def forward(self, i: torch.Tensor) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: + def forward(self, i: torch.Tensor) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]: """ Args: @@ -101,10 +103,10 @@ def forward(self, i: torch.Tensor) -> Tuple[List[torch.Tensor], List[List[torch. of each discriminator. """ - out: List[torch.Tensor] = [] - intermediate_features: List[List[torch.Tensor]] = [] + out: list[torch.Tensor] = [] + intermediate_features: list[list[torch.Tensor]] = [] for disc in self.children(): - out_d: List[torch.Tensor] = disc(i) + out_d: list[torch.Tensor] = disc(i) out.append(out_d[-1]) intermediate_features.append(out_d[:-1]) @@ -144,12 +146,12 @@ def __init__( in_channels: int, out_channels: int, kernel_size: int, - activation: Union[str, tuple] = "PRELU", - norm: Union[str, tuple] = "INSTANCE", + activation: str | tuple = "PRELU", + norm: str | tuple = "INSTANCE", bias: bool = False, - padding: Union[int, Sequence[int]] = 1, - dropout: Union[float, tuple] = 0.0, - last_conv_kernel_size: Optional[int] = None, + padding: int | Sequence[int] = 1, + dropout: float | tuple = 0.0, + last_conv_kernel_size: int | None = None, ) -> None: super().__init__() @@ -217,7 +219,7 @@ def __init__( self.apply(self.initialise_weights) - def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: """ Args: diff --git a/generative/networks/nets/transformer.py b/generative/networks/nets/transformer.py index 84476aef..3fbeb6c3 100644 --- a/generative/networks/nets/transformer.py +++ b/generative/networks/nets/transformer.py @@ -9,7 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from __future__ import annotations + import torch import torch.nn as nn @@ -57,5 +58,5 @@ def __init__( ), ) - def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: return self.model(x, context=context) diff --git a/generative/networks/nets/vqvae.py b/generative/networks/nets/vqvae.py index 07c7aeba..4ce00c20 100644 --- a/generative/networks/nets/vqvae.py +++ b/generative/networks/nets/vqvae.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple, Union +from __future__ import annotations + +from typing import Sequence import torch import torch.nn as nn @@ -45,9 +47,9 @@ def __init__( num_channels: int, num_res_channels: int, adn_ordering: str = "NDA", - act: Optional[Union[Tuple, str]] = "RELU", - dropout: Optional[Union[Tuple, str, float]] = None, - dropout_dim: Optional[int] = 1, + act: tuple | str | None = "RELU", + dropout: tuple | str | float | None = None, + dropout_dim: int | None = 1, bias: bool = True, ) -> None: super().__init__() @@ -129,8 +131,8 @@ def __init__( in_channels: int, out_channels: int, num_levels: int = 3, - downsample_parameters: Tuple[Tuple[int, int, int, int], ...] = ((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1)), - upsample_parameters: Tuple[Tuple[int, int, int, int, int], ...] = ( + downsample_parameters: tuple[tuple[int, int, int, int], ...] = ((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1)), + upsample_parameters: tuple[tuple[int, int, int, int, int], ...] = ( (2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0), @@ -145,9 +147,9 @@ def __init__( decay: float = 0.5, epsilon: float = 1e-5, adn_ordering: str = "NDA", - dropout: Optional[Union[Tuple, str, float]] = 0.1, - act: Optional[Union[Tuple, str]] = "RELU", - output_act: Optional[Union[Tuple, str]] = None, + dropout: tuple | str | float | None = 0.1, + act: tuple | str | None = "RELU", + output_act: tuple | str | None = None, ddp_sync: bool = True, ): super().__init__() @@ -340,7 +342,7 @@ def construct_decoder(self) -> torch.nn.Sequential: def encode(self, images: torch.Tensor) -> torch.Tensor: return self.encoder(images) - def quantize(self, encodings: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def quantize(self, encodings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: x_loss, x = self.quantizer(encodings) return x, x_loss @@ -353,7 +355,7 @@ def index_quantize(self, images: torch.Tensor) -> torch.Tensor: def decode_samples(self, embedding_indices: torch.Tensor) -> torch.Tensor: return self.decode(self.quantizer.embed(embedding_indices)) - def forward(self, images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: quantizations, quantization_losses = self.quantize(self.encode(images)) reconstruction = self.decode(quantizations) diff --git a/generative/networks/schedulers/__init__.py b/generative/networks/schedulers/__init__.py index ff9e37bb..bb2eb347 100644 --- a/generative/networks/schedulers/__init__.py +++ b/generative/networks/schedulers/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .ddim import DDIMScheduler from .ddpm import DDPMScheduler from .pndm import PNDMScheduler diff --git a/generative/networks/schedulers/ddim.py b/generative/networks/schedulers/ddim.py index c6c4c2c6..1015692b 100644 --- a/generative/networks/schedulers/ddim.py +++ b/generative/networks/schedulers/ddim.py @@ -29,7 +29,8 @@ # limitations under the License. # ========================================================================= -from typing import Optional, Tuple, Union +from __future__ import annotations + import numpy as np import torch @@ -110,7 +111,7 @@ def __init__( self.clip_sample = clip_sample self.steps_offset = steps_offset - def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None: + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -149,8 +150,8 @@ def step( timestep: int, sample: torch.Tensor, eta: float = 0.0, - generator: Optional[torch.Generator] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + generator: torch.Generator | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index 7db5d7cb..6fbf11c7 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -29,7 +29,8 @@ # limitations under the License. # ========================================================================= -from typing import Optional, Tuple, Union +from __future__ import annotations + import numpy as np import torch @@ -98,7 +99,7 @@ def __init__( self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) - def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None: + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -145,7 +146,7 @@ def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torc return mean - def _get_variance(self, timestep: int, predicted_variance: Optional[torch.Tensor] = None) -> torch.Tensor: + def _get_variance(self, timestep: int, predicted_variance: torch.Tensor | None = None) -> torch.Tensor: """ Compute the variance of the posterior at timestep t. @@ -183,8 +184,8 @@ def step( model_output: torch.Tensor, timestep: int, sample: torch.Tensor, - generator: Optional[torch.Generator] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + generator: torch.Generator | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). diff --git a/generative/networks/schedulers/pndm.py b/generative/networks/schedulers/pndm.py index 6e05da21..0eb0fb32 100644 --- a/generative/networks/schedulers/pndm.py +++ b/generative/networks/schedulers/pndm.py @@ -29,7 +29,9 @@ # limitations under the License. # ========================================================================= -from typing import Any, Optional, Tuple, Union +from __future__ import annotations + +from typing import Any import numpy as np import torch @@ -122,7 +124,7 @@ def __init__( self.plms_timesteps = torch.Tensor([]) self.timesteps = torch.Tensor([]) - def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None: + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -170,7 +172,7 @@ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, to def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor - ) -> Tuple[torch.Tensor, Any]: + ) -> tuple[torch.Tensor, Any]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). diff --git a/generative/utils/__init__.py b/generative/utils/__init__.py index 8b6a35ab..be9d721b 100644 --- a/generative/utils/__init__.py +++ b/generative/utils/__init__.py @@ -9,4 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .enums import AdversarialIterationEvents, AdversarialKeys diff --git a/generative/utils/enums.py b/generative/utils/enums.py index 9c510f89..c78a4c16 100644 --- a/generative/utils/enums.py +++ b/generative/utils/enums.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import TYPE_CHECKING from monai.config import IgniteInfo diff --git a/generative/utils/ordering.py b/generative/utils/ordering.py index bb9a6db8..ef328958 100644 --- a/generative/utils/ordering.py +++ b/generative/utils/ordering.py @@ -9,7 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from __future__ import annotations + import numpy as np import torch @@ -46,11 +47,11 @@ def __init__( self, ordering_type: str, spatial_dims: int, - dimensions: Union[Tuple[int, int, int], Tuple[int, int, int, int]], - reflected_spatial_dims: Union[Tuple[bool, bool], Tuple[bool, bool, bool]] = (), - transpositions_axes: Union[Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int, int], ...]] = (), - rot90_axes: Union[Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int, int], ...]] = (), - transformation_order: Tuple[str, ...] = ( + dimensions: tuple[int, int, int] | tuple[int, int, int, int], + reflected_spatial_dims: tuple[bool, bool] | tuple[bool, bool, bool] = (), + transpositions_axes: tuple[tuple[int, int], ...] | tuple[tuple[int, int, int], ...] = (), + rot90_axes: tuple[tuple[int, int], ...] | tuple[tuple[int, int, int], ...] = (), + transformation_order: tuple[str, ...] = ( OrderingTransformations.TRANSPOSE.value, OrderingTransformations.ROTATE_90.value, OrderingTransformations.REFLECT.value, diff --git a/setup.cfg b/setup.cfg index d97e1c81..3730d9b8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,9 +26,10 @@ exclude = *.pyi,.git,.eggs,generative/_version.py,versioneer.py,venv,.venv,_vers known_first_party = generative profile = black line_length = 120 -skip = .git, .eggs, venv, .venv, versioneer.py, _version.py, conf.py, monai/__init__.py +skip = .git, .eggs, venv, .venv, versioneer.py, _version.py, conf.py, monai/__init__.py, tutorials/ skip_glob = *.pyi - +add_imports = from __future__ import annotations +append_only = true [mypy] # Suppresses error messages about imports that cannot be resolved. diff --git a/setup.py b/setup.py index 37dca295..4ef93b60 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from setuptools import find_packages, setup setup( diff --git a/tests/test_adversarial.py b/tests/test_adversarial.py index 9b012120..d837e8a0 100644 --- a/tests/test_adversarial.py +++ b/tests/test_adversarial.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py index bb6af8f8..54bc8dba 100644 --- a/tests/test_autoencoderkl.py +++ b/tests/test_autoencoderkl.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_compute_fid_metric.py b/tests/test_compute_fid_metric.py index 3cbc6180..fd9e2e75 100644 --- a/tests/test_compute_fid_metric.py +++ b/tests/test_compute_fid_metric.py @@ -10,6 +10,8 @@ # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_compute_mmd_metric.py b/tests/test_compute_mmd_metric.py index e7e26f3a..452d68e8 100644 --- a/tests/test_compute_mmd_metric.py +++ b/tests/test_compute_mmd_metric.py @@ -10,6 +10,8 @@ # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_compute_ms_ssim_metric.py b/tests/test_compute_ms_ssim_metric.py index 342605e8..da17de8f 100644 --- a/tests/test_compute_ms_ssim_metric.py +++ b/tests/test_compute_ms_ssim_metric.py @@ -10,6 +10,8 @@ # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py index ff98a11e..c450ed3d 100644 --- a/tests/test_diffusion_inferer.py +++ b/tests/test_diffusion_inferer.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py index f34b4a3f..1c5c647b 100644 --- a/tests/test_diffusion_model_unet.py +++ b/tests/test_diffusion_model_unet.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_integration_workflows_adversarial.py b/tests/test_integration_workflows_adversarial.py index dc176106..4ce46554 100644 --- a/tests/test_integration_workflows_adversarial.py +++ b/tests/test_integration_workflows_adversarial.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import shutil import tempfile diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index 58394754..8a830027 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_ordering.py b/tests/test_ordering.py index c40b77e9..8aea5447 100644 --- a/tests/test_ordering.py +++ b/tests/test_ordering.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np @@ -233,7 +235,7 @@ OrderingTransformations.REFLECT.value, ), } - ], + ] ] TEST_ORDERING_TRANSFORMATION_FAILURE = [ @@ -251,7 +253,7 @@ "flip", ), } - ], + ] ] TEST_REVERT = [ diff --git a/tests/test_patch_gan.py b/tests/test_patch_gan.py index 1c0270ed..7e8df802 100644 --- a/tests/test_patch_gan.py +++ b/tests/test_patch_gan.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py index 8abb2b0c..4112fd87 100644 --- a/tests/test_perceptual_loss.py +++ b/tests/test_perceptual_loss.py @@ -10,6 +10,8 @@ # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_scheduler_ddim.py b/tests/test_scheduler_ddim.py index c6b112ba..67d773fe 100644 --- a/tests/test_scheduler_ddim.py +++ b/tests/test_scheduler_ddim.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_scheduler_ddpm.py b/tests/test_scheduler_ddpm.py index 21f76509..7e07563e 100644 --- a/tests/test_scheduler_ddpm.py +++ b/tests/test_scheduler_ddpm.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_scheduler_pndm.py b/tests/test_scheduler_pndm.py index ff667b86..ff9b5ce0 100644 --- a/tests/test_scheduler_pndm.py +++ b/tests/test_scheduler_pndm.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_spectral_loss.py b/tests/test_spectral_loss.py index 6662f4c3..2bd2d970 100644 --- a/tests/test_spectral_loss.py +++ b/tests/test_spectral_loss.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 9ddb4ca7..5cc18d27 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_vector_quantizer.py b/tests/test_vector_quantizer.py index c6d32b6f..d4c9e209 100644 --- a/tests/test_vector_quantizer.py +++ b/tests/test_vector_quantizer.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch diff --git a/tests/test_vqvae.py b/tests/test_vqvae.py index 044fd73c..5e40acff 100644 --- a/tests/test_vqvae.py +++ b/tests/test_vqvae.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import unittest import torch From 43dfe8331b9395e0966e88e6dc6044c236610550 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Fri, 10 Feb 2023 00:02:37 +0000 Subject: [PATCH 2/3] Remove use of typing Signed-off-by: Walter Hugo Lopez Pinaya --- generative/inferers/inferer.py | 3 +-- generative/losses/perceptual.py | 1 - generative/losses/spectral_loss.py | 1 - generative/metrics/fid.py | 4 ++-- generative/metrics/mmd.py | 3 ++- generative/metrics/ms_ssim.py | 2 +- generative/networks/nets/autoencoderkl.py | 2 +- .../networks/nets/diffusion_model_unet.py | 6 +++--- .../networks/nets/patchgan_discriminator.py | 2 +- generative/networks/nets/transformer.py | 1 - generative/networks/nets/vqvae.py | 2 +- generative/networks/schedulers/ddim.py | 5 ++--- generative/networks/schedulers/ddpm.py | 11 +++-------- generative/networks/schedulers/pndm.py | 4 ++-- generative/utils/ordering.py | 1 - generative/version.py | 2 ++ tests/README.md | 19 ------------------- tests/test_compute_fid_metric.py | 1 - tests/test_compute_mmd_metric.py | 1 - tests/test_compute_ms_ssim_metric.py | 1 - tests/test_perceptual_loss.py | 1 - 21 files changed, 21 insertions(+), 52 deletions(-) delete mode 100644 tests/README.md diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index f758594f..84ca866e 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -9,11 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import annotations import math -from typing import Callable +from collections.abc import Callable import torch import torch.nn as nn diff --git a/generative/losses/perceptual.py b/generative/losses/perceptual.py index 9160265b..1d3b7b9d 100644 --- a/generative/losses/perceptual.py +++ b/generative/losses/perceptual.py @@ -11,7 +11,6 @@ from __future__ import annotations - import torch import torch.nn as nn from lpips import LPIPS diff --git a/generative/losses/spectral_loss.py b/generative/losses/spectral_loss.py index 3224d43d..d881f5dd 100644 --- a/generative/losses/spectral_loss.py +++ b/generative/losses/spectral_loss.py @@ -11,7 +11,6 @@ from __future__ import annotations - import torch import torch.nn.functional as F from monai.utils import LossReduction diff --git a/generative/metrics/fid.py b/generative/metrics/fid.py index 14ff6517..cc9f11e6 100644 --- a/generative/metrics/fid.py +++ b/generative/metrics/fid.py @@ -8,12 +8,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +# # ========================================================================= # Adapted from https://github.com/photosynthesis-team/piq # which has the following license: # https://github.com/photosynthesis-team/piq/blob/master/LICENSE - +# # Copyright 2023 photosynthesis-team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/generative/metrics/mmd.py b/generative/metrics/mmd.py index b781b698..bba93141 100644 --- a/generative/metrics/mmd.py +++ b/generative/metrics/mmd.py @@ -8,9 +8,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from __future__ import annotations -from typing import Callable +from collections.abc import Callable import torch from monai.metrics.regression import RegressionMetric diff --git a/generative/metrics/ms_ssim.py b/generative/metrics/ms_ssim.py index 456b05b4..9cd6cabb 100644 --- a/generative/metrics/ms_ssim.py +++ b/generative/metrics/ms_ssim.py @@ -8,8 +8,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations +from __future__ import annotations import torch import torch.nn.functional as F diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index e9ce57f6..6a21880e 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -13,7 +13,7 @@ import importlib.util import math -from typing import Sequence +from collections.abc import Sequence import torch import torch.nn as nn diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 46aaab58..68a5f882 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -8,12 +8,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +# # ========================================================================= # Adapted from https://github.com/huggingface/diffusers # which has the following license: # https://github.com/huggingface/diffusers/blob/main/LICENSE - +# # Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -33,7 +33,7 @@ import importlib.util import math -from typing import Sequence +from collections.abc import Sequence import torch import torch.nn.functional as F diff --git a/generative/networks/nets/patchgan_discriminator.py b/generative/networks/nets/patchgan_discriminator.py index a2eb2ae1..5648596b 100644 --- a/generative/networks/nets/patchgan_discriminator.py +++ b/generative/networks/nets/patchgan_discriminator.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence import torch import torch.nn as nn diff --git a/generative/networks/nets/transformer.py b/generative/networks/nets/transformer.py index 3fbeb6c3..a6fc81ec 100644 --- a/generative/networks/nets/transformer.py +++ b/generative/networks/nets/transformer.py @@ -11,7 +11,6 @@ from __future__ import annotations - import torch import torch.nn as nn from x_transformers import Decoder, TransformerWrapper diff --git a/generative/networks/nets/vqvae.py b/generative/networks/nets/vqvae.py index 4ce00c20..c1f32761 100644 --- a/generative/networks/nets/vqvae.py +++ b/generative/networks/nets/vqvae.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence import torch import torch.nn as nn diff --git a/generative/networks/schedulers/ddim.py b/generative/networks/schedulers/ddim.py index 1015692b..e12994b3 100644 --- a/generative/networks/schedulers/ddim.py +++ b/generative/networks/schedulers/ddim.py @@ -8,12 +8,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +# # ========================================================================= # Adapted from https://github.com/huggingface/diffusers # which has the following license: # https://github.com/huggingface/diffusers/blob/main/LICENSE - +# # Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -31,7 +31,6 @@ from __future__ import annotations - import numpy as np import torch import torch.nn as nn diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index 6fbf11c7..2f25f9f1 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -8,12 +8,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +# # ========================================================================= # Adapted from https://github.com/huggingface/diffusers # which has the following license: # https://github.com/huggingface/diffusers/blob/main/LICENSE - +# # Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -31,7 +31,6 @@ from __future__ import annotations - import numpy as np import torch import torch.nn as nn @@ -180,11 +179,7 @@ def _get_variance(self, timestep: int, predicted_variance: torch.Tensor | None = return variance def step( - self, - model_output: torch.Tensor, - timestep: int, - sample: torch.Tensor, - generator: torch.Generator | None = None, + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, generator: torch.Generator | None = None ) -> tuple[torch.Tensor, torch.Tensor]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion diff --git a/generative/networks/schedulers/pndm.py b/generative/networks/schedulers/pndm.py index 0eb0fb32..4f5e4f61 100644 --- a/generative/networks/schedulers/pndm.py +++ b/generative/networks/schedulers/pndm.py @@ -8,12 +8,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +# # ========================================================================= # Adapted from https://github.com/huggingface/diffusers # which has the following license: # https://github.com/huggingface/diffusers/blob/main/LICENSE - +# # Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/generative/utils/ordering.py b/generative/utils/ordering.py index ef328958..f00a3716 100644 --- a/generative/utils/ordering.py +++ b/generative/utils/ordering.py @@ -11,7 +11,6 @@ from __future__ import annotations - import numpy as np import torch diff --git a/generative/version.py b/generative/version.py index d498e36c..c3618d87 100644 --- a/generative/version.py +++ b/generative/version.py @@ -9,4 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + __version__ = "0.1.0" diff --git a/tests/README.md b/tests/README.md deleted file mode 100644 index 9c1e44f7..00000000 --- a/tests/README.md +++ /dev/null @@ -1,19 +0,0 @@ -# UNIT TESTS - -## Set environment -To use the tests already available at MONAI core, first we clone it: -```shell -git clone https://github.com/Project-MONAI/MONAI --branch main -``` - -Then we add it to PYTHONPATH -```shell -export PYTHONPATH="${PYTHONPATH}:./MONAI/" -``` - -## Executing tests -To run tests, use the following command: - -```shell script - python -m unittest discover tests -``` diff --git a/tests/test_compute_fid_metric.py b/tests/test_compute_fid_metric.py index fd9e2e75..3a512916 100644 --- a/tests/test_compute_fid_metric.py +++ b/tests/test_compute_fid_metric.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import annotations import unittest diff --git a/tests/test_compute_mmd_metric.py b/tests/test_compute_mmd_metric.py index 452d68e8..ab016653 100644 --- a/tests/test_compute_mmd_metric.py +++ b/tests/test_compute_mmd_metric.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import annotations import unittest diff --git a/tests/test_compute_ms_ssim_metric.py b/tests/test_compute_ms_ssim_metric.py index da17de8f..f648086c 100644 --- a/tests/test_compute_ms_ssim_metric.py +++ b/tests/test_compute_ms_ssim_metric.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import annotations import unittest diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py index 4112fd87..8e0c6cb3 100644 --- a/tests/test_perceptual_loss.py +++ b/tests/test_perceptual_loss.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import annotations import unittest From 863fa4f1e2bc8c1017353ef90ca5b044cd713398 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Fri, 10 Feb 2023 00:14:54 +0000 Subject: [PATCH 3/3] Add work around to avoid errors with vqvae Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/layers/__init__.py | 2 -- generative/networks/layers/vector_quantizer.py | 2 -- setup.cfg | 3 ++- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/generative/networks/layers/__init__.py b/generative/networks/layers/__init__.py index 51b907c1..bd6e831f 100644 --- a/generative/networks/layers/__init__.py +++ b/generative/networks/layers/__init__.py @@ -9,6 +9,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from .vector_quantizer import EMAQuantizer, VectorQuantizer diff --git a/generative/networks/layers/vector_quantizer.py b/generative/networks/layers/vector_quantizer.py index bcbffa4b..661f2129 100644 --- a/generative/networks/layers/vector_quantizer.py +++ b/generative/networks/layers/vector_quantizer.py @@ -9,8 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from typing import Sequence import torch diff --git a/setup.cfg b/setup.cfg index 3730d9b8..d7fde472 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,7 +26,8 @@ exclude = *.pyi,.git,.eggs,generative/_version.py,versioneer.py,venv,.venv,_vers known_first_party = generative profile = black line_length = 120 -skip = .git, .eggs, venv, .venv, versioneer.py, _version.py, conf.py, monai/__init__.py, tutorials/ +# generative/networks/layers/ is excluded because it is raising JIT errors +skip = .git, .eggs, venv, .venv, versioneer.py, _version.py, conf.py, monai/__init__.py, tutorials/, generative/networks/layers/ skip_glob = *.pyi add_imports = from __future__ import annotations append_only = true