Skip to content

Commit

Permalink
Merge pull request #2 from RUCAIBox/master
Browse files Browse the repository at this point in the history
pull code
  • Loading branch information
linzihan-backforward authored Jul 1, 2020
2 parents 8ec67ed + 8a1273a commit 4d9a5a6
Show file tree
Hide file tree
Showing 13 changed files with 564 additions and 129 deletions.
177 changes: 163 additions & 14 deletions data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
# @Email : [email protected]
# @File : data.py

import random
import torch
from torch.utils.data import DataLoader, Dataset
from sampler import Sampler

class Data(Dataset):
def __init__(self, config, interaction):
def __init__(self, config, interaction, batch_size=1, sampler=None):
'''
:param config(config.Config()): global configurations
:param interaction(dict): dict of {
Expand All @@ -15,12 +18,14 @@ def __init__(self, config, interaction):
'''
self.config = config
self.interaction = interaction
self.batch_size = batch_size
self.sampler = sampler

self._check()

self.dataloader = DataLoader(
dataset=self,
batch_size=config['train.batch_size'],
batch_size=self.batch_size,
shuffle=False,
num_workers=config['data.num_workers']
)
Expand All @@ -45,16 +50,160 @@ def __len__(self):
def __iter__(self):
return iter(self.dataloader)

def split(self, ratio):
'''
:param ratio(float): A float in (0, 1), representing the first object's ratio
:return: Two object of class Data, which has (ratio) and (1 - ratio), respectively
'''
div = int(ratio * self.__len__())
first_inter = {}
second_inter = {}
def split_by_ratio(self, train_ratio, test_ratio, valid_ratio=0,
train_batch_size=None, test_batch_size=None, valid_batch_size=None
):
if train_ratio <= 0:
raise ValueError('train ratio [{}] should be possitive'.format(train_ratio))
if test_ratio <= 0:
raise ValueError('test ratio [{}] should be possitive'.format(test_ratio))
if valid_ratio < 0:
raise ValueError('valid ratio [{}] should be none negative'.format(valid_ratio))

tot_ratio = train_ratio + test_ratio + valid_ratio
train_ratio /= tot_ratio
test_ratio /= tot_ratio
# valid_ratio /= tot_ratio

train_cnt = int(train_ratio * self.__len__())
if valid_ratio == 0:
test_cnt = self.__len__() - train_cnt
# valid_cnt = 0
else:
test_cnt = int(test_ratio * self.__len__())
# valid_cnt = self.__len__() - train_cnt - test_cnt

if train_batch_size is None:
train_batch_size = self.batch_size
if test_batch_size is None:
test_batch_size = self.batch_size
if valid_batch_size is None:
valid_batch_size = self.batch_size

train_inter = {}
test_inter = {}
valid_inter = {}
for k in self.interaction:
first_inter[k] = self.interaction[k][:div]
second_inter[k] = self.interaction[k][div:]
return Data(config=self.config, interaction=first_inter), \
Data(config=self.config, interaction=second_inter)
train_inter[k] = self.interaction[k][:train_cnt]
test_inter[k] = self.interaction[k][train_cnt : train_cnt+test_cnt]
if valid_ratio > 0:
valid_inter[k] = self.interaction[k][train_cnt+test_cnt:]

if valid_ratio > 0:
return Data(config=self.config, interaction=train_inter, batch_size=train_batch_size, sampler=self.sampler), \
Data(config=self.config, interaction=test_inter, batch_size=test_batch_size, sampler=self.sampler), \
Data(config=self.config, interaction=valid_inter, batch_size=valid_batch_size, sampler=self.sampler)
else:
return Data(config=self.config, interaction=train_inter, batch_size=train_batch_size, sampler=self.sampler), \
Data(config=self.config, interaction=test_inter, batch_size=test_batch_size, sampler=self.sampler)

def random_shuffle(self):
idx = list(range(self.__len__()))
random.shuffle(idx)
next_inter = {}
pass
# TODO torch.xxx to random shuffle self.interaction

def remove_lower_value_by_key(self, key, min_remain_value=0):
new_inter = {}
for k in self.interaction:
new_inter[k] = []
for i in range(self.__len__()):
if self.interaction[key][i] >= min_remain_value:
for k in self.interaction:
new_inter[k].append(self.interaction[k][i])
for k in self.interaction:
new_inter[k] = torch.stack(new_inter[k])

new_sampler = Sampler(
self.sampler.n_users, self.sampler.n_items,
new_inter['user_id'], new_inter['item_id'],
padding=self.sampler.padding, missing=self.sampler.missing
)

return Data(config=self.config, interaction=new_inter, batch_size=self.batch_size, sampler=new_sampler)

def neg_sample_1by1(self):
new_inter = {
'user_id': [],
'pos_item_id': [],
'neg_item_id': []
}
for i in range(self.__len__()):
uid = self.interaction['user_id'][i].item()
new_inter['user_id'].append(uid)
new_inter['pos_item_id'].append(self.interaction['item_id'][i].item())
new_inter['neg_item_id'].append(self.sampler.sample_by_user_id(uid)[0])
for k in new_inter:
new_inter[k] = torch.LongTensor(new_inter[k])
return Data(
config=self.config,
interaction=new_inter,
batch_size=self.batch_size,
sampler=self.sampler
)

# def neg_sample_to(self, num):
# new_inter = {
# 'user_id': [],
# 'item_list': [],
# 'label': []
# }

# uid2itemlist = {}
# for i in range(self.__len__()):
# uid = self.interaction['user_id'][i].item()
# iid = self.interaction['item_id'][i].item()
# if uid not in uid2itemlist:
# uid2itemlist[uid] = []
# uid2itemlist[uid].append(iid)
# for uid in uid2itemlist:
# pos_num = len(uid2itemlist[uid])
# if pos_num >= num:
# uid2itemlist[uid] = uid2itemlist[uid][:num-1]
# pos_num = num - 1
# neg_item_id = self.sampler.sample_by_user_id(uid, num - pos_num)
# uid2itemlist[uid] += neg_item_id
# label = [1] * pos_num + [0] * (num - pos_num)
# new_inter['user_id'].append(uid)
# new_inter['item_list'].append(uid2itemlist[uid])
# new_inter['label'].append(label)

# for k in new_inter:
# new_inter[k] = torch.LongTensor(new_inter[k])

# return Data(config=self.config, interaction=new_inter, batch_size=self.batch_size, sampler=self.sampler)

def neg_sample_to(self, num):
new_inter = {
'user_id': [],
'item_id': [],
'label': []
}

uid2itemlist = {}
for i in range(self.__len__()):
uid = self.interaction['user_id'][i].item()
iid = self.interaction['item_id'][i].item()
if uid not in uid2itemlist:
uid2itemlist[uid] = []
uid2itemlist[uid].append(iid)
for uid in uid2itemlist:
pos_num = len(uid2itemlist[uid])
if pos_num >= num:
uid2itemlist[uid] = uid2itemlist[uid][:num-1]
pos_num = num - 1
neg_item_id = self.sampler.sample_by_user_id(uid, num - pos_num)
for iid in uid2itemlist[uid]:
new_inter['user_id'].append(uid)
new_inter['item_id'].append(iid)
new_inter['label'].append(1)
for iid in neg_item_id:
new_inter['user_id'].append(uid)
new_inter['item_id'].append(iid)
new_inter['label'].append(0)

for k in new_inter:
new_inter[k] = torch.LongTensor(new_inter[k])

return Data(config=self.config, interaction=new_inter, batch_size=self.batch_size, sampler=self.sampler)
112 changes: 98 additions & 14 deletions data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
from os.path import isdir, isfile
import torch
from .data import Data
from sampler import Sampler

class AbstractDataset(object):
def __init__(self, config):
def __init__(self, config, padding=False, missing=False):
self.config = config
self.token = config['data.name']
self.dataset_path = config['data.path']
self.dataset = self._load_data(config)
self.padding = padding
self.missing = missing

self.dataset, self.sampler, self.n_users, self.n_items = self._load_data(config)

def __str__(self):
return 'Dataset - {}'.format(self.token)
Expand All @@ -31,17 +35,87 @@ def _download_dataset(self):
:return: path of the downloaded dataset
'''
pass

def _get_n_users(self, user_id):
self.user_ori2idx = {}
self.user_idx2ori = []
if self.missing:
tot_users = 2
self.user_ori2idx['padding'] = 0
self.user_ori2idx['missing'] = 1
self.user_idx2ori = ['padding', 'missing']
elif self.padding:
tot_users = 1
self.user_ori2idx['padding'] = 0
self.user_idx2ori = ['padding']
else:
tot_users = 0

for uid in user_id:
if uid not in self.user_ori2idx:
self.user_ori2idx[uid] = tot_users
self.user_idx2ori.append(uid)
tot_users += 1
return tot_users

def _get_n_items(self, item_id):
self.item_ori2idx = {}
self.item_idx2ori = []
if self.missing:
tot_items = 2
self.item_ori2idx['padding'] = 0
self.item_ori2idx['missing'] = 1
self.item_idx2ori = ['padding', 'missing']
elif self.padding:
tot_items = 1
self.item_ori2idx['padding'] = 0
self.item_idx2ori = ['padding']
else:
tot_items = 0

for iid in item_id:
if iid not in self.item_ori2idx:
self.item_ori2idx[iid] = tot_items
self.item_idx2ori.append(iid)
tot_items += 1
return tot_items

def preprocessing(self, workflow=None):
'''
Preprocessing of the dataset
:param workflow List(List(str, *args))
'''
cur = self.dataset
for func in workflow:
if func == 'split':
cur = cur.split(self.config['process.ratio'])
return cur
train_data = test_data = valid_data = None
for func in workflow['preprocessing']:
if func == 'remove_lower_value_by_key':
cur = cur.remove_lower_value_by_key(
key=self.config['process.remove_lower_value_by_key.key'],
min_remain_value=self.config['process.remove_lower_value_by_key.min_remain_value']
)
elif func == 'split_by_ratio':
train_data, test_data, valid_data = cur.split_by_ratio(
train_ratio=self.config['process.split_by_ratio.train_ratio'],
test_ratio=self.config['process.split_by_ratio.test_ratio'],
valid_ratio=self.config['process.split_by_ratio.valid_ratio'],
train_batch_size=self.config['train_batch_size'],
test_batch_size=self.config['test_batch_size'],
valid_batch_size=self.config['valid_batch_size']
)
break

for func in workflow['train']:
if func == 'neg_sample_1by1':
train_data = train_data.neg_sample_1by1()

for func in workflow['test']:
if func == 'neg_sample_to':
test_data = test_data.neg_sample_to(num=self.config['process.neg_sample_to.num'])

for func in workflow['valid']:
if func == 'neg_sample_to':
valid_data = valid_data.neg_sample_to(num=self.config['process.neg_sample_to.num'])

return train_data, test_data, valid_data

class UIRTDataset(AbstractDataset):
def __init__(self, config):
Expand All @@ -61,16 +135,26 @@ def _load_data(self, config):
for line in file:
line = map(int, line.strip().split('\t'))
lines.append(line)
user_id, item_id, rating, timestamp = map(torch.LongTensor, zip(*lines))
user_id, item_id, rating, timestamp = map(list, zip(*lines))
n_users = self._get_n_users(user_id)
n_items = self._get_n_items(item_id)

new_user_id = torch.LongTensor([self.user_ori2idx[_] for _ in user_id])
new_item_id = torch.LongTensor([self.item_ori2idx[_] for _ in item_id])

sampler = Sampler(n_users, n_items,
new_user_id, new_item_id, padding=self.padding, missing=self.missing)

return Data(
config=config,
interaction={
'user_id': user_id,
'item_id': item_id,
'rating': rating,
'timestamp': timestamp
}
)
'user_id': new_user_id,
'item_id': new_item_id,
'rating': torch.LongTensor(rating),
'timestamp': torch.LongTensor(timestamp)
},
sampler=sampler
), sampler, n_users, n_items

class ML100kDataset(UIRTDataset):
def __init__(self, config):
Expand Down
Loading

0 comments on commit 4d9a5a6

Please sign in to comment.