Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

使用Efficientnet-b0导致qfedavg失效 #57

Open
Xiaxuanxuan opened this issue Jan 30, 2024 · 3 comments
Open

使用Efficientnet-b0导致qfedavg失效 #57

Xiaxuanxuan opened this issue Jan 30, 2024 · 3 comments

Comments

@Xiaxuanxuan
Copy link

感谢您的联邦框架!!非常简洁并且方便移植!!!
不过有一个问题想麻烦您回答,当我将model换成efficientnet-b0,在cifar10数据集上使用qfedavg、fedfv、fedprox时,会出现自始至终loss不变的问题,这是我设定的model
from torch import nn
from flgo.utils.fmodule import FModule
from efficientnet_pytorch import EfficientNet

class Model(FModule):
def init(self):
super(Model, self).init()
pretrained = True
self.base_model = (
EfficientNet.from_pretrained("efficientnet-b0")
if pretrained
else EfficientNet.from_name("efficientnet-b0")
)
# self.base_model=torchvision.models.efficientnet_v2_s(pretrained=pretrained)
nftrs = self.base_model._fc.in_features
# print("Number of features output by EfficientNet", nftrs)
self.base_model._fc = nn.Linear(nftrs, 10)

def forward(self, x):
    # Convolution layers
    x = self.base_model.extract_features(x)
    # Pooling and final linear layer
    feature_x = self.base_model._avg_pooling(x)
    if self.base_model._global_params.include_top:
        x = feature_x.flatten(start_dim=1)
        x = self.base_model._dropout(x)
        x = self.base_model._fc(x)
    return x

def init_local_module(object):
pass

def init_global_module(object):
if 'Server' in object.class.name:
object.model = Model().to(object.device)
会出现这样的结果
issue

@WwZzz
Copy link
Owner

WwZzz commented Feb 3, 2024

你好,之前有人在flgo交流群中提出了同样的问题,该问题是因为qfedavg的代码中使用norm接口直接计算模型的范数,norm结构默认调用的是flgo.utiles.fmodule._model_dict_norm,而model.state_dict()中通常包含了统计量参数,使得带bn层的模型由该接口得到的范数都会非常大,若是更新过程中除以了模型范数的话,会出现这种模型更新被放缩到0的情形,我这里贴上我修复后的qfedavg代码,稍后会将该更新整合到flgo中

@WwZzz
Copy link
Owner

WwZzz commented Feb 3, 2024

`"""This is a non-official implementation of 'Fair Resource Allocation in
Federated Learning' (http://arxiv.org/abs/1905.10497). And this implementation
refers to the official github repository https://github.com/litian96/fair_flearn """
import flgo.algorithm.fedbase as fedbase
import flgo.utils.fmodule as fmodule
import copy

class Server(fedbase.BasicServer):
def initialize(self, *args, **kwargs):
self.init_algo_para({'q': 1.0})

def iterate(self):
    self.selected_clients = self.sample()
    res = self.communicate(self.selected_clients)
    self.model = self.model - fmodule._model_sum(res['dk']) / sum(res['hk'])
    return len(self.received_clients) > 0

class Client(fedbase.BasicClient):
def unpack(self, package):
model = package['model']
self.global_model = copy.deepcopy(model)
return model

def pack(self, model):
    Fk = self.test(self.global_model, 'train')['loss'] + 1e-8
    L = 1.0 / self.learning_rate
    delta_wk = L * (self.global_model - model)
    dk = (Fk ** self.q) * delta_wk
    norm_dwk = 0.0
    for p in delta_wk.parameters():
        norm_dwk += (p**2).sum()
    hk = self.q * (Fk ** (self.q - 1)) * (norm_dwk) + L * (Fk ** self.q)
    self.global_model = None
    return {'dk': dk, 'hk': hk}

`

@WwZzz
Copy link
Owner

WwZzz commented Feb 3, 2024

将涉及到norm计算的地方替换成基于model.parameter计算可以修复该问题,但是由于bn和niid在联邦学习中具有天然冲突,建议直接使用不带bn或是将bn替换成gn的模型

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants