Skip to content

Commit

Permalink
feature(whl): add ag_news dataset. (#31)
Browse files Browse the repository at this point in the history
* init commit

* polish

* debug

* debug

* debug

* add sogounews dataset

* debug

* debug

* debug

* debug

* polish config

* polish docs

---------

Co-authored-by: ‘whl’ <‘[email protected]’>
Co-authored-by: 汪昊霖 <PJLAB\[email protected]>
  • Loading branch information
3 people authored Apr 23, 2024
1 parent 9fb4241 commit 3cf4f7c
Show file tree
Hide file tree
Showing 18 changed files with 1,092 additions and 317 deletions.
2 changes: 2 additions & 0 deletions fling/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@
from .tiny_imagenet import TinyImagenetDataset
from .mini_imagenet import MiniImagenetDataset
from .imagenet import ImagenetDataset
from .ag_news import AGNewsDataset
from .sogou_news import SogouNews
from .build_dataset import get_dataset
68 changes: 68 additions & 0 deletions fling/dataset/ag_news.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch
from torch.utils.data import Dataset
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

from fling.utils.registry_utils import DATASET_REGISTRY


@DATASET_REGISTRY.register('ag_news')
class AGNewsDataset(Dataset):
"""
Implementation of AG news dataset. This dataset contains over 1 million of news articles with 4 categories.
For more information, please refer to the link: http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html .
"""
vocab = None

def __init__(self, cfg: dict, train: bool):
super(AGNewsDataset, self).__init__()
self.train = train
self.cfg = cfg
split = 'train' if self.train else 'test'
self.dataset = list(AG_NEWS(cfg.data.data_path, split=split))
self.tokenizer = get_tokenizer("basic_english")
self.max_length = cfg.data.get('max_length', float('inf'))

def _yield_tokens(data_iter):
for _, text in data_iter:
dat = self.tokenizer(text)
yield dat

# Prepare vocabulary tabular.
if AGNewsDataset.vocab is None:
AGNewsDataset.vocab = build_vocab_from_iterator(
_yield_tokens(iter(self.dataset)), specials=['<unk>', '<pad>'], min_freq=5
)
AGNewsDataset.vocab.set_default_index(self.vocab["<unk>"])

real_max_len = max([len(self._process_text((self.dataset[i][1]))) for i in range(len(self.dataset))])
self.max_length = min(self.max_length, real_max_len)

print(
f'Dataset Generated. Total vocab size: {len(self.vocab)}; '
f'Max length of the input: {self.max_length}; '
f'Dataset length: {len(self.dataset)}.'
)

def _process_text(self, x):
return AGNewsDataset.vocab(self.tokenizer(x))

def _process_label(self, x):
return int(x) - 1

def __len__(self):
return len(self.dataset)

def __getitem__(self, item):
label, text = self.dataset[item]
label = self._process_label(label)
text = self._process_text(text)

if len(text) > self.max_length:
text = text[:self.max_length]
else:
text += [self.vocab['<pad>']] * (self.max_length - len(text))

assert len(text) == self.max_length
return {'input': torch.LongTensor(text), 'class_id': label}
72 changes: 72 additions & 0 deletions fling/dataset/sogou_news.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import csv

import torch
from torch.utils.data import Dataset
from torchtext.datasets import SogouNews
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

from fling.utils.registry_utils import DATASET_REGISTRY


@DATASET_REGISTRY.register('sogou_news')
class SogouNewsDataset(Dataset):
"""
Implementation of Sogou news dataset. The Sogou News dataset is a mixture of 2,909,551 news articles from the \
SogouCA and SogouCS news corpora, in 5 categories. For more information, please refer to the link: \
http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html .
"""
vocab = None

def __init__(self, cfg: dict, train: bool):
super(SogouNewsDataset, self).__init__()
self.train = train
self.cfg = cfg
split = 'train' if self.train else 'test'
csv.field_size_limit(int(1e8))
self.dataset = list(SogouNews(cfg.data.data_path, split=split))
self.tokenizer = get_tokenizer("basic_english")
self.max_length = cfg.data.get('max_length', float('inf'))

def _yield_tokens(data_iter):
for _, text in data_iter:
dat = self.tokenizer(text)
yield dat

# Prepare vocabulary tabular.
if SogouNewsDataset.vocab is None:
SogouNewsDataset.vocab = build_vocab_from_iterator(
_yield_tokens(iter(self.dataset)), specials=['<unk>', '<pad>'], min_freq=5
)
SogouNewsDataset.vocab.set_default_index(self.vocab["<unk>"])

real_max_len = max([len(self._process_text((self.dataset[i][1]))) for i in range(len(self.dataset))])
self.max_length = min(self.max_length, real_max_len)

print(
f'Dataset Generated. Total vocab size: {len(self.vocab)}; '
f'Max length of the input: {self.max_length}; '
f'Dataset length: {len(self.dataset)}.'
)

def _process_text(self, x):
return SogouNewsDataset.vocab(self.tokenizer(x))

def _process_label(self, x):
return int(x) - 1

def __len__(self):
return len(self.dataset)

def __getitem__(self, item):
label, text = self.dataset[item]
label = self._process_label(label)
text = self._process_text(text)

if len(text) > self.max_length:
text = text[:self.max_length]
else:
text += [self.vocab['<pad>']] * (self.max_length - len(text))

assert len(text) == self.max_length
return {'input': torch.LongTensor(text), 'class_id': label}
6 changes: 1 addition & 5 deletions fling/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,5 @@
from .resnet import resnet4, resnet6, resnet8, resnet10, resnet18, resnet34, resnet50
from .swin_transformer import SwinTransformer
from .vit import ViT
from .language_classifier import TransformerClassifier
from .build_model import get_model

# Algorithm specific models
# FedRoD
from .fedrod_resnet import fedrod_resnet4, fedrod_resnet6, fedrod_resnet8, fedrod_resnet10, fedrod_resnet18,\
fedrod_resnet34, fedrod_resnet50
Loading

0 comments on commit 3cf4f7c

Please sign in to comment.