Skip to content

Commit

Permalink
remove unnecessary clone of cfg
Browse files Browse the repository at this point in the history
  • Loading branch information
rayrayraykk committed Oct 10, 2022
1 parent ba5df54 commit edb4d4b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
46 changes: 23 additions & 23 deletions federatedscope/core/data/base_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, datadict, global_cfg):
datadict: `Dict` with `client_id` as key, `ClientData` as value.
global_cfg: global CfgNode
"""
self.cfg = global_cfg
self.global_cfg = global_cfg
self.client_cfgs = None
datadict = self.preprocess(datadict)
super(StandaloneDataDict, self).__init__(datadict)
Expand All @@ -29,7 +29,7 @@ def resetup(self, global_cfg, client_cfgs=None):
global_cfg: enable new config for `ClientData`
client_cfgs: enable new client-specific config for `ClientData`
"""
self.cfg, self.client_cfgs = global_cfg, client_cfgs
self.global_cfg, self.client_cfgs = global_cfg, client_cfgs
for client_id, client_data in self.items():
if isinstance(client_data, ClientData):
if client_cfgs is not None:
Expand All @@ -53,17 +53,17 @@ def preprocess(self, datadict):
Args:
datadict: dict with `client_id` as key, `ClientData` as value.
"""
if self.cfg.federate.merge_test_data:
if self.global_cfg.federate.merge_test_data:
server_data = merge_data(
all_data=datadict,
merged_max_data_id=self.cfg.federate.client_num,
merged_max_data_id=self.global_cfg.federate.client_num,
specified_dataset_name=['test'])
# `0` indicate Server
datadict[0] = server_data

if self.cfg.federate.method == "global":
if self.cfg.federate.client_num != 1:
if self.cfg.data.server_holds_all:
if self.global_cfg.federate.method == "global":
if self.global_cfg.federate.client_num != 1:
if self.global_cfg.data.server_holds_all:
assert datadict[0] is not None \
and len(datadict[0]) != 0, \
"You specified cfg.data.server_holds_all=True " \
Expand All @@ -72,10 +72,10 @@ def preprocess(self, datadict):
datadict[1] = datadict[0]
else:
logger.info(f"Will merge data from clients whose ids in "
f"[1, {self.cfg.federate.client_num}]")
f"[1, {self.global_cfg.federate.client_num}]")
datadict[1] = merge_data(
all_data=datadict,
merged_max_data_id=self.cfg.federate.client_num)
merged_max_data_id=self.global_cfg.federate.client_num)
datadict = self.attack(datadict)
return datadict

Expand All @@ -84,40 +84,41 @@ def attack(self, datadict):
Apply attack to `StandaloneDataDict`.
"""
if 'backdoor' in self.cfg.attack.attack_method and 'edge' in \
self.cfg.attack.trigger_type:
if 'backdoor' in self.global_cfg.attack.attack_method and 'edge' in \
self.global_cfg.attack.trigger_type:
import os
import torch
from federatedscope.attack.auxiliary import \
create_ardis_poisoned_dataset, create_ardis_test_dataset
if not os.path.exists(self.cfg.attack.edge_path):
os.makedirs(self.cfg.attack.edge_path)
if not os.path.exists(self.global_cfg.attack.edge_path):
os.makedirs(self.global_cfg.attack.edge_path)
poisoned_edgeset = create_ardis_poisoned_dataset(
data_path=self.cfg.attack.edge_path)
data_path=self.global_cfg.attack.edge_path)

ardis_test_dataset = create_ardis_test_dataset(
self.cfg.attack.edge_path)
self.global_cfg.attack.edge_path)

logger.info("Writing poison_data to: {}".format(
self.cfg.attack.edge_path))
self.global_cfg.attack.edge_path))

with open(
self.cfg.attack.edge_path +
self.global_cfg.attack.edge_path +
"poisoned_edgeset_training", "wb") as saved_data_file:
torch.save(poisoned_edgeset, saved_data_file)

with open(self.cfg.attack.edge_path + "ardis_test_dataset.pt",
"wb") as ardis_data_file:
with open(
self.global_cfg.attack.edge_path +
"ardis_test_dataset.pt", "wb") as ardis_data_file:
torch.save(ardis_test_dataset, ardis_data_file)
logger.warning(
'please notice: downloading the poisoned dataset \
on cifar-10 from \
https://github.com/ksreenivasan/OOD_Federated_Learning'
)

if 'backdoor' in self.cfg.attack.attack_method:
if 'backdoor' in self.global_cfg.attack.attack_method:
from federatedscope.attack.auxiliary import poisoning
poisoning(datadict, self.cfg)
poisoning(datadict, self.global_cfg)
return datadict


Expand All @@ -126,8 +127,6 @@ class ClientData(dict):
`ClientData` converts dataset to train/val/test DataLoader.
Key `data` in `ClientData` is the raw dataset.
"""
client_cfg = None

def __init__(self, client_cfg, train=None, val=None, test=None, **kwargs):
"""
Expand All @@ -139,6 +138,7 @@ def __init__(self, client_cfg, train=None, val=None, test=None, **kwargs):
val: valid dataset, which will be converted to DataLoader
test: test dataset, which will be converted to DataLoader
"""
self.client_cfg = None
self.train = train
self.val = val
self.test = test
Expand Down
2 changes: 1 addition & 1 deletion federatedscope/core/data/base_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, global_cfg, client_cfgs=None):
global_cfg: global CfgNode
client_cfgs: client cfg `Dict`
"""
self.global_cfg = global_cfg.clone()
self.global_cfg = global_cfg
self.client_cfgs = client_cfgs
self.splitter = get_splitter(global_cfg)

Expand Down

0 comments on commit edb4d4b

Please sign in to comment.