Skip to content

Commit

Permalink
Merge pull request #4 from hyp1231/master
Browse files Browse the repository at this point in the history
FEA: data & dataset
  • Loading branch information
hyp1231 authored Jun 28, 2020
2 parents fb3e75b + 3547913 commit a0bb1cb
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.vscode/
2 changes: 2 additions & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .dataset import *
from .data import Data
60 changes: 60 additions & 0 deletions data/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
# @Author : Yupeng Hou
# @Email : [email protected]
# @File : data.py

from torch.utils.data import DataLoader, Dataset

class Data(Dataset):
def __init__(self, config, interaction):
'''
:param config(config.Config()): global configurations
:param interaction(dict): dict of {
Name: Tensor (batch, )
}
'''
self.config = config
self.interaction = interaction

self._check()

self.dataloader = DataLoader(
dataset=self,
batch_size=config['train.batch_size'],
shuffle=False,
num_workers=config['data.num_workers']
)

def _check(self):
assert len(self.interaction.keys()) > 0
for i, k in enumerate(self.interaction):
if not i:
self.length = len(self.interaction[k])
else:
assert len(self.interaction[k]) == self.length

def __getitem__(self, index):
ret = {}
for k in self.interaction:
ret[k] = self.interaction[k][index]
return ret

def __len__(self):
return self.length

def __iter__(self):
return iter(self.dataloader)

def split(self, ratio):
'''
:param ratio(float): A float in (0, 1), representing the first object's ratio
:return: Two object of class Data, which has (ratio) and (1 - ratio), respectively
'''
div = int(ratio * self.__len__())
first_inter = {}
second_inter = {}
for k in self.interaction:
first_inter[k] = self.interaction[k][:div]
second_inter[k] = self.interaction[k][div:]
return Data(config=self.config, interaction=first_inter), \
Data(config=self.config, interaction=second_inter)
73 changes: 73 additions & 0 deletions data/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
# @Author : Yupeng Hou
# @Email : [email protected]
# @File : dataset.py

from os.path import isdir, isfile
import torch
from .data import Data

class AbstractDataset(object):
def __init__(self, config):
self.token = config['data.name']
self.dataset_path = config['data.path']
self.dataset = self._load_data(config)

def __str__(self):
return 'Dataset - {}'.format(self.token)

def _load_data(self, config):
'''
:return: data.Data
'''
raise NotImplementedError('Func [_load_data] of [{}] has not been implemented'.format(
self.__str__()
))

def _download_dataset(self):
'''
Download dataset from url
:return: path of the downloaded dataset
'''
pass

def preprocessing(self, workflow=None):
'''
Preprocessing of the dataset
:param workflow List(List(str, *args))
'''
cur = self.dataset
for func, params in workflow:
if func == 'split':
cur = cur.split(*params)
return cur

class ML100kDataset(AbstractDataset):
def __init__(self, config):
super(ML100kDataset, self).__init__(config)

def _load_data(self, config):
if self.dataset_path:
dataset_path = config['data.path']
else:
dataset_path = self._download_dataset(self.token)

if not isfile(dataset_path):
raise ValueError('[{}] is a illegal path.'.format(dataset_path))

lines = []
with open(dataset_path, 'r', encoding='utf-8') as file:
for line in file:
line = map(int, line.strip().split('\t'))
lines.append(line)
user_id, item_id, rating, timestamp = map(torch.LongTensor, zip(*lines))
return Data(
config=config,
interaction={
'user_id': user_id,
'item_id': item_id,
'rating': rating,
'timestamp': timestamp
}
)

0 comments on commit a0bb1cb

Please sign in to comment.