Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
nagadomi committed Oct 17, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
2 parents 4c854f2 + b6eef3e commit 7e8e187
Showing 13 changed files with 177 additions and 46 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -48,7 +48,7 @@ CLI tools are also available to filter out low quality images using these result

#### Dependencies

- Python 3 (Probably works with Python 3.9 or later, developed with 3.10)
- Python 3 (Works with Python 3.10 or later, developed with 3.10)
- [PyTorch](https://pytorch.org/get-started/locally/)
- See requirements.txt

2 changes: 2 additions & 0 deletions iw3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import truststore
truststore.inject_into_ssl()
11 changes: 11 additions & 0 deletions iw3/gui.py
Original file line number Diff line number Diff line change
@@ -136,10 +136,20 @@ def initialize_component(self):
name="chk_metadata")
self.chk_metadata.SetValue(False)

self.sep_image_format = wx.StaticLine(self.pnl_file, size=(2, 16), style=wx.LI_VERTICAL)
self.lbl_image_format = wx.StaticText(self.pnl_file, label=" " + T("Image Format"))
self.cbo_image_format = wx.ComboBox(self.pnl_file, choices=["png", "jpeg", "webp"],
style=wx.CB_READONLY, name="cbo_image_format")
self.cbo_image_format.SetSelection(0)
self.cbo_image_format.SetToolTip(T("Output Image Format"))

sublayout = wx.BoxSizer(wx.HORIZONTAL)
sublayout.Add(self.chk_resume, flag=wx.ALIGN_LEFT | wx.ALIGN_CENTER_VERTICAL)
sublayout.Add(self.chk_recursive, flag=wx.ALIGN_LEFT | wx.ALIGN_CENTER_VERTICAL)
sublayout.Add(self.chk_metadata, flag=wx.ALIGN_LEFT | wx.ALIGN_CENTER_VERTICAL)
sublayout.Add(self.sep_image_format, flag=wx.ALIGN_LEFT | wx.ALIGN_CENTER_VERTICAL)
sublayout.Add(self.lbl_image_format, flag=wx.ALIGN_LEFT | wx.ALIGN_CENTER_VERTICAL)
sublayout.Add(self.cbo_image_format, flag=wx.ALIGN_LEFT | wx.ALIGN_CENTER_VERTICAL)

layout = wx.GridBagSizer(vgap=4, hgap=4)
layout.Add(self.lbl_input, (0, 0), flag=wx.ALIGN_RIGHT | wx.ALIGN_CENTER_VERTICAL)
@@ -1121,6 +1131,7 @@ def parse_args(self):
pix_fmt=self.cbo_pix_fmt.GetValue(),
colorspace=self.cbo_colorspace.GetValue(),
video_format=self.cbo_video_format.GetValue(),
format=self.cbo_image_format.GetValue(),
video_codec=self.cbo_video_codec.GetValue(),
crf=int(self.cbo_crf.GetValue()),
profile_level=profile_level,
1 change: 1 addition & 0 deletions iw3/locales/ja.yml
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
"Skip processing when the output file already exists": "出力ファイルが存在する場合に処理をスキップ"
"Process all subfolders": "すべてのサブフォルダを処理"
"Add metadata to filename": "ファイル名にメタデータを追加"
"Image Format": "画像フォーマット"
"Play": "再生"

"Error": "エラー"
48 changes: 40 additions & 8 deletions iw3/utils.py
Original file line number Diff line number Diff line change
@@ -301,7 +301,7 @@ def to_deciaml(f, scale, zfill=0):
else:
metadata = ""

return basename + metadata + auto_detect_suffix + (args.video_extension if video else ".png")
return basename + metadata + auto_detect_suffix + (args.video_extension if video else get_image_ext(args.format))


def make_video_codec_option(args):
@@ -333,8 +333,37 @@ def make_video_codec_option(args):
return options


def save_image(im, output_filename):
im.save(output_filename)
def get_image_ext(format):
if format == "png":
return ".png"
elif format == "webp":
return ".webp"
elif format == "jpeg":
return ".jpg"
else:
raise NotImplementedError(format)


def save_image(im, output_filename, format="png"):
if format == "png":
options = {
"compress_level": 6
}
elif format == "webp":
options = {
"quality": 95,
"method": 4,
"lossless": True
}
elif format == "jpeg":
options = {
"quality": 95,
"subsampling": "4:2:0",
}
else:
raise NotImplementedError(format)

im.save(output_filename, format=format, **options)


def remove_bg_from_image(im, bg_session):
@@ -568,7 +597,7 @@ def process_images(files, output_dir, args, depth_model, side_model, title=None)
continue
im = TF.to_tensor(im).to(args.state["device"])
output = process_image(im, args, depth_model, side_model)
f = pool.submit(save_image, output, output_filename)
f = pool.submit(save_image, output, output_filename, format=args.format)
# f.result() # for debug
futures.append(f)
pbar.update(1)
@@ -747,11 +776,12 @@ def process_video_keyframes(input_filename, output_path, args, depth_model, side
futures = []

def frame_callback(frame):
output = process_image(frame.to_image(), args, depth_model, side_model)
im = TF.to_tensor(frame.to_image()).to(args.state["device"])
output = process_image(im, args, depth_model, side_model)
output_filename = path.join(
output_dir,
path.basename(output_dir) + "_" + str(frame.pts).zfill(8) + FULL_SBS_SUFFIX + ".png")
f = pool.submit(save_image, output, output_filename)
path.basename(output_dir) + "_" + str(frame.pts).zfill(8) + FULL_SBS_SUFFIX + get_image_ext(args.format))
f = pool.submit(save_image, output, output_filename, format=args.format)
futures.append(f)
VU.process_video_keyframes(input_filename, frame_callback=frame_callback,
min_interval_sec=args.keyframe_interval,
@@ -1284,7 +1314,7 @@ def fix_rgb_depth_pair(files1, files2):
output_filename = path.join(
output_dir,
make_output_filename(rgb_filename, args, video=False))
f = pool.submit(save_image, sbs, output_filename)
f = pool.submit(save_image, sbs, output_filename, format=args.format)
futures.append(f)
pbar.update(1)
if suspend_event is not None:
@@ -1450,6 +1480,8 @@ def __repr__(self):
help="max inference worker threads for video processing. 0 is disabled")
parser.add_argument("--video-format", "-vf", type=str, default="mp4", choices=["mp4", "mkv", "avi"],
help="video container format")
parser.add_argument("--format", "-f", type=str, default="png", choices=["png", "webp", "jpeg"],
help="output image format")
parser.add_argument("--video-codec", "-vc", type=str, default=None, help="video codec")

parser.add_argument("--metadata", type=str, nargs="?", default=None, const="filename", choices=["filename"],
9 changes: 5 additions & 4 deletions nunif/models/onnx_helper_models.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
import onnx
import copy
from .model import I2IBaseModel
from ..modules.reflection_pad2d import reflection_pad2d_loop
from ..utils.alpha import ChannelWiseSum
from ..logger import logger

@@ -15,7 +16,7 @@ def __init__(self):
super().__init__({}, scale=1, offset=0, in_channels=3)

def forward(self, x: torch.Tensor, left: int, right: int, top: int, bottom: int):
return F.pad(x, (left, right, top, bottom), mode="reflect")
return reflection_pad2d_loop(x, (left, right, top, bottom))

def export_onnx(self, f, **kwargs):
"""
@@ -25,11 +26,11 @@ def export_onnx(self, f, **kwargs):
var out = await ses.run({"x": x, "left": pad, "right": pad, "top": pad, "bottom": pad});
"""
x = torch.rand([1, 3, 256, 256], dtype=torch.float32)
pad = 4
pad = [512, 120, 512, 120]
model = torch.jit.script(self.to_inference_model())
torch.onnx.export(
model,
[x, pad, pad, pad, pad],
[x, *pad],
f,
input_names=["x", "left", "right", "top", "bottom"],
output_names=["y"],
@@ -432,7 +433,7 @@ def _test_alpha_border():


if __name__ == "__main__":
# _test_pad()
_test_pad()
# _test_blend_filter()
# _test_alpha_border()
# _test_resize()
61 changes: 50 additions & 11 deletions nunif/modules/reflection_pad2d.py
Original file line number Diff line number Diff line change
@@ -3,34 +3,61 @@
import torch.nn as nn


def reflection_pad2d_naive(x, padding, detach=False):
def _detach_fn(x, flag: bool):
if flag:
return x.detach()
else:
return x


def reflection_pad2d_naive(x, padding: tuple[int, int, int, int], detach: bool = False):
assert x.ndim == 4 and len(padding) == 4
# TODO: over 2x size support
assert padding[0] < x.shape[3] and padding[1] < x.shape[3]
assert padding[2] < x.shape[2] and padding[3] < x.shape[2]
assert padding[0] <= x.shape[3] and padding[1] <= x.shape[3]
assert padding[2] <= x.shape[2] and padding[3] <= x.shape[2]
left, right, top, bottom = padding

detach_fn = lambda t: t.detach() if detach else t
if left > 0:
x = torch.cat((torch.flip(detach_fn(x[:, :, :, 1:left + 1]), dims=[3]), x), dim=3)
x = torch.cat((torch.flip(_detach_fn(x[:, :, :, 1:left + 1], detach), dims=[3]), x), dim=3)
elif left < 0:
x = x[:, :, :, -left:]
if right > 0:
x = torch.cat((x, torch.flip(detach_fn(x[:, :, :, -right - 1:-1]), dims=[3])), dim=3)
x = torch.cat((x, torch.flip(_detach_fn(x[:, :, :, -right - 1:-1], detach), dims=[3])), dim=3)
elif right < 0:
x = x[:, :, :, :right]
if top > 0:
x = torch.cat((torch.flip(detach_fn(x[:, :, 1:top + 1, :]), dims=[2]), x), dim=2)
x = torch.cat((torch.flip(_detach_fn(x[:, :, 1:top + 1, :], detach), dims=[2]), x), dim=2)
elif top < 0:
x = x[:, :, -top:, :]
if bottom > 0:
x = torch.cat((x, torch.flip(detach_fn(x[:, :, -bottom - 1:-1, :]), dims=[2])), dim=2)
x = torch.cat((x, torch.flip(_detach_fn(x[:, :, -bottom - 1:-1, :], detach), dims=[2])), dim=2)
elif bottom < 0:
x = x[:, :, :bottom, :]

return x.contiguous()


def _loop_step(pad: int, base: int) -> tuple[int, int]:
remain = 0
if pad > base:
remain = pad - base
pad = base
return pad, remain


def reflection_pad2d_loop(x, padding: tuple[int, int, int, int], detach: bool = False):
# Limit one-step padding size to image size
# For onnxruntime
height, width = x.shape[2:]
left, right, top, bottom = padding
while left != 0 or right != 0 or top != 0 or bottom != 0:
left_step, left = _loop_step(left, width)
right_step, right = _loop_step(right, width)
top_step, top = _loop_step(top, height)
bottom_step, bottom = _loop_step(bottom, height)
x = reflection_pad2d_naive(x, (left_step, right_step, top_step, bottom_step), detach=detach)
return x


class ReflectionPad2dNaive(nn.Module):
def __init__(self, padding, detach=False):
super().__init__()
@@ -99,7 +126,19 @@ def _test_grad():
print(x.grad)


def _test_loop():
import torchvision.io as IO
import torchvision.transforms.functional as TF

x = IO.read_image("cc0/dog2.jpg") / 255.0
x = x[:, :256, :256].unsqueeze(0)

x = reflection_pad2d_loop(x, (640, -10, 320, -10))
TF.to_pil_image(x[0]).show()


if __name__ == "__main__":
_test()
_test_grad()
# _test()
# _test_grad()
# _test_vis()
_test_loop()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -28,3 +28,4 @@ timm
numba # only for iww3 sbs training
av >= 12.2.0
rembg # for --remove-bg
truststore
3 changes: 3 additions & 0 deletions waifu2x/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import truststore
truststore.inject_into_ssl()

from .utils import Waifu2x
from . import models

6 changes: 5 additions & 1 deletion waifu2x/export_onnx.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
import argparse
from nunif.models import load_model
from nunif.models.onnx_helper_models import (
ONNXReflectionPadding,
ONNXReplicationPadding,
ONNXTTASplit,
ONNXTTAMerge,
@@ -113,8 +114,11 @@ def convert_utils(output_dir):
utils_dir = path.join(output_dir, "utils")
os.makedirs(utils_dir, exist_ok=True)

pad = ONNXReflectionPadding()
pad.export_onnx(path.join(utils_dir, "reflection_pad.onnx"))
pad = ONNXReplicationPadding()
pad.export_onnx(path.join(utils_dir, "pad.onnx"))
pad.export_onnx(path.join(utils_dir, "replication_pad.onnx"))
pad.export_onnx(path.join(utils_dir, "pad.onnx")) # for compatibility

tta_split = ONNXTTASplit()
tta_split.export_onnx(path.join(utils_dir, "tta_split.onnx"))
37 changes: 36 additions & 1 deletion waifu2x/models/winc_unet.py
Original file line number Diff line number Diff line change
@@ -181,6 +181,7 @@ def __init__(self, in_channels, out_channels, base_dim=96,
scale_factor=2):
super(WincUNetBase, self).__init__()
assert scale_factor in {1, 2, 4}
self.scale_factor = scale_factor
C = base_dim
C2 = int(C * lv2_ratio)
# assert C % 32 == 0 and C2 % 32 == 0 # slow when C % 32 != 0
@@ -212,7 +213,12 @@ def __init__(self, in_channels, out_channels, base_dim=96,
basic_module_init(self.wac1_proj)
basic_module_init(self.fusion)

def forward(self, x):
self.tile_mode = False

def set_tile_mode(self):
self.tile_mode = True

def _forward(self, x):
ov = self.overscan(x)
x = self.patch(x)
x = F.leaky_relu(x, 0.2, inplace=True)
@@ -232,6 +238,32 @@ def forward(self, x):

return z

def _forward_tile4x4(self, x):
tl = x[:, :, :64, :64]
tr = x[:, :, :64, -64:]
bl = x[:, :, -64:, :64]
br = x[:, :, -64:, -64:]
x = torch.cat([tl, tr, bl, br], dim=0).contiguous()
x = self._forward(x)
tl, tr, bl, br = x.split(x.shape[0] // 4, dim=0)
top = torch.cat([tl, tr], dim=3)
bottom = torch.cat([bl, br], dim=3)
x = torch.cat([top, bottom], dim=2).contiguous()
return x

def forward(self, x):
if self.tile_mode:
B, C, H, W = x.shape
if self.scale_factor == 4:
assert H == 110 and W == H
assert H == 110 and W == H
pass
else:
raise NotImplementedError()
return self._forward_tile4x4(x)
else:
return self._forward(x)


def tile_size_validator(size):
return (size > 16 and
@@ -308,6 +340,9 @@ def forward(self, x):
else:
return torch.clamp(z, 0., 1.)

def set_tile_mode(self):
self.unet.set_tile_mode()

def to_2x(self, shared=True):
unet = self.unet if shared else copy.deepcopy(self.unet)
return WincUNetDownscaled(unet, downscale_factor=2,
23 changes: 12 additions & 11 deletions waifu2x/training/trainer.py
Original file line number Diff line number Diff line change
@@ -146,11 +146,12 @@ def get_last_layer(model):
"waifu2x.swin_unet_8x",
"waifu2x.winc_unet_1x",
"waifu2x.winc_unet_2x",
"waifu2x.winc_unet_4x",
"waifu2x.winc_unet_1x_small",
"waifu2x.winc_unet_2x_small",
}:
return model.unet.to_image.proj.weight
elif model.name in {"waifu2x.winc_unet_4x"}:
return model.unet.to_image.conv.weight
elif model.name in {"waifu2x.cunet", "waifu2x.upcunet"}:
return model.unet2.conv_bottom.weight
elif model.name in {"waifu2x.upconv_7", "waifu2x.vgg_7"}:
@@ -278,12 +279,8 @@ def train_step(self, data):
else:
z, y = self.diff_aug(z, y)
fake = z
if isinstance(self.discriminator, SelfSupervisedDiscriminator):
*z_real, _ = self.discriminator(torch.clamp(fake, 0, 1), y, scale_factor)
if len(z_real) == 1:
z_real = z_real[0]
else:
z_real = self.discriminator(torch.clamp(fake, 0, 1), y, scale_factor)

z_real = self.discriminator(torch.clamp(fake, 0, 1), y, scale_factor)
recon_loss = self.criterion(z, y)
generator_loss = self.discriminator_criterion(z_real)
self.sum_p_loss += recon_loss.item()
@@ -303,8 +300,8 @@ def train_step(self, data):
self.discriminator.requires_grad_(True)
if isinstance(self.discriminator, SelfSupervisedDiscriminator):
*z_fake, fake_ss_loss = self.discriminator(torch.clamp(fake.detach(), 0, 1),
y, scale_factor)
*z_real, real_ss_loss = self.discriminator(y, y, scale_factor)
y, scale_factor, train=True)
*z_real, real_ss_loss = self.discriminator(y, y, scale_factor, train=True)
if len(z_fake) == 1:
z_fake = z_fake[0]
z_real = z_real[0]
@@ -479,6 +476,8 @@ def setup_model(self):
if self.args.freeze and hasattr(self.model, "freeze"):
self.model.freeze()
logger.debug("call model.freeze()")
if self.args.tile_mode:
self.model.set_tile_mode()

def create_model(self):
kwargs = {"in_channels": 3, "out_channels": 3}
@@ -650,8 +649,8 @@ def train(args):
"waifu2x.swin_unet_2x",
"waifu2x.swin_unet_4x"}
assert args.discriminator_stop_criteria < args.generator_start_criteria
if args.size % 4 != 0:
raise ValueError("--size must be a multiple of 4")
# if args.size % 4 != 0:
# raise ValueError("--size must be a multiple of 4")
if args.arch in ARCH_SWIN_UNET and ((args.size - 16) % 12 != 0 or (args.size - 16) % 16 != 0):
raise ValueError("--size must be `(SIZE - 16) % 12 == 0 and (SIZE - 16) % 16 == 0` for SwinUNet models")
if args.method in {"noise", "noise_scale", "noise_scale4x"} and args.noise_level is None:
@@ -775,6 +774,8 @@ def register(subparsers, default_parser):
help="use only bicubic downsampling for bicubic downsampling restoration (classic super-resolution)")
parser.add_argument("--freeze", action="store_true",
help="call model.freeze() if avaliable")
parser.add_argument("--tile-mode", action="store_true",
help="call model.set_tile_mode()")
parser.add_argument("--pre-antialias", action="store_true",
help=("Set `pre_antialias=True` for SwinUNet4x."))
parser.add_argument("--privilege", action="store_true",
19 changes: 10 additions & 9 deletions waifu2x/unlimited_waifu2x/public_html/script.js
Original file line number Diff line number Diff line change
@@ -28,9 +28,9 @@ function gen_arch_config()

/* swin_unet */
config["swin_unet"] = {
art: {color_stability: true},
art_scan: {color_stability: false},
photo: {color_stability: false}};
art: {color_stability: true, padding: "replication"},
art_scan: {color_stability: false, padding: "replication"},
photo: {color_stability: false, padding: "reflection"}};
var swin = config["swin_unet"];
const calc_tile_size_swin_unet = function (tile_size, config) {
while (true) {
@@ -66,7 +66,8 @@ function gen_arch_config()
};
var base_config = {
arch: "cunet", domain: "art", calc_tile_size: calc_tile_size_cunet,
color_stability: true
color_stability: true,
padding: "replication",
};
config["cunet"]["art"] = {
scale2x: {...base_config, scale: 2, offset: 36},
@@ -459,9 +460,9 @@ const onnx_runner = {
x = await this.alpha_border_padding(rgb, alpha1, BigInt(config.offset));
// _debug_print_image_data(this.to_image_data(x.data, null, x.dims[3], x.dims[2]));
x = await this.padding(x, BigInt(p.pad[0]), BigInt(p.pad[1]),
BigInt(p.pad[2]), BigInt(p.pad[3]));
BigInt(p.pad[2]), BigInt(p.pad[3]), config.padding);
alpha3 = await this.padding(alpha3, BigInt(p.pad[0]), BigInt(p.pad[1]),
BigInt(p.pad[2]), BigInt(p.pad[3]));
BigInt(p.pad[2]), BigInt(p.pad[3]), config.padding);
alpha1 = null;
} else {
var alpha3 = {data: null};
@@ -470,7 +471,7 @@ const onnx_runner = {
await seam_blending.build();
var p = seam_blending.get_rendering_config();
x = await this.padding(x, BigInt(p.pad[0]), BigInt(p.pad[1]),
BigInt(p.pad[2]), BigInt(p.pad[3]));
BigInt(p.pad[2]), BigInt(p.pad[3]), config.padding);
}
var ch, h, w;
[ch, h, w] = [x.dims[1], x.dims[2], x.dims[3]];
@@ -556,8 +557,8 @@ const onnx_runner = {
console.timeEnd("render");
this.running = false;
},
padding: async function(x, left, right, top, bottom) {
const ses = await onnx_session.get_session(CONFIG.get_helper_model_path("pad"));
padding: async function(x, left, right, top, bottom, mode) {
const ses = await onnx_session.get_session(CONFIG.get_helper_model_path(mode + "_pad"));
left = new ort.Tensor('int64', BigInt64Array.from([left]), []);
right = new ort.Tensor('int64', BigInt64Array.from([right]), []);
top = new ort.Tensor('int64', BigInt64Array.from([top]), []);

0 comments on commit 7e8e187

Please sign in to comment.