Skip to content

Commit

Permalink
[Fixbug] Fix twitter related bugs and merge_test_data (#284)
Browse files Browse the repository at this point in the history
  • Loading branch information
rayrayraykk authored Aug 5, 2022
1 parent 1c3b864 commit 41e14cc
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 11 deletions.
42 changes: 42 additions & 0 deletions benchmark/FedHPOB/scripts/lr/run_hpo_twitter_lr.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
set -e

cudaid=$1
dataset=$2

cd ../..

out_dir=out_${dataset}

if [ ! -d $out_dir ];then
mkdir $out_dir
fi

echo "HPO starts..."

sample_rates=(0.01)
lrs=(0.00001 0.0001 0.001 0.01 0.1 1.0)
wds=(0.0 0.001 0.01 0.1)
steps=(1 2 3 4)
batch_sizes=(64)

for (( sr=0; sr<${#sample_rates[@]}; sr++ ))
do
for (( l=0; l<${#lrs[@]}; l++ ))
do
for (( w=0; w<${#wds[@]}; w++ ))
do
for (( s=0; s<${#steps[@]}; s++ ))
do
for (( b=0; b<${#batch_sizes[@]}; b++ ))
do
for k in {1..3}
do
python federatedscope/main.py --cfg fedhpo/openml/openml_lr.yaml device $cudaid optimizer.lr ${lrs[$l]} optimizer.weight_decay ${wds[$w]} federate.local_update_steps ${steps[$s]} data.type ${dataset}@openml data.batch_size ${batch_sizes[$b]} federate.sample_client_rate ${sample_rates[$sr]} model.out_channels $out_channels seed $k outdir lr/${out_dir}_${sample_rates[$sr]} expname lr${lrs[$l]}_wd${wds[$w]}_dropout0_step${steps[$s]}_batch${batch_sizes[$b]}_seed${k} >/dev/null 2>&1
done
done
done
done
done
done

echo "HPO ends."
36 changes: 36 additions & 0 deletions benchmark/FedHPOB/scripts/lr/twitter.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use_gpu: True
device: 0
early_stop:
patience: 100
federate:
mode: standalone
total_round_num: 500
sample_client_rate: 0.01
make_global_eval: True
merge_test_data: True
share_local_model: True
online_aggr: True
data:
root: data/
type: twitter
batch_size: 5
subsample: 0.005
num_workers: 0
model:
type: lr
out_channels: 2
dropout: 0.0
train:
local_update_steps: 10
optimizer:
lr: 0.0003
weight_decay: 0.0
criterion:
type: CrossEntropyLoss
trainer:
type: nlptrainer
eval:
freq: 1
metrics: ['acc', 'correct', 'f1']
split: [ 'test' ]
best_res_update_round_wise_key: 'test_loss'
28 changes: 22 additions & 6 deletions federatedscope/core/auxiliaries/data_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,28 +644,44 @@ def merge_data(all_data, merged_max_data_id, specified_dataset_name=None):
assert len(dataset_names) >= 1, \
"At least one sub-dataset is required in client 1"
data_name = "test" if "test" in dataset_names else dataset_names[0]
if isinstance(all_data[1][data_name], dict):
data_elem_names = list(all_data[1][data_name].keys()) # e.g., x, y
id_has_key = 1
while "test" not in all_data[id_has_key]:
id_has_key += 1
if len(all_data) <= id_has_key:
raise KeyError(f'All data do not key {data_name}.')
if isinstance(all_data[id_has_key][data_name], dict):
data_elem_names = list(
all_data[id_has_key][data_name].keys()) # e.g., x, y
merged_data = {name: defaultdict(list) for name in dataset_names}
for data_id in range(1, merged_max_data_id):
for d_name in dataset_names:
if d_name not in all_data[data_id]:
continue
for elem_name in data_elem_names:
merged_data[d_name][elem_name].append(
all_data[data_id][d_name][elem_name])
for d_name in dataset_names:
for elem_name in data_elem_names:
merged_data[d_name][elem_name] = np.concatenate(
merged_data[d_name][elem_name])
elif issubclass(type(all_data[1][data_name]), torch.utils.data.DataLoader):
merged_data = {name: all_data[1][name] for name in dataset_names}
for data_id in range(2, merged_max_data_id):
elif issubclass(type(all_data[id_has_key][data_name]),
torch.utils.data.DataLoader):
merged_data = {
name: all_data[id_has_key][name]
for name in dataset_names
}
for data_id in range(1, merged_max_data_id):
if data_id == id_has_key:
continue
for d_name in dataset_names:
if d_name not in all_data[data_id]:
continue
merged_data[d_name].dataset.extend(
all_data[data_id][d_name].dataset)
else:
raise NotImplementedError(
"Un-supported type when merging data across different clients."
f"Your data type is {type(all_data[1][data_name])}. "
f"Your data type is {type(all_data[id_has_key][data_name])}. "
f"Currently we only support the following forms: "
" 1): {data_id: {train: {x:ndarray, y:ndarray}} }"
" 2): {data_id: {train: DataLoader }")
Expand Down
3 changes: 2 additions & 1 deletion federatedscope/core/monitors/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,8 @@ def update_best_result(self,
if round_wise_update_key not in [
"val_loss", "test_loss", "loss", "val_avg_loss",
"test_avg_loss", "avg_loss", "test_acc", "test_std",
"val_acc", "val_std", "val_imp_ratio"
"val_acc", "val_std", "val_imp_ratio", "train_loss",
"train_avg_loss"
]:
raise NotImplementedError(
f"We currently support round_wise_update_key as one "
Expand Down
6 changes: 5 additions & 1 deletion federatedscope/cv/dataset/leaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def process(self):

class LocalDataset(Dataset):
"""
Convert data list to torch Dataset to save memory usage.
Convert data list to torch Dataset to save memory usage.
"""
def __init__(self,
Xs,
Expand Down Expand Up @@ -122,3 +122,7 @@ def __getitem__(self, idx):
target = self.target_transform(target)

return data, target

def extend(self, dataset):
self.Xs = np.vstack((self.Xs, dataset.Xs))
self.targets = np.hstack((self.targets, dataset.targets))
8 changes: 5 additions & 3 deletions federatedscope/nlp/baseline/fedavg_lr_on_twitter.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ federate:
mode: standalone
total_round_num: 100
sample_client_num: 10
share_local_model: True
online_aggr: True
data:
root: data/
type: twitter
Expand All @@ -26,7 +28,7 @@ criterion:
trainer:
type: nlptrainer
eval:
freq: 10
metrics: ['acc', 'correct']
split: ['train']
freq: 1
metrics: ['acc', 'correct', 'f1']
split: [ 'train' ]
best_res_update_round_wise_key: 'train_loss'

0 comments on commit 41e14cc

Please sign in to comment.