-
Notifications
You must be signed in to change notification settings - Fork 214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FedGlobalContrast and FedSimCLR baseline #354
Conversation
# Split data into dict | ||
data_dict = dict() | ||
splitter = get_splitter(config) | ||
data_train = splitter(data_train) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
splitting data_train
and data_test
respectively cannot ensure the iid property for the train and test of the same client.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rayrayraykk any suggestion to modify the splitter to enable splitting (data1, data2, ...) with the same categorical distributions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have the args to control the dist.
See
splitter(dataset[split], prior=train_label_distribution)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a paramater to decide using the same splitter of training and evaluation or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keep the train val test splits iid:
def __call__(self, dataset, prior=None): |
for idx in range(1, self._cfg.federate.client_num + 1) | ||
} | ||
|
||
def _register_default_handlers(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to call base class's method to execute the first four registrations?
federatedscope/cl/fedgc/server.py
Outdated
if other_client_id != client_id] | ||
# print("start cal loss") | ||
self.loss_list[client_id] = global_loss_fn(z1, z2, others_z2) | ||
print(self.loss_list[client_id]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no print!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's deleted
GlobalContrastFL(Fedgc) Client receive aggregated model weight from server then update local | ||
weight; it also receive global loss from server to train model and update weight locally. | ||
""" | ||
def _register_default_handlers(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to call base class's method to avoid repeating several of these lines?
round, sender, content = message.state, message.sender, message.content | ||
global_loss = content['global_loss'] | ||
model_para = self.trainer.train_with_global_loss(global_loss) | ||
self.trainer.update(model_para) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is extremely confusing to update the local model by a state_dict produced by itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it will be deleted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- It is ok to alternatively make local and global updates. However, we have suggested to enable a combination of these losses in one gradient descent step, but it seems that this implementation still fails to do that. 2. There are still many issues remaining. 3. This pr would not pass the check of a linter, imo, right? @rayrayraykk .
The quality of this pr is exceptionally low. Please complete it ASAP.
Yes, the unit-test provided still does not work. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Run pre-commit run --all-files
and fix the format issues before merging.
And please see the inline comments.
If done, please remove [WIP] in the title of this PR.
federatedscope/cl/fedgc/utils.py
Outdated
import numpy as np | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import networkx as nx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete if never used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
federatedscope/cl/fedgc/utils.py
Outdated
@@ -0,0 +1,105 @@ | |||
import torch | |||
import numpy as np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete if never used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
federatedscope/cl/fedgc/utils.py
Outdated
|
||
return loss | ||
|
||
# def compute_global_NT_xentloss(z1, z2, others_z2=[], temperature=0.5, device='cpu'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete if never used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
federatedscope/cl/fedgc/server.py
Outdated
from federatedscope.core.workers.server import Server | ||
from federatedscope.core.auxiliaries.utils import merge_dict | ||
from federatedscope.cl.fedgc.utils import global_NT_xentloss | ||
from torchviz import make_dot, make_dot_from_trace |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add torchviz
to the dependency [cl], and I can't run CL with the minimal version of FS.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's deleted in new pr
return data, modified_config | ||
|
||
|
||
register_data("Cifar4CL", load_cifar_dataset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As you use register_data
, your function should be taken into two args:
def load_cifar_dataset(config, client_cfgs=None):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have deleted register_data
and finished unit test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The functionalities have been validated. However, the performance of FedSimCLR cannot be reproduced exactly. We annotate this as a TODO.
current problem: global calculate contrast loss is slow because of epoch data size, but it will take more communication cost with batch size calculate global contrast loss