Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FedGlobalContrast and FedSimCLR baseline #354

Merged
merged 56 commits into from
Nov 8, 2022
Merged

Conversation

xkxxfyf
Copy link
Contributor

@xkxxfyf xkxxfyf commented Aug 30, 2022

current problem: global calculate contrast loss is slow because of epoch data size, but it will take more communication cost with batch size calculate global contrast loss

@rayrayraykk rayrayraykk added the Feature New feature label Aug 31, 2022
@joneswong joneswong self-assigned this Aug 31, 2022
# Split data into dict
data_dict = dict()
splitter = get_splitter(config)
data_train = splitter(data_train)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

splitting data_train and data_test respectively cannot ensure the iid property for the train and test of the same client.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rayrayraykk any suggestion to modify the splitter to enable splitting (data1, data2, ...) with the same categorical distributions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have the args to control the dist.
See

splitter(dataset[split], prior=train_label_distribution)):
.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a paramater to decide using the same splitter of training and evaluation or not.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep the train val test splits iid:

def __call__(self, dataset, prior=None):

for idx in range(1, self._cfg.federate.client_num + 1)
}

def _register_default_handlers(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to call base class's method to execute the first four registrations?

if other_client_id != client_id]
# print("start cal loss")
self.loss_list[client_id] = global_loss_fn(z1, z2, others_z2)
print(self.loss_list[client_id])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no print!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's deleted

GlobalContrastFL(Fedgc) Client receive aggregated model weight from server then update local
weight; it also receive global loss from server to train model and update weight locally.
"""
def _register_default_handlers(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to call base class's method to avoid repeating several of these lines?

round, sender, content = message.state, message.sender, message.content
global_loss = content['global_loss']
model_para = self.trainer.train_with_global_loss(global_loss)
self.trainer.update(model_para)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is extremely confusing to update the local model by a state_dict produced by itself.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will be deleted

Copy link
Collaborator

@joneswong joneswong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. It is ok to alternatively make local and global updates. However, we have suggested to enable a combination of these losses in one gradient descent step, but it seems that this implementation still fails to do that. 2. There are still many issues remaining. 3. This pr would not pass the check of a linter, imo, right? @rayrayraykk .

The quality of this pr is exceptionally low. Please complete it ASAP.

@rayrayraykk
Copy link
Collaborator

  1. It is ok to alternatively make local and global updates. However, we have suggested to enable a combination of these losses in one gradient descent step, but it seems that this implementation still fails to do that. 2. There are still many issues remaining. 3. This pr would not pass the check of a linter, imo, right? @rayrayraykk .

The quality of this pr is exceptionally low. Please complete it ASAP.

Yes, the unit-test provided still does not work.

Copy link
Collaborator

@rayrayraykk rayrayraykk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Run pre-commit run --all-files and fix the format issues before merging.
And please see the inline comments.

If done, please remove [WIP] in the title of this PR.

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import networkx as nx
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete if never used

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

@@ -0,0 +1,105 @@
import torch
import numpy as np
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete if never used

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure


return loss

# def compute_global_NT_xentloss(z1, z2, others_z2=[], temperature=0.5, device='cpu'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete if never used

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

from federatedscope.core.workers.server import Server
from federatedscope.core.auxiliaries.utils import merge_dict
from federatedscope.cl.fedgc.utils import global_NT_xentloss
from torchviz import make_dot, make_dot_from_trace
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add torchviz to the dependency [cl], and I can't run CL with the minimal version of FS.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's deleted in new pr

return data, modified_config


register_data("Cifar4CL", load_cifar_dataset)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you use register_data, your function should be taken into two args:
def load_cifar_dataset(config, client_cfgs=None):

Copy link
Contributor Author

@xkxxfyf xkxxfyf Oct 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have deleted register_data and finished unit test

@xkxxfyf xkxxfyf changed the title [WIP] FedGlobalContrast and FedSimCLR baseline FedGlobalContrast and FedSimCLR baseline Oct 31, 2022
@xkxxfyf xkxxfyf requested a review from joneswong October 31, 2022 11:52
Copy link
Collaborator

@joneswong joneswong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The functionalities have been validated. However, the performance of FedSimCLR cannot be reproduced exactly. We annotate this as a TODO.

@joneswong joneswong merged commit 94e0d97 into alibaba:master Nov 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature New feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants