-
-
Notifications
You must be signed in to change notification settings - Fork 291
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added PixTransform guided upsampling (#1506)
* Added PixTransform guided upsampling * Added grayscale image op * WIP * Fixed type error * Formatting * Add desc * Formatting * Added tile-based auto split * Use dataclass * Improved exact auto split * Fixed normalization * Split in Lab by default * Pass split mode param * Fixed exact split * Optimize single colors * Reduce iters to 1k by default * Remove redundant logging * Reduce iterations * FIxed write_into for single-channel images * Rebranding
- Loading branch information
1 parent
1aa7c68
commit 0c408a9
Showing
12 changed files
with
571 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from typing import Callable | ||
|
||
import numpy as np | ||
from typing_extensions import Concatenate, ParamSpec | ||
|
||
ImageOp = Callable[[np.ndarray], np.ndarray] | ||
""" | ||
An image processing operation that takes an image and produces a new image. | ||
The given image is guaranteed to *not* be modified. | ||
""" | ||
|
||
|
||
def clipped(op: ImageOp) -> ImageOp: | ||
""" | ||
Ensures that all values in the returned image are between 0 and 1. | ||
""" | ||
return lambda i: np.clip(op(i), 0, 1) | ||
|
||
|
||
P = ParamSpec("P") | ||
|
||
|
||
def to_op(fn: Callable[Concatenate[np.ndarray, P], np.ndarray]) -> Callable[P, ImageOp]: | ||
""" | ||
Applies a form of currying to convert the given function into a constructor for an image operation. | ||
Example: Simple resize method could be defined as follows: `resize(np.ndarray, Size2D) -> np.ndarray`. | ||
It takes an image and its new size and returns the resized image. | ||
If we want to convert it to an image operation, we have to create a function with the following signature: `resize_op(Size2D) -> ImageOp`. | ||
The implementation of this function would be rather simple, it would simply take all arguments of `resize` except for the image like this: | ||
```py | ||
def resize_op(size: Size2D) -> ImageOp: | ||
return lambda img: resize(img, size) | ||
``` | ||
`to_op` does exactly this transformation, but for any number of arguments. | ||
Note: This only works if the input image is the first argument of the given function. | ||
""" | ||
|
||
def p(*args: P.args, **kwargs: P.kwargs) -> ImageOp: | ||
return lambda i: fn(i, *args, **kwargs) | ||
|
||
return p |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2019 Riccardo de Lutio | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
90 changes: 90 additions & 0 deletions
90
backend/src/nodes/impl/pytorch/pix_transform/auto_split.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from __future__ import annotations | ||
|
||
import gc | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from ....utils.utils import Region, Size, get_h_w_c | ||
from ...image_op import to_op | ||
from ...upscale.auto_split import Split, auto_split | ||
from ...upscale.grayscale import SplitMode, grayscale_split | ||
from ...upscale.passthrough import passthrough_single_color | ||
from ...upscale.tiler import Tiler | ||
from .pix_transform import Params, PixTransform | ||
|
||
|
||
class _PixTiler(Tiler): | ||
def __init__(self, max_tile_size: int = 2048) -> None: | ||
self.max_tile_size: int = max_tile_size | ||
|
||
def allow_smaller_tile_size(self) -> bool: | ||
return False | ||
|
||
def starting_tile_size(self, width: int, height: int, _channels: int) -> Size: | ||
square = min(width, height, self.max_tile_size) | ||
return square, square | ||
|
||
def split(self, tile_size: Size) -> Size: | ||
# half the tile size plus a bit extra to account for overlap | ||
size = tile_size[0] // 2 + tile_size[0] // 8 | ||
if size < 16: | ||
raise ValueError("Cannot split any further.") | ||
return size, size | ||
|
||
|
||
def pix_transform_auto_split( | ||
source: np.ndarray, | ||
guide: np.ndarray, | ||
device: torch.device, | ||
params: Params, | ||
split_mode: SplitMode = SplitMode.LAB, | ||
) -> np.ndarray: | ||
""" | ||
Automatically splits the source and guide image into segments that can be processed by PixTransform. | ||
The source and guide image may have any number of channels and any size, also long as the size of the guide image is a whole number (greater than 1) multiple of the size of the source image. | ||
""" | ||
|
||
s_w, s_h, _ = get_h_w_c(source) | ||
g_w, g_h, _ = get_h_w_c(guide) | ||
|
||
assert ( | ||
g_h > s_h and g_w > s_w | ||
), f"The guide image mus be larger than the source image." | ||
assert ( | ||
g_w / s_w == g_w // s_w and g_w / s_w == g_h / s_h | ||
), "The size of the guide image must be an integer multiple of the size of the source image (e.g. 2x, 3x, 4x, ...)." | ||
|
||
tiler = _PixTiler() | ||
scale = g_w // s_w | ||
|
||
def upscale(tile: np.ndarray, region: Region): | ||
try: | ||
tile_guide = region.scale(scale).read_from(guide) | ||
pix_op = to_op(PixTransform)( | ||
guide_img=np.transpose(tile_guide, (2, 0, 1)), | ||
device=device, | ||
params=params, | ||
) | ||
# passthrough single colors to speed up alpha channels | ||
pass_op = to_op(passthrough_single_color)(scale, pix_op) | ||
|
||
return grayscale_split(tile, pass_op, split_mode) | ||
except RuntimeError as e: | ||
# Check to see if its actually the CUDA out of memory error | ||
if "allocate" in str(e) or "CUDA" in str(e): | ||
# Collect garbage (clear VRAM) | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
return Split() | ||
else: | ||
# Re-raise the exception if not an OOM error | ||
raise | ||
|
||
try: | ||
return auto_split(source, upscale, tiler) | ||
finally: | ||
del device | ||
gc.collect() | ||
torch.cuda.empty_cache() |
124 changes: 124 additions & 0 deletions
124
backend/src/nodes/impl/pytorch/pix_transform/pix_transform.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from typing import Literal, Tuple | ||
|
||
import numpy as np | ||
import torch | ||
import torch.optim as optim | ||
import torch.utils.data | ||
|
||
from .pix_transform_net import PixTransformNet | ||
|
||
|
||
@dataclass | ||
class Params: | ||
spatial_features_input: bool = True | ||
# spatial color head | ||
weights_regularizer: Tuple[float, float, float] | None = (0.0001, 0.001, 0.001) | ||
loss: Literal["mse", "l1"] = "l1" | ||
lr: float = 0.001 | ||
batch_size: int = 32 | ||
iteration: int = 32 * 1024 | ||
|
||
|
||
def PixTransform( | ||
source_img: np.ndarray, | ||
guide_img: np.ndarray, | ||
device: torch.device, | ||
params: Params, | ||
) -> np.ndarray: | ||
if len(guide_img.shape) < 3: | ||
guide_img = np.expand_dims(guide_img, 0) | ||
|
||
_n_channels, hr_height, hr_width = guide_img.shape | ||
|
||
source_img = source_img.squeeze() | ||
lr_height, lr_width = source_img.shape | ||
|
||
assert hr_height == hr_width | ||
assert lr_height == lr_width | ||
assert hr_height % lr_height == 0 | ||
|
||
D = hr_height // lr_height | ||
M = lr_height | ||
_N = hr_height | ||
|
||
# normalize guide and source | ||
guide_img = ( | ||
guide_img - np.mean(guide_img, axis=(1, 2), keepdims=True) | ||
) / np.maximum(0.0001, np.std(guide_img, axis=(1, 2), keepdims=True)) | ||
|
||
source_img_mean = np.mean(source_img) | ||
source_img_std = np.std(source_img) | ||
source_img = (source_img - source_img_mean) / np.maximum(0.0001, source_img_std) | ||
|
||
if params.spatial_features_input: | ||
x = np.linspace(-0.5, 0.5, hr_width) | ||
x_grid, y_grid = np.meshgrid(x, x, indexing="ij") | ||
|
||
x_grid = np.expand_dims(x_grid, axis=0) | ||
y_grid = np.expand_dims(y_grid, axis=0) | ||
|
||
guide_img = np.concatenate([guide_img, x_grid, y_grid], axis=0) | ||
|
||
#### prepare_patches ######################################################################### | ||
# guide_patches is M^2 x C x D x D | ||
# source_pixels is M^2 x 1 | ||
|
||
guide_tensor = torch.from_numpy(guide_img).float().to(device) | ||
source_tensor = torch.from_numpy(source_img).float().to(device) | ||
|
||
guide_patches = torch.zeros((M * M, guide_tensor.shape[0], D, D)).to(device) | ||
source_pixels = torch.zeros((M * M, 1)).to(device) | ||
for i in range(0, M): | ||
for j in range(0, M): | ||
guide_patches[j + i * M, :, :, :] = guide_tensor[ | ||
:, i * D : (i + 1) * D, j * D : (j + 1) * D | ||
] | ||
source_pixels[j + i * M] = source_tensor[i : (i + 1), j : (j + 1)] | ||
|
||
train_data = torch.utils.data.TensorDataset(guide_patches, source_pixels) | ||
train_loader = torch.utils.data.DataLoader( | ||
train_data, batch_size=params.batch_size, shuffle=True | ||
) | ||
############################################################################################### | ||
|
||
#### setup network ############################################################################ | ||
mynet = ( | ||
PixTransformNet( | ||
channels_in=guide_tensor.shape[0], | ||
weights_regularizer=params.weights_regularizer, | ||
) | ||
.train() | ||
.to(device) | ||
) | ||
optimizer = optim.Adam(mynet.params_with_regularizer, lr=params.lr) | ||
if params.loss == "mse": | ||
myloss = torch.nn.MSELoss() | ||
elif params.loss == "l1": | ||
myloss = torch.nn.L1Loss() | ||
else: | ||
assert False, "unknown loss!" | ||
############################################################################################### | ||
|
||
epochs = params.batch_size * params.iteration // (M * M) | ||
for _epoch in range(0, epochs): | ||
for x, y in train_loader: | ||
optimizer.zero_grad() | ||
|
||
y_pred = mynet(x) | ||
y_mean_pred = torch.mean(y_pred, dim=[2, 3]) | ||
|
||
source_patch_consistency = myloss(y_mean_pred, y) | ||
|
||
source_patch_consistency.backward() | ||
optimizer.step() | ||
|
||
# compute final prediction, un-normalize, and back to numpy | ||
mynet.eval() | ||
predicted_target_img = mynet(guide_tensor.unsqueeze(0)).squeeze() | ||
predicted_target_img = source_img_mean + source_img_std * predicted_target_img | ||
predicted_target_img = predicted_target_img.cpu().detach().squeeze().numpy() | ||
|
||
return predicted_target_img |
68 changes: 68 additions & 0 deletions
68
backend/src/nodes/impl/pytorch/pix_transform/pix_transform_net.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Tuple | ||
|
||
import torch.nn as nn | ||
|
||
|
||
class PixTransformNet(nn.Module): | ||
def __init__( | ||
self, | ||
channels_in: int = 5, | ||
kernel_size: int = 1, | ||
weights_regularizer: Tuple[float, float, float] | None = None, | ||
): | ||
super(PixTransformNet, self).__init__() | ||
|
||
self.channels_in = channels_in | ||
|
||
self.spatial_net = nn.Sequential( | ||
nn.Conv2d(2, 32, (1, 1), padding=0), | ||
nn.ReLU(), | ||
nn.Conv2d( | ||
32, 2048, (kernel_size, kernel_size), padding=(kernel_size - 1) // 2 | ||
), | ||
) | ||
self.color_net = nn.Sequential( | ||
nn.Conv2d(channels_in - 2, 32, (1, 1), padding=0), | ||
nn.ReLU(), | ||
nn.Conv2d( | ||
32, 2048, (kernel_size, kernel_size), padding=(kernel_size - 1) // 2 | ||
), | ||
) | ||
self.head_net = nn.Sequential( | ||
nn.ReLU(), | ||
nn.Conv2d( | ||
2048, 32, (kernel_size, kernel_size), padding=(kernel_size - 1) // 2 | ||
), | ||
nn.ReLU(), | ||
nn.Conv2d(32, 1, (1, 1), padding=0), | ||
) | ||
|
||
if weights_regularizer is None: | ||
reg_spatial = 0.0001 | ||
reg_color = 0.001 | ||
reg_head = 0.0001 | ||
else: | ||
reg_spatial = weights_regularizer[0] | ||
reg_color = weights_regularizer[1] | ||
reg_head = weights_regularizer[2] | ||
|
||
self.params_with_regularizer = [] | ||
self.params_with_regularizer += [ | ||
{"params": self.spatial_net.parameters(), "weight_decay": reg_spatial} | ||
] | ||
self.params_with_regularizer += [ | ||
{"params": self.color_net.parameters(), "weight_decay": reg_color} | ||
] | ||
self.params_with_regularizer += [ | ||
{"params": self.head_net.parameters(), "weight_decay": reg_head} | ||
] | ||
|
||
def forward(self, input_): | ||
input_spatial = input_[:, self.channels_in - 2 :, :, :] | ||
input_color = input_[:, 0 : self.channels_in - 2, :, :] | ||
|
||
merged_features = self.spatial_net(input_spatial) + self.color_net(input_color) | ||
|
||
return self.head_net(merged_features) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.