Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor data-related interfaces & add interfaces for trainer and worker #365

Merged
merged 43 commits into from
Oct 19, 2022

Conversation

rayrayraykk
Copy link
Collaborator

@rayrayraykk rayrayraykk commented Sep 6, 2022

Main Changes

Data

Data module overview

federatedscope.core.auxiliaries.data_builder.get_data -700 lines!

  • Load Dataset (federatedscope.core.data.utils.load_dataset):
    • Load local file to torch dataset, -300 lines!
  • Translate data (federatedscope.core.data.BaseDataTranslator)
    • Dataset -> ML split -> FL split -> DataLoader for FS
  • Convert mode(federatedscope.core.data.utils.convert_data_mode)
    • To adapt simulation mode and distributed mode

Data Translator

Dataset -> (ML split) -> (FL split) -> DataLoader for FS

  • ML split(split_train_val_test):
    • Build train/val/test
  • FL split(split_to_client):
    • Build Data for each client

Data interface

  • ClientData(federatedscope.core.data.ClientData)

    • A subclass of dict with train, val and test.

    • Convert dataset to DataLoader.

      • cfg.dataloader.type, cfg.dataloader.batch_size, cfg.dataloader.shuffle
    • Example:

      • # Instantiate client_data for each Client
        client_data = ClientData(PyGDataLoader, 
                                 cfg, 
                                 train=train_data, 
                                 val=None, 
                                 test=test_data)
        # other_cfg with different batch size
        client_data.setup(other_cfg)
        print(client_data)
        
        >> {'train': PyGDataLoader(train_data), 'test': PyGDataLoader(test_data)}
  • StandaloneDataDict(federatedscope.core.data.StandaloneDataDict)

    • A subclass of dict with client_id as keys:
      • {1: ClientData, 2: ClientData, ...}
    • Responsible for training/evaluation method conversion:
      • Global evaluation, global training, etc.

Trainer

  • BaseTrainer

    • Code without hook-like functions (For starters)

    • Example:

      class BaseTrainer(abc.ABC):
          def __init__(self, model, data, device, **kwargs):
              self.model = model
              self.data = data
              self.device = device
              self.kwargs = kwargs
      
          @abc.abstractmethod
          def train(self):
              raise NotImplementedError
      
          @abc.abstractmethod
          def evaluate(self, target_data_split_name='test'):
              raise NotImplementedError
      
          @abc.abstractmethod
          def update(self, model_parameters, strict=False):
              raise NotImplementedError
      
          @abc.abstractmethod
          def get_model_para(self):
              raise NotImplementedError
      
          @abc.abstractmethod
          def print_trainer_meta_info(self):
              raise NotImplementedError

    Example for a trainer without hook-like functions

    • Less than 100 lines Vs. 300+ lines

FedRunner

  • Move data process to Data module
    • merge data, global eval, etc.

Worker

  • BaseClient & BaseServer

    • Example

      class BaseClient(Worker):
          def __init__(self, ID, state, config, model, strategy):
              super(BaseClient, self).__init__(ID, state, config, model, strategy)
              self.msg_handlers = dict()
      
          def register_handlers(self, msg_type, callback_func):
            	if msg_type in self.msg_handlers.keys():
                	logger.warning(f"Overwriting msg_handlers {msg_type}.")
              self.msg_handlers[msg_type] = callback_func
      
          def _register_default_handlers(self):
      				pass
      
          @abc.abstractmethod
          def run(self):
              raise NotImplementedError
      
          @abc.abstractmethod
          def callback_funcs_for_model_para(self, message):
              raise NotImplementedError
      
          @abc.abstractmethod
          def callback_funcs_for_assign_id(self, message):
              raise NotImplementedError
      
          @abc.abstractmethod
          def callback_funcs_for_join_in_info(self, message):
              raise NotImplementedError
      
          @abc.abstractmethod
          def callback_funcs_for_address(self, message):
              raise NotImplementedError
      
          @abc.abstractmethod
          def callback_funcs_for_evaluate(self, message):
              raise NotImplementedError
      
          @abc.abstractmethod
          def callback_funcs_for_finish(self, message):
              raise NotImplementedError
      
          @abc.abstractmethod
          def callback_funcs_for_converged(self, message):
              raise NotImplementedError

@rayrayraykk rayrayraykk linked an issue Sep 7, 2022 that may be closed by this pull request
@rayrayraykk rayrayraykk added enhancement New feature or request Feature New feature and removed Feature New feature labels Sep 9, 2022
@rayrayraykk rayrayraykk changed the title Refactor data-related interfaces [WIP]Refactor data-related interfaces Sep 13, 2022
@rayrayraykk rayrayraykk changed the title [WIP]Refactor data-related interfaces Refactor data-related interfaces Sep 13, 2022
@rayrayraykk rayrayraykk changed the title Refactor data-related interfaces [WIP] Refactor data-related interfaces Sep 14, 2022
Copy link
Collaborator

@joneswong joneswong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally, looks good to me. As all this design and implementations have been discussed before, there are just two minor questions I wanna discuss: (1) do these class names make sense to you? e.g., the dict containing client-wise data is called standalone, the class responsible for converting a vanilla torch dataset into a fed counterpart is called translator, etc. (2) is it necessary to make the methods and attributes that would not be accessed outside the class private (i.e., self._balabala)? @xieyxclack @yxdyc @DavdGao @Osier-Yi

@joneswong joneswong self-assigned this Oct 3, 2022
joneswong
joneswong previously approved these changes Oct 10, 2022
Copy link
Collaborator

@joneswong joneswong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@xieyxclack
Copy link
Collaborator

@xieyxclack xieyxclack closed this Oct 10, 2022
@xieyxclack xieyxclack reopened this Oct 10, 2022
xieyxclack
xieyxclack previously approved these changes Oct 10, 2022
Copy link
Collaborator

@xieyxclack xieyxclack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, please see the inline comments for minor suggestions, thx!

self.model.load_state_dict(model_parameters, strict)

def get_model_para(self):
return self.model.cpu().state_dict()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we move it to cpu?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I follow the principle of GeneralTorchTrainer, see here.

raise NotImplementedError

@abc.abstractmethod
def print_trainer_meta_info(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is necessary for an FL course?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is called in FedRunner.

@@ -1,7 +1,7 @@
from federatedscope.register import register_data


def MyData(config):
def MyData(config, client_cfgs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we provide client_cfgs in MyData?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the client_cfgs will useful in MyData when using personalized cfg. (cfg.dataloader.batch_size varies.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we can set client_cfgs=None by default here since it just serves as an example.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree!

loss = self.criterion(outputs, y)

# _hook_on_batch_backward
self.optimizer.zero_grad()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest putting this line before the forward process (i.e., outputs = self.model(x)).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, should we modify GeneralTorchTrainer, which I follow the principle of.


# _hook_on_fit_end
return num_samples, self.model.cpu().state_dict(), \
{'loss_total': total_loss}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can provide the avg_loss here

@@ -10,9 +10,7 @@ federate:
data:
root: data/
type: shakespeare
batch_size: 64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

@@ -11,9 +11,7 @@ data:
root: data/
type: femnist
splits: [0.6,0.2,0.2]
batch_size: 64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

@@ -11,9 +11,7 @@ data:
root: data/
type: femnist
splits: [0.6,0.2,0.2]
batch_size: 64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

@@ -11,7 +11,6 @@ federate:
data:
root: data/
type: synthetic
batch_size: 64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

@@ -8,9 +8,7 @@ federate:
data:
root: data/
type: shakespeare
batch_size: 64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

@xieyxclack
Copy link
Collaborator

xieyxclack commented Oct 10, 2022

@rayrayraykk @joneswong
(1) IMO, some names such as standalone and translator might be confusing for new users, but we can provide docs and comments to explain. (And I cannot give better names for them at this time)
(2) Since we have not distinguished private methods/attributes from others in this version most time, maybe we can remain it as a TODO item and fix them later

yxdyc
yxdyc previously approved these changes Oct 13, 2022
Copy link
Collaborator

@yxdyc yxdyc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

datadict = self.split_to_client(train, val, test)
return datadict

def split_train_val_test(self, dataset):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what is the recommended way to implement customized split funcs? e.g., VMF and HMF datasets need to be split according to user/item ids.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that way, we use a dummysplitter, where we treat MF datasets as FL datasets.

Copy link
Collaborator

@DavdGao DavdGao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see the inline comments

@@ -616,22 +584,22 @@ def get_data(config):
# will restore the user-specified on after the generation
setup_seed(12345)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This random seed is out of the control of cfg.seed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The seed is set to generate data, we'd better keep is as fixed or we could use a cfg.data.seed instead of cfg.seed.


# DataLoader related args
cfg.dataloader = CN()
cfg.dataloader.type = 'base'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe list all the options for cfg.dataloader.type in annotation

cfg.dataloader.batch_size = 64
cfg.dataloader.shuffle = True
cfg.dataloader.num_workers = 0
cfg.dataloader.drop_last = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is drop_last only valid for training dataloader?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestion! I've updated it accordingly.

# Split train/val/test to client
if len(train) > 0:
split_train = self.splitter(train)
if self.global_cfg.data.consistent_label_distribution:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the meaning of consistent_label_distribution?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using splitter, the train/test/val might be non-iid. With consistent_label_distribution set True, the ML split is iid.

self.kwargs = kwargs

@abc.abstractmethod
def train(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need a abstract method for fintuning?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think finetuning is optional, and if someone needs ft, he could implement finetuning in his way.

# Noise multiplier
tmp = cfg.sgdmf.constant * np.power(sample_ratio, 2) * (
cfg.federate.total_round_num * ctx.num_total_train_batch) * np.log(
1. / cfg.sgdmf.delta)
noise_multipler = np.sqrt(tmp / np.power(cfg.sgdmf.epsilon, 2))
ctx.scale = max(cfg.sgdmf.theta, 1.) * noise_multipler * np.power(
ctx.scale = max(cfg.dataloader.theta, 1.) * noise_multipler * np.power(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since theta is only used in sgdmf, is it appropriate to place theta under the namespace of dataloader? (that all users can see it)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add a docstring to explain this, as many other args (sizes is graph-related) in dataloader are optional.

@rayrayraykk
Copy link
Collaborator Author

I change the version to 0.2.1.

Copy link
Collaborator

@joneswong joneswong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approved.

@joneswong joneswong merged commit 84a3722 into alibaba:master Oct 19, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Encapsulation of Trainer class Cannot set different data related parameters for different clients.
5 participants