From 8d4758be8606b2a50fb5f8c4eb3f64705a249534 Mon Sep 17 00:00:00 2001 From: Yunfan Li <54800821+Yunfan-Li@users.noreply.github.com> Date: Mon, 11 Jul 2022 11:57:23 +0800 Subject: [PATCH] Add files via upload --- boost.py | 333 +++++++++++++++++++++++++++++++++++++++++++ data.py | 195 ++++++++++++++++++++++++++ engine.py | 204 +++++++++++++++++++++++++++ evaluate.py | 93 ++++++++++++ loss.py | 368 ++++++++++++++++++++++++++++++++++++++++++++++++ misc.py | 381 ++++++++++++++++++++++++++++++++++++++++++++++++++ model.py | 211 ++++++++++++++++++++++++++++ readme.md | 68 +++++++++ train.py | 282 +++++++++++++++++++++++++++++++++++++ transforms.py | 271 +++++++++++++++++++++++++++++++++++ 10 files changed, 2406 insertions(+) create mode 100644 boost.py create mode 100644 data.py create mode 100644 engine.py create mode 100644 evaluate.py create mode 100644 loss.py create mode 100644 misc.py create mode 100644 model.py create mode 100644 readme.md create mode 100644 train.py create mode 100644 transforms.py diff --git a/boost.py b/boost.py new file mode 100644 index 0000000..cf040d7 --- /dev/null +++ b/boost.py @@ -0,0 +1,333 @@ +import argparse +import time +import datetime +import misc +import numpy as np +import os +import torch +import torch.backends.cudnn as cudnn +from pathlib import Path +from data import build_dataset +from model import get_resnet, Network +from misc import NativeScalerWithGradNormCount as NativeScaler +from loss import InstanceLossBoost, ClusterLossBoost +from engine import boost_one_epoch, evaluate +import json + + +def get_args_parser(): + parser = argparse.ArgumentParser("TCL", add_help=False) + parser.add_argument( + "--batch_size", default=256, type=int, help="Batch size per GPU" + ) + parser.add_argument("--epochs", default=200, type=int) + + # Model parameters + parser.add_argument( + "--model", + default="resnet34", + type=str, + metavar="MODEL", + choices=["resnet50", "resnet34", "resnet18"], + help="Name of model to train", + ) + parser.add_argument("--feat_dim", default=128, type=int, help="dimension of ICH") + parser.add_argument( + "--ins_temp", + default=0.5, + type=float, + help="temperature of instance-level contrastive loss", + ) + parser.add_argument( + "--clu_temp", + default=1.0, + type=float, + help="temperature of cluster-level contrastive loss", + ) + + # Optimizer parameters + parser.add_argument("--weight_decay", type=float, default=1e-4, help="weight decay") + parser.add_argument( + "--lr", + type=float, + default=1e-4, + metavar="LR", + help="learning rate (absolute lr)", + ) + + # Dataset parameters + parser.add_argument( + "--data_path", default="./datasets/", type=str, help="dataset path", + ) + parser.add_argument( + "--dataset", + default="CIFAR-10", + type=str, + help="dataset", + choices=["CIFAR-10", "CIFAR-100", "ImageNet-10", "ImageNet"], + ) + parser.add_argument( + "--nb_cluster", default=10, type=int, help="number of the clusters", + ) + parser.add_argument( + "--output_dir", + default="./save/", + help="path where to save, empty for no saving", + ) + parser.add_argument( + "--device", default="cuda", help="device to use for training / testing" + ) + parser.add_argument("--seed", default=0, type=int) + parser.add_argument( + "--resume", + default="./save/checkpoint-0.pth", + help="resume from checkpoint", + ) + parser.add_argument( + "--start_epoch", default=0, type=int, metavar="N", help="start epoch" + ) + parser.add_argument("--save_freq", default=20, type=int, help="saving frequency") + parser.add_argument( + "--eval_freq", default=10, type=int, help="evaluation frequency" + ) + parser.add_argument("--num_workers", default=10, type=int) + parser.add_argument( + "--pin_mem", + action="store_true", + help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.", + ) + parser.add_argument( + "--dist_eval", + action="store_true", + default=False, + help="Enabling distributed evaluation (recommended during training for faster monitor", + ) + + # distributed training parameters + parser.add_argument( + "--world_size", default=1, type=int, help="number of distributed processes" + ) + parser.add_argument("--local_rank", default=-1, type=int) + parser.add_argument("--dist_on_itp", action="store_true") + parser.add_argument( + "--dist_url", default="env://", help="url used to set up distributed training" + ) + + return parser + + +def main(args): + misc.init_distributed_mode(args) + + print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__)))) + print("{}".format(args).replace(", ", ",\n")) + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + cudnn.benchmark = True + + dataset_train = build_dataset(type="train", args=args) + dataset_pseudo = build_dataset(type="pseudo", args=args) + dataset_val = build_dataset(type="val", args=args) + + num_tasks = misc.get_world_size() + global_rank = misc.get_rank() + sampler_train = torch.utils.data.DistributedSampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + sampler_pseudo = torch.utils.data.DistributedSampler( + dataset_pseudo, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + print("Sampler_train = %s" % str(sampler_train)) + if args.dist_eval: + if len(dataset_val) % num_tasks != 0: + print( + "Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. " + "This will slightly alter validation results as extra duplicate entries are added to achieve " + "equal num of samples per-process." + ) + sampler_val = torch.utils.data.DistributedSampler( + dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) # shuffle=True to reduce monitor bias + else: + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + + if global_rank == 0 and args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + data_loader_train = torch.utils.data.DataLoader( + dataset_train, + sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=True, + ) + + data_loader_ps = torch.utils.data.DataLoader( + dataset_pseudo, + sampler=sampler_pseudo, + batch_size=1000, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=True, + ) + + data_loader_val = torch.utils.data.DataLoader( + dataset_val, + sampler=sampler_val, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=False, + ) + + backbone, hidden_dim = get_resnet(args) + model = Network(backbone, hidden_dim, args.feat_dim, args.nb_cluster) + + if args.resume: + checkpoint = torch.load(args.resume, map_location="cpu") + + print("Load pre-trained checkpoint from: %s" % args.resume) + checkpoint_model = checkpoint["model"] + + # load pre-trained model + msg = model.load_state_dict(checkpoint_model, strict=False) + print(msg) + + model.to(device) + + metric_logger = misc.MetricLogger(delimiter=" ") + header = "Test:" + + # switch to evaluation mode + model.eval() + + feat_vector = [] + labels_vector = [] + for (images, labels, _) in metric_logger.log_every(data_loader_val, 20, header): + images = images.to(device, non_blocking=True) + + # compute output + with torch.cuda.amp.autocast(): + feat, c = model.forward_zc(images) + c = torch.argmax(c, dim=1) + + feat_vector.extend(feat.cpu().detach().numpy()) + labels_vector.extend(labels.numpy()) + feat_vector = np.array(feat_vector) + labels_vector = np.array(labels_vector) + print( + "Feat shape {}, Label shape {}".format(feat_vector.shape, labels_vector.shape) + ) + + model_without_ddp = model + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + + print("Model = %s" % str(model_without_ddp)) + print("number of params (M): %.2f" % (n_parameters / 1.0e6)) + + eff_batch_size = args.batch_size * misc.get_world_size() + + print("base lr: %.3e" % args.lr) + print("effective batch size: %d" % eff_batch_size) + + optimizer = torch.optim.Adam( + [ + {"params": model.resnet.parameters(), "lr": args.lr,}, + {"params": model.instance_projector.parameters(), "lr": args.lr}, + {"params": model.cluster_projector.parameters(), "lr": args.lr}, + ], + lr=args.lr, + weight_decay=args.weight_decay, + ) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + loss_scaler = NativeScaler() + + criterion_ins = InstanceLossBoost( + tau=args.ins_temp, distributed=True, alpha=0.99, gamma=0.5 + ) + criterion_clu = ClusterLossBoost(distributed=True, cluster_num=args.nb_cluster) + + misc.load_model( + args=args, + model_without_ddp=model_without_ddp, + optimizer=optimizer, + loss_scaler=loss_scaler, + ) + + print(f"Start training for {args.epochs} epochs") + pseudo_labels = -torch.ones(dataset_train.__len__(), dtype=torch.long) + start_time = time.time() + max_accuracy = 0.0 + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + data_loader_train.sampler.set_epoch(epoch) + train_stats, pseudo_labels = boost_one_epoch( + model, + criterion_ins, + criterion_clu, + data_loader_train, + optimizer, + device, + epoch, + loss_scaler, + pseudo_labels, + args=args, + ) + if args.output_dir and ( + epoch % args.save_freq == 0 or epoch + 1 == args.epochs + ): + misc.save_model( + args=args, + model=model, + model_without_ddp=model_without_ddp, + optimizer=optimizer, + loss_scaler=loss_scaler, + epoch=epoch, + ) + if ( + epoch % args.eval_freq == 0 + or epoch + 1 == args.epochs + ): + test_stats = evaluate(data_loader_val, model, device) + print( + f"Clustering performance on the {len(dataset_val)} test images: NMI={test_stats['nmi']:.2f}%, ACC={test_stats['acc']:.2f}%, ARI={test_stats['ari']:.2f}%" + ) + max_accuracy = max(max_accuracy, test_stats["acc"]) + print(f"Max accuracy: {max_accuracy:.2f}%") + + if epoch == args.start_epoch: + test_stats = {"pred_num": 1000} + + log_stats = { + **{f"train_{k}": v for k, v in train_stats.items()}, + **{f"test_{k}": v for k, v in test_stats.items()}, + "epoch": epoch, + "n_parameters": n_parameters, + } + + if args.output_dir and misc.is_main_process(): + with open( + os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8" + ) as f: + f.write(json.dumps(log_stats) + "\n") + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print("Training time {}".format(total_time_str)) + + +if __name__ == "__main__": + args = get_args_parser() + args = args.parse_args() + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + main(args) diff --git a/data.py b/data.py new file mode 100644 index 0000000..16a3e20 --- /dev/null +++ b/data.py @@ -0,0 +1,195 @@ +import torchvision +from typing import Any, Callable, Optional +from PIL import Image +from torchvision.datasets.folder import default_loader +from transforms import build_transform +from torch.utils import data + + +class CIFAR10(torchvision.datasets.CIFAR10): + def __init__( + self, + root: str, + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ): + super(CIFAR10, self).__init__( + root, train, transform, target_transform, download + ) + self.train = train + + def __getitem__(self, index: int): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target, index) where target is index of the target class. + """ + img, target = self.data[index], self.targets[index] + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + if self.train: + return img, target, index + else: + return img, target, index + 50000 + + +class CIFAR100(CIFAR10): + """`CIFAR100 `_ Dataset. + + This is a subclass of the `CIFAR10` Dataset. + """ + + base_folder = "cifar-100-python" + url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" + filename = "cifar-100-python.tar.gz" + tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85" + train_list = [ + ["train", "16019d7e3df5f24257cddd939b257f8d"], + ] + + test_list = [ + ["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"], + ] + meta = { + "filename": "meta", + "key": "fine_label_names", + "md5": "7973b15100ade9c7d40fb424638fde48", + } + + +IMG_EXTENSIONS = ( + ".jpg", + ".jpeg", + ".png", + ".ppm", + ".bmp", + ".pgm", + ".tif", + ".tiff", + ".webp", +) + + +class ImageFolder(torchvision.datasets.DatasetFolder): + """A generic data loader where the images are arranged in this way: :: + + root/dog/xxx.png + root/dog/xxy.png + root/dog/xxz.png + + root/cat/123.png + root/cat/nsdf3.png + root/cat/asd932_.png + + Args: + root (string): Root directory path. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + loader (callable, optional): A function to load an image given its path. + is_valid_file (callable, optional): A function that takes path of an Image file + and check if the file is a valid file (used to check of corrupt files) + + Attributes: + classes (list): List of the class names sorted alphabetically. + class_to_idx (dict): Dict with items (class_name, class_index). + imgs (list): List of (image path, class_index) tuples + """ + + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + loader: Callable[[str], Any] = default_loader, + is_valid_file: Optional[Callable[[str], bool]] = None, + ): + super(ImageFolder, self).__init__( + root, + loader, + IMG_EXTENSIONS if is_valid_file is None else None, + transform=transform, + target_transform=target_transform, + is_valid_file=is_valid_file, + ) + self.imgs = self.samples + + def __getitem__(self, index: int): + """ + Args: + index (int): Index + + Returns: + tuple: (sample, target, index) where target is class_index of the target class. + """ + path, target = self.samples[index] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + + return sample, target, index + + +def build_dataset(type, args): + is_train = type == "train" + transform = build_transform(is_train, args) + root = args.data_path + + if args.dataset == "CIFAR-10": + dataset = data.ConcatDataset( + [ + CIFAR10( + root=root + "CIFAR-10", + train=True, + download=True, + transform=transform, + ), + CIFAR10( + root=root + "CIFAR-10", + train=False, + download=True, + transform=transform, + ), + ] + ) + elif args.dataset == "CIFAR-100": + dataset = data.ConcatDataset( + [ + CIFAR100( + root=root + "CIFAR-100", + train=True, + download=True, + transform=transform, + ), + CIFAR100( + root=root + "CIFAR-100", + train=False, + download=True, + transform=transform, + ), + ] + ) + elif args.dataset == "ImageNet-10": + dataset = ImageFolder(root=root + "ImageNet-10", transform=transform) + elif args.dataset == "ImageNet": + dataset = ImageFolder(root=root + "ImageNet/train", transform=transform) + + print(dataset) + + return dataset diff --git a/engine.py b/engine.py new file mode 100644 index 0000000..e843720 --- /dev/null +++ b/engine.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import math +import sys + +import torch +import torch.nn.functional as F + +import misc +import numpy as np +from evaluate import cluster_metric + + +def train_one_epoch( + model, + criterion_ins, + criterion_clu, + data_loader, + optimizer, + device, + epoch, + loss_scaler, + args, +): + model.train(True) + metric_logger = misc.MetricLogger(delimiter=" ") + header = "Epoch: [{}]".format(epoch) + print_freq = 20 + + optimizer.zero_grad() + + for data_iter_step, ((x_w, x_s, x), _, index) in enumerate( + metric_logger.log_every(data_loader, print_freq, header) + ): + x_w = x_w.to(device, non_blocking=True) + x_s = x_s.to(device, non_blocking=True) + + with torch.cuda.amp.autocast(): + z_i, z_j, c_i, c_j = model(x_w, x_s) + c_i = F.softmax(c_i, dim=1) + c_j = F.softmax(c_j, dim=1) + loss_ins = criterion_ins(torch.concat((z_i, z_j), dim=0)) + loss_clu = criterion_clu(torch.concat((c_i, c_j), dim=0)) + loss = loss_ins + loss_clu + + loss_ins_value = loss_ins.item() + loss_clu_value = loss_clu.item() + + if not math.isfinite(loss_ins_value) or not math.isfinite(loss_clu_value): + print( + "Loss is {}, {}, stopping training".format( + loss_ins_value, loss_clu_value + ) + ) + sys.exit(1) + + loss_scaler( + loss, + optimizer, + parameters=model.parameters(), + create_graph=False, + update_grad=True, + ) + optimizer.zero_grad() + + torch.cuda.synchronize() + + metric_logger.update(loss_ins=loss_ins_value) + metric_logger.update(loss_clu=loss_clu_value) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def evaluate(data_loader, model, device): + metric_logger = misc.MetricLogger(delimiter=" ") + header = "Test:" + + # switch to evaluation mode + model.eval() + + pred_vector = [] + labels_vector = [] + for (images, labels, _) in metric_logger.log_every(data_loader, 20, header): + images = images.to(device, non_blocking=True) + + # compute output + with torch.cuda.amp.autocast(): + preds = model.module.forward_c(images) + preds = torch.argmax(preds, dim=1) + + pred_vector.extend(preds.cpu().detach().numpy()) + labels_vector.extend(labels.numpy()) + pred_vector = np.array(pred_vector) + labels_vector = np.array(labels_vector) + print( + "Pred shape {}, Label shape {}".format(pred_vector.shape, labels_vector.shape) + ) + + nmi, ari, acc = cluster_metric(labels_vector, pred_vector) + print(nmi, ari, acc) + + metric_logger.meters["nmi"].update(nmi) + metric_logger.meters["acc"].update(acc) + metric_logger.meters["ari"].update(ari) + # gather the stats from all processes + metric_logger.synchronize_between_processes() + + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +def boost_one_epoch( + model, + criterion_ins, + criterion_clu, + data_loader, + optimizer, + device, + epoch, + loss_scaler, + pseudo_labels, + args, +): + metric_logger = misc.MetricLogger(delimiter=" ") + header = "Epoch: [{}]".format(epoch) + print_freq = 20 + + optimizer.zero_grad() + for data_iter_step, ((x_w, x_s, x), _, index) in enumerate( + metric_logger.log_every(data_loader, print_freq, header) + ): + x_w = x_w.to(device, non_blocking=True) + x_s = x_s.to(device, non_blocking=True) + x = x.to(device, non_blocking=True) + + model.eval() + with torch.cuda.amp.autocast(), torch.no_grad(): + _, _, c = model(x, x, return_ci=False) + c = F.softmax(c / args.clu_temp, dim=1) + pseudo_labels_cur, index_cur = criterion_ins.generate_pseudo_labels( + c, pseudo_labels[index].to(c.device), index.to(c.device) + ) + pseudo_labels[index_cur] = pseudo_labels_cur + pseudo_index = pseudo_labels != -1 + metric_logger.update(pseudo_num=pseudo_index.sum().item()) + metric_logger.update( + pseudo_cluster=torch.unique(pseudo_labels[pseudo_index]).shape[0] + ) + if epoch == args.start_epoch: + continue + + model.train(True) + with torch.cuda.amp.autocast(): + z_i, z_j, c_j = model(x_w, x_s, return_ci=False) + loss_ins = criterion_ins( + torch.concat((z_i, z_j), dim=0), pseudo_labels[index].to(x_s.device) + ) + loss_clu = criterion_clu(c_j, pseudo_labels[index].to(x_s.device)) + loss = loss_ins + loss_clu + + loss_ins_value = loss_ins.item() + loss_clu_value = loss_clu.item() + + if not math.isfinite(loss_ins_value) or not math.isfinite(loss_clu_value): + print( + "Loss is {}, {}, stopping training".format( + loss_ins_value, loss_clu_value + ) + ) + sys.exit(1) + loss_scaler( + loss, + optimizer, + parameters=model.parameters(), + create_graph=False, + update_grad=True, + ) + optimizer.zero_grad() + + torch.cuda.synchronize() + + metric_logger.update(loss_ins=loss_ins_value) + metric_logger.update(loss_clu=loss_clu_value) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return ( + {k: meter.global_avg for k, meter in metric_logger.meters.items()}, + pseudo_labels, + ) + diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..98674b7 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,93 @@ +from sklearn import metrics +from munkres import Munkres +import numpy as np +import torch +from scipy.optimize import linear_sum_assignment + + +def reorder_preds(predictions, targets): + predictions = torch.from_numpy(predictions).cuda() + targets = torch.from_numpy(targets).cuda() + match = _hungarian_match(predictions, targets, preds_k=1000, targets_k=1000) + reordered_preds = torch.zeros(predictions.shape[0], dtype=predictions.dtype).cuda() + for pred_i, target_i in match: + reordered_preds[predictions == int(pred_i)] = int(target_i) + return reordered_preds.cpu().numpy() + + +def _hungarian_match(flat_preds, flat_targets, preds_k, targets_k): + # Based on implementation from IIC + num_samples = flat_targets.shape[0] + + assert preds_k == targets_k # one to one + num_k = preds_k + num_correct = np.zeros((num_k, num_k)) + + for c1 in range(num_k): + for c2 in range(num_k): + # elementwise, so each sample contributes once + votes = int(((flat_preds == c1) * (flat_targets == c2)).sum()) + num_correct[c1, c2] = votes + + # num_correct is small + match = linear_sum_assignment(num_samples - num_correct) + match = np.array(list(zip(*match))) + + # return as list of tuples, out_c to gt_c + res = [] + for out_c, gt_c in match: + res.append((out_c, gt_c)) + + return res + + +def cluster_metric(label, pred): + nmi = metrics.normalized_mutual_info_score(label, pred) + ari = metrics.adjusted_rand_score(label, pred) + # pred_adjusted = get_y_preds(label, pred, len(set(label))) + pred_adjusted = reorder_preds(pred, label) + acc = metrics.accuracy_score(pred_adjusted, label) + return nmi * 100, ari * 100, acc * 100 + + +def calculate_cost_matrix(C, n_clusters): + cost_matrix = np.zeros((n_clusters, n_clusters)) + # cost_matrix[i,j] will be the cost of assigning cluster i to label j + for j in range(n_clusters): + s = np.sum(C[:, j]) # number of examples in cluster i + for i in range(n_clusters): + t = C[i, j] + cost_matrix[j, i] = s - t + return cost_matrix + + +def get_cluster_labels_from_indices(indices): + n_clusters = len(indices) + cluster_labels = np.zeros(n_clusters) + for i in range(n_clusters): + cluster_labels[i] = indices[i][1] + return cluster_labels + + +def get_y_preds(y_true, cluster_assignments, n_clusters): + """ + Computes the predicted labels, where label assignments now + correspond to the actual labels in y_true (as estimated by Munkres) + cluster_assignments: array of labels, outputted by kmeans + y_true: true labels + n_clusters: number of clusters in the dataset + returns: a tuple containing the accuracy and confusion matrix, + in that order + """ + confusion_matrix = metrics.confusion_matrix( + y_true, cluster_assignments, labels=None + ) + # compute accuracy based on optimal 1:1 assignment of clusters to labels + cost_matrix = calculate_cost_matrix(confusion_matrix, n_clusters) + indices = Munkres().compute(cost_matrix) + kmeans_to_true_cluster_labels = get_cluster_labels_from_indices(indices) + + if np.min(cluster_assignments) != 0: + cluster_assignments = cluster_assignments - np.min(cluster_assignments) + y_pred = kmeans_to_true_cluster_labels[cluster_assignments] + return y_pred diff --git a/loss.py b/loss.py new file mode 100644 index 0000000..e08dc28 --- /dev/null +++ b/loss.py @@ -0,0 +1,368 @@ +import torch +from torch import nn +import numpy as np +import torch.nn.functional as F +import diffdist +import torch.distributed as dist + + +def gather(z): + gather_z = [torch.zeros_like(z) for _ in range(torch.distributed.get_world_size())] + gather_z = diffdist.functional.all_gather(gather_z, z) + gather_z = torch.cat(gather_z) + + return gather_z + + +def accuracy(logits, labels, k): + topk = torch.sort(logits.topk(k, dim=1)[1], 1)[0] + labels = torch.sort(labels, 1)[0] + acc = (topk == labels).all(1).float() + return acc + + +def mean_cumulative_gain(logits, labels, k): + topk = torch.sort(logits.topk(k, dim=1)[1], 1)[0] + labels = torch.sort(labels, 1)[0] + mcg = (topk == labels).float().mean(1) + return mcg + + +def mean_average_precision(logits, labels, k): + # TODO: not the fastest solution but looks fine + argsort = torch.argsort(logits, dim=1, descending=True) + labels_to_sorted_idx = ( + torch.sort(torch.gather(torch.argsort(argsort, dim=1), 1, labels), dim=1)[0] + 1 + ) + precision = ( + 1 + torch.arange(k, device=logits.device).float() + ) / labels_to_sorted_idx + return precision.sum(1) / k + + +class InstanceLoss(nn.Module): + """ + Contrastive loss with distributed data parallel support + """ + + LARGE_NUMBER = 1e4 + + def __init__(self, tau=0.5, multiplier=2, distributed=False): + super().__init__() + self.tau = tau + self.multiplier = multiplier + self.distributed = distributed + + def forward(self, z, get_map=False): + n = z.shape[0] + assert n % self.multiplier == 0 + + z = z / np.sqrt(self.tau) + + if self.distributed: + z_list = [torch.zeros_like(z) for _ in range(dist.get_world_size())] + # all_gather fills the list as [, , ...] + # TODO: try to rewrite it with pytorch official tools + z_list = diffdist.functional.all_gather(z_list, z) + # split it into [, , ..., , , ...] + z_list = [chunk for x in z_list for chunk in x.chunk(self.multiplier)] + # sort it to [, , ...] that simply means [, , ...] as expected below + z_sorted = [] + for m in range(self.multiplier): + for i in range(dist.get_world_size()): + z_sorted.append(z_list[i * self.multiplier + m]) + z = torch.cat(z_sorted, dim=0) + n = z.shape[0] + + logits = z @ z.t() + logits[np.arange(n), np.arange(n)] = -self.LARGE_NUMBER + + logprob = F.log_softmax(logits, dim=1) + + # choose all positive objects for an example, for i it would be (i + k * n/m), where k=0...(m-1) + m = self.multiplier + labels = (np.repeat(np.arange(n), m) + np.tile(np.arange(m) * n // m, n)) % n + # remove labels pointet to itself, i.e. (i, i) + labels = labels.reshape(n, m)[:, 1:].reshape(-1) + + loss = -logprob[np.repeat(np.arange(n), m - 1), labels].sum() / n / (m - 1) + + return loss + + +class ClusterLoss(nn.Module): + """ + Contrastive loss with distributed data parallel support + """ + + LARGE_NUMBER = 1e4 + + def __init__(self, tau=1.0, multiplier=2, distributed=False): + super().__init__() + self.tau = tau + self.multiplier = multiplier + self.distributed = distributed + + def forward(self, c, get_map=False): + n = c.shape[0] + assert n % self.multiplier == 0 + + # c = c / np.sqrt(self.tau) + + if self.distributed: + c_list = [torch.zeros_like(c) for _ in range(dist.get_world_size())] + # all_gather fills the list as [, , ...] + c_list = diffdist.functional.all_gather(c_list, c) + # split it into [, , ..., , , ...] + c_list = [chunk for x in c_list for chunk in x.chunk(self.multiplier)] + # sort it to [, , ...] that simply means [, , ...] as expected below + c_sorted = [] + for m in range(self.multiplier): + for i in range(dist.get_world_size()): + c_sorted.append(c_list[i * self.multiplier + m]) + c_aug0 = torch.cat( + c_sorted[: int(self.multiplier * dist.get_world_size() / 2)], dim=0 + ) + c_aug1 = torch.cat( + c_sorted[int(self.multiplier * dist.get_world_size() / 2) :], dim=0 + ) + + p_i = c_aug0.sum(0).view(-1) + p_i /= p_i.sum() + en_i = np.log(p_i.size(0)) + (p_i * torch.log(p_i)).sum() + p_j = c_aug1.sum(0).view(-1) + p_j /= p_j.sum() + en_j = np.log(p_j.size(0)) + (p_j * torch.log(p_j)).sum() + en_loss = en_i + en_j + + c = torch.cat((c_aug0.t(), c_aug1.t()), dim=0) + n = c.shape[0] + + c = F.normalize(c, p=2, dim=1) / np.sqrt(self.tau) + + logits = c @ c.t() + logits[np.arange(n), np.arange(n)] = -self.LARGE_NUMBER + + logprob = F.log_softmax(logits, dim=1) + + # choose all positive objects for an example, for i it would be (i + k * n/m), where k=0...(m-1) + m = self.multiplier + labels = (np.repeat(np.arange(n), m) + np.tile(np.arange(m) * n // m, n)) % n + # remove labels pointet to itself, i.e. (i, i) + labels = labels.reshape(n, m)[:, 1:].reshape(-1) + + loss = -logprob[np.repeat(np.arange(n), m - 1), labels].sum() / n / (m - 1) + + return loss + en_loss + + +class InstanceLossBoost(nn.Module): + """ + Contrastive loss with distributed data parallel support + """ + + LARGE_NUMBER = 1e4 + + def __init__( + self, + tau=0.5, + multiplier=2, + distributed=False, + alpha=0.9, + gamma=0.1, + cluster_num=10, + ): + super().__init__() + self.tau = tau + self.multiplier = multiplier + self.distributed = distributed + self.alpha = alpha + self.gamma = gamma + self.cluster_num = cluster_num + + @torch.no_grad() + def generate_pseudo_labels(self, c, pseudo_label_cur, index): + if self.distributed: + c_list = [torch.zeros_like(c) for _ in range(dist.get_world_size())] + pseudo_label_cur_list = [torch.zeros_like(pseudo_label_cur) for _ in range(dist.get_world_size())] + index_list = [torch.zeros_like(index) for _ in range(dist.get_world_size())] + # all_gather fills the list as [, , ...] + c_list = diffdist.functional.all_gather(c_list, c) + pseudo_label_cur_list = diffdist.functional.all_gather(pseudo_label_cur_list, pseudo_label_cur) + index_list = diffdist.functional.all_gather(index_list, index) + c = torch.cat(c_list, dim=0,) + pseudo_label_cur = torch.cat(pseudo_label_cur_list, dim=0,) + index = torch.cat(index_list, dim=0,) + batch_size = c.shape[0] + device = c.device + pseudo_label_nxt = -torch.ones(batch_size, dtype=torch.long).to(device) + tmp = torch.arange(0, batch_size).to(device) + + prediction = c.argmax(dim=1) + confidence = c.max(dim=1).values + unconfident_pred_index = confidence < self.alpha + pseudo_per_class = np.ceil(batch_size / self.cluster_num * self.gamma).astype( + int + ) + for i in range(self.cluster_num): + class_idx = prediction == i + if class_idx.sum() == 0: + continue + confidence_class = confidence[class_idx] + num = min(confidence_class.shape[0], pseudo_per_class) + confident_idx = torch.argsort(-confidence_class) + for j in range(num): + idx = tmp[class_idx][confident_idx[j]] + pseudo_label_nxt[idx] = i + + todo_index = pseudo_label_cur == -1 + pseudo_label_cur[todo_index] = pseudo_label_nxt[todo_index] + pseudo_label_nxt = pseudo_label_cur + pseudo_label_nxt[unconfident_pred_index] = -1 + return pseudo_label_nxt.cpu(), index + + def forward(self, z, pseudo_label): + n = z.shape[0] + assert n % self.multiplier == 0 + + if self.distributed: + z_list = [torch.zeros_like(z) for _ in range(dist.get_world_size())] + pseudo_label_list = [ + torch.zeros_like(pseudo_label) for _ in range(dist.get_world_size()) + ] + # all_gather fills the list as [, , ...] + z_list = diffdist.functional.all_gather(z_list, z) + pseudo_label_list = diffdist.functional.all_gather( + pseudo_label_list, pseudo_label + ) + # split it into [, , ..., , , ...] + z_list = [chunk for x in z_list for chunk in x.chunk(self.multiplier)] + pseudo_label_list = [ + chunk for x in pseudo_label_list for chunk in x.chunk(self.multiplier) + ] + # sort it to [, , ...] that simply means [, , ...] as expected below + z_sorted = [] + pesudo_label_sorted = [] + for m in range(self.multiplier): + for i in range(dist.get_world_size()): + z_sorted.append(z_list[i * self.multiplier + m]) + pesudo_label_sorted.append( + pseudo_label_list[i * self.multiplier + m] + ) + z_i = torch.cat( + z_sorted[: int(self.multiplier * dist.get_world_size() / 2)], dim=0 + ) + z_j = torch.cat( + z_sorted[int(self.multiplier * dist.get_world_size() / 2) :], dim=0 + ) + pseudo_label = torch.cat(pesudo_label_sorted, dim=0,) + n = z_i.shape[0] + + invalid_index = pseudo_label == -1 + mask = torch.eq(pseudo_label.view(-1, 1), pseudo_label.view(1, -1)).to( + z_i.device + ) + mask[invalid_index, :] = False + mask[:, invalid_index] = False + mask_eye = torch.eye(n).float().to(z_i.device) + mask &= ~(mask_eye.bool()) + mask = mask.float() + + contrast_count = self.multiplier + contrast_feature = torch.cat((z_i, z_j), dim=0) + + anchor_feature = contrast_feature + anchor_count = contrast_count + + # compute logits + anchor_dot_contrast = torch.div( + torch.matmul(anchor_feature, contrast_feature.T), self.tau + ) + # for numerical stability + logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) + logits = anchor_dot_contrast - logits_max.detach() + + # tile mask + # mask_with_eye = mask | mask_eye.bool() + # mask = torch.cat(mask) + mask = mask.repeat(anchor_count, contrast_count) + mask_eye = mask_eye.repeat(anchor_count, contrast_count) + # mask-out self-contrast cases + logits_mask = torch.scatter( + torch.ones_like(mask), + 1, + torch.arange(n * anchor_count).view(-1, 1).to(z_i.device), + 0, + ) + logits_mask *= 1 - mask + mask_eye = mask_eye * logits_mask + + # compute log_prob + exp_logits = torch.exp(logits) * logits_mask + log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) + + # compute mean of log-likelihood over positive + mean_log_prob_pos = (mask_eye * log_prob).sum(1) / mask_eye.sum(1) + + # loss + instance_loss = -mean_log_prob_pos + instance_loss = instance_loss.view(anchor_count, n).mean() + + return instance_loss + + +class ClusterLossBoost(nn.Module): + """ + Contrastive loss with distributed data parallel support + """ + + LARGE_NUMBER = 1e4 + + def __init__(self, multiplier=1, distributed=False, cluster_num=10): + super().__init__() + self.multiplier = multiplier + self.distributed = distributed + self.cluster_num = cluster_num + + def forward(self, c, pseudo_label): + if self.distributed: + # c_list = [torch.zeros_like(c) for _ in range(dist.get_world_size())] + pesudo_label_list = [ + torch.zeros_like(pseudo_label) for _ in range(dist.get_world_size()) + ] + # all_gather fills the list as [, , ...] + # c_list = diffdist.functional.all_gather(c_list, c) + pesudo_label_list = diffdist.functional.all_gather( + pesudo_label_list, pseudo_label + ) + # split it into [, , ..., , , ...] + # c_list = [chunk for x in c_list for chunk in x.chunk(self.multiplier)] + pesudo_label_list = [ + chunk for x in pesudo_label_list for chunk in x.chunk(self.multiplier) + ] + # sort it to [, , ...] that simply means [, , ...] as expected below + # c_sorted = [] + pesudo_label_sorted = [] + for m in range(self.multiplier): + for i in range(dist.get_world_size()): + # c_sorted.append(c_list[i * self.multiplier + m]) + pesudo_label_sorted.append( + pesudo_label_list[i * self.multiplier + m] + ) + # c = torch.cat(c_sorted, dim=0) + pesudo_label_all = torch.cat(pesudo_label_sorted, dim=0) + pseudo_index = pesudo_label_all != -1 + pesudo_label_all = pesudo_label_all[pseudo_index] + idx, counts = torch.unique(pesudo_label_all, return_counts=True) + freq = pesudo_label_all.shape[0] / counts.float() + weight = torch.ones(self.cluster_num).to(c.device) + weight[idx] = freq + pseudo_index = pseudo_label != -1 + if pseudo_index.sum() > 0: + criterion = nn.CrossEntropyLoss(weight=weight).to(c.device) + loss_ce = criterion( + c[pseudo_index], pseudo_label[pseudo_index].to(c.device) + ) + else: + loss_ce = torch.tensor(0.0, requires_grad=True).to(c.device) + return loss_ce diff --git a/misc.py b/misc.py new file mode 100644 index 0000000..642d7bc --- /dev/null +++ b/misc.py @@ -0,0 +1,381 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import builtins +import datetime +import os +import time +from collections import defaultdict, deque +from pathlib import Path + +import torch +import torch.distributed as dist +from torch._six import inf + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + "'{}' object has no attribute '{}'".format(type(self).__name__, attr) + ) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + log_msg = [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + if torch.cuda.is_available(): + log_msg.append("max mem: {memory:.0f}") + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print( + "{} Total time: {} ({:.4f} s / it)".format( + header, total_time_str, total_time / len(iterable) + ) + ) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + force = force or (get_world_size() > 8) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print("[{}] ".format(now), end="") # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if args.dist_on_itp: + args.rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + args.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + args.gpu = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + args.dist_url = "tcp://%s:%s" % ( + os.environ["MASTER_ADDR"], + os.environ["MASTER_PORT"], + ) + os.environ["LOCAL_RANK"] = str(args.gpu) + os.environ["RANK"] = str(args.rank) + os.environ["WORLD_SIZE"] = str(args.world_size) + # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] + elif "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) + args.gpu = args.rank % torch.cuda.device_count() + else: + print("Not using distributed mode") + setup_for_distributed(is_master=True) # hack + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print( + "| distributed init (rank {}): {}, gpu {}".format( + args.rank, args.dist_url, args.gpu + ), + flush=True, + ) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__( + self, + loss, + optimizer, + clip_grad=None, + parameters=None, + create_graph=False, + update_grad=True, + ): + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_( + optimizer + ) # unscale the gradients of optimizer's assigned params in-place + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.0) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm( + torch.stack( + [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters] + ), + norm_type, + ) + return total_norm + + +def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): + output_dir = Path(args.output_dir) + epoch_name = str(epoch) + if loss_scaler is not None: + checkpoint_paths = [output_dir / ("checkpoint-%s.pth" % epoch_name)] + for checkpoint_path in checkpoint_paths: + to_save = { + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + # "encoder_optimizer": encoder_optimizer.state_dict(), + "epoch": epoch, + "scaler": loss_scaler.state_dict(), + "args": args, + } + + save_on_master(to_save, checkpoint_path) + else: + client_state = {"epoch": epoch} + model.save_checkpoint( + save_dir=args.output_dir, + tag="checkpoint-%s" % epoch_name, + client_state=client_state, + ) + + +def load_model(args, model_without_ddp, optimizer, loss_scaler): + checkpoint = torch.load(args.resume, map_location="cpu") + model_without_ddp.load_state_dict(checkpoint["model"]) + print("Resume checkpoint %s" % args.resume) + if "optimizer" in checkpoint and "epoch" in checkpoint: + optimizer.load_state_dict(checkpoint["optimizer"]) + args.start_epoch = checkpoint["epoch"] + 1 + if "scaler" in checkpoint: + loss_scaler.load_state_dict(checkpoint["scaler"]) + print("With optim!") + + +def all_reduce_mean(x): + world_size = get_world_size() + if world_size > 1: + x_reduce = torch.tensor(x).cuda() + dist.all_reduce(x_reduce) + x_reduce /= world_size + return x_reduce.item() + else: + return x diff --git a/model.py b/model.py new file mode 100644 index 0000000..e8fbffb --- /dev/null +++ b/model.py @@ -0,0 +1,211 @@ +from torch import nn +from torch.nn import functional as F +from torchvision.models.resnet import Bottleneck, BasicBlock, conv1x1 +import torch +from timm.models.layers import trunc_normal_ + + +class ResNet(nn.Module): + def __init__( + self, + block, + layers, + in_channel=3, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm_layer=None, + ): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation) + ) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d( + in_channel, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False + ) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] + ) + self.layer3 = self._make_layer( + block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] + ) + self.layer4 = self._make_layer( + block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] + ) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.rep_dim = 512 * block.expansion + self.fc = nn.Linear(self.rep_dim, self.rep_dim) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + # for name, param in self.named_parameters(): + # if ( + # name.startswith("conv1") + # or name.startswith("bn1") + # or name.startswith("layer1") + # or name.startswith("layer2") + # ): + # print("Freeze gradient for", name) + # param.requires_grad = False + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + # See note [TorchScript super()] + # with torch.no_grad(): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, x): + return self._forward_impl(x) + + +def get_resnet(args): + if args.model == "resnet18": + return ResNet(BasicBlock, [2, 2, 2, 2]), 512 + elif args.model == "resnet34": + return ResNet(BasicBlock, [3, 4, 6, 3]), 512 + elif args.model == "resnet50": + return ResNet(Bottleneck, [3, 4, 6, 3]), 2048 + else: + raise NotImplementedError + + +class Network(nn.Module): + def __init__(self, resnet, hidden_dim, feature_dim, class_num): + super(Network, self).__init__() + self.resnet = resnet + self.feature_dim = feature_dim + self.cluster_num = class_num + self.instance_projector = nn.Sequential( + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, self.feature_dim), + ) + self.cluster_projector = nn.Sequential( + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, self.cluster_num), + ) + trunc_normal_(self.cluster_projector[2].weight, std=0.02) + trunc_normal_(self.cluster_projector[5].weight, std=0.02) + + def forward(self, x_i, x_j, return_ci=True): + h_i = self.resnet(x_i) + h_j = self.resnet(x_j) + + z_i = F.normalize(self.instance_projector(h_i), dim=1) + z_j = F.normalize(self.instance_projector(h_j), dim=1) + + c_j = self.cluster_projector(h_j) + + if return_ci: + c_i = self.cluster_projector(h_i) + return z_i, z_j, c_i, c_j + else: + return z_i, z_j, c_j + + def forward_c(self, x): + h = self.resnet(x) + c = self.cluster_projector(h) + c = F.softmax(c, dim=1) + return c + + def forward_zc(self, x): + h = self.resnet(x) + z = F.normalize(self.instance_projector(h), dim=1) + c = self.cluster_projector(h) + c = F.softmax(c, dim=1) + return z, c diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..cf75841 --- /dev/null +++ b/readme.md @@ -0,0 +1,68 @@ +# Twin Contrastive Learning for Online Clustering (TCL) + +This is the code for the paper "Twin Contrastive Learning for Online Clustering" (IJCV 2022). + +TCL extends the previous work "Contrastive Clustering" (AAAI 2021, https://github.com/Yunfan-Li/Contrastive-Clustering) by selecting most confident predictions to finetune both the instance- and cluster-level contrastive learning. + +TCL proposes to mix weak and strong augmentations for both image and text modality. More performance gains are observed by the twin contrastive learning framework compared with the standard instance-level contrastive learning. + +The code supports multi-gpu training. + +Paper Link: https://link.springer.com/article/10.1007/s11263-022-01639-z + +# Environment + +- diffdist=0.1 +- python=3.9.12 +- pytorch=1.11.0 +- torchvision=0.12.0 +- munkres=1.1.4 +- numpy=1.22.3 +- opencv-python=4.6.0.66 +- scikit-learn=1.0.2 +- cudatoolkit=11.3.1 + +# Usage + +TCL is composed of the training and boosting stages. Configurations such as model, dataset, temperature, etc. could be set with argparse. Clustering performance is evaluated during the training or boosting. + +## Training + +The following command is used for training on CIFAR-10 with a 4-gpu machine, + +> OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=4 train.py + +## Boosting + +The following command is used for boosting on CIFAR-10 with a 4-gpu machine, + +> OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=4 boost.py + +## Clustering On ImageNet +To clustering datasets like ImageNet with a large number of classes, a reasonable batch size is needed. However, considering the gpu memory consumption, we recommend inheriting the moco v2 pretrained model (https://dl.fbaipublicfiles.com/moco/moco_checkpoints/moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar) and freezing part of the network parameters (see details in model.py). + +# Dataset + +CIFAR-10, CIFAR-100 could be automatically downloaded by Pytorch. For ImageNet-10 and ImageNet-dogs, we provided their indices from ImageNet in the "dataset" folder. + +To run TCL on ImageNet and its subsets, you need to prepare the data and pass the image folder path to the `--data_path` argment. + +# Citation + +If you find TCL useful in your research, please consider citing: +``` + +``` + +or the previous conference version +``` +@inproceedings{li2021contrastive, + title={Contrastive clustering}, + author={Li, Yunfan and Hu, Peng and Liu, Zitao and Peng, Dezhong and Zhou, Joey Tianyi and Peng, Xi}, + booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, + volume={35}, + number={10}, + pages={8547--8555}, + year={2021} +} +``` \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..97d8089 --- /dev/null +++ b/train.py @@ -0,0 +1,282 @@ +import argparse +import copy +import time +import datetime +import misc +import numpy as np +import os +import torch +import torch.backends.cudnn as cudnn +from pathlib import Path +from data import build_dataset +from model import get_resnet, Network +from misc import NativeScalerWithGradNormCount as NativeScaler +from loss import InstanceLoss, ClusterLoss +from engine import train_one_epoch, evaluate +import json + + +def get_args_parser(): + parser = argparse.ArgumentParser("TCL", add_help=False) + parser.add_argument( + "--batch_size", default=256, type=int, help="Batch size per GPU" + ) + parser.add_argument("--epochs", default=1000, type=int) + + # Model parameters + parser.add_argument( + "--model", + default="resnet34", + type=str, + metavar="MODEL", + choices=["resnet50", "resnet34", "resnet18"], + help="Name of model to train", + ) + parser.add_argument("--feat_dim", default=128, type=int, help="dimension of ICH") + parser.add_argument( + "--ins_temp", + default=0.5, + type=float, + help="temperature of instance-level contrastive loss", + ) + parser.add_argument( + "--clu_temp", + default=1.0, + type=float, + help="temperature of cluster-level contrastive loss", + ) + + # Optimizer parameters + parser.add_argument("--weight_decay", type=float, default=1e-4, help="weight decay") + parser.add_argument( + "--lr", + type=float, + default=1e-4, + metavar="LR", + help="learning rate (absolute lr)", + ) + + # Dataset parameters + parser.add_argument( + "--data_path", default="./datasets/", type=str, help="dataset path", + ) + parser.add_argument( + "--dataset", + default="CIFAR-10", + type=str, + help="dataset", + choices=["CIFAR-10", "CIFAR-100", "ImageNet-10", "ImageNet"], + ) + parser.add_argument( + "--nb_cluster", default=10, type=int, help="number of the clusters", + ) + parser.add_argument( + "--output_dir", + default="./save/", + help="path where to save, empty for no saving", + ) + parser.add_argument( + "--device", default="cuda", help="device to use for training / testing" + ) + parser.add_argument("--seed", default=0, type=int) + parser.add_argument( + "--resume", + default=False, + help="resume from checkpoint", + ) + parser.add_argument( + "--start_epoch", default=0, type=int, metavar="N", help="start epoch" + ) + parser.add_argument("--save_freq", default=50, type=int, help="saving frequency") + parser.add_argument( + "--eval_freq", default=10, type=int, help="evaluation frequency" + ) + parser.add_argument("--num_workers", default=10, type=int) + parser.add_argument( + "--pin_mem", + action="store_true", + help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.", + ) + parser.add_argument( + "--dist_eval", + action="store_true", + default=False, + help="Enabling distributed evaluation (recommended during training for faster monitor", + ) + + # distributed training parameters + parser.add_argument( + "--world_size", default=1, type=int, help="number of distributed processes" + ) + parser.add_argument("--local_rank", default=-1, type=int) + parser.add_argument("--dist_on_itp", action="store_true") + parser.add_argument( + "--dist_url", default="env://", help="url used to set up distributed training" + ) + + return parser + + +def main(args): + misc.init_distributed_mode(args) + + print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__)))) + print("{}".format(args).replace(", ", ",\n")) + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + cudnn.benchmark = True + + dataset_train = build_dataset(type="train", args=args) + dataset_val = build_dataset(type="val", args=args) + + num_tasks = misc.get_world_size() + global_rank = misc.get_rank() + sampler_train = torch.utils.data.DistributedSampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + print("Sampler_train = %s" % str(sampler_train)) + if args.dist_eval: + if len(dataset_val) % num_tasks != 0: + print( + "Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. " + "This will slightly alter validation results as extra duplicate entries are added to achieve " + "equal num of samples per-process." + ) + sampler_val = torch.utils.data.DistributedSampler( + dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) # shuffle=True to reduce monitor bias + else: + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + + if global_rank == 0 and args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + data_loader_train = torch.utils.data.DataLoader( + dataset_train, + sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=True, + ) + + data_loader_val = torch.utils.data.DataLoader( + dataset_val, + sampler=sampler_val, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=False, + ) + + backbone, hidden_dim = get_resnet(args) + model = Network(backbone, hidden_dim, args.feat_dim, args.nb_cluster) + + if args.resume: + checkpoint = torch.load(args.resume, map_location="cpu") + state_dict = checkpoint["state_dict"] + state_dict_copy = copy.deepcopy(state_dict) + for param in list(state_dict): + if param.startswith("module.encoder_q"): + param_copy = "resnet" + param[16:] + state_dict[param_copy] = state_dict_copy[param] + state_dict.pop(param) + + print("Load pre-trained checkpoint from: %s" % args.resume) + msg = model.load_state_dict(state_dict, strict=False) + print(msg) + + model.to(device) + + model_without_ddp = model + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + + print("Model = %s" % str(model_without_ddp)) + print("number of params (M): %.2f" % (n_parameters / 1.0e6)) + + eff_batch_size = args.batch_size * misc.get_world_size() + print("effective batch size: %d" % eff_batch_size) + + optimizer = torch.optim.Adam( + [ + {"params": model.resnet.parameters(), "lr": args.lr,}, + {"params": model.instance_projector.parameters(), "lr": args.lr}, + {"params": model.cluster_projector.parameters(), "lr": args.lr}, + ], + lr=args.lr, + weight_decay=args.weight_decay, + ) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + loss_scaler = NativeScaler() + + criterion_ins = InstanceLoss(tau=args.ins_temp, distributed=True) + criterion_clu = ClusterLoss(tau=args.clu_temp, distributed=True) + + print(f"Start training for {args.epochs} epochs") + start_time = time.time() + max_accuracy = 0.0 + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + data_loader_train.sampler.set_epoch(epoch) + train_stats = train_one_epoch( + model, + criterion_ins, + criterion_clu, + data_loader_train, + optimizer, + device, + epoch, + loss_scaler, + args=args, + ) + if args.output_dir and ( + epoch % args.save_freq == 0 or epoch + 1 == args.epochs + ): + misc.save_model( + args=args, + model=model, + model_without_ddp=model_without_ddp, + optimizer=optimizer, + loss_scaler=loss_scaler, + epoch=epoch, + ) + if epoch % args.eval_freq == 0 or epoch + 1 == args.epochs: + test_stats = evaluate(data_loader_val, model, device) + print( + f"Clustering performance on {len(dataset_val)} test images: NMI={test_stats['nmi']:.2f}%, ACC={test_stats['acc']:.2f}%, ARI={test_stats['ari']:.2f}%" + ) + max_accuracy = max(max_accuracy, test_stats["acc"]) + print(f"Max accuracy: {max_accuracy:.2f}%") + + log_stats = { + **{f"train_{k}": v for k, v in train_stats.items()}, + **{f"test_{k}": v for k, v in test_stats.items()}, + "epoch": epoch, + "n_parameters": n_parameters, + } + + if args.output_dir and misc.is_main_process(): + with open( + os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8" + ) as f: + f.write(json.dumps(log_stats) + "\n") + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print("Training time {}".format(total_time_str)) + + +if __name__ == "__main__": + args = get_args_parser() + args = args.parse_args() + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + main(args) diff --git a/transforms.py b/transforms.py new file mode 100644 index 0000000..3f12899 --- /dev/null +++ b/transforms.py @@ -0,0 +1,271 @@ +# List of augmentations based on randaugment +import random +from PIL import Image, ImageFilter, ImageOps, ImageOps, ImageEnhance +import numpy as np +import torch +from torchvision import transforms + +random_mirror = True + + +def ShearX(img, v): + if random_mirror and random.random() > 0.5: + v = -v + return img.transform(img.size, Image.AFFINE, (1, v, 0, 0, 1, 0)) + + +def ShearY(img, v): + if random_mirror and random.random() > 0.5: + v = -v + return img.transform(img.size, Image.AFFINE, (1, 0, 0, v, 1, 0)) + + +def Identity(img, v): + return img + + +def TranslateX(img, v): + if random_mirror and random.random() > 0.5: + v = -v + v = v * img.size[0] + return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0)) + + +def TranslateY(img, v): + if random_mirror and random.random() > 0.5: + v = -v + v = v * img.size[1] + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v)) + + +def TranslateXAbs(img, v): + if random.random() > 0.5: + v = -v + return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0)) + + +def TranslateYAbs(img, v): + if random.random() > 0.5: + v = -v + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v)) + + +def Rotate(img, v): + if random_mirror and random.random() > 0.5: + v = -v + return img.rotate(v) + + +def AutoContrast(img, _): + return ImageOps.autocontrast(img) + + +def Invert(img, _): + return ImageOps.invert(img) + + +def Equalize(img, _): + return ImageOps.equalize(img) + + +def Solarize(img, v): + return ImageOps.solarize(img, v) + + +def Posterize(img, v): + v = int(v) + return ImageOps.posterize(img, v) + + +def Contrast(img, v): + return ImageEnhance.Contrast(img).enhance(v) + + +def Color(img, v): + return ImageEnhance.Color(img).enhance(v) + + +def Brightness(img, v): + return ImageEnhance.Brightness(img).enhance(v) + + +def Sharpness(img, v): + return ImageEnhance.Sharpness(img).enhance(v) + + +def augment_list(): + l = [ + (Identity, 0, 1), + (AutoContrast, 0, 1), + (Equalize, 0, 1), + (Rotate, -30, 30), + (Solarize, 0, 256), + (Color, 0.05, 0.95), + (Contrast, 0.05, 0.95), + (Brightness, 0.05, 0.95), + (Sharpness, 0.05, 0.95), + (ShearX, -0.1, 0.1), + (TranslateX, -0.1, 0.1), + (TranslateY, -0.1, 0.1), + (Posterize, 4, 8), + (ShearY, -0.1, 0.1), + ] + return l + + +augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()} + + +class AutoAugment: + def __init__(self, n): + self.n = n + self.augment_list = augment_list() + + def __call__(self, img): + ops = random.choices(self.augment_list, k=self.n) + for op, minval, maxval in ops: + val = (random.random()) * float(maxval - minval) + minval + img = op(img, val) + + return img + + +def get_augment(name): + return augment_dict[name] + + +def apply_augment(img, name, level): + augment_fn, low, high = get_augment(name) + return augment_fn(img.copy(), level * (high - low) + low) + + +class Cutout(object): + def __init__(self, n_holes, length): + self.n_holes = n_holes + self.length = length + + def __call__(self, img): + h = img.size(1) + w = img.size(2) + length = random.randint(1, self.length) + mask = np.ones((h, w), np.float32) + + for n in range(self.n_holes): + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - length // 2, 0, h) + y2 = np.clip(y + length // 2, 0, h) + x1 = np.clip(x - length // 2, 0, w) + x2 = np.clip(x + length // 2, 0, w) + + mask[y1:y2, x1:x2] = 0.0 + + mask = torch.from_numpy(mask) + mask = mask.expand_as(img) + img = img * mask + + return img + + +class GaussianBlur(object): + """Gaussian blur augmentation from SimCLR: https://arxiv.org/abs/2002.05709""" + + def __init__(self, sigma=[0.1, 2.0]): + self.sigma = sigma + + def __call__(self, x): + sigma = random.uniform(self.sigma[0], self.sigma[1]) + x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) + return x + + +class Augmentation: + def __init__( + self, + img_size=224, + val_img_size=256, + s=1, + num_aug=4, + cutout_holes=1, + cutout_size=75, + blur=1.0, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ): + self.weak_aug = transforms.Compose( + [ + transforms.RandomResizedCrop( + img_size, interpolation=Image.BICUBIC, scale=(0.2, 1.0) + ), + transforms.RandomHorizontalFlip(), + transforms.RandomApply( + [transforms.ColorJitter(0.8 * s, 0.8 * s, 0.4 * s, 0.2 * s)], p=0.8 + ), + transforms.RandomGrayscale(p=0.2), + transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=blur), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ] + ) + self.strong_aug = transforms.Compose( + [ + transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC), + transforms.RandomHorizontalFlip(), + AutoAugment(n=num_aug), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + Cutout(n_holes=cutout_holes, length=cutout_size), + ] + ) + self.val_aug = transforms.Compose( + [ + transforms.Resize( + (val_img_size, val_img_size), interpolation=Image.BICUBIC + ), + transforms.CenterCrop(img_size), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ] + ) + + def __call__(self, x): + return self.weak_aug(x), self.strong_aug(x), self.val_aug(x) + + +def build_transform(is_train, args): + if args.dataset == "CIFAR-10": + augmentation = Augmentation( + img_size=256, + val_img_size=224, + s=0.5, + num_aug=4, + cutout_holes=1, + cutout_size=75, + mean=[0.4914, 0.4822, 0.4465], + std=[0.2023, 0.1994, 0.2010], + ) + elif args.dataset == "CIFAR-100": + augmentation = Augmentation( + img_size=256, + val_img_size=224, + s=0.5, + num_aug=4, + cutout_holes=1, + cutout_size=75, + mean=[0.5071, 0.4867, 0.4408], + std=[0.2675, 0.2565, 0.2761], + ) + elif args.dataset == "ImageNet-10" or args.dataset == "ImageNet": + augmentation = Augmentation( + img_size=256, + val_img_size=224, + s=0.5, + num_aug=4, + cutout_holes=1, + cutout_size=75, + blur=0.5, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ) + return augmentation if is_train else augmentation.val_aug