Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Haoxuan/add nextitnet #1126

Merged
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ The table below lists the recommender algorithms currently available in the repo
| Attentive Asynchronous Singular Value Decomposition (A2SVD)<sup>*</sup> | [Python CPU / Python GPU](notebooks/00_quick_start/sequential_recsys_amazondataset.ipynb) | Collaborative Filtering | Sequential-based algorithm that aims to capture both long and short-term user preferences using attention mechanism |
| Cornac/Bayesian Personalized Ranking (BPR) | [Python CPU](notebooks/02_model/cornac_bpr_deep_dive.ipynb) | Collaborative Filtering | Matrix factorization algorithm for predicting item ranking with implicit feedback |
| Convolutional Sequence Embedding Recommendation (Caser) | [Python CPU / Python GPU](notebooks/00_quick_start/sequential_recsys_amazondataset.ipynb) | Collaborative Filtering | Algorithm based on convolutions that aims to capture both user’s general preferences and sequential patterns |
| A Simple Convolutional Generative Network for Next Item Recommendation (NextItNet) | [Python CPU / Python GPU](notebooks/00_quick_start/sequential_recsys_amazondataset.ipynb) | Collaborative Filtering | Algorithm based on dilated convolutions and residual network that aims to capture sequential patterns |
| Deep Knowledge-Aware Network (DKN)<sup>*</sup> | [Python CPU / Python GPU](notebooks/00_quick_start/dkn_synthetic.ipynb) | Content-Based Filtering | Deep learning algorithm incorporating a knowledge graph and article embeddings to provide powerful news or article recommendations |
| Extreme Deep Factorization Machine (xDeepFM)<sup>*</sup> | [Python CPU / Python GPU](notebooks/00_quick_start/xdeepfm_criteo.ipynb) | Hybrid | Deep learning based algorithm for implicit and explicit feedback with user/item features |
| FastAI Embedding Dot Bias (FAST) | [Python CPU / Python GPU](notebooks/00_quick_start/fastai_movielens.ipynb) | Collaborative Filtering | General purpose algorithm with embeddings and biases for users and items |
Expand Down
1 change: 1 addition & 0 deletions notebooks/00_quick_start/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ In this directory, notebooks are provided to perform a quick demonstration of di
| [sar_azureml](sar_movielens_with_azureml.ipynb)| MovieLens | Python CPU | An example of how to utilize and evaluate SAR using the [Azure Machine Learning service](https://docs.microsoft.com/azure/machine-learning/service/overview-what-is-azure-ml) (AzureML). It takes the content of the [sar quickstart notebook](sar_movielens.ipynb) and demonstrates how to use the power of the cloud to manage data, switch to powerful GPU machines, and monitor runs while training a model.
| [a2svd](sequential_recsys_amazondataset.ipynb) | Amazon | Python CPU, GPU | Use A2SVD [11] to predict a set of movies the user is going to interact in a short time. |
| [caser](sequential_recsys_amazondataset.ipynb) | Amazon | Python CPU, GPU | Use Caser [12] to predict a set of movies the user is going to interact in a short time. |
| [nextitnet](sequential_recsys_amazondataset.ipynb) | Amazon | Python CPU, GPU | Use NextItNet [12] to predict a set of movies the user is going to interact in a short time. |
| [gru4rec](sequential_recsys_amazondataset.ipynb) | Amazon | Python CPU, GPU | Use GRU4Rec [13] to predict a set of movies the user is going to interact in a short time. |
| [sli-rec](sequential_recsys_amazondataset.ipynb) | Amazon | Python CPU, GPU | Use SLi-Rec [11] to predict a set of movies the user is going to interact in a short time. |
| [wide-and-deep](wide_deep_movielens.ipynb) | MovieLens | Python CPU, GPU | Utilizing Wide-and-Deep Model (Wide-and-Deep) [5] to predict movie ratings in a Python+GPU (TensorFlow) environment.
Expand Down
16 changes: 12 additions & 4 deletions notebooks/00_quick_start/sequential_recsys_amazondataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"### Example: SLi_Rec : Adaptive User Modeling with Long and Short-Term Preferences for Personailzed Recommendation\n",
"Unlike a general recommender such as Matrix Factorization or xDeepFM (in the repo) which doesn't consider the order of the user's activities, sequential recommender systems take the sequence of the user behaviors as context and the goal is to predict the items that the user will interact in a short time (in an extreme case, the item that the user will interact next).\n",
"\n",
"This notebook aims to give you a quick example of how to train a sequential model based on a public Amazon dataset. Currently, we can support GRU4Rec \\[2\\], Caser \\[3\\], A2SVD \\[1\\] and SLi_Rec \\[1\\]. Without loss of generality, this notebook takes [SLi_Rec model](https://www.microsoft.com/en-us/research/uploads/prod/2019/07/IJCAI19-ready_v1.pdf) for example.\n",
"This notebook aims to give you a quick example of how to train a sequential model based on a public Amazon dataset. Currently, we can support NextItNet \\[4\\], GRU4Rec \\[2\\], Caser \\[3\\], A2SVD \\[1\\] and SLi_Rec \\[1\\]. Without loss of generality, this notebook takes [SLi_Rec model](https://www.microsoft.com/en-us/research/uploads/prod/2019/07/IJCAI19-ready_v1.pdf) for example.\n",
"SLi_Rec \\[1\\] is a deep learning-based model aims at capturing both long and short-term user preferences for precise recommender systems. To summarize, SLi_Rec has the following key properties:\n",
"\n",
"* It adopts the attentive \"Asymmetric-SVD\" paradigm for long-term modeling;\n",
Expand Down Expand Up @@ -71,8 +71,10 @@
"#from reco_utils.recommender.deeprec.models.sequential.asvd import A2SVDModel\n",
"#from reco_utils.recommender.deeprec.models.sequential.caser import CaserModel\n",
"#from reco_utils.recommender.deeprec.models.sequential.gru4rec import GRU4RecModel\n",
"#from reco_utils.recommender.deeprec.models.sequential.nextitnet import NextItNetModel\n",
"\n",
"from reco_utils.recommender.deeprec.io.sequential_iterator import SequentialIterator\n",
"#from reco_utils.recommender.deeprec.io.nextitnet_iterator import NextItNetIterator\n",
"\n",
"print(\"System version: {}\".format(sys.version))\n",
"print(\"Tensorflow version: {}\".format(tf.__version__))\n",
Expand Down Expand Up @@ -181,7 +183,8 @@
"if not os.path.exists(train_file):\n",
" download_and_extract(reviews_name, reviews_file)\n",
" download_and_extract(meta_name, meta_file)\n",
" data_preprocessing(*input_files, sample_rate=sample_rate, valid_num_ngs=valid_num_ngs, test_num_ngs=test_num_ngs)"
" data_preprocessing(*input_files, sample_rate=sample_rate, valid_num_ngs=valid_num_ngs, test_num_ngs=test_num_ngs)\n",
" # data_preprocessing(*input_files, sample_rate=sample_rate, valid_num_ngs=valid_num_ngs, test_num_ngs=test_num_ngs, is_history_expanding=False)\n"
]
},
{
Expand Down Expand Up @@ -452,8 +455,11 @@
"| GRU4Rec | 0.8411 | 0.8332 | 0.3213 | 0.4547 | 439.0 | 4285.0 | max_seq_length=50, hidden_size=40|\n",
"| Caser | 0.8244 | 0.8171 | 0.283 | 0.4194 | 314.3 | 5369.9 | T=1, n_v=128, n_h=128, L=3, min_seq_length=5|\n",
"| SLi_Rec | 0.8631 | 0.8519 | 0.3491 | 0.4842 | 549.6 | 5014.0 | attention_size=40, max_seq_length=50, hidden_size=40|\n",
"| NextItNet* | 0.6745 | 0.6697 | 0.0307 | 0.1567 | 112.0 | 214.5 | min_seq_length=3, dilations=\\[1,2,4,1,2,4\\], kernel_size=3 |\n",
"\n",
" Note that the four models are grid searched with a coarse granularity and the results are for reference only. "
" Note 1: The five models are grid searched with a coarse granularity and the results are for reference only.\n",
" <br>Note 2: NextItNet model requires a dataset with strong sequence property, but the Amazon dataset used in this notebook does not meet that requirement, so NextItNet Model may not performance good. If you wish to use other datasets with strong sequence property, NextItNet is recommended.\n",
" <br>Note 3: Time cost of NextItNet Model is significantly shorter than other models because it doesn't need a history expanding of training data."
]
},
{
Expand All @@ -465,7 +471,9 @@
"\n",
"\\[2\\] Balázs Hidasi, Alexandros Karatzoglou, Linas Baltrunas, Domonkos Tikk. Session-based Recommendations with Recurrent Neural Networks. ICLR (Poster) 2016\n",
"\n",
"\\[3\\] Tang, Jiaxi, and Ke Wang. Personalized top-n sequential recommendation via convolutional sequence embedding. Proceedings of the Eleventh ACM International Conference on Web Search and Data Mining. ACM, 2018."
"\\[3\\] Tang, Jiaxi, and Ke Wang. Personalized top-n sequential recommendation via convolutional sequence embedding. Proceedings of the Eleventh ACM International Conference on Web Search and Data Mining. ACM, 2018.\n",
"\n",
"\\[4\\] Yuan, F., Karatzoglou, A., Arapakis, I., Jose, J. M., & He, X. A Simple Convolutional Generative Network for Next Item Recommendation. WSDM, 2019"
]
},
{
Expand Down
83 changes: 82 additions & 1 deletion reco_utils/dataset/amazon_reviews.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def data_preprocessing(
sample_rate=0.01,
valid_num_ngs=4,
test_num_ngs=9,
is_history_expanding=True,
):
"""Create data for training, validation and testing from original dataset

Expand All @@ -37,7 +38,10 @@ def data_preprocessing(
_create_item2cate(instance_output)
sampled_instance_file = _get_sampled_data(instance_output, sample_rate=sample_rate)
preprocessed_output = _data_processing(sampled_instance_file)
_data_generating(preprocessed_output, train_file, valid_file, test_file)
if is_history_expanding:
_data_generating(preprocessed_output, train_file, valid_file, test_file)
else:
_data_generating_no_history_expanding(preprocessed_output, train_file, valid_file, test_file)
_create_vocab(train_file, user_vocab, item_vocab, cate_vocab)
_negative_sampling_offline(
sampled_instance_file, valid_file, test_file, valid_num_ngs, test_num_ngs
Expand Down Expand Up @@ -233,6 +237,83 @@ def _data_generating(input_file, train_file, valid_file, test_file, min_sequence
cate_list.append(category)
dt_list.append(date_time)

def _data_generating_no_history_expanding(input_file, train_file, valid_file, test_file, min_sequence=1):
f_input = open(input_file, "r")
f_train = open(train_file, "w")
f_valid = open(valid_file, "w")
f_test = open(test_file, "w")
print("data generating...")

last_user_id = None
last_movie_id = None
last_category = None
last_datetime = None
last_tfile = None
for line in f_input:
line_split = line.strip().split("\t")
tfile = line_split[0]
label = int(line_split[1])
user_id = line_split[2]
movie_id = line_split[3]
date_time = line_split[4]
category = line_split[5]

if last_tfile == "train":
fo = f_train
elif last_tfile == "valid":
fo = f_valid
elif last_tfile == "test":
fo = f_test
if user_id != last_user_id or tfile == "valid" or tfile == "test":
if last_user_id != None:
history_clk_num = len(movie_id_list)
cat_str = ""
mid_str = ""
dt_str = ""
for c1 in cate_list[:-1]:
cat_str += c1 + ","
for mid in movie_id_list[:-1]:
mid_str += mid + ","
for dt_time in dt_list[:-1]:
dt_str += dt_time + ","
if len(cat_str) > 0:
cat_str = cat_str[:-1]
if len(mid_str) > 0:
mid_str = mid_str[:-1]
if len(dt_str) > 0:
dt_str = dt_str[:-1]
if history_clk_num > min_sequence:
fo.write(
line_split[1]
+ "\t"
+ last_user_id
+ "\t"
+ last_movie_id
+ "\t"
+ last_category
+ "\t"
+ last_datetime
+ "\t"
+ mid_str
+ "\t"
+ cat_str
+ "\t"
+ dt_str
+ "\n"
)
if tfile == "train" or last_user_id == None:
movie_id_list = []
cate_list = []
dt_list = []
last_user_id = user_id
last_movie_id = movie_id
last_category = category
last_datetime = date_time
last_tfile = tfile
if label:
movie_id_list.append(movie_id)
cate_list.append(category)
dt_list.append(date_time)

def _create_item2cate(instance_file):
print("creating item2cate dict")
Expand Down
59 changes: 59 additions & 0 deletions reco_utils/recommender/deeprec/config/nextitnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#data
#data format:sequential model
data:
user_vocab : ./tests/resources/deeprec/slirec/user_vocab.pkl # the map file of user to id
item_vocab : ./tests/resources/deeprec/slirec/item_vocab.pkl # the map file of item to id
cate_vocab : ./tests/resources/deeprec/slirec/category_vocab.pkl # the map file of category to id

#model
model:
method : classification # classification or regression
model_type : NextItNet
layer_sizes : [100, 64] # layers' size of DNN. In this example, DNN has two layers, and each layer has 100 hidden nodes.
activation : [relu, relu] # activation function for DNN
user_dropout: True
dropout : [0.3, 0.3] #drop out values for DNN layer
item_embedding_dim : 32 # the embedding dimension of items
cate_embedding_dim : 8 # the embedding dimension of categories
user_embedding_dim : 16 # the embedding dimension of users

#train
#init_method: normal,tnormal,uniform,he_normal,he_uniform,xavier_normal,xavier_uniform
train:
init_method: tnormal # method for initializing model parameters
init_value : 0.01 # stddev values for initializing model parameters
embed_l2 : 0.0001 # l2 regularization for embedding parameters
embed_l1 : 0.0000 # l1 regularization for embedding parameters
layer_l2 : 0.0001 # l2 regularization for hidden layer parameters
layer_l1 : 0.0000 # l1 regularization for hidden layer parameters
cross_l2 : 0.0000 # l2 regularization for cross layer parameters
cross_l1 : 0.000 # l1 regularization for cross layer parameters
learning_rate : 0.001
loss : softmax # pointwise: log_loss, cross_entropy_loss, square_loss pairwise: softmax
optimizer : lazyadam # adam, adadelta, sgd, ftrl, gd, padagrad, pgd, rmsprop, lazyadam
epochs : 50 # number of epoch for training
batch_size : 400 # batch size, should be constrained as an integer multiple of the number of (1 + train_num_ngs) when need_sample is True
enable_BN : True # whether to use batch normalization in hidden layers
EARLY_STOP : 10 # the number of epoch that controls EARLY STOPPING
max_seq_length : 50 # the maximum number of records in the history sequence
need_sample : True # whether to perform dynamic negative sampling in mini-batch
train_num_ngs : 4 # indicates how many negative instances followed by one positive instances if need_sample is True

min_seq_length : 3 # the minimum number of records in the history sequence
dilations : [1, 2, 4, 1, 2, 4] # dilations in each delated CNN layer
kernel_size : 3 # kernel size in each delated CNN layer

#show info
#metric :'auc', 'logloss', 'group_auc'
info:
show_step : 100 # print training information after a certain number of mini-batch
save_model: True # whether to save models
save_epoch : 1 # if save_model is set to True, save the model every save_epoch.
metrics : ['auc','logloss'] # metrics for evaluation.
pairwise_metrics : ['mean_mrr', 'ndcg@2;4;6', "group_auc"] # pairwise metrics for evaluation, available when pairwise comparisons are needed
MODEL_DIR : ./tests/resources/deeprec/nextitnet/model/nextitnet_model/ # directory of saved models.
SUMMARIES_DIR : ./tests/resources/deeprec/nextitnet/summary/nextitnet_summary/ # directory of saved summaries.
write_tfevents : True # whether to save summaries.



32 changes: 31 additions & 1 deletion reco_utils/recommender/deeprec/deeprec_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def check_type(config):
"L",
"n_v",
"n_h",
"kernel_size",
"min_seq_length",
"attention_size",
"epochs",
Expand Down Expand Up @@ -119,7 +120,13 @@ def check_type(config):
if param in config and not isinstance(config[param], str):
raise TypeError("Parameters {0} must be str".format(param))

list_parameters = ["layer_sizes", "activation", "dropout", "att_fcn_layer_sizes"]
list_parameters = [
"layer_sizes",
"activation",
"dropout",
"att_fcn_layer_sizes",
"dilations",
]
for param in list_parameters:
if param in config and not isinstance(config[param], list):
raise TypeError("Parameters {0} must be list".format(param))
Expand Down Expand Up @@ -226,6 +233,26 @@ def check_nn_config(f_config):
"hidden_size",
"att_fcn_layer_sizes",
]
elif f_config["model_type"] in [
"nextitnet",
"next_it_net",
"NextItNet",
"NEXT_IT_NET",
]:
required_parameters = [
"item_embedding_dim",
"cate_embedding_dim",
"user_embedding_dim",
"max_seq_length",
"loss",
"method",
"user_vocab",
"item_vocab",
"cate_vocab",
"dilations",
"kernel_size",
"min_seq_length",
]
else:
required_parameters = []

Expand Down Expand Up @@ -413,6 +440,9 @@ def create_hparams(flags):
att_fcn_layer_sizes=flags["att_fcn_layer_sizes"]
if "att_fcn_layer_sizes" in flags
else None,
# nextitnet
dilations=flags["dilations"] if "dilations" in flags else None,
kernel_size=flags["kernel_size"] if "kernel_size" in flags else None,
)


Expand Down
Loading