Skip to content

Commit

Permalink
feature(whl): add single loss landscape visualizer (#37)
Browse files Browse the repository at this point in the history
* init commit

* fix

* add hessian eigen value

* remove redundant files

* polish doc

* new

* chage

* polish

* polish

---------

Co-authored-by: ‘whl’ <‘[email protected]’>
  • Loading branch information
kxzxvbk and ‘whl’ authored Jun 14, 2024
1 parent 4c8df63 commit 2c8800a
Show file tree
Hide file tree
Showing 14 changed files with 713 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ For attacking methods, please refer to our examples in: [demo for attack](https:
- Support for a variety of algorithms and datasets.
- Support multiprocessing training on each client for better efficiency.
- Using single GPU to simulate Federated Learning process (multi-GPU version will be released soon).
- Strong visualization utilities. See [demo](https://github.com/kxzxvbk/Fling/blob/main/fling/utils/visualize_utils/demo) for detailed information.

## Supported Algorithms

Expand Down
1 change: 1 addition & 0 deletions fling/pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .generic_model_pipeline import generic_model_pipeline
from .personalized_model_pipeline import personalized_model_pipeline
from .generic_model_visualization_pipeline import generic_model_visualization_pipeline
143 changes: 143 additions & 0 deletions fling/pipeline/generic_model_visualization_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import copy
import os
import torch
from torch.utils.data import DataLoader

from fling.component.client import get_client
from fling.component.server import get_server
from fling.component.group import get_group
from fling.dataset import get_dataset

from fling.utils.data_utils import data_sampling
from fling.utils import Logger, compile_config, client_sampling, VariableMonitor, LRScheduler
from fling.utils import get_launcher
from fling.utils import plot_2d_loss_landscape


def generic_model_visualization_pipeline(args: dict, seed: int = 0) -> None:
r"""
Overview:
Pipeline for generic federated learning. Under this setting, models of each client is the same.
We plot the loss landscape before and after aggregation in each round.
The final performance of this generic model is tested on the server (typically using a global test dataset).
Arguments:
- args: dict type arguments.
- seed: random seed.
"""
# Compile the input arguments first.
args = compile_config(args, seed)

# Construct logger.
logger = Logger(args.other.logging_path)

# Load dataset.
train_set = get_dataset(args, train=True)
test_set = get_dataset(args, train=False)

part_test = [test_set[i] for i in range(100)]
part_test = DataLoader(part_test, batch_size=args.learn.batch_size, shuffle=True)

# Split dataset into clients.
train_sets = data_sampling(train_set, args, seed, train=True)

# Initialize group, clients and server.
group = get_group(args, logger)
group.server = get_server(args, test_dataset=test_set)
for i in range(args.client.client_num):
group.append(get_client(args=args, client_id=i, train_dataset=train_sets[i]))
group.initialize()

# Setup lr_scheduler.
lr_scheduler = LRScheduler(args)
# Setup launcher.
launcher = get_launcher(args)

# Variables for visualization
last_global_model = None

# Training loop
for i in range(args.learn.global_eps):
logger.logging('Starting round: ' + str(i))
# Initialize variable monitor.
train_monitor = VariableMonitor()

# Random sample participated clients in each communication round.
participated_clients = client_sampling(range(args.client.client_num), args.client.sample_rate)

# Adjust learning rate.
cur_lr = lr_scheduler.get_lr(train_round=i)

# Local training for each participated client and add results to the monitor.
# Use multiprocessing for acceleration.
train_results = launcher.launch(
clients=[group.clients[j] for j in participated_clients], lr=cur_lr, task_name='train'
)
for item in train_results:
train_monitor.append(item)

# Testing
if i % args.other.test_freq == 0 and "before_aggregation" in args.learn.test_place:
test_result = group.server.test(model=group.clients[0].model)
# Logging test variables.
logger.add_scalars_dict(prefix='before_aggregation_test', dic=test_result, rnd=i)

if last_global_model is not None and i > 45:
plot_2d_loss_landscape(
model=last_global_model,
dataloader=part_test,
device=args.learn.device,
caption='Global-test Loss Landscape',
save_path=os.path.join(args.other.logging_path, f"losslandscape_gt_{i}.pdf"),
target_model1=group.clients[0].model,
target_model2=group.clients[1].model,
resolution=20,
noise_range=(-0.1, 1.0),
log_scale=True,
max_val=20
)
plot_2d_loss_landscape(
model=last_global_model,
dataloader=group.clients[0].train_dataloader,
device=args.learn.device,
caption='Client-1-train Loss Landscape',
save_path=os.path.join(args.other.logging_path, f"losslandscape_ct1_{i}.pdf"),
target_model1=group.clients[0].model,
target_model2=group.clients[1].model,
resolution=20,
noise_range=(-0.1, 1.1),
log_scale=True,
max_val=20
)
plot_2d_loss_landscape(
model=last_global_model,
dataloader=group.clients[1].train_dataloader,
device=args.learn.device,
caption='Client-2-train Loss Landscape',
save_path=os.path.join(args.other.logging_path, f"losslandscape_ct2_{i}.pdf"),
target_model1=group.clients[0].model,
target_model2=group.clients[1].model,
resolution=20,
noise_range=(-0.1, 1.1),
log_scale=True,
max_val=20
)

# Aggregate parameters in each client.
trans_cost = group.aggregate(i)

last_global_model = copy.deepcopy(group.clients[0].model)

# Logging train variables.
mean_train_variables = train_monitor.variable_mean()
mean_train_variables.update({'trans_cost': trans_cost / 1e6, 'lr': cur_lr})
logger.add_scalars_dict(prefix='train', dic=mean_train_variables, rnd=i)

# Testing
if i % args.other.test_freq == 0 and "after_aggregation" in args.learn.test_place:
test_result = group.server.test(model=group.clients[0].model)

# Logging test variables.
logger.add_scalars_dict(prefix='after_aggregation_test', dic=test_result, rnd=i)

# Saving model checkpoints.
torch.save(group.server.glob_dict, os.path.join(args.other.logging_path, 'model.ckpt'))
1 change: 1 addition & 0 deletions fling/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .utils import Logger, client_sampling, VariableMonitor
from .data_utils import get_data_transform
from .launcher_utils import get_launcher
from .visualize_utils import plot_2d_loss_landscape
22 changes: 22 additions & 0 deletions fling/utils/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import math
import pickle
import random
Expand Down Expand Up @@ -244,3 +245,24 @@ def forward(self, x):

def _tensor_size(self, t):
return t.size()[1] * t.size()[2] * t.size()[3]


def model_add(model1: nn.Module, model2: nn.Module) -> nn.Module:
ret = copy.deepcopy(model1)
sd1, sd2 = model1.state_dict(), model2.state_dict()
ret.load_state_dict({k: sd1[k] + sd2[k] for k in sd1.keys()})
return ret


def model_sub(model1: nn.Module, model2: nn.Module) -> nn.Module:
ret = copy.deepcopy(model1)
sd1, sd2 = model1.state_dict(), model2.state_dict()
ret.load_state_dict({k: sd1[k] - sd2[k] for k in sd1.keys()})
return ret


def model_mul(scalar: float, model: nn.Module) -> nn.Module:
ret = copy.deepcopy(model)
sd = model.state_dict()
ret.load_state_dict({k: scalar * sd[k] for k in sd.keys()})
return ret
3 changes: 3 additions & 0 deletions fling/utils/visualize_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from .loss_landscape import plot_2d_loss_landscape
from .conv_kernel_visualizer import plot_conv_kernels
from .hessian_eigen_value import calculate_hessian_dominant_eigen_values
from .activation_maximization import ActivationMaximizer
21 changes: 21 additions & 0 deletions fling/utils/visualize_utils/conv_kernel_visualizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torchvision
from torch import nn

from fling.utils import Logger


def plot_conv_kernels(logger: Logger, layer: nn.Conv2d, name: str) -> None:
"""
Overview:
Plot the kernels in a certain convolution layer for better visualization.
Arguments:
logger: The logger to write result image.
layer: The convolution layer to visualize.
name: The name of the plotted figure.
"""
param = layer.weight
in_channels = param.shape[1]
k_w, k_h = param.size()[3], param.size()[2]
kernel_all = param.view(-1, 1, k_w, k_h)
kernel_grid = torchvision.utils.make_grid(kernel_all, normalize=True, scale_each=True, nrow=in_channels)
logger.add_image(f'{name}', kernel_grid, global_step=0)
25 changes: 25 additions & 0 deletions fling/utils/visualize_utils/demo/demo_conv_kernel_visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from torchvision.models import resnet18

from fling.utils import Logger
from fling.utils.visualize_utils import plot_conv_kernels
from easydict import EasyDict
from fling.utils.registry_utils import MODEL_REGISTRY

if __name__ == '__main__':
# Step 1: prepare the model.
model_arg = EasyDict(dict(
name='resnet8',
input_channel=3,
class_number=100,
))
model_name = model_arg.pop('name')
model = MODEL_REGISTRY.build(model_name, **model_arg)

# You can also initialize the model without using configurations.
# e.g. model = resnet18(pretrained=True)

# Step 2: prepare the logger.
logger = Logger('resnet18_conv_kernels')

# Step 3: save the kernels.
plot_conv_kernels(logger, model.conv1, name='pre-conv')
86 changes: 86 additions & 0 deletions fling/utils/visualize_utils/demo/demo_hessian_eigen_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from easydict import EasyDict

import torch
from torch import nn
from torch.utils.data import DataLoader

from fling import dataset
from fling.utils.visualize_utils import calculate_hessian_dominant_eigen_values
from fling.utils.registry_utils import DATASET_REGISTRY
from easydict import EasyDict
from fling.utils.registry_utils import MODEL_REGISTRY


class ToyModel(nn.Module):
"""
Overview:
A toy model for demonstrating attacking results.
"""

def __init__(self):
super(ToyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.relu1 = nn.ReLU()

self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.relu2 = nn.ReLU()

self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.relu3 = nn.ReLU()

self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.flat = nn.Flatten()

self.fc = nn.Linear(128, 10)

def forward(self, x):
x = self.relu1(self.conv1(x))
x = self.relu2(self.conv2(x))
x = self.relu3(self.conv3(x))
x = self.flat(self.pool(x))
return self.fc(x)


if __name__ == '__main__':
# Step 1: prepare the dataset.
dataset_config = EasyDict(dict(data=dict(data_path='./data/cifar10', transforms=dict())))
dataset = DATASET_REGISTRY.build('cifar10', dataset_config, train=False)

# Test dataset is for generating loss landscape.
test_dataset = [dataset[i] for i in range(100)]
test_dataloader = DataLoader(test_dataset, batch_size=100)

# Step 2: prepare the model.
model_arg = EasyDict(dict(
name='resnet8',
input_channel=3,
class_number=10,
))
model_name = model_arg.pop('name')
model = MODEL_REGISTRY.build(model_name, **model_arg)

# You can also initialize the model without using configurations.
# e.g. model = ToyModel()

# Step 3: train the randomly initialized model.
dataloader = DataLoader(dataset, batch_size=100)
device = 'cuda'
model = model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()
for _ in range(0):
for _, (data) in enumerate(dataloader):
data_x, data_y = data['input'], data['class_id']
data_x, data_y = data_x.to(device), data_y.to(device)
pred_y = model(data_x)
loss = criterion(pred_y, data_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.to('cpu')

# Step 4: plot the loss landscape after training the model.
# Only one line of code for visualization.
res = calculate_hessian_dominant_eigen_values(model, iter_num=20, dataloader=test_dataloader, device='cuda')
print(res)
Loading

0 comments on commit 2c8800a

Please sign in to comment.