-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature(whl): add ag_news dataset. (#31)
* init commit * polish * debug * debug * debug * add sogounews dataset * debug * debug * debug * debug * polish config * polish docs --------- Co-authored-by: ‘whl’ <‘[email protected]’> Co-authored-by: 汪昊霖 <PJLAB\[email protected]>
- Loading branch information
1 parent
9fb4241
commit 3cf4f7c
Showing
18 changed files
with
1,092 additions
and
317 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import torch | ||
from torch.utils.data import Dataset | ||
from torchtext.datasets import AG_NEWS | ||
from torchtext.data.utils import get_tokenizer | ||
from torchtext.vocab import build_vocab_from_iterator | ||
|
||
from fling.utils.registry_utils import DATASET_REGISTRY | ||
|
||
|
||
@DATASET_REGISTRY.register('ag_news') | ||
class AGNewsDataset(Dataset): | ||
""" | ||
Implementation of AG news dataset. This dataset contains over 1 million of news articles with 4 categories. | ||
For more information, please refer to the link: http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html . | ||
""" | ||
vocab = None | ||
|
||
def __init__(self, cfg: dict, train: bool): | ||
super(AGNewsDataset, self).__init__() | ||
self.train = train | ||
self.cfg = cfg | ||
split = 'train' if self.train else 'test' | ||
self.dataset = list(AG_NEWS(cfg.data.data_path, split=split)) | ||
self.tokenizer = get_tokenizer("basic_english") | ||
self.max_length = cfg.data.get('max_length', float('inf')) | ||
|
||
def _yield_tokens(data_iter): | ||
for _, text in data_iter: | ||
dat = self.tokenizer(text) | ||
yield dat | ||
|
||
# Prepare vocabulary tabular. | ||
if AGNewsDataset.vocab is None: | ||
AGNewsDataset.vocab = build_vocab_from_iterator( | ||
_yield_tokens(iter(self.dataset)), specials=['<unk>', '<pad>'], min_freq=5 | ||
) | ||
AGNewsDataset.vocab.set_default_index(self.vocab["<unk>"]) | ||
|
||
real_max_len = max([len(self._process_text((self.dataset[i][1]))) for i in range(len(self.dataset))]) | ||
self.max_length = min(self.max_length, real_max_len) | ||
|
||
print( | ||
f'Dataset Generated. Total vocab size: {len(self.vocab)}; ' | ||
f'Max length of the input: {self.max_length}; ' | ||
f'Dataset length: {len(self.dataset)}.' | ||
) | ||
|
||
def _process_text(self, x): | ||
return AGNewsDataset.vocab(self.tokenizer(x)) | ||
|
||
def _process_label(self, x): | ||
return int(x) - 1 | ||
|
||
def __len__(self): | ||
return len(self.dataset) | ||
|
||
def __getitem__(self, item): | ||
label, text = self.dataset[item] | ||
label = self._process_label(label) | ||
text = self._process_text(text) | ||
|
||
if len(text) > self.max_length: | ||
text = text[:self.max_length] | ||
else: | ||
text += [self.vocab['<pad>']] * (self.max_length - len(text)) | ||
|
||
assert len(text) == self.max_length | ||
return {'input': torch.LongTensor(text), 'class_id': label} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import csv | ||
|
||
import torch | ||
from torch.utils.data import Dataset | ||
from torchtext.datasets import SogouNews | ||
from torchtext.data.utils import get_tokenizer | ||
from torchtext.vocab import build_vocab_from_iterator | ||
|
||
from fling.utils.registry_utils import DATASET_REGISTRY | ||
|
||
|
||
@DATASET_REGISTRY.register('sogou_news') | ||
class SogouNewsDataset(Dataset): | ||
""" | ||
Implementation of Sogou news dataset. The Sogou News dataset is a mixture of 2,909,551 news articles from the \ | ||
SogouCA and SogouCS news corpora, in 5 categories. For more information, please refer to the link: \ | ||
http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html . | ||
""" | ||
vocab = None | ||
|
||
def __init__(self, cfg: dict, train: bool): | ||
super(SogouNewsDataset, self).__init__() | ||
self.train = train | ||
self.cfg = cfg | ||
split = 'train' if self.train else 'test' | ||
csv.field_size_limit(int(1e8)) | ||
self.dataset = list(SogouNews(cfg.data.data_path, split=split)) | ||
self.tokenizer = get_tokenizer("basic_english") | ||
self.max_length = cfg.data.get('max_length', float('inf')) | ||
|
||
def _yield_tokens(data_iter): | ||
for _, text in data_iter: | ||
dat = self.tokenizer(text) | ||
yield dat | ||
|
||
# Prepare vocabulary tabular. | ||
if SogouNewsDataset.vocab is None: | ||
SogouNewsDataset.vocab = build_vocab_from_iterator( | ||
_yield_tokens(iter(self.dataset)), specials=['<unk>', '<pad>'], min_freq=5 | ||
) | ||
SogouNewsDataset.vocab.set_default_index(self.vocab["<unk>"]) | ||
|
||
real_max_len = max([len(self._process_text((self.dataset[i][1]))) for i in range(len(self.dataset))]) | ||
self.max_length = min(self.max_length, real_max_len) | ||
|
||
print( | ||
f'Dataset Generated. Total vocab size: {len(self.vocab)}; ' | ||
f'Max length of the input: {self.max_length}; ' | ||
f'Dataset length: {len(self.dataset)}.' | ||
) | ||
|
||
def _process_text(self, x): | ||
return SogouNewsDataset.vocab(self.tokenizer(x)) | ||
|
||
def _process_label(self, x): | ||
return int(x) - 1 | ||
|
||
def __len__(self): | ||
return len(self.dataset) | ||
|
||
def __getitem__(self, item): | ||
label, text = self.dataset[item] | ||
label = self._process_label(label) | ||
text = self._process_text(text) | ||
|
||
if len(text) > self.max_length: | ||
text = text[:self.max_length] | ||
else: | ||
text += [self.vocab['<pad>']] * (self.max_length - len(text)) | ||
|
||
assert len(text) == self.max_length | ||
return {'input': torch.LongTensor(text), 'class_id': label} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.