From 37519c8a4c2debc26128b3bb1784d3160b7a26db Mon Sep 17 00:00:00 2001 From: Splendide Imaginarius <119545140+Splendide-Imaginarius@users.noreply.github.com> Date: Thu, 11 Apr 2024 02:39:08 +0000 Subject: [PATCH] Color Transfer: add Initial Reference Image parameter --- .../src/nodes/impl/color_transfer/mean_std.py | 5 +++- .../image_filter/correction/color_transfer.py | 29 ++++++++++++++----- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/backend/src/nodes/impl/color_transfer/mean_std.py b/backend/src/nodes/impl/color_transfer/mean_std.py index 89612dcbf..0717e3b01 100644 --- a/backend/src/nodes/impl/color_transfer/mean_std.py +++ b/backend/src/nodes/impl/color_transfer/mean_std.py @@ -89,6 +89,7 @@ def scale_array( def mean_std_transfer( img: np.ndarray, ref_img: np.ndarray, + init_img: np.ndarray, colorspace: TransferColorSpace, overflow_method: OverflowMethod, valid_indices: np.ndarray, @@ -118,12 +119,14 @@ def mean_std_transfer( c_clip_min, c_clip_max = (-127, 127) img = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2LAB) + init_img = cv2.cvtColor(init_img, cv2.COLOR_BGR2LAB) elif colorspace == TransferColorSpace.RGB: a_clip_min, a_clip_max = (0, 1) b_clip_min, b_clip_max = (0, 1) c_clip_min, c_clip_max = (0, 1) img = img[:, :, :3] ref_img = ref_img[:, :, :3] + init_img = init_img[:, :, :3] else: raise ValueError(f"Invalid color space {colorspace}") @@ -135,7 +138,7 @@ def mean_std_transfer( b_std_tar, c_mean_tar, c_std_tar, - ) = image_stats(img[valid_indices]) + ) = image_stats(init_img[valid_indices]) ( a_mean_src, a_std_src, diff --git a/backend/src/packages/chaiNNer_standard/image_filter/correction/color_transfer.py b/backend/src/packages/chaiNNer_standard/image_filter/correction/color_transfer.py index 666b1cbf8..e09c123bf 100644 --- a/backend/src/packages/chaiNNer_standard/image_filter/correction/color_transfer.py +++ b/backend/src/packages/chaiNNer_standard/image_filter/correction/color_transfer.py @@ -42,7 +42,7 @@ class TransferColorAlgorithm(Enum): icon="MdInput", inputs=[ ImageInput("Image", channels=[3, 4]), - ImageInput("Reference Image", channels=[3, 4]), + ImageInput("Goal Reference Image", channels=[3, 4]), EnumInput( TransferColorAlgorithm, label="Algorithm", @@ -50,6 +50,9 @@ class TransferColorAlgorithm(Enum): default=TransferColorAlgorithm.MEAN_STD, ).with_id(5), if_enum_group(5, TransferColorAlgorithm.MEAN_STD)( + ImageInput("Initial Reference Image", channels=[3, 4]) + .make_optional() + .with_id(6), EnumInput( TransferColorSpace, label="Colorspace", @@ -65,10 +68,14 @@ def color_transfer_node( img: np.ndarray, ref_img: np.ndarray, algorithm: TransferColorAlgorithm, + init_img: np.ndarray | None, colorspace: TransferColorSpace, overflow_method: OverflowMethod, reciprocal_scale: bool, ) -> np.ndarray: + if init_img is None: + init_img = img + _, _, img_c = get_h_w_c(img) # Preserve alpha @@ -77,6 +84,13 @@ def color_transfer_node( alpha = img[:, :, 3] bgr_img = img[:, :, :3] + _, _, init_img_c = get_h_w_c(init_img) + + init_alpha = None + if init_img_c == 4: + init_alpha = init_img[:, :, 3] + bgr_init_img = init_img[:, :, :3] + _, _, ref_img_c = get_h_w_c(ref_img) ref_alpha = None @@ -86,9 +100,9 @@ def color_transfer_node( # Don't process RGB data if the pixel is fully transparent, since # such RGB data is indeterminate. - valid_rgb_indices = np.ones(img.shape[:-1], dtype=bool) - if alpha is not None: - valid_rgb_indices = alpha > 0 + init_valid_rgb_indices = np.ones(init_img.shape[:-1], dtype=bool) + if init_alpha is not None: + init_valid_rgb_indices = init_alpha > 0 ref_valid_rgb_indices = np.ones(ref_img.shape[:-1], dtype=bool) if ref_alpha is not None: @@ -99,19 +113,20 @@ def color_transfer_node( transfer = mean_std_transfer( bgr_img, bgr_ref_img, + bgr_init_img, colorspace, overflow_method, reciprocal_scale=reciprocal_scale, - valid_indices=valid_rgb_indices, + valid_indices=init_valid_rgb_indices, ref_valid_indices=ref_valid_rgb_indices, ) elif algorithm == TransferColorAlgorithm.LINEAR_HISTOGRAM: transfer = linear_histogram_transfer( - bgr_img, bgr_ref_img, valid_rgb_indices, ref_valid_rgb_indices + bgr_img, bgr_ref_img, init_valid_rgb_indices, ref_valid_rgb_indices ) elif algorithm == TransferColorAlgorithm.PRINCIPAL_COLOR: transfer = principal_color_transfer( - bgr_img, bgr_ref_img, valid_rgb_indices, ref_valid_rgb_indices + bgr_img, bgr_ref_img, init_valid_rgb_indices, ref_valid_rgb_indices ) if alpha is not None: