-
Notifications
You must be signed in to change notification settings - Fork 218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add federated contrastive learning baseline SimCLR and its linear prob evaluation #278
Conversation
|
||
|
||
class SimCLRTransform(): | ||
def __init__(self, is_sup, image_size=32): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please provide python-style docstring for the newly added classes and functions
transform_train = SimCLRTransform(is_sup=False, image_size=32) | ||
transform_test = T.Compose([ | ||
T.ToTensor(), | ||
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it conventional to use 0.5 rather than the sample mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting both parameters to 0.5 and using with T.totensor() can force the data to be scaled to the [-1,1] interval
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here the first arg is mean of the signals, to my knowledge, it is usually calculated from the available examples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
T.totensor() can take the value of image from 0-255 to 0-1, and T.Normalize(0.5,0.5) can take 0-1 to -1-1 using function (x-mean)/std
federatedscope/cl/model/SimCLR.py
Outdated
|
||
|
||
class Bottleneck(nn.Module): | ||
expansion = 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it necessary to use class attribute?
representations = torch.cat([z1, z2], dim=0) | ||
similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1) | ||
|
||
l_pos = torch.diag(similarity_matrix, N) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think, up to now, similarity_matrix
is a 2N-by-2N matrix. Am I wrong? Why do we need to take the the above and below main diagonal?
federatedscope/cl/trainer/trainer.py
Outdated
# print(len(x), x[0].size(), x[1].size(), label.size()) | ||
x1, x2 = x[0], x[1] | ||
z1, z2 = ctx.model(x1, x2) | ||
if len(label.size()) == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when will we enter such a branch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we enter this branch in contrastive learning with two augment data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean when does the length of the size of label
become zero
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I follow the torch_trainer and add this branch. Should I remove it?
@@ -10,7 +10,7 @@ def get_optimizer(model, type, lr, **kwargs): | |||
if isinstance(type, str): | |||
if hasattr(torch.optim, type): | |||
if isinstance(model, torch.nn.Module): | |||
return getattr(torch.optim, type)(model.parameters(), lr, | |||
return getattr(torch.optim, type)(filter(lambda p: p.requires_grad, model.parameters()), lr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is any pfl algo affected by such a change? @yxdyc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any code snippet copied from or somewhat inspired by other place, please provide a copyright for your files.
This pr includes a new trainer, which is designed for conducting contrastive learning. @DavdGao could you have a look at that part for us? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that no readily available splitter has been adopted in your exp, right? So how do you construct the non-iidness? We have planed to start with the LDA splitter. Please conduct the exp accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see the inline comments and follow the contributor rule to format your code.
config = config | ||
return data_dict, config | ||
|
||
def Cifar4LP(config): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicated code from line 118 to line 156.
federatedscope/cl/model/SimCLR.py
Outdated
|
||
|
||
# Model class | ||
class ResNet(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please import ResNet from https://github.com/alibaba/FederatedScope/blob/master/federatedscope/contrib/model/resnet.py if there is no other concern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this PR #267 , a ResNet model is already added in federatedscope/contrib/model/resnet.py
. Please check if we still need to add a new resnet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shell scripts for reproducing results of standalone and fedavg should be provided.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see the inline comments and keep the code consistent with the master branch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a new unit test for the new trainer and dataset
federatedscope/core/aggregator.py
Outdated
@@ -104,12 +104,37 @@ def _para_weighted_avg(self, models, recover_fun=None): | |||
return avg_model | |||
|
|||
|
|||
class NoCommunicationAggregator(Aggregator): | |||
class NoCommunicationAggregator(ClientsAvgAggregator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yxdyc Please have a look at this change to local mode, thanks.
# Split data into dict | ||
data_dict = dict() | ||
|
||
# Splitter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please refer to splitter
# Build dict of Dataloader |
federatedscope/cl/model/SimCLR.py
Outdated
@@ -0,0 +1,235 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file looks like a copy from https://github.com/akhilmathurs/orchestra/blob/main/models.py.
Please consider the copyright issues.
@@ -0,0 +1,41 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -0,0 +1,190 @@ | |||
import math |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To reproduce baselines, the sample_splitter
is needed.
https://github.com/akhilmathurs/orchestra/blob/228d7a6379b6788e7dc288d3a9557d62b940c47a/utils.py#L44
|
T.RandomResizedCrop(32, scale=(0.5, 1.0), interpolation=T.InterpolationMode.BICUBIC), | ||
T.RandomHorizontalFlip(p=0.5), | ||
T.ToTensor(), | ||
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am still very curious about why we have to use such a mean and std. If it is a conventional usage in CL, please explain for us. The ultimate image classification task does not use this transformation, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
T.totensor() can take the value of image from 0-255 to 0-1, and T.Normalize(0.5,0.5) can take 0-1 to -1-1 using function (x-mean)/std. the data augement is time-costing and use sample mean and std wil take more time.
splitter = get_splitter(config) | ||
data_train = splitter(data_train) | ||
data_val = data_train | ||
data_test = splitter(data_test) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although the original train and test data of CIFAR10 are iid, how to ensure that splitting them by our splitter respectively can keep the train and test data of a specific client iid?
No description provided.