Skip to content

Commit

Permalink
Added PixTransform guided upsampling (#1506)
Browse files Browse the repository at this point in the history
* Added PixTransform guided upsampling

* Added grayscale image op

* WIP

* Fixed type error

* Formatting

* Add desc

* Formatting

* Added tile-based auto split

* Use dataclass

* Improved exact auto split

* Fixed normalization

* Split in Lab by default

* Pass split mode param

* Fixed exact split

* Optimize single colors

* Reduce iters to 1k by default

* Remove redundant logging

* Reduce iterations

* FIxed write_into for single-channel images

* Rebranding
  • Loading branch information
RunDevelopment authored Mar 1, 2023
1 parent 1aa7c68 commit 0c408a9
Show file tree
Hide file tree
Showing 12 changed files with 571 additions and 15 deletions.
44 changes: 44 additions & 0 deletions backend/src/nodes/impl/image_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Callable

import numpy as np
from typing_extensions import Concatenate, ParamSpec

ImageOp = Callable[[np.ndarray], np.ndarray]
"""
An image processing operation that takes an image and produces a new image.
The given image is guaranteed to *not* be modified.
"""


def clipped(op: ImageOp) -> ImageOp:
"""
Ensures that all values in the returned image are between 0 and 1.
"""
return lambda i: np.clip(op(i), 0, 1)


P = ParamSpec("P")


def to_op(fn: Callable[Concatenate[np.ndarray, P], np.ndarray]) -> Callable[P, ImageOp]:
"""
Applies a form of currying to convert the given function into a constructor for an image operation.
Example: Simple resize method could be defined as follows: `resize(np.ndarray, Size2D) -> np.ndarray`.
It takes an image and its new size and returns the resized image.
If we want to convert it to an image operation, we have to create a function with the following signature: `resize_op(Size2D) -> ImageOp`.
The implementation of this function would be rather simple, it would simply take all arguments of `resize` except for the image like this:
```py
def resize_op(size: Size2D) -> ImageOp:
return lambda img: resize(img, size)
```
`to_op` does exactly this transformation, but for any number of arguments.
Note: This only works if the input image is the first argument of the given function.
"""

def p(*args: P.args, **kwargs: P.kwargs) -> ImageOp:
return lambda i: fn(i, *args, **kwargs)

return p
21 changes: 21 additions & 0 deletions backend/src/nodes/impl/pytorch/pix_transform/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2019 Riccardo de Lutio

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
90 changes: 90 additions & 0 deletions backend/src/nodes/impl/pytorch/pix_transform/auto_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from __future__ import annotations

import gc

import numpy as np
import torch

from ....utils.utils import Region, Size, get_h_w_c
from ...image_op import to_op
from ...upscale.auto_split import Split, auto_split
from ...upscale.grayscale import SplitMode, grayscale_split
from ...upscale.passthrough import passthrough_single_color
from ...upscale.tiler import Tiler
from .pix_transform import Params, PixTransform


class _PixTiler(Tiler):
def __init__(self, max_tile_size: int = 2048) -> None:
self.max_tile_size: int = max_tile_size

def allow_smaller_tile_size(self) -> bool:
return False

def starting_tile_size(self, width: int, height: int, _channels: int) -> Size:
square = min(width, height, self.max_tile_size)
return square, square

def split(self, tile_size: Size) -> Size:
# half the tile size plus a bit extra to account for overlap
size = tile_size[0] // 2 + tile_size[0] // 8
if size < 16:
raise ValueError("Cannot split any further.")
return size, size


def pix_transform_auto_split(
source: np.ndarray,
guide: np.ndarray,
device: torch.device,
params: Params,
split_mode: SplitMode = SplitMode.LAB,
) -> np.ndarray:
"""
Automatically splits the source and guide image into segments that can be processed by PixTransform.
The source and guide image may have any number of channels and any size, also long as the size of the guide image is a whole number (greater than 1) multiple of the size of the source image.
"""

s_w, s_h, _ = get_h_w_c(source)
g_w, g_h, _ = get_h_w_c(guide)

assert (
g_h > s_h and g_w > s_w
), f"The guide image mus be larger than the source image."
assert (
g_w / s_w == g_w // s_w and g_w / s_w == g_h / s_h
), "The size of the guide image must be an integer multiple of the size of the source image (e.g. 2x, 3x, 4x, ...)."

tiler = _PixTiler()
scale = g_w // s_w

def upscale(tile: np.ndarray, region: Region):
try:
tile_guide = region.scale(scale).read_from(guide)
pix_op = to_op(PixTransform)(
guide_img=np.transpose(tile_guide, (2, 0, 1)),
device=device,
params=params,
)
# passthrough single colors to speed up alpha channels
pass_op = to_op(passthrough_single_color)(scale, pix_op)

return grayscale_split(tile, pass_op, split_mode)
except RuntimeError as e:
# Check to see if its actually the CUDA out of memory error
if "allocate" in str(e) or "CUDA" in str(e):
# Collect garbage (clear VRAM)
gc.collect()
torch.cuda.empty_cache()
return Split()
else:
# Re-raise the exception if not an OOM error
raise

try:
return auto_split(source, upscale, tiler)
finally:
del device
gc.collect()
torch.cuda.empty_cache()
124 changes: 124 additions & 0 deletions backend/src/nodes/impl/pytorch/pix_transform/pix_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Literal, Tuple

import numpy as np
import torch
import torch.optim as optim
import torch.utils.data

from .pix_transform_net import PixTransformNet


@dataclass
class Params:
spatial_features_input: bool = True
# spatial color head
weights_regularizer: Tuple[float, float, float] | None = (0.0001, 0.001, 0.001)
loss: Literal["mse", "l1"] = "l1"
lr: float = 0.001
batch_size: int = 32
iteration: int = 32 * 1024


def PixTransform(
source_img: np.ndarray,
guide_img: np.ndarray,
device: torch.device,
params: Params,
) -> np.ndarray:
if len(guide_img.shape) < 3:
guide_img = np.expand_dims(guide_img, 0)

_n_channels, hr_height, hr_width = guide_img.shape

source_img = source_img.squeeze()
lr_height, lr_width = source_img.shape

assert hr_height == hr_width
assert lr_height == lr_width
assert hr_height % lr_height == 0

D = hr_height // lr_height
M = lr_height
_N = hr_height

# normalize guide and source
guide_img = (
guide_img - np.mean(guide_img, axis=(1, 2), keepdims=True)
) / np.maximum(0.0001, np.std(guide_img, axis=(1, 2), keepdims=True))

source_img_mean = np.mean(source_img)
source_img_std = np.std(source_img)
source_img = (source_img - source_img_mean) / np.maximum(0.0001, source_img_std)

if params.spatial_features_input:
x = np.linspace(-0.5, 0.5, hr_width)
x_grid, y_grid = np.meshgrid(x, x, indexing="ij")

x_grid = np.expand_dims(x_grid, axis=0)
y_grid = np.expand_dims(y_grid, axis=0)

guide_img = np.concatenate([guide_img, x_grid, y_grid], axis=0)

#### prepare_patches #########################################################################
# guide_patches is M^2 x C x D x D
# source_pixels is M^2 x 1

guide_tensor = torch.from_numpy(guide_img).float().to(device)
source_tensor = torch.from_numpy(source_img).float().to(device)

guide_patches = torch.zeros((M * M, guide_tensor.shape[0], D, D)).to(device)
source_pixels = torch.zeros((M * M, 1)).to(device)
for i in range(0, M):
for j in range(0, M):
guide_patches[j + i * M, :, :, :] = guide_tensor[
:, i * D : (i + 1) * D, j * D : (j + 1) * D
]
source_pixels[j + i * M] = source_tensor[i : (i + 1), j : (j + 1)]

train_data = torch.utils.data.TensorDataset(guide_patches, source_pixels)
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=params.batch_size, shuffle=True
)
###############################################################################################

#### setup network ############################################################################
mynet = (
PixTransformNet(
channels_in=guide_tensor.shape[0],
weights_regularizer=params.weights_regularizer,
)
.train()
.to(device)
)
optimizer = optim.Adam(mynet.params_with_regularizer, lr=params.lr)
if params.loss == "mse":
myloss = torch.nn.MSELoss()
elif params.loss == "l1":
myloss = torch.nn.L1Loss()
else:
assert False, "unknown loss!"
###############################################################################################

epochs = params.batch_size * params.iteration // (M * M)
for _epoch in range(0, epochs):
for x, y in train_loader:
optimizer.zero_grad()

y_pred = mynet(x)
y_mean_pred = torch.mean(y_pred, dim=[2, 3])

source_patch_consistency = myloss(y_mean_pred, y)

source_patch_consistency.backward()
optimizer.step()

# compute final prediction, un-normalize, and back to numpy
mynet.eval()
predicted_target_img = mynet(guide_tensor.unsqueeze(0)).squeeze()
predicted_target_img = source_img_mean + source_img_std * predicted_target_img
predicted_target_img = predicted_target_img.cpu().detach().squeeze().numpy()

return predicted_target_img
68 changes: 68 additions & 0 deletions backend/src/nodes/impl/pytorch/pix_transform/pix_transform_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import annotations

from typing import Tuple

import torch.nn as nn


class PixTransformNet(nn.Module):
def __init__(
self,
channels_in: int = 5,
kernel_size: int = 1,
weights_regularizer: Tuple[float, float, float] | None = None,
):
super(PixTransformNet, self).__init__()

self.channels_in = channels_in

self.spatial_net = nn.Sequential(
nn.Conv2d(2, 32, (1, 1), padding=0),
nn.ReLU(),
nn.Conv2d(
32, 2048, (kernel_size, kernel_size), padding=(kernel_size - 1) // 2
),
)
self.color_net = nn.Sequential(
nn.Conv2d(channels_in - 2, 32, (1, 1), padding=0),
nn.ReLU(),
nn.Conv2d(
32, 2048, (kernel_size, kernel_size), padding=(kernel_size - 1) // 2
),
)
self.head_net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(
2048, 32, (kernel_size, kernel_size), padding=(kernel_size - 1) // 2
),
nn.ReLU(),
nn.Conv2d(32, 1, (1, 1), padding=0),
)

if weights_regularizer is None:
reg_spatial = 0.0001
reg_color = 0.001
reg_head = 0.0001
else:
reg_spatial = weights_regularizer[0]
reg_color = weights_regularizer[1]
reg_head = weights_regularizer[2]

self.params_with_regularizer = []
self.params_with_regularizer += [
{"params": self.spatial_net.parameters(), "weight_decay": reg_spatial}
]
self.params_with_regularizer += [
{"params": self.color_net.parameters(), "weight_decay": reg_color}
]
self.params_with_regularizer += [
{"params": self.head_net.parameters(), "weight_decay": reg_head}
]

def forward(self, input_):
input_spatial = input_[:, self.channels_in - 2 :, :, :]
input_color = input_[:, 0 : self.channels_in - 2, :, :]

merged_features = self.spatial_net(input_spatial) + self.color_net(input_color)

return self.head_net(merged_features)
3 changes: 2 additions & 1 deletion backend/src/nodes/impl/upscale/auto_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,12 @@ def no_split_upscale(i: np.ndarray, r: Region) -> np.ndarray:
MAX_ITER = 20
for _ in range(MAX_ITER):
try:
max_overlap = min(*starting_tile_size) // 4
return exact_split(
img=img,
exact_size=starting_tile_size,
upscale=no_split_upscale,
overlap=overlap,
overlap=min(max_overlap, overlap),
)
except _SplitEx:
starting_tile_size = split_tile_size(starting_tile_size)
Expand Down
9 changes: 3 additions & 6 deletions backend/src/nodes/impl/upscale/convenient_upscale.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from typing import Callable, Tuple
from typing import Tuple

import numpy as np

from ...utils.utils import get_h_w_c
from ..image_op import ImageOp, clipped
from ..image_utils import as_target_channels


def clipped(upscale: Callable[[np.ndarray], np.ndarray]) -> Callable:
return lambda i: np.clip(upscale(i), 0, 1)


def with_black_and_white_backgrounds(img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
c = get_h_w_c(img)[2]
assert c == 4
Expand All @@ -27,7 +24,7 @@ def convenient_upscale(
img: np.ndarray,
model_in_nc: int,
model_out_nc: int,
upscale: Callable[[np.ndarray], np.ndarray],
upscale: ImageOp,
) -> np.ndarray:
"""
Upscales the given image in an intuitive/convenient way.
Expand Down
Loading

0 comments on commit 0c408a9

Please sign in to comment.