From bea66ac307736d814d576965bc924f678fac37a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Mon, 22 Jan 2024 14:34:03 +0800 Subject: [PATCH 1/9] init commit --- README.md | 1 + fling/component/client/fedmoon_client.py | 1 - fling/component/group/base_group.py | 5 +- fling/model/cnn.py | 2 +- fling/utils/__init__.py | 1 + fling/utils/torch_utils.py | 6 +- fling/utils/visualize_utils/__init__.py | 1 + .../demo/demo_single_loss_landscape.py | 49 +++++++ fling/utils/visualize_utils/loss_landscape.py | 135 ++++++++++++++++++ flzoo/cifar10/cifar10_fedmoon_cnn_config.py | 9 +- .../cifar100_fedmoon_resnet_config.py | 8 +- setup.py | 2 +- 12 files changed, 200 insertions(+), 20 deletions(-) create mode 100644 fling/utils/visualize_utils/__init__.py create mode 100644 fling/utils/visualize_utils/demo/demo_single_loss_landscape.py create mode 100644 fling/utils/visualize_utils/loss_landscape.py diff --git a/README.md b/README.md index 3421499..c0431d5 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,7 @@ For other algorithms and datasets, users can refer to `argzoo/` or customize you - Support for a variety of algorithms and datasets. - Support multiprocessing training on each client for better efficiency. - Using single GPU to simulate Federated Learning process (multi-GPU version will be released soon). +- Strong visualization utilities. See [demo](https://github.com/kxzxvbk/Fling/blob/main/fling/utils/visualize_utils/demo) for detailed information. ## Supported Algorithms diff --git a/fling/component/client/fedmoon_client.py b/fling/component/client/fedmoon_client.py index 42a7e21..0e1965f 100644 --- a/fling/component/client/fedmoon_client.py +++ b/fling/component/client/fedmoon_client.py @@ -37,7 +37,6 @@ def _store_prev_model(self, model: nn.Module) -> None: self.prev_models.pop(0) self.prev_models.append(copy.deepcopy(model)) - def _store_global_model(self, model: nn.Module) -> None: r""" Overview: diff --git a/fling/component/group/base_group.py b/fling/component/group/base_group.py index 6d93aae..a0f6847 100644 --- a/fling/component/group/base_group.py +++ b/fling/component/group/base_group.py @@ -98,8 +98,9 @@ def aggregate(self, train_round: int, aggr_parameter_args: dict = None) -> int: # Pick out the parameters for aggregation if needed. if aggr_parameter_args is not None: fed_keys_bak = self.clients[0].fed_keys - new_fed_keys = get_weights(self.clients[0].model, aggr_parameter_args, - return_dict=True, include_non_param=True).keys() + new_fed_keys = get_weights( + self.clients[0].model, aggr_parameter_args, return_dict=True, include_non_param=True + ).keys() for client in self.clients: client.set_fed_keys(new_fed_keys) diff --git a/fling/model/cnn.py b/fling/model/cnn.py index d953798..c57af13 100644 --- a/fling/model/cnn.py +++ b/fling/model/cnn.py @@ -19,7 +19,7 @@ def __init__( activation='relu' ): super(CNNModel, self).__init__() - + self.layers = [] self.layers.append(nn.Conv2d(input_channel, hidden_dims[0], kernel_size=kernel_sizes[0], padding=paddings[0])) self.layers.append(get_activation(name=activation)) diff --git a/fling/utils/__init__.py b/fling/utils/__init__.py index 6c75daf..93a32c9 100644 --- a/fling/utils/__init__.py +++ b/fling/utils/__init__.py @@ -4,3 +4,4 @@ from .utils import Logger, client_sampling, VariableMonitor from .data_utils import get_data_transform from .launcher_utils import get_launcher +from .visualize_utils import plot_2d_loss_landscape diff --git a/fling/utils/torch_utils.py b/fling/utils/torch_utils.py index 3a9cf61..bcf4a6d 100644 --- a/fling/utils/torch_utils.py +++ b/fling/utils/torch_utils.py @@ -64,8 +64,10 @@ def calculate_mean_std(train_dataset: Dataset, test_dataset: Dataset) -> tuple: return reduce(lambda x, y: x + y, res) / len(res), reduce(lambda x, y: x + y, res_std) / len(res) -def get_weights(model: nn.Module, parameter_args: dict, - return_dict: bool = False, include_non_param: bool = False) -> Union[List, Dict]: +def get_weights(model: nn.Module, + parameter_args: dict, + return_dict: bool = False, + include_non_param: bool = False) -> Union[List, Dict]: """ Overview: Get model parameters, using the given ``parameter_args``. diff --git a/fling/utils/visualize_utils/__init__.py b/fling/utils/visualize_utils/__init__.py new file mode 100644 index 0000000..7ce876e --- /dev/null +++ b/fling/utils/visualize_utils/__init__.py @@ -0,0 +1 @@ +from .loss_landscape import plot_2d_loss_landscape diff --git a/fling/utils/visualize_utils/demo/demo_single_loss_landscape.py b/fling/utils/visualize_utils/demo/demo_single_loss_landscape.py new file mode 100644 index 0000000..ec0c9e4 --- /dev/null +++ b/fling/utils/visualize_utils/demo/demo_single_loss_landscape.py @@ -0,0 +1,49 @@ +import torch +from torch.utils.data import DataLoader +from torchvision.models import resnet18 +from torchvision import transforms +from torchvision.datasets import CIFAR10 + +from fling.utils.visualize_utils import plot_2d_loss_landscape + +if __name__ == '__main__': + # Step 1: prepare the dataset. + transform = transforms.Compose([transforms.ToTensor()]) + dataset = CIFAR10('./data/cifar10', transform=transform) + + # Test dataset is for generating loss landscape. + test_dataset = [dataset[i] for i in range(100)] + test_dataloader = DataLoader(test_dataset, batch_size=100) + + # Step 2: prepare the model. + model = resnet18(pretrained=False, num_classes=10) + + # Step 3: train the randomly initialized model. + dataloader = DataLoader(dataset, batch_size=100) + device = 'cuda' + model = model.to(device) + model.train() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + criterion = torch.nn.CrossEntropyLoss() + for _ in range(10): + for _, (data_x, data_y) in enumerate(dataloader): + data_x, data_y = data_x.to(device), data_y.to(device) + pred_y = model(data_x) + loss = criterion(pred_y, data_y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + model.to('cpu') + + # Step 4: plot the loss landscape after training the model. + # Only one line of code for visualization! + plot_2d_loss_landscape( + model=model, + dataloader=test_dataloader, + device='cuda', + caption='Loss Landscape Trained', + save_path='./landscape.pdf', + noise_range=(-0.01, 0.01), + resolution=30, + log_scale=True + ) diff --git a/fling/utils/visualize_utils/loss_landscape.py b/fling/utils/visualize_utils/loss_landscape.py new file mode 100644 index 0000000..70ef4e3 --- /dev/null +++ b/fling/utils/visualize_utils/loss_landscape.py @@ -0,0 +1,135 @@ +import copy +from typing import Tuple, Dict +from matplotlib import pyplot as plt +from tqdm import tqdm + +import torch +from torch import nn +from torch.utils.data import DataLoader + + +def _gen_rand_like(tensor: torch.Tensor) -> torch.Tensor: + # Return a tensor whose shape is identical to the input tensor. + # The returned tensor is a filled with Gaussian noise and the norm in each line is the same as the input. + tmp = torch.rand_like(tensor) + tmp = tmp / torch.norm(tmp, dim=1, keepdim=True) + tmp = tmp * torch.norm(tensor, dim=1, keepdim=True) + return tmp + + +def _calc_loss_value( + model: nn.Module, data_loader: DataLoader, device: str, criterion: nn.Module = nn.CrossEntropyLoss() +): + # Given a model and corresponding dataset, calculate the mean loss value. + model = model.to(device) + model.eval() + tot_loss = [] + for _, (data_x, data_y) in enumerate(data_loader): + data_x, data_y = data_x.to(device), data_y.to(device) + pred_y = model(data_x) + loss = criterion(pred_y, data_y) + tot_loss.append(loss.item()) + model.to('cpu') + return sum(tot_loss) / len(tot_loss) + + +def plot_2d_loss_landscape( + model: nn.Module, + dataloader: DataLoader, + device: str, + caption: str, + save_path: str, + parameter_args: Dict = {"name": "all"}, + noise_range: Tuple[float, float] = (-1, 1), + resolution: int = 20, + visualize: bool = False, + log_scale: bool = False +) -> None: + """ + Overview: + This is a function that use visualization techniques proposed in: Visualizing the Loss Landscape of Neural Nets. + Currently, only linear layers and convolution layers will be considered. + Arguments: + model: The model that is needed to be checked for loss landscape. + dataloader: The dataloader used to check the landscape. + caption: The caption of generated graph. + save_path: The save path of the generated loss landscape picture. + parameter_args: Specify what parameters should add noises. Default to be ``{"name": "all"}``. For other \ + usages, please refer to the usage of ``aggregation_parameters`` in our configuration. A tutorial can \ + be found in: https://github.com/kxzxvbk/Fling/docs/meaning_for_configurations_en.md. + device: The device to run on, such as ``"cuda"`` or ``"cpu"``. + noise_range: The coordinate range of the loss-landscape, default to be ``(-1, 1)``. + resolution: The resolution of generated landscape. A larger resolution will cost longer time for computation, \ + but a lower resolution may result in unclear contours. Default to be ``20``. + visualize: Whether to directly show the picture in GUI. Default to be ``False``. + log_scale: Whether to use a log function to normalize the loss. Default to be ``False``. + """ + # Copy the original model. + orig_model = model + model = copy.deepcopy(model) + + # Generate two random directions. + rand_x, rand_y = {}, {} + for k, layer in model.named_modules(): + if parameter_args['name'] == 'all': + incl = True + elif parameter_args['name'] == 'contain': + kw = parameter_args['keywords'] + incl = any([kk in k for kk in kw]) + elif parameter_args['name'] == 'except': + kw = parameter_args['keywords'] + incl = all([kk not in k for kk in kw]) + else: + raise ValueError(f"Illegal parameter_args: {parameter_args}") + if not incl: + continue + + if isinstance(layer, nn.Linear): + orig_weight = copy.deepcopy(layer.weight) + rand_x0 = _gen_rand_like(orig_weight) + rand_y0 = _gen_rand_like(orig_weight) + elif isinstance(layer, nn.Conv2d): + orig_weight = copy.deepcopy(layer.weight) + orig_weight = orig_weight.reshape(orig_weight.shape[0], -1) + rand_x0 = _gen_rand_like(orig_weight) + rand_y0 = _gen_rand_like(orig_weight) + else: + continue + rand_x[k], rand_y[k] = rand_x0, rand_y0 + + # Generate the meshgrid for loss landscape. + x_coords = torch.linspace(noise_range[0], noise_range[1], resolution) + y_coords = torch.linspace(noise_range[0], noise_range[1], resolution) + loss_values = torch.zeros((resolution, resolution)).float() + + orig_layers = dict(orig_model.named_modules()) + with torch.no_grad(): + for i in tqdm(range(resolution)): + for j in range(resolution): + x_coord, y_coord = x_coords[i], y_coords[j] + for k, layer in model.named_modules(): + if k not in rand_x.keys(): + continue + elif isinstance(layer, nn.Linear): + orig_weight = copy.deepcopy(orig_layers[k].weight) + orig_weight += rand_x[k] * x_coord + rand_y[k] * y_coord + layer.weight = orig_weight + elif isinstance(layer, nn.Conv2d): + orig_weight = copy.deepcopy(orig_layers[k].weight) + orig_shape = orig_weight.shape + orig_weight = orig_weight.reshape(orig_weight.shape[0], -1) + orig_weight += rand_x[k] * x_coord + rand_y[k] * y_coord + layer.weight.data = orig_weight.reshape(orig_shape) + loss_values[i][j] = _calc_loss_value(model=model, data_loader=dataloader, device=device) + if log_scale: + loss_values = torch.log(loss_values) + + # Plot the result. + x_mesh, y_mesh = torch.meshgrid(x_coords, y_coords) + ax = plt.axes(projection='3d') + ax.plot_surface(x_mesh, y_mesh, loss_values, rstride=1, cstride=1, cmap='viridis', edgecolor='none') + ax.set_title(caption) + plt.savefig(save_path) + if visualize: + plt.show() + plt.cla() diff --git a/flzoo/cifar10/cifar10_fedmoon_cnn_config.py b/flzoo/cifar10/cifar10_fedmoon_cnn_config.py index b9fb215..3bbc507 100644 --- a/flzoo/cifar10/cifar10_fedmoon_cnn_config.py +++ b/flzoo/cifar10/cifar10_fedmoon_cnn_config.py @@ -14,15 +14,10 @@ optimizer=dict(name='sgd', lr=0.01, momentum=0.9), # The weight of fedmoon loss. mu=5, - temperature=0.5, + temperature=0.5, queue_len=1, ), - model=dict( - name='cnn', - input_channel=3, - linear_hidden_dims=[256], - class_number=10 - ), + model=dict(name='cnn', input_channel=3, linear_hidden_dims=[256], class_number=10), client=dict(name='fedmoon_client', client_num=10), server=dict(name='base_server'), group=dict(name='base_group', aggregation_method='avg'), diff --git a/flzoo/cifar100/cifar100_fedmoon_resnet_config.py b/flzoo/cifar100/cifar100_fedmoon_resnet_config.py index 226e667..75d8487 100644 --- a/flzoo/cifar100/cifar100_fedmoon_resnet_config.py +++ b/flzoo/cifar100/cifar100_fedmoon_resnet_config.py @@ -14,14 +14,10 @@ optimizer=dict(name='sgd', lr=0.01, momentum=0.9), # The weight of fedmoon loss. mu=1, - temperature=0.5, + temperature=0.5, queue_len=1, ), - model=dict( - name='resnet8', - input_channel=3, - class_number=100 - ), + model=dict(name='resnet8', input_channel=3, class_number=100), client=dict(name='fedmoon_client', client_num=10), server=dict(name='base_server'), group=dict(name='base_group', aggregation_method='avg'), diff --git a/setup.py b/setup.py index 989d65a..bce2d70 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ python_requires=">=3.7", install_requires=[ 'setuptools<=66.1.1', 'yapf==0.29.0', 'torch>=1.7.0', 'torchvision', 'numpy>=1.18.0', 'easydict==1.9', - 'tensorboard>=2.10.1', 'tqdm', 'timm', 'click', 'prettytable', 'einops', 'scipy', 'six', 'lmdb' + 'tensorboard>=2.10.1', 'tqdm', 'timm', 'click', 'prettytable', 'einops', 'scipy', 'six', 'lmdb', 'matplotlib' ], extras_require={ 'test': [ From f2f39df8deb7c4b96d15ab492731d33511c8ce4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Fri, 26 Jan 2024 11:59:40 +0800 Subject: [PATCH 2/9] fix --- fling/utils/visualize_utils/__init__.py | 1 + .../visualize_utils/conv_kernel_visualizer.py | 21 ++++++++ .../demo/demo_conv_kernel_visualize.py | 14 +++++ .../demo/demo_single_loss_landscape.py | 52 ++++++++++++++++--- fling/utils/visualize_utils/loss_landscape.py | 38 +++++++------- 5 files changed, 100 insertions(+), 26 deletions(-) create mode 100644 fling/utils/visualize_utils/conv_kernel_visualizer.py create mode 100644 fling/utils/visualize_utils/demo/demo_conv_kernel_visualize.py diff --git a/fling/utils/visualize_utils/__init__.py b/fling/utils/visualize_utils/__init__.py index 7ce876e..6105a89 100644 --- a/fling/utils/visualize_utils/__init__.py +++ b/fling/utils/visualize_utils/__init__.py @@ -1 +1,2 @@ from .loss_landscape import plot_2d_loss_landscape +from .conv_kernel_visualizer import plot_conv_kernels diff --git a/fling/utils/visualize_utils/conv_kernel_visualizer.py b/fling/utils/visualize_utils/conv_kernel_visualizer.py new file mode 100644 index 0000000..ab29e59 --- /dev/null +++ b/fling/utils/visualize_utils/conv_kernel_visualizer.py @@ -0,0 +1,21 @@ +import torchvision +from torch import nn + +from fling.utils import Logger + + +def plot_conv_kernels(logger: Logger, layer: nn.Conv2d, name: str) -> None: + """ + Overview: + Plot the kernels in a certain convolution layer for better visualization. + Arguments: + logger: The logger to write result image. + layer: The convolution layer to visualize. + name: The name of the plotted figure. + """ + param = layer.weight + in_channels = param.shape[1] + k_w, k_h = param.size()[3], param.size()[2] + kernel_all = param.view(-1, 1, k_w, k_h) + kernel_grid = torchvision.utils.make_grid(kernel_all, normalize=True, scale_each=True, nrow=in_channels) + logger.add_image(f'{name}', kernel_grid, global_step=0) diff --git a/fling/utils/visualize_utils/demo/demo_conv_kernel_visualize.py b/fling/utils/visualize_utils/demo/demo_conv_kernel_visualize.py new file mode 100644 index 0000000..4741bf3 --- /dev/null +++ b/fling/utils/visualize_utils/demo/demo_conv_kernel_visualize.py @@ -0,0 +1,14 @@ +from torchvision.models import resnet18 + +from fling.utils import Logger +from fling.utils.visualize_utils import plot_conv_kernels + +if __name__ == '__main__': + # Step 1: prepare the model. + model = resnet18(pretrained=True) + + # Step 2: prepare the logger. + logger = Logger('resnet18_conv_kernels') + + # Step 3: save the kernels. + plot_conv_kernels(logger, model.conv1, name='pre-conv') diff --git a/fling/utils/visualize_utils/demo/demo_single_loss_landscape.py b/fling/utils/visualize_utils/demo/demo_single_loss_landscape.py index ec0c9e4..df0dc10 100644 --- a/fling/utils/visualize_utils/demo/demo_single_loss_landscape.py +++ b/fling/utils/visualize_utils/demo/demo_single_loss_landscape.py @@ -1,15 +1,49 @@ +from easydict import EasyDict + import torch +from torch import nn from torch.utils.data import DataLoader from torchvision.models import resnet18 -from torchvision import transforms -from torchvision.datasets import CIFAR10 +from fling import dataset from fling.utils.visualize_utils import plot_2d_loss_landscape +from fling.utils.registry_utils import DATASET_REGISTRY + + +class ToyModel(nn.Module): + """ + Overview: + A toy model for demonstrating attacking results. + """ + + def __init__(self): + super(ToyModel, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) + self.relu1 = nn.ReLU() + + self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) + self.relu2 = nn.ReLU() + + self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1) + self.relu3 = nn.ReLU() + + self.pool = nn.AdaptiveAvgPool2d((1, 1)) + self.flat = nn.Flatten() + + self.fc = nn.Linear(128, 10) + + def forward(self, x): + x = self.relu1(self.conv1(x)) + x = self.relu2(self.conv2(x)) + x = self.relu3(self.conv3(x)) + x = self.flat(self.pool(x)) + return self.fc(x) + if __name__ == '__main__': # Step 1: prepare the dataset. - transform = transforms.Compose([transforms.ToTensor()]) - dataset = CIFAR10('./data/cifar10', transform=transform) + dataset_config = EasyDict(dict(data=dict(data_path='./data/cifar10', transforms=dict()))) + dataset = DATASET_REGISTRY.build('cifar10', dataset_config, train=False) # Test dataset is for generating loss landscape. test_dataset = [dataset[i] for i in range(100)] @@ -23,10 +57,11 @@ device = 'cuda' model = model.to(device) model.train() - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + optimizer = torch.optim.Adam(model.parameters(), lr=5e-4) criterion = torch.nn.CrossEntropyLoss() for _ in range(10): - for _, (data_x, data_y) in enumerate(dataloader): + for _, (data) in enumerate(dataloader): + data_x, data_y = data['input'], data['class_id'] data_x, data_y = data_x.to(device), data_y.to(device) pred_y = model(data_x) loss = criterion(pred_y, data_y) @@ -43,7 +78,8 @@ device='cuda', caption='Loss Landscape Trained', save_path='./landscape.pdf', - noise_range=(-0.01, 0.01), + noise_range=(-1, 1), resolution=30, - log_scale=True + log_scale=True, + max_val=20, ) diff --git a/fling/utils/visualize_utils/loss_landscape.py b/fling/utils/visualize_utils/loss_landscape.py index 70ef4e3..6add623 100644 --- a/fling/utils/visualize_utils/loss_landscape.py +++ b/fling/utils/visualize_utils/loss_landscape.py @@ -11,10 +11,8 @@ def _gen_rand_like(tensor: torch.Tensor) -> torch.Tensor: # Return a tensor whose shape is identical to the input tensor. # The returned tensor is a filled with Gaussian noise and the norm in each line is the same as the input. - tmp = torch.rand_like(tensor) - tmp = tmp / torch.norm(tmp, dim=1, keepdim=True) - tmp = tmp * torch.norm(tensor, dim=1, keepdim=True) - return tmp + tmp = torch.randn_like(tensor) + return tmp * torch.norm(tensor, dim=1, keepdim=True) / torch.norm(tmp, dim=1, keepdim=True) def _calc_loss_value( @@ -24,8 +22,8 @@ def _calc_loss_value( model = model.to(device) model.eval() tot_loss = [] - for _, (data_x, data_y) in enumerate(data_loader): - data_x, data_y = data_x.to(device), data_y.to(device) + for _, (data) in enumerate(data_loader): + data_x, data_y = data['input'].to(device), data['class_id'].to(device) pred_y = model(data_x) loss = criterion(pred_y, data_y) tot_loss.append(loss.item()) @@ -43,7 +41,8 @@ def plot_2d_loss_landscape( noise_range: Tuple[float, float] = (-1, 1), resolution: int = 20, visualize: bool = False, - log_scale: bool = False + log_scale: bool = False, + max_val: float = 5 ) -> None: """ Overview: @@ -63,6 +62,7 @@ def plot_2d_loss_landscape( but a lower resolution may result in unclear contours. Default to be ``20``. visualize: Whether to directly show the picture in GUI. Default to be ``False``. log_scale: Whether to use a log function to normalize the loss. Default to be ``False``. + max_val: The max value of permitted loss. This is for better visualization. """ # Copy the original model. orig_model = model @@ -71,6 +71,7 @@ def plot_2d_loss_landscape( # Generate two random directions. rand_x, rand_y = {}, {} for k, layer in model.named_modules(): + # Decide which parameters should be included. if parameter_args['name'] == 'all': incl = True elif parameter_args['name'] == 'contain': @@ -84,13 +85,12 @@ def plot_2d_loss_landscape( if not incl: continue + # Generate random noises. if isinstance(layer, nn.Linear): - orig_weight = copy.deepcopy(layer.weight) - rand_x0 = _gen_rand_like(orig_weight) - rand_y0 = _gen_rand_like(orig_weight) + rand_x0 = _gen_rand_like(layer.weight) + rand_y0 = _gen_rand_like(layer.weight) elif isinstance(layer, nn.Conv2d): - orig_weight = copy.deepcopy(layer.weight) - orig_weight = orig_weight.reshape(orig_weight.shape[0], -1) + orig_weight = layer.weight.reshape(layer.weight.shape[0], -1) rand_x0 = _gen_rand_like(orig_weight) rand_y0 = _gen_rand_like(orig_weight) else: @@ -111,16 +111,18 @@ def plot_2d_loss_landscape( if k not in rand_x.keys(): continue elif isinstance(layer, nn.Linear): - orig_weight = copy.deepcopy(orig_layers[k].weight) - orig_weight += rand_x[k] * x_coord + rand_y[k] * y_coord - layer.weight = orig_weight + orig_weight = orig_layers[k].weight.clone() + delta_w = rand_x[k] * x_coord + rand_y[k] * y_coord + orig_weight += delta_w + layer.weight.data = orig_weight elif isinstance(layer, nn.Conv2d): - orig_weight = copy.deepcopy(orig_layers[k].weight) + orig_weight = orig_layers[k].weight.clone() orig_shape = orig_weight.shape orig_weight = orig_weight.reshape(orig_weight.shape[0], -1) - orig_weight += rand_x[k] * x_coord + rand_y[k] * y_coord + delta_w = rand_x[k] * x_coord + rand_y[k] * y_coord + orig_weight += delta_w layer.weight.data = orig_weight.reshape(orig_shape) - loss_values[i][j] = _calc_loss_value(model=model, data_loader=dataloader, device=device) + loss_values[i][j] = min(_calc_loss_value(model=model, data_loader=dataloader, device=device), max_val) if log_scale: loss_values = torch.log(loss_values) From 3a198cf160fc16d18145471779e08706680ebc9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Wed, 31 Jan 2024 20:13:34 +0800 Subject: [PATCH 3/9] add hessian eigen value --- dlg_attacker/txt_logger_output.txt | 38 +++++++ fling/utils/visualize_utils/__init__.py | 1 + .../demo/demo_hessian_eigen_value.py | 75 +++++++++++++ .../demo/demo_single_loss_landscape.py | 3 +- .../visualize_utils/hessian_eigen_value.py | 106 ++++++++++++++++++ 5 files changed, 221 insertions(+), 2 deletions(-) create mode 100644 dlg_attacker/txt_logger_output.txt create mode 100644 fling/utils/visualize_utils/demo/demo_hessian_eigen_value.py create mode 100644 fling/utils/visualize_utils/hessian_eigen_value.py diff --git a/dlg_attacker/txt_logger_output.txt b/dlg_attacker/txt_logger_output.txt new file mode 100644 index 0000000..ab2f4dd --- /dev/null +++ b/dlg_attacker/txt_logger_output.txt @@ -0,0 +1,38 @@ +[Fri Jan 26 12:44:10 2024] +----------------+-------+-----------+----------+ +| Phase | Round | last_psnr | max_psnr | ++----------------+-------+-----------+----------+ +| reconstruction | 1 | 12.25909 | 12.98523 | ++----------------+-------+-----------+----------+ +[Fri Jan 26 12:44:10 2024] Star batch: 1... +[Fri Jan 26 12:46:14 2024] +----------------+-------+-----------+----------+ +| Phase | Round | last_psnr | max_psnr | ++----------------+-------+-----------+----------+ +| reconstruction | 2 | 16.02017 | 17.04575 | ++----------------+-------+-----------+----------+ +[Fri Jan 26 12:46:14 2024] Final reconstruction PSNR: 14.139629364013672. Max reconstruction PSNR: 15.015491485595703. +[Fri Jan 26 12:49:05 2024] Star batch: 0... +[Fri Jan 26 12:51:11 2024] +----------------+-------+-----------+----------+ +| Phase | Round | last_psnr | max_psnr | ++----------------+-------+-----------+----------+ +| reconstruction | 1 | 19.19348 | 19.27219 | ++----------------+-------+-----------+----------+ +[Fri Jan 26 12:51:11 2024] Star batch: 1... +[Fri Jan 26 12:53:17 2024] +----------------+-------+-----------+----------+ +| Phase | Round | last_psnr | max_psnr | ++----------------+-------+-----------+----------+ +| reconstruction | 2 | 7.22477 | 9.89688 | ++----------------+-------+-----------+----------+ +[Fri Jan 26 12:53:17 2024] Final reconstruction PSNR: 13.209124565124512. Max reconstruction PSNR: 14.584534168243408. +[Fri Jan 26 12:55:51 2024] Star batch: 0... +[Fri Jan 26 12:57:57 2024] +----------------+-------+-----------+----------+ +| Phase | Round | last_psnr | max_psnr | ++----------------+-------+-----------+----------+ +| reconstruction | 1 | 17.92364 | 17.92364 | ++----------------+-------+-----------+----------+ +[Fri Jan 26 12:57:57 2024] Star batch: 1... +[Fri Jan 26 13:00:02 2024] +----------------+-------+-----------+----------+ +| Phase | Round | last_psnr | max_psnr | ++----------------+-------+-----------+----------+ +| reconstruction | 2 | 10.95893 | 10.95893 | ++----------------+-------+-----------+----------+ +[Fri Jan 26 13:00:02 2024] Final reconstruction PSNR: 14.441285133361816. Max reconstruction PSNR: 14.441285133361816. diff --git a/fling/utils/visualize_utils/__init__.py b/fling/utils/visualize_utils/__init__.py index 6105a89..3634fd8 100644 --- a/fling/utils/visualize_utils/__init__.py +++ b/fling/utils/visualize_utils/__init__.py @@ -1,2 +1,3 @@ from .loss_landscape import plot_2d_loss_landscape from .conv_kernel_visualizer import plot_conv_kernels +from .hessian_eigen_value import calculate_hessian_dominant_eigen_values diff --git a/fling/utils/visualize_utils/demo/demo_hessian_eigen_value.py b/fling/utils/visualize_utils/demo/demo_hessian_eigen_value.py new file mode 100644 index 0000000..b3875d9 --- /dev/null +++ b/fling/utils/visualize_utils/demo/demo_hessian_eigen_value.py @@ -0,0 +1,75 @@ +from easydict import EasyDict + +import torch +from torch import nn +from torch.utils.data import DataLoader + +from fling import dataset +from fling.utils.visualize_utils import calculate_hessian_dominant_eigen_values +from fling.utils.registry_utils import DATASET_REGISTRY + + +class ToyModel(nn.Module): + """ + Overview: + A toy model for demonstrating attacking results. + """ + + def __init__(self): + super(ToyModel, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) + self.relu1 = nn.ReLU() + + self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) + self.relu2 = nn.ReLU() + + self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1) + self.relu3 = nn.ReLU() + + self.pool = nn.AdaptiveAvgPool2d((1, 1)) + self.flat = nn.Flatten() + + self.fc = nn.Linear(128, 10) + + def forward(self, x): + x = self.relu1(self.conv1(x)) + x = self.relu2(self.conv2(x)) + x = self.relu3(self.conv3(x)) + x = self.flat(self.pool(x)) + return self.fc(x) + + +if __name__ == '__main__': + # Step 1: prepare the dataset. + dataset_config = EasyDict(dict(data=dict(data_path='./data/cifar10', transforms=dict()))) + dataset = DATASET_REGISTRY.build('cifar10', dataset_config, train=False) + + # Test dataset is for generating loss landscape. + test_dataset = [dataset[i] for i in range(100)] + test_dataloader = DataLoader(test_dataset, batch_size=100) + + # Step 2: prepare the model. + model = ToyModel() + + # Step 3: train the randomly initialized model. + dataloader = DataLoader(dataset, batch_size=100) + device = 'cuda' + model = model.to(device) + model.train() + optimizer = torch.optim.Adam(model.parameters(), lr=5e-4) + criterion = torch.nn.CrossEntropyLoss() + for _ in range(0): + for _, (data) in enumerate(dataloader): + data_x, data_y = data['input'], data['class_id'] + data_x, data_y = data_x.to(device), data_y.to(device) + pred_y = model(data_x) + loss = criterion(pred_y, data_y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + model.to('cpu') + + # Step 4: plot the loss landscape after training the model. + # Only one line of code for visualization! + res = calculate_hessian_dominant_eigen_values(model, iter_num=20, dataloader=test_dataloader, device='cuda') + print(res) diff --git a/fling/utils/visualize_utils/demo/demo_single_loss_landscape.py b/fling/utils/visualize_utils/demo/demo_single_loss_landscape.py index df0dc10..4ab1289 100644 --- a/fling/utils/visualize_utils/demo/demo_single_loss_landscape.py +++ b/fling/utils/visualize_utils/demo/demo_single_loss_landscape.py @@ -3,7 +3,6 @@ import torch from torch import nn from torch.utils.data import DataLoader -from torchvision.models import resnet18 from fling import dataset from fling.utils.visualize_utils import plot_2d_loss_landscape @@ -50,7 +49,7 @@ def forward(self, x): test_dataloader = DataLoader(test_dataset, batch_size=100) # Step 2: prepare the model. - model = resnet18(pretrained=False, num_classes=10) + model = ToyModel() # Step 3: train the randomly initialized model. dataloader = DataLoader(dataset, batch_size=100) diff --git a/fling/utils/visualize_utils/hessian_eigen_value.py b/fling/utils/visualize_utils/hessian_eigen_value.py new file mode 100644 index 0000000..853de94 --- /dev/null +++ b/fling/utils/visualize_utils/hessian_eigen_value.py @@ -0,0 +1,106 @@ +from typing import Sequence, List, Dict +import copy + +import torch +from torch import nn +from torch.autograd import grad +from torch.utils.data import DataLoader + + +def _get_first_grad(loss: torch.Tensor, w: List) -> Sequence: + """ + Calculate: g_i = \\frac{dL}{dW_i} + """ + return grad(loss, w, retain_graph=True, create_graph=True) + + +def _get_hv(g: Sequence, w: Sequence, v: Sequence) -> Sequence: + """ + Calculate: Hv = \\frac{d(gv)}{dW_i} + """ + assert len(w) == len(v) + return grad(g, w, grad_outputs=v, retain_graph=True) + + +def _normalize(vs: Sequence) -> None: + """ + Normalize vectors in ``vs``. + """ + for i in range(len(vs)): + vs[i] = vs[i] / torch.norm(vs[i]) + + +def _calc_loss_value( + model: nn.Module, data_loader: DataLoader, device: str, criterion: nn.Module = nn.CrossEntropyLoss() +): + # Given a model and corresponding dataset, calculate the mean loss value. + model.eval() + tot_loss = [] + for _, (data) in enumerate(data_loader): + data_x, data_y = data['input'].to(device), data['class_id'].to(device) + pred_y = model(data_x) + loss = criterion(pred_y, data_y) + tot_loss.append(loss) + tot_loss = torch.stack(tot_loss, dim=0) + return torch.mean(tot_loss) + + +def _rayleigh_quotient(hv: Sequence, v: Sequence) -> List: + """ + Calculate: \\lambda = \\frac{v^THv}{v^Tv} + """ + return [((torch.flatten(v[i].T) @ torch.flatten(hv[i])) / + (torch.flatten(v[i].T) @ torch.flatten(v[i]))).item() for i in range(len(hv))] + + +def calculate_hessian_dominant_eigen_values( + model: nn.Module, + iter_num: int, + dataloader: DataLoader, + device: str +) -> Dict: + """ + Overview: + Calculate each dominant eigen value of each layer in the model. + Arguments: + model: The neural network that calculates ``loss``. + iter_num: Number of iterations using power iteration. + dataloader: The dataloader used to calculate hessian eigen values. + device: The device to run on, such as ``"cuda"`` or ``"cpu"``. + Returns: + A diction of dominant eigen values for each layer. + """ + # Copy the original model. + orig_model = model + model = copy.deepcopy(model).to(device) + + # Calculate loss value using given data. + loss = _calc_loss_value(model, data_loader=dataloader, device=device) + + # Calculate eigen values and return. + # Flatten the parameter weights. + ws = dict(model.named_parameters()) + keys = list(ws.keys()) + ws = list(ws.values()) + + # Calculate grad. + g = _get_first_grad(loss, ws) + + # Initialize vs and normalize them. + vs = [torch.randn_like(g[i]) for i in range(len(g))] + _normalize(vs) + + # Power iteration. + for i in range(iter_num): + hv = _get_hv(g, ws, vs) + print(f'Iteration: {i}') + print(_rayleigh_quotient(hv, vs)) + vs = [hv[i].detach() for i in range(len(hv))] + _normalize(vs) + + # Calculate eigen values. + hv = _get_hv(g, ws, vs) + lambdas = _rayleigh_quotient(hv, vs) + dict_lambdas = {keys[i]: lambdas[i] for i in range(len(lambdas))} + + return dict_lambdas From f6769f2c9c0181bcbf3d619128fc7c99c7cc4c0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Wed, 31 Jan 2024 20:14:01 +0800 Subject: [PATCH 4/9] remove redundant files --- dlg_attacker/txt_logger_output.txt | 38 ------------------------------ 1 file changed, 38 deletions(-) delete mode 100644 dlg_attacker/txt_logger_output.txt diff --git a/dlg_attacker/txt_logger_output.txt b/dlg_attacker/txt_logger_output.txt deleted file mode 100644 index ab2f4dd..0000000 --- a/dlg_attacker/txt_logger_output.txt +++ /dev/null @@ -1,38 +0,0 @@ -[Fri Jan 26 12:44:10 2024] +----------------+-------+-----------+----------+ -| Phase | Round | last_psnr | max_psnr | -+----------------+-------+-----------+----------+ -| reconstruction | 1 | 12.25909 | 12.98523 | -+----------------+-------+-----------+----------+ -[Fri Jan 26 12:44:10 2024] Star batch: 1... -[Fri Jan 26 12:46:14 2024] +----------------+-------+-----------+----------+ -| Phase | Round | last_psnr | max_psnr | -+----------------+-------+-----------+----------+ -| reconstruction | 2 | 16.02017 | 17.04575 | -+----------------+-------+-----------+----------+ -[Fri Jan 26 12:46:14 2024] Final reconstruction PSNR: 14.139629364013672. Max reconstruction PSNR: 15.015491485595703. -[Fri Jan 26 12:49:05 2024] Star batch: 0... -[Fri Jan 26 12:51:11 2024] +----------------+-------+-----------+----------+ -| Phase | Round | last_psnr | max_psnr | -+----------------+-------+-----------+----------+ -| reconstruction | 1 | 19.19348 | 19.27219 | -+----------------+-------+-----------+----------+ -[Fri Jan 26 12:51:11 2024] Star batch: 1... -[Fri Jan 26 12:53:17 2024] +----------------+-------+-----------+----------+ -| Phase | Round | last_psnr | max_psnr | -+----------------+-------+-----------+----------+ -| reconstruction | 2 | 7.22477 | 9.89688 | -+----------------+-------+-----------+----------+ -[Fri Jan 26 12:53:17 2024] Final reconstruction PSNR: 13.209124565124512. Max reconstruction PSNR: 14.584534168243408. -[Fri Jan 26 12:55:51 2024] Star batch: 0... -[Fri Jan 26 12:57:57 2024] +----------------+-------+-----------+----------+ -| Phase | Round | last_psnr | max_psnr | -+----------------+-------+-----------+----------+ -| reconstruction | 1 | 17.92364 | 17.92364 | -+----------------+-------+-----------+----------+ -[Fri Jan 26 12:57:57 2024] Star batch: 1... -[Fri Jan 26 13:00:02 2024] +----------------+-------+-----------+----------+ -| Phase | Round | last_psnr | max_psnr | -+----------------+-------+-----------+----------+ -| reconstruction | 2 | 10.95893 | 10.95893 | -+----------------+-------+-----------+----------+ -[Fri Jan 26 13:00:02 2024] Final reconstruction PSNR: 14.441285133361816. Max reconstruction PSNR: 14.441285133361816. From 4b10605dcd0b5509b95c4060984eece7f619af3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Wed, 31 Jan 2024 21:08:02 +0800 Subject: [PATCH 5/9] polish doc --- fling/utils/visualize_utils/hessian_eigen_value.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fling/utils/visualize_utils/hessian_eigen_value.py b/fling/utils/visualize_utils/hessian_eigen_value.py index 853de94..63422b7 100644 --- a/fling/utils/visualize_utils/hessian_eigen_value.py +++ b/fling/utils/visualize_utils/hessian_eigen_value.py @@ -61,7 +61,9 @@ def calculate_hessian_dominant_eigen_values( ) -> Dict: """ Overview: - Calculate each dominant eigen value of each layer in the model. + Using power iteration to calculate each dominant eigen value of each layer in the model. + Reference paper: HAWQ: Hessian AWare Quantization of Neural Networks with Mixed-Precision + Arguments: model: The neural network that calculates ``loss``. iter_num: Number of iterations using power iteration. @@ -93,8 +95,6 @@ def calculate_hessian_dominant_eigen_values( # Power iteration. for i in range(iter_num): hv = _get_hv(g, ws, vs) - print(f'Iteration: {i}') - print(_rayleigh_quotient(hv, vs)) vs = [hv[i].detach() for i in range(len(hv))] _normalize(vs) From 4841a00974420f257e8cfb2ee8057ed4af50a0b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Thu, 7 Mar 2024 23:53:53 +0800 Subject: [PATCH 6/9] new --- fling/pipeline/__init__.py | 2 + .../generic_model_visualization_pipeline.py | 122 ++++++++++++++++++ fling/utils/torch_utils.py | 22 ++++ fling/utils/visualize_utils/loss_landscape.py | 59 ++++++++- ...ar10_fedavg_visualization_resnet_config.py | 25 ++++ 5 files changed, 224 insertions(+), 6 deletions(-) create mode 100644 fling/pipeline/generic_model_visualization_pipeline.py create mode 100644 flzoo/cifar10/cifar10_fedavg_visualization_resnet_config.py diff --git a/fling/pipeline/__init__.py b/fling/pipeline/__init__.py index ea2c657..eae5bc6 100644 --- a/fling/pipeline/__init__.py +++ b/fling/pipeline/__init__.py @@ -1,2 +1,4 @@ from .generic_model_pipeline import generic_model_pipeline from .personalized_model_pipeline import personalized_model_pipeline +from .generic_model_visualization_pipeline import generic_model_visualization_pipeline + diff --git a/fling/pipeline/generic_model_visualization_pipeline.py b/fling/pipeline/generic_model_visualization_pipeline.py new file mode 100644 index 0000000..9953e9b --- /dev/null +++ b/fling/pipeline/generic_model_visualization_pipeline.py @@ -0,0 +1,122 @@ +import copy +import os +import torch +from torch.utils.data import DataLoader + +from fling.component.client import get_client +from fling.component.server import get_server +from fling.component.group import get_group +from fling.dataset import get_dataset + +from fling.utils.data_utils import data_sampling +from fling.utils import Logger, compile_config, client_sampling, VariableMonitor, LRScheduler +from fling.utils import get_launcher +from fling.utils import plot_2d_loss_landscape + + +def generic_model_visualization_pipeline(args: dict, seed: int = 0) -> None: + r""" + Overview: + Pipeline for generic federated learning. Under this setting, models of each client is the same. + We plot the loss landscape before and after aggregation in each round. + The final performance of this generic model is tested on the server (typically using a global test dataset). + Arguments: + - args: dict type arguments. + - seed: random seed. + """ + # Compile the input arguments first. + args = compile_config(args, seed) + + # Construct logger. + logger = Logger(args.other.logging_path) + + # Load dataset. + train_set = get_dataset(args, train=True) + test_set = get_dataset(args, train=False) + + part_test = [test_set[i] for i in range(100)] + part_test = DataLoader(part_test, batch_size=args.learn.batch_size, shuffle=True) + + # Split dataset into clients. + train_sets = data_sampling(train_set, args, seed, train=True) + + # Initialize group, clients and server. + group = get_group(args, logger) + group.server = get_server(args, test_dataset=test_set) + for i in range(args.client.client_num): + group.append(get_client(args=args, client_id=i, train_dataset=train_sets[i])) + group.initialize() + + # Setup lr_scheduler. + lr_scheduler = LRScheduler(args) + # Setup launcher. + launcher = get_launcher(args) + + # Variables for visualization + last_global_model = None + + # Training loop + for i in range(args.learn.global_eps): + logger.logging('Starting round: ' + str(i)) + # Initialize variable monitor. + train_monitor = VariableMonitor() + + # Random sample participated clients in each communication round. + participated_clients = client_sampling(range(args.client.client_num), args.client.sample_rate) + + # Adjust learning rate. + cur_lr = lr_scheduler.get_lr(train_round=i) + + # Local training for each participated client and add results to the monitor. + # Use multiprocessing for acceleration. + train_results = launcher.launch( + clients=[group.clients[j] for j in participated_clients], lr=cur_lr, task_name='train' + ) + for item in train_results: + train_monitor.append(item) + + # Testing + if i % args.other.test_freq == 0 and "before_aggregation" in args.learn.test_place: + test_result = group.server.test(model=group.clients[0].model) + # Logging test variables. + logger.add_scalars_dict(prefix='before_aggregation_test', dic=test_result, rnd=i) + + if last_global_model is not None: + plot_2d_loss_landscape(model=last_global_model, dataloader=part_test, + device=args.learn.device, caption='Global-test Loss Landscape', + save_path=os.path.join(args.other.logging_path, f"losslandscape_gt_{i}.pdf"), + target_model1=group.clients[0].model, target_model2=group.clients[1].model, + resolution=20, noise_range=(-0.1, 1.0), + log_scale=True, max_val=20) + plot_2d_loss_landscape(model=last_global_model, dataloader=group.clients[0].train_dataloader, + device=args.learn.device, caption='Client-1-train Loss Landscape', + save_path=os.path.join(args.other.logging_path, f"losslandscape_ct1_{i}.pdf"), + target_model1=group.clients[0].model, target_model2=group.clients[1].model, + resolution=20, noise_range=(-0.1, 1.1), + log_scale=True, max_val=20) + plot_2d_loss_landscape(model=last_global_model, dataloader=group.clients[1].train_dataloader, + device=args.learn.device, caption='Client-2-train Loss Landscape', + save_path=os.path.join(args.other.logging_path, f"losslandscape_ct2_{i}.pdf"), + target_model1=group.clients[0].model, target_model2=group.clients[1].model, + resolution=20, noise_range=(-0.1, 1.1), + log_scale=True, max_val=20) + + # Aggregate parameters in each client. + trans_cost = group.aggregate(i) + + last_global_model = copy.deepcopy(group.clients[0].model) + + # Logging train variables. + mean_train_variables = train_monitor.variable_mean() + mean_train_variables.update({'trans_cost': trans_cost / 1e6, 'lr': cur_lr}) + logger.add_scalars_dict(prefix='train', dic=mean_train_variables, rnd=i) + + # Testing + if i % args.other.test_freq == 0 and "after_aggregation" in args.learn.test_place: + test_result = group.server.test(model=group.clients[0].model) + + # Logging test variables. + logger.add_scalars_dict(prefix='after_aggregation_test', dic=test_result, rnd=i) + + # Saving model checkpoints. + torch.save(group.server.glob_dict, os.path.join(args.other.logging_path, 'model.ckpt')) diff --git a/fling/utils/torch_utils.py b/fling/utils/torch_utils.py index b5398c1..5f23805 100644 --- a/fling/utils/torch_utils.py +++ b/fling/utils/torch_utils.py @@ -1,3 +1,4 @@ +import copy import math import pickle import random @@ -244,3 +245,24 @@ def forward(self, x): def _tensor_size(self, t): return t.size()[1] * t.size()[2] * t.size()[3] + + +def model_add(model1: nn.Module, model2: nn.Module) -> nn.Module: + ret = copy.deepcopy(model1) + sd1, sd2 = model1.state_dict(), model2.state_dict() + ret.load_state_dict({k: sd1[k] + sd2[k] for k in sd1.keys()}) + return ret + + +def model_sub(model1: nn.Module, model2: nn.Module) -> nn.Module: + ret = copy.deepcopy(model1) + sd1, sd2 = model1.state_dict(), model2.state_dict() + ret.load_state_dict({k: sd1[k] - sd2[k] for k in sd1.keys()}) + return ret + + +def model_mul(scalar: float, model: nn.Module) -> nn.Module: + ret = copy.deepcopy(model) + sd = model.state_dict() + ret.load_state_dict({k: scalar * sd[k] for k in sd.keys()}) + return ret diff --git a/fling/utils/visualize_utils/loss_landscape.py b/fling/utils/visualize_utils/loss_landscape.py index 6add623..8c33ef1 100644 --- a/fling/utils/visualize_utils/loss_landscape.py +++ b/fling/utils/visualize_utils/loss_landscape.py @@ -1,7 +1,9 @@ import copy -from typing import Tuple, Dict +import math +from typing import Tuple, Dict, Optional from matplotlib import pyplot as plt from tqdm import tqdm +from fling.utils.torch_utils import model_add, model_sub, model_mul import torch from torch import nn @@ -16,7 +18,7 @@ def _gen_rand_like(tensor: torch.Tensor) -> torch.Tensor: def _calc_loss_value( - model: nn.Module, data_loader: DataLoader, device: str, criterion: nn.Module = nn.CrossEntropyLoss() + model: nn.Module, data_loader: DataLoader, device: str, criterion: nn.Module = nn.CrossEntropyLoss() ): # Given a model and corresponding dataset, calculate the mean loss value. model = model.to(device) @@ -37,12 +39,14 @@ def plot_2d_loss_landscape( device: str, caption: str, save_path: str, + target_model1: Optional[nn.Module] = None, + target_model2: Optional[nn.Module] = None, parameter_args: Dict = {"name": "all"}, noise_range: Tuple[float, float] = (-1, 1), resolution: int = 20, visualize: bool = False, log_scale: bool = False, - max_val: float = 5 + max_val: float = 5, ) -> None: """ Overview: @@ -53,6 +57,9 @@ def plot_2d_loss_landscape( dataloader: The dataloader used to check the landscape. caption: The caption of generated graph. save_path: The save path of the generated loss landscape picture. + target_model1: If specified, the first direction of visualization will be not randomly chosen, but will be set \ + as ``target_model1 - model``. + target_model2: Similar to ``target_model1``, determine the second direction of visualization. parameter_args: Specify what parameters should add noises. Default to be ``{"name": "all"}``. For other \ usages, please refer to the usage of ``aggregation_parameters`` in our configuration. A tutorial can \ be found in: https://github.com/kxzxvbk/Fling/docs/meaning_for_configurations_en.md. @@ -67,6 +74,7 @@ def plot_2d_loss_landscape( # Copy the original model. orig_model = model model = copy.deepcopy(model) + model.eval() # Generate two random directions. rand_x, rand_y = {}, {} @@ -86,10 +94,19 @@ def plot_2d_loss_landscape( continue # Generate random noises. - if isinstance(layer, nn.Linear): + if (isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d) or isinstance(layer, nn.BatchNorm2d)) \ + and target_model1 is not None and target_model2 is not None: + # Detail: If the target models are specified, all parameters together with statistics (e.g. BN statistics) + # will be permuted. If the target models are not specified, only weight tensors in linear layers and + # convolution layers will be permuted. + rand_x0 = model_sub(dict(target_model1.named_modules())[k], layer) + rand_y0 = model_sub(dict(target_model2.named_modules())[k], layer) + elif isinstance(layer, nn.Linear): + # Generate random linear weight tensors. rand_x0 = _gen_rand_like(layer.weight) rand_y0 = _gen_rand_like(layer.weight) elif isinstance(layer, nn.Conv2d): + # Generate random convolution weight tensors. orig_weight = layer.weight.reshape(layer.weight.shape[0], -1) rand_x0 = _gen_rand_like(orig_weight) rand_y0 = _gen_rand_like(orig_weight) @@ -110,12 +127,20 @@ def plot_2d_loss_landscape( for k, layer in model.named_modules(): if k not in rand_x.keys(): continue + elif target_model1 is not None and target_model2 is not None: + # If the target models are specified, manipulate the total model. + new_layer = model_add(model_mul(x_coord, rand_x[k]), model_mul(y_coord, rand_y[k])) + new_layer = model_add(new_layer, orig_layers[k]) + # Copy the new generated layers to the original object. + layer.load_state_dict(new_layer.state_dict()) elif isinstance(layer, nn.Linear): + # Target models are not specified, only operate the weight tensors. orig_weight = orig_layers[k].weight.clone() delta_w = rand_x[k] * x_coord + rand_y[k] * y_coord orig_weight += delta_w layer.weight.data = orig_weight elif isinstance(layer, nn.Conv2d): + # Operate on the convolution weight tensors. orig_weight = orig_layers[k].weight.clone() orig_shape = orig_weight.shape orig_weight = orig_weight.reshape(orig_weight.shape[0], -1) @@ -129,9 +154,31 @@ def plot_2d_loss_landscape( # Plot the result. x_mesh, y_mesh = torch.meshgrid(x_coords, y_coords) ax = plt.axes(projection='3d') - ax.plot_surface(x_mesh, y_mesh, loss_values, rstride=1, cstride=1, cmap='viridis', edgecolor='none') + + # Add special dots. + non_nan_tensor = loss_values[~torch.isnan(loss_values)] + max_loss = torch.max(non_nan_tensor) + + loss1 = _calc_loss_value(model=orig_model, data_loader=dataloader, device=device) + if log_scale: + loss1 = math.log(loss1) + ax.text(0, 0, max_loss, "GM ({:.2f})".format(loss1), color='black', zorder=2) + + # Add client model dots. + if target_model1 is not None and target_model2 is not None: + loss1 = _calc_loss_value(model=target_model1, data_loader=dataloader, device=device) + if log_scale: + loss1 = math.log(loss1) + ax.text(1, 1, max_loss, "LM1 ({:.2f})".format(loss1), color='black', zorder=2) + + loss2 = _calc_loss_value(model=target_model2, data_loader=dataloader, device=device) + if log_scale: + loss2 = math.log(loss2) + ax.text(0, 1, max_loss, "LM2 ({:.2f})".format(loss2), color='black', zorder=2) + + ax.plot_surface(x_mesh, y_mesh, loss_values, rstride=1, cstride=1, cmap='viridis', edgecolor='none', zorder=1) ax.set_title(caption) plt.savefig(save_path) if visualize: plt.show() - plt.cla() + plt.close() diff --git a/flzoo/cifar10/cifar10_fedavg_visualization_resnet_config.py b/flzoo/cifar10/cifar10_fedavg_visualization_resnet_config.py new file mode 100644 index 0000000..898392c --- /dev/null +++ b/flzoo/cifar10/cifar10_fedavg_visualization_resnet_config.py @@ -0,0 +1,25 @@ +from easydict import EasyDict + +exp_args = dict( + data=dict( + dataset='cifar10', data_path='./data/CIFAR10', sample_method=dict(name='iid', train_num=500, test_num=100) + ), + learn=dict( + device='cuda:0', local_eps=5, global_eps=10, batch_size=100, optimizer=dict(name='sgd', lr=0.1, momentum=0.9) + ), + model=dict( + name='resnet8', + input_channel=3, + class_number=10, + ), + client=dict(name='base_client', client_num=2), + server=dict(name='base_server'), + group=dict(name='base_group', aggregation_method='avg'), + other=dict(test_freq=3, logging_path='./logging/cifar10_fedavg_visualization_resnet_iid') +) + +exp_args = EasyDict(exp_args) + +if __name__ == '__main__': + from fling.pipeline import generic_model_visualization_pipeline + generic_model_visualization_pipeline(exp_args, seed=0) From 3c9ba05fc01802cf7929102841afee05d44fcab2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Fri, 8 Mar 2024 20:56:12 +0800 Subject: [PATCH 7/9] chage --- fling/pipeline/generic_model_visualization_pipeline.py | 2 +- flzoo/cifar10/cifar10_fedavg_visualization_resnet_config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fling/pipeline/generic_model_visualization_pipeline.py b/fling/pipeline/generic_model_visualization_pipeline.py index 9953e9b..b60c8a8 100644 --- a/fling/pipeline/generic_model_visualization_pipeline.py +++ b/fling/pipeline/generic_model_visualization_pipeline.py @@ -81,7 +81,7 @@ def generic_model_visualization_pipeline(args: dict, seed: int = 0) -> None: # Logging test variables. logger.add_scalars_dict(prefix='before_aggregation_test', dic=test_result, rnd=i) - if last_global_model is not None: + if last_global_model is not None and i > 45: plot_2d_loss_landscape(model=last_global_model, dataloader=part_test, device=args.learn.device, caption='Global-test Loss Landscape', save_path=os.path.join(args.other.logging_path, f"losslandscape_gt_{i}.pdf"), diff --git a/flzoo/cifar10/cifar10_fedavg_visualization_resnet_config.py b/flzoo/cifar10/cifar10_fedavg_visualization_resnet_config.py index 898392c..fa2b790 100644 --- a/flzoo/cifar10/cifar10_fedavg_visualization_resnet_config.py +++ b/flzoo/cifar10/cifar10_fedavg_visualization_resnet_config.py @@ -5,7 +5,7 @@ dataset='cifar10', data_path='./data/CIFAR10', sample_method=dict(name='iid', train_num=500, test_num=100) ), learn=dict( - device='cuda:0', local_eps=5, global_eps=10, batch_size=100, optimizer=dict(name='sgd', lr=0.1, momentum=0.9) + device='cuda:0', local_eps=5, global_eps=50, batch_size=100, optimizer=dict(name='sgd', lr=0.1, momentum=0.9) ), model=dict( name='resnet8', From 7ff95d51f340f89fffea1777c7e15f0f90d1aeb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Tue, 11 Jun 2024 12:08:49 +0800 Subject: [PATCH 8/9] polish --- fling/model/resnet.py | 1 - .../demo/demo_conv_kernel_visualize.py | 13 ++++++++++++- .../demo/demo_hessian_eigen_value.py | 15 +++++++++++++-- .../demo/demo_single_loss_landscape.py | 13 ++++++++++++- 4 files changed, 37 insertions(+), 5 deletions(-) diff --git a/fling/model/resnet.py b/fling/model/resnet.py index 7f058b3..b94cc0c 100644 --- a/fling/model/resnet.py +++ b/fling/model/resnet.py @@ -260,7 +260,6 @@ def _forward_impl(self, x: Tensor, mode: str = 'compute-logit') -> Tensor: x = self.avgpool(x) x = torch.flatten(x, 1) - x = self.mlp(x) y = self.fc(x) if mode == 'compute-logit': return y diff --git a/fling/utils/visualize_utils/demo/demo_conv_kernel_visualize.py b/fling/utils/visualize_utils/demo/demo_conv_kernel_visualize.py index 4741bf3..0058088 100644 --- a/fling/utils/visualize_utils/demo/demo_conv_kernel_visualize.py +++ b/fling/utils/visualize_utils/demo/demo_conv_kernel_visualize.py @@ -2,10 +2,21 @@ from fling.utils import Logger from fling.utils.visualize_utils import plot_conv_kernels +from easydict import EasyDict +from fling.utils.registry_utils import MODEL_REGISTRY if __name__ == '__main__': # Step 1: prepare the model. - model = resnet18(pretrained=True) + model_arg = EasyDict(dict( + name='resnet8', + input_channel=3, + class_number=100, + )) + model_name = model_arg.pop('name') + model = MODEL_REGISTRY.build(model_name, **model_arg) + + # You can also initialize the model without using configurations. + # e.g. model = resnet18(pretrained=True) # Step 2: prepare the logger. logger = Logger('resnet18_conv_kernels') diff --git a/fling/utils/visualize_utils/demo/demo_hessian_eigen_value.py b/fling/utils/visualize_utils/demo/demo_hessian_eigen_value.py index b3875d9..f269893 100644 --- a/fling/utils/visualize_utils/demo/demo_hessian_eigen_value.py +++ b/fling/utils/visualize_utils/demo/demo_hessian_eigen_value.py @@ -7,6 +7,8 @@ from fling import dataset from fling.utils.visualize_utils import calculate_hessian_dominant_eigen_values from fling.utils.registry_utils import DATASET_REGISTRY +from easydict import EasyDict +from fling.utils.registry_utils import MODEL_REGISTRY class ToyModel(nn.Module): @@ -49,7 +51,16 @@ def forward(self, x): test_dataloader = DataLoader(test_dataset, batch_size=100) # Step 2: prepare the model. - model = ToyModel() + model_arg = EasyDict(dict( + name='resnet8', + input_channel=3, + class_number=10, + )) + model_name = model_arg.pop('name') + model = MODEL_REGISTRY.build(model_name, **model_arg) + + # You can also initialize the model without using configurations. + # e.g. model = ToyModel() # Step 3: train the randomly initialized model. dataloader = DataLoader(dataset, batch_size=100) @@ -70,6 +81,6 @@ def forward(self, x): model.to('cpu') # Step 4: plot the loss landscape after training the model. - # Only one line of code for visualization! + # Only one line of code for visualization. res = calculate_hessian_dominant_eigen_values(model, iter_num=20, dataloader=test_dataloader, device='cuda') print(res) diff --git a/fling/utils/visualize_utils/demo/demo_single_loss_landscape.py b/fling/utils/visualize_utils/demo/demo_single_loss_landscape.py index 4ab1289..6fc574d 100644 --- a/fling/utils/visualize_utils/demo/demo_single_loss_landscape.py +++ b/fling/utils/visualize_utils/demo/demo_single_loss_landscape.py @@ -7,6 +7,8 @@ from fling import dataset from fling.utils.visualize_utils import plot_2d_loss_landscape from fling.utils.registry_utils import DATASET_REGISTRY +from easydict import EasyDict +from fling.utils.registry_utils import MODEL_REGISTRY class ToyModel(nn.Module): @@ -49,7 +51,16 @@ def forward(self, x): test_dataloader = DataLoader(test_dataset, batch_size=100) # Step 2: prepare the model. - model = ToyModel() + model_arg = EasyDict(dict( + name='resnet8', + input_channel=3, + class_number=10, + )) + model_name = model_arg.pop('name') + model = MODEL_REGISTRY.build(model_name, **model_arg) + + # You can also initialize the model without using configurations. + # e.g. model = ToyModel() # Step 3: train the randomly initialized model. dataloader = DataLoader(dataset, batch_size=100) From 86e66d0124bb0d43e4425fc5274d997ed44bd980 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Tue, 11 Jun 2024 12:15:09 +0800 Subject: [PATCH 9/9] polish --- fling/pipeline/__init__.py | 1 - .../generic_model_visualization_pipeline.py | 57 +++++++++++++------ .../visualize_utils/hessian_eigen_value.py | 11 ++-- fling/utils/visualize_utils/loss_landscape.py | 28 ++++----- 4 files changed, 58 insertions(+), 39 deletions(-) diff --git a/fling/pipeline/__init__.py b/fling/pipeline/__init__.py index eae5bc6..19cb23f 100644 --- a/fling/pipeline/__init__.py +++ b/fling/pipeline/__init__.py @@ -1,4 +1,3 @@ from .generic_model_pipeline import generic_model_pipeline from .personalized_model_pipeline import personalized_model_pipeline from .generic_model_visualization_pipeline import generic_model_visualization_pipeline - diff --git a/fling/pipeline/generic_model_visualization_pipeline.py b/fling/pipeline/generic_model_visualization_pipeline.py index b60c8a8..7a9df11 100644 --- a/fling/pipeline/generic_model_visualization_pipeline.py +++ b/fling/pipeline/generic_model_visualization_pipeline.py @@ -82,24 +82,45 @@ def generic_model_visualization_pipeline(args: dict, seed: int = 0) -> None: logger.add_scalars_dict(prefix='before_aggregation_test', dic=test_result, rnd=i) if last_global_model is not None and i > 45: - plot_2d_loss_landscape(model=last_global_model, dataloader=part_test, - device=args.learn.device, caption='Global-test Loss Landscape', - save_path=os.path.join(args.other.logging_path, f"losslandscape_gt_{i}.pdf"), - target_model1=group.clients[0].model, target_model2=group.clients[1].model, - resolution=20, noise_range=(-0.1, 1.0), - log_scale=True, max_val=20) - plot_2d_loss_landscape(model=last_global_model, dataloader=group.clients[0].train_dataloader, - device=args.learn.device, caption='Client-1-train Loss Landscape', - save_path=os.path.join(args.other.logging_path, f"losslandscape_ct1_{i}.pdf"), - target_model1=group.clients[0].model, target_model2=group.clients[1].model, - resolution=20, noise_range=(-0.1, 1.1), - log_scale=True, max_val=20) - plot_2d_loss_landscape(model=last_global_model, dataloader=group.clients[1].train_dataloader, - device=args.learn.device, caption='Client-2-train Loss Landscape', - save_path=os.path.join(args.other.logging_path, f"losslandscape_ct2_{i}.pdf"), - target_model1=group.clients[0].model, target_model2=group.clients[1].model, - resolution=20, noise_range=(-0.1, 1.1), - log_scale=True, max_val=20) + plot_2d_loss_landscape( + model=last_global_model, + dataloader=part_test, + device=args.learn.device, + caption='Global-test Loss Landscape', + save_path=os.path.join(args.other.logging_path, f"losslandscape_gt_{i}.pdf"), + target_model1=group.clients[0].model, + target_model2=group.clients[1].model, + resolution=20, + noise_range=(-0.1, 1.0), + log_scale=True, + max_val=20 + ) + plot_2d_loss_landscape( + model=last_global_model, + dataloader=group.clients[0].train_dataloader, + device=args.learn.device, + caption='Client-1-train Loss Landscape', + save_path=os.path.join(args.other.logging_path, f"losslandscape_ct1_{i}.pdf"), + target_model1=group.clients[0].model, + target_model2=group.clients[1].model, + resolution=20, + noise_range=(-0.1, 1.1), + log_scale=True, + max_val=20 + ) + plot_2d_loss_landscape( + model=last_global_model, + dataloader=group.clients[1].train_dataloader, + device=args.learn.device, + caption='Client-2-train Loss Landscape', + save_path=os.path.join(args.other.logging_path, f"losslandscape_ct2_{i}.pdf"), + target_model1=group.clients[0].model, + target_model2=group.clients[1].model, + resolution=20, + noise_range=(-0.1, 1.1), + log_scale=True, + max_val=20 + ) # Aggregate parameters in each client. trans_cost = group.aggregate(i) diff --git a/fling/utils/visualize_utils/hessian_eigen_value.py b/fling/utils/visualize_utils/hessian_eigen_value.py index 63422b7..d109b01 100644 --- a/fling/utils/visualize_utils/hessian_eigen_value.py +++ b/fling/utils/visualize_utils/hessian_eigen_value.py @@ -49,15 +49,14 @@ def _rayleigh_quotient(hv: Sequence, v: Sequence) -> List: """ Calculate: \\lambda = \\frac{v^THv}{v^Tv} """ - return [((torch.flatten(v[i].T) @ torch.flatten(hv[i])) / - (torch.flatten(v[i].T) @ torch.flatten(v[i]))).item() for i in range(len(hv))] + return [ + ((torch.flatten(v[i].T) @ torch.flatten(hv[i])) / (torch.flatten(v[i].T) @ torch.flatten(v[i]))).item() + for i in range(len(hv)) + ] def calculate_hessian_dominant_eigen_values( - model: nn.Module, - iter_num: int, - dataloader: DataLoader, - device: str + model: nn.Module, iter_num: int, dataloader: DataLoader, device: str ) -> Dict: """ Overview: diff --git a/fling/utils/visualize_utils/loss_landscape.py b/fling/utils/visualize_utils/loss_landscape.py index 8c33ef1..69089bf 100644 --- a/fling/utils/visualize_utils/loss_landscape.py +++ b/fling/utils/visualize_utils/loss_landscape.py @@ -18,7 +18,7 @@ def _gen_rand_like(tensor: torch.Tensor) -> torch.Tensor: def _calc_loss_value( - model: nn.Module, data_loader: DataLoader, device: str, criterion: nn.Module = nn.CrossEntropyLoss() + model: nn.Module, data_loader: DataLoader, device: str, criterion: nn.Module = nn.CrossEntropyLoss() ): # Given a model and corresponding dataset, calculate the mean loss value. model = model.to(device) @@ -34,19 +34,19 @@ def _calc_loss_value( def plot_2d_loss_landscape( - model: nn.Module, - dataloader: DataLoader, - device: str, - caption: str, - save_path: str, - target_model1: Optional[nn.Module] = None, - target_model2: Optional[nn.Module] = None, - parameter_args: Dict = {"name": "all"}, - noise_range: Tuple[float, float] = (-1, 1), - resolution: int = 20, - visualize: bool = False, - log_scale: bool = False, - max_val: float = 5, + model: nn.Module, + dataloader: DataLoader, + device: str, + caption: str, + save_path: str, + target_model1: Optional[nn.Module] = None, + target_model2: Optional[nn.Module] = None, + parameter_args: Dict = {"name": "all"}, + noise_range: Tuple[float, float] = (-1, 1), + resolution: int = 20, + visualize: bool = False, + log_scale: bool = False, + max_val: float = 5, ) -> None: """ Overview: