Brief demo showing how ENRC can be applied to the Concatenated-MNIST data set introduced in the ENRC paper, reference below:

Lukas Miklautz, Dominik Mautz, Muzaffer Can Altinigneli, Christian Böhm, Claudia Plant: Deep Embedded Non-Redundant Clustering. AAAI 2020: 5174-5181

The jupyter notebook can be found here.

# Importing all necessary libraries

%load_ext autoreload
%autoreload 2

# internal packages
import os
from collections import Counter
# external packages
import torch
import torchvision
import numpy as np
import sklearn
from sklearn.metrics import normalized_mutual_info_score
from matplotlib import pyplot as plt
%matplotlib inline
import seaborn as sns
import pandas as pd
import cluspy
from cluspy.deep import encode_batchwise, get_dataloader
from cluspy.alternative import NrKmeans
from cluspy.metrics.multipe_labelings_scoring import MultipleLabelingsConfusionMatrix

# specify base paths

base_path = "material"
model_name = "autoencoder.pth"

print("torch: ",torch.__version__)
print("torchvision: ",torchvision.__version__)
print("numpy: ", np.__version__)
print("scikit-learn:", sklearn.__version__)
print("cluspy:", cluspy.__version__)
# Some helper functions, you can ignore those in the beginning
def denormalize_fn(tensor:torch.Tensor, mean:float, std:float, w:int, h:int)->torch.Tensor:
    This applies an inverse z-transformation and reshaping to visualize the images properly.
    pt_std = torch.as_tensor(std, dtype=torch.float32, device=tensor.device)
    pt_mean = torch.as_tensor(mean, dtype=torch.float32, device=tensor.device)
    return (tensor.mul(pt_std).add(pt_mean).view(-1, 1, h, w) * 255).int().detach()

def plot_images(images:torch.Tensor, pad:int=1):
    """Aligns multiple images on an N by 8 grid"""
    def imshow(img):
        plt.figure(figsize=(10, 20))
        npimg = img.numpy()
        npimg = np.array(npimg)
        plt.imshow(np.transpose(npimg, (1, 2, 0)),
                   vmin=0, vmax=1)
    imshow(torchvision.utils.make_grid(images, pad_value=255, normalize=False, padding=pad));
def detect_device():
    """Automatically detects if you have a cuda enabled GPU"""
    if torch.cuda.is_available():
        device = torch.device('cuda:1')
        device = torch.device('cpu')
    return device

Create Concatenated-MNIST data set

We create randomly paired MNIST digits to create a data set with classes from 00 to 99, where the left and right digit are independent of each other.

def load_mnist(train=True):
    # setup normalization function
    mnist_mean = 0.1307
    mnist_std = 0.3081
    normalize = torchvision.transforms.Normalize((mnist_mean,), (mnist_std,))
    # download the MNIST data set
    trainset = torchvision.datasets.MNIST(root='./data', train=train, download=True)
    data =
    # preprocess the data
    # Scale to [0,1]
    data = data.float()/255
    # Apply z-transformation
#     data = normalize(data)
    # Flatten from a shape of (-1, 28,28) to (-1, 28*28)
    data = data.reshape(-1, data.shape[1] * data.shape[2])
    labels = trainset.targets
    return data, labels

data, labels = load_mnist()

random_state = np.random.randint(100000)
rng = np.random.default_rng(random_state)
subsample_size = 10000
rand_idx = rng.choice(data.shape[0], subsample_size, replace=False)

data_eval = data[rand_idx]
labels_eval = labels[rand_idx]

data_train = np.delete(data, rand_idx, axis=0)
labels_train = np.delete(labels, rand_idx,  axis=0)

data_test, labels_test = load_mnist(train=False)

print("Data Set Information")
print("Number of data points: ", data.shape[0])
print("Number of dimensions: ", data.shape[1])
print(f"Mean: {data.mean():.2f}, Standard deviation: {data.std():.2f}")
print(f"Min: {data.min():.2f}, Max: {data.max():.2f}")
print("Number of classes: ", len(set(labels.tolist())))
print("Class distribution:\n", sorted(Counter(labels.tolist()).items()))
def create_random_pairings(data, labels, random_state=None, data_cat_dim=3):
    """Creates a random pairings between two images"""
    if random_state is None:
        random_state = np.random.randint(100000)
    rng = np.random.default_rng(random_state)
    left_idx = rng.choice(data.shape[0], data.shape[0], replace=False)
    right_idx = rng.choice(data.shape[0], data.shape[0], replace=False)

    left_data = data[left_idx].clone()
    right_data = data[right_idx].clone()
    left_labels = labels[left_idx].clone()
    right_labels = labels[right_idx].clone()
    concat_data =[left_data, right_data], data_cat_dim)
    concat_labels = torch.stack([left_labels, right_labels], dim=1)
    return concat_data, concat_labels
# Specify random state if you want to use the same pairing across runs
random_state = None
concat_data, concat_labels = create_random_pairings(data.reshape(-1, 1, 28, 28), labels, random_state)
# Flatten data for feed forward network
concat_data = concat_data.reshape(-1, 28*56)
# z-transform
mean = concat_data.mean()
std = concat_data.std()

denormalize = lambda x: denormalize_fn(x, mean=mean, std=std, w=56, h=28)

concat_data -= mean
concat_data /= std
random_state = np.random.randint(100000)
rng = np.random.default_rng(random_state)
subsample_size = 10000
rand_idx = rng.choice(concat_data.shape[0], subsample_size, replace=False)

data_eval = concat_data[rand_idx]
labels_eval = concat_labels[rand_idx]

data_train = np.delete(concat_data, rand_idx, axis=0)
labels_train = np.delete(concat_labels, rand_idx,  axis=0)

print("Data Set Information")
print("Number of data points: ", concat_data.shape[0])
print("Number of dimensions: ", concat_data.shape[1])
print(f"Mean: {concat_data.mean():.2f}, Standard deviation: {concat_data.std():.2f}")
print(f"Min: {concat_data.min():.2f}, Max: {concat_data.max():.2f}")
for labeling in concat_labels.t():
    print("Number of classes: ", len(set(labeling.tolist())))
    print("Class distribution:\n", sorted(Counter(labeling.tolist()).items()))
print("Plot some images to see if everything worked:")

Plot some images to see if everything worked:


Pretrain Autoencoder

from cluspy.deep import FlexibleAutoencoder

# Set all parameters needed for training

# The size of the mini-batch that is passed in each trainings iteration
batch_size = 128
# The learning rate specifies the step size of the gradient descent algorithm
learning_rate = 1e-3
# Set device on which the model should be trained on (For most of you this will be the CPU)
device = detect_device()
print("Use device: ", device)

# load model to device

# create a Dataloader to train the autoencoder in mini-batch fashion
trainloader = get_dataloader(data_train, batch_size=batch_size, shuffle=True, drop_last=True, additional_inputs=[labels_train])

# create a Dataloader to evaluate the autoencoder in mini-batch fashion
evalloader = get_dataloader(data_eval, batch_size=batch_size, shuffle=False, drop_last=False, additional_inputs=[labels_eval])

# create a Dataloader to evaluate the autoencoder in mini-batch fashion  on the full data
fullloader = get_dataloader(concat_data, batch_size=batch_size, shuffle=False, drop_last=False, additional_inputs=[concat_labels])

# define optimizer (use a high weight decay as regularization and such that the embedded data has a small magnitude)
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate, weight_decay=1e-2)

# define loss function
loss_fn = torch.nn.MSELoss()

# path to were we want to save/load the model to/from
pretrained_model_name = "pretrained_" + model_name
pretrained_model_path = os.path.join(base_path, pretrained_model_name)
# Train and save model
model_path = "test_cat.pth"
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau
scheduler_params = {"factor":0.9, "patience":5, "verbose":True}, lr=learning_rate, dataloader=trainloader, evalloader=evalloader, device=device, model_path=model_path,
         scheduler=scheduler, scheduler_params=scheduler_params, print_step=5)
# load best model (can be executed in case model was previously saved)
sd = torch.load(model_path)
# Plot how well we are at reconstructing the data:
to_plot = 8
print("Original Images")
print("Reconstructed Images")
reconstruction = model(concat_data[0:to_plot].to(device)).detach().cpu()
Apply NrKmeans to the AE Embedding as a baseline

from cluspy.alternative import NrKmeans
n_clusters = [10,10]
embedded_data = encode_batchwise(fullloader, model, device)
nrkmeans = NrKmeans(n_clusters=n_clusters)
preds = nrkmeans.fit_predict(embedded_data)

NMI for all cluster combinations

We see that two clusterings could be found and the two subspaces are mutually non-redundant (the opposite clusterings are close to zero).

cm = MultipleLabelingsConfusionMatrix(labels_true=concat_labels, labels_pred=preds)


# Indices of the dimension of each clustering
for i, P_i in enumerate(nrkmeans.P):
    print(f"Clustering {i} dims: {P_i}")
Clustering 0 dims: [ 3  2  6 10 19  7 12 11 17 16]
Clustering 1 dims: [ 0  4  9  8 18 15 13  5 14  1]

Apply ENRC

from cluspy.deep.enrc import ENRC
# load best model
sd = torch.load(model_path)

scheduler = torch.optim.lr_scheduler.StepLR
scheduler_params = {"step_size":40, "gamma":0.9}

enrc = ENRC(n_clusters=[10,10], 
            # Reduce learning_rate by factor of 10 as usually done in Deep Clustering for stable training
            # Use nrkmeans to initialize ENRC
            # Use a random subsample of the data to speed up the initialization procedure
            # Prints training information
# Indices of the dimension of each clustering
for i, P_i in enumerate(enrc.P):
    print(f"Clustering {i} dims: {P_i}")
Clustering 0 dims: [ 9 11 18  3 10  1 14 16 12  5]
Clustering 1 dims: [ 2  6 13  7  0 19 15  8  4 17]
# Soft Beta Weights
sns.heatmap(enrc.betas, vmin=0, vmax=1.0)


NMI for all cluster combinations

We see that we could improve upon the results of NrKmeans and the two subspaces are non-redundant.

cm = MultipleLabelingsConfusionMatrix(labels_true=concat_labels, labels_pred=enrc.labels_)


Plot reconstructed centroids from each clustering

We see that for the centers of clustering 0 only the left side digits change, while for the clustering 1 only the right side digits change.

for subspace_i in range(len(enrc.P)):
    rec_centers = enrc.reconstruct_subspace_centroids(subspace_i)
    print("Centers of Clustering ", subspace_i)
Centers of Clustering  0


Centers of Clustering  1

