Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fixbug] Fix twitter related bugs and merge_test_data #284

Merged
merged 3 commits into from
Aug 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions benchmark/FedHPOB/scripts/lr/run_hpo_twitter_lr.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
set -e

cudaid=$1
dataset=$2

cd ../..

out_dir=out_${dataset}

if [ ! -d $out_dir ];then
mkdir $out_dir
fi

echo "HPO starts..."

sample_rates=(0.01)
lrs=(0.00001 0.0001 0.001 0.01 0.1 1.0)
wds=(0.0 0.001 0.01 0.1)
steps=(1 2 3 4)
batch_sizes=(64)

for (( sr=0; sr<${#sample_rates[@]}; sr++ ))
do
for (( l=0; l<${#lrs[@]}; l++ ))
do
for (( w=0; w<${#wds[@]}; w++ ))
do
for (( s=0; s<${#steps[@]}; s++ ))
do
for (( b=0; b<${#batch_sizes[@]}; b++ ))
do
for k in {1..3}
do
python federatedscope/main.py --cfg fedhpo/openml/openml_lr.yaml device $cudaid optimizer.lr ${lrs[$l]} optimizer.weight_decay ${wds[$w]} federate.local_update_steps ${steps[$s]} data.type ${dataset}@openml data.batch_size ${batch_sizes[$b]} federate.sample_client_rate ${sample_rates[$sr]} model.out_channels $out_channels seed $k outdir lr/${out_dir}_${sample_rates[$sr]} expname lr${lrs[$l]}_wd${wds[$w]}_dropout0_step${steps[$s]}_batch${batch_sizes[$b]}_seed${k} >/dev/null 2>&1
done
done
done
done
done
done

echo "HPO ends."
36 changes: 36 additions & 0 deletions benchmark/FedHPOB/scripts/lr/twitter.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use_gpu: True
device: 0
early_stop:
patience: 100
federate:
mode: standalone
total_round_num: 500
sample_client_rate: 0.01
make_global_eval: True
merge_test_data: True
share_local_model: True
online_aggr: True
data:
root: data/
type: twitter
batch_size: 5
subsample: 0.005
num_workers: 0
model:
type: lr
out_channels: 2
dropout: 0.0
train:
local_update_steps: 10
optimizer:
lr: 0.0003
weight_decay: 0.0
criterion:
type: CrossEntropyLoss
trainer:
type: nlptrainer
eval:
freq: 1
metrics: ['acc', 'correct', 'f1']
split: [ 'test' ]
best_res_update_round_wise_key: 'test_loss'
28 changes: 22 additions & 6 deletions federatedscope/core/auxiliaries/data_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,28 +644,44 @@ def merge_data(all_data, merged_max_data_id, specified_dataset_name=None):
assert len(dataset_names) >= 1, \
"At least one sub-dataset is required in client 1"
data_name = "test" if "test" in dataset_names else dataset_names[0]
if isinstance(all_data[1][data_name], dict):
data_elem_names = list(all_data[1][data_name].keys()) # e.g., x, y
id_has_key = 1
while "test" not in all_data[id_has_key]:
id_has_key += 1
if len(all_data) <= id_has_key:
raise KeyError(f'All data do not key {data_name}.')
if isinstance(all_data[id_has_key][data_name], dict):
data_elem_names = list(
all_data[id_has_key][data_name].keys()) # e.g., x, y
merged_data = {name: defaultdict(list) for name in dataset_names}
for data_id in range(1, merged_max_data_id):
for d_name in dataset_names:
if d_name not in all_data[data_id]:
continue
for elem_name in data_elem_names:
merged_data[d_name][elem_name].append(
all_data[data_id][d_name][elem_name])
for d_name in dataset_names:
for elem_name in data_elem_names:
merged_data[d_name][elem_name] = np.concatenate(
merged_data[d_name][elem_name])
elif issubclass(type(all_data[1][data_name]), torch.utils.data.DataLoader):
merged_data = {name: all_data[1][name] for name in dataset_names}
for data_id in range(2, merged_max_data_id):
elif issubclass(type(all_data[id_has_key][data_name]),
torch.utils.data.DataLoader):
merged_data = {
name: all_data[id_has_key][name]
for name in dataset_names
}
for data_id in range(1, merged_max_data_id):
if data_id == id_has_key:
continue
for d_name in dataset_names:
if d_name not in all_data[data_id]:
continue
merged_data[d_name].dataset.extend(
all_data[data_id][d_name].dataset)
else:
raise NotImplementedError(
"Un-supported type when merging data across different clients."
f"Your data type is {type(all_data[1][data_name])}. "
f"Your data type is {type(all_data[id_has_key][data_name])}. "
f"Currently we only support the following forms: "
" 1): {data_id: {train: {x:ndarray, y:ndarray}} }"
" 2): {data_id: {train: DataLoader }")
Expand Down
3 changes: 2 additions & 1 deletion federatedscope/core/monitors/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,8 @@ def update_best_result(self,
if round_wise_update_key not in [
"val_loss", "test_loss", "loss", "val_avg_loss",
"test_avg_loss", "avg_loss", "test_acc", "test_std",
"val_acc", "val_std", "val_imp_ratio"
"val_acc", "val_std", "val_imp_ratio", "train_loss",
"train_avg_loss"
]:
raise NotImplementedError(
f"We currently support round_wise_update_key as one "
Expand Down
6 changes: 5 additions & 1 deletion federatedscope/cv/dataset/leaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def process(self):

class LocalDataset(Dataset):
"""
Convert data list to torch Dataset to save memory usage.
Convert data list to torch Dataset to save memory usage.
"""
def __init__(self,
Xs,
Expand Down Expand Up @@ -122,3 +122,7 @@ def __getitem__(self, idx):
target = self.target_transform(target)

return data, target

def extend(self, dataset):
self.Xs = np.vstack((self.Xs, dataset.Xs))
self.targets = np.hstack((self.targets, dataset.targets))
8 changes: 5 additions & 3 deletions federatedscope/nlp/baseline/fedavg_lr_on_twitter.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ federate:
mode: standalone
total_round_num: 100
sample_client_num: 10
share_local_model: True
online_aggr: True
data:
root: data/
type: twitter
Expand All @@ -26,7 +28,7 @@ criterion:
trainer:
type: nlptrainer
eval:
freq: 10
metrics: ['acc', 'correct']
split: ['train']
freq: 1
metrics: ['acc', 'correct', 'f1']
split: [ 'train' ]
best_res_update_round_wise_key: 'train_loss'