-
Notifications
You must be signed in to change notification settings - Fork 214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support using cached data and re-splitting for huggingface datasets #302
Merged
Merged
Changes from 2 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
0df2b4e
support using cached data for huggingface datasets; and re-splitting …
yxdyc 2acbeb4
support using cached data for huggingface datasets; and re-splitting …
yxdyc f796bd8
Merge remote-tracking branch 'upstream/master' into Feature/dataset_e…
yxdyc 795da26
minor fix according to weirui's comments
yxdyc d2056dd
minor fix for unittest
yxdyc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,13 @@ | ||
import os | ||
import pickle | ||
import logging | ||
from random import shuffle | ||
|
||
import numpy as np | ||
from collections import defaultdict | ||
|
||
from federatedscope.core.auxiliaries.utils import setup_seed | ||
|
||
import federatedscope.register as register | ||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -285,8 +290,14 @@ def load_torchtext_data(name, splits=None, config=None): | |
|
||
if config.model.type.endswith('transformers'): | ||
from transformers import AutoTokenizer | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
config.model.type.split('@')[0]) | ||
|
||
try: | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
config.model.type.split('@')[0], | ||
local_files_only=True, | ||
cache_dir=os.path.join(os.getcwd(), "huggingface")) | ||
except: | ||
logging.error("") | ||
|
||
x_all = tokenizer(x_all, | ||
return_tensors='pt', | ||
|
@@ -402,6 +413,7 @@ def load_torch_geometric_data(name, splits=None, config=None): | |
|
||
def load_huggingface_datasets_data(name, splits=None, config=None): | ||
from datasets import load_dataset | ||
from datasets import load_from_disk | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Merge Line 415 and Line 416. |
||
|
||
if config.data.args: | ||
raw_args = config.data.args[0] | ||
|
@@ -410,18 +422,46 @@ def load_huggingface_datasets_data(name, splits=None, config=None): | |
assert 'max_len' in raw_args, "Miss key 'max_len' in " \ | ||
"`config.data.args`." | ||
filtered_args = filter_dict(load_dataset, raw_args) | ||
dataset = load_dataset(path=config.data.root, | ||
name=name, | ||
**filtered_args) | ||
logger.info("Begin to load huggingface dataset") | ||
if "hg_cache_dir" in raw_args: | ||
hugging_face_path = raw_args["hg_cache_dir"] | ||
else: | ||
hugging_face_path = os.getcwd() | ||
|
||
if "load_disk_dir" in raw_args: | ||
yxdyc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dataset = load_from_disk(raw_args["load_disk_dir"]) | ||
else: | ||
dataset = load_dataset(path=config.data.root, | ||
name=name, | ||
**filtered_args) | ||
if config.model.type.endswith('transformers'): | ||
os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||
from transformers import AutoTokenizer | ||
logger.info("To load huggingface tokenizer") | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
config.model.type.split('@')[0]) | ||
config.model.type.split('@')[0], | ||
local_files_only=True, | ||
cache_dir=os.path.join(hugging_face_path, "transformers")) | ||
|
||
for split in dataset: | ||
x_all = [i['sentence'] for i in dataset[split]] | ||
targets = [i['label'] for i in dataset[split]] | ||
|
||
if split == "train" and "used_train_ratio" in raw_args and \ | ||
1 > raw_args['used_train_ratio'] > 0: | ||
selected_idx = [i for i in range(len(dataset[split]))] | ||
shuffle(selected_idx) | ||
selected_idx = selected_idx[:int( | ||
len(selected_idx) * raw_args['used_train_ratio'])] | ||
x_all = [ | ||
element for i, element in enumerate(x_all) | ||
if i in selected_idx | ||
] | ||
targets = [ | ||
element for i, element in enumerate(targets) | ||
if i in selected_idx | ||
] | ||
|
||
x_all = tokenizer(x_all, | ||
return_tensors='pt', | ||
padding=True, | ||
|
@@ -441,6 +481,42 @@ def load_huggingface_datasets_data(name, splits=None, config=None): | |
(x, y) for x, y in zip(dataset['test'][0], dataset['test'][1]) | ||
] if (set(dataset['test'][1]) - set([-1])) else None, | ||
} | ||
original_train_size = len(data_dict["train"]) | ||
|
||
if "half_val_dummy_test" in raw_args and raw_args[ | ||
"half_val_dummy_test"]: | ||
# since the "test" set from GLUE dataset may be masked, we need to | ||
# submit to get the ground-truth, for fast FL experiments, | ||
# we split the validation set into two parts with the same size as | ||
# new test/val data | ||
original_val = [(x, y) for x, y in zip(dataset['validation'][0], | ||
dataset['validation'][1])] | ||
data_dict["val"], data_dict[ | ||
"test"] = original_val[:len(original_val) // | ||
2], original_val[len(original_val) // | ||
2:] | ||
if "val_as_dummy_test" in raw_args and raw_args["val_as_dummy_test"]: | ||
# use the validation set as tmp test set, | ||
# and partial training set as validation set | ||
data_dict["test"] = data_dict["val"] | ||
data_dict["val"] = [] | ||
if "part_train_dummy_val" in raw_args and 1 > raw_args[ | ||
"part_train_dummy_val"] > 0: | ||
new_val_part = int(original_train_size * | ||
raw_args["part_train_dummy_val"]) | ||
data_dict["val"].extend(data_dict["train"][:new_val_part]) | ||
data_dict["train"] = data_dict["train"][new_val_part:] | ||
if "part_train_dummy_test" in raw_args and 1 > raw_args[ | ||
"part_train_dummy_test"] > 0: | ||
new_test_part = int(original_train_size * | ||
raw_args["part_train_dummy_test"]) | ||
data_dict["test"] = data_dict["val"] | ||
if data_dict["test"] is not None: | ||
data_dict["test"].extend(data_dict["train"][:new_test_part]) | ||
else: | ||
data_dict["test"] = (data_dict["train"][:new_test_part]) | ||
data_dict["train"] = data_dict["train"][new_test_part:] | ||
|
||
return data_dict | ||
|
||
def load_openml_data(tid, splits=None, config=None): | ||
|
@@ -529,6 +605,9 @@ def get_data(config): | |
obj: The dataset object. | ||
cfg.node: The updated configuration. | ||
""" | ||
# fix the seed for data generation, | ||
# will restore the user-specified on after the generation | ||
setup_seed(12345) | ||
for func in register.data_dict.values(): | ||
data_and_config = func(config) | ||
if data_and_config is not None: | ||
|
@@ -615,6 +694,8 @@ def get_data(config): | |
from federatedscope.attack.auxiliary import poisoning | ||
poisoning(data, modified_config) | ||
|
||
setup_seed(config.seed) | ||
|
||
if config.federate.mode.lower() == 'standalone': | ||
return data, modified_config | ||
else: | ||
|
@@ -631,6 +712,8 @@ def get_data(config): | |
data_idx = config.distribute.data_idx | ||
return data[data_idx], config | ||
|
||
setup_seed(config.seed) | ||
|
||
|
||
def merge_data(all_data, merged_max_data_id, specified_dataset_name=None): | ||
if specified_dataset_name is None: | ||
|
43 changes: 43 additions & 0 deletions
43
federatedscope/nlp/baseline/fedavg_transformer_on_cola.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# different from federatedscope/nlp/baseline/fedavg_bert_on_sst2.yaml, | ||
# this yaml demonstrate | ||
# (1) using cached tokenizer via `load_disk_dir` and `hg_cache_dir` | ||
# (2) using some GLUE validation data as partial test data of the FL version | ||
|
||
use_gpu: True | ||
device: -1 | ||
early_stop: | ||
patience: 5 | ||
seed: 1 | ||
federate: | ||
mode: standalone | ||
total_round_num: 500 | ||
client_num: 50 | ||
sample_client_rate: 0.2 | ||
unseen_clients_rate: 0.2 | ||
data: | ||
root: 'glue' | ||
type: 'cola@huggingface_datasets' | ||
args: [{'load_disk_dir': 'huggingface/datasets/glue/cola', | ||
'hg_cache_dir': 'huggingface', 'max_len': 128, | ||
'val_as_dummy_test': True, 'part_train_dummy_val': 0.2} ] | ||
batch_size: 64 | ||
splitter: 'lda' | ||
splitter_args: [ { 'alpha': 0.4, 'min_size': 1} ] | ||
num_workers: 0 | ||
model: | ||
type: 'google/bert_uncased_L-2_H-128_A-2@transformers' | ||
task: 'SequenceClassification' | ||
out_channels: 2 | ||
train: | ||
local_update_steps: 1 | ||
batch_or_epoch: epoch | ||
optimizer: | ||
lr: 0.1 | ||
weight_decay: 0.0 | ||
criterion: | ||
type: CrossEntropyLoss | ||
trainer: | ||
type: nlptrainer | ||
eval: | ||
freq: 5 | ||
metrics: ['acc', 'correct', 'f1'] |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why need a try here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case of the cached file not existed