-
-
Notifications
You must be signed in to change notification settings - Fork 744
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #529 from FedML-AI/test/v0.7.0
Test/v0.7.0
- Loading branch information
Showing
16 changed files
with
297 additions
and
53 deletions.
There are no files selected for viewing
66 changes: 66 additions & 0 deletions
66
...examples/cross_silo/mqtt_s3_fedavg_defense_mnist_lr_example/config/crfl/fedml_config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
common_args: | ||
training_type: "cross_silo" | ||
scenario: "horizontal" | ||
using_mlops: false | ||
random_seed: 0 | ||
config_version: release | ||
|
||
environment_args: | ||
bootstrap: config/bootstrap.sh | ||
|
||
data_args: | ||
dataset: "mnist" | ||
data_cache_dir: ~/fedml_data | ||
partition_method: "hetero" | ||
partition_alpha: 0.5 | ||
|
||
model_args: | ||
model: "lr" | ||
model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically | ||
global_model_file_path: "./model_file_cache/global_model.pt" | ||
|
||
train_args: | ||
federated_optimizer: "FedAvg" | ||
# for CLI running, this can be None; in MLOps deployment, `client_id_list` will be replaced with real-time selected devices | ||
client_id_list: | ||
# for FoolsGold Defense, if use_memory is true, then client_num_in_total should be equal to client_number_per_round | ||
client_num_in_total: 1000 | ||
client_num_per_round: 4 | ||
comm_round: 10 | ||
epochs: 1 | ||
batch_size: 10 | ||
client_optimizer: sgd | ||
learning_rate: 0.03 | ||
weight_decay: 0.001 | ||
|
||
validation_args: | ||
frequency_of_the_test: 1 | ||
|
||
device_args: | ||
worker_num: 4 | ||
using_gpu: false | ||
gpu_mapping_file: config/gpu_mapping.yaml | ||
gpu_mapping_key: mapping_default | ||
|
||
comm_args: | ||
backend: "MQTT_S3" | ||
mqtt_config_path: | ||
s3_config_path: | ||
grpc_ipconfig_path: ./config/grpc_ipconfig.csv | ||
|
||
tracking_args: | ||
# the default log path is at ~/fedml-client/fedml/logs/ and ~/fedml-server/fedml/logs/ | ||
enable_wandb: false | ||
wandb_key: ee0b5f53d949c84cee7decbe7a629e63fb2f8408 | ||
wandb_project: fedml | ||
wandb_name: fedml_torch_fedavg_mnist_lr | ||
|
||
attack_args: | ||
enable_attack: false | ||
attack_type: None | ||
|
||
defense_args: | ||
enable_defense: true | ||
defense_type: crfl | ||
federated_optimizer: FedAvg | ||
sigma: 0.02 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from .defense_base import BaseDefenseMethod | ||
from typing import Callable, List, Tuple, Dict, Any | ||
from ..common import utils | ||
from ...dp.mechanisms import Gaussian | ||
|
||
""" | ||
CRFL: Certifiably Robust Federated Learning against Backdoor Attacks (ICML 2021) | ||
http://proceedings.mlr.press/v139/xie21a/xie21a.pdf | ||
""" | ||
|
||
|
||
class CRFLDefense(BaseDefenseMethod): | ||
def __init__(self, config): | ||
self.config = config | ||
self.epoch = 1 | ||
if hasattr(config, "clip_threshold"): | ||
self.clip_threshold = config.clip_threshold | ||
else: | ||
self.clip_threshold = None | ||
if hasattr(config, "sigma") and isinstance(config.sigma, float): | ||
self.sigma = config.sigma | ||
else: | ||
self.sigma = 0.01 # in the code of CRFL, the author set sigma to 0.01 | ||
|
||
def run( | ||
self, | ||
raw_client_grad_list: List[Tuple[float, Dict]], | ||
base_aggregation_func: Callable = None, | ||
extra_auxiliary_info: Any = None, | ||
): | ||
new_grad_list = self.defend_before_aggregation( | ||
raw_client_grad_list, extra_auxiliary_info | ||
) | ||
avg_params = self.defend_on_aggregation(new_grad_list, base_aggregation_func) | ||
return self.defend_after_aggregation(avg_params) | ||
|
||
def defend_before_aggregation( | ||
self, | ||
raw_client_grad_list: List[Tuple[float, Dict]], | ||
extra_auxiliary_info: Any = None, | ||
): | ||
return raw_client_grad_list | ||
|
||
def defend_on_aggregation( | ||
self, | ||
raw_client_grad_list: List[Tuple[float, Dict]], | ||
base_aggregation_func: Callable = None, | ||
extra_auxiliary_info: Any = None, | ||
): | ||
avg_params = base_aggregation_func(args=self.config, raw_grad_list=raw_client_grad_list) | ||
""" | ||
clip the global model; dynamic threshold is adjusted according to the dataset; | ||
in the experiment, the authors set the dynamic threshold as follows: | ||
dataset == MNIST: dynamic_thres = epoch * 0.1 + 2 | ||
dataseet == LOAN: dynamic_thres = epoch * 0.025 + 2 | ||
datset == EMNIST: dynamic_thres = epoch * 0.25 + 4 | ||
""" | ||
print(f"avg params = {avg_params}") | ||
dynamic_threshold = self.epoch * 0.1 + 2 | ||
if self.clip_threshold is None or self.clip_threshold > dynamic_threshold: | ||
self.clip_threshold = dynamic_threshold | ||
self.epoch += 1 | ||
|
||
print(f"self.clip_threshold={self.clip_threshold}") | ||
new_model = self.clip_weight_norm(avg_params, self.clip_threshold) | ||
# the output model is new model; later the algo adds dp noise to the global model | ||
return new_model | ||
|
||
def defend_after_aggregation(self, global_model): | ||
# todo: to discuss with chaoyang: the output is the clipped model (real model); | ||
# add dp noise to the real model and sent the permuted model to clients; how to get the last iteration? | ||
new_global_model = dict() | ||
for k in global_model.keys(): | ||
new_global_model[k] = global_model[k] + Gaussian.add_noise_using_sigma(self.sigma, global_model[k].shape) | ||
return new_global_model | ||
|
||
@staticmethod | ||
def clip_weight_norm(model, clip_threshold): | ||
total_norm = utils.compute_model_norm(model) | ||
print(f"total_norm = {total_norm}") | ||
if total_norm > clip_threshold: | ||
clip_coef = clip_threshold / (total_norm + 1e-6) | ||
new_model = dict() | ||
for k in model.keys(): | ||
new_model[k] = model[k] * clip_coef | ||
return new_model | ||
return model |
Oops, something went wrong.