-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature(whl): add single loss landscape visualizer (#37)
* init commit * fix * add hessian eigen value * remove redundant files * polish doc * new * chage * polish * polish --------- Co-authored-by: ‘whl’ <‘[email protected]’>
- Loading branch information
Showing
14 changed files
with
713 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .generic_model_pipeline import generic_model_pipeline | ||
from .personalized_model_pipeline import personalized_model_pipeline | ||
from .generic_model_visualization_pipeline import generic_model_visualization_pipeline |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
import copy | ||
import os | ||
import torch | ||
from torch.utils.data import DataLoader | ||
|
||
from fling.component.client import get_client | ||
from fling.component.server import get_server | ||
from fling.component.group import get_group | ||
from fling.dataset import get_dataset | ||
|
||
from fling.utils.data_utils import data_sampling | ||
from fling.utils import Logger, compile_config, client_sampling, VariableMonitor, LRScheduler | ||
from fling.utils import get_launcher | ||
from fling.utils import plot_2d_loss_landscape | ||
|
||
|
||
def generic_model_visualization_pipeline(args: dict, seed: int = 0) -> None: | ||
r""" | ||
Overview: | ||
Pipeline for generic federated learning. Under this setting, models of each client is the same. | ||
We plot the loss landscape before and after aggregation in each round. | ||
The final performance of this generic model is tested on the server (typically using a global test dataset). | ||
Arguments: | ||
- args: dict type arguments. | ||
- seed: random seed. | ||
""" | ||
# Compile the input arguments first. | ||
args = compile_config(args, seed) | ||
|
||
# Construct logger. | ||
logger = Logger(args.other.logging_path) | ||
|
||
# Load dataset. | ||
train_set = get_dataset(args, train=True) | ||
test_set = get_dataset(args, train=False) | ||
|
||
part_test = [test_set[i] for i in range(100)] | ||
part_test = DataLoader(part_test, batch_size=args.learn.batch_size, shuffle=True) | ||
|
||
# Split dataset into clients. | ||
train_sets = data_sampling(train_set, args, seed, train=True) | ||
|
||
# Initialize group, clients and server. | ||
group = get_group(args, logger) | ||
group.server = get_server(args, test_dataset=test_set) | ||
for i in range(args.client.client_num): | ||
group.append(get_client(args=args, client_id=i, train_dataset=train_sets[i])) | ||
group.initialize() | ||
|
||
# Setup lr_scheduler. | ||
lr_scheduler = LRScheduler(args) | ||
# Setup launcher. | ||
launcher = get_launcher(args) | ||
|
||
# Variables for visualization | ||
last_global_model = None | ||
|
||
# Training loop | ||
for i in range(args.learn.global_eps): | ||
logger.logging('Starting round: ' + str(i)) | ||
# Initialize variable monitor. | ||
train_monitor = VariableMonitor() | ||
|
||
# Random sample participated clients in each communication round. | ||
participated_clients = client_sampling(range(args.client.client_num), args.client.sample_rate) | ||
|
||
# Adjust learning rate. | ||
cur_lr = lr_scheduler.get_lr(train_round=i) | ||
|
||
# Local training for each participated client and add results to the monitor. | ||
# Use multiprocessing for acceleration. | ||
train_results = launcher.launch( | ||
clients=[group.clients[j] for j in participated_clients], lr=cur_lr, task_name='train' | ||
) | ||
for item in train_results: | ||
train_monitor.append(item) | ||
|
||
# Testing | ||
if i % args.other.test_freq == 0 and "before_aggregation" in args.learn.test_place: | ||
test_result = group.server.test(model=group.clients[0].model) | ||
# Logging test variables. | ||
logger.add_scalars_dict(prefix='before_aggregation_test', dic=test_result, rnd=i) | ||
|
||
if last_global_model is not None and i > 45: | ||
plot_2d_loss_landscape( | ||
model=last_global_model, | ||
dataloader=part_test, | ||
device=args.learn.device, | ||
caption='Global-test Loss Landscape', | ||
save_path=os.path.join(args.other.logging_path, f"losslandscape_gt_{i}.pdf"), | ||
target_model1=group.clients[0].model, | ||
target_model2=group.clients[1].model, | ||
resolution=20, | ||
noise_range=(-0.1, 1.0), | ||
log_scale=True, | ||
max_val=20 | ||
) | ||
plot_2d_loss_landscape( | ||
model=last_global_model, | ||
dataloader=group.clients[0].train_dataloader, | ||
device=args.learn.device, | ||
caption='Client-1-train Loss Landscape', | ||
save_path=os.path.join(args.other.logging_path, f"losslandscape_ct1_{i}.pdf"), | ||
target_model1=group.clients[0].model, | ||
target_model2=group.clients[1].model, | ||
resolution=20, | ||
noise_range=(-0.1, 1.1), | ||
log_scale=True, | ||
max_val=20 | ||
) | ||
plot_2d_loss_landscape( | ||
model=last_global_model, | ||
dataloader=group.clients[1].train_dataloader, | ||
device=args.learn.device, | ||
caption='Client-2-train Loss Landscape', | ||
save_path=os.path.join(args.other.logging_path, f"losslandscape_ct2_{i}.pdf"), | ||
target_model1=group.clients[0].model, | ||
target_model2=group.clients[1].model, | ||
resolution=20, | ||
noise_range=(-0.1, 1.1), | ||
log_scale=True, | ||
max_val=20 | ||
) | ||
|
||
# Aggregate parameters in each client. | ||
trans_cost = group.aggregate(i) | ||
|
||
last_global_model = copy.deepcopy(group.clients[0].model) | ||
|
||
# Logging train variables. | ||
mean_train_variables = train_monitor.variable_mean() | ||
mean_train_variables.update({'trans_cost': trans_cost / 1e6, 'lr': cur_lr}) | ||
logger.add_scalars_dict(prefix='train', dic=mean_train_variables, rnd=i) | ||
|
||
# Testing | ||
if i % args.other.test_freq == 0 and "after_aggregation" in args.learn.test_place: | ||
test_result = group.server.test(model=group.clients[0].model) | ||
|
||
# Logging test variables. | ||
logger.add_scalars_dict(prefix='after_aggregation_test', dic=test_result, rnd=i) | ||
|
||
# Saving model checkpoints. | ||
torch.save(group.server.glob_dict, os.path.join(args.other.logging_path, 'model.ckpt')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,4 @@ | ||
from .loss_landscape import plot_2d_loss_landscape | ||
from .conv_kernel_visualizer import plot_conv_kernels | ||
from .hessian_eigen_value import calculate_hessian_dominant_eigen_values | ||
from .activation_maximization import ActivationMaximizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torchvision | ||
from torch import nn | ||
|
||
from fling.utils import Logger | ||
|
||
|
||
def plot_conv_kernels(logger: Logger, layer: nn.Conv2d, name: str) -> None: | ||
""" | ||
Overview: | ||
Plot the kernels in a certain convolution layer for better visualization. | ||
Arguments: | ||
logger: The logger to write result image. | ||
layer: The convolution layer to visualize. | ||
name: The name of the plotted figure. | ||
""" | ||
param = layer.weight | ||
in_channels = param.shape[1] | ||
k_w, k_h = param.size()[3], param.size()[2] | ||
kernel_all = param.view(-1, 1, k_w, k_h) | ||
kernel_grid = torchvision.utils.make_grid(kernel_all, normalize=True, scale_each=True, nrow=in_channels) | ||
logger.add_image(f'{name}', kernel_grid, global_step=0) |
25 changes: 25 additions & 0 deletions
25
fling/utils/visualize_utils/demo/demo_conv_kernel_visualize.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from torchvision.models import resnet18 | ||
|
||
from fling.utils import Logger | ||
from fling.utils.visualize_utils import plot_conv_kernels | ||
from easydict import EasyDict | ||
from fling.utils.registry_utils import MODEL_REGISTRY | ||
|
||
if __name__ == '__main__': | ||
# Step 1: prepare the model. | ||
model_arg = EasyDict(dict( | ||
name='resnet8', | ||
input_channel=3, | ||
class_number=100, | ||
)) | ||
model_name = model_arg.pop('name') | ||
model = MODEL_REGISTRY.build(model_name, **model_arg) | ||
|
||
# You can also initialize the model without using configurations. | ||
# e.g. model = resnet18(pretrained=True) | ||
|
||
# Step 2: prepare the logger. | ||
logger = Logger('resnet18_conv_kernels') | ||
|
||
# Step 3: save the kernels. | ||
plot_conv_kernels(logger, model.conv1, name='pre-conv') |
86 changes: 86 additions & 0 deletions
86
fling/utils/visualize_utils/demo/demo_hessian_eigen_value.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from easydict import EasyDict | ||
|
||
import torch | ||
from torch import nn | ||
from torch.utils.data import DataLoader | ||
|
||
from fling import dataset | ||
from fling.utils.visualize_utils import calculate_hessian_dominant_eigen_values | ||
from fling.utils.registry_utils import DATASET_REGISTRY | ||
from easydict import EasyDict | ||
from fling.utils.registry_utils import MODEL_REGISTRY | ||
|
||
|
||
class ToyModel(nn.Module): | ||
""" | ||
Overview: | ||
A toy model for demonstrating attacking results. | ||
""" | ||
|
||
def __init__(self): | ||
super(ToyModel, self).__init__() | ||
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) | ||
self.relu1 = nn.ReLU() | ||
|
||
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) | ||
self.relu2 = nn.ReLU() | ||
|
||
self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1) | ||
self.relu3 = nn.ReLU() | ||
|
||
self.pool = nn.AdaptiveAvgPool2d((1, 1)) | ||
self.flat = nn.Flatten() | ||
|
||
self.fc = nn.Linear(128, 10) | ||
|
||
def forward(self, x): | ||
x = self.relu1(self.conv1(x)) | ||
x = self.relu2(self.conv2(x)) | ||
x = self.relu3(self.conv3(x)) | ||
x = self.flat(self.pool(x)) | ||
return self.fc(x) | ||
|
||
|
||
if __name__ == '__main__': | ||
# Step 1: prepare the dataset. | ||
dataset_config = EasyDict(dict(data=dict(data_path='./data/cifar10', transforms=dict()))) | ||
dataset = DATASET_REGISTRY.build('cifar10', dataset_config, train=False) | ||
|
||
# Test dataset is for generating loss landscape. | ||
test_dataset = [dataset[i] for i in range(100)] | ||
test_dataloader = DataLoader(test_dataset, batch_size=100) | ||
|
||
# Step 2: prepare the model. | ||
model_arg = EasyDict(dict( | ||
name='resnet8', | ||
input_channel=3, | ||
class_number=10, | ||
)) | ||
model_name = model_arg.pop('name') | ||
model = MODEL_REGISTRY.build(model_name, **model_arg) | ||
|
||
# You can also initialize the model without using configurations. | ||
# e.g. model = ToyModel() | ||
|
||
# Step 3: train the randomly initialized model. | ||
dataloader = DataLoader(dataset, batch_size=100) | ||
device = 'cuda' | ||
model = model.to(device) | ||
model.train() | ||
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4) | ||
criterion = torch.nn.CrossEntropyLoss() | ||
for _ in range(0): | ||
for _, (data) in enumerate(dataloader): | ||
data_x, data_y = data['input'], data['class_id'] | ||
data_x, data_y = data_x.to(device), data_y.to(device) | ||
pred_y = model(data_x) | ||
loss = criterion(pred_y, data_y) | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
model.to('cpu') | ||
|
||
# Step 4: plot the loss landscape after training the model. | ||
# Only one line of code for visualization. | ||
res = calculate_hessian_dominant_eigen_values(model, iter_num=20, dataloader=test_dataloader, device='cuda') | ||
print(res) |
Oops, something went wrong.