-
Notifications
You must be signed in to change notification settings - Fork 214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor data-related interfaces & add interfaces for trainer and worker #365
Changes from all commits
e9b3a8c
7aa3387
30f81d6
0e85138
0c7cd41
8d6b1c3
1477c61
d9b1dc0
2dd4bca
c7ed566
bb2f044
fc97a52
7e60fdc
c87c428
4e9709e
e9e12a6
fb2b15e
94755da
a032337
3cea5f2
577f167
c283108
85ae82d
a7b2568
482beb5
26716d7
8eee1c7
dc91315
b42463c
7860115
070b432
e6fcf99
f81b17d
b0f70bb
ba5df54
edb4d4b
b99cc31
551fe3a
05c6514
8875422
e25a4eb
bb9a9e9
47c89b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import inspect | ||
from federatedscope.register import register_trainer | ||
from federatedscope.core.trainers import BaseTrainer | ||
|
||
# An example for converting torch training process to FS training process | ||
|
||
# Refer to `federatedscope.core.trainers.BaseTrainer` for interface. | ||
|
||
# Try with FEMNIST: | ||
# python federatedscope/main.py --cfg scripts/example_configs/femnist.yaml \ | ||
# trainer.type mytorchtrainer federate.sample_client_rate 0.01 \ | ||
# federate.total_round_num 5 eval.best_res_update_round_wise_key test_loss | ||
|
||
|
||
class MyTorchTrainer(BaseTrainer): | ||
def __init__(self, model, data, device, **kwargs): | ||
import torch | ||
# NN modules | ||
self.model = model | ||
# FS `ClientData` or your own data | ||
self.data = data | ||
# Device name | ||
self.device = device | ||
# kwargs | ||
self.kwargs = kwargs | ||
# Criterion & Optimizer | ||
self.criterion = torch.nn.CrossEntropyLoss() | ||
self.optimizer = torch.optim.SGD(self.model.parameters(), | ||
lr=0.001, | ||
momentum=0.9, | ||
weight_decay=1e-4) | ||
|
||
def train(self): | ||
# _hook_on_fit_start_init | ||
self.model.to(self.device) | ||
self.model.train() | ||
|
||
total_loss = num_samples = 0 | ||
# _hook_on_batch_start_init | ||
for x, y in self.data['train']: | ||
# _hook_on_batch_forward | ||
x, y = x.to(self.device), y.to(self.device) | ||
outputs = self.model(x) | ||
loss = self.criterion(outputs, y) | ||
|
||
# _hook_on_batch_backward | ||
self.optimizer.zero_grad() | ||
loss.backward() | ||
self.optimizer.step() | ||
|
||
# _hook_on_batch_end | ||
total_loss += loss.item() * y.shape[0] | ||
num_samples += y.shape[0] | ||
|
||
# _hook_on_fit_end | ||
return num_samples, self.model.cpu().state_dict(), \ | ||
{'loss_total': total_loss, 'avg_loss': total_loss/float( | ||
num_samples)} | ||
|
||
def evaluate(self, target_data_split_name='test'): | ||
import torch | ||
with torch.no_grad(): | ||
self.model.to(self.device) | ||
self.model.eval() | ||
total_loss = num_samples = 0 | ||
# _hook_on_batch_start_init | ||
for x, y in self.data[target_data_split_name]: | ||
# _hook_on_batch_forward | ||
x, y = x.to(self.device), y.to(self.device) | ||
pred = self.model(x) | ||
loss = self.criterion(pred, y) | ||
|
||
# _hook_on_batch_end | ||
total_loss += loss.item() * y.shape[0] | ||
num_samples += y.shape[0] | ||
|
||
# _hook_on_fit_end | ||
return { | ||
f'{target_data_split_name}_loss': total_loss, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above |
||
f'{target_data_split_name}_total': num_samples, | ||
f'{target_data_split_name}_avg_loss': total_loss / | ||
float(num_samples) | ||
} | ||
|
||
def update(self, model_parameters, strict=False): | ||
self.model.load_state_dict(model_parameters, strict) | ||
|
||
def get_model_para(self): | ||
return self.model.cpu().state_dict() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why we move it to cpu? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I follow the principle of |
||
|
||
def print_trainer_meta_info(self): | ||
sign = inspect.signature(self.__init__).parameters.values() | ||
meta_info = tuple([(val.name, getattr(self, val.name)) | ||
for val in sign]) | ||
return f'{self.__class__.__name__}{meta_info}' | ||
|
||
|
||
def call_my_torch_trainer(trainer_type): | ||
if trainer_type == 'mytorchtrainer': | ||
trainer_builder = MyTorchTrainer | ||
return trainer_builder | ||
|
||
|
||
register_trainer('mytorchtrainer', call_my_torch_trainer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest putting this line before the forward process (i.e.,
outputs = self.model(x)
).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If so, should we modify GeneralTorchTrainer, which I follow the principle of.