-
Notifications
You must be signed in to change notification settings - Fork 214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor data-related interfaces & add interfaces for trainer and worker #365
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generally, looks good to me. As all this design and implementations have been discussed before, there are just two minor questions I wanna discuss: (1) do these class names make sense to you? e.g., the dict containing client-wise data is called standalone, the class responsible for converting a vanilla torch dataset into a fed counterpart is called translator, etc. (2) is it necessary to make the methods and attributes that would not be accessed outside the class private (i.e., self._balabala
)? @xieyxclack @yxdyc @DavdGao @Osier-Yi
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, please see the inline comments for minor suggestions, thx!
self.model.load_state_dict(model_parameters, strict) | ||
|
||
def get_model_para(self): | ||
return self.model.cpu().state_dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we move it to cpu?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I follow the principle of GeneralTorchTrainer
, see here.
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
def print_trainer_meta_info(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is necessary for an FL course?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is called in FedRunner
.
@@ -1,7 +1,7 @@ | |||
from federatedscope.register import register_data | |||
|
|||
|
|||
def MyData(config): | |||
def MyData(config, client_cfgs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we provide client_cfgs
in MyData
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the client_cfgs
will useful in MyData
when using personalized cfg
. (cfg.dataloader.batch_size
varies.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we can set client_cfgs=None
by default here since it just serves as an example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree!
loss = self.criterion(outputs, y) | ||
|
||
# _hook_on_batch_backward | ||
self.optimizer.zero_grad() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest putting this line before the forward process (i.e., outputs = self.model(x)
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If so, should we modify GeneralTorchTrainer, which I follow the principle of.
|
||
# _hook_on_fit_end | ||
return num_samples, self.model.cpu().state_dict(), \ | ||
{'loss_total': total_loss} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can provide the avg_loss
here
@@ -10,9 +10,7 @@ federate: | |||
data: | |||
root: data/ | |||
type: shakespeare | |||
batch_size: 64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
@@ -11,9 +11,7 @@ data: | |||
root: data/ | |||
type: femnist | |||
splits: [0.6,0.2,0.2] | |||
batch_size: 64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
@@ -11,9 +11,7 @@ data: | |||
root: data/ | |||
type: femnist | |||
splits: [0.6,0.2,0.2] | |||
batch_size: 64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
@@ -11,7 +11,6 @@ federate: | |||
data: | |||
root: data/ | |||
type: synthetic | |||
batch_size: 64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
@@ -8,9 +8,7 @@ federate: | |||
data: | |||
root: data/ | |||
type: shakespeare | |||
batch_size: 64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
@rayrayraykk @joneswong |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
datadict = self.split_to_client(train, val, test) | ||
return datadict | ||
|
||
def split_train_val_test(self, dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So what is the recommended way to implement customized split funcs? e.g., VMF and HMF datasets need to be split according to user/item ids.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that way, we use a dummysplitter
, where we treat MF datasets as FL datasets.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see the inline comments
@@ -616,22 +584,22 @@ def get_data(config): | |||
# will restore the user-specified on after the generation | |||
setup_seed(12345) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This random seed is out of the control of cfg.seed
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The seed is set to generate data, we'd better keep is as fixed or we could use a cfg.data.seed
instead of cfg.seed
.
|
||
# DataLoader related args | ||
cfg.dataloader = CN() | ||
cfg.dataloader.type = 'base' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe list all the options for cfg.dataloader.type
in annotation
cfg.dataloader.batch_size = 64 | ||
cfg.dataloader.shuffle = True | ||
cfg.dataloader.num_workers = 0 | ||
cfg.dataloader.drop_last = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is drop_last
only valid for training dataloader?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your suggestion! I've updated it accordingly.
# Split train/val/test to client | ||
if len(train) > 0: | ||
split_train = self.splitter(train) | ||
if self.global_cfg.data.consistent_label_distribution: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the meaning of consistent_label_distribution?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When using splitter, the train/test/val might be non-iid. With consistent_label_distribution
set True
, the ML split is iid.
self.kwargs = kwargs | ||
|
||
@abc.abstractmethod | ||
def train(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need a abstract method for fintuning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think finetuning is optional, and if someone needs ft, he could implement finetuning in his way.
# Noise multiplier | ||
tmp = cfg.sgdmf.constant * np.power(sample_ratio, 2) * ( | ||
cfg.federate.total_round_num * ctx.num_total_train_batch) * np.log( | ||
1. / cfg.sgdmf.delta) | ||
noise_multipler = np.sqrt(tmp / np.power(cfg.sgdmf.epsilon, 2)) | ||
ctx.scale = max(cfg.sgdmf.theta, 1.) * noise_multipler * np.power( | ||
ctx.scale = max(cfg.dataloader.theta, 1.) * noise_multipler * np.power( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since theta
is only used in sgdmf, is it appropriate to place theta
under the namespace of dataloader? (that all users can see it)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can add a docstring to explain this, as many other args (sizes
is graph-related) in dataloader
are optional.
I change the version to 0.2.1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
approved.
Main Changes
Data
Data module overview
federatedscope.core.auxiliaries.data_builder.get_data
-700 lines!federatedscope.core.data.utils.load_dataset
):federatedscope.core.data.BaseDataTranslator
)federatedscope.core.data.utils.convert_data_mode
)Data Translator
Dataset -> (ML split) -> (FL split) -> DataLoader for FS
split_train_val_test
):split_to_client
):Data interface
ClientData(
federatedscope.core.data.ClientData
)A subclass of
dict
withtrain
,val
andtest
.Convert dataset to DataLoader.
cfg.dataloader.type
,cfg.dataloader.batch_size
,cfg.dataloader.shuffle
Example:
StandaloneDataDict(
federatedscope.core.data.StandaloneDataDict
)dict
with client_id as keys:{1: ClientData, 2: ClientData, ...}
Trainer
BaseTrainer
Code without hook-like functions (For starters)
Example:
Example for a trainer without hook-like functions
FedRunner
merge data
,global eval
, etc.Worker
BaseClient & BaseServer
Example