diff --git a/boost.py b/boost.py new file mode 100644 index 0000000..cf040d7 --- /dev/null +++ b/boost.py @@ -0,0 +1,333 @@ +import argparse +import time +import datetime +import misc +import numpy as np +import os +import torch +import torch.backends.cudnn as cudnn +from pathlib import Path +from data import build_dataset +from model import get_resnet, Network +from misc import NativeScalerWithGradNormCount as NativeScaler +from loss import InstanceLossBoost, ClusterLossBoost +from engine import boost_one_epoch, evaluate +import json + + +def get_args_parser(): + parser = argparse.ArgumentParser("TCL", add_help=False) + parser.add_argument( + "--batch_size", default=256, type=int, help="Batch size per GPU" + ) + parser.add_argument("--epochs", default=200, type=int) + + # Model parameters + parser.add_argument( + "--model", + default="resnet34", + type=str, + metavar="MODEL", + choices=["resnet50", "resnet34", "resnet18"], + help="Name of model to train", + ) + parser.add_argument("--feat_dim", default=128, type=int, help="dimension of ICH") + parser.add_argument( + "--ins_temp", + default=0.5, + type=float, + help="temperature of instance-level contrastive loss", + ) + parser.add_argument( + "--clu_temp", + default=1.0, + type=float, + help="temperature of cluster-level contrastive loss", + ) + + # Optimizer parameters + parser.add_argument("--weight_decay", type=float, default=1e-4, help="weight decay") + parser.add_argument( + "--lr", + type=float, + default=1e-4, + metavar="LR", + help="learning rate (absolute lr)", + ) + + # Dataset parameters + parser.add_argument( + "--data_path", default="./datasets/", type=str, help="dataset path", + ) + parser.add_argument( + "--dataset", + default="CIFAR-10", + type=str, + help="dataset", + choices=["CIFAR-10", "CIFAR-100", "ImageNet-10", "ImageNet"], + ) + parser.add_argument( + "--nb_cluster", default=10, type=int, help="number of the clusters", + ) + parser.add_argument( + "--output_dir", + default="./save/", + help="path where to save, empty for no saving", + ) + parser.add_argument( + "--device", default="cuda", help="device to use for training / testing" + ) + parser.add_argument("--seed", default=0, type=int) + parser.add_argument( + "--resume", + default="./save/checkpoint-0.pth", + help="resume from checkpoint", + ) + parser.add_argument( + "--start_epoch", default=0, type=int, metavar="N", help="start epoch" + ) + parser.add_argument("--save_freq", default=20, type=int, help="saving frequency") + parser.add_argument( + "--eval_freq", default=10, type=int, help="evaluation frequency" + ) + parser.add_argument("--num_workers", default=10, type=int) + parser.add_argument( + "--pin_mem", + action="store_true", + help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.", + ) + parser.add_argument( + "--dist_eval", + action="store_true", + default=False, + help="Enabling distributed evaluation (recommended during training for faster monitor", + ) + + # distributed training parameters + parser.add_argument( + "--world_size", default=1, type=int, help="number of distributed processes" + ) + parser.add_argument("--local_rank", default=-1, type=int) + parser.add_argument("--dist_on_itp", action="store_true") + parser.add_argument( + "--dist_url", default="env://", help="url used to set up distributed training" + ) + + return parser + + +def main(args): + misc.init_distributed_mode(args) + + print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__)))) + print("{}".format(args).replace(", ", ",\n")) + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + cudnn.benchmark = True + + dataset_train = build_dataset(type="train", args=args) + dataset_pseudo = build_dataset(type="pseudo", args=args) + dataset_val = build_dataset(type="val", args=args) + + num_tasks = misc.get_world_size() + global_rank = misc.get_rank() + sampler_train = torch.utils.data.DistributedSampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + sampler_pseudo = torch.utils.data.DistributedSampler( + dataset_pseudo, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + print("Sampler_train = %s" % str(sampler_train)) + if args.dist_eval: + if len(dataset_val) % num_tasks != 0: + print( + "Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. " + "This will slightly alter validation results as extra duplicate entries are added to achieve " + "equal num of samples per-process." + ) + sampler_val = torch.utils.data.DistributedSampler( + dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) # shuffle=True to reduce monitor bias + else: + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + + if global_rank == 0 and args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + data_loader_train = torch.utils.data.DataLoader( + dataset_train, + sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=True, + ) + + data_loader_ps = torch.utils.data.DataLoader( + dataset_pseudo, + sampler=sampler_pseudo, + batch_size=1000, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=True, + ) + + data_loader_val = torch.utils.data.DataLoader( + dataset_val, + sampler=sampler_val, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=False, + ) + + backbone, hidden_dim = get_resnet(args) + model = Network(backbone, hidden_dim, args.feat_dim, args.nb_cluster) + + if args.resume: + checkpoint = torch.load(args.resume, map_location="cpu") + + print("Load pre-trained checkpoint from: %s" % args.resume) + checkpoint_model = checkpoint["model"] + + # load pre-trained model + msg = model.load_state_dict(checkpoint_model, strict=False) + print(msg) + + model.to(device) + + metric_logger = misc.MetricLogger(delimiter=" ") + header = "Test:" + + # switch to evaluation mode + model.eval() + + feat_vector = [] + labels_vector = [] + for (images, labels, _) in metric_logger.log_every(data_loader_val, 20, header): + images = images.to(device, non_blocking=True) + + # compute output + with torch.cuda.amp.autocast(): + feat, c = model.forward_zc(images) + c = torch.argmax(c, dim=1) + + feat_vector.extend(feat.cpu().detach().numpy()) + labels_vector.extend(labels.numpy()) + feat_vector = np.array(feat_vector) + labels_vector = np.array(labels_vector) + print( + "Feat shape {}, Label shape {}".format(feat_vector.shape, labels_vector.shape) + ) + + model_without_ddp = model + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + + print("Model = %s" % str(model_without_ddp)) + print("number of params (M): %.2f" % (n_parameters / 1.0e6)) + + eff_batch_size = args.batch_size * misc.get_world_size() + + print("base lr: %.3e" % args.lr) + print("effective batch size: %d" % eff_batch_size) + + optimizer = torch.optim.Adam( + [ + {"params": model.resnet.parameters(), "lr": args.lr,}, + {"params": model.instance_projector.parameters(), "lr": args.lr}, + {"params": model.cluster_projector.parameters(), "lr": args.lr}, + ], + lr=args.lr, + weight_decay=args.weight_decay, + ) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + loss_scaler = NativeScaler() + + criterion_ins = InstanceLossBoost( + tau=args.ins_temp, distributed=True, alpha=0.99, gamma=0.5 + ) + criterion_clu = ClusterLossBoost(distributed=True, cluster_num=args.nb_cluster) + + misc.load_model( + args=args, + model_without_ddp=model_without_ddp, + optimizer=optimizer, + loss_scaler=loss_scaler, + ) + + print(f"Start training for {args.epochs} epochs") + pseudo_labels = -torch.ones(dataset_train.__len__(), dtype=torch.long) + start_time = time.time() + max_accuracy = 0.0 + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + data_loader_train.sampler.set_epoch(epoch) + train_stats, pseudo_labels = boost_one_epoch( + model, + criterion_ins, + criterion_clu, + data_loader_train, + optimizer, + device, + epoch, + loss_scaler, + pseudo_labels, + args=args, + ) + if args.output_dir and ( + epoch % args.save_freq == 0 or epoch + 1 == args.epochs + ): + misc.save_model( + args=args, + model=model, + model_without_ddp=model_without_ddp, + optimizer=optimizer, + loss_scaler=loss_scaler, + epoch=epoch, + ) + if ( + epoch % args.eval_freq == 0 + or epoch + 1 == args.epochs + ): + test_stats = evaluate(data_loader_val, model, device) + print( + f"Clustering performance on the {len(dataset_val)} test images: NMI={test_stats['nmi']:.2f}%, ACC={test_stats['acc']:.2f}%, ARI={test_stats['ari']:.2f}%" + ) + max_accuracy = max(max_accuracy, test_stats["acc"]) + print(f"Max accuracy: {max_accuracy:.2f}%") + + if epoch == args.start_epoch: + test_stats = {"pred_num": 1000} + + log_stats = { + **{f"train_{k}": v for k, v in train_stats.items()}, + **{f"test_{k}": v for k, v in test_stats.items()}, + "epoch": epoch, + "n_parameters": n_parameters, + } + + if args.output_dir and misc.is_main_process(): + with open( + os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8" + ) as f: + f.write(json.dumps(log_stats) + "\n") + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print("Training time {}".format(total_time_str)) + + +if __name__ == "__main__": + args = get_args_parser() + args = args.parse_args() + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + main(args) diff --git a/data.py b/data.py new file mode 100644 index 0000000..16a3e20 --- /dev/null +++ b/data.py @@ -0,0 +1,195 @@ +import torchvision +from typing import Any, Callable, Optional +from PIL import Image +from torchvision.datasets.folder import default_loader +from transforms import build_transform +from torch.utils import data + + +class CIFAR10(torchvision.datasets.CIFAR10): + def __init__( + self, + root: str, + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ): + super(CIFAR10, self).__init__( + root, train, transform, target_transform, download + ) + self.train = train + + def __getitem__(self, index: int): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target, index) where target is index of the target class. + """ + img, target = self.data[index], self.targets[index] + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + if self.train: + return img, target, index + else: + return img, target, index + 50000 + + +class CIFAR100(CIFAR10): + """`CIFAR100 `_ Dataset. + + This is a subclass of the `CIFAR10` Dataset. + """ + + base_folder = "cifar-100-python" + url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" + filename = "cifar-100-python.tar.gz" + tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85" + train_list = [ + ["train", "16019d7e3df5f24257cddd939b257f8d"], + ] + + test_list = [ + ["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"], + ] + meta = { + "filename": "meta", + "key": "fine_label_names", + "md5": "7973b15100ade9c7d40fb424638fde48", + } + + +IMG_EXTENSIONS = ( + ".jpg", + ".jpeg", + ".png", + ".ppm", + ".bmp", + ".pgm", + ".tif", + ".tiff", + ".webp", +) + + +class ImageFolder(torchvision.datasets.DatasetFolder): + """A generic data loader where the images are arranged in this way: :: + + root/dog/xxx.png + root/dog/xxy.png + root/dog/xxz.png + + root/cat/123.png + root/cat/nsdf3.png + root/cat/asd932_.png + + Args: + root (string): Root directory path. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + loader (callable, optional): A function to load an image given its path. + is_valid_file (callable, optional): A function that takes path of an Image file + and check if the file is a valid file (used to check of corrupt files) + + Attributes: + classes (list): List of the class names sorted alphabetically. + class_to_idx (dict): Dict with items (class_name, class_index). + imgs (list): List of (image path, class_index) tuples + """ + + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + loader: Callable[[str], Any] = default_loader, + is_valid_file: Optional[Callable[[str], bool]] = None, + ): + super(ImageFolder, self).__init__( + root, + loader, + IMG_EXTENSIONS if is_valid_file is None else None, + transform=transform, + target_transform=target_transform, + is_valid_file=is_valid_file, + ) + self.imgs = self.samples + + def __getitem__(self, index: int): + """ + Args: + index (int): Index + + Returns: + tuple: (sample, target, index) where target is class_index of the target class. + """ + path, target = self.samples[index] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + + return sample, target, index + + +def build_dataset(type, args): + is_train = type == "train" + transform = build_transform(is_train, args) + root = args.data_path + + if args.dataset == "CIFAR-10": + dataset = data.ConcatDataset( + [ + CIFAR10( + root=root + "CIFAR-10", + train=True, + download=True, + transform=transform, + ), + CIFAR10( + root=root + "CIFAR-10", + train=False, + download=True, + transform=transform, + ), + ] + ) + elif args.dataset == "CIFAR-100": + dataset = data.ConcatDataset( + [ + CIFAR100( + root=root + "CIFAR-100", + train=True, + download=True, + transform=transform, + ), + CIFAR100( + root=root + "CIFAR-100", + train=False, + download=True, + transform=transform, + ), + ] + ) + elif args.dataset == "ImageNet-10": + dataset = ImageFolder(root=root + "ImageNet-10", transform=transform) + elif args.dataset == "ImageNet": + dataset = ImageFolder(root=root + "ImageNet/train", transform=transform) + + print(dataset) + + return dataset diff --git a/engine.py b/engine.py new file mode 100644 index 0000000..e843720 --- /dev/null +++ b/engine.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import math +import sys + +import torch +import torch.nn.functional as F + +import misc +import numpy as np +from evaluate import cluster_metric + + +def train_one_epoch( + model, + criterion_ins, + criterion_clu, + data_loader, + optimizer, + device, + epoch, + loss_scaler, + args, +): + model.train(True) + metric_logger = misc.MetricLogger(delimiter=" ") + header = "Epoch: [{}]".format(epoch) + print_freq = 20 + + optimizer.zero_grad() + + for data_iter_step, ((x_w, x_s, x), _, index) in enumerate( + metric_logger.log_every(data_loader, print_freq, header) + ): + x_w = x_w.to(device, non_blocking=True) + x_s = x_s.to(device, non_blocking=True) + + with torch.cuda.amp.autocast(): + z_i, z_j, c_i, c_j = model(x_w, x_s) + c_i = F.softmax(c_i, dim=1) + c_j = F.softmax(c_j, dim=1) + loss_ins = criterion_ins(torch.concat((z_i, z_j), dim=0)) + loss_clu = criterion_clu(torch.concat((c_i, c_j), dim=0)) + loss = loss_ins + loss_clu + + loss_ins_value = loss_ins.item() + loss_clu_value = loss_clu.item() + + if not math.isfinite(loss_ins_value) or not math.isfinite(loss_clu_value): + print( + "Loss is {}, {}, stopping training".format( + loss_ins_value, loss_clu_value + ) + ) + sys.exit(1) + + loss_scaler( + loss, + optimizer, + parameters=model.parameters(), + create_graph=False, + update_grad=True, + ) + optimizer.zero_grad() + + torch.cuda.synchronize() + + metric_logger.update(loss_ins=loss_ins_value) + metric_logger.update(loss_clu=loss_clu_value) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def evaluate(data_loader, model, device): + metric_logger = misc.MetricLogger(delimiter=" ") + header = "Test:" + + # switch to evaluation mode + model.eval() + + pred_vector = [] + labels_vector = [] + for (images, labels, _) in metric_logger.log_every(data_loader, 20, header): + images = images.to(device, non_blocking=True) + + # compute output + with torch.cuda.amp.autocast(): + preds = model.module.forward_c(images) + preds = torch.argmax(preds, dim=1) + + pred_vector.extend(preds.cpu().detach().numpy()) + labels_vector.extend(labels.numpy()) + pred_vector = np.array(pred_vector) + labels_vector = np.array(labels_vector) + print( + "Pred shape {}, Label shape {}".format(pred_vector.shape, labels_vector.shape) + ) + + nmi, ari, acc = cluster_metric(labels_vector, pred_vector) + print(nmi, ari, acc) + + metric_logger.meters["nmi"].update(nmi) + metric_logger.meters["acc"].update(acc) + metric_logger.meters["ari"].update(ari) + # gather the stats from all processes + metric_logger.synchronize_between_processes() + + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +def boost_one_epoch( + model, + criterion_ins, + criterion_clu, + data_loader, + optimizer, + device, + epoch, + loss_scaler, + pseudo_labels, + args, +): + metric_logger = misc.MetricLogger(delimiter=" ") + header = "Epoch: [{}]".format(epoch) + print_freq = 20 + + optimizer.zero_grad() + for data_iter_step, ((x_w, x_s, x), _, index) in enumerate( + metric_logger.log_every(data_loader, print_freq, header) + ): + x_w = x_w.to(device, non_blocking=True) + x_s = x_s.to(device, non_blocking=True) + x = x.to(device, non_blocking=True) + + model.eval() + with torch.cuda.amp.autocast(), torch.no_grad(): + _, _, c = model(x, x, return_ci=False) + c = F.softmax(c / args.clu_temp, dim=1) + pseudo_labels_cur, index_cur = criterion_ins.generate_pseudo_labels( + c, pseudo_labels[index].to(c.device), index.to(c.device) + ) + pseudo_labels[index_cur] = pseudo_labels_cur + pseudo_index = pseudo_labels != -1 + metric_logger.update(pseudo_num=pseudo_index.sum().item()) + metric_logger.update( + pseudo_cluster=torch.unique(pseudo_labels[pseudo_index]).shape[0] + ) + if epoch == args.start_epoch: + continue + + model.train(True) + with torch.cuda.amp.autocast(): + z_i, z_j, c_j = model(x_w, x_s, return_ci=False) + loss_ins = criterion_ins( + torch.concat((z_i, z_j), dim=0), pseudo_labels[index].to(x_s.device) + ) + loss_clu = criterion_clu(c_j, pseudo_labels[index].to(x_s.device)) + loss = loss_ins + loss_clu + + loss_ins_value = loss_ins.item() + loss_clu_value = loss_clu.item() + + if not math.isfinite(loss_ins_value) or not math.isfinite(loss_clu_value): + print( + "Loss is {}, {}, stopping training".format( + loss_ins_value, loss_clu_value + ) + ) + sys.exit(1) + loss_scaler( + loss, + optimizer, + parameters=model.parameters(), + create_graph=False, + update_grad=True, + ) + optimizer.zero_grad() + + torch.cuda.synchronize() + + metric_logger.update(loss_ins=loss_ins_value) + metric_logger.update(loss_clu=loss_clu_value) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return ( + {k: meter.global_avg for k, meter in metric_logger.meters.items()}, + pseudo_labels, + ) + diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..98674b7 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,93 @@ +from sklearn import metrics +from munkres import Munkres +import numpy as np +import torch +from scipy.optimize import linear_sum_assignment + + +def reorder_preds(predictions, targets): + predictions = torch.from_numpy(predictions).cuda() + targets = torch.from_numpy(targets).cuda() + match = _hungarian_match(predictions, targets, preds_k=1000, targets_k=1000) + reordered_preds = torch.zeros(predictions.shape[0], dtype=predictions.dtype).cuda() + for pred_i, target_i in match: + reordered_preds[predictions == int(pred_i)] = int(target_i) + return reordered_preds.cpu().numpy() + + +def _hungarian_match(flat_preds, flat_targets, preds_k, targets_k): + # Based on implementation from IIC + num_samples = flat_targets.shape[0] + + assert preds_k == targets_k # one to one + num_k = preds_k + num_correct = np.zeros((num_k, num_k)) + + for c1 in range(num_k): + for c2 in range(num_k): + # elementwise, so each sample contributes once + votes = int(((flat_preds == c1) * (flat_targets == c2)).sum()) + num_correct[c1, c2] = votes + + # num_correct is small + match = linear_sum_assignment(num_samples - num_correct) + match = np.array(list(zip(*match))) + + # return as list of tuples, out_c to gt_c + res = [] + for out_c, gt_c in match: + res.append((out_c, gt_c)) + + return res + + +def cluster_metric(label, pred): + nmi = metrics.normalized_mutual_info_score(label, pred) + ari = metrics.adjusted_rand_score(label, pred) + # pred_adjusted = get_y_preds(label, pred, len(set(label))) + pred_adjusted = reorder_preds(pred, label) + acc = metrics.accuracy_score(pred_adjusted, label) + return nmi * 100, ari * 100, acc * 100 + + +def calculate_cost_matrix(C, n_clusters): + cost_matrix = np.zeros((n_clusters, n_clusters)) + # cost_matrix[i,j] will be the cost of assigning cluster i to label j + for j in range(n_clusters): + s = np.sum(C[:, j]) # number of examples in cluster i + for i in range(n_clusters): + t = C[i, j] + cost_matrix[j, i] = s - t + return cost_matrix + + +def get_cluster_labels_from_indices(indices): + n_clusters = len(indices) + cluster_labels = np.zeros(n_clusters) + for i in range(n_clusters): + cluster_labels[i] = indices[i][1] + return cluster_labels + + +def get_y_preds(y_true, cluster_assignments, n_clusters): + """ + Computes the predicted labels, where label assignments now + correspond to the actual labels in y_true (as estimated by Munkres) + cluster_assignments: array of labels, outputted by kmeans + y_true: true labels + n_clusters: number of clusters in the dataset + returns: a tuple containing the accuracy and confusion matrix, + in that order + """ + confusion_matrix = metrics.confusion_matrix( + y_true, cluster_assignments, labels=None + ) + # compute accuracy based on optimal 1:1 assignment of clusters to labels + cost_matrix = calculate_cost_matrix(confusion_matrix, n_clusters) + indices = Munkres().compute(cost_matrix) + kmeans_to_true_cluster_labels = get_cluster_labels_from_indices(indices) + + if np.min(cluster_assignments) != 0: + cluster_assignments = cluster_assignments - np.min(cluster_assignments) + y_pred = kmeans_to_true_cluster_labels[cluster_assignments] + return y_pred diff --git a/loss.py b/loss.py new file mode 100644 index 0000000..e08dc28 --- /dev/null +++ b/loss.py @@ -0,0 +1,368 @@ +import torch +from torch import nn +import numpy as np +import torch.nn.functional as F +import diffdist +import torch.distributed as dist + + +def gather(z): + gather_z = [torch.zeros_like(z) for _ in range(torch.distributed.get_world_size())] + gather_z = diffdist.functional.all_gather(gather_z, z) + gather_z = torch.cat(gather_z) + + return gather_z + + +def accuracy(logits, labels, k): + topk = torch.sort(logits.topk(k, dim=1)[1], 1)[0] + labels = torch.sort(labels, 1)[0] + acc = (topk == labels).all(1).float() + return acc + + +def mean_cumulative_gain(logits, labels, k): + topk = torch.sort(logits.topk(k, dim=1)[1], 1)[0] + labels = torch.sort(labels, 1)[0] + mcg = (topk == labels).float().mean(1) + return mcg + + +def mean_average_precision(logits, labels, k): + # TODO: not the fastest solution but looks fine + argsort = torch.argsort(logits, dim=1, descending=True) + labels_to_sorted_idx = ( + torch.sort(torch.gather(torch.argsort(argsort, dim=1), 1, labels), dim=1)[0] + 1 + ) + precision = ( + 1 + torch.arange(k, device=logits.device).float() + ) / labels_to_sorted_idx + return precision.sum(1) / k + + +class InstanceLoss(nn.Module): + """ + Contrastive loss with distributed data parallel support + """ + + LARGE_NUMBER = 1e4 + + def __init__(self, tau=0.5, multiplier=2, distributed=False): + super().__init__() + self.tau = tau + self.multiplier = multiplier + self.distributed = distributed + + def forward(self, z, get_map=False): + n = z.shape[0] + assert n % self.multiplier == 0 + + z = z / np.sqrt(self.tau) + + if self.distributed: + z_list = [torch.zeros_like(z) for _ in range(dist.get_world_size())] + # all_gather fills the list as [, , ...] + # TODO: try to rewrite it with pytorch official tools + z_list = diffdist.functional.all_gather(z_list, z) + # split it into [, , ..., , , ...] + z_list = [chunk for x in z_list for chunk in x.chunk(self.multiplier)] + # sort it to [, , ...] that simply means [, , ...] as expected below + z_sorted = [] + for m in range(self.multiplier): + for i in range(dist.get_world_size()): + z_sorted.append(z_list[i * self.multiplier + m]) + z = torch.cat(z_sorted, dim=0) + n = z.shape[0] + + logits = z @ z.t() + logits[np.arange(n), np.arange(n)] = -self.LARGE_NUMBER + + logprob = F.log_softmax(logits, dim=1) + + # choose all positive objects for an example, for i it would be (i + k * n/m), where k=0...(m-1) + m = self.multiplier + labels = (np.repeat(np.arange(n), m) + np.tile(np.arange(m) * n // m, n)) % n + # remove labels pointet to itself, i.e. (i, i) + labels = labels.reshape(n, m)[:, 1:].reshape(-1) + + loss = -logprob[np.repeat(np.arange(n), m - 1), labels].sum() / n / (m - 1) + + return loss + + +class ClusterLoss(nn.Module): + """ + Contrastive loss with distributed data parallel support + """ + + LARGE_NUMBER = 1e4 + + def __init__(self, tau=1.0, multiplier=2, distributed=False): + super().__init__() + self.tau = tau + self.multiplier = multiplier + self.distributed = distributed + + def forward(self, c, get_map=False): + n = c.shape[0] + assert n % self.multiplier == 0 + + # c = c / np.sqrt(self.tau) + + if self.distributed: + c_list = [torch.zeros_like(c) for _ in range(dist.get_world_size())] + # all_gather fills the list as [, , ...] + c_list = diffdist.functional.all_gather(c_list, c) + # split it into [, , ..., , , ...] + c_list = [chunk for x in c_list for chunk in x.chunk(self.multiplier)] + # sort it to [, , ...] that simply means [, , ...] as expected below + c_sorted = [] + for m in range(self.multiplier): + for i in range(dist.get_world_size()): + c_sorted.append(c_list[i * self.multiplier + m]) + c_aug0 = torch.cat( + c_sorted[: int(self.multiplier * dist.get_world_size() / 2)], dim=0 + ) + c_aug1 = torch.cat( + c_sorted[int(self.multiplier * dist.get_world_size() / 2) :], dim=0 + ) + + p_i = c_aug0.sum(0).view(-1) + p_i /= p_i.sum() + en_i = np.log(p_i.size(0)) + (p_i * torch.log(p_i)).sum() + p_j = c_aug1.sum(0).view(-1) + p_j /= p_j.sum() + en_j = np.log(p_j.size(0)) + (p_j * torch.log(p_j)).sum() + en_loss = en_i + en_j + + c = torch.cat((c_aug0.t(), c_aug1.t()), dim=0) + n = c.shape[0] + + c = F.normalize(c, p=2, dim=1) / np.sqrt(self.tau) + + logits = c @ c.t() + logits[np.arange(n), np.arange(n)] = -self.LARGE_NUMBER + + logprob = F.log_softmax(logits, dim=1) + + # choose all positive objects for an example, for i it would be (i + k * n/m), where k=0...(m-1) + m = self.multiplier + labels = (np.repeat(np.arange(n), m) + np.tile(np.arange(m) * n // m, n)) % n + # remove labels pointet to itself, i.e. (i, i) + labels = labels.reshape(n, m)[:, 1:].reshape(-1) + + loss = -logprob[np.repeat(np.arange(n), m - 1), labels].sum() / n / (m - 1) + + return loss + en_loss + + +class InstanceLossBoost(nn.Module): + """ + Contrastive loss with distributed data parallel support + """ + + LARGE_NUMBER = 1e4 + + def __init__( + self, + tau=0.5, + multiplier=2, + distributed=False, + alpha=0.9, + gamma=0.1, + cluster_num=10, + ): + super().__init__() + self.tau = tau + self.multiplier = multiplier + self.distributed = distributed + self.alpha = alpha + self.gamma = gamma + self.cluster_num = cluster_num + + @torch.no_grad() + def generate_pseudo_labels(self, c, pseudo_label_cur, index): + if self.distributed: + c_list = [torch.zeros_like(c) for _ in range(dist.get_world_size())] + pseudo_label_cur_list = [torch.zeros_like(pseudo_label_cur) for _ in range(dist.get_world_size())] + index_list = [torch.zeros_like(index) for _ in range(dist.get_world_size())] + # all_gather fills the list as [, , ...] + c_list = diffdist.functional.all_gather(c_list, c) + pseudo_label_cur_list = diffdist.functional.all_gather(pseudo_label_cur_list, pseudo_label_cur) + index_list = diffdist.functional.all_gather(index_list, index) + c = torch.cat(c_list, dim=0,) + pseudo_label_cur = torch.cat(pseudo_label_cur_list, dim=0,) + index = torch.cat(index_list, dim=0,) + batch_size = c.shape[0] + device = c.device + pseudo_label_nxt = -torch.ones(batch_size, dtype=torch.long).to(device) + tmp = torch.arange(0, batch_size).to(device) + + prediction = c.argmax(dim=1) + confidence = c.max(dim=1).values + unconfident_pred_index = confidence < self.alpha + pseudo_per_class = np.ceil(batch_size / self.cluster_num * self.gamma).astype( + int + ) + for i in range(self.cluster_num): + class_idx = prediction == i + if class_idx.sum() == 0: + continue + confidence_class = confidence[class_idx] + num = min(confidence_class.shape[0], pseudo_per_class) + confident_idx = torch.argsort(-confidence_class) + for j in range(num): + idx = tmp[class_idx][confident_idx[j]] + pseudo_label_nxt[idx] = i + + todo_index = pseudo_label_cur == -1 + pseudo_label_cur[todo_index] = pseudo_label_nxt[todo_index] + pseudo_label_nxt = pseudo_label_cur + pseudo_label_nxt[unconfident_pred_index] = -1 + return pseudo_label_nxt.cpu(), index + + def forward(self, z, pseudo_label): + n = z.shape[0] + assert n % self.multiplier == 0 + + if self.distributed: + z_list = [torch.zeros_like(z) for _ in range(dist.get_world_size())] + pseudo_label_list = [ + torch.zeros_like(pseudo_label) for _ in range(dist.get_world_size()) + ] + # all_gather fills the list as [, , ...] + z_list = diffdist.functional.all_gather(z_list, z) + pseudo_label_list = diffdist.functional.all_gather( + pseudo_label_list, pseudo_label + ) + # split it into [, , ..., , , ...] + z_list = [chunk for x in z_list for chunk in x.chunk(self.multiplier)] + pseudo_label_list = [ + chunk for x in pseudo_label_list for chunk in x.chunk(self.multiplier) + ] + # sort it to [, , ...] that simply means [, , ...] as expected below + z_sorted = [] + pesudo_label_sorted = [] + for m in range(self.multiplier): + for i in range(dist.get_world_size()): + z_sorted.append(z_list[i * self.multiplier + m]) + pesudo_label_sorted.append( + pseudo_label_list[i * self.multiplier + m] + ) + z_i = torch.cat( + z_sorted[: int(self.multiplier * dist.get_world_size() / 2)], dim=0 + ) + z_j = torch.cat( + z_sorted[int(self.multiplier * dist.get_world_size() / 2) :], dim=0 + ) + pseudo_label = torch.cat(pesudo_label_sorted, dim=0,) + n = z_i.shape[0] + + invalid_index = pseudo_label == -1 + mask = torch.eq(pseudo_label.view(-1, 1), pseudo_label.view(1, -1)).to( + z_i.device + ) + mask[invalid_index, :] = False + mask[:, invalid_index] = False + mask_eye = torch.eye(n).float().to(z_i.device) + mask &= ~(mask_eye.bool()) + mask = mask.float() + + contrast_count = self.multiplier + contrast_feature = torch.cat((z_i, z_j), dim=0) + + anchor_feature = contrast_feature + anchor_count = contrast_count + + # compute logits + anchor_dot_contrast = torch.div( + torch.matmul(anchor_feature, contrast_feature.T), self.tau + ) + # for numerical stability + logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) + logits = anchor_dot_contrast - logits_max.detach() + + # tile mask + # mask_with_eye = mask | mask_eye.bool() + # mask = torch.cat(mask) + mask = mask.repeat(anchor_count, contrast_count) + mask_eye = mask_eye.repeat(anchor_count, contrast_count) + # mask-out self-contrast cases + logits_mask = torch.scatter( + torch.ones_like(mask), + 1, + torch.arange(n * anchor_count).view(-1, 1).to(z_i.device), + 0, + ) + logits_mask *= 1 - mask + mask_eye = mask_eye * logits_mask + + # compute log_prob + exp_logits = torch.exp(logits) * logits_mask + log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) + + # compute mean of log-likelihood over positive + mean_log_prob_pos = (mask_eye * log_prob).sum(1) / mask_eye.sum(1) + + # loss + instance_loss = -mean_log_prob_pos + instance_loss = instance_loss.view(anchor_count, n).mean() + + return instance_loss + + +class ClusterLossBoost(nn.Module): + """ + Contrastive loss with distributed data parallel support + """ + + LARGE_NUMBER = 1e4 + + def __init__(self, multiplier=1, distributed=False, cluster_num=10): + super().__init__() + self.multiplier = multiplier + self.distributed = distributed + self.cluster_num = cluster_num + + def forward(self, c, pseudo_label): + if self.distributed: + # c_list = [torch.zeros_like(c) for _ in range(dist.get_world_size())] + pesudo_label_list = [ + torch.zeros_like(pseudo_label) for _ in range(dist.get_world_size()) + ] + # all_gather fills the list as [, , ...] + # c_list = diffdist.functional.all_gather(c_list, c) + pesudo_label_list = diffdist.functional.all_gather( + pesudo_label_list, pseudo_label + ) + # split it into [, , ..., , , ...] + # c_list = [chunk for x in c_list for chunk in x.chunk(self.multiplier)] + pesudo_label_list = [ + chunk for x in pesudo_label_list for chunk in x.chunk(self.multiplier) + ] + # sort it to [, , ...] that simply means [, , ...] as expected below + # c_sorted = [] + pesudo_label_sorted = [] + for m in range(self.multiplier): + for i in range(dist.get_world_size()): + # c_sorted.append(c_list[i * self.multiplier + m]) + pesudo_label_sorted.append( + pesudo_label_list[i * self.multiplier + m] + ) + # c = torch.cat(c_sorted, dim=0) + pesudo_label_all = torch.cat(pesudo_label_sorted, dim=0) + pseudo_index = pesudo_label_all != -1 + pesudo_label_all = pesudo_label_all[pseudo_index] + idx, counts = torch.unique(pesudo_label_all, return_counts=True) + freq = pesudo_label_all.shape[0] / counts.float() + weight = torch.ones(self.cluster_num).to(c.device) + weight[idx] = freq + pseudo_index = pseudo_label != -1 + if pseudo_index.sum() > 0: + criterion = nn.CrossEntropyLoss(weight=weight).to(c.device) + loss_ce = criterion( + c[pseudo_index], pseudo_label[pseudo_index].to(c.device) + ) + else: + loss_ce = torch.tensor(0.0, requires_grad=True).to(c.device) + return loss_ce diff --git a/misc.py b/misc.py new file mode 100644 index 0000000..642d7bc --- /dev/null +++ b/misc.py @@ -0,0 +1,381 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import builtins +import datetime +import os +import time +from collections import defaultdict, deque +from pathlib import Path + +import torch +import torch.distributed as dist +from torch._six import inf + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + "'{}' object has no attribute '{}'".format(type(self).__name__, attr) + ) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + log_msg = [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + if torch.cuda.is_available(): + log_msg.append("max mem: {memory:.0f}") + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print( + "{} Total time: {} ({:.4f} s / it)".format( + header, total_time_str, total_time / len(iterable) + ) + ) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + force = force or (get_world_size() > 8) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print("[{}] ".format(now), end="") # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if args.dist_on_itp: + args.rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + args.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + args.gpu = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + args.dist_url = "tcp://%s:%s" % ( + os.environ["MASTER_ADDR"], + os.environ["MASTER_PORT"], + ) + os.environ["LOCAL_RANK"] = str(args.gpu) + os.environ["RANK"] = str(args.rank) + os.environ["WORLD_SIZE"] = str(args.world_size) + # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] + elif "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) + args.gpu = args.rank % torch.cuda.device_count() + else: + print("Not using distributed mode") + setup_for_distributed(is_master=True) # hack + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print( + "| distributed init (rank {}): {}, gpu {}".format( + args.rank, args.dist_url, args.gpu + ), + flush=True, + ) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__( + self, + loss, + optimizer, + clip_grad=None, + parameters=None, + create_graph=False, + update_grad=True, + ): + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_( + optimizer + ) # unscale the gradients of optimizer's assigned params in-place + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.0) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm( + torch.stack( + [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters] + ), + norm_type, + ) + return total_norm + + +def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): + output_dir = Path(args.output_dir) + epoch_name = str(epoch) + if loss_scaler is not None: + checkpoint_paths = [output_dir / ("checkpoint-%s.pth" % epoch_name)] + for checkpoint_path in checkpoint_paths: + to_save = { + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + # "encoder_optimizer": encoder_optimizer.state_dict(), + "epoch": epoch, + "scaler": loss_scaler.state_dict(), + "args": args, + } + + save_on_master(to_save, checkpoint_path) + else: + client_state = {"epoch": epoch} + model.save_checkpoint( + save_dir=args.output_dir, + tag="checkpoint-%s" % epoch_name, + client_state=client_state, + ) + + +def load_model(args, model_without_ddp, optimizer, loss_scaler): + checkpoint = torch.load(args.resume, map_location="cpu") + model_without_ddp.load_state_dict(checkpoint["model"]) + print("Resume checkpoint %s" % args.resume) + if "optimizer" in checkpoint and "epoch" in checkpoint: + optimizer.load_state_dict(checkpoint["optimizer"]) + args.start_epoch = checkpoint["epoch"] + 1 + if "scaler" in checkpoint: + loss_scaler.load_state_dict(checkpoint["scaler"]) + print("With optim!") + + +def all_reduce_mean(x): + world_size = get_world_size() + if world_size > 1: + x_reduce = torch.tensor(x).cuda() + dist.all_reduce(x_reduce) + x_reduce /= world_size + return x_reduce.item() + else: + return x diff --git a/model.py b/model.py new file mode 100644 index 0000000..e8fbffb --- /dev/null +++ b/model.py @@ -0,0 +1,211 @@ +from torch import nn +from torch.nn import functional as F +from torchvision.models.resnet import Bottleneck, BasicBlock, conv1x1 +import torch +from timm.models.layers import trunc_normal_ + + +class ResNet(nn.Module): + def __init__( + self, + block, + layers, + in_channel=3, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm_layer=None, + ): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation) + ) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d( + in_channel, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False + ) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] + ) + self.layer3 = self._make_layer( + block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] + ) + self.layer4 = self._make_layer( + block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] + ) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.rep_dim = 512 * block.expansion + self.fc = nn.Linear(self.rep_dim, self.rep_dim) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + # for name, param in self.named_parameters(): + # if ( + # name.startswith("conv1") + # or name.startswith("bn1") + # or name.startswith("layer1") + # or name.startswith("layer2") + # ): + # print("Freeze gradient for", name) + # param.requires_grad = False + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + # See note [TorchScript super()] + # with torch.no_grad(): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, x): + return self._forward_impl(x) + + +def get_resnet(args): + if args.model == "resnet18": + return ResNet(BasicBlock, [2, 2, 2, 2]), 512 + elif args.model == "resnet34": + return ResNet(BasicBlock, [3, 4, 6, 3]), 512 + elif args.model == "resnet50": + return ResNet(Bottleneck, [3, 4, 6, 3]), 2048 + else: + raise NotImplementedError + + +class Network(nn.Module): + def __init__(self, resnet, hidden_dim, feature_dim, class_num): + super(Network, self).__init__() + self.resnet = resnet + self.feature_dim = feature_dim + self.cluster_num = class_num + self.instance_projector = nn.Sequential( + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, self.feature_dim), + ) + self.cluster_projector = nn.Sequential( + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, self.cluster_num), + ) + trunc_normal_(self.cluster_projector[2].weight, std=0.02) + trunc_normal_(self.cluster_projector[5].weight, std=0.02) + + def forward(self, x_i, x_j, return_ci=True): + h_i = self.resnet(x_i) + h_j = self.resnet(x_j) + + z_i = F.normalize(self.instance_projector(h_i), dim=1) + z_j = F.normalize(self.instance_projector(h_j), dim=1) + + c_j = self.cluster_projector(h_j) + + if return_ci: + c_i = self.cluster_projector(h_i) + return z_i, z_j, c_i, c_j + else: + return z_i, z_j, c_j + + def forward_c(self, x): + h = self.resnet(x) + c = self.cluster_projector(h) + c = F.softmax(c, dim=1) + return c + + def forward_zc(self, x): + h = self.resnet(x) + z = F.normalize(self.instance_projector(h), dim=1) + c = self.cluster_projector(h) + c = F.softmax(c, dim=1) + return z, c diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..cf75841 --- /dev/null +++ b/readme.md @@ -0,0 +1,68 @@ +# Twin Contrastive Learning for Online Clustering (TCL) + +This is the code for the paper "Twin Contrastive Learning for Online Clustering" (IJCV 2022). + +TCL extends the previous work "Contrastive Clustering" (AAAI 2021, https://github.com/Yunfan-Li/Contrastive-Clustering) by selecting most confident predictions to finetune both the instance- and cluster-level contrastive learning. + +TCL proposes to mix weak and strong augmentations for both image and text modality. More performance gains are observed by the twin contrastive learning framework compared with the standard instance-level contrastive learning. + +The code supports multi-gpu training. + +Paper Link: https://link.springer.com/article/10.1007/s11263-022-01639-z + +# Environment + +- diffdist=0.1 +- python=3.9.12 +- pytorch=1.11.0 +- torchvision=0.12.0 +- munkres=1.1.4 +- numpy=1.22.3 +- opencv-python=4.6.0.66 +- scikit-learn=1.0.2 +- cudatoolkit=11.3.1 + +# Usage + +TCL is composed of the training and boosting stages. Configurations such as model, dataset, temperature, etc. could be set with argparse. Clustering performance is evaluated during the training or boosting. + +## Training + +The following command is used for training on CIFAR-10 with a 4-gpu machine, + +> OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=4 train.py + +## Boosting + +The following command is used for boosting on CIFAR-10 with a 4-gpu machine, + +> OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=4 boost.py + +## Clustering On ImageNet +To clustering datasets like ImageNet with a large number of classes, a reasonable batch size is needed. However, considering the gpu memory consumption, we recommend inheriting the moco v2 pretrained model (https://dl.fbaipublicfiles.com/moco/moco_checkpoints/moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar) and freezing part of the network parameters (see details in model.py). + +# Dataset + +CIFAR-10, CIFAR-100 could be automatically downloaded by Pytorch. For ImageNet-10 and ImageNet-dogs, we provided their indices from ImageNet in the "dataset" folder. + +To run TCL on ImageNet and its subsets, you need to prepare the data and pass the image folder path to the `--data_path` argment. + +# Citation + +If you find TCL useful in your research, please consider citing: +``` + +``` + +or the previous conference version +``` +@inproceedings{li2021contrastive, + title={Contrastive clustering}, + author={Li, Yunfan and Hu, Peng and Liu, Zitao and Peng, Dezhong and Zhou, Joey Tianyi and Peng, Xi}, + booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, + volume={35}, + number={10}, + pages={8547--8555}, + year={2021} +} +``` \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..97d8089 --- /dev/null +++ b/train.py @@ -0,0 +1,282 @@ +import argparse +import copy +import time +import datetime +import misc +import numpy as np +import os +import torch +import torch.backends.cudnn as cudnn +from pathlib import Path +from data import build_dataset +from model import get_resnet, Network +from misc import NativeScalerWithGradNormCount as NativeScaler +from loss import InstanceLoss, ClusterLoss +from engine import train_one_epoch, evaluate +import json + + +def get_args_parser(): + parser = argparse.ArgumentParser("TCL", add_help=False) + parser.add_argument( + "--batch_size", default=256, type=int, help="Batch size per GPU" + ) + parser.add_argument("--epochs", default=1000, type=int) + + # Model parameters + parser.add_argument( + "--model", + default="resnet34", + type=str, + metavar="MODEL", + choices=["resnet50", "resnet34", "resnet18"], + help="Name of model to train", + ) + parser.add_argument("--feat_dim", default=128, type=int, help="dimension of ICH") + parser.add_argument( + "--ins_temp", + default=0.5, + type=float, + help="temperature of instance-level contrastive loss", + ) + parser.add_argument( + "--clu_temp", + default=1.0, + type=float, + help="temperature of cluster-level contrastive loss", + ) + + # Optimizer parameters + parser.add_argument("--weight_decay", type=float, default=1e-4, help="weight decay") + parser.add_argument( + "--lr", + type=float, + default=1e-4, + metavar="LR", + help="learning rate (absolute lr)", + ) + + # Dataset parameters + parser.add_argument( + "--data_path", default="./datasets/", type=str, help="dataset path", + ) + parser.add_argument( + "--dataset", + default="CIFAR-10", + type=str, + help="dataset", + choices=["CIFAR-10", "CIFAR-100", "ImageNet-10", "ImageNet"], + ) + parser.add_argument( + "--nb_cluster", default=10, type=int, help="number of the clusters", + ) + parser.add_argument( + "--output_dir", + default="./save/", + help="path where to save, empty for no saving", + ) + parser.add_argument( + "--device", default="cuda", help="device to use for training / testing" + ) + parser.add_argument("--seed", default=0, type=int) + parser.add_argument( + "--resume", + default=False, + help="resume from checkpoint", + ) + parser.add_argument( + "--start_epoch", default=0, type=int, metavar="N", help="start epoch" + ) + parser.add_argument("--save_freq", default=50, type=int, help="saving frequency") + parser.add_argument( + "--eval_freq", default=10, type=int, help="evaluation frequency" + ) + parser.add_argument("--num_workers", default=10, type=int) + parser.add_argument( + "--pin_mem", + action="store_true", + help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.", + ) + parser.add_argument( + "--dist_eval", + action="store_true", + default=False, + help="Enabling distributed evaluation (recommended during training for faster monitor", + ) + + # distributed training parameters + parser.add_argument( + "--world_size", default=1, type=int, help="number of distributed processes" + ) + parser.add_argument("--local_rank", default=-1, type=int) + parser.add_argument("--dist_on_itp", action="store_true") + parser.add_argument( + "--dist_url", default="env://", help="url used to set up distributed training" + ) + + return parser + + +def main(args): + misc.init_distributed_mode(args) + + print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__)))) + print("{}".format(args).replace(", ", ",\n")) + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + cudnn.benchmark = True + + dataset_train = build_dataset(type="train", args=args) + dataset_val = build_dataset(type="val", args=args) + + num_tasks = misc.get_world_size() + global_rank = misc.get_rank() + sampler_train = torch.utils.data.DistributedSampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + print("Sampler_train = %s" % str(sampler_train)) + if args.dist_eval: + if len(dataset_val) % num_tasks != 0: + print( + "Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. " + "This will slightly alter validation results as extra duplicate entries are added to achieve " + "equal num of samples per-process." + ) + sampler_val = torch.utils.data.DistributedSampler( + dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) # shuffle=True to reduce monitor bias + else: + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + + if global_rank == 0 and args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + data_loader_train = torch.utils.data.DataLoader( + dataset_train, + sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=True, + ) + + data_loader_val = torch.utils.data.DataLoader( + dataset_val, + sampler=sampler_val, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=False, + ) + + backbone, hidden_dim = get_resnet(args) + model = Network(backbone, hidden_dim, args.feat_dim, args.nb_cluster) + + if args.resume: + checkpoint = torch.load(args.resume, map_location="cpu") + state_dict = checkpoint["state_dict"] + state_dict_copy = copy.deepcopy(state_dict) + for param in list(state_dict): + if param.startswith("module.encoder_q"): + param_copy = "resnet" + param[16:] + state_dict[param_copy] = state_dict_copy[param] + state_dict.pop(param) + + print("Load pre-trained checkpoint from: %s" % args.resume) + msg = model.load_state_dict(state_dict, strict=False) + print(msg) + + model.to(device) + + model_without_ddp = model + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + + print("Model = %s" % str(model_without_ddp)) + print("number of params (M): %.2f" % (n_parameters / 1.0e6)) + + eff_batch_size = args.batch_size * misc.get_world_size() + print("effective batch size: %d" % eff_batch_size) + + optimizer = torch.optim.Adam( + [ + {"params": model.resnet.parameters(), "lr": args.lr,}, + {"params": model.instance_projector.parameters(), "lr": args.lr}, + {"params": model.cluster_projector.parameters(), "lr": args.lr}, + ], + lr=args.lr, + weight_decay=args.weight_decay, + ) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + loss_scaler = NativeScaler() + + criterion_ins = InstanceLoss(tau=args.ins_temp, distributed=True) + criterion_clu = ClusterLoss(tau=args.clu_temp, distributed=True) + + print(f"Start training for {args.epochs} epochs") + start_time = time.time() + max_accuracy = 0.0 + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + data_loader_train.sampler.set_epoch(epoch) + train_stats = train_one_epoch( + model, + criterion_ins, + criterion_clu, + data_loader_train, + optimizer, + device, + epoch, + loss_scaler, + args=args, + ) + if args.output_dir and ( + epoch % args.save_freq == 0 or epoch + 1 == args.epochs + ): + misc.save_model( + args=args, + model=model, + model_without_ddp=model_without_ddp, + optimizer=optimizer, + loss_scaler=loss_scaler, + epoch=epoch, + ) + if epoch % args.eval_freq == 0 or epoch + 1 == args.epochs: + test_stats = evaluate(data_loader_val, model, device) + print( + f"Clustering performance on {len(dataset_val)} test images: NMI={test_stats['nmi']:.2f}%, ACC={test_stats['acc']:.2f}%, ARI={test_stats['ari']:.2f}%" + ) + max_accuracy = max(max_accuracy, test_stats["acc"]) + print(f"Max accuracy: {max_accuracy:.2f}%") + + log_stats = { + **{f"train_{k}": v for k, v in train_stats.items()}, + **{f"test_{k}": v for k, v in test_stats.items()}, + "epoch": epoch, + "n_parameters": n_parameters, + } + + if args.output_dir and misc.is_main_process(): + with open( + os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8" + ) as f: + f.write(json.dumps(log_stats) + "\n") + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print("Training time {}".format(total_time_str)) + + +if __name__ == "__main__": + args = get_args_parser() + args = args.parse_args() + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + main(args) diff --git a/transforms.py b/transforms.py new file mode 100644 index 0000000..3f12899 --- /dev/null +++ b/transforms.py @@ -0,0 +1,271 @@ +# List of augmentations based on randaugment +import random +from PIL import Image, ImageFilter, ImageOps, ImageOps, ImageEnhance +import numpy as np +import torch +from torchvision import transforms + +random_mirror = True + + +def ShearX(img, v): + if random_mirror and random.random() > 0.5: + v = -v + return img.transform(img.size, Image.AFFINE, (1, v, 0, 0, 1, 0)) + + +def ShearY(img, v): + if random_mirror and random.random() > 0.5: + v = -v + return img.transform(img.size, Image.AFFINE, (1, 0, 0, v, 1, 0)) + + +def Identity(img, v): + return img + + +def TranslateX(img, v): + if random_mirror and random.random() > 0.5: + v = -v + v = v * img.size[0] + return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0)) + + +def TranslateY(img, v): + if random_mirror and random.random() > 0.5: + v = -v + v = v * img.size[1] + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v)) + + +def TranslateXAbs(img, v): + if random.random() > 0.5: + v = -v + return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0)) + + +def TranslateYAbs(img, v): + if random.random() > 0.5: + v = -v + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v)) + + +def Rotate(img, v): + if random_mirror and random.random() > 0.5: + v = -v + return img.rotate(v) + + +def AutoContrast(img, _): + return ImageOps.autocontrast(img) + + +def Invert(img, _): + return ImageOps.invert(img) + + +def Equalize(img, _): + return ImageOps.equalize(img) + + +def Solarize(img, v): + return ImageOps.solarize(img, v) + + +def Posterize(img, v): + v = int(v) + return ImageOps.posterize(img, v) + + +def Contrast(img, v): + return ImageEnhance.Contrast(img).enhance(v) + + +def Color(img, v): + return ImageEnhance.Color(img).enhance(v) + + +def Brightness(img, v): + return ImageEnhance.Brightness(img).enhance(v) + + +def Sharpness(img, v): + return ImageEnhance.Sharpness(img).enhance(v) + + +def augment_list(): + l = [ + (Identity, 0, 1), + (AutoContrast, 0, 1), + (Equalize, 0, 1), + (Rotate, -30, 30), + (Solarize, 0, 256), + (Color, 0.05, 0.95), + (Contrast, 0.05, 0.95), + (Brightness, 0.05, 0.95), + (Sharpness, 0.05, 0.95), + (ShearX, -0.1, 0.1), + (TranslateX, -0.1, 0.1), + (TranslateY, -0.1, 0.1), + (Posterize, 4, 8), + (ShearY, -0.1, 0.1), + ] + return l + + +augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()} + + +class AutoAugment: + def __init__(self, n): + self.n = n + self.augment_list = augment_list() + + def __call__(self, img): + ops = random.choices(self.augment_list, k=self.n) + for op, minval, maxval in ops: + val = (random.random()) * float(maxval - minval) + minval + img = op(img, val) + + return img + + +def get_augment(name): + return augment_dict[name] + + +def apply_augment(img, name, level): + augment_fn, low, high = get_augment(name) + return augment_fn(img.copy(), level * (high - low) + low) + + +class Cutout(object): + def __init__(self, n_holes, length): + self.n_holes = n_holes + self.length = length + + def __call__(self, img): + h = img.size(1) + w = img.size(2) + length = random.randint(1, self.length) + mask = np.ones((h, w), np.float32) + + for n in range(self.n_holes): + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - length // 2, 0, h) + y2 = np.clip(y + length // 2, 0, h) + x1 = np.clip(x - length // 2, 0, w) + x2 = np.clip(x + length // 2, 0, w) + + mask[y1:y2, x1:x2] = 0.0 + + mask = torch.from_numpy(mask) + mask = mask.expand_as(img) + img = img * mask + + return img + + +class GaussianBlur(object): + """Gaussian blur augmentation from SimCLR: https://arxiv.org/abs/2002.05709""" + + def __init__(self, sigma=[0.1, 2.0]): + self.sigma = sigma + + def __call__(self, x): + sigma = random.uniform(self.sigma[0], self.sigma[1]) + x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) + return x + + +class Augmentation: + def __init__( + self, + img_size=224, + val_img_size=256, + s=1, + num_aug=4, + cutout_holes=1, + cutout_size=75, + blur=1.0, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ): + self.weak_aug = transforms.Compose( + [ + transforms.RandomResizedCrop( + img_size, interpolation=Image.BICUBIC, scale=(0.2, 1.0) + ), + transforms.RandomHorizontalFlip(), + transforms.RandomApply( + [transforms.ColorJitter(0.8 * s, 0.8 * s, 0.4 * s, 0.2 * s)], p=0.8 + ), + transforms.RandomGrayscale(p=0.2), + transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=blur), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ] + ) + self.strong_aug = transforms.Compose( + [ + transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC), + transforms.RandomHorizontalFlip(), + AutoAugment(n=num_aug), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + Cutout(n_holes=cutout_holes, length=cutout_size), + ] + ) + self.val_aug = transforms.Compose( + [ + transforms.Resize( + (val_img_size, val_img_size), interpolation=Image.BICUBIC + ), + transforms.CenterCrop(img_size), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ] + ) + + def __call__(self, x): + return self.weak_aug(x), self.strong_aug(x), self.val_aug(x) + + +def build_transform(is_train, args): + if args.dataset == "CIFAR-10": + augmentation = Augmentation( + img_size=256, + val_img_size=224, + s=0.5, + num_aug=4, + cutout_holes=1, + cutout_size=75, + mean=[0.4914, 0.4822, 0.4465], + std=[0.2023, 0.1994, 0.2010], + ) + elif args.dataset == "CIFAR-100": + augmentation = Augmentation( + img_size=256, + val_img_size=224, + s=0.5, + num_aug=4, + cutout_holes=1, + cutout_size=75, + mean=[0.5071, 0.4867, 0.4408], + std=[0.2675, 0.2565, 0.2761], + ) + elif args.dataset == "ImageNet-10" or args.dataset == "ImageNet": + augmentation = Augmentation( + img_size=256, + val_img_size=224, + s=0.5, + num_aug=4, + cutout_holes=1, + cutout_size=75, + blur=0.5, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ) + return augmentation if is_train else augmentation.val_aug