Skip to content

Commit

Permalink
Refactor data-related interfaces & add interfaces for trainer and wor…
Browse files Browse the repository at this point in the history
…ker (#365)
  • Loading branch information
rayrayraykk authored Oct 19, 2022
1 parent d2c889d commit 84a3722
Show file tree
Hide file tree
Showing 161 changed files with 2,218 additions and 1,481 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,13 @@ Note that FederatedScope provides a unified interface for both standalone mode a

The standalone mode in FederatedScope means to simulate multiple participants (servers and clients) in a single device, while participants' data are isolated from each other and their models might be shared via message passing.

Here we demonstrate how to run a standard FL task with FederatedScope, with setting `cfg.data.type = 'FEMNIST'`and `cfg.model.type = 'ConvNet2'` to run vanilla FedAvg for an image classification task. Users can customize training configurations, such as `cfg.federated.total_round_num`, `cfg.data.batch_size`, and `cfg.train.optimizer.lr`, in the configuration (a .yaml file), and run a standard FL task as:
Here we demonstrate how to run a standard FL task with FederatedScope, with setting `cfg.data.type = 'FEMNIST'`and `cfg.model.type = 'ConvNet2'` to run vanilla FedAvg for an image classification task. Users can customize training configurations, such as `cfg.federated.total_round_num`, `cfg.dataloader.batch_size`, and `cfg.train.optimizer.lr`, in the configuration (a .yaml file), and run a standard FL task as:

```bash
# Run with default configurations
python federatedscope/main.py --cfg scripts/example_configs/femnist.yaml
# Or with custom configurations
python federatedscope/main.py --cfg scripts/example_configs/femnist.yaml federate.total_round_num 50 data.batch_size 128
python federatedscope/main.py --cfg scripts/example_configs/femnist.yaml federate.total_round_num 50 dataloader.batch_size 128
```

Then you can observe some monitored metrics during the training process as:
Expand Down
3 changes: 3 additions & 0 deletions environment/extra_dependencies_torch1.10-application.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ conda install -y nltk
conda install -y sentencepiece textgrid typeguard -c conda-forge
conda install -y transformers==4.16.2 tokenizers==0.10.3 datasets -c huggingface -c conda-forge
conda install -y torchtext -c pytorch

# Tabular
conda install -y openml==0.12.2
2 changes: 1 addition & 1 deletion federatedscope/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import absolute_import, division, print_function

__version__ = '0.2.0'
__version__ = '0.2.1'


def _setup_logger():
Expand Down
21 changes: 11 additions & 10 deletions federatedscope/attack/auxiliary/poisoning_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def load_poisoned_dataset_edgeset(data, ctx, mode):
poison_testset.append((transforms_funcs(sample), label))
data['poison_' + mode] = DataLoader(
poison_testset,
batch_size=ctx.data.batch_size,
batch_size=ctx.dataloader.batch_size,
shuffle=False,
num_workers=ctx.data.num_workers)
num_workers=ctx.dataloader.num_workers)

elif "CIFAR10" in ctx.data.type:
target_label = int(ctx.attack.target_label_ind)
Expand Down Expand Up @@ -91,9 +91,9 @@ def load_poisoned_dataset_edgeset(data, ctx, mode):
poison_testset.append((transforms_funcs(sample), label))
data['poison_' + mode] = DataLoader(
poison_testset,
batch_size=ctx.data.batch_size,
batch_size=ctx.dataloader.batch_size,
shuffle=False,
num_workers=ctx.data.num_workers)
num_workers=ctx.dataloader.num_workers)

else:
raise RuntimeError(
Expand Down Expand Up @@ -213,9 +213,9 @@ def load_poisoned_dataset_pixel(data, ctx, mode):
poisoned_dataset[iii] = (transforms_funcs(sample), label)

data[mode] = DataLoader(poisoned_dataset,
batch_size=ctx.data.batch_size,
batch_size=ctx.dataloader.batch_size,
shuffle=True,
num_workers=ctx.data.num_workers)
num_workers=ctx.dataloader.num_workers)

if mode == MODE.TEST or mode == MODE.VAL:
poisoned_dataset = addTrigger(data[mode].dataset,
Expand All @@ -234,10 +234,11 @@ def load_poisoned_dataset_pixel(data, ctx, mode):
# (channel, height, width) = sample.shape #(c,h,w)
poisoned_dataset[iii] = (transforms_funcs(sample), label)

data['poison_' + mode] = DataLoader(poisoned_dataset,
batch_size=ctx.data.batch_size,
shuffle=False,
num_workers=ctx.data.num_workers)
data['poison_' + mode] = DataLoader(
poisoned_dataset,
batch_size=ctx.dataloader.batch_size,
shuffle=False,
num_workers=ctx.dataloader.num_workers)

return data

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from federatedscope.core.trainers import GeneralTorchTrainer
from federatedscope.core.auxiliaries.dataloader_builder import WrapDataset
from federatedscope.core.data.wrap_dataset import WrapDataset
from federatedscope.attack.auxiliary.MIA_get_target_data import get_target_data

logger = logging.getLogger(__name__)
Expand Down
7 changes: 0 additions & 7 deletions federatedscope/attack/trainer/benign_trainer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
from calendar import c
import logging
from typing import Type
import torch
import numpy as np

from federatedscope.core.trainers import GeneralTorchTrainer
from federatedscope.core.auxiliaries.transform_builder import get_transform
from federatedscope.attack.auxiliary.backdoor_utils import normalize
from federatedscope.core.auxiliaries.dataloader_builder import WrapDataset
from federatedscope.core.auxiliaries.dataloader_builder import get_dataloader
from federatedscope.core.auxiliaries.ReIterator import ReIterator

logger = logging.getLogger(__name__)

Expand Down
7 changes: 4 additions & 3 deletions federatedscope/contrib/data/example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from federatedscope.register import register_data


def MyData(config):
def MyData(config, client_cfgs=None):
r"""
Returns:
data:
Expand All @@ -17,12 +17,13 @@ def MyData(config):
"""
data = None
config = config
client_cfgs = client_cfgs
return data, config


def call_my_data(config):
def call_my_data(config, client_cfgs):
if config.data.type == "mydata":
data, modified_config = MyData(config)
data, modified_config = MyData(config, client_cfgs)
return data, modified_config


Expand Down
104 changes: 104 additions & 0 deletions federatedscope/contrib/trainer/torch_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import inspect
from federatedscope.register import register_trainer
from federatedscope.core.trainers import BaseTrainer

# An example for converting torch training process to FS training process

# Refer to `federatedscope.core.trainers.BaseTrainer` for interface.

# Try with FEMNIST:
# python federatedscope/main.py --cfg scripts/example_configs/femnist.yaml \
# trainer.type mytorchtrainer federate.sample_client_rate 0.01 \
# federate.total_round_num 5 eval.best_res_update_round_wise_key test_loss


class MyTorchTrainer(BaseTrainer):
def __init__(self, model, data, device, **kwargs):
import torch
# NN modules
self.model = model
# FS `ClientData` or your own data
self.data = data
# Device name
self.device = device
# kwargs
self.kwargs = kwargs
# Criterion & Optimizer
self.criterion = torch.nn.CrossEntropyLoss()
self.optimizer = torch.optim.SGD(self.model.parameters(),
lr=0.001,
momentum=0.9,
weight_decay=1e-4)

def train(self):
# _hook_on_fit_start_init
self.model.to(self.device)
self.model.train()

total_loss = num_samples = 0
# _hook_on_batch_start_init
for x, y in self.data['train']:
# _hook_on_batch_forward
x, y = x.to(self.device), y.to(self.device)
outputs = self.model(x)
loss = self.criterion(outputs, y)

# _hook_on_batch_backward
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

# _hook_on_batch_end
total_loss += loss.item() * y.shape[0]
num_samples += y.shape[0]

# _hook_on_fit_end
return num_samples, self.model.cpu().state_dict(), \
{'loss_total': total_loss, 'avg_loss': total_loss/float(
num_samples)}

def evaluate(self, target_data_split_name='test'):
import torch
with torch.no_grad():
self.model.to(self.device)
self.model.eval()
total_loss = num_samples = 0
# _hook_on_batch_start_init
for x, y in self.data[target_data_split_name]:
# _hook_on_batch_forward
x, y = x.to(self.device), y.to(self.device)
pred = self.model(x)
loss = self.criterion(pred, y)

# _hook_on_batch_end
total_loss += loss.item() * y.shape[0]
num_samples += y.shape[0]

# _hook_on_fit_end
return {
f'{target_data_split_name}_loss': total_loss,
f'{target_data_split_name}_total': num_samples,
f'{target_data_split_name}_avg_loss': total_loss /
float(num_samples)
}

def update(self, model_parameters, strict=False):
self.model.load_state_dict(model_parameters, strict)

def get_model_para(self):
return self.model.cpu().state_dict()

def print_trainer_meta_info(self):
sign = inspect.signature(self.__init__).parameters.values()
meta_info = tuple([(val.name, getattr(self, val.name))
for val in sign])
return f'{self.__class__.__name__}{meta_info}'


def call_my_torch_trainer(trainer_type):
if trainer_type == 'mytorchtrainer':
trainer_builder = MyTorchTrainer
return trainer_builder


register_trainer('mytorchtrainer', call_my_torch_trainer)
Loading

0 comments on commit 84a3722

Please sign in to comment.