From 68d16078efa5ac4d206779aadc6e4e3cba54d859 Mon Sep 17 00:00:00 2001 From: Alphonsce Date: Thu, 9 May 2024 23:07:51 +0300 Subject: [PATCH] made small fixes to files and run black and isort --- setup.py | 21 +- src/metr/__init__.py | 38 +- src/metr/_check.py | 4 - src/metr/_detect.py | 90 -- src/metr/_get_noise.py | 94 -- src/metr/finetune_ldm_decoder.py | 393 +++--- src/metr/guided_diffusion/dist_util.py | 4 +- src/metr/guided_diffusion/fp16_util.py | 40 +- .../guided_diffusion/gaussian_diffusion.py | 158 +-- src/metr/guided_diffusion/image_datasets.py | 28 +- src/metr/guided_diffusion/logger.py | 51 +- src/metr/guided_diffusion/losses.py | 14 +- src/metr/guided_diffusion/nn.py | 6 +- src/metr/guided_diffusion/resample.py | 15 +- src/metr/guided_diffusion/respace.py | 20 +- src/metr/guided_diffusion/script_util.py | 12 +- src/metr/guided_diffusion/train_util.py | 60 +- src/metr/guided_diffusion/unet.py | 63 +- src/metr/inverse_stable_diffusion.py | 79 +- src/metr/io_utils.py | 49 +- src/metr/ldm/data/base.py | 12 +- src/metr/ldm/data/imagenet.py | 123 +- src/metr/ldm/data/lsun.py | 49 +- src/metr/ldm/lr_scheduler.py | 30 +- src/metr/ldm/models/autoencoder.py | 105 +- src/metr/ldm/models/diffusion/ddim.py | 343 ++++-- src/metr/ldm/models/diffusion/ddpm.py | 1057 ++++++++++------- .../models/diffusion/dpm_solver/__init__.py | 2 +- .../models/diffusion/dpm_solver/dpm_solver.py | 563 +++++---- .../models/diffusion/dpm_solver/sampler.py | 70 +- src/metr/ldm/models/diffusion/plms.py | 258 ++-- .../ldm/models/diffusion/sampling_util.py | 6 +- src/metr/ldm/modules/attention.py | 189 +-- .../ldm/modules/diffusionmodules/model.py | 632 +++++----- .../modules/diffusionmodules/openaimodel.py | 183 +-- .../ldm/modules/diffusionmodules/upscaling.py | 49 +- src/metr/ldm/modules/diffusionmodules/util.py | 54 +- .../modules/distributions/distributions.py | 37 +- src/metr/ldm/modules/ema.py | 21 +- src/metr/ldm/modules/encoders/modules.py | 103 +- .../ldm/modules/image_degradation/bsrgan.py | 204 ++-- .../modules/image_degradation/bsrgan_light.py | 169 +-- .../modules/image_degradation/utils_image.py | 272 +++-- src/metr/ldm/modules/losses/__init__.py | 2 +- src/metr/ldm/modules/losses/contperceptual.py | 71 +- src/metr/ldm/modules/losses/vqperceptual.py | 94 +- src/metr/ldm/modules/midas/api.py | 33 +- .../ldm/modules/midas/midas/base_model.py | 2 +- src/metr/ldm/modules/midas/midas/blocks.py | 130 +- src/metr/ldm/modules/midas/midas/dpt_depth.py | 14 +- src/metr/ldm/modules/midas/midas/midas_net.py | 8 +- .../modules/midas/midas/midas_net_custom.py | 86 +- .../ldm/modules/midas/midas/transforms.py | 50 +- src/metr/ldm/modules/midas/midas/vit.py | 67 +- src/metr/ldm/modules/midas/utils.py | 21 +- src/metr/ldm/modules/x_transformer.py | 306 +++-- src/metr/ldm/util.py | 16 +- src/metr/loss/color_wrapper.py | 51 +- src/metr/loss/dct2d.py | 68 +- src/metr/loss/deep_loss.py | 159 ++- src/metr/loss/loss_provider.py | 139 ++- src/metr/loss/rfft2d.py | 55 +- src/metr/loss/shift_wrapper.py | 22 +- src/metr/loss/ssim.py | 43 +- src/metr/loss/watson.py | 85 +- src/metr/loss/watson_fft.py | 96 +- src/metr/loss/watson_vgg.py | 97 +- src/metr/metr_pp_eval_stable_sig.py | 267 +++-- src/metr/modified_stable_diffusion.py | 32 +- src/metr/open_clip/__init__.py | 45 +- src/metr/open_clip/coca_model.py | 132 +- src/metr/open_clip/factory.py | 206 ++-- src/metr/open_clip/hf_configs.py | 6 +- src/metr/open_clip/hf_model.py | 64 +- src/metr/open_clip/loss.py | 98 +- src/metr/open_clip/model.py | 138 ++- src/metr/open_clip/modified_resnet.py | 31 +- src/metr/open_clip/openai.py | 26 +- src/metr/open_clip/pretrained.py | 201 ++-- src/metr/open_clip/push_to_hf_hub.py | 91 +- src/metr/open_clip/timm_model.py | 65 +- src/metr/open_clip/tokenizer.py | 60 +- src/metr/open_clip/transform.py | 89 +- src/metr/open_clip/transformer.py | 306 ++--- src/metr/open_clip/utils.py | 7 +- src/metr/open_clip/version.py | 2 +- src/metr/optim_utils.py | 183 +-- src/metr/pytorch_fid/__init__.py | 2 +- src/metr/pytorch_fid/fid_score.py | 134 +-- src/metr/pytorch_fid/inception.py | 77 +- src/metr/run_metr.py | 287 +++-- src/metr/run_metr_fid.py | 262 ++-- src/metr/stable_sig/__init__.py | 2 +- src/metr/stable_sig/utils.py | 177 +-- src/metr/stable_sig/utils_img.py | 92 +- src/metr/stable_sig/utils_model.py | 83 +- src/metr/taming/data/ade20k.py | 64 +- .../taming/data/annotated_objects_coco.py | 105 +- .../taming/data/annotated_objects_dataset.py | 139 ++- .../data/annotated_objects_open_images.py | 102 +- src/metr/taming/data/base.py | 18 +- src/metr/taming/data/coco.py | 138 ++- .../data/conditional_builder/objects_bbox.py | 50 +- .../objects_center_points.py | 100 +- .../taming/data/conditional_builder/utils.py | 41 +- src/metr/taming/data/custom.py | 9 +- src/metr/taming/data/faceshq.py | 24 +- src/metr/taming/data/helper_types.py | 11 +- src/metr/taming/data/image_transforms.py | 12 +- src/metr/taming/data/imagenet.py | 183 ++- src/metr/taming/data/open_images_helper.py | 739 ++++++------ src/metr/taming/data/sflckr.py | 66 +- src/metr/taming/data/utils.py | 39 +- src/metr/taming/lr_scheduler.py | 10 +- src/metr/taming/models/cond_transformer.py | 147 +-- src/metr/taming/models/vqgan.py | 323 ++--- .../taming/modules/diffusionmodules/model.py | 519 ++++---- .../taming/modules/discriminator/model.py | 21 +- src/metr/taming/modules/losses/__init__.py | 1 - src/metr/taming/modules/losses/lpips.py | 36 +- .../taming/modules/losses/segmentation.py | 17 +- .../taming/modules/losses/vqperceptual.py | 83 +- src/metr/taming/modules/misc/coord.py | 13 +- src/metr/taming/modules/transformer/mingpt.py | 167 ++- .../taming/modules/transformer/permuter.py | 90 +- src/metr/taming/modules/util.py | 32 +- src/metr/taming/modules/vqvae/quantize.py | 238 ++-- src/metr/taming/util.py | 41 +- src/metr/utils.py | 12 +- 129 files changed, 7494 insertions(+), 6746 deletions(-) delete mode 100644 src/metr/_check.py delete mode 100644 src/metr/_detect.py delete mode 100644 src/metr/_get_noise.py diff --git a/setup.py b/setup.py index a9c462e..ce0c384 100644 --- a/setup.py +++ b/setup.py @@ -1,27 +1,15 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import re from distutils.core import Command from setuptools import find_packages, setup - # IMPORTANT: # 1. all dependencies should be listed here with their version requirements if any # 2. once modified, run: `make deps_table_update` to update src/tree-ring-watermark/dependency_versions_table.py _deps = [ -# "torch==1.13.0", - "transformers==4.31.0", - "diffusers==0.14.0", + # "torch==1.13.0", + "transformers==4.31.0", + "diffusers==0.14.0", ] # this is a lookup table with items like: @@ -108,12 +96,10 @@ def run(self): package_dir={"": "src"}, packages=find_packages("src"), include_package_data=True, - package_data={ # 'tree_ring_watermark': ['src/tree_ring_watermark/open_clip/bpe_simple_vocab_16e6.txt.gz', 'src/tree_ring_watermark/open_clip/model_configs/*.json'] "": ["*.gz", "*.json", "*.pth"] }, - python_requires=">=3.7.0", install_requires=install_requires, extras_require=extras, @@ -152,4 +138,3 @@ def run(self): # twine upload dist/* -r pypi # 8. Add release notes to the tag in github once everything is looking hunky-dory. # 9. Update the version in __init__.py, setup.py to the new version "-dev" and push to master - diff --git a/src/metr/__init__.py b/src/metr/__init__.py index 7a23317..e0ee907 100644 --- a/src/metr/__init__.py +++ b/src/metr/__init__.py @@ -1,29 +1,29 @@ -# I call it a 0.1.0 version -import sys import os +import sys sys.path.append(os.path.dirname(os.path.abspath(__file__))) __version__ = "0.1.0" -from ._check import check -from ._detect import detect -from ._get_noise import get_noise -from .utils import set_org, get_org - -from .modified_stable_diffusion import ModifiedStableDiffusionPipeline -from .inverse_stable_diffusion import InversableStableDiffusionPipeline - -from .optim_utils import * -from .io_utils import * - -# To run run_tree_watermarking scripts: -from .open_clip import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss -from .pytorch_fid.fid_score import * from .guided_diffusion.script_util import ( NUM_CLASSES, - model_and_diffusion_defaults, - create_model_and_diffusion, add_dict_to_argparser, args_to_dict, -) \ No newline at end of file + create_model_and_diffusion, + model_and_diffusion_defaults, +) +from .inverse_stable_diffusion import InversableStableDiffusionPipeline +from .io_utils import * +from .modified_stable_diffusion import ModifiedStableDiffusionPipeline + +# To run metr scripts: +from .open_clip import ( + create_loss, + create_model, + create_model_and_transforms, + create_model_from_pretrained, + get_tokenizer, +) +from .optim_utils import * +from .pytorch_fid.fid_score import * +from .utils import get_org, set_org diff --git a/src/metr/_check.py b/src/metr/_check.py deleted file mode 100644 index c758150..0000000 --- a/src/metr/_check.py +++ /dev/null @@ -1,4 +0,0 @@ -from diffusers import DiffusionPipeline - -def check(pipeline: DiffusionPipeline, model_hash: str): - pass diff --git a/src/metr/_detect.py b/src/metr/_detect.py deleted file mode 100644 index e817220..0000000 --- a/src/metr/_detect.py +++ /dev/null @@ -1,90 +0,0 @@ -from huggingface_hub import snapshot_download -import numpy as np -import torch -from torchvision import transforms -import PIL -from typing import Union -from huggingface_hub import snapshot_download -from diffusers import DDIMInverseScheduler -from .utils import get_org -from ._get_noise import _circle_mask -import os - -def _transform_img(image, target_size=512): - tform = transforms.Compose( - [ - transforms.Resize(target_size), - transforms.CenterCrop(target_size), - transforms.ToTensor(), - ] - ) - image = tform(image) - return 2.0 * image - 1.0 - - -def load_keys(cache_dir): - # Initialize an empty dictionary to store the numpy arrays - arrays = {} - - # List all files in the directory - for file_name in os.listdir(cache_dir): - # Check if the file is a .npy file - if file_name.endswith('.npy'): - # Define the file path - file_path = os.path.join(cache_dir, file_name) - - # Load the numpy array and store it in the dictionary - arrays[file_name] = np.load(file_path) - - # Return the 'arrays' dictionary - return arrays - - -# def detect(image: Union[PIL.Image.Image, torch.Tensor, np.ndarray], model_hash: str): -def detect(image: Union[PIL.Image.Image, torch.Tensor, np.ndarray], pipe, model_hash, org): - ''' - pipe: Inverse Diffusion process pipeline - ''' - detection_time_num_inference = 50 - threshold = 77 - - # org = get_org() - repo_id = os.path.join(org, model_hash) - - cache_dir = snapshot_download(repo_id, repo_type="dataset") - keys = load_keys(cache_dir) - - # ddim inversion - curr_scheduler = pipe.scheduler - pipe.scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config) - img = _transform_img(image).unsqueeze(0).to(pipe.unet.dtype).to(pipe.device) - image_latents = pipe.vae.encode(img).latent_dist.mode() * 0.18215 - # Here latents are z in stable diff - inverted_latents = pipe( - prompt='', - latents=image_latents, - guidance_scale=1, - num_inference_steps=detection_time_num_inference, - output_type='latent', - ) - inverted_latents = inverted_latents.images.float().cpu() - - # check if one key matches - shape = image_latents.shape - for filename, w_key in keys.items(): - w_channel, w_radius = filename.split(".npy")[0].split("_")[1:3] - - np_mask = _circle_mask(shape[-1], r=int(w_radius)) - torch_mask = torch.tensor(np_mask) - w_mask = torch.zeros(shape, dtype=torch.bool) - w_mask[:, int(w_channel)] = torch_mask - - # calculate the distance - inverted_latents_fft = torch.fft.fftshift(torch.fft.fft2(inverted_latents), dim=(-1, -2)) - dist = torch.abs(inverted_latents_fft[w_mask] - w_key[w_mask]).mean().item() - - if dist <= threshold: - pipe.scheduler = curr_scheduler - return True - - return False diff --git a/src/metr/_get_noise.py b/src/metr/_get_noise.py deleted file mode 100644 index e964f16..0000000 --- a/src/metr/_get_noise.py +++ /dev/null @@ -1,94 +0,0 @@ -import torch -from typing import Union, List, Tuple -import numpy as np -import hashlib -import os -import tempfile -from huggingface_hub import hf_api -from .utils import get_org - -api = hf_api.HfApi() - -def _circle_mask(size=64, r=10, x_offset=0, y_offset=0): - # reference: https://stackoverflow.com/questions/69687798/generating-a-soft-circluar-mask-using-numpy-python-3 - x0 = y0 = size // 2 - x0 += x_offset - y0 += y_offset - y, x = np.ogrid[:size, :size] - y = y[::-1] - - return ((x - x0)**2 + (y-y0)**2)<= r**2 - -def _get_pattern(shape, w_pattern='ring', generator=None): - gt_init = torch.randn(shape, generator=generator) - - if 'rand' in w_pattern: - gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) - gt_patch[:] = gt_patch[0] - elif 'zeros' in w_pattern: - gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) * 0 - elif 'ring' in w_pattern: - gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) - - gt_patch_tmp = gt_patch.clone().detach() - for i in range(shape[-1] // 2, 0, -1): - tmp_mask = _circle_mask(gt_init.shape[-1], r=i) - tmp_mask = torch.tensor(tmp_mask) - - for j in range(gt_patch.shape[1]): - gt_patch[:, j, tmp_mask] = gt_patch_tmp[0, j, 0, i].item() - - return gt_patch - - -# def get_noise(shape: Union[torch.Size, List, Tuple], model_hash: str) -> torch.Tensor: -def get_noise(shape: Union[torch.Size, List, Tuple], model_hash: str, org, generator=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # for now we hard code all hyperparameters - w_channel = 0 # id for watermarked channel - w_radius = 10 # watermark radius - w_pattern = 'rand' # watermark pattern - - # get watermark key and mask - np_mask = _circle_mask(shape[-1], r=w_radius) - torch_mask = torch.tensor(np_mask) - w_mask = torch.zeros(shape, dtype=torch.bool) - w_mask[:, w_channel] = torch_mask - - w_key = _get_pattern(shape, w_pattern=w_pattern, generator=generator) - - # inject watermark - assert len(shape) == 4, f"Make sure you pass a `shape` tuple/list of length 4 not {len(shape)}" - assert shape[0] == 1, f"For now only batch_size=1 is supported, not {shape[0]}." - - init_latents = torch.randn(shape, generator=generator) - - init_latents_fft = torch.fft.fftshift(torch.fft.fft2(init_latents), dim=(-1, -2)) - init_latents_fft[w_mask] = w_key[w_mask].clone() - init_latents = torch.fft.ifft2(torch.fft.ifftshift(init_latents_fft, dim=(-1, -2))).real - - # convert the tensor to bytes - tensor_bytes = init_latents.numpy().tobytes() - - # generate a hash from the bytes - hash_object = hashlib.sha256(tensor_bytes) - hex_dig = hash_object.hexdigest() - - file_name = "_".join([hex_dig, str(w_channel), str(w_radius), w_pattern]) + ".npy" - temp_dir = tempfile.gettempdir() - file_path = os.path.join(temp_dir, file_name) - np.save(file_path, w_key) - - # org = get_org() - - repo_id = os.path.join(org, model_hash) - - api.create_repo(repo_id=repo_id, exist_ok=True, repo_type="dataset") - - api.upload_file( - path_or_fileobj=file_path, - path_in_repo=file_name, - repo_id=repo_id, - repo_type="dataset", - ) - - return init_latents diff --git a/src/metr/finetune_ldm_decoder.py b/src/metr/finetune_ldm_decoder.py index 4423e48..74f8eb8 100644 --- a/src/metr/finetune_ldm_decoder.py +++ b/src/metr/finetune_ldm_decoder.py @@ -5,11 +5,11 @@ # LICENSE file in the root directory of this source tree. import argparse +import importlib import json import os import sys from copy import deepcopy -from omegaconf import OmegaConf from pathlib import Path from typing import Callable, Iterable @@ -18,6 +18,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from omegaconf import OmegaConf from torch.utils.data import DataLoader from torchvision import transforms from torchvision.utils import save_image @@ -26,25 +27,27 @@ # import .stable_sig.utils_img # import .stable_sig.utils_model -import importlib + def import_from_stable_sig(name): module = importlib.import_module(".stable_sig." + name, package=__package__) return module + utils = import_from_stable_sig("utils") utils_img = import_from_stable_sig("utils_img") utils_model = import_from_stable_sig("utils_model") +from tqdm import tqdm + +import wandb + # sys.path.append('src') from .ldm.models.autoencoder import AutoencoderKL from .ldm.models.diffusion.ddpm import LatentDiffusion from .loss.loss_provider import LossProvider -import wandb -from tqdm import tqdm - -device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") def get_parser(): @@ -53,48 +56,58 @@ def get_parser(): def aa(*args, **kwargs): group.add_argument(*args, **kwargs) - group = parser.add_argument_group('Data parameters') + group = parser.add_argument_group("Data parameters") aa("--train_dir", type=str, help="Path to the training data directory", required=True) aa("--val_dir", type=str, help="Path to the validation data directory", required=True) - group = parser.add_argument_group('Model parameters') - aa("--ldm_config", type=str, default="v2-inference.yaml", help="Path to the configuration file for the LDM model") - aa("--ldm_ckpt", type=str, default="v2-1_512-ema-pruned.ckpt", help="Path to the checkpoint file for the LDM model") - aa("--msg_decoder_path", type=str, default= "dec_48b_whit.torchscript.pt", help="Path to the hidden decoder for the watermarking model") + group = parser.add_argument_group("Model parameters") + aa("--ldm_config", type=str, default="v2-inference.yaml", help="Path to the configuration file for the LDM model") + aa("--ldm_ckpt", type=str, default="v2-1_512-ema-pruned.ckpt", help="Path to the checkpoint file for the LDM model") + aa( + "--msg_decoder_path", + type=str, + default="dec_48b_whit.torchscript.pt", + help="Path to the hidden decoder for the watermarking model", + ) aa("--num_bits", type=int, default=48, help="Number of bits in the watermark") aa("--redundancy", type=int, default=1, help="Number of times the watermark is repeated to increase robustness") aa("--decoder_depth", type=int, default=8, help="Depth of the decoder in the watermarking model") aa("--decoder_channels", type=int, default=64, help="Number of channels in the decoder of the watermarking model") - group = parser.add_argument_group('Training parameters') + group = parser.add_argument_group("Training parameters") aa("--batch_size", type=int, default=4, help="Batch size for training") aa("--img_size", type=int, default=256, help="Resize images to this size") - aa("--loss_i", type=str, default="watson-vgg", help="Type of loss for the image loss. Can be watson-vgg, mse, watson-dft, etc.") + aa( + "--loss_i", + type=str, + default="watson-vgg", + help="Type of loss for the image loss. Can be watson-vgg, mse, watson-dft, etc.", + ) aa("--loss_w", type=str, default="bce", help="Type of loss for the watermark loss. Can be mse or bce") aa("--lambda_i", type=float, default=0.2, help="Weight of the image loss in the total loss") aa("--lambda_w", type=float, default=1.0, help="Weight of the watermark loss in the total loss") aa("--optimizer", type=str, default="AdamW,lr=5e-4", help="Optimizer and learning rate for training") aa("--steps", type=int, default=100, help="Number of steps to train the model for") aa("--warmup_steps", type=int, default=20, help="Number of warmup steps for the optimizer") - aa("--num_val_imgs", type=int, default=200, help="Number of images for validation") + aa("--num_val_imgs", type=int, default=200, help="Number of images for validation") - group = parser.add_argument_group('Logging and saving freq. parameters') + group = parser.add_argument_group("Logging and saving freq. parameters") aa("--log_freq", type=int, default=10, help="Logging frequency (in steps)") aa("--save_img_freq", type=int, default=1000, help="Frequency of saving generated images (in steps)") - aa('--with_tracking', action='store_true') - aa('--project_name', default='watermark_attacks') - aa('--run_name', default='test') + aa("--with_tracking", action="store_true") + aa("--project_name", default="watermark_attacks") + aa("--run_name", default="test") - group = parser.add_argument_group('Experiments parameters') + group = parser.add_argument_group("Experiments parameters") aa("--num_keys", type=int, default=1, help="Number of fine-tuned checkpoints to generate") aa("--output_dir", type=str, default="output/", help="Output directory for logs and images (Default: /output)") aa("--checkpoint_name", default=None) aa("--seed", type=int, default=0) aa("--debug", type=utils.bool_inst, default=False, help="Debug mode") - group = parser.add_argument_group('Additional parameters') + group = parser.add_argument_group("Additional parameters") aa("--not_rand_key", action="store_true") - aa("--key_str", type=str, default='111010110101000001010111010011010100010000100111') + aa("--key_str", type=str, default="111010110101000001010111010011010100010000100111") aa("--no_attacks", action="store_true") return parser @@ -102,15 +115,15 @@ def aa(*args, **kwargs): def main(params): if params.with_tracking: - wandb_run = wandb.init(project=params.project_name, name=params.run_name, tags=['tree_ring_watermark']) + wandb_run = wandb.init(project=params.project_name, name=params.run_name, tags=["tree_ring_watermark"]) else: wandb_run = None - # Set seeds for reproductibility + # Set seeds for reproductibility torch.manual_seed(params.seed) torch.cuda.manual_seed_all(params.seed) np.random.seed(params.seed) - + # Print the arguments print("__git__:{}".format(utils.get_sha())) print("__log__:{}".format(json.dumps(vars(params)))) @@ -118,13 +131,13 @@ def main(params): # Create the directories if not os.path.exists(params.output_dir): os.makedirs(params.output_dir) - imgs_dir = os.path.join(params.output_dir, 'imgs') + imgs_dir = os.path.join(params.output_dir, "imgs") params.imgs_dir = imgs_dir if not os.path.exists(imgs_dir): os.makedirs(imgs_dir, exist_ok=True) # Loads LDM auto-encoder models - print(f'>>> Building LDM model with config {params.ldm_config} and weights from {params.ldm_ckpt}...') + print(f">>> Building LDM model with config {params.ldm_config} and weights from {params.ldm_ckpt}...") config = OmegaConf.load(f"{params.ldm_config}") ldm_ae: LatentDiffusion = utils_model.load_model_from_config(config, params.ldm_ckpt) ldm_ae: AutoencoderKL = ldm_ae.first_stage_model @@ -138,41 +151,48 @@ def main(params): ldm_ae.eval() ldm_ae.to(device) - + # Loads hidden decoder - print(f'>>> Building hidden decoder with weights from {params.msg_decoder_path}...') - if 'torchscript' in params.msg_decoder_path: + print(f">>> Building hidden decoder with weights from {params.msg_decoder_path}...") + if "torchscript" in params.msg_decoder_path: msg_decoder = torch.jit.load(params.msg_decoder_path).to(device) # already whitened - + else: - msg_decoder = utils_model.get_hidden_decoder(num_bits=params.num_bits, redundancy=params.redundancy, num_blocks=params.decoder_depth, channels=params.decoder_channels).to(device) + msg_decoder = utils_model.get_hidden_decoder( + num_bits=params.num_bits, + redundancy=params.redundancy, + num_blocks=params.decoder_depth, + channels=params.decoder_channels, + ).to(device) ckpt = utils_model.get_hidden_decoder_ckpt(params.msg_decoder_path) print(msg_decoder.load_state_dict(ckpt, strict=False)) msg_decoder.eval() # whitening - print(f'>>> Whitening...') + print(f">>> Whitening...") with torch.no_grad(): # features from the dataset - transform = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(256), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ]) + transform = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(256), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) loader = utils.get_dataloader(params.train_dir, transform, batch_size=16, collate_fn=None) ys = [] for i, x in enumerate(loader): x = x.to(device) y = msg_decoder(x) - ys.append(y.to('cpu')) + ys.append(y.to("cpu")) ys = torch.cat(ys, dim=0) nbit = ys.shape[1] - + # whitening - mean = ys.mean(dim=0, keepdim=True) # NxD -> 1xD - ys_centered = ys - mean # NxD + mean = ys.mean(dim=0, keepdim=True) # NxD -> 1xD + ys_centered = ys - mean # NxD cov = ys_centered.T @ ys_centered e, v = torch.linalg.eigh(cov) L = torch.diag(1.0 / torch.pow(e, exponent=0.5)) @@ -184,9 +204,9 @@ def main(params): msg_decoder = nn.Sequential(msg_decoder, linear.to(device)) torchscript_m = torch.jit.script(msg_decoder) params.msg_decoder_path = params.msg_decoder_path.replace(".pth", "_whit.pth") - print(f'>>> Creating torchscript at {params.msg_decoder_path}...') + print(f">>> Creating torchscript at {params.msg_decoder_path}...") torch.jit.save(torchscript_m, params.msg_decoder_path) - + msg_decoder.eval() nbit = msg_decoder(torch.zeros(1, 3, 128, 128).to(device)).shape[-1] @@ -195,58 +215,78 @@ def main(params): param.requires_grad = False # Loads the data - print(f'>>> Loading data from {params.train_dir} and {params.val_dir}...') - vqgan_transform = transforms.Compose([ - transforms.Resize(params.img_size), - transforms.CenterCrop(params.img_size), - transforms.ToTensor(), - utils_img.normalize_vqgan, - ]) - train_loader = utils.get_dataloader(params.train_dir, vqgan_transform, params.batch_size, num_imgs=params.batch_size*params.steps, shuffle=True, num_workers=4, collate_fn=None) - val_loader = utils.get_dataloader(params.val_dir, vqgan_transform, params.batch_size*4, num_imgs=params.num_val_imgs, shuffle=False, num_workers=4, collate_fn=None) + print(f">>> Loading data from {params.train_dir} and {params.val_dir}...") + vqgan_transform = transforms.Compose( + [ + transforms.Resize(params.img_size), + transforms.CenterCrop(params.img_size), + transforms.ToTensor(), + utils_img.normalize_vqgan, + ] + ) + train_loader = utils.get_dataloader( + params.train_dir, + vqgan_transform, + params.batch_size, + num_imgs=params.batch_size * params.steps, + shuffle=True, + num_workers=4, + collate_fn=None, + ) + val_loader = utils.get_dataloader( + params.val_dir, + vqgan_transform, + params.batch_size * 4, + num_imgs=params.num_val_imgs, + shuffle=False, + num_workers=4, + collate_fn=None, + ) vqgan_to_imnet = transforms.Compose([utils_img.unnormalize_vqgan, utils_img.normalize_img]) - + # Create losses - print(f'>>> Creating losses...') - print(f'Losses: {params.loss_w} and {params.loss_i}...') - if params.loss_w == 'mse': - loss_w = lambda decoded, keys, temp=10.0: torch.mean((decoded*temp - (2*keys-1))**2) # b k - b k - elif params.loss_w == 'bce': - loss_w = lambda decoded, keys, temp=10.0: F.binary_cross_entropy_with_logits(decoded*temp, keys, reduction='mean') + print(f">>> Creating losses...") + print(f"Losses: {params.loss_w} and {params.loss_i}...") + if params.loss_w == "mse": + loss_w = lambda decoded, keys, temp=10.0: torch.mean((decoded * temp - (2 * keys - 1)) ** 2) # b k - b k + elif params.loss_w == "bce": + loss_w = lambda decoded, keys, temp=10.0: F.binary_cross_entropy_with_logits( + decoded * temp, keys, reduction="mean" + ) else: raise NotImplementedError - - if params.loss_i == 'mse': - loss_i = lambda imgs_w, imgs: torch.mean((imgs_w - imgs)**2) - elif params.loss_i == 'watson-dft': + + if params.loss_i == "mse": + loss_i = lambda imgs_w, imgs: torch.mean((imgs_w - imgs) ** 2) + elif params.loss_i == "watson-dft": provider = LossProvider() - loss_percep = provider.get_loss_function('Watson-DFT', colorspace='RGB', pretrained=True, reduction='sum') + loss_percep = provider.get_loss_function("Watson-DFT", colorspace="RGB", pretrained=True, reduction="sum") loss_percep = loss_percep.to(device) - loss_i = lambda imgs_w, imgs: loss_percep((1+imgs_w)/2.0, (1+imgs)/2.0)/ imgs_w.shape[0] - elif params.loss_i == 'watson-vgg': + loss_i = lambda imgs_w, imgs: loss_percep((1 + imgs_w) / 2.0, (1 + imgs) / 2.0) / imgs_w.shape[0] + elif params.loss_i == "watson-vgg": provider = LossProvider() - loss_percep = provider.get_loss_function('Watson-VGG', colorspace='RGB', pretrained=True, reduction='sum') + loss_percep = provider.get_loss_function("Watson-VGG", colorspace="RGB", pretrained=True, reduction="sum") loss_percep = loss_percep.to(device) - loss_i = lambda imgs_w, imgs: loss_percep((1+imgs_w)/2.0, (1+imgs)/2.0)/ imgs_w.shape[0] - elif params.loss_i == 'ssim': + loss_i = lambda imgs_w, imgs: loss_percep((1 + imgs_w) / 2.0, (1 + imgs) / 2.0) / imgs_w.shape[0] + elif params.loss_i == "ssim": provider = LossProvider() - loss_percep = provider.get_loss_function('SSIM', colorspace='RGB', pretrained=True, reduction='sum') + loss_percep = provider.get_loss_function("SSIM", colorspace="RGB", pretrained=True, reduction="sum") loss_percep = loss_percep.to(device) - loss_i = lambda imgs_w, imgs: loss_percep((1+imgs_w)/2.0, (1+imgs)/2.0)/ imgs_w.shape[0] + loss_i = lambda imgs_w, imgs: loss_percep((1 + imgs_w) / 2.0, (1 + imgs) / 2.0) / imgs_w.shape[0] else: raise NotImplementedError for ii_key in range(params.num_keys): # Creating key - print(f'\n>>> Creating key with {nbit} bits...') + print(f"\n>>> Creating key with {nbit} bits...") key = torch.randint(0, 2, (1, nbit), dtype=torch.float32, device=device) - key_str = "".join([ str(int(ii)) for ii in key.tolist()[0]]) + key_str = "".join([str(int(ii)) for ii in key.tolist()[0]]) if params.not_rand_key: key_str = params.key_str bit_list = [int(char) for char in key_str] key = torch.tensor(bit_list, dtype=torch.float32, device=device) - print(f'Key: {key_str}') + print(f"Key: {key_str}") # Copy the LDM decoder and finetune the copy ldm_decoder = deepcopy(ldm_ae) @@ -259,24 +299,39 @@ def main(params): optimizer = utils.build_optimizer(model_params=ldm_decoder.parameters(), **optim_params) # Training loop - print(f'>>> Training...') - - train_stats = train(train_loader, optimizer, loss_w, loss_i, ldm_ae, ldm_decoder, msg_decoder, vqgan_to_imnet, key, params, wandb_run) + print(f">>> Training...") + + train_stats = train( + train_loader, + optimizer, + loss_w, + loss_i, + ldm_ae, + ldm_decoder, + msg_decoder, + vqgan_to_imnet, + key, + params, + wandb_run, + ) val_stats = val(val_loader, ldm_ae, ldm_decoder, msg_decoder, vqgan_to_imnet, key, params, wandb_run) - log_stats = {'key': key_str, - **{f'train_{k}': v for k, v in train_stats.items()}, - **{f'val_{k}': v for k, v in val_stats.items()}, - } + log_stats = { + "key": key_str, + **{f"train_{k}": v for k, v in train_stats.items()}, + **{f"val_{k}": v for k, v in val_stats.items()}, + } save_dict = { - 'ldm_decoder': ldm_decoder.state_dict(), - 'optimizer': optimizer.state_dict(), - 'params': params, + "ldm_decoder": ldm_decoder.state_dict(), + "optimizer": optimizer.state_dict(), + "params": params, } # Save checkpoint # torch.save(save_dict, os.path.join(params.output_dir, f"checkpoint_{ii_key:03d}.pth")) if not params.checkpoint_name: - torch.save(ldm_decoder.state_dict(), os.path.join(params.output_dir, f"ldm_decoder_checkpoint_{ii_key:03d}.pth")) + torch.save( + ldm_decoder.state_dict(), os.path.join(params.output_dir, f"ldm_decoder_checkpoint_{ii_key:03d}.pth") + ) else: torch.save(ldm_decoder.state_dict(), os.path.join(params.output_dir, f"{params.checkpoint_name}.pth")) torch.save(optimizer.state_dict(), os.path.join(params.output_dir, f"optimizer_checkpoint_{ii_key:03d}.pth")) @@ -286,33 +341,46 @@ def main(params): f.write(json.dumps(log_stats) + "\n") with (Path(params.output_dir) / "keys.txt").open("a") as f: f.write(os.path.join(params.output_dir, f"checkpoint_{ii_key:03d}.pth") + "\t" + key_str + "\n") - print('\n') - -def train(data_loader: Iterable, optimizer: torch.optim.Optimizer, loss_w: Callable, loss_i: Callable, ldm_ae: AutoencoderKL, ldm_decoder:AutoencoderKL, msg_decoder: nn.Module, vqgan_to_imnet:nn.Module, key: torch.Tensor, params: argparse.Namespace, wandb_run: wandb.run): - header = 'Train' + print("\n") + + +def train( + data_loader: Iterable, + optimizer: torch.optim.Optimizer, + loss_w: Callable, + loss_i: Callable, + ldm_ae: AutoencoderKL, + ldm_decoder: AutoencoderKL, + msg_decoder: nn.Module, + vqgan_to_imnet: nn.Module, + key: torch.Tensor, + params: argparse.Namespace, + wandb_run: wandb.run, +): + header = "Train" metric_logger = utils.MetricLogger(delimiter=" ") ldm_decoder.decoder.train() base_lr = optimizer.param_groups[0]["lr"] for ii, imgs in enumerate(metric_logger.log_every(data_loader, params.log_freq, header)): imgs = imgs.to(device) keys = key.repeat(imgs.shape[0], 1) - + utils.adjust_learning_rate(optimizer, ii, params.steps, params.warmup_steps, base_lr) # encode images - imgs_z = ldm_ae.encode(imgs) # b c h w -> b z h/f w/f + imgs_z = ldm_ae.encode(imgs) # b c h w -> b z h/f w/f imgs_z = imgs_z.mode() # decode latents with original and finetuned decoder - imgs_d0 = ldm_ae.decode(imgs_z) # b z h/f w/f -> b c h w - imgs_w = ldm_decoder.decode(imgs_z) # b z h/f w/f -> b c h w + imgs_d0 = ldm_ae.decode(imgs_z) # b z h/f w/f -> b c h w + imgs_w = ldm_decoder.decode(imgs_z) # b z h/f w/f -> b c h w # extract watermark - decoded = msg_decoder(vqgan_to_imnet(imgs_w)) # b c h w -> b k + decoded = msg_decoder(vqgan_to_imnet(imgs_w)) # b c h w -> b k print("-----------------") print_decoded = (decoded > 0).to(int) - print_decoded = "".join([ str(int(ii)) for ii in print_decoded.tolist()[0]]) - print(f'decoded: {print_decoded}') + print_decoded = "".join([str(int(ii)) for ii in print_decoded.tolist()[0]]) + print(f"decoded: {print_decoded}") # compute loss lossw = loss_w(decoded, keys) @@ -325,9 +393,9 @@ def train(data_loader: Iterable, optimizer: torch.optim.Optimizer, loss_w: Calla optimizer.zero_grad() # log stats - diff = (~torch.logical_xor(decoded>0, keys>0)) # b k -> b k - bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b - word_accs = (bit_accs == 1) # b + diff = ~torch.logical_xor(decoded > 0, keys > 0) # b k -> b k + bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b + word_accs = bit_accs == 1 # b log_stats = { "iteration": ii, "loss": loss.item(), @@ -340,37 +408,59 @@ def train(data_loader: Iterable, optimizer: torch.optim.Optimizer, loss_w: Calla "lr": optimizer.param_groups[0]["lr"], } for name, loss in log_stats.items(): - metric_logger.update(**{name:loss}) + metric_logger.update(**{name: loss}) if ii % params.log_freq == 0: print(json.dumps(log_stats)) - + # save images during training if ii % params.save_img_freq == 0: - save_image(torch.clamp(utils_img.unnormalize_vqgan(imgs),0,1), os.path.join(params.imgs_dir, f'{ii:03}_train_orig.png'), nrow=8) - save_image(torch.clamp(utils_img.unnormalize_vqgan(imgs_d0),0,1), os.path.join(params.imgs_dir, f'{ii:03}_train_d0.png'), nrow=8) - save_image(torch.clamp(utils_img.unnormalize_vqgan(imgs_w),0,1), os.path.join(params.imgs_dir, f'{ii:03}_train_w.png'), nrow=8) + save_image( + torch.clamp(utils_img.unnormalize_vqgan(imgs), 0, 1), + os.path.join(params.imgs_dir, f"{ii:03}_train_orig.png"), + nrow=8, + ) + save_image( + torch.clamp(utils_img.unnormalize_vqgan(imgs_d0), 0, 1), + os.path.join(params.imgs_dir, f"{ii:03}_train_d0.png"), + nrow=8, + ) + save_image( + torch.clamp(utils_img.unnormalize_vqgan(imgs_w), 0, 1), + os.path.join(params.imgs_dir, f"{ii:03}_train_w.png"), + nrow=8, + ) if params.with_tracking: - wandb_run.log({'Train_cycle': log_stats}) - - print("Averaged {} stats:".format('train'), metric_logger) + wandb_run.log({"Train_cycle": log_stats}) + + print("Averaged {} stats:".format("train"), metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + @torch.no_grad() -def val(data_loader: Iterable, ldm_ae: AutoencoderKL, ldm_decoder: AutoencoderKL, msg_decoder: nn.Module, vqgan_to_imnet:nn.Module, key: torch.Tensor, params: argparse.Namespace, wandb_run: wandb.run): - header = 'Eval' +def val( + data_loader: Iterable, + ldm_ae: AutoencoderKL, + ldm_decoder: AutoencoderKL, + msg_decoder: nn.Module, + vqgan_to_imnet: nn.Module, + key: torch.Tensor, + params: argparse.Namespace, + wandb_run: wandb.run, +): + header = "Eval" metric_logger = utils.MetricLogger(delimiter=" ") ldm_decoder.decoder.eval() for ii, imgs in enumerate(metric_logger.log_every(data_loader, params.log_freq, header)): - + imgs = imgs.to(device) - imgs_z = ldm_ae.encode(imgs) # b c h w -> b z h/f w/f + imgs_z = ldm_ae.encode(imgs) # b c h w -> b z h/f w/f imgs_z = imgs_z.mode() - imgs_d0 = ldm_ae.decode(imgs_z) # b z h/f w/f -> b c h w - imgs_w = ldm_decoder.decode(imgs_z) # b z h/f w/f -> b c h w - + imgs_d0 = ldm_ae.decode(imgs_z) # b z h/f w/f -> b c h w + imgs_w = ldm_decoder.decode(imgs_z) # b z h/f w/f -> b c h w + keys = key.repeat(imgs.shape[0], 1) log_stats = { @@ -379,57 +469,68 @@ def val(data_loader: Iterable, ldm_ae: AutoencoderKL, ldm_decoder: AutoencoderKL # "psnr_ori": utils_img.psnr(imgs_w, imgs).mean().item(), } attacks = { - 'none': lambda x: x, - 'crop_01': lambda x: utils_img.center_crop(x, 0.1), - 'crop_05': lambda x: utils_img.center_crop(x, 0.5), - 'rot_25': lambda x: utils_img.rotate(x, 25), - 'rot_90': lambda x: utils_img.rotate(x, 90), - 'resize_03': lambda x: utils_img.resize(x, 0.3), - 'resize_07': lambda x: utils_img.resize(x, 0.7), - 'brightness_1p5': lambda x: utils_img.adjust_brightness(x, 1.5), - 'brightness_2': lambda x: utils_img.adjust_brightness(x, 2), - 'jpeg_80': lambda x: utils_img.jpeg_compress(x, 80), - 'jpeg_50': lambda x: utils_img.jpeg_compress(x, 50), + "none": lambda x: x, + "crop_01": lambda x: utils_img.center_crop(x, 0.1), + "crop_05": lambda x: utils_img.center_crop(x, 0.5), + "rot_25": lambda x: utils_img.rotate(x, 25), + "rot_90": lambda x: utils_img.rotate(x, 90), + "resize_03": lambda x: utils_img.resize(x, 0.3), + "resize_07": lambda x: utils_img.resize(x, 0.7), + "brightness_1p5": lambda x: utils_img.adjust_brightness(x, 1.5), + "brightness_2": lambda x: utils_img.adjust_brightness(x, 2), + "jpeg_80": lambda x: utils_img.jpeg_compress(x, 80), + "jpeg_50": lambda x: utils_img.jpeg_compress(x, 50), } if params.no_attacks: - attacks = { - 'none': lambda x: x - } + attacks = {"none": lambda x: x} for name, attack in attacks.items(): imgs_aug = attack(vqgan_to_imnet(imgs_w)) - decoded = msg_decoder(imgs_aug) # b c h w -> b k - + decoded = msg_decoder(imgs_aug) # b c h w -> b k + print("-----------------") print_decoded = (decoded > 0).to(int) - print_decoded = "".join([ str(int(ii)) for ii in print_decoded.tolist()[0]]) - print(f'decoded: {print_decoded}') + print_decoded = "".join([str(int(ii)) for ii in print_decoded.tolist()[0]]) + print(f"decoded: {print_decoded}") - diff = (~torch.logical_xor(decoded>0, keys>0)) # b k -> b k - bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b - word_accs = (bit_accs == 1) # b - log_stats[f'bit_acc_{name}'] = torch.mean(bit_accs).item() - log_stats[f'word_acc_{name}'] = torch.mean(word_accs.type(torch.float)).item() + diff = ~torch.logical_xor(decoded > 0, keys > 0) # b k -> b k + bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b + word_accs = bit_accs == 1 # b + log_stats[f"bit_acc_{name}"] = torch.mean(bit_accs).item() + log_stats[f"word_acc_{name}"] = torch.mean(word_accs.type(torch.float)).item() for name, loss in log_stats.items(): - metric_logger.update(**{name:loss}) + metric_logger.update(**{name: loss}) if params.with_tracking: - wandb_run.log({'Val_cycle': log_stats}) + wandb_run.log({"Val_cycle": log_stats}) if ii % params.save_img_freq == 0: - save_image(torch.clamp(utils_img.unnormalize_vqgan(imgs),0,1), os.path.join(params.imgs_dir, f'{ii:03}_val_orig.png'), nrow=8) - save_image(torch.clamp(utils_img.unnormalize_vqgan(imgs_d0),0,1), os.path.join(params.imgs_dir, f'{ii:03}_val_d0.png'), nrow=8) - save_image(torch.clamp(utils_img.unnormalize_vqgan(imgs_w),0,1), os.path.join(params.imgs_dir, f'{ii:03}_val_w.png'), nrow=8) - + save_image( + torch.clamp(utils_img.unnormalize_vqgan(imgs), 0, 1), + os.path.join(params.imgs_dir, f"{ii:03}_val_orig.png"), + nrow=8, + ) + save_image( + torch.clamp(utils_img.unnormalize_vqgan(imgs_d0), 0, 1), + os.path.join(params.imgs_dir, f"{ii:03}_val_d0.png"), + nrow=8, + ) + save_image( + torch.clamp(utils_img.unnormalize_vqgan(imgs_w), 0, 1), + os.path.join(params.imgs_dir, f"{ii:03}_val_w.png"), + nrow=8, + ) + # if params.with_tracking: # wandb_run.log({'Val_final': wandb.Table(dataframe=pd.Dataframe.from_dict(log_stats))}) - print("Averaged {} stats:".format('eval'), metric_logger) + print("Averaged {} stats:".format("eval"), metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -if __name__ == '__main__': + +if __name__ == "__main__": # generate parser / parse parameters parser = get_parser() diff --git a/src/metr/guided_diffusion/dist_util.py b/src/metr/guided_diffusion/dist_util.py index 7acb48b..df40bc9 100644 --- a/src/metr/guided_diffusion/dist_util.py +++ b/src/metr/guided_diffusion/dist_util.py @@ -7,9 +7,9 @@ import socket import blobfile as bf -from mpi4py import MPI import torch as th import torch.distributed as dist +from mpi4py import MPI # Change this to reflect your cluster layout. # The GPU for a given rank is (rank % GPUS_PER_NODE). @@ -55,7 +55,7 @@ def load_state_dict(path, **kwargs): """ Load a PyTorch file without redundant fetches across MPI ranks. """ - chunk_size = 2 ** 30 # MPI has a relatively small size limit + chunk_size = 2**30 # MPI has a relatively small size limit if MPI.COMM_WORLD.Get_rank() == 0: with bf.BlobFile(path, "rb") as f: data = f.read() diff --git a/src/metr/guided_diffusion/fp16_util.py b/src/metr/guided_diffusion/fp16_util.py index df3882d..66ce892 100644 --- a/src/metr/guided_diffusion/fp16_util.py +++ b/src/metr/guided_diffusion/fp16_util.py @@ -40,9 +40,7 @@ def make_master_params(param_groups_and_shapes): master_params = [] for param_group, shape in param_groups_and_shapes: master_param = nn.Parameter( - _flatten_dense_tensors( - [param.detach().float() for (_, param) in param_group] - ).view(shape) + _flatten_dense_tensors([param.detach().float() for (_, param) in param_group]).view(shape) ) master_param.requires_grad = True master_params.append(master_param) @@ -54,12 +52,10 @@ def model_grads_to_master_grads(param_groups_and_shapes, master_params): Copy the gradients from the model parameters into the master parameters from make_master_params(). """ - for master_param, (param_group, shape) in zip( - master_params, param_groups_and_shapes - ): - master_param.grad = _flatten_dense_tensors( - [param_grad_or_zeros(param) for (_, param) in param_group] - ).view(shape) + for master_param, (param_group, shape) in zip(master_params, param_groups_and_shapes): + master_param.grad = _flatten_dense_tensors([param_grad_or_zeros(param) for (_, param) in param_group]).view( + shape + ) def master_params_to_model_params(param_groups_and_shapes, master_params): @@ -92,14 +88,10 @@ def get_param_groups_and_shapes(named_model_params): return [scalar_vector_named_params, matrix_named_params] -def master_params_to_state_dict( - model, param_groups_and_shapes, master_params, use_fp16 -): +def master_params_to_state_dict(model, param_groups_and_shapes, master_params, use_fp16): if use_fp16: state_dict = model.state_dict() - for master_param, (param_group, _) in zip( - master_params, param_groups_and_shapes - ): + for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): for (name, _), unflat_master_param in zip( param_group, unflatten_master_params(param_group, master_param.view(-1)) ): @@ -115,9 +107,7 @@ def master_params_to_state_dict( def state_dict_to_master_params(model, state_dict, use_fp16): if use_fp16: - named_model_params = [ - (name, state_dict[name]) for name, _ in model.named_parameters() - ] + named_model_params = [(name, state_dict[name]) for name, _ in model.named_parameters()] param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) master_params = make_master_params(param_groups_and_shapes) else: @@ -164,9 +154,7 @@ def __init__( self.lg_loss_scale = initial_lg_loss_scale if self.use_fp16: - self.param_groups_and_shapes = get_param_groups_and_shapes( - self.model.named_parameters() - ) + self.param_groups_and_shapes = get_param_groups_and_shapes(self.model.named_parameters()) self.master_params = make_master_params(self.param_groups_and_shapes) self.model.convert_to_fp16() @@ -175,7 +163,7 @@ def zero_grad(self): def backward(self, loss: th.Tensor): if self.use_fp16: - loss_scale = 2 ** self.lg_loss_scale + loss_scale = 2**self.lg_loss_scale (loss * loss_scale).backward() else: loss.backward() @@ -189,7 +177,7 @@ def optimize(self, opt: th.optim.Optimizer): def _optimize_fp16(self, opt: th.optim.Optimizer): logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) - grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) + grad_norm, param_norm = self._compute_norms(grad_scale=2**self.lg_loss_scale) if check_overflow(grad_norm): self.lg_loss_scale -= 1 logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") @@ -200,7 +188,7 @@ def _optimize_fp16(self, opt: th.optim.Optimizer): logger.logkv_mean("param_norm", param_norm) for p in self.master_params: - p.grad.mul_(1.0 / (2 ** self.lg_loss_scale)) + p.grad.mul_(1.0 / (2**self.lg_loss_scale)) opt.step() zero_master_grads(self.master_params) master_params_to_model_params(self.param_groups_and_shapes, self.master_params) @@ -225,9 +213,7 @@ def _compute_norms(self, grad_scale=1.0): return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) def master_params_to_state_dict(self, master_params): - return master_params_to_state_dict( - self.model, self.param_groups_and_shapes, master_params, self.use_fp16 - ) + return master_params_to_state_dict(self.model, self.param_groups_and_shapes, master_params, self.use_fp16) def state_dict_to_master_params(self, state_dict): return state_dict_to_master_params(self.model, state_dict, self.use_fp16) diff --git a/src/metr/guided_diffusion/gaussian_diffusion.py b/src/metr/guided_diffusion/gaussian_diffusion.py index 132d1d6..550e7c8 100644 --- a/src/metr/guided_diffusion/gaussian_diffusion.py +++ b/src/metr/guided_diffusion/gaussian_diffusion.py @@ -7,13 +7,13 @@ import enum import math -from PIL import Image import numpy as np import torch as th +from PIL import Image +from .losses import discretized_gaussian_log_likelihood, normal_kl from .nn import mean_flat -from .losses import normal_kl, discretized_gaussian_log_likelihood def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): @@ -31,9 +31,7 @@ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): scale = 1000 / num_diffusion_timesteps beta_start = scale * 0.0001 beta_end = scale * 0.02 - return np.linspace( - beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 - ) + return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) elif schedule_name == "cosine": return betas_for_alpha_bar( num_diffusion_timesteps, @@ -89,9 +87,7 @@ class ModelVarType(enum.Enum): class LossType(enum.Enum): MSE = enum.auto() # use raw MSE loss (and KL when learning variances) - RESCALED_MSE = ( - enum.auto() - ) # use raw MSE loss (with RESCALED_KL when learning variances) + RESCALED_MSE = enum.auto() # use raw MSE loss (with RESCALED_KL when learning variances) KL = enum.auto() # use the variational lower-bound RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB @@ -152,22 +148,12 @@ def __init__( self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) # calculations for posterior q(x_{t-1} | x_t, x_0) - self.posterior_variance = ( - betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) - ) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) # log calculation clipped because the posterior variance is 0 at the # beginning of the diffusion chain. - self.posterior_log_variance_clipped = np.log( - np.append(self.posterior_variance[1], self.posterior_variance[1:]) - ) - self.posterior_mean_coef1 = ( - betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) - ) - self.posterior_mean_coef2 = ( - (1.0 - self.alphas_cumprod_prev) - * np.sqrt(alphas) - / (1.0 - self.alphas_cumprod) - ) + self.posterior_log_variance_clipped = np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) + self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) def q_mean_variance(self, x_start, t): """ @@ -177,13 +163,9 @@ def q_mean_variance(self, x_start, t): :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ - mean = ( - _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - ) + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) - log_variance = _extract_into_tensor( - self.log_one_minus_alphas_cumprod, t, x_start.shape - ) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) return mean, variance, log_variance def q_sample(self, x_start, t, noise=None): @@ -202,8 +184,7 @@ def q_sample(self, x_start, t, noise=None): assert noise.shape == x_start.shape return ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) - * noise + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) def q_posterior_mean_variance(self, x_start, x_t, t): @@ -219,9 +200,7 @@ def q_posterior_mean_variance(self, x_start, x_t, t): + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = _extract_into_tensor( - self.posterior_log_variance_clipped, t, x_t.shape - ) + posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) assert ( posterior_mean.shape[0] == posterior_variance.shape[0] @@ -230,9 +209,7 @@ def q_posterior_mean_variance(self, x_start, x_t, t): ) return posterior_mean, posterior_variance, posterior_log_variance_clipped - def p_mean_variance( - self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None - ): + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. @@ -267,9 +244,7 @@ def p_mean_variance( model_log_variance = model_var_values model_variance = th.exp(model_log_variance) else: - min_log = _extract_into_tensor( - self.posterior_log_variance_clipped, t, x.shape - ) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) # The model_var_values is [-1, 1] for [min_var, max_var]. frac = (model_var_values + 1) / 2 @@ -299,26 +274,18 @@ def process_xstart(x): return x if self.model_mean_type == ModelMeanType.PREVIOUS_X: - pred_xstart = process_xstart( - self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) - ) + pred_xstart = process_xstart(self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)) model_mean = model_output elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: if self.model_mean_type == ModelMeanType.START_X: pred_xstart = process_xstart(model_output) else: - pred_xstart = process_xstart( - self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) - ) - model_mean, _, _ = self.q_posterior_mean_variance( - x_start=pred_xstart, x_t=x, t=t - ) + pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) else: raise NotImplementedError(self.model_mean_type) - assert ( - model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape - ) + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape return { "mean": model_mean, "variance": model_variance, @@ -337,16 +304,12 @@ def _predict_xstart_from_xprev(self, x_t, t, xprev): assert x_t.shape == xprev.shape return ( # (xprev - coef2*x_t) / coef1 _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev - - _extract_into_tensor( - self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape - ) - * x_t + - _extract_into_tensor(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape) * x_t ) def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - pred_xstart + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _scale_timesteps(self, t): @@ -364,9 +327,7 @@ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): This uses the conditioning strategy from Sohl-Dickstein et al. (2015). """ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) - new_mean = ( - p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() - ) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() return new_mean def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): @@ -382,15 +343,11 @@ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) - eps = eps - (1 - alpha_bar).sqrt() * cond_fn( - x, self._scale_timesteps(t), **model_kwargs - ) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, self._scale_timesteps(t), **model_kwargs) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) - out["mean"], _, _ = self.q_posterior_mean_variance( - x_start=out["pred_xstart"], x_t=x, t=t - ) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) return out def p_sample( @@ -429,13 +386,9 @@ def p_sample( model_kwargs=model_kwargs, ) noise = th.randn_like(x) - nonzero_mask = ( - (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) - ) # no noise when t == 0 + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 if cond_fn is not None: - out["mean"] = self.condition_mean( - cond_fn, out, x, t, model_kwargs=model_kwargs - ) + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise return {"sample": sample, "pred_xstart": out["pred_xstart"]} @@ -484,17 +437,17 @@ def p_sample_loop( progress=progress, ): final = sample - + if return_image: final = final["sample"] final = ((final + 1) * 127.5).clamp(0, 255).to(th.uint8) final = final.permute(0, 2, 3, 1) final = final.contiguous() - + outputs = [] for i in range(final.shape[0]): final_i = final[i].detach().cpu().numpy() - final_i = Image.fromarray(final_i, 'RGB') + final_i = Image.fromarray(final_i, "RGB") outputs.append(final_i) return outputs else: @@ -583,20 +536,11 @@ def ddim_sample( alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) - sigma = ( - eta - * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) - * th.sqrt(1 - alpha_bar / alpha_bar_prev) - ) + sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) # Equation 12. noise = th.randn_like(x) - mean_pred = ( - out["pred_xstart"] * th.sqrt(alpha_bar_prev) - + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps - ) - nonzero_mask = ( - (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) - ) # no noise when t == 0 + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out["pred_xstart"]} @@ -625,16 +569,12 @@ def ddim_reverse_sample( # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - - out["pred_xstart"] + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) # Equation 12. reversed - mean_pred = ( - out["pred_xstart"] * th.sqrt(alpha_bar_next) - + th.sqrt(1 - alpha_bar_next) * eps - ) + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} @@ -672,7 +612,7 @@ def ddim_reverse_sample_loop( eta=eta, ): final = sample - + return final["sample"] def ddim_reverse_sample_loop_progressive( @@ -698,7 +638,7 @@ def ddim_reverse_sample_loop_progressive( if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) - + img = th.tensor(np.array(image), device=device, dtype=th.float32).unsqueeze(0).permute(0, 3, 1, 2) / 127.5 - 1 indices = list(range(self.num_timesteps)) @@ -763,11 +703,11 @@ def ddim_sample_loop( final = ((final + 1) * 127.5).clamp(0, 255).to(th.uint8) final = final.permute(0, 2, 3, 1) final = final.contiguous() - + outputs = [] for i in range(final.shape[0]): final_i = final[i].detach().cpu().numpy() - final_i = Image.fromarray(final_i, 'RGB') + final_i = Image.fromarray(final_i, "RGB") outputs.append(final_i) return outputs else: @@ -823,9 +763,7 @@ def ddim_sample_loop_progressive( yield out img = out["sample"] - def _vb_terms_bpd( - self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None - ): + def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None): """ Get a term for the variational lower-bound. @@ -836,15 +774,9 @@ def _vb_terms_bpd( - 'output': a shape [N] tensor of NLLs or KLs. - 'pred_xstart': the x_0 predictions. """ - true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( - x_start=x_start, x_t=x_t, t=t - ) - out = self.p_mean_variance( - model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs - ) - kl = normal_kl( - true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] - ) + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) + out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs) + kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]) kl = mean_flat(kl) / np.log(2.0) decoder_nll = -discretized_gaussian_log_likelihood( @@ -916,9 +848,7 @@ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): terms["vb"] *= self.num_timesteps / 1000.0 target = { - ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( - x_start=x_start, x_t=x_t, t=t - )[0], + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0], ModelMeanType.START_X: x_start, ModelMeanType.EPSILON: noise, }[self.model_mean_type] @@ -946,9 +876,7 @@ def _prior_bpd(self, x_start): batch_size = x_start.shape[0] t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) - kl_prior = normal_kl( - mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 - ) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) return mean_flat(kl_prior) / np.log(2.0) def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): diff --git a/src/metr/guided_diffusion/image_datasets.py b/src/metr/guided_diffusion/image_datasets.py index 93022ae..05b118d 100644 --- a/src/metr/guided_diffusion/image_datasets.py +++ b/src/metr/guided_diffusion/image_datasets.py @@ -1,10 +1,10 @@ import math import random -from PIL import Image import blobfile as bf -from mpi4py import MPI import numpy as np +from mpi4py import MPI +from PIL import Image from torch.utils.data import DataLoader, Dataset @@ -56,13 +56,9 @@ def load_data( random_flip=random_flip, ) if deterministic: - loader = DataLoader( - dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True - ) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True) else: - loader = DataLoader( - dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True - ) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True) while True: yield from loader @@ -128,14 +124,10 @@ def center_crop_arr(pil_image, image_size): # argument, which uses BOX downsampling at powers of two first. # Thus, we do it by hand to improve downsample quality. while min(*pil_image.size) >= 2 * image_size: - pil_image = pil_image.resize( - tuple(x // 2 for x in pil_image.size), resample=Image.BOX - ) + pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) scale = image_size / min(*pil_image.size) - pil_image = pil_image.resize( - tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC - ) + pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) arr = np.array(pil_image) crop_y = (arr.shape[0] - image_size) // 2 @@ -152,14 +144,10 @@ def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0) # argument, which uses BOX downsampling at powers of two first. # Thus, we do it by hand to improve downsample quality. while min(*pil_image.size) >= 2 * smaller_dim_size: - pil_image = pil_image.resize( - tuple(x // 2 for x in pil_image.size), resample=Image.BOX - ) + pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) scale = smaller_dim_size / min(*pil_image.size) - pil_image = pil_image.resize( - tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC - ) + pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) arr = np.array(pil_image) crop_y = random.randrange(arr.shape[0] - image_size + 1) diff --git a/src/metr/guided_diffusion/logger.py b/src/metr/guided_diffusion/logger.py index b1d856d..1c8b4d4 100644 --- a/src/metr/guided_diffusion/logger.py +++ b/src/metr/guided_diffusion/logger.py @@ -3,14 +3,14 @@ https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py """ +import datetime +import json import os -import sys -import shutil import os.path as osp -import json -import time -import datetime +import shutil +import sys import tempfile +import time import warnings from collections import defaultdict from contextlib import contextmanager @@ -39,16 +39,14 @@ def __init__(self, filename_or_file): self.file = open(filename_or_file, "wt") self.own_file = True else: - assert hasattr(filename_or_file, "read"), ( - "expected file or str, got %s" % filename_or_file - ) + assert hasattr(filename_or_file, "read"), "expected file or str, got %s" % filename_or_file self.file = filename_or_file self.own_file = False def writekvs(self, kvs): # Create strings for printing key2str = {} - for (key, val) in sorted(kvs.items()): + for key, val in sorted(kvs.items()): if hasattr(val, "__float__"): valstr = "%-8.3g" % val else: @@ -66,11 +64,8 @@ def writekvs(self, kvs): # Write out the data dashes = "-" * (keywidth + valwidth + 7) lines = [dashes] - for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): - lines.append( - "| %s%s | %s%s |" - % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) - ) + for key, val in sorted(key2str.items(), key=lambda kv: kv[0].lower()): + lines.append("| %s%s | %s%s |" % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))) lines.append(dashes) self.file.write("\n".join(lines) + "\n") @@ -83,7 +78,7 @@ def _truncate(self, s): def writeseq(self, seq): seq = list(seq) - for (i, elem) in enumerate(seq): + for i, elem in enumerate(seq): self.file.write(elem) if i < len(seq) - 1: # add space unless this is the last one self.file.write(" ") @@ -125,7 +120,7 @@ def writekvs(self, kvs): self.file.seek(0) lines = self.file.readlines() self.file.seek(0) - for (i, k) in enumerate(self.keys): + for i, k in enumerate(self.keys): if i > 0: self.file.write(",") self.file.write(k) @@ -134,7 +129,7 @@ def writekvs(self, kvs): self.file.write(line[:-1]) self.file.write(self.sep * len(extra_keys)) self.file.write("\n") - for (i, k) in enumerate(self.keys): + for i, k in enumerate(self.keys): if i > 0: self.file.write(",") v = kvs.get(k) @@ -159,8 +154,8 @@ def __init__(self, dir): prefix = "events" path = osp.join(osp.abspath(dir), prefix) import tensorflow as tf - from tensorflow.python import pywrap_tensorflow from tensorflow.core.util import event_pb2 + from tensorflow.python import pywrap_tensorflow from tensorflow.python.util import compat self.tf = tf @@ -175,9 +170,7 @@ def summary_val(k, v): summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) event = self.event_pb2.Event(wall_time=time.time(), summary=summary) - event.step = ( - self.step - ) # is there any reason why you'd want to specify the step? + event.step = self.step # is there any reason why you'd want to specify the step? self.writer.WriteEvent(event) self.writer.Flush() self.step += 1 @@ -229,7 +222,7 @@ def logkvs(d): """ Log a dictionary of key-value pairs """ - for (k, v) in d.items(): + for k, v in d.items(): logkv(k, v) @@ -358,10 +351,7 @@ def dumpkvs(self): else: d = mpi_weighted_mean( self.comm, - { - name: (val, self.name2cnt.get(name, 1)) - for (name, val) in self.name2val.items() - }, + {name: (val, self.name2cnt.get(name, 1)) for (name, val) in self.name2val.items()}, ) if self.comm.rank != 0: d["dummy"] = 1 # so we don't get a warning about empty dict @@ -421,16 +411,12 @@ def mpi_weighted_mean(comm, local_name2valcount): name2sum = defaultdict(float) name2count = defaultdict(float) for n2vc in all_name2valcount: - for (name, (val, count)) in n2vc.items(): + for name, (val, count) in n2vc.items(): try: val = float(val) except ValueError: if comm.rank == 0: - warnings.warn( - "WARNING: tried to compute mean on non-float {}={}".format( - name, val - ) - ) + warnings.warn("WARNING: tried to compute mean on non-float {}={}".format(name, val)) else: name2sum[name] += val * count name2count[name] += count @@ -492,4 +478,3 @@ def scoped_configure(dir=None, format_strs=None, comm=None): finally: Logger.CURRENT.close() Logger.CURRENT = prevlogger - diff --git a/src/metr/guided_diffusion/losses.py b/src/metr/guided_diffusion/losses.py index 251e42e..187f31b 100644 --- a/src/metr/guided_diffusion/losses.py +++ b/src/metr/guided_diffusion/losses.py @@ -5,7 +5,6 @@ """ import numpy as np - import torch as th @@ -25,18 +24,9 @@ def normal_kl(mean1, logvar1, mean2, logvar2): # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for th.exp(). - logvar1, logvar2 = [ - x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) - for x in (logvar1, logvar2) - ] + logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)] - return 0.5 * ( - -1.0 - + logvar2 - - logvar1 - + th.exp(logvar1 - logvar2) - + ((mean1 - mean2) ** 2) * th.exp(-logvar2) - ) + return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2)) def approx_standard_normal_cdf(x): diff --git a/src/metr/guided_diffusion/nn.py b/src/metr/guided_diffusion/nn.py index a4cd59c..c168b00 100644 --- a/src/metr/guided_diffusion/nn.py +++ b/src/metr/guided_diffusion/nn.py @@ -111,9 +111,9 @@ def timestep_embedding(timesteps, dim, max_period=10000): :return: an [N x dim] Tensor of positional embeddings. """ half = dim // 2 - freqs = th.exp( - -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half - ).to(device=timesteps.device) + freqs = th.exp(-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half).to( + device=timesteps.device + ) args = timesteps[:, None].float() * freqs[None] embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) if dim % 2: diff --git a/src/metr/guided_diffusion/resample.py b/src/metr/guided_diffusion/resample.py index c82eccd..4a99c14 100644 --- a/src/metr/guided_diffusion/resample.py +++ b/src/metr/guided_diffusion/resample.py @@ -80,10 +80,7 @@ def update_with_local_losses(self, local_ts, local_losses): :param local_ts: an integer Tensor of timesteps. :param local_losses: a 1D Tensor of losses. """ - batch_sizes = [ - th.tensor([0], dtype=th.int32, device=local_ts.device) - for _ in range(dist.get_world_size()) - ] + batch_sizes = [th.tensor([0], dtype=th.int32, device=local_ts.device) for _ in range(dist.get_world_size())] dist.all_gather( batch_sizes, th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), @@ -97,9 +94,7 @@ def update_with_local_losses(self, local_ts, local_losses): loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] dist.all_gather(timestep_batches, local_ts) dist.all_gather(loss_batches, local_losses) - timesteps = [ - x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] - ] + timesteps = [x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]] losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] self.update_with_all_losses(timesteps, losses) @@ -126,15 +121,13 @@ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): self.diffusion = diffusion self.history_per_term = history_per_term self.uniform_prob = uniform_prob - self._loss_history = np.zeros( - [diffusion.num_timesteps, history_per_term], dtype=np.float64 - ) + self._loss_history = np.zeros([diffusion.num_timesteps, history_per_term], dtype=np.float64) self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) def weights(self): if not self._warmed_up(): return np.ones([self.diffusion.num_timesteps], dtype=np.float64) - weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) weights /= np.sum(weights) weights *= 1 - self.uniform_prob weights += self.uniform_prob / len(weights) diff --git a/src/metr/guided_diffusion/respace.py b/src/metr/guided_diffusion/respace.py index b568817..95bdd7c 100644 --- a/src/metr/guided_diffusion/respace.py +++ b/src/metr/guided_diffusion/respace.py @@ -32,9 +32,7 @@ def space_timesteps(num_timesteps, section_counts): for i in range(1, num_timesteps): if len(range(0, num_timesteps, i)) == desired_count: return set(range(0, num_timesteps, i)) - raise ValueError( - f"cannot create exactly {num_timesteps} steps with an integer stride" - ) + raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride") section_counts = [int(x) for x in section_counts.split(",")] size_per = num_timesteps // len(section_counts) extra = num_timesteps % len(section_counts) @@ -43,9 +41,7 @@ def space_timesteps(num_timesteps, section_counts): for i, section_count in enumerate(section_counts): size = size_per + (1 if i < extra else 0) if size < section_count: - raise ValueError( - f"cannot divide section of {size} steps into {section_count}" - ) + raise ValueError(f"cannot divide section of {size} steps into {section_count}") if section_count <= 1: frac_stride = 1 else: @@ -85,14 +81,10 @@ def __init__(self, use_timesteps, **kwargs): kwargs["betas"] = np.array(new_betas) super().__init__(**kwargs) - def p_mean_variance( - self, model, *args, **kwargs - ): # pylint: disable=signature-differs + def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) - def training_losses( - self, model, *args, **kwargs - ): # pylint: disable=signature-differs + def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs return super().training_losses(self._wrap_model(model), *args, **kwargs) def condition_mean(self, cond_fn, *args, **kwargs): @@ -104,9 +96,7 @@ def condition_score(self, cond_fn, *args, **kwargs): def _wrap_model(self, model): if isinstance(model, _WrappedModel): return model - return _WrappedModel( - model, self.timestep_map, self.rescale_timesteps, self.original_num_steps - ) + return _WrappedModel(model, self.timestep_map, self.rescale_timesteps, self.original_num_steps) def _scale_timesteps(self, t): # Scaling is done by the wrapped model. diff --git a/src/metr/guided_diffusion/script_util.py b/src/metr/guided_diffusion/script_util.py index 2bfdad9..3ded493 100644 --- a/src/metr/guided_diffusion/script_util.py +++ b/src/metr/guided_diffusion/script_util.py @@ -3,7 +3,7 @@ from . import gaussian_diffusion as gd from .respace import SpacedDiffusion, space_timesteps -from .unet import SuperResModel, UNetModel, EncoderUNetModel +from .unet import EncoderUNetModel, SuperResModel, UNetModel NUM_CLASSES = 1000 @@ -407,15 +407,9 @@ def create_gaussian_diffusion( return SpacedDiffusion( use_timesteps=space_timesteps(steps, timestep_respacing), betas=betas, - model_mean_type=( - gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X - ), + model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X), model_var_type=( - ( - gd.ModelVarType.FIXED_LARGE - if not sigma_small - else gd.ModelVarType.FIXED_SMALL - ) + (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL) if not learn_sigma else gd.ModelVarType.LEARNED_RANGE ), diff --git a/src/metr/guided_diffusion/train_util.py b/src/metr/guided_diffusion/train_util.py index 97c7db3..7dec4ff 100644 --- a/src/metr/guided_diffusion/train_util.py +++ b/src/metr/guided_diffusion/train_util.py @@ -45,11 +45,7 @@ def __init__( self.batch_size = batch_size self.microbatch = microbatch if microbatch > 0 else batch_size self.lr = lr - self.ema_rate = ( - [ema_rate] - if isinstance(ema_rate, float) - else [float(x) for x in ema_rate.split(",")] - ) + self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else [float(x) for x in ema_rate.split(",")] self.log_interval = log_interval self.save_interval = save_interval self.resume_checkpoint = resume_checkpoint @@ -72,21 +68,14 @@ def __init__( fp16_scale_growth=fp16_scale_growth, ) - self.opt = AdamW( - self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay - ) + self.opt = AdamW(self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay) if self.resume_step: self._load_optimizer_state() # Model was resumed, either due to a restart or a checkpoint # being specified at the command line. - self.ema_params = [ - self._load_ema_parameters(rate) for rate in self.ema_rate - ] + self.ema_params = [self._load_ema_parameters(rate) for rate in self.ema_rate] else: - self.ema_params = [ - copy.deepcopy(self.mp_trainer.master_params) - for _ in range(len(self.ema_rate)) - ] + self.ema_params = [copy.deepcopy(self.mp_trainer.master_params) for _ in range(len(self.ema_rate))] if th.cuda.is_available(): self.use_ddp = True @@ -100,10 +89,7 @@ def __init__( ) else: if dist.get_world_size() > 1: - logger.warn( - "Distributed training requires CUDA. " - "Gradients will not be synchronized properly!" - ) + logger.warn("Distributed training requires CUDA. " "Gradients will not be synchronized properly!") self.use_ddp = False self.ddp_model = self.model @@ -114,11 +100,7 @@ def _load_and_sync_parameters(self): self.resume_step = parse_resume_step_from_filename(resume_checkpoint) if dist.get_rank() == 0: logger.log(f"loading model from checkpoint: {resume_checkpoint}...") - self.model.load_state_dict( - dist_util.load_state_dict( - resume_checkpoint, map_location=dist_util.dev() - ) - ) + self.model.load_state_dict(dist_util.load_state_dict(resume_checkpoint, map_location=dist_util.dev())) dist_util.sync_params(self.model.parameters()) @@ -130,9 +112,7 @@ def _load_ema_parameters(self, rate): if ema_checkpoint: if dist.get_rank() == 0: logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") - state_dict = dist_util.load_state_dict( - ema_checkpoint, map_location=dist_util.dev() - ) + state_dict = dist_util.load_state_dict(ema_checkpoint, map_location=dist_util.dev()) ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) dist_util.sync_params(ema_params) @@ -140,21 +120,14 @@ def _load_ema_parameters(self, rate): def _load_optimizer_state(self): main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint - opt_checkpoint = bf.join( - bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" - ) + opt_checkpoint = bf.join(bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt") if bf.exists(opt_checkpoint): logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") - state_dict = dist_util.load_state_dict( - opt_checkpoint, map_location=dist_util.dev() - ) + state_dict = dist_util.load_state_dict(opt_checkpoint, map_location=dist_util.dev()) self.opt.load_state_dict(state_dict) def run_loop(self): - while ( - not self.lr_anneal_steps - or self.step + self.resume_step < self.lr_anneal_steps - ): + while not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps: batch, cond = next(self.data) self.run_step(batch, cond) if self.step % self.log_interval == 0: @@ -181,10 +154,7 @@ def forward_backward(self, batch, cond): self.mp_trainer.zero_grad() for i in range(0, batch.shape[0], self.microbatch): micro = batch[i : i + self.microbatch].to(dist_util.dev()) - micro_cond = { - k: v[i : i + self.microbatch].to(dist_util.dev()) - for k, v in cond.items() - } + micro_cond = {k: v[i : i + self.microbatch].to(dist_util.dev()) for k, v in cond.items()} last_batch = (i + self.microbatch) >= batch.shape[0] t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) @@ -203,14 +173,10 @@ def forward_backward(self, batch, cond): losses = compute_losses() if isinstance(self.schedule_sampler, LossAwareSampler): - self.schedule_sampler.update_with_local_losses( - t, losses["loss"].detach() - ) + self.schedule_sampler.update_with_local_losses(t, losses["loss"].detach()) loss = (losses["loss"] * weights).mean() - log_loss_dict( - self.diffusion, t, {k: v * weights for k, v in losses.items()} - ) + log_loss_dict(self.diffusion, t, {k: v * weights for k, v in losses.items()}) self.mp_trainer.backward(loss) def _update_ema(self): diff --git a/src/metr/guided_diffusion/unet.py b/src/metr/guided_diffusion/unet.py index 96b4693..acc415d 100644 --- a/src/metr/guided_diffusion/unet.py +++ b/src/metr/guided_diffusion/unet.py @@ -1,6 +1,5 @@ -from abc import abstractmethod - import math +from abc import abstractmethod import numpy as np import torch as th @@ -8,15 +7,7 @@ import torch.nn.functional as F from .fp16_util import convert_module_to_f16, convert_module_to_f32 -from .nn import ( - checkpoint, - conv_nd, - linear, - avg_pool_nd, - zero_module, - normalization, - timestep_embedding, -) +from .nn import avg_pool_nd, checkpoint, conv_nd, linear, normalization, timestep_embedding, zero_module class AttentionPool2d(nn.Module): @@ -32,9 +23,7 @@ def __init__( output_dim: int = None, ): super().__init__() - self.positional_embedding = nn.Parameter( - th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5 - ) + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.num_heads = embed_dim // num_heads_channels @@ -100,9 +89,7 @@ def __init__(self, channels, use_conv, dims=2, out_channels=None): def forward(self, x): assert x.shape[1] == self.channels if self.dims == 3: - x = F.interpolate( - x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - ) + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") else: x = F.interpolate(x, scale_factor=2, mode="nearest") if self.use_conv: @@ -128,9 +115,7 @@ def __init__(self, channels, use_conv, dims=2, out_channels=None): self.dims = dims stride = 2 if dims != 3 else (1, 2, 2) if use_conv: - self.op = conv_nd( - dims, self.channels, self.out_channels, 3, stride=stride, padding=1 - ) + self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1) else: assert self.channels == self.out_channels self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) @@ -207,17 +192,13 @@ def __init__( normalization(self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) - ), + zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: - self.skip_connection = conv_nd( - dims, channels, self.out_channels, 3, padding=1 - ) + self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) @@ -229,9 +210,7 @@ def forward(self, x, emb): :param emb: an [N x emb_channels] Tensor of timestep embeddings. :return: an [N x C x ...] Tensor of outputs. """ - return checkpoint( - self._forward, (x, emb), self.parameters(), self.use_checkpoint - ) + return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint) def _forward(self, x, emb): if self.updown: @@ -321,7 +300,7 @@ def count_flops_attn(model, _x, y): # We perform two matmuls with the same number of ops. # The first computes the weight matrix, the second computes # the combination of the value vectors. - matmul_ops = 2 * b * (num_spatial ** 2) * c + matmul_ops = 2 * b * (num_spatial**2) * c model.total_ops += th.DoubleTensor([matmul_ops]) @@ -346,9 +325,7 @@ def forward(self, qkv): ch = width // (3 * self.n_heads) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards + weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) a = th.einsum("bts,bcs->bct", weight, v) return a.reshape(bs, -1, length) @@ -478,9 +455,7 @@ def __init__( self.label_emb = nn.Embedding(num_classes, time_embed_dim) ch = input_ch = int(channel_mult[0] * model_channels) - self.input_blocks = nn.ModuleList( - [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] - ) + self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]) self._feature_size = ch input_block_chans = [ch] ds = 1 @@ -526,9 +501,7 @@ def __init__( down=True, ) if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) ) ) ch = out_ch @@ -736,9 +709,7 @@ def __init__( ) ch = int(channel_mult[0] * model_channels) - self.input_blocks = nn.ModuleList( - [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] - ) + self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]) self._feature_size = ch input_block_chans = [ch] ds = 1 @@ -784,9 +755,7 @@ def __init__( down=True, ) if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) ) ) ch = out_ch @@ -834,9 +803,7 @@ def __init__( self.out = nn.Sequential( normalization(ch), nn.SiLU(), - AttentionPool2d( - (image_size // ds), ch, num_head_channels, out_channels - ), + AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels), ) elif pool == "spatial": self.out = nn.Sequential( diff --git a/src/metr/inverse_stable_diffusion.py b/src/metr/inverse_stable_diffusion.py index 2a02974..87a68ff 100644 --- a/src/metr/inverse_stable_diffusion.py +++ b/src/metr/inverse_stable_diffusion.py @@ -1,38 +1,36 @@ from functools import partial -from typing import Callable, List, Optional, Union, Tuple +from typing import Callable, List, Optional, Tuple, Union import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - from diffusers.models import AutoencoderKL, UNet2DConditionModel + # from diffusers import StableDiffusionPipeline -from diffusers.pipelines.stable_diffusion.safety_checker import \ - StableDiffusionSafetyChecker -from diffusers.schedulers import DDIMScheduler,PNDMScheduler, LMSDiscreteScheduler +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from .modified_stable_diffusion import ModifiedStableDiffusionPipeline - ### credit to: https://github.com/cccntu/efficient-prompt-to-prompt + def backward_ddim(x_t, alpha_t, alpha_tm1, eps_xt): - """ from noise to image""" + """from noise to image""" return ( alpha_tm1**0.5 - * ( - (alpha_t**-0.5 - alpha_tm1**-0.5) * x_t - + ((1 / alpha_tm1 - 1) ** 0.5 - (1 / alpha_t - 1) ** 0.5) * eps_xt - ) + * ((alpha_t**-0.5 - alpha_tm1**-0.5) * x_t + ((1 / alpha_tm1 - 1) ** 0.5 - (1 / alpha_t - 1) ** 0.5) * eps_xt) + x_t ) + def forward_ddim(x_t, alpha_t, alpha_tp1, eps_xt): - """ from image to noise, it's the same as backward_ddim""" + """from image to noise, it's the same as backward_ddim""" return backward_ddim(x_t, alpha_t, alpha_tp1, eps_xt) class InversableStableDiffusionPipeline(ModifiedStableDiffusionPipeline): - def __init__(self, + def __init__( + self, vae, text_encoder, tokenizer, @@ -42,17 +40,12 @@ def __init__(self, feature_extractor, requires_safety_checker: bool = True, ): - super(InversableStableDiffusionPipeline, self).__init__(vae, - text_encoder, - tokenizer, - unet, - scheduler, - safety_checker, - feature_extractor, - requires_safety_checker) + super(InversableStableDiffusionPipeline, self).__init__( + vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker + ) self.forward_diffusion = partial(self.backward_diffusion, reverse_process=True) - + def get_random_latents(self, latents=None, height=512, width=512, generator=None): height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor @@ -86,7 +79,7 @@ def get_text_embedding(self, prompt): ).input_ids text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] return text_embeddings - + @torch.inference_mode() def get_image_latents(self, image, sample=True, rng_generator=None): encoding_dist = self.vae.encode(image).latent_dist @@ -97,7 +90,6 @@ def get_image_latents(self, image, sample=True, rng_generator=None): latents = encoding * 0.18215 return latents - @torch.inference_mode() def backward_diffusion( self, @@ -113,8 +105,7 @@ def backward_diffusion( reverse_process: True = False, **kwargs, ): - """ Generate image from text prompt and latents - """ + """Generate image from text prompt and latents""" # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @@ -132,8 +123,9 @@ def backward_diffusion( else: prompt_to_prompt = False - - for i, t in enumerate(self.progress_bar(timesteps_tensor if not reverse_process else reversed(timesteps_tensor))): + for i, t in enumerate( + self.progress_bar(timesteps_tensor if not reverse_process else reversed(timesteps_tensor)) + ): if prompt_to_prompt: if i < use_old_emb_i: text_embeddings = old_text_embeddings @@ -141,33 +133,23 @@ def backward_diffusion( text_embeddings = new_text_embeddings # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * 2) if do_classifier_free_guidance else latents - ) + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet( - latent_model_input, t, encoder_hidden_states=text_embeddings - ).sample + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - - prev_timestep = ( - t - - self.scheduler.config.num_train_timesteps - // self.scheduler.num_inference_steps - ) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) - - # ddim + + # ddim alpha_prod_t = self.scheduler.alphas_cumprod[t] alpha_prod_t_prev = ( self.scheduler.alphas_cumprod[prev_timestep] @@ -184,13 +166,10 @@ def backward_diffusion( ) return latents - @torch.inference_mode() def decode_image(self, latents: torch.FloatTensor, **kwargs): scaled_latents = 1 / 0.18215 * latents - image = [ - self.vae.decode(scaled_latents[i : i + 1]).sample for i in range(len(latents)) - ] + image = [self.vae.decode(scaled_latents[i : i + 1]).sample for i in range(len(latents))] image = torch.cat(image, dim=0) return image diff --git a/src/metr/io_utils.py b/src/metr/io_utils.py index 612b5c5..9e1e314 100644 --- a/src/metr/io_utils.py +++ b/src/metr/io_utils.py @@ -1,8 +1,8 @@ -import os import glob import json import logging -from typing import Any, Mapping, Iterable, Union, List, Callable, Optional +import os +from typing import Any, Callable, Iterable, List, Mapping, Optional, Union from tqdm.auto import tqdm @@ -23,7 +23,7 @@ def read_jsonlines(filename: str) -> Iterable[Mapping[str, Any]]: """Yields an iterable of Python dicts after reading jsonlines from the input file.""" file_size = os.path.getsize(filename) with open(filename) as fp: - for line in tqdm(fp.readlines(), desc=f'Reading JSON lines from {filename}', unit='lines'): + for line in tqdm(fp.readlines(), desc=f"Reading JSON lines from {filename}", unit="lines"): try: example = json.loads(line) yield example @@ -33,17 +33,19 @@ def read_jsonlines(filename: str) -> Iterable[Mapping[str, Any]]: raise ex -def hf_read_jsonlines(filename: str, - n: Optional[int]=None, - minimal_questions: Optional[bool]=False, - unique_questions: Optional[bool] = False) -> Iterable[Mapping[str, Any]]: +def hf_read_jsonlines( + filename: str, + n: Optional[int] = None, + minimal_questions: Optional[bool] = False, + unique_questions: Optional[bool] = False, +) -> Iterable[Mapping[str, Any]]: """Yields an iterable of Python dicts after reading jsonlines from the input file. - Optionally reads only first n lines from file.""" + Optionally reads only first n lines from file.""" file_size = os.path.getsize(filename) # O(n) but no memory with open(filename) as f: - num_lines= sum(1 for _ in f) - if n is None: + num_lines = sum(1 for _ in f) + if n is None: n = num_lines # returning a generator with the scope stmt seemed to be the issue, but I am not 100% sure @@ -54,7 +56,9 @@ def line_generator(): unique_qc_ids = set() # note, I am p sure that readlines is not lazy, returns a list, thus really only the # object conversion is lazy - for i, line in tqdm(enumerate(open(filename).readlines()[:n]), desc=f'Reading JSON lines from {filename}', unit='lines'): + for i, line in tqdm( + enumerate(open(filename).readlines()[:n]), desc=f"Reading JSON lines from {filename}", unit="lines" + ): try: full_example = json.loads(line) @@ -71,12 +75,12 @@ def line_generator(): full_example = full_example q_object = full_example["object"] q_object.pop("question_info") - example= {} + example = {} example["object"] = { - "answer":q_object["answer"], - "clue_spans":q_object["clue_spans"], - "qc_id":q_object["qc_id"], - "question_text":q_object["question_text"], + "answer": q_object["answer"], + "clue_spans": q_object["clue_spans"], + "qc_id": q_object["qc_id"], + "question_text": q_object["question_text"], } yield example @@ -84,6 +88,7 @@ def line_generator(): logging.error(f'Input text: "{line}"') logging.error(ex.args) raise ex + return line_generator @@ -94,10 +99,10 @@ def load_jsonlines(filename: str) -> List[Mapping[str, Any]]: def write_jsonlines(objs: Iterable[Mapping[str, Any]], filename: str, to_dict: Callable = lambda x: x): """Writes a list of Python Mappings as jsonlines at the input file.""" - with open(filename, 'w') as fp: - for obj in tqdm(objs, desc=f'Writing JSON lines at {filename}'): + with open(filename, "w") as fp: + for obj in tqdm(objs, desc=f"Writing JSON lines at {filename}"): fp.write(json.dumps(to_dict(obj))) - fp.write('\n') + fp.write("\n") def read_json(filename: str) -> Mapping[str, Any]: @@ -106,11 +111,11 @@ def read_json(filename: str) -> Mapping[str, Any]: return json.load(fp) -def write_json(obj: Mapping[str, Any], filename: str, indent:int=None): +def write_json(obj: Mapping[str, Any], filename: str, indent: int = None): """Writes a Python Mapping at the input file in JSON format.""" - with open(filename, 'w') as fp: + with open(filename, "w") as fp: json.dump(obj, fp, indent=indent) def print_json(d, indent=4): - print(json.dumps(d, indent=indent)) \ No newline at end of file + print(json.dumps(d, indent=indent)) diff --git a/src/metr/ldm/data/base.py b/src/metr/ldm/data/base.py index b196c2f..1b6a138 100644 --- a/src/metr/ldm/data/base.py +++ b/src/metr/ldm/data/base.py @@ -1,11 +1,13 @@ from abc import abstractmethod -from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset + +from torch.utils.data import ChainDataset, ConcatDataset, Dataset, IterableDataset class Txt2ImgIterableBaseDataset(IterableDataset): - ''' + """ Define an interface to make the IterableDatasets for text2img data chainable - ''' + """ + def __init__(self, num_records=0, valid_ids=None, size=256): super().__init__() self.num_records = num_records @@ -13,11 +15,11 @@ def __init__(self, num_records=0, valid_ids=None, size=256): self.sample_ids = valid_ids self.size = size - print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') + print(f"{self.__class__.__name__} dataset contains {self.__len__()} examples.") def __len__(self): return self.num_records @abstractmethod def __iter__(self): - pass \ No newline at end of file + pass diff --git a/src/metr/ldm/data/imagenet.py b/src/metr/ldm/data/imagenet.py index 1c473f9..8483e16 100644 --- a/src/metr/ldm/data/imagenet.py +++ b/src/metr/ldm/data/imagenet.py @@ -1,32 +1,35 @@ -import os, yaml, pickle, shutil, tarfile, glob -import cv2 +import glob +import os +import pickle +import shutil +import tarfile +from functools import partial + import albumentations -import PIL +import cv2 import numpy as np +import PIL +import taming.data.utils as tdu import torchvision.transforms.functional as TF +import yaml +from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light from omegaconf import OmegaConf -from functools import partial from PIL import Image -from tqdm import tqdm +from taming.data.imagenet import ImagePaths, download, give_synsets_from_indices, retrieve, str_to_indices from torch.utils.data import Dataset, Subset - -import taming.data.utils as tdu -from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve -from taming.data.imagenet import ImagePaths - -from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light +from tqdm import tqdm def synset2idx(path_to_yaml="data/index_synset.yaml"): with open(path_to_yaml) as f: di2s = yaml.load(f) - return dict((v,k) for k,v in di2s.items()) + return dict((v, k) for k, v in di2s.items()) class ImageNetBase(Dataset): def __init__(self, config=None): self.config = config or OmegaConf.create() - if not type(self.config)==dict: + if not type(self.config) == dict: self.config = OmegaConf.to_container(self.config) self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) self.process_images = True # if False we skip loading & processing images and self.data contains filepaths @@ -46,9 +49,11 @@ def _prepare(self): raise NotImplementedError() def _filter_relpaths(self, relpaths): - ignore = set([ - "n06596364_9591.JPEG", - ]) + ignore = set( + [ + "n06596364_9591.JPEG", + ] + ) relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] if "sub_indices" in self.config: indices = str_to_indices(self.config["sub_indices"]) @@ -67,20 +72,19 @@ def _prepare_synset_to_human(self): SIZE = 2655750 URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" self.human_dict = os.path.join(self.root, "synset_human.txt") - if (not os.path.exists(self.human_dict) or - not os.path.getsize(self.human_dict)==SIZE): + if not os.path.exists(self.human_dict) or not os.path.getsize(self.human_dict) == SIZE: download(URL, self.human_dict) def _prepare_idx_to_synset(self): URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" self.idx2syn = os.path.join(self.root, "index_synset.yaml") - if (not os.path.exists(self.idx2syn)): + if not os.path.exists(self.idx2syn): download(URL, self.idx2syn) def _prepare_human_to_integer_label(self): URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") - if (not os.path.exists(self.human2integer)): + if not os.path.exists(self.human2integer): download(URL, self.human2integer) with open(self.human2integer, "r") as f: lines = f.read().splitlines() @@ -122,11 +126,12 @@ def _load(self): if self.process_images: self.size = retrieve(self.config, "size", default=256) - self.data = ImagePaths(self.abspaths, - labels=labels, - size=self.size, - random_crop=self.random_crop, - ) + self.data = ImagePaths( + self.abspaths, + labels=labels, + size=self.size, + random_crop=self.random_crop, + ) else: self.data = self.abspaths @@ -157,8 +162,7 @@ def _prepare(self): self.datadir = os.path.join(self.root, "data") self.txt_filelist = os.path.join(self.root, "filelist.txt") self.expected_length = 1281167 - self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", - default=True) + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", default=True) if not tdu.is_prepared(self.root): # prep print("Preparing dataset {} in {}".format(self.NAME, self.root)) @@ -166,8 +170,9 @@ def _prepare(self): datadir = self.datadir if not os.path.exists(datadir): path = os.path.join(self.root, self.FILES[0]) - if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + if not os.path.exists(path) or not os.path.getsize(path) == self.SIZES[0]: import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) assert atpath == path @@ -179,7 +184,7 @@ def _prepare(self): print("Extracting sub-tars.") subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) for subpath in tqdm(subpaths): - subdir = subpath[:-len(".tar")] + subdir = subpath[: -len(".tar")] os.makedirs(subdir, exist_ok=True) with tarfile.open(subpath, "r:") as tar: tar.extractall(path=subdir) @@ -187,7 +192,7 @@ def _prepare(self): filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = sorted(filelist) - filelist = "\n".join(filelist)+"\n" + filelist = "\n".join(filelist) + "\n" with open(self.txt_filelist, "w") as f: f.write(filelist) @@ -222,8 +227,7 @@ def _prepare(self): self.datadir = os.path.join(self.root, "data") self.txt_filelist = os.path.join(self.root, "filelist.txt") self.expected_length = 50000 - self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", - default=False) + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", default=False) if not tdu.is_prepared(self.root): # prep print("Preparing dataset {} in {}".format(self.NAME, self.root)) @@ -231,8 +235,9 @@ def _prepare(self): datadir = self.datadir if not os.path.exists(datadir): path = os.path.join(self.root, self.FILES[0]) - if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + if not os.path.exists(path) or not os.path.getsize(path) == self.SIZES[0]: import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) assert atpath == path @@ -242,7 +247,7 @@ def _prepare(self): tar.extractall(path=datadir) vspath = os.path.join(self.root, self.FILES[1]) - if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + if not os.path.exists(vspath) or not os.path.getsize(vspath) == self.SIZES[1]: download(self.VS_URL, vspath) with open(vspath, "r") as f: @@ -261,18 +266,15 @@ def _prepare(self): filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = sorted(filelist) - filelist = "\n".join(filelist)+"\n" + filelist = "\n".join(filelist) + "\n" with open(self.txt_filelist, "w") as f: f.write(filelist) tdu.mark_prepared(self.root) - class ImageNetSR(Dataset): - def __init__(self, size=None, - degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., - random_crop=True): + def __init__(self, size=None, degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.0, random_crop=True): """ Imagenet Superresolution Dataloader Performs following ops in order: @@ -296,12 +298,12 @@ def __init__(self, size=None, self.LR_size = int(size / downscale_f) self.min_crop_f = min_crop_f self.max_crop_f = max_crop_f - assert(max_crop_f <= 1.) + assert max_crop_f <= 1.0 self.center_crop = not random_crop self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) - self.pil_interpolation = False # gets reset later if incase interp_op is from pillow + self.pil_interpolation = False # gets reset later if incase interp_op is from pillow if degradation == "bsrgan": self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) @@ -311,17 +313,17 @@ def __init__(self, size=None, else: interpolation_fn = { - "cv_nearest": cv2.INTER_NEAREST, - "cv_bilinear": cv2.INTER_LINEAR, - "cv_bicubic": cv2.INTER_CUBIC, - "cv_area": cv2.INTER_AREA, - "cv_lanczos": cv2.INTER_LANCZOS4, - "pil_nearest": PIL.Image.NEAREST, - "pil_bilinear": PIL.Image.BILINEAR, - "pil_bicubic": PIL.Image.BICUBIC, - "pil_box": PIL.Image.BOX, - "pil_hamming": PIL.Image.HAMMING, - "pil_lanczos": PIL.Image.LANCZOS, + "cv_nearest": cv2.INTER_NEAREST, + "cv_bilinear": cv2.INTER_LINEAR, + "cv_bicubic": cv2.INTER_CUBIC, + "cv_area": cv2.INTER_AREA, + "cv_lanczos": cv2.INTER_LANCZOS4, + "pil_nearest": PIL.Image.NEAREST, + "pil_bilinear": PIL.Image.BILINEAR, + "pil_bicubic": PIL.Image.BICUBIC, + "pil_box": PIL.Image.BOX, + "pil_hamming": PIL.Image.HAMMING, + "pil_lanczos": PIL.Image.LANCZOS, }[degradation] self.pil_interpolation = degradation.startswith("pil_") @@ -330,8 +332,9 @@ def __init__(self, size=None, self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) else: - self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, - interpolation=interpolation_fn) + self.degradation_process = albumentations.SmallestMaxSize( + max_size=self.LR_size, interpolation=interpolation_fn + ) def __len__(self): return len(self.base) @@ -366,8 +369,8 @@ def __getitem__(self, i): else: LR_image = self.degradation_process(image=image)["image"] - example["image"] = (image/127.5 - 1.0).astype(np.float32) - example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + example["LR_image"] = (LR_image / 127.5 - 1.0).astype(np.float32) return example @@ -379,7 +382,9 @@ def __init__(self, **kwargs): def get_base(self): with open("data/imagenet_train_hr_indices.p", "rb") as f: indices = pickle.load(f) - dset = ImageNetTrain(process_images=False,) + dset = ImageNetTrain( + process_images=False, + ) return Subset(dset, indices) @@ -390,5 +395,7 @@ def __init__(self, **kwargs): def get_base(self): with open("data/imagenet_val_hr_indices.p", "rb") as f: indices = pickle.load(f) - dset = ImageNetValidation(process_images=False,) + dset = ImageNetValidation( + process_images=False, + ) return Subset(dset, indices) diff --git a/src/metr/ldm/data/lsun.py b/src/metr/ldm/data/lsun.py index 6256e45..799f630 100644 --- a/src/metr/ldm/data/lsun.py +++ b/src/metr/ldm/data/lsun.py @@ -1,4 +1,5 @@ import os + import numpy as np import PIL from PIL import Image @@ -7,13 +8,7 @@ class LSUNBase(Dataset): - def __init__(self, - txt_file, - data_root, - size=None, - interpolation="bicubic", - flip_p=0.5 - ): + def __init__(self, txt_file, data_root, size=None, interpolation="bicubic", flip_p=0.5): self.data_paths = txt_file self.data_root = data_root with open(self.data_paths, "r") as f: @@ -21,16 +16,16 @@ def __init__(self, self._length = len(self.image_paths) self.labels = { "relative_file_path_": [l for l in self.image_paths], - "file_path_": [os.path.join(self.data_root, l) - for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, l) for l in self.image_paths], } self.size = size - self.interpolation = {"linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - }[interpolation] + self.interpolation = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] self.flip = transforms.RandomHorizontalFlip(p=flip_p) def __len__(self): @@ -45,9 +40,14 @@ def __getitem__(self, i): # default to score-sde preprocessing img = np.array(image).astype(np.uint8) crop = min(img.shape[0], img.shape[1]) - h, w, = img.shape[0], img.shape[1] - img = img[(h - crop) // 2:(h + crop) // 2, - (w - crop) // 2:(w + crop) // 2] + ( + h, + w, + ) = ( + img.shape[0], + img.shape[1], + ) + img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] image = Image.fromarray(img) if self.size is not None: @@ -65,9 +65,10 @@ def __init__(self, **kwargs): class LSUNChurchesValidation(LSUNBase): - def __init__(self, flip_p=0., **kwargs): - super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", - flip_p=flip_p, **kwargs) + def __init__(self, flip_p=0.0, **kwargs): + super().__init__( + txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", flip_p=flip_p, **kwargs + ) class LSUNBedroomsTrain(LSUNBase): @@ -77,8 +78,7 @@ def __init__(self, **kwargs): class LSUNBedroomsValidation(LSUNBase): def __init__(self, flip_p=0.0, **kwargs): - super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", - flip_p=flip_p, **kwargs) + super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", flip_p=flip_p, **kwargs) class LSUNCatsTrain(LSUNBase): @@ -87,6 +87,5 @@ def __init__(self, **kwargs): class LSUNCatsValidation(LSUNBase): - def __init__(self, flip_p=0., **kwargs): - super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", - flip_p=flip_p, **kwargs) + def __init__(self, flip_p=0.0, **kwargs): + super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", flip_p=flip_p, **kwargs) diff --git a/src/metr/ldm/lr_scheduler.py b/src/metr/ldm/lr_scheduler.py index be39da9..d30c6d8 100644 --- a/src/metr/ldm/lr_scheduler.py +++ b/src/metr/ldm/lr_scheduler.py @@ -5,18 +5,20 @@ class LambdaWarmUpCosineScheduler: """ note: use with a base_lr of 1.0 """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): self.lr_warm_up_steps = warm_up_steps self.lr_start = lr_start self.lr_min = lr_min self.lr_max = lr_max self.lr_max_decay_steps = max_decay_steps - self.last_lr = 0. + self.last_lr = 0.0 self.verbosity_interval = verbosity_interval def schedule(self, n, **kwargs): if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") if n < self.lr_warm_up_steps: lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start self.last_lr = lr @@ -24,13 +26,12 @@ def schedule(self, n, **kwargs): else: t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) t = min(t, 1.0) - lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( - 1 + np.cos(t * np.pi)) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi)) self.last_lr = lr return lr def __call__(self, n, **kwargs): - return self.schedule(n,**kwargs) + return self.schedule(n, **kwargs) class LambdaWarmUpCosineScheduler2: @@ -38,6 +39,7 @@ class LambdaWarmUpCosineScheduler2: supports repeated iterations, configurable via lists note: use with a base_lr of 1.0. """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) self.lr_warm_up_steps = warm_up_steps @@ -46,7 +48,7 @@ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosit self.f_max = f_max self.cycle_lengths = cycle_lengths self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) - self.last_f = 0. + self.last_f = 0.0 self.verbosity_interval = verbosity_interval def find_in_interval(self, n): @@ -60,8 +62,8 @@ def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}") + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] self.last_f = f @@ -69,8 +71,7 @@ def schedule(self, n, **kwargs): else: t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) t = min(t, 1.0) - f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( - 1 + np.cos(t * np.pi)) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi)) self.last_f = f return f @@ -84,15 +85,16 @@ def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}") + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] self.last_f = f return f else: - f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / ( + self.cycle_lengths[cycle] + ) self.last_f = f return f - diff --git a/src/metr/ldm/models/autoencoder.py b/src/metr/ldm/models/autoencoder.py index d122549..7b04047 100644 --- a/src/metr/ldm/models/autoencoder.py +++ b/src/metr/ldm/models/autoencoder.py @@ -1,28 +1,28 @@ -import torch -import pytorch_lightning as pl -import torch.nn.functional as F from contextlib import contextmanager -from ldm.modules.diffusionmodules.model import Encoder, Decoder +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from ldm.modules.diffusionmodules.model import Decoder, Encoder from ldm.modules.distributions.distributions import DiagonalGaussianDistribution - -from ldm.util import instantiate_from_config from ldm.modules.ema import LitEma +from ldm.util import instantiate_from_config class AutoencoderKL(pl.LightningModule): - def __init__(self, - ddconfig, - lossconfig, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - ema_decay=None, - learn_logvar=False - ): + def __init__( + self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False, + ): super().__init__() self.learn_logvar = learn_logvar self.image_key = image_key @@ -30,11 +30,11 @@ def __init__(self, self.decoder = Decoder(**ddconfig) self.loss = instantiate_from_config(lossconfig) assert ddconfig["double_z"] - self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim if colorize_nlabels is not None: - assert type(colorize_nlabels)==int + assert type(colorize_nlabels) == int self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) if monitor is not None: self.monitor = monitor @@ -42,7 +42,7 @@ def __init__(self, self.use_ema = ema_decay is not None if self.use_ema: self.ema_decay = ema_decay - assert 0. < ema_decay < 1. + assert 0.0 < ema_decay < 1.0 self.model_ema = LitEma(self, decay=ema_decay) print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") @@ -112,16 +112,30 @@ def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 0: # train encoder+decoder+logvar - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) return aeloss if optimizer_idx == 1: # train the discriminator - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) @@ -136,11 +150,25 @@ def validation_step(self, batch, batch_idx): def _validation_step(self, batch, batch_idx, postfix=""): inputs = self.get_input(batch, self.image_key) reconstructions, posterior = self(inputs) - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, - last_layer=self.get_last_layer(), split="val"+postfix) - - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, - last_layer=self.get_last_layer(), split="val"+postfix) + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + postfix, + ) + + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + postfix, + ) self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) self.log_dict(log_dict_ae) @@ -149,15 +177,17 @@ def _validation_step(self, batch, batch_idx, postfix=""): def configure_optimizers(self): lr = self.learning_rate - ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( - self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) + ae_params_list = ( + list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()) + ) if self.learn_logvar: print(f"{self.__class__.__name__}: Learning logvar") ae_params_list.append(self.loss.logvar) - opt_ae = torch.optim.Adam(ae_params_list, - lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr, betas=(0.5, 0.9)) + opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) return [opt_ae, opt_disc], [] def get_last_layer(self): @@ -194,7 +224,7 @@ def to_rgb(self, x): if not hasattr(self, "colorize"): self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) x = F.conv2d(x, weight=self.colorize) - x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 return x @@ -216,4 +246,3 @@ def quantize(self, x, *args, **kwargs): def forward(self, x, *args, **kwargs): return x - diff --git a/src/metr/ldm/models/diffusion/ddim.py b/src/metr/ldm/models/diffusion/ddim.py index 27ead0e..93c5373 100644 --- a/src/metr/ldm/models/diffusion/ddim.py +++ b/src/metr/ldm/models/diffusion/ddim.py @@ -1,11 +1,15 @@ """SAMPLING ONLY.""" -import torch import numpy as np +import torch +from ldm.modules.diffusionmodules.util import ( + extract_into_tensor, + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, +) from tqdm import tqdm -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor - class DDIMSampler(object): def __init__(self, model, schedule="linear", **kwargs): @@ -20,67 +24,75 @@ def register_buffer(self, name, attr): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, "alphas have to be defined for each timestep" to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - self.register_buffer('betas', to_torch(self.model.betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod.cpu()))) + self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))) + self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))) + self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1))) # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=ddim_eta, verbose=verbose + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer("ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps) @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - dynamic_threshold=None, - ucg_schedule=None, - **kwargs - ): + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + ucg_schedule=None, + **kwargs, + ): if conditioning is not None: if isinstance(conditioning, dict): ctmp = conditioning[list(conditioning.keys())[0]] - while isinstance(ctmp, list): ctmp = ctmp[0] + while isinstance(ctmp, list): + ctmp = ctmp[0] cbs = ctmp.shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") @@ -98,35 +110,53 @@ def sample(self, # sampling C, H, W = shape size = (batch_size, C, H, W) - print(f'Data shape for DDIM sampling is {size}, eta {eta}') - - samples, intermediates = self.ddim_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ucg_schedule=ucg_schedule - ) + print(f"Data shape for DDIM sampling is {size}, eta {eta}") + + samples, intermediates = self.ddim_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ucg_schedule=ucg_schedule, + ) return samples, intermediates @torch.no_grad() - def ddim_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, - ucg_schedule=None): + def ddim_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ucg_schedule=None, + ): device = self.model.betas.device b = shape[0] if x_T is None: @@ -140,12 +170,12 @@ def ddim_sampling(self, cond, shape, subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 timesteps = self.ddim_timesteps[:subset_end] - intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] print(f"Running DDIM Sampling with {total_steps} timesteps") - iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) for i, step in enumerate(iterator): index = total_steps - i - 1 @@ -154,37 +184,60 @@ def ddim_sampling(self, cond, shape, if mask is not None: assert x0 is not None img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? - img = img_orig * mask + (1. - mask) * img + img = img_orig * mask + (1.0 - mask) * img if ucg_schedule is not None: assert len(ucg_schedule) == len(time_range) unconditional_guidance_scale = ucg_schedule[i] - outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold) + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) img, pred_x0 = outs - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) return img, intermediates @torch.no_grad() - def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, - dynamic_threshold=None): + def p_sample_ddim( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ): b, *_, device = *x.shape, x.device - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: model_output = self.model.apply_model(x, t, c) else: x_in = torch.cat([x] * 2) @@ -194,13 +247,9 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F c_in = dict() for k in c: if isinstance(c[k], list): - c_in[k] = [torch.cat([ - unconditional_conditioning[k][i], - c[k][i]]) for i in range(len(c[k]))] + c_in[k] = [torch.cat([unconditional_conditioning[k][i], c[k][i]]) for i in range(len(c[k]))] else: - c_in[k] = torch.cat([ - unconditional_conditioning[k], - c[k]]) + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) elif isinstance(c, list): c_in = list() assert isinstance(unconditional_conditioning, list) @@ -217,18 +266,20 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F e_t = model_output if score_corrector is not None: - assert self.model.parameterization == "eps", 'not implemented' + assert self.model.parameterization == "eps", "not implemented" e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + ) sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas # select parameters corresponding to the currently considered timestep a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device) # current prediction for x_0 if self.model.parameterization != "v": @@ -243,16 +294,25 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F raise NotImplementedError() # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: + if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise return x_prev, pred_x0 @torch.no_grad() - def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, - unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): + def encode( + self, + x0, + c, + t_enc, + use_original_steps=False, + return_intermediates=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + callback=None, + ): num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] assert t_enc <= num_reference_steps @@ -268,33 +328,37 @@ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=No x_next = x0 intermediates = [] inter_steps = [] - for i in tqdm(range(num_steps), desc='Encoding Image'): + for i in tqdm(range(num_steps), desc="Encoding Image"): t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) - if unconditional_guidance_scale == 1.: + if unconditional_guidance_scale == 1.0: noise_pred = self.model.apply_model(x_next, t, c) else: assert unconditional_conditioning is not None e_t_uncond, noise_pred = torch.chunk( - self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), - torch.cat((unconditional_conditioning, c))), 2) + self.model.apply_model( + torch.cat((x_next, x_next)), torch.cat((t, t)), torch.cat((unconditional_conditioning, c)) + ), + 2, + ) noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next - weighted_noise_pred = alphas_next[i].sqrt() * ( - (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + weighted_noise_pred = ( + alphas_next[i].sqrt() * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + ) x_next = xt_weighted + weighted_noise_pred - if return_intermediates and i % ( - num_steps // return_intermediates) == 0 and i < num_steps - 1: + if return_intermediates and i % (num_steps // return_intermediates) == 0 and i < num_steps - 1: intermediates.append(x_next) inter_steps.append(i) elif return_intermediates and i >= num_steps - 2: intermediates.append(x_next) inter_steps.append(i) - if callback: callback(i) + if callback: + callback(i) - out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} + out = {"x_encoded": x_next, "intermediate_steps": inter_steps} if return_intermediates: - out.update({'intermediates': intermediates}) + out.update({"intermediates": intermediates}) return x_next, out @torch.no_grad() @@ -310,12 +374,22 @@ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): if noise is None: noise = torch.randn_like(x0) - return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + - extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + return ( + extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise + ) @torch.no_grad() - def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - use_original_steps=False, callback=None): + def decode( + self, + x_latent, + cond, + t_start, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + callback=None, + ): timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps timesteps = timesteps[:t_start] @@ -324,13 +398,20 @@ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unco total_steps = timesteps.shape[0] print(f"Running DDIM Sampling with {total_steps} timesteps") - iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + iterator = tqdm(time_range, desc="Decoding image", total=total_steps) x_dec = x_latent for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) - x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning) - if callback: callback(i) - return x_dec \ No newline at end of file + x_dec, _ = self.p_sample_ddim( + x_dec, + cond, + ts, + index=index, + use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + if callback: + callback(i) + return x_dec diff --git a/src/metr/ldm/models/diffusion/ddpm.py b/src/metr/ldm/models/diffusion/ddpm.py index fbbfeca..8169ecf 100644 --- a/src/metr/ldm/models/diffusion/ddpm.py +++ b/src/metr/ldm/models/diffusion/ddpm.py @@ -6,31 +6,28 @@ -- merci """ -import torch -import torch.nn as nn -import numpy as np -import pytorch_lightning as pl -from torch.optim.lr_scheduler import LambdaLR -from einops import rearrange, repeat +import itertools from contextlib import contextmanager, nullcontext from functools import partial -import itertools -from tqdm import tqdm -from torchvision.utils import make_grid -from pytorch_lightning.utilities.rank_zero import rank_zero_only -from omegaconf import ListConfig -from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config -from ldm.modules.ema import LitEma -from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution -from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL -from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn as nn +from einops import rearrange, repeat +from ldm.models.autoencoder import AutoencoderKL, IdentityFirstStage from ldm.models.diffusion.ddim import DDIMSampler +from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule, noise_like +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution, normal_kl +from ldm.modules.ema import LitEma +from ldm.util import count_params, default, exists, instantiate_from_config, isimage, ismap, log_txt_as_img, mean_flat +from omegaconf import ListConfig +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from torch.optim.lr_scheduler import LambdaLR +from torchvision.utils import make_grid +from tqdm import tqdm - -__conditioning_keys__ = {'concat': 'c_concat', - 'crossattn': 'c_crossattn', - 'adm': 'y'} +__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"} def disabled_train(self, mode=True): @@ -45,39 +42,40 @@ def uniform_on_device(r1, r2, shape, device): class DDPM(pl.LightningModule): # classic DDPM with Gaussian diffusion, in image space - def __init__(self, - unet_config, - timesteps=1000, - beta_schedule="linear", - loss_type="l2", - ckpt_path=None, - ignore_keys=[], - load_only_unet=False, - monitor="val/loss", - use_ema=True, - first_stage_key="image", - image_size=256, - channels=3, - log_every_t=100, - clip_denoised=True, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - given_betas=None, - original_elbo_weight=0., - v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta - l_simple_weight=1., - conditioning_key=None, - parameterization="eps", # all assuming fixed variance schedules - scheduler_config=None, - use_positional_encodings=False, - learn_logvar=False, - logvar_init=0., - make_it_fit=False, - ucg_training=None, - reset_ema=False, - reset_num_ema_updates=False, - ): + def __init__( + self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0.0, + v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1.0, + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0.0, + make_it_fit=False, + ucg_training=None, + reset_ema=False, + reset_num_ema_updates=False, + ): super().__init__() assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"' self.parameterization = parameterization @@ -107,20 +105,29 @@ def __init__(self, if monitor is not None: self.monitor = monitor self.make_it_fit = make_it_fit - if reset_ema: assert exists(ckpt_path) + if reset_ema: + assert exists(ckpt_path) if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) if reset_ema: assert self.use_ema - print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + print( + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint." + ) self.model_ema = LitEma(self.model) if reset_num_ema_updates: print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") assert self.use_ema self.model_ema.reset_num_updates() - self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, - linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + self.register_schedule( + given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) self.loss_type = loss_type @@ -133,60 +140,71 @@ def __init__(self, if self.ucg_training: self.ucg_prng = np.random.RandomState() - def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): if exists(given_betas): betas = given_betas else: - betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, - cosine_s=cosine_s) - alphas = 1. - betas + betas = make_beta_schedule( + beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s + ) + alphas = 1.0 - betas alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) - timesteps, = betas.shape + (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end - assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + assert alphas_cumprod.shape[0] == self.num_timesteps, "alphas have to be defined for each timestep" to_torch = partial(torch.tensor, dtype=torch.float32) - self.register_buffer('betas', to_torch(betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))) + self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))) + self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))) + self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))) # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( - 1. - alphas_cumprod) + self.v_posterior * betas + posterior_variance = (1 - self.v_posterior) * betas * (1.0 - alphas_cumprod_prev) / ( + 1.0 - alphas_cumprod + ) + self.v_posterior * betas # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.register_buffer('posterior_variance', to_torch(posterior_variance)) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) - self.register_buffer('posterior_mean_coef1', to_torch( - betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) - self.register_buffer('posterior_mean_coef2', to_torch( - (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + self.register_buffer("posterior_log_variance_clipped", to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer( + "posterior_mean_coef1", to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)) + ) + self.register_buffer( + "posterior_mean_coef2", to_torch((1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)) + ) if self.parameterization == "eps": - lvlb_weights = self.betas ** 2 / ( - 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + lvlb_weights = self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) elif self.parameterization == "x0": - lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod)) elif self.parameterization == "v": - lvlb_weights = torch.ones_like(self.betas ** 2 / ( - 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))) + lvlb_weights = torch.ones_like( + self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + ) else: raise NotImplementedError("mu not supported") lvlb_weights[0] = lvlb_weights[1] - self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) assert not torch.isnan(self.lvlb_weights).all() @contextmanager @@ -216,14 +234,11 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): print("Deleting key {} from state_dict.".format(k)) del sd[k] if self.make_it_fit: - n_params = len([name for name, _ in - itertools.chain(self.named_parameters(), - self.named_buffers())]) + n_params = len([name for name, _ in itertools.chain(self.named_parameters(), self.named_buffers())]) for name, param in tqdm( - itertools.chain(self.named_parameters(), - self.named_buffers()), - desc="Fitting old weights to new weights", - total=n_params + itertools.chain(self.named_parameters(), self.named_buffers()), + desc="Fitting old weights to new weights", + total=n_params, ): if not name in sd: continue @@ -259,8 +274,9 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): sd[name] = new_param - missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( - sd, strict=False) + missing, unexpected = ( + self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) + ) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") if len(missing) > 0: print(f"Missing Keys:\n {missing}") @@ -274,35 +290,35 @@ def q_mean_variance(self, x_start, t): :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ - mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) return mean, variance, log_variance def predict_start_from_noise(self, x_t, t, noise): return ( - extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) def predict_start_from_z_and_v(self, x_t, t, v): # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v ) def predict_eps_from_z_and_v(self, x_t, t, v): return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t ) def q_posterior(self, x_start, x_t, t): posterior_mean = ( - extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + - extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) @@ -315,7 +331,7 @@ def p_mean_variance(self, x, t, clip_denoised: bool): elif self.parameterization == "x0": x_recon = model_out if clip_denoised: - x_recon.clamp_(-1., 1.) + x_recon.clamp_(-1.0, 1.0) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance @@ -335,9 +351,10 @@ def p_sample_loop(self, shape, return_intermediates=False): b = shape[0] img = torch.randn(shape, device=device) intermediates = [img] - for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): - img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), - clip_denoised=self.clip_denoised) + for i in tqdm(reversed(range(0, self.num_timesteps)), desc="Sampling t", total=self.num_timesteps): + img = self.p_sample( + img, torch.full((b,), i, device=device, dtype=torch.long), clip_denoised=self.clip_denoised + ) if i % self.log_every_t == 0 or i == self.num_timesteps - 1: intermediates.append(img) if return_intermediates: @@ -348,30 +365,33 @@ def p_sample_loop(self, shape, return_intermediates=False): def sample(self, batch_size=16, return_intermediates=False): image_size = self.image_size channels = self.channels - return self.p_sample_loop((batch_size, channels, image_size, image_size), - return_intermediates=return_intermediates) + return self.p_sample_loop( + (batch_size, channels, image_size, image_size), return_intermediates=return_intermediates + ) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) def get_v(self, x, noise, t): return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x ) def get_loss(self, pred, target, mean=True): - if self.loss_type == 'l1': + if self.loss_type == "l1": loss = (target - pred).abs() if mean: loss = loss.mean() - elif self.loss_type == 'l2': + elif self.loss_type == "l2": if mean: loss = torch.nn.functional.mse_loss(target, pred) else: - loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + loss = torch.nn.functional.mse_loss(target, pred, reduction="none") else: raise NotImplementedError("unknown loss type '{loss_type}'") @@ -394,17 +414,17 @@ def p_losses(self, x_start, t, noise=None): loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) - log_prefix = 'train' if self.training else 'val' + log_prefix = "train" if self.training else "val" - loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_dict.update({f"{log_prefix}/loss_simple": loss.mean()}) loss_simple = loss.mean() * self.l_simple_weight loss_vlb = (self.lvlb_weights[t] * loss).mean() - loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + loss_dict.update({f"{log_prefix}/loss_vlb": loss_vlb}) loss = loss_simple + self.original_elbo_weight * loss_vlb - loss_dict.update({f'{log_prefix}/loss': loss}) + loss_dict.update({f"{log_prefix}/loss": loss}) return loss, loss_dict @@ -418,7 +438,7 @@ def get_input(self, batch, k): x = batch[k] if len(x.shape) == 3: x = x[..., None] - x = rearrange(x, 'b h w c -> b c h w') + x = rearrange(x, "b h w c -> b c h w") x = x.to(memory_format=torch.contiguous_format).float() return x @@ -439,15 +459,13 @@ def training_step(self, batch, batch_idx): loss, loss_dict = self.shared_step(batch) - self.log_dict(loss_dict, prog_bar=True, - logger=True, on_step=True, on_epoch=True) + self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log("global_step", self.global_step, - prog_bar=True, logger=True, on_step=True, on_epoch=False) + self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False) if self.use_scheduler: - lr = self.optimizers().param_groups[0]['lr'] - self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + lr = self.optimizers().param_groups[0]["lr"] + self.log("lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) return loss @@ -456,7 +474,7 @@ def validation_step(self, batch, batch_idx): _, loss_dict_no_ema = self.shared_step(batch) with self.ema_scope(): _, loss_dict_ema = self.shared_step(batch) - loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + loss_dict_ema = {key + "_ema": loss_dict_ema[key] for key in loss_dict_ema} self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) @@ -466,8 +484,8 @@ def on_train_batch_end(self, *args, **kwargs): def _get_rows_from_list(self, samples): n_imgs_per_row = len(samples) - denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') - denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = rearrange(samples, "n b c h w -> b n c h w") + denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) return denoise_grid @@ -486,7 +504,7 @@ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwarg for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) t = t.to(self.device).long() noise = torch.randn_like(x_start) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) @@ -521,27 +539,30 @@ def configure_optimizers(self): class LatentDiffusion(DDPM): """main class""" - def __init__(self, - first_stage_config, - cond_stage_config, - num_timesteps_cond=None, - cond_stage_key="image", - cond_stage_trainable=False, - concat_mode=True, - cond_stage_forward=None, - conditioning_key=None, - scale_factor=1.0, - scale_by_std=False, - force_null_conditioning=False, - *args, **kwargs): + def __init__( + self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + force_null_conditioning=False, + *args, + **kwargs, + ): self.force_null_conditioning = force_null_conditioning self.num_timesteps_cond = default(num_timesteps_cond, 1) self.scale_by_std = scale_by_std - assert self.num_timesteps_cond <= kwargs['timesteps'] + assert self.num_timesteps_cond <= kwargs["timesteps"] # for backwards compatibility after implementation of DiffusionWrapper if conditioning_key is None: - conditioning_key = 'concat' if concat_mode else 'crossattn' - if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning: + conditioning_key = "concat" if concat_mode else "crossattn" + if cond_stage_config == "__is_unconditional__" and not self.force_null_conditioning: conditioning_key = None ckpt_path = kwargs.pop("ckpt_path", None) reset_ema = kwargs.pop("reset_ema", False) @@ -558,7 +579,7 @@ def __init__(self, if not scale_by_std: self.scale_factor = scale_factor else: - self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.register_buffer("scale_factor", torch.tensor(scale_factor)) self.instantiate_first_stage(first_stage_config) self.instantiate_cond_stage(cond_stage_config) self.cond_stage_forward = cond_stage_forward @@ -572,24 +593,33 @@ def __init__(self, if reset_ema: assert self.use_ema print( - f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint." + ) self.model_ema = LitEma(self.model) if reset_num_ema_updates: print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") assert self.use_ema self.model_ema.reset_num_updates() - def make_cond_schedule(self, ): + def make_cond_schedule( + self, + ): self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() - self.cond_ids[:self.num_timesteps_cond] = ids + self.cond_ids[: self.num_timesteps_cond] = ids @rank_zero_only @torch.no_grad() def on_train_batch_start(self, batch, batch_idx, dataloader_idx): # only for very first batch - if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: - assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + if ( + self.scale_by_std + and self.current_epoch == 0 + and self.global_step == 0 + and batch_idx == 0 + and not self.restarted_from_ckpt + ): + assert self.scale_factor == 1.0, "rather not use custom rescaling and std-rescaling simultaneously" # set rescale weight to 1./std of encodings print("### USING STD-RESCALING ###") x = super().get_input(batch, self.first_stage_key) @@ -597,13 +627,19 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx): encoder_posterior = self.encode_first_stage(x) z = self.get_first_stage_encoding(encoder_posterior).detach() del self.scale_factor - self.register_buffer('scale_factor', 1. / z.flatten().std()) + self.register_buffer("scale_factor", 1.0 / z.flatten().std()) print(f"setting self.scale_factor to {self.scale_factor}") print("### USING STD-RESCALING ###") - def register_schedule(self, - given_betas=None, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) self.shorten_cond_schedule = self.num_timesteps_cond > 1 @@ -633,20 +669,21 @@ def instantiate_cond_stage(self, config): for param in self.cond_stage_model.parameters(): param.requires_grad = False else: - assert config != '__is_first_stage__' - assert config != '__is_unconditional__' + assert config != "__is_first_stage__" + assert config != "__is_unconditional__" model = instantiate_from_config(config) self.cond_stage_model = model - def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + def _get_denoise_row_from_list(self, samples, desc="", force_no_decoder_quantization=False): denoise_row = [] for zd in tqdm(samples, desc=desc): - denoise_row.append(self.decode_first_stage(zd.to(self.device), - force_not_quantize=force_no_decoder_quantization)) + denoise_row.append( + self.decode_first_stage(zd.to(self.device), force_not_quantize=force_no_decoder_quantization) + ) n_imgs_per_row = len(denoise_row) denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W - denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') - denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w") + denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) return denoise_grid @@ -661,7 +698,7 @@ def get_first_stage_encoding(self, encoder_posterior): def get_learned_conditioning(self, c): if self.cond_stage_forward is None: - if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + if hasattr(self.cond_stage_model, "encode") and callable(self.cond_stage_model.encode): c = self.cond_stage_model.encode(c) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() @@ -695,15 +732,20 @@ def delta_border(self, h, w): def get_weighting(self, h, w, Ly, Lx, device): weighting = self.delta_border(h, w) - weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], - self.split_input_params["clip_max_weight"], ) + weighting = torch.clip( + weighting, + self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], + ) weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) if self.split_input_params["tie_braker"]: L_weighting = self.delta_border(Ly, Lx) - L_weighting = torch.clip(L_weighting, - self.split_input_params["clip_min_tie_weight"], - self.split_input_params["clip_max_tie_weight"]) + L_weighting = torch.clip( + L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"], + ) L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) weighting = weighting * L_weighting @@ -734,9 +776,12 @@ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) unfold = torch.nn.Unfold(**fold_params) - fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), - dilation=1, padding=0, - stride=(stride[0] * uf, stride[1] * uf)) + fold_params2 = dict( + kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, + padding=0, + stride=(stride[0] * uf, stride[1] * uf), + ) fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) @@ -747,9 +792,12 @@ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) unfold = torch.nn.Unfold(**fold_params) - fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), - dilation=1, padding=0, - stride=(stride[0] // df, stride[1] // df)) + fold_params2 = dict( + kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, + padding=0, + stride=(stride[0] // df, stride[1] // df), + ) fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) @@ -762,8 +810,17 @@ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once return fold, unfold, normalization, weighting @torch.no_grad() - def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, - cond_key=None, return_original_cond=False, bs=None, return_x=False): + def get_input( + self, + batch, + k, + return_first_stage_outputs=False, + force_c_encode=False, + cond_key=None, + return_original_cond=False, + bs=None, + return_x=False, + ): x = super().get_input(batch, k) if bs is not None: x = x[:bs] @@ -775,9 +832,9 @@ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=F if cond_key is None: cond_key = self.cond_stage_key if cond_key != self.first_stage_key: - if cond_key in ['caption', 'coordinates_bbox', "txt"]: + if cond_key in ["caption", "coordinates_bbox", "txt"]: xc = batch[cond_key] - elif cond_key in ['class_label', 'cls']: + elif cond_key in ["class_label", "cls"]: xc = batch else: xc = super().get_input(batch, cond_key).to(self.device) @@ -796,14 +853,14 @@ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=F if self.use_positional_encodings: pos_x, pos_y = self.compute_latent_shifts(batch) ckey = __conditioning_keys__[self.model.conditioning_key] - c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + c = {ckey: c, "pos_x": pos_x, "pos_y": pos_y} else: c = None xc = None if self.use_positional_encodings: pos_x, pos_y = self.compute_latent_shifts(batch) - c = {'pos_x': pos_x, 'pos_y': pos_y} + c = {"pos_x": pos_x, "pos_y": pos_y} out = [z, c] if return_first_stage_outputs: xrec = self.decode_first_stage(z) @@ -820,9 +877,9 @@ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): if z.dim() == 4: z = torch.argmax(z.exp(), dim=1).long() z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) - z = rearrange(z, 'b h w c -> b c h w').contiguous() + z = rearrange(z, "b h w c -> b c h w").contiguous() - z = 1. / self.scale_factor * z + z = 1.0 / self.scale_factor * z return self.first_stage_model.decode(z) @torch.no_grad() @@ -852,7 +909,7 @@ def apply_model(self, x_noisy, t, cond, return_ids=False): else: if not isinstance(cond, list): cond = [cond] - key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + key = "c_concat" if self.model.conditioning_key == "concat" else "c_crossattn" cond = {key: cond} x_recon = self.model(x_noisy, t, **cond) @@ -863,8 +920,9 @@ def apply_model(self, x_noisy, t, cond, return_ids=False): return x_recon def _predict_eps_from_xstart(self, x_t, t, pred_xstart): - return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _prior_bpd(self, x_start): """ @@ -886,7 +944,7 @@ def p_losses(self, x_start, cond, t, noise=None): model_output = self.apply_model(x_noisy, t, cond) loss_dict = {} - prefix = 'train' if self.training else 'val' + prefix = "train" if self.training else "val" if self.parameterization == "x0": target = x_start @@ -898,27 +956,37 @@ def p_losses(self, x_start, cond, t, noise=None): raise NotImplementedError() loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) - loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + loss_dict.update({f"{prefix}/loss_simple": loss_simple.mean()}) logvar_t = self.logvar[t].to(self.device) loss = loss_simple / torch.exp(logvar_t) + logvar_t # loss = loss_simple / torch.exp(self.logvar) + self.logvar if self.learn_logvar: - loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) - loss_dict.update({'logvar': self.logvar.data.mean()}) + loss_dict.update({f"{prefix}/loss_gamma": loss.mean()}) + loss_dict.update({"logvar": self.logvar.data.mean()}) loss = self.l_simple_weight * loss.mean() loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() - loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) - loss += (self.original_elbo_weight * loss_vlb) - loss_dict.update({f'{prefix}/loss': loss}) + loss_dict.update({f"{prefix}/loss_vlb": loss_vlb}) + loss += self.original_elbo_weight * loss_vlb + loss_dict.update({f"{prefix}/loss": loss}) return loss, loss_dict - def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, - return_x0=False, score_corrector=None, corrector_kwargs=None): + def p_mean_variance( + self, + x, + c, + t, + clip_denoised: bool, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + score_corrector=None, + corrector_kwargs=None, + ): t_in = t model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) @@ -937,7 +1005,7 @@ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=Fals raise NotImplementedError() if clip_denoised: - x_recon.clamp_(-1., 1.) + x_recon.clamp_(-1.0, 1.0) if quantize_denoised: x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) @@ -949,15 +1017,33 @@ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=Fals return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() - def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, - return_codebook_ids=False, quantize_denoised=False, return_x0=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + def p_sample( + self, + x, + c, + t, + clip_denoised=False, + repeat_noise=False, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + ): b, *_, device = *x.shape, x.device - outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, - return_codebook_ids=return_codebook_ids, - quantize_denoised=quantize_denoised, - return_x0=return_x0, - score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + outputs = self.p_mean_variance( + x=x, + c=c, + t=t, + clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) if return_codebook_ids: raise DeprecationWarning("Support dropped.") model_mean, _, model_log_variance, logits = outputs @@ -967,7 +1053,7 @@ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, model_mean, _, model_log_variance = outputs noise = noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: + if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) @@ -980,10 +1066,25 @@ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() - def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, - img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., - score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, - log_every_t=None): + def progressive_denoising( + self, + cond, + shape, + verbose=True, + callback=None, + quantize_denoised=False, + img_callback=None, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + batch_size=None, + x_T=None, + start_T=None, + log_every_t=None, + ): if not log_every_t: log_every_t = self.log_every_t timesteps = self.num_timesteps @@ -999,47 +1100,76 @@ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quanti intermediates = [] if cond is not None: if isinstance(cond, dict): - cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + cond = { + key: ( + cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + ) + for key in cond + } else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] if start_T is not None: timesteps = min(timesteps, start_T) - iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', - total=timesteps) if verbose else reversed( - range(0, timesteps)) + iterator = ( + tqdm(reversed(range(0, timesteps)), desc="Progressive Generation", total=timesteps) + if verbose + else reversed(range(0, timesteps)) + ) if type(temperature) == float: temperature = [temperature] * timesteps for i in iterator: ts = torch.full((b,), i, device=self.device, dtype=torch.long) if self.shorten_cond_schedule: - assert self.model.conditioning_key != 'hybrid' + assert self.model.conditioning_key != "hybrid" tc = self.cond_ids[ts].to(cond.device) cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) - img, x0_partial = self.p_sample(img, cond, ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised, return_x0=True, - temperature=temperature[i], noise_dropout=noise_dropout, - score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + img, x0_partial = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + return_x0=True, + temperature=temperature[i], + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) if mask is not None: assert x0 is not None img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1. - mask) * img + img = img_orig * mask + (1.0 - mask) * img if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(x0_partial) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) return img, intermediates @torch.no_grad() - def p_sample_loop(self, cond, shape, return_intermediates=False, - x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, start_T=None, - log_every_t=None): + def p_sample_loop( + self, + cond, + shape, + return_intermediates=False, + x_T=None, + verbose=True, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + start_T=None, + log_every_t=None, + ): if not log_every_t: log_every_t = self.log_every_t @@ -1056,8 +1186,11 @@ def p_sample_loop(self, cond, shape, return_intermediates=False, if start_T is not None: timesteps = min(timesteps, start_T) - iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( - range(0, timesteps)) + iterator = ( + tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps) + if verbose + else reversed(range(0, timesteps)) + ) if mask is not None: assert x0 is not None @@ -1066,55 +1199,76 @@ def p_sample_loop(self, cond, shape, return_intermediates=False, for i in iterator: ts = torch.full((b,), i, device=device, dtype=torch.long) if self.shorten_cond_schedule: - assert self.model.conditioning_key != 'hybrid' + assert self.model.conditioning_key != "hybrid" tc = self.cond_ids[ts].to(cond.device) cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) - img = self.p_sample(img, cond, ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised) + img = self.p_sample(img, cond, ts, clip_denoised=self.clip_denoised, quantize_denoised=quantize_denoised) if mask is not None: img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1. - mask) * img + img = img_orig * mask + (1.0 - mask) * img if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(img) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) if return_intermediates: return img, intermediates return img @torch.no_grad() - def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, - verbose=True, timesteps=None, quantize_denoised=False, - mask=None, x0=None, shape=None, **kwargs): + def sample( + self, + cond, + batch_size=16, + return_intermediates=False, + x_T=None, + verbose=True, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + shape=None, + **kwargs, + ): if shape is None: shape = (batch_size, self.channels, self.image_size, self.image_size) if cond is not None: if isinstance(cond, dict): - cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + cond = { + key: ( + cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + ) + for key in cond + } else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] - return self.p_sample_loop(cond, - shape, - return_intermediates=return_intermediates, x_T=x_T, - verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, - mask=mask, x0=x0) + return self.p_sample_loop( + cond, + shape, + return_intermediates=return_intermediates, + x_T=x_T, + verbose=verbose, + timesteps=timesteps, + quantize_denoised=quantize_denoised, + mask=mask, + x0=x0, + ) @torch.no_grad() def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): if ddim: ddim_sampler = DDIMSampler(self) shape = (self.channels, self.image_size, self.image_size) - samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, - shape, cond, verbose=False, **kwargs) + samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs) else: - samples, intermediates = self.sample(cond=cond, batch_size=batch_size, - return_intermediates=True, **kwargs) + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, return_intermediates=True, **kwargs) return samples, intermediates @@ -1138,26 +1292,43 @@ def get_unconditional_conditioning(self, batch_size, null_label=None): raise NotImplementedError("todo") if isinstance(c, list): # in case the encoder gives us a list for i in range(len(c)): - c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device) + c[i] = repeat(c[i], "1 ... -> b ...", b=batch_size).to(self.device) else: - c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device) + c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device) return c @torch.no_grad() - def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None, - quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, - plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, - use_ema_scope=True, - **kwargs): + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=50, + ddim_eta=0.0, + return_keys=None, + quantize_denoised=True, + inpaint=True, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): ema_scope = self.ema_scope if use_ema_scope else nullcontext use_ddim = ddim_steps is not None log = dict() - z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=N) + z, c, x, xrec, xc = self.get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N, + ) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) log["inputs"] = x @@ -1169,10 +1340,10 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0 elif self.cond_stage_key in ["caption", "txt"]: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) log["conditioning"] = xc - elif self.cond_stage_key in ['class_label', "cls"]: + elif self.cond_stage_key in ["class_label", "cls"]: try: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) - log['conditioning'] = xc + log["conditioning"] = xc except KeyError: # probably no "human_label" in batch pass @@ -1187,23 +1358,24 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0 z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) t = t.to(self.device).long() noise = torch.randn_like(z_start) z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy)) diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W - diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') - diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") + diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) log["diffusion_row"] = diffusion_grid if sample: # get denoise row with ema_scope("Sampling"): - samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta) + samples, z_denoise_row = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta + ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log["samples"] = x_samples @@ -1211,13 +1383,16 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0 denoise_grid = self._get_denoise_row_from_list(z_denoise_row) log["denoise_row"] = denoise_grid - if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( - self.first_stage_model, IdentityFirstStage): + if ( + quantize_denoised + and not isinstance(self.first_stage_model, AutoencoderKL) + and not isinstance(self.first_stage_model, IdentityFirstStage) + ): # also display when quantizing x0 while sampling with ema_scope("Plotting Quantized Denoised"): - samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta, - quantize_denoised=True) + samples, z_denoise_row = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta, quantize_denoised=True + ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, # quantize_denoised=True) x_samples = self.decode_first_stage(samples.to(self.device)) @@ -1228,11 +1403,15 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0 if self.model.conditioning_key == "crossattn-adm": uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]} with ema_scope("Sampling with classifier-free guidance"): - samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=uc, - ) + samples_cfg, _ = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) x_samples_cfg = self.decode_first_stage(samples_cfg) log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg @@ -1241,28 +1420,30 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0 b, h, w = z.shape[0], z.shape[2], z.shape[3] mask = torch.ones(N, h, w).to(self.device) # zeros will be filled in - mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0 mask = mask[:, None, ...] with ema_scope("Plotting Inpaint"): - samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, - ddim_steps=ddim_steps, x0=z[:N], mask=mask) + samples, _ = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, ddim_steps=ddim_steps, x0=z[:N], mask=mask + ) x_samples = self.decode_first_stage(samples.to(self.device)) log["samples_inpainting"] = x_samples log["mask"] = mask # outpaint - mask = 1. - mask + mask = 1.0 - mask with ema_scope("Plotting Outpaint"): - samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, - ddim_steps=ddim_steps, x0=z[:N], mask=mask) + samples, _ = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, ddim_steps=ddim_steps, x0=z[:N], mask=mask + ) x_samples = self.decode_first_stage(samples.to(self.device)) log["samples_outpainting"] = x_samples if plot_progressive_rows: with ema_scope("Plotting Progressives"): - img, progressives = self.progressive_denoising(c, - shape=(self.channels, self.image_size, self.image_size), - batch_size=N) + img, progressives = self.progressive_denoising( + c, shape=(self.channels, self.image_size, self.image_size), batch_size=N + ) prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") log["progressive_row"] = prog_row @@ -1280,20 +1461,15 @@ def configure_optimizers(self): print(f"{self.__class__.__name__}: Also optimizing conditioner params!") params = params + list(self.cond_stage_model.parameters()) if self.learn_logvar: - print('Diffusion model optimizing logvar') + print("Diffusion model optimizing logvar") params.append(self.logvar) opt = torch.optim.AdamW(params, lr=lr) if self.use_scheduler: - assert 'target' in self.scheduler_config + assert "target" in self.scheduler_config scheduler = instantiate_from_config(self.scheduler_config) print("Setting up LambdaLR scheduler...") - scheduler = [ - { - 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), - 'interval': 'step', - 'frequency': 1 - }] + scheduler = [{"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), "interval": "step", "frequency": 1}] return [opt], scheduler return opt @@ -1303,7 +1479,7 @@ def to_rgb(self, x): if not hasattr(self, "colorize"): self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) x = nn.functional.conv2d(x, weight=self.colorize) - x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 return x @@ -1313,34 +1489,34 @@ def __init__(self, diff_model_config, conditioning_key): self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False) self.diffusion_model = instantiate_from_config(diff_model_config) self.conditioning_key = conditioning_key - assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] + assert self.conditioning_key in [None, "concat", "crossattn", "hybrid", "adm", "hybrid-adm", "crossattn-adm"] def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): if self.conditioning_key is None: out = self.diffusion_model(x, t) - elif self.conditioning_key == 'concat': + elif self.conditioning_key == "concat": xc = torch.cat([x] + c_concat, dim=1) out = self.diffusion_model(xc, t) - elif self.conditioning_key == 'crossattn': + elif self.conditioning_key == "crossattn": if not self.sequential_cross_attn: cc = torch.cat(c_crossattn, 1) else: cc = c_crossattn out = self.diffusion_model(x, t, context=cc) - elif self.conditioning_key == 'hybrid': + elif self.conditioning_key == "hybrid": xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(xc, t, context=cc) - elif self.conditioning_key == 'hybrid-adm': + elif self.conditioning_key == "hybrid-adm": assert c_adm is not None xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(xc, t, context=cc, y=c_adm) - elif self.conditioning_key == 'crossattn-adm': + elif self.conditioning_key == "crossattn-adm": assert c_adm is not None cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(x, t, context=cc, y=c_adm) - elif self.conditioning_key == 'adm': + elif self.conditioning_key == "adm": cc = c_crossattn[0] out = self.diffusion_model(x, t, y=cc) else: @@ -1370,15 +1546,21 @@ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): if not log_mode: z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) else: - z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, - force_c_encode=True, return_original_cond=True, bs=bs) + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) x_low = batch[self.low_scale_key][:bs] - x_low = rearrange(x_low, 'b h w c -> b c h w') + x_low = rearrange(x_low, "b h w c -> b c h w") x_low = x_low.to(memory_format=torch.contiguous_format).float() zx, noise_level = self.low_scale_model(x_low) if self.noise_level_key is not None: # get noise level from batch instead, e.g. when extracting a custom noise level for bsr - raise NotImplementedError('TODO') + raise NotImplementedError("TODO") all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} if log_mode: @@ -1388,16 +1570,30 @@ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): return z, all_conds @torch.no_grad() - def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, - plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, - unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True, - **kwargs): + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1.0, + return_keys=None, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): ema_scope = self.ema_scope if use_ema_scope else nullcontext use_ddim = ddim_steps is not None log = dict() - z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N, - log_mode=True) + z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input( + batch, self.first_stage_key, bs=N, log_mode=True + ) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) log["inputs"] = x @@ -1411,9 +1607,9 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta= elif self.cond_stage_key in ["caption", "txt"]: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) log["conditioning"] = xc - elif self.cond_stage_key in ['class_label', 'cls']: + elif self.cond_stage_key in ["class_label", "cls"]: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) - log['conditioning'] = xc + log["conditioning"] = xc elif isimage(xc): log["conditioning"] = xc if ismap(xc): @@ -1425,23 +1621,24 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta= z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) t = t.to(self.device).long() noise = torch.randn_like(z_start) z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy)) diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W - diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') - diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") + diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) log["diffusion_row"] = diffusion_grid if sample: # get denoise row with ema_scope("Sampling"): - samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta) + samples, z_denoise_row = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta + ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log["samples"] = x_samples @@ -1460,7 +1657,7 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta= uc[k] = [uc_tmp] elif k == "c_adm": # todo: only run with text-based guidance? assert isinstance(c[k], torch.Tensor) - #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level + # uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level uc[k] = c[k] elif isinstance(c[k], list): uc[k] = [c[k][i] for i in range(len(c[k]))] @@ -1468,19 +1665,23 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta= uc[k] = c[k] with ema_scope("Sampling with classifier-free guidance"): - samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=uc, - ) + samples_cfg, _ = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) x_samples_cfg = self.decode_first_stage(samples_cfg) log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg if plot_progressive_rows: with ema_scope("Plotting Progressives"): - img, progressives = self.progressive_denoising(c, - shape=(self.channels, self.image_size, self.image_size), - batch_size=N) + img, progressives = self.progressive_denoising( + c, shape=(self.channels, self.image_size, self.image_size), batch_size=N + ) prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") log["progressive_row"] = prog_row @@ -1489,21 +1690,24 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta= class LatentFinetuneDiffusion(LatentDiffusion): """ - Basis for different finetunas, such as inpainting or depth2image - To disable finetuning mode, set finetune_keys to None + Basis for different finetunas, such as inpainting or depth2image + To disable finetuning mode, set finetune_keys to None """ - def __init__(self, - concat_keys: tuple, - finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", - "model_ema.diffusion_modelinput_blocks00weight" - ), - keep_finetune_dims=4, - # if model was trained without concat mode before and we would like to keep these channels - c_concat_log_start=None, # to log reconstruction of c_concat codes - c_concat_log_end=None, - *args, **kwargs - ): + def __init__( + self, + concat_keys: tuple, + finetune_keys=( + "model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight", + ), + keep_finetune_dims=4, + # if model was trained without concat mode before and we would like to keep these channels + c_concat_log_start=None, # to log reconstruction of c_concat codes + c_concat_log_end=None, + *args, + **kwargs, + ): ckpt_path = kwargs.pop("ckpt_path", None) ignore_keys = kwargs.pop("ignore_keys", list()) super().__init__(*args, **kwargs) @@ -1512,7 +1716,8 @@ def __init__(self, self.keep_dims = keep_finetune_dims self.c_concat_log_start = c_concat_log_start self.c_concat_log_end = c_concat_log_end - if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint' + if exists(self.finetune_keys): + assert exists(ckpt_path), "can only finetune from a given checkpoint" if exists(ckpt_path): self.init_from_ckpt(ckpt_path, ignore_keys) @@ -1533,14 +1738,16 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): for name, param in self.named_parameters(): if name in self.finetune_keys: print( - f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only") + f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only" + ) new_entry = torch.zeros_like(param) # zero init - assert exists(new_entry), 'did not find matching parameter to modify' - new_entry[:, :self.keep_dims, ...] = sd[k] + assert exists(new_entry), "did not find matching parameter to modify" + new_entry[:, : self.keep_dims, ...] = sd[k] sd[k] = new_entry - missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( - sd, strict=False) + missing, unexpected = ( + self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) + ) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") if len(missing) > 0: print(f"Missing Keys: {missing}") @@ -1548,11 +1755,25 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): print(f"Unexpected Keys: {unexpected}") @torch.no_grad() - def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, - quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, - plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, - use_ema_scope=True, - **kwargs): + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1.0, + return_keys=None, + quantize_denoised=True, + inpaint=True, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): ema_scope = self.ema_scope if use_ema_scope else nullcontext use_ddim = ddim_steps is not None @@ -1570,16 +1791,16 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta= elif self.cond_stage_key in ["caption", "txt"]: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) log["conditioning"] = xc - elif self.cond_stage_key in ['class_label', 'cls']: + elif self.cond_stage_key in ["class_label", "cls"]: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) - log['conditioning'] = xc + log["conditioning"] = xc elif isimage(xc): log["conditioning"] = xc if ismap(xc): log["original_conditioning"] = self.to_rgb(xc) if not (self.c_concat_log_start is None and self.c_concat_log_end is None): - log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end]) + log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start : self.c_concat_log_end]) if plot_diffusion_rows: # get diffusion row @@ -1587,24 +1808,28 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta= z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) t = t.to(self.device).long() noise = torch.randn_like(z_start) z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy)) diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W - diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') - diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") + diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) log["diffusion_row"] = diffusion_grid if sample: # get denoise row with ema_scope("Sampling"): - samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, - batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta) + samples, z_denoise_row = self.sample_log( + cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log["samples"] = x_samples @@ -1617,12 +1842,15 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta= uc_cat = c_cat uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} with ema_scope("Sampling with classifier-free guidance"): - samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, - batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=uc_full, - ) + samples_cfg, _ = self.sample_log( + cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_full, + ) x_samples_cfg = self.decode_first_stage(samples_cfg) log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg @@ -1634,13 +1862,9 @@ class LatentInpaintDiffusion(LatentFinetuneDiffusion): can either run as pure inpainting model (only concat mode) or with mixed conditionings, e.g. mask as concat and text via cross-attn. To disable finetuning mode, set finetune_keys to None - """ + """ - def __init__(self, - concat_keys=("mask", "masked_image"), - masked_image_key="masked_image", - *args, **kwargs - ): + def __init__(self, concat_keys=("mask", "masked_image"), masked_image_key="masked_image", *args, **kwargs): super().__init__(concat_keys, *args, **kwargs) self.masked_image_key = masked_image_key assert self.masked_image_key in concat_keys @@ -1648,14 +1872,20 @@ def __init__(self, @torch.no_grad() def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): # note: restricted to non-trainable encoders currently - assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting' - z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, - force_c_encode=True, return_original_cond=True, bs=bs) + assert not self.cond_stage_trainable, "trainable cond stages not yet supported for inpainting" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) assert exists(self.concat_keys) c_cat = list() for ck in self.concat_keys: - cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + cc = rearrange(batch[ck], "b h w c -> b c h w").to(memory_format=torch.contiguous_format).float() if bs is not None: cc = cc[:bs] cc = cc.to(self.device) @@ -1674,8 +1904,9 @@ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs @torch.no_grad() def log_images(self, *args, **kwargs): log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs) - log["masked_image"] = rearrange(args[0]["masked_image"], - 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + log["masked_image"] = ( + rearrange(args[0]["masked_image"], "b h w c -> b c h w").to(memory_format=torch.contiguous_format).float() + ) return log @@ -1692,9 +1923,15 @@ def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwarg @torch.no_grad() def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): # note: restricted to non-trainable encoders currently - assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img' - z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, - force_c_encode=True, return_original_cond=True, bs=bs) + assert not self.cond_stage_trainable, "trainable cond stages not yet supported for depth2img" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) assert exists(self.concat_keys) assert len(self.concat_keys) == 1 @@ -1712,9 +1949,10 @@ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs align_corners=False, ) - depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3], - keepdim=True) - cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1. + depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax( + cc, dim=[1, 2, 3], keepdim=True + ) + cc = 2.0 * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.0 c_cat.append(cc) c_cat = torch.cat(c_cat, dim=1) all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} @@ -1726,18 +1964,21 @@ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs def log_images(self, *args, **kwargs): log = super().log_images(*args, **kwargs) depth = self.depth_model(args[0][self.depth_stage_key]) - depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \ - torch.amax(depth, dim=[1, 2, 3], keepdim=True) - log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1. + depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), torch.amax( + depth, dim=[1, 2, 3], keepdim=True + ) + log["depth"] = 2.0 * (depth - depth_min) / (depth_max - depth_min) - 1.0 return log class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): """ - condition on low-res image (and optionally on some spatial noise augmentation) + condition on low-res image (and optionally on some spatial noise augmentation) """ - def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None, - low_scale_config=None, low_scale_key=None, *args, **kwargs): + + def __init__( + self, concat_keys=("lr",), reshuffle_patch_size=None, low_scale_config=None, low_scale_key=None, *args, **kwargs + ): super().__init__(concat_keys=concat_keys, *args, **kwargs) self.reshuffle_patch_size = reshuffle_patch_size self.low_scale_model = None @@ -1757,9 +1998,15 @@ def instantiate_low_stage(self, config): @torch.no_grad() def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): # note: restricted to non-trainable encoders currently - assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft' - z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, - force_c_encode=True, return_original_cond=True, bs=bs) + assert not self.cond_stage_trainable, "trainable cond stages not yet supported for upscaling-ft" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) assert exists(self.concat_keys) assert len(self.concat_keys) == 1 @@ -1768,11 +2015,15 @@ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs noise_level = None for ck in self.concat_keys: cc = batch[ck] - cc = rearrange(cc, 'b h w c -> b c h w') + cc = rearrange(cc, "b h w c -> b c h w") if exists(self.reshuffle_patch_size): assert isinstance(self.reshuffle_patch_size, int) - cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w', - p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size) + cc = rearrange( + cc, + "b c (p1 h) (p2 w) -> b (p1 p2 c) h w", + p1=self.reshuffle_patch_size, + p2=self.reshuffle_patch_size, + ) if bs is not None: cc = cc[:bs] cc = cc.to(self.device) @@ -1791,5 +2042,5 @@ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs @torch.no_grad() def log_images(self, *args, **kwargs): log = super().log_images(*args, **kwargs) - log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w') + log["lr"] = rearrange(args[0]["lr"], "b h w c -> b c h w") return log diff --git a/src/metr/ldm/models/diffusion/dpm_solver/__init__.py b/src/metr/ldm/models/diffusion/dpm_solver/__init__.py index 7427f38..f56611c 100644 --- a/src/metr/ldm/models/diffusion/dpm_solver/__init__.py +++ b/src/metr/ldm/models/diffusion/dpm_solver/__init__.py @@ -1 +1 @@ -from .sampler import DPMSolverSampler \ No newline at end of file +from .sampler import DPMSolverSampler diff --git a/src/metr/ldm/models/diffusion/dpm_solver/dpm_solver.py b/src/metr/ldm/models/diffusion/dpm_solver/dpm_solver.py index 095e5ba..a3995f7 100644 --- a/src/metr/ldm/models/diffusion/dpm_solver/dpm_solver.py +++ b/src/metr/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -1,17 +1,18 @@ +import math + import torch import torch.nn.functional as F -import math from tqdm import tqdm class NoiseScheduleVP: def __init__( - self, - schedule='discrete', - betas=None, - alphas_cumprod=None, - continuous_beta_0=0.1, - continuous_beta_1=20., + self, + schedule="discrete", + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20.0, ): """Create a wrapper class for the forward SDE (VP type). *** @@ -70,50 +71,63 @@ def __init__( >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) """ - if schedule not in ['discrete', 'linear', 'cosine']: + if schedule not in ["discrete", "linear", "cosine"]: raise ValueError( "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( - schedule)) + schedule + ) + ) self.schedule = schedule - if schedule == 'discrete': + if schedule == "discrete": if betas is not None: log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) else: assert alphas_cumprod is not None log_alphas = 0.5 * torch.log(alphas_cumprod) self.total_N = len(log_alphas) - self.T = 1. - self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) - self.log_alpha_array = log_alphas.reshape((1, -1,)) + self.T = 1.0 + self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape( + ( + 1, + -1, + ) + ) else: self.total_N = 1000 self.beta_0 = continuous_beta_0 self.beta_1 = continuous_beta_1 self.cosine_s = 0.008 - self.cosine_beta_max = 999. - self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( - 1. + self.cosine_s) / math.pi - self.cosine_s - self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.cosine_beta_max = 999.0 + self.cosine_t_max = ( + math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi) + * 2.0 + * (1.0 + self.cosine_s) + / math.pi + - self.cosine_s + ) + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)) self.schedule = schedule - if schedule == 'cosine': + if schedule == "cosine": # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. self.T = 0.9946 else: - self.T = 1. + self.T = 1.0 def marginal_log_mean_coeff(self, t): """ Compute log(alpha_t) of a given continuous-time label t in [0, T]. """ - if self.schedule == 'discrete': - return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), - self.log_alpha_array.to(t.device)).reshape((-1)) - elif self.schedule == 'linear': - return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 - elif self.schedule == 'cosine': - log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + if self.schedule == "discrete": + return interpolate_fn( + t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device) + ).reshape((-1)) + elif self.schedule == "linear": + return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == "cosine": + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)) log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 return log_alpha_t @@ -127,48 +141,56 @@ def marginal_std(self, t): """ Compute sigma_t of a given continuous-time label t in [0, T]. """ - return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t))) def marginal_lambda(self, t): """ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. """ log_mean_coeff = self.marginal_log_mean_coeff(t) - log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff)) return log_mean_coeff - log_std def inverse_lambda(self, lamb): """ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. """ - if self.schedule == 'linear': - tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - Delta = self.beta_0 ** 2 + tmp + if self.schedule == "linear": + tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) - elif self.schedule == 'discrete': - log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) - t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), - torch.flip(self.t_array.to(lamb.device), [1])) + elif self.schedule == "discrete": + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb) + t = interpolate_fn( + log_alpha.reshape((-1, 1)), + torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1]), + ) return t.reshape((-1,)) else: - log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( - 1. + self.cosine_s) / math.pi - self.cosine_s + log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + t_fn = ( + lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) + * 2.0 + * (1.0 + self.cosine_s) + / math.pi + - self.cosine_s + ) t = t_fn(log_alpha) return t def model_wrapper( - model, - noise_schedule, - model_type="noise", - model_kwargs={}, - guidance_type="uncond", - condition=None, - unconditional_condition=None, - guidance_scale=1., - classifier_fn=None, - classifier_kwargs={}, + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1.0, + classifier_fn=None, + classifier_kwargs={}, ): """Create a wrapper function for the noise prediction model. DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to @@ -249,8 +271,8 @@ def get_model_input_time(t_continuous): For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. For continuous-time DPMs, we just use `t_continuous`. """ - if noise_schedule.schedule == 'discrete': - return (t_continuous - 1. / noise_schedule.total_N) * 1000. + if noise_schedule.schedule == "discrete": + return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0 else: return t_continuous @@ -302,7 +324,7 @@ def model_fn(x, t_continuous): noise = noise_pred_fn(x, t_continuous) return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad elif guidance_type == "classifier-free": - if guidance_scale == 1. or unconditional_condition is None: + if guidance_scale == 1.0 or unconditional_condition is None: return noise_pred_fn(x, t_continuous, cond=condition) else: x_in = torch.cat([x] * 2) @@ -317,7 +339,7 @@ def model_fn(x, t_continuous): class DPM_Solver: - def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): + def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.0): """Construct a DPM-Solver. We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). @@ -387,20 +409,21 @@ def get_time_steps(self, skip_type, t_T, t_0, N, device): Returns: A pytorch tensor of the time steps, with the shape (N + 1,). """ - if skip_type == 'logSNR': + if skip_type == "logSNR": lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) return self.noise_schedule.inverse_lambda(logSNR_steps) - elif skip_type == 'time_uniform': + elif skip_type == "time_uniform": return torch.linspace(t_T, t_0, N + 1).to(device) - elif skip_type == 'time_quadratic': + elif skip_type == "time_quadratic": t_order = 2 - t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device) + t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device) return t else: raise ValueError( - "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type) + ) def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): """ @@ -435,29 +458,57 @@ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type if order == 3: K = steps // 3 + 1 if steps % 3 == 0: - orders = [3, ] * (K - 2) + [2, 1] + orders = [ + 3, + ] * ( + K - 2 + ) + [2, 1] elif steps % 3 == 1: - orders = [3, ] * (K - 1) + [1] + orders = [ + 3, + ] * ( + K - 1 + ) + [1] else: - orders = [3, ] * (K - 1) + [2] + orders = [ + 3, + ] * ( + K - 1 + ) + [2] elif order == 2: if steps % 2 == 0: K = steps // 2 - orders = [2, ] * K + orders = [ + 2, + ] * K else: K = steps // 2 + 1 - orders = [2, ] * (K - 1) + [1] + orders = [ + 2, + ] * ( + K - 1 + ) + [1] elif order == 1: K = 1 - orders = [1, ] * steps + orders = [ + 1, + ] * steps else: raise ValueError("'order' must be '1' or '2' or '3'.") - if skip_type == 'logSNR': + if skip_type == "logSNR": # To reproduce the results in DPM-Solver paper timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) else: timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ - torch.cumsum(torch.tensor([0, ] + orders)).to(device)] + torch.cumsum( + torch.tensor( + [ + 0, + ] + + orders + ) + ).to(device) + ] return timesteps_outer, orders def denoise_to_zero_fn(self, x, s): @@ -491,12 +542,9 @@ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=Fal phi_1 = torch.expm1(-h) if model_s is None: model_s = self.model_fn(x, s) - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - ) + x_t = expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s if return_intermediate: - return x_t, {'model_s': model_s} + return x_t, {"model_s": model_s} else: return x_t else: @@ -504,16 +552,17 @@ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=Fal if model_s is None: model_s = self.model_fn(x, s) x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s ) if return_intermediate: - return x_t, {'model_s': model_s} + return x_t, {"model_s": model_s} else: return x_t - def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, - solver_type='dpm_solver'): + def singlestep_dpm_solver_second_update( + self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type="dpm_solver" + ): """ Singlestep solver DPM-Solver-2 from time `s` to time `t`. Args: @@ -529,7 +578,7 @@ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, ret Returns: x_t: A pytorch tensor. The approximated solution at time `t`. """ - if solver_type not in ['dpm_solver', 'taylor']: + if solver_type not in ["dpm_solver", "taylor"]: raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) if r1 is None: r1 = 0.5 @@ -539,8 +588,11 @@ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, ret h = lambda_t - lambda_s lambda_s1 = lambda_s + r1 * h s1 = ns.inverse_lambda(lambda_s1) - log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff( - s1), ns.marginal_log_mean_coeff(t) + log_alpha_s, log_alpha_s1, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(t), + ) sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) @@ -550,23 +602,19 @@ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, ret if model_s is None: model_s = self.model_fn(x, s) - x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s - ) + x_s1 = expand_dims(sigma_s1 / sigma_s, dims) * x - expand_dims(alpha_s1 * phi_11, dims) * model_s model_s1 = self.model_fn(x_s1, s1) - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) ) - elif solver_type == 'taylor': + elif solver_type == "taylor": x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * ( - model_s1 - model_s) + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1.0 / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * (model_s1 - model_s) ) else: phi_11 = torch.expm1(r1 * h) @@ -575,29 +623,39 @@ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, ret if model_s is None: model_s = self.model_fn(x, s) x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s ) model_s1 = self.model_fn(x_s1, s1) - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) ) - elif solver_type == 'taylor': + elif solver_type == "taylor": x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1.0 / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * (model_s1 - model_s) ) if return_intermediate: - return x_t, {'model_s': model_s, 'model_s1': model_s1} + return x_t, {"model_s": model_s, "model_s1": model_s1} else: return x_t - def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None, - return_intermediate=False, solver_type='dpm_solver'): + def singlestep_dpm_solver_third_update( + self, + x, + s, + t, + r1=1.0 / 3.0, + r2=2.0 / 3.0, + model_s=None, + model_s1=None, + return_intermediate=False, + solver_type="dpm_solver", + ): """ Singlestep solver DPM-Solver-3 from time `s` to time `t`. Args: @@ -616,12 +674,12 @@ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., mo Returns: x_t: A pytorch tensor. The approximated solution at time `t`. """ - if solver_type not in ['dpm_solver', 'taylor']: + if solver_type not in ["dpm_solver", "taylor"]: raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) if r1 is None: - r1 = 1. / 3. + r1 = 1.0 / 3.0 if r2 is None: - r2 = 2. / 3. + r2 = 2.0 / 3.0 ns = self.noise_schedule dims = x.dim() lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) @@ -630,93 +688,98 @@ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., mo lambda_s2 = lambda_s + r2 * h s1 = ns.inverse_lambda(lambda_s1) s2 = ns.inverse_lambda(lambda_s2) - log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( - s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std( - s2), ns.marginal_std(t) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(s2), + ns.marginal_log_mean_coeff(t), + ) + sigma_s, sigma_s1, sigma_s2, sigma_t = ( + ns.marginal_std(s), + ns.marginal_std(s1), + ns.marginal_std(s2), + ns.marginal_std(t), + ) alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) if self.predict_x0: phi_11 = torch.expm1(-r1 * h) phi_12 = torch.expm1(-r2 * h) phi_1 = torch.expm1(-h) - phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. - phi_2 = phi_1 / h + 1. + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0 + phi_2 = phi_1 / h + 1.0 phi_3 = phi_2 / h - 0.5 if model_s is None: model_s = self.model_fn(x, s) if model_s1 is None: - x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s - ) + x_s1 = expand_dims(sigma_s1 / sigma_s, dims) * x - expand_dims(alpha_s1 * phi_11, dims) * model_s model_s1 = self.model_fn(x_s1, s1) x_s2 = ( - expand_dims(sigma_s2 / sigma_s, dims) * x - - expand_dims(alpha_s2 * phi_12, dims) * model_s - + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) + expand_dims(sigma_s2 / sigma_s, dims) * x + - expand_dims(alpha_s2 * phi_12, dims) * model_s + + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) ) model_s2 = self.model_fn(x_s2, s2) - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1.0 / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) ) - elif solver_type == 'taylor': - D1_0 = (1. / r1) * (model_s1 - model_s) - D1_1 = (1. / r2) * (model_s2 - model_s) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + expand_dims(alpha_t * phi_2, dims) * D1 - - expand_dims(alpha_t * phi_3, dims) * D2 + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + expand_dims(alpha_t * phi_2, dims) * D1 + - expand_dims(alpha_t * phi_3, dims) * D2 ) else: phi_11 = torch.expm1(r1 * h) phi_12 = torch.expm1(r2 * h) phi_1 = torch.expm1(h) - phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. - phi_2 = phi_1 / h - 1. + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0 + phi_2 = phi_1 / h - 1.0 phi_3 = phi_2 / h - 0.5 if model_s is None: model_s = self.model_fn(x, s) if model_s1 is None: x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s ) model_s1 = self.model_fn(x_s1, s1) x_s2 = ( - expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x - - expand_dims(sigma_s2 * phi_12, dims) * model_s - - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) + expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x + - expand_dims(sigma_s2 * phi_12, dims) * model_s + - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) ) model_s2 = self.model_fn(x_s2, s2) - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1.0 / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) ) - elif solver_type == 'taylor': - D1_0 = (1. / r1) * (model_s1 - model_s) - D1_1 = (1. / r2) * (model_s2 - model_s) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - expand_dims(sigma_t * phi_2, dims) * D1 - - expand_dims(sigma_t * phi_3, dims) * D2 + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - expand_dims(sigma_t * phi_2, dims) * D1 + - expand_dims(sigma_t * phi_3, dims) * D2 ) if return_intermediate: - return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2} else: return x_t @@ -733,14 +796,17 @@ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, Returns: x_t: A pytorch tensor. The approximated solution at time `t`. """ - if solver_type not in ['dpm_solver', 'taylor']: + if solver_type not in ["dpm_solver", "taylor"]: raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) ns = self.noise_schedule dims = x.dim() model_prev_1, model_prev_0 = model_prev_list t_prev_1, t_prev_0 = t_prev_list - lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda( - t_prev_0), ns.marginal_lambda(t) + lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) alpha_t = torch.exp(log_alpha_t) @@ -748,36 +814,36 @@ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, h_0 = lambda_prev_0 - lambda_prev_1 h = lambda_t - lambda_prev_0 r0 = h_0 / h - D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1) if self.predict_x0: - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0 + - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * D1_0 ) - elif solver_type == 'taylor': + elif solver_type == "taylor": x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1_0 ) else: - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0 + - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * D1_0 ) - elif solver_type == 'taylor': + elif solver_type == "taylor": x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1_0 ) return x_t - def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): """ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. Args: @@ -794,8 +860,12 @@ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, dims = x.dim() model_prev_2, model_prev_1, model_prev_0 = model_prev_list t_prev_2, t_prev_1, t_prev_0 = t_prev_list - lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda( - t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_2), + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) alpha_t = torch.exp(log_alpha_t) @@ -804,28 +874,29 @@ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, h_0 = lambda_prev_0 - lambda_prev_1 h = lambda_t - lambda_prev_0 r0, r1 = h_0 / h, h_1 / h - D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) - D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) + D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1) + D1_1 = expand_dims(1.0 / r1, dims) * (model_prev_1 - model_prev_2) D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) - D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) + D2 = expand_dims(1.0 / (r0 + r1), dims) * (D1_0 - D1_1) if self.predict_x0: x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 - - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2 + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1 + - expand_dims(alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5), dims) * D2 ) else: x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 - - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2 + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1 + - expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5), dims) * D2 ) return x_t - def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, - r2=None): + def singlestep_dpm_solver_update( + self, x, s, t, order, return_intermediate=False, solver_type="dpm_solver", r1=None, r2=None + ): """ Singlestep DPM-Solver with the order `order` from time `s` to time `t`. Args: @@ -844,15 +915,17 @@ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False if order == 1: return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) elif order == 2: - return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, - solver_type=solver_type, r1=r1) + return self.singlestep_dpm_solver_second_update( + x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1 + ) elif order == 3: - return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, - solver_type=solver_type, r1=r1, r2=r2) + return self.singlestep_dpm_solver_third_update( + x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2 + ) else: raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpm_solver"): """ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. Args: @@ -875,8 +948,9 @@ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, else: raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, - solver_type='dpm_solver'): + def dpm_solver_adaptive( + self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type="dpm_solver" + ): """ The adaptive step size solver based on singlestep DPM-Solver. Args: @@ -906,17 +980,17 @@ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol if order == 2: r1 = 0.5 lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, - solver_type=solver_type, - **kwargs) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, solver_type=solver_type, **kwargs + ) elif order == 3: - r1, r2 = 1. / 3., 2. / 3. - lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, - return_intermediate=True, - solver_type=solver_type) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, - solver_type=solver_type, - **kwargs) + r1, r2 = 1.0 / 3.0, 2.0 / 3.0 + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type + ) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update( + x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs + ) else: raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) while torch.abs((s - t_0)).mean() > t_err: @@ -926,20 +1000,31 @@ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) E = norm_fn((x_higher - x_lower) / delta).max() - if torch.all(E <= 1.): + if torch.all(E <= 1.0): x = x_higher s = t x_prev = x_lower lambda_s = ns.marginal_lambda(s) - h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + h = torch.min(theta * h * torch.float_power(E, -1.0 / order).float(), lambda_0 - lambda_s) nfe += order - print('adaptive solver nfe', nfe) + print("adaptive solver nfe", nfe) return x - def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', - method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', - atol=0.0078, rtol=0.05, - ): + def sample( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=3, + skip_type="time_uniform", + method="singlestep", + lower_order_final=True, + denoise_to_zero=False, + solver_type="dpm_solver", + atol=0.0078, + rtol=0.05, + ): """ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. ===================================================== @@ -1034,14 +1119,15 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time Returns: x_end: A pytorch tensor. The approximated solution at time `t_end`. """ - t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end t_T = self.noise_schedule.T if t_start is None else t_start device = x.device - if method == 'adaptive': + if method == "adaptive": with torch.no_grad(): - x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, - solver_type=solver_type) - elif method == 'multistep': + x = self.dpm_solver_adaptive( + x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type + ) + elif method == "multistep": assert steps >= order timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) assert timesteps.shape[0] - 1 == steps @@ -1052,8 +1138,9 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time # Init the first `order` values by lower order multistep DPM-Solver. for init_order in tqdm(range(1, order), desc="DPM init order"): vec_t = timesteps[init_order].expand(x.shape[0]) - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, - solver_type=solver_type) + x = self.multistep_dpm_solver_update( + x, model_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type + ) model_prev_list.append(self.model_fn(x, vec_t)) t_prev_list.append(vec_t) # Compute the remaining values by `order`-th order multistep DPM-Solver. @@ -1063,8 +1150,9 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time step_order = min(order, steps + 1 - step) else: step_order = order - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, - solver_type=solver_type) + x = self.multistep_dpm_solver_update( + x, model_prev_list, t_prev_list, vec_t, step_order, solver_type=solver_type + ) for i in range(order - 1): t_prev_list[i] = t_prev_list[i + 1] model_prev_list[i] = model_prev_list[i + 1] @@ -1072,20 +1160,22 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time # We do not need to evaluate the final model value. if step < steps: model_prev_list[-1] = self.model_fn(x, vec_t) - elif method in ['singlestep', 'singlestep_fixed']: - if method == 'singlestep': - timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, - skip_type=skip_type, - t_T=t_T, t_0=t_0, - device=device) - elif method == 'singlestep_fixed': + elif method in ["singlestep", "singlestep_fixed"]: + if method == "singlestep": + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver( + steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device + ) + elif method == "singlestep_fixed": K = steps // order - orders = [order, ] * K + orders = [ + order, + ] * K timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) for i, order in enumerate(orders): t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] - timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), - N=order, device=device) + timesteps_inner = self.get_time_steps( + skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), N=order, device=device + ) lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) h = lambda_inner[-1] - lambda_inner[0] @@ -1101,6 +1191,7 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time # other utility functions ############################################################# + def interpolate_fn(x, xp, yp): """ A piecewise linear function y = f(x), using xp and yp as keypoints. @@ -1122,7 +1213,9 @@ def interpolate_fn(x, xp, yp): torch.eq(x_idx, 0), torch.tensor(1, device=x.device), torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, ), ) end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) @@ -1132,7 +1225,9 @@ def interpolate_fn(x, xp, yp): torch.eq(x_idx, 0), torch.tensor(0, device=x.device), torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, ), ) y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) @@ -1151,4 +1246,4 @@ def expand_dims(v, dims): Returns: a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. """ - return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file + return v[(...,) + (None,) * (dims - 1)] diff --git a/src/metr/ldm/models/diffusion/dpm_solver/sampler.py b/src/metr/ldm/models/diffusion/dpm_solver/sampler.py index 7d137b8..4104fe3 100644 --- a/src/metr/ldm/models/diffusion/dpm_solver/sampler.py +++ b/src/metr/ldm/models/diffusion/dpm_solver/sampler.py @@ -1,13 +1,10 @@ """SAMPLING ONLY.""" -import torch -from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver +import torch +from .dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper -MODEL_TYPES = { - "eps": "noise", - "v": "v" -} +MODEL_TYPES = {"eps": "noise", "v": "v"} class DPMSolverSampler(object): @@ -15,7 +12,7 @@ def __init__(self, model, **kwargs): super().__init__() self.model = model to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) - self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + self.register_buffer("alphas_cumprod", to_torch(model.alphas_cumprod)) def register_buffer(self, name, attr): if type(attr) == torch.Tensor: @@ -24,30 +21,31 @@ def register_buffer(self, name, attr): setattr(self, name, attr) @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs - ): + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs, + ): if conditioning is not None: if isinstance(conditioning, dict): cbs = conditioning[list(conditioning.keys())[0]].shape[0] @@ -61,7 +59,7 @@ def sample(self, C, H, W = shape size = (batch_size, C, H, W) - print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') + print(f"Data shape for DPM-Solver sampling is {size}, sampling steps {S}") device = self.model.betas.device if x_T is None: @@ -69,7 +67,7 @@ def sample(self, else: img = x_T - ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod) model_fn = model_wrapper( lambda x, t, c: self.model.apply_model(x, t, c), @@ -82,6 +80,8 @@ def sample(self, ) dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) - x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) + x = dpm_solver.sample( + img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True + ) - return x.to(device), None \ No newline at end of file + return x.to(device), None diff --git a/src/metr/ldm/models/diffusion/plms.py b/src/metr/ldm/models/diffusion/plms.py index 7002a36..3e73b5a 100644 --- a/src/metr/ldm/models/diffusion/plms.py +++ b/src/metr/ldm/models/diffusion/plms.py @@ -1,12 +1,12 @@ """SAMPLING ONLY.""" -import torch -import numpy as np -from tqdm import tqdm from functools import partial -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +import numpy as np +import torch from ldm.models.diffusion.sampling_util import norm_thresholding +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from tqdm import tqdm class PLMSSampler(object): @@ -22,65 +22,72 @@ def register_buffer(self, name, attr): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True): if ddim_eta != 0: - raise ValueError('ddim_eta must be 0 for PLMS') - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + raise ValueError("ddim_eta must be 0 for PLMS") + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, "alphas have to be defined for each timestep" to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - self.register_buffer('betas', to_torch(self.model.betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod.cpu()))) + self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))) + self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))) + self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1))) # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=ddim_eta, verbose=verbose + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer("ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps) @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - dynamic_threshold=None, - **kwargs - ): + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + **kwargs, + ): if conditioning is not None: if isinstance(conditioning, dict): cbs = conditioning[list(conditioning.keys())[0]].shape[0] @@ -94,34 +101,51 @@ def sample(self, # sampling C, H, W = shape size = (batch_size, C, H, W) - print(f'Data shape for PLMS sampling is {size}') - - samples, intermediates = self.plms_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ) + print(f"Data shape for PLMS sampling is {size}") + + samples, intermediates = self.plms_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) return samples, intermediates @torch.no_grad() - def plms_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, - dynamic_threshold=None): + def plms_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ): device = self.model.betas.device b = shape[0] if x_T is None: @@ -135,12 +159,12 @@ def plms_sampling(self, cond, shape, subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 timesteps = self.ddim_timesteps[:subset_end] - intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps) total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] print(f"Running PLMS Sampling with {total_steps} timesteps") - iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps) old_eps = [] for i, step in enumerate(iterator): @@ -151,38 +175,64 @@ def plms_sampling(self, cond, shape, if mask is not None: assert x0 is not None img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? - img = img_orig * mask + (1. - mask) * img - - outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - old_eps=old_eps, t_next=ts_next, - dynamic_threshold=dynamic_threshold) + img = img_orig * mask + (1.0 - mask) * img + + outs = self.p_sample_plms( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, + t_next=ts_next, + dynamic_threshold=dynamic_threshold, + ) img, pred_x0, e_t = outs old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) return img, intermediates @torch.no_grad() - def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, - dynamic_threshold=None): + def p_sample_plms( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + old_eps=None, + t_next=None, + dynamic_threshold=None, + ): b, *_, device = *x.shape, x.device def get_model_output(x, t): - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: e_t = self.model.apply_model(x, t, c) else: x_in = torch.cat([x] * 2) @@ -199,7 +249,9 @@ def get_model_output(x, t): alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + ) sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas def get_x_prev_and_pred_x0(e_t, index): @@ -207,7 +259,7 @@ def get_x_prev_and_pred_x0(e_t, index): a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device) # current prediction for x_0 pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() @@ -216,9 +268,9 @@ def get_x_prev_and_pred_x0(e_t, index): if dynamic_threshold is not None: pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: + if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise return x_prev, pred_x0 diff --git a/src/metr/ldm/models/diffusion/sampling_util.py b/src/metr/ldm/models/diffusion/sampling_util.py index 7eff02b..e8b1ec6 100644 --- a/src/metr/ldm/models/diffusion/sampling_util.py +++ b/src/metr/ldm/models/diffusion/sampling_util.py @@ -1,5 +1,5 @@ -import torch import numpy as np +import torch def append_dims(x, target_dims): @@ -7,7 +7,7 @@ def append_dims(x, target_dims): From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: - raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") return x[(...,) + (None,) * dims_to_append] @@ -19,4 +19,4 @@ def norm_thresholding(x0, value): def spatial_norm_thresholding(x0, value): # b c h w s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) - return x0 * (value / s) \ No newline at end of file + return x0 * (value / s) diff --git a/src/metr/ldm/modules/attention.py b/src/metr/ldm/modules/attention.py index 509cd87..0cbcc0d 100644 --- a/src/metr/ldm/modules/attention.py +++ b/src/metr/ldm/modules/attention.py @@ -1,31 +1,33 @@ -from inspect import isfunction import math +from inspect import isfunction +from typing import Any, Optional + import torch import torch.nn.functional as F -from torch import nn, einsum from einops import rearrange, repeat -from typing import Optional, Any - from ldm.modules.diffusionmodules.util import checkpoint - +from torch import einsum, nn try: import xformers import xformers.ops + XFORMERS_IS_AVAILBLE = True except: XFORMERS_IS_AVAILBLE = False # CrossAttn precision handling import os + _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") + def exists(val): return val is not None def uniq(arr): - return{el: True for el in arr}.keys() + return {el: True for el in arr}.keys() def default(val, d): @@ -57,20 +59,13 @@ def forward(self, x): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) - project_in = nn.Sequential( - nn.Linear(dim, inner_dim), - nn.GELU() - ) if not glu else GEGLU(dim, inner_dim) - - self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) - ) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) def forward(self, x): return self.net(x) @@ -95,26 +90,10 @@ def __init__(self, in_channels): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h_ = x @@ -124,41 +103,38 @@ def forward(self, x): v = self.v(h_) # compute attention - b,c,h,w = q.shape - q = rearrange(q, 'b c h w -> b (h w) c') - k = rearrange(k, 'b c h w -> b c (h w)') - w_ = torch.einsum('bij,bjk->bik', q, k) + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) - w_ = w_ * (int(c)**(-0.5)) + w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values - v = rearrange(v, 'b c h w -> b c (h w)') - w_ = rearrange(w_, 'b i j -> b j i') - h_ = torch.einsum('bij,bjk->bik', v, w_) - h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) h_ = self.proj_out(h_) - return x+h_ + return x + h_ class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), - nn.Dropout(dropout) - ) + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) def forward(self, x, context=None, mask=None): h = self.heads @@ -168,29 +144,29 @@ def forward(self, x, context=None, mask=None): k = self.to_k(context) v = self.to_v(context) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) # force cast to fp32 to avoid overflowing - if _ATTN_PRECISION =="fp32": - with torch.autocast(enabled=False, device_type = 'cuda'): + if _ATTN_PRECISION == "fp32": + with torch.autocast(enabled=False, device_type="cuda"): q, k = q.float(), k.float() - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + sim = einsum("b i d, b j d -> b i j", q, k) * self.scale else: - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - + sim = einsum("b i d, b j d -> b i j", q, k) * self.scale + del q, k - + if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') + mask = rearrange(mask, "b ... -> b (...)") max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) + mask = repeat(mask, "b j -> (b h) () j", h=h) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of sim = sim.softmax(dim=-1) - out = einsum('b i j, b j d -> b i d', sim, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + out = einsum("b i j, b j d -> b i d", sim, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) return self.to_out(out) @@ -198,8 +174,10 @@ class MemoryEfficientCrossAttention(nn.Module): # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() - print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " - f"{heads} heads.") + print( + f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads." + ) inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -246,20 +224,36 @@ def forward(self, x, context=None, mask=None): class BasicTransformerBlock(nn.Module): ATTENTION_MODES = { "softmax": CrossAttention, # vanilla attention - "softmax-xformers": MemoryEfficientCrossAttention + "softmax-xformers": MemoryEfficientCrossAttention, } - def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, - disable_self_attn=False): + + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False, + ): super().__init__() attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" assert attn_mode in self.ATTENTION_MODES attn_cls = self.ATTENTION_MODES[attn_mode] self.disable_self_attn = disable_self_attn - self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, - context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None, + ) # is a self-attention if not self.disable_self_attn self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, - heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.attn2 = attn_cls( + query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) @@ -284,10 +278,19 @@ class SpatialTransformer(nn.Module): Finally, reshape to image NEW: use_linear for more efficiency instead of the 1x1 convs """ - def __init__(self, in_channels, n_heads, d_head, - depth=1, dropout=0., context_dim=None, - disable_self_attn=False, use_linear=False, - use_checkpoint=True): + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True, + ): super().__init__() if exists(context_dim) and not isinstance(context_dim, list): context_dim = [context_dim] @@ -295,25 +298,26 @@ def __init__(self, in_channels, n_heads, d_head, inner_dim = n_heads * d_head self.norm = Normalize(in_channels) if not use_linear: - self.proj_in = nn.Conv2d(in_channels, - inner_dim, - kernel_size=1, - stride=1, - padding=0) + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) else: self.proj_in = nn.Linear(in_channels, inner_dim) self.transformer_blocks = nn.ModuleList( - [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], - disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) - for d in range(depth)] + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint, + ) + for d in range(depth) + ] ) if not use_linear: - self.proj_out = zero_module(nn.Conv2d(inner_dim, - in_channels, - kernel_size=1, - stride=1, - padding=0)) + self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) else: self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) self.use_linear = use_linear @@ -327,15 +331,14 @@ def forward(self, x, context=None): x = self.norm(x) if not self.use_linear: x = self.proj_in(x) - x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + x = rearrange(x, "b c h w -> b (h w) c").contiguous() if self.use_linear: x = self.proj_in(x) for i, block in enumerate(self.transformer_blocks): x = block(x, context=context[i]) if self.use_linear: x = self.proj_out(x) - x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() if not self.use_linear: x = self.proj_out(x) return x + x_in - diff --git a/src/metr/ldm/modules/diffusionmodules/model.py b/src/metr/ldm/modules/diffusionmodules/model.py index b089eeb..8b0f4de 100644 --- a/src/metr/ldm/modules/diffusionmodules/model.py +++ b/src/metr/ldm/modules/diffusionmodules/model.py @@ -1,16 +1,17 @@ # pytorch_diffusion + derived encoder decoder import math +from typing import Any, Optional + +import numpy as np import torch import torch.nn as nn -import numpy as np from einops import rearrange -from typing import Optional, Any - from ldm.modules.attention import MemoryEfficientCrossAttention try: import xformers import xformers.ops + XFORMERS_IS_AVAILBLE = True except: XFORMERS_IS_AVAILBLE = False @@ -34,13 +35,13 @@ def get_timestep_embedding(timesteps, embedding_dim): emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0,1,0,0)) + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb def nonlinearity(x): # swish - return x*torch.sigmoid(x) + return x * torch.sigmoid(x) def Normalize(in_channels, num_groups=32): @@ -52,11 +53,7 @@ def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") @@ -71,15 +68,11 @@ def __init__(self, in_channels, with_conv): self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=2, - padding=0) + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) def forward(self, x): if self.with_conv: - pad = (0,1,0,1) + pad = (0, 1, 0, 1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: @@ -88,8 +81,7 @@ def forward(self, x): class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, - dropout, temb_channels=512): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels @@ -97,34 +89,17 @@ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, - out_channels) + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0) + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x, temb): h = x @@ -133,7 +108,7 @@ def forward(self, x, temb): h = self.conv1(h) if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] h = self.norm2(h) h = nonlinearity(h) @@ -146,7 +121,7 @@ def forward(self, x, temb): else: x = self.nin_shortcut(x) - return x+h + return x + h class AttnBlock(nn.Module): @@ -155,26 +130,10 @@ def __init__(self, in_channels): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h_ = x @@ -184,56 +143,42 @@ def forward(self, x): v = self.v(h_) # compute attention - b,c,h,w = q.shape - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c)**(-0.5)) + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values - v = v.reshape(b,c,h*w) - w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b,c,h,w) + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) h_ = self.proj_out(h_) - return x+h_ + return x + h_ + class MemoryEfficientAttnBlock(nn.Module): """ - Uses xformers efficient implementation, - see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - Note: this is a single-head self-attention operation + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation """ + # def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.attention_op: Optional[Any] = None def forward(self, x): @@ -245,7 +190,7 @@ def forward(self, x): # compute attention B, C, H, W = q.shape - q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) + q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) q, k, v = map( lambda t: t.unsqueeze(3) @@ -257,28 +202,29 @@ def forward(self, x): ) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) - out = ( - out.unsqueeze(0) - .reshape(B, 1, out.shape[1], C) - .permute(0, 2, 1, 3) - .reshape(B, out.shape[1], C) - ) - out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C) + out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) out = self.proj_out(out) - return x+out + return x + out class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): def forward(self, x, context=None, mask=None): b, c, h, w = x.shape - x = rearrange(x, 'b c h w -> b (h w) c') + x = rearrange(x, "b c h w -> b (h w) c") out = super().forward(x, context=context, mask=mask) - out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c) + out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c) return x + out def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): - assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' + assert attn_type in [ + "vanilla", + "vanilla-xformers", + "memory-efficient-cross-attn", + "linear", + "none", + ], f"attn_type {attn_type} unknown" if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": attn_type = "vanilla-xformers" print(f"making attention of type '{attn_type}' with {in_channels} in_channels") @@ -298,13 +244,27 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): class Model(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = "linear" self.ch = ch - self.temb_ch = self.ch*4 + self.temb_ch = self.ch * 4 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution @@ -314,70 +274,69 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, if self.use_timestep: # timestep embedding self.temb = nn.Module() - self.temb.dense = nn.ModuleList([ - torch.nn.Linear(self.ch, - self.temb_ch), - torch.nn.Linear(self.temb_ch, - self.temb_ch), - ]) + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) + in_ch_mult = (1,) + tuple(ch_mult) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - skip_in = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): if i_block == self.num_res_blocks: - skip_in = ch*in_ch_mult[i_level] - block.append(ResnetBlock(in_channels=block_in+skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -387,18 +346,14 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, x, t=None, context=None): - #assert x.shape[2] == x.shape[3] == self.resolution + # assert x.shape[2] == x.shape[3] == self.resolution if context is not None: # assume aligned context, cat along channel axis x = torch.cat((x, context), dim=1) @@ -420,7 +375,7 @@ def forward(self, x, t=None, context=None): if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle @@ -431,9 +386,8 @@ def forward(self, x, t=None, context=None): # upsampling for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block]( - torch.cat([h, hs.pop()], dim=1), temb) + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: @@ -450,12 +404,27 @@ def get_last_layer(self): class Encoder(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", - **ignore_kwargs): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs, + ): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) @@ -464,56 +433,49 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, self.in_channels = in_channels # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) + in_ch_mult = (1,) + tuple(ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - 2*z_channels if double_z else z_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d( + block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 + ) def forward(self, x): # timestep embedding @@ -527,7 +489,7 @@ def forward(self, x): if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle @@ -544,12 +506,28 @@ def forward(self, x): class Decoder(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, - attn_type="vanilla", **ignorekwargs): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) @@ -560,43 +538,37 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, self.tanh_out = tanh_out # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,)+tuple(ch_mult) - block_in = ch*ch_mult[self.num_resolutions-1] - curr_res = resolution // 2**(self.num_resolutions-1) - self.z_shape = (1,z_channels,curr_res,curr_res) - print("Working with z of shape {} = {} dimensions.".format( - self.z_shape, np.prod(self.z_shape))) + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) # z to block_in - self.conv_in = torch.nn.Conv2d(z_channels, - block_in, - kernel_size=3, - stride=1, - padding=1) + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -606,18 +578,14 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, z): - #assert z.shape[1:] == self.z_shape[1:] + # assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape # timestep embedding @@ -633,7 +601,7 @@ def forward(self, z): # upsampling for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): + for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) @@ -655,29 +623,23 @@ def forward(self, z): class SimpleDecoder(nn.Module): def __init__(self, in_channels, out_channels, *args, **kwargs): super().__init__() - self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), - ResnetBlock(in_channels=in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=2 * in_channels, - out_channels=4 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=4 * in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - nn.Conv2d(2*in_channels, in_channels, 1), - Upsample(in_channels, with_conv=True)]) + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) # end self.norm_out = Normalize(in_channels) - self.conv_out = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): for i, layer in enumerate(self.model): - if i in [1,2,3]: + if i in [1, 2, 3]: x = layer(x, None) else: x = layer(x) @@ -689,8 +651,7 @@ def forward(self, x): class UpsampleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, - ch_mult=(2,2), dropout=0.0): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0): super().__init__() # upsampling self.temb_ch = 0 @@ -704,10 +665,11 @@ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, res_block = [] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): - res_block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + res_block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) block_in = block_out self.res_blocks.append(nn.ModuleList(res_block)) if i_level != self.num_resolutions - 1: @@ -716,11 +678,7 @@ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): # upsampling @@ -741,31 +699,34 @@ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): super().__init__() # residual block, interpolate, residual block self.factor = factor - self.conv_in = nn.Conv2d(in_channels, - mid_channels, - kernel_size=3, - stride=1, - padding=1) - self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0) for _ in range(depth)]) + self.conv_in = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1) + self.res_block1 = nn.ModuleList( + [ + ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) + for _ in range(depth) + ] + ) self.attn = AttnBlock(mid_channels) - self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0) for _ in range(depth)]) + self.res_block2 = nn.ModuleList( + [ + ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) + for _ in range(depth) + ] + ) - self.conv_out = nn.Conv2d(mid_channels, - out_channels, - kernel_size=1, - ) + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) def forward(self, x): x = self.conv_in(x) for block in self.res_block1: x = block(x, None) - x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = torch.nn.functional.interpolate( + x, size=(int(round(x.shape[2] * self.factor)), int(round(x.shape[3] * self.factor))) + ) x = self.attn(x) for block in self.res_block2: x = block(x, None) @@ -774,17 +735,42 @@ def forward(self, x): class MergedRescaleEncoder(nn.Module): - def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, - ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): super().__init__() intermediate_chn = ch * ch_mult[-1] - self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, - z_channels=intermediate_chn, double_z=False, resolution=resolution, - attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, - out_ch=None) - self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, - mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) def forward(self, x): x = self.encoder(x) @@ -793,15 +779,41 @@ def forward(self, x): class MergedRescaleDecoder(nn.Module): - def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), - dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): super().__init__() - tmp_chn = z_channels*ch_mult[-1] - self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, - resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, - ch_mult=ch_mult, resolution=resolution, ch=ch) - self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, - out_channels=tmp_chn, depth=rescale_module_depth) + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) def forward(self, x): x = self.rescaler(x) @@ -813,14 +825,24 @@ class Upsampler(nn.Module): def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): super().__init__() assert out_size >= in_size - num_blocks = int(np.log2(out_size//in_size))+1 - factor_up = 1.+ (out_size % in_size) - print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") - self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, - out_channels=in_channels) - self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, - attn_resolutions=[], in_channels=None, ch=in_channels, - ch_mult=[ch_mult for _ in range(num_blocks)]) + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1.0 + (out_size % in_size) + print( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + ) + self.rescaler = LatentRescaler( + factor=factor_up, in_channels=in_channels, mid_channels=2 * in_channels, out_channels=in_channels + ) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)], + ) def forward(self, x): x = self.rescaler(x) @@ -838,14 +860,10 @@ def __init__(self, in_channels=None, learned=False, mode="bilinear"): raise NotImplementedError() assert in_channels is not None # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=4, - stride=2, - padding=1) + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1) def forward(self, x, scale_factor=1.0): - if scale_factor==1.0: + if scale_factor == 1.0: return x else: x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) diff --git a/src/metr/ldm/modules/diffusionmodules/openaimodel.py b/src/metr/ldm/modules/diffusionmodules/openaimodel.py index 7df6b5a..2ff9f69 100644 --- a/src/metr/ldm/modules/diffusionmodules/openaimodel.py +++ b/src/metr/ldm/modules/diffusionmodules/openaimodel.py @@ -1,21 +1,20 @@ -from abc import abstractmethod import math +from abc import abstractmethod import numpy as np import torch as th import torch.nn as nn import torch.nn.functional as F - +from ldm.modules.attention import SpatialTransformer from ldm.modules.diffusionmodules.util import ( + avg_pool_nd, checkpoint, conv_nd, linear, - avg_pool_nd, - zero_module, normalization, timestep_embedding, + zero_module, ) -from ldm.modules.attention import SpatialTransformer from ldm.util import exists @@ -23,6 +22,7 @@ def convert_module_to_f16(x): pass + def convert_module_to_f32(x): pass @@ -41,7 +41,7 @@ def __init__( output_dim: int = None, ): super().__init__() - self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.num_heads = embed_dim // num_heads_channels @@ -108,25 +108,25 @@ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): def forward(self, x): assert x.shape[1] == self.channels if self.dims == 3: - x = F.interpolate( - x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - ) + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") else: x = F.interpolate(x, scale_factor=2, mode="nearest") if self.use_conv: x = self.conv(x) return x + class TransposedUpsample(nn.Module): - 'Learned 2x upsampling without padding' + "Learned 2x upsampling without padding" + def __init__(self, channels, out_channels=None, ks=5): super().__init__() self.channels = channels self.out_channels = out_channels or channels - self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + self.up = nn.ConvTranspose2d(self.channels, self.out_channels, kernel_size=ks, stride=2) - def forward(self,x): + def forward(self, x): return self.up(x) @@ -139,7 +139,7 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -147,9 +147,7 @@ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): self.dims = dims stride = 2 if dims != 3 else (1, 2, 2) if use_conv: - self.op = conv_nd( - dims, self.channels, self.out_channels, 3, stride=stride, padding=padding - ) + self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) @@ -225,17 +223,13 @@ def __init__( normalization(self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) - ), + zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: - self.skip_connection = conv_nd( - dims, channels, self.out_channels, 3, padding=1 - ) + self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) @@ -246,10 +240,7 @@ def forward(self, x, emb): :param emb: an [N x emb_channels] Tensor of timestep embeddings. :return: an [N x C x ...] Tensor of outputs. """ - return checkpoint( - self._forward, (x, emb), self.parameters(), self.use_checkpoint - ) - + return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint) def _forward(self, x, emb): if self.updown: @@ -311,8 +302,10 @@ def __init__( self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) def forward(self, x): - return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! - #return pt_checkpoint(self._forward, x) # pytorch + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch def _forward(self, x): b, c, *spatial = x.shape @@ -339,7 +332,7 @@ def count_flops_attn(model, _x, y): # We perform two matmuls with the same number of ops. # The first computes the weight matrix, the second computes # the combination of the value vectors. - matmul_ops = 2 * b * (num_spatial ** 2) * c + matmul_ops = 2 * b * (num_spatial**2) * c model.total_ops += th.DoubleTensor([matmul_ops]) @@ -363,9 +356,7 @@ def forward(self, qkv): ch = width // (3 * self.n_heads) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards + weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) a = th.einsum("bts,bcs->bct", weight, v) return a.reshape(bs, -1, length) @@ -460,10 +451,10 @@ def __init__( use_scale_shift_norm=False, resblock_updown=False, use_new_attention_order=False, - use_spatial_transformer=False, # custom transformer support - transformer_depth=1, # custom transformer support - context_dim=None, # custom transformer support - n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model legacy=True, disable_self_attentions=None, num_attention_blocks=None, @@ -472,11 +463,16 @@ def __init__( ): super().__init__() if use_spatial_transformer: - assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + assert ( + context_dim is not None + ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." if context_dim is not None: - assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: context_dim = list(context_dim) @@ -484,10 +480,10 @@ def __init__( num_heads_upsample = num_heads if num_heads == -1: - assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set" if num_head_channels == -1: - assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + assert num_heads != -1, "Either num_heads or num_head_channels has to be set" self.image_size = image_size self.in_channels = in_channels @@ -497,19 +493,25 @@ def __init__( self.num_res_blocks = len(channel_mult) * [num_res_blocks] else: if len(num_res_blocks) != len(channel_mult): - raise ValueError("provide num_res_blocks either as an int (globally constant) or " - "as a list/tuple (per-level) with the same length as channel_mult") + raise ValueError( + "provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult" + ) self.num_res_blocks = num_res_blocks if disable_self_attentions is not None: # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not assert len(disable_self_attentions) == len(channel_mult) if num_attention_blocks is not None: assert len(num_attention_blocks) == len(self.num_res_blocks) - assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) - print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " - f"This option has LESS priority than attention_resolutions {attention_resolutions}, " - f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " - f"attention will still not be set.") + assert all( + map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))) + ) + print( + f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set." + ) self.attention_resolutions = attention_resolutions self.dropout = dropout @@ -540,11 +542,7 @@ def __init__( raise ValueError() self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] + [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] ) self._feature_size = model_channels input_block_chans = [model_channels] @@ -571,7 +569,7 @@ def __init__( num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 + # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels if exists(disable_self_attentions): disabled_sa = disable_self_attentions[level] @@ -586,10 +584,17 @@ def __init__( num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) @@ -610,9 +615,7 @@ def __init__( down=True, ) if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) ) ) ch = out_ch @@ -626,7 +629,7 @@ def __init__( num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 + # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels self.middle_block = TimestepEmbedSequential( ResBlock( @@ -637,17 +640,26 @@ def __init__( use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint - ), + ( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( # always uses a self-attn + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + ) + ), ResBlock( ch, time_embed_dim, @@ -682,7 +694,7 @@ def __init__( num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 + # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels if exists(disable_self_attentions): disabled_sa = disable_self_attentions[level] @@ -697,10 +709,17 @@ def __init__( num_heads=num_heads_upsample, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, ) ) if level and i == self.num_res_blocks[level]: @@ -730,10 +749,10 @@ def __init__( ) if self.predict_codebook_ids: self.id_predictor = nn.Sequential( - normalization(ch), - conv_nd(dims, model_channels, n_embed, 1), - #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits - ) + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) def convert_to_fp16(self): """ @@ -751,7 +770,7 @@ def convert_to_fp32(self): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. diff --git a/src/metr/ldm/modules/diffusionmodules/upscaling.py b/src/metr/ldm/modules/diffusionmodules/upscaling.py index 0381666..82cc215 100644 --- a/src/metr/ldm/modules/diffusionmodules/upscaling.py +++ b/src/metr/ldm/modules/diffusionmodules/upscaling.py @@ -1,8 +1,8 @@ -import torch -import torch.nn as nn -import numpy as np from functools import partial +import numpy as np +import torch +import torch.nn as nn from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule from ldm.util import default @@ -14,37 +14,41 @@ def __init__(self, noise_schedule_config=None): if noise_schedule_config is not None: self.register_schedule(**noise_schedule_config) - def register_schedule(self, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, - cosine_s=cosine_s) - alphas = 1. - betas + def register_schedule( + self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 + ): + betas = make_beta_schedule( + beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s + ) + alphas = 1.0 - betas alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) - timesteps, = betas.shape + (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end - assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + assert alphas_cumprod.shape[0] == self.num_timesteps, "alphas have to be defined for each timestep" to_torch = partial(torch.tensor, dtype=torch.float32) - self.register_buffer('betas', to_torch(betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))) + self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))) + self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))) + self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) def forward(self, x): return x, None @@ -76,6 +80,3 @@ def forward(self, x, noise_level=None): assert isinstance(noise_level, torch.Tensor) z = self.q_sample(x, noise_level) return z, noise_level - - - diff --git a/src/metr/ldm/modules/diffusionmodules/util.py b/src/metr/ldm/modules/diffusionmodules/util.py index 637363d..996d2da 100644 --- a/src/metr/ldm/modules/diffusionmodules/util.py +++ b/src/metr/ldm/modules/diffusionmodules/util.py @@ -8,26 +8,22 @@ # thanks! -import os import math +import os + +import numpy as np import torch import torch.nn as nn -import numpy as np from einops import repeat - from ldm.util import instantiate_from_config def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): if schedule == "linear": - betas = ( - torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 - ) + betas = torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2 elif schedule == "cosine": - timesteps = ( - torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s - ) + timesteps = torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s alphas = timesteps / (1 + cosine_s) * np.pi / 2 alphas = torch.cos(alphas).pow(2) alphas = alphas / alphas[0] @@ -44,11 +40,11 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): - if ddim_discr_method == 'uniform': + if ddim_discr_method == "uniform": c = num_ddpm_timesteps // num_ddim_timesteps ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) - elif ddim_discr_method == 'quad': - ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + elif ddim_discr_method == "quad": + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2).astype(int) else: raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') @@ -56,7 +52,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep # add one to get the final alpha values right (the ones from first scale to data during sampling) steps_out = ddim_timesteps + 1 if verbose: - print(f'Selected timesteps for ddim sampler: {steps_out}') + print(f"Selected timesteps for ddim sampler: {steps_out}") return steps_out @@ -68,9 +64,11 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): # according the the formula provided in https://arxiv.org/abs/2010.02502 sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) if verbose: - print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') - print(f'For the chosen value of eta, which is {eta}, ' - f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + print(f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}") + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) return sigmas, alphas, alphas_prev @@ -122,9 +120,11 @@ def forward(ctx, run_function, length, *args): ctx.run_function = run_function ctx.input_tensors = list(args[:length]) ctx.input_params = list(args[length:]) - ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), - "dtype": torch.get_autocast_gpu_dtype(), - "cache_enabled": torch.is_autocast_cache_enabled()} + ctx.gpu_autocast_kwargs = { + "enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled(), + } with torch.no_grad(): output_tensors = ctx.run_function(*ctx.input_tensors) return output_tensors @@ -132,8 +132,7 @@ def forward(ctx, run_function, length, *args): @staticmethod def backward(ctx, *output_grads): ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] - with torch.enable_grad(), \ - torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): # Fixes a bug where the first op in run_function modifies the # Tensor storage in place, which is not allowed for detach()'d # Tensors. @@ -162,15 +161,15 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): """ if not repeat_only: half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half - ).to(device=timesteps.device) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) else: - embedding = repeat(timesteps, 'b -> b d', d=dim) + embedding = repeat(timesteps, "b -> b d", d=dim) return embedding @@ -218,6 +217,7 @@ class GroupNorm32(nn.GroupNorm): def forward(self, x): return super().forward(x.float()).type(x.dtype) + def conv_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. @@ -261,10 +261,10 @@ def __init__(self, c_concat_config, c_crossattn_config): def forward(self, c_concat, c_crossattn): c_concat = self.concat_conditioner(c_concat) c_crossattn = self.crossattn_conditioner(c_crossattn) - return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} def noise_like(shape, device, repeat=False): repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() \ No newline at end of file + return repeat_noise() if repeat else noise() diff --git a/src/metr/ldm/modules/distributions/distributions.py b/src/metr/ldm/modules/distributions/distributions.py index f2b8ef9..b5f3b1a 100644 --- a/src/metr/ldm/modules/distributions/distributions.py +++ b/src/metr/ldm/modules/distributions/distributions.py @@ -1,5 +1,5 @@ -import torch import numpy as np +import torch class AbstractDistribution: @@ -38,25 +38,25 @@ def sample(self): def kl(self, other=None): if self.deterministic: - return torch.Tensor([0.]) + return torch.Tensor([0.0]) else: if other is None: - return 0.5 * torch.sum(torch.pow(self.mean, 2) - + self.var - 1.0 - self.logvar, - dim=[1, 2, 3]) + return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - 1.0 - self.logvar + other.logvar, - dim=[1, 2, 3]) - - def nll(self, sample, dims=[1,2,3]): + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: - return torch.Tensor([0.]) + return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) def mode(self): return self.mean @@ -78,15 +78,8 @@ def normal_kl(mean1, logvar1, mean2, logvar2): # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for torch.exp(). - logvar1, logvar2 = [ - x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) - for x in (logvar1, logvar2) - ] + logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] return 0.5 * ( - -1.0 - + logvar2 - - logvar1 - + torch.exp(logvar1 - logvar2) - + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) ) diff --git a/src/metr/ldm/modules/ema.py b/src/metr/ldm/modules/ema.py index c8c75af..7a0c970 100644 --- a/src/metr/ldm/modules/ema.py +++ b/src/metr/ldm/modules/ema.py @@ -6,28 +6,29 @@ class LitEma(nn.Module): def __init__(self, model, decay=0.9999, use_num_upates=True): super().__init__() if decay < 0.0 or decay > 1.0: - raise ValueError('Decay must be between 0 and 1') + raise ValueError("Decay must be between 0 and 1") self.m_name2s_name = {} - self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) - self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates - else torch.tensor(-1,dtype=torch.int)) + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int) + ) for name, p in model.named_parameters(): if p.requires_grad: - #remove as '.'-character is not allowed in buffers - s_name = name.replace('.','') - self.m_name2s_name.update({name:s_name}) - self.register_buffer(s_name,p.clone().detach().data) + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) self.collected_params = [] - def forward(self,model): + def forward(self, model): decay = self.decay if self.num_updates >= 0: self.num_updates += 1 - decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) one_minus_decay = 1.0 - decay diff --git a/src/metr/ldm/modules/encoders/modules.py b/src/metr/ldm/modules/encoders/modules.py index 4edd549..250f445 100644 --- a/src/metr/ldm/modules/encoders/modules.py +++ b/src/metr/ldm/modules/encoders/modules.py @@ -1,11 +1,9 @@ +import open_clip import torch import torch.nn as nn +from ldm.util import count_params, default from torch.utils.checkpoint import checkpoint - -from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel - -import open_clip -from ldm.util import default, count_params +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer class AbstractEncoder(nn.Module): @@ -23,7 +21,7 @@ def encode(self, x): class ClassEmbedder(nn.Module): - def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): + def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1): super().__init__() self.key = key self.embedding = nn.Embedding(n_classes, embed_dim) @@ -35,9 +33,9 @@ def forward(self, batch, key=None, disable_dropout=False): key = self.key # this is for use in crossattn c = batch[key][:, None] - if self.ucg_rate > 0. and not disable_dropout: - mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) - c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) + if self.ucg_rate > 0.0 and not disable_dropout: + mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) + c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) c = c.long() c = self.embedding(c) return c @@ -57,24 +55,34 @@ def disabled_train(self, mode=True): class FrozenT5Embedder(AbstractEncoder): """Uses the T5 transformer encoder for text""" - def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + + def __init__( + self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True + ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() self.tokenizer = T5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) self.device = device - self.max_length = max_length # TODO: typical value? + self.max_length = max_length # TODO: typical value? if freeze: self.freeze() def freeze(self): self.transformer = self.transformer.eval() - #self.train = disabled_train + # self.train = disabled_train for param in self.parameters(): param.requires_grad = False def forward(self, text): - batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) tokens = batch_encoding["input_ids"].to(self.device) outputs = self.transformer(input_ids=tokens) @@ -87,13 +95,18 @@ def encode(self, text): class FrozenCLIPEmbedder(AbstractEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" - LAYERS = [ - "last", - "pooled", - "hidden" - ] - def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, - freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 + + LAYERS = ["last", "pooled", "hidden"] + + def __init__( + self, + version="openai/clip-vit-large-patch14", + device="cuda", + max_length=77, + freeze=True, + layer="last", + layer_idx=None, + ): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS self.tokenizer = CLIPTokenizer.from_pretrained(version) @@ -110,15 +123,22 @@ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_l def freeze(self): self.transformer = self.transformer.eval() - #self.train = disabled_train + # self.train = disabled_train for param in self.parameters(): param.requires_grad = False def forward(self, text): - batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) tokens = batch_encoding["input_ids"].to(self.device) - outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") + outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden") if self.layer == "last": z = outputs.last_hidden_state elif self.layer == "pooled": @@ -135,16 +155,19 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder): """ Uses the OpenCLIP transformer encoder for text """ + LAYERS = [ - #"pooled", + # "pooled", "last", - "penultimate" + "penultimate", ] - def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, - freeze=True, layer="last"): + + def __init__( + self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, freeze=True, layer="last" + ): super().__init__() assert layer in self.LAYERS - model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device("cpu"), pretrained=version) del model.visual self.model = model @@ -179,7 +202,7 @@ def encode_with_transformer(self, text): x = self.model.ln_final(x) return x - def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): for i, r in enumerate(self.model.transformer.resblocks): if i == len(self.model.transformer.resblocks) - self.layer_idx: break @@ -194,13 +217,21 @@ def encode(self, text): class FrozenCLIPT5Encoder(AbstractEncoder): - def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", - clip_max_length=77, t5_max_length=77): + def __init__( + self, + clip_version="openai/clip-vit-large-patch14", + t5_version="google/t5-v1_1-xl", + device="cuda", + clip_max_length=77, + t5_max_length=77, + ): super().__init__() self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) - print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " - f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") + print( + f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params." + ) def encode(self, text): return self(text) @@ -209,5 +240,3 @@ def forward(self, text): clip_z = self.clip_encoder.encode(text) t5_z = self.t5_encoder.encode(text) return [clip_z, t5_z] - - diff --git a/src/metr/ldm/modules/image_degradation/bsrgan.py b/src/metr/ldm/modules/image_degradation/bsrgan.py index 32ef561..de58ccb 100644 --- a/src/metr/ldm/modules/image_degradation/bsrgan.py +++ b/src/metr/ldm/modules/image_degradation/bsrgan.py @@ -10,33 +10,32 @@ # -------------------------------------------- """ -import numpy as np -import cv2 -import torch - -from functools import partial import random -from scipy import ndimage +from functools import partial + +import albumentations +import cv2 +import ldm.modules.image_degradation.utils_image as util +import numpy as np import scipy import scipy.stats as ss +import torch +from scipy import ndimage from scipy.interpolate import interp2d from scipy.linalg import orth -import albumentations - -import ldm.modules.image_degradation.utils_image as util def modcrop_np(img, sf): - ''' + """ Args: img: numpy image, WxH or WxHxC sf: scale factor Return: cropped image - ''' + """ w, h = img.shape[:2] im = np.copy(img) - return im[:w - w % sf, :h - h % sf, ...] + return im[: w - w % sf, : h - h % sf, ...] """ @@ -54,7 +53,7 @@ def analytic_kernel(k): # Loop over the small kernel to fill the big one for r in range(k_size): for c in range(k_size): - big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k # Crop the edges of the big kernel to ignore very small values and increase run time of SR crop = k_size // 2 cropped_big_k = big_k[crop:-crop, crop:-crop] @@ -63,7 +62,7 @@ def analytic_kernel(k): def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): - """ generate an anisotropic Gaussian kernel + """generate an anisotropic Gaussian kernel Args: ksize : e.g., 15, kernel size theta : [0, pi], rotation angle range @@ -74,7 +73,7 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): k : kernel """ - v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1.0, 0.0])) V = np.array([[v[0], v[1]], [v[1], -v[0]]]) D = np.array([[l1, 0], [0, l2]]) Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) @@ -126,13 +125,13 @@ def shift_pixel(x, sf, upper_left=True): def blur(x, k): - ''' + """ x: image, NxcxHxW k: kernel, Nx1xhxw - ''' + """ n, c = x.shape[:2] p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 - x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate") k = k.repeat(1, c, 1, 1) k = k.view(-1, 1, k.shape[2], k.shape[3]) x = x.view(1, -1, x.shape[2], x.shape[3]) @@ -142,8 +141,8 @@ def blur(x, k): return x -def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): - """" +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10.0, noise_level=0): + """ " # modified version of https://github.com/assafshocher/BlindSR_dataset_generator # Kai Zhang # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var @@ -157,8 +156,7 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var # Set COV matrix using Lambdas and Theta LAMBDA = np.diag([lambda_1, lambda_2]) - Q = np.array([[np.cos(theta), -np.sin(theta)], - [np.sin(theta), np.cos(theta)]]) + Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) SIGMA = Q @ LAMBDA @ Q.T INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] @@ -208,13 +206,13 @@ def fspecial_laplacian(alpha): def fspecial(filter_type, *args, **kwargs): - ''' + """ python code from: https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py - ''' - if filter_type == 'gaussian': + """ + if filter_type == "gaussian": return fspecial_gaussian(*args, **kwargs) - if filter_type == 'laplacian': + if filter_type == "laplacian": return fspecial_laplacian(*args, **kwargs) @@ -226,19 +224,19 @@ def fspecial(filter_type, *args, **kwargs): def bicubic_degradation(x, sf=3): - ''' + """ Args: x: HxWxC image, [0, 1] sf: down-scale factor Return: bicubicly downsampled LR image - ''' + """ x = util.imresize_np(x, scale=1 / sf) return x def srmd_degradation(x, k, sf=3): - ''' blur + bicubic downsampling + """blur + bicubic downsampling Args: x: HxWxC image, [0, 1] k: hxw, double @@ -253,14 +251,14 @@ def srmd_degradation(x, k, sf=3): pages={3262--3271}, year={2018} } - ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + """ + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror' x = bicubic_degradation(x, sf=sf) return x def dpsr_degradation(x, k, sf=3): - ''' bicubic downsampling + blur + """bicubic downsampling + blur Args: x: HxWxC image, [0, 1] k: hxw, double @@ -275,22 +273,22 @@ def dpsr_degradation(x, k, sf=3): pages={1671--1681}, year={2019} } - ''' + """ x = bicubic_degradation(x, sf=sf) - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") return x def classical_degradation(x, k, sf=3): - ''' blur + downsampling + """blur + downsampling Args: x: HxWxC image, [0, 1]/[0, 255] k: hxw, double sf: down-scale factor Return: downsampled LR image - ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + """ + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) st = 0 return x[st::sf, st::sf, ...] @@ -314,7 +312,7 @@ def add_sharpening(img, weight=0.5, radius=50, threshold=10): blur = cv2.GaussianBlur(img, (radius, radius), 0) residual = img - blur mask = np.abs(residual) * 255 > threshold - mask = mask.astype('float32') + mask = mask.astype("float32") soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) K = img + weight * residual @@ -330,8 +328,8 @@ def add_blur(img, sf=4): l2 = wd2 * random.random() k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) else: - k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) - img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + k = fspecial("gaussian", 2 * random.randint(2, 11) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode="mirror") return img @@ -366,6 +364,7 @@ def add_resize(img, sf=4): # img = np.clip(img, 0.0, 1.0) # return img + def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): noise_level = random.randint(noise_level1, noise_level2) rnum = np.random.rand() @@ -374,11 +373,11 @@ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): elif rnum < 0.4: # add grayscale Gaussian noise img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) else: # add noise - L = noise_level2 / 255. + L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img @@ -392,23 +391,23 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25): elif rnum < 0.4: img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) else: - L = noise_level2 / 255. + L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img def add_Poisson_noise(img): - img = np.clip((img * 255.0).round(), 0, 255) / 255. + img = np.clip((img * 255.0).round(), 0, 255) / 255.0 vals = 10 ** (2 * random.random() + 2.0) # [2, 4] if random.random() < 0.5: img = np.random.poisson(img * vals).astype(np.float32) / vals else: img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) - img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0 noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray img += noise_gray[:, :, np.newaxis] img = np.clip(img, 0.0, 1.0) @@ -418,7 +417,7 @@ def add_Poisson_noise(img): def add_JPEG_noise(img): quality_factor = random.randint(30, 95) img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) - result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) img = cv2.imdecode(encimg, 1) img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) return img @@ -428,10 +427,10 @@ def random_crop(lq, hq, sf=4, lq_patchsize=64): h, w = lq.shape[:2] rnd_h = random.randint(0, h - lq_patchsize) rnd_w = random.randint(0, w - lq_patchsize) - lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :] rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) - hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + hq = hq[rnd_h_H : rnd_h_H + lq_patchsize * sf, rnd_w_H : rnd_w_H + lq_patchsize * sf, :] return lq, hq @@ -452,18 +451,19 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): sf_ori = sf h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = img.shape[:2] if h < lq_patchsize * sf or w < lq_patchsize * sf: - raise ValueError(f'img size ({h1}X{w1}) is too small!') + raise ValueError(f"img size ({h1}X{w1}) is too small!") hq = img.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: - img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), interpolation=random.choice([1, 2, 3]) + ) else: img = util.imresize_np(img, 1 / 2, True) img = np.clip(img, 0.0, 1.0) @@ -487,13 +487,16 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): # downsample2 if random.random() < 0.75: sf1 = random.uniform(1, 2 * sf) - img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, + (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror") img = img[0::sf, 0::sf, ...] # nearest downsampling img = np.clip(img, 0.0, 1.0) @@ -544,15 +547,18 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): sf_ori = sf h1, w1 = image.shape[:2] - image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = image.shape[:2] hq = image.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: - image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: image = util.imresize_np(image, 1 / 2, True) image = np.clip(image, 0.0, 1.0) @@ -576,13 +582,16 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): # downsample2 if random.random() < 0.75: sf1 = random.uniform(1, 2 * sf) - image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror") image = image[0::sf, 0::sf, ...] # nearest downsampling image = np.clip(image, 0.0, 1.0) @@ -609,7 +618,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): # add final JPEG compression noise image = add_JPEG_noise(image) image = util.single2uint(image) - example = {"image":image} + example = {"image": image} return example @@ -630,11 +639,11 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc """ h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = img.shape[:2] if h < lq_patchsize * sf or w < lq_patchsize * sf: - raise ValueError(f'img size ({h1}X{w1}) is too small!') + raise ValueError(f"img size ({h1}X{w1}) is too small!") if use_sharp: img = add_sharpening(img) @@ -686,11 +695,12 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc with torch.no_grad(): img, hq = isp_model.forward(img.copy(), hq) else: - print('check the shuffle!') + print("check the shuffle!") # resize to desired size - img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), interpolation=random.choice([1, 2, 3]) + ) # add final JPEG compression noise img = add_JPEG_noise(img) @@ -701,30 +711,30 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc return img, hq -if __name__ == '__main__': - print("hey") - img = util.imread_uint('utils/test.png', 3) - print(img) - img = util.uint2single(img) - print(img) - img = img[:448, :448] - h = img.shape[0] // 4 - print("resizing to", h) - sf = 4 - deg_fn = partial(degradation_bsrgan_variant, sf=sf) - for i in range(20): - print(i) - img_lq = deg_fn(img) - print(img_lq) - img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] - print(img_lq.shape) - print("bicubic", img_lq_bicubic.shape) - print(img_hq.shape) - lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) - util.imsave(img_concat, str(i) + '.png') - - +if __name__ == "__main__": + print("hey") + img = util.imread_uint("utils/test.png", 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize( + util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0 + ) + lq_bicubic_nearest = cv2.resize( + util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0 + ) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + ".png") diff --git a/src/metr/ldm/modules/image_degradation/bsrgan_light.py b/src/metr/ldm/modules/image_degradation/bsrgan_light.py index 808c7f8..1b8a6b5 100644 --- a/src/metr/ldm/modules/image_degradation/bsrgan_light.py +++ b/src/metr/ldm/modules/image_degradation/bsrgan_light.py @@ -1,18 +1,17 @@ # -*- coding: utf-8 -*- -import numpy as np -import cv2 -import torch - -from functools import partial import random -from scipy import ndimage +from functools import partial + +import albumentations +import cv2 +import ldm.modules.image_degradation.utils_image as util +import numpy as np import scipy import scipy.stats as ss +import torch +from scipy import ndimage from scipy.interpolate import interp2d from scipy.linalg import orth -import albumentations - -import ldm.modules.image_degradation.utils_image as util """ # -------------------------------------------- @@ -25,17 +24,18 @@ # -------------------------------------------- """ + def modcrop_np(img, sf): - ''' + """ Args: img: numpy image, WxH or WxHxC sf: scale factor Return: cropped image - ''' + """ w, h = img.shape[:2] im = np.copy(img) - return im[:w - w % sf, :h - h % sf, ...] + return im[: w - w % sf, : h - h % sf, ...] """ @@ -53,7 +53,7 @@ def analytic_kernel(k): # Loop over the small kernel to fill the big one for r in range(k_size): for c in range(k_size): - big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k # Crop the edges of the big kernel to ignore very small values and increase run time of SR crop = k_size // 2 cropped_big_k = big_k[crop:-crop, crop:-crop] @@ -62,7 +62,7 @@ def analytic_kernel(k): def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): - """ generate an anisotropic Gaussian kernel + """generate an anisotropic Gaussian kernel Args: ksize : e.g., 15, kernel size theta : [0, pi], rotation angle range @@ -73,7 +73,7 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): k : kernel """ - v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1.0, 0.0])) V = np.array([[v[0], v[1]], [v[1], -v[0]]]) D = np.array([[l1, 0], [0, l2]]) Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) @@ -125,13 +125,13 @@ def shift_pixel(x, sf, upper_left=True): def blur(x, k): - ''' + """ x: image, NxcxHxW k: kernel, Nx1xhxw - ''' + """ n, c = x.shape[:2] p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 - x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate") k = k.repeat(1, c, 1, 1) k = k.view(-1, 1, k.shape[2], k.shape[3]) x = x.view(1, -1, x.shape[2], x.shape[3]) @@ -141,8 +141,8 @@ def blur(x, k): return x -def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): - """" +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10.0, noise_level=0): + """ " # modified version of https://github.com/assafshocher/BlindSR_dataset_generator # Kai Zhang # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var @@ -156,8 +156,7 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var # Set COV matrix using Lambdas and Theta LAMBDA = np.diag([lambda_1, lambda_2]) - Q = np.array([[np.cos(theta), -np.sin(theta)], - [np.sin(theta), np.cos(theta)]]) + Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) SIGMA = Q @ LAMBDA @ Q.T INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] @@ -207,13 +206,13 @@ def fspecial_laplacian(alpha): def fspecial(filter_type, *args, **kwargs): - ''' + """ python code from: https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py - ''' - if filter_type == 'gaussian': + """ + if filter_type == "gaussian": return fspecial_gaussian(*args, **kwargs) - if filter_type == 'laplacian': + if filter_type == "laplacian": return fspecial_laplacian(*args, **kwargs) @@ -225,19 +224,19 @@ def fspecial(filter_type, *args, **kwargs): def bicubic_degradation(x, sf=3): - ''' + """ Args: x: HxWxC image, [0, 1] sf: down-scale factor Return: bicubicly downsampled LR image - ''' + """ x = util.imresize_np(x, scale=1 / sf) return x def srmd_degradation(x, k, sf=3): - ''' blur + bicubic downsampling + """blur + bicubic downsampling Args: x: HxWxC image, [0, 1] k: hxw, double @@ -252,14 +251,14 @@ def srmd_degradation(x, k, sf=3): pages={3262--3271}, year={2018} } - ''' - x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + """ + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror' x = bicubic_degradation(x, sf=sf) return x def dpsr_degradation(x, k, sf=3): - ''' bicubic downsampling + blur + """bicubic downsampling + blur Args: x: HxWxC image, [0, 1] k: hxw, double @@ -274,22 +273,22 @@ def dpsr_degradation(x, k, sf=3): pages={1671--1681}, year={2019} } - ''' + """ x = bicubic_degradation(x, sf=sf) - x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode="wrap") return x def classical_degradation(x, k, sf=3): - ''' blur + downsampling + """blur + downsampling Args: x: HxWxC image, [0, 1]/[0, 255] k: hxw, double sf: down-scale factor Return: downsampled LR image - ''' - x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + """ + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) st = 0 return x[st::sf, st::sf, ...] @@ -313,7 +312,7 @@ def add_sharpening(img, weight=0.5, radius=50, threshold=10): blur = cv2.GaussianBlur(img, (radius, radius), 0) residual = img - blur mask = np.abs(residual) * 255 > threshold - mask = mask.astype('float32') + mask = mask.astype("float32") soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) K = img + weight * residual @@ -325,16 +324,16 @@ def add_blur(img, sf=4): wd2 = 4.0 + sf wd = 2.0 + 0.2 * sf - wd2 = wd2/4 - wd = wd/4 + wd2 = wd2 / 4 + wd = wd / 4 if random.random() < 0.5: l1 = wd2 * random.random() l2 = wd2 * random.random() k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) else: - k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) - img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + k = fspecial("gaussian", random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode="mirror") return img @@ -369,6 +368,7 @@ def add_resize(img, sf=4): # img = np.clip(img, 0.0, 1.0) # return img + def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): noise_level = random.randint(noise_level1, noise_level2) rnum = np.random.rand() @@ -377,11 +377,11 @@ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): elif rnum < 0.4: # add grayscale Gaussian noise img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) else: # add noise - L = noise_level2 / 255. + L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img @@ -395,23 +395,23 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25): elif rnum < 0.4: img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) else: - L = noise_level2 / 255. + L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img def add_Poisson_noise(img): - img = np.clip((img * 255.0).round(), 0, 255) / 255. + img = np.clip((img * 255.0).round(), 0, 255) / 255.0 vals = 10 ** (2 * random.random() + 2.0) # [2, 4] if random.random() < 0.5: img = np.random.poisson(img * vals).astype(np.float32) / vals else: img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) - img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0 noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray img += noise_gray[:, :, np.newaxis] img = np.clip(img, 0.0, 1.0) @@ -421,7 +421,7 @@ def add_Poisson_noise(img): def add_JPEG_noise(img): quality_factor = random.randint(80, 95) img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) - result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) img = cv2.imdecode(encimg, 1) img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) return img @@ -431,10 +431,10 @@ def random_crop(lq, hq, sf=4, lq_patchsize=64): h, w = lq.shape[:2] rnd_h = random.randint(0, h - lq_patchsize) rnd_w = random.randint(0, w - lq_patchsize) - lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :] rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) - hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + hq = hq[rnd_h_H : rnd_h_H + lq_patchsize * sf, rnd_w_H : rnd_w_H + lq_patchsize * sf, :] return lq, hq @@ -455,18 +455,19 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): sf_ori = sf h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = img.shape[:2] if h < lq_patchsize * sf or w < lq_patchsize * sf: - raise ValueError(f'img size ({h1}X{w1}) is too small!') + raise ValueError(f"img size ({h1}X{w1}) is too small!") hq = img.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: - img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), interpolation=random.choice([1, 2, 3]) + ) else: img = util.imresize_np(img, 1 / 2, True) img = np.clip(img, 0.0, 1.0) @@ -490,13 +491,16 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): # downsample2 if random.random() < 0.75: sf1 = random.uniform(1, 2 * sf) - img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, + (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror") img = img[0::sf, 0::sf, ...] # nearest downsampling img = np.clip(img, 0.0, 1.0) @@ -547,15 +551,18 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): sf_ori = sf h1, w1 = image.shape[:2] - image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = image.shape[:2] hq = image.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: - image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: image = util.imresize_np(image, 1 / 2, True) image = np.clip(image, 0.0, 1.0) @@ -582,13 +589,16 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): # downsample2 if random.random() < 0.8: sf1 = random.uniform(1, 2 * sf) - image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror") image = image[0::sf, 0::sf, ...] # nearest downsampling image = np.clip(image, 0.0, 1.0) @@ -617,16 +627,16 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): image = add_JPEG_noise(image) image = util.single2uint(image) if up: - image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then + image = cv2.resize( + image, (w1, h1), interpolation=cv2.INTER_CUBIC + ) # todo: random, as above? want to condition on it then example = {"image": image} return example - - -if __name__ == '__main__': +if __name__ == "__main__": print("hey") - img = util.imread_uint('utils/test.png', 3) + img = util.imread_uint("utils/test.png", 3) img = img[:448, :448] h = img.shape[0] // 4 print("resizing to", h) @@ -638,14 +648,17 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): img_lq = deg_fn(img)["image"] img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) print(img_lq) - img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)[ + "image" + ] print(img_lq.shape) print("bicubic", img_lq_bicubic.shape) print(img_hq.shape) - lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), - (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) + lq_nearest = cv2.resize( + util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0 + ) + lq_bicubic_nearest = cv2.resize( + util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0 + ) img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) - util.imsave(img_concat, str(i) + '.png') + util.imsave(img_concat, str(i) + ".png") diff --git a/src/metr/ldm/modules/image_degradation/utils_image.py b/src/metr/ldm/modules/image_degradation/utils_image.py index 0175f15..f9b2960 100644 --- a/src/metr/ldm/modules/image_degradation/utils_image.py +++ b/src/metr/ldm/modules/image_degradation/utils_image.py @@ -1,18 +1,20 @@ -import os import math +import os import random +from datetime import datetime + +import cv2 import numpy as np import torch -import cv2 from torchvision.utils import make_grid -from datetime import datetime -#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py +# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py -os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" -''' + +""" # -------------------------------------------- # Kai Zhang (github: https://github.com/cszn) # 03/Mar/2019 @@ -20,10 +22,10 @@ # https://github.com/twhui/SRGAN-pyTorch # https://github.com/xinntao/BasicSR # -------------------------------------------- -''' +""" -IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] +IMG_EXTENSIONS = [".jpg", ".JPG", ".jpeg", ".JPEG", ".png", ".PNG", ".ppm", ".PPM", ".bmp", ".BMP", ".tif"] def is_image_file(filename): @@ -31,12 +33,12 @@ def is_image_file(filename): def get_timestamp(): - return datetime.now().strftime('%y%m%d-%H%M%S') + return datetime.now().strftime("%y%m%d-%H%M%S") def imshow(x, title=None, cbar=False, figsize=None): plt.figure(figsize=figsize) - plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + plt.imshow(np.squeeze(x), interpolation="nearest", cmap="gray") if title: plt.title(title) if cbar: @@ -44,24 +46,24 @@ def imshow(x, title=None, cbar=False, figsize=None): plt.show() -def surf(Z, cmap='rainbow', figsize=None): +def surf(Z, cmap="rainbow", figsize=None): plt.figure(figsize=figsize) - ax3 = plt.axes(projection='3d') + ax3 = plt.axes(projection="3d") w, h = Z.shape[:2] - xx = np.arange(0,w,1) - yy = np.arange(0,h,1) + xx = np.arange(0, w, 1) + yy = np.arange(0, h, 1) X, Y = np.meshgrid(xx, yy) - ax3.plot_surface(X,Y,Z,cmap=cmap) - #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + ax3.plot_surface(X, Y, Z, cmap=cmap) + # ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) plt.show() -''' +""" # -------------------------------------------- # get image pathes # -------------------------------------------- -''' +""" def get_image_paths(dataroot): @@ -72,37 +74,37 @@ def get_image_paths(dataroot): def _get_paths_from_images(path): - assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + assert os.path.isdir(path), "{:s} is not a valid directory".format(path) images = [] for dirpath, _, fnames in sorted(os.walk(path)): for fname in sorted(fnames): if is_image_file(fname): img_path = os.path.join(dirpath, fname) images.append(img_path) - assert images, '{:s} has no valid image file'.format(path) + assert images, "{:s} has no valid image file".format(path) return images -''' +""" # -------------------------------------------- # split large images into small images # -------------------------------------------- -''' +""" def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): w, h = img.shape[:2] patches = [] if w > p_max and h > p_max: - w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) - h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) - w1.append(w-p_size) - h1.append(h-p_size) -# print(w1) -# print(h1) + w1 = list(np.arange(0, w - p_size, p_size - p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h - p_size, p_size - p_overlap, dtype=np.int)) + w1.append(w - p_size) + h1.append(h - p_size) + # print(w1) + # print(h1) for i in w1: for j in h1: - patches.append(img[i:i+p_size, j:j+p_size,:]) + patches.append(img[i : i + p_size, j : j + p_size, :]) else: patches.append(img) @@ -118,7 +120,7 @@ def imssave(imgs, img_path): for i, img in enumerate(imgs): if img.ndim == 3: img = img[:, :, [2, 1, 0]] - new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + new_path = os.path.join(os.path.dirname(img_path), img_name + str("_s{:04d}".format(i)) + ".png") cv2.imwrite(new_path, img) @@ -139,15 +141,16 @@ def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, # img_name, ext = os.path.splitext(os.path.basename(img_path)) img = imread_uint(img_path, n_channels=n_channels) patches = patches_from_image(img, p_size, p_overlap, p_max) - imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) - #if original_dataroot == taget_dataroot: - #del img_path + imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path))) + # if original_dataroot == taget_dataroot: + # del img_path + -''' +""" # -------------------------------------------- # makedir # -------------------------------------------- -''' +""" def mkdir(path): @@ -165,18 +168,18 @@ def mkdirs(paths): def mkdir_and_rename(path): if os.path.exists(path): - new_name = path + '_archived_' + get_timestamp() - print('Path already exists. Rename it to [{:s}]'.format(new_name)) + new_name = path + "_archived_" + get_timestamp() + print("Path already exists. Rename it to [{:s}]".format(new_name)) os.rename(path, new_name) os.makedirs(path) -''' +""" # -------------------------------------------- # read image from path # opencv is fast, but read BGR numpy image # -------------------------------------------- -''' +""" # -------------------------------------------- @@ -206,6 +209,7 @@ def imsave(img, img_path): img = img[:, :, [2, 1, 0]] cv2.imwrite(img_path, img) + def imwrite(img, img_path): img = np.squeeze(img) if img.ndim == 3: @@ -213,7 +217,6 @@ def imwrite(img, img_path): cv2.imwrite(img_path, img) - # -------------------------------------------- # get single image of size HxWxn_channles (BGR) # -------------------------------------------- @@ -221,7 +224,7 @@ def read_img(path): # read image by cv2 # return: Numpy float32, HWC, BGR, [0,1] img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE - img = img.astype(np.float32) / 255. + img = img.astype(np.float32) / 255.0 if img.ndim == 2: img = np.expand_dims(img, axis=2) # some images have 4 channels @@ -230,7 +233,7 @@ def read_img(path): return img -''' +""" # -------------------------------------------- # image format conversion # -------------------------------------------- @@ -238,7 +241,7 @@ def read_img(path): # numpy(single) <---> tensor # numpy(unit) <---> tensor # -------------------------------------------- -''' +""" # -------------------------------------------- @@ -248,22 +251,22 @@ def read_img(path): def uint2single(img): - return np.float32(img/255.) + return np.float32(img / 255.0) def single2uint(img): - return np.uint8((img.clip(0, 1)*255.).round()) + return np.uint8((img.clip(0, 1) * 255.0).round()) def uint162single(img): - return np.float32(img/65535.) + return np.float32(img / 65535.0) def single2uint16(img): - return np.uint16((img.clip(0, 1)*65535.).round()) + return np.uint16((img.clip(0, 1) * 65535.0).round()) # -------------------------------------------- @@ -275,14 +278,14 @@ def single2uint16(img): def uint2tensor4(img): if img.ndim == 2: img = np.expand_dims(img, axis=2) - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0).unsqueeze(0) # convert uint to 3-dimensional torch tensor def uint2tensor3(img): if img.ndim == 2: img = np.expand_dims(img, axis=2) - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0) # convert 2/3/4-dimensional torch tensor to uint @@ -290,7 +293,7 @@ def tensor2uint(img): img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() if img.ndim == 3: img = np.transpose(img, (1, 2, 0)) - return np.uint8((img*255.0).round()) + return np.uint8((img * 255.0).round()) # -------------------------------------------- @@ -316,6 +319,7 @@ def tensor2single(img): return img + # convert torch tensor to single def tensor2single3(img): img = img.data.squeeze().float().cpu().numpy() @@ -340,11 +344,11 @@ def single42tensor4(img): # from skimage.io import imread, imsave def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): - ''' + """ Converts a torch Tensor into an image Numpy array of BGR channel order Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) - ''' + """ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] n_dim = tensor.dim() @@ -358,15 +362,14 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): elif n_dim == 2: img_np = tensor.numpy() else: - raise TypeError( - 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + raise TypeError("Only support 4D, 3D and 2D tensor. But received with dimension: {:d}".format(n_dim)) if out_type == np.uint8: img_np = (img_np * 255.0).round() # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. return img_np.astype(out_type) -''' +""" # -------------------------------------------- # Augmentation, flipe and/or rotate # -------------------------------------------- @@ -374,12 +377,11 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): # (1) augmet_img: numpy image of WxHxC or WxH # (2) augment_img_tensor4: tensor image 1xCxWxH # -------------------------------------------- -''' +""" def augment_img(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' + """Kai Zhang (github: https://github.com/cszn)""" if mode == 0: return img elif mode == 1: @@ -399,8 +401,7 @@ def augment_img(img, mode=0): def augment_img_tensor4(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' + """Kai Zhang (github: https://github.com/cszn)""" if mode == 0: return img elif mode == 1: @@ -420,8 +421,7 @@ def augment_img_tensor4(img, mode=0): def augment_img_tensor(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' + """Kai Zhang (github: https://github.com/cszn)""" img_size = img.size() img_np = img.data.cpu().numpy() if len(img_size) == 3: @@ -484,11 +484,11 @@ def _augment(img): return [_augment(img) for img in img_list] -''' +""" # -------------------------------------------- # modcrop and shave # -------------------------------------------- -''' +""" def modcrop(img_in, scale): @@ -497,13 +497,13 @@ def modcrop(img_in, scale): if img.ndim == 2: H, W = img.shape H_r, W_r = H % scale, W % scale - img = img[:H - H_r, :W - W_r] + img = img[: H - H_r, : W - W_r] elif img.ndim == 3: H, W, C = img.shape H_r, W_r = H % scale, W % scale - img = img[:H - H_r, :W - W_r, :] + img = img[: H - H_r, : W - W_r, :] else: - raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + raise ValueError("Wrong img ndim: [{:d}].".format(img.ndim)) return img @@ -511,11 +511,11 @@ def shave(img_in, border=0): # img_in: Numpy, HWC or HW img = np.copy(img_in) h, w = img.shape[:2] - img = img[border:h-border, border:w-border] + img = img[border : h - border, border : w - border] return img -''' +""" # -------------------------------------------- # image processing process on numpy image # channel_convert(in_c, tar_type, img_list): @@ -523,96 +523,99 @@ def shave(img_in, border=0): # bgr2ycbcr(img, only_y=True): # ycbcr2rgb(img): # -------------------------------------------- -''' +""" def rgb2ycbcr(img, only_y=True): - '''same as matlab rgb2ycbcr + """same as matlab rgb2ycbcr only_y: only return Y channel Input: uint8, [0, 255] float, [0, 1] - ''' + """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: - img *= 255. + img *= 255.0 # convert if only_y: rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 else: - rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], - [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + rlt = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]] + ) / 255.0 + [16, 128, 128] if in_img_type == np.uint8: rlt = rlt.round() else: - rlt /= 255. + rlt /= 255.0 return rlt.astype(in_img_type) def ycbcr2rgb(img): - '''same as matlab ycbcr2rgb + """same as matlab ycbcr2rgb Input: uint8, [0, 255] float, [0, 1] - ''' + """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: - img *= 255. + img *= 255.0 # convert - rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], - [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + rlt = np.matmul( + img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], [0.00625893, -0.00318811, 0]] + ) * 255.0 + [-222.921, 135.576, -276.836] if in_img_type == np.uint8: rlt = rlt.round() else: - rlt /= 255. + rlt /= 255.0 return rlt.astype(in_img_type) def bgr2ycbcr(img, only_y=True): - '''bgr version of rgb2ycbcr + """bgr version of rgb2ycbcr only_y: only return Y channel Input: uint8, [0, 255] float, [0, 1] - ''' + """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: - img *= 255. + img *= 255.0 # convert if only_y: rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 else: - rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], - [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + rlt = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]] + ) / 255.0 + [16, 128, 128] if in_img_type == np.uint8: rlt = rlt.round() else: - rlt /= 255. + rlt /= 255.0 return rlt.astype(in_img_type) def channel_convert(in_c, tar_type, img_list): # conversion among BGR, gray and y - if in_c == 3 and tar_type == 'gray': # BGR to gray + if in_c == 3 and tar_type == "gray": # BGR to gray gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] return [np.expand_dims(img, axis=2) for img in gray_list] - elif in_c == 3 and tar_type == 'y': # BGR to y + elif in_c == 3 and tar_type == "y": # BGR to y y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] return [np.expand_dims(img, axis=2) for img in y_list] - elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + elif in_c == 1 and tar_type == "RGB": # gray/y to BGR return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] else: return img_list -''' +""" # -------------------------------------------- # metric, PSNR and SSIM # -------------------------------------------- -''' +""" # -------------------------------------------- @@ -620,19 +623,19 @@ def channel_convert(in_c, tar_type, img_list): # -------------------------------------------- def calculate_psnr(img1, img2, border=0): # img1 and img2 have range [0, 255] - #img1 = img1.squeeze() - #img2 = img2.squeeze() + # img1 = img1.squeeze() + # img2 = img2.squeeze() if not img1.shape == img2.shape: - raise ValueError('Input images must have the same dimensions.') + raise ValueError("Input images must have the same dimensions.") h, w = img1.shape[:2] - img1 = img1[border:h-border, border:w-border] - img2 = img2[border:h-border, border:w-border] + img1 = img1[border : h - border, border : w - border] + img2 = img2[border : h - border, border : w - border] img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) - mse = np.mean((img1 - img2)**2) + mse = np.mean((img1 - img2) ** 2) if mse == 0: - return float('inf') + return float("inf") return 20 * math.log10(255.0 / math.sqrt(mse)) @@ -640,17 +643,17 @@ def calculate_psnr(img1, img2, border=0): # SSIM # -------------------------------------------- def calculate_ssim(img1, img2, border=0): - '''calculate SSIM + """calculate SSIM the same outputs as MATLAB's img1, img2: [0, 255] - ''' - #img1 = img1.squeeze() - #img2 = img2.squeeze() + """ + # img1 = img1.squeeze() + # img2 = img2.squeeze() if not img1.shape == img2.shape: - raise ValueError('Input images must have the same dimensions.') + raise ValueError("Input images must have the same dimensions.") h, w = img1.shape[:2] - img1 = img1[border:h-border, border:w-border] - img2 = img2[border:h-border, border:w-border] + img1 = img1[border : h - border, border : w - border] + img2 = img2[border : h - border, border : w - border] if img1.ndim == 2: return ssim(img1, img2) @@ -658,17 +661,17 @@ def calculate_ssim(img1, img2, border=0): if img1.shape[2] == 3: ssims = [] for i in range(3): - ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + ssims.append(ssim(img1[:, :, i], img2[:, :, i])) return np.array(ssims).mean() elif img1.shape[2] == 1: return ssim(np.squeeze(img1), np.squeeze(img2)) else: - raise ValueError('Wrong input image dimensions.') + raise ValueError("Wrong input image dimensions.") def ssim(img1, img2): - C1 = (0.01 * 255)**2 - C2 = (0.03 * 255)**2 + C1 = (0.01 * 255) ** 2 + C2 = (0.03 * 255) ** 2 img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) @@ -684,16 +687,15 @@ def ssim(img1, img2): sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 - ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * - (sigma1_sq + sigma2_sq + C2)) + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) return ssim_map.mean() -''' +""" # -------------------------------------------- # matlab's bicubic imresize (numpy and torch) [0, 1] # -------------------------------------------- -''' +""" # matlab 'imresize' function, now only support 'bicubic' @@ -701,8 +703,9 @@ def cubic(x): absx = torch.abs(x) absx2 = absx**2 absx3 = absx**3 - return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ - (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + ( + -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2 + ) * (((absx > 1) * (absx <= 2)).type_as(absx)) def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): @@ -729,8 +732,9 @@ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width # The indices of the input pixels involved in computing the k-th output # pixel are in row k of the indices matrix. - indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( - 1, P).expand(out_length, P) + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(1, P).expand( + out_length, P + ) # The weights used to compute the k-th output pixel are in row k of the # weights matrix. @@ -773,7 +777,7 @@ def imresize(img, scale, antialiasing=True): in_C, in_H, in_W = img.size() out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) kernel_width = 4 - kernel = 'cubic' + kernel = "cubic" # Return the desired dimension order for performing the resize. The # strategy is to perform the resize first along the dimension with the @@ -782,9 +786,11 @@ def imresize(img, scale, antialiasing=True): # get weights and indices weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( - in_H, out_H, scale, kernel, kernel_width, antialiasing) + in_H, out_H, scale, kernel, kernel_width, antialiasing + ) weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( - in_W, out_W, scale, kernel, kernel_width, antialiasing) + in_W, out_W, scale, kernel, kernel_width, antialiasing + ) # process H dimension # symmetric copying img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) @@ -805,7 +811,7 @@ def imresize(img, scale, antialiasing=True): for i in range(out_H): idx = int(indices_H[i][0]) for j in range(out_C): - out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + out_1[j, i, :] = img_aug[j, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) # process W dimension # symmetric copying @@ -827,7 +833,7 @@ def imresize(img, scale, antialiasing=True): for i in range(out_W): idx = int(indices_W[i][0]) for j in range(out_C): - out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv(weights_W[i]) if need_squeeze: out_2.squeeze_() return out_2 @@ -848,7 +854,7 @@ def imresize_np(img, scale, antialiasing=True): in_H, in_W, in_C = img.size() out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) kernel_width = 4 - kernel = 'cubic' + kernel = "cubic" # Return the desired dimension order for performing the resize. The # strategy is to perform the resize first along the dimension with the @@ -857,9 +863,11 @@ def imresize_np(img, scale, antialiasing=True): # get weights and indices weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( - in_H, out_H, scale, kernel, kernel_width, antialiasing) + in_H, out_H, scale, kernel, kernel_width, antialiasing + ) weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( - in_W, out_W, scale, kernel, kernel_width, antialiasing) + in_W, out_W, scale, kernel, kernel_width, antialiasing + ) # process H dimension # symmetric copying img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) @@ -880,7 +888,7 @@ def imresize_np(img, scale, antialiasing=True): for i in range(out_H): idx = int(indices_H[i][0]) for j in range(out_C): - out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + out_1[i, :, j] = img_aug[idx : idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) # process W dimension # symmetric copying @@ -902,15 +910,15 @@ def imresize_np(img, scale, antialiasing=True): for i in range(out_W): idx = int(indices_W[i][0]) for j in range(out_C): - out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + out_2[:, i, j] = out_1_aug[:, idx : idx + kernel_width, j].mv(weights_W[i]) if need_squeeze: out_2.squeeze_() return out_2.numpy() -if __name__ == '__main__': - print('---') +if __name__ == "__main__": + print("---") # img = imread_uint('test.bmp', 3) # img = uint2single(img) -# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file +# img_bicubic = imresize_np(img, 1/4) diff --git a/src/metr/ldm/modules/losses/__init__.py b/src/metr/ldm/modules/losses/__init__.py index 876d7c5..d862942 100644 --- a/src/metr/ldm/modules/losses/__init__.py +++ b/src/metr/ldm/modules/losses/__init__.py @@ -1 +1 @@ -from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator \ No newline at end of file +from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator diff --git a/src/metr/ldm/modules/losses/contperceptual.py b/src/metr/ldm/modules/losses/contperceptual.py index 672c1e3..4dfe08b 100644 --- a/src/metr/ldm/modules/losses/contperceptual.py +++ b/src/metr/ldm/modules/losses/contperceptual.py @@ -1,14 +1,24 @@ import torch import torch.nn as nn - from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? class LPIPSWithDiscriminator(nn.Module): - def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, - disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, - perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, - disc_loss="hinge"): + def __init__( + self, + disc_start, + logvar_init=0.0, + kl_weight=1.0, + pixelloss_weight=1.0, + disc_num_layers=3, + disc_in_channels=3, + disc_factor=1.0, + disc_weight=1.0, + perceptual_weight=1.0, + use_actnorm=False, + disc_conditional=False, + disc_loss="hinge", + ): super().__init__() assert disc_loss in ["hinge", "vanilla"] @@ -19,10 +29,9 @@ def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight= # output log variance self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) - self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, - n_layers=disc_num_layers, - use_actnorm=use_actnorm - ).apply(weights_init) + self.discriminator = NLayerDiscriminator( + input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm + ).apply(weights_init) self.discriminator_iter_start = disc_start self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss self.disc_factor = disc_factor @@ -42,9 +51,18 @@ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): d_weight = d_weight * self.discriminator_weight return d_weight - def forward(self, inputs, reconstructions, posteriors, optimizer_idx, - global_step, last_layer=None, cond=None, split="train", - weights=None): + def forward( + self, + inputs, + reconstructions, + posteriors, + optimizer_idx, + global_step, + last_layer=None, + cond=None, + split="train", + weights=None, + ): rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) if self.perceptual_weight > 0: p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) @@ -53,7 +71,7 @@ def forward(self, inputs, reconstructions, posteriors, optimizer_idx, nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar weighted_nll_loss = nll_loss if weights is not None: - weighted_nll_loss = weights*nll_loss + weighted_nll_loss = weights * nll_loss weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] kl_loss = posteriors.kl() @@ -82,13 +100,16 @@ def forward(self, inputs, reconstructions, posteriors, optimizer_idx, disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss - log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), - "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - "{}/d_weight".format(split): d_weight.detach(), - "{}/disc_factor".format(split): torch.tensor(disc_factor), - "{}/g_loss".format(split): g_loss.detach().mean(), - } + log = { + "{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } return loss, log if optimizer_idx == 1: @@ -103,9 +124,9 @@ def forward(self, inputs, reconstructions, posteriors, optimizer_idx, disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) - log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), - "{}/logits_real".format(split): logits_real.detach().mean(), - "{}/logits_fake".format(split): logits_fake.detach().mean() - } + log = { + "{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean(), + } return d_loss, log - diff --git a/src/metr/ldm/modules/losses/vqperceptual.py b/src/metr/ldm/modules/losses/vqperceptual.py index f699817..e79a5c3 100644 --- a/src/metr/ldm/modules/losses/vqperceptual.py +++ b/src/metr/ldm/modules/losses/vqperceptual.py @@ -1,23 +1,23 @@ import torch -from torch import nn import torch.nn.functional as F from einops import repeat - from taming.modules.discriminator.model import NLayerDiscriminator, weights_init from taming.modules.losses.lpips import LPIPS from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss +from torch import nn def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] - loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) - loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) + loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3]) + loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3]) loss_real = (weights * loss_real).sum() / weights.sum() loss_fake = (weights * loss_fake).sum() / weights.sum() d_loss = 0.5 * (loss_real + loss_fake) return d_loss -def adopt_weight(weight, global_step, threshold=0, value=0.): + +def adopt_weight(weight, global_step, threshold=0, value=0.0): if global_step < threshold: weight = value return weight @@ -32,20 +32,34 @@ def measure_perplexity(predicted_indices, n_embed): cluster_use = torch.sum(avg_probs > 0) return perplexity, cluster_use + def l1(x, y): - return torch.abs(x-y) + return torch.abs(x - y) def l2(x, y): - return torch.pow((x-y), 2) + return torch.pow((x - y), 2) class VQLPIPSWithDiscriminator(nn.Module): - def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, - disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, - perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, - disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", - pixel_loss="l1"): + def __init__( + self, + disc_start, + codebook_weight=1.0, + pixelloss_weight=1.0, + disc_num_layers=3, + disc_in_channels=3, + disc_factor=1.0, + disc_weight=1.0, + perceptual_weight=1.0, + use_actnorm=False, + disc_conditional=False, + disc_ndf=64, + disc_loss="hinge", + n_classes=None, + perceptual_loss="lpips", + pixel_loss="l1", + ): super().__init__() assert disc_loss in ["hinge", "vanilla"] assert perceptual_loss in ["lpips", "clips", "dists"] @@ -64,11 +78,9 @@ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, else: self.pixel_loss = l2 - self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, - n_layers=disc_num_layers, - use_actnorm=use_actnorm, - ndf=disc_ndf - ).apply(weights_init) + self.discriminator = NLayerDiscriminator( + input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm, ndf=disc_ndf + ).apply(weights_init) self.discriminator_iter_start = disc_start if disc_loss == "hinge": self.disc_loss = hinge_d_loss @@ -95,11 +107,21 @@ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): d_weight = d_weight * self.discriminator_weight return d_weight - def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, - global_step, last_layer=None, cond=None, split="train", predicted_indices=None): + def forward( + self, + codebook_loss, + inputs, + reconstructions, + optimizer_idx, + global_step, + last_layer=None, + cond=None, + split="train", + predicted_indices=None, + ): if not exists(codebook_loss): - codebook_loss = torch.tensor([0.]).to(inputs.device) - #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + codebook_loss = torch.tensor([0.0]).to(inputs.device) + # rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) if self.perceptual_weight > 0: p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) @@ -108,7 +130,7 @@ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, p_loss = torch.tensor([0.0]) nll_loss = rec_loss - #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] nll_loss = torch.mean(nll_loss) # now the GAN part @@ -131,15 +153,16 @@ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() - log = {"{}/total_loss".format(split): loss.clone().detach().mean(), - "{}/quant_loss".format(split): codebook_loss.detach().mean(), - "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - "{}/p_loss".format(split): p_loss.detach().mean(), - "{}/d_weight".format(split): d_weight.detach(), - "{}/disc_factor".format(split): torch.tensor(disc_factor), - "{}/g_loss".format(split): g_loss.detach().mean(), - } + log = { + "{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/quant_loss".format(split): codebook_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/p_loss".format(split): p_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } if predicted_indices is not None: assert self.n_classes is not None with torch.no_grad(): @@ -160,8 +183,9 @@ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) - log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), - "{}/logits_real".format(split): logits_real.detach().mean(), - "{}/logits_fake".format(split): logits_fake.detach().mean() - } + log = { + "{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean(), + } return d_loss, log diff --git a/src/metr/ldm/modules/midas/api.py b/src/metr/ldm/modules/midas/api.py index b58ebbf..6619f51 100644 --- a/src/metr/ldm/modules/midas/api.py +++ b/src/metr/ldm/modules/midas/api.py @@ -3,13 +3,11 @@ import cv2 import torch import torch.nn as nn -from torchvision.transforms import Compose - from ldm.modules.midas.midas.dpt_depth import DPTDepthModel from ldm.modules.midas.midas.midas_net import MidasNet from ldm.modules.midas.midas.midas_net_custom import MidasNet_small -from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet - +from ldm.modules.midas.midas.transforms import NormalizeImage, PrepareForNet, Resize +from torchvision.transforms import Compose ISL_PATHS = { "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", @@ -98,18 +96,20 @@ def load_model(model_type): model = MidasNet(model_path, non_negative=True) net_w, net_h = 384, 384 resize_mode = "upper_bound" - normalization = NormalizeImage( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ) + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) elif model_type == "midas_v21_small": - model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, - non_negative=True, blocks={'expand': True}) + model = MidasNet_small( + model_path, + features=64, + backbone="efficientnet_lite3", + exportable=True, + non_negative=True, + blocks={"expand": True}, + ) net_w, net_h = 256, 256 resize_mode = "upper_bound" - normalization = NormalizeImage( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ) + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) else: print(f"model_type '{model_type}' not implemented, use: --model_type large") @@ -135,11 +135,7 @@ def load_model(model_type): class MiDaSInference(nn.Module): - MODEL_TYPES_TORCH_HUB = [ - "DPT_Large", - "DPT_Hybrid", - "MiDaS_small" - ] + MODEL_TYPES_TORCH_HUB = ["DPT_Large", "DPT_Hybrid", "MiDaS_small"] MODEL_TYPES_ISL = [ "dpt_large", "dpt_hybrid", @@ -149,7 +145,7 @@ class MiDaSInference(nn.Module): def __init__(self, model_type): super().__init__() - assert (model_type in self.MODEL_TYPES_ISL) + assert model_type in self.MODEL_TYPES_ISL model, _ = load_model(model_type) self.model = model self.model.train = disabled_train @@ -167,4 +163,3 @@ def forward(self, x): ) assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) return prediction - diff --git a/src/metr/ldm/modules/midas/midas/base_model.py b/src/metr/ldm/modules/midas/midas/base_model.py index 5cf4302..5c2e0e9 100644 --- a/src/metr/ldm/modules/midas/midas/base_model.py +++ b/src/metr/ldm/modules/midas/midas/base_model.py @@ -8,7 +8,7 @@ def load(self, path): Args: path (str): file path """ - parameters = torch.load(path, map_location=torch.device('cpu')) + parameters = torch.load(path, map_location=torch.device("cpu")) if "optimizer" in parameters: parameters = parameters["model"] diff --git a/src/metr/ldm/modules/midas/midas/blocks.py b/src/metr/ldm/modules/midas/midas/blocks.py index 2145d18..0739f60 100644 --- a/src/metr/ldm/modules/midas/midas/blocks.py +++ b/src/metr/ldm/modules/midas/midas/blocks.py @@ -1,18 +1,22 @@ import torch import torch.nn as nn -from .vit import ( - _make_pretrained_vitb_rn50_384, - _make_pretrained_vitl16_384, - _make_pretrained_vitb16_384, - forward_vit, -) - -def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): +from .vit import _make_pretrained_vitb16_384, _make_pretrained_vitb_rn50_384, _make_pretrained_vitl16_384, forward_vit + + +def _make_encoder( + backbone, + features, + use_pretrained, + groups=1, + expand=False, + exportable=True, + hooks=None, + use_vit_only=False, + use_readout="ignore", +): if backbone == "vitl16_384": - pretrained = _make_pretrained_vitl16_384( - use_pretrained, hooks=hooks, use_readout=use_readout - ) + pretrained = _make_pretrained_vitl16_384(use_pretrained, hooks=hooks, use_readout=use_readout) scratch = _make_scratch( [256, 512, 1024, 1024], features, groups=groups, expand=expand ) # ViT-L/16 - 85.0% Top1 (backbone) @@ -27,22 +31,20 @@ def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, ex [256, 512, 768, 768], features, groups=groups, expand=expand ) # ViT-H/16 - 85.0% Top1 (backbone) elif backbone == "vitb16_384": - pretrained = _make_pretrained_vitb16_384( - use_pretrained, hooks=hooks, use_readout=use_readout - ) + pretrained = _make_pretrained_vitb16_384(use_pretrained, hooks=hooks, use_readout=use_readout) scratch = _make_scratch( [96, 192, 384, 768], features, groups=groups, expand=expand ) # ViT-B/16 - 84.6% Top1 (backbone) elif backbone == "resnext101_wsl": pretrained = _make_pretrained_resnext101_wsl(use_pretrained) - scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 elif backbone == "efficientnet_lite3": pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) - scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 else: print(f"Backbone '{backbone}' not implemented") assert False - + return pretrained, scratch @@ -53,11 +55,11 @@ def _make_scratch(in_shape, out_shape, groups=1, expand=False): out_shape2 = out_shape out_shape3 = out_shape out_shape4 = out_shape - if expand==True: + if expand == True: out_shape1 = out_shape - out_shape2 = out_shape*2 - out_shape3 = out_shape*4 - out_shape4 = out_shape*8 + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 scratch.layer1_rn = nn.Conv2d( in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups @@ -77,10 +79,7 @@ def _make_scratch(in_shape, out_shape, groups=1, expand=False): def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): efficientnet = torch.hub.load( - "rwightman/gen-efficientnet-pytorch", - "tf_efficientnet_lite3", - pretrained=use_pretrained, - exportable=exportable + "rwightman/gen-efficientnet-pytorch", "tf_efficientnet_lite3", pretrained=use_pretrained, exportable=exportable ) return _make_efficientnet_backbone(efficientnet) @@ -88,21 +87,17 @@ def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): def _make_efficientnet_backbone(effnet): pretrained = nn.Module() - pretrained.layer1 = nn.Sequential( - effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] - ) + pretrained.layer1 = nn.Sequential(effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]) pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) return pretrained - + def _make_resnet_backbone(resnet): pretrained = nn.Module() - pretrained.layer1 = nn.Sequential( - resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 - ) + pretrained.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1) pretrained.layer2 = resnet.layer2 pretrained.layer3 = resnet.layer3 @@ -116,10 +111,8 @@ def _make_pretrained_resnext101_wsl(use_pretrained): return _make_resnet_backbone(resnet) - class Interpolate(nn.Module): - """Interpolation module. - """ + """Interpolation module.""" def __init__(self, scale_factor, mode, align_corners=False): """Init. @@ -145,16 +138,13 @@ def forward(self, x): tensor: interpolated data """ - x = self.interp( - x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners - ) + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) return x class ResidualConvUnit(nn.Module): - """Residual convolution module. - """ + """Residual convolution module.""" def __init__(self, features): """Init. @@ -164,13 +154,9 @@ def __init__(self, features): """ super().__init__() - self.conv1 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True - ) + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True) - self.conv2 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True - ) + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True) self.relu = nn.ReLU(inplace=True) @@ -192,8 +178,7 @@ def forward(self, x): class FeatureFusionBlock(nn.Module): - """Feature fusion block. - """ + """Feature fusion block.""" def __init__(self, features): """Init. @@ -219,18 +204,13 @@ def forward(self, *xs): output = self.resConfUnit2(output) - output = nn.functional.interpolate( - output, scale_factor=2, mode="bilinear", align_corners=True - ) + output = nn.functional.interpolate(output, scale_factor=2, mode="bilinear", align_corners=True) return output - - class ResidualConvUnit_custom(nn.Module): - """Residual convolution module. - """ + """Residual convolution module.""" def __init__(self, features, activation, bn): """Init. @@ -242,17 +222,13 @@ def __init__(self, features, activation, bn): self.bn = bn - self.groups=1 + self.groups = 1 - self.conv1 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups - ) - - self.conv2 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups - ) + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) - if self.bn==True: + if self.bn == True: self.bn1 = nn.BatchNorm2d(features) self.bn2 = nn.BatchNorm2d(features) @@ -269,15 +245,15 @@ def forward(self, x): Returns: tensor: output """ - + out = self.activation(x) out = self.conv1(out) - if self.bn==True: + if self.bn == True: out = self.bn1(out) - + out = self.activation(out) out = self.conv2(out) - if self.bn==True: + if self.bn == True: out = self.bn2(out) if self.groups > 1: @@ -289,8 +265,7 @@ def forward(self, x): class FeatureFusionBlock_custom(nn.Module): - """Feature fusion block. - """ + """Feature fusion block.""" def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): """Init. @@ -303,18 +278,18 @@ def __init__(self, features, activation, deconv=False, bn=False, expand=False, a self.deconv = deconv self.align_corners = align_corners - self.groups=1 + self.groups = 1 self.expand = expand out_features = features - if self.expand==True: - out_features = features//2 - + if self.expand == True: + out_features = features // 2 + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) - + self.skip_add = nn.quantized.FloatFunctional() def forward(self, *xs): @@ -332,11 +307,8 @@ def forward(self, *xs): output = self.resConfUnit2(output) - output = nn.functional.interpolate( - output, scale_factor=2, mode="bilinear", align_corners=self.align_corners - ) + output = nn.functional.interpolate(output, scale_factor=2, mode="bilinear", align_corners=self.align_corners) output = self.out_conv(output) return output - diff --git a/src/metr/ldm/modules/midas/midas/dpt_depth.py b/src/metr/ldm/modules/midas/midas/dpt_depth.py index 4e9aab5..e6cfdd1 100644 --- a/src/metr/ldm/modules/midas/midas/dpt_depth.py +++ b/src/metr/ldm/modules/midas/midas/dpt_depth.py @@ -3,13 +3,7 @@ import torch.nn.functional as F from .base_model import BaseModel -from .blocks import ( - FeatureFusionBlock, - FeatureFusionBlock_custom, - Interpolate, - _make_encoder, - forward_vit, -) +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder, forward_vit def _make_fusion_block(features, use_bn): @@ -48,7 +42,7 @@ def __init__( self.pretrained, self.scratch = _make_encoder( backbone, features, - False, # Set to true of you want to train from scratch, uses ImageNet weights + False, # Set to true of you want to train from scratch, uses ImageNet weights groups=1, expand=False, exportable=False, @@ -63,7 +57,6 @@ def __init__( self.scratch.output_conv = head - def forward(self, x): if self.channels_last == True: x.contiguous(memory_format=torch.channels_last) @@ -102,8 +95,7 @@ def __init__(self, path=None, non_negative=True, **kwargs): super().__init__(head, **kwargs) if path is not None: - self.load(path) + self.load(path) def forward(self, x): return super().forward(x).squeeze(dim=1) - diff --git a/src/metr/ldm/modules/midas/midas/midas_net.py b/src/metr/ldm/modules/midas/midas/midas_net.py index 8a95497..8c13f39 100644 --- a/src/metr/ldm/modules/midas/midas/midas_net.py +++ b/src/metr/ldm/modules/midas/midas/midas_net.py @@ -2,6 +2,7 @@ This file contains code that is adapted from https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py """ + import torch import torch.nn as nn @@ -10,8 +11,7 @@ class MidasNet(BaseModel): - """Network for monocular depth estimation. - """ + """Network for monocular depth estimation.""" def __init__(self, path=None, features=256, non_negative=True): """Init. @@ -27,7 +27,9 @@ def __init__(self, path=None, features=256, non_negative=True): use_pretrained = False if path is None else True - self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + self.pretrained, self.scratch = _make_encoder( + backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained + ) self.scratch.refinenet4 = FeatureFusionBlock(features) self.scratch.refinenet3 = FeatureFusionBlock(features) diff --git a/src/metr/ldm/modules/midas/midas/midas_net_custom.py b/src/metr/ldm/modules/midas/midas/midas_net_custom.py index 50e4acb..c1e167d 100644 --- a/src/metr/ldm/modules/midas/midas/midas_net_custom.py +++ b/src/metr/ldm/modules/midas/midas/midas_net_custom.py @@ -2,6 +2,7 @@ This file contains code that is adapted from https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py """ + import torch import torch.nn as nn @@ -10,11 +11,19 @@ class MidasNet_small(BaseModel): - """Network for monocular depth estimation. - """ - - def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, - blocks={'expand': True}): + """Network for monocular depth estimation.""" + + def __init__( + self, + path=None, + features=64, + backbone="efficientnet_lite3", + non_negative=True, + exportable=True, + channels_last=False, + align_corners=True, + blocks={"expand": True}, + ): """Init. Args: @@ -27,49 +36,57 @@ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_ne super(MidasNet_small, self).__init__() use_pretrained = False if path else True - + self.channels_last = channels_last self.blocks = blocks self.backbone = backbone self.groups = 1 - features1=features - features2=features - features3=features - features4=features + features1 = features + features2 = features + features3 = features + features4 = features self.expand = False - if "expand" in self.blocks and self.blocks['expand'] == True: + if "expand" in self.blocks and self.blocks["expand"] == True: self.expand = True - features1=features - features2=features*2 - features3=features*4 - features4=features*8 + features1 = features + features2 = features * 2 + features3 = features * 4 + features4 = features * 8 + + self.pretrained, self.scratch = _make_encoder( + self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable + ) - self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) - - self.scratch.activation = nn.ReLU(False) + self.scratch.activation = nn.ReLU(False) - self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) - self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) - self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) - self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + self.scratch.refinenet4 = FeatureFusionBlock_custom( + features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners + ) + self.scratch.refinenet3 = FeatureFusionBlock_custom( + features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners + ) + self.scratch.refinenet2 = FeatureFusionBlock_custom( + features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners + ) + self.scratch.refinenet1 = FeatureFusionBlock_custom( + features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners + ) - self.scratch.output_conv = nn.Sequential( - nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1, groups=self.groups), Interpolate(scale_factor=2, mode="bilinear"), - nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), self.scratch.activation, nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), nn.ReLU(True) if non_negative else nn.Identity(), nn.Identity(), ) - + if path: self.load(path) - def forward(self, x): """Forward pass. @@ -79,38 +96,35 @@ def forward(self, x): Returns: tensor: depth """ - if self.channels_last==True: + if self.channels_last == True: print("self.channels_last = ", self.channels_last) x.contiguous(memory_format=torch.channels_last) - layer_1 = self.pretrained.layer1(x) layer_2 = self.pretrained.layer2(layer_1) layer_3 = self.pretrained.layer3(layer_2) layer_4 = self.pretrained.layer4(layer_3) - + layer_1_rn = self.scratch.layer1_rn(layer_1) layer_2_rn = self.scratch.layer2_rn(layer_2) layer_3_rn = self.scratch.layer3_rn(layer_3) layer_4_rn = self.scratch.layer4_rn(layer_4) - path_4 = self.scratch.refinenet4(layer_4_rn) path_3 = self.scratch.refinenet3(path_4, layer_3_rn) path_2 = self.scratch.refinenet2(path_3, layer_2_rn) path_1 = self.scratch.refinenet1(path_2, layer_1_rn) - + out = self.scratch.output_conv(path_1) return torch.squeeze(out, dim=1) - def fuse_model(m): prev_previous_type = nn.Identity() - prev_previous_name = '' + prev_previous_name = "" previous_type = nn.Identity() - previous_name = '' + previous_name = "" for name, module in m.named_modules(): if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: # print("FUSED ", prev_previous_name, previous_name, name) @@ -125,4 +139,4 @@ def fuse_model(m): prev_previous_type = previous_type prev_previous_name = previous_name previous_type = type(module) - previous_name = name \ No newline at end of file + previous_name = name diff --git a/src/metr/ldm/modules/midas/midas/transforms.py b/src/metr/ldm/modules/midas/midas/transforms.py index 350cbc1..aede0fa 100644 --- a/src/metr/ldm/modules/midas/midas/transforms.py +++ b/src/metr/ldm/modules/midas/midas/transforms.py @@ -1,7 +1,8 @@ -import numpy as np -import cv2 import math +import cv2 +import numpy as np + def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): """Rezise the sample to ensure the given size. Keeps aspect ratio. @@ -28,13 +29,9 @@ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): shape[1] = math.ceil(scale * shape[1]) # resize - sample["image"] = cv2.resize( - sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method - ) + sample["image"] = cv2.resize(sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method) - sample["disparity"] = cv2.resize( - sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST - ) + sample["disparity"] = cv2.resize(sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST) sample["mask"] = cv2.resize( sample["mask"].astype(np.float32), tuple(shape[::-1]), @@ -46,8 +43,7 @@ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): class Resize(object): - """Resize sample to given size (width, height). - """ + """Resize sample to given size (width, height).""" def __init__( self, @@ -133,24 +129,14 @@ def get_size(self, width, height): # fit height scale_width = scale_height else: - raise ValueError( - f"resize_method {self.__resize_method} not implemented" - ) + raise ValueError(f"resize_method {self.__resize_method} not implemented") if self.__resize_method == "lower_bound": - new_height = self.constrain_to_multiple_of( - scale_height * height, min_val=self.__height - ) - new_width = self.constrain_to_multiple_of( - scale_width * width, min_val=self.__width - ) + new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width) elif self.__resize_method == "upper_bound": - new_height = self.constrain_to_multiple_of( - scale_height * height, max_val=self.__height - ) - new_width = self.constrain_to_multiple_of( - scale_width * width, max_val=self.__width - ) + new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width) elif self.__resize_method == "minimal": new_height = self.constrain_to_multiple_of(scale_height * height) new_width = self.constrain_to_multiple_of(scale_width * width) @@ -160,9 +146,7 @@ def get_size(self, width, height): return (new_width, new_height) def __call__(self, sample): - width, height = self.get_size( - sample["image"].shape[1], sample["image"].shape[0] - ) + width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0]) # resize sample sample["image"] = cv2.resize( @@ -180,9 +164,7 @@ def __call__(self, sample): ) if "depth" in sample: - sample["depth"] = cv2.resize( - sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST - ) + sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST) sample["mask"] = cv2.resize( sample["mask"].astype(np.float32), @@ -195,8 +177,7 @@ def __call__(self, sample): class NormalizeImage(object): - """Normlize image by given mean and std. - """ + """Normlize image by given mean and std.""" def __init__(self, mean, std): self.__mean = mean @@ -209,8 +190,7 @@ def __call__(self, sample): class PrepareForNet(object): - """Prepare sample for usage as network input. - """ + """Prepare sample for usage as network input.""" def __init__(self): pass diff --git a/src/metr/ldm/modules/midas/midas/vit.py b/src/metr/ldm/modules/midas/midas/vit.py index ea46b1b..13bd1e7 100644 --- a/src/metr/ldm/modules/midas/midas/vit.py +++ b/src/metr/ldm/modules/midas/midas/vit.py @@ -1,8 +1,9 @@ +import math +import types + +import timm import torch import torch.nn as nn -import timm -import types -import math import torch.nn.functional as F @@ -117,9 +118,7 @@ def _resize_pos_embed(self, posemb, gs_h, gs_w): def forward_flex(self, x): b, c, h, w = x.shape - pos_embed = self._resize_pos_embed( - self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] - ) + pos_embed = self._resize_pos_embed(self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]) B = x.shape[0] @@ -131,15 +130,11 @@ def forward_flex(self, x): x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) if getattr(self, "dist_token", None) is not None: - cls_tokens = self.cls_token.expand( - B, -1, -1 - ) # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks dist_token = self.dist_token.expand(B, -1, -1) x = torch.cat((cls_tokens, dist_token, x), dim=1) else: - cls_tokens = self.cls_token.expand( - B, -1, -1 - ) # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) x = x + pos_embed @@ -169,13 +164,9 @@ def get_readout_oper(vit_features, features, use_readout, start_index=1): elif use_readout == "add": readout_oper = [AddReadout(start_index)] * len(features) elif use_readout == "project": - readout_oper = [ - ProjectReadout(vit_features, start_index) for out_feat in features - ] + readout_oper = [ProjectReadout(vit_features, start_index) for out_feat in features] else: - assert ( - False - ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + assert False, "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" return readout_oper @@ -287,9 +278,7 @@ def _make_vit_b16_backbone( # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) - pretrained.model._resize_pos_embed = types.MethodType( - _resize_pos_embed, pretrained.model - ) + pretrained.model._resize_pos_embed = types.MethodType(_resize_pos_embed, pretrained.model) return pretrained @@ -311,24 +300,18 @@ def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) hooks = [2, 5, 8, 11] if hooks == None else hooks - return _make_vit_b16_backbone( - model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout - ) + return _make_vit_b16_backbone(model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout) def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) hooks = [2, 5, 8, 11] if hooks == None else hooks - return _make_vit_b16_backbone( - model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout - ) + return _make_vit_b16_backbone(model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout) def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): - model = timm.create_model( - "vit_deit_base_distilled_patch16_384", pretrained=pretrained - ) + model = timm.create_model("vit_deit_base_distilled_patch16_384", pretrained=pretrained) hooks = [2, 5, 8, 11] if hooks == None else hooks return _make_vit_b16_backbone( @@ -358,12 +341,8 @@ def _make_vit_b_rn50_backbone( pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) else: - pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( - get_activation("1") - ) - pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( - get_activation("2") - ) + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(get_activation("1")) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(get_activation("2")) pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) @@ -419,12 +398,8 @@ def _make_vit_b_rn50_backbone( ), ) else: - pretrained.act_postprocess1 = nn.Sequential( - nn.Identity(), nn.Identity(), nn.Identity() - ) - pretrained.act_postprocess2 = nn.Sequential( - nn.Identity(), nn.Identity(), nn.Identity() - ) + pretrained.act_postprocess1 = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity()) + pretrained.act_postprocess2 = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity()) pretrained.act_postprocess3 = nn.Sequential( readout_oper[2], @@ -468,16 +443,12 @@ def _make_vit_b_rn50_backbone( # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. - pretrained.model._resize_pos_embed = types.MethodType( - _resize_pos_embed, pretrained.model - ) + pretrained.model._resize_pos_embed = types.MethodType(_resize_pos_embed, pretrained.model) return pretrained -def _make_pretrained_vitb_rn50_384( - pretrained, use_readout="ignore", hooks=None, use_vit_only=False -): +def _make_pretrained_vitb_rn50_384(pretrained, use_readout="ignore", hooks=None, use_vit_only=False): model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) hooks = [0, 1, 8, 11] if hooks == None else hooks diff --git a/src/metr/ldm/modules/midas/utils.py b/src/metr/ldm/modules/midas/utils.py index 9a9d3b5..d8de2b3 100644 --- a/src/metr/ldm/modules/midas/utils.py +++ b/src/metr/ldm/modules/midas/utils.py @@ -1,8 +1,10 @@ """Utils for monoDepth.""" -import sys + import re -import numpy as np +import sys + import cv2 +import numpy as np import torch @@ -74,9 +76,7 @@ def write_pfm(path, image, scale=1): if len(image.shape) == 3 and image.shape[2] == 3: # color image color = True - elif ( - len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 - ): # greyscale + elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale color = False else: raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") @@ -135,9 +135,7 @@ def resize_image(img): img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) - img_resized = ( - torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() - ) + img_resized = torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() img_resized = img_resized.unsqueeze(0) return img_resized @@ -156,12 +154,11 @@ def resize_depth(depth, width, height): """ depth = torch.squeeze(depth[0, :, :, :]).to("cpu") - depth_resized = cv2.resize( - depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC - ) + depth_resized = cv2.resize(depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC) return depth_resized + def write_depth(path, depth, bits=1): """Write depth map to pfm and png file. @@ -174,7 +171,7 @@ def write_depth(path, depth, bits=1): depth_min = depth.min() depth_max = depth.max() - max_val = (2**(8*bits))-1 + max_val = (2 ** (8 * bits)) - 1 if depth_max - depth_min > np.finfo("float").eps: out = max_val * (depth - depth_min) / (depth_max - depth_min) diff --git a/src/metr/ldm/modules/x_transformer.py b/src/metr/ldm/modules/x_transformer.py index 5fc15bf..8de312c 100644 --- a/src/metr/ldm/modules/x_transformer.py +++ b/src/metr/ldm/modules/x_transformer.py @@ -1,25 +1,21 @@ """shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" -import torch -from torch import nn, einsum -import torch.nn.functional as F + +from collections import namedtuple from functools import partial from inspect import isfunction -from collections import namedtuple -from einops import rearrange, repeat, reduce + +import torch +import torch.nn.functional as F +from einops import rearrange, reduce, repeat +from torch import einsum, nn # constants DEFAULT_DIM_HEAD = 64 -Intermediates = namedtuple('Intermediates', [ - 'pre_softmax_attn', - 'post_softmax_attn' -]) +Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"]) -LayerIntermediates = namedtuple('Intermediates', [ - 'hiddens', - 'attn_intermediates' -]) +LayerIntermediates = namedtuple("Intermediates", ["hiddens", "attn_intermediates"]) class AbsolutePositionalEmbedding(nn.Module): @@ -39,18 +35,19 @@ def forward(self, x): class FixedPositionalEmbedding(nn.Module): def __init__(self, dim): super().__init__() - inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer('inv_freq', inv_freq) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) def forward(self, x, seq_dim=1, offset=0): t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset - sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq) emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) return emb[None, :, :] # helpers + def exists(val): return val is not None @@ -64,18 +61,21 @@ def default(val, d): def always(val): def inner(*args, **kwargs): return val + return inner def not_equals(val): def inner(x): return x != val + return inner def equals(val): def inner(x): return x == val + return inner @@ -85,6 +85,7 @@ def max_neg_value(tensor): # keyword argument helpers + def pick_and_pop(keys, d): values = list(map(lambda key: d.pop(key), keys)) return dict(zip(keys, values)) @@ -109,7 +110,7 @@ def group_by_key_prefix(prefix, d): def groupby_prefix_and_trim(prefix, d): kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) - kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))) return kwargs_without_prefix, kwargs @@ -139,7 +140,7 @@ def forward(self, x, **kwargs): class ScaleNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() - self.scale = dim ** -0.5 + self.scale = dim**-0.5 self.eps = eps self.g = nn.Parameter(torch.ones(1)) @@ -151,7 +152,7 @@ def forward(self, x): class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-8): super().__init__() - self.scale = dim ** -0.5 + self.scale = dim**-0.5 self.eps = eps self.g = nn.Parameter(torch.ones(dim)) @@ -171,16 +172,14 @@ def __init__(self, dim): self.gru = nn.GRUCell(dim, dim) def forward(self, x, residual): - gated_output = self.gru( - rearrange(x, 'b n d -> (b n) d'), - rearrange(residual, 'b n d -> (b n) d') - ) + gated_output = self.gru(rearrange(x, "b n d -> (b n) d"), rearrange(residual, "b n d -> (b n) d")) return gated_output.reshape_as(x) # feedforward + class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() @@ -192,20 +191,13 @@ def forward(self, x): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) - project_in = nn.Sequential( - nn.Linear(dim, inner_dim), - nn.GELU() - ) if not glu else GEGLU(dim, inner_dim) - - self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) - ) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) def forward(self, x): return self.net(x) @@ -214,23 +206,23 @@ def forward(self, x): # attention. class Attention(nn.Module): def __init__( - self, - dim, - dim_head=DEFAULT_DIM_HEAD, - heads=8, - causal=False, - mask=None, - talking_heads=False, - sparse_topk=None, - use_entmax15=False, - num_mem_kv=0, - dropout=0., - on_attn=False + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0.0, + on_attn=False, ): super().__init__() if use_entmax15: raise NotImplementedError("Check out entmax activation instead of softmax activation!") - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.heads = heads self.causal = causal self.mask = mask @@ -252,7 +244,7 @@ def __init__( self.sparse_topk = sparse_topk # entmax - #self.attn_fn = entmax15 if use_entmax15 else F.softmax + # self.attn_fn = entmax15 if use_entmax15 else F.softmax self.attn_fn = F.softmax # add memory key / values @@ -266,15 +258,7 @@ def __init__( self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) def forward( - self, - x, - context=None, - mask=None, - context_mask=None, - rel_pos=None, - sinusoidal_emb=None, - prev_attn=None, - mem=None + self, x, context=None, mask=None, context_mask=None, rel_pos=None, sinusoidal_emb=None, prev_attn=None, mem=None ): b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device kv_input = default(context, x) @@ -297,25 +281,25 @@ def forward( k = self.to_k(k_input) v = self.to_v(v_input) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) input_mask = None if any(map(exists, (mask, context_mask))): q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) k_mask = q_mask if not exists(context) else context_mask k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) - q_mask = rearrange(q_mask, 'b i -> b () i ()') - k_mask = rearrange(k_mask, 'b j -> b () () j') + q_mask = rearrange(q_mask, "b i -> b () i ()") + k_mask = rearrange(k_mask, "b j -> b () () j") input_mask = q_mask * k_mask if self.num_mem_kv > 0: - mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + mem_k, mem_v = map(lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v)) k = torch.cat((mem_k, k), dim=-2) v = torch.cat((mem_v, v), dim=-2) if exists(input_mask): input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) - dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale mask_value = max_neg_value(dots) if exists(prev_attn): @@ -324,7 +308,7 @@ def forward( pre_softmax_attn = dots if talking_heads: - dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + dots = einsum("b h i j, h k -> b k i j", dots, self.pre_softmax_proj).contiguous() if exists(rel_pos): dots = rel_pos(dots) @@ -336,7 +320,7 @@ def forward( if self.causal: i, j = dots.shape[-2:] r = torch.arange(i, device=device) - mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j") mask = F.pad(mask, (j - i, 0), value=False) dots.masked_fill_(mask, mask_value) del mask @@ -354,49 +338,46 @@ def forward( attn = self.dropout(attn) if talking_heads: - attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + attn = einsum("b h i j, h k -> b k i j", attn, self.post_softmax_proj).contiguous() - out = einsum('b h i j, b h j d -> b h i d', attn, v) - out = rearrange(out, 'b h n d -> b n (h d)') + out = einsum("b h i j, b h j d -> b h i d", attn, v) + out = rearrange(out, "b h n d -> b n (h d)") - intermediates = Intermediates( - pre_softmax_attn=pre_softmax_attn, - post_softmax_attn=post_softmax_attn - ) + intermediates = Intermediates(pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn) return self.to_out(out), intermediates class AttentionLayers(nn.Module): def __init__( - self, - dim, - depth, - heads=8, - causal=False, - cross_attend=False, - only_cross=False, - use_scalenorm=False, - use_rmsnorm=False, - use_rezero=False, - rel_pos_num_buckets=32, - rel_pos_max_distance=128, - position_infused_attn=False, - custom_layers=None, - sandwich_coef=None, - par_ratio=None, - residual_attn=False, - cross_residual_attn=False, - macaron=False, - pre_norm=True, - gate_residual=False, - **kwargs + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs, ): super().__init__() - ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) - attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs) + attn_kwargs, _ = groupby_prefix_and_trim("attn_", kwargs) - dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD) self.dim = dim self.depth = depth @@ -406,7 +387,9 @@ def __init__( self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None self.rotary_pos_emb = always(None) - assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + assert ( + rel_pos_num_buckets <= rel_pos_max_distance + ), "number of relative position buckets must be less than the relative position max distance" self.rel_pos = None self.pre_norm = pre_norm @@ -422,47 +405,47 @@ def __init__( branch_fn = Rezero if use_rezero else None if cross_attend and not only_cross: - default_block = ('a', 'c', 'f') + default_block = ("a", "c", "f") elif cross_attend and only_cross: - default_block = ('c', 'f') + default_block = ("c", "f") else: - default_block = ('a', 'f') + default_block = ("a", "f") if macaron: - default_block = ('f',) + default_block + default_block = ("f",) + default_block if exists(custom_layers): layer_types = custom_layers elif exists(par_ratio): par_depth = depth * len(default_block) - assert 1 < par_ratio <= par_depth, 'par ratio out of range' - default_block = tuple(filter(not_equals('f'), default_block)) + assert 1 < par_ratio <= par_depth, "par ratio out of range" + default_block = tuple(filter(not_equals("f"), default_block)) par_attn = par_depth // par_ratio depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper par_width = (depth_cut + depth_cut // par_attn) // par_attn - assert len(default_block) <= par_width, 'default block is too large for par_ratio' - par_block = default_block + ('f',) * (par_width - len(default_block)) + assert len(default_block) <= par_width, "default block is too large for par_ratio" + par_block = default_block + ("f",) * (par_width - len(default_block)) par_head = par_block * par_attn - layer_types = par_head + ('f',) * (par_depth - len(par_head)) + layer_types = par_head + ("f",) * (par_depth - len(par_head)) elif exists(sandwich_coef): - assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' - layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + assert sandwich_coef > 0 and sandwich_coef <= depth, "sandwich coefficient should be less than the depth" + layer_types = ("a",) * sandwich_coef + default_block * (depth - sandwich_coef) + ("f",) * sandwich_coef else: layer_types = default_block * depth self.layer_types = layer_types - self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + self.num_attn_layers = len(list(filter(equals("a"), layer_types))) for layer_type in self.layer_types: - if layer_type == 'a': + if layer_type == "a": layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) - elif layer_type == 'c': + elif layer_type == "c": layer = Attention(dim, heads=heads, **attn_kwargs) - elif layer_type == 'f': + elif layer_type == "f": layer = FeedForward(dim, **ff_kwargs) layer = layer if not macaron else Scale(0.5, layer) else: - raise Exception(f'invalid layer type {layer_type}') + raise Exception(f"invalid layer type {layer_type}") if isinstance(layer, Attention) and exists(branch_fn): layer = branch_fn(layer) @@ -472,21 +455,9 @@ def __init__( else: residual_fn = Residual() - self.layers.append(nn.ModuleList([ - norm_fn(), - layer, - residual_fn - ])) + self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn])) - def forward( - self, - x, - context=None, - mask=None, - context_mask=None, - mems=None, - return_hiddens=False - ): + def forward(self, x, context=None, mask=None, context_mask=None, mems=None, return_hiddens=False): hiddens = [] intermediates = [] prev_attn = None @@ -497,7 +468,7 @@ def forward( for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): is_last = ind == (len(self.layers) - 1) - if layer_type == 'a': + if layer_type == "a": hiddens.append(x) layer_mem = mems.pop(0) @@ -506,32 +477,35 @@ def forward( if self.pre_norm: x = norm(x) - if layer_type == 'a': - out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, - prev_attn=prev_attn, mem=layer_mem) - elif layer_type == 'c': + if layer_type == "a": + out, inter = block( + x, + mask=mask, + sinusoidal_emb=self.pia_pos_emb, + rel_pos=self.rel_pos, + prev_attn=prev_attn, + mem=layer_mem, + ) + elif layer_type == "c": out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) - elif layer_type == 'f': + elif layer_type == "f": out = block(x) x = residual_fn(out, residual) - if layer_type in ('a', 'c'): + if layer_type in ("a", "c"): intermediates.append(inter) - if layer_type == 'a' and self.residual_attn: + if layer_type == "a" and self.residual_attn: prev_attn = inter.pre_softmax_attn - elif layer_type == 'c' and self.cross_residual_attn: + elif layer_type == "c" and self.cross_residual_attn: prev_cross_attn = inter.pre_softmax_attn if not self.pre_norm and not is_last: x = norm(x) if return_hiddens: - intermediates = LayerIntermediates( - hiddens=hiddens, - attn_intermediates=intermediates - ) + intermediates = LayerIntermediates(hiddens=hiddens, attn_intermediates=intermediates) return x, intermediates @@ -540,27 +514,26 @@ def forward( class Encoder(AttentionLayers): def __init__(self, **kwargs): - assert 'causal' not in kwargs, 'cannot set causality on encoder' + assert "causal" not in kwargs, "cannot set causality on encoder" super().__init__(causal=False, **kwargs) - class TransformerWrapper(nn.Module): def __init__( - self, - *, - num_tokens, - max_seq_len, - attn_layers, - emb_dim=None, - max_mem_len=0., - emb_dropout=0., - num_memory_tokens=None, - tie_embedding=False, - use_pos_emb=True + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0.0, + emb_dropout=0.0, + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True, ): super().__init__() - assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + assert isinstance(attn_layers, AttentionLayers), "attention layers must be one of Encoder or Decoder" dim = attn_layers.dim emb_dim = default(emb_dim, dim) @@ -570,8 +543,11 @@ def __init__( self.num_tokens = num_tokens self.token_emb = nn.Embedding(num_tokens, emb_dim) - self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( - use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.pos_emb = ( + AbsolutePositionalEmbedding(emb_dim, max_seq_len) + if (use_pos_emb and not attn_layers.has_pos_emb) + else always(0) + ) self.emb_dropout = nn.Dropout(emb_dropout) self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() @@ -589,22 +565,13 @@ def __init__( self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) # let funnel encoder know number of memory tokens, if specified - if hasattr(attn_layers, 'num_memory_tokens'): + if hasattr(attn_layers, "num_memory_tokens"): attn_layers.num_memory_tokens = num_memory_tokens def init_(self): nn.init.normal_(self.token_emb.weight, std=0.02) - def forward( - self, - x, - return_embeddings=False, - mask=None, - return_mems=False, - return_attn=False, - mems=None, - **kwargs - ): + def forward(self, x, return_embeddings=False, mask=None, return_mems=False, return_attn=False, mems=None, **kwargs): b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens x = self.token_emb(x) x += self.pos_emb(x) @@ -613,7 +580,7 @@ def forward( x = self.project_emb(x) if num_mem > 0: - mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + mem = repeat(self.memory_tokens, "n d -> b n d", b=b) x = torch.cat((mem, x), dim=1) # auto-handle masking after appending memory tokens @@ -630,7 +597,7 @@ def forward( if return_mems: hiddens = intermediates.hiddens new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens - new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + new_mems = list(map(lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems)) return out, new_mems if return_attn: @@ -638,4 +605,3 @@ def forward( return out, attn_maps return out - diff --git a/src/metr/ldm/util.py b/src/metr/ldm/util.py index 38f1481..78aa938 100644 --- a/src/metr/ldm/util.py +++ b/src/metr/ldm/util.py @@ -1,11 +1,11 @@ import importlib +from inspect import isfunction -import torch import numpy as np - -from inspect import isfunction +import torch from PIL import Image, ImageDraw, ImageFont + def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) # xc a list of captions to plot @@ -14,9 +14,9 @@ def log_txt_as_img(wh, xc, size=10): for bi in range(b): txt = Image.new("RGB", wh, color="white") draw = ImageDraw.Draw(txt) - font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) nc = int(40 * (wh[0] / 256)) - lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + lines = "\n".join(xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)) try: draw.text((0, 0), lines, fill="black", font=font) @@ -37,7 +37,7 @@ def ismap(x): def isimage(x): - if not isinstance(x,torch.Tensor): + if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) @@ -69,7 +69,7 @@ def count_params(model, verbose=False): def instantiate_from_config(config): if not "target" in config: - if config == '__is_first_stage__': + if config == "__is_first_stage__": return None elif config == "__is_unconditional__": return None @@ -82,4 +82,4 @@ def get_obj_from_str(string, reload=False): if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) \ No newline at end of file + return getattr(importlib.import_module(module, package=None), cls) diff --git a/src/metr/loss/color_wrapper.py b/src/metr/loss/color_wrapper.py index 5cba086..89e3e3a 100644 --- a/src/metr/loss/color_wrapper.py +++ b/src/metr/loss/color_wrapper.py @@ -2,26 +2,29 @@ import torch.nn as nn import torch.nn.functional as F + class RGB2YCbCr(nn.Module): def __init__(self): super().__init__() - transf = torch.tensor([[.299, .587, .114], [-.1687, -.3313, .5], [.5, -.4187, -.0813]]).transpose(0, 1) + transf = torch.tensor([[0.299, 0.587, 0.114], [-0.1687, -0.3313, 0.5], [0.5, -0.4187, -0.0813]]).transpose(0, 1) self.transform = nn.Parameter(transf, requires_grad=False) bias = torch.tensor([0, 0.5, 0.5]) self.bias = nn.Parameter(bias, requires_grad=False) - + def forward(self, rgb): N, C, H, W = rgb.shape assert C == 3 - rgb = rgb.transpose(1,3) + rgb = rgb.transpose(1, 3) cbcr = torch.matmul(rgb, self.transform) cbcr += self.bias - return cbcr.transpose(1,3) + return cbcr.transpose(1, 3) + class ColorWrapper(nn.Module): """ Extension for single-channel loss to work on color images """ + def __init__(self, lossclass, args, kwargs, trainable=False): """ Parameters: @@ -31,37 +34,39 @@ def __init__(self, lossclass, args, kwargs, trainable=False): kwargs: dict, key word arguments for instantiation of loss fun """ super().__init__() - + # submodules - self.add_module('to_YCbCr', RGB2YCbCr()) - self.add_module('ly', lossclass(*args, **kwargs)) - self.add_module('lcb', lossclass(*args, **kwargs)) - self.add_module('lcr', lossclass(*args, **kwargs)) - + self.add_module("to_YCbCr", RGB2YCbCr()) + self.add_module("ly", lossclass(*args, **kwargs)) + self.add_module("lcb", lossclass(*args, **kwargs)) + self.add_module("lcr", lossclass(*args, **kwargs)) + # weights self.w_tild = nn.Parameter(torch.zeros(3), requires_grad=trainable) - - @property + + @property def w(self): return F.softmax(self.w_tild, dim=0) - + def forward(self, input, target): # convert color space input = self.to_YCbCr(input) target = self.to_YCbCr(target) - - ly = self.ly(input[:,[0],:,:], target[:,[0],:,:]) - lcb = self.lcb(input[:,[1],:,:], target[:,[1],:,:]) - lcr = self.lcr(input[:,[2],:,:], target[:,[2],:,:]) - + + ly = self.ly(input[:, [0], :, :], target[:, [0], :, :]) + lcb = self.lcb(input[:, [1], :, :], target[:, [1], :, :]) + lcr = self.lcr(input[:, [2], :, :], target[:, [2], :, :]) + w = self.w - + return ly * w[0] + lcb * w[1] + lcr * w[2] + class GreyscaleWrapper(nn.Module): """ Maps 3 channel RGB or 1 channel greyscale input to 3 greyscale channels """ + def __init__(self, lossclass, args, kwargs): """ Parameters: @@ -70,15 +75,15 @@ def __init__(self, lossclass, args, kwargs): kwargs: dict, key word arguments for instantiation of loss fun """ super().__init__() - + # submodules - self.add_module('loss', lossclass(*args, **kwargs)) + self.add_module("loss", lossclass(*args, **kwargs)) def to_greyscale(self, tensor): - return tensor[:,[0],:,:] * 0.3 + tensor[:,[1],:,:] * 0.59 + tensor[:,[2],:,:] * 0.11 + return tensor[:, [0], :, :] * 0.3 + tensor[:, [1], :, :] * 0.59 + tensor[:, [2], :, :] * 0.11 def forward(self, input, target): - (N,C,X,Y) = input.size() + (N, C, X, Y) = input.size() if N == 3: # convert input to greyscale diff --git a/src/metr/loss/dct2d.py b/src/metr/loss/dct2d.py index 65a59f1..921fdb6 100644 --- a/src/metr/loss/dct2d.py +++ b/src/metr/loss/dct2d.py @@ -3,85 +3,95 @@ import torch.nn as nn import torch.nn.functional as F + class Dct2d(nn.Module): """ Blockwhise 2D DCT """ + def __init__(self, blocksize=8, interleaving=False): """ Parameters: - blocksize: int, size of the Blocks for discrete cosine transform + blocksize: int, size of the Blocks for discrete cosine transform interleaving: bool, should the blocks interleave? """ - super().__init__() # call super constructor - + super().__init__() # call super constructor + self.blocksize = blocksize self.interleaving = interleaving - + if interleaving: self.stride = self.blocksize // 2 else: self.stride = self.blocksize - + # precompute DCT weight matrix - A = np.zeros((blocksize,blocksize)) + A = np.zeros((blocksize, blocksize)) for i in range(blocksize): - c_i = 1/np.sqrt(2) if i == 0 else 1. + c_i = 1 / np.sqrt(2) if i == 0 else 1.0 for n in range(blocksize): - A[i,n] = np.sqrt(2/blocksize) * c_i * np.cos((2*n+ 1)/(blocksize*2) * i * np.pi) - + A[i, n] = np.sqrt(2 / blocksize) * c_i * np.cos((2 * n + 1) / (blocksize * 2) * i * np.pi) + # set up conv layer self.A = nn.Parameter(torch.tensor(A, dtype=torch.float32), requires_grad=False) self.unfold = torch.nn.Unfold(kernel_size=blocksize, padding=0, stride=self.stride) return - + def forward(self, x): """ performs 2D blockwhise DCT - + Parameters: x: tensor of dimension (N, 1, h, w) - + Return: tensor of dimension (N, k, blocksize, blocksize) where the 2nd dimension indexes the block. Dimensions 3 and 4 are the block DCT coefficients """ - + (N, C, H, W) = x.shape - assert (C == 1), "DCT is only implemented for a single channel" - assert (H >= self.blocksize), "Input too small for blocksize" - assert (W >= self.blocksize), "Input too small for blocksize" - assert (H % self.stride == 0) and (W % self.stride == 0), "FFT is only for dimensions divisible by the blocksize" - + assert C == 1, "DCT is only implemented for a single channel" + assert H >= self.blocksize, "Input too small for blocksize" + assert W >= self.blocksize, "Input too small for blocksize" + assert (H % self.stride == 0) and ( + W % self.stride == 0 + ), "FFT is only for dimensions divisible by the blocksize" + # unfold to blocks x = self.unfold(x) # now shape (N, blocksize**2, k) (N, _, k) = x.shape - x = x.view(-1,self.blocksize,self.blocksize,k).permute(0,3,1,2) + x = x.view(-1, self.blocksize, self.blocksize, k).permute(0, 3, 1, 2) # now shape (N, #k, blocksize, blocksize) # perform DCT - coeff = self.A.matmul(x).matmul(self.A.transpose(0,1)) - + coeff = self.A.matmul(x).matmul(self.A.transpose(0, 1)) + return coeff - + def inverse(self, coeff, output_shape): """ performs 2D blockwhise iDCT - + Parameters: coeff: tensor of dimension (N, k, blocksize, blocksize) where the 2nd dimension indexes the block. Dimensions 3 and 4 are the block DCT coefficients output_shape: (h, w) dimensions of the reconstructed image - + Return: tensor of dimension (N, 1, h, w) """ if self.interleaving: - raise Exception('Inverse block DCT is not implemented for interleaving blocks!') - + raise Exception("Inverse block DCT is not implemented for interleaving blocks!") + # perform iDCT - x = self.A.transpose(0,1).matmul(coeff).matmul(self.A) + x = self.A.transpose(0, 1).matmul(coeff).matmul(self.A) (N, k, _, _) = x.shape - x = x.permute(0,2,3,1).view(-1, self.blocksize**2, k) - x = F.fold(x, output_size=(output_shape[-2], output_shape[-1]), kernel_size=self.blocksize, padding=0, stride=self.blocksize) + x = x.permute(0, 2, 3, 1).view(-1, self.blocksize**2, k) + x = F.fold( + x, + output_size=(output_shape[-2], output_shape[-1]), + kernel_size=self.blocksize, + padding=0, + stride=self.blocksize, + ) return x diff --git a/src/metr/loss/deep_loss.py b/src/metr/loss/deep_loss.py index f9962c6..8db3b90 100644 --- a/src/metr/loss/deep_loss.py +++ b/src/metr/loss/deep_loss.py @@ -1,26 +1,41 @@ # Deeploss function from Zhang et al. (2018) +from collections import namedtuple + +import numpy as np import torch -import torchvision import torch.nn as nn -import numpy as np +import torchvision from torchvision import models -from collections import namedtuple + class NetLinLayer(nn.Module): - ''' A single linear layer which does a 1x1 conv ''' + """A single linear layer which does a 1x1 conv""" + def __init__(self, chn_in, chn_out=1, use_dropout=False): super(NetLinLayer, self).__init__() - layers = [nn.Dropout(),] if(use_dropout) else [nn.Dropout(p=0.0),] - layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] + layers = ( + [ + nn.Dropout(), + ] + if (use_dropout) + else [ + nn.Dropout(p=0.0), + ] + ) + layers += [ + nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), + ] self.model = nn.Sequential(*layers) -def normalize_tensor(in_feat,eps=1e-10): +def normalize_tensor(in_feat, eps=1e-10): # norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1)).view(in_feat.size()[0],1,in_feat.size()[2],in_feat.size()[3]).repeat(1,in_feat.size()[1],1,1) - norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1)).view(in_feat.size()[0],1,in_feat.size()[2],in_feat.size()[3]) - return in_feat/(norm_factor.expand_as(in_feat)+eps) + norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1)).view( + in_feat.size()[0], 1, in_feat.size()[2], in_feat.size()[3] + ) + return in_feat / (norm_factor.expand_as(in_feat) + eps) class vgg16(torch.nn.Module): @@ -58,7 +73,7 @@ def forward(self, X): h_relu4_3 = h h = self.slice5(h) h_relu5_3 = h - vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]) out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) return out @@ -77,7 +92,7 @@ def __init__(self, requires_grad=False, pretrained=True): self.N_slices = 7 for x in range(2): self.slice1.add_module(str(x), pretrained_features[x]) - for x in range(2,5): + for x in range(2, 5): self.slice2.add_module(str(x), pretrained_features[x]) for x in range(5, 8): self.slice3.add_module(str(x), pretrained_features[x]) @@ -108,8 +123,8 @@ def forward(self, X): h_relu6 = h h = self.slice7(h) h_relu7 = h - vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) - out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) + vgg_outputs = namedtuple("SqueezeOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"]) + out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7) return out @@ -149,13 +164,25 @@ def forward(self, X): h_relu4 = h h = self.slice5(h) h_relu5 = h - alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) + alexnet_outputs = namedtuple("AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"]) out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) return out + class PNetLin(nn.Module): - def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, use_gpu=True, spatial=False, version='0.1', colorspace='RGB', reduction='none'): + def __init__( + self, + pnet_type="vgg", + pnet_rand=False, + pnet_tune=False, + use_dropout=True, + use_gpu=True, + spatial=False, + version="0.1", + colorspace="RGB", + reduction="none", + ): super(PNetLin, self).__init__() self.use_gpu = use_gpu @@ -167,37 +194,39 @@ def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropou self.colorspace = colorspace self.reduction = reduction - if(self.pnet_type in ['vgg','vgg16']): + if self.pnet_type in ["vgg", "vgg16"]: net_type = vgg16 - self.chns = [64,128,256,512,512] - elif(self.pnet_type=='alex'): + self.chns = [64, 128, 256, 512, 512] + elif self.pnet_type == "alex": net_type = alexnet - self.chns = [64,192,384,256,256] - elif(self.pnet_type=='squeeze'): + self.chns = [64, 192, 384, 256, 256] + elif self.pnet_type == "squeeze": net_type = squeezenet - self.chns = [64,128,256,384,384,512,512] + self.chns = [64, 128, 256, 384, 384, 512, 512] - if(self.pnet_tune): - self.net = net_type(pretrained=not self.pnet_rand,requires_grad=True) + if self.pnet_tune: + self.net = net_type(pretrained=not self.pnet_rand, requires_grad=True) else: - self.net = [net_type(pretrained=not self.pnet_rand,requires_grad=False),] - - self.lin0 = NetLinLayer(self.chns[0],use_dropout=use_dropout) - self.lin1 = NetLinLayer(self.chns[1],use_dropout=use_dropout) - self.lin2 = NetLinLayer(self.chns[2],use_dropout=use_dropout) - self.lin3 = NetLinLayer(self.chns[3],use_dropout=use_dropout) - self.lin4 = NetLinLayer(self.chns[4],use_dropout=use_dropout) - self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] - if(self.pnet_type=='squeeze'): # 7 layers for squeezenet - self.lin5 = NetLinLayer(self.chns[5],use_dropout=use_dropout) - self.lin6 = NetLinLayer(self.chns[6],use_dropout=use_dropout) - self.lins+=[self.lin5,self.lin6] - - self.shift = torch.autograd.Variable(torch.Tensor([-.030, -.088, -.188]).view(1,3,1,1)) - self.scale = torch.autograd.Variable(torch.Tensor([.458, .448, .450]).view(1,3,1,1)) - - if(use_gpu): - if(self.pnet_tune): + self.net = [ + net_type(pretrained=not self.pnet_rand, requires_grad=False), + ] + + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + if self.pnet_type == "squeeze": # 7 layers for squeezenet + self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) + self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) + self.lins += [self.lin5, self.lin6] + + self.shift = torch.autograd.Variable(torch.Tensor([-0.030, -0.088, -0.188]).view(1, 3, 1, 1)) + self.scale = torch.autograd.Variable(torch.Tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1)) + + if use_gpu: + if self.pnet_tune: self.net.cuda() else: self.net[0].cuda() @@ -208,19 +237,19 @@ def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropou self.lin2.cuda() self.lin3.cuda() self.lin4.cuda() - if(self.pnet_type=='squeeze'): + if self.pnet_type == "squeeze": self.lin5.cuda() self.lin6.cuda() def forward(self, in0, in1): - in0_sc = (in0 - self.shift.expand_as(in0))/self.scale.expand_as(in0) - in1_sc = (in1 - self.shift.expand_as(in0))/self.scale.expand_as(in0) + in0_sc = (in0 - self.shift.expand_as(in0)) / self.scale.expand_as(in0) + in1_sc = (in1 - self.shift.expand_as(in0)) / self.scale.expand_as(in0) - if self.colorspace == 'Gray': + if self.colorspace == "Gray": in0_sc = util.tensor2tensorGrayscaleLazy(in0_sc) in1_sc = util.tensor2tensorGrayscaleLazy(in1_sc) - if(self.version=='0.0'): + if self.version == "0.0": # v0.0 - original release had a bug, where input was not scaled in0_input = in0 in1_input = in1 @@ -229,7 +258,7 @@ def forward(self, in0, in1): in0_input = in0_sc in1_input = in1_sc - if(self.pnet_tune): + if self.pnet_tune: outs0 = self.net.forward(in0_input) outs1 = self.net.forward(in1_input) else: @@ -238,32 +267,32 @@ def forward(self, in0, in1): feats0 = {} feats1 = {} - diffs = [0]*len(outs0) + diffs = [0] * len(outs0) - for (kk,out0) in enumerate(outs0): + for kk, out0 in enumerate(outs0): feats0[kk] = normalize_tensor(outs0[kk]) # norm NN outputs - feats1[kk] = normalize_tensor(outs1[kk]) - diffs[kk] = (feats0[kk]-feats1[kk])**2 # squared diff + feats1[kk] = normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 # squared diff if self.spatial: lin_models = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] - if(self.pnet_type=='squeeze'): + if self.pnet_type == "squeeze": lin_models.extend([self.lin5, self.lin6]) res = [lin_models[kk].model(diffs[kk]) for kk in range(len(diffs))] return res - - val = torch.mean(torch.mean(self.lin0.model(diffs[0]),dim=3),dim=2) # sum means over H, W - val = val + torch.mean(torch.mean(self.lin1.model(diffs[1]),dim=3),dim=2) - val = val + torch.mean(torch.mean(self.lin2.model(diffs[2]),dim=3),dim=2) - val = val + torch.mean(torch.mean(self.lin3.model(diffs[3]),dim=3),dim=2) - val = val + torch.mean(torch.mean(self.lin4.model(diffs[4]),dim=3),dim=2) - if(self.pnet_type=='squeeze'): - val = val + torch.mean(torch.mean(self.lin5.model(diffs[5]),dim=3),dim=2) - val = val + torch.mean(torch.mean(self.lin6.model(diffs[6]),dim=3),dim=2) - - val = val.view(val.size()[0],val.size()[1],1,1) - - if self.reduction == 'sum': + + val = torch.mean(torch.mean(self.lin0.model(diffs[0]), dim=3), dim=2) # sum means over H, W + val = val + torch.mean(torch.mean(self.lin1.model(diffs[1]), dim=3), dim=2) + val = val + torch.mean(torch.mean(self.lin2.model(diffs[2]), dim=3), dim=2) + val = val + torch.mean(torch.mean(self.lin3.model(diffs[3]), dim=3), dim=2) + val = val + torch.mean(torch.mean(self.lin4.model(diffs[4]), dim=3), dim=2) + if self.pnet_type == "squeeze": + val = val + torch.mean(torch.mean(self.lin5.model(diffs[5]), dim=3), dim=2) + val = val + torch.mean(torch.mean(self.lin6.model(diffs[6]), dim=3), dim=2) + + val = val.view(val.size()[0], val.size()[1], 1, 1) + + if self.reduction == "sum": val = torch.sum(val) return val diff --git a/src/metr/loss/loss_provider.py b/src/metr/loss/loss_provider.py index fbb002a..965378e 100644 --- a/src/metr/loss/loss_provider.py +++ b/src/metr/loss/loss_provider.py @@ -1,28 +1,40 @@ -import torch -import torch.nn as nn import os from collections import OrderedDict +import torch +import torch.nn as nn from loss.color_wrapper import ColorWrapper, GreyscaleWrapper +from loss.deep_loss import PNetLin from loss.shift_wrapper import ShiftWrapper +from loss.ssim import SSIM from loss.watson import WatsonDistance from loss.watson_fft import WatsonDistanceFft from loss.watson_vgg import WatsonDistanceVgg -from loss.deep_loss import PNetLin -from loss.ssim import SSIM -class LossProvider(): +class LossProvider: def __init__(self): - self.loss_functions = ['L1', 'L2', 'SSIM', 'Watson-dct', 'Watson-fft', 'Watson-vgg', 'Deeploss-vgg', 'Deeploss-squeeze', 'Adaptive'] - self.color_models = ['LA', 'RGB'] + self.loss_functions = [ + "L1", + "L2", + "SSIM", + "Watson-dct", + "Watson-fft", + "Watson-vgg", + "Deeploss-vgg", + "Deeploss-squeeze", + "Adaptive", + ] + self.color_models = ["LA", "RGB"] def load_state_dict(self, filename): current_dir = os.path.dirname(__file__) - path = os.path.join(current_dir, 'losses', filename) - return torch.load(path, map_location='cpu') - - def get_loss_function(self, model, colorspace='RGB', reduction='sum', deterministic=False, pretrained=True, image_size=None): + path = os.path.join(current_dir, "losses", filename) + return torch.load(path, map_location="cpu") + + def get_loss_function( + self, model, colorspace="RGB", reduction="sum", deterministic=False, pretrained=True, image_size=None + ): """ returns a trained loss class. model: one of the values returned by self.loss_functions @@ -30,86 +42,87 @@ def get_loss_function(self, model, colorspace='RGB', reduction='sum', determinis deterministic: bool, if false (default) uses shifting of image blocks for watson-fft image_size: tuple, size of input images. Only required for adaptive loss. Eg: [3, 64, 64] """ - is_greyscale = colorspace in ['grey', 'Grey', 'LA', 'greyscale', 'grey-scale'] - + is_greyscale = colorspace in ["grey", "Grey", "LA", "greyscale", "grey-scale"] - if model.lower() in ['l2']: + if model.lower() in ["l2"]: loss = nn.MSELoss(reduction=reduction) - elif model.lower() in ['l1']: + elif model.lower() in ["l1"]: loss = nn.L1Loss(reduction=reduction) - elif model.lower() in ['ssim']: - loss = SSIM(size_average=(reduction in ['sum', 'mean'])) - elif model.lower() in ['watson', 'watson-dct']: + elif model.lower() in ["ssim"]: + loss = SSIM(size_average=(reduction in ["sum", "mean"])) + elif model.lower() in ["watson", "watson-dct"]: if is_greyscale: if deterministic: loss = WatsonDistance(reduction=reduction) - if pretrained: - loss.load_state_dict(self.load_state_dict('gray_watson_dct_trial0.pth')) + if pretrained: + loss.load_state_dict(self.load_state_dict("gray_watson_dct_trial0.pth")) else: - loss = ShiftWrapper(WatsonDistance, (), {'reduction': reduction}) - if pretrained: - loss.loss.load_state_dict(self.load_state_dict('gray_watson_dct_trial0.pth')) + loss = ShiftWrapper(WatsonDistance, (), {"reduction": reduction}) + if pretrained: + loss.loss.load_state_dict(self.load_state_dict("gray_watson_dct_trial0.pth")) else: if deterministic: - loss = ColorWrapper(WatsonDistance, (), {'reduction': reduction}) - if pretrained: - loss.load_state_dict(self.load_state_dict('rgb_watson_dct_trial0.pth')) + loss = ColorWrapper(WatsonDistance, (), {"reduction": reduction}) + if pretrained: + loss.load_state_dict(self.load_state_dict("rgb_watson_dct_trial0.pth")) else: - loss = ShiftWrapper(ColorWrapper, (WatsonDistance, (), {'reduction': reduction}), {}) - if pretrained: - loss.loss.load_state_dict(self.load_state_dict('rgb_watson_dct_trial0.pth')) - elif model.lower() in ['watson-fft', 'watson-dft']: + loss = ShiftWrapper(ColorWrapper, (WatsonDistance, (), {"reduction": reduction}), {}) + if pretrained: + loss.loss.load_state_dict(self.load_state_dict("rgb_watson_dct_trial0.pth")) + elif model.lower() in ["watson-fft", "watson-dft"]: if is_greyscale: if deterministic: loss = WatsonDistanceFft(reduction=reduction) - if pretrained: - loss.load_state_dict(self.load_state_dict('gray_watson_fft_trial0.pth')) + if pretrained: + loss.load_state_dict(self.load_state_dict("gray_watson_fft_trial0.pth")) else: - loss = ShiftWrapper(WatsonDistanceFft, (), {'reduction': reduction}) - if pretrained: - loss.loss.load_state_dict(self.load_state_dict('gray_watson_fft_trial0.pth')) + loss = ShiftWrapper(WatsonDistanceFft, (), {"reduction": reduction}) + if pretrained: + loss.loss.load_state_dict(self.load_state_dict("gray_watson_fft_trial0.pth")) else: if deterministic: - loss = ColorWrapper(WatsonDistanceFft, (), {'reduction': reduction}) - if pretrained: - loss.load_state_dict(self.load_state_dict('rgb_watson_fft_trial0.pth')) + loss = ColorWrapper(WatsonDistanceFft, (), {"reduction": reduction}) + if pretrained: + loss.load_state_dict(self.load_state_dict("rgb_watson_fft_trial0.pth")) else: - loss = ShiftWrapper(ColorWrapper, (WatsonDistanceFft, (), {'reduction': reduction}), {}) - if pretrained: - loss.loss.load_state_dict(self.load_state_dict('rgb_watson_fft_trial0.pth')) - elif model.lower() in ['watson-vgg', 'watson-deep']: + loss = ShiftWrapper(ColorWrapper, (WatsonDistanceFft, (), {"reduction": reduction}), {}) + if pretrained: + loss.loss.load_state_dict(self.load_state_dict("rgb_watson_fft_trial0.pth")) + elif model.lower() in ["watson-vgg", "watson-deep"]: if is_greyscale: - loss = GreyscaleWrapper(WatsonDistanceVgg, (), {'reduction': reduction}) - if pretrained: - loss.loss.load_state_dict(self.load_state_dict('gray_watson_vgg_trial0.pth')) + loss = GreyscaleWrapper(WatsonDistanceVgg, (), {"reduction": reduction}) + if pretrained: + loss.loss.load_state_dict(self.load_state_dict("gray_watson_vgg_trial0.pth")) else: loss = WatsonDistanceVgg(reduction=reduction) - if pretrained: - loss.load_state_dict(self.load_state_dict('rgb_watson_vgg_trial0.pth')) - elif model.lower() in ['deeploss-vgg']: + if pretrained: + loss.load_state_dict(self.load_state_dict("rgb_watson_vgg_trial0.pth")) + elif model.lower() in ["deeploss-vgg"]: if is_greyscale: - loss = GreyscaleWrapper(PNetLin, (), {'pnet_type': 'vgg', 'reduction': reduction, 'use_dropout': False}) - if pretrained: - loss.loss.load_state_dict(self.load_state_dict('gray_pnet_lin_vgg_trial0.pth')) + loss = GreyscaleWrapper(PNetLin, (), {"pnet_type": "vgg", "reduction": reduction, "use_dropout": False}) + if pretrained: + loss.loss.load_state_dict(self.load_state_dict("gray_pnet_lin_vgg_trial0.pth")) else: - loss = PNetLin(pnet_type='vgg', reduction=reduction, use_dropout=False) - if pretrained: - loss.load_state_dict(self.load_state_dict('rgb_pnet_lin_vgg_trial0.pth')) - elif model.lower() in ['deeploss-squeeze']: + loss = PNetLin(pnet_type="vgg", reduction=reduction, use_dropout=False) + if pretrained: + loss.load_state_dict(self.load_state_dict("rgb_pnet_lin_vgg_trial0.pth")) + elif model.lower() in ["deeploss-squeeze"]: if is_greyscale: - loss = GreyscaleWrapper(PNetLin, (), {'pnet_type': 'squeeze', 'reduction': reduction, 'use_dropout': False}) - if pretrained: - loss.loss.load_state_dict(self.load_state_dict('gray_pnet_lin_squeeze_trial0.pth')) + loss = GreyscaleWrapper( + PNetLin, (), {"pnet_type": "squeeze", "reduction": reduction, "use_dropout": False} + ) + if pretrained: + loss.loss.load_state_dict(self.load_state_dict("gray_pnet_lin_squeeze_trial0.pth")) else: - loss = PNetLin(pnet_type='squeeze', reduction=reduction, use_dropout=False) - if pretrained: - loss.load_state_dict(self.load_state_dict('rgb_pnet_lin_squeeze_trial0.pth')) + loss = PNetLin(pnet_type="squeeze", reduction=reduction, use_dropout=False) + if pretrained: + loss.load_state_dict(self.load_state_dict("rgb_pnet_lin_squeeze_trial0.pth")) else: raise Exception('Metric "{}" not implemented'.format(model)) # freeze all training of the loss functions - if pretrained: + if pretrained: for param in loss.parameters(): param.requires_grad = False - + return loss diff --git a/src/metr/loss/rfft2d.py b/src/metr/loss/rfft2d.py index 4a5dcc3..fe315f0 100644 --- a/src/metr/loss/rfft2d.py +++ b/src/metr/loss/rfft2d.py @@ -1,8 +1,8 @@ +import numpy as np import torch -import torch.nn as nn import torch.fft as fft +import torch.nn as nn import torch.nn.functional as F -import numpy as np class Rfft2d(nn.Module): @@ -10,66 +10,75 @@ class Rfft2d(nn.Module): Blockwhise 2D FFT for fixed blocksize of 8x8 """ + def __init__(self, blocksize=8, interleaving=False): """ Parameters: """ - super().__init__() # call super constructor - + super().__init__() # call super constructor + self.blocksize = blocksize self.interleaving = interleaving if interleaving: self.stride = self.blocksize // 2 else: self.stride = self.blocksize - + self.unfold = torch.nn.Unfold(kernel_size=self.blocksize, padding=0, stride=self.stride) return - + def forward(self, x): """ performs 2D blockwhise DCT - + Parameters: x: tensor of dimension (N, 1, h, w) - + Return: tensor of dimension (N, k, b, b/2, 2) - where the 2nd dimension indexes the block. Dimensions 3 and 4 are the block real FFT coefficients. + where the 2nd dimension indexes the block. Dimensions 3 and 4 are the block real FFT coefficients. The last dimension is pytorches representation of complex values """ - + (N, C, H, W) = x.shape - assert (C == 1), "FFT is only implemented for a single channel" - assert (H >= self.blocksize), "Input too small for blocksize" - assert (W >= self.blocksize), "Input too small for blocksize" - assert (H % self.stride == 0) and (W % self.stride == 0), "FFT is only for dimensions divisible by the blocksize" - + assert C == 1, "FFT is only implemented for a single channel" + assert H >= self.blocksize, "Input too small for blocksize" + assert W >= self.blocksize, "Input too small for blocksize" + assert (H % self.stride == 0) and ( + W % self.stride == 0 + ), "FFT is only for dimensions divisible by the blocksize" + # unfold to blocks x = self.unfold(x) # now shape (N, 64, k) (N, _, k) = x.shape - x = x.view(-1,self.blocksize,self.blocksize,k).permute(0,3,1,2) + x = x.view(-1, self.blocksize, self.blocksize, k).permute(0, 3, 1, 2) # now shape (N, #k, b, b) # perform DCT coeff = fft.rfft(x) coeff = torch.view_as_real(coeff) - + return coeff / self.blocksize**2 - + def inverse(self, coeff, output_shape): """ performs 2D blockwhise inverse rFFT - + Parameters: output_shape: Tuple, dimensions of the outpus sample """ if self.interleaving: - raise Exception('Inverse block FFT is not implemented for interleaving blocks!') - + raise Exception("Inverse block FFT is not implemented for interleaving blocks!") + # perform iRFFT x = fft.irfft(coeff, dim=2, signal_sizes=(self.blocksize, self.blocksize)) (N, k, _, _) = x.shape - x = x.permute(0,2,3,1).view(-1, self.blocksize**2, k) - x = F.fold(x, output_size=(output_shape[-2], output_shape[-1]), kernel_size=self.blocksize, padding=0, stride=self.blocksize) + x = x.permute(0, 2, 3, 1).view(-1, self.blocksize**2, k) + x = F.fold( + x, + output_size=(output_shape[-2], output_shape[-1]), + kernel_size=self.blocksize, + padding=0, + stride=self.blocksize, + ) return x * (self.blocksize**2) diff --git a/src/metr/loss/shift_wrapper.py b/src/metr/loss/shift_wrapper.py index 5c4394a..2955f54 100644 --- a/src/metr/loss/shift_wrapper.py +++ b/src/metr/loss/shift_wrapper.py @@ -1,12 +1,14 @@ +import numpy as np import torch import torch.nn as nn -import numpy as np + class ShiftWrapper(nn.Module): """ - Extension for 2-dimensional inout loss functions. - Shifts the inputs by up to 4 pixels. Uses replication padding. + Extension for 2-dimensional inout loss functions. + Shifts the inputs by up to 4 pixels. Uses replication padding. """ + def __init__(self, lossclass, args, kwargs): """ Parameters: @@ -16,16 +18,16 @@ def __init__(self, lossclass, args, kwargs): kwargs: dict, key word arguments for instantiation of loss fun """ super().__init__() - + # submodules - self.add_module('loss', lossclass(*args, **kwargs)) + self.add_module("loss", lossclass(*args, **kwargs)) # shift amount self.max_shift = 8 - + # padding self.pad = nn.ReplicationPad2d(self.max_shift // 2) - + def forward(self, input, target): # convert color space input = self.pad(input) @@ -34,7 +36,7 @@ def forward(self, input, target): shift_x = np.random.randint(self.max_shift) shift_y = np.random.randint(self.max_shift) - input = input[:,:,shift_x:-(self.max_shift - shift_x),shift_y:-(self.max_shift - shift_y)] - target = target[:,:,shift_x:-(self.max_shift - shift_x),shift_y:-(self.max_shift - shift_y)] - + input = input[:, :, shift_x : -(self.max_shift - shift_x), shift_y : -(self.max_shift - shift_y)] + target = target[:, :, shift_x : -(self.max_shift - shift_x), shift_y : -(self.max_shift - shift_y)] + return self.loss(input, target) diff --git a/src/metr/loss/ssim.py b/src/metr/loss/ssim.py index 0bd0614..899a08a 100644 --- a/src/metr/loss/ssim.py +++ b/src/metr/loss/ssim.py @@ -1,14 +1,17 @@ # SSIM implementation from https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py +from math import exp + +import numpy as np import torch import torch.nn.functional as F from torch.autograd import Variable -import numpy as np -from math import exp + def gaussian(window_size, sigma): - gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) - return gauss/gauss.sum() + gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]) + return gauss / gauss.sum() + def create_window(window_size, channel): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) @@ -16,30 +19,32 @@ def create_window(window_size, channel): window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) return window -def _ssim(img1, img2, window, window_size, channel, size_average = True): - mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) - mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) - mu1_mu2 = mu1*mu2 + mu1_mu2 = mu1 * mu2 - sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq - sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq - sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 C1 = 0.01**2 C2 = 0.03**2 - ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) if size_average: return ssim_map.mean() else: return ssim_map.mean(1).mean(1).mean(1) + class SSIM(torch.nn.Module): - def __init__(self, window_size = 11, size_average = True): + def __init__(self, window_size=11, size_average=True): super(SSIM, self).__init__() self.window_size = window_size self.size_average = size_average @@ -53,23 +58,23 @@ def forward(self, img1, img2): window = self.window else: window = create_window(self.window_size, channel) - + if img1.is_cuda: window = window.cuda(img1.get_device()) window = window.type_as(img1) - + self.window = window self.channel = channel + return 1 - _ssim(img1, img2, window, self.window_size, channel, self.size_average) - return 1 - _ssim(img1, img2, window, self.window_size, channel, self.size_average) -def ssim(img1, img2, window_size = 11, size_average = True): +def ssim(img1, img2, window_size=11, size_average=True): (_, channel, _, _) = img1.size() window = create_window(window_size, channel) - + if img1.is_cuda: window = window.cuda(img1.get_device()) window = window.type_as(img1) - + return _ssim(img1, img2, window, window_size, channel, size_average) diff --git a/src/metr/loss/watson.py b/src/metr/loss/watson.py index 4e666b4..79c3810 100644 --- a/src/metr/loss/watson.py +++ b/src/metr/loss/watson.py @@ -5,61 +5,71 @@ EPS = 1e-10 + def softmax(a, b, factor=1): concat = torch.cat([a.unsqueeze(-1), b.unsqueeze(-1)], dim=-1) softmax_factors = F.softmax(concat * factor, dim=-1) - return a * softmax_factors[:,:,:,:,0] + b * softmax_factors[:,:,:,:,1] + return a * softmax_factors[:, :, :, :, 0] + b * softmax_factors[:, :, :, :, 1] + class WatsonDistance(nn.Module): """ Loss function based on Watsons perceptual distance. Based on DCT quantization """ - def __init__(self, blocksize=8, trainable=False, reduction='sum'): + + def __init__(self, blocksize=8, trainable=False, reduction="sum"): """ Parameters: - blocksize: int, size of the Blocks for discrete cosine transform + blocksize: int, size of the Blocks for discrete cosine transform trainable: bool, if True parameters of the loss are trained and dropout is enabled. reduction: 'sum' or 'none', determines return format """ super().__init__() - + # input mapping blocksize = torch.as_tensor(blocksize) - + # module to perform 2D blockwise DCT - self.add_module('dct', Dct2d(blocksize=blocksize.item(), interleaving=False)) - + self.add_module("dct", Dct2d(blocksize=blocksize.item(), interleaving=False)) + # parameters, initialized with values from watson paper self.blocksize = nn.Parameter(blocksize, requires_grad=False) if self.blocksize == 8: # init with Jpeg QM - self.t_tild = nn.Parameter(torch.log(torch.tensor( # log-scaled weights - [[1.40, 1.01, 1.16, 1.66, 2.40, 3.43, 4.79, 6.56], - [1.01, 1.45, 1.32, 1.52, 2.00, 2.71, 3.67, 4.93], - [1.16, 1.32, 2.24, 2.59, 2.98, 3.64, 4.60, 5.88], - [1.66, 1.52, 2.59, 3.77, 4.55, 5.30, 6.28, 7.60], - [2.40, 2.00, 2.98, 4.55, 6.15, 7.46, 8.71, 10.17], - [3.43, 2.71, 3.64, 5.30, 7.46, 9.62, 11.58, 13.51], - [4.79, 3.67, 4.60, 6.28, 8.71, 11.58, 14.50, 17.29], - [6.56, 4.93, 5.88, 7.60, 10.17, 13.51, 17.29, 21.15]] - )), requires_grad=trainable) + self.t_tild = nn.Parameter( + torch.log( + torch.tensor( # log-scaled weights + [ + [1.40, 1.01, 1.16, 1.66, 2.40, 3.43, 4.79, 6.56], + [1.01, 1.45, 1.32, 1.52, 2.00, 2.71, 3.67, 4.93], + [1.16, 1.32, 2.24, 2.59, 2.98, 3.64, 4.60, 5.88], + [1.66, 1.52, 2.59, 3.77, 4.55, 5.30, 6.28, 7.60], + [2.40, 2.00, 2.98, 4.55, 6.15, 7.46, 8.71, 10.17], + [3.43, 2.71, 3.64, 5.30, 7.46, 9.62, 11.58, 13.51], + [4.79, 3.67, 4.60, 6.28, 8.71, 11.58, 14.50, 17.29], + [6.56, 4.93, 5.88, 7.60, 10.17, 13.51, 17.29, 21.15], + ] + ) + ), + requires_grad=trainable, + ) else: # init with uniform QM self.t_tild = nn.Parameter(torch.zeros((self.blocksize, self.blocksize)), requires_grad=trainable) - + # other default parameters - self.alpha = nn.Parameter(torch.tensor(0.649), requires_grad=trainable) # luminance masking - w = torch.tensor(0.7) # contrast masking - self.w_tild = nn.Parameter(torch.log(w / (1- w)), requires_grad=trainable) # inverse of sigmoid - self.beta = nn.Parameter(torch.tensor(4.), requires_grad=trainable) # pooling - + self.alpha = nn.Parameter(torch.tensor(0.649), requires_grad=trainable) # luminance masking + w = torch.tensor(0.7) # contrast masking + self.w_tild = nn.Parameter(torch.log(w / (1 - w)), requires_grad=trainable) # inverse of sigmoid + self.beta = nn.Parameter(torch.tensor(4.0), requires_grad=trainable) # pooling + # dropout for training self.dropout = nn.Dropout(0.5 if trainable else 0) - + # reduction self.reduction = reduction - if reduction not in ['sum', 'none']: + if reduction not in ["sum", "none"]: raise Exception('Reduction "{}" not supported. Valid values are: "sum", "none".'.format(reduction)) @property @@ -67,36 +77,35 @@ def t(self): # returns QM qm = torch.exp(self.t_tild) return qm - + @property def w(self): # return luminance masking parameter return torch.sigmoid(self.w_tild) - + def forward(self, input, target): # dct c0 = self.dct(target) c1 = self.dct(input) - + N, K, B, B = c0.shape - + # luminance masking - avg_lum = torch.mean(c0[:,:,0,0]) + avg_lum = torch.mean(c0[:, :, 0, 0]) t_l = self.t.view(1, 1, B, B).expand(N, K, B, B) - t_l = t_l * (((c0[:,:,0,0] + EPS) / (avg_lum + EPS)) ** self.alpha).view(N, K, 1, 1) - + t_l = t_l * (((c0[:, :, 0, 0] + EPS) / (avg_lum + EPS)) ** self.alpha).view(N, K, 1, 1) + # contrast masking - s = softmax(t_l, (c0.abs() + EPS)**self.w * t_l**(1 - self.w)) - + s = softmax(t_l, (c0.abs() + EPS) ** self.w * t_l ** (1 - self.w)) + # pooling watson_dist = (((c0 - c1) / s).abs() + EPS) ** self.beta watson_dist = self.dropout(watson_dist) + EPS - watson_dist = torch.sum(watson_dist, dim=(1,2,3)) + watson_dist = torch.sum(watson_dist, dim=(1, 2, 3)) watson_dist = watson_dist ** (1 / self.beta) # reduction - if self.reduction == 'sum': + if self.reduction == "sum": watson_dist = torch.sum(watson_dist) - + return watson_dist - diff --git a/src/metr/loss/watson_fft.py b/src/metr/loss/watson_fft.py index e613c50..45a4ca2 100644 --- a/src/metr/loss/watson_fft.py +++ b/src/metr/loss/watson_fft.py @@ -5,51 +5,54 @@ EPS = 1e-10 + def softmax(a, b, factor=1): concat = torch.cat([a.unsqueeze(-1), b.unsqueeze(-1)], dim=-1) softmax_factors = F.softmax(concat * factor, dim=-1) - return a * softmax_factors[:,:,:,:,0] + b * softmax_factors[:,:,:,:,1] + return a * softmax_factors[:, :, :, :, 0] + b * softmax_factors[:, :, :, :, 1] + class WatsonDistanceFft(nn.Module): """ Loss function based on Watsons perceptual distance. Based on FFT quantization """ - def __init__(self, blocksize=8, trainable=False, reduction='sum'): + + def __init__(self, blocksize=8, trainable=False, reduction="sum"): """ Parameters: - blocksize: int, size of the Blocks for discrete cosine transform + blocksize: int, size of the Blocks for discrete cosine transform trainable: bool, if True parameters of the loss are trained and dropout is enabled. reduction: 'sum' or 'none', determines return format """ super().__init__() self.trainable = trainable - + # input mapping blocksize = torch.as_tensor(blocksize) - + # module to perform 2D blockwise rFFT - self.add_module('fft', Rfft2d(blocksize=blocksize.item(), interleaving=False)) - + self.add_module("fft", Rfft2d(blocksize=blocksize.item(), interleaving=False)) + # parameters self.weight_size = (blocksize, blocksize // 2 + 1) self.blocksize = nn.Parameter(blocksize, requires_grad=False) # init with uniform QM self.t_tild = nn.Parameter(torch.zeros(self.weight_size), requires_grad=trainable) - self.alpha = nn.Parameter(torch.tensor(0.1), requires_grad=trainable) # luminance masking - w = torch.tensor(0.2) # contrast masking - self.w_tild = nn.Parameter(torch.log(w / (1- w)), requires_grad=trainable) # inverse of sigmoid - self.beta = nn.Parameter(torch.tensor(1.), requires_grad=trainable) # pooling - + self.alpha = nn.Parameter(torch.tensor(0.1), requires_grad=trainable) # luminance masking + w = torch.tensor(0.2) # contrast masking + self.w_tild = nn.Parameter(torch.log(w / (1 - w)), requires_grad=trainable) # inverse of sigmoid + self.beta = nn.Parameter(torch.tensor(1.0), requires_grad=trainable) # pooling + # phase weights - self.w_phase_tild = nn.Parameter(torch.zeros(self.weight_size) -2., requires_grad=trainable) - + self.w_phase_tild = nn.Parameter(torch.zeros(self.weight_size) - 2.0, requires_grad=trainable) + # dropout for training self.dropout = nn.Dropout(0.5 if trainable else 0) - + # reduction self.reduction = reduction - if reduction not in ['sum', 'none']: + if reduction not in ["sum", "none"]: raise Exception('Reduction "{}" not supported. Valid values are: "sum", "none".'.format(reduction)) @property @@ -57,64 +60,65 @@ def t(self): # returns QM qm = torch.exp(self.t_tild) return qm - + @property def w(self): # return luminance masking parameter return torch.sigmoid(self.w_tild) - + @property def w_phase(self): # return weights for phase - w_phase = torch.exp(self.w_phase_tild) + w_phase = torch.exp(self.w_phase_tild) # set weights of non-phases to 0 if not self.trainable: - w_phase[0,0] = 0. - w_phase[0,self.weight_size[1] - 1] = 0. - w_phase[self.weight_size[1] - 1,self.weight_size[1] - 1] = 0. - w_phase[self.weight_size[1] - 1, 0] = 0. + w_phase[0, 0] = 0.0 + w_phase[0, self.weight_size[1] - 1] = 0.0 + w_phase[self.weight_size[1] - 1, self.weight_size[1] - 1] = 0.0 + w_phase[self.weight_size[1] - 1, 0] = 0.0 return w_phase - + def forward(self, input, target): # fft c0 = self.fft(target) c1 = self.fft(input) - + N, K, H, W, _ = c0.shape - + # get amplitudes - c0_amp = torch.norm(c0 + EPS, p='fro', dim=4) - c1_amp = torch.norm(c1 + EPS, p='fro', dim=4) - + c0_amp = torch.norm(c0 + EPS, p="fro", dim=4) + c1_amp = torch.norm(c1 + EPS, p="fro", dim=4) + # luminance masking - avg_lum = torch.mean(c0_amp[:,:,0,0]) + avg_lum = torch.mean(c0_amp[:, :, 0, 0]) t_l = self.t.view(1, 1, H, W).expand(N, K, H, W) - t_l = t_l * (((c0_amp[:,:,0,0] + EPS) / (avg_lum + EPS)) ** self.alpha).view(N, K, 1, 1) - + t_l = t_l * (((c0_amp[:, :, 0, 0] + EPS) / (avg_lum + EPS)) ** self.alpha).view(N, K, 1, 1) + # contrast masking - s = softmax(t_l, (c0_amp.abs() + EPS)**self.w * t_l**(1 - self.w)) - + s = softmax(t_l, (c0_amp.abs() + EPS) ** self.w * t_l ** (1 - self.w)) + # pooling watson_dist = (((c0_amp - c1_amp) / s).abs() + EPS) ** self.beta watson_dist = self.dropout(watson_dist) + EPS - watson_dist = torch.sum(watson_dist, dim=(1,2,3)) + watson_dist = torch.sum(watson_dist, dim=(1, 2, 3)) watson_dist = watson_dist ** (1 / self.beta) - + # get phases - c0_phase = torch.atan2( c0[:,:,:,:,1], c0[:,:,:,:,0] + EPS) - c1_phase = torch.atan2( c1[:,:,:,:,1], c1[:,:,:,:,0] + EPS) - + c0_phase = torch.atan2(c0[:, :, :, :, 1], c0[:, :, :, :, 0] + EPS) + c1_phase = torch.atan2(c1[:, :, :, :, 1], c1[:, :, :, :, 0] + EPS) + # angular distance - phase_dist = torch.acos(torch.cos(c0_phase - c1_phase)*(1 - EPS*10**3)) * self.w_phase # we multiply with a factor ->1 to prevent taking the gradient of acos(-1) or acos(1). The gradient in this case would be -/+ inf + phase_dist = ( + torch.acos(torch.cos(c0_phase - c1_phase) * (1 - EPS * 10**3)) * self.w_phase + ) # we multiply with a factor ->1 to prevent taking the gradient of acos(-1) or acos(1). The gradient in this case would be -/+ inf phase_dist = self.dropout(phase_dist) - phase_dist = torch.sum(phase_dist, dim=(1,2,3)) - + phase_dist = torch.sum(phase_dist, dim=(1, 2, 3)) + # perceptual distance distance = watson_dist + phase_dist - + # reduce - if self.reduction == 'sum': + if self.reduction == "sum": distance = torch.sum(distance) - + return distance - diff --git a/src/metr/loss/watson_vgg.py b/src/metr/loss/watson_vgg.py index 620201a..f79ae12 100644 --- a/src/metr/loss/watson_vgg.py +++ b/src/metr/loss/watson_vgg.py @@ -5,33 +5,34 @@ EPS = 1e-10 + class VggFeatureExtractor(nn.Module): def __init__(self): super(VggFeatureExtractor, self).__init__() - + # download vgg vgg16 = torchvision.models.vgg16(pretrained=True).features - + # set non trainable for param in vgg16.parameters(): param.requires_grad = False - + # slice model self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() - - for x in range(4): # conv relu conv relu + + for x in range(4): # conv relu conv relu self.slice1.add_module(str(x), vgg16[x]) - for x in range(4, 9): # max conv relu conv relu + for x in range(4, 9): # max conv relu conv relu self.slice2.add_module(str(x), vgg16[x]) - for x in range(9, 16): # max cov relu conv relu conv relu + for x in range(9, 16): # max cov relu conv relu conv relu self.slice3.add_module(str(x), vgg16[x]) - for x in range(16, 23): # conv relu max conv relu conv relu + for x in range(16, 23): # conv relu max conv relu conv relu self.slice4.add_module(str(x), vgg16[x]) - for x in range(23, 30): # conv relu conv relu max conv relu + for x in range(23, 30): # conv relu conv relu max conv relu self.slice5.add_module(str(x), vgg16[x]) def forward(self, X): @@ -52,108 +53,110 @@ def forward(self, X): def normalize_tensor(t): # norms a tensor over the channel dimension to an euclidean length of 1. N, C, H, W = t.shape - norm_factor = torch.sqrt(torch.sum(t**2,dim=1)).view(N,1,H,W) - return t/(norm_factor.expand_as(t)+EPS) + norm_factor = torch.sqrt(torch.sum(t**2, dim=1)).view(N, 1, H, W) + return t / (norm_factor.expand_as(t) + EPS) + def softmax(a, b, factor=1): concat = torch.cat([a.unsqueeze(-1), b.unsqueeze(-1)], dim=-1) softmax_factors = F.softmax(concat * factor, dim=-1) - return a * softmax_factors[:,:,:,:,0] + b * softmax_factors[:,:,:,:,1] + return a * softmax_factors[:, :, :, :, 0] + b * softmax_factors[:, :, :, :, 1] + class WatsonDistanceVgg(nn.Module): """ Loss function based on Watsons perceptual distance. Based on deep feature extraction """ - def __init__(self, trainable=False, reduction='sum'): + + def __init__(self, trainable=False, reduction="sum"): """ Parameters: trainable: bool, if True parameters of the loss are trained and dropout is enabled. reduction: 'sum' or 'none', determines return format """ super().__init__() - + # module to perform feature extraction - self.add_module('vgg', VggFeatureExtractor()) - + self.add_module("vgg", VggFeatureExtractor()) + # imagenet-normalization - self.shift = nn.Parameter(torch.Tensor([-.030, -.088, -.188]).view(1,3,1,1), requires_grad=False) - self.scale = nn.Parameter(torch.Tensor([.458, .448, .450]).view(1,3,1,1), requires_grad=False) - + self.shift = nn.Parameter(torch.Tensor([-0.030, -0.088, -0.188]).view(1, 3, 1, 1), requires_grad=False) + self.scale = nn.Parameter(torch.Tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1), requires_grad=False) + # channel dimensions self.L = 5 - self.channels = [64,128,256,512,512] - + self.channels = [64, 128, 256, 512, 512] + # sensitivity parameters self.t0_tild = nn.Parameter(torch.zeros((self.channels[0])), requires_grad=trainable) self.t1_tild = nn.Parameter(torch.zeros((self.channels[1])), requires_grad=trainable) self.t2_tild = nn.Parameter(torch.zeros((self.channels[2])), requires_grad=trainable) self.t3_tild = nn.Parameter(torch.zeros((self.channels[3])), requires_grad=trainable) self.t4_tild = nn.Parameter(torch.zeros((self.channels[4])), requires_grad=trainable) - + # other default parameters - w = torch.tensor(0.2) # contrast masking - self.w0_tild = nn.Parameter(torch.log(w / (1- w)), requires_grad=trainable) # inverse of sigmoid - self.w1_tild = nn.Parameter(torch.log(w / (1- w)), requires_grad=trainable) - self.w2_tild = nn.Parameter(torch.log(w / (1- w)), requires_grad=trainable) - self.w3_tild = nn.Parameter(torch.log(w / (1- w)), requires_grad=trainable) - self.w4_tild = nn.Parameter(torch.log(w / (1- w)), requires_grad=trainable) - self.beta = nn.Parameter(torch.tensor(1.), requires_grad=trainable) # pooling - + w = torch.tensor(0.2) # contrast masking + self.w0_tild = nn.Parameter(torch.log(w / (1 - w)), requires_grad=trainable) # inverse of sigmoid + self.w1_tild = nn.Parameter(torch.log(w / (1 - w)), requires_grad=trainable) + self.w2_tild = nn.Parameter(torch.log(w / (1 - w)), requires_grad=trainable) + self.w3_tild = nn.Parameter(torch.log(w / (1 - w)), requires_grad=trainable) + self.w4_tild = nn.Parameter(torch.log(w / (1 - w)), requires_grad=trainable) + self.beta = nn.Parameter(torch.tensor(1.0), requires_grad=trainable) # pooling + # dropout for training self.dropout = nn.Dropout(0.5 if trainable else 0) - + # reduction self.reduction = reduction - if reduction not in ['sum', 'none']: + if reduction not in ["sum", "none"]: raise Exception('Reduction "{}" not supported. Valid values are: "sum", "none".'.format(reduction)) @property def t(self): return [torch.exp(t) for t in [self.t0_tild, self.t1_tild, self.t2_tild, self.t3_tild, self.t4_tild]] - + @property def w(self): # return luminance masking parameter return [torch.sigmoid(w) for w in [self.w0_tild, self.w1_tild, self.w2_tild, self.w3_tild, self.w4_tild]] - + def forward(self, input, target): # normalization - input = (input - self.shift.expand_as(input))/self.scale.expand_as(input) - target = (target - self.shift.expand_as(target))/self.scale.expand_as(target) - + input = (input - self.shift.expand_as(input)) / self.scale.expand_as(input) + target = (target - self.shift.expand_as(target)) / self.scale.expand_as(target) + # feature extraction c0 = self.vgg(target) c1 = self.vgg(input) - + # norm over channels for l in range(self.L): c0[l] = normalize_tensor(c0[l]) c1[l] = normalize_tensor(c1[l]) - + # contrast masking t = self.t w = self.w s = [] for l in range(self.L): N, C_l, H_l, W_l = c0[l].shape - t_l = t[l].view(1,C_l,1,1).expand(N, C_l, H_l, W_l) - s.append(softmax(t_l, (c0[l].abs() + EPS)**w[l] * t_l**(1 - w[l]))) - + t_l = t[l].view(1, C_l, 1, 1).expand(N, C_l, H_l, W_l) + s.append(softmax(t_l, (c0[l].abs() + EPS) ** w[l] * t_l ** (1 - w[l]))) + # pooling watson_dist = 0 for l in range(self.L): _, _, H_l, W_l = c0[l].shape layer_dist = (((c0[l] - c1[l]) / s[l]).abs() + EPS) ** self.beta layer_dist = self.dropout(layer_dist) + EPS - layer_dist = torch.sum(layer_dist, dim=(1,2,3)) # sum over dimensions of layer - layer_dist = (1 / (H_l * W_l)) * layer_dist # normalize by layer size + layer_dist = torch.sum(layer_dist, dim=(1, 2, 3)) # sum over dimensions of layer + layer_dist = (1 / (H_l * W_l)) * layer_dist # normalize by layer size watson_dist += layer_dist # sum over layers watson_dist = watson_dist ** (1 / self.beta) # reduction - if self.reduction == 'sum': + if self.reduction == "sum": watson_dist = torch.sum(watson_dist) - + return watson_dist - diff --git a/src/metr/metr_pp_eval_stable_sig.py b/src/metr/metr_pp_eval_stable_sig.py index 37c7c80..7e7a4fd 100644 --- a/src/metr/metr_pp_eval_stable_sig.py +++ b/src/metr/metr_pp_eval_stable_sig.py @@ -5,44 +5,46 @@ # LICENSE file in the root directory of this source tree. import argparse +import importlib import json import os import shutil -import tqdm from pathlib import Path -from PIL import Image import numpy as np import pandas as pd import torch import torch.nn as nn -from torchvision import transforms -from torchvision.transforms import ToPILImage - +import tqdm +from PIL import Image from pytorch_fid.fid_score import InceptionV3, calculate_frechet_distance, compute_statistics_of_path from skimage.metrics import peak_signal_noise_ratio, structural_similarity +from torchvision import transforms +from torchvision.transforms import ToPILImage # import utils # import utils_img # import utils_model -import importlib + def import_from_stable_sig(name): module = importlib.import_module(".stable_sig." + name, package=__package__) return module + utils = import_from_stable_sig("utils") utils_img = import_from_stable_sig("utils_img") utils_model = import_from_stable_sig("utils_model") from wm_attacks import ReSDPipeline -from wm_attacks.wmattacker_no_saving import VAEWMAttacker, DiffWMAttacker +from wm_attacks.wmattacker_no_saving import DiffWMAttacker, VAEWMAttacker import wandb device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + def save_imgs(img_dir, img_dir_nw, save_dir, num_imgs=None, mult=10): filenames = os.listdir(img_dir) filenames.sort() @@ -51,12 +53,13 @@ def save_imgs(img_dir, img_dir_nw, save_dir, num_imgs=None, mult=10): for ii, filename in enumerate(tqdm.tqdm(filenames)): img_1 = Image.open(os.path.join(img_dir_nw, filename)) img_2 = Image.open(os.path.join(img_dir, filename)) - diff = np.abs(np.asarray(img_1).astype(int) - np.asarray(img_2).astype(int)) *10 + diff = np.abs(np.asarray(img_1).astype(int) - np.asarray(img_2).astype(int)) * 10 diff = Image.fromarray(diff.astype(np.uint8)) shutil.copy(os.path.join(img_dir_nw, filename), os.path.join(save_dir, f"{ii:02d}_nw.png")) shutil.copy(os.path.join(img_dir, filename), os.path.join(save_dir, f"{ii:02d}_w.png")) diff.save(os.path.join(save_dir, f"{ii:02d}_diff.png")) + def get_img_metric(img_dir, img_dir_nw, num_imgs=None): filenames = os.listdir(img_dir) filenames.sort() @@ -69,44 +72,48 @@ def get_img_metric(img_dir, img_dir_nw, num_imgs=None): img_ori = np.asarray(pil_img_ori) img = np.asarray(pil_img) log_stat = { - 'filename': filename, - 'ssim': structural_similarity(img_ori, img, channel_axis=2), - 'psnr': peak_signal_noise_ratio(img_ori, img), - 'linf': np.amax(np.abs(img_ori.astype(int)-img.astype(int))) + "filename": filename, + "ssim": structural_similarity(img_ori, img, channel_axis=2), + "psnr": peak_signal_noise_ratio(img_ori, img), + "linf": np.amax(np.abs(img_ori.astype(int) - img.astype(int))), } log_stats.append(log_stat) return log_stats -def cached_fid(path1, path2, batch_size=32, device='cuda:0', dims=2048, num_workers=10): + +def cached_fid(path1, path2, batch_size=32, device="cuda:0", dims=2048, num_workers=10): for p in [path1, path2]: if not os.path.exists(p): - raise RuntimeError('Invalid path: %s' % p) + raise RuntimeError("Invalid path: %s" % p) block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] model = InceptionV3([block_idx]).to(device) # cache path2 storage_path = Path.home() / f'.cache/torch/fid/{path2.replace("/", "_")}' - if (storage_path / 'm.pt').exists(): - m2 = torch.load(storage_path / 'm.pt') - s2 = torch.load(storage_path / 's.pt') + if (storage_path / "m.pt").exists(): + m2 = torch.load(storage_path / "m.pt") + s2 = torch.load(storage_path / "s.pt") else: storage_path.mkdir(parents=True) m2, s2 = compute_statistics_of_path(str(path2), model, batch_size, dims, device, num_workers) - torch.save(m2, storage_path / 'm.pt') - torch.save(s2, storage_path / 's.pt') - m1, s1 = compute_statistics_of_path(str(path1), model, batch_size, dims, device, num_workers) + torch.save(m2, storage_path / "m.pt") + torch.save(s2, storage_path / "s.pt") + m1, s1 = compute_statistics_of_path(str(path1), model, batch_size, dims, device, num_workers) fid_value = calculate_frechet_distance(m1, s1, m2, s2) return fid_value + @torch.no_grad() def get_bit_accs(img_dir: str, msg_decoder: nn.Module, key: torch.Tensor, batch_size: int = 16, attacks: dict = {}): # resize crop - transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ]) + transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) data_loader = utils.get_dataloader(img_dir, transform, batch_size=batch_size, collate_fn=None) - log_stats = {ii:{} for ii in range(len(data_loader.dataset))} + log_stats = {ii: {} for ii in range(len(data_loader.dataset))} for ii, imgs in enumerate(tqdm.tqdm(data_loader)): imgs = imgs.to(device) @@ -115,54 +122,58 @@ def get_bit_accs(img_dir: str, msg_decoder: nn.Module, key: torch.Tensor, batch_ for name, attack in attacks.items(): imgs_aug = attack(imgs) # print(type(imgs_aug), imgs_aug.shape) - decoded = msg_decoder(imgs_aug) # b c h w -> b k - diff = (~torch.logical_xor(decoded>0, keys>0)) # b k -> b k - bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b - word_accs = (bit_accs == 1) # b + decoded = msg_decoder(imgs_aug) # b c h w -> b k + diff = ~torch.logical_xor(decoded > 0, keys > 0) # b k -> b k + bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b + word_accs = bit_accs == 1 # b for jj in range(bit_accs.shape[0]): - img_num = ii*batch_size+jj + img_num = ii * batch_size + jj log_stat = log_stats[img_num] - log_stat[f'bit_acc_{name}'] = bit_accs[jj].item() - log_stat[f'word_acc_{name}'] = 1.0 if word_accs[jj].item() else 0.0 + log_stat[f"bit_acc_{name}"] = bit_accs[jj].item() + log_stat[f"word_acc_{name}"] = 1.0 if word_accs[jj].item() else 0.0 - log_stats = [{'img': img_num, **log_stats[img_num]} for img_num in range(len(data_loader.dataset))] + log_stats = [{"img": img_num, **log_stats[img_num]} for img_num in range(len(data_loader.dataset))] return log_stats + @torch.no_grad() def get_msgs(img_dir: str, msg_decoder: nn.Module, batch_size: int = 16, attacks: dict = {}): # resize crop - transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ]) + transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) data_loader = utils.get_dataloader(img_dir, transform, batch_size=batch_size, collate_fn=None) - log_stats = {ii:{} for ii in range(len(data_loader.dataset))} + log_stats = {ii: {} for ii in range(len(data_loader.dataset))} for ii, imgs in enumerate(tqdm.tqdm(data_loader)): imgs = imgs.to(device) for name, attack in attacks.items(): imgs_aug = attack(imgs) - decoded = msg_decoder(imgs_aug)>0 # b c h w -> b k + decoded = msg_decoder(imgs_aug) > 0 # b c h w -> b k for jj in range(decoded.shape[0]): - img_num = ii*batch_size+jj + img_num = ii * batch_size + jj log_stat = log_stats[img_num] - log_stat[f'decoded_{name}'] = "".join([('1' if el else '0') for el in decoded[jj].detach()]) + log_stat[f"decoded_{name}"] = "".join([("1" if el else "0") for el in decoded[jj].detach()]) - log_stats = [{'img': img_num, **log_stats[img_num]} for img_num in range(len(data_loader.dataset))] + log_stats = [{"img": img_num, **log_stats[img_num]} for img_num in range(len(data_loader.dataset))] return log_stats + def main(params): if params.with_tracking: - wandb_run = wandb.init(project=params.project_name, name=params.run_name, tags=['tree_ring_watermark']) + wandb_run = wandb.init(project=params.project_name, name=params.run_name, tags=["tree_ring_watermark"]) else: wandb_run = None - # Set seeds for reproductibility + # Set seeds for reproductibility np.random.seed(params.seed) - + # Print the arguments print("__git__:{}".format(utils.get_sha())) print("__log__:{}".format(json.dumps(vars(params)))) @@ -170,7 +181,7 @@ def main(params): # Create the directories if not os.path.exists(params.output_dir): os.makedirs(params.output_dir) - save_img_dir = os.path.join(params.output_dir, 'imgs') + save_img_dir = os.path.join(params.output_dir, "imgs") params.save_img_dir = save_img_dir if not os.path.exists(save_img_dir): os.makedirs(save_img_dir, exist_ok=True) @@ -181,14 +192,14 @@ def main(params): if params.save_n_imgs > 0: save_imgs(params.img_dir, params.img_dir_nw, save_img_dir, num_imgs=params.save_n_imgs) - print(f'>>> Computing img-2-img stats...') + print(f">>> Computing img-2-img stats...") img_metrics = get_img_metric(params.img_dir, params.img_dir_nw, num_imgs=params.num_imgs) img_df = pd.DataFrame(img_metrics) - img_df.to_csv(os.path.join(params.output_dir, 'img_metrics.csv'), index=False) - ssims = img_df['ssim'].tolist() - psnrs = img_df['psnr'].tolist() - linfs = img_df['linf'].tolist() - ssim_mean, ssim_std, ssim_max, ssim_min = np.mean(ssims), np.std(ssims), np.max(ssims), np.min(ssims) + img_df.to_csv(os.path.join(params.output_dir, "img_metrics.csv"), index=False) + ssims = img_df["ssim"].tolist() + psnrs = img_df["psnr"].tolist() + linfs = img_df["linf"].tolist() + ssim_mean, ssim_std, ssim_max, ssim_min = np.mean(ssims), np.std(ssims), np.max(ssims), np.min(ssims) psnr_mean, psnr_std, psnr_max, psnr_min = np.mean(psnrs), np.std(psnrs), np.max(psnrs), np.min(psnrs) linf_mean, linf_std, linf_max, linf_min = np.mean(linfs), np.std(linfs), np.max(linfs), np.min(linfs) print(f"SSIM: {ssim_mean:.4f}±{ssim_std:.4f} [{ssim_min:.4f}, {ssim_max:.4f}]") @@ -196,7 +207,7 @@ def main(params): print(f"Linf: {linf_mean:.4f}±{linf_std:.4f} [{linf_min:.4f}, {linf_max:.4f}]") if params.img_dir_fid is not None: - print(f'>>> Computing image distribution stats...') + print(f">>> Computing image distribution stats...") fid = cached_fid(params.img_dir, params.img_dir_fid) print(f"FID watermark : {fid:.4f}") fid_nw = cached_fid(params.img_dir_nw, params.img_dir_fid) @@ -205,35 +216,42 @@ def main(params): if params.eval_bits: # Loads hidden decoder - print(f'>>> Building hidden decoder with weights from {params.msg_decoder_path}...') - if 'torchscript' in params.msg_decoder_path: + print(f">>> Building hidden decoder with weights from {params.msg_decoder_path}...") + if "torchscript" in params.msg_decoder_path: msg_decoder = torch.jit.load(params.msg_decoder_path).to(device) else: - msg_decoder = utils_model.get_hidden_decoder(num_bits=params.num_bits, redundancy=params.redundancy, num_blocks=params.decoder_depth, channels=params.decoder_channels).to(device) + msg_decoder = utils_model.get_hidden_decoder( + num_bits=params.num_bits, + redundancy=params.redundancy, + num_blocks=params.decoder_depth, + channels=params.decoder_channels, + ).to(device) ckpt = utils_model.get_hidden_decoder_ckpt(params.msg_decoder_path) print(msg_decoder.load_state_dict(ckpt, strict=False)) msg_decoder.eval() # whitening - print(f'>>> Whitening...') + print(f">>> Whitening...") with torch.no_grad(): data_dir = "/checkpoint/pfz/watermarking/data/coco_10k_orig/0" - transform = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(256), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ]) + transform = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(256), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) loader = utils.get_dataloader(data_dir, transform, batch_size=16, collate_fn=None) ys = [] for i, x in enumerate(loader): x = x.to(device) y = msg_decoder(x) - ys.append(y.to('cpu')) + ys.append(y.to("cpu")) ys = torch.cat(ys, dim=0) nbit = ys.shape[1] - mean = ys.mean(dim=0, keepdim=True) # NxD -> 1xD - ys_centered = ys - mean # NxD + mean = ys.mean(dim=0, keepdim=True) # NxD -> 1xD + ys_centered = ys - mean # NxD cov = ys_centered.T @ ys_centered e, v = torch.linalg.eigh(cov) L = torch.diag(1.0 / torch.pow(e, exponent=0.5)) @@ -245,7 +263,7 @@ def main(params): msg_decoder = nn.Sequential(msg_decoder, linear.to(device)) torchscript_m = torch.jit.script(msg_decoder) torch.jit.save(torchscript_m, params.msg_decoder_path.replace(".pth", "_whit.torchscript.pt")) - + msg_decoder.eval() nbit = msg_decoder(torch.zeros(1, 3, 128, 128).to(device)).shape[-1] @@ -257,64 +275,74 @@ def main(params): # vae_2020_attacker = VAEWMAttacker("cheng2020-anchor", quality=1, metric='mse', device=device) # piller = ToPILImage() - if params.attack_mode == 'all': + if params.attack_mode == "all": attacks = { - 'none': lambda x: x, - 'crop_05': lambda x: utils_img.center_crop(x, 0.5), - 'crop_01': lambda x: utils_img.center_crop(x, 0.1), - 'rot_25': lambda x: utils_img.rotate(x, 25), - 'rot_90': lambda x: utils_img.rotate(x, 90), - 'jpeg_80': lambda x: utils_img.jpeg_compress(x, 80), - 'jpeg_50': lambda x: utils_img.jpeg_compress(x, 50), - 'brightness_1p5': lambda x: utils_img.adjust_brightness(x, 1.5), - 'brightness_2': lambda x: utils_img.adjust_brightness(x, 2), - 'contrast_1p5': lambda x: utils_img.adjust_contrast(x, 1.5), - 'contrast_2': lambda x: utils_img.adjust_contrast(x, 2), - 'saturation_1p5': lambda x: utils_img.adjust_saturation(x, 1.5), - 'saturation_2': lambda x: utils_img.adjust_saturation(x, 2), - 'sharpness_1p5': lambda x: utils_img.adjust_sharpness(x, 1.5), - 'sharpness_2': lambda x: utils_img.adjust_sharpness(x, 2), - 'resize_05': lambda x: utils_img.resize(x, 0.5), - 'resize_01': lambda x: utils_img.resize(x, 0.1), - 'overlay_text': lambda x: utils_img.overlay_text(x, [76,111,114,101,109,32,73,112,115,117,109]), - 'comb': lambda x: utils_img.jpeg_compress(utils_img.adjust_brightness(utils_img.center_crop(x, 0.5), 1.5), 80), + "none": lambda x: x, + "crop_05": lambda x: utils_img.center_crop(x, 0.5), + "crop_01": lambda x: utils_img.center_crop(x, 0.1), + "rot_25": lambda x: utils_img.rotate(x, 25), + "rot_90": lambda x: utils_img.rotate(x, 90), + "jpeg_80": lambda x: utils_img.jpeg_compress(x, 80), + "jpeg_50": lambda x: utils_img.jpeg_compress(x, 50), + "brightness_1p5": lambda x: utils_img.adjust_brightness(x, 1.5), + "brightness_2": lambda x: utils_img.adjust_brightness(x, 2), + "contrast_1p5": lambda x: utils_img.adjust_contrast(x, 1.5), + "contrast_2": lambda x: utils_img.adjust_contrast(x, 2), + "saturation_1p5": lambda x: utils_img.adjust_saturation(x, 1.5), + "saturation_2": lambda x: utils_img.adjust_saturation(x, 2), + "sharpness_1p5": lambda x: utils_img.adjust_sharpness(x, 1.5), + "sharpness_2": lambda x: utils_img.adjust_sharpness(x, 2), + "resize_05": lambda x: utils_img.resize(x, 0.5), + "resize_01": lambda x: utils_img.resize(x, 0.1), + "overlay_text": lambda x: utils_img.overlay_text( + x, [76, 111, 114, 101, 109, 32, 73, 112, 115, 117, 109] + ), + "comb": lambda x: utils_img.jpeg_compress( + utils_img.adjust_brightness(utils_img.center_crop(x, 0.5), 1.5), 80 + ), # 'diff_150': lambda x: diff_attacker.attack(x.cpu()), # 'vae_2018_1': lambda x: vae_2018_attacker.attack(piller(x.cpu())), # 'vae_2020_1': lambda x: vae_2020_attacker.attack(piller(x.cpu())) } - elif params.attack_mode == 'few': + elif params.attack_mode == "few": attacks = { - 'none': lambda x: x, - 'crop_01': lambda x: utils_img.center_crop(x, 0.1), - 'brightness_2': lambda x: utils_img.adjust_brightness(x, 2), - 'contrast_2': lambda x: utils_img.adjust_contrast(x, 2), - 'jpeg_50': lambda x: utils_img.jpeg_compress(x, 50), - 'comb': lambda x: utils_img.jpeg_compress(utils_img.adjust_brightness(utils_img.center_crop(x, 0.5), 1.5), 80), + "none": lambda x: x, + "crop_01": lambda x: utils_img.center_crop(x, 0.1), + "brightness_2": lambda x: utils_img.adjust_brightness(x, 2), + "contrast_2": lambda x: utils_img.adjust_contrast(x, 2), + "jpeg_50": lambda x: utils_img.jpeg_compress(x, 50), + "comb": lambda x: utils_img.jpeg_compress( + utils_img.adjust_brightness(utils_img.center_crop(x, 0.5), 1.5), 80 + ), # 'diff_150': lambda x: diff_attacker.attack(x.cpu()), # 'vae_2018_1': lambda x: vae_2018_attacker.attack(x.cpu()), # 'vae_2020_1': lambda x: vae_2020_attacker.attack(x.cpu()) } else: - attacks = {'none': lambda x: x} + attacks = {"none": lambda x: x} if params.decode_only: log_stats = get_msgs(params.img_dir, msg_decoder, batch_size=params.batch_size, attacks=attacks) - else: + else: # Creating key - key = torch.tensor([k=='1' for k in params.key_str]).to(device) + key = torch.tensor([k == "1" for k in params.key_str]).to(device) log_stats = get_bit_accs(params.img_dir, msg_decoder, key, batch_size=params.batch_size, attacks=attacks) - print(f'>>> Saving log stats to {params.output_dir}...') + print(f">>> Saving log stats to {params.output_dir}...") df = pd.DataFrame(log_stats) - df.to_csv(os.path.join(params.output_dir, 'log_stats.csv'), index=False) + df.to_csv(os.path.join(params.output_dir, "log_stats.csv"), index=False) df_mean = pd.DataFrame(df.mean()).transpose() - df_mean.to_csv(os.path.join(params.output_dir, 'mean_log_stats.csv'), index=False) + df_mean.to_csv(os.path.join(params.output_dir, "mean_log_stats.csv"), index=False) if params.with_tracking: # mean_table = wandb.Table(dataframe=df_mean) # wandb_run.log({"mean_values": df_mean.drop(columns=["img"])}) - log_dict = df_mean.drop(columns=["img"]).rename(columns={"bit_acc_none": "Bit_acc", "word_acc_none": "Word_acc"}).to_dict("records")[0] + log_dict = ( + df_mean.drop(columns=["img"]) + .rename(columns={"bit_acc_none": "Bit_acc", "word_acc_none": "Word_acc"}) + .to_dict("records")[0] + ) wandb.log(log_dict) @@ -324,22 +352,27 @@ def get_parser(): def aa(*args, **kwargs): group.add_argument(*args, **kwargs) - group = parser.add_argument_group('Data parameters') + group = parser.add_argument_group("Data parameters") aa("--img_dir", type=str, default="", help="") aa("--num_imgs", type=int, default=None) - group = parser.add_argument_group('Eval imgs') + group = parser.add_argument_group("Eval imgs") aa("--eval_imgs", type=utils.bool_inst, default=True, help="") - aa("--img_dir_nw", type=str, default="/checkpoint/pfz/2023_logs/0104_aisign_sd_txt2img/_ldm_decoder_ckpt=0_config=0_ckpt=0/samples", help="") + aa( + "--img_dir_nw", + type=str, + default="/checkpoint/pfz/2023_logs/0104_aisign_sd_txt2img/_ldm_decoder_ckpt=0_config=0_ckpt=0/samples", + help="", + ) aa("--img_dir_fid", type=str, default=None, help="") aa("--save_n_imgs", type=int, default=10) - group = parser.add_argument_group('Eval bits') + group = parser.add_argument_group("Eval bits") aa("--eval_bits", type=utils.bool_inst, default=True, help="") aa("--decode_only", type=utils.bool_inst, default=False, help="") aa("--key_str", type=str, default="111010110101000001010111010011010100010000100111") - aa("--msg_decoder_path", type=str, default= "models/dec_48b_whit.torchscript.pt") - aa("--attack_mode", type=str, default= "all") + aa("--msg_decoder_path", type=str, default="models/dec_48b_whit.torchscript.pt") + aa("--attack_mode", type=str, default="all") aa("--num_bits", type=int, default=48) aa("--redundancy", type=int, default=1) aa("--decoder_depth", type=int, default=8) @@ -347,21 +380,25 @@ def aa(*args, **kwargs): aa("--img_size", type=int, default=512) aa("--batch_size", type=int, default=32) - group = parser.add_argument_group('Experiments parameters') - aa("--output_dir", type=str, default="output/", help="Output directory for logs and images, when doing eval images (Default: /output)") + group = parser.add_argument_group("Experiments parameters") + aa( + "--output_dir", + type=str, + default="output/", + help="Output directory for logs and images, when doing eval images (Default: /output)", + ) aa("--seed", type=int, default=0) aa("--debug", type=utils.bool_inst, default=False, help="Debug mode") - group = parser.add_argument_group('Logging parameters') - aa('--with_tracking', action='store_true') - aa('--project_name', default='eval_stable_tree') - aa('--run_name', default='test') - + group = parser.add_argument_group("Logging parameters") + aa("--with_tracking", action="store_true") + aa("--project_name", default="eval_stable_tree") + aa("--run_name", default="test") return parser -if __name__ == '__main__': +if __name__ == "__main__": # generate parser / parse parameters parser = get_parser() diff --git a/src/metr/modified_stable_diffusion.py b/src/metr/modified_stable_diffusion.py index 693c3a1..1abd396 100644 --- a/src/metr/modified_stable_diffusion.py +++ b/src/metr/modified_stable_diffusion.py @@ -1,12 +1,11 @@ - -from typing import Callable, List, Optional, Union, Any, Dict import copy +from typing import Any, Callable, Dict, List, Optional, Union + import numpy as np import PIL - import torch from diffusers import StableDiffusionPipeline -from diffusers.utils import logging, BaseOutput +from diffusers.utils import BaseOutput, logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -18,7 +17,8 @@ class ModifiedStableDiffusionPipelineOutput(BaseOutput): class ModifiedStableDiffusionPipeline(StableDiffusionPipeline): - def __init__(self, + def __init__( + self, vae, text_encoder, tokenizer, @@ -28,14 +28,9 @@ def __init__(self, feature_extractor, requires_safety_checker: bool = True, ): - super(ModifiedStableDiffusionPipeline, self).__init__(vae, - text_encoder, - tokenizer, - unet, - scheduler, - safety_checker, - feature_extractor, - requires_safety_checker) + super(ModifiedStableDiffusionPipeline, self).__init__( + vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker + ) @torch.no_grad() def __call__( @@ -201,15 +196,14 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return ModifiedStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, init_latents=init_latents) - + return ModifiedStableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept, init_latents=init_latents + ) @torch.inference_mode() def decode_image(self, latents: torch.FloatTensor, **kwargs): scaled_latents = 1 / 0.18215 * latents - image = [ - self.vae.decode(scaled_latents[i : i + 1]).sample for i in range(len(latents)) - ] + image = [self.vae.decode(scaled_latents[i : i + 1]).sample for i in range(len(latents))] image = torch.cat(image, dim=0) return image @@ -227,4 +221,4 @@ def get_image_latents(self, image, sample=True, rng_generator=None): else: encoding = encoding_dist.mode() latents = encoding * 0.18215 - return latents \ No newline at end of file + return latents diff --git a/src/metr/open_clip/__init__.py b/src/metr/open_clip/__init__.py index 088c864..2c2d3b2 100644 --- a/src/metr/open_clip/__init__.py +++ b/src/metr/open_clip/__init__.py @@ -1,13 +1,38 @@ from .coca_model import CoCa from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD -from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss -from .factory import list_models, add_model_config, get_model_config, load_checkpoint -from .loss import ClipLoss, DistillClipLoss, CoCaLoss -from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ - convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype -from .openai import load_openai_model, list_openai_models -from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ - get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained +from .factory import ( + add_model_config, + create_loss, + create_model, + create_model_and_transforms, + create_model_from_pretrained, + get_model_config, + get_tokenizer, + list_models, + load_checkpoint, +) +from .loss import ClipLoss, CoCaLoss, DistillClipLoss +from .model import ( + CLIP, + CLIPTextCfg, + CLIPVisionCfg, + CustomTextCLIP, + convert_weights_to_fp16, + convert_weights_to_lp, + get_cast_dtype, + trace_model, +) +from .openai import list_openai_models, load_openai_model +from .pretrained import ( + download_pretrained, + download_pretrained_from_url, + get_pretrained_cfg, + get_pretrained_url, + is_pretrained_cfg, + list_pretrained, + list_pretrained_models_by_tag, + list_pretrained_tags_by_model, +) from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub -from .tokenizer import SimpleTokenizer, tokenize, decode -from .transform import image_transform, AugmentationCfg +from .tokenizer import SimpleTokenizer, decode, tokenize +from .transform import AugmentationCfg, image_transform diff --git a/src/metr/open_clip/coca_model.py b/src/metr/open_clip/coca_model.py index 039453a..6ef4fe1 100644 --- a/src/metr/open_clip/coca_model.py +++ b/src/metr/open_clip/coca_model.py @@ -1,43 +1,30 @@ +from dataclasses import dataclass from typing import Optional +import numpy as np import torch from torch import nn from torch.nn import functional as F -import numpy as np -from dataclasses import dataclass -from .transformer import ( - LayerNormFp32, - LayerNorm, - QuickGELU, - MultimodalTransformer, -) -from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower +from .model import CLIPTextCfg, CLIPVisionCfg, _build_text_tower, _build_vision_tower +from .transformer import LayerNorm, LayerNormFp32, MultimodalTransformer, QuickGELU try: from transformers import ( BeamSearchScorer, LogitsProcessorList, - TopPLogitsWarper, - TopKLogitsWarper, - RepetitionPenaltyLogitsProcessor, - MinLengthLogitsProcessor, MaxLengthCriteria, - StoppingCriteriaList + MinLengthLogitsProcessor, + RepetitionPenaltyLogitsProcessor, + StoppingCriteriaList, + TopKLogitsWarper, + TopPLogitsWarper, ) - GENERATION_TYPES = { - "top_k": TopKLogitsWarper, - "top_p": TopPLogitsWarper, - "beam_search": "beam_search" - } + GENERATION_TYPES = {"top_k": TopKLogitsWarper, "top_p": TopPLogitsWarper, "beam_search": "beam_search"} _has_transformers = True except ImportError as e: - GENERATION_TYPES = { - "top_k": None, - "top_p": None, - "beam_search": "beam_search" - } + GENERATION_TYPES = {"top_k": None, "top_p": None, "beam_search": "beam_search"} _has_transformers = False @@ -51,16 +38,14 @@ class MultimodalCfg(CLIPTextCfg): def _build_text_decoder_tower( - embed_dim, - multimodal_cfg, - quick_gelu: bool = False, - cast_dtype: Optional[torch.dtype] = None, + embed_dim, + multimodal_cfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, ): multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg act_layer = QuickGELU if quick_gelu else nn.GELU - norm_layer = ( - LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm - ) + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm decoder = MultimodalTransformer( context_length=multimodal_cfg.context_length, @@ -78,14 +63,14 @@ def _build_text_decoder_tower( class CoCa(nn.Module): def __init__( - self, - embed_dim, - multimodal_cfg: MultimodalCfg, - text_cfg: CLIPTextCfg, - vision_cfg: CLIPVisionCfg, - quick_gelu: bool = False, - cast_dtype: Optional[torch.dtype] = None, - pad_id: int = 0, + self, + embed_dim, + multimodal_cfg: MultimodalCfg, + text_cfg: CLIPTextCfg, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + pad_id: int = 0, ): super().__init__() multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg @@ -134,7 +119,7 @@ def _encode_image(self, images, normalize=True): return image_latent, tokens_embs def _encode_text(self, text, normalize=True, embed_cls=True): - text = text[:, :-1] if embed_cls else text # make space for CLS token + text = text[:, :-1] if embed_cls else text # make space for CLS token text_latent, token_emb = self.text(text) text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent return text_latent, token_emb @@ -153,7 +138,7 @@ def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=Non image_latent, image_embs = self._encode_image(image) # TODO: add assertion to avoid bugs? - labels = text[:, -token_embs.shape[1]:] + labels = text[:, -token_embs.shape[1] :] logits = self.text_decoder(image_embs, token_embs) return { @@ -161,7 +146,7 @@ def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=Non "text_features": text_latent, "logits": logits, "labels": labels, - "logit_scale": self.logit_scale.exp() + "logit_scale": self.logit_scale.exp(), } def generate( @@ -170,7 +155,7 @@ def generate( text=None, seq_len=30, max_seq_len=77, - temperature=1., + temperature=1.0, generation_type="beam_search", top_p=0.1, # keep tokens in the 1 - top_p quantile top_k=1, # keeps the top_k most probable tokens @@ -182,7 +167,7 @@ def generate( min_seq_len=5, stopping_criteria=None, repetition_penalty=1.0, - fixed_output_length=False # if True output.shape == (batch_size, seq_len) + fixed_output_length=False, # if True output.shape == (batch_size, seq_len) ): # taking many ideas and components from HuggingFace GenerationMixin # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation @@ -203,15 +188,13 @@ def generate( if stopping_criteria is None: stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] - stopping_criteria = StoppingCriteriaList( - stopping_criteria - ) + stopping_criteria = StoppingCriteriaList(stopping_criteria) device = image.device if generation_type == "beam_search": output = self._generate_beamsearch( - image_inputs = image, + image_inputs=image, pad_token_id=pad_token_id, eos_token_id=eos_token_id, sot_token_id=sot_token_id, @@ -223,8 +206,12 @@ def generate( ) if fixed_output_length and output.shape[1] < seq_len: return torch.cat( - (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id), - dim=1 + ( + output, + torch.ones(output.shape[0], seq_len - output.shape[1], device=device, dtype=output.dtype) + * self.pad_id, + ), + dim=1, ) return output @@ -234,8 +221,7 @@ def generate( logit_warper = GENERATION_TYPES[generation_type](top_k) else: raise ValueError( - f"generation_type has to be one of " - f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." + f"generation_type has to be one of " f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." ) image_latent, image_embs = self._encode_image(image) @@ -256,7 +242,9 @@ def generate( while True: x = out[:, -max_seq_len:] cur_len = x.shape[1] - logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1] + logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][ + :, -1 + ] mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id @@ -269,7 +257,7 @@ def generate( filtered_logits = logit_warper(x[~mask, :], filtered_logits) probs = F.softmax(filtered_logits / temperature, dim=-1) - if (cur_len + 1 == seq_len): + if cur_len + 1 == seq_len: sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id else: sample[~mask, :] = torch.multinomial(probs, 1) @@ -288,17 +276,17 @@ def generate( return out def _generate_beamsearch( - self, - image_inputs, - pad_token_id=None, - eos_token_id=None, - sot_token_id=None, - num_beams=6, - num_beam_groups=3, - min_seq_len=5, - stopping_criteria=None, - logit_processor=None, - logit_warper=None, + self, + image_inputs, + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + logit_processor=None, + logit_warper=None, ): device = image_inputs.device batch_size = image_inputs.shape[0] @@ -349,11 +337,11 @@ def _generate_beamsearch( # do one decoder step on all beams of all sentences in batch model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) outputs = self( - model_inputs['images'], - model_inputs['text'], + model_inputs["images"], + model_inputs["text"], embed_cls=False, image_latent=image_latent, - image_embs=image_embs + image_embs=image_embs, ) for beam_group_idx in range(num_beam_groups): @@ -371,7 +359,7 @@ def _generate_beamsearch( group_input_ids = input_ids[batch_group_indices] # select outputs of beams of currentg group only - next_token_logits = outputs['logits'][batch_group_indices, -1, :] + next_token_logits = outputs["logits"][batch_group_indices, -1, :] vocab_size = next_token_logits.shape[-1] next_token_scores_processed = logits_processor( @@ -412,7 +400,9 @@ def _generate_beamsearch( # (beam_idx // group_size) -> batch_idx # (beam_idx % group_size) -> offset of idx inside the group reordering_indices[batch_group_indices] = ( - num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) + num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + + group_start_idx + + (beam_idx % group_size) ) input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) @@ -433,7 +423,7 @@ def _generate_beamsearch( max_length=stopping_criteria.max_length, beam_indices=final_beam_indices, ) - return sequence_outputs['sequences'] + return sequence_outputs["sequences"] def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): diff --git a/src/metr/open_clip/factory.py b/src/metr/open_clip/factory.py index 14011f9..8ace372 100644 --- a/src/metr/open_clip/factory.py +++ b/src/metr/open_clip/factory.py @@ -9,42 +9,53 @@ import torch -from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD -from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ - resize_pos_embed, get_cast_dtype from .coca_model import CoCa -from .loss import ClipLoss, DistillClipLoss, CoCaLoss +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .loss import ClipLoss, CoCaLoss, DistillClipLoss +from .model import ( + CLIP, + CustomTextCLIP, + convert_to_custom_text_state_dict, + convert_weights_to_lp, + get_cast_dtype, + resize_pos_embed, +) from .openai import load_openai_model -from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf -from .transform import image_transform, AugmentationCfg +from .pretrained import ( + download_pretrained, + download_pretrained_from_hf, + get_pretrained_cfg, + is_pretrained_cfg, + list_pretrained_tags_by_model, +) from .tokenizer import HFTokenizer, tokenize +from .transform import AugmentationCfg, image_transform - -HF_HUB_PREFIX = 'hf-hub:' +HF_HUB_PREFIX = "hf-hub:" _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs def _natural_key(string_): - return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] def _rescan_model_configs(): global _MODEL_CONFIGS - config_ext = ('.json',) + config_ext = (".json",) config_files = [] for config_path in _MODEL_CONFIG_PATHS: if config_path.is_file() and config_path.suffix in config_ext: config_files.append(config_path) elif config_path.is_dir(): for ext in config_ext: - config_files.extend(config_path.glob(f'*{ext}')) + config_files.extend(config_path.glob(f"*{ext}")) for cf in config_files: - with open(cf, 'r') as f: + with open(cf, "r") as f: model_cfg = json.load(f) - if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): + if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")): _MODEL_CONFIGS[cf.stem] = model_cfg _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} @@ -54,12 +65,12 @@ def _rescan_model_configs(): def list_models(): - """ enumerate available model architectures based on config files """ + """enumerate available model architectures based on config files""" return list(_MODEL_CONFIGS.keys()) def add_model_config(path): - """ add model config path or file and update registry """ + """add model config path or file and update registry""" if not isinstance(path, Path): path = Path(path) _MODEL_CONFIG_PATHS.append(path) @@ -75,21 +86,24 @@ def get_model_config(model_name): def get_tokenizer(model_name): if model_name.startswith(HF_HUB_PREFIX): - tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):]) + tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX) :]) else: config = get_model_config(model_name) - tokenizer = HFTokenizer( - config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize + tokenizer = ( + HFTokenizer(config["text_cfg"]["hf_tokenizer_name"]) + if "hf_tokenizer_name" in config["text_cfg"] + else tokenize + ) return tokenizer -def load_state_dict(checkpoint_path: str, map_location='cpu'): +def load_state_dict(checkpoint_path: str, map_location="cpu"): checkpoint = torch.load(checkpoint_path, map_location=map_location) - if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] + if isinstance(checkpoint, dict) and "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] else: state_dict = checkpoint - if next(iter(state_dict.items()))[0].startswith('module'): + if next(iter(state_dict.items()))[0].startswith("module"): state_dict = {k[7:]: v for k, v in state_dict.items()} return state_dict @@ -97,7 +111,7 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'): def load_checkpoint(model, checkpoint_path, strict=True): state_dict = load_state_dict(checkpoint_path) # detect old format and make compatible with new format - if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): + if "positional_embedding" in state_dict and not hasattr(model, "positional_embedding"): state_dict = convert_to_custom_text_state_dict(state_dict) resize_pos_embed(state_dict, model) incompatible_keys = model.load_state_dict(state_dict, strict=strict) @@ -105,33 +119,33 @@ def load_checkpoint(model, checkpoint_path, strict=True): def create_model( - model_name: str, - pretrained: Optional[str] = None, - precision: str = 'fp32', - device: Union[str, torch.device] = 'cpu', - jit: bool = False, - force_quick_gelu: bool = False, - force_custom_text: bool = False, - force_patch_dropout: Optional[float] = None, - force_image_size: Optional[Union[int, Tuple[int, int]]] = None, - pretrained_image: bool = False, - pretrained_hf: bool = True, - cache_dir: Optional[str] = None, - output_dict: Optional[bool] = None, - require_pretrained: bool = False, + model_name: str, + pretrained: Optional[str] = None, + precision: str = "fp32", + device: Union[str, torch.device] = "cpu", + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, + require_pretrained: bool = False, ): has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) if has_hf_hub_prefix: - model_id = model_name[len(HF_HUB_PREFIX):] + model_id = model_name[len(HF_HUB_PREFIX) :] checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) - config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) + config_path = download_pretrained_from_hf(model_id, filename="open_clip_config.json", cache_dir=cache_dir) - with open(config_path, 'r', encoding='utf-8') as f: + with open(config_path, "r", encoding="utf-8") as f: config = json.load(f) - pretrained_cfg = config['preprocess_cfg'] - model_cfg = config['model_cfg'] + pretrained_cfg = config["preprocess_cfg"] + model_cfg = config["model_cfg"] else: - model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names + model_name = model_name.replace("/", "-") # for callers using old naming with / in ViT names checkpoint_path = None pretrained_cfg = {} model_cfg = None @@ -139,8 +153,8 @@ def create_model( if isinstance(device, str): device = torch.device(device) - if pretrained and pretrained.lower() == 'openai': - logging.info(f'Loading pretrained {model_name} from OpenAI.') + if pretrained and pretrained.lower() == "openai": + logging.info(f"Loading pretrained {model_name} from OpenAI.") model = load_openai_model( model_name, precision=precision, @@ -155,10 +169,10 @@ def create_model( else: model_cfg = model_cfg or get_model_config(model_name) if model_cfg is not None: - logging.info(f'Loaded {model_name} model config.') + logging.info(f"Loaded {model_name} model config.") else: - logging.error(f'Model config for {model_name} not found; available models {list_models()}.') - raise RuntimeError(f'Model config for {model_name} not found.') + logging.error(f"Model config for {model_name} not found; available models {list_models()}.") + raise RuntimeError(f"Model config for {model_name} not found.") if force_quick_gelu: # override for use of QuickGELU on non-OpenAI transformer models @@ -173,19 +187,19 @@ def create_model( model_cfg["vision_cfg"]["image_size"] = force_image_size if pretrained_image: - if 'timm_model_name' in model_cfg.get('vision_cfg', {}): + if "timm_model_name" in model_cfg.get("vision_cfg", {}): # pretrained weight loading for timm models set via vision_cfg - model_cfg['vision_cfg']['timm_model_pretrained'] = True + model_cfg["vision_cfg"]["timm_model_pretrained"] = True else: - assert False, 'pretrained image towers currently only supported for timm models' + assert False, "pretrained image towers currently only supported for timm models" cast_dtype = get_cast_dtype(precision) - is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) - custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model + is_hf_model = "hf_model_name" in model_cfg.get("text_cfg", {}) + custom_text = model_cfg.pop("custom_text", False) or force_custom_text or is_hf_model if custom_text: if is_hf_model: - model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf + model_cfg["text_cfg"]["hf_model_pretrained"] = pretrained_hf if "coca" in model_name: model = CoCa(**model_cfg, cast_dtype=cast_dtype) else: @@ -195,7 +209,7 @@ def create_model( pretrained_loaded = False if pretrained: - checkpoint_path = '' + checkpoint_path = "" pretrained_cfg = get_pretrained_cfg(model_name, pretrained) if pretrained_cfg: checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) @@ -203,32 +217,34 @@ def create_model( checkpoint_path = pretrained if checkpoint_path: - logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') + logging.info(f"Loading pretrained {model_name} weights ({pretrained}).") load_checkpoint(model, checkpoint_path) else: error_str = ( - f'Pretrained weights ({pretrained}) not found for model {model_name}.' - f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') + f"Pretrained weights ({pretrained}) not found for model {model_name}." + f"Available pretrained tags ({list_pretrained_tags_by_model(model_name)}." + ) logging.warning(error_str) raise RuntimeError(error_str) pretrained_loaded = True elif has_hf_hub_prefix: - logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') + logging.info(f"Loading pretrained {model_name} weights ({pretrained}).") load_checkpoint(model, checkpoint_path) pretrained_loaded = True if require_pretrained and not pretrained_loaded: # callers of create_model_from_pretrained always expect pretrained weights raise RuntimeError( - f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') + f"Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded." + ) model.to(device=device) if precision in ("fp16", "bf16"): - convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16) + convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == "bf16" else torch.float16) # set image / mean metadata from pretrained_cfg if available, or use default - model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN - model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD + model.visual.image_mean = pretrained_cfg.get("mean", None) or OPENAI_DATASET_MEAN + model.visual.image_std = pretrained_cfg.get("std", None) or OPENAI_DATASET_STD # to always output dict even if it is clip if output_dict and hasattr(model, "output_dict"): @@ -272,22 +288,22 @@ def create_loss(args): def create_model_and_transforms( - model_name: str, - pretrained: Optional[str] = None, - precision: str = 'fp32', - device: Union[str, torch.device] = 'cpu', - jit: bool = False, - force_quick_gelu: bool = False, - force_custom_text: bool = False, - force_patch_dropout: Optional[float] = None, - force_image_size: Optional[Union[int, Tuple[int, int]]] = None, - pretrained_image: bool = False, - pretrained_hf: bool = True, - image_mean: Optional[Tuple[float, ...]] = None, - image_std: Optional[Tuple[float, ...]] = None, - aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, - cache_dir: Optional[str] = None, - output_dict: Optional[bool] = None, + model_name: str, + pretrained: Optional[str] = None, + precision: str = "fp32", + device: Union[str, torch.device] = "cpu", + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, ): model = create_model( model_name, @@ -305,8 +321,8 @@ def create_model_and_transforms( output_dict=output_dict, ) - image_mean = image_mean or getattr(model.visual, 'image_mean', None) - image_std = image_std or getattr(model.visual, 'image_std', None) + image_mean = image_mean or getattr(model.visual, "image_mean", None) + image_std = image_std or getattr(model.visual, "image_std", None) preprocess_train = image_transform( model.visual.image_size, is_train=True, @@ -325,18 +341,18 @@ def create_model_and_transforms( def create_model_from_pretrained( - model_name: str, - pretrained: Optional[str] = None, - precision: str = 'fp32', - device: Union[str, torch.device] = 'cpu', - jit: bool = False, - force_quick_gelu: bool = False, - force_custom_text: bool = False, - force_image_size: Optional[Union[int, Tuple[int, int]]] = None, - return_transform: bool = True, - image_mean: Optional[Tuple[float, ...]] = None, - image_std: Optional[Tuple[float, ...]] = None, - cache_dir: Optional[str] = None, + model_name: str, + pretrained: Optional[str] = None, + precision: str = "fp32", + device: Union[str, torch.device] = "cpu", + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + return_transform: bool = True, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + cache_dir: Optional[str] = None, ): model = create_model( model_name, @@ -354,8 +370,8 @@ def create_model_from_pretrained( if not return_transform: return model - image_mean = image_mean or getattr(model.visual, 'image_mean', None) - image_std = image_std or getattr(model.visual, 'image_std', None) + image_mean = image_mean or getattr(model.visual, "image_mean", None) + image_std = image_std or getattr(model.visual, "image_std", None) preprocess = image_transform( model.visual.image_size, is_train=False, diff --git a/src/metr/open_clip/hf_configs.py b/src/metr/open_clip/hf_configs.py index e236222..b22278c 100644 --- a/src/metr/open_clip/hf_configs.py +++ b/src/metr/open_clip/hf_configs.py @@ -9,7 +9,7 @@ "heads": "num_attention_heads", "layers": "num_hidden_layers", "layer_attr": "layer", - "token_embeddings_attr": "embeddings" + "token_embeddings_attr": "embeddings", }, "pooler": "mean_pooler", }, @@ -22,7 +22,7 @@ "heads": "num_attention_heads", "layers": "num_hidden_layers", "layer_attr": "layer", - "token_embeddings_attr": "embeddings" + "token_embeddings_attr": "embeddings", }, "pooler": "mean_pooler", }, @@ -38,7 +38,7 @@ "heads": "num_heads", "layers": "num_layers", "layer_attr": "block", - "token_embeddings_attr": "embed_tokens" + "token_embeddings_attr": "embed_tokens", }, "pooler": "mean_pooler", }, diff --git a/src/metr/open_clip/hf_model.py b/src/metr/open_clip/hf_model.py index fbccc81..2edf7e9 100644 --- a/src/metr/open_clip/hf_model.py +++ b/src/metr/open_clip/hf_model.py @@ -11,26 +11,28 @@ try: import transformers - from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig - from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ - BaseModelOutputWithPoolingAndCrossAttentions + from transformers import AutoConfig, AutoModel, AutoTokenizer, PretrainedConfig + from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + BaseModelOutputWithPoolingAndCrossAttentions, + ) except ImportError as e: transformers = None - class BaseModelOutput: pass - class PretrainedConfig: pass + from .hf_configs import arch_dict # utils def _camel2snake(s): - return re.sub(r'(? torch.Tensor: def get_logits(self, image_features, text_features, logit_scale): if self.world_size > 1: all_image_features, all_text_features = gather_features( - image_features, text_features, - self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) + image_features, + text_features, + self.local_loss, + self.gather_with_grad, + self.rank, + self.world_size, + self.use_horovod, + ) if self.local_loss: logits_per_image = logit_scale * image_features @ all_text_features.T @@ -114,7 +114,7 @@ def get_logits(self, image_features, text_features, logit_scale): else: logits_per_image = logit_scale * image_features @ text_features.T logits_per_text = logit_scale * text_features @ image_features.T - + return logits_per_image, logits_per_text def forward(self, image_features, text_features, logit_scale, output_dict=False): @@ -123,26 +123,23 @@ def forward(self, image_features, text_features, logit_scale, output_dict=False) labels = self.get_ground_truth(device, logits_per_image.shape[0]) - total_loss = ( - F.cross_entropy(logits_per_image, labels) + - F.cross_entropy(logits_per_text, labels) - ) / 2 + total_loss = (F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)) / 2 return {"contrastive_loss": total_loss} if output_dict else total_loss class CoCaLoss(ClipLoss): def __init__( - self, - caption_loss_weight, - clip_loss_weight, - pad_id=0, # pad_token for open_clip custom tokenizer - local_loss=False, - gather_with_grad=False, - cache_labels=False, - rank=0, - world_size=1, - use_horovod=False, + self, + caption_loss_weight, + clip_loss_weight, + pad_id=0, # pad_token for open_clip custom tokenizer + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, ): super().__init__( local_loss=local_loss, @@ -150,7 +147,7 @@ def __init__( cache_labels=cache_labels, rank=rank, world_size=world_size, - use_horovod=use_horovod + use_horovod=use_horovod, ) self.clip_loss_weight = clip_loss_weight @@ -179,31 +176,28 @@ def dist_loss(self, teacher_logits, student_logits): return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0) def forward( - self, - image_features, - text_features, - logit_scale, - dist_image_features, - dist_text_features, - dist_logit_scale, - output_dict=False, + self, + image_features, + text_features, + logit_scale, + dist_image_features, + dist_text_features, + dist_logit_scale, + output_dict=False, ): - logits_per_image, logits_per_text = \ - self.get_logits(image_features, text_features, logit_scale) + logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) - dist_logits_per_image, dist_logits_per_text = \ - self.get_logits(dist_image_features, dist_text_features, dist_logit_scale) + dist_logits_per_image, dist_logits_per_text = self.get_logits( + dist_image_features, dist_text_features, dist_logit_scale + ) labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0]) - contrastive_loss = ( - F.cross_entropy(logits_per_image, labels) + - F.cross_entropy(logits_per_text, labels) - ) / 2 + contrastive_loss = (F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)) / 2 distill_loss = ( - self.dist_loss(dist_logits_per_image, logits_per_image) + - self.dist_loss(dist_logits_per_text, logits_per_text) + self.dist_loss(dist_logits_per_image, logits_per_image) + + self.dist_loss(dist_logits_per_text, logits_per_text) ) / 2 if output_dict: diff --git a/src/metr/open_clip/model.py b/src/metr/open_clip/model.py index 4f5e775..0f935fb 100644 --- a/src/metr/open_clip/model.py +++ b/src/metr/open_clip/model.py @@ -2,9 +2,10 @@ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. """ -from dataclasses import dataclass + import logging import math +from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np @@ -16,7 +17,7 @@ from .hf_model import HFTextEncoder from .modified_resnet import ModifiedResNet from .timm_model import TimmModel -from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer +from .transformer import Attention, LayerNorm, LayerNormFp32, QuickGELU, TextTransformer, VisionTransformer from .utils import to_2tuple @@ -29,18 +30,24 @@ class CLIPVisionCfg: patch_size: int = 16 image_size: Union[Tuple[int, int], int] = 224 ls_init_value: Optional[float] = None # layer scale initial value - patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results - input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design - global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) - attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer - n_queries: int = 256 # n_queries for attentional pooler - attn_pooler_heads: int = 8 # n heads for attentional_pooling + patch_dropout: float = ( + 0.0 # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results + ) + input_patchnorm: bool = ( + False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design + ) + global_average_pool: bool = ( + False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) + ) + attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer + n_queries: int = 256 # n_queries for attentional pooler + attn_pooler_heads: int = 8 # n heads for attentional_pooling timm_model_name: str = None # a valid model name overrides layers, width, patch_size timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model - timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') - timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') + timm_pool: str = "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + timm_proj: str = "linear" # linear projection for timm model output ('linear', 'mlp', '') timm_proj_bias: bool = False # enable bias final projection - timm_drop: float = 0. # head dropout + timm_drop: float = 0.0 # head dropout timm_drop_path: Optional[float] = None # backbone stochastic depth output_tokens: bool = False @@ -56,8 +63,8 @@ class CLIPTextCfg: hf_model_name: str = None hf_tokenizer_name: str = None hf_model_pretrained: bool = True - proj: str = 'mlp' - pooler_type: str = 'mean_pooler' + proj: str = "mlp" + pooler_type: str = "mean_pooler" embed_cls: bool = False pad_id: int = 0 output_tokens: bool = False @@ -65,18 +72,15 @@ class CLIPTextCfg: def get_cast_dtype(precision: str): cast_dtype = None - if precision == 'bf16': + if precision == "bf16": cast_dtype = torch.bfloat16 - elif precision == 'fp16': + elif precision == "fp16": cast_dtype = torch.float16 return cast_dtype def _build_vision_tower( - embed_dim: int, - vision_cfg: CLIPVisionCfg, - quick_gelu: bool = False, - cast_dtype: Optional[torch.dtype] = None + embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None ): if isinstance(vision_cfg, dict): vision_cfg = CLIPVisionCfg(**vision_cfg) @@ -135,10 +139,10 @@ def _build_vision_tower( def _build_text_tower( - embed_dim: int, - text_cfg: CLIPTextCfg, - quick_gelu: bool = False, - cast_dtype: Optional[torch.dtype] = None, + embed_dim: int, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, ): if isinstance(text_cfg, dict): text_cfg = CLIPTextCfg(**text_cfg) @@ -177,13 +181,13 @@ class CLIP(nn.Module): output_dict: torch.jit.Final[bool] def __init__( - self, - embed_dim: int, - vision_cfg: CLIPVisionCfg, - text_cfg: CLIPTextCfg, - quick_gelu: bool = False, - cast_dtype: Optional[torch.dtype] = None, - output_dict: bool = False, + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, ): super().__init__() self.output_dict = output_dict @@ -196,7 +200,7 @@ def __init__( self.positional_embedding = text.positional_embedding self.ln_final = text.ln_final self.text_projection = text.text_projection - self.register_buffer('attn_mask', text.attn_mask, persistent=False) + self.register_buffer("attn_mask", text.attn_mask, persistent=False) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) @@ -234,7 +238,7 @@ def forward(self, image, text): return { "image_features": image_features, "text_features": text_features, - "logit_scale": self.logit_scale.exp() + "logit_scale": self.logit_scale.exp(), } return image_features, text_features, self.logit_scale.exp() @@ -243,13 +247,13 @@ class CustomTextCLIP(nn.Module): output_dict: torch.jit.Final[bool] def __init__( - self, - embed_dim: int, - vision_cfg: CLIPVisionCfg, - text_cfg: CLIPTextCfg, - quick_gelu: bool = False, - cast_dtype: Optional[torch.dtype] = None, - output_dict: bool = False, + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, ): super().__init__() self.output_dict = output_dict @@ -284,7 +288,7 @@ def forward(self, image, text): return { "image_features": image_features, "text_features": text_features, - "logit_scale": self.logit_scale.exp() + "logit_scale": self.logit_scale.exp(), } return image_features, text_features, self.logit_scale.exp() @@ -318,45 +322,50 @@ def _convert_weights(l): # used to maintain checkpoint compatibility def convert_to_custom_text_state_dict(state_dict: dict): - if 'text_projection' in state_dict: + if "text_projection" in state_dict: # old format state_dict, move text tower -> .text new_state_dict = {} for k, v in state_dict.items(): - if any(k.startswith(p) for p in ( - 'text_projection', - 'positional_embedding', - 'token_embedding', - 'transformer', - 'ln_final', - )): - k = 'text.' + k + if any( + k.startswith(p) + for p in ( + "text_projection", + "positional_embedding", + "token_embedding", + "transformer", + "ln_final", + ) + ): + k = "text." + k new_state_dict[k] = v return new_state_dict return state_dict def build_model_from_openai_state_dict( - state_dict: dict, - quick_gelu=True, - cast_dtype=torch.float16, + state_dict: dict, + quick_gelu=True, + cast_dtype=torch.float16, ): vit = "visual.proj" in state_dict if vit: vision_width = state_dict["visual.conv1.weight"].shape[0] vision_layers = len( - [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")] + ) vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) image_size = vision_patch_size * grid_size else: counts: list = [ - len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4] + ] vision_layers = tuple(counts) vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) vision_patch_size = None - assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + assert output_width**2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] image_size = output_width * 32 embed_dim = state_dict["text_projection"].shape[1] @@ -395,7 +404,7 @@ def build_model_from_openai_state_dict( return model.eval() -def trace_model(model, batch_size=256, device=torch.device('cpu')): +def trace_model(model, batch_size=256, device=torch.device("cpu")): model.eval() image_size = model.visual.image_size example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) @@ -403,18 +412,17 @@ def trace_model(model, batch_size=256, device=torch.device('cpu')): model = torch.jit.trace_module( model, inputs=dict( - forward=(example_images, example_text), - encode_text=(example_text,), - encode_image=(example_images,) - )) + forward=(example_images, example_text), encode_text=(example_text,), encode_image=(example_images,) + ), + ) model.visual.image_size = image_size return model -def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): +def resize_pos_embed(state_dict, model, interpolation: str = "bicubic", antialias: bool = True): # Rescale the grid of position embeddings when loading from state_dict - old_pos_embed = state_dict.get('visual.positional_embedding', None) - if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): + old_pos_embed = state_dict.get("visual.positional_embedding", None) + if old_pos_embed is None or not hasattr(model.visual, "grid_size"): return grid_size = to_2tuple(model.visual.grid_size) extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) @@ -428,7 +436,7 @@ def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialia pos_emb_tok, pos_emb_img = None, old_pos_embed old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) - logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + logging.info("Resizing position embedding grid-size from %s to %s", old_grid_size, grid_size) pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) pos_emb_img = F.interpolate( pos_emb_img, @@ -442,4 +450,4 @@ def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialia new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) else: new_pos_embed = pos_emb_img - state_dict['visual.positional_embedding'] = new_pos_embed + state_dict["visual.positional_embedding"] = new_pos_embed diff --git a/src/metr/open_clip/modified_resnet.py b/src/metr/open_clip/modified_resnet.py index f7c0b03..9f7d46a 100644 --- a/src/metr/open_clip/modified_resnet.py +++ b/src/metr/open_clip/modified_resnet.py @@ -1,11 +1,10 @@ from collections import OrderedDict import torch +from open_clip.utils import freeze_batch_norm_2d from torch import nn from torch.nn import functional as F -from open_clip.utils import freeze_batch_norm_2d - class Bottleneck(nn.Module): expansion = 4 @@ -33,11 +32,15 @@ def __init__(self, inplanes, planes, stride=1): if stride > 1 or inplanes != planes * Bottleneck.expansion: # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 - self.downsample = nn.Sequential(OrderedDict([ - ("-1", nn.AvgPool2d(stride)), - ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), - ("1", nn.BatchNorm2d(planes * self.expansion)) - ])) + self.downsample = nn.Sequential( + OrderedDict( + [ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)), + ] + ) + ) def forward(self, x: torch.Tensor): identity = x @@ -58,7 +61,7 @@ def forward(self, x: torch.Tensor): class AttentionPool2d(nn.Module): def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): super().__init__() - self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) self.k_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) @@ -70,7 +73,9 @@ def forward(self, x): x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC x, _ = F.multi_head_attention_forward( - query=x, key=x, value=x, + query=x, + key=x, + value=x, embed_dim_to_check=x.shape[-1], num_heads=self.num_heads, q_proj_weight=self.q_proj.weight, @@ -81,12 +86,12 @@ def forward(self, x): bias_k=None, bias_v=None, add_zero_attn=False, - dropout_p=0., + dropout_p=0.0, out_proj_weight=self.c_proj.weight, out_proj_bias=self.c_proj.bias, use_separate_proj_weight=True, training=self.training, - need_weights=False + need_weights=False, ) return x[0] @@ -140,7 +145,7 @@ def _make_layer(self, planes, blocks, stride=1): def init_parameters(self): if self.attnpool is not None: - std = self.attnpool.c_proj.in_features ** -0.5 + std = self.attnpool.c_proj.in_features**-0.5 nn.init.normal_(self.attnpool.q_proj.weight, std=std) nn.init.normal_(self.attnpool.k_proj.weight, std=std) nn.init.normal_(self.attnpool.v_proj.weight, std=std) @@ -152,7 +157,7 @@ def init_parameters(self): nn.init.zeros_(param) def lock(self, unlocked_groups=0, freeze_bn_stats=False): - assert unlocked_groups == 0, 'partial locking not currently supported for this model' + assert unlocked_groups == 0, "partial locking not currently supported for this model" for param in self.parameters(): param.requires_grad = False if freeze_bn_stats: diff --git a/src/metr/open_clip/openai.py b/src/metr/open_clip/openai.py index cc4e13e..cb63ad1 100644 --- a/src/metr/open_clip/openai.py +++ b/src/metr/open_clip/openai.py @@ -10,22 +10,22 @@ import torch from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype -from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url +from .pretrained import download_pretrained_from_url, get_pretrained_url, list_pretrained_models_by_tag __all__ = ["list_openai_models", "load_openai_model"] def list_openai_models() -> List[str]: """Returns the names of available CLIP models""" - return list_pretrained_models_by_tag('openai') + return list_pretrained_models_by_tag("openai") def load_openai_model( - name: str, - precision: Optional[str] = None, - device: Optional[Union[str, torch.device]] = None, - jit: bool = True, - cache_dir: Optional[str] = None, + name: str, + precision: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, + jit: bool = True, + cache_dir: Optional[str] = None, ): """Load a CLIP model @@ -52,10 +52,10 @@ def load_openai_model( if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" if precision is None: - precision = 'fp32' if device == 'cpu' else 'fp16' + precision = "fp32" if device == "cpu" else "fp16" - if get_pretrained_url(name, 'openai'): - model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) + if get_pretrained_url(name, "openai"): + model_path = download_pretrained_from_url(get_pretrained_url(name, "openai"), cache_dir=cache_dir) elif os.path.isfile(name): model_path = name else: @@ -83,9 +83,9 @@ def load_openai_model( # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use model = model.to(device) - if precision.startswith('amp') or precision == 'fp32': + if precision.startswith("amp") or precision == "fp32": model.float() - elif precision == 'bf16': + elif precision == "bf16": convert_weights_to_lp(model, dtype=torch.bfloat16) return model @@ -113,7 +113,7 @@ def patch_device(module): patch_device(model.encode_text) # patch dtype to float32 (typically for CPU) - if precision == 'fp32': + if precision == "fp32": float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] float_node = float_input.node() diff --git a/src/metr/open_clip/pretrained.py b/src/metr/open_clip/pretrained.py index 87e7e52..6c19e03 100644 --- a/src/metr/open_clip/pretrained.py +++ b/src/metr/open_clip/pretrained.py @@ -11,6 +11,7 @@ try: from huggingface_hub import hf_hub_download + hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) _has_hf_hub = True except ImportError: @@ -18,7 +19,7 @@ _has_hf_hub = False -def _pcfg(url='', hf_hub='', mean=None, std=None): +def _pcfg(url="", hf_hub="", mean=None, std=None): return dict( url=url, hf_hub=hf_hub, @@ -29,175 +30,202 @@ def _pcfg(url='', hf_hub='', mean=None, std=None): _RN50 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), + "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt" + ), yfcc15m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt" + ), cc12m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt" + ), ) _RN50_quickgelu = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), + "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt" + ), yfcc15m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt" + ), cc12m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt" + ), ) _RN101 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), + "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt" + ), yfcc15m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt" + ), ) _RN101_quickgelu = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), + "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt" + ), yfcc15m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt" + ), ) _RN50x4 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"), + "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt" + ), ) _RN50x16 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"), + "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt" + ), ) _RN50x64 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"), + "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt" + ), ) _VITB32 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt" + ), laion400m_e31=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt" + ), laion400m_e32=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt" + ), laion2b_e16=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), - laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/') + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth" + ), + laion2b_s34b_b79k=_pcfg(hf_hub="laion/CLIP-ViT-B-32-laion2B-s34B-b79K/"), ) _VITB32_quickgelu = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt" + ), laion400m_e31=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt" + ), laion400m_e32=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt" + ), ) _VITB16 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), + "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt" + ), laion400m_e31=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt" + ), laion400m_e32=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt" + ), # laion400m_32k=_pcfg( # url="", # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), # laion400m_64k=_pcfg( # url="", # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), - laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), + laion2b_s34b_b88k=_pcfg(hf_hub="laion/CLIP-ViT-B-16-laion2B-s34B-b88K/"), ) _VITB16_PLUS_240 = dict( laion400m_e31=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt" + ), laion400m_e32=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt" + ), ) _VITL14 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), + "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt" + ), laion400m_e31=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt" + ), laion400m_e32=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), - laion2b_s32b_b82k=_pcfg( - hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt" + ), + laion2b_s32b_b82k=_pcfg(hf_hub="laion/CLIP-ViT-L-14-laion2B-s32B-b82K/", mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), ) _VITL14_336 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), + "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt" + ), ) _VITH14 = dict( - laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), + laion2b_s32b_b79k=_pcfg(hf_hub="laion/CLIP-ViT-H-14-laion2B-s32B-b79K/"), ) _VITg14 = dict( - laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), - laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), + laion2b_s12b_b42k=_pcfg(hf_hub="laion/CLIP-ViT-g-14-laion2B-s12B-b42K/"), + laion2b_s34b_b88k=_pcfg(hf_hub="laion/CLIP-ViT-g-14-laion2B-s34B-b88K/"), ) _VITbigG14 = dict( - laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), + laion2b_s39b_b160k=_pcfg(hf_hub="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/"), ) _robertaViTB32 = dict( - laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'), + laion2b_s12b_b32k=_pcfg(hf_hub="laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/"), ) _xlmRobertaBaseViTB32 = dict( - laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'), + laion5b_s13b_b90k=_pcfg(hf_hub="laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/"), ) _xlmRobertaLargeFrozenViTH14 = dict( - frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'), + frozen_laion5b_s13b_b90k=_pcfg(hf_hub="laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/"), ) _convnext_base = dict( - laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'), + laion400m_s13b_b51k=_pcfg(hf_hub="laion/CLIP-convnext_base-laion400M-s13B-b51K/"), ) _convnext_base_w = dict( - laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'), - laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'), - laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'), + laion2b_s13b_b82k=_pcfg(hf_hub="laion/CLIP-convnext_base_w-laion2B-s13B-b82K/"), + laion2b_s13b_b82k_augreg=_pcfg(hf_hub="laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/"), + laion_aesthetic_s13b_b82k=_pcfg(hf_hub="laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/"), ) _convnext_base_w_320 = dict( - laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'), - laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), + laion_aesthetic_s13b_b82k=_pcfg(hf_hub="laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/"), + laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub="laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/"), ) _convnext_large_d = dict( - laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'), + laion2b_s26b_b102k_augreg=_pcfg(hf_hub="laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/"), ) _convnext_large_d_320 = dict( - laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'), - laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'), + laion2b_s29b_b131k_ft=_pcfg(hf_hub="laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/"), + laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub="laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/"), ) _convnext_xxlarge = dict( - laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'), - laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'), - laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'), + laion2b_s34b_b82k_augreg=_pcfg(hf_hub="laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/"), + laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub="laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/"), + laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub="laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/"), ) _coca_VITB32 = dict( - laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'), - mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/') + laion2b_s13b_b90k=_pcfg(hf_hub="laion/CoCa-ViT-B-32-laion2B-s13B-b90k/"), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub="laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/"), ) _coca_VITL14 = dict( - laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'), - mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/') + laion2b_s13b_b90k=_pcfg(hf_hub="laion/CoCa-ViT-L-14-laion2B-s13B-b90k/"), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub="laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/"), ) @@ -234,18 +262,18 @@ def _pcfg(url='', hf_hub='', mean=None, std=None): def _clean_tag(tag: str): # normalize pretrained tags - return tag.lower().replace('-', '_') + return tag.lower().replace("-", "_") def list_pretrained(as_str: bool = False): - """ returns list of pretrained models + """returns list of pretrained models Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True """ - return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] + return [":".join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] def list_pretrained_models_by_tag(tag: str): - """ return all models having the specified pretrain tag """ + """return all models having the specified pretrain tag""" models = [] tag = _clean_tag(tag) for k in _PRETRAINED.keys(): @@ -255,7 +283,7 @@ def list_pretrained_models_by_tag(tag: str): def list_pretrained_tags_by_model(model: str): - """ return all pretrain tags for the specified model architecture """ + """return all pretrain tags for the specified model architecture""" tags = [] if model in _PRETRAINED: tags.extend(_PRETRAINED[model].keys()) @@ -277,24 +305,24 @@ def get_pretrained_cfg(model: str, tag: str): def get_pretrained_url(model: str, tag: str): cfg = get_pretrained_cfg(model, _clean_tag(tag)) - return cfg.get('url', '') + return cfg.get("url", "") def download_pretrained_from_url( - url: str, - cache_dir: Union[str, None] = None, + url: str, + cache_dir: Union[str, None] = None, ): if not cache_dir: cache_dir = os.path.expanduser("~/.cache/clip") os.makedirs(cache_dir, exist_ok=True) filename = os.path.basename(url) - if 'openaipublic' in url: + if "openaipublic" in url: expected_sha256 = url.split("/")[-2] - elif 'mlfoundations' in url: + elif "mlfoundations" in url: expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] else: - expected_sha256 = '' + expected_sha256 = "" download_target = os.path.join(cache_dir, filename) @@ -306,12 +334,14 @@ def download_pretrained_from_url( if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): return download_target else: - warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + warnings.warn( + f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" + ) else: return download_target with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: - with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit="iB", unit_scale=True) as loop: while True: buffer = source.read(8192) if not buffer: @@ -320,7 +350,9 @@ def download_pretrained_from_url( output.write(buffer) loop.update(len(buffer)) - if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith( + expected_sha256 + ): raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") return download_target @@ -330,15 +362,16 @@ def has_hf_hub(necessary=False): if not _has_hf_hub and necessary: # if no HF Hub module installed, and it is necessary to continue, raise error raise RuntimeError( - 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') + "Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`." + ) return _has_hf_hub def download_pretrained_from_hf( - model_id: str, - filename: str = 'open_clip_pytorch_model.bin', - revision=None, - cache_dir: Union[str, None] = None, + model_id: str, + filename: str = "open_clip_pytorch_model.bin", + revision=None, + cache_dir: Union[str, None] = None, ): has_hf_hub(True) cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) @@ -346,19 +379,19 @@ def download_pretrained_from_hf( def download_pretrained( - cfg: Dict, - force_hf_hub: bool = False, - cache_dir: Union[str, None] = None, + cfg: Dict, + force_hf_hub: bool = False, + cache_dir: Union[str, None] = None, ): - target = '' + target = "" if not cfg: return target - download_url = cfg.get('url', '') - download_hf_hub = cfg.get('hf_hub', '') + download_url = cfg.get("url", "") + download_hf_hub = cfg.get("hf_hub", "") if download_hf_hub and force_hf_hub: # use HF hub even if url exists - download_url = '' + download_url = "" if download_url: target = download_pretrained_from_url(download_url, cache_dir=cache_dir) diff --git a/src/metr/open_clip/push_to_hf_hub.py b/src/metr/open_clip/push_to_hf_hub.py index 23c0631..b2c12d9 100644 --- a/src/metr/open_clip/push_to_hf_hub.py +++ b/src/metr/open_clip/push_to_hf_hub.py @@ -16,6 +16,7 @@ upload_folder, ) from huggingface_hub.utils import EntryNotFoundError + _has_hf_hub = True except ImportError: _has_hf_hub = False @@ -24,21 +25,17 @@ from .tokenizer import HFTokenizer -def save_config_for_hf( - model, - config_path: str, - model_config: Optional[dict] -): +def save_config_for_hf(model, config_path: str, model_config: Optional[dict]): preprocess_cfg = { - 'mean': model.visual.image_mean, - 'std': model.visual.image_std, + "mean": model.visual.image_mean, + "std": model.visual.image_std, } hf_config = { - 'model_cfg': model_config, - 'preprocess_cfg': preprocess_cfg, + "model_cfg": model_config, + "preprocess_cfg": preprocess_cfg, } - with config_path.open('w') as f: + with config_path.open("w") as f: json.dump(hf_config, f, indent=2) @@ -47,8 +44,8 @@ def save_for_hf( tokenizer: HFTokenizer, model_config: dict, save_directory: str, - weights_filename='open_clip_pytorch_model.bin', - config_filename='open_clip_config.json', + weights_filename="open_clip_pytorch_model.bin", + config_filename="open_clip_config.json", ): save_directory = Path(save_directory) save_directory.mkdir(exist_ok=True, parents=True) @@ -67,7 +64,7 @@ def push_to_hf_hub( tokenizer, model_config: Optional[dict], repo_id: str, - commit_message: str = 'Add model', + commit_message: str = "Add model", token: Optional[str] = None, revision: Optional[str] = None, private: bool = False, @@ -76,7 +73,7 @@ def push_to_hf_hub( ): if not isinstance(tokenizer, HFTokenizer): # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14 - tokenizer = HFTokenizer('openai/clip-vit-large-patch14') + tokenizer = HFTokenizer("openai/clip-vit-large-patch14") # Create repo if it doesn't exist yet repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) @@ -106,7 +103,7 @@ def push_to_hf_hub( # Add readme if it does not exist if not has_readme: model_card = model_card or {} - model_name = repo_id.split('/')[-1] + model_name = repo_id.split("/")[-1] readme_path = Path(tmpdir) / "README.md" readme_text = generate_readme(model_card, model_name) readme_path.write_text(readme_text) @@ -127,7 +124,7 @@ def push_pretrained_to_hf_hub( repo_id: str, image_mean: Optional[Tuple[float, ...]] = None, image_std: Optional[Tuple[float, ...]] = None, - commit_message: str = 'Add model', + commit_message: str = "Add model", token: Optional[str] = None, revision: Optional[str] = None, private: bool = False, @@ -165,16 +162,16 @@ def generate_readme(model_card: dict, model_name: str): readme_text += "tags:\n- zero-shot-image-classification\n- clip\n" readme_text += "library_tag: open_clip\n" readme_text += f"license: {model_card.get('license', 'mit')}\n" - if 'details' in model_card and 'Dataset' in model_card['details']: - readme_text += 'datasets:\n' + if "details" in model_card and "Dataset" in model_card["details"]: + readme_text += "datasets:\n" readme_text += f"- {model_card['details']['Dataset'].lower()}\n" readme_text += "---\n" readme_text += f"# Model card for {model_name}\n" - if 'description' in model_card: + if "description" in model_card: readme_text += f"\n{model_card['description']}\n" - if 'details' in model_card: + if "details" in model_card: readme_text += f"\n## Model Details\n" - for k, v in model_card['details'].items(): + for k, v in model_card["details"].items(): if isinstance(v, (list, tuple)): readme_text += f"- **{k}:**\n" for vi in v: @@ -185,22 +182,22 @@ def generate_readme(model_card: dict, model_name: str): readme_text += f" - {ki}: {vi}\n" else: readme_text += f"- **{k}:** {v}\n" - if 'usage' in model_card: + if "usage" in model_card: readme_text += f"\n## Model Usage\n" - readme_text += model_card['usage'] - readme_text += '\n' + readme_text += model_card["usage"] + readme_text += "\n" - if 'comparison' in model_card: + if "comparison" in model_card: readme_text += f"\n## Model Comparison\n" - readme_text += model_card['comparison'] - readme_text += '\n' + readme_text += model_card["comparison"] + readme_text += "\n" - if 'citation' in model_card: + if "citation" in model_card: readme_text += f"\n## Citation\n" - if not isinstance(model_card['citation'], (list, tuple)): - citations = [model_card['citation']] + if not isinstance(model_card["citation"], (list, tuple)): + citations = [model_card["citation"]] else: - citations = model_card['citation'] + citations = model_card["citation"] for c in citations: readme_text += f"```bibtex\n{c}\n```\n" @@ -210,25 +207,39 @@ def generate_readme(model_card: dict, model_name: str): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Push to Hugging Face Hub") parser.add_argument( - "--model", type=str, help="Name of the model to use.", + "--model", + type=str, + help="Name of the model to use.", ) parser.add_argument( - "--pretrained", type=str, + "--pretrained", + type=str, help="Use a pretrained CLIP model weights with the specified tag or file path.", ) parser.add_argument( - "--repo-id", type=str, + "--repo-id", + type=str, help="Destination HF Hub repo-id ie 'organization/model_id'.", ) parser.add_argument( - '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', - help='Override default image mean value of dataset') + "--image-mean", + type=float, + nargs="+", + default=None, + metavar="MEAN", + help="Override default image mean value of dataset", + ) parser.add_argument( - '--image-std', type=float, nargs='+', default=None, metavar='STD', - help='Override default image std deviation of of dataset') + "--image-std", + type=float, + nargs="+", + default=None, + metavar="STD", + help="Override default image std deviation of of dataset", + ) args = parser.parse_args() - print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}') + print(f"Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}") # FIXME add support to pass model_card json / template from file via cmd line @@ -240,4 +251,4 @@ def generate_readme(model_card: dict, model_name: str): image_std=args.image_std, ) - print(f'{args.model} saved.') + print(f"{args.model} saved.") diff --git a/src/metr/open_clip/timm_model.py b/src/metr/open_clip/timm_model.py index dc71a69..ca6c5ec 100644 --- a/src/metr/open_clip/timm_model.py +++ b/src/metr/open_clip/timm_model.py @@ -2,6 +2,7 @@ Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. """ + import logging from collections import OrderedDict @@ -11,14 +12,15 @@ try: import timm from timm.models.layers import Mlp, to_2tuple + try: # old timm imports < 0.8.1 - from timm.models.layers.attention_pool2d import RotAttentionPool2d from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d + from timm.models.layers.attention_pool2d import RotAttentionPool2d except ImportError: # new timm imports >= 0.8.1 - from timm.layers import RotAttentionPool2d from timm.layers import AttentionPool2d as AbsAttentionPool2d + from timm.layers import RotAttentionPool2d except ImportError: timm = None @@ -26,21 +28,21 @@ class TimmModel(nn.Module): - """ timm model adapter + """timm model adapter # FIXME this adapter is a work in progress, may change in ways that break weight compat """ def __init__( - self, - model_name, - embed_dim, - image_size=224, - pool='avg', - proj='linear', - proj_bias=False, - drop=0., - drop_path=None, - pretrained=False, + self, + model_name, + embed_dim, + image_size=224, + pool="avg", + proj="linear", + proj_bias=False, + drop=0.0, + drop_path=None, + pretrained=False, ): super().__init__() if timm is None: @@ -49,14 +51,14 @@ def __init__( self.image_size = to_2tuple(image_size) timm_kwargs = {} if drop_path is not None: - timm_kwargs['drop_path_rate'] = drop_path + timm_kwargs["drop_path_rate"] = drop_path self.trunk = timm.create_model(model_name, pretrained=pretrained, **timm_kwargs) - feat_size = self.trunk.default_cfg.get('pool_size', None) + feat_size = self.trunk.default_cfg.get("pool_size", None) feature_ndim = 1 if not feat_size else 2 - if pool in ('abs_attn', 'rot_attn'): + if pool in ("abs_attn", "rot_attn"): assert feature_ndim == 2 # if attn pooling used, remove both classifier and default pool - self.trunk.reset_classifier(0, global_pool='') + self.trunk.reset_classifier(0, global_pool="") else: # reset global pool if pool config set, otherwise leave as network default reset_kwargs = dict(global_pool=pool) if pool else {} @@ -64,26 +66,26 @@ def __init__( prev_chs = self.trunk.num_features head_layers = OrderedDict() - if pool == 'abs_attn': - head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) + if pool == "abs_attn": + head_layers["pool"] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) prev_chs = embed_dim - elif pool == 'rot_attn': - head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) + elif pool == "rot_attn": + head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) prev_chs = embed_dim else: - assert proj, 'projection layer needed if non-attention pooling is used.' + assert proj, "projection layer needed if non-attention pooling is used." # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used - if proj == 'linear': - head_layers['drop'] = nn.Dropout(drop) - head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) - elif proj == 'mlp': - head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) + if proj == "linear": + head_layers["drop"] = nn.Dropout(drop) + head_layers["proj"] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) + elif proj == "mlp": + head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) self.head = nn.Sequential(head_layers) def lock(self, unlocked_groups=0, freeze_bn_stats=False): - """ lock modules + """lock modules Args: unlocked_groups (int): leave last n layer groups unlocked (default: 0) """ @@ -97,10 +99,11 @@ def lock(self, unlocked_groups=0, freeze_bn_stats=False): # NOTE: partial freeze requires latest timm (master) branch and is subject to change try: # FIXME import here until API stable and in an official release - from timm.models.helpers import group_parameters, group_modules + from timm.models.helpers import group_modules, group_parameters except ImportError: raise RuntimeError( - 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') + "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`" + ) matcher = self.trunk.group_matcher() gparams = group_parameters(self.trunk, matcher) max_layer_id = max(gparams.keys()) @@ -119,7 +122,7 @@ def set_grad_checkpointing(self, enable=True): try: self.trunk.set_grad_checkpointing(enable) except Exception as e: - logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') + logging.warning("grad checkpointing not supported for this timm image tower, continuing without...") def forward(self, x): x = self.trunk(x) diff --git a/src/metr/open_clip/tokenizer.py b/src/metr/open_clip/tokenizer.py index 23fcfcb..52bc7ca 100644 --- a/src/metr/open_clip/tokenizer.py +++ b/src/metr/open_clip/tokenizer.py @@ -2,18 +2,19 @@ Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. """ + import gzip import html + +# https://stackoverflow.com/q/62691279 import os from functools import lru_cache -from typing import Union, List +from typing import List, Union import ftfy import regex as re import torch -# https://stackoverflow.com/q/62691279 -import os os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -33,13 +34,13 @@ def bytes_to_unicode(): To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. """ - bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) cs = bs[:] n = 0 for b in range(2**8): if b not in bs: bs.append(b) - cs.append(2**8+n) + cs.append(2**8 + n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) @@ -64,7 +65,7 @@ def basic_clean(text): def whitespace_clean(text): - text = re.sub(r'\s+', ' ', text) + text = re.sub(r"\s+", " ", text) text = text.strip() return text @@ -73,24 +74,26 @@ class SimpleTokenizer(object): def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') - merges = merges[1:49152-256-2+1] + merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] merges = [tuple(merge.split()) for merge in merges] vocab = list(bytes_to_unicode().values()) - vocab = vocab + [v+'' for v in vocab] + vocab = vocab + [v + "" for v in vocab] for merge in merges: - vocab.append(''.join(merge)) + vocab.append("".join(merge)) if not special_tokens: - special_tokens = ['', ''] + special_tokens = ["", ""] else: - special_tokens = ['', ''] + special_tokens + special_tokens = ["", ""] + special_tokens vocab.extend(special_tokens) self.encoder = dict(zip(vocab, range(len(vocab)))) self.decoder = {v: k for k, v in self.encoder.items()} self.bpe_ranks = dict(zip(merges, range(len(merges)))) - self.cache = {t:t for t in special_tokens} + self.cache = {t: t for t in special_tokens} special = "|".join(special_tokens) - self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + self.pat = re.compile( + special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE + ) self.vocab_size = len(self.encoder) self.all_special_ids = [self.encoder[t] for t in special_tokens] @@ -98,14 +101,14 @@ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): def bpe(self, token): if token in self.cache: return self.cache[token] - word = tuple(token[:-1]) + ( token[-1] + '',) + word = tuple(token[:-1]) + (token[-1] + "",) pairs = get_pairs(word) if not pairs: - return token+'' + return token + "" while True: - bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) if bigram not in self.bpe_ranks: break first, second = bigram @@ -120,8 +123,8 @@ def bpe(self, token): new_word.extend(word[i:]) break - if word[i] == first and i < len(word)-1 and word[i+1] == second: - new_word.append(first+second) + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) i += 2 else: new_word.append(word[i]) @@ -132,7 +135,7 @@ def bpe(self, token): break else: pairs = get_pairs(word) - word = ' '.join(word) + word = " ".join(word) self.cache[token] = word return word @@ -140,22 +143,24 @@ def encode(self, text): bpe_tokens = [] text = whitespace_clean(basic_clean(text)).lower() for token in re.findall(self.pat, text): - token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) - bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) return bpe_tokens def decode(self, tokens): - text = ''.join([self.decoder[token] for token in tokens]) - text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + text = "".join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace").replace("", " ") return text _tokenizer = SimpleTokenizer() + def decode(output_ids: torch.Tensor): output_ids = output_ids.cpu().numpy() return _tokenizer.decode(output_ids) + def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: """ Returns the tokenized representation of given input string(s) @@ -183,7 +188,7 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.Lo if len(tokens) > context_length: tokens = tokens[:context_length] # Truncate tokens[-1] = eot_token - result[i, :len(tokens)] = torch.tensor(tokens) + result[i, : len(tokens)] = torch.tensor(tokens) return result @@ -193,6 +198,7 @@ class HFTokenizer: def __init__(self, tokenizer_name: str): from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) def save_pretrained(self, dest): @@ -206,9 +212,9 @@ def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> to texts = [whitespace_clean(basic_clean(text)) for text in texts] input_ids = self.tokenizer( texts, - return_tensors='pt', + return_tensors="pt", max_length=context_length, - padding='max_length', + padding="max_length", truncation=True, ).input_ids return input_ids diff --git a/src/metr/open_clip/transform.py b/src/metr/open_clip/transform.py index 748884a..5a037a3 100644 --- a/src/metr/open_clip/transform.py +++ b/src/metr/open_clip/transform.py @@ -1,13 +1,19 @@ import warnings -from dataclasses import dataclass, asdict +from dataclasses import asdict, dataclass from typing import Any, Dict, Optional, Sequence, Tuple, Union import torch import torch.nn as nn import torchvision.transforms.functional as F - -from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ - CenterCrop +from torchvision.transforms import ( + CenterCrop, + Compose, + InterpolationMode, + Normalize, + RandomResizedCrop, + Resize, + ToTensor, +) from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD @@ -25,13 +31,13 @@ class AugmentationCfg: class ResizeMaxSize(nn.Module): - def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): + def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0): super().__init__() if not isinstance(max_size, int): raise TypeError(f"Size should be int. Got {type(max_size)}") self.max_size = max_size self.interpolation = interpolation - self.fn = min if fn == 'min' else min + self.fn = min if fn == "min" else min self.fill = fill def forward(self, img): @@ -45,22 +51,22 @@ def forward(self, img): img = F.resize(img, new_size, self.interpolation) pad_h = self.max_size - new_size[0] pad_w = self.max_size - new_size[1] - img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) + img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill) return img def _convert_to_rgb(image): - return image.convert('RGB') + return image.convert("RGB") def image_transform( - image_size: int, - is_train: bool, - mean: Optional[Tuple[float, ...]] = None, - std: Optional[Tuple[float, ...]] = None, - resize_longest_max: bool = False, - fill_color: int = 0, - aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, + image_size: int, + is_train: bool, + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, + resize_longest_max: bool = False, + fill_color: int = 0, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, ): mean = mean or OPENAI_DATASET_MEAN if not isinstance(mean, (list, tuple)): @@ -81,53 +87,58 @@ def image_transform( normalize = Normalize(mean=mean, std=std) if is_train: aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} - use_timm = aug_cfg_dict.pop('use_timm', False) + use_timm = aug_cfg_dict.pop("use_timm", False) if use_timm: from timm.data import create_transform # timm can still be optional + if isinstance(image_size, (tuple, list)): assert len(image_size) >= 2 input_size = (3,) + image_size[-2:] else: input_size = (3, image_size, image_size) # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time - aug_cfg_dict.setdefault('interpolation', 'random') - aug_cfg_dict.setdefault('color_jitter', None) # disable by default + aug_cfg_dict.setdefault("interpolation", "random") + aug_cfg_dict.setdefault("color_jitter", None) # disable by default train_transform = create_transform( input_size=input_size, is_training=True, - hflip=0., + hflip=0.0, mean=mean, std=std, - re_mode='pixel', + re_mode="pixel", **aug_cfg_dict, ) else: - train_transform = Compose([ - RandomResizedCrop( - image_size, - scale=aug_cfg_dict.pop('scale'), - interpolation=InterpolationMode.BICUBIC, - ), - _convert_to_rgb, - ToTensor(), - normalize, - ]) + train_transform = Compose( + [ + RandomResizedCrop( + image_size, + scale=aug_cfg_dict.pop("scale"), + interpolation=InterpolationMode.BICUBIC, + ), + _convert_to_rgb, + ToTensor(), + normalize, + ] + ) if aug_cfg_dict: - warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') + warnings.warn( + f"Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())})." + ) return train_transform else: if resize_longest_max: - transforms = [ - ResizeMaxSize(image_size, fill=fill_color) - ] + transforms = [ResizeMaxSize(image_size, fill=fill_color)] else: transforms = [ Resize(image_size, interpolation=InterpolationMode.BICUBIC), CenterCrop(image_size), ] - transforms.extend([ - _convert_to_rgb, - ToTensor(), - normalize, - ]) + transforms.extend( + [ + _convert_to_rgb, + ToTensor(), + normalize, + ] + ) return Compose(transforms) diff --git a/src/metr/open_clip/transformer.py b/src/metr/open_clip/transformer.py index 4e01510..fb576a9 100644 --- a/src/metr/open_clip/transformer.py +++ b/src/metr/open_clip/transformer.py @@ -1,5 +1,5 @@ -from collections import OrderedDict import math +from collections import OrderedDict from typing import Callable, Optional, Sequence, Tuple import torch @@ -51,12 +51,12 @@ class PatchDropout(nn.Module): def __init__(self, prob, exclude_first_token=True): super().__init__() - assert 0 <= prob < 1. + assert 0 <= prob < 1.0 self.prob = prob self.exclude_first_token = exclude_first_token # exclude CLS token def forward(self, x): - if not self.training or self.prob == 0.: + if not self.training or self.prob == 0.0: return x if self.exclude_first_token: @@ -86,23 +86,23 @@ def forward(self, x): class Attention(nn.Module): def __init__( - self, - dim, - num_heads=8, - qkv_bias=True, - scaled_cosine=False, - scale_heads=False, - logit_scale_max=math.log(1. / 0.01), - attn_drop=0., - proj_drop=0. + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=False, + scale_heads=False, + logit_scale_max=math.log(1.0 / 0.01), + attn_drop=0.0, + proj_drop=0.0, ): super().__init__() self.scaled_cosine = scaled_cosine self.scale_heads = scale_heads - assert dim % num_heads == 0, 'dim should be divisible by num_heads' + assert dim % num_heads == 0, "dim should be divisible by num_heads" self.num_heads = num_heads self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 + self.scale = self.head_dim**-0.5 self.logit_scale_max = logit_scale_max # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original @@ -162,12 +162,7 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None): class AttentionalPooler(nn.Module): def __init__( - self, - d_model: int, - context_dim: int, - n_head: int = 8, - n_queries: int = 256, - norm_layer: Callable = LayerNorm + self, d_model: int, context_dim: int, n_head: int = 8, n_queries: int = 256, norm_layer: Callable = LayerNorm ): super().__init__() self.query = nn.Parameter(torch.randn(n_queries, d_model)) @@ -188,14 +183,14 @@ def _repeat(self, query, N: int): class ResidualAttentionBlock(nn.Module): def __init__( - self, - d_model: int, - n_head: int, - mlp_ratio: float = 4.0, - ls_init_value: float = None, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, - is_cross_attention: bool = False, + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + is_cross_attention: bool = False, ): super().__init__() @@ -207,34 +202,36 @@ def __init__( self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) - self.mlp = nn.Sequential(OrderedDict([ - ("c_fc", nn.Linear(d_model, mlp_width)), - ("gelu", act_layer()), - ("c_proj", nn.Linear(mlp_width, d_model)) - ])) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)), + ] + ) + ) self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() def attention( - self, - q_x: torch.Tensor, - k_x: Optional[torch.Tensor] = None, - v_x: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None, + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, ): k_x = k_x if k_x is not None else q_x v_x = v_x if v_x is not None else q_x attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None - return self.attn( - q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask - )[0] + return self.attn(q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask)[0] def forward( - self, - q_x: torch.Tensor, - k_x: Optional[torch.Tensor] = None, - v_x: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None, + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, ): k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None @@ -246,23 +243,24 @@ def forward( class CustomResidualAttentionBlock(nn.Module): def __init__( - self, - d_model: int, - n_head: int, - mlp_ratio: float = 4.0, - ls_init_value: float = None, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, - scale_cosine_attn: bool = False, - scale_heads: bool = False, - scale_attn: bool = False, - scale_fc: bool = False, + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_cosine_attn: bool = False, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, ): super().__init__() self.ln_1 = norm_layer(d_model) self.attn = Attention( - d_model, n_head, + d_model, + n_head, scaled_cosine=scale_cosine_attn, scale_heads=scale_heads, ) @@ -271,12 +269,16 @@ def __init__( self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) - self.mlp = nn.Sequential(OrderedDict([ - ("c_fc", nn.Linear(d_model, mlp_width)), - ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), - ("gelu", act_layer()), - ("c_proj", nn.Linear(mlp_width, d_model)) - ])) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("ln", norm_layer(mlp_width) if scale_fc else nn.Identity()), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)), + ] + ) + ) self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): @@ -287,25 +289,28 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): class Transformer(nn.Module): def __init__( - self, - width: int, - layers: int, - heads: int, - mlp_ratio: float = 4.0, - ls_init_value: float = None, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, ): super().__init__() self.width = width self.layers = layers self.grad_checkpointing = False - self.resblocks = nn.ModuleList([ - ResidualAttentionBlock( - width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) - for _ in range(layers) - ]) + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer + ) + for _ in range(layers) + ] + ) def get_cast_dtype(self) -> torch.dtype: return self.resblocks[0].mlp.c_fc.weight.dtype @@ -324,24 +329,24 @@ class VisionTransformer(nn.Module): output_tokens: torch.jit.Final[bool] def __init__( - self, - image_size: int, - patch_size: int, - width: int, - layers: int, - heads: int, - mlp_ratio: float, - ls_init_value: float = None, - global_average_pool: bool = False, - attentional_pool: bool = False, - n_queries: int = 256, - attn_pooler_heads: int = 8, - output_dim: int = 512, - patch_dropout: float = 0., - input_patchnorm: bool = False, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, - output_tokens: bool = False + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, + global_average_pool: bool = False, + attentional_pool: bool = False, + n_queries: int = 256, + attn_pooler_heads: int = 8, + output_dim: int = 512, + patch_dropout: float = 0.0, + input_patchnorm: bool = False, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_tokens: bool = False, ): super().__init__() self.output_tokens = output_tokens @@ -359,15 +364,17 @@ def __init__( self.conv1 = nn.Linear(patch_input_dim, width) else: self.patchnorm_pre_ln = nn.Identity() - self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + self.conv1 = nn.Conv2d( + in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False + ) # class embeddings and positional embeddings - scale = width ** -0.5 + scale = width**-0.5 self.class_embedding = nn.Parameter(scale * torch.randn(width)) self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn - self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity() self.ln_pre = norm_layer(width) self.transformer = Transformer( @@ -460,7 +467,9 @@ def forward(self, x: torch.Tensor): # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 if self.input_patchnorm: # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') - x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1]) + x = x.reshape( + x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1] + ) x = x.permute(0, 2, 4, 1, 3, 5) x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1) x = self.patchnorm_pre_ln(x) @@ -472,8 +481,13 @@ def forward(self, x: torch.Tensor): # class embeddings and positional embeddings x = torch.cat( - [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), - x], dim=1) # shape = [*, grid ** 2 + 1, width] + [ + self.class_embedding.to(x.dtype) + + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x, + ], + dim=1, + ) # shape = [*, grid ** 2 + 1, width] x = x + self.positional_embedding.to(x.dtype) # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in @@ -497,7 +511,7 @@ def forward(self, x: torch.Tensor): if self.output_tokens: return pooled, tokens - + return pooled @@ -505,19 +519,19 @@ class TextTransformer(nn.Module): output_tokens: torch.jit.Final[bool] def __init__( - self, - context_length: int = 77, - vocab_size: int = 49408, - width: int = 512, - heads: int = 8, - layers: int = 12, - ls_init_value: float = None, - output_dim: int = 512, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, - embed_cls: bool = False, - pad_id: int = 0, - output_tokens: bool = False, + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + embed_cls: bool = False, + pad_id: int = 0, + output_tokens: bool = False, ): super().__init__() self.output_tokens = output_tokens @@ -548,7 +562,7 @@ def __init__( ) self.ln_final = norm_layer(width) - self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False) self.init_parameters() @@ -558,8 +572,8 @@ def init_parameters(self): if self.cls_emb is not None: nn.init.normal_(self.cls_emb, std=0.01) - proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) - attn_std = self.transformer.width ** -0.5 + proj_std = (self.transformer.width**-0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width**-0.5 fc_std = (2 * self.transformer.width) ** -0.5 for block in self.transformer.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) @@ -568,7 +582,7 @@ def init_parameters(self): nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: - nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): @@ -631,16 +645,16 @@ def forward(self, text): class MultimodalTransformer(Transformer): def __init__( - self, - width: int, - layers: int, - heads: int, - context_length: int = 77, - mlp_ratio: float = 4.0, - ls_init_value: float = None, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, - output_dim: int = 512, + self, + width: int, + layers: int, + heads: int, + context_length: int = 77, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_dim: int = 512, ): super().__init__( @@ -653,27 +667,29 @@ def __init__( norm_layer=norm_layer, ) self.context_length = context_length - self.cross_attn = nn.ModuleList([ - ResidualAttentionBlock( - width, - heads, - mlp_ratio, - ls_init_value=ls_init_value, - act_layer=act_layer, - norm_layer=norm_layer, - is_cross_attention=True, - ) - for _ in range(layers) - ]) + self.cross_attn = nn.ModuleList( + [ + ResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + is_cross_attention=True, + ) + for _ in range(layers) + ] + ) - self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False) self.ln_final = norm_layer(width) self.text_projection = nn.Parameter(torch.empty(width, output_dim)) def init_parameters(self): - proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) - attn_std = self.transformer.width ** -0.5 + proj_std = (self.transformer.width**-0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width**-0.5 fc_std = (2 * self.transformer.width) ** -0.5 for block in self.transformer.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) @@ -687,7 +703,7 @@ def init_parameters(self): nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: - nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) def build_attention_mask(self): # lazily create causal attention mask, with full attention between the tokens diff --git a/src/metr/open_clip/utils.py b/src/metr/open_clip/utils.py index 51e80c5..4266ba7 100644 --- a/src/metr/open_clip/utils.py +++ b/src/metr/open_clip/utils.py @@ -1,11 +1,11 @@ -from itertools import repeat import collections.abc +from itertools import repeat from torch import nn as nn from torchvision.ops.misc import FrozenBatchNorm2d -def freeze_batch_norm_2d(module, module_match={}, name=''): +def freeze_batch_norm_2d(module, module_match={}, name=""): """ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and @@ -37,7 +37,7 @@ def freeze_batch_norm_2d(module, module_match={}, name=''): res.eps = module.eps else: for child_name, child in module.named_children(): - full_child_name = '.'.join([name, child_name]) if name else child_name + full_child_name = ".".join([name, child_name]) if name else child_name new_child = freeze_batch_norm_2d(child, module_match, full_child_name) if new_child is not child: res.add_module(child_name, new_child) @@ -50,6 +50,7 @@ def parse(x): if isinstance(x, collections.abc.Iterable): return x return tuple(repeat(x, n)) + return parse diff --git a/src/metr/open_clip/version.py b/src/metr/open_clip/version.py index 48aa744..8f4a351 100644 --- a/src/metr/open_clip/version.py +++ b/src/metr/open_clip/version.py @@ -1 +1 @@ -__version__ = '2.16.0' +__version__ = "2.16.0" diff --git a/src/metr/optim_utils.py b/src/metr/optim_utils.py index 06b065d..b3f876c 100644 --- a/src/metr/optim_utils.py +++ b/src/metr/optim_utils.py @@ -1,21 +1,21 @@ -import torch -from torchvision import transforms -from datasets import load_dataset - -from PIL import Image, ImageFilter -import random -import numpy as np import copy -from typing import Any, Mapping import json +import random +from typing import Any, Mapping + +import numpy as np import scipy +import torch +from datasets import load_dataset +from PIL import Image, ImageFilter +from torchvision import transforms def read_json(filename: str) -> Mapping[str, Any]: """Returns a Python dict representation of JSON object at input file.""" with open(filename) as fp: return json.load(fp) - + def set_random_seed(seed=0): torch.manual_seed(seed + 0) @@ -58,10 +58,14 @@ def image_distortion(img1, img2, seed, args): if args.crop_scale is not None and args.crop_ratio is not None: set_random_seed(seed) - img1 = transforms.RandomResizedCrop(img1.size, scale=(args.crop_scale, args.crop_scale), ratio=(args.crop_ratio, args.crop_ratio))(img1) + img1 = transforms.RandomResizedCrop( + img1.size, scale=(args.crop_scale, args.crop_scale), ratio=(args.crop_ratio, args.crop_ratio) + )(img1) set_random_seed(seed) - img2 = transforms.RandomResizedCrop(img2.size, scale=(args.crop_scale, args.crop_scale), ratio=(args.crop_ratio, args.crop_ratio))(img2) - + img2 = transforms.RandomResizedCrop( + img2.size, scale=(args.crop_scale, args.crop_scale), ratio=(args.crop_ratio, args.crop_ratio) + )(img2) + if args.gaussian_blur_r is not None: img1 = img1.filter(ImageFilter.GaussianBlur(radius=args.gaussian_blur_r)) img2 = img2.filter(ImageFilter.GaussianBlur(radius=args.gaussian_blur_r)) @@ -89,25 +93,25 @@ def measure_similarity(images, prompt, model, clip_preprocess, tokenizer, device text = tokenizer([prompt]).to(device) text_features = model.encode_text(text) - + image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) - + return (image_features @ text_features.T).mean(-1) def get_dataset(args): - if 'laion' in args.dataset: - dataset = load_dataset(args.dataset)['train'] - prompt_key = 'TEXT' - elif 'coco' in args.dataset: - with open('fid_outputs/coco/meta_data.json') as f: + if "laion" in args.dataset: + dataset = load_dataset(args.dataset)["train"] + prompt_key = "TEXT" + elif "coco" in args.dataset: + with open("fid_outputs/coco/meta_data.json") as f: dataset = json.load(f) - dataset = dataset['annotations'] - prompt_key = 'caption' + dataset = dataset["annotations"] + prompt_key = "caption" else: - dataset = load_dataset(args.dataset)['test'] - prompt_key = 'Prompt' + dataset = load_dataset(args.dataset)["test"] + prompt_key = "Prompt" return dataset, prompt_key @@ -120,13 +124,13 @@ def circle_mask(size=64, r=10, x_offset=0, y_offset=0): y, x = np.ogrid[:size, :size] y = y[::-1] - return ((x - x0)**2 + (y-y0)**2)<= r**2 + return ((x - x0) ** 2 + (y - y0) ** 2) <= r**2 def get_watermarking_mask(init_latents_w, args, device): watermarking_mask = torch.zeros(init_latents_w.shape, dtype=torch.bool).to(device) - if args.w_mask_shape == 'circle': + if args.w_mask_shape == "circle": np_mask = circle_mask(init_latents_w.shape[-1], r=args.w_radius) torch_mask = torch.tensor(np_mask).to(device) @@ -135,28 +139,40 @@ def get_watermarking_mask(init_latents_w, args, device): watermarking_mask[:, :] = torch_mask else: watermarking_mask[:, args.w_channel] = torch_mask - elif args.w_mask_shape == 'square': + elif args.w_mask_shape == "square": anchor_p = init_latents_w.shape[-1] // 2 if args.w_channel == -1: # all channels - watermarking_mask[:, :, anchor_p-args.w_radius:anchor_p+args.w_radius, anchor_p-args.w_radius:anchor_p+args.w_radius] = True + watermarking_mask[ + :, + :, + anchor_p - args.w_radius : anchor_p + args.w_radius, + anchor_p - args.w_radius : anchor_p + args.w_radius, + ] = True else: - watermarking_mask[:, args.w_channel, anchor_p-args.w_radius:anchor_p+args.w_radius, anchor_p-args.w_radius:anchor_p+args.w_radius] = True - elif args.w_mask_shape == 'no': + watermarking_mask[ + :, + args.w_channel, + anchor_p - args.w_radius : anchor_p + args.w_radius, + anchor_p - args.w_radius : anchor_p + args.w_radius, + ] = True + elif args.w_mask_shape == "no": pass else: - raise NotImplementedError(f'w_mask_shape: {args.w_mask_shape}') + raise NotImplementedError(f"w_mask_shape: {args.w_mask_shape}") return watermarking_mask + class MsgError(Exception): "Raised, when len(args.msg) != args.w_radius" pass + def encrypt_message(gt_init, args, device, message): - ''' + """ Inserts given message into Fourier space of gaussian noise - ''' + """ if args.use_random_msgs and (not message or len(message) != args.w_radius): raise MsgError("Message argument not passed or its length is not equal to radius ") @@ -171,22 +187,26 @@ def encrypt_message(gt_init, args, device, message): for i in range(args.w_radius, 0, -1): tmp_mask = circle_mask(gt_init.shape[-1], r=i) tmp_mask = torch.tensor(tmp_mask).to(device) - + for j in range(gt_patch.shape[1]): gt_patch[:, j, tmp_mask] = gt_patch_tmp[0, j, 0, i].item() - + elif args.msg_type == "binary": if not args.use_random_msgs: print(f"NOT USING RANDOM MESSAGE, INSERTING: {args.msg}") - message_mat[args.w_channel] = list(map(lambda x: args.msg_scaler if x == "1" else -args.msg_scaler, list(args.msg))) + message_mat[args.w_channel] = list( + map(lambda x: args.msg_scaler if x == "1" else -args.msg_scaler, list(args.msg)) + ) else: print(f"USING RANDOM MESSAGE, INSERTING: {message}") - message_mat[args.w_channel] = list(map(lambda x: args.msg_scaler if x == "1" else -args.msg_scaler, list(message))) + message_mat[args.w_channel] = list( + map(lambda x: args.msg_scaler if x == "1" else -args.msg_scaler, list(message)) + ) gt_patch_tmp = copy.deepcopy(gt_patch) for i in range(args.w_radius, 0, -1): tmp_mask = circle_mask(gt_init.shape[-1], r=i) tmp_mask = torch.tensor(tmp_mask).to(device) - + for j in range(gt_patch.shape[1]): # итерация по каналам gt_patch[:, j, tmp_mask] = message_mat[j][i - 1] @@ -197,55 +217,55 @@ def encrypt_message(gt_init, args, device, message): def get_watermarking_pattern(pipe, args, device, shape=None, message=None): - ''' + """ Creates elements of gt_patch array - ''' + """ set_random_seed(args.w_seed) if shape is not None: gt_init = torch.randn(*shape, device=device) else: gt_init = pipe.get_random_latents() - if 'seed_ring' in args.w_pattern: + if "seed_ring" in args.w_pattern: gt_patch = gt_init gt_patch_tmp = copy.deepcopy(gt_patch) for i in range(args.w_radius, 0, -1): tmp_mask = circle_mask(gt_init.shape[-1], r=i) tmp_mask = torch.tensor(tmp_mask).to(device) - + for j in range(gt_patch.shape[1]): gt_patch[:, j, tmp_mask] = gt_patch_tmp[0, j, 0, i].item() - elif 'seed_zeros' in args.w_pattern: + elif "seed_zeros" in args.w_pattern: gt_patch = gt_init * 0 - elif 'seed_rand' in args.w_pattern: + elif "seed_rand" in args.w_pattern: gt_patch = gt_init - elif 'rand' in args.w_pattern: + elif "rand" in args.w_pattern: gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) gt_patch[:] = gt_patch[0] - elif 'zeros' in args.w_pattern: + elif "zeros" in args.w_pattern: gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) * 0 - elif 'const' in args.w_pattern: + elif "const" in args.w_pattern: gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) * 0 gt_patch += args.w_pattern_const - elif 'ring' in args.w_pattern: + elif "ring" in args.w_pattern: gt_patch = encrypt_message(gt_init, args, device, message) return gt_patch def inject_watermark(init_latents_w, watermarking_mask, gt_patch, args): - ''' + """ Injects gt_patch elements into watermarking_mask indexes - ''' + """ init_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(init_latents_w), dim=(-1, -2)) - if args.w_injection == 'complex': + if args.w_injection == "complex": init_latents_w_fft[watermarking_mask] = gt_patch[watermarking_mask].clone() - elif args.w_injection == 'seed': + elif args.w_injection == "seed": init_latents_w[watermarking_mask] = gt_patch[watermarking_mask].clone() return init_latents_w else: - NotImplementedError(f'w_injection: {args.w_injection}') + NotImplementedError(f"w_injection: {args.w_injection}") init_latents_w = torch.fft.ifft2(torch.fft.ifftshift(init_latents_w_fft, dim=(-1, -2))).real @@ -253,48 +273,53 @@ def inject_watermark(init_latents_w, watermarking_mask, gt_patch, args): def eval_watermark(reversed_latents_no_w, reversed_latents_w, watermarking_mask, gt_patch, args): - ''' + """ Compares values on watermarking_mask indexes in fourier space of image with gt_patch - ''' - if 'complex' in args.w_measurement: + """ + if "complex" in args.w_measurement: reversed_latents_no_w_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents_no_w), dim=(-1, -2)) reversed_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents_w), dim=(-1, -2)) target_patch = gt_patch - elif 'seed' in args.w_measurement: + elif "seed" in args.w_measurement: reversed_latents_no_w_fft = reversed_latents_no_w reversed_latents_w_fft = reversed_latents_w target_patch = gt_patch else: - NotImplementedError(f'w_measurement: {args.w_measurement}') + NotImplementedError(f"w_measurement: {args.w_measurement}") - if 'l1' in args.w_measurement: - no_w_metric = torch.abs(reversed_latents_no_w_fft[watermarking_mask] - target_patch[watermarking_mask]).mean().item() + if "l1" in args.w_measurement: + no_w_metric = ( + torch.abs(reversed_latents_no_w_fft[watermarking_mask] - target_patch[watermarking_mask]).mean().item() + ) w_metric = torch.abs(reversed_latents_w_fft[watermarking_mask] - target_patch[watermarking_mask]).mean().item() else: - NotImplementedError(f'w_measurement: {args.w_measurement}') + NotImplementedError(f"w_measurement: {args.w_measurement}") return no_w_metric, w_metric def detect_msg(reversed_latents_w, args): - ''' + """ Get predicted message from reversed_latents - ''' + """ pred_msg = [] r = args.w_radius channel = args.w_channel - if 'complex' in args.w_measurement: + if "complex" in args.w_measurement: reversed_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents_w), dim=(-1, -2)) - elif 'seed' in args.w_measurement: + elif "seed" in args.w_measurement: reversed_latents_w_fft = reversed_latents_w else: - NotImplementedError(f'w_measurement: {args.w_measurement}') + NotImplementedError(f"w_measurement: {args.w_measurement}") for i in range(r, 0, -1): # Getting the edges of circles: if r > 1: - tmp_mask = (circle_mask(reversed_latents_w.shape[-1], r=i).astype(int) - circle_mask(reversed_latents_w.shape[-1], r=i - 1).astype(int)).astype(bool) + tmp_mask = ( + circle_mask(reversed_latents_w.shape[-1], r=i).astype(int) + - circle_mask(reversed_latents_w.shape[-1], r=i - 1).astype(int) + ).astype(bool) else: tmp_mask = circle_mask(reversed_latents_w.shape[-1], r=i) @@ -302,32 +327,38 @@ def detect_msg(reversed_latents_w, args): pred_msg.append((pred_circle_tmp_value > 0).to(int).item()) - return pred_msg[::-1] # Prediction is done from the biggest cirlce + return pred_msg[::-1] # Prediction is done from the biggest cirlce + def get_p_value(reversed_latents_no_w, reversed_latents_w, watermarking_mask, gt_patch, args): # assume it's Fourier space wm - reversed_latents_no_w_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents_no_w), dim=(-1, -2))[watermarking_mask].flatten() - reversed_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents_w), dim=(-1, -2))[watermarking_mask].flatten() + reversed_latents_no_w_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents_no_w), dim=(-1, -2))[ + watermarking_mask + ].flatten() + reversed_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents_w), dim=(-1, -2))[ + watermarking_mask + ].flatten() target_patch = gt_patch[watermarking_mask].flatten() target_patch = torch.concatenate([target_patch.real, target_patch.imag]) - + # no_w reversed_latents_no_w_fft = torch.concatenate([reversed_latents_no_w_fft.real, reversed_latents_no_w_fft.imag]) sigma_no_w = reversed_latents_no_w_fft.std() - lambda_no_w = (target_patch ** 2 / sigma_no_w ** 2).sum().item() + lambda_no_w = (target_patch**2 / sigma_no_w**2).sum().item() x_no_w = (((reversed_latents_no_w_fft - target_patch) / sigma_no_w) ** 2).sum().item() p_no_w = scipy.stats.ncx2.cdf(x=x_no_w, df=len(target_patch), nc=lambda_no_w) # w reversed_latents_w_fft = torch.concatenate([reversed_latents_w_fft.real, reversed_latents_w_fft.imag]) sigma_w = reversed_latents_w_fft.std() - lambda_w = (target_patch ** 2 / sigma_w ** 2).sum().item() + lambda_w = (target_patch**2 / sigma_w**2).sum().item() x_w = (((reversed_latents_w_fft - target_patch) / sigma_w) ** 2).sum().item() p_w = scipy.stats.ncx2.cdf(x=x_w, df=len(target_patch), nc=lambda_w) return p_no_w, p_w + def compute_psnr(a, b): mse = torch.mean((a - b) ** 2).item() if mse == 0: @@ -336,16 +367,16 @@ def compute_psnr(a, b): def compute_msssim(a, b): - return ms_ssim(a, b, data_range=1.).item() + return ms_ssim(a, b, data_range=1.0).item() def compute_ssim(a, b): - return ssim(a, b, data_range=1.).item() + return ssim(a, b, data_range=1.0).item() def eval_psnr_ssim_msssim(ori_img_path, new_img_path): - ori_img = Image.open(ori_img_path).convert('RGB') - new_img = Image.open(new_img_path).convert('RGB') + ori_img = Image.open(ori_img_path).convert("RGB") + new_img = Image.open(new_img_path).convert("RGB") if ori_img.size != new_img.size: new_img = new_img.resize(ori_img.size) ori_x = transforms.ToTensor()(ori_img).unsqueeze(0) diff --git a/src/metr/pytorch_fid/__init__.py b/src/metr/pytorch_fid/__init__.py index 0404d81..493f741 100644 --- a/src/metr/pytorch_fid/__init__.py +++ b/src/metr/pytorch_fid/__init__.py @@ -1 +1 @@ -__version__ = '0.3.0' +__version__ = "0.3.0" diff --git a/src/metr/pytorch_fid/fid_score.py b/src/metr/pytorch_fid/fid_score.py index 86034e7..d186426 100644 --- a/src/metr/pytorch_fid/fid_score.py +++ b/src/metr/pytorch_fid/fid_score.py @@ -31,6 +31,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import pathlib from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser @@ -49,32 +50,36 @@ def tqdm(x): return x + try: from inception import InceptionV3 except: from pytorch_fid.inception import InceptionV3 parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) -parser.add_argument('--batch-size', type=int, default=50, - help='Batch size to use') -parser.add_argument('--num-workers', type=int, - help=('Number of processes to use for data loading. ' - 'Defaults to `min(8, num_cpus)`')) -parser.add_argument('--device', type=str, default=None, - help='Device to use. Like cuda, cuda:0 or cpu') -parser.add_argument('--dims', type=int, default=2048, - choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), - help=('Dimensionality of Inception features to use. ' - 'By default, uses pool3 features')) -parser.add_argument('--save-stats', action='store_true', - help=('Generate an npz archive from a directory of samples. ' - 'The first path is used as input and the second as output.')) -parser.add_argument('path', type=str, nargs=2, - help=('Paths to the generated images or ' - 'to .npz statistic files')) - -IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', - 'tif', 'tiff', 'webp'} +parser.add_argument("--batch-size", type=int, default=50, help="Batch size to use") +parser.add_argument( + "--num-workers", type=int, help=("Number of processes to use for data loading. " "Defaults to `min(8, num_cpus)`") +) +parser.add_argument("--device", type=str, default=None, help="Device to use. Like cuda, cuda:0 or cpu") +parser.add_argument( + "--dims", + type=int, + default=2048, + choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), + help=("Dimensionality of Inception features to use. " "By default, uses pool3 features"), +) +parser.add_argument( + "--save-stats", + action="store_true", + help=( + "Generate an npz archive from a directory of samples. " + "The first path is used as input and the second as output." + ), +) +parser.add_argument("path", type=str, nargs=2, help=("Paths to the generated images or " "to .npz statistic files")) + +IMAGE_EXTENSIONS = {"bmp", "jpg", "jpeg", "pgm", "png", "ppm", "tif", "tiff", "webp"} class ImagePathDataset(torch.utils.data.Dataset): @@ -87,14 +92,13 @@ def __len__(self): def __getitem__(self, i): path = self.files[i] - img = Image.open(path).convert('RGB') + img = Image.open(path).convert("RGB") if self.transforms is not None: img = self.transforms(img) return img -def get_activations(files, model, batch_size=50, dims=2048, device='cpu', - num_workers=1): +def get_activations(files, model, batch_size=50, dims=2048, device="cpu", num_workers=1): """Calculates the activations of the pool_3 layer for all images. Params: @@ -117,21 +121,22 @@ def get_activations(files, model, batch_size=50, dims=2048, device='cpu', model.eval() if batch_size > len(files): - print(('Warning: batch size is bigger than the data size. ' - 'Setting batch size to data size')) + print(("Warning: batch size is bigger than the data size. " "Setting batch size to data size")) batch_size = len(files) # dataset = ImagePathDataset(files, transforms=TF.ToTensor()) - dataset = ImagePathDataset(files, - transforms=TF.Compose([ - TF.Resize((299,299)), - TF.ToTensor(), - ])) - dataloader = torch.utils.data.DataLoader(dataset, - batch_size=batch_size, - shuffle=False, - drop_last=False, - num_workers=num_workers) + dataset = ImagePathDataset( + files, + transforms=TF.Compose( + [ + TF.Resize((299, 299)), + TF.ToTensor(), + ] + ), + ) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers + ) pred_arr = np.empty((len(files), dims)) @@ -150,7 +155,7 @@ def get_activations(files, model, batch_size=50, dims=2048, device='cpu', pred = pred.squeeze(3).squeeze(2).cpu().numpy() - pred_arr[start_idx:start_idx + pred.shape[0]] = pred + pred_arr[start_idx : start_idx + pred.shape[0]] = pred start_idx = start_idx + pred.shape[0] @@ -185,18 +190,15 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): sigma1 = np.atleast_2d(sigma1) sigma2 = np.atleast_2d(sigma2) - assert mu1.shape == mu2.shape, \ - 'Training and test mean vectors have different lengths' - assert sigma1.shape == sigma2.shape, \ - 'Training and test covariances have different dimensions' + assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" + assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" diff = mu1 - mu2 # Product might be almost singular covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) if not np.isfinite(covmean).all(): - msg = ('fid calculation produces singular product; ' - 'adding %s to diagonal of cov estimates') % eps + msg = ("fid calculation produces singular product; " "adding %s to diagonal of cov estimates") % eps print(msg) offset = np.eye(sigma1.shape[0]) * eps covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) @@ -205,17 +207,15 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): if np.iscomplexobj(covmean): if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): m = np.max(np.abs(covmean.imag)) - raise ValueError('Imaginary component {}'.format(m)) + raise ValueError("Imaginary component {}".format(m)) covmean = covmean.real tr_covmean = np.trace(covmean) - return (diff.dot(diff) + np.trace(sigma1) - + np.trace(sigma2) - 2 * tr_covmean) + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean -def calculate_activation_statistics(files, model, batch_size=50, dims=2048, - device='cpu', num_workers=1): +def calculate_activation_statistics(files, model, batch_size=50, dims=2048, device="cpu", num_workers=1): """Calculation of the statistics used by the FID. Params: -- files : List of image files paths @@ -239,17 +239,14 @@ def calculate_activation_statistics(files, model, batch_size=50, dims=2048, return mu, sigma -def compute_statistics_of_path(path, model, batch_size, dims, device, - num_workers=1): - if path.endswith('.npz'): +def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=1): + if path.endswith(".npz"): with np.load(path) as f: - m, s = f['mu'][:], f['sigma'][:] + m, s = f["mu"][:], f["sigma"][:] else: path = pathlib.Path(path) - files = sorted([file for ext in IMAGE_EXTENSIONS - for file in path.glob('*.{}'.format(ext))]) - m, s = calculate_activation_statistics(files, model, batch_size, - dims, device, num_workers) + files = sorted([file for ext in IMAGE_EXTENSIONS for file in path.glob("*.{}".format(ext))]) + m, s = calculate_activation_statistics(files, model, batch_size, dims, device, num_workers) return m, s @@ -258,16 +255,14 @@ def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1): """Calculates the FID of two paths""" for p in paths: if not os.path.exists(p): - raise RuntimeError('Invalid path: %s' % p) + raise RuntimeError("Invalid path: %s" % p) block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] model = InceptionV3([block_idx]).to(device) - m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, - dims, device, num_workers) - m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, - dims, device, num_workers) + m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, dims, device, num_workers) + m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, dims, device, num_workers) fid_value = calculate_frechet_distance(m1, s1, m2, s2) return fid_value @@ -276,10 +271,10 @@ def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1): def save_fid_stats(paths, batch_size, device, dims, num_workers=1): """Calculates the FID of two paths""" if not os.path.exists(paths[0]): - raise RuntimeError('Invalid path: %s' % paths[0]) + raise RuntimeError("Invalid path: %s" % paths[0]) if os.path.exists(paths[1]): - raise RuntimeError('Existing output file: %s' % paths[1]) + raise RuntimeError("Existing output file: %s" % paths[1]) block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] @@ -287,8 +282,7 @@ def save_fid_stats(paths, batch_size, device, dims, num_workers=1): print(f"Saving statistics for {paths[0]}") - m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, - dims, device, num_workers) + m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, dims, device, num_workers) np.savez_compressed(paths[1], mu=m1, sigma=s1) @@ -297,7 +291,7 @@ def main(): args = parser.parse_args() if args.device is None: - device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') + device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu") else: device = torch.device(args.device) @@ -318,13 +312,9 @@ def main(): save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers) return - fid_value = calculate_fid_given_paths(args.path, - args.batch_size, - device, - args.dims, - num_workers) - print('FID: ', fid_value) + fid_value = calculate_fid_given_paths(args.path, args.batch_size, device, args.dims, num_workers) + print("FID: ", fid_value) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/src/metr/pytorch_fid/inception.py b/src/metr/pytorch_fid/inception.py index 8898a20..034e96e 100644 --- a/src/metr/pytorch_fid/inception.py +++ b/src/metr/pytorch_fid/inception.py @@ -10,7 +10,7 @@ # Inception weights ported to Pytorch from # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz -FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 +FID_WEIGHTS_URL = "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" # noqa: E501 class InceptionV3(nn.Module): @@ -22,18 +22,20 @@ class InceptionV3(nn.Module): # Maps feature dimensionality to their output blocks indices BLOCK_INDEX_BY_DIM = { - 64: 0, # First max pooling features + 64: 0, # First max pooling features 192: 1, # Second max pooling featurs 768: 2, # Pre-aux classifier features - 2048: 3 # Final average pooling features + 2048: 3, # Final average pooling features } - def __init__(self, - output_blocks=(DEFAULT_BLOCK_INDEX,), - resize_input=True, - normalize_input=True, - requires_grad=False, - use_fid_inception=True): + def __init__( + self, + output_blocks=(DEFAULT_BLOCK_INDEX,), + resize_input=True, + normalize_input=True, + requires_grad=False, + use_fid_inception=True, + ): """Build pretrained InceptionV3 Parameters @@ -71,32 +73,27 @@ def __init__(self, self.output_blocks = sorted(output_blocks) self.last_needed_block = max(output_blocks) - assert self.last_needed_block <= 3, \ - 'Last possible output block index is 3' + assert self.last_needed_block <= 3, "Last possible output block index is 3" self.blocks = nn.ModuleList() if use_fid_inception: inception = fid_inception_v3() else: - inception = _inception_v3(weights='DEFAULT') + inception = _inception_v3(weights="DEFAULT") # Block 0: input to maxpool1 block0 = [ inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3, - nn.MaxPool2d(kernel_size=3, stride=2) + nn.MaxPool2d(kernel_size=3, stride=2), ] self.blocks.append(nn.Sequential(*block0)) # Block 1: maxpool1 to maxpool2 if self.last_needed_block >= 1: - block1 = [ - inception.Conv2d_3b_1x1, - inception.Conv2d_4a_3x3, - nn.MaxPool2d(kernel_size=3, stride=2) - ] + block1 = [inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, nn.MaxPool2d(kernel_size=3, stride=2)] self.blocks.append(nn.Sequential(*block1)) # Block 2: maxpool2 to aux classifier @@ -119,7 +116,7 @@ def __init__(self, inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c, - nn.AdaptiveAvgPool2d(output_size=(1, 1)) + nn.AdaptiveAvgPool2d(output_size=(1, 1)), ] self.blocks.append(nn.Sequential(*block3)) @@ -144,10 +141,7 @@ def forward(self, inp): x = inp if self.resize_input: - x = F.interpolate(x, - size=(299, 299), - mode='bilinear', - align_corners=False) + x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False) if self.normalize_input: x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) @@ -166,7 +160,7 @@ def forward(self, inp): def _inception_v3(*args, **kwargs): """Wraps `torchvision.models.inception_v3`""" try: - version = tuple(map(int, torchvision.__version__.split('.')[:2])) + version = tuple(map(int, torchvision.__version__.split(".")[:2])) except ValueError: # Just a caution against weird version strings version = (0,) @@ -174,22 +168,20 @@ def _inception_v3(*args, **kwargs): # Skips default weight inititialization if supported by torchvision # version. See https://github.com/mseitzer/pytorch-fid/issues/28. if version >= (0, 6): - kwargs['init_weights'] = False + kwargs["init_weights"] = False # Backwards compatibility: `weights` argument was handled by `pretrained` # argument prior to version 0.13. - if version < (0, 13) and 'weights' in kwargs: - if kwargs['weights'] == 'DEFAULT': - kwargs['pretrained'] = True - elif kwargs['weights'] is None: - kwargs['pretrained'] = False + if version < (0, 13) and "weights" in kwargs: + if kwargs["weights"] == "DEFAULT": + kwargs["pretrained"] = True + elif kwargs["weights"] is None: + kwargs["pretrained"] = False else: raise ValueError( - 'weights=={} not supported in torchvision {}'.format( - kwargs['weights'], torchvision.__version__ - ) + "weights=={} not supported in torchvision {}".format(kwargs["weights"], torchvision.__version__) ) - del kwargs['weights'] + del kwargs["weights"] return torchvision.models.inception_v3(*args, **kwargs) @@ -203,9 +195,7 @@ def fid_inception_v3(): This method first constructs torchvision's Inception and then patches the necessary parts that are different in the FID Inception model. """ - inception = _inception_v3(num_classes=1008, - aux_logits=False, - weights=None) + inception = _inception_v3(num_classes=1008, aux_logits=False, weights=None) inception.Mixed_5b = FIDInceptionA(192, pool_features=32) inception.Mixed_5c = FIDInceptionA(256, pool_features=64) inception.Mixed_5d = FIDInceptionA(288, pool_features=64) @@ -223,6 +213,7 @@ def fid_inception_v3(): class FIDInceptionA(torchvision.models.inception.InceptionA): """InceptionA block patched for FID computation""" + def __init__(self, in_channels, pool_features): super(FIDInceptionA, self).__init__(in_channels, pool_features) @@ -238,8 +229,7 @@ def forward(self, x): # Patch: Tensorflow's average pool does not use the padded zero's in # its average calculation - branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, - count_include_pad=False) + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False) branch_pool = self.branch_pool(branch_pool) outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] @@ -248,6 +238,7 @@ def forward(self, x): class FIDInceptionC(torchvision.models.inception.InceptionC): """InceptionC block patched for FID computation""" + def __init__(self, in_channels, channels_7x7): super(FIDInceptionC, self).__init__(in_channels, channels_7x7) @@ -266,8 +257,7 @@ def forward(self, x): # Patch: Tensorflow's average pool does not use the padded zero's in # its average calculation - branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, - count_include_pad=False) + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False) branch_pool = self.branch_pool(branch_pool) outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] @@ -276,6 +266,7 @@ def forward(self, x): class FIDInceptionE_1(torchvision.models.inception.InceptionE): """First InceptionE block patched for FID computation""" + def __init__(self, in_channels): super(FIDInceptionE_1, self).__init__(in_channels) @@ -299,8 +290,7 @@ def forward(self, x): # Patch: Tensorflow's average pool does not use the padded zero's in # its average calculation - branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, - count_include_pad=False) + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False) branch_pool = self.branch_pool(branch_pool) outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] @@ -309,6 +299,7 @@ def forward(self, x): class FIDInceptionE_2(torchvision.models.inception.InceptionE): """Second InceptionE block patched for FID computation""" + def __init__(self, in_channels): super(FIDInceptionE_2, self).__init__(in_channels) diff --git a/src/metr/run_metr.py b/src/metr/run_metr.py index 8cc8763..aaefcb4 100644 --- a/src/metr/run_metr.py +++ b/src/metr/run_metr.py @@ -1,37 +1,40 @@ from PIL import Image, ImageFile + ImageFile.LOAD_TRUNCATED_IMAGES = True -import PIL -import sys -import torch -import os +import argparse +import copy import glob -import numpy as np +import os +import sys +from statistics import mean, stdev +import numpy as np +import PIL +import torch +from diffusers import DPMSolverMultistepScheduler +from sklearn import metrics +from tqdm import tqdm from wm_attacks import ReSDPipeline - from wm_attacks.wmattacker_no_saving import DiffWMAttacker, VAEWMAttacker -# ------------ - -import argparse import wandb -import copy -from tqdm import tqdm -from statistics import mean, stdev -from sklearn import metrics from .inverse_stable_diffusion import InversableStableDiffusionPipeline - -from diffusers import DPMSolverMultistepScheduler - -from .pytorch_fid.fid_score import * +from .io_utils import * from .open_clip import create_model_and_transforms, get_tokenizer - from .optim_utils import * -from .io_utils import * +from .pytorch_fid.fid_score import * from .stable_sig.utils_model import * +# ------------ + + + + + + + def main(args): if args.save_locally: if not os.path.exists(args.local_path) and not os.path.exists(args.local_path + f"/imgs_no_w/"): @@ -41,27 +44,36 @@ def main(args): table = None if args.with_tracking: - wandb.init(project=args.project_name, name=args.run_name, tags=['tree_ring_watermark']) + wandb.init(project=args.project_name, name=args.run_name, tags=["tree_ring_watermark"]) wandb.config.update(args) if args.use_attack: - columns = ['gen_no_w', 'no_w_clip_score', 'gen_w', 'w_clip_score', 'att_gen_w', 'prompt', 'no_w_metric', 'w_metric'] + columns = [ + "gen_no_w", + "no_w_clip_score", + "gen_w", + "w_clip_score", + "att_gen_w", + "prompt", + "no_w_metric", + "w_metric", + ] else: - columns = ['gen_no_w', 'no_w_clip_score', 'gen_w', 'w_clip_score', 'prompt', 'no_w_metric', 'w_metric'] + columns = ["gen_no_w", "no_w_clip_score", "gen_w", "w_clip_score", "prompt", "no_w_metric", "w_metric"] if args.use_random_msgs: - columns.append('message') + columns.append("message") table = wandb.Table(columns=columns) - + # load diffusion model - device = 'cuda' if torch.cuda.is_available() else 'cpu' - - scheduler = DPMSolverMultistepScheduler.from_pretrained(args.model_id, subfolder='scheduler') + device = "cuda" if torch.cuda.is_available() else "cpu" + + scheduler = DPMSolverMultistepScheduler.from_pretrained(args.model_id, subfolder="scheduler") pipe = InversableStableDiffusionPipeline.from_pretrained( args.model_id, scheduler=scheduler, torch_dtype=torch.float16, - revision='fp16', - ) + revision="fp16", + ) pipe = pipe.to(device) if not args.no_stable_sig: @@ -70,13 +82,15 @@ def main(args): # reference model if args.reference_model is not None: - ref_model, _, ref_clip_preprocess = create_model_and_transforms(args.reference_model, pretrained=args.reference_model_pretrain, device=device) + ref_model, _, ref_clip_preprocess = create_model_and_transforms( + args.reference_model, pretrained=args.reference_model_pretrain, device=device + ) ref_tokenizer = get_tokenizer(args.reference_model) # dataset dataset, prompt_key = get_dataset(args) - tester_prompt = '' # assume at the detection time, the original prompt is unknown + tester_prompt = "" # assume at the detection time, the original prompt is unknown text_embeddings = pipe.get_text_embedding(tester_prompt) # ground-truth patch @@ -94,27 +108,29 @@ def main(args): if args.use_attack: if args.attack_type == "diff": - attack_pipe = ReSDPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16, revision="fp16") + attack_pipe = ReSDPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16, revision="fp16" + ) attack_pipe.set_progress_bar_config(disable=True) attack_pipe.to(device) attacker = DiffWMAttacker(attack_pipe, noise_step=args.diff_attack_steps) if args.attack_type == "vae": - attacker = VAEWMAttacker(args.vae_attack_name, quality=args.vae_attack_quality, metric='mse', device=device) + attacker = VAEWMAttacker(args.vae_attack_name, quality=args.vae_attack_quality, metric="mse", device=device) for i in tqdm(range(args.start, args.end)): seed = i + args.gen_seed - + current_prompt = dataset[i][prompt_key] if args.given_prompt: current_prompt = args.given_prompt if args.use_random_msgs: msg_key = torch.randint(0, 2, (1, args.w_radius), dtype=torch.float32, device="cpu") - msg_str = "".join([ str(int(ii)) for ii in msg_key.tolist()[0]]) + msg_str = "".join([str(int(ii)) for ii in msg_key.tolist()[0]]) if args.use_random_msgs: gt_patch = get_watermarking_pattern(pipe, args, device, message=msg_str) - + ### generation # generation without watermarking set_random_seed(seed) @@ -127,9 +143,9 @@ def main(args): height=args.image_length, width=args.image_length, latents=init_latents_no_w, - ) + ) orig_image_no_w = outputs_no_w.images[0] - + # generation with watermarking if init_latents_no_w is None: set_random_seed(seed) @@ -151,7 +167,7 @@ def main(args): height=args.image_length, width=args.image_length, latents=init_latents_w, - ) + ) orig_image_w = outputs_w.images[0] ### test watermark @@ -192,8 +208,10 @@ def main(args): ) # eval - no_w_metric, w_metric = eval_watermark(reversed_latents_no_w, reversed_latents_w, watermarking_mask, gt_patch, args) - + no_w_metric, w_metric = eval_watermark( + reversed_latents_no_w, reversed_latents_w, watermarking_mask, gt_patch, args + ) + if args.save_rev_lat: rev_lat_path = f"{args.path_rev_lat}/s_{args.msg_scaler}_r_{args.w_radius}" if not os.path.exists(rev_lat_path): @@ -219,18 +237,24 @@ def main(args): if correct_bits_tmp == args.w_radius: words_right += 1 - if args.reference_model is not None: - sims = measure_similarity([orig_image_no_w, orig_image_w], current_prompt, ref_model, ref_clip_preprocess, ref_tokenizer, device) + sims = measure_similarity( + [orig_image_no_w, orig_image_w], current_prompt, ref_model, ref_clip_preprocess, ref_tokenizer, device + ) w_no_sim = sims[0].item() w_sim = sims[1].item() else: w_no_sim = 0 w_sim = 0 - results.append({ - 'no_w_metric': no_w_metric, 'w_metric': w_metric, 'w_no_sim': w_no_sim, 'w_sim': w_sim, - }) + results.append( + { + "no_w_metric": no_w_metric, + "w_metric": w_metric, + "w_no_sim": w_no_sim, + "w_sim": w_sim, + } + ) no_w_metrics.append(-no_w_metric) w_metrics.append(-w_metric) @@ -243,9 +267,26 @@ def main(args): if (args.reference_model is not None) and (i < args.max_num_log_image): # log images when we use reference_model if args.use_attack: - data_to_add = [wandb.Image(orig_image_no_w), w_no_sim, wandb.Image(orig_image_w), w_sim, wandb.Image(att_img_w), current_prompt, no_w_metric, w_metric] + data_to_add = [ + wandb.Image(orig_image_no_w), + w_no_sim, + wandb.Image(orig_image_w), + w_sim, + wandb.Image(att_img_w), + current_prompt, + no_w_metric, + w_metric, + ] else: - data_to_add = [wandb.Image(orig_image_no_w), w_no_sim, wandb.Image(orig_image_w), w_sim, current_prompt, no_w_metric, w_metric] + data_to_add = [ + wandb.Image(orig_image_no_w), + w_no_sim, + wandb.Image(orig_image_w), + w_sim, + current_prompt, + no_w_metric, + w_metric, + ] else: if args.use_attack: data_to_add = [None, w_no_sim, None, w_sim, None, current_prompt, no_w_metric, w_metric] @@ -267,18 +308,24 @@ def main(args): fpr, tpr, thresholds = metrics.roc_curve(t_labels, preds, pos_label=1) auc = metrics.auc(fpr, tpr) - acc = np.max(1 - (fpr + (1 - tpr))/2) - low = tpr[np.where(fpr<.01)[0][-1]] + acc = np.max(1 - (fpr + (1 - tpr)) / 2) + low = tpr[np.where(fpr < 0.01)[0][-1]] if args.with_tracking: - wandb.log({'Table': table}) + wandb.log({"Table": table}) if (i - args.start) > 0: metrics_dict = { - 'clip_score_mean': mean(clip_scores), 'clip_score_std': stdev(clip_scores), - 'w_clip_score_mean': mean(clip_scores_w), 'w_clip_score_std': stdev(clip_scores_w), - 'auc': auc, 'acc':acc, 'TPR@1%FPR': low, - 'w_det_dist_mean': -mean(w_metrics), 'w_det_dist_std': stdev(w_metrics), - 'no_w_det_dist_mean': -mean(no_w_metrics), 'no_w_det_dist_std': stdev(no_w_metrics), + "clip_score_mean": mean(clip_scores), + "clip_score_std": stdev(clip_scores), + "w_clip_score_mean": mean(clip_scores_w), + "w_clip_score_std": stdev(clip_scores_w), + "auc": auc, + "acc": acc, + "TPR@1%FPR": low, + "w_det_dist_mean": -mean(w_metrics), + "w_det_dist_std": stdev(w_metrics), + "no_w_det_dist_mean": -mean(no_w_metrics), + "no_w_det_dist_std": stdev(no_w_metrics), } if args.msg_type == "binary": metrics_dict["Bit_acc"] = mean(bit_accs) @@ -286,85 +333,85 @@ def main(args): if (i - args.start) > 0: wandb.log(metrics_dict) - - print(f'clip_score_mean: {mean(clip_scores)}') - print(f'w_clip_score_mean: {mean(clip_scores_w)}') - print(f'auc: {auc}, acc: {acc}, TPR@1%FPR: {low}') - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='diffusion watermark') - parser.add_argument('--project_name', default='watermark_attacks') - parser.add_argument('--run_name', default='test') - parser.add_argument('--dataset', default='Gustavosta/Stable-Diffusion-Prompts') - parser.add_argument('--start', default=0, type=int) - parser.add_argument('--end', default=10, type=int) - parser.add_argument('--image_length', default=512, type=int) - parser.add_argument('--model_id', default='stabilityai/stable-diffusion-2-1-base') - parser.add_argument('--with_tracking', action='store_true') + + print(f"clip_score_mean: {mean(clip_scores)}") + print(f"w_clip_score_mean: {mean(clip_scores_w)}") + print(f"auc: {auc}, acc: {acc}, TPR@1%FPR: {low}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="diffusion watermark") + parser.add_argument("--project_name", default="watermark_attacks") + parser.add_argument("--run_name", default="test") + parser.add_argument("--dataset", default="Gustavosta/Stable-Diffusion-Prompts") + parser.add_argument("--start", default=0, type=int) + parser.add_argument("--end", default=10, type=int) + parser.add_argument("--image_length", default=512, type=int) + parser.add_argument("--model_id", default="stabilityai/stable-diffusion-2-1-base") + parser.add_argument("--with_tracking", action="store_true") # logs and metrics: - parser.add_argument('--freq_log', default=20, type=int) - parser.add_argument('--save_locally', action='store_true') - parser.add_argument('--local_path', default='generated_images') - - parser.add_argument('--num_images', default=1, type=int) - parser.add_argument('--guidance_scale', default=7.5, type=float) - parser.add_argument('--num_inference_steps', default=40, type=int) - parser.add_argument('--test_num_inference_steps', default=None, type=int) - parser.add_argument('--reference_model', default=None) - parser.add_argument('--reference_model_pretrain', default=None) - parser.add_argument('--max_num_log_image', default=100, type=int) - parser.add_argument('--gen_seed', default=0, type=int) + parser.add_argument("--freq_log", default=20, type=int) + parser.add_argument("--save_locally", action="store_true") + parser.add_argument("--local_path", default="generated_images") + + parser.add_argument("--num_images", default=1, type=int) + parser.add_argument("--guidance_scale", default=7.5, type=float) + parser.add_argument("--num_inference_steps", default=40, type=int) + parser.add_argument("--test_num_inference_steps", default=None, type=int) + parser.add_argument("--reference_model", default=None) + parser.add_argument("--reference_model_pretrain", default=None) + parser.add_argument("--max_num_log_image", default=100, type=int) + parser.add_argument("--gen_seed", default=0, type=int) # watermark - parser.add_argument('--w_seed', default=999999, type=int) - parser.add_argument('--w_channel', default=3, type=int) - parser.add_argument('--w_pattern', default='ring') - parser.add_argument('--w_mask_shape', default='circle') - parser.add_argument('--w_radius', default=10, type=int) - parser.add_argument('--w_measurement', default='l1_complex') - parser.add_argument('--w_injection', default='complex') - parser.add_argument('--w_pattern_const', default=0, type=float) - + parser.add_argument("--w_seed", default=999999, type=int) + parser.add_argument("--w_channel", default=3, type=int) + parser.add_argument("--w_pattern", default="ring") + parser.add_argument("--w_mask_shape", default="circle") + parser.add_argument("--w_radius", default=10, type=int) + parser.add_argument("--w_measurement", default="l1_complex") + parser.add_argument("--w_injection", default="complex") + parser.add_argument("--w_pattern_const", default=0, type=float) + # for image distortion - parser.add_argument('--r_degree', default=None, type=float) - parser.add_argument('--jpeg_ratio', default=None, type=int) - parser.add_argument('--crop_scale', default=None, type=float) - parser.add_argument('--crop_ratio', default=None, type=float) - parser.add_argument('--gaussian_blur_r', default=None, type=int) - parser.add_argument('--gaussian_std', default=None, type=float) - parser.add_argument('--brightness_factor', default=None, type=float) - parser.add_argument('--rand_aug', default=0, type=int) + parser.add_argument("--r_degree", default=None, type=float) + parser.add_argument("--jpeg_ratio", default=None, type=int) + parser.add_argument("--crop_scale", default=None, type=float) + parser.add_argument("--crop_ratio", default=None, type=float) + parser.add_argument("--gaussian_blur_r", default=None, type=int) + parser.add_argument("--gaussian_std", default=None, type=float) + parser.add_argument("--brightness_factor", default=None, type=float) + parser.add_argument("--rand_aug", default=0, type=int) # VAE or Diff attack - parser.add_argument('--use_attack', action='store_true') - parser.add_argument('--attack_type', default='diff') - parser.add_argument('--use_attack_prompt', action='store_true') - parser.add_argument('--diff_attack_steps', default=60, type=int) - parser.add_argument('--vae_attack_name', default='cheng2020-anchor') - parser.add_argument('--vae_attack_quality', default=3, type=int) + parser.add_argument("--use_attack", action="store_true") + parser.add_argument("--attack_type", default="diff") + parser.add_argument("--use_attack_prompt", action="store_true") + parser.add_argument("--diff_attack_steps", default=60, type=int) + parser.add_argument("--vae_attack_name", default="cheng2020-anchor") + parser.add_argument("--vae_attack_quality", default=3, type=int) # METR++ - parser.add_argument('--decoder_state_dict_path', default='finetune_ldm_decoder/ldm_decoder_checkpoint_000.pth') - parser.add_argument('--no_stable_sig', action='store_true') - parser.add_argument('--stable_sig_full_model_config', default="v2-inference.yaml") - parser.add_argument('--stable_sig_full_model_ckpt', default='v2-1_512-ema-pruned.ckpt') + parser.add_argument("--decoder_state_dict_path", default="finetune_ldm_decoder/ldm_decoder_checkpoint_000.pth") + parser.add_argument("--no_stable_sig", action="store_true") + parser.add_argument("--stable_sig_full_model_config", default="v2-inference.yaml") + parser.add_argument("--stable_sig_full_model_ckpt", default="v2-1_512-ema-pruned.ckpt") # Message encryption (for testing: putting the same message on each image, but they can be different): - parser.add_argument('--msg_type', default='rand', help="Can be: rand or binary or decimal") - parser.add_argument('--msg', default='1110101101') - parser.add_argument('--use_random_msgs', action='store_true', help="Generate random message each step of cycle") - parser.add_argument('--msg_scaler', default=100, type=int, help="Scaling coefficient of message") + parser.add_argument("--msg_type", default="rand", help="Can be: rand or binary or decimal") + parser.add_argument("--msg", default="1110101101") + parser.add_argument("--use_random_msgs", action="store_true", help="Generate random message each step of cycle") + parser.add_argument("--msg_scaler", default=100, type=int, help="Scaling coefficient of message") # For testing - parser.add_argument('--given_prompt', default=None, type=str) - parser.add_argument('--save_rev_lat', action='store_true', help="Flag to save reversed latents") - parser.add_argument('--path_rev_lat', default=None, type=str) + parser.add_argument("--given_prompt", default=None, type=str) + parser.add_argument("--save_rev_lat", action="store_true", help="Flag to save reversed latents") + parser.add_argument("--path_rev_lat", default=None, type=str) args = parser.parse_args() if args.test_num_inference_steps is None: args.test_num_inference_steps = args.num_inference_steps - - main(args) \ No newline at end of file + + main(args) diff --git a/src/metr/run_metr_fid.py b/src/metr/run_metr_fid.py index f56a0cb..10a306f 100644 --- a/src/metr/run_metr_fid.py +++ b/src/metr/run_metr_fid.py @@ -1,34 +1,32 @@ import argparse -import wandb import copy -from tqdm import tqdm +import glob import json +import math +import os +import sys +import numpy as np import PIL -import sys import torch -import os -import glob -import numpy as np - -from .inverse_stable_diffusion import InversableStableDiffusionPipeline from diffusers import DPMSolverMultistepScheduler -from .optim_utils import * -from .io_utils import * -from .pytorch_fid.fid_score import * - +from PIL import Image, ImageFile +from pytorch_msssim import ssim +from tqdm import tqdm from wm_attacks import ReSDPipeline - from wm_attacks.wmattacker_with_saving import DiffWMAttacker, VAEWMAttacker -import math -from pytorch_msssim import ssim +import wandb +from .inverse_stable_diffusion import InversableStableDiffusionPipeline +from .io_utils import * +from .optim_utils import * +from .pytorch_fid.fid_score import * from .stable_sig.utils_model import * -from PIL import Image, ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True + def compute_psnr(a, b): mse = torch.mean((a - b) ** 2).item() if mse == 0: @@ -37,16 +35,16 @@ def compute_psnr(a, b): def compute_msssim(a, b): - return ms_ssim(a, b, data_range=1.).item() + return ms_ssim(a, b, data_range=1.0).item() def compute_ssim(a, b): - return ssim(a, b, data_range=1.).item() + return ssim(a, b, data_range=1.0).item() def eval_psnr_ssim_msssim(ori_img_path, new_img_path): - ori_img = Image.open(ori_img_path).convert('RGB') - new_img = Image.open(new_img_path).convert('RGB') + ori_img = Image.open(ori_img_path).convert("RGB") + new_img = Image.open(new_img_path).convert("RGB") if ori_img.size != new_img.size: new_img = new_img.resize(ori_img.size) ori_x = transforms.ToTensor()(ori_img).unsqueeze(0) @@ -57,20 +55,20 @@ def eval_psnr_ssim_msssim(ori_img_path, new_img_path): def main(args): table = None if args.with_tracking: - wandb.init(project=args.project_name, name=args.run_name, tags=['tree_ring_watermark_fid']) + wandb.init(project=args.project_name, name=args.run_name, tags=["tree_ring_watermark_fid"]) wandb.config.update(args) - table = wandb.Table(columns=['gen_no_w', 'gen_w', 'prompt']) - + table = wandb.Table(columns=["gen_no_w", "gen_w", "prompt"]) + # load diffusion model - device = 'cuda' if torch.cuda.is_available() else 'cpu' - - scheduler = DPMSolverMultistepScheduler.from_pretrained(args.model_id, subfolder='scheduler') + device = "cuda" if torch.cuda.is_available() else "cpu" + + scheduler = DPMSolverMultistepScheduler.from_pretrained(args.model_id, subfolder="scheduler") pipe = InversableStableDiffusionPipeline.from_pretrained( args.model_id, scheduler=scheduler, torch_dtype=torch.float16, - revision='fp16', - ) + revision="fp16", + ) pipe = pipe.to(device) if args.use_stable_sig: @@ -80,16 +78,16 @@ def main(args): # hard coding for now with open(args.prompt_file) as f: dataset = json.load(f) - image_files = dataset['images'] - dataset = dataset['annotations'] - prompt_key = 'caption' - - no_w_dir = args.image_folder + '/no_w_gen' - w_dir = args.image_folder + '/w_gen' - if args.attack_type == 'diff': - att_w_dir = args.image_folder + f'/diff_{args.diff_attack_steps}' - if args.attack_type == 'vae': - att_w_dir = args.image_folder + f'/{args.vae_attack_name}_{args.vae_attack_quality}' + image_files = dataset["images"] + dataset = dataset["annotations"] + prompt_key = "caption" + + no_w_dir = args.image_folder + "/no_w_gen" + w_dir = args.image_folder + "/w_gen" + if args.attack_type == "diff": + att_w_dir = args.image_folder + f"/diff_{args.diff_attack_steps}" + if args.attack_type == "vae": + att_w_dir = args.image_folder + f"/{args.vae_attack_name}_{args.vae_attack_quality}" os.makedirs(no_w_dir, exist_ok=True) os.makedirs(w_dir, exist_ok=True) @@ -100,16 +98,16 @@ def main(args): if args.run_generation: for i in tqdm(range(args.start, args.end)): seed = i + args.gen_seed - + current_prompt = dataset[i][prompt_key] if args.use_random_msgs: msg_key = torch.randint(0, 2, (1, args.w_radius), dtype=torch.float32, device="cpu") - msg_str = "".join([ str(int(ii)) for ii in msg_key.tolist()[0]]) + msg_str = "".join([str(int(ii)) for ii in msg_key.tolist()[0]]) if args.use_random_msgs: gt_patch = get_watermarking_pattern(pipe, args, device, message=msg_str) - + ### generation # generation without watermarking set_random_seed(seed) @@ -124,11 +122,11 @@ def main(args): height=args.image_length, width=args.image_length, latents=init_latents_no_w, - ) + ) orig_image_no_w = outputs_no_w.images[0] else: orig_image_no_w = None - + # generation with watermarking if init_latents_no_w is None: set_random_seed(seed) @@ -140,7 +138,7 @@ def main(args): watermarking_mask = get_watermarking_mask(init_latents_w, args, device) # inject watermark - init_latents_w = inject_watermark(init_latents_w, watermarking_mask,gt_patch, args) + init_latents_w = inject_watermark(init_latents_w, watermarking_mask, gt_patch, args) outputs_w = pipe( current_prompt, @@ -150,7 +148,7 @@ def main(args): height=args.image_length, width=args.image_length, latents=init_latents_w, - ) + ) orig_image_w = outputs_w.images[0] # distortion @@ -165,11 +163,11 @@ def main(args): table.add_data(None, wandb.Image(orig_image_w), current_prompt) else: table.add_data(None, None, current_prompt) - - image_file_name = image_files[i]['file_name'] + + image_file_name = image_files[i]["file_name"] if args.run_no_w: - orig_image_no_w.save(f'{no_w_dir}/{image_file_name}') - orig_image_w.save(f'{w_dir}/{image_file_name}') + orig_image_no_w.save(f"{no_w_dir}/{image_file_name}") + orig_image_w.save(f"{w_dir}/{image_file_name}") ### calculate fid try: @@ -179,17 +177,21 @@ def main(args): num_workers = min(num_cpus, 8) if num_cpus is not None else 0 - ori_img_paths = glob.glob(os.path.join(no_w_dir, '*.*')) - ori_img_paths = sorted([path for path in ori_img_paths if path.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.gif'))]) + ori_img_paths = glob.glob(os.path.join(no_w_dir, "*.*")) + ori_img_paths = sorted( + [path for path in ori_img_paths if path.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif"))] + ) if args.use_attack: if args.attack_type == "diff": - attack_pipe = ReSDPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16, revision="fp16") + attack_pipe = ReSDPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16, revision="fp16" + ) attack_pipe.set_progress_bar_config(disable=True) attack_pipe.to(device) attacker = DiffWMAttacker(attack_pipe, noise_step=args.diff_attack_steps, batch_size=5, captions={}) if args.attack_type == "vae": - attacker = VAEWMAttacker(args.vae_attack_name, quality=args.vae_attack_quality, metric='mse', device=device) + attacker = VAEWMAttacker(args.vae_attack_name, quality=args.vae_attack_quality, metric="mse", device=device) wm_img_paths = [] att_img_paths = [] @@ -199,7 +201,7 @@ def main(args): wm_img_paths.append(os.path.join(w_dir, img_name)) att_img_paths.append(os.path.join(att_w_dir, img_name)) # using attacker on whole folder: we are using from with_saving file - attacker.attack(wm_img_paths, att_img_paths) + attacker.attack(wm_img_paths, att_img_paths) # fid for no_w target_folder = args.gt_folder @@ -207,28 +209,16 @@ def main(args): target_folder = no_w_dir if args.run_no_w: - fid_value_no_w = calculate_fid_given_paths([target_folder, no_w_dir], - 50, - device, - 2048, - num_workers) + fid_value_no_w = calculate_fid_given_paths([target_folder, no_w_dir], 50, device, 2048, num_workers) else: fid_value_no_w = None # fid for w - fid_value_w = calculate_fid_given_paths([target_folder, w_dir], - 50, - device, - 2048, - num_workers) + fid_value_w = calculate_fid_given_paths([target_folder, w_dir], 50, device, 2048, num_workers) # fid for att_w if args.use_attack: - fid_value_w_att = calculate_fid_given_paths([target_folder, att_w_dir], - 50, - device, - 2048, - num_workers) + fid_value_w_att = calculate_fid_given_paths([target_folder, att_w_dir], 50, device, 2048, num_workers) # psnr and ssim if args.additional_metrics: @@ -236,7 +226,7 @@ def main(args): # w_dir - это wm_path # att_w_dir - это att_path clean_psnr_list = [] - clean_ssim_list = [] + clean_ssim_list = [] wm_psnr_list = [] wm_ssim_list = [] @@ -272,88 +262,88 @@ def main(args): att_ssim = np.array(att_ssim_list).mean() if args.with_tracking: - wandb.log({'Table': table}) - metrics_table = {'fid_no_w': fid_value_no_w, 'fid_w': fid_value_w} + wandb.log({"Table": table}) + metrics_table = {"fid_no_w": fid_value_no_w, "fid_w": fid_value_w} if args.use_attack: - metrics_table['fid_att_w'] = fid_value_w_att + metrics_table["fid_att_w"] = fid_value_w_att if args.additional_metrics: - metrics_table['psnr_no_w'] = clean_psnr - metrics_table['ssim_no_w'] = clean_ssim - metrics_table['psnr_w'] = wm_psnr - metrics_table['ssim_w'] = wm_ssim + metrics_table["psnr_no_w"] = clean_psnr + metrics_table["ssim_no_w"] = clean_ssim + metrics_table["psnr_w"] = wm_psnr + metrics_table["ssim_w"] = wm_ssim if args.use_attack: - metrics_table['psnr_att_w'] = att_psnr - metrics_table['ssim_att_w'] = att_ssim + metrics_table["psnr_att_w"] = att_psnr + metrics_table["ssim_att_w"] = att_ssim wandb.log(metrics_table) - print(f'fid_no_w: {fid_value_no_w}, fid_w: {fid_value_w}') - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='diffusion watermark') - parser.add_argument('--project_name', default='watermark_attacks') - parser.add_argument('--run_name', default='test') - parser.add_argument('--start', default=0, type=int) - parser.add_argument('--end', default=10, type=int) - parser.add_argument('--image_length', default=512, type=int) - parser.add_argument('--model_id', default='stabilityai/stable-diffusion-2-1-base') - parser.add_argument('--with_tracking', action='store_true') - parser.add_argument('--num_images', default=1, type=int) - parser.add_argument('--guidance_scale', default=7.5, type=float) - parser.add_argument('--num_inference_steps', default=40, type=int) - parser.add_argument('--max_num_log_image', default=100, type=int) - parser.add_argument('--run_no_w', action='store_true') - parser.add_argument('--gen_seed', default=0, type=int) - - parser.add_argument('--prompt_file', default='fid_outputs/coco/meta_data.json') - parser.add_argument('--gt_folder', default='fid_outputs/coco/ground_truth') - parser.add_argument('--image_folder', default='fid_outputs/coco/fid_run') + print(f"fid_no_w: {fid_value_no_w}, fid_w: {fid_value_w}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="diffusion watermark") + parser.add_argument("--project_name", default="watermark_attacks") + parser.add_argument("--run_name", default="test") + parser.add_argument("--start", default=0, type=int) + parser.add_argument("--end", default=10, type=int) + parser.add_argument("--image_length", default=512, type=int) + parser.add_argument("--model_id", default="stabilityai/stable-diffusion-2-1-base") + parser.add_argument("--with_tracking", action="store_true") + parser.add_argument("--num_images", default=1, type=int) + parser.add_argument("--guidance_scale", default=7.5, type=float) + parser.add_argument("--num_inference_steps", default=40, type=int) + parser.add_argument("--max_num_log_image", default=100, type=int) + parser.add_argument("--run_no_w", action="store_true") + parser.add_argument("--gen_seed", default=0, type=int) + + parser.add_argument("--prompt_file", default="fid_outputs/coco/meta_data.json") + parser.add_argument("--gt_folder", default="fid_outputs/coco/ground_truth") + parser.add_argument("--image_folder", default="fid_outputs/coco/fid_run") # Compute metrics with gen_no_w and gen_w: - parser.add_argument('--target_clean_generated', action='store_true') + parser.add_argument("--target_clean_generated", action="store_true") - parser.add_argument('--run_generation', action='store_true') - parser.add_argument('--additional_metrics', action='store_true') + parser.add_argument("--run_generation", action="store_true") + parser.add_argument("--additional_metrics", action="store_true") # watermark - parser.add_argument('--w_seed', default=999999, type=int) - parser.add_argument('--w_channel', default=3, type=int) - parser.add_argument('--w_pattern', default='ring') - parser.add_argument('--w_mask_shape', default='circle') - parser.add_argument('--w_radius', default=10, type=int) - parser.add_argument('--w_measurement', default='l1_complex') - parser.add_argument('--w_injection', default='complex') - parser.add_argument('--w_pattern_const', default=0, type=float) + parser.add_argument("--w_seed", default=999999, type=int) + parser.add_argument("--w_channel", default=3, type=int) + parser.add_argument("--w_pattern", default="ring") + parser.add_argument("--w_mask_shape", default="circle") + parser.add_argument("--w_radius", default=10, type=int) + parser.add_argument("--w_measurement", default="l1_complex") + parser.add_argument("--w_injection", default="complex") + parser.add_argument("--w_pattern_const", default=0, type=float) # VAE or Diff attack - parser.add_argument('--use_attack', action='store_true') - parser.add_argument('--attack_type', default='diff') - parser.add_argument('--use_attack_prompt', action='store_true') - parser.add_argument('--diff_attack_steps', default=60, type=int) - parser.add_argument('--vae_attack_name', default='cheng2020-anchor') - parser.add_argument('--vae_attack_quality', default=3, type=int) + parser.add_argument("--use_attack", action="store_true") + parser.add_argument("--attack_type", default="diff") + parser.add_argument("--use_attack_prompt", action="store_true") + parser.add_argument("--diff_attack_steps", default=60, type=int) + parser.add_argument("--vae_attack_name", default="cheng2020-anchor") + parser.add_argument("--vae_attack_quality", default=3, type=int) # Message encryption (for testing: putting the same message on each image, but they can be different): - parser.add_argument('--msg_type', default='rand', help="Can be: rand or binary or decimal") - parser.add_argument('--msg', default='1110101101') - parser.add_argument('--use_random_msgs', action='store_true', help="Generate random message each step of cycle") - parser.add_argument('--msg_scaler', default=100, type=int, help="Scaling coefficient of message") + parser.add_argument("--msg_type", default="rand", help="Can be: rand or binary or decimal") + parser.add_argument("--msg", default="1110101101") + parser.add_argument("--use_random_msgs", action="store_true", help="Generate random message each step of cycle") + parser.add_argument("--msg_scaler", default=100, type=int, help="Scaling coefficient of message") # METR++: - parser.add_argument('--use_stable_sig', action='store_true') - parser.add_argument('--decoder_state_dict_path', default='finetune_ldm_decoder/ldm_decoder_checkpoint_000.pth') - parser.add_argument('--stable_sig_full_model_config', default="v2-inference.yaml") - parser.add_argument('--stable_sig_full_model_ckpt', default='v2-1_512-ema-pruned.ckpt') + parser.add_argument("--use_stable_sig", action="store_true") + parser.add_argument("--decoder_state_dict_path", default="finetune_ldm_decoder/ldm_decoder_checkpoint_000.pth") + parser.add_argument("--stable_sig_full_model_config", default="v2-inference.yaml") + parser.add_argument("--stable_sig_full_model_ckpt", default="v2-1_512-ema-pruned.ckpt") # for image distortion - parser.add_argument('--r_degree', default=None, type=float) - parser.add_argument('--jpeg_ratio', default=None, type=int) - parser.add_argument('--crop_scale', default=None, type=float) - parser.add_argument('--crop_ratio', default=None, type=float) - parser.add_argument('--gaussian_blur_r', default=None, type=int) - parser.add_argument('--gaussian_std', default=None, type=float) - parser.add_argument('--brightness_factor', default=None, type=float) - parser.add_argument('--rand_aug', default=0, type=int) + parser.add_argument("--r_degree", default=None, type=float) + parser.add_argument("--jpeg_ratio", default=None, type=int) + parser.add_argument("--crop_scale", default=None, type=float) + parser.add_argument("--crop_ratio", default=None, type=float) + parser.add_argument("--gaussian_blur_r", default=None, type=int) + parser.add_argument("--gaussian_std", default=None, type=float) + parser.add_argument("--brightness_factor", default=None, type=float) + parser.add_argument("--rand_aug", default=0, type=int) args = parser.parse_args() - - main(args) \ No newline at end of file + + main(args) diff --git a/src/metr/stable_sig/__init__.py b/src/metr/stable_sig/__init__.py index b794fd4..3dc1f76 100644 --- a/src/metr/stable_sig/__init__.py +++ b/src/metr/stable_sig/__init__.py @@ -1 +1 @@ -__version__ = '0.1.0' +__version__ = "0.1.0" diff --git a/src/metr/stable_sig/utils.py b/src/metr/stable_sig/utils.py index 6058303..e20cda8 100644 --- a/src/metr/stable_sig/utils.py +++ b/src/metr/stable_sig/utils.py @@ -4,52 +4,56 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import math -import time import datetime +import functools +import math import os import subprocess -import functools +import time from collections import defaultdict, deque import numpy as np -from PIL import Image - import torch +from PIL import Image from torch.utils.data import DataLoader, Subset -from torchvision.datasets.folder import is_image_file, default_loader +from torchvision.datasets.folder import default_loader, is_image_file ### Optimizer building + def parse_params(s): """ Parse parameters into a dictionary, used for optimizer and scheduler parsing. - Example: + Example: "SGD,lr=0.01" -> {"name": "SGD", "lr": 0.01} """ - s = s.replace(' ', '').split(',') + s = s.replace(" ", "").split(",") params = {} - params['name'] = s[0] + params["name"] = s[0] for x in s[1:]: - x = x.split('=') - params[x[0]]=float(x[1]) + x = x.split("=") + params[x[0]] = float(x[1]) return params + def build_optimizer(name, model_params, **optim_params): - """ Build optimizer from a dictionary of parameters """ - torch_optimizers = sorted(name for name in torch.optim.__dict__ - if name[0].isupper() and not name.startswith("__") - and callable(torch.optim.__dict__[name])) + """Build optimizer from a dictionary of parameters""" + torch_optimizers = sorted( + name + for name in torch.optim.__dict__ + if name[0].isupper() and not name.startswith("__") and callable(torch.optim.__dict__[name]) + ) if hasattr(torch.optim, name): return getattr(torch.optim, name)(model_params, **optim_params) raise ValueError(f'Unknown optimizer "{name}", choose among {str(torch_optimizers)}') + def adjust_learning_rate(optimizer, step, steps, warmup_steps, blr, min_lr=1e-6): """Decay the learning rate with half-cycle cosine after warmup""" if step < warmup_steps: - lr = blr * step / warmup_steps + lr = blr * step / warmup_steps else: - lr = min_lr + (blr - min_lr) * 0.5 * (1. + math.cos(math.pi * (step - warmup_steps) / (steps - warmup_steps))) + lr = min_lr + (blr - min_lr) * 0.5 * (1.0 + math.cos(math.pi * (step - warmup_steps) / (steps - warmup_steps))) for param_group in optimizer.param_groups: if "lr_scale" in param_group: param_group["lr"] = lr * param_group["lr_scale"] @@ -57,8 +61,10 @@ def adjust_learning_rate(optimizer, step, steps, warmup_steps, blr, min_lr=1e-6) param_group["lr"] = lr return lr + ### Data loading + @functools.lru_cache() def get_image_paths(path): paths = [] @@ -67,6 +73,7 @@ def get_image_paths(path): paths.append(os.path.join(path, filename)) return sorted([fn for fn in paths if is_image_file(fn)]) + class ImageFolder: """An image folder dataset intended for self-supervised learning.""" @@ -85,24 +92,37 @@ def __getitem__(self, idx: int): def __len__(self): return len(self.samples) + def collate_fn(batch): - """ Collate function for data loader. Allows to have img of different size""" + """Collate function for data loader. Allows to have img of different size""" return batch -def get_dataloader(data_dir, transform, batch_size=128, num_imgs=None, shuffle=False, num_workers=4, collate_fn=collate_fn): - """ Get dataloader for the images in the data_dir. The data_dir must be of the form: input/0/... """ + +def get_dataloader( + data_dir, transform, batch_size=128, num_imgs=None, shuffle=False, num_workers=4, collate_fn=collate_fn +): + """Get dataloader for the images in the data_dir. The data_dir must be of the form: input/0/...""" dataset = ImageFolder(data_dir, transform=transform) if num_imgs is not None: dataset = Subset(dataset, np.random.choice(len(dataset), num_imgs, replace=False)) - return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, drop_last=False, collate_fn=collate_fn) + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + pin_memory=True, + drop_last=False, + collate_fn=collate_fn, + ) + def pil_imgs_from_folder(folder): - """ Get all images in the folder as PIL images """ + """Get all images in the folder as PIL images""" images = [] filenames = [] for filename in os.listdir(folder): try: - img = Image.open(os.path.join(folder,filename)) + img = Image.open(os.path.join(folder, filename)) if img is not None: filenames.append(filename) images.append(img) @@ -110,8 +130,10 @@ def pil_imgs_from_folder(folder): print("Error opening image: ", filename) return images, filenames + ### Metric logging + class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a window or the global series average. @@ -154,11 +176,9 @@ def value(self): def __str__(self): return self.fmt.format( - median=self.median, - avg=self.avg, - global_avg=self.global_avg, - max=self.max, - value=self.value) + median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value + ) + class MetricLogger(object): def __init__(self, delimiter="\t"): @@ -177,15 +197,12 @@ def __getattr__(self, attr): return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] - raise AttributeError("'{}' object has no attribute '{}'".format( - type(self).__name__, attr)) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): - loss_str.append( - "{}: {}".format(name, str(meter)) - ) + loss_str.append("{}: {}".format(name, str(meter))) return self.delimiter.join(loss_str) def add_meter(self, name, meter): @@ -194,31 +211,28 @@ def add_meter(self, name, meter): def log_every(self, iterable, print_freq, header=None): i = 0 if not header: - header = '' + header = "" start_time = time.time() end = time.time() - iter_time = SmoothedValue(fmt='{avg:.6f}') - data_time = SmoothedValue(fmt='{avg:.6f}') - space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + iter_time = SmoothedValue(fmt="{avg:.6f}") + data_time = SmoothedValue(fmt="{avg:.6f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" if torch.cuda.is_available(): - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}', - 'max mem: {memory:.0f}' - ]) + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) else: - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}' - ]) + log_msg = self.delimiter.join( + [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"] + ) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) @@ -228,49 +242,60 @@ def log_every(self, iterable, print_freq, header=None): eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time), - memory=torch.cuda.max_memory_allocated() / MB)) + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) else: - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time))) + print( + log_msg.format( + i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) + ) + ) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('{} Total time: {} ({:.6f} s / it)'.format(header, total_time_str, total_time / (len(iterable)+1))) + print("{} Total time: {} ({:.6f} s / it)".format(header, total_time_str, total_time / (len(iterable) + 1))) + + +### Misc -### Misc def bool_inst(v): if isinstance(v, bool): return v - if v.lower() in ('yes', 'true', 't', 'y', '1'): + if v.lower() in ("yes", "true", "t", "y", "1"): return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): + elif v.lower() in ("no", "false", "f", "n", "0"): return False else: - raise ValueError('Boolean value expected in args') + raise ValueError("Boolean value expected in args") + def get_sha(): cwd = os.path.dirname(os.path.abspath(__file__)) def _run(command): - return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() - sha = 'N/A' + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" diff = "clean" - branch = 'N/A' + branch = "N/A" try: - sha = _run(['git', 'rev-parse', 'HEAD']) - subprocess.check_output(['git', 'diff'], cwd=cwd) - diff = _run(['git', 'diff-index', 'HEAD']) + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) diff = "has uncommited changes" if diff else "clean" - branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) except Exception: pass message = f"sha: {sha}, status: {diff}, branch: {branch}" - return message \ No newline at end of file + return message diff --git a/src/metr/stable_sig/utils_img.py b/src/metr/stable_sig/utils_img.py index a412156..222b887 100644 --- a/src/metr/stable_sig/utils_img.py +++ b/src/metr/stable_sig/utils_img.py @@ -7,103 +7,120 @@ # pyright: reportMissingModuleSource=false import numpy as np -from augly.image import functional as aug_functional import torch +from augly.image import functional as aug_functional from torchvision import transforms from torchvision.transforms import functional device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -default_transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), -]) - -normalize_vqgan = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize (x - 0.5) / 0.5 -unnormalize_vqgan = transforms.Normalize(mean=[-1, -1, -1], std=[1/0.5, 1/0.5, 1/0.5]) # Unnormalize (x * 0.5) + 0.5 -normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize (x - mean) / std -unnormalize_img = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225]) # Unnormalize (x * std) + mean - -def psnr(x, y, img_space='vqgan'): - """ - Return PSNR +default_transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] +) + +normalize_vqgan = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize (x - 0.5) / 0.5 +unnormalize_vqgan = transforms.Normalize( + mean=[-1, -1, -1], std=[1 / 0.5, 1 / 0.5, 1 / 0.5] +) # Unnormalize (x * 0.5) + 0.5 +normalize_img = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] +) # Normalize (x - mean) / std +unnormalize_img = transforms.Normalize( + mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], std=[1 / 0.229, 1 / 0.224, 1 / 0.225] +) # Unnormalize (x * std) + mean + + +def psnr(x, y, img_space="vqgan"): + """ + Return PSNR Args: x: Image tensor with values approx. between [-1,1] y: Image tensor with values approx. between [-1,1], ex: original image """ - if img_space == 'vqgan': + if img_space == "vqgan": delta = torch.clamp(unnormalize_vqgan(x), 0, 1) - torch.clamp(unnormalize_vqgan(y), 0, 1) - elif img_space == 'img': + elif img_space == "img": delta = torch.clamp(unnormalize_img(x), 0, 1) - torch.clamp(unnormalize_img(y), 0, 1) else: delta = x - y delta = 255 * delta - delta = delta.reshape(-1, x.shape[-3], x.shape[-2], x.shape[-1]) # BxCxHxW - psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2, dim=(1,2,3))) # B + delta = delta.reshape(-1, x.shape[-3], x.shape[-2], x.shape[-1]) # BxCxHxW + psnr = 20 * np.log10(255) - 10 * torch.log10(torch.mean(delta**2, dim=(1, 2, 3))) # B return psnr + def center_crop(x, scale): - """ Perform center crop such that the target area of the crop is at a given scale + """Perform center crop such that the target area of the crop is at a given scale Args: x: PIL image - scale: target area scale + scale: target area scale """ scale = np.sqrt(scale) - new_edges_size = [int(s*scale) for s in x.shape[-2:]][::-1] + new_edges_size = [int(s * scale) for s in x.shape[-2:]][::-1] return functional.center_crop(x, new_edges_size) + def resize(x, scale): - """ Perform center crop such that the target area of the crop is at a given scale + """Perform center crop such that the target area of the crop is at a given scale Args: x: PIL image - scale: target area scale + scale: target area scale """ scale = np.sqrt(scale) - new_edges_size = [int(s*scale) for s in x.shape[-2:]][::-1] + new_edges_size = [int(s * scale) for s in x.shape[-2:]][::-1] return functional.resize(x, new_edges_size) + def rotate(x, angle): - """ Rotate image by angle + """Rotate image by angle Args: x: image (PIl or tensor) angle: angle in degrees """ return functional.rotate(x, angle) + def adjust_brightness(x, brightness_factor): - """ Adjust brightness of an image + """Adjust brightness of an image Args: x: PIL image brightness_factor: brightness factor """ return normalize_img(functional.adjust_brightness(unnormalize_img(x), brightness_factor)) + def adjust_contrast(x, contrast_factor): - """ Adjust contrast of an image + """Adjust contrast of an image Args: x: PIL image contrast_factor: contrast factor """ return normalize_img(functional.adjust_contrast(unnormalize_img(x), contrast_factor)) + def adjust_saturation(x, saturation_factor): - """ Adjust saturation of an image + """Adjust saturation of an image Args: x: PIL image saturation_factor: saturation factor """ return normalize_img(functional.adjust_saturation(unnormalize_img(x), saturation_factor)) + def adjust_hue(x, hue_factor): - """ Adjust hue of an image + """Adjust hue of an image Args: x: PIL image hue_factor: hue factor """ return normalize_img(functional.adjust_hue(unnormalize_img(x), hue_factor)) + def adjust_gamma(x, gamma, gain=1): - """ Adjust gamma of an image + """Adjust gamma of an image Args: x: PIL image gamma: gamma factor @@ -111,16 +128,18 @@ def adjust_gamma(x, gamma, gain=1): """ return normalize_img(functional.adjust_gamma(unnormalize_img(x), gamma, gain)) + def adjust_sharpness(x, sharpness_factor): - """ Adjust sharpness of an image + """Adjust sharpness of an image Args: x: PIL image sharpness_factor: sharpness factor """ return normalize_img(functional.adjust_sharpness(unnormalize_img(x), sharpness_factor)) -def overlay_text(x, text='Lorem Ipsum'): - """ Overlay text on image + +def overlay_text(x, text="Lorem Ipsum"): + """Overlay text on image Args: x: PIL image text: text to overlay @@ -132,13 +151,14 @@ def overlay_text(x, text='Lorem Ipsum'): to_pil = transforms.ToPILImage() to_tensor = transforms.ToTensor() img_aug = torch.zeros_like(x, device=x.device) - for ii,img in enumerate(x): + for ii, img in enumerate(x): pil_img = to_pil(unnormalize_img(img)) img_aug[ii] = to_tensor(aug_functional.overlay_text(pil_img, text=text)) return normalize_img(img_aug) + def jpeg_compress(x, quality_factor): - """ Apply jpeg compression to image + """Apply jpeg compression to image Args: x: PIL image quality_factor: quality factor @@ -146,7 +166,7 @@ def jpeg_compress(x, quality_factor): to_pil = transforms.ToPILImage() to_tensor = transforms.ToTensor() img_aug = torch.zeros_like(x, device=x.device) - for ii,img in enumerate(x): + for ii, img in enumerate(x): pil_img = to_pil(unnormalize_img(img)) img_aug[ii] = to_tensor(aug_functional.encoding_quality(pil_img, quality=quality_factor)) return normalize_img(img_aug) diff --git a/src/metr/stable_sig/utils_model.py b/src/metr/stable_sig/utils_model.py index b299fa2..ba31224 100644 --- a/src/metr/stable_sig/utils_model.py +++ b/src/metr/stable_sig/utils_model.py @@ -1,36 +1,40 @@ # File with supplementary functions for stable-tree watermarking import importlib + import torch import torch.nn as nn - -from omegaconf import OmegaConf +from omegaconf import OmegaConf # from diffusers.models import AutoencoderKL ### Load HiDDeN models + class ConvBNRelu(nn.Module): """ Building block used in HiDDeN network. Is a sequence of Convolution, Batch Normalization, and ReLU activation """ + def __init__(self, channels_in, channels_out): super(ConvBNRelu, self).__init__() - + self.layers = nn.Sequential( nn.Conv2d(channels_in, channels_out, 3, stride=1, padding=1), nn.BatchNorm2d(channels_out, eps=1e-3), - nn.GELU() + nn.GELU(), ) def forward(self, x): return self.layers(x) + class HiddenDecoder(nn.Module): """ Decoder module. Receives a watermarked image and extracts the watermark. """ + def __init__(self, num_blocks, num_bits, channels, redundancy=1): super(HiddenDecoder, self).__init__() @@ -39,35 +43,37 @@ def __init__(self, num_blocks, num_bits, channels, redundancy=1): for _ in range(num_blocks - 1): layers.append(ConvBNRelu(channels, channels)) - layers.append(ConvBNRelu(channels, num_bits*redundancy)) + layers.append(ConvBNRelu(channels, num_bits * redundancy)) layers.append(nn.AdaptiveAvgPool2d(output_size=(1, 1))) self.layers = nn.Sequential(*layers) - self.linear = nn.Linear(num_bits*redundancy, num_bits*redundancy) + self.linear = nn.Linear(num_bits * redundancy, num_bits * redundancy) self.num_bits = num_bits self.redundancy = redundancy def forward(self, img_w): - x = self.layers(img_w) # b d 1 1 - x = x.squeeze(-1).squeeze(-1) # b d + x = self.layers(img_w) # b d 1 1 + x = x.squeeze(-1).squeeze(-1) # b d x = self.linear(x) - x = x.view(-1, self.num_bits, self.redundancy) # b k*r -> b k r - x = torch.sum(x, dim=-1) # b k r -> b k + x = x.view(-1, self.num_bits, self.redundancy) # b k*r -> b k r + x = torch.sum(x, dim=-1) # b k r -> b k return x + class HiddenEncoder(nn.Module): """ Inserts a watermark into an image. """ + def __init__(self, num_blocks, num_bits, channels, last_tanh=True): super(HiddenEncoder, self).__init__() layers = [ConvBNRelu(3, channels)] - for _ in range(num_blocks-1): + for _ in range(num_blocks - 1): layer = ConvBNRelu(channels, channels) layers.append(layer) @@ -81,8 +87,8 @@ def __init__(self, num_blocks, num_bits, channels, last_tanh=True): def forward(self, imgs, msgs): - msgs = msgs.unsqueeze(-1).unsqueeze(-1) # b l 1 1 - msgs = msgs.expand(-1,-1, imgs.size(-2), imgs.size(-1)) # b l h w + msgs = msgs.unsqueeze(-1).unsqueeze(-1) # b l 1 1 + msgs = msgs.expand(-1, -1, imgs.size(-2), imgs.size(-1)) # b l h w encoded_image = self.conv_bns(imgs) @@ -95,35 +101,50 @@ def forward(self, imgs, msgs): return im_w + def get_hidden_decoder(num_bits, redundancy=1, num_blocks=7, channels=64): decoder = HiddenDecoder(num_blocks=num_blocks, num_bits=num_bits, channels=channels, redundancy=redundancy) return decoder + def get_hidden_decoder_ckpt(ckpt_path): ckpt = torch.load(ckpt_path, map_location="cpu") - decoder_ckpt = { k.replace('module.', '').replace('decoder.', '') : v for k,v in ckpt['encoder_decoder'].items() if 'decoder' in k} + decoder_ckpt = { + k.replace("module.", "").replace("decoder.", ""): v + for k, v in ckpt["encoder_decoder"].items() + if "decoder" in k + } return decoder_ckpt + def get_hidden_encoder(num_bits, num_blocks=4, channels=64): encoder = HiddenEncoder(num_blocks=num_blocks, num_bits=num_bits, channels=channels) return encoder + def get_hidden_encoder_ckpt(ckpt_path): ckpt = torch.load(ckpt_path, map_location="cpu") - encoder_ckpt = { k.replace('module.', '').replace('encoder.', '') : v for k,v in ckpt['encoder_decoder'].items() if 'encoder' in k} + encoder_ckpt = { + k.replace("module.", "").replace("encoder.", ""): v + for k, v in ckpt["encoder_decoder"].items() + if "encoder" in k + } return encoder_ckpt + ### Load LDM models + def instantiate_from_config(config): if not "target" in config: - if config == '__is_first_stage__': + if config == "__is_first_stage__": return None elif config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) + def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) # module = "." + module @@ -132,6 +153,7 @@ def get_obj_from_str(string, reload=False): importlib.reload(module_imp) return getattr(importlib.import_module(module, package="tree_ring_watermark"), cls) + def load_model_from_config(config, ckpt, verbose=False): print(f"Loading model from {ckpt}") @@ -153,31 +175,30 @@ def load_model_from_config(config, ckpt, verbose=False): model.eval() return model + class Sampler: - ''' + """ Crutch to make decode(x).sample return x - ''' + """ + def __init__(self, x): self.sample = x - - -def change_pipe_vae_decoder(pipe, - weights_path, - args - ): - ''' - - loads dict of weights into predefined vae config + + +def change_pipe_vae_decoder(pipe, weights_path, args): + """ + - loads dict of weights into predefined vae config - changes pipe.vae.decode function into decoding with this vae ------------- weights_path: path to weights of decoder - ''' + """ config_path = args.stable_sig_full_model_config ckpt_path = args.stable_sig_full_model_ckpt ldm_config = config_path ldm_ckpt = ckpt_path - print(f'>>> Building LDM model with config {ldm_config} and weights from {ldm_ckpt}...') + print(f">>> Building LDM model with config {ldm_config} and weights from {ldm_ckpt}...") # ---------------- config = OmegaConf.load(f"{ldm_config}") @@ -190,9 +211,9 @@ def change_pipe_vae_decoder(pipe, # loading the fine-tuned decoder weights state_dict = torch.load(weights_path) - print(f'>>> Loaded VAE decoder weights from {weights_path}') + print(f">>> Loaded VAE decoder weights from {weights_path}") unexpected_keys = ldm_aef.load_state_dict(state_dict, strict=False) - pipe.vae.decode = (lambda x, *args, **kwargs: Sampler(ldm_aef.decode(x))) # здесь было еще .unsqueeze(0) + pipe.vae.decode = lambda x, *args, **kwargs: Sampler(ldm_aef.decode(x)) # здесь было еще .unsqueeze(0) - return pipe \ No newline at end of file + return pipe diff --git a/src/metr/taming/data/ade20k.py b/src/metr/taming/data/ade20k.py index 366dae9..507b671 100644 --- a/src/metr/taming/data/ade20k.py +++ b/src/metr/taming/data/ade20k.py @@ -1,30 +1,33 @@ import os -import numpy as np -import cv2 + import albumentations +import cv2 +import numpy as np from PIL import Image +from taming.data.sflckr import SegmentationBase # for examples included in repo from torch.utils.data import Dataset -from taming.data.sflckr import SegmentationBase # for examples included in repo - class Examples(SegmentationBase): def __init__(self, size=256, random_crop=False, interpolation="bicubic"): - super().__init__(data_csv="data/ade20k_examples.txt", - data_root="data/ade20k_images", - segmentation_root="data/ade20k_segmentations", - size=size, random_crop=random_crop, - interpolation=interpolation, - n_labels=151, shift_segmentation=False) + super().__init__( + data_csv="data/ade20k_examples.txt", + data_root="data/ade20k_images", + segmentation_root="data/ade20k_segmentations", + size=size, + random_crop=random_crop, + interpolation=interpolation, + n_labels=151, + shift_segmentation=False, + ) # With semantic map and scene label class ADE20kBase(Dataset): def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None): self.split = self.get_split() - self.n_labels = 151 # unknown + 150 - self.data_csv = {"train": "data/ade20k_train.txt", - "validation": "data/ade20k_test.txt"}[self.split] + self.n_labels = 151 # unknown + 150 + self.data_csv = {"train": "data/ade20k_train.txt", "validation": "data/ade20k_test.txt"}[self.split] self.data_root = "data/ade20k_root" with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f: self.scene_categories = f.read().splitlines() @@ -34,18 +37,15 @@ def __init__(self, config=None, size=None, random_crop=False, interpolation="bic self._length = len(self.image_paths) self.labels = { "relative_file_path_": [l for l in self.image_paths], - "file_path_": [os.path.join(self.data_root, "images", l) - for l in self.image_paths], - "relative_segmentation_path_": [l.replace(".jpg", ".png") - for l in self.image_paths], - "segmentation_path_": [os.path.join(self.data_root, "annotations", - l.replace(".jpg", ".png")) - for l in self.image_paths], - "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")] - for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, "images", l) for l in self.image_paths], + "relative_segmentation_path_": [l.replace(".jpg", ".png") for l in self.image_paths], + "segmentation_path_": [ + os.path.join(self.data_root, "annotations", l.replace(".jpg", ".png")) for l in self.image_paths + ], + "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")] for l in self.image_paths], } - size = None if size is not None and size<=0 else size + size = None if size is not None and size <= 0 else size self.size = size if crop_size is None: self.crop_size = size if size is not None else None @@ -58,11 +58,12 @@ def __init__(self, config=None, size=None, random_crop=False, interpolation="bic "bilinear": cv2.INTER_LINEAR, "bicubic": cv2.INTER_CUBIC, "area": cv2.INTER_AREA, - "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] - self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, - interpolation=self.interpolation) - self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, - interpolation=cv2.INTER_NEAREST) + "lanczos": cv2.INTER_LANCZOS4, + }[self.interpolation] + self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, interpolation=self.interpolation) + self.segmentation_rescaler = albumentations.SmallestMaxSize( + max_size=self.size, interpolation=cv2.INTER_NEAREST + ) if crop_size is not None: self.center_crop = not random_crop @@ -91,7 +92,7 @@ def __getitem__(self, i): processed = self.preprocessor(image=image, mask=segmentation) else: processed = {"image": image, "mask": segmentation} - example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) + example["image"] = (processed["image"] / 127.5 - 1.0).astype(np.float32) segmentation = processed["mask"] onehot = np.eye(self.n_labels)[segmentation] example["segmentation"] = onehot @@ -101,8 +102,9 @@ def __getitem__(self, i): class ADE20kTrain(ADE20kBase): # default to random_crop=True def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None): - super().__init__(config=config, size=size, random_crop=random_crop, - interpolation=interpolation, crop_size=crop_size) + super().__init__( + config=config, size=size, random_crop=random_crop, interpolation=interpolation, crop_size=crop_size + ) def get_split(self): return "train" diff --git a/src/metr/taming/data/annotated_objects_coco.py b/src/metr/taming/data/annotated_objects_coco.py index af000ec..0f89650 100644 --- a/src/metr/taming/data/annotated_objects_coco.py +++ b/src/metr/taming/data/annotated_objects_coco.py @@ -1,76 +1,82 @@ import json +from collections import defaultdict from itertools import chain from pathlib import Path -from typing import Iterable, Dict, List, Callable, Any -from collections import defaultdict - -from tqdm import tqdm +from typing import Any, Callable, Dict, Iterable, List from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset -from taming.data.helper_types import Annotation, ImageDescription, Category +from taming.data.helper_types import Annotation, Category, ImageDescription +from tqdm import tqdm COCO_PATH_STRUCTURE = { - 'train': { - 'top_level': '', - 'instances_annotations': 'annotations/instances_train2017.json', - 'stuff_annotations': 'annotations/stuff_train2017.json', - 'files': 'train2017' + "train": { + "top_level": "", + "instances_annotations": "annotations/instances_train2017.json", + "stuff_annotations": "annotations/stuff_train2017.json", + "files": "train2017", + }, + "validation": { + "top_level": "", + "instances_annotations": "annotations/instances_val2017.json", + "stuff_annotations": "annotations/stuff_val2017.json", + "files": "val2017", }, - 'validation': { - 'top_level': '', - 'instances_annotations': 'annotations/instances_val2017.json', - 'stuff_annotations': 'annotations/stuff_val2017.json', - 'files': 'val2017' - } } def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]: return { - str(img['id']): ImageDescription( - id=img['id'], - license=img.get('license'), - file_name=img['file_name'], - coco_url=img['coco_url'], - original_size=(img['width'], img['height']), - date_captured=img.get('date_captured'), - flickr_url=img.get('flickr_url') + str(img["id"]): ImageDescription( + id=img["id"], + license=img.get("license"), + file_name=img["file_name"], + coco_url=img["coco_url"], + original_size=(img["width"], img["height"]), + date_captured=img.get("date_captured"), + flickr_url=img.get("flickr_url"), ) for img in description_json } def load_categories(category_json: Iterable) -> Dict[str, Category]: - return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name']) - for cat in category_json if cat['name'] != 'other'} + return { + str(cat["id"]): Category(id=str(cat["id"]), super_category=cat["supercategory"], name=cat["name"]) + for cat in category_json + if cat["name"] != "other" + } -def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription], - category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]: +def load_annotations( + annotations_json: List[Dict], + image_descriptions: Dict[str, ImageDescription], + category_no_for_id: Callable[[str], int], + split: str, +) -> Dict[str, List[Annotation]]: annotations = defaultdict(list) total = sum(len(a) for a in annotations_json) - for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total): - image_id = str(ann['image_id']) + for ann in tqdm(chain(*annotations_json), f"Loading {split} annotations", total=total): + image_id = str(ann["image_id"]) if image_id not in image_descriptions: - raise ValueError(f'image_id [{image_id}] has no image description.') - category_id = ann['category_id'] + raise ValueError(f"image_id [{image_id}] has no image description.") + category_id = ann["category_id"] try: category_no = category_no_for_id(str(category_id)) except KeyError: continue width, height = image_descriptions[image_id].original_size - bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height) + bbox = (ann["bbox"][0] / width, ann["bbox"][1] / height, ann["bbox"][2] / width, ann["bbox"][3] / height) annotations[image_id].append( Annotation( - id=ann['id'], - area=bbox[2]*bbox[3], # use bbox area - is_group_of=ann['iscrowd'], - image_id=ann['image_id'], + id=ann["id"], + area=bbox[2] * bbox[3], # use bbox area + is_group_of=ann["iscrowd"], + image_id=ann["image_id"], bbox=bbox, category_id=str(category_id), - category_no=category_no + category_no=category_no, ) ) return dict(annotations) @@ -101,38 +107,39 @@ def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs): self.use_things = use_things self.use_stuff = use_stuff - with open(self.paths['instances_annotations']) as f: + with open(self.paths["instances_annotations"]) as f: inst_data_json = json.load(f) - with open(self.paths['stuff_annotations']) as f: + with open(self.paths["stuff_annotations"]) as f: stuff_data_json = json.load(f) category_jsons = [] annotation_jsons = [] if self.use_things: - category_jsons.append(inst_data_json['categories']) - annotation_jsons.append(inst_data_json['annotations']) + category_jsons.append(inst_data_json["categories"]) + annotation_jsons.append(inst_data_json["annotations"]) if self.use_stuff: - category_jsons.append(stuff_data_json['categories']) - annotation_jsons.append(stuff_data_json['annotations']) + category_jsons.append(stuff_data_json["categories"]) + annotation_jsons.append(stuff_data_json["annotations"]) self.categories = load_categories(chain(*category_jsons)) self.filter_categories() self.setup_category_id_and_number() - self.image_descriptions = load_image_descriptions(inst_data_json['images']) + self.image_descriptions = load_image_descriptions(inst_data_json["images"]) annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split) - self.annotations = self.filter_object_number(annotations, self.min_object_area, - self.min_objects_per_image, self.max_objects_per_image) + self.annotations = self.filter_object_number( + annotations, self.min_object_area, self.min_objects_per_image, self.max_objects_per_image + ) self.image_ids = list(self.annotations.keys()) self.clean_up_annotations_and_image_descriptions() def get_path_structure(self) -> Dict[str, str]: if self.split not in COCO_PATH_STRUCTURE: - raise ValueError(f'Split [{self.split} does not exist for COCO data.]') + raise ValueError(f"Split [{self.split} does not exist for COCO data.]") return COCO_PATH_STRUCTURE[self.split] def get_image_path(self, image_id: str) -> Path: - return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name) + return self.paths["files"].joinpath(self.image_descriptions[str(image_id)].file_name) def get_image_description(self, image_id: str) -> Dict[str, Any]: # noinspection PyProtectedMember diff --git a/src/metr/taming/data/annotated_objects_dataset.py b/src/metr/taming/data/annotated_objects_dataset.py index 53cc346..6c47da0 100644 --- a/src/metr/taming/data/annotated_objects_dataset.py +++ b/src/metr/taming/data/annotated_objects_dataset.py @@ -1,26 +1,43 @@ -from pathlib import Path -from typing import Optional, List, Callable, Dict, Any, Union import warnings +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union import PIL.Image as pil_image -from torch import Tensor -from torch.utils.data import Dataset -from torchvision import transforms - from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder from taming.data.conditional_builder.utils import load_object_from_string -from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType -from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \ - Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor +from taming.data.helper_types import Annotation, BoundingBox, CropMethodType, Image, SplitType +from taming.data.image_transforms import ( + CenterCropReturnCoordinates, + Random2dCropReturnCoordinates, + RandomCrop1dReturnCoordinates, + RandomHorizontalFlipReturn, + convert_pil_to_tensor, +) +from torch import Tensor +from torch.utils.data import Dataset +from torchvision import transforms class AnnotatedObjectsDataset(Dataset): - def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int, - min_object_area: float, min_objects_per_image: int, max_objects_per_image: int, - crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool, - encode_crop: bool, category_allow_list_target: str = "", category_mapping_target: str = "", - no_object_classes: Optional[int] = None): + def __init__( + self, + data_path: Union[str, Path], + split: SplitType, + keys: List[str], + target_image_size: int, + min_object_area: float, + min_objects_per_image: int, + max_objects_per_image: int, + crop_method: CropMethodType, + random_flip: bool, + no_tokens: int, + use_group_parameter: bool, + encode_crop: bool, + category_allow_list_target: str = "", + category_mapping_target: str = "", + no_object_classes: Optional[int] = None, + ): self.data_path = data_path self.split = split self.keys = keys @@ -57,47 +74,46 @@ def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]: sub_paths = {name: top_level.joinpath(sub_path) for name, sub_path in self.get_path_structure().items()} for path in sub_paths.values(): if not path.exists(): - raise FileNotFoundError(f'{type(self).__name__} data structure error: [{path}] does not exist.') + raise FileNotFoundError(f"{type(self).__name__} data structure error: [{path}] does not exist.") return sub_paths @staticmethod def load_image_from_disk(path: Path) -> Image: - return pil_image.open(path).convert('RGB') + return pil_image.open(path).convert("RGB") @staticmethod def setup_transform(target_image_size: int, crop_method: CropMethodType, random_flip: bool): transform_functions = [] - if crop_method == 'none': + if crop_method == "none": transform_functions.append(transforms.Resize((target_image_size, target_image_size))) - elif crop_method == 'center': - transform_functions.extend([ - transforms.Resize(target_image_size), - CenterCropReturnCoordinates(target_image_size) - ]) - elif crop_method == 'random-1d': - transform_functions.extend([ - transforms.Resize(target_image_size), - RandomCrop1dReturnCoordinates(target_image_size) - ]) - elif crop_method == 'random-2d': - transform_functions.extend([ - Random2dCropReturnCoordinates(target_image_size), - transforms.Resize(target_image_size) - ]) + elif crop_method == "center": + transform_functions.extend( + [transforms.Resize(target_image_size), CenterCropReturnCoordinates(target_image_size)] + ) + elif crop_method == "random-1d": + transform_functions.extend( + [transforms.Resize(target_image_size), RandomCrop1dReturnCoordinates(target_image_size)] + ) + elif crop_method == "random-2d": + transform_functions.extend( + [Random2dCropReturnCoordinates(target_image_size), transforms.Resize(target_image_size)] + ) elif crop_method is None: return None else: - raise ValueError(f'Received invalid crop method [{crop_method}].') + raise ValueError(f"Received invalid crop method [{crop_method}].") if random_flip: transform_functions.append(RandomHorizontalFlipReturn()) - transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.)) + transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.0)) return transform_functions def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool], Tensor): crop_bbox = None flipped = None for t in self.transform_functions: - if isinstance(t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates)): + if isinstance( + t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates) + ): crop_bbox, x = t(x) elif isinstance(t, RandomHorizontalFlipReturn): flipped, x = t(x) @@ -114,22 +130,22 @@ def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder: # cannot set this up in init because no_classes is only known after loading data in init of superclass if self._conditional_builders is None: self._conditional_builders = { - 'objects_center_points': ObjectsCenterPointsConditionalBuilder( + "objects_center_points": ObjectsCenterPointsConditionalBuilder( self.no_classes, self.max_objects_per_image, self.no_tokens, self.encode_crop, self.use_group_parameter, - getattr(self, 'use_additional_parameters', False) + getattr(self, "use_additional_parameters", False), ), - 'objects_bbox': ObjectsBoundingBoxConditionalBuilder( + "objects_bbox": ObjectsBoundingBoxConditionalBuilder( self.no_classes, self.max_objects_per_image, self.no_tokens, self.encode_crop, self.use_group_parameter, - getattr(self, 'use_additional_parameters', False) - ) + getattr(self, "use_additional_parameters", False), + ), } return self._conditional_builders @@ -142,14 +158,19 @@ def filter_categories(self) -> None: def setup_category_id_and_number(self) -> None: self.category_ids = list(self.categories.keys()) self.category_ids.sort() - if '/m/01s55n' in self.category_ids: - self.category_ids.remove('/m/01s55n') - self.category_ids.append('/m/01s55n') + if "/m/01s55n" in self.category_ids: + self.category_ids.remove("/m/01s55n") + self.category_ids.append("/m/01s55n") self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)} - if self.category_allow_list is not None and self.category_mapping is None \ - and len(self.category_ids) != len(self.category_allow_list): - warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. ' - 'Make sure all names in category_allow_list exist.') + if ( + self.category_allow_list is not None + and self.category_mapping is None + and len(self.category_ids) != len(self.category_allow_list) + ): + warnings.warn( + "Unexpected number of categories: Mismatch with category_allow_list. " + "Make sure all names in category_allow_list exist." + ) def clean_up_annotations_and_image_descriptions(self) -> None: image_id_set = set(self.image_ids) @@ -157,8 +178,12 @@ def clean_up_annotations_and_image_descriptions(self) -> None: self.image_descriptions = {k: v for k, v in self.image_descriptions.items() if k in image_id_set} @staticmethod - def filter_object_number(all_annotations: Dict[str, List[Annotation]], min_object_area: float, - min_objects_per_image: int, max_objects_per_image: int) -> Dict[str, List[Annotation]]: + def filter_object_number( + all_annotations: Dict[str, List[Annotation]], + min_object_area: float, + min_objects_per_image: int, + max_objects_per_image: int, + ) -> Dict[str, List[Annotation]]: filtered = {} for image_id, annotations in all_annotations.items(): annotations_with_min_area = [a for a in annotations if a.area > min_object_area] @@ -172,18 +197,18 @@ def __len__(self): def __getitem__(self, n: int) -> Dict[str, Any]: image_id = self.get_image_id(n) sample = self.get_image_description(image_id) - sample['annotations'] = self.get_annotation(image_id) + sample["annotations"] = self.get_annotation(image_id) - if 'image' in self.keys: - sample['image_path'] = str(self.get_image_path(image_id)) - sample['image'] = self.load_image_from_disk(sample['image_path']) - sample['image'] = convert_pil_to_tensor(sample['image']) - sample['crop_bbox'], sample['flipped'], sample['image'] = self.image_transform(sample['image']) - sample['image'] = sample['image'].permute(1, 2, 0) + if "image" in self.keys: + sample["image_path"] = str(self.get_image_path(image_id)) + sample["image"] = self.load_image_from_disk(sample["image_path"]) + sample["image"] = convert_pil_to_tensor(sample["image"]) + sample["crop_bbox"], sample["flipped"], sample["image"] = self.image_transform(sample["image"]) + sample["image"] = sample["image"].permute(1, 2, 0) for conditional, builder in self.conditional_builders.items(): if conditional in self.keys: - sample[conditional] = builder.build(sample['annotations'], sample['crop_bbox'], sample['flipped']) + sample[conditional] = builder.build(sample["annotations"], sample["crop_bbox"], sample["flipped"]) if self.keys: # only return specified keys diff --git a/src/metr/taming/data/annotated_objects_open_images.py b/src/metr/taming/data/annotated_objects_open_images.py index aede680..c9f882e 100644 --- a/src/metr/taming/data/annotated_objects_open_images.py +++ b/src/metr/taming/data/annotated_objects_open_images.py @@ -1,77 +1,79 @@ +import warnings from collections import defaultdict -from csv import DictReader, reader as TupleReader +from csv import DictReader +from csv import reader as TupleReader from pathlib import Path -from typing import Dict, List, Any -import warnings +from typing import Any, Dict, List from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset from taming.data.helper_types import Annotation, Category from tqdm import tqdm OPEN_IMAGES_STRUCTURE = { - 'train': { - 'top_level': '', - 'class_descriptions': 'class-descriptions-boxable.csv', - 'annotations': 'oidv6-train-annotations-bbox.csv', - 'file_list': 'train-images-boxable.csv', - 'files': 'train' + "train": { + "top_level": "", + "class_descriptions": "class-descriptions-boxable.csv", + "annotations": "oidv6-train-annotations-bbox.csv", + "file_list": "train-images-boxable.csv", + "files": "train", + }, + "validation": { + "top_level": "", + "class_descriptions": "class-descriptions-boxable.csv", + "annotations": "validation-annotations-bbox.csv", + "file_list": "validation-images.csv", + "files": "validation", }, - 'validation': { - 'top_level': '', - 'class_descriptions': 'class-descriptions-boxable.csv', - 'annotations': 'validation-annotations-bbox.csv', - 'file_list': 'validation-images.csv', - 'files': 'validation' + "test": { + "top_level": "", + "class_descriptions": "class-descriptions-boxable.csv", + "annotations": "test-annotations-bbox.csv", + "file_list": "test-images.csv", + "files": "test", }, - 'test': { - 'top_level': '', - 'class_descriptions': 'class-descriptions-boxable.csv', - 'annotations': 'test-annotations-bbox.csv', - 'file_list': 'test-images.csv', - 'files': 'test' - } } -def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str], - category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]: +def load_annotations( + descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str], category_no_for_id: Dict[str, int] +) -> Dict[str, List[Annotation]]: annotations: Dict[str, List[Annotation]] = defaultdict(list) with open(descriptor_path) as file: reader = DictReader(file) - for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'): - width = float(row['XMax']) - float(row['XMin']) - height = float(row['YMax']) - float(row['YMin']) + for i, row in tqdm(enumerate(reader), total=14620000, desc="Loading OpenImages annotations"): + width = float(row["XMax"]) - float(row["XMin"]) + height = float(row["YMax"]) - float(row["YMin"]) area = width * height - category_id = row['LabelName'] + category_id = row["LabelName"] if category_id in category_mapping: category_id = category_mapping[category_id] if area >= min_object_area and category_id in category_no_for_id: - annotations[row['ImageID']].append( + annotations[row["ImageID"]].append( Annotation( id=i, - image_id=row['ImageID'], - source=row['Source'], + image_id=row["ImageID"], + source=row["Source"], category_id=category_id, category_no=category_no_for_id[category_id], - confidence=float(row['Confidence']), - bbox=(float(row['XMin']), float(row['YMin']), width, height), + confidence=float(row["Confidence"]), + bbox=(float(row["XMin"]), float(row["YMin"]), width, height), area=area, - is_occluded=bool(int(row['IsOccluded'])), - is_truncated=bool(int(row['IsTruncated'])), - is_group_of=bool(int(row['IsGroupOf'])), - is_depiction=bool(int(row['IsDepiction'])), - is_inside=bool(int(row['IsInside'])) + is_occluded=bool(int(row["IsOccluded"])), + is_truncated=bool(int(row["IsTruncated"])), + is_group_of=bool(int(row["IsGroupOf"])), + is_depiction=bool(int(row["IsDepiction"])), + is_inside=bool(int(row["IsInside"])), ) ) - if 'train' in str(descriptor_path) and i < 14000000: - warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].') + if "train" in str(descriptor_path) and i < 14000000: + warnings.warn(f"Running with subset of Open Images. Train dataset has length [{len(annotations)}].") return dict(annotations) def load_image_ids(csv_path: Path) -> List[str]: with open(csv_path) as file: reader = DictReader(file) - return [row['image_name'] for row in reader] + return [row["image_name"] for row in reader] def load_categories(csv_path: Path) -> Dict[str, Category]: @@ -112,26 +114,28 @@ def __init__(self, use_additional_parameters: bool, **kwargs): super().__init__(**kwargs) self.use_additional_parameters = use_additional_parameters - self.categories = load_categories(self.paths['class_descriptions']) + self.categories = load_categories(self.paths["class_descriptions"]) self.filter_categories() self.setup_category_id_and_number() self.image_descriptions = {} - annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping, - self.category_number) - self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image, - self.max_objects_per_image) + annotations = load_annotations( + self.paths["annotations"], self.min_object_area, self.category_mapping, self.category_number + ) + self.annotations = self.filter_object_number( + annotations, self.min_object_area, self.min_objects_per_image, self.max_objects_per_image + ) self.image_ids = list(self.annotations.keys()) self.clean_up_annotations_and_image_descriptions() def get_path_structure(self) -> Dict[str, str]: if self.split not in OPEN_IMAGES_STRUCTURE: - raise ValueError(f'Split [{self.split} does not exist for Open Images data.]') + raise ValueError(f"Split [{self.split} does not exist for Open Images data.]") return OPEN_IMAGES_STRUCTURE[self.split] def get_image_path(self, image_id: str) -> Path: - return self.paths['files'].joinpath(f'{image_id:0>16}.jpg') + return self.paths["files"].joinpath(f"{image_id:0>16}.jpg") def get_image_description(self, image_id: str) -> Dict[str, Any]: image_path = self.get_image_path(image_id) - return {'file_path': str(image_path), 'file_name': image_path.name} + return {"file_path": str(image_path), "file_name": image_path.name} diff --git a/src/metr/taming/data/base.py b/src/metr/taming/data/base.py index e21667d..f97feb8 100644 --- a/src/metr/taming/data/base.py +++ b/src/metr/taming/data/base.py @@ -1,12 +1,14 @@ import bisect -import numpy as np + import albumentations +import numpy as np from PIL import Image -from torch.utils.data import Dataset, ConcatDataset +from torch.utils.data import ConcatDataset, Dataset class ConcatDatasetWithIndex(ConcatDataset): """Modified from original pytorch code to return dataset idx""" + def __getitem__(self, idx): if idx < 0: if -idx > len(self): @@ -30,11 +32,11 @@ def __init__(self, paths, size=None, random_crop=False, labels=None): self._length = len(paths) if self.size is not None and self.size > 0: - self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) + self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) if not self.random_crop: - self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) + self.cropper = albumentations.CenterCrop(height=self.size, width=self.size) else: - self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) + self.cropper = albumentations.RandomCrop(height=self.size, width=self.size) self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) else: self.preprocessor = lambda **kwargs: kwargs @@ -48,7 +50,7 @@ def preprocess_image(self, image_path): image = image.convert("RGB") image = np.array(image).astype(np.uint8) image = self.preprocessor(image=image)["image"] - image = (image/127.5 - 1.0).astype(np.float32) + image = (image / 127.5 - 1.0).astype(np.float32) return image def __getitem__(self, i): @@ -62,9 +64,9 @@ def __getitem__(self, i): class NumpyPaths(ImagePaths): def preprocess_image(self, image_path): image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 - image = np.transpose(image, (1,2,0)) + image = np.transpose(image, (1, 2, 0)) image = Image.fromarray(image, mode="RGB") image = np.array(image).astype(np.uint8) image = self.preprocessor(image=image)["image"] - image = (image/127.5 - 1.0).astype(np.float32) + image = (image / 127.5 - 1.0).astype(np.float32) return image diff --git a/src/metr/taming/data/coco.py b/src/metr/taming/data/coco.py index 2b2f783..709abe9 100644 --- a/src/metr/taming/data/coco.py +++ b/src/metr/taming/data/coco.py @@ -1,28 +1,42 @@ -import os import json +import os + import albumentations import numpy as np from PIL import Image -from tqdm import tqdm +from taming.data.sflckr import SegmentationBase # for examples included in repo from torch.utils.data import Dataset - -from taming.data.sflckr import SegmentationBase # for examples included in repo +from tqdm import tqdm class Examples(SegmentationBase): def __init__(self, size=256, random_crop=False, interpolation="bicubic"): - super().__init__(data_csv="data/coco_examples.txt", - data_root="data/coco_images", - segmentation_root="data/coco_segmentations", - size=size, random_crop=random_crop, - interpolation=interpolation, - n_labels=183, shift_segmentation=True) + super().__init__( + data_csv="data/coco_examples.txt", + data_root="data/coco_images", + segmentation_root="data/coco_segmentations", + size=size, + random_crop=random_crop, + interpolation=interpolation, + n_labels=183, + shift_segmentation=True, + ) class CocoBase(Dataset): """needed for (image, caption, segmentation) pairs""" - def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False, - crop_size=None, force_no_crop=False, given_files=None): + + def __init__( + self, + size=None, + dataroot="", + datajson="", + onehot_segmentation=False, + use_stuffthing=False, + crop_size=None, + force_no_crop=False, + given_files=None, + ): self.split = self.get_split() self.size = size if crop_size is None: @@ -30,12 +44,14 @@ def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=Fals else: self.crop_size = crop_size - self.onehot = onehot_segmentation # return segmentation as rgb or one hot - self.stuffthing = use_stuffthing # include thing in segmentation + self.onehot = onehot_segmentation # return segmentation as rgb or one hot + self.stuffthing = use_stuffthing # include thing in segmentation if self.onehot and not self.stuffthing: - raise NotImplemented("One hot mode is only supported for the " - "stuffthings version because labels are stored " - "a bit different.") + raise NotImplemented( + "One hot mode is only supported for the " + "stuffthings version because labels are stored " + "a bit different." + ) data_json = datajson with open(data_json) as json_file: @@ -44,18 +60,19 @@ def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=Fals self.img_id_to_filepath = dict() self.img_id_to_segmentation_filepath = dict() - assert data_json.split("/")[-1] in ["captions_train2017.json", - "captions_val2017.json"] + assert data_json.split("/")[-1] in ["captions_train2017.json", "captions_val2017.json"] if self.stuffthing: self.segmentation_prefix = ( - "data/cocostuffthings/val2017" if - data_json.endswith("captions_val2017.json") else - "data/cocostuffthings/train2017") + "data/cocostuffthings/val2017" + if data_json.endswith("captions_val2017.json") + else "data/cocostuffthings/train2017" + ) else: self.segmentation_prefix = ( - "data/coco/annotations/stuff_val2017_pixelmaps" if - data_json.endswith("captions_val2017.json") else - "data/coco/annotations/stuff_train2017_pixelmaps") + "data/coco/annotations/stuff_val2017_pixelmaps" + if data_json.endswith("captions_val2017.json") + else "data/coco/annotations/stuff_train2017_pixelmaps" + ) imagedirs = self.json_data["images"] self.labels = {"image_ids": list()} @@ -63,8 +80,7 @@ def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=Fals self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"]) self.img_id_to_captions[imgdir["id"]] = list() pngfilename = imgdir["file_name"].replace("jpg", "png") - self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join( - self.segmentation_prefix, pngfilename) + self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(self.segmentation_prefix, pngfilename) if given_files is not None: if pngfilename in given_files: self.labels["image_ids"].append(imgdir["id"]) @@ -77,18 +93,16 @@ def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=Fals self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]])) self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) - if self.split=="validation": + if self.split == "validation": self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) else: self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) self.preprocessor = albumentations.Compose( - [self.rescaler, self.cropper], - additional_targets={"segmentation": "image"}) + [self.rescaler, self.cropper], additional_targets={"segmentation": "image"} + ) if force_no_crop: self.rescaler = albumentations.Resize(height=self.size, width=self.size) - self.preprocessor = albumentations.Compose( - [self.rescaler], - additional_targets={"segmentation": "image"}) + self.preprocessor = albumentations.Compose([self.rescaler], additional_targets={"segmentation": "image"}) def __len__(self): return len(self.labels["image_ids"]) @@ -138,24 +152,30 @@ def __getitem__(self, i): captions = self.img_id_to_captions[self.labels["image_ids"][i]] # randomly draw one of all available captions per image caption = captions[np.random.randint(0, len(captions))] - example = {"image": image, - "caption": [str(caption[0])], - "segmentation": segmentation, - "img_path": img_path, - "seg_path": seg_path, - "filename_": img_path.split(os.sep)[-1] - } + example = { + "image": image, + "caption": [str(caption[0])], + "segmentation": segmentation, + "img_path": img_path, + "seg_path": seg_path, + "filename_": img_path.split(os.sep)[-1], + } return example class CocoImagesAndCaptionsTrain(CocoBase): """returns a pair of (image, caption)""" + def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False): - super().__init__(size=size, - dataroot="data/coco/train2017", - datajson="data/coco/annotations/captions_train2017.json", - onehot_segmentation=onehot_segmentation, - use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop) + super().__init__( + size=size, + dataroot="data/coco/train2017", + datajson="data/coco/annotations/captions_train2017.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, + crop_size=crop_size, + force_no_crop=force_no_crop, + ) def get_split(self): return "train" @@ -163,14 +183,26 @@ def get_split(self): class CocoImagesAndCaptionsValidation(CocoBase): """returns a pair of (image, caption)""" - def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, - given_files=None): - super().__init__(size=size, - dataroot="data/coco/val2017", - datajson="data/coco/annotations/captions_val2017.json", - onehot_segmentation=onehot_segmentation, - use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, - given_files=given_files) + + def __init__( + self, + size, + onehot_segmentation=False, + use_stuffthing=False, + crop_size=None, + force_no_crop=False, + given_files=None, + ): + super().__init__( + size=size, + dataroot="data/coco/val2017", + datajson="data/coco/annotations/captions_val2017.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, + crop_size=crop_size, + force_no_crop=force_no_crop, + given_files=given_files, + ) def get_split(self): return "validation" diff --git a/src/metr/taming/data/conditional_builder/objects_bbox.py b/src/metr/taming/data/conditional_builder/objects_bbox.py index 15881e7..17e815f 100644 --- a/src/metr/taming/data/conditional_builder/objects_bbox.py +++ b/src/metr/taming/data/conditional_builder/objects_bbox.py @@ -1,16 +1,25 @@ from itertools import cycle -from typing import List, Tuple, Callable, Optional +from typing import Callable, List, Optional, Tuple -from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont from more_itertools.recipes import grouper +from PIL import Image as pil_image +from PIL import ImageDraw as pil_img_draw +from PIL import ImageFont +from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder +from taming.data.conditional_builder.utils import ( + BLACK, + COLOR_PALETTE, + GRAY_75, + WHITE, + absolute_bbox, + additional_parameters_string, + get_plot_font_size, + pad_list, +) +from taming.data.helper_types import Annotation, BoundingBox from taming.data.image_transforms import convert_pil_to_tensor from torch import LongTensor, Tensor -from taming.data.helper_types import BoundingBox, Annotation -from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder -from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \ - pad_list, get_plot_font_size, absolute_bbox - class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder): @property @@ -19,8 +28,7 @@ def object_descriptor_length(self) -> int: def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: object_triples = [ - (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox)) - for ann in annotations + (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox)) for ann in annotations ] empty_triple = (self.none, self.none, self.none) object_triples = pad_list(object_triples, empty_triple, self.no_max_objects) @@ -36,25 +44,31 @@ def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, Boundi assert conditional.shape[0] == self.embedding_dim return [ (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2])) - for object_triple in object_triples if object_triple[0] != self.none + for object_triple in object_triples + if object_triple[0] != self.none ], crop_coordinates - def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], - line_width: int = 3, font_size: Optional[int] = None) -> Tensor: - plot = pil_image.new('RGB', figure_size, WHITE) + def plot( + self, + conditional: LongTensor, + label_for_category_no: Callable[[int], str], + figure_size: Tuple[int, int], + line_width: int = 3, + font_size: Optional[int] = None, + ) -> Tensor: + plot = pil_image.new("RGB", figure_size, WHITE) draw = pil_img_draw.Draw(plot) font = ImageFont.truetype( - "/usr/share/fonts/truetype/lato/Lato-Regular.ttf", - size=get_plot_font_size(font_size, figure_size) + "/usr/share/fonts/truetype/lato/Lato-Regular.ttf", size=get_plot_font_size(font_size, figure_size) ) width, height = plot.size description, crop_coordinates = self.inverse_build(conditional) for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)): annotation = self.representation_to_annotation(representation) - class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation) + class_label = label_for_category_no(annotation.category_no) + " " + additional_parameters_string(annotation) bbox = absolute_bbox(bbox, width, height) draw.rectangle(bbox, outline=color, width=line_width) - draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font) + draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor="la", fill=BLACK, font=font) if crop_coordinates is not None: draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) - return convert_pil_to_tensor(plot) / 127.5 - 1. + return convert_pil_to_tensor(plot) / 127.5 - 1.0 diff --git a/src/metr/taming/data/conditional_builder/objects_center_points.py b/src/metr/taming/data/conditional_builder/objects_center_points.py index 9a48032..f394004 100644 --- a/src/metr/taming/data/conditional_builder/objects_center_points.py +++ b/src/metr/taming/data/conditional_builder/objects_center_points.py @@ -2,21 +2,42 @@ import random import warnings from itertools import cycle -from typing import List, Optional, Tuple, Callable +from typing import Callable, List, Optional, Tuple -from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont from more_itertools.recipes import grouper -from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \ - additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \ - absolute_bbox, rescale_annotations -from taming.data.helper_types import BoundingBox, Annotation +from PIL import Image as pil_image +from PIL import ImageDraw as pil_img_draw +from PIL import ImageFont +from taming.data.conditional_builder.utils import ( + BLACK, + COLOR_PALETTE, + FULL_CROP, + GRAY_75, + WHITE, + absolute_bbox, + additional_parameters_string, + filter_annotations, + get_circle_size, + get_plot_font_size, + horizontally_flip_bbox, + pad_list, + rescale_annotations, +) +from taming.data.helper_types import Annotation, BoundingBox from taming.data.image_transforms import convert_pil_to_tensor from torch import LongTensor, Tensor class ObjectsCenterPointsConditionalBuilder: - def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool, - use_group_parameter: bool, use_additional_parameters: bool): + def __init__( + self, + no_object_classes: int, + no_max_objects: int, + no_tokens: int, + encode_crop: bool, + use_group_parameter: bool, + use_additional_parameters: bool, + ): self.no_object_classes = no_object_classes self.no_max_objects = no_max_objects self.no_tokens = no_tokens @@ -66,11 +87,13 @@ def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox: return x0, y0, x1 - x0, y1 - y0 def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]: - return self.tokenize_coordinates(bbox[0], bbox[1]), \ - self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3]) + return self.tokenize_coordinates(bbox[0], bbox[1]), self.tokenize_coordinates( + bbox[0] + bbox[2], bbox[1] + bbox[3] + ) - def inverse_build(self, conditional: LongTensor) \ - -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]: + def inverse_build( + self, conditional: LongTensor + ) -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]: conditional_list = conditional.tolist() crop_coordinates = None if self.encode_crop: @@ -80,28 +103,36 @@ def inverse_build(self, conditional: LongTensor) \ assert conditional.shape[0] == self.embedding_dim return [ (object_tuple[0], self.coordinates_from_token(object_tuple[1])) - for object_tuple in table_of_content if object_tuple[0] != self.none + for object_tuple in table_of_content + if object_tuple[0] != self.none ], crop_coordinates - def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], - line_width: int = 3, font_size: Optional[int] = None) -> Tensor: - plot = pil_image.new('RGB', figure_size, WHITE) + def plot( + self, + conditional: LongTensor, + label_for_category_no: Callable[[int], str], + figure_size: Tuple[int, int], + line_width: int = 3, + font_size: Optional[int] = None, + ) -> Tensor: + plot = pil_image.new("RGB", figure_size, WHITE) draw = pil_img_draw.Draw(plot) circle_size = get_circle_size(figure_size) - font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf', - size=get_plot_font_size(font_size, figure_size)) + font = ImageFont.truetype( + "/usr/share/fonts/truetype/lato/Lato-Regular.ttf", size=get_plot_font_size(font_size, figure_size) + ) width, height = plot.size description, crop_coordinates = self.inverse_build(conditional) for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)): x_abs, y_abs = x * width, y * height ann = self.representation_to_annotation(representation) - label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann) + label = label_for_category_no(ann.category_no) + " " + additional_parameters_string(ann) ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size] draw.ellipse(ellipse_bbox, fill=color, width=0) - draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font) + draw.text((x_abs, y_abs), label, anchor="md", fill=BLACK, font=font) if crop_coordinates is not None: draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) - return convert_pil_to_tensor(plot) / 127.5 - 1. + return convert_pil_to_tensor(plot) / 127.5 - 1.0 def object_representation(self, annotation: Annotation) -> int: modifier = 0 @@ -118,12 +149,18 @@ def representation_to_annotation(self, representation: int) -> Annotation: modifier = representation // self.no_object_classes # noinspection PyTypeChecker return Annotation( - area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None, + area=None, + image_id=None, + bbox=None, + category_id=None, + id=None, + source=None, + confidence=None, category_no=category_no, is_group_of=bool((modifier & 1) * self.use_group_parameter), is_occluded=bool((modifier & 2) * self.use_additional_parameters), is_depiction=bool((modifier & 4) * self.use_additional_parameters), - is_inside=bool((modifier & 8) * self.use_additional_parameters) + is_inside=bool((modifier & 8) * self.use_additional_parameters), ) def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]: @@ -131,21 +168,24 @@ def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]: def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: object_tuples = [ - (self.object_representation(a), - self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2)) + ( + self.object_representation(a), + self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2), + ) for a in annotations ] empty_tuple = (self.none, self.none) object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects) return object_tuples - def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \ - -> LongTensor: + def build( + self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False + ) -> LongTensor: if len(annotations) == 0: - warnings.warn('Did not receive any annotations.') + warnings.warn("Did not receive any annotations.") if len(annotations) > self.no_max_objects: - warnings.warn('Received more annotations than allowed.') - annotations = annotations[:self.no_max_objects] + warnings.warn("Received more annotations than allowed.") + annotations = annotations[: self.no_max_objects] if not crop_coordinates: crop_coordinates = FULL_CROP diff --git a/src/metr/taming/data/conditional_builder/utils.py b/src/metr/taming/data/conditional_builder/utils.py index d0ee175..ae272d8 100644 --- a/src/metr/taming/data/conditional_builder/utils.py +++ b/src/metr/taming/data/conditional_builder/utils.py @@ -1,17 +1,27 @@ import importlib -from typing import List, Any, Tuple, Optional +from typing import Any, List, Optional, Tuple -from taming.data.helper_types import BoundingBox, Annotation +from taming.data.helper_types import Annotation, BoundingBox # source: seaborn, color palette tab10 -COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188), - (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)] +COLOR_PALETTE = [ + (30, 118, 179), + (255, 126, 13), + (43, 159, 43), + (213, 38, 39), + (147, 102, 188), + (139, 85, 74), + (226, 118, 193), + (126, 126, 126), + (187, 188, 33), + (22, 189, 206), +] BLACK = (0, 0, 0) GRAY_75 = (63, 63, 63) GRAY_50 = (127, 127, 127) GRAY_25 = (191, 191, 191) WHITE = (255, 255, 255) -FULL_CROP = (0., 0., 1., 1.) +FULL_CROP = (0.0, 0.0, 1.0, 1.0) def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float: @@ -22,8 +32,8 @@ def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float """ rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3] rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3] - x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0])) - y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1])) + x_overlap = max(0.0, min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0])) + y_overlap = max(0.0, min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1])) return x_overlap * y_overlap @@ -41,10 +51,9 @@ def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List: return list_ + [pad_element for _ in range(pad_to_length - len(list_))] -def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \ - List[Annotation]: +def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> List[Annotation]: def clamp(x: float): - return max(min(x, 1.), 0.) + return max(min(x, 1.0), 0.0) def rescale_bbox(bbox: BoundingBox) -> BoundingBox: x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) @@ -64,18 +73,18 @@ def filter_annotations(annotations: List[Annotation], crop_coordinates: Bounding def additional_parameters_string(annotation: Annotation, short: bool = True) -> str: sl = slice(1) if short else slice(None) - string = '' + string = "" if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside): return string if annotation.is_group_of: - string += 'group'[sl] + ',' + string += "group"[sl] + "," if annotation.is_occluded: - string += 'occluded'[sl] + ',' + string += "occluded"[sl] + "," if annotation.is_depiction: - string += 'depiction'[sl] + ',' + string += "depiction"[sl] + "," if annotation.is_inside: - string += 'inside'[sl] - return '(' + string.strip(",") + ')' + string += "inside"[sl] + return "(" + string.strip(",") + ")" def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int: diff --git a/src/metr/taming/data/custom.py b/src/metr/taming/data/custom.py index 33f302a..adee08e 100644 --- a/src/metr/taming/data/custom.py +++ b/src/metr/taming/data/custom.py @@ -1,10 +1,10 @@ import os -import numpy as np + import albumentations +import numpy as np +from taming.data.base import ConcatDatasetWithIndex, ImagePaths, NumpyPaths from torch.utils.data import Dataset -from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex - class CustomBase(Dataset): def __init__(self, *args, **kwargs): @@ -19,7 +19,6 @@ def __getitem__(self, i): return example - class CustomTrain(CustomBase): def __init__(self, size, training_images_list_file): super().__init__() @@ -34,5 +33,3 @@ def __init__(self, size, test_images_list_file): with open(test_images_list_file, "r") as f: paths = f.read().splitlines() self.data = ImagePaths(paths=paths, size=size, random_crop=False) - - diff --git a/src/metr/taming/data/faceshq.py b/src/metr/taming/data/faceshq.py index 6912d04..03083e0 100644 --- a/src/metr/taming/data/faceshq.py +++ b/src/metr/taming/data/faceshq.py @@ -1,10 +1,10 @@ import os -import numpy as np + import albumentations +import numpy as np +from taming.data.base import ConcatDatasetWithIndex, ImagePaths, NumpyPaths from torch.utils.data import Dataset -from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex - class FacesBase(Dataset): def __init__(self, *args, **kwargs): @@ -78,10 +78,9 @@ def __init__(self, size, keys=None, crop_size=None, coord=False): self.data = ConcatDatasetWithIndex([d1, d2]) self.coord = coord if crop_size is not None: - self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) + self.cropper = albumentations.RandomCrop(height=crop_size, width=crop_size) if self.coord: - self.cropper = albumentations.Compose([self.cropper], - additional_targets={"coord": "image"}) + self.cropper = albumentations.Compose([self.cropper], additional_targets={"coord": "image"}) def __len__(self): return len(self.data) @@ -93,8 +92,8 @@ def __getitem__(self, i): out = self.cropper(image=ex["image"]) ex["image"] = out["image"] else: - h,w,_ = ex["image"].shape - coord = np.arange(h*w).reshape(h,w,1)/(h*w) + h, w, _ = ex["image"].shape + coord = np.arange(h * w).reshape(h, w, 1) / (h * w) out = self.cropper(image=ex["image"], coord=coord) ex["image"] = out["image"] ex["coord"] = out["coord"] @@ -110,10 +109,9 @@ def __init__(self, size, keys=None, crop_size=None, coord=False): self.data = ConcatDatasetWithIndex([d1, d2]) self.coord = coord if crop_size is not None: - self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) + self.cropper = albumentations.CenterCrop(height=crop_size, width=crop_size) if self.coord: - self.cropper = albumentations.Compose([self.cropper], - additional_targets={"coord": "image"}) + self.cropper = albumentations.Compose([self.cropper], additional_targets={"coord": "image"}) def __len__(self): return len(self.data) @@ -125,8 +123,8 @@ def __getitem__(self, i): out = self.cropper(image=ex["image"]) ex["image"] = out["image"] else: - h,w,_ = ex["image"].shape - coord = np.arange(h*w).reshape(h,w,1)/(h*w) + h, w, _ = ex["image"].shape + coord = np.arange(h * w).reshape(h, w, 1) / (h * w) out = self.cropper(image=ex["image"], coord=coord) ex["image"] = out["image"] ex["coord"] = out["coord"] diff --git a/src/metr/taming/data/helper_types.py b/src/metr/taming/data/helper_types.py index fb51e30..88f9b91 100644 --- a/src/metr/taming/data/helper_types.py +++ b/src/metr/taming/data/helper_types.py @@ -1,16 +1,17 @@ -from typing import Dict, Tuple, Optional, NamedTuple, Union +from typing import Dict, NamedTuple, Optional, Tuple, Union + from PIL.Image import Image as pil_image from torch import Tensor try: - from typing import Literal + from typing import Literal except ImportError: - from typing_extensions import Literal + from typing_extensions import Literal Image = Union[Tensor, pil_image] BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h -CropMethodType = Literal['none', 'random', 'center', 'random-2d'] -SplitType = Literal['train', 'validation', 'test'] +CropMethodType = Literal["none", "random", "center", "random-2d"] +SplitType = Literal["train", "validation", "test"] class ImageDescription(NamedTuple): diff --git a/src/metr/taming/data/image_transforms.py b/src/metr/taming/data/image_transforms.py index 657ac33..e30f715 100644 --- a/src/metr/taming/data/image_transforms.py +++ b/src/metr/taming/data/image_transforms.py @@ -3,12 +3,12 @@ from typing import Union import torch +from taming.data.helper_types import BoundingBox, Image from torch import Tensor -from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor +from torchvision.transforms import CenterCrop, PILToTensor, RandomCrop, RandomHorizontalFlip +from torchvision.transforms import functional as F from torchvision.transforms.functional import _get_image_size as get_image_size -from taming.data.helper_types import BoundingBox, Image - pil_to_tensor = PILToTensor() @@ -89,11 +89,11 @@ def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox: w = height / width h = 1.0 x0 = 0.5 - w / 2 - y0 = 0. + y0 = 0.0 else: w = 1.0 h = width / height - x0 = 0. + x0 = 0.0 y0 = 0.5 - h / 2 return x0, y0, w, h @@ -110,7 +110,7 @@ def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tenso torchvision.transforms.RandomHorizontalFlip (version 1.7.0) """ width, height = get_image_size(img) - return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size) + return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size) class RandomHorizontalFlipReturn(RandomHorizontalFlip): diff --git a/src/metr/taming/data/imagenet.py b/src/metr/taming/data/imagenet.py index 9a02ec4..2d45839 100644 --- a/src/metr/taming/data/imagenet.py +++ b/src/metr/taming/data/imagenet.py @@ -1,15 +1,18 @@ -import os, tarfile, glob, shutil -import yaml -import numpy as np -from tqdm import tqdm -from PIL import Image +import glob +import os +import shutil +import tarfile + import albumentations +import numpy as np +import taming.data.utils as bdu +import yaml from omegaconf import OmegaConf -from torch.utils.data import Dataset - +from PIL import Image from taming.data.base import ImagePaths from taming.util import download, retrieve -import taming.data.utils as bdu +from torch.utils.data import Dataset +from tqdm import tqdm def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"): @@ -41,7 +44,7 @@ def str_to_indices(string): class ImageNetBase(Dataset): def __init__(self, config=None): self.config = config or OmegaConf.create() - if not type(self.config)==dict: + if not type(self.config) == dict: self.config = OmegaConf.to_container(self.config) self._prepare() self._prepare_synset_to_human() @@ -58,9 +61,11 @@ def _prepare(self): raise NotImplementedError() def _filter_relpaths(self, relpaths): - ignore = set([ - "n06596364_9591.JPEG", - ]) + ignore = set( + [ + "n06596364_9591.JPEG", + ] + ) relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] if "sub_indices" in self.config: indices = str_to_indices(self.config["sub_indices"]) @@ -78,14 +83,13 @@ def _prepare_synset_to_human(self): SIZE = 2655750 URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" self.human_dict = os.path.join(self.root, "synset_human.txt") - if (not os.path.exists(self.human_dict) or - not os.path.getsize(self.human_dict)==SIZE): + if not os.path.exists(self.human_dict) or not os.path.getsize(self.human_dict) == SIZE: download(URL, self.human_dict) def _prepare_idx_to_synset(self): URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" self.idx2syn = os.path.join(self.root, "index_synset.yaml") - if (not os.path.exists(self.idx2syn)): + if not os.path.exists(self.idx2syn): download(URL, self.idx2syn) def _load(self): @@ -114,10 +118,9 @@ def _load(self): "class_label": np.array(self.class_labels), "human_label": np.array(self.human_labels), } - self.data = ImagePaths(self.abspaths, - labels=labels, - size=retrieve(self.config, "size", default=0), - random_crop=self.random_crop) + self.data = ImagePaths( + self.abspaths, labels=labels, size=retrieve(self.config, "size", default=0), random_crop=self.random_crop + ) class ImageNetTrain(ImageNetBase): @@ -132,8 +135,7 @@ class ImageNetTrain(ImageNetBase): ] def _prepare(self): - self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", - default=True) + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", default=True) cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) self.datadir = os.path.join(self.root, "data") @@ -146,8 +148,9 @@ def _prepare(self): datadir = self.datadir if not os.path.exists(datadir): path = os.path.join(self.root, self.FILES[0]) - if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + if not os.path.exists(path) or not os.path.getsize(path) == self.SIZES[0]: import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) assert atpath == path @@ -159,16 +162,15 @@ def _prepare(self): print("Extracting sub-tars.") subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) for subpath in tqdm(subpaths): - subdir = subpath[:-len(".tar")] + subdir = subpath[: -len(".tar")] os.makedirs(subdir, exist_ok=True) with tarfile.open(subpath, "r:") as tar: tar.extractall(path=subdir) - filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = sorted(filelist) - filelist = "\n".join(filelist)+"\n" + filelist = "\n".join(filelist) + "\n" with open(self.txt_filelist, "w") as f: f.write(filelist) @@ -190,8 +192,7 @@ class ImageNetValidation(ImageNetBase): ] def _prepare(self): - self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", - default=False) + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", default=False) cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) self.datadir = os.path.join(self.root, "data") @@ -204,8 +205,9 @@ def _prepare(self): datadir = self.datadir if not os.path.exists(datadir): path = os.path.join(self.root, self.FILES[0]) - if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + if not os.path.exists(path) or not os.path.getsize(path) == self.SIZES[0]: import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) assert atpath == path @@ -215,7 +217,7 @@ def _prepare(self): tar.extractall(path=datadir) vspath = os.path.join(self.root, self.FILES[1]) - if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + if not os.path.exists(vspath) or not os.path.getsize(vspath) == self.SIZES[1]: download(self.VS_URL, vspath) with open(vspath, "r") as f: @@ -234,37 +236,34 @@ def _prepare(self): filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = sorted(filelist) - filelist = "\n".join(filelist)+"\n" + filelist = "\n".join(filelist) + "\n" with open(self.txt_filelist, "w") as f: f.write(filelist) bdu.mark_prepared(self.root) -def get_preprocessor(size=None, random_crop=False, additional_targets=None, - crop_size=None): +def get_preprocessor(size=None, random_crop=False, additional_targets=None, crop_size=None): if size is not None and size > 0: transforms = list() - rescaler = albumentations.SmallestMaxSize(max_size = size) + rescaler = albumentations.SmallestMaxSize(max_size=size) transforms.append(rescaler) if not random_crop: - cropper = albumentations.CenterCrop(height=size,width=size) + cropper = albumentations.CenterCrop(height=size, width=size) transforms.append(cropper) else: - cropper = albumentations.RandomCrop(height=size,width=size) + cropper = albumentations.RandomCrop(height=size, width=size) transforms.append(cropper) flipper = albumentations.HorizontalFlip() transforms.append(flipper) - preprocessor = albumentations.Compose(transforms, - additional_targets=additional_targets) + preprocessor = albumentations.Compose(transforms, additional_targets=additional_targets) elif crop_size is not None and crop_size > 0: if not random_crop: - cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) + cropper = albumentations.CenterCrop(height=crop_size, width=crop_size) else: - cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) + cropper = albumentations.RandomCrop(height=crop_size, width=crop_size) transforms = [cropper] - preprocessor = albumentations.Compose(transforms, - additional_targets=additional_targets) + preprocessor = albumentations.Compose(transforms, additional_targets=additional_targets) else: preprocessor = lambda **kwargs: kwargs return preprocessor @@ -280,22 +279,19 @@ def rgba_to_depth(x): class BaseWithDepth(Dataset): - DEFAULT_DEPTH_ROOT="data/imagenet_depth" + DEFAULT_DEPTH_ROOT = "data/imagenet_depth" - def __init__(self, config=None, size=None, random_crop=False, - crop_size=None, root=None): + def __init__(self, config=None, size=None, random_crop=False, crop_size=None, root=None): self.config = config self.base_dset = self.get_base_dset() self.preprocessor = get_preprocessor( - size=size, - crop_size=crop_size, - random_crop=random_crop, - additional_targets={"depth": "image"}) + size=size, crop_size=crop_size, random_crop=random_crop, additional_targets={"depth": "image"} + ) self.crop_size = crop_size if self.crop_size is not None: self.rescaler = albumentations.Compose( - [albumentations.SmallestMaxSize(max_size = self.crop_size)], - additional_targets={"depth": "image"}) + [albumentations.SmallestMaxSize(max_size=self.crop_size)], additional_targets={"depth": "image"} + ) if root is not None: self.DEFAULT_DEPTH_ROOT = root @@ -305,16 +301,16 @@ def __len__(self): def preprocess_depth(self, path): rgba = np.array(Image.open(path)) depth = rgba_to_depth(rgba) - depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min()) - depth = 2.0*depth-1.0 + depth = (depth - depth.min()) / max(1e-8, depth.max() - depth.min()) + depth = 2.0 * depth - 1.0 return depth def __getitem__(self, i): e = self.base_dset[i] e["depth"] = self.preprocess_depth(self.get_depth_path(e)) # up if necessary - h,w,c = e["image"].shape - if self.crop_size and min(h,w) < self.crop_size: + h, w, c = e["image"].shape + if self.crop_size and min(h, w) < self.crop_size: # have to upscale to be able to crop - this just uses bilinear out = self.rescaler(image=e["image"], depth=e["depth"]) e["image"] = out["image"] @@ -338,7 +334,7 @@ def get_base_dset(self): return ImageNetTrain({"sub_indices": self.sub_indices}) def get_depth_path(self, e): - fid = os.path.splitext(e["relpath"])[0]+".png" + fid = os.path.splitext(e["relpath"])[0] + ".png" fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid) return fid @@ -355,7 +351,7 @@ def get_base_dset(self): return ImageNetValidation({"sub_indices": self.sub_indices}) def get_depth_path(self, e): - fid = os.path.splitext(e["relpath"])[0]+".png" + fid = os.path.splitext(e["relpath"])[0] + ".png" fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid) return fid @@ -363,15 +359,17 @@ def get_depth_path(self, e): class RINTrainWithDepth(ImageNetTrainWithDepth): def __init__(self, config=None, size=None, random_crop=True, crop_size=None): sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319" - super().__init__(config=config, size=size, random_crop=random_crop, - sub_indices=sub_indices, crop_size=crop_size) + super().__init__( + config=config, size=size, random_crop=random_crop, sub_indices=sub_indices, crop_size=crop_size + ) class RINValidationWithDepth(ImageNetValidationWithDepth): def __init__(self, config=None, size=None, random_crop=False, crop_size=None): sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319" - super().__init__(config=config, size=size, random_crop=random_crop, - sub_indices=sub_indices, crop_size=crop_size) + super().__init__( + config=config, size=size, random_crop=random_crop, sub_indices=sub_indices, crop_size=crop_size + ) class DRINExamples(Dataset): @@ -379,10 +377,8 @@ def __init__(self): self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"}) with open("data/drin_examples.txt", "r") as f: relpaths = f.read().splitlines() - self.image_paths = [os.path.join("data/drin_images", - relpath) for relpath in relpaths] - self.depth_paths = [os.path.join("data/drin_depth", - relpath.replace(".JPEG", ".png")) for relpath in relpaths] + self.image_paths = [os.path.join("data/drin_images", relpath) for relpath in relpaths] + self.depth_paths = [os.path.join("data/drin_depth", relpath.replace(".JPEG", ".png")) for relpath in relpaths] def __len__(self): return len(self.image_paths) @@ -393,14 +389,14 @@ def preprocess_image(self, image_path): image = image.convert("RGB") image = np.array(image).astype(np.uint8) image = self.preprocessor(image=image)["image"] - image = (image/127.5 - 1.0).astype(np.float32) + image = (image / 127.5 - 1.0).astype(np.float32) return image def preprocess_depth(self, path): rgba = np.array(Image.open(path)) depth = rgba_to_depth(rgba) - depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min()) - depth = 2.0*depth-1.0 + depth = (depth - depth.min()) / max(1e-8, depth.max() - depth.min()) + depth = 2.0 * depth - 1.0 return depth def __getitem__(self, i): @@ -414,7 +410,7 @@ def __getitem__(self, i): def imscale(x, factor, keepshapes=False, keepmode="bicubic"): - if factor is None or factor==1: + if factor is None or factor == 1: return x dtype = x.dtype @@ -422,30 +418,30 @@ def imscale(x, factor, keepshapes=False, keepmode="bicubic"): assert x.min() >= -1 assert x.max() <= 1 - keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR, - "bicubic": Image.BICUBIC}[keepmode] + keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR, "bicubic": Image.BICUBIC}[keepmode] - lr = (x+1.0)*127.5 - lr = lr.clip(0,255).astype(np.uint8) + lr = (x + 1.0) * 127.5 + lr = lr.clip(0, 255).astype(np.uint8) lr = Image.fromarray(lr) h, w, _ = x.shape - nh = h//factor - nw = w//factor + nh = h // factor + nw = w // factor assert nh > 0 and nw > 0, (nh, nw) - lr = lr.resize((nw,nh), Image.BICUBIC) + lr = lr.resize((nw, nh), Image.BICUBIC) if keepshapes: - lr = lr.resize((w,h), keepmode) - lr = np.array(lr)/127.5-1.0 + lr = lr.resize((w, h), keepmode) + lr = np.array(lr) / 127.5 - 1.0 lr = lr.astype(dtype) return lr class ImageNetScale(Dataset): - def __init__(self, size=None, crop_size=None, random_crop=False, - up_factor=None, hr_factor=None, keep_mode="bicubic"): + def __init__( + self, size=None, crop_size=None, random_crop=False, up_factor=None, hr_factor=None, keep_mode="bicubic" + ): self.base = self.get_base() self.size = size @@ -458,18 +454,18 @@ def __init__(self, size=None, crop_size=None, random_crop=False, transforms = list() if self.size is not None and self.size > 0: - rescaler = albumentations.SmallestMaxSize(max_size = self.size) + rescaler = albumentations.SmallestMaxSize(max_size=self.size) self.rescaler = rescaler transforms.append(rescaler) if self.crop_size is not None and self.crop_size > 0: if len(transforms) == 0: - self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size) + self.rescaler = albumentations.SmallestMaxSize(max_size=self.crop_size) if not self.random_crop: - cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size) + cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) else: - cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size) + cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) transforms.append(cropper) if len(transforms) > 0: @@ -477,8 +473,7 @@ def __init__(self, size=None, crop_size=None, random_crop=False, additional_targets = {"lr": "image"} else: additional_targets = None - self.preprocessor = albumentations.Compose(transforms, - additional_targets=additional_targets) + self.preprocessor = albumentations.Compose(transforms, additional_targets=additional_targets) else: self.preprocessor = lambda **kwargs: kwargs @@ -490,16 +485,15 @@ def __getitem__(self, i): image = example["image"] # adjust resolution image = imscale(image, self.hr_factor, keepshapes=False) - h,w,c = image.shape - if self.crop_size and min(h,w) < self.crop_size: + h, w, c = image.shape + if self.crop_size and min(h, w) < self.crop_size: # have to upscale to be able to crop - this just uses bilinear image = self.rescaler(image=image)["image"] if self.up_factor is None: image = self.preprocessor(image=image)["image"] example["image"] = image else: - lr = imscale(image, self.up_factor, keepshapes=True, - keepmode=self.keep_mode) + lr = imscale(image, self.up_factor, keepshapes=True, keepmode=self.keep_mode) out = self.preprocessor(image=image, lr=lr) example["image"] = out["image"] @@ -507,6 +501,7 @@ def __getitem__(self, i): return example + class ImageNetScaleTrain(ImageNetScale): def __init__(self, random_crop=True, **kwargs): super().__init__(random_crop=random_crop, **kwargs) @@ -514,13 +509,14 @@ def __init__(self, random_crop=True, **kwargs): def get_base(self): return ImageNetTrain() + class ImageNetScaleValidation(ImageNetScale): def get_base(self): return ImageNetValidation() -from skimage.feature import canny from skimage.color import rgb2gray +from skimage.feature import canny class ImageNetEdges(ImageNetScale): @@ -530,14 +526,14 @@ def __init__(self, up_factor=1, **kwargs): def __getitem__(self, i): example = self.base[i] image = example["image"] - h,w,c = image.shape - if self.crop_size and min(h,w) < self.crop_size: + h, w, c = image.shape + if self.crop_size and min(h, w) < self.crop_size: # have to upscale to be able to crop - this just uses bilinear image = self.rescaler(image=image)["image"] lr = canny(rgb2gray(image), sigma=2) lr = lr.astype(np.float32) - lr = lr[:,:,None][:,:,[0,0,0]] + lr = lr[:, :, None][:, :, [0, 0, 0]] out = self.preprocessor(image=image, lr=lr) example["image"] = out["image"] @@ -553,6 +549,7 @@ def __init__(self, random_crop=True, **kwargs): def get_base(self): return ImageNetTrain() + class ImageNetEdgesValidation(ImageNetEdges): def get_base(self): return ImageNetValidation() diff --git a/src/metr/taming/data/open_images_helper.py b/src/metr/taming/data/open_images_helper.py index 8feb7c6..8aa7d31 100644 --- a/src/metr/taming/data/open_images_helper.py +++ b/src/metr/taming/data/open_images_helper.py @@ -1,379 +1,372 @@ open_images_unify_categories_for_coco = { - '/m/03bt1vf': '/m/01g317', - '/m/04yx4': '/m/01g317', - '/m/05r655': '/m/01g317', - '/m/01bl7v': '/m/01g317', - '/m/0cnyhnx': '/m/01xq0k1', - '/m/01226z': '/m/018xm', - '/m/05ctyq': '/m/018xm', - '/m/058qzx': '/m/04ctx', - '/m/06pcq': '/m/0l515', - '/m/03m3pdh': '/m/02crq1', - '/m/046dlr': '/m/01x3z', - '/m/0h8mzrc': '/m/01x3z', + "/m/03bt1vf": "/m/01g317", + "/m/04yx4": "/m/01g317", + "/m/05r655": "/m/01g317", + "/m/01bl7v": "/m/01g317", + "/m/0cnyhnx": "/m/01xq0k1", + "/m/01226z": "/m/018xm", + "/m/05ctyq": "/m/018xm", + "/m/058qzx": "/m/04ctx", + "/m/06pcq": "/m/0l515", + "/m/03m3pdh": "/m/02crq1", + "/m/046dlr": "/m/01x3z", + "/m/0h8mzrc": "/m/01x3z", } top_300_classes_plus_coco_compatibility = [ - ('Man', 1060962), - ('Clothing', 986610), - ('Tree', 748162), - ('Woman', 611896), - ('Person', 610294), - ('Human face', 442948), - ('Girl', 175399), - ('Building', 162147), - ('Car', 159135), - ('Plant', 155704), - ('Human body', 137073), - ('Flower', 133128), - ('Window', 127485), - ('Human arm', 118380), - ('House', 114365), - ('Wheel', 111684), - ('Suit', 99054), - ('Human hair', 98089), - ('Human head', 92763), - ('Chair', 88624), - ('Boy', 79849), - ('Table', 73699), - ('Jeans', 57200), - ('Tire', 55725), - ('Skyscraper', 53321), - ('Food', 52400), - ('Footwear', 50335), - ('Dress', 50236), - ('Human leg', 47124), - ('Toy', 46636), - ('Tower', 45605), - ('Boat', 43486), - ('Land vehicle', 40541), - ('Bicycle wheel', 34646), - ('Palm tree', 33729), - ('Fashion accessory', 32914), - ('Glasses', 31940), - ('Bicycle', 31409), - ('Furniture', 30656), - ('Sculpture', 29643), - ('Bottle', 27558), - ('Dog', 26980), - ('Snack', 26796), - ('Human hand', 26664), - ('Bird', 25791), - ('Book', 25415), - ('Guitar', 24386), - ('Jacket', 23998), - ('Poster', 22192), - ('Dessert', 21284), - ('Baked goods', 20657), - ('Drink', 19754), - ('Flag', 18588), - ('Houseplant', 18205), - ('Tableware', 17613), - ('Airplane', 17218), - ('Door', 17195), - ('Sports uniform', 17068), - ('Shelf', 16865), - ('Drum', 16612), - ('Vehicle', 16542), - ('Microphone', 15269), - ('Street light', 14957), - ('Cat', 14879), - ('Fruit', 13684), - ('Fast food', 13536), - ('Animal', 12932), - ('Vegetable', 12534), - ('Train', 12358), - ('Horse', 11948), - ('Flowerpot', 11728), - ('Motorcycle', 11621), - ('Fish', 11517), - ('Desk', 11405), - ('Helmet', 10996), - ('Truck', 10915), - ('Bus', 10695), - ('Hat', 10532), - ('Auto part', 10488), - ('Musical instrument', 10303), - ('Sunglasses', 10207), - ('Picture frame', 10096), - ('Sports equipment', 10015), - ('Shorts', 9999), - ('Wine glass', 9632), - ('Duck', 9242), - ('Wine', 9032), - ('Rose', 8781), - ('Tie', 8693), - ('Butterfly', 8436), - ('Beer', 7978), - ('Cabinetry', 7956), - ('Laptop', 7907), - ('Insect', 7497), - ('Goggles', 7363), - ('Shirt', 7098), - ('Dairy Product', 7021), - ('Marine invertebrates', 7014), - ('Cattle', 7006), - ('Trousers', 6903), - ('Van', 6843), - ('Billboard', 6777), - ('Balloon', 6367), - ('Human nose', 6103), - ('Tent', 6073), - ('Camera', 6014), - ('Doll', 6002), - ('Coat', 5951), - ('Mobile phone', 5758), - ('Swimwear', 5729), - ('Strawberry', 5691), - ('Stairs', 5643), - ('Goose', 5599), - ('Umbrella', 5536), - ('Cake', 5508), - ('Sun hat', 5475), - ('Bench', 5310), - ('Bookcase', 5163), - ('Bee', 5140), - ('Computer monitor', 5078), - ('Hiking equipment', 4983), - ('Office building', 4981), - ('Coffee cup', 4748), - ('Curtain', 4685), - ('Plate', 4651), - ('Box', 4621), - ('Tomato', 4595), - ('Coffee table', 4529), - ('Office supplies', 4473), - ('Maple', 4416), - ('Muffin', 4365), - ('Cocktail', 4234), - ('Castle', 4197), - ('Couch', 4134), - ('Pumpkin', 3983), - ('Computer keyboard', 3960), - ('Human mouth', 3926), - ('Christmas tree', 3893), - ('Mushroom', 3883), - ('Swimming pool', 3809), - ('Pastry', 3799), - ('Lavender (Plant)', 3769), - ('Football helmet', 3732), - ('Bread', 3648), - ('Traffic sign', 3628), - ('Common sunflower', 3597), - ('Television', 3550), - ('Bed', 3525), - ('Cookie', 3485), - ('Fountain', 3484), - ('Paddle', 3447), - ('Bicycle helmet', 3429), - ('Porch', 3420), - ('Deer', 3387), - ('Fedora', 3339), - ('Canoe', 3338), - ('Carnivore', 3266), - ('Bowl', 3202), - ('Human eye', 3166), - ('Ball', 3118), - ('Pillow', 3077), - ('Salad', 3061), - ('Beetle', 3060), - ('Orange', 3050), - ('Drawer', 2958), - ('Platter', 2937), - ('Elephant', 2921), - ('Seafood', 2921), - ('Monkey', 2915), - ('Countertop', 2879), - ('Watercraft', 2831), - ('Helicopter', 2805), - ('Kitchen appliance', 2797), - ('Personal flotation device', 2781), - ('Swan', 2739), - ('Lamp', 2711), - ('Boot', 2695), - ('Bronze sculpture', 2693), - ('Chicken', 2677), - ('Taxi', 2643), - ('Juice', 2615), - ('Cowboy hat', 2604), - ('Apple', 2600), - ('Tin can', 2590), - ('Necklace', 2564), - ('Ice cream', 2560), - ('Human beard', 2539), - ('Coin', 2536), - ('Candle', 2515), - ('Cart', 2512), - ('High heels', 2441), - ('Weapon', 2433), - ('Handbag', 2406), - ('Penguin', 2396), - ('Rifle', 2352), - ('Violin', 2336), - ('Skull', 2304), - ('Lantern', 2285), - ('Scarf', 2269), - ('Saucer', 2225), - ('Sheep', 2215), - ('Vase', 2189), - ('Lily', 2180), - ('Mug', 2154), - ('Parrot', 2140), - ('Human ear', 2137), - ('Sandal', 2115), - ('Lizard', 2100), - ('Kitchen & dining room table', 2063), - ('Spider', 1977), - ('Coffee', 1974), - ('Goat', 1926), - ('Squirrel', 1922), - ('Cello', 1913), - ('Sushi', 1881), - ('Tortoise', 1876), - ('Pizza', 1870), - ('Studio couch', 1864), - ('Barrel', 1862), - ('Cosmetics', 1841), - ('Moths and butterflies', 1841), - ('Convenience store', 1817), - ('Watch', 1792), - ('Home appliance', 1786), - ('Harbor seal', 1780), - ('Luggage and bags', 1756), - ('Vehicle registration plate', 1754), - ('Shrimp', 1751), - ('Jellyfish', 1730), - ('French fries', 1723), - ('Egg (Food)', 1698), - ('Football', 1697), - ('Musical keyboard', 1683), - ('Falcon', 1674), - ('Candy', 1660), - ('Medical equipment', 1654), - ('Eagle', 1651), - ('Dinosaur', 1634), - ('Surfboard', 1630), - ('Tank', 1628), - ('Grape', 1624), - ('Lion', 1624), - ('Owl', 1622), - ('Ski', 1613), - ('Waste container', 1606), - ('Frog', 1591), - ('Sparrow', 1585), - ('Rabbit', 1581), - ('Pen', 1546), - ('Sea lion', 1537), - ('Spoon', 1521), - ('Sink', 1512), - ('Teddy bear', 1507), - ('Bull', 1495), - ('Sofa bed', 1490), - ('Dragonfly', 1479), - ('Brassiere', 1478), - ('Chest of drawers', 1472), - ('Aircraft', 1466), - ('Human foot', 1463), - ('Pig', 1455), - ('Fork', 1454), - ('Antelope', 1438), - ('Tripod', 1427), - ('Tool', 1424), - ('Cheese', 1422), - ('Lemon', 1397), - ('Hamburger', 1393), - ('Dolphin', 1390), - ('Mirror', 1390), - ('Marine mammal', 1387), - ('Giraffe', 1385), - ('Snake', 1368), - ('Gondola', 1364), - ('Wheelchair', 1360), - ('Piano', 1358), - ('Cupboard', 1348), - ('Banana', 1345), - ('Trumpet', 1335), - ('Lighthouse', 1333), - ('Invertebrate', 1317), - ('Carrot', 1268), - ('Sock', 1260), - ('Tiger', 1241), - ('Camel', 1224), - ('Parachute', 1224), - ('Bathroom accessory', 1223), - ('Earrings', 1221), - ('Headphones', 1218), - ('Skirt', 1198), - ('Skateboard', 1190), - ('Sandwich', 1148), - ('Saxophone', 1141), - ('Goldfish', 1136), - ('Stool', 1104), - ('Traffic light', 1097), - ('Shellfish', 1081), - ('Backpack', 1079), - ('Sea turtle', 1078), - ('Cucumber', 1075), - ('Tea', 1051), - ('Toilet', 1047), - ('Roller skates', 1040), - ('Mule', 1039), - ('Bust', 1031), - ('Broccoli', 1030), - ('Crab', 1020), - ('Oyster', 1019), - ('Cannon', 1012), - ('Zebra', 1012), - ('French horn', 1008), - ('Grapefruit', 998), - ('Whiteboard', 997), - ('Zucchini', 997), - ('Crocodile', 992), - - ('Clock', 960), - ('Wall clock', 958), - - ('Doughnut', 869), - ('Snail', 868), - - ('Baseball glove', 859), - - ('Panda', 830), - ('Tennis racket', 830), - - ('Pear', 652), - - ('Bagel', 617), - ('Oven', 616), - ('Ladybug', 615), - ('Shark', 615), - ('Polar bear', 614), - ('Ostrich', 609), - - ('Hot dog', 473), - ('Microwave oven', 467), - ('Fire hydrant', 20), - ('Stop sign', 20), - ('Parking meter', 20), - ('Bear', 20), - ('Flying disc', 20), - ('Snowboard', 20), - ('Tennis ball', 20), - ('Kite', 20), - ('Baseball bat', 20), - ('Kitchen knife', 20), - ('Knife', 20), - ('Submarine sandwich', 20), - ('Computer mouse', 20), - ('Remote control', 20), - ('Toaster', 20), - ('Sink', 20), - ('Refrigerator', 20), - ('Alarm clock', 20), - ('Wall clock', 20), - ('Scissors', 20), - ('Hair dryer', 20), - ('Toothbrush', 20), - ('Suitcase', 20) + ("Man", 1060962), + ("Clothing", 986610), + ("Tree", 748162), + ("Woman", 611896), + ("Person", 610294), + ("Human face", 442948), + ("Girl", 175399), + ("Building", 162147), + ("Car", 159135), + ("Plant", 155704), + ("Human body", 137073), + ("Flower", 133128), + ("Window", 127485), + ("Human arm", 118380), + ("House", 114365), + ("Wheel", 111684), + ("Suit", 99054), + ("Human hair", 98089), + ("Human head", 92763), + ("Chair", 88624), + ("Boy", 79849), + ("Table", 73699), + ("Jeans", 57200), + ("Tire", 55725), + ("Skyscraper", 53321), + ("Food", 52400), + ("Footwear", 50335), + ("Dress", 50236), + ("Human leg", 47124), + ("Toy", 46636), + ("Tower", 45605), + ("Boat", 43486), + ("Land vehicle", 40541), + ("Bicycle wheel", 34646), + ("Palm tree", 33729), + ("Fashion accessory", 32914), + ("Glasses", 31940), + ("Bicycle", 31409), + ("Furniture", 30656), + ("Sculpture", 29643), + ("Bottle", 27558), + ("Dog", 26980), + ("Snack", 26796), + ("Human hand", 26664), + ("Bird", 25791), + ("Book", 25415), + ("Guitar", 24386), + ("Jacket", 23998), + ("Poster", 22192), + ("Dessert", 21284), + ("Baked goods", 20657), + ("Drink", 19754), + ("Flag", 18588), + ("Houseplant", 18205), + ("Tableware", 17613), + ("Airplane", 17218), + ("Door", 17195), + ("Sports uniform", 17068), + ("Shelf", 16865), + ("Drum", 16612), + ("Vehicle", 16542), + ("Microphone", 15269), + ("Street light", 14957), + ("Cat", 14879), + ("Fruit", 13684), + ("Fast food", 13536), + ("Animal", 12932), + ("Vegetable", 12534), + ("Train", 12358), + ("Horse", 11948), + ("Flowerpot", 11728), + ("Motorcycle", 11621), + ("Fish", 11517), + ("Desk", 11405), + ("Helmet", 10996), + ("Truck", 10915), + ("Bus", 10695), + ("Hat", 10532), + ("Auto part", 10488), + ("Musical instrument", 10303), + ("Sunglasses", 10207), + ("Picture frame", 10096), + ("Sports equipment", 10015), + ("Shorts", 9999), + ("Wine glass", 9632), + ("Duck", 9242), + ("Wine", 9032), + ("Rose", 8781), + ("Tie", 8693), + ("Butterfly", 8436), + ("Beer", 7978), + ("Cabinetry", 7956), + ("Laptop", 7907), + ("Insect", 7497), + ("Goggles", 7363), + ("Shirt", 7098), + ("Dairy Product", 7021), + ("Marine invertebrates", 7014), + ("Cattle", 7006), + ("Trousers", 6903), + ("Van", 6843), + ("Billboard", 6777), + ("Balloon", 6367), + ("Human nose", 6103), + ("Tent", 6073), + ("Camera", 6014), + ("Doll", 6002), + ("Coat", 5951), + ("Mobile phone", 5758), + ("Swimwear", 5729), + ("Strawberry", 5691), + ("Stairs", 5643), + ("Goose", 5599), + ("Umbrella", 5536), + ("Cake", 5508), + ("Sun hat", 5475), + ("Bench", 5310), + ("Bookcase", 5163), + ("Bee", 5140), + ("Computer monitor", 5078), + ("Hiking equipment", 4983), + ("Office building", 4981), + ("Coffee cup", 4748), + ("Curtain", 4685), + ("Plate", 4651), + ("Box", 4621), + ("Tomato", 4595), + ("Coffee table", 4529), + ("Office supplies", 4473), + ("Maple", 4416), + ("Muffin", 4365), + ("Cocktail", 4234), + ("Castle", 4197), + ("Couch", 4134), + ("Pumpkin", 3983), + ("Computer keyboard", 3960), + ("Human mouth", 3926), + ("Christmas tree", 3893), + ("Mushroom", 3883), + ("Swimming pool", 3809), + ("Pastry", 3799), + ("Lavender (Plant)", 3769), + ("Football helmet", 3732), + ("Bread", 3648), + ("Traffic sign", 3628), + ("Common sunflower", 3597), + ("Television", 3550), + ("Bed", 3525), + ("Cookie", 3485), + ("Fountain", 3484), + ("Paddle", 3447), + ("Bicycle helmet", 3429), + ("Porch", 3420), + ("Deer", 3387), + ("Fedora", 3339), + ("Canoe", 3338), + ("Carnivore", 3266), + ("Bowl", 3202), + ("Human eye", 3166), + ("Ball", 3118), + ("Pillow", 3077), + ("Salad", 3061), + ("Beetle", 3060), + ("Orange", 3050), + ("Drawer", 2958), + ("Platter", 2937), + ("Elephant", 2921), + ("Seafood", 2921), + ("Monkey", 2915), + ("Countertop", 2879), + ("Watercraft", 2831), + ("Helicopter", 2805), + ("Kitchen appliance", 2797), + ("Personal flotation device", 2781), + ("Swan", 2739), + ("Lamp", 2711), + ("Boot", 2695), + ("Bronze sculpture", 2693), + ("Chicken", 2677), + ("Taxi", 2643), + ("Juice", 2615), + ("Cowboy hat", 2604), + ("Apple", 2600), + ("Tin can", 2590), + ("Necklace", 2564), + ("Ice cream", 2560), + ("Human beard", 2539), + ("Coin", 2536), + ("Candle", 2515), + ("Cart", 2512), + ("High heels", 2441), + ("Weapon", 2433), + ("Handbag", 2406), + ("Penguin", 2396), + ("Rifle", 2352), + ("Violin", 2336), + ("Skull", 2304), + ("Lantern", 2285), + ("Scarf", 2269), + ("Saucer", 2225), + ("Sheep", 2215), + ("Vase", 2189), + ("Lily", 2180), + ("Mug", 2154), + ("Parrot", 2140), + ("Human ear", 2137), + ("Sandal", 2115), + ("Lizard", 2100), + ("Kitchen & dining room table", 2063), + ("Spider", 1977), + ("Coffee", 1974), + ("Goat", 1926), + ("Squirrel", 1922), + ("Cello", 1913), + ("Sushi", 1881), + ("Tortoise", 1876), + ("Pizza", 1870), + ("Studio couch", 1864), + ("Barrel", 1862), + ("Cosmetics", 1841), + ("Moths and butterflies", 1841), + ("Convenience store", 1817), + ("Watch", 1792), + ("Home appliance", 1786), + ("Harbor seal", 1780), + ("Luggage and bags", 1756), + ("Vehicle registration plate", 1754), + ("Shrimp", 1751), + ("Jellyfish", 1730), + ("French fries", 1723), + ("Egg (Food)", 1698), + ("Football", 1697), + ("Musical keyboard", 1683), + ("Falcon", 1674), + ("Candy", 1660), + ("Medical equipment", 1654), + ("Eagle", 1651), + ("Dinosaur", 1634), + ("Surfboard", 1630), + ("Tank", 1628), + ("Grape", 1624), + ("Lion", 1624), + ("Owl", 1622), + ("Ski", 1613), + ("Waste container", 1606), + ("Frog", 1591), + ("Sparrow", 1585), + ("Rabbit", 1581), + ("Pen", 1546), + ("Sea lion", 1537), + ("Spoon", 1521), + ("Sink", 1512), + ("Teddy bear", 1507), + ("Bull", 1495), + ("Sofa bed", 1490), + ("Dragonfly", 1479), + ("Brassiere", 1478), + ("Chest of drawers", 1472), + ("Aircraft", 1466), + ("Human foot", 1463), + ("Pig", 1455), + ("Fork", 1454), + ("Antelope", 1438), + ("Tripod", 1427), + ("Tool", 1424), + ("Cheese", 1422), + ("Lemon", 1397), + ("Hamburger", 1393), + ("Dolphin", 1390), + ("Mirror", 1390), + ("Marine mammal", 1387), + ("Giraffe", 1385), + ("Snake", 1368), + ("Gondola", 1364), + ("Wheelchair", 1360), + ("Piano", 1358), + ("Cupboard", 1348), + ("Banana", 1345), + ("Trumpet", 1335), + ("Lighthouse", 1333), + ("Invertebrate", 1317), + ("Carrot", 1268), + ("Sock", 1260), + ("Tiger", 1241), + ("Camel", 1224), + ("Parachute", 1224), + ("Bathroom accessory", 1223), + ("Earrings", 1221), + ("Headphones", 1218), + ("Skirt", 1198), + ("Skateboard", 1190), + ("Sandwich", 1148), + ("Saxophone", 1141), + ("Goldfish", 1136), + ("Stool", 1104), + ("Traffic light", 1097), + ("Shellfish", 1081), + ("Backpack", 1079), + ("Sea turtle", 1078), + ("Cucumber", 1075), + ("Tea", 1051), + ("Toilet", 1047), + ("Roller skates", 1040), + ("Mule", 1039), + ("Bust", 1031), + ("Broccoli", 1030), + ("Crab", 1020), + ("Oyster", 1019), + ("Cannon", 1012), + ("Zebra", 1012), + ("French horn", 1008), + ("Grapefruit", 998), + ("Whiteboard", 997), + ("Zucchini", 997), + ("Crocodile", 992), + ("Clock", 960), + ("Wall clock", 958), + ("Doughnut", 869), + ("Snail", 868), + ("Baseball glove", 859), + ("Panda", 830), + ("Tennis racket", 830), + ("Pear", 652), + ("Bagel", 617), + ("Oven", 616), + ("Ladybug", 615), + ("Shark", 615), + ("Polar bear", 614), + ("Ostrich", 609), + ("Hot dog", 473), + ("Microwave oven", 467), + ("Fire hydrant", 20), + ("Stop sign", 20), + ("Parking meter", 20), + ("Bear", 20), + ("Flying disc", 20), + ("Snowboard", 20), + ("Tennis ball", 20), + ("Kite", 20), + ("Baseball bat", 20), + ("Kitchen knife", 20), + ("Knife", 20), + ("Submarine sandwich", 20), + ("Computer mouse", 20), + ("Remote control", 20), + ("Toaster", 20), + ("Sink", 20), + ("Refrigerator", 20), + ("Alarm clock", 20), + ("Wall clock", 20), + ("Scissors", 20), + ("Hair dryer", 20), + ("Toothbrush", 20), + ("Suitcase", 20), ] diff --git a/src/metr/taming/data/sflckr.py b/src/metr/taming/data/sflckr.py index 91101be..f49e450 100644 --- a/src/metr/taming/data/sflckr.py +++ b/src/metr/taming/data/sflckr.py @@ -1,17 +1,24 @@ import os -import numpy as np -import cv2 + import albumentations +import cv2 +import numpy as np from PIL import Image from torch.utils.data import Dataset class SegmentationBase(Dataset): - def __init__(self, - data_csv, data_root, segmentation_root, - size=None, random_crop=False, interpolation="bicubic", - n_labels=182, shift_segmentation=False, - ): + def __init__( + self, + data_csv, + data_root, + segmentation_root, + size=None, + random_crop=False, + interpolation="bicubic", + n_labels=182, + shift_segmentation=False, + ): self.n_labels = n_labels self.shift_segmentation = shift_segmentation self.data_csv = data_csv @@ -22,13 +29,13 @@ def __init__(self, self._length = len(self.image_paths) self.labels = { "relative_file_path_": [l for l in self.image_paths], - "file_path_": [os.path.join(self.data_root, l) - for l in self.image_paths], - "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png")) - for l in self.image_paths] + "file_path_": [os.path.join(self.data_root, l) for l in self.image_paths], + "segmentation_path_": [ + os.path.join(self.segmentation_root, l.replace(".jpg", ".png")) for l in self.image_paths + ], } - size = None if size is not None and size<=0 else size + size = None if size is not None and size <= 0 else size self.size = size if self.size is not None: self.interpolation = interpolation @@ -37,11 +44,12 @@ def __init__(self, "bilinear": cv2.INTER_LINEAR, "bicubic": cv2.INTER_CUBIC, "area": cv2.INTER_AREA, - "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] - self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, - interpolation=self.interpolation) - self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, - interpolation=cv2.INTER_NEAREST) + "lanczos": cv2.INTER_LANCZOS4, + }[self.interpolation] + self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, interpolation=self.interpolation) + self.segmentation_rescaler = albumentations.SmallestMaxSize( + max_size=self.size, interpolation=cv2.INTER_NEAREST + ) self.center_crop = not random_crop if self.center_crop: self.cropper = albumentations.CenterCrop(height=self.size, width=self.size) @@ -65,18 +73,14 @@ def __getitem__(self, i): segmentation = np.array(segmentation).astype(np.uint8) if self.shift_segmentation: # used to support segmentations containing unlabeled==255 label - segmentation = segmentation+1 + segmentation = segmentation + 1 if self.size is not None: segmentation = self.segmentation_rescaler(image=segmentation)["image"] if self.size is not None: - processed = self.preprocessor(image=image, - mask=segmentation - ) + processed = self.preprocessor(image=image, mask=segmentation) else: - processed = {"image": image, - "mask": segmentation - } - example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) + processed = {"image": image, "mask": segmentation} + example["image"] = (processed["image"] / 127.5 - 1.0).astype(np.float32) segmentation = processed["mask"] onehot = np.eye(self.n_labels)[segmentation] example["segmentation"] = onehot @@ -85,7 +89,11 @@ def __getitem__(self, i): class Examples(SegmentationBase): def __init__(self, size=None, random_crop=False, interpolation="bicubic"): - super().__init__(data_csv="data/sflckr_examples.txt", - data_root="data/sflckr_images", - segmentation_root="data/sflckr_segmentations", - size=size, random_crop=random_crop, interpolation=interpolation) + super().__init__( + data_csv="data/sflckr_examples.txt", + data_root="data/sflckr_images", + segmentation_root="data/sflckr_segmentations", + size=size, + random_crop=random_crop, + interpolation=interpolation, + ) diff --git a/src/metr/taming/data/utils.py b/src/metr/taming/data/utils.py index 2b3c3d5..9dffe55 100644 --- a/src/metr/taming/data/utils.py +++ b/src/metr/taming/data/utils.py @@ -9,7 +9,7 @@ import torch from taming.data.helper_types import Annotation from torch._six import string_classes -from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format +from torch.utils.data._utils.collate import default_collate_err_msg_format, np_str_obj_array_pattern from tqdm import tqdm @@ -24,9 +24,7 @@ def unpack(path): with zipfile.ZipFile(path, "r") as f: f.extractall(path=os.path.split(path)[0]) else: - raise NotImplementedError( - "Unknown file extension: {}".format(os.path.splitext(path)[1]) - ) + raise NotImplementedError("Unknown file extension: {}".format(os.path.splitext(path)[1])) def reporthook(bar): @@ -58,19 +56,11 @@ def mark_prepared(root): def prompt_download(file_, source, target_dir, content_dir=None): targetpath = os.path.join(target_dir, file_) while not os.path.exists(targetpath): - if content_dir is not None and os.path.exists( - os.path.join(target_dir, content_dir) - ): + if content_dir is not None and os.path.exists(os.path.join(target_dir, content_dir)): break - print( - "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath) - ) + print("Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath)) if content_dir is not None: - print( - "Or place its content into '{}'.".format( - os.path.join(target_dir, content_dir) - ) - ) + print("Or place its content into '{}'.".format(os.path.join(target_dir, content_dir))) input("Press Enter when done...") return targetpath @@ -78,9 +68,7 @@ def prompt_download(file_, source, target_dir, content_dir=None): def download_url(file_, url, target_dir): targetpath = os.path.join(target_dir, file_) os.makedirs(target_dir, exist_ok=True) - with tqdm( - unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_ - ) as bar: + with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_) as bar: urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar)) return targetpath @@ -104,9 +92,7 @@ def quadratic_crop(x, bbox, alpha=1.0): l = int(alpha * max(w, h)) l = max(l, 2) - required_padding = -1 * min( - center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l) - ) + required_padding = -1 * min(center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l)) required_padding = int(np.ceil(required_padding)) if required_padding > 0: padding = [ @@ -122,7 +108,7 @@ def quadratic_crop(x, bbox, alpha=1.0): def custom_collate(batch): - r"""source: pytorch 1.9.0, only one modification to original code """ + r"""source: pytorch 1.9.0, only one modification to original code""" elem = batch[0] elem_type = type(elem) @@ -135,9 +121,8 @@ def custom_collate(batch): storage = elem.storage()._new_shared(numel) out = elem.new(storage) return torch.stack(batch, 0, out=out) - elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ - and elem_type.__name__ != 'string_': - if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': + elif elem_type.__module__ == "numpy" and elem_type.__name__ != "str_" and elem_type.__name__ != "string_": + if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap": # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype.str) is not None: raise TypeError(default_collate_err_msg_format.format(elem.dtype)) @@ -153,7 +138,7 @@ def custom_collate(batch): return batch elif isinstance(elem, collections.abc.Mapping): return {key: custom_collate([d[key] for d in batch]) for key in elem} - elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple + elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple return elem_type(*(custom_collate(samples) for samples in zip(*batch))) if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added return batch # added @@ -162,7 +147,7 @@ def custom_collate(batch): it = iter(batch) elem_size = len(next(it)) if not all(len(elem) == elem_size for elem in it): - raise RuntimeError('each element in list of batch should be of equal size') + raise RuntimeError("each element in list of batch should be of equal size") transposed = zip(*batch) return [custom_collate(samples) for samples in transposed] diff --git a/src/metr/taming/lr_scheduler.py b/src/metr/taming/lr_scheduler.py index e598ed1..d6266c3 100644 --- a/src/metr/taming/lr_scheduler.py +++ b/src/metr/taming/lr_scheduler.py @@ -5,18 +5,20 @@ class LambdaWarmUpCosineScheduler: """ note: use with a base_lr of 1.0 """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): self.lr_warm_up_steps = warm_up_steps self.lr_start = lr_start self.lr_min = lr_min self.lr_max = lr_max self.lr_max_decay_steps = max_decay_steps - self.last_lr = 0. + self.last_lr = 0.0 self.verbosity_interval = verbosity_interval def schedule(self, n): if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") if n < self.lr_warm_up_steps: lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start self.last_lr = lr @@ -24,11 +26,9 @@ def schedule(self, n): else: t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) t = min(t, 1.0) - lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( - 1 + np.cos(t * np.pi)) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi)) self.last_lr = lr return lr def __call__(self, n): return self.schedule(n) - diff --git a/src/metr/taming/models/cond_transformer.py b/src/metr/taming/models/cond_transformer.py index e4c6373..af549a8 100644 --- a/src/metr/taming/models/cond_transformer.py +++ b/src/metr/taming/models/cond_transformer.py @@ -1,8 +1,9 @@ -import os, math +import math +import os + +import pytorch_lightning as pl import torch import torch.nn.functional as F -import pytorch_lightning as pl - from main import instantiate_from_config from taming.modules.util import SOSProvider @@ -14,20 +15,21 @@ def disabled_train(self, mode=True): class Net2NetTransformer(pl.LightningModule): - def __init__(self, - transformer_config, - first_stage_config, - cond_stage_config, - permuter_config=None, - ckpt_path=None, - ignore_keys=[], - first_stage_key="image", - cond_stage_key="depth", - downsample_cond_size=-1, - pkeep=1.0, - sos_token=0, - unconditional=False, - ): + def __init__( + self, + transformer_config, + first_stage_config, + cond_stage_config, + permuter_config=None, + ckpt_path=None, + ignore_keys=[], + first_stage_key="image", + cond_stage_key="depth", + downsample_cond_size=-1, + pkeep=1.0, + sos_token=0, + unconditional=False, + ): super().__init__() self.be_unconditional = unconditional self.sos_token = sos_token @@ -66,8 +68,10 @@ def init_cond_stage_from_ckpt(self, config): print("Using first stage also as cond stage.") self.cond_stage_model = self.first_stage_model elif config == "__is_unconditional__" or self.be_unconditional: - print(f"Using no cond stage. Assuming the training is intended to be unconditional. " - f"Prepending {self.sos_token} as a sos token.") + print( + f"Using no cond stage. Assuming the training is intended to be unconditional. " + f"Prepending {self.sos_token} as a sos token." + ) self.be_unconditional = True self.cond_stage_key = self.first_stage_key self.cond_stage_model = SOSProvider(self.sos_token) @@ -83,11 +87,10 @@ def forward(self, x, c): _, c_indices = self.encode_to_c(c) if self.training and self.pkeep < 1.0: - mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape, - device=z_indices.device)) + mask = torch.bernoulli(self.pkeep * torch.ones(z_indices.shape, device=z_indices.device)) mask = mask.round().to(dtype=torch.int64) r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size) - a_indices = mask*z_indices+(1-mask)*r_indices + a_indices = mask * z_indices + (1 - mask) * r_indices else: a_indices = z_indices @@ -99,29 +102,28 @@ def forward(self, x, c): # make the prediction logits, _ = self.transformer(cz_indices[:, :-1]) # cut off conditioning outputs - output i corresponds to p(z_i | z_{ -1: c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size)) - quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c) + quant_c, _, [_, _, indices] = self.cond_stage_model.encode(c) if len(indices.shape) > 2: indices = indices.view(c.shape[0], -1) return quant_c, indices @@ -184,9 +186,8 @@ def encode_to_c(self, c): @torch.no_grad() def decode_to_img(self, index, zshape): index = self.permuter(index, reverse=True) - bhwc = (zshape[0],zshape[2],zshape[3],zshape[1]) - quant_z = self.first_stage_model.quantize.get_codebook_entry( - index.reshape(-1), shape=bhwc) + bhwc = (zshape[0], zshape[2], zshape[3], zshape[1]) + quant_z = self.first_stage_model.quantize.get_codebook_entry(index.reshape(-1), shape=bhwc) x = self.first_stage_model.decode(quant_z) return x @@ -206,31 +207,40 @@ def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_inte quant_c, c_indices = self.encode_to_c(c) # create a "half"" sample - z_start_indices = z_indices[:,:z_indices.shape[1]//2] - index_sample = self.sample(z_start_indices, c_indices, - steps=z_indices.shape[1]-z_start_indices.shape[1], - temperature=temperature if temperature is not None else 1.0, - sample=True, - top_k=top_k if top_k is not None else 100, - callback=callback if callback is not None else lambda k: None) + z_start_indices = z_indices[:, : z_indices.shape[1] // 2] + index_sample = self.sample( + z_start_indices, + c_indices, + steps=z_indices.shape[1] - z_start_indices.shape[1], + temperature=temperature if temperature is not None else 1.0, + sample=True, + top_k=top_k if top_k is not None else 100, + callback=callback if callback is not None else lambda k: None, + ) x_sample = self.decode_to_img(index_sample, quant_z.shape) # sample z_start_indices = z_indices[:, :0] - index_sample = self.sample(z_start_indices, c_indices, - steps=z_indices.shape[1], - temperature=temperature if temperature is not None else 1.0, - sample=True, - top_k=top_k if top_k is not None else 100, - callback=callback if callback is not None else lambda k: None) + index_sample = self.sample( + z_start_indices, + c_indices, + steps=z_indices.shape[1], + temperature=temperature if temperature is not None else 1.0, + sample=True, + top_k=top_k if top_k is not None else 100, + callback=callback if callback is not None else lambda k: None, + ) x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape) # det sample z_start_indices = z_indices[:, :0] - index_sample = self.sample(z_start_indices, c_indices, - steps=z_indices.shape[1], - sample=False, - callback=callback if callback is not None else lambda k: None) + index_sample = self.sample( + z_start_indices, + c_indices, + steps=z_indices.shape[1], + sample=False, + callback=callback if callback is not None else lambda k: None, + ) x_sample_det = self.decode_to_img(index_sample, quant_z.shape) # reconstruction @@ -316,32 +326,33 @@ def configure_optimizers(self): # separate out all parameters to those that will and won't experience regularizing weight decay decay = set() no_decay = set() - whitelist_weight_modules = (torch.nn.Linear, ) + whitelist_weight_modules = (torch.nn.Linear,) blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) for mn, m in self.transformer.named_modules(): for pn, p in m.named_parameters(): - fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + fpn = "%s.%s" % (mn, pn) if mn else pn # full param name - if pn.endswith('bias'): + if pn.endswith("bias"): # all biases will not be decayed no_decay.add(fpn) - elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): # weights of whitelist modules will be weight decayed decay.add(fpn) - elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): # weights of blacklist modules will NOT be weight decayed no_decay.add(fpn) # special case the position embedding parameter in the root GPT module as not decayed - no_decay.add('pos_emb') + no_decay.add("pos_emb") # validate that we considered every parameter param_dict = {pn: p for pn, p in self.transformer.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay - assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) - assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ - % (str(param_dict.keys() - union_params), ) + assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) + assert ( + len(param_dict.keys() - union_params) == 0 + ), "parameters %s were not separated into either decay/no_decay set!" % (str(param_dict.keys() - union_params),) # create the pytorch optimizer object optim_groups = [ diff --git a/src/metr/taming/models/vqgan.py b/src/metr/taming/models/vqgan.py index a6950ba..0afb660 100644 --- a/src/metr/taming/models/vqgan.py +++ b/src/metr/taming/models/vqgan.py @@ -1,42 +1,40 @@ +import pytorch_lightning as pl import torch import torch.nn.functional as F -import pytorch_lightning as pl - from main import instantiate_from_config - -from taming.modules.diffusionmodules.model import Encoder, Decoder +from taming.modules.diffusionmodules.model import Decoder, Encoder +from taming.modules.vqvae.quantize import EMAVectorQuantizer, GumbelQuantize from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer -from taming.modules.vqvae.quantize import GumbelQuantize -from taming.modules.vqvae.quantize import EMAVectorQuantizer + class VQModel(pl.LightningModule): - def __init__(self, - ddconfig, - lossconfig, - n_embed, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - ): + def __init__( + self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ): super().__init__() self.image_key = image_key self.encoder = Encoder(**ddconfig) self.decoder = Decoder(**ddconfig) self.loss = instantiate_from_config(lossconfig) - self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, - remap=remap, sane_index_shape=sane_index_shape) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape) self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.image_key = image_key if colorize_nlabels is not None: - assert type(colorize_nlabels)==int + assert type(colorize_nlabels) == int self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) if monitor is not None: self.monitor = monitor @@ -86,8 +84,9 @@ def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 0: # autoencode - aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") + aeloss, log_dict_ae = self.loss( + qloss, x, xrec, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train" + ) self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) @@ -95,8 +94,9 @@ def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 1: # discriminator - discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") + discloss, log_dict_disc = self.loss( + qloss, x, xrec, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train" + ) self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) return discloss @@ -104,30 +104,32 @@ def training_step(self, batch, batch_idx, optimizer_idx): def validation_step(self, batch, batch_idx): x = self.get_input(batch, self.image_key) xrec, qloss = self(x) - aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, - last_layer=self.get_last_layer(), split="val") + aeloss, log_dict_ae = self.loss( + qloss, x, xrec, 0, self.global_step, last_layer=self.get_last_layer(), split="val" + ) - discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, - last_layer=self.get_last_layer(), split="val") + discloss, log_dict_disc = self.loss( + qloss, x, xrec, 1, self.global_step, last_layer=self.get_last_layer(), split="val" + ) rec_loss = log_dict_ae["val/rec_loss"] - self.log("val/rec_loss", rec_loss, - prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) - self.log("val/aeloss", aeloss, - prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) + self.log("val/rec_loss", rec_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) + self.log("val/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) self.log_dict(log_dict_ae) self.log_dict(log_dict_disc) return self.log_dict def configure_optimizers(self): lr = self.learning_rate - opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ - list(self.decoder.parameters())+ - list(self.quantize.parameters())+ - list(self.quant_conv.parameters())+ - list(self.post_quant_conv.parameters()), - lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr, betas=(0.5, 0.9)) + opt_ae = torch.optim.Adam( + list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quantize.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()), + lr=lr, + betas=(0.5, 0.9), + ) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) return [opt_ae, opt_disc], [] def get_last_layer(self): @@ -152,7 +154,7 @@ def to_rgb(self, x): if not hasattr(self, "colorize"): self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) x = F.conv2d(x, weight=self.colorize) - x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 return x @@ -163,12 +165,15 @@ def __init__(self, n_labels, *args, **kwargs): def configure_optimizers(self): lr = self.learning_rate - opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ - list(self.decoder.parameters())+ - list(self.quantize.parameters())+ - list(self.quant_conv.parameters())+ - list(self.post_quant_conv.parameters()), - lr=lr, betas=(0.5, 0.9)) + opt_ae = torch.optim.Adam( + list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quantize.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()), + lr=lr, + betas=(0.5, 0.9), + ) return opt_ae def training_step(self, batch, batch_idx): @@ -184,8 +189,7 @@ def validation_step(self, batch, batch_idx): aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val") self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) total_loss = log_dict_ae["val/total_loss"] - self.log("val/total_loss", total_loss, - prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) + self.log("val/total_loss", total_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) return aeloss @torch.no_grad() @@ -209,19 +213,27 @@ def log_images(self, batch, **kwargs): class VQNoDiscModel(VQModel): - def __init__(self, - ddconfig, - lossconfig, - n_embed, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None - ): - super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim, - ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key, - colorize_nlabels=colorize_nlabels) + def __init__( + self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + ): + super().__init__( + ddconfig=ddconfig, + lossconfig=lossconfig, + n_embed=n_embed, + embed_dim=embed_dim, + ckpt_path=ckpt_path, + ignore_keys=ignore_keys, + image_key=image_key, + colorize_nlabels=colorize_nlabels, + ) def training_step(self, batch, batch_idx): x = self.get_input(batch, self.image_key) @@ -229,8 +241,7 @@ def training_step(self, batch, batch_idx): # autoencode aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train") output = pl.TrainResult(minimize=aeloss) - output.log("train/aeloss", aeloss, - prog_bar=True, logger=True, on_step=True, on_epoch=True) + output.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) return output @@ -240,61 +251,63 @@ def validation_step(self, batch, batch_idx): aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val") rec_loss = log_dict_ae["val/rec_loss"] output = pl.EvalResult(checkpoint_on=rec_loss) - output.log("val/rec_loss", rec_loss, - prog_bar=True, logger=True, on_step=True, on_epoch=True) - output.log("val/aeloss", aeloss, - prog_bar=True, logger=True, on_step=True, on_epoch=True) + output.log("val/rec_loss", rec_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + output.log("val/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) output.log_dict(log_dict_ae) return output def configure_optimizers(self): - optimizer = torch.optim.Adam(list(self.encoder.parameters())+ - list(self.decoder.parameters())+ - list(self.quantize.parameters())+ - list(self.quant_conv.parameters())+ - list(self.post_quant_conv.parameters()), - lr=self.learning_rate, betas=(0.5, 0.9)) + optimizer = torch.optim.Adam( + list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quantize.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()), + lr=self.learning_rate, + betas=(0.5, 0.9), + ) return optimizer class GumbelVQ(VQModel): - def __init__(self, - ddconfig, - lossconfig, - n_embed, - embed_dim, - temperature_scheduler_config, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - kl_weight=1e-8, - remap=None, - ): + def __init__( + self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + temperature_scheduler_config, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + kl_weight=1e-8, + remap=None, + ): z_channels = ddconfig["z_channels"] - super().__init__(ddconfig, - lossconfig, - n_embed, - embed_dim, - ckpt_path=None, - ignore_keys=ignore_keys, - image_key=image_key, - colorize_nlabels=colorize_nlabels, - monitor=monitor, - ) + super().__init__( + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=ignore_keys, + image_key=image_key, + colorize_nlabels=colorize_nlabels, + monitor=monitor, + ) self.loss.n_classes = n_embed self.vocab_size = n_embed - self.quantize = GumbelQuantize(z_channels, embed_dim, - n_embed=n_embed, - kl_weight=kl_weight, temp_init=1.0, - remap=remap) + self.quantize = GumbelQuantize( + z_channels, embed_dim, n_embed=n_embed, kl_weight=kl_weight, temp_init=1.0, remap=remap + ) - self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp + self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) @@ -317,8 +330,9 @@ def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 0: # autoencode - aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") + aeloss, log_dict_ae = self.loss( + qloss, x, xrec, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train" + ) self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True) @@ -326,24 +340,25 @@ def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 1: # discriminator - discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") + discloss, log_dict_disc = self.loss( + qloss, x, xrec, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train" + ) self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) return discloss def validation_step(self, batch, batch_idx): x = self.get_input(batch, self.image_key) xrec, qloss = self(x, return_pred_indices=True) - aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, - last_layer=self.get_last_layer(), split="val") + aeloss, log_dict_ae = self.loss( + qloss, x, xrec, 0, self.global_step, last_layer=self.get_last_layer(), split="val" + ) - discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, - last_layer=self.get_last_layer(), split="val") + discloss, log_dict_disc = self.loss( + qloss, x, xrec, 1, self.global_step, last_layer=self.get_last_layer(), split="val" + ) rec_loss = log_dict_ae["val/rec_loss"] - self.log("val/rec_loss", rec_loss, - prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) - self.log("val/aeloss", aeloss, - prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log("val/rec_loss", rec_loss, prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log("val/aeloss", aeloss, prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) self.log_dict(log_dict_ae) self.log_dict(log_dict_disc) return self.log_dict @@ -364,41 +379,43 @@ def log_images(self, batch, **kwargs): class EMAVQ(VQModel): - def __init__(self, - ddconfig, - lossconfig, - n_embed, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - ): - super().__init__(ddconfig, - lossconfig, - n_embed, - embed_dim, - ckpt_path=None, - ignore_keys=ignore_keys, - image_key=image_key, - colorize_nlabels=colorize_nlabels, - monitor=monitor, - ) - self.quantize = EMAVectorQuantizer(n_embed=n_embed, - embedding_dim=embed_dim, - beta=0.25, - remap=remap) + def __init__( + self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ): + super().__init__( + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=ignore_keys, + image_key=image_key, + colorize_nlabels=colorize_nlabels, + monitor=monitor, + ) + self.quantize = EMAVectorQuantizer(n_embed=n_embed, embedding_dim=embed_dim, beta=0.25, remap=remap) + def configure_optimizers(self): lr = self.learning_rate - #Remove self.quantize from parameter list since it is updated via EMA - opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ - list(self.decoder.parameters())+ - list(self.quant_conv.parameters())+ - list(self.post_quant_conv.parameters()), - lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr, betas=(0.5, 0.9)) - return [opt_ae, opt_disc], [] \ No newline at end of file + # Remove self.quantize from parameter list since it is updated via EMA + opt_ae = torch.optim.Adam( + list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()), + lr=lr, + betas=(0.5, 0.9), + ) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] diff --git a/src/metr/taming/modules/diffusionmodules/model.py b/src/metr/taming/modules/diffusionmodules/model.py index d3a5db6..7e5e749 100644 --- a/src/metr/taming/modules/diffusionmodules/model.py +++ b/src/metr/taming/modules/diffusionmodules/model.py @@ -1,8 +1,9 @@ # pytorch_diffusion + derived encoder decoder import math + +import numpy as np import torch import torch.nn as nn -import numpy as np def get_timestep_embedding(timesteps, embedding_dim): @@ -22,13 +23,13 @@ def get_timestep_embedding(timesteps, embedding_dim): emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0,1,0,0)) + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb def nonlinearity(x): # swish - return x*torch.sigmoid(x) + return x * torch.sigmoid(x) def Normalize(in_channels): @@ -40,11 +41,7 @@ def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") @@ -59,15 +56,11 @@ def __init__(self, in_channels, with_conv): self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=2, - padding=0) + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) def forward(self, x): if self.with_conv: - pad = (0,1,0,1) + pad = (0, 1, 0, 1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: @@ -76,8 +69,7 @@ def forward(self, x): class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, - dropout, temb_channels=512): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels @@ -85,34 +77,17 @@ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, - out_channels) + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0) + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x, temb): h = x @@ -121,7 +96,7 @@ def forward(self, x, temb): h = self.conv1(h) if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] h = self.norm2(h) h = nonlinearity(h) @@ -134,7 +109,7 @@ def forward(self, x, temb): else: x = self.nin_shortcut(x) - return x+h + return x + h class AttnBlock(nn.Module): @@ -143,27 +118,10 @@ def __init__(self, in_channels): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h_ = x @@ -173,32 +131,43 @@ def forward(self, x): v = self.v(h_) # compute attention - b,c,h,w = q.shape - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c)**(-0.5)) + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values - v = v.reshape(b,c,h*w) - w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b,c,h,w) + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) h_ = self.proj_out(h_) - return x+h_ + return x + h_ class Model(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, use_timestep=True): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True + ): super().__init__() self.ch = ch - self.temb_ch = self.ch*4 + self.temb_ch = self.ch * 4 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution @@ -208,70 +177,69 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, if self.use_timestep: # timestep embedding self.temb = nn.Module() - self.temb.dense = nn.ModuleList([ - torch.nn.Linear(self.ch, - self.temb_ch), - torch.nn.Linear(self.temb_ch, - self.temb_ch), - ]) + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) + in_ch_mult = (1,) + tuple(ch_mult) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(AttnBlock(block_in)) down = nn.Module() down.block = block down.attn = attn - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - skip_in = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): if i_block == self.num_res_blocks: - skip_in = ch*in_ch_mult[i_level] - block.append(ResnetBlock(in_channels=block_in+skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(AttnBlock(block_in)) @@ -281,19 +249,14 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) - + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, x, t=None): - #assert x.shape[2] == x.shape[3] == self.resolution + # assert x.shape[2] == x.shape[3] == self.resolution if self.use_timestep: # timestep embedding @@ -313,7 +276,7 @@ def forward(self, x, t=None): if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle @@ -324,9 +287,8 @@ def forward(self, x, t=None): # upsampling for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block]( - torch.cat([h, hs.pop()], dim=1), temb) + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: @@ -340,9 +302,22 @@ def forward(self, x, t=None): class Encoder(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, double_z=True, **ignore_kwargs): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + **ignore_kwargs + ): super().__init__() self.ch = ch self.temb_ch = 0 @@ -352,59 +327,51 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, self.in_channels = in_channels # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) + in_ch_mult = (1,) + tuple(ch_mult) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(AttnBlock(block_in)) down = nn.Module() down.block = block down.attn = attn - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - 2*z_channels if double_z else z_channels, - kernel_size=3, - stride=1, - padding=1) - + self.conv_out = torch.nn.Conv2d( + block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 + ) def forward(self, x): - #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) + # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) # timestep embedding temb = None @@ -417,7 +384,7 @@ def forward(self, x): if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle @@ -434,9 +401,22 @@ def forward(self, x): class Decoder(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, give_pre_end=False, **ignorekwargs): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + **ignorekwargs + ): super().__init__() self.ch = ch self.temb_ch = 0 @@ -447,43 +427,37 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, self.give_pre_end = give_pre_end # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,)+tuple(ch_mult) - block_in = ch*ch_mult[self.num_resolutions-1] - curr_res = resolution // 2**(self.num_resolutions-1) - self.z_shape = (1,z_channels,curr_res,curr_res) - print("Working with z of shape {} = {} dimensions.".format( - self.z_shape, np.prod(self.z_shape))) + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) # z to block_in - self.conv_in = torch.nn.Conv2d(z_channels, - block_in, - kernel_size=3, - stride=1, - padding=1) + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(AttnBlock(block_in)) @@ -493,18 +467,14 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, z): - #assert z.shape[1:] == self.z_shape[1:] + # assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape # timestep embedding @@ -520,7 +490,7 @@ def forward(self, z): # upsampling for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): + for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) @@ -538,13 +508,26 @@ def forward(self, z): class VUNet(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, - in_channels, c_channels, - resolution, z_channels, use_timestep=False, **ignore_kwargs): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + c_channels, + resolution, + z_channels, + use_timestep=False, + **ignore_kwargs + ): super().__init__() self.ch = ch - self.temb_ch = self.ch*4 + self.temb_ch = self.ch * 4 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution @@ -553,75 +536,70 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, if self.use_timestep: # timestep embedding self.temb = nn.Module() - self.temb.dense = nn.ModuleList([ - torch.nn.Linear(self.ch, - self.temb_ch), - torch.nn.Linear(self.temb_ch, - self.temb_ch), - ]) + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) # downsampling - self.conv_in = torch.nn.Conv2d(c_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) + self.conv_in = torch.nn.Conv2d(c_channels, self.ch, kernel_size=3, stride=1, padding=1) curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) + in_ch_mult = (1,) + tuple(ch_mult) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(AttnBlock(block_in)) down = nn.Module() down.block = block down.attn = attn - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) - self.z_in = torch.nn.Conv2d(z_channels, - block_in, - kernel_size=1, - stride=1, - padding=0) + self.z_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=1, stride=1, padding=0) # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=2*block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=2 * block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - skip_in = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): if i_block == self.num_res_blocks: - skip_in = ch*in_ch_mult[i_level] - block.append(ResnetBlock(in_channels=block_in+skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(AttnBlock(block_in)) @@ -631,19 +609,14 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) - + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, x, z): - #assert x.shape[2] == x.shape[3] == self.resolution + # assert x.shape[2] == x.shape[3] == self.resolution if self.use_timestep: # timestep embedding @@ -663,22 +636,21 @@ def forward(self, x, z): if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] z = self.z_in(z) - h = torch.cat((h,z),dim=1) + h = torch.cat((h, z), dim=1) h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) # upsampling for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block]( - torch.cat([h, hs.pop()], dim=1), temb) + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: @@ -694,29 +666,23 @@ def forward(self, x, z): class SimpleDecoder(nn.Module): def __init__(self, in_channels, out_channels, *args, **kwargs): super().__init__() - self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), - ResnetBlock(in_channels=in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=2 * in_channels, - out_channels=4 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=4 * in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - nn.Conv2d(2*in_channels, in_channels, 1), - Upsample(in_channels, with_conv=True)]) + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) # end self.norm_out = Normalize(in_channels) - self.conv_out = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): for i, layer in enumerate(self.model): - if i in [1,2,3]: + if i in [1, 2, 3]: x = layer(x, None) else: x = layer(x) @@ -728,8 +694,7 @@ def forward(self, x): class UpsampleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, - ch_mult=(2,2), dropout=0.0): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0): super().__init__() # upsampling self.temb_ch = 0 @@ -743,10 +708,11 @@ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, res_block = [] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): - res_block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + res_block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) block_in = block_out self.res_blocks.append(nn.ModuleList(res_block)) if i_level != self.num_resolutions - 1: @@ -755,11 +721,7 @@ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): # upsampling @@ -773,4 +735,3 @@ def forward(self, x): h = nonlinearity(h) h = self.conv_out(h) return h - diff --git a/src/metr/taming/modules/discriminator/model.py b/src/metr/taming/modules/discriminator/model.py index 2aaa311..0817acd 100644 --- a/src/metr/taming/modules/discriminator/model.py +++ b/src/metr/taming/modules/discriminator/model.py @@ -1,23 +1,23 @@ import functools -import torch.nn as nn - +import torch.nn as nn from taming.modules.util import ActNorm def weights_init(m): classname = m.__class__.__name__ - if classname.find('Conv') != -1: + if classname.find("Conv") != -1: nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find('BatchNorm') != -1: + elif classname.find("BatchNorm") != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0) class NLayerDiscriminator(nn.Module): """Defines a PatchGAN discriminator as in Pix2Pix - --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py """ + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): """Construct a PatchGAN discriminator Parameters: @@ -43,23 +43,24 @@ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): nf_mult_prev = 1 for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult - nf_mult = min(2 ** n, 8) + nf_mult = min(2**n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), - nn.LeakyReLU(0.2, True) + nn.LeakyReLU(0.2, True), ] nf_mult_prev = nf_mult - nf_mult = min(2 ** n_layers, 8) + nf_mult = min(2**n_layers, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), - nn.LeakyReLU(0.2, True) + nn.LeakyReLU(0.2, True), ] sequence += [ - nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) + ] # output 1 channel prediction map self.main = nn.Sequential(*sequence) def forward(self, input): diff --git a/src/metr/taming/modules/losses/__init__.py b/src/metr/taming/modules/losses/__init__.py index d09caf9..4008db3 100644 --- a/src/metr/taming/modules/losses/__init__.py +++ b/src/metr/taming/modules/losses/__init__.py @@ -1,2 +1 @@ from taming.modules.losses.vqperceptual import DummyLoss - diff --git a/src/metr/taming/modules/losses/lpips.py b/src/metr/taming/modules/losses/lpips.py index a728044..0c7025d 100644 --- a/src/metr/taming/modules/losses/lpips.py +++ b/src/metr/taming/modules/losses/lpips.py @@ -1,11 +1,11 @@ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" -import torch -import torch.nn as nn -from torchvision import models from collections import namedtuple +import torch +import torch.nn as nn from taming.util import get_ckpt_path +from torchvision import models class LPIPS(nn.Module): @@ -57,19 +57,28 @@ def forward(self, input, target): class ScalingLayer(nn.Module): def __init__(self): super(ScalingLayer, self).__init__() - self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) - self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) + self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]) + self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]) def forward(self, inp): return (inp - self.shift) / self.scale class NetLinLayer(nn.Module): - """ A single linear layer which does a 1x1 conv """ + """A single linear layer which does a 1x1 conv""" + def __init__(self, chn_in, chn_out=1, use_dropout=False): super(NetLinLayer, self).__init__() - layers = [nn.Dropout(), ] if (use_dropout) else [] - layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] + layers = ( + [ + nn.Dropout(), + ] + if (use_dropout) + else [] + ) + layers += [ + nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), + ] self.model = nn.Sequential(*layers) @@ -108,16 +117,15 @@ def forward(self, X): h_relu4_3 = h h = self.slice5(h) h_relu5_3 = h - vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]) out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) return out -def normalize_tensor(x,eps=1e-10): - norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) - return x/(norm_factor+eps) +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x / (norm_factor + eps) def spatial_average(x, keepdim=True): - return x.mean([2,3],keepdim=keepdim) - + return x.mean([2, 3], keepdim=keepdim) diff --git a/src/metr/taming/modules/losses/segmentation.py b/src/metr/taming/modules/losses/segmentation.py index 4ba77de..45bc97d 100644 --- a/src/metr/taming/modules/losses/segmentation.py +++ b/src/metr/taming/modules/losses/segmentation.py @@ -4,19 +4,20 @@ class BCELoss(nn.Module): def forward(self, prediction, target): - loss = F.binary_cross_entropy_with_logits(prediction,target) + loss = F.binary_cross_entropy_with_logits(prediction, target) return loss, {} class BCELossWithQuant(nn.Module): - def __init__(self, codebook_weight=1.): + def __init__(self, codebook_weight=1.0): super().__init__() self.codebook_weight = codebook_weight def forward(self, qloss, target, prediction, split): - bce_loss = F.binary_cross_entropy_with_logits(prediction,target) - loss = bce_loss + self.codebook_weight*qloss - return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), - "{}/bce_loss".format(split): bce_loss.detach().mean(), - "{}/quant_loss".format(split): qloss.detach().mean() - } + bce_loss = F.binary_cross_entropy_with_logits(prediction, target) + loss = bce_loss + self.codebook_weight * qloss + return loss, { + "{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/bce_loss".format(split): bce_loss.detach().mean(), + "{}/quant_loss".format(split): qloss.detach().mean(), + } diff --git a/src/metr/taming/modules/losses/vqperceptual.py b/src/metr/taming/modules/losses/vqperceptual.py index c2febd4..3ba7be6 100644 --- a/src/metr/taming/modules/losses/vqperceptual.py +++ b/src/metr/taming/modules/losses/vqperceptual.py @@ -1,9 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F - -from taming.modules.losses.lpips import LPIPS from taming.modules.discriminator.model import NLayerDiscriminator, weights_init +from taming.modules.losses.lpips import LPIPS class DummyLoss(nn.Module): @@ -11,31 +10,42 @@ def __init__(self): super().__init__() -def adopt_weight(weight, global_step, threshold=0, value=0.): +def adopt_weight(weight, global_step, threshold=0, value=0.0): if global_step < threshold: weight = value return weight def hinge_d_loss(logits_real, logits_fake): - loss_real = torch.mean(F.relu(1. - logits_real)) - loss_fake = torch.mean(F.relu(1. + logits_fake)) + loss_real = torch.mean(F.relu(1.0 - logits_real)) + loss_fake = torch.mean(F.relu(1.0 + logits_fake)) d_loss = 0.5 * (loss_real + loss_fake) return d_loss def vanilla_d_loss(logits_real, logits_fake): d_loss = 0.5 * ( - torch.mean(torch.nn.functional.softplus(-logits_real)) + - torch.mean(torch.nn.functional.softplus(logits_fake))) + torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake)) + ) return d_loss class VQLPIPSWithDiscriminator(nn.Module): - def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, - disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, - perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, - disc_ndf=64, disc_loss="hinge"): + def __init__( + self, + disc_start, + codebook_weight=1.0, + pixelloss_weight=1.0, + disc_num_layers=3, + disc_in_channels=3, + disc_factor=1.0, + disc_weight=1.0, + perceptual_weight=1.0, + use_actnorm=False, + disc_conditional=False, + disc_ndf=64, + disc_loss="hinge", + ): super().__init__() assert disc_loss in ["hinge", "vanilla"] self.codebook_weight = codebook_weight @@ -43,11 +53,9 @@ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, self.perceptual_loss = LPIPS().eval() self.perceptual_weight = perceptual_weight - self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, - n_layers=disc_num_layers, - use_actnorm=use_actnorm, - ndf=disc_ndf - ).apply(weights_init) + self.discriminator = NLayerDiscriminator( + input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm, ndf=disc_ndf + ).apply(weights_init) self.discriminator_iter_start = disc_start if disc_loss == "hinge": self.disc_loss = hinge_d_loss @@ -73,8 +81,17 @@ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): d_weight = d_weight * self.discriminator_weight return d_weight - def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, - global_step, last_layer=None, cond=None, split="train"): + def forward( + self, + codebook_loss, + inputs, + reconstructions, + optimizer_idx, + global_step, + last_layer=None, + cond=None, + split="train", + ): rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) if self.perceptual_weight > 0: p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) @@ -83,7 +100,7 @@ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, p_loss = torch.tensor([0.0]) nll_loss = rec_loss - #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] nll_loss = torch.mean(nll_loss) # now the GAN part @@ -106,15 +123,16 @@ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() - log = {"{}/total_loss".format(split): loss.clone().detach().mean(), - "{}/quant_loss".format(split): codebook_loss.detach().mean(), - "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - "{}/p_loss".format(split): p_loss.detach().mean(), - "{}/d_weight".format(split): d_weight.detach(), - "{}/disc_factor".format(split): torch.tensor(disc_factor), - "{}/g_loss".format(split): g_loss.detach().mean(), - } + log = { + "{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/quant_loss".format(split): codebook_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/p_loss".format(split): p_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } return loss, log if optimizer_idx == 1: @@ -129,8 +147,9 @@ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) - log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), - "{}/logits_real".format(split): logits_real.detach().mean(), - "{}/logits_fake".format(split): logits_fake.detach().mean() - } + log = { + "{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean(), + } return d_loss, log diff --git a/src/metr/taming/modules/misc/coord.py b/src/metr/taming/modules/misc/coord.py index ee69b0c..aaebf5f 100644 --- a/src/metr/taming/modules/misc/coord.py +++ b/src/metr/taming/modules/misc/coord.py @@ -1,5 +1,6 @@ import torch + class CoordStage(object): def __init__(self, n_embed, down_factor): self.n_embed = n_embed @@ -11,13 +12,12 @@ def eval(self): def encode(self, c): """fake vqmodel interface""" assert 0.0 <= c.min() and c.max() <= 1.0 - b,ch,h,w = c.shape + b, ch, h, w = c.shape assert ch == 1 - c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, - mode="area") + c = torch.nn.functional.interpolate(c, scale_factor=1 / self.down_factor, mode="area") c = c.clamp(0.0, 1.0) - c = self.n_embed*c + c = self.n_embed * c c_quant = c.round() c_ind = c_quant.to(dtype=torch.long) @@ -25,7 +25,6 @@ def encode(self, c): return c_quant, None, info def decode(self, c): - c = c/self.n_embed - c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, - mode="nearest") + c = c / self.n_embed + c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, mode="nearest") return c diff --git a/src/metr/taming/modules/transformer/mingpt.py b/src/metr/taming/modules/transformer/mingpt.py index d14b7b6..50f79ec 100644 --- a/src/metr/taming/modules/transformer/mingpt.py +++ b/src/metr/taming/modules/transformer/mingpt.py @@ -8,8 +8,8 @@ - the final decoder is a linear projection into a vanilla Softmax classifier """ -import math import logging +import math import torch import torch.nn as nn @@ -20,7 +20,8 @@ class GPTConfig: - """ base GPT config, params common to all GPT versions """ + """base GPT config, params common to all GPT versions""" + embd_pdrop = 0.1 resid_pdrop = 0.1 attn_pdrop = 0.1 @@ -28,12 +29,13 @@ class GPTConfig: def __init__(self, vocab_size, block_size, **kwargs): self.vocab_size = vocab_size self.block_size = block_size - for k,v in kwargs.items(): + for k, v in kwargs.items(): setattr(self, k, v) class GPT1Config(GPTConfig): - """ GPT-1 like network roughly 125M params """ + """GPT-1 like network roughly 125M params""" + n_layer = 12 n_head = 12 n_embd = 768 @@ -59,10 +61,9 @@ def __init__(self, config): # output projection self.proj = nn.Linear(config.n_embd, config.n_embd) # causal mask to ensure that attention is only applied to the left in the input sequence - mask = torch.tril(torch.ones(config.block_size, - config.block_size)) + mask = torch.tril(torch.ones(config.block_size, config.block_size)) if hasattr(config, "n_unmasked"): - mask[:config.n_unmasked, :config.n_unmasked] = 1 + mask[: config.n_unmasked, : config.n_unmasked] = 1 self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size)) self.n_head = config.n_head @@ -70,9 +71,9 @@ def forward(self, x, layer_past=None): B, T, C = x.size() # calculate query, key, values for all heads in batch and move head forward to be the batch dim - k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) present = torch.stack((k, v)) if layer_past is not None: @@ -83,20 +84,21 @@ def forward(self, x, layer_past=None): # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) if layer_past is None: - att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) + att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf")) att = F.softmax(att, dim=-1) att = self.attn_drop(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side # output projection y = self.resid_drop(self.proj(y)) - return y, present # TODO: check that this does not break anything + return y, present # TODO: check that this does not break anything class Block(nn.Module): - """ an unassuming Transformer block """ + """an unassuming Transformer block""" + def __init__(self, config): super().__init__() self.ln1 = nn.LayerNorm(config.n_embd) @@ -111,7 +113,8 @@ def __init__(self, config): def forward(self, x, layer_past=None, return_present=False): # TODO: check that training still works - if return_present: assert not self.training + if return_present: + assert not self.training # layer past: tuple of length two with B, nh, T, hs attn, present = self.attn(self.ln1(x), layer_past=layer_past) @@ -123,14 +126,32 @@ def forward(self, x, layer_past=None, return_present=False): class GPT(nn.Module): - """ the full GPT language model, with a context size of block_size """ - def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256, - embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0): + """the full GPT language model, with a context size of block_size""" + + def __init__( + self, + vocab_size, + block_size, + n_layer=12, + n_head=8, + n_embd=256, + embd_pdrop=0.0, + resid_pdrop=0.0, + attn_pdrop=0.0, + n_unmasked=0, + ): super().__init__() - config = GPTConfig(vocab_size=vocab_size, block_size=block_size, - embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop, - n_layer=n_layer, n_head=n_head, n_embd=n_embd, - n_unmasked=n_unmasked) + config = GPTConfig( + vocab_size=vocab_size, + block_size=block_size, + embd_pdrop=embd_pdrop, + resid_pdrop=resid_pdrop, + attn_pdrop=attn_pdrop, + n_layer=n_layer, + n_head=n_head, + n_embd=n_embd, + n_unmasked=n_unmasked, + ) # input embedding stem self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) @@ -159,14 +180,14 @@ def _init_weights(self, module): def forward(self, idx, embeddings=None, targets=None): # forward the GPT model - token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector + token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector - if embeddings is not None: # prepend explicit embeddings + if embeddings is not None: # prepend explicit embeddings token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) t = token_embeddings.shape[1] assert t <= self.block_size, "Cannot forward, model block size is exhausted." - position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector + position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector x = self.drop(token_embeddings + position_embeddings) x = self.blocks(x) x = self.ln_f(x) @@ -182,19 +203,26 @@ def forward(self, idx, embeddings=None, targets=None): def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None): # inference only assert not self.training - token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector - if embeddings is not None: # prepend explicit embeddings + token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector + if embeddings is not None: # prepend explicit embeddings token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) if past is not None: assert past_length is not None - past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head + past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head past_shape = list(past.shape) - expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head] + expected_shape = [ + self.config.n_layer, + 2, + idx.shape[0], + self.config.n_head, + past_length, + self.config.n_embd // self.config.n_head, + ] assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}" position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector else: - position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :] + position_embeddings = self.pos_emb[:, : token_embeddings.shape[1], :] x = self.drop(token_embeddings + position_embeddings) presents = [] # accumulate over layers @@ -224,13 +252,32 @@ def forward(self, idx): class CodeGPT(nn.Module): """Takes in semi-embeddings""" - def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256, - embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0): + + def __init__( + self, + vocab_size, + block_size, + in_channels, + n_layer=12, + n_head=8, + n_embd=256, + embd_pdrop=0.0, + resid_pdrop=0.0, + attn_pdrop=0.0, + n_unmasked=0, + ): super().__init__() - config = GPTConfig(vocab_size=vocab_size, block_size=block_size, - embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop, - n_layer=n_layer, n_head=n_head, n_embd=n_embd, - n_unmasked=n_unmasked) + config = GPTConfig( + vocab_size=vocab_size, + block_size=block_size, + embd_pdrop=embd_pdrop, + resid_pdrop=resid_pdrop, + attn_pdrop=attn_pdrop, + n_layer=n_layer, + n_head=n_head, + n_embd=n_embd, + n_unmasked=n_unmasked, + ) # input embedding stem self.tok_emb = nn.Linear(in_channels, config.n_embd) self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) @@ -259,14 +306,14 @@ def _init_weights(self, module): def forward(self, idx, embeddings=None, targets=None): # forward the GPT model - token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector + token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector - if embeddings is not None: # prepend explicit embeddings + if embeddings is not None: # prepend explicit embeddings token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) t = token_embeddings.shape[1] assert t <= self.block_size, "Cannot forward, model block size is exhausted." - position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector + position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector x = self.drop(token_embeddings + position_embeddings) x = self.blocks(x) x = self.taming_cinln_f(x) @@ -280,15 +327,16 @@ def forward(self, idx, embeddings=None, targets=None): return logits, loss - #### sampling utils + def top_k_logits(logits, k): v, ix = torch.topk(logits, k) out = logits.clone() - out[out < v[:, [-1]]] = -float('Inf') + out[out < v[:, [-1]]] = -float("Inf") return out + @torch.no_grad() def sample(model, x, steps, temperature=1.0, sample=False, top_k=None): """ @@ -321,8 +369,7 @@ def sample(model, x, steps, temperature=1.0, sample=False, top_k=None): @torch.no_grad() -def sample_with_past(x, model, steps, temperature=1., sample_logits=True, - top_k=None, top_p=None, callback=None): +def sample_with_past(x, model, steps, temperature=1.0, sample_logits=True, top_k=None, top_p=None, callback=None): # x is conditioning sample = x cond_len = x.shape[1] @@ -330,7 +377,7 @@ def sample_with_past(x, model, steps, temperature=1., sample_logits=True, for n in range(steps): if callback is not None: callback(n) - logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1)) + logits, _, present = model.forward_with_past(x, past=past, past_length=(n + cond_len - 1)) if past is None: past = [present] else: @@ -353,15 +400,16 @@ def sample_with_past(x, model, steps, temperature=1., sample_logits=True, #### clustering utils + class KMeans(nn.Module): def __init__(self, ncluster=512, nc=3, niter=10): super().__init__() self.ncluster = ncluster self.nc = nc self.niter = niter - self.shape = (3,32,32) - self.register_buffer("C", torch.zeros(self.ncluster,nc)) - self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) + self.shape = (3, 32, 32) + self.register_buffer("C", torch.zeros(self.ncluster, nc)) + self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) def is_initialized(self): return self.initialized.item() == 1 @@ -370,31 +418,30 @@ def is_initialized(self): def initialize(self, x): N, D = x.shape assert D == self.nc, D - c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random + c = x[torch.randperm(N)[: self.ncluster]] # init clusters at random for i in range(self.niter): # assign all pixels to the closest codebook element - a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1) + a = ((x[:, None, :] - c[None, :, :]) ** 2).sum(-1).argmin(1) # move each codebook element to be the mean of the pixels that assigned to it - c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)]) + c = torch.stack([x[a == k].mean(0) for k in range(self.ncluster)]) # re-assign any poorly positioned codebook elements nanix = torch.any(torch.isnan(c), dim=1) ndead = nanix.sum().item() - print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead)) - c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters + print("done step %d/%d, re-initialized %d dead clusters" % (i + 1, self.niter, ndead)) + c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters self.C.copy_(c) self.initialized.fill_(1) - def forward(self, x, reverse=False, shape=None): if not reverse: # flatten - bs,c,h,w = x.shape + bs, c, h, w = x.shape assert c == self.nc - x = x.reshape(bs,c,h*w,1) - C = self.C.permute(1,0) - C = C.reshape(1,c,1,self.ncluster) - a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices + x = x.reshape(bs, c, h * w, 1) + C = self.C.permute(1, 0) + C = C.reshape(1, c, 1, self.ncluster) + a = ((x - C) ** 2).sum(1).argmin(-1) # bs, h*w indices return a else: # flatten @@ -408,7 +455,7 @@ def forward(self, x, reverse=False, shape=None): x = torch.gather(c, dim=3, index=x) """ x = self.C[x] - x = x.permute(0,2,1) + x = x.permute(0, 2, 1) shape = shape if shape is not None else self.shape x = x.reshape(bs, *shape) diff --git a/src/metr/taming/modules/transformer/permuter.py b/src/metr/taming/modules/transformer/permuter.py index 0d43bb1..224c807 100644 --- a/src/metr/taming/modules/transformer/permuter.py +++ b/src/metr/taming/modules/transformer/permuter.py @@ -1,11 +1,12 @@ +import numpy as np import torch import torch.nn as nn -import numpy as np class AbstractPermuter(nn.Module): def __init__(self, *args, **kwargs): super().__init__() + def forward(self, x, reverse=False): raise NotImplementedError @@ -22,20 +23,18 @@ class Subsample(AbstractPermuter): def __init__(self, H, W): super().__init__() C = 1 - indices = np.arange(H*W).reshape(C,H,W) + indices = np.arange(H * W).reshape(C, H, W) while min(H, W) > 1: - indices = indices.reshape(C,H//2,2,W//2,2) - indices = indices.transpose(0,2,4,1,3) - indices = indices.reshape(C*4,H//2, W//2) - H = H//2 - W = W//2 - C = C*4 + indices = indices.reshape(C, H // 2, 2, W // 2, 2) + indices = indices.transpose(0, 2, 4, 1, 3) + indices = indices.reshape(C * 4, H // 2, W // 2) + H = H // 2 + W = W // 2 + C = C * 4 assert H == W == 1 idx = torch.tensor(indices.ravel()) - self.register_buffer('forward_shuffle_idx', - nn.Parameter(idx, requires_grad=False)) - self.register_buffer('backward_shuffle_idx', - nn.Parameter(torch.argsort(idx), requires_grad=False)) + self.register_buffer("forward_shuffle_idx", nn.Parameter(idx, requires_grad=False)) + self.register_buffer("backward_shuffle_idx", nn.Parameter(torch.argsort(idx), requires_grad=False)) def forward(self, x, reverse=False): if not reverse: @@ -52,24 +51,23 @@ def mortonify(i, j): z = np.uint(0) for pos in range(32): - z = (z | - ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) | - ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1)) - ) + z = ( + z + | ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) + | ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos + 1)) + ) return z class ZCurve(AbstractPermuter): def __init__(self, H, W): super().__init__() - reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)] + reverseidx = [np.int64(mortonify(i, j)) for i in range(H) for j in range(W)] idx = np.argsort(reverseidx) idx = torch.tensor(idx) reverseidx = torch.tensor(reverseidx) - self.register_buffer('forward_shuffle_idx', - idx) - self.register_buffer('backward_shuffle_idx', - reverseidx) + self.register_buffer("forward_shuffle_idx", idx) + self.register_buffer("backward_shuffle_idx", reverseidx) def forward(self, x, reverse=False): if not reverse: @@ -83,17 +81,17 @@ def __init__(self, H, W): super().__init__() assert H == W size = W - indices = np.arange(size*size).reshape(size,size) + indices = np.arange(size * size).reshape(size, size) - i0 = size//2 - j0 = size//2-1 + i0 = size // 2 + j0 = size // 2 - 1 i = i0 j = j0 idx = [indices[i0, j0]] step_mult = 0 - for c in range(1, size//2+1): + for c in range(1, size // 2 + 1): step_mult += 1 # steps left for k in range(step_mult): @@ -108,7 +106,7 @@ def __init__(self, H, W): idx.append(indices[i, j]) step_mult += 1 - if c < size//2: + if c < size // 2: # step right for k in range(step_mult): i = i + 1 @@ -122,14 +120,14 @@ def __init__(self, H, W): idx.append(indices[i, j]) else: # end reached - for k in range(step_mult-1): + for k in range(step_mult - 1): i = i + 1 idx.append(indices[i, j]) - assert len(idx) == size*size + assert len(idx) == size * size idx = torch.tensor(idx) - self.register_buffer('forward_shuffle_idx', idx) - self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) + self.register_buffer("forward_shuffle_idx", idx) + self.register_buffer("backward_shuffle_idx", torch.argsort(idx)) def forward(self, x, reverse=False): if not reverse: @@ -143,17 +141,17 @@ def __init__(self, H, W): super().__init__() assert H == W size = W - indices = np.arange(size*size).reshape(size,size) + indices = np.arange(size * size).reshape(size, size) - i0 = size//2 - j0 = size//2-1 + i0 = size // 2 + j0 = size // 2 - 1 i = i0 j = j0 idx = [indices[i0, j0]] step_mult = 0 - for c in range(1, size//2+1): + for c in range(1, size // 2 + 1): step_mult += 1 # steps left for k in range(step_mult): @@ -168,7 +166,7 @@ def __init__(self, H, W): idx.append(indices[i, j]) step_mult += 1 - if c < size//2: + if c < size // 2: # step right for k in range(step_mult): i = i + 1 @@ -182,15 +180,15 @@ def __init__(self, H, W): idx.append(indices[i, j]) else: # end reached - for k in range(step_mult-1): + for k in range(step_mult - 1): i = i + 1 idx.append(indices[i, j]) - assert len(idx) == size*size + assert len(idx) == size * size idx = idx[::-1] idx = torch.tensor(idx) - self.register_buffer('forward_shuffle_idx', idx) - self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) + self.register_buffer("forward_shuffle_idx", idx) + self.register_buffer("backward_shuffle_idx", torch.argsort(idx)) def forward(self, x, reverse=False): if not reverse: @@ -202,10 +200,10 @@ def forward(self, x, reverse=False): class Random(nn.Module): def __init__(self, H, W): super().__init__() - indices = np.random.RandomState(1).permutation(H*W) + indices = np.random.RandomState(1).permutation(H * W) idx = torch.tensor(indices.ravel()) - self.register_buffer('forward_shuffle_idx', idx) - self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) + self.register_buffer("forward_shuffle_idx", idx) + self.register_buffer("backward_shuffle_idx", torch.argsort(idx)) def forward(self, x, reverse=False): if not reverse: @@ -217,14 +215,14 @@ def forward(self, x, reverse=False): class AlternateParsing(AbstractPermuter): def __init__(self, H, W): super().__init__() - indices = np.arange(W*H).reshape(H,W) + indices = np.arange(W * H).reshape(H, W) for i in range(1, H, 2): indices[i, :] = indices[i, ::-1] idx = indices.flatten() - assert len(idx) == H*W + assert len(idx) == H * W idx = torch.tensor(idx) - self.register_buffer('forward_shuffle_idx', idx) - self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) + self.register_buffer("forward_shuffle_idx", idx) + self.register_buffer("backward_shuffle_idx", torch.argsort(idx)) def forward(self, x, reverse=False): if not reverse: diff --git a/src/metr/taming/modules/util.py b/src/metr/taming/modules/util.py index 9ee1638..66fc332 100644 --- a/src/metr/taming/modules/util.py +++ b/src/metr/taming/modules/util.py @@ -8,8 +8,7 @@ def count_params(model): class ActNorm(nn.Module): - def __init__(self, num_features, logdet=False, affine=True, - allow_reverse_init=False): + def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False): assert affine super().__init__() self.logdet = logdet @@ -17,25 +16,13 @@ def __init__(self, num_features, logdet=False, affine=True, self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) self.allow_reverse_init = allow_reverse_init - self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) + self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) def initialize(self, input): with torch.no_grad(): flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) - mean = ( - flatten.mean(1) - .unsqueeze(1) - .unsqueeze(2) - .unsqueeze(3) - .permute(1, 0, 2, 3) - ) - std = ( - flatten.std(1) - .unsqueeze(1) - .unsqueeze(2) - .unsqueeze(3) - .permute(1, 0, 2, 3) - ) + mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3) + std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3) self.loc.data.copy_(-mean) self.scale.data.copy_(1 / (std + 1e-6)) @@ -44,7 +31,7 @@ def forward(self, input, reverse=False): if reverse: return self.reverse(input) if len(input.shape) == 2: - input = input[:,:,None,None] + input = input[:, :, None, None] squeeze = True else: squeeze = False @@ -62,7 +49,7 @@ def forward(self, input, reverse=False): if self.logdet: log_abs = torch.log(torch.abs(self.scale)) - logdet = height*width*torch.sum(log_abs) + logdet = height * width * torch.sum(log_abs) logdet = logdet * torch.ones(input.shape[0]).to(input) return h, logdet @@ -80,7 +67,7 @@ def reverse(self, output): self.initialized.fill_(1) if len(output.shape) == 2: - output = output[:,:,None,None] + output = output[:, :, None, None] squeeze = True else: squeeze = False @@ -102,13 +89,14 @@ def encode(self, *args, **kwargs): class Labelator(AbstractEncoder): """Net2Net Interface for Class-Conditional Model""" + def __init__(self, n_classes, quantize_interface=True): super().__init__() self.n_classes = n_classes self.quantize_interface = quantize_interface def encode(self, c): - c = c[:,None] + c = c[:, None] if self.quantize_interface: return c, None, [None, None, c.long()] return c @@ -123,7 +111,7 @@ def __init__(self, sos_token, quantize_interface=True): def encode(self, x): # get batch size from data and replicate sos_token - c = torch.ones(x.shape[0], 1)*self.sos_token + c = torch.ones(x.shape[0], 1) * self.sos_token c = c.long().to(x.device) if self.quantize_interface: return c, None, [None, None, c] diff --git a/src/metr/taming/modules/vqvae/quantize.py b/src/metr/taming/modules/vqvae/quantize.py index d75544e..16e2663 100644 --- a/src/metr/taming/modules/vqvae/quantize.py +++ b/src/metr/taming/modules/vqvae/quantize.py @@ -1,9 +1,9 @@ +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import numpy as np -from torch import einsum from einops import rearrange +from torch import einsum class VectorQuantizer(nn.Module): @@ -46,17 +46,18 @@ def forward(self, z): z_flattened = z.view(-1, self.e_dim) # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ - torch.sum(self.embedding.weight**2, dim=1) - 2 * \ - torch.matmul(z_flattened, self.embedding.weight.t()) + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 * torch.matmul(z_flattened, self.embedding.weight.t()) + ) ## could possible replace this here # #\start... # find closest encodings min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) - min_encodings = torch.zeros( - min_encoding_indices.shape[0], self.n_e).to(z) + min_encodings = torch.zeros(min_encoding_indices.shape[0], self.n_e).to(z) min_encodings.scatter_(1, min_encoding_indices, 1) # dtype min encodings: torch.float32 @@ -65,17 +66,16 @@ def forward(self, z): # get quantized latent vectors z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) - #.........\end + # .........\end # with: # .........\start - #min_encoding_indices = torch.argmin(d, dim=1) - #z_q = self.embedding(min_encoding_indices) + # min_encoding_indices = torch.argmin(d, dim=1) + # z_q = self.embedding(min_encoding_indices) # ......\end......... (TODO) # compute loss for embedding - loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ - torch.mean((z_q - z.detach()) ** 2) + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) # preserve gradients z_q = z + (z_q - z).detach() @@ -93,7 +93,7 @@ def get_codebook_entry(self, indices, shape): # shape specifying (batch, height, width, channel) # TODO: check for more easy handling with nn.Embedding min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices) - min_encodings.scatter_(1, indices[:,None], 1) + min_encodings.scatter_(1, indices[:, None], 1) # get quantized latent vectors z_q = torch.matmul(min_encodings.float(), self.embedding.weight) @@ -114,9 +114,19 @@ class GumbelQuantize(nn.Module): Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 https://arxiv.org/abs/1611.01144 """ - def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, - kl_weight=5e-4, temp_init=1.0, use_vqinterface=True, - remap=None, unknown_index="random"): + + def __init__( + self, + num_hiddens, + embedding_dim, + n_embed, + straight_through=True, + kl_weight=5e-4, + temp_init=1.0, + use_vqinterface=True, + remap=None, + unknown_index="random", + ): super().__init__() self.embedding_dim = embedding_dim @@ -135,37 +145,39 @@ def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, if self.remap is not None: self.register_buffer("used", torch.tensor(np.load(self.remap))) self.re_embed = self.used.shape[0] - self.unknown_index = unknown_index # "random" or "extra" or integer + self.unknown_index = unknown_index # "random" or "extra" or integer if self.unknown_index == "extra": self.unknown_index = self.re_embed - self.re_embed = self.re_embed+1 - print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices.") + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) else: self.re_embed = n_embed def remap_to_used(self, inds): ishape = inds.shape - assert len(ishape)>1 - inds = inds.reshape(ishape[0],-1) + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) used = self.used.to(inds) - match = (inds[:,:,None]==used[None,None,...]).long() + match = (inds[:, :, None] == used[None, None, ...]).long() new = match.argmax(-1) - unknown = match.sum(2)<1 + unknown = match.sum(2) < 1 if self.unknown_index == "random": - new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) else: new[unknown] = self.unknown_index return new.reshape(ishape) def unmap_to_all(self, inds): ishape = inds.shape - assert len(ishape)>1 - inds = inds.reshape(ishape[0],-1) + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) used = self.used.to(inds) - if self.re_embed > self.used.shape[0]: # extra token - inds[inds>=self.used.shape[0]] = 0 # simply set to zero - back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) return back.reshape(ishape) def forward(self, z, temp=None, return_logits=False): @@ -177,14 +189,14 @@ def forward(self, z, temp=None, return_logits=False): if self.remap is not None: # continue only with used logits full_zeros = torch.zeros_like(logits) - logits = logits[:,self.used,...] + logits = logits[:, self.used, ...] soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) if self.remap is not None: # go back to all entries but unused set to zero - full_zeros[:,self.used,...] = soft_one_hot + full_zeros[:, self.used, ...] = soft_one_hot soft_one_hot = full_zeros - z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) + z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) # + kl divergence to the prior loss qy = F.softmax(logits, dim=1) @@ -201,12 +213,12 @@ def forward(self, z, temp=None, return_logits=False): def get_codebook_entry(self, indices, shape): b, h, w, c = shape - assert b*h*w == indices.shape[0] - indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w) + assert b * h * w == indices.shape[0] + indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w) if self.remap is not None: indices = self.unmap_to_all(indices) one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() - z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight) + z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight) return z_q @@ -215,11 +227,11 @@ class VectorQuantizer2(nn.Module): Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix multiplications and allows for post-hoc remapping of indices. """ + # NOTE: due to a bug the beta term was applied to the wrong term. for # backwards compatibility we use the buggy version by default, but you can # specify legacy=False to fix it. - def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", - sane_index_shape=False, legacy=True): + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): super().__init__() self.n_e = n_e self.e_dim = e_dim @@ -233,12 +245,14 @@ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", if self.remap is not None: self.register_buffer("used", torch.tensor(np.load(self.remap))) self.re_embed = self.used.shape[0] - self.unknown_index = unknown_index # "random" or "extra" or integer + self.unknown_index = unknown_index # "random" or "extra" or integer if self.unknown_index == "extra": self.unknown_index = self.re_embed - self.re_embed = self.re_embed+1 - print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices.") + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) else: self.re_embed = n_e @@ -246,40 +260,42 @@ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", def remap_to_used(self, inds): ishape = inds.shape - assert len(ishape)>1 - inds = inds.reshape(ishape[0],-1) + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) used = self.used.to(inds) - match = (inds[:,:,None]==used[None,None,...]).long() + match = (inds[:, :, None] == used[None, None, ...]).long() new = match.argmax(-1) - unknown = match.sum(2)<1 + unknown = match.sum(2) < 1 if self.unknown_index == "random": - new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) else: new[unknown] = self.unknown_index return new.reshape(ishape) def unmap_to_all(self, inds): ishape = inds.shape - assert len(ishape)>1 - inds = inds.reshape(ishape[0],-1) + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) used = self.used.to(inds) - if self.re_embed > self.used.shape[0]: # extra token - inds[inds>=self.used.shape[0]] = 0 # simply set to zero - back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) return back.reshape(ishape) def forward(self, z, temp=None, rescale_logits=False, return_logits=False): - assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" - assert rescale_logits==False, "Only for interface compatible with Gumbel" - assert return_logits==False, "Only for interface compatible with Gumbel" + assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" + assert rescale_logits == False, "Only for interface compatible with Gumbel" + assert return_logits == False, "Only for interface compatible with Gumbel" # reshape z -> (batch, height, width, channel) and flatten - z = rearrange(z, 'b c h w -> b h w c').contiguous() + z = rearrange(z, "b c h w -> b h w c").contiguous() z_flattened = z.view(-1, self.e_dim) # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ - torch.sum(self.embedding.weight**2, dim=1) - 2 * \ - torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")) + ) min_encoding_indices = torch.argmin(d, dim=1) z_q = self.embedding(min_encoding_indices).view(z.shape) @@ -288,35 +304,32 @@ def forward(self, z, temp=None, rescale_logits=False, return_logits=False): # compute loss for embedding if not self.legacy: - loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ - torch.mean((z_q - z.detach()) ** 2) + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) else: - loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ - torch.mean((z_q - z.detach()) ** 2) + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) # preserve gradients z_q = z + (z_q - z).detach() # reshape back to match original input shape - z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() + z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() if self.remap is not None: - min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis min_encoding_indices = self.remap_to_used(min_encoding_indices) - min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten if self.sane_index_shape: - min_encoding_indices = min_encoding_indices.reshape( - z_q.shape[0], z_q.shape[2], z_q.shape[3]) + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) return z_q, loss, (perplexity, min_encodings, min_encoding_indices) def get_codebook_entry(self, indices, shape): # shape specifying (batch, height, width, channel) if self.remap is not None: - indices = indices.reshape(shape[0],-1) # add batch axis + indices = indices.reshape(shape[0], -1) # add batch axis indices = self.unmap_to_all(indices) - indices = indices.reshape(-1) # flatten again + indices = indices.reshape(-1) # flatten again # get quantized latent vectors z_q = self.embedding(indices) @@ -328,15 +341,16 @@ def get_codebook_entry(self, indices, shape): return z_q + class EmbeddingEMA(nn.Module): def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): super().__init__() self.decay = decay - self.eps = eps + self.eps = eps weight = torch.randn(num_tokens, codebook_dim) - self.weight = nn.Parameter(weight, requires_grad = False) - self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False) - self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False) + self.weight = nn.Parameter(weight, requires_grad=False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) self.update = True def forward(self, embed_id): @@ -345,22 +359,19 @@ def forward(self, embed_id): def cluster_size_ema_update(self, new_cluster_size): self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) - def embed_avg_ema_update(self, new_embed_avg): + def embed_avg_ema_update(self, new_embed_avg): self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) def weight_update(self, num_tokens): n = self.cluster_size.sum() - smoothed_cluster_size = ( - (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n - ) - #normalize embedding average with smoothed cluster size + smoothed_cluster_size = (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n + # normalize embedding average with smoothed cluster size embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) - self.weight.data.copy_(embed_normalized) + self.weight.data.copy_(embed_normalized) class EMAVectorQuantizer(nn.Module): - def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, - remap=None, unknown_index="random"): + def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, remap=None, unknown_index="random"): super().__init__() self.codebook_dim = codebook_dim self.num_tokens = num_tokens @@ -371,75 +382,78 @@ def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, if self.remap is not None: self.register_buffer("used", torch.tensor(np.load(self.remap))) self.re_embed = self.used.shape[0] - self.unknown_index = unknown_index # "random" or "extra" or integer + self.unknown_index = unknown_index # "random" or "extra" or integer if self.unknown_index == "extra": self.unknown_index = self.re_embed - self.re_embed = self.re_embed+1 - print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices.") + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) else: self.re_embed = n_embed def remap_to_used(self, inds): ishape = inds.shape - assert len(ishape)>1 - inds = inds.reshape(ishape[0],-1) + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) used = self.used.to(inds) - match = (inds[:,:,None]==used[None,None,...]).long() + match = (inds[:, :, None] == used[None, None, ...]).long() new = match.argmax(-1) - unknown = match.sum(2)<1 + unknown = match.sum(2) < 1 if self.unknown_index == "random": - new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) else: new[unknown] = self.unknown_index return new.reshape(ishape) def unmap_to_all(self, inds): ishape = inds.shape - assert len(ishape)>1 - inds = inds.reshape(ishape[0],-1) + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) used = self.used.to(inds) - if self.re_embed > self.used.shape[0]: # extra token - inds[inds>=self.used.shape[0]] = 0 # simply set to zero - back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) return back.reshape(ishape) def forward(self, z): # reshape z -> (batch, height, width, channel) and flatten - #z, 'b c h w -> b h w c' - z = rearrange(z, 'b c h w -> b h w c') + # z, 'b c h w -> b h w c' + z = rearrange(z, "b c h w -> b h w c") z_flattened = z.reshape(-1, self.codebook_dim) - - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ - self.embedding.weight.pow(2).sum(dim=1) - 2 * \ - torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = ( + z_flattened.pow(2).sum(dim=1, keepdim=True) + + self.embedding.weight.pow(2).sum(dim=1) + - 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight) + ) # 'n d -> d n' encoding_indices = torch.argmin(d, dim=1) z_q = self.embedding(encoding_indices).view(z.shape) - encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) avg_probs = torch.mean(encodings, dim=0) perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) if self.training and self.embedding.update: - #EMA cluster size - encodings_sum = encodings.sum(0) + # EMA cluster size + encodings_sum = encodings.sum(0) self.embedding.cluster_size_ema_update(encodings_sum) - #EMA embedding average - embed_sum = encodings.transpose(0,1) @ z_flattened + # EMA embedding average + embed_sum = encodings.transpose(0, 1) @ z_flattened self.embedding.embed_avg_ema_update(embed_sum) - #normalize embed_avg and update weight + # normalize embed_avg and update weight self.embedding.weight_update(self.num_tokens) # compute loss for embedding - loss = self.beta * F.mse_loss(z_q.detach(), z) + loss = self.beta * F.mse_loss(z_q.detach(), z) # preserve gradients z_q = z + (z_q - z).detach() # reshape back to match original input shape - #z_q, 'b h w c -> b c h w' - z_q = rearrange(z_q, 'b h w c -> b c h w') + # z_q, 'b h w c -> b c h w' + z_q = rearrange(z_q, "b h w c -> b c h w") return z_q, loss, (perplexity, encodings, encoding_indices) diff --git a/src/metr/taming/util.py b/src/metr/taming/util.py index 06053e5..d02479c 100644 --- a/src/metr/taming/util.py +++ b/src/metr/taming/util.py @@ -1,18 +1,14 @@ -import os, hashlib +import hashlib +import os + import requests from tqdm import tqdm -URL_MAP = { - "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" -} +URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} -CKPT_MAP = { - "vgg_lpips": "vgg.pth" -} +CKPT_MAP = {"vgg_lpips": "vgg.pth"} -MD5_MAP = { - "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" -} +MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} def download(url, local_path, chunk_size=1024): @@ -59,9 +55,7 @@ def __init__(self, cause, keys=None, visited=None): super().__init__(message) -def retrieve( - list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False -): +def retrieve(list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False): """Given a nested list or dict return the desired value at key expanding callable nodes if necessary and :attr:`expand` is ``True``. The expansion is done in-place. @@ -104,9 +98,7 @@ def retrieve( if callable(list_or_dict): if not expand: raise KeyNotFoundError( - ValueError( - "Trying to get past callable node with expand=False." - ), + ValueError("Trying to get past callable node with expand=False."), keys=keys, visited=visited, ) @@ -143,15 +135,16 @@ def retrieve( if __name__ == "__main__": - config = {"keya": "a", - "keyb": "b", - "keyc": - {"cc1": 1, - "cc2": 2, - } - } + config = { + "keya": "a", + "keyb": "b", + "keyc": { + "cc1": 1, + "cc2": 2, + }, + } from omegaconf import OmegaConf + config = OmegaConf.create(config) print(config) retrieve(config, "keya") - diff --git a/src/metr/utils.py b/src/metr/utils.py index 5ad552d..e7f6a1c 100644 --- a/src/metr/utils.py +++ b/src/metr/utils.py @@ -1,9 +1,11 @@ import os + from huggingface_hub import hf_api -HOME_PATH = os.path.expanduser('~') -CACHE_PATH = os.path.join(HOME_PATH, './cache/trk/') -FILE_PATH = 'current_org' +HOME_PATH = os.path.expanduser("~") +CACHE_PATH = os.path.join(HOME_PATH, "./cache/trk/") +FILE_PATH = "current_org" + def set_org(org: str): file_path = os.path.join(CACHE_PATH, FILE_PATH) @@ -13,7 +15,7 @@ def set_org(org: str): os.makedirs(CACHE_PATH) # Create a new text file in the directory - with open(file_path, 'w') as file: + with open(file_path, "w") as file: file.write(org) # write an empty string to the file @@ -22,7 +24,7 @@ def get_org() -> str: if not os.path.isfile(file_path): raise ValueError(f"{file_path} does not exist. Make sure to run `trk.setup_repo(...)` first.") - with open(file_path, 'r') as file: + with open(file_path, "r") as file: current_org = file.read() return current_org