Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor data-related interfaces & add interfaces for trainer and worker #365

Merged
merged 43 commits into from
Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
e9b3a8c
Refactor data-related interfaces
rayrayraykk Sep 6, 2022
7aa3387
fix minor bugs
rayrayraykk Sep 7, 2022
30f81d6
move interface
rayrayraykk Sep 7, 2022
0e85138
move data translator
rayrayraykk Sep 7, 2022
0c7cd41
rename file
rayrayraykk Sep 7, 2022
8d6b1c3
Merge branch 'alibaba:master' into refactor_data
rayrayraykk Sep 13, 2022
1477c61
add README for data protocal
rayrayraykk Sep 13, 2022
d9b1dc0
update interface of data translator
rayrayraykk Sep 14, 2022
2dd4bca
move toy to tabular folder
rayrayraykk Sep 14, 2022
c7ed566
WIP
rayrayraykk Sep 15, 2022
bb2f044
[WIP] refactor clientdata, TODO: update yaml, apply translator
rayrayraykk Sep 15, 2022
fc97a52
update yaml
rayrayraykk Sep 16, 2022
7e60fdc
remove torch in dataset
rayrayraykk Sep 16, 2022
c87c428
fix node trainer
rayrayraykk Sep 16, 2022
4e9709e
update graph-level interface
rayrayraykk Sep 16, 2022
e9e12a6
fix minor bugs
rayrayraykk Sep 16, 2022
fb2b15e
fix link graph
rayrayraykk Sep 16, 2022
94755da
fix link trainer bugs
rayrayraykk Sep 16, 2022
a032337
fix linktrainer
rayrayraykk Sep 16, 2022
3cea5f2
merge master
rayrayraykk Sep 16, 2022
577f167
add kwargs for splitter
rayrayraykk Sep 16, 2022
c283108
fix graph-level bugs
rayrayraykk Sep 16, 2022
85ae82d
fix minor bugs
rayrayraykk Sep 20, 2022
a7b2568
roll back
rayrayraykk Sep 20, 2022
482beb5
fix fedsageplus
rayrayraykk Sep 20, 2022
26716d7
fix minor bugs
rayrayraykk Sep 20, 2022
8eee1c7
add centralized torch trainer
rayrayraykk Sep 20, 2022
dc91315
update docs
rayrayraykk Sep 20, 2022
b42463c
add abc for client and server
rayrayraykk Sep 20, 2022
7860115
fix minor bug
rayrayraykk Sep 27, 2022
070b432
minor changes
rayrayraykk Oct 8, 2022
e6fcf99
fix bugs
rayrayraykk Oct 8, 2022
f81b17d
fix minor bug
rayrayraykk Oct 9, 2022
b0f70bb
Merge branch 'master' into refactor_data
rayrayraykk Oct 9, 2022
ba5df54
update merge_data according to #385
rayrayraykk Oct 9, 2022
edb4d4b
remove unnecessary clone of cfg
rayrayraykk Oct 10, 2022
b99cc31
minor changes
rayrayraykk Oct 10, 2022
551fe3a
fix minor bug
rayrayraykk Oct 10, 2022
05c6514
Update dataloader_builder.py
rayrayraykk Oct 13, 2022
8875422
update version
rayrayraykk Oct 17, 2022
e25a4eb
retriger UT
rayrayraykk Oct 17, 2022
bb9a9e9
Merge branch 'alibaba:master' into refactor_data
rayrayraykk Oct 19, 2022
47c89b1
fix format
rayrayraykk Oct 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,13 @@ Note that FederatedScope provides a unified interface for both standalone mode a

The standalone mode in FederatedScope means to simulate multiple participants (servers and clients) in a single device, while participants' data are isolated from each other and their models might be shared via message passing.

Here we demonstrate how to run a standard FL task with FederatedScope, with setting `cfg.data.type = 'FEMNIST'`and `cfg.model.type = 'ConvNet2'` to run vanilla FedAvg for an image classification task. Users can customize training configurations, such as `cfg.federated.total_round_num`, `cfg.data.batch_size`, and `cfg.train.optimizer.lr`, in the configuration (a .yaml file), and run a standard FL task as:
Here we demonstrate how to run a standard FL task with FederatedScope, with setting `cfg.data.type = 'FEMNIST'`and `cfg.model.type = 'ConvNet2'` to run vanilla FedAvg for an image classification task. Users can customize training configurations, such as `cfg.federated.total_round_num`, `cfg.dataloader.batch_size`, and `cfg.train.optimizer.lr`, in the configuration (a .yaml file), and run a standard FL task as:

```bash
# Run with default configurations
python federatedscope/main.py --cfg scripts/example_configs/femnist.yaml
# Or with custom configurations
python federatedscope/main.py --cfg scripts/example_configs/femnist.yaml federate.total_round_num 50 data.batch_size 128
python federatedscope/main.py --cfg scripts/example_configs/femnist.yaml federate.total_round_num 50 dataloader.batch_size 128
```

Then you can observe some monitored metrics during the training process as:
Expand Down
3 changes: 3 additions & 0 deletions environment/extra_dependencies_torch1.10-application.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ conda install -y nltk
conda install -y sentencepiece textgrid typeguard -c conda-forge
conda install -y transformers==4.16.2 tokenizers==0.10.3 datasets -c huggingface -c conda-forge
conda install -y torchtext -c pytorch

# Tabular
conda install -y openml==0.12.2
21 changes: 11 additions & 10 deletions federatedscope/attack/auxiliary/poisoning_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def load_poisoned_dataset_edgeset(data, ctx, mode):
poison_testset.append((transforms_funcs(sample), label))
data['poison_' + mode] = DataLoader(
poison_testset,
batch_size=ctx.data.batch_size,
batch_size=ctx.dataloader.batch_size,
shuffle=False,
num_workers=ctx.data.num_workers)
num_workers=ctx.dataloader.num_workers)

elif "CIFAR10" in ctx.data.type:
target_label = int(ctx.attack.target_label_ind)
Expand Down Expand Up @@ -91,9 +91,9 @@ def load_poisoned_dataset_edgeset(data, ctx, mode):
poison_testset.append((transforms_funcs(sample), label))
data['poison_' + mode] = DataLoader(
poison_testset,
batch_size=ctx.data.batch_size,
batch_size=ctx.dataloader.batch_size,
shuffle=False,
num_workers=ctx.data.num_workers)
num_workers=ctx.dataloader.num_workers)

else:
raise RuntimeError(
Expand Down Expand Up @@ -213,9 +213,9 @@ def load_poisoned_dataset_pixel(data, ctx, mode):
poisoned_dataset[iii] = (transforms_funcs(sample), label)

data[mode] = DataLoader(poisoned_dataset,
batch_size=ctx.data.batch_size,
batch_size=ctx.dataloader.batch_size,
shuffle=True,
num_workers=ctx.data.num_workers)
num_workers=ctx.dataloader.num_workers)

if mode == MODE.TEST or mode == MODE.VAL:
poisoned_dataset = addTrigger(data[mode].dataset,
Expand All @@ -234,10 +234,11 @@ def load_poisoned_dataset_pixel(data, ctx, mode):
# (channel, height, width) = sample.shape #(c,h,w)
poisoned_dataset[iii] = (transforms_funcs(sample), label)

data['poison_' + mode] = DataLoader(poisoned_dataset,
batch_size=ctx.data.batch_size,
shuffle=False,
num_workers=ctx.data.num_workers)
data['poison_' + mode] = DataLoader(
poisoned_dataset,
batch_size=ctx.dataloader.batch_size,
shuffle=False,
num_workers=ctx.dataloader.num_workers)

return data

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from federatedscope.core.trainers import GeneralTorchTrainer
from federatedscope.core.auxiliaries.dataloader_builder import WrapDataset
from federatedscope.core.data.wrap_dataset import WrapDataset
from federatedscope.attack.auxiliary.MIA_get_target_data import get_target_data

logger = logging.getLogger(__name__)
Expand Down
7 changes: 0 additions & 7 deletions federatedscope/attack/trainer/benign_trainer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
from calendar import c
import logging
from typing import Type
import torch
import numpy as np

from federatedscope.core.trainers import GeneralTorchTrainer
from federatedscope.core.auxiliaries.transform_builder import get_transform
from federatedscope.attack.auxiliary.backdoor_utils import normalize
from federatedscope.core.auxiliaries.dataloader_builder import WrapDataset
from federatedscope.core.auxiliaries.dataloader_builder import get_dataloader
from federatedscope.core.auxiliaries.ReIterator import ReIterator

logger = logging.getLogger(__name__)

Expand Down
7 changes: 4 additions & 3 deletions federatedscope/contrib/data/example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from federatedscope.register import register_data


def MyData(config):
def MyData(config, client_cfgs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we provide client_cfgs in MyData?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the client_cfgs will useful in MyData when using personalized cfg. (cfg.dataloader.batch_size varies.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we can set client_cfgs=None by default here since it just serves as an example.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree!

r"""
Returns:
data:
Expand All @@ -17,12 +17,13 @@ def MyData(config):
"""
data = None
config = config
client_cfgs = client_cfgs
return data, config


def call_my_data(config):
def call_my_data(config, client_cfgs):
if config.data.type == "mydata":
data, modified_config = MyData(config)
data, modified_config = MyData(config, client_cfgs)
return data, modified_config


Expand Down
101 changes: 101 additions & 0 deletions federatedscope/contrib/trainer/torch_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import inspect
from federatedscope.register import register_trainer
from federatedscope.core.trainers import BaseTrainer

# An example for converting torch training process to FS training process

# Refer to `federatedscope.core.trainers.BaseTrainer` for interface.

# Try with FEMNIST:
# python federatedscope/main.py --cfg scripts/example_configs/femnist.yaml \
# trainer.type mytorchtrainer federate.sample_client_rate 0.01 \
# federate.total_round_num 5 eval.best_res_update_round_wise_key test_loss


class MyTorchTrainer(BaseTrainer):
def __init__(self, model, data, device, **kwargs):
import torch
# NN modules
self.model = model
# FS `ClientData` or your own data
self.data = data
# Device name
self.device = device
# kwargs
self.kwargs = kwargs
# Criterion & Optimizer
self.criterion = torch.nn.CrossEntropyLoss()
self.optimizer = torch.optim.SGD(self.model.parameters(),
lr=0.001,
momentum=0.9,
weight_decay=1e-4)

def train(self):
# _hook_on_fit_start_init
self.model.to(self.device)
self.model.train()

total_loss = num_samples = 0
# _hook_on_batch_start_init
for x, y in self.data['train']:
# _hook_on_batch_forward
x, y = x.to(self.device), y.to(self.device)
outputs = self.model(x)
loss = self.criterion(outputs, y)

# _hook_on_batch_backward
self.optimizer.zero_grad()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest putting this line before the forward process (i.e., outputs = self.model(x)).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, should we modify GeneralTorchTrainer, which I follow the principle of.

loss.backward()
self.optimizer.step()

# _hook_on_batch_end
total_loss += loss.item() * y.shape[0]
num_samples += y.shape[0]

# _hook_on_fit_end
return num_samples, self.model.cpu().state_dict(), \
{'loss_total': total_loss}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can provide the avg_loss here


def evaluate(self, target_data_split_name='test'):
import torch
with torch.no_grad():
self.model.to(self.device)
self.model.eval()
total_loss = num_samples = 0
# _hook_on_batch_start_init
for x, y in self.data[target_data_split_name]:
# _hook_on_batch_forward
x, y = x.to(self.device), y.to(self.device)
pred = self.model(x)
loss = self.criterion(pred, y)

# _hook_on_batch_end
total_loss += loss.item() * y.shape[0]
num_samples += y.shape[0]

# _hook_on_fit_end
return {
f'{target_data_split_name}_loss': total_loss,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

f'{target_data_split_name}_total': num_samples
}

def update(self, model_parameters, strict=False):
self.model.load_state_dict(model_parameters, strict)

def get_model_para(self):
return self.model.cpu().state_dict()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we move it to cpu?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I follow the principle of GeneralTorchTrainer, see here.


def print_trainer_meta_info(self):
sign = inspect.signature(self.__init__).parameters.values()
meta_info = tuple([(val.name, getattr(self, val.name))
for val in sign])
return f'{self.__class__.__name__}{meta_info}'


def call_my_torch_trainer(trainer_type):
if trainer_type == 'mytorchtrainer':
trainer_builder = MyTorchTrainer
return trainer_builder


register_trainer('mytorchtrainer', call_my_torch_trainer)
Loading