From 991a92ca57419ec0c7f9d6abe5c3ebf906e932e5 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sat, 12 Oct 2024 09:19:05 +0900 Subject: [PATCH 01/10] Add reflection_pad2d_loop for onnxruntime --- nunif/models/onnx_helper_models.py | 9 +++-- nunif/modules/reflection_pad2d.py | 61 ++++++++++++++++++++++++------ 2 files changed, 55 insertions(+), 15 deletions(-) diff --git a/nunif/models/onnx_helper_models.py b/nunif/models/onnx_helper_models.py index a70f51f0..9f4fa3fb 100644 --- a/nunif/models/onnx_helper_models.py +++ b/nunif/models/onnx_helper_models.py @@ -6,6 +6,7 @@ import onnx import copy from .model import I2IBaseModel +from ..modules.reflection_pad2d import reflection_pad2d_loop from ..utils.alpha import ChannelWiseSum from ..logger import logger @@ -15,7 +16,7 @@ def __init__(self): super().__init__({}, scale=1, offset=0, in_channels=3) def forward(self, x: torch.Tensor, left: int, right: int, top: int, bottom: int): - return F.pad(x, (left, right, top, bottom), mode="reflect") + return reflection_pad2d_loop(x, (left, right, top, bottom)) def export_onnx(self, f, **kwargs): """ @@ -25,11 +26,11 @@ def export_onnx(self, f, **kwargs): var out = await ses.run({"x": x, "left": pad, "right": pad, "top": pad, "bottom": pad}); """ x = torch.rand([1, 3, 256, 256], dtype=torch.float32) - pad = 4 + pad = [512, 120, 512, 120] model = torch.jit.script(self.to_inference_model()) torch.onnx.export( model, - [x, pad, pad, pad, pad], + [x, *pad], f, input_names=["x", "left", "right", "top", "bottom"], output_names=["y"], @@ -432,7 +433,7 @@ def _test_alpha_border(): if __name__ == "__main__": - # _test_pad() + _test_pad() # _test_blend_filter() # _test_alpha_border() # _test_resize() diff --git a/nunif/modules/reflection_pad2d.py b/nunif/modules/reflection_pad2d.py index 34955631..b7702e56 100644 --- a/nunif/modules/reflection_pad2d.py +++ b/nunif/modules/reflection_pad2d.py @@ -3,34 +3,61 @@ import torch.nn as nn -def reflection_pad2d_naive(x, padding, detach=False): +def _detach_fn(x, flag: bool): + if flag: + return x.detach() + else: + return x + + +def reflection_pad2d_naive(x, padding: tuple[int, int, int, int], detach: bool = False): assert x.ndim == 4 and len(padding) == 4 - # TODO: over 2x size support - assert padding[0] < x.shape[3] and padding[1] < x.shape[3] - assert padding[2] < x.shape[2] and padding[3] < x.shape[2] + assert padding[0] <= x.shape[3] and padding[1] <= x.shape[3] + assert padding[2] <= x.shape[2] and padding[3] <= x.shape[2] left, right, top, bottom = padding - detach_fn = lambda t: t.detach() if detach else t if left > 0: - x = torch.cat((torch.flip(detach_fn(x[:, :, :, 1:left + 1]), dims=[3]), x), dim=3) + x = torch.cat((torch.flip(_detach_fn(x[:, :, :, 1:left + 1], detach), dims=[3]), x), dim=3) elif left < 0: x = x[:, :, :, -left:] if right > 0: - x = torch.cat((x, torch.flip(detach_fn(x[:, :, :, -right - 1:-1]), dims=[3])), dim=3) + x = torch.cat((x, torch.flip(_detach_fn(x[:, :, :, -right - 1:-1], detach), dims=[3])), dim=3) elif right < 0: x = x[:, :, :, :right] if top > 0: - x = torch.cat((torch.flip(detach_fn(x[:, :, 1:top + 1, :]), dims=[2]), x), dim=2) + x = torch.cat((torch.flip(_detach_fn(x[:, :, 1:top + 1, :], detach), dims=[2]), x), dim=2) elif top < 0: x = x[:, :, -top:, :] if bottom > 0: - x = torch.cat((x, torch.flip(detach_fn(x[:, :, -bottom - 1:-1, :]), dims=[2])), dim=2) + x = torch.cat((x, torch.flip(_detach_fn(x[:, :, -bottom - 1:-1, :], detach), dims=[2])), dim=2) elif bottom < 0: x = x[:, :, :bottom, :] return x.contiguous() +def _loop_step(pad: int, base: int) -> tuple[int, int]: + remain = 0 + if pad > base: + remain = pad - base + pad = base + return pad, remain + + +def reflection_pad2d_loop(x, padding: tuple[int, int, int, int], detach: bool = False): + # Limit one-step padding size to image size + # For onnxruntime + height, width = x.shape[2:] + left, right, top, bottom = padding + while left != 0 or right != 0 or top != 0 or bottom != 0: + left_step, left = _loop_step(left, width) + right_step, right = _loop_step(right, width) + top_step, top = _loop_step(top, height) + bottom_step, bottom = _loop_step(bottom, height) + x = reflection_pad2d_naive(x, (left_step, right_step, top_step, bottom_step), detach=detach) + return x + + class ReflectionPad2dNaive(nn.Module): def __init__(self, padding, detach=False): super().__init__() @@ -99,7 +126,19 @@ def _test_grad(): print(x.grad) +def _test_loop(): + import torchvision.io as IO + import torchvision.transforms.functional as TF + + x = IO.read_image("cc0/dog2.jpg") / 255.0 + x = x[:, :256, :256].unsqueeze(0) + + x = reflection_pad2d_loop(x, (640, -10, 320, -10)) + TF.to_pil_image(x[0]).show() + + if __name__ == "__main__": - _test() - _test_grad() + # _test() + # _test_grad() # _test_vis() + _test_loop() From db1fbd29ad111300cfa84986acd22613e68ad020 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sat, 12 Oct 2024 09:22:09 +0900 Subject: [PATCH 02/10] waifu2x: unlimited: Use reflection padding instead of replication padding --- waifu2x/export_onnx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/waifu2x/export_onnx.py b/waifu2x/export_onnx.py index 722d7276..991fba62 100644 --- a/waifu2x/export_onnx.py +++ b/waifu2x/export_onnx.py @@ -7,7 +7,7 @@ import argparse from nunif.models import load_model from nunif.models.onnx_helper_models import ( - ONNXReplicationPadding, + ONNXReflectionPadding, ONNXTTASplit, ONNXTTAMerge, ONNXCreateSeamBlendingFilter, @@ -113,7 +113,7 @@ def convert_utils(output_dir): utils_dir = path.join(output_dir, "utils") os.makedirs(utils_dir, exist_ok=True) - pad = ONNXReplicationPadding() + pad = ONNXReflectionPadding() pad.export_onnx(path.join(utils_dir, "pad.onnx")) tta_split = ONNXTTASplit() From a02f622ca5ad37316c294ef8c10729380590a055 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sun, 13 Oct 2024 02:27:53 +0900 Subject: [PATCH 03/10] wafu2x: unlimited: Use reflection padding for photo model --- waifu2x/export_onnx.py | 6 +++++- .../unlimited_waifu2x/public_html/script.js | 19 ++++++++++--------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/waifu2x/export_onnx.py b/waifu2x/export_onnx.py index 991fba62..44bd125a 100644 --- a/waifu2x/export_onnx.py +++ b/waifu2x/export_onnx.py @@ -8,6 +8,7 @@ from nunif.models import load_model from nunif.models.onnx_helper_models import ( ONNXReflectionPadding, + ONNXReplicationPadding, ONNXTTASplit, ONNXTTAMerge, ONNXCreateSeamBlendingFilter, @@ -114,7 +115,10 @@ def convert_utils(output_dir): os.makedirs(utils_dir, exist_ok=True) pad = ONNXReflectionPadding() - pad.export_onnx(path.join(utils_dir, "pad.onnx")) + pad.export_onnx(path.join(utils_dir, "reflection_pad.onnx")) + pad = ONNXReplicationPadding() + pad.export_onnx(path.join(utils_dir, "replication_pad.onnx")) + pad.export_onnx(path.join(utils_dir, "pad.onnx")) # for compatibility tta_split = ONNXTTASplit() tta_split.export_onnx(path.join(utils_dir, "tta_split.onnx")) diff --git a/waifu2x/unlimited_waifu2x/public_html/script.js b/waifu2x/unlimited_waifu2x/public_html/script.js index 8c0348f4..20c98bed 100644 --- a/waifu2x/unlimited_waifu2x/public_html/script.js +++ b/waifu2x/unlimited_waifu2x/public_html/script.js @@ -28,9 +28,9 @@ function gen_arch_config() /* swin_unet */ config["swin_unet"] = { - art: {color_stability: true}, - art_scan: {color_stability: false}, - photo: {color_stability: false}}; + art: {color_stability: true, padding: "replication"}, + art_scan: {color_stability: false, padding: "replication"}, + photo: {color_stability: false, padding: "reflection"}}; var swin = config["swin_unet"]; const calc_tile_size_swin_unet = function (tile_size, config) { while (true) { @@ -66,7 +66,8 @@ function gen_arch_config() }; var base_config = { arch: "cunet", domain: "art", calc_tile_size: calc_tile_size_cunet, - color_stability: true + color_stability: true, + padding: "replication", }; config["cunet"]["art"] = { scale2x: {...base_config, scale: 2, offset: 36}, @@ -459,9 +460,9 @@ const onnx_runner = { x = await this.alpha_border_padding(rgb, alpha1, BigInt(config.offset)); // _debug_print_image_data(this.to_image_data(x.data, null, x.dims[3], x.dims[2])); x = await this.padding(x, BigInt(p.pad[0]), BigInt(p.pad[1]), - BigInt(p.pad[2]), BigInt(p.pad[3])); + BigInt(p.pad[2]), BigInt(p.pad[3]), config.padding); alpha3 = await this.padding(alpha3, BigInt(p.pad[0]), BigInt(p.pad[1]), - BigInt(p.pad[2]), BigInt(p.pad[3])); + BigInt(p.pad[2]), BigInt(p.pad[3]), config.padding); alpha1 = null; } else { var alpha3 = {data: null}; @@ -470,7 +471,7 @@ const onnx_runner = { await seam_blending.build(); var p = seam_blending.get_rendering_config(); x = await this.padding(x, BigInt(p.pad[0]), BigInt(p.pad[1]), - BigInt(p.pad[2]), BigInt(p.pad[3])); + BigInt(p.pad[2]), BigInt(p.pad[3]), config.padding); } var ch, h, w; [ch, h, w] = [x.dims[1], x.dims[2], x.dims[3]]; @@ -556,8 +557,8 @@ const onnx_runner = { console.timeEnd("render"); this.running = false; }, - padding: async function(x, left, right, top, bottom) { - const ses = await onnx_session.get_session(CONFIG.get_helper_model_path("pad")); + padding: async function(x, left, right, top, bottom, mode) { + const ses = await onnx_session.get_session(CONFIG.get_helper_model_path(mode + "_pad")); left = new ort.Tensor('int64', BigInt64Array.from([left]), []); right = new ort.Tensor('int64', BigInt64Array.from([right]), []); top = new ort.Tensor('int64', BigInt64Array.from([top]), []); From dba3841c9e2eabf21f2411cc21e7348aefd46ad7 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Mon, 14 Oct 2024 10:18:17 +0900 Subject: [PATCH 04/10] waifu2x: Update get_last_layer in train --- waifu2x/training/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/waifu2x/training/trainer.py b/waifu2x/training/trainer.py index a510525f..790905fc 100644 --- a/waifu2x/training/trainer.py +++ b/waifu2x/training/trainer.py @@ -146,11 +146,12 @@ def get_last_layer(model): "waifu2x.swin_unet_8x", "waifu2x.winc_unet_1x", "waifu2x.winc_unet_2x", - "waifu2x.winc_unet_4x", "waifu2x.winc_unet_1x_small", "waifu2x.winc_unet_2x_small", }: return model.unet.to_image.proj.weight + elif model.name in {"waifu2x.winc_unet_4x"}: + return model.unet.to_image.conv.weight elif model.name in {"waifu2x.cunet", "waifu2x.upcunet"}: return model.unet2.conv_bottom.weight elif model.name in {"waifu2x.upconv_7", "waifu2x.vgg_7"}: From 87a248109ed77c8d75518551e3e3769cc1428c33 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Tue, 15 Oct 2024 03:01:47 +0900 Subject: [PATCH 05/10] waifu2x: Change SelfSupervisedDiscriminator argument --- waifu2x/training/trainer.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/waifu2x/training/trainer.py b/waifu2x/training/trainer.py index 790905fc..f1ab1ab4 100644 --- a/waifu2x/training/trainer.py +++ b/waifu2x/training/trainer.py @@ -279,12 +279,8 @@ def train_step(self, data): else: z, y = self.diff_aug(z, y) fake = z - if isinstance(self.discriminator, SelfSupervisedDiscriminator): - *z_real, _ = self.discriminator(torch.clamp(fake, 0, 1), y, scale_factor) - if len(z_real) == 1: - z_real = z_real[0] - else: - z_real = self.discriminator(torch.clamp(fake, 0, 1), y, scale_factor) + + z_real = self.discriminator(torch.clamp(fake, 0, 1), y, scale_factor) recon_loss = self.criterion(z, y) generator_loss = self.discriminator_criterion(z_real) self.sum_p_loss += recon_loss.item() @@ -304,8 +300,8 @@ def train_step(self, data): self.discriminator.requires_grad_(True) if isinstance(self.discriminator, SelfSupervisedDiscriminator): *z_fake, fake_ss_loss = self.discriminator(torch.clamp(fake.detach(), 0, 1), - y, scale_factor) - *z_real, real_ss_loss = self.discriminator(y, y, scale_factor) + y, scale_factor, train=True) + *z_real, real_ss_loss = self.discriminator(y, y, scale_factor, train=True) if len(z_fake) == 1: z_fake = z_fake[0] z_real = z_real[0] From 07a66cd673144702ccfde8531d525668bd2463b4 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Wed, 16 Oct 2024 05:43:14 +0900 Subject: [PATCH 06/10] waifu2x: Add tile-mode training for non pixel-wise loss --- waifu2x/models/winc_unet.py | 37 ++++++++++++++++++++++++++++++++++++- waifu2x/training/trainer.py | 8 ++++++-- 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/waifu2x/models/winc_unet.py b/waifu2x/models/winc_unet.py index 8fa589a9..46579468 100644 --- a/waifu2x/models/winc_unet.py +++ b/waifu2x/models/winc_unet.py @@ -181,6 +181,7 @@ def __init__(self, in_channels, out_channels, base_dim=96, scale_factor=2): super(WincUNetBase, self).__init__() assert scale_factor in {1, 2, 4} + self.scale_factor = scale_factor C = base_dim C2 = int(C * lv2_ratio) # assert C % 32 == 0 and C2 % 32 == 0 # slow when C % 32 != 0 @@ -212,7 +213,12 @@ def __init__(self, in_channels, out_channels, base_dim=96, basic_module_init(self.wac1_proj) basic_module_init(self.fusion) - def forward(self, x): + self.tile_mode = False + + def set_tile_mode(self): + self.tile_mode = True + + def _forward(self, x): ov = self.overscan(x) x = self.patch(x) x = F.leaky_relu(x, 0.2, inplace=True) @@ -232,6 +238,32 @@ def forward(self, x): return z + def _forward_tile4x4(self, x): + tl = x[:, :, :64, :64] + tr = x[:, :, :64, -64:] + bl = x[:, :, -64:, :64] + br = x[:, :, -64:, -64:] + x = torch.cat([tl, tr, bl, br], dim=0).contiguous() + x = self._forward(x) + tl, tr, bl, br = x.split(x.shape[0] // 4, dim=0) + top = torch.cat([tl, tr], dim=3) + bottom = torch.cat([bl, br], dim=3) + x = torch.cat([top, bottom], dim=2).contiguous() + return x + + def forward(self, x): + if self.tile_mode: + B, C, H, W = x.shape + if self.scale_factor == 4: + assert H == 110 and W == H + assert H == 110 and W == H + pass + else: + raise NotImplementedError() + return self._forward_tile4x4(x) + else: + return self._forward(x) + def tile_size_validator(size): return (size > 16 and @@ -308,6 +340,9 @@ def forward(self, x): else: return torch.clamp(z, 0., 1.) + def set_tile_mode(self): + self.unet.set_tile_mode() + def to_2x(self, shared=True): unet = self.unet if shared else copy.deepcopy(self.unet) return WincUNetDownscaled(unet, downscale_factor=2, diff --git a/waifu2x/training/trainer.py b/waifu2x/training/trainer.py index f1ab1ab4..86bec3ca 100644 --- a/waifu2x/training/trainer.py +++ b/waifu2x/training/trainer.py @@ -476,6 +476,8 @@ def setup_model(self): if self.args.freeze and hasattr(self.model, "freeze"): self.model.freeze() logger.debug("call model.freeze()") + if self.args.tile_mode: + self.model.set_tile_mode() def create_model(self): kwargs = {"in_channels": 3, "out_channels": 3} @@ -647,8 +649,8 @@ def train(args): "waifu2x.swin_unet_2x", "waifu2x.swin_unet_4x"} assert args.discriminator_stop_criteria < args.generator_start_criteria - if args.size % 4 != 0: - raise ValueError("--size must be a multiple of 4") + # if args.size % 4 != 0: + # raise ValueError("--size must be a multiple of 4") if args.arch in ARCH_SWIN_UNET and ((args.size - 16) % 12 != 0 or (args.size - 16) % 16 != 0): raise ValueError("--size must be `(SIZE - 16) % 12 == 0 and (SIZE - 16) % 16 == 0` for SwinUNet models") if args.method in {"noise", "noise_scale", "noise_scale4x"} and args.noise_level is None: @@ -772,6 +774,8 @@ def register(subparsers, default_parser): help="use only bicubic downsampling for bicubic downsampling restoration (classic super-resolution)") parser.add_argument("--freeze", action="store_true", help="call model.freeze() if avaliable") + parser.add_argument("--tile-mode", action="store_true", + help="call model.set_tile_mode()") parser.add_argument("--pre-antialias", action="store_true", help=("Set `pre_antialias=True` for SwinUNet4x.")) parser.add_argument("--privilege", action="store_true", From 06956a6fd97825bf08ab78701a9b67d3b4acb453 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Thu, 17 Oct 2024 09:12:39 +0900 Subject: [PATCH 07/10] iw3: Add --format option (image format); Fix --keyframe --- iw3/utils.py | 48 ++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/iw3/utils.py b/iw3/utils.py index 1f311985..c506d192 100644 --- a/iw3/utils.py +++ b/iw3/utils.py @@ -301,7 +301,7 @@ def to_deciaml(f, scale, zfill=0): else: metadata = "" - return basename + metadata + auto_detect_suffix + (args.video_extension if video else ".png") + return basename + metadata + auto_detect_suffix + (args.video_extension if video else get_image_ext(args.format)) def make_video_codec_option(args): @@ -333,8 +333,37 @@ def make_video_codec_option(args): return options -def save_image(im, output_filename): - im.save(output_filename) +def get_image_ext(format): + if format == "png": + return ".png" + elif format == "webp": + return ".webp" + elif format == "jpeg": + return ".jpg" + else: + raise NotImplementedError(format) + + +def save_image(im, output_filename, format="png"): + if format == "png": + options = { + "compress_level": 6 + } + elif format == "webp": + options = { + "quality": 95, + "method": 4, + "lossless": True + } + elif format == "jpeg": + options = { + "quality": 95, + "subsampling": "4:2:0", + } + else: + raise NotImplementedError(format) + + im.save(output_filename, format=format, **options) def remove_bg_from_image(im, bg_session): @@ -568,7 +597,7 @@ def process_images(files, output_dir, args, depth_model, side_model, title=None) continue im = TF.to_tensor(im).to(args.state["device"]) output = process_image(im, args, depth_model, side_model) - f = pool.submit(save_image, output, output_filename) + f = pool.submit(save_image, output, output_filename, format=args.format) # f.result() # for debug futures.append(f) pbar.update(1) @@ -747,11 +776,12 @@ def process_video_keyframes(input_filename, output_path, args, depth_model, side futures = [] def frame_callback(frame): - output = process_image(frame.to_image(), args, depth_model, side_model) + im = TF.to_tensor(frame.to_image()).to(args.state["device"]) + output = process_image(im, args, depth_model, side_model) output_filename = path.join( output_dir, - path.basename(output_dir) + "_" + str(frame.pts).zfill(8) + FULL_SBS_SUFFIX + ".png") - f = pool.submit(save_image, output, output_filename) + path.basename(output_dir) + "_" + str(frame.pts).zfill(8) + FULL_SBS_SUFFIX + get_image_ext(args.format)) + f = pool.submit(save_image, output, output_filename, format=args.format) futures.append(f) VU.process_video_keyframes(input_filename, frame_callback=frame_callback, min_interval_sec=args.keyframe_interval, @@ -1284,7 +1314,7 @@ def fix_rgb_depth_pair(files1, files2): output_filename = path.join( output_dir, make_output_filename(rgb_filename, args, video=False)) - f = pool.submit(save_image, sbs, output_filename) + f = pool.submit(save_image, sbs, output_filename, format=args.format) futures.append(f) pbar.update(1) if suspend_event is not None: @@ -1450,6 +1480,8 @@ def __repr__(self): help="max inference worker threads for video processing. 0 is disabled") parser.add_argument("--video-format", "-vf", type=str, default="mp4", choices=["mp4", "mkv", "avi"], help="video container format") + parser.add_argument("--format", "-f", type=str, default="png", choices=["png", "webp", "jpeg"], + help="output image format") parser.add_argument("--video-codec", "-vc", type=str, default=None, help="video codec") parser.add_argument("--metadata", type=str, nargs="?", default=None, const="filename", choices=["filename"], From 216a87ea4fbe3e582696c938ca5faffad9da1401 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Thu, 17 Oct 2024 09:36:05 +0900 Subject: [PATCH 08/10] iw3: Add Image Format option in GUI --- iw3/gui.py | 11 +++++++++++ iw3/locales/ja.yml | 1 + 2 files changed, 12 insertions(+) diff --git a/iw3/gui.py b/iw3/gui.py index 845c59d5..f16b2fe5 100644 --- a/iw3/gui.py +++ b/iw3/gui.py @@ -136,10 +136,20 @@ def initialize_component(self): name="chk_metadata") self.chk_metadata.SetValue(False) + self.sep_image_format = wx.StaticLine(self.pnl_file, size=(2, 16), style=wx.LI_VERTICAL) + self.lbl_image_format = wx.StaticText(self.pnl_file, label=" " + T("Image Format")) + self.cbo_image_format = wx.ComboBox(self.pnl_file, choices=["png", "jpeg", "webp"], + style=wx.CB_READONLY, name="cbo_image_format") + self.cbo_image_format.SetSelection(0) + self.cbo_image_format.SetToolTip(T("Output Image Format")) + sublayout = wx.BoxSizer(wx.HORIZONTAL) sublayout.Add(self.chk_resume, flag=wx.ALIGN_LEFT | wx.ALIGN_CENTER_VERTICAL) sublayout.Add(self.chk_recursive, flag=wx.ALIGN_LEFT | wx.ALIGN_CENTER_VERTICAL) sublayout.Add(self.chk_metadata, flag=wx.ALIGN_LEFT | wx.ALIGN_CENTER_VERTICAL) + sublayout.Add(self.sep_image_format, flag=wx.ALIGN_LEFT | wx.ALIGN_CENTER_VERTICAL) + sublayout.Add(self.lbl_image_format, flag=wx.ALIGN_LEFT | wx.ALIGN_CENTER_VERTICAL) + sublayout.Add(self.cbo_image_format, flag=wx.ALIGN_LEFT | wx.ALIGN_CENTER_VERTICAL) layout = wx.GridBagSizer(vgap=4, hgap=4) layout.Add(self.lbl_input, (0, 0), flag=wx.ALIGN_RIGHT | wx.ALIGN_CENTER_VERTICAL) @@ -1121,6 +1131,7 @@ def parse_args(self): pix_fmt=self.cbo_pix_fmt.GetValue(), colorspace=self.cbo_colorspace.GetValue(), video_format=self.cbo_video_format.GetValue(), + format=self.cbo_image_format.GetValue(), video_codec=self.cbo_video_codec.GetValue(), crf=int(self.cbo_crf.GetValue()), profile_level=profile_level, diff --git a/iw3/locales/ja.yml b/iw3/locales/ja.yml index ed0fbcc7..36c3b684 100644 --- a/iw3/locales/ja.yml +++ b/iw3/locales/ja.yml @@ -11,6 +11,7 @@ "Skip processing when the output file already exists": "出力ファイルが存在する場合に処理をスキップ" "Process all subfolders": "すべてのサブフォルダを処理" "Add metadata to filename": "ファイル名にメタデータを追加" +"Image Format": "画像フォーマット" "Play": "再生" "Error": "エラー" From ca38e20a81f3f67ad7671f28d963e48881e8425a Mon Sep 17 00:00:00 2001 From: nagadomi Date: Thu, 17 Oct 2024 10:36:41 +0900 Subject: [PATCH 09/10] Use truststore to fix SSL error in VMWare and GCP --- iw3/__init__.py | 2 ++ requirements.txt | 1 + waifu2x/__init__.py | 3 +++ 3 files changed, 6 insertions(+) diff --git a/iw3/__init__.py b/iw3/__init__.py index e62d23d2..61fff34d 100644 --- a/iw3/__init__.py +++ b/iw3/__init__.py @@ -1,2 +1,4 @@ import os os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' +import truststore +truststore.inject_into_ssl() diff --git a/requirements.txt b/requirements.txt index 20e16962..b459b604 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,3 +28,4 @@ timm numba # only for iww3 sbs training av >= 12.2.0 rembg # for --remove-bg +truststore diff --git a/waifu2x/__init__.py b/waifu2x/__init__.py index 737b0ed5..f09ec866 100644 --- a/waifu2x/__init__.py +++ b/waifu2x/__init__.py @@ -1,5 +1,8 @@ import os os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' +import truststore +truststore.inject_into_ssl() + from .utils import Waifu2x from . import models From b6eef3ef69b37cfa2855e05265925440e0b8aed3 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Thu, 17 Oct 2024 10:39:53 +0900 Subject: [PATCH 10/10] Update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7a50a65f..5a841093 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ CLI tools are also available to filter out low quality images using these result #### Dependencies -- Python 3 (Probably works with Python 3.9 or later, developed with 3.10) +- Python 3 (Works with Python 3.10 or later, developed with 3.10) - [PyTorch](https://pytorch.org/get-started/locally/) - See requirements.txt