This is the implementation of our paper FedTGP: Trainable Global Prototypes with Adaptive-Margin-Enhanced Contrastive Learning for Data and Model Heterogeneity in Federated Learning (accepted by AAAI 2024).
Key words: federated learning, data heterogeneity, model heterogeneity, communication overhead, intellectual property (IP) protection
Take away: We enhance the typical HtFL method FedProto with Trainable Global Prototypes (TGP) and Adaptive-margin-enhanced Contrastive Learning (ACL), making it more versatile and resilient to various model heterogeneities.
Citation
@inproceedings{zhang2024fedtgp,
title={FedTGP: Trainable Global Prototypes with Adaptive-Margin-Enhanced Contrastive Learning for Data and Model Heterogeneity in Federated Learning},
author={Zhang, Jianqing and Liu, Yang and Hua, Yang and Cao, Jian},
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
year={2024}
}
The global and client prototypes in FedProto and our FedTGP. Different colors and numbers represent classes and clients, respectively. Circles represent the client prototypes and triangles represent the global prototypes. The black and yellow dotted arrows show the inter-class separation among the client and global prototypes, respectively. Triangles with dotted borders represent our Trainable Global Prototypes (TGP). The red arrows show the inter-class intervals between TGP and the client prototypes of other classes in our Adaptive-margin-enhanced Contrastive Learning (ACL).
Due to the file size limitation, we only upload the statistics (config.json
) of the Cifar10 dataset in the practical setting (
Learning reasonable global prototypes can be challenging in some cases, particularly due to the limited number of client prototypes and the introduced adaptive margin during ACL. To address this, consider setting a larger top_cnt
and ensuring that the global communication iteration number is larger than 1000, which should result in a Server loss
smaller than 0.001. The best accuracy is typically achieved when a minimal Server loss
is obtained. In most of our experiments, we achieved a Server loss
of 0.0.
main.py
: system configurations.run_me.sh
: command lines to run experiments. It is advisable to retune hyperparameters on new tasks.flcore/
:utils/
:data_utils.py
: the code to read the dataset.mem_utils.py
: the code to record memory usage.result_utils.py
: the code to save results to files.
All codes are stored in ./system
. Just run the following commands.
cd ./system
sh run_me.sh