Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabled to minimize local entropy #468

Merged
merged 4 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions federatedscope/contrib/trainer/local_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import math
from collections import defaultdict

import torch

from federatedscope.core.trainers import BaseTrainer
from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer


def copy_params(src):
tgt = dict()
for name, t in src.named_parameters():
if t.requires_grad:
tgt[name] = t.detach().clone()
return tgt


def prox_term(cur, last):
loss = .0
for name, tensor in last.items():
loss += 0.5 * torch.sum((cur[name] - tensor)**2)
return loss


def add_noise(model, sigma):
for p in model.parameters():
if p.requires_grad:
p.data += sigma * torch.randn(size=p.shape, device=p.device)


def moving_avg(cur, new, alpha):
for k, v in cur.items():
v.data = (1 - alpha) * v + alpha * new[k]


class LocalEntropyTrainer(BaseTrainer):
def __init__(self, model, data, device, **kwargs):
# NN modules
self.model = model
# FS `ClientData` or your own data
self.data = data
# Device name
self.device = device
# configs
self.kwargs = kwargs
self.config = kwargs['config']
self.optim_config = self.config.train.optimizer
self.local_entropy_config = self.config.trainer.local_entropy

def train(self):
# Criterion & Optimizer
criterion = torch.nn.CrossEntropyLoss().to(self.device)
optimizer = get_optimizer(self.model, **self.optim_config)

# _hook_on_fit_start_init
self.model.to(self.device)
current_global_model = copy_params(self.model)
mu = copy_params(self.model)
self.model.train()

num_samples, total_loss = self.run_epoch(optimizer, criterion,
current_global_model, mu)
for name, param in self.model.named_parameters():
if name in mu:
param.data = mu[name]

# _hook_on_fit_end
return num_samples, self.model.cpu().state_dict(), \
{'loss_total': total_loss, 'avg_loss': total_loss/float(
num_samples)}

def run_epoch(self, optimizer, criterion, current_global_model, mu):
running_loss = 0.0
num_samples = 0
thermal = self.local_entropy_config.gamma
# for inputs, targets in self.trainloader:
for inputs, targets in self.data['train']:
inputs = inputs.to(self.device)
targets = targets.to(self.device)

# Descent Step
outputs = self.model(inputs)
ce_loss = criterion(outputs, targets)
loss = ce_loss + thermal * prox_term(self.model.state_dict(),
current_global_model)
loss.backward()
optimizer.step()

# add noise for langevine dynamics
add_noise(
self.model,
math.sqrt(self.optim_config.lr) *
self.local_entropy_config.eps)

# acc local updates
moving_avg(mu, self.model.state_dict(),
self.local_entropy_config.alpha)

with torch.no_grad():
running_loss += targets.shape[0] * ce_loss.item()

num_samples += targets.shape[0]
thermal *= 1.001

return num_samples, running_loss

def evaluate(self, target_data_split_name='test'):
if target_data_split_name != 'test':
return {}

with torch.no_grad():
criterion = torch.nn.CrossEntropyLoss().to(self.device)

self.model.to(self.device)
self.model.eval()
total_loss = num_samples = num_corrects = 0
# _hook_on_batch_start_init
for x, y in self.data[target_data_split_name]:
# _hook_on_batch_forward
x, y = x.to(self.device), y.to(self.device)
pred = self.model(x)
loss = criterion(pred, y)
cor = torch.sum(torch.argmax(pred, dim=-1).eq(y))

# _hook_on_batch_end
total_loss += loss.item() * y.shape[0]
num_samples += y.shape[0]
num_corrects += cor.item()

# _hook_on_fit_end
return {
f'{target_data_split_name}_acc': float(num_corrects) /
float(num_samples),
f'{target_data_split_name}_loss': total_loss,
f'{target_data_split_name}_total': num_samples,
f'{target_data_split_name}_avg_loss': total_loss /
float(num_samples)
}

def update(self, model_parameters, strict=False):
self.model.load_state_dict(model_parameters, strict)

def get_model_para(self):
return self.model.cpu().state_dict()
12 changes: 12 additions & 0 deletions federatedscope/contrib/trainer/local_entropy_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from federatedscope.register import register_trainer
from federatedscope.core.trainers import BaseTrainer


def call_local_entropy_trainer(trainer_type):
if trainer_type == 'local_entropy_trainer':
from federatedscope.contrib.trainer.local_entropy \
import LocalEntropyTrainer
return LocalEntropyTrainer


register_trainer('local_entropy_trainer', call_local_entropy_trainer)
5 changes: 5 additions & 0 deletions federatedscope/core/configs/cfg_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ def extend_training_cfg(cfg):
cfg.trainer.sam.rho = 1.0
cfg.trainer.sam.eta = .0

cfg.trainer.local_entropy = CN()
cfg.trainer.local_entropy.gamma = 1e-4
cfg.trainer.local_entropy.eps = 1e-3
cfg.trainer.local_entropy.alpha = 0.75

# ---------------------------------------------------------------------- #
# Training related options
# ---------------------------------------------------------------------- #
Expand Down
52 changes: 52 additions & 0 deletions scripts/fedsam_exp_scripts/fedentsgd_on_cifar10.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use_gpu: True
device: 0
early_stop:
patience: 0
federate:
mode: standalone
total_round_num: 10000
client_num: 100
sample_client_num: 5
make_global_eval: True
merge_test_data: True
fedopt:
use: True
optimizer:
lr: 0.0001
weight_decay: 0.0
momentum: 0.0
data:
root: data/
type: 'CIFAR10@torchvision'
splits: [1.0,0.0,0.0]
num_workers: 0
transform: [['RandomCrop', {'size': 32, 'padding': 4}], ['RandomHorizontalFlip'], ['ToTensor'], ['Normalize', {'mean': [0.4914, 0.4822, 0.4465], 'std': [0.2023, 0.1994, 0.2010]}]]
test_transform: [['ToTensor'], ['Normalize', {'mean': [0.4914, 0.4822, 0.4465], 'std': [0.2023, 0.1994, 0.2010]}]]
args: [{'download': True}]
splitter: 'fedsam_cifar10_splitter'
splitter_args: [{'alpha': 0.05}]
dataloader:
batch_size: 64
model:
type: fedsam_conv2
out_channels: 10
dropout: 0.0
criterion:
type: CrossEntropyLoss
trainer:
type: local_entropy_trainer
local_entropy:
gamma: 0.0001
eps: 0.001
alpha: 0.75
train:
batch_or_epoch: 'epoch'
optimizer:
lr: 0.1
weight_decay: 0.0
momentum: 0.0
eval:
freq: 100
metrics: ['acc', 'correct']
best_res_update_round_wise_key: test_loss
count_flops: False
10 changes: 10 additions & 0 deletions scripts/fedsam_exp_scripts/hpo_for_fedentsgd.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
set -e

bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 0 1e-4 1e-4 0.1 >/dev/null 2>/dev/null &
bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 1 1e-4 1e-4 1.0 >/dev/null 2>/dev/null &
bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 2 1e-4 1e-3 0.1 >/dev/null 2>/dev/null &
bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 3 1e-4 1e-3 1.0 >/dev/null 2>/dev/null &
bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 4 1e-3 1e-4 0.1 >/dev/null 2>/dev/null &
bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 5 1e-3 1e-4 1.0 >/dev/null 2>/dev/null &
bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 6 1e-3 1e-3 0.1 >/dev/null 2>/dev/null &
bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 7 1e-3 1e-3 1.0 >/dev/null 2>/dev/null &
18 changes: 18 additions & 0 deletions scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
set -e

lda_alpha=$1
cudaid=$2
gamma=$3
eps=$4
lr=$5

echo $lda_alpha
echo $cudaid
echo $gamma
echo $eps
echo $lr

for (( i=0; i<5; i++ ))
do
python federatedscope/main.py --cfg scripts/fedsam_exp_scripts/fedentsgd_on_cifar10.yaml seed $i device $cudaid data.splitter_args "[{'alpha': ${lda_alpha}}]" trainer.local_entropy.gamma $gamma fedopt.optimizer.lr $gamma trainer.local_entropy.eps $eps train.optimizer.lr $lr expname fedentsgd_${lda_alpha}_${gamma}_${eps}_${lr}_${i}
done