Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: self supervised labeled cls dataset #597

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
95 changes: 95 additions & 0 deletions data/self_supervised_labeled_cls_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os.path
import warnings

import numpy as np
import torch

# import torchvision.transforms as transforms
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import (
make_labeled_dataset,
make_labeled_path_dataset,
)
from data.utils import load_image


class SelfSupervisedLabeledClsDataset(BaseDataset):
"""
This dataset class can load unaligned/unpaired datasets.

It requires two directories to host training images from domain A '/path/to/data/trainA'
and from domain B '/path/to/data/trainB' respectively.

Domain A must have labels, at the moment the subdir of domain A acts as the label string (turned into an int)

You can train the model with the dataset flag '--dataroot /path/to/data'.
Similarly, you need to prepare two directories:
'/path/to/data/testA' and '/path/to/data/testB' during test time.
"""

def __init__(self, opt, phase):
"""Initialize this dataset class.

Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseDataset.__init__(self, opt, phase)

if not os.path.isfile(self.dir_A + "/paths.txt"):
self.A_img_paths, self.A_label = make_labeled_dataset(
self.dir_A, opt.data_max_dataset_size
) # load images from '/path/to/data/trainA' as well as labels
self.A_label = np.array(self.A_label)

else:
self.A_img_paths, self.A_label = make_labeled_path_dataset(
self.dir_A, "/paths.txt", opt.data_max_dataset_size
) # load images from '/path/to/data/trainA/paths.txt' as well as labels
self.A_label = np.array(self.A_label, dtype=np.float32)

self.A_size = len(self.A_img_paths) # get the size of dataset A

self.transform_A = get_transform(self.opt, grayscale=(self.input_nc == 1))

self.semantic_nclasses = self.opt.cls_semantic_nclasses

def get_img(
self,
A_img_path,
A_label_mask_path,
A_label_cls,
B_img_path,
B_label_mask_path,
B_label_cls,
index,
):
A_img = load_image(A_img_path)
# apply image transformation
A = self.transform_A(A_img)
# get labels
A_label = self.A_label[index % self.A_size]
A_label_mask = torch.ones_like(A, dtype=torch.long)
if A_label > self.semantic_nclasses - 1:
warnings.warn(
"A label is above number of semantic classes for img %s" % (A_img_path)
)
A_label = self.semantic_nclasses - 1

return {
"A": A,
"A_img_paths": A_img_path,
"A_label_cls": A_label,
"A_label_mask": A_label_mask,
"B": A,
"B_img_paths": A_img_path,
"B_label_cls": A_label,
"B_label_mask": A_label_mask,
}

def __len__(self):
"""Return the total number of images in the dataset.

As we have two datasets with potentially different number of images,
we take a maximum of
"""
return self.A_size
2 changes: 1 addition & 1 deletion docs/options.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Here are all the available options to call with `train.py`
| --name | string | experiment_name | name of the experiment. It decides where to store samples and models |
| --phase | string | train | train, val, test, etc |
| --suffix | string | | customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size} |
| --test_batch_size | int | 1 | input batch size |
| --test_batch_size | int | -1 | input batch size, default to train batch size |
| --warning_mode | flag | | whether to display warning |
| --with_amp | flag | | whether to activate torch amp on forward passes |
| --with_tf32 | flag | | whether to activate tf32 for faster computations (Ampere GPU and beyond only) |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/options.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Here are all the available options to call with ``train.py``
+----------------------+-----------------+-----------------+------------------------------------------------------------------------------------------+
| --suffix | string | | customized suffix: opt.name = opt.name + suffix: e.g., {model}\_{netG}_size{load_size} |
+----------------------+-----------------+-----------------+------------------------------------------------------------------------------------------+
| --test_batch_size | int | 1 | input batch size |
| --test_batch_size | int | -1 | input batch size |
+----------------------+-----------------+-----------------+------------------------------------------------------------------------------------------+
| --warning_mode | flag | | whether to display warning |
+----------------------+-----------------+-----------------+------------------------------------------------------------------------------------------+
Expand Down
2 changes: 1 addition & 1 deletion models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ def get_current_visuals(self, phase="train"):
for name in group:
if phase == "test":
name = name + "_test"
if isinstance(name, str):
if isinstance(name, str) and hasattr(self, name):
cur_visual[name] = getattr(self, name)

visual_ret.append(cur_visual)
Expand Down
11 changes: 6 additions & 5 deletions models/palette_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ def __init__(self, opt, rank):
and self.opt.alg_diffusion_generate_per_class
and not self.use_ref
):
self.nb_classes_inference = (
max(self.opt.f_s_semantic_nclasses, self.opt.cls_semantic_nclasses) - 1
# Take into account the default "background class" of the semantic classes.
self.nb_classes_inference = max(
self.opt.f_s_semantic_nclasses - 1, self.opt.cls_semantic_nclasses
)

for i in range(self.nb_classes_inference):
Expand Down Expand Up @@ -475,13 +476,13 @@ def inference(self):
):
for i in range(self.nb_classes_inference):
if "class" in self.opt.alg_diffusion_cond_embed:
cur_class = torch.ones_like(self.cls)[: self.inference_num] * (
i + 1
)
cur_class = torch.ones_like(self.cls)[: self.inference_num] * i
else:
cur_class = None

if "mask" in self.opt.alg_diffusion_cond_embed:
# Take into account the default "background class", add the
# offset to get the real class id.
cur_mask = self.mask[: self.inference_num].clone().clamp(
min=0, max=1
) * (i + 1)
Expand Down
27 changes: 14 additions & 13 deletions options/common_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,25 +562,26 @@ def initialize(self, parser):
type=str,
default="unaligned",
choices=[
"unaligned",
"unaligned_labeled_cls",
"unaligned_labeled_mask",
"aligned",
"nuplet_unaligned_labeled_mask",
"self_supervised_labeled_cls",
"self_supervised_labeled_mask",
"unaligned_labeled_mask_cls",
"self_supervised_labeled_mask_cls",
"unaligned_labeled_mask_online",
"self_supervised_labeled_mask_online",
"unaligned_labeled_mask_cls_online",
"self_supervised_labeled_mask_cls_online",
"aligned",
"nuplet_unaligned_labeled_mask",
"temporal_labeled_mask_online",
"self_supervised_labeled_mask_online",
"self_supervised_labeled_mask_online_ref",
"self_supervised_labeled_mask_ref",
"self_supervised_temporal",
"single",
"unaligned_labeled_mask_ref",
"self_supervised_labeled_mask_ref",
"temporal_labeled_mask_online",
"unaligned",
"unaligned_labeled_cls",
"unaligned_labeled_mask",
"unaligned_labeled_mask_cls",
"unaligned_labeled_mask_cls_online",
"unaligned_labeled_mask_online",
"unaligned_labeled_mask_online_ref",
"self_supervised_labeled_mask_online_ref",
"unaligned_labeled_mask_ref",
],
help="chooses how datasets are loaded.",
)
Expand Down
3 changes: 3 additions & 0 deletions options/inference_diffusion_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def initialize(self, parser):
parser.add_argument("--cond_rotation", type=float, default=0)
parser.add_argument("--cond_persp_horizontal", type=float, default=0)
parser.add_argument("--cond_persp_vertical", type=float, default=0)
parser.add_argument(
"--canny_in", type=str, help="canny image to use for conditionning"
)

parser.add_argument(
"--min_crop_bbox_ratio",
Expand Down
8 changes: 7 additions & 1 deletion options/train_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ def initialize(self, parser):
)

parser.add_argument(
"--test_batch_size", type=int, default=1, help="input batch size"
"--test_batch_size",
type=int,
default=-1,
help="input batch size, defaults to train batch size",
)

parser.add_argument(
Expand Down Expand Up @@ -653,6 +656,9 @@ def gather_specific_options(self, opt, parser, args):
def _after_parse(self, opt, set_device=True):
opt = super()._after_parse(opt=opt, set_device=set_device)

if opt.test_batch_size == -1:
opt.test_batch_size = opt.train_batch_size

# process opt.suffix
if opt.suffix:
suffix = ("_" + opt.suffix.format(**vars(opt))) if opt.suffix != "" else ""
Expand Down
55 changes: 44 additions & 11 deletions scripts/gen_single_image_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def generate(
mask_in,
ref_in,
bbox_in,
canny_in,
cond_in,
cond_keep_ratio,
bbox_width_factor,
Expand Down Expand Up @@ -230,6 +231,8 @@ def generate(
mask_delta[i].append(delta_values[0])

# Load image
if "class" in conditioning:
cls = cls_value

# reading image
img = cv2.imread(img_in)
Expand Down Expand Up @@ -524,23 +527,47 @@ def generate(
elif opt.alg_diffusion_cond_image_creation == "sketch":
cond_image = fill_img_with_sketch(img_tensor.unsqueeze(0), mask.unsqueeze(0))
elif opt.alg_diffusion_cond_image_creation == "canny":
mask_is_none = mask is None
if mask_is_none:
# `fill_img_with_canny` needs a mask.
mask = torch.ones_like(img_tensor, device=img_tensor.device)

clamp = torch.clamp(mask, 0, 1)
if cond_in:
# mask the background to avoid canny edges around cond image
img_tensor_canny = clamp * img_tensor + clamp - 1
else:
img_tensor_canny = img_tensor
cond_image = fill_img_with_canny(
img_tensor_canny.unsqueeze(0),
mask.unsqueeze(0),
low_threshold=alg_diffusion_sketch_canny_thresholds[0],
high_threshold=alg_diffusion_sketch_canny_thresholds[1],
low_threshold_random=-1,
high_threshold_random=-1,
)

if canny_in:
cond_image = cv2.imread(canny_in, cv2.IMREAD_GRAYSCALE)
cond_image = cv2.resize(
cond_image, (img_tensor.shape[-1], img_tensor.shape[-2])
)
cond_image = torch.Tensor(cond_image).to(img_tensor.device)
cond_image = (cond_image / 255) * 2 - 1
# Add the RGB channels.
cond_image = torch.stack([cond_image] * 3)
cond_image = cond_image.unsqueeze(0)
else:
# If no canny image is provided, the canny is generated from
# the conditioning image.
cond_image = fill_img_with_canny(
img_tensor_canny.unsqueeze(0),
mask.unsqueeze(0),
low_threshold=alg_diffusion_sketch_canny_thresholds[0],
high_threshold=alg_diffusion_sketch_canny_thresholds[1],
low_threshold_random=-1,
high_threshold_random=-1,
)

if cond_in:
# restore background
cond_image = cond_image * clamp + img_tensor * (1 - clamp)

if mask_is_none:
# Revert to "None".
mask = None
elif opt.alg_diffusion_cond_image_creation == "sam":
opt.f_s_weight_sam = "../" + opt.f_s_weight_sam
if not os.path.exists(opt.f_s_weight_sam):
Expand Down Expand Up @@ -637,7 +664,9 @@ def generate(
out_img_real_size = img_orig.copy()
else:
out_img_resized = out_img
out_img_real_size = img_orig.copy()
out_img_real_size = cv2.resize(
out_img, (img_orig.shape[1], img_orig.shape[0]), cv2.INTER_CUBIC
)

# fill out crop into original image
if bbox_in:
Expand All @@ -647,13 +676,17 @@ def generate(

if cond_image is not None:
cond_img = to_np(cond_image)
cond_img = cv2.resize(
cond_img, (img_orig.shape[1], img_orig.shape[0]), cv2.INTER_CUBIC
)

if write:
cv2.imwrite(os.path.join(dir_out, name + "_orig.png"), img_orig)
if cond_image is not None:
cv2.imwrite(os.path.join(dir_out, name + "_cond.png"), cond_img)
cv2.imwrite(os.path.join(dir_out, name + "_generated.png"), out_img_real_size)
cv2.imwrite(os.path.join(dir_out, name + "_generated_crop.png"), out_img)
cv2.imwrite(os.path.join(dir_out, name + "_y_t.png"), to_np(y_t))
if cond_image is not None:
cv2.imwrite(os.path.join(dir_out, name + "_cond.png"), cond_img)
if mask is not None:
cv2.imwrite(os.path.join(dir_out, name + "_y_0.png"), to_np(img_tensor))
cv2.imwrite(os.path.join(dir_out, name + "_generated_crop.png"), out_img)
Expand Down
7 changes: 4 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,10 @@ def train_gpu(rank, world_size, opt, trainset, trainset_temporal):

if opt.output_display_id > 0:
metrics = model.get_current_metrics()
visualizer.plot_current_metrics(
epoch, float(epoch_iter) / trainset_size, metrics
)
if len(metrics) != 0:
visualizer.plot_current_metrics(
epoch, float(epoch_iter) / trainset_size, metrics
)

if (
total_iters % opt.train_D_accuracy_every < batch_size
Expand Down
Loading