From 203d2746ab1bc895f588cf880a63c08439e7d612 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 16 Feb 2023 16:46:11 -0600 Subject: [PATCH] Fixes type --- generative/inferers/inferer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index b44ac11f..56556241 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -69,7 +69,7 @@ def sample( save_intermediates: bool | None = False, intermediate_steps: int | None = 100, conditioning: torch.Tensor | None = None, - verbose: bool | None = True, + verbose: bool = True, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: @@ -114,7 +114,7 @@ def get_likelihood( conditioning: torch.Tensor | None = None, original_input_range: tuple | None = (0, 255), scaled_input_range: tuple | None = (0, 1), - verbose: bool | None = True, + verbose: bool = True, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Computes the likelihoods for an input. @@ -321,7 +321,7 @@ def sample( save_intermediates: bool | None = False, intermediate_steps: int | None = 100, conditioning: torch.Tensor | None = None, - verbose: bool | None = True, + verbose: bool = True, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: @@ -371,7 +371,7 @@ def get_likelihood( conditioning: torch.Tensor | None = None, original_input_range: tuple | None = (0, 255), scaled_input_range: tuple | None = (0, 1), - verbose: bool | None = True, + verbose: bool = True, resample_latent_likelihoods: bool | None = False, resample_interpolation_mode: str | None = "bilinear", ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: @@ -474,7 +474,7 @@ def sample( conditioning: torch.Tensor | None = None, temperature: float = 1.0, top_k: int | None = None, - verbose: bool | None = True, + verbose: bool = True, ) -> torch.Tensor: """ Sampling function for the VQVAE + Transformer model.