diff --git a/federatedscope/cv/dataset/leaf.py b/federatedscope/cv/dataset/leaf.py index bd1c70b2a..a90cabf91 100644 --- a/federatedscope/cv/dataset/leaf.py +++ b/federatedscope/cv/dataset/leaf.py @@ -1,5 +1,8 @@ import zipfile import os +import torch + +import numpy as np import os.path as osp from torch.utils.data import Dataset @@ -16,14 +19,15 @@ def is_exists(path, names): class LEAF(Dataset): """Base class for LEAF dataset from "LEAF: A Benchmark for Federated Settings" - + Arguments: root (str): root path. name (str): name of dataset, in `LEAF_NAMES`. transform: transform for x. target_transform: transform for y. - + """ + def __init__(self, root, name, transform, target_transform): self.root = root self.name = name @@ -84,3 +88,34 @@ def process_file(self): def process(self): raise NotImplementedError + + +class LocalDataset(Dataset): + """ + Convert data list to torch Dataset to save memory usage. + """ + + def __init__(self, Xs, targets, pre_process=None, transform=None, target_transform=None): + assert len(Xs) == len( + targets), "The number of data and labels are not equal." + self.Xs = np.array(Xs) + self.targets = np.array(targets) + self.pre_process = pre_process + self.transform = transform + self.target_transform = target_transform + + def __len__(self): + return len(self.Xs) + + def __getitem__(self, idx): + data, target = self.Xs[idx], self.targets[idx] + if self.pre_process: + data = self.pre_process(data) + + if self.transform: + data = self.transform(data) + + if self.target_transform: + target = self.target_transform(target) + + return data, target diff --git a/federatedscope/nlp/dataset/leaf_twitter.py b/federatedscope/nlp/dataset/leaf_twitter.py index 29b9f8e30..8db2ad28c 100644 --- a/federatedscope/nlp/dataset/leaf_twitter.py +++ b/federatedscope/nlp/dataset/leaf_twitter.py @@ -1,30 +1,26 @@ import os import random -import pickle import json -import numpy as np import torch import math import os.path as osp from tqdm import tqdm -from collections import defaultdict - from sklearn.model_selection import train_test_split from federatedscope.core.auxiliaries.utils import save_local_data, download_url -from federatedscope.cv.dataset.leaf import LEAF +from federatedscope.cv.dataset.leaf import LEAF, LocalDataset from federatedscope.nlp.dataset.utils import * class LEAF_TWITTER(LEAF): """ LEAF NLP dataset from - + leaf.cmu.edu - + Arguments: root (str): root path. name (str): name of dataset, ‘shakespeare’ or ‘xxx’. @@ -45,6 +41,8 @@ def __init__(self, seed=123, transform=None, target_transform=None): + self.root = root + self.name = name self.s_frac = s_frac self.tr_frac = tr_frac self.val_frac = val_frac @@ -53,12 +51,18 @@ def __init__(self, if name != 'twitter': raise ValueError(f'`name` should be `twitter`.') else: + if not os.path.exists( + osp.join(osp.join(root, name, 'raw'), 'embs.json')): + self.download() + self.extract() print('Loading embs...') - with open(osp.join(osp.join(root, name, 'raw'), 'embs.json'), 'r') as inf: + with open(osp.join(osp.join(root, name, 'raw'), 'embs.json'), + 'r') as inf: embs = json.load(inf) self.id2word = embs['vocab'] self.word2id = {v: k for k, v in enumerate(self.id2word)} - super(LEAF_TWITTER, self).__init__(root, name, transform, target_transform) + super(LEAF_TWITTER, self).__init__(root, name, transform, + target_transform) files = os.listdir(self.processed_dir) files = [f for f in files if f.startswith('task_')] if len(files): @@ -97,15 +101,26 @@ def download(self): for name in self.raw_file_names: download_url(f'{url}/{name}', self.raw_dir) + def _to_bag_of_word(self, text): + bag = np.zeros(len(self.word2id)) + for i in text: + if i != -1: + bag[i] += 1 + else: + break + text = torch.FloatTensor(bag) + + return text + def __getitem__(self, index): """ Arguments: index (int): Index :returns: - dict: {'train':[(text, target)], - 'test':[(text, target)], - 'val':[(text, target)]} + dict: {'train':Dataset, + 'test':Dataset, + 'val':Dataset} where target is the target class. """ text_dict = {} @@ -113,45 +128,29 @@ def __getitem__(self, index): for key in data: text_dict[key] = [] texts, targets = data[key] - for idx in range(targets.shape[0]): - text = texts[idx] - - if self.transform is not None: - text = self.transform(text) - else: - # Bag of word - bag = np.zeros(len(self.word2id)) - for i in text: - if i != -1: - bag[i] += 1 - else: - break - text = torch.FloatTensor(bag) - - if self.target_transform is not None: - target = self.target_transform(target) - - text_dict[key].append((text, targets[idx])) + if self.transform: + text_dict[key] = LocalDataset(texts, targets, None, + self.transform, + self.target_transform) + else: + text_dict[key] = LocalDataset(texts, targets, None, + self._to_bag_of_word, + self.target_transform) return text_dict def tokenizer(self, data, targets): - """ - TOKENIZER = { - 'twitter': { - 'x': bag_of_words, - 'y': target_to_binary - } - } - """ # [ID, Date, Query, User, Content] processed_data = [] for raw_text in data: - ids = [self.word2id[w] if w in self.word2id else 0 for w in split_line(raw_text[4])] + ids = [ + self.word2id[w] if w in self.word2id else 0 + for w in split_line(raw_text[4]) + ] if len(ids) < self.max_len: - ids += [-1] * (self.max_len-len(ids)) + ids += [-1] * (self.max_len - len(ids)) else: - ids = ids[: self.max_len] + ids = ids[:self.max_len] processed_data.append(ids) targets = [target_to_binary(raw_target) for raw_target in targets] @@ -169,7 +168,6 @@ def process(self): for num, file in enumerate(files): with open(osp.join(raw_path, file), 'r') as f: raw_data = json.load(f) - user_list = list(raw_data['user_data'].keys()) n_tasks = math.ceil(len(user_list) * self.s_frac) random.shuffle(user_list) @@ -189,7 +187,7 @@ def process(self): targets = torch.LongTensor(targets) try: - train_data, test_data, train_targets, test_targets =\ + train_data, test_data, train_targets, test_targets = \ train_test_split( data, targets, @@ -207,7 +205,7 @@ def process(self): train_test_split( test_data, test_targets, - train_size=self.val_frac / (1.-self.tr_frac), + train_size=self.val_frac / (1. - self.tr_frac), random_state=self.seed ) except: