Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more models and datasets from external packages. #42

Merged
merged 12 commits into from
May 7, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ RUN conda install -y pyg==2.0.4 -c pyg \
&& conda clean -a -y

# for speech and nlp
RUN conda install -y sentencepiece textgrid typeguard -c conda-forge \
RUN conda install -y sentencepiece textgrid typeguard transformers -c conda-forge \
&& conda install -y torchtext -c pytorch \
&& conda install -y datasets -c huggingface -c conda-forge \
&& conda clean -a -y


Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ RUN conda install -y pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoo
# for graph
RUN conda install -y pyg==2.0.1 -c pyg \
&& conda install -y rdkit=2021.09.4 -c conda-forge \
&& conda install -y nltk \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nltk should be put at the back of the line (the NLP part below)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NLTK is used for generating features of some graph datasets.

&& conda clean -a -y

# for speech and nlp
RUN conda install -y sentencepiece textgrid typeguard -c conda-forge \
RUN conda install -y sentencepiece textgrid typeguard transformers -c conda-forge \
&& conda install -y torchtext -c pytorch \
&& conda install -y datasets -c huggingface -c conda-forge \
&& conda clean -a -y

# auxiliaries (communications, monitoring, etc.)
Expand Down
3 changes: 2 additions & 1 deletion enviroment/requirements-torch1.10-application.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ textgrid
typeguard
nltk
torchtext

transformers
datasets

2 changes: 2 additions & 0 deletions enviroment/requirements-torch1.8-application.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,7 @@ textgrid
typeguard
nltk
torchtext
transformers
datasets


181 changes: 150 additions & 31 deletions federatedscope/core/auxiliaries/data_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ def load_torchvision_data(name, splits=None, config=None):

def load_torchtext_data(name, splits=None, config=None):
from torch.nn.utils.rnn import pad_sequence
from torchtext.data import get_tokenizer
from federatedscope.nlp.dataset.utils import label_to_index

dataset_func = getattr(import_module('torchtext.datasets'), name)
Expand All @@ -253,40 +252,92 @@ def load_torchtext_data(name, splits=None, config=None):

# torchtext.transforms requires >= 0.12.0 and torch = 1.11.0,
# so we do not use `get_transform` in torchtext.
tokenizer = get_tokenizer("basic_english")
if len(config.data.transform) == 0:
raise ValueError(
"`transform` must be one pretrained Word Embeddings from \
['GloVe', 'FastText', 'CharNGram']")
if len(config.data.transform) == 1:
config.data.transform.append({})
vocab = getattr(import_module('torchtext.vocab'),
config.data.transform[0])(dim=config.model.in_channels,
**config.data.transform[1])
data_list = []

# Merge all data and tokenize
x_list = []
y_list = []
for data_iter in dataset:
# TODO: we may need a more general and principled load function for the `IterableDataset`.
data, targets = [], []
if config.model.task == 'seq2seq':
for item in data_iter:
data.append(
vocab.get_vecs_by_tokens(tokenizer(item[1]),
lower_case_backup=True))
targets.append(
vocab.get_vecs_by_tokens(tokenizer(item[0]),
lower_case_backup=True))
for i, item in enumerate(data_iter):
data.append(item[1])
targets.append(item[0])
x_list.append(data)
y_list.append(targets)

x_all, y_all = [], []
for i in range(len(x_list)):
x_all += x_list[i]
y_all += y_list[i]

if config.model.type.endswith('transformers'):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
config.model.type.split('@')[0])

x_all = tokenizer(x_all,
return_tensors='pt',
padding=True,
truncation=True,
max_length=raw_args['max_len'])
data = [{key: value[i]
for key, value in x_all.items()}
for i in range(len(next(iter(x_all.values()))))]
if 'classification' in config.model.task.lower():
targets = label_to_index(y_all)
else:
y_all = tokenizer(y_all,
return_tensors='pt',
padding=True,
truncation=True,
max_length=raw_args['max_len'])
targets = [{key: value[i]
for key, value in y_all.items()}
for i in range(len(next(iter(y_all.values()))))]
else:
from torchtext.data import get_tokenizer
tokenizer = get_tokenizer("basic_english")
if len(config.data.transform) == 0:
raise ValueError(
"`transform` must be one pretrained Word Embeddings from \
['GloVe', 'FastText', 'CharNGram']")
if len(config.data.transform) == 1:
config.data.transform.append({})
vocab = getattr(import_module('torchtext.vocab'),
config.data.transform[0])(
dim=config.model.in_channels,
**config.data.transform[1])

if 'classification' in config.model.task.lower():
data = [
vocab.get_vecs_by_tokens(tokenizer(x),
lower_case_backup=True)
for x in x_all
]
targets = label_to_index(y_all)
else:
data = [
vocab.get_vecs_by_tokens(tokenizer(x),
lower_case_backup=True)
for x in x_all
]
targets = [
vocab.get_vecs_by_tokens(tokenizer(y),
lower_case_backup=True)
for y in y_all
]
targets = pad_sequence(targets).transpose(
0, 1)[:, :raw_args['max_len'], :]
else:
for item in data_iter:
data.append(
vocab.get_vecs_by_tokens(tokenizer(item[1]),
lower_case_backup=True))
targets.append(item[0])
targets = label_to_index(targets)
data = pad_sequence(data).transpose(0,
1)[:, :raw_args['max_len'], :]
data_list.append([(x, y) for x, y in zip(data, targets)])
# Split data to raw
num_items = [len(ds) for ds in x_list]
data_list, cnt = [], 0
for num in num_items:
data_list.append([
(x, y)
for x, y in zip(data[cnt:cnt + num], targets[cnt:cnt + num])
])
cnt += num

if len(data_list) == 3:
# Use raw splits
Expand Down Expand Up @@ -334,11 +385,79 @@ def load_torch_geometric_data(name, splits=None, config=None):
dataset_func = getattr(import_module('torch_geometric.datasets'), name)
raise NotImplementedError

def load_huggingface_datasets_data(name, splits=None, config=None):
from datasets import load_dataset

if config.data.args:
raw_args = config.data.args[0]
else:
raw_args = {}
assert 'max_len' in raw_args, "Miss key 'max_len' in `config.data.args`."
filtered_args = filter_dict(load_dataset, raw_args)
dataset = load_dataset(path=config.data.root,
name=name,
**filtered_args)
if config.model.type.endswith('transformers'):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
config.model.type.split('@')[0])

for split in dataset:
x_all = [i['sentence'] for i in dataset[split]]
targets = [i['label'] for i in dataset[split]]

x_all = tokenizer(x_all,
return_tensors='pt',
padding=True,
truncation=True,
max_length=raw_args['max_len'])
data = [{key: value[i]
for key, value in x_all.items()}
for i in range(len(next(iter(x_all.values()))))]
dataset[split] = (data, targets)
data_dict = {
'train': [(x, y)
for x, y in zip(dataset['train'][0], dataset['train'][1])
],
'val': [(x, y) for x, y in zip(dataset['validation'][0],
dataset['validation'][1])],
'test': [
(x, y) for x, y in zip(dataset['test'][0], dataset['test'][1])
] if (set(dataset['test'][1]) - set([-1])) else None,
}
return data_dict

def load_openml_data(tid, splits=None, config=None):
import openml
from sklearn.model_selection import train_test_split

task = openml.tasks.get_task(int(tid))
did = task.dataset_id
dataset = openml.datasets.get_dataset(did)
data, targets, _, _ = dataset.get_data(
dataset_format="array", target=dataset.default_target_attribute)

train_data, test_data, train_targets, test_targets = train_test_split(
data, targets, train_size=splits[0], random_state=config.seed)
val_data, test_data, val_targets, test_targets = train_test_split(
test_data,
test_targets,
train_size=splits[1] / (1. - splits[0]),
random_state=config.seed)
data_dict = {
'train': [(x, y) for x, y in zip(train_data, train_targets)],
'val': [(x, y) for x, y in zip(val_data, val_targets)],
'test': [(x, y) for x, y in zip(test_data, test_targets)]
}
return data_dict

DATA_LOAD_FUNCS = {
'torchvision': load_torchvision_data,
'torchtext': load_torchtext_data,
'torchaudio': load_torchaudio_data,
'torch_geometric': load_torch_geometric_data
'torch_geometric': load_torch_geometric_data,
'huggingface_datasets': load_huggingface_datasets_data,
'openml': load_openml_data
}

modified_config = config.clone()
Expand Down Expand Up @@ -441,4 +560,4 @@ def merge_data(all_data):
merged_data[d_name][elem_name] = np.concatenate(
merged_data[d_name][elem_name])

return merged_data
return merged_data
3 changes: 3 additions & 0 deletions federatedscope/core/auxiliaries/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def get_model(model_config, local_data, backend='torch'):
elif model_config.type.lower() in ['lstm']:
from federatedscope.nlp.model import get_rnn
model = get_rnn(model_config, local_data)
elif model_config.type.lower().endswith('transformers'):
from federatedscope.nlp.model import get_transformer
model = get_transformer(model_config, local_data)
elif model_config.type.lower() in ['gcn', 'sage', 'gpr', 'gat', 'gin']:
from federatedscope.gfl.model import get_gnn
model = get_gnn(model_config, local_data)
Expand Down
18 changes: 18 additions & 0 deletions federatedscope/core/auxiliaries/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,24 @@ def download_url(url: str, folder='folder'):
return path


def move_to(obj, device):
import torch
if torch.is_tensor(obj):
return obj.to(device)
elif isinstance(obj, dict):
res = {}
for k, v in obj.items():
res[k] = move_to(v, device)
return res
elif isinstance(obj, list):
res = []
for v in obj:
res.append(move_to(v, device))
return res
else:
raise TypeError("Invalid type for move_to")


class Timeout(object):
def __init__(self, seconds, max_failure=5):
self.seconds = seconds
Expand Down
2 changes: 1 addition & 1 deletion federatedscope/core/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
'Trainer', 'Context', 'GeneralTorchTrainer', 'GeneralMultiModelTrainer',
'wrap_pFedMeTrainer', 'wrap_DittoTrainer', 'FedEMTrainer',
'wrap_fedprox_trainer', 'wrap_nbafl_trainer', 'wrap_nbafl_server'
]
]
1 change: 0 additions & 1 deletion federatedscope/core/worker/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,6 @@ def trigger_for_start(self):
self.broadcast_model_para(msg_type='model_para',
sample_client_num=self.sample_client_num)


def terminate(self, msg_type='finish'):
"""
To terminate the FL course
Expand Down
25 changes: 25 additions & 0 deletions federatedscope/example_configs/openml_lr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use_gpu: True
device: 1
early_stop:
patience: 100
federate:
mode: 'standalone'
total_round_num: 400
client_num: 5
share_local_model: True
online_aggr: True
trainer:
type: 'general'
eval:
freq: 10
metrics: ['acc', 'correct']
data:
type: '10101@openml' # task_id@openml
splits: [0.8, 0.1, 0.1]
splitter: 'lda'
splitter_args: [{'alpha': 0.5}]
model:
type: lr
out_channels: 2
criterion:
type: CrossEntropyLoss
34 changes: 34 additions & 0 deletions federatedscope/nlp/baseline/fedavg_bert_on_sst2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use_gpu: True
device: 2
federate:
mode: standalone
local_update_steps: 1
total_round_num: 40
batch_or_epoch: 'epoch'
client_num: 5
share_local_model: True
online_aggr: True
sample_client_rate: 1.0
data:
root: 'glue'
type: 'sst2@huggingface_datasets'
args: [{'max_len': 512}]
batch_size: 128
splitter: 'lda'
splitter_args: [{'alpha': 0.5}]
num_workers: 0
model:
type: 'google/bert_uncased_L-2_H-128_A-2@transformers'
task: 'SequenceClassification'
out_channels: 2
optimizer:
lr: 0.0001
weight_decay: 0.0
criterion:
type: 'CrossEntropyLoss'
trainer:
type: 'nlptrainer'
eval:
freq: 2
metrics: ['acc', 'correct', 'f1']
split: ['val', 'train']
Loading