Skip to content

Commit

Permalink
refactor: applies linting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jaydrennan authored and brycedrennan committed Nov 22, 2023
1 parent 1ed199a commit 1ec841a
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 55 deletions.
1 change: 1 addition & 0 deletions imaginairy/modules/sgm/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
if TYPE_CHECKING:
from .autoencoding.regularizers import AbstractRegularizer


# from .ema import LitEma
# from .util import (default, get_nested_attribute, get_obj_from_str,
# instantiate_from_config)
Expand Down
11 changes: 7 additions & 4 deletions imaginairy/modules/sgm/autoencoding/lpips/loss/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,11 @@ def from_pretrained(cls, name="vgg_lpips"):
)
return model

def forward(self, input, target):
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
def forward(self, input_tensor, target):
in0_input, in1_input = (
self.scaling_layer(input_tensor),
self.scaling_layer(target),
)
outs0, outs1 = self.net(in0_input), self.net(in1_input)
feats0, feats1, diffs = {}, {}, {}
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
Expand All @@ -59,8 +62,8 @@ def forward(self, input, target):
for kk in range(len(self.chns))
]
val = res[0]
for l in range(1, len(self.chns)):
val += res[l]
for i in range(1, len(self.chns)):
val += res[i]
return val


Expand Down
4 changes: 2 additions & 2 deletions imaginairy/modules/sgm/autoencoding/lpips/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,6 @@ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
] # output 1 channel prediction map
self.main = nn.Sequential(*sequence)

def forward(self, input):
def forward(self, input_tensor):
"""Standard forward."""
return self.main(input)
return self.main(input_tensor)
39 changes: 21 additions & 18 deletions imaginairy/modules/sgm/autoencoding/lpips/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@

def download(url, local_path, chunk_size=1024):
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
with requests.get(url, stream=True) as r:
total_size = int(r.headers.get("content-length", 0))
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
with open(local_path, "wb") as f:
for data in r.iter_content(chunk_size=chunk_size):
if data:
f.write(data)
pbar.update(chunk_size)
with requests.get(url, stream=True) as r, tqdm(
total=int(r.headers.get("content-length", 0)), unit="B", unit_scale=True
) as pbar, open(local_path, "wb") as f:
for data in r.iter_content(chunk_size=chunk_size):
if data:
f.write(data)
pbar.update(chunk_size)


def md5_hash(path):
Expand Down Expand Up @@ -55,9 +54,13 @@ def __init__(

self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))

def initialize(self, input):
def initialize(self, input_tensor):
with torch.no_grad():
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
flatten = (
input_tensor.permute(1, 0, 2, 3)
.contiguous()
.view(input_tensor.shape[1], -1)
)
mean = (
flatten.mean(1)
.unsqueeze(1)
Expand All @@ -76,30 +79,30 @@ def initialize(self, input):
self.loc.data.copy_(-mean)
self.scale.data.copy_(1 / (std + 1e-6))

def forward(self, input, reverse=False):
def forward(self, input_tensor, reverse=False):
if reverse:
return self.reverse(input)
if len(input.shape) == 2:
input = input[:, :, None, None]
return self.reverse(input_tensor)
if len(input_tensor.shape) == 2:
input_tensor = input_tensor[:, :, None, None]
squeeze = True
else:
squeeze = False

_, _, height, width = input.shape
_, _, height, width = input_tensor.shape

if self.training and self.initialized.item() == 0:
self.initialize(input)
self.initialize(input_tensor)
self.initialized.fill_(1)

h = self.scale * (input + self.loc)
h = self.scale * (input_tensor + self.loc)

if squeeze:
h = h.squeeze(-1).squeeze(-1)

if self.logdet:
log_abs = torch.log(torch.abs(self.scale))
logdet = height * width * torch.sum(log_abs)
logdet = logdet * torch.ones(input.shape[0]).to(input)
logdet = logdet * torch.ones(input_tensor.shape[0]).to(input_tensor)
return h, logdet

return h
Expand Down
8 changes: 4 additions & 4 deletions imaginairy/modules/sgm/autoencoding/temporal_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwar
padding=padding,
)

def forward(self, input, timesteps, skip_video=False):
x = super().forward(input)
def forward(self, input_tensor, timesteps, skip_video=False):
x = super().forward(input_tensor)
if skip_video:
return x
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
Expand Down Expand Up @@ -294,8 +294,8 @@ def make_time_attn(


class Conv2DWrapper(torch.nn.Conv2d):
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
return super().forward(input)
def forward(self, input_tensor: torch.Tensor, **kwargs) -> torch.Tensor:
return super().forward(input_tensor)


class VideoDecoder(Decoder):
Expand Down
4 changes: 2 additions & 2 deletions imaginairy/modules/sgm/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ def sample(
):
randn = torch.randn(batch_size, *shape).to(self.device)

def denoiser(input, sigma, c):
return self.denoiser(self.model, input, sigma, c, **kwargs)
def denoiser(input_tensor, sigma, c):
return self.denoiser(self.model, input_tensor, sigma, c, **kwargs)

samples = self.sampler(denoiser, randn, cond, uc=uc)
return samples
Expand Down
9 changes: 5 additions & 4 deletions imaginairy/modules/sgm/diffusionmodules/denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,20 @@ def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
def forward(
self,
network: nn.Module,
input: torch.Tensor,
input_tensor: torch.Tensor,
sigma: torch.Tensor,
cond: Dict,
**additional_model_inputs,
) -> torch.Tensor:
sigma = self.possibly_quantize_sigma(sigma)
sigma_shape = sigma.shape
sigma = append_dims(sigma, input.ndim)
sigma = append_dims(sigma, input_tensor.ndim)
c_skip, c_out, c_in, c_noise = self.scaling(sigma)
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
return (
network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out
+ input * c_skip
network(input_tensor * c_in, c_noise, cond, **additional_model_inputs)
* c_out
+ input_tensor * c_skip
)


Expand Down
34 changes: 16 additions & 18 deletions imaginairy/modules/sgm/diffusionmodules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
from imaginairy.modules.sgm.autoencoding.lpips.loss.lpips import LPIPS
from imaginairy.modules.sgm.encoders.modules import GeneralConditioner
from imaginairy.utils import instantiate_from_config
from imaginairy.vendored.k_diffusion.utils import append_dims

# from ...modules.autoencoding.lpips.loss.lpips import LPIPS
# from ...modules.encoders.modules import GeneralConditioner
# from ...util import append_dims, instantiate_from_config
from .denoiser import Denoiser


Expand Down Expand Up @@ -44,54 +42,54 @@ def __init__(
self.batch2model_keys = set(batch2model_keys)

def get_noised_input(
self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor
self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input_tensor: torch.Tensor
) -> torch.Tensor:
noised_input = input + noise * sigmas_bc
noised_input = input_tensor + noise * sigmas_bc
return noised_input

def forward(
self,
network: nn.Module,
denoiser: Denoiser,
conditioner: GeneralConditioner,
input: torch.Tensor,
input_tensor: torch.Tensor,
batch: Dict,
) -> torch.Tensor:
cond = conditioner(batch)
return self._forward(network, denoiser, cond, input, batch)
return self._forward(network, denoiser, cond, input_tensor, batch)

def _forward(
self,
network: nn.Module,
denoiser: Denoiser,
cond: Dict,
input: torch.Tensor,
input_tensor: torch.Tensor,
batch: Dict,
) -> Tuple[torch.Tensor, Dict]:
additional_model_inputs = {
key: batch[key] for key in self.batch2model_keys.intersection(batch)
}
sigmas = self.sigma_sampler(input.shape[0]).to(input)
sigmas = self.sigma_sampler(input_tensor.shape[0]).to(input)

noise = torch.randn_like(input)
noise = torch.randn_like(input_tensor)
if self.offset_noise_level > 0.0:
offset_shape = (
(input.shape[0], 1, input.shape[2])
(input_tensor.shape[0], 1, input.shape[2])
if self.n_frames is not None
else (input.shape[0], input.shape[1])
else (input_tensor.shape[0], input.shape[1])
)
noise = noise + self.offset_noise_level * append_dims(
torch.randn(offset_shape, device=input.device),
input.ndim,
torch.randn(offset_shape, device=input_tensor.device),
input_tensor.ndim,
)
sigmas_bc = append_dims(sigmas, input.ndim)
noised_input = self.get_noised_input(sigmas_bc, noise, input)
sigmas_bc = append_dims(sigmas, input_tensor.ndim)
noised_input = self.get_noised_input(sigmas_bc, noise, input_tensor)

model_output = denoiser(
network, noised_input, sigmas, cond, **additional_model_inputs
)
w = append_dims(self.loss_weighting(sigmas), input.ndim)
return self.get_loss(model_output, input, w)
w = append_dims(self.loss_weighting(sigmas), input_tensor.ndim)
return self.get_loss(model_output, input_tensor, w)

def get_loss(self, model_output, target, w):
if self.loss_type == "l2":
Expand Down
20 changes: 19 additions & 1 deletion imaginairy/modules/sgm/diffusionmodules/video_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
from typing import List, Optional, Union

import torch as th
import torch.nn as nn
from einops import rearrange

from imaginairy.modules.sgm.diffusionmodules.openaimodel import *
from imaginairy.modules.sgm.diffusionmodules.openaimodel import (
Downsample,
ResBlock,
SpatialVideoTransformer,
Timestep,
TimestepEmbedSequential,
Upsample,
)
from imaginairy.modules.sgm.diffusionmodules.util import (
conv_nd,
linear,
normalization,
timestep_embedding,
zero_module,
)
from imaginairy.utils import default

from .util import AlphaBlender

# import torch.nn.functional as F


class VideoResBlock(ResBlock):
def __init__(
Expand Down
4 changes: 2 additions & 2 deletions imaginairy/modules/sgm/diffusionmodules/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
class IdentityWrapper(nn.Module):
def __init__(self, diffusion_model, compile_model: bool = False):
super().__init__()
compile = (
torch_compile = (
torch.compile
if (version.parse(torch.__version__) >= version.parse("2.0.0"))
and compile_model
else lambda x: x
)
self.diffusion_model = compile(diffusion_model)
self.diffusion_model = torch_compile(diffusion_model)

def forward(self, *args, **kwargs):
return self.diffusion_model(*args, **kwargs)
Expand Down

0 comments on commit 1ec841a

Please sign in to comment.