diff --git a/invokeai/app/invocations/pbr_maps.py b/invokeai/app/invocations/pbr_maps.py new file mode 100644 index 00000000000..7475c96e4c5 --- /dev/null +++ b/invokeai/app/invocations/pbr_maps.py @@ -0,0 +1,57 @@ +import pathlib +from typing import Literal + +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output +from invokeai.app.invocations.fields import ImageField, InputField, OutputField +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.image_util.pbr_maps.architecture.pbr_rrdb_net import PBR_RRDB_Net +from invokeai.backend.image_util.pbr_maps.pbr_maps import NORMAL_MAP_MODEL, OTHER_MAP_MODEL, PBRMapsGenerator +from invokeai.backend.util.devices import TorchDevice + + +@invocation_output("pbr_maps-output") +class PBRMapsOutput(BaseInvocationOutput): + normal_map: ImageField = OutputField(default=None, description="The generated normal map") + roughness_map: ImageField = OutputField(default=None, description="The generated roughness map") + displacement_map: ImageField = OutputField(default=None, description="The generated displacement map") + + +@invocation("pbr_maps", title="PBR Maps", tags=["image", "material"], category="image", version="1.0.0") +class PBRMapsInvocation(BaseInvocation): + """Generate Normal, Displacement and Roughness Map from a given image""" + + image: ImageField = InputField(default=None, description="Input image") + tile_size: int = InputField(default=512, description="Tile size") + border_mode: Literal["none", "seamless", "mirror", "replicate"] = InputField( + default="none", description="Border mode to apply to eliminate any artifacts or seams" + ) + + def invoke(self, context: InvocationContext) -> PBRMapsOutput: + image_pil = context.images.get_pil(self.image.image_name, mode="RGB") + + def loader(model_path: pathlib.Path): + return PBRMapsGenerator.load_model(model_path, TorchDevice.choose_torch_device()) + + with ( + context.models.load_remote_model(NORMAL_MAP_MODEL, loader) as normal_map_model, + context.models.load_remote_model(OTHER_MAP_MODEL, loader) as other_map_model, + ): + assert isinstance(normal_map_model, PBR_RRDB_Net) + assert isinstance(other_map_model, PBR_RRDB_Net) + pbr_pipeline = PBRMapsGenerator(normal_map_model, other_map_model, TorchDevice.choose_torch_device()) + normal_map, roughness_map, displacement_map = pbr_pipeline.generate_maps( + image_pil, self.tile_size, self.border_mode + ) + + normal_map = context.images.save(normal_map) + normal_map_field = ImageField(image_name=normal_map.image_name) + + roughness_map = context.images.save(roughness_map) + roughness_map_field = ImageField(image_name=roughness_map.image_name) + + displacement_map = context.images.save(displacement_map) + displacement_map_map_field = ImageField(image_name=displacement_map.image_name) + + return PBRMapsOutput( + normal_map=normal_map_field, roughness_map=roughness_map_field, displacement_map=displacement_map_map_field + ) diff --git a/invokeai/backend/image_util/pbr_maps/architecture/block.py b/invokeai/backend/image_util/pbr_maps/architecture/block.py new file mode 100644 index 00000000000..6c066c7a310 --- /dev/null +++ b/invokeai/backend/image_util/pbr_maps/architecture/block.py @@ -0,0 +1,367 @@ +# Original: https://github.com/joeyballentine/Material-Map-Generator +# Adopted and optimized for Invoke AI + +from collections import OrderedDict +from typing import Any, List, Literal, Optional + +import torch +import torch.nn as nn + +ACTIVATION_LAYER_TYPE = Literal["relu", "leakyrelu", "prelu"] +NORMALIZATION_LAYER_TYPE = Literal["batch", "instance"] +PADDING_LAYER_TYPE = Literal["zero", "reflect", "replicate"] +BLOCK_MODE = Literal["CNA", "NAC", "CNAC"] +UPCONV_BLOCK_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear"] + + +def act(act_type: ACTIVATION_LAYER_TYPE, inplace: bool = True, neg_slope: float = 0.2, n_prelu: int = 1): + """Helper to select Activation Layer""" + if act_type == "relu": + layer = nn.ReLU(inplace) + elif act_type == "leakyrelu": + layer = nn.LeakyReLU(neg_slope, inplace) + elif act_type == "prelu": + layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) + return layer + + +def norm(norm_type: NORMALIZATION_LAYER_TYPE, nc: int): + """Helper to select Normalization Layer""" + if norm_type == "batch": + layer = nn.BatchNorm2d(nc, affine=True) + elif norm_type == "instance": + layer = nn.InstanceNorm2d(nc, affine=False) + return layer + + +def pad(pad_type: PADDING_LAYER_TYPE, padding: int): + """Helper to select Padding Layer""" + if padding == 0 or pad_type == "zero": + return None + if pad_type == "reflect": + layer = nn.ReflectionPad2d(padding) + elif pad_type == "replicate": + layer = nn.ReplicationPad2d(padding) + return layer + + +def get_valid_padding(kernel_size: int, dilation: int): + kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) + padding = (kernel_size - 1) // 2 + return padding + + +def sequential(*args: Any): + # Flatten Sequential. It unwraps nn.Sequential. + if len(args) == 1: + if isinstance(args[0], OrderedDict): + raise NotImplementedError("sequential does not support OrderedDict input.") + return args[0] # No sequential is needed. + modules: List[nn.Module] = [] + for module in args: + if isinstance(module, nn.Sequential): + for submodule in module.children(): + modules.append(submodule) + elif isinstance(module, nn.Module): + modules.append(module) + return nn.Sequential(*modules) + + +def conv_block( + in_nc: int, + out_nc: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + pad_type: Optional[PADDING_LAYER_TYPE] = "zero", + norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None, + act_type: Optional[ACTIVATION_LAYER_TYPE] = "relu", + mode: BLOCK_MODE = "CNA", +): + """ + Conv layer with padding, normalization, activation + mode: CNA --> Conv -> Norm -> Act + NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16) + """ + assert mode in ["CNA", "NAC", "CNAC"], f"Wrong conv mode [{mode}]" + padding = get_valid_padding(kernel_size, dilation) + p = pad(pad_type, padding) if pad_type else None + padding = padding if pad_type == "zero" else 0 + + c = nn.Conv2d( + in_nc, + out_nc, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + groups=groups, + ) + a = act(act_type) if act_type else None + match mode: + case "CNA": + n = norm(norm_type, out_nc) if norm_type else None + return sequential(p, c, n, a) + case "NAC": + if norm_type is None and act_type is not None: + a = act(act_type, inplace=False) + n = norm(norm_type, in_nc) if norm_type else None + return sequential(n, a, p, c) + case "CNAC": + n = norm(norm_type, in_nc) if norm_type else None + return sequential(n, a, p, c) + + +class ConcatBlock(nn.Module): + # Concat the output of a submodule to its input + def __init__(self, submodule: nn.Module): + super(ConcatBlock, self).__init__() + self.sub = submodule + + def forward(self, x: torch.Tensor): + output = torch.cat((x, self.sub(x)), dim=1) + return output + + def __repr__(self): + tmpstr = "Identity .. \n|" + modstr = self.sub.__repr__().replace("\n", "\n|") + tmpstr = tmpstr + modstr + return tmpstr + + +class ShortcutBlock(nn.Module): + # Elementwise sum the output of a submodule to its input + def __init__(self, submodule: nn.Module): + super(ShortcutBlock, self).__init__() + self.sub = submodule + + def forward(self, x: torch.Tensor): + output = x + self.sub(x) + return output + + def __repr__(self): + tmpstr = "Identity + \n|" + modstr = self.sub.__repr__().replace("\n", "\n|") + tmpstr = tmpstr + modstr + return tmpstr + + +class ShortcutBlockSPSR(nn.Module): + # Elementwise sum the output of a submodule to its input + def __init__(self, submodule: nn.Module): + super(ShortcutBlockSPSR, self).__init__() + self.sub = submodule + + def forward(self, x: torch.Tensor): + return x, self.sub + + def __repr__(self): + tmpstr = "Identity + \n|" + modstr = self.sub.__repr__().replace("\n", "\n|") + tmpstr = tmpstr + modstr + return tmpstr + + +class ResNetBlock(nn.Module): + """ + ResNet Block, 3-3 style + with extra residual scaling used in EDSR + (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17) + """ + + def __init__( + self, + in_nc: int, + mid_nc: int, + out_nc: int, + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + pad_type: PADDING_LAYER_TYPE = "zero", + norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None, + act_type: Optional[ACTIVATION_LAYER_TYPE] = "relu", + mode: BLOCK_MODE = "CNA", + res_scale: int = 1, + ): + super(ResNetBlock, self).__init__() + conv0 = conv_block( + in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, norm_type, act_type, mode + ) + if mode == "CNA": + act_type = None + if mode == "CNAC": # Residual path: |-CNAC-| + act_type = None + norm_type = None + conv1 = conv_block( + mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, norm_type, act_type, mode + ) + + self.res = sequential(conv0, conv1) + self.res_scale = res_scale + + def forward(self, x: torch.Tensor): + res = self.res(x).mul(self.res_scale) + return x + res + + +class ResidualDenseBlock_5C(nn.Module): + """ + Residual Dense Block + style: 5 convs + The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) + """ + + def __init__( + self, + nc: int, + kernel_size: int = 3, + gc: int = 32, + stride: int = 1, + bias: bool = True, + pad_type: PADDING_LAYER_TYPE = "zero", + norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None, + act_type: ACTIVATION_LAYER_TYPE = "leakyrelu", + mode: BLOCK_MODE = "CNA", + ): + super(ResidualDenseBlock_5C, self).__init__() + # gc: growth channel, i.e. intermediate channels + self.conv1 = conv_block( + nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode + ) + self.conv2 = conv_block( + nc + gc, + gc, + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=norm_type, + act_type=act_type, + mode=mode, + ) + self.conv3 = conv_block( + nc + 2 * gc, + gc, + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=norm_type, + act_type=act_type, + mode=mode, + ) + self.conv4 = conv_block( + nc + 3 * gc, + gc, + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=norm_type, + act_type=act_type, + mode=mode, + ) + if mode == "CNA": + last_act = None + else: + last_act = act_type + self.conv5 = conv_block( + nc + 4 * gc, nc, 3, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=last_act, mode=mode + ) + + def forward(self, x: torch.Tensor): + x1 = self.conv1(x) + x2 = self.conv2(torch.cat((x, x1), 1)) + x3 = self.conv3(torch.cat((x, x1, x2), 1)) + x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5.mul(0.2) + x + + +class RRDB(nn.Module): + """ + Residual in Residual Dense Block + (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) + """ + + def __init__( + self, + nc: int, + kernel_size: int = 3, + gc: int = 32, + stride: int = 1, + bias: bool = True, + pad_type: PADDING_LAYER_TYPE = "zero", + norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None, + act_type: ACTIVATION_LAYER_TYPE = "leakyrelu", + mode: BLOCK_MODE = "CNA", + ): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode) + self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode) + self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode) + + def forward(self, x: torch.Tensor): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out.mul(0.2) + x + + +# Upsampler +def pixelshuffle_block( + in_nc: int, + out_nc: int, + upscale_factor: int = 2, + kernel_size: int = 3, + stride: int = 1, + bias: bool = True, + pad_type: PADDING_LAYER_TYPE = "zero", + norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None, + act_type: ACTIVATION_LAYER_TYPE = "relu", +): + """ + Pixel shuffle layer + (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional + Neural Network, CVPR17) + """ + conv = conv_block( + in_nc, + out_nc * (upscale_factor**2), + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=None, + act_type=None, + ) + pixel_shuffle = nn.PixelShuffle(upscale_factor) + + n = norm(norm_type, out_nc) if norm_type else None + a = act(act_type) if act_type else None + return sequential(conv, pixel_shuffle, n, a) + + +def upconv_blcok( + in_nc: int, + out_nc: int, + upscale_factor: int = 2, + kernel_size: int = 3, + stride: int = 1, + bias: bool = True, + pad_type: PADDING_LAYER_TYPE = "zero", + norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None, + act_type: ACTIVATION_LAYER_TYPE = "relu", + mode: UPCONV_BLOCK_MODE = "nearest", +): + # Adopted from https://distill.pub/2016/deconv-checkerboard/ + upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode) + conv = conv_block( + in_nc, out_nc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type + ) + return sequential(upsample, conv) diff --git a/invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py b/invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py new file mode 100644 index 00000000000..2d8e443a25e --- /dev/null +++ b/invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py @@ -0,0 +1,70 @@ +# Original: https://github.com/joeyballentine/Material-Map-Generator +# Adopted and optimized for Invoke AI + +import math +from typing import Literal, Optional + +import torch +import torch.nn as nn + +import invokeai.backend.image_util.pbr_maps.architecture.block as B + +UPSCALE_MODE = Literal["upconv", "pixelshuffle"] + + +class PBR_RRDB_Net(nn.Module): + def __init__( + self, + in_nc: int, + out_nc: int, + nf: int, + nb: int, + gc: int = 32, + upscale: int = 4, + norm_type: Optional[B.NORMALIZATION_LAYER_TYPE] = None, + act_type: B.ACTIVATION_LAYER_TYPE = "leakyrelu", + mode: B.BLOCK_MODE = "CNA", + res_scale: int = 1, + upsample_mode: UPSCALE_MODE = "upconv", + ): + super(PBR_RRDB_Net, self).__init__() + n_upscale = int(math.log(upscale, 2)) + if upscale == 3: + n_upscale = 1 + + fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None) + rb_blocks = [ + B.RRDB( + nf, + kernel_size=3, + gc=32, + stride=1, + bias=True, + pad_type="zero", + norm_type=norm_type, + act_type=act_type, + mode="CNA", + ) + for _ in range(nb) + ] + LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode) + + if upsample_mode == "upconv": + upsample_block = B.upconv_blcok + elif upsample_mode == "pixelshuffle": + upsample_block = B.pixelshuffle_block + + if upscale == 3: + upsampler = upsample_block(nf, nf, 3, act_type=act_type) + else: + upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)] + + HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type) + HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None) + + self.model = B.sequential( + fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)), *upsampler, HR_conv0, HR_conv1 + ) + + def forward(self, x: torch.Tensor): + return self.model(x) diff --git a/invokeai/backend/image_util/pbr_maps/pbr_maps.py b/invokeai/backend/image_util/pbr_maps/pbr_maps.py new file mode 100644 index 00000000000..fb1b09c8a58 --- /dev/null +++ b/invokeai/backend/image_util/pbr_maps/pbr_maps.py @@ -0,0 +1,104 @@ +# Original: https://github.com/joeyballentine/Material-Map-Generator +# Adopted and optimized for Invoke AI + +import pathlib +from typing import Any, Literal + +import cv2 +import numpy as np +import numpy.typing as npt +import torch +from PIL import Image + +from invokeai.backend.image_util.pbr_maps.architecture.pbr_rrdb_net import PBR_RRDB_Net +from invokeai.backend.image_util.pbr_maps.utils.image_ops import crop_seamless, esrgan_launcher_split_merge + +NORMAL_MAP_MODEL = "https://github.com/joeyballentine/Material-Map-Generator/blob/master/utils/models/1x_NormalMapGenerator-CX-Lite_200000_G.pth" +OTHER_MAP_MODEL = "https://github.com/joeyballentine/Material-Map-Generator/blob/master/utils/models/1x_FrankenMapGenerator-CX-Lite_215000_G.pth" + + +class PBRMapsGenerator: + def __init__(self, normal_map_model: PBR_RRDB_Net, other_map_model: PBR_RRDB_Net, device: torch.device) -> None: + self.normal_map_model = normal_map_model + self.other_map_model = other_map_model + self.device = device + + @staticmethod + def load_model(model_path: pathlib.Path, device: torch.device) -> PBR_RRDB_Net: + state_dict = torch.load(model_path.as_posix(), map_location="cpu") + + model = PBR_RRDB_Net( + 3, + 3, + 32, + 12, + gc=32, + upscale=1, + norm_type=None, + act_type="leakyrelu", + mode="CNA", + res_scale=1, + upsample_mode="upconv", + ) + + model.load_state_dict(state_dict, strict=False) + del state_dict + model.eval() + + for _, v in model.named_parameters(): + v.requires_grad = False + + return model.to(device) + + def process(self, img: npt.NDArray[Any], model: PBR_RRDB_Net): + img = img.astype(np.float32) / np.iinfo(img.dtype).max + img = img[..., ::-1].copy() + tensor_img = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).to(self.device) + + with torch.no_grad(): + output = model(tensor_img).data.squeeze(0).float().cpu().clamp_(0, 1).numpy() + output = output[[2, 1, 0], :, :] + output = np.transpose(output, (1, 2, 0)) + output = (output * 255.0).round() + return output + + def _cv2_to_pil(self, image: npt.NDArray[Any]): + return Image.fromarray(cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_RGB2BGR)) + + def generate_maps( + self, + image: Image.Image, + tile_size: int = 512, + border_mode: Literal["none", "seamless", "mirror", "replicate"] = "none", + ): + models = [self.normal_map_model, self.other_map_model] + np_image = np.array(image).astype(np.uint8) + + match border_mode: + case "seamless": + np_image = cv2.copyMakeBorder(np_image, 16, 16, 16, 16, cv2.BORDER_WRAP) + case "mirror": + np_image = cv2.copyMakeBorder(np_image, 16, 16, 16, 16, cv2.BORDER_REFLECT_101) + case "replicate": + np_image = cv2.copyMakeBorder(np_image, 16, 16, 16, 16, cv2.BORDER_REPLICATE) + case "none": + pass + + img_height, img_width = np_image.shape[:2] + + # Checking whether to perform tiled inference + do_split = img_height > tile_size or img_width > tile_size + + if do_split: + rlts = esrgan_launcher_split_merge(np_image, self.process, models, scale_factor=1, tile_size=tile_size) + else: + rlts = [self.process(np_image, model) for model in models] + + if border_mode != "none": + rlts = [crop_seamless(rlt) for rlt in rlts] + + normal_map = self._cv2_to_pil(rlts[0]) + roughness = self._cv2_to_pil(rlts[1][:, :, 1]) + displacement = self._cv2_to_pil(rlts[1][:, :, 0]) + + return normal_map, roughness, displacement diff --git a/invokeai/backend/image_util/pbr_maps/utils/image_ops.py b/invokeai/backend/image_util/pbr_maps/utils/image_ops.py new file mode 100644 index 00000000000..426620797cb --- /dev/null +++ b/invokeai/backend/image_util/pbr_maps/utils/image_ops.py @@ -0,0 +1,93 @@ +# Original: https://github.com/joeyballentine/Material-Map-Generator +# Adopted and optimized for Invoke AI + +import math +from typing import Any, Callable, List + +import numpy as np +import numpy.typing as npt + +from invokeai.backend.image_util.pbr_maps.architecture.pbr_rrdb_net import PBR_RRDB_Net + + +def crop_seamless(img: npt.NDArray[Any]): + img_height, img_width = img.shape[:2] + y, x = 16, 16 + h, w = img_height - 32, img_width - 32 + img = img[y : y + h, x : x + w] + return img + + +# from https://github.com/ata4/esrgan-launcher/blob/master/upscale.py +def esrgan_launcher_split_merge( + input_image: npt.NDArray[Any], + upscale_function: Callable[[npt.NDArray[Any], PBR_RRDB_Net], npt.NDArray[Any]], + models: List[PBR_RRDB_Net], + scale_factor: int = 4, + tile_size: int = 512, + tile_padding: float = 0.125, +): + width, height, depth = input_image.shape + output_width = width * scale_factor + output_height = height * scale_factor + output_shape = (output_width, output_height, depth) + + # start with black image + output_images = [np.zeros(output_shape, np.uint8) for _ in range(len(models))] + + tile_padding = math.ceil(tile_size * tile_padding) + tile_size = math.ceil(tile_size / scale_factor) + + tiles_x = math.ceil(width / tile_size) + tiles_y = math.ceil(height / tile_size) + + for y in range(tiles_y): + for x in range(tiles_x): + # extract tile from input image + ofs_x = x * tile_size + ofs_y = y * tile_size + + # input tile area on total image + input_start_x = ofs_x + input_end_x = min(ofs_x + tile_size, width) + + input_start_y = ofs_y + input_end_y = min(ofs_y + tile_size, height) + + # input tile area on total image with padding + input_start_x_pad = max(input_start_x - tile_padding, 0) + input_end_x_pad = min(input_end_x + tile_padding, width) + + input_start_y_pad = max(input_start_y - tile_padding, 0) + input_end_y_pad = min(input_end_y + tile_padding, height) + + # input tile dimensions + input_tile_width = input_end_x - input_start_x + input_tile_height = input_end_y - input_start_y + + input_tile = input_image[input_start_x_pad:input_end_x_pad, input_start_y_pad:input_end_y_pad] + + for idx, model in enumerate(models): + # upscale tile + output_tile = upscale_function(input_tile, model) + + # output tile area on total image + output_start_x = input_start_x * scale_factor + output_end_x = input_end_x * scale_factor + + output_start_y = input_start_y * scale_factor + output_end_y = input_end_y * scale_factor + + # output tile area without padding + output_start_x_tile = (input_start_x - input_start_x_pad) * scale_factor + output_end_x_tile = output_start_x_tile + input_tile_width * scale_factor + + output_start_y_tile = (input_start_y - input_start_y_pad) * scale_factor + output_end_y_tile = output_start_y_tile + input_tile_height * scale_factor + + # put tile into output image + output_images[idx][output_start_x:output_end_x, output_start_y:output_end_y] = output_tile[ + output_start_x_tile:output_end_x_tile, output_start_y_tile:output_end_y_tile + ] + + return output_images