Skip to content

Commit

Permalink
Fix twitter dataset(#187)
Browse files Browse the repository at this point in the history
  • Loading branch information
rayrayraykk authored Jun 27, 2022
1 parent d8268ff commit 7b64422
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 47 deletions.
39 changes: 37 additions & 2 deletions federatedscope/cv/dataset/leaf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import zipfile
import os
import torch

import numpy as np
import os.path as osp

from torch.utils.data import Dataset
Expand All @@ -16,14 +19,15 @@ def is_exists(path, names):

class LEAF(Dataset):
"""Base class for LEAF dataset from "LEAF: A Benchmark for Federated Settings"
Arguments:
root (str): root path.
name (str): name of dataset, in `LEAF_NAMES`.
transform: transform for x.
target_transform: transform for y.
"""

def __init__(self, root, name, transform, target_transform):
self.root = root
self.name = name
Expand Down Expand Up @@ -84,3 +88,34 @@ def process_file(self):

def process(self):
raise NotImplementedError


class LocalDataset(Dataset):
"""
Convert data list to torch Dataset to save memory usage.
"""

def __init__(self, Xs, targets, pre_process=None, transform=None, target_transform=None):
assert len(Xs) == len(
targets), "The number of data and labels are not equal."
self.Xs = np.array(Xs)
self.targets = np.array(targets)
self.pre_process = pre_process
self.transform = transform
self.target_transform = target_transform

def __len__(self):
return len(self.Xs)

def __getitem__(self, idx):
data, target = self.Xs[idx], self.targets[idx]
if self.pre_process:
data = self.pre_process(data)

if self.transform:
data = self.transform(data)

if self.target_transform:
target = self.target_transform(target)

return data, target
88 changes: 43 additions & 45 deletions federatedscope/nlp/dataset/leaf_twitter.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,26 @@
import os
import random
import pickle
import json

import numpy as np
import torch
import math

import os.path as osp

from tqdm import tqdm
from collections import defaultdict

from sklearn.model_selection import train_test_split

from federatedscope.core.auxiliaries.utils import save_local_data, download_url
from federatedscope.cv.dataset.leaf import LEAF
from federatedscope.cv.dataset.leaf import LEAF, LocalDataset
from federatedscope.nlp.dataset.utils import *


class LEAF_TWITTER(LEAF):
"""
LEAF NLP dataset from
leaf.cmu.edu
Arguments:
root (str): root path.
name (str): name of dataset, ‘shakespeare’ or ‘xxx’.
Expand All @@ -45,6 +41,8 @@ def __init__(self,
seed=123,
transform=None,
target_transform=None):
self.root = root
self.name = name
self.s_frac = s_frac
self.tr_frac = tr_frac
self.val_frac = val_frac
Expand All @@ -53,12 +51,18 @@ def __init__(self,
if name != 'twitter':
raise ValueError(f'`name` should be `twitter`.')
else:
if not os.path.exists(
osp.join(osp.join(root, name, 'raw'), 'embs.json')):
self.download()
self.extract()
print('Loading embs...')
with open(osp.join(osp.join(root, name, 'raw'), 'embs.json'), 'r') as inf:
with open(osp.join(osp.join(root, name, 'raw'), 'embs.json'),
'r') as inf:
embs = json.load(inf)
self.id2word = embs['vocab']
self.word2id = {v: k for k, v in enumerate(self.id2word)}
super(LEAF_TWITTER, self).__init__(root, name, transform, target_transform)
super(LEAF_TWITTER, self).__init__(root, name, transform,
target_transform)
files = os.listdir(self.processed_dir)
files = [f for f in files if f.startswith('task_')]
if len(files):
Expand Down Expand Up @@ -97,61 +101,56 @@ def download(self):
for name in self.raw_file_names:
download_url(f'{url}/{name}', self.raw_dir)

def _to_bag_of_word(self, text):
bag = np.zeros(len(self.word2id))
for i in text:
if i != -1:
bag[i] += 1
else:
break
text = torch.FloatTensor(bag)

return text

def __getitem__(self, index):
"""
Arguments:
index (int): Index
:returns:
dict: {'train':[(text, target)],
'test':[(text, target)],
'val':[(text, target)]}
dict: {'train':Dataset,
'test':Dataset,
'val':Dataset}
where target is the target class.
"""
text_dict = {}
data = self.data_dict[index]
for key in data:
text_dict[key] = []
texts, targets = data[key]
for idx in range(targets.shape[0]):
text = texts[idx]

if self.transform is not None:
text = self.transform(text)
else:
# Bag of word
bag = np.zeros(len(self.word2id))
for i in text:
if i != -1:
bag[i] += 1
else:
break
text = torch.FloatTensor(bag)

if self.target_transform is not None:
target = self.target_transform(target)

text_dict[key].append((text, targets[idx]))
if self.transform:
text_dict[key] = LocalDataset(texts, targets, None,
self.transform,
self.target_transform)
else:
text_dict[key] = LocalDataset(texts, targets, None,
self._to_bag_of_word,
self.target_transform)

return text_dict

def tokenizer(self, data, targets):
"""
TOKENIZER = {
'twitter': {
'x': bag_of_words,
'y': target_to_binary
}
}
"""
# [ID, Date, Query, User, Content]
processed_data = []
for raw_text in data:
ids = [self.word2id[w] if w in self.word2id else 0 for w in split_line(raw_text[4])]
ids = [
self.word2id[w] if w in self.word2id else 0
for w in split_line(raw_text[4])
]
if len(ids) < self.max_len:
ids += [-1] * (self.max_len-len(ids))
ids += [-1] * (self.max_len - len(ids))
else:
ids = ids[: self.max_len]
ids = ids[:self.max_len]
processed_data.append(ids)
targets = [target_to_binary(raw_target) for raw_target in targets]

Expand All @@ -169,7 +168,6 @@ def process(self):
for num, file in enumerate(files):
with open(osp.join(raw_path, file), 'r') as f:
raw_data = json.load(f)

user_list = list(raw_data['user_data'].keys())
n_tasks = math.ceil(len(user_list) * self.s_frac)
random.shuffle(user_list)
Expand All @@ -189,7 +187,7 @@ def process(self):
targets = torch.LongTensor(targets)

try:
train_data, test_data, train_targets, test_targets =\
train_data, test_data, train_targets, test_targets = \
train_test_split(
data,
targets,
Expand All @@ -207,7 +205,7 @@ def process(self):
train_test_split(
test_data,
test_targets,
train_size=self.val_frac / (1.-self.tr_frac),
train_size=self.val_frac / (1. - self.tr_frac),
random_state=self.seed
)
except:
Expand Down

0 comments on commit 7b64422

Please sign in to comment.