diff --git a/federatedscope/contrib/trainer/torch_example.py b/federatedscope/contrib/trainer/torch_example.py index dbdf938b7..df0f7b3e3 100644 --- a/federatedscope/contrib/trainer/torch_example.py +++ b/federatedscope/contrib/trainer/torch_example.py @@ -24,12 +24,14 @@ def __init__(self, model, data, device, **kwargs): self.kwargs = kwargs # Criterion & Optimizer self.criterion = torch.nn.CrossEntropyLoss() + + def train(self): + import torch self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) - def train(self): # _hook_on_fit_start_init self.model.to(self.device) self.model.train()