-
Notifications
You must be signed in to change notification settings - Fork 627
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from hyp1231/master
FEA: data & dataset
- Loading branch information
Showing
4 changed files
with
136 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
.vscode/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .dataset import * | ||
from .data import Data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Author : Yupeng Hou | ||
# @Email : [email protected] | ||
# @File : data.py | ||
|
||
from torch.utils.data import DataLoader, Dataset | ||
|
||
class Data(Dataset): | ||
def __init__(self, config, interaction): | ||
''' | ||
:param config(config.Config()): global configurations | ||
:param interaction(dict): dict of { | ||
Name: Tensor (batch, ) | ||
} | ||
''' | ||
self.config = config | ||
self.interaction = interaction | ||
|
||
self._check() | ||
|
||
self.dataloader = DataLoader( | ||
dataset=self, | ||
batch_size=config['train.batch_size'], | ||
shuffle=False, | ||
num_workers=config['data.num_workers'] | ||
) | ||
|
||
def _check(self): | ||
assert len(self.interaction.keys()) > 0 | ||
for i, k in enumerate(self.interaction): | ||
if not i: | ||
self.length = len(self.interaction[k]) | ||
else: | ||
assert len(self.interaction[k]) == self.length | ||
|
||
def __getitem__(self, index): | ||
ret = {} | ||
for k in self.interaction: | ||
ret[k] = self.interaction[k][index] | ||
return ret | ||
|
||
def __len__(self): | ||
return self.length | ||
|
||
def __iter__(self): | ||
return iter(self.dataloader) | ||
|
||
def split(self, ratio): | ||
''' | ||
:param ratio(float): A float in (0, 1), representing the first object's ratio | ||
:return: Two object of class Data, which has (ratio) and (1 - ratio), respectively | ||
''' | ||
div = int(ratio * self.__len__()) | ||
first_inter = {} | ||
second_inter = {} | ||
for k in self.interaction: | ||
first_inter[k] = self.interaction[k][:div] | ||
second_inter[k] = self.interaction[k][div:] | ||
return Data(config=self.config, interaction=first_inter), \ | ||
Data(config=self.config, interaction=second_inter) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Author : Yupeng Hou | ||
# @Email : [email protected] | ||
# @File : dataset.py | ||
|
||
from os.path import isdir, isfile | ||
import torch | ||
from .data import Data | ||
|
||
class AbstractDataset(object): | ||
def __init__(self, config): | ||
self.token = config['data.name'] | ||
self.dataset_path = config['data.path'] | ||
self.dataset = self._load_data(config) | ||
|
||
def __str__(self): | ||
return 'Dataset - {}'.format(self.token) | ||
|
||
def _load_data(self, config): | ||
''' | ||
:return: data.Data | ||
''' | ||
raise NotImplementedError('Func [_load_data] of [{}] has not been implemented'.format( | ||
self.__str__() | ||
)) | ||
|
||
def _download_dataset(self): | ||
''' | ||
Download dataset from url | ||
:return: path of the downloaded dataset | ||
''' | ||
pass | ||
|
||
def preprocessing(self, workflow=None): | ||
''' | ||
Preprocessing of the dataset | ||
:param workflow List(List(str, *args)) | ||
''' | ||
cur = self.dataset | ||
for func, params in workflow: | ||
if func == 'split': | ||
cur = cur.split(*params) | ||
return cur | ||
|
||
class ML100kDataset(AbstractDataset): | ||
def __init__(self, config): | ||
super(ML100kDataset, self).__init__(config) | ||
|
||
def _load_data(self, config): | ||
if self.dataset_path: | ||
dataset_path = config['data.path'] | ||
else: | ||
dataset_path = self._download_dataset(self.token) | ||
|
||
if not isfile(dataset_path): | ||
raise ValueError('[{}] is a illegal path.'.format(dataset_path)) | ||
|
||
lines = [] | ||
with open(dataset_path, 'r', encoding='utf-8') as file: | ||
for line in file: | ||
line = map(int, line.strip().split('\t')) | ||
lines.append(line) | ||
user_id, item_id, rating, timestamp = map(torch.LongTensor, zip(*lines)) | ||
return Data( | ||
config=config, | ||
interaction={ | ||
'user_id': user_id, | ||
'item_id': item_id, | ||
'rating': rating, | ||
'timestamp': timestamp | ||
} | ||
) | ||
|