diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml
index 55b05fd9d..2238d264c 100644
--- a/.github/workflows/python-package.yml
+++ b/.github/workflows/python-package.yml
@@ -40,7 +40,4 @@ jobs:
python -m pytest -v tests/config/test_config.py
export PYTHONPATH=.
python tests/config/test_command_line.py --use_gpu=False --valid_metric=Recall@10 --split_ratio=[0.7,0.2,0.1] --metrics=['Recall@10'] --epochs=200 --eval_setting='LO_RS' --learning_rate=0.3
- - name: Test evaluation_setting
- run: |
- python -m pytest -v tests/evaluation_setting
-
+
diff --git a/README.md b/README.md
index fbd298750..a19df9c22 100644
--- a/README.md
+++ b/README.md
@@ -53,6 +53,12 @@ oriented to the GPU environment.
for testing and comparing recommendation algorithms.
## RecBole News
+**12/06/2020**: We release RecBole [v0.1.2](https://github.com/RUCAIBox/RecBole/releases/tag/v0.1.2).
+
+**11/29/2020**: We constructed preliminary experiments to test the time and memory cost on three
+different-sized datasets and provided the [test result](https://github.com/RUCAIBox/RecBole#time-and-memory-costs)
+for reference.
+
**11/03/2020**: We release the first version of RecBole **v0.1.1**.
@@ -154,35 +160,23 @@ python run_recbole.py --model=[model_name]
```
-## Time and memory cost of models
-We test our models on three datasets of different size (small size, medium size and large size) to estimate their time and memory cost. You can
-click links to check more information.
-(**NOTE:** Our test results only reflect the approximate time and memory cost of models. If you find any error in our result,
-please let us know.)
+## Time and Memory Costs
+We constructed preliminary experiments to test the time and memory cost on three different-sized datasets (small, medium and large). For detailed information, you can click the following links.
-* [General recommendation models](time_test_result/General_recommendation.md)
-* [Sequential recommendation models]()
-* [Context-aware recommendation models]()
-* [Knowledge-based recommendation models]()
+* [General recommendation models](asset/time_test_result/General_recommendation.md)
+* [Sequential recommendation models](asset/time_test_result/Sequential_recommendation.md)
+* [Context-aware recommendation models](asset/time_test_result/Context-aware_recommendation.md)
+* [Knowledge-based recommendation models](asset/time_test_result/Knowledge-based_recommendation.md)
-Here is our testing device information:
-```
-GPU: TITAN GTX
-Driver Version: 430.64
-CUDA Version: 10.1
-Memory size: 65412748 KB
-CPU: Intel(R) Xeon(R) Silver 4110 CPU @ 2.10GHz
-The number of CPU cores: 8
-Cache size: 11264KB
-```
+NOTE: Our test results only gave the approximate time and memory cost of our implementations in the RecBole library (based on our machine server). Any feedback or suggestions about the implementations and test are welcome. We will keep improving our implementations, and update these test results.
## RecBole Major Releases
| Releases | Date | Features |
|-----------|--------|-------------------------|
+| v0.1.2 | 12/06/2020 | Basic RecBole |
| v0.1.1 | 11/03/2020 | Basic RecBole |
-
## Contributing
Please let us know if you encounter a bug or have any suggestions by [filing an issue](https://github.com/RUCAIBox/RecBole/issues).
diff --git a/asset/time_test_result/Context-aware_recommendation.md b/asset/time_test_result/Context-aware_recommendation.md
new file mode 100644
index 000000000..39751b0b4
--- /dev/null
+++ b/asset/time_test_result/Context-aware_recommendation.md
@@ -0,0 +1,189 @@
+## Time and memory cost of context-aware recommendation models
+
+### Datasets information:
+
+| Dataset | #Interaction | #Feature Field | #Feature |
+| ------- | ------------: | --------------: | --------: |
+| ml-1m | 1,000,209 | 5 | 134 |
+| Criteo | 2,292,530 | 39 | 2,572,192 |
+| Avazu | 4,218,938 | 21 | 1,326,631 |
+
+### Device information
+
+```
+OS: Linux
+Python Version: 3.8.3
+PyTorch Version: 1.7.0
+cudatoolkit Version: 10.1
+GPU: TITAN RTX(24GB)
+Machine Specs: 32 CPU machine, 64GB RAM
+```
+
+### 1) ml-1m dataset:
+
+#### Time and memory cost on ml-1m dataset:
+
+| Method | Training Time (sec/epoch) | Evaluation Time (sec/epoch) | GPU Memory (GB) |
+| --------- | -----------------: | -----------------: | -----------: |
+| LR | 18.34 | 2.18 | 0.82 |
+| DIN | 20.37 | 2.26 | 1.16 |
+| DSSM | 21.93 | 2.24 | 0.95 |
+| FM | 19.33 | 2.34 | 0.83 |
+| DeepFM | 20.42 | 2.27 | 0.91 |
+| Wide&Deep | 26.13 | 2.95 | 0.89 |
+| NFM | 23.36 | 2.26 | 0.89 |
+| AFM | 20.08 | 2.26 | 0.92 |
+| AutoInt | 22.41 | 2.34 | 0.94 |
+| DCN | 28.33 | 2.97 | 0.93 |
+| FNN(DNN) | 19.51 | 2.21 | 0.91 |
+| PNN | 22.29 | 2.23 | 0.91 |
+| FFM | 22.98 | 2.47 | 0.87 |
+| FwFM | 23.38 | 2.50 | 0.85 |
+| xDeepFM | 24.40 | 2.30 | 1.06 |
+
+#### Config file of ml-1m dataset:
+
+```
+# dataset config
+field_separator: "\t"
+seq_separator: " "
+USER_ID_FIELD: user_id
+ITEM_ID_FIELD: item_id
+LABEL_FIELD: label
+threshold:
+ rating: 4.0
+drop_filter_field : True
+load_col:
+ inter: [user_id, item_id, rating]
+ item: [item_id, release_year, genre]
+ user: [user_id, age, gender, occupation]
+
+# training and evaluation
+epochs: 500
+train_batch_size: 2048
+eval_batch_size: 2048
+eval_setting: RO_RS
+group_by_user: False
+valid_metric: AUC
+metrics: ['AUC', 'LogLoss']
+```
+
+Other parameters (including model parameters) are default value.
+
+### 2)Criteo dataset:
+
+#### Time and memory cost on Criteo dataset:
+
+| Method | Training Time (sec/epoch) | Evaluation Time (sec/epoch) | GPU Memory (GB) |
+| --------- | -------------------------: | ---------------------------: | ---------------: |
+| LR | 7.65 | 0.61 | 1.11 |
+| DIN | - | - | - |
+| DSSM | - | - | - |
+| FM | 9.77 | 0.73 | 1.45 |
+| DeepFM | 13.64 | 0.83 | 1.72 |
+| Wide&Deep | 13.58 | 0.80 | 1.72 |
+| NFM | 13.36 | 0.75 | 1.72 |
+| AFM | 19.40 | 1.02 | 2.34 |
+| AutoInt | 19.40 | 0.98 | 2.06 |
+| DCN | 16.25 | 0.78 | 1.67 |
+| FNN(DNN) | 10.03 | 0.64 | 1.63 |
+| PNN | 12.92 | 0.72 | 1.85 |
+| FFM | - | - | - |
+| FwFM | 1175.24 | 8.90 | 2.12 |
+| xDeepFM | 32.27 | 1.34 | 2.25 |
+
+#### Config file of Criteo dataset:
+
+```
+# dataset config
+field_separator: "\t"
+seq_separator: " "
+USER_ID_FIELD: ~
+ITEM_ID_FIELD: ~
+LABEL_FIELD: label
+
+load_col:
+ inter: '*'
+
+highest_val:
+ index: 2292530
+
+fill_nan: True
+normalize_all: True
+min_item_inter_num: 0
+min_user_inter_num: 0
+
+drop_filter_field : True
+
+
+# training and evaluation
+epochs: 500
+train_batch_size: 2048
+eval_batch_size: 2048
+eval_setting: RO_RS
+group_by_user: False
+valid_metric: AUC
+metrics: ['AUC', 'LogLoss']
+```
+
+Other parameters (including model parameters) are default value.
+
+### 3)Avazu dataset:
+
+#### Time and memory cost on Avazu dataset:
+
+| Method | Training Time (sec/epoch) | Evaluation Time (sec/epoch) | GPU Memory (GB) |
+| --------- | -------------------------: | ---------------------------: | ---------------: |
+| LR | 9.30 | 0.76 | 1.42 |
+| DIN | - | - | - |
+| DSSM | - | - | - |
+| FM | 25.68 | 0.94 | 2.60 |
+| DeepFM | 28.41 | 1.19 | 2.66 |
+| Wide&Deep | 27.58 | 0.97 | 2.66 |
+| NFM | 30.46 | 1.06 | 2.66 |
+| AFM | 31.03 | 1.06 | 2.69 |
+| AutoInt | 38.11 | 1.41 | 2.84 |
+| DCN | 30.78 | 0.96 | 2.64 |
+| FNN(DNN) | 23.53 | 0.84 | 2.60 |
+| PNN | 25.86 | 0.90 | 2.68 |
+| FFM | - | - | - |
+| FwFM | 336.75 | 7.49 | 2.63 |
+| xDeepFM | 54.88 | 1.45 | 2.89 |
+
+#### Config file of Avazu dataset:
+
+```
+# dataset config
+field_separator: "\t"
+seq_separator: " "
+USER_ID_FIELD: ~
+ITEM_ID_FIELD: ~
+LABEL_FIELD: label
+fill_nan: True
+normalize_all: True
+
+load_col:
+ inter: '*'
+
+lowest_val:
+ timestamp: 14102931
+drop_filter_field : False
+
+# training and evaluation
+epochs: 500
+train_batch_size: 2048
+eval_batch_size: 2048
+eval_setting: RO_RS
+group_by_user: False
+valid_metric: AUC
+metrics: ['AUC', 'LogLoss']
+```
+
+Other parameters (including model parameters) are default value.
+
+
+
+
+
+
+
diff --git a/asset/time_test_result/General_recommendation.md b/asset/time_test_result/General_recommendation.md
new file mode 100644
index 000000000..e88472078
--- /dev/null
+++ b/asset/time_test_result/General_recommendation.md
@@ -0,0 +1,175 @@
+## Time and memory cost of general recommendation models
+
+### Datasets information:
+
+| Dataset | #User | #Item | #Interaction | Sparsity |
+| ------- | -------: | ------: | ------------: | --------: |
+| ml-1m | 6,041 | 3,707 | 1,000,209 | 0.9553 |
+| Netflix | 80,476 | 16,821 | 1,977,844 | 0.9985 |
+| Yelp | 102,046 | 98,408 | 2,903,648 | 0.9997 |
+
+### Device information
+
+```
+OS: Linux
+Python Version: 3.8.3
+PyTorch Version: 1.7.0
+cudatoolkit Version: 10.1
+GPU: TITAN RTX(24GB)
+Machine Specs: 32 CPU machine, 64GB RAM
+```
+
+### 1) ml-1m dataset:
+
+#### Time and memory cost on ml-1m dataset:
+
+| Method | Training Time (sec/epoch) | Evaluation Time (sec/epoch) | GPU Memory (GB) |
+| ---------- | ------------------------: | --------------------------: | --------------: |
+| Popularity | 2.11 | 8.08 | 0.82 |
+| ItemKNN | 2.00 | 11.76 | 0.82 |
+| BPRMF | 1.93 | 7.43 | 0.91 |
+| NeuMF | 4.94 | 13.12 | 0.94 |
+| DMF | 4.47 | 12.63 | 1.52 |
+| NAIS | 59.27 | 24.41 | 21.83 |
+| NGCF | 12.09 | 7.12 | 1.20 |
+| GCMC | 9.04 | 54.15 | 1.32 |
+| LightGCN | 7.83 | 7.47 | 1.15 |
+| DGCF | 181.66 | 8.06 | 6.59 |
+| ConvNCF | 8.46 | 19.60 | 1.31 |
+| FISM | 19.30 | 10.92 | 6.94 |
+| SpectralCF | 13.87 | 6.97 | 1.19 |
+
+#### Config file of ml-1m dataset:
+
+```
+# dataset config
+field_separator: "\t"
+seq_separator: " "
+USER_ID_FIELD: user_id
+ITEM_ID_FIELD: item_id
+RATING_FIELD: rating
+TIME_FIELD: timestamp
+LABEL_FIELD: label
+NEG_PREFIX: neg_
+load_col:
+ inter: [user_id, item_id, rating, timestamp]
+min_user_inter_num: 0
+min_item_inter_num: 0
+
+
+# training and evaluation
+epochs: 500
+train_batch_size: 2048
+eval_batch_size: 2048
+valid_metric: MRR@10
+```
+
+Other parameters (including model parameters) are default value.
+
+### 2)Netflix dataset:
+
+#### Time and memory cost on Netflix dataset:
+
+| Method | Training Time (sec/epoch) | Evaluation Time (sec/epoch) | GPU Memory (GB) |
+| ---------- | ----------------: | -----------------: | -----------: |
+| Popularity | 3.98 | 58.86 | 0.86 |
+| ItemKNN | 5.42 | 69.64 | 0.86 |
+| BPRMF | 4.42 | 52.81 | 1.08 |
+| NeuMF | 11.33 | 238.92 | 1.26 |
+| DMF | 20.62 | 68.89 | 7.12 |
+| NAIS | - | - | - |
+| NGCF | 52.50 | 51.60 | 2.00 |
+| GCMC | 93.15 | 1810.43 | 3.17 |
+| LightGCN | 30.21 | 47.12 | 1.58 |
+| DGCF | 750.74 | 47.23 | 12.52 |
+| ConvNCF | 17.02 | 402.65 | 1.44 |
+| FISM | 86.52 | 83.26 | 20.54 |
+| SpectralCF | 59.92 | 46.94 | 1.88 |
+
+#### Config file of Netflix dataset:
+
+```
+# dataset config
+field_separator: "\t"
+seq_separator: " "
+USER_ID_FIELD: user_id
+ITEM_ID_FIELD: item_id
+RATING_FIELD: rating
+TIME_FIELD: timestamp
+LABEL_FIELD: label
+NEG_PREFIX: neg_
+load_col:
+ inter: [user_id, item_id, rating, timestamp]
+min_user_inter_num: 3
+min_item_inter_num: 0
+lowest_val:
+ timestamp: 1133366400
+ rating: 3
+drop_filter_field : True
+
+# training and evaluation
+epochs: 500
+train_batch_size: 2048
+eval_batch_size: 2048
+valid_metric: MRR@10
+```
+
+Other parameters (including model parameters) are default value.
+
+### 3) Yelp dataset:
+
+#### Time and memory cost on Yelp dataset:
+
+| Method | Training Time (sec/epoch) | Evaluate Time (sec/epoch) | GPU Memory (GB) |
+| ---------- | -------------------------: | -------------------------: | ---------------: |
+| Popularity | 5.69 | 134.23 | 0.89 |
+| ItemKNN | 8.44 | 194.24 | 0.90 |
+| BPRMF | 6.31 | 120.03 | 1.29 |
+| NeuMF | 17.38 | 2069.53 | 1.67 |
+| DMF | 43.96 | 173.13 | 9.22 |
+| NAIS | - | - | - |
+| NGCF | 122.90 | 129.59 | 3.28 |
+| GCMC | 299.36 | 9833.24 | 5.96 |
+| LightGCN | 67.91 | 116.16 | 2.02 |
+| DGCF | 1542.00 | 119.00 | 17.17 |
+| ConvNCF | 87.56 | 11155.31 | 1.62 |
+| FISM | - | - | - |
+| SpectralCF | 138.99 | 133.37 | 3.10 |
+
+#### Config file of Yelp dataset:
+
+```
+# dataset config
+field_separator: "\t"
+seq_separator: " "
+USER_ID_FIELD: user_id
+ITEM_ID_FIELD: business_id
+RATING_FIELD: stars
+TIME_FIELD: date
+LABEL_FIELD: label
+NEG_PREFIX: neg_
+load_col:
+ inter: [user_id, business_id, stars]
+min_user_inter_num: 10
+min_item_inter_num: 4
+lowest_val:
+ stars: 3
+drop_filter_field: True
+
+# training and evaluation
+epochs: 500
+train_batch_size: 2048
+eval_batch_size: 2048
+valid_metric: MRR@10
+```
+
+Other parameters (including model parameters) are default value.
+
+
+
+
+
+
+
+
+
diff --git a/asset/time_test_result/Knowledge-based_recommendation.md b/asset/time_test_result/Knowledge-based_recommendation.md
new file mode 100644
index 000000000..cb9e80992
--- /dev/null
+++ b/asset/time_test_result/Knowledge-based_recommendation.md
@@ -0,0 +1,167 @@
+## Time and memory cost of knowledge-based recommendation models
+
+### Datasets information:
+
+| Dataset | #User | #Item | #Interaction | Sparsity | #Entity | #Relation | #Triple |
+| ------- | ------: | -------: | ------------: | --------: | ---------: | ---------: | ---------: |
+| ml-1m | 6,040 | 3,629 | 836,478 | 0.9618 | 79,388 | 51 | 385,923 |
+| ml-10m | 69,864 | 10,599 | 8,242,124 | 0.9889 | 181,941 | 51 | 1,051,385 |
+| LFM-1b | 64,536 | 156,343 | 6,544,312 | 0.9994 | 1,751,586 | 10 | 3,054,516 |
+
+### Device information
+
+```
+OS: Linux
+Python Version: 3.8.3
+PyTorch Version: 1.7.0
+cudatoolkit Version: 10.1
+GPU: TITAN RTX(24GB)
+Machine Specs: 32 CPU machine, 64GB RAM
+```
+
+### 1) ml-1m dataset:
+
+#### Time and memory cost on ml-1m dataset:
+
+| Method | Training Time (sec/epoch) | Evaluation Time (sec/epoch) | GPU Memory (GB) |
+| --------- | -------------------------: | ---------------------------: | ---------------: |
+| CKE | 3.76 | 8.73 | 1.16 |
+| KTUP | 3.82 | 17.68 | 1.04 |
+| RippleNet | 9.39 | 13.13 | 4.57 |
+| KGAT | 9.59 | 8.63 | 3.52 |
+| KGNN-LS | 4.78 | 15.09 | 1.04 |
+| KGCN | 2.25 | 13.71 | 1.04 |
+| MKR | 6.25 | 14.89 | 1.29 |
+| CFKG | 1.49 | 9.76 | 0.97 |
+
+#### Config file of ml-1m dataset:
+
+```
+# dataset config
+field_separator: "\t"
+seq_separator: " "
+USER_ID_FIELD: user_id
+ITEM_ID_FIELD: item_id
+RATING_FIELD: rating
+HEAD_ENTITY_ID_FIELD: head_id
+TAIL_ENTITY_ID_FIELD: tail_id
+RELATION_ID_FIELD: relation_id
+ENTITY_ID_FIELD: entity_id
+NEG_PREFIX: neg_
+LABEL_FIELD: label
+load_col:
+ inter: [user_id, item_id, rating]
+ kg: [head_id, relation_id, tail_id]
+ link: [item_id, entity_id]
+lowest_val:
+ rating: 3
+drop_filter_field: True
+
+# training and evaluation
+epochs: 500
+train_batch_size: 2048
+eval_batch_size: 2048
+valid_metric: MRR@10
+```
+
+Other parameters (including model parameters) are default value.
+
+### 2)ml-10m dataset:
+
+#### Time and memory cost on ml-10m dataset:
+
+| Method | Training Time (sec/epoch) | Evaluation Time (sec/epoch) | GPU Memory (GB) |
+| --------- | -------------------------: | ---------------------------: | ---------------: |
+| CKE | 8.65 | 85.53 | 1.46 |
+| KTUP | 40.71 | 507.56 | 1.43 |
+| RippleNet | 32.01 | 152.40 | 4.71 |
+| KGAT | 298.22 | 80.94 | 22.44 |
+| KGNN-LS | 15.47 | 241.57 | 1.42 |
+| KGCN | 7.73 | 244.93 | 1.42 |
+| MKR | 61.05 | 383.29 | 1.80 |
+| CFKG | 5.99 | 140.74 | 1.35 |
+
+#### Config file of ml-10m dataset:
+
+```
+# dataset config
+field_separator: "\t"
+seq_separator: " "
+USER_ID_FIELD: user_id
+ITEM_ID_FIELD: item_id
+RATING_FIELD: rating
+HEAD_ENTITY_ID_FIELD: head_id
+TAIL_ENTITY_ID_FIELD: tail_id
+RELATION_ID_FIELD: relation_id
+ENTITY_ID_FIELD: entity_id
+NEG_PREFIX: neg_
+LABEL_FIELD: label
+load_col:
+ inter: [user_id, item_id, rating]
+ kg: [head_id, relation_id, tail_id]
+ link: [item_id, entity_id]
+lowest_val:
+ rating: 3
+drop_filter_field: True
+
+# training and evaluation
+epochs: 500
+train_batch_size: 2048
+eval_batch_size: 2048
+valid_metric: MRR@10
+```
+
+Other parameters (including model parameters) are default value.
+
+### 3)LFM-1b dataset:
+
+#### Time and memory cost on LFM-1b dataset:
+
+| Method | Training Time (sec/epoch) | Evaluation Time (sec/epoch) | GPU Memory (GB) |
+| --------- | -------------------------: | ---------------------------: | ---------------: |
+| CKE | 62.99 | 82.93 | 4.45 |
+| KTUP | 91.79 | 3218.69 | 4.36 |
+| RippleNet | 126.26 | 188.38 | 6.49 |
+| KGAT | 626.07 | 75.70 | 23.28 |
+| KGNN-LS | 62.55 | 1709.10 | 4.73 |
+| KGCN | 52.54 | 1763.03 | 4.71 |
+| MKR | 290.01 | 2341.91 | 6.96 |
+| CFKG | 53.35 | 553.58 | 4.22 |
+
+#### Config file of LFM-1b dataset:
+
+```
+# dataset config
+field_separator: "\t"
+seq_separator: " "
+USER_ID_FIELD: user_id
+ITEM_ID_FIELD: tracks_id
+RATING_FIELD: rating
+HEAD_ENTITY_ID_FIELD: head_id
+TAIL_ENTITY_ID_FIELD: tail_id
+RELATION_ID_FIELD: relation_id
+ENTITY_ID_FIELD: entity_id
+NEG_PREFIX: neg_
+LABEL_FIELD: label
+load_col:
+ inter: [user_id, tracks_id, timestamp]
+ kg: [head_id, relation_id, tail_id]
+ link: [tracks_id, entity_id]
+lowest_val:
+ timestamp: 1356969600
+
+highest_val:
+ timestamp: 1362067200
+drop_filter_field: True
+min_user_inter_num: 2
+min_item_inter_num: 15
+
+# training and evaluation
+epochs: 500
+train_batch_size: 2048
+eval_batch_size: 2048
+valid_metric: MRR@10
+```
+
+Other parameters (including model parameters) are default value.
+
diff --git a/asset/time_test_result/Sequential_recommendation.md b/asset/time_test_result/Sequential_recommendation.md
new file mode 100644
index 000000000..3a0fb4f6c
--- /dev/null
+++ b/asset/time_test_result/Sequential_recommendation.md
@@ -0,0 +1,225 @@
+## Time and memory cost of sequential recommendation models
+
+### Datasets information:
+
+| Dataset | #User | #Item | #Interaction | Sparsity |
+| ---------- | -------: | ------: | ------------: | --------: |
+| ml-1m | 6,041 | 3,707 | 1,000,209 | 0.9553 |
+| DIGINETICA | 59,425 | 42,116 | 547,416 | 0.9998 |
+| Yelp | 102,046 | 98,408 | 2,903,648 | 0.9997 |
+
+### Device information
+
+```
+OS: Linux
+Python Version: 3.8.3
+PyTorch Version: 1.7.0
+cudatoolkit Version: 10.1
+GPU: TITAN RTX(24GB)
+Machine Specs: 32 CPU machine, 64GB RAM
+```
+
+### 1) ml-1m dataset:
+
+#### Time and memory cost on ml-1m dataset:
+
+| Method | Training Time (sec/epoch) | Evaluate Time (sec/epoch) | GPU Memory (GB) |
+| ---------------- | -----------------: | -----------------: | -----------: |
+| Improved GRU-Rec | 7.78 | 0.11 | 1.27 |
+| SASRec | 17.78 | 0.12 | 1.84 |
+| NARM | 8.29 | 0.11 | 1.29 |
+| FPMC | 7.51 | 0.11 | 1.18 |
+| STAMP | 7.32 | 0.11 | 1.20 |
+| Caser | 44.85 | 0.12 | 1.14 |
+| NextItNet | 16433.27 | 96.31 | 1.86 |
+| TransRec | 10.08 | 0.16 | 8.18 |
+| S3Rec | - | - | - |
+| GRU4RecF | 10.20 | 0.15 | 1.80 |
+| SASRecF | 18.84 | 0.17 | 1.78 |
+| BERT4Rec | 36.09 | 0.34 | 1.97 |
+| FDSA | 31.86 | 0.19 | 2.32 |
+| SRGNN | 327.38 | 2.19 | 1.21 |
+| GCSAN | 335.27 | 0.02 | 1.58 |
+| KSR | - | - | - |
+| GRU4RecKG | - | - | - |
+
+#### Config file of ml-1m dataset:
+
+```
+# dataset config
+field_separator: "\t"
+seq_separator: " "
+USER_ID_FIELD: user_id
+ITEM_ID_FIELD: item_id
+TIME_FIELD: timestamp
+NEG_PREFIX: neg_
+ITEM_LIST_LENGTH_FIELD: item_length
+LIST_SUFFIX: _list
+MAX_ITEM_LIST_LENGTH: 20
+POSITION_FIELD: position_id
+load_col:
+ inter: [user_id, item_id, timestamp]
+min_user_inter_num: 0
+min_item_inter_num: 0
+
+# training and evaluation
+epochs: 500
+train_batch_size: 2048
+eval_batch_size: 2048
+valid_metric: MRR@10
+eval_setting: TO_LS,full
+training_neg_sample_num: 0
+```
+
+Other parameters (including model parameters) are default value.
+
+**NOTE :**
+
+1) For FPMC and TransRec model, `training_neg_sample_num` should be `1` .
+
+2) For SASRecF, GRU4RecF and FDSA, `load_col` should as below:
+
+```
+load_col:
+ inter: [user_id, item_id, timestamp]
+ item: [item_id, genre]
+```
+
+### 2)DIGINETICA dataset:
+
+#### Time and memory cost on DIGINETICA dataset:
+
+| Method | Training Time (sec/epoch) | Evaluate Time (sec/epoch) | GPU Memory (GB) |
+| ---------------- | -----------------: | -----------------: | -----------: |
+| Improved GRU-Rec | 4.10 | 1.05 | 4.02 |
+| SASRec | 8.36 | 1.21 | 4.43 |
+| NARM | 4.30 | 1.08 | 4.09 |
+| FPMC | 2.98 | 1.08 | 4.08 |
+| STAMP | 4.27 | 1.04 | 3.88 |
+| Caser | 17.15 | 1.18 | 3.94 |
+| NextItNet | - | - | - |
+| TransRec | - | - | - |
+| S3Rec | - | - | - |
+| GRU4RecF | 4.79 | 1.17 | 4.83 |
+| SASRecF | 8.66 | 1.29 | 5.11 |
+| BERT4Rec | 16.80 | 3.54 | 7.97 |
+| FDSA | 13.44 | 1.47 | 5.66 |
+| SRGNN | 88.59 | 15.37 | 4.01 |
+| GCSAN | 96.69 | 17.11 | 4.25 |
+| KSR | - | - | - |
+| GRU4RecKG | - | - | - |
+
+#### Config file of DIGINETICA dataset:
+
+```
+# dataset config
+field_separator: "\t"
+seq_separator: " "
+USER_ID_FIELD: session_id
+ITEM_ID_FIELD: item_id
+TIME_FIELD: timestamp
+NEG_PREFIX: neg_
+ITEM_LIST_LENGTH_FIELD: item_length
+LIST_SUFFIX: _list
+MAX_ITEM_LIST_LENGTH: 20
+POSITION_FIELD: position_id
+load_col:
+ inter: [session_id, item_id, timestamp]
+min_user_inter_num: 6
+min_item_inter_num: 1
+
+# training and evaluation
+epochs: 500
+train_batch_size: 2048
+eval_batch_size: 2048
+valid_metric: MRR@10
+eval_setting: TO_LS,full
+training_neg_sample_num: 0
+```
+
+Other parameters (including model parameters) are default value.
+
+**NOTE :**
+
+1) For FPMC and TransRec model, `training_neg_sample_num` should be `1` .
+
+2) For SASRecF, GRU4RecF and FDSA, `load_col` should as below:
+
+```
+load_col:
+ inter: [session_id, item_id, timestamp]
+ item: [item_id, item_category]
+```
+
+### 3)Yelp dataset:
+
+#### Time and memory cost on Yelp dataset:
+
+| Method | Training Time (sec/epoch) | Evaluation Time (sec/epoch) | GPU Memory (GB) |
+| ---------------- | -----------------: | -----------------: | -----------: |
+| Improved GRU-Rec | 44.31 | 2.74 | 7.92 |
+| SASRec | 75.51 | 3.11 | 8.32 |
+| NARM | 45.65 | 2.76 | 7.98 |
+| FPMC | 21.05 | 3.05 | 8.22 |
+| STAMP | 42.08 | 2.72 | 7.77 |
+| Caser | 147.15 | 2.89 | 7.87 |
+| NextItNet | 45019.38 | 1670.76 | 8.44 |
+| TransRec | - | - | - |
+| S3Rec | - | - | - |
+| GRU4RecF | - | - | - |
+| SASRecF | - | - | - |
+| BERT4Rec | 193.74 | 8.43 | 16.57 |
+| FDSA | - | - | - |
+| SRGNN | 825.11 | 33.20 | 7.90 |
+| GCSAN | 837.23 | 33.00 | 8.14 |
+| KSR | - | - | - |
+| GRU4RecKG | - | - | - |
+
+#### Config file of DIGINETICA dataset:
+
+```
+# dataset config
+field_separator: "\t"
+seq_separator: " "
+USER_ID_FIELD: session_id
+ITEM_ID_FIELD: item_id
+TIME_FIELD: timestamp
+NEG_PREFIX: neg_
+ITEM_LIST_LENGTH_FIELD: item_length
+LIST_SUFFIX: _list
+MAX_ITEM_LIST_LENGTH: 20
+POSITION_FIELD: position_id
+load_col:
+ inter: [session_id, item_id, timestamp]
+min_user_inter_num: 6
+min_item_inter_num: 1
+
+# training and evaluation
+epochs: 500
+train_batch_size: 2048
+eval_batch_size: 2048
+valid_metric: MRR@10
+eval_setting: TO_LS,full
+training_neg_sample_num: 0
+```
+
+Other parameters (including model parameters) are default value.
+
+**NOTE :**
+
+1) For FPMC and TransRec model, `training_neg_sample_num` should be `1` .
+
+2) For SASRecF, GRU4RecF and FDSA, `load_col` should as below:
+
+```
+load_col:
+ inter: [session_id, item_id, timestamp]
+ item: [item_id, item_category]
+```
+
+
+
+
+
+
+
diff --git a/conda/conda_release.sh b/conda/conda_release.sh
new file mode 100644
index 000000000..a8b0e3efa
--- /dev/null
+++ b/conda/conda_release.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+
+conda-build --python 3.6 .
+printf "python 3.6 version is released \n"
+conda-build --python 3.7 .
+printf "python 3.7 version is released \n"
+conda-build --python 3.8 .
+printf "python 3.8 version is released \n"
diff --git a/recbole/__init__.py b/recbole/__init__.py
index 182dd10a5..83c915e76 100644
--- a/recbole/__init__.py
+++ b/recbole/__init__.py
@@ -2,4 +2,4 @@
from __future__ import print_function
from __future__ import division
-__version__ = '0.1.1'
+__version__ = '0.1.2'
diff --git a/recbole/config/configurator.py b/recbole/config/configurator.py
index 381164ae7..fb6784498 100644
--- a/recbole/config/configurator.py
+++ b/recbole/config/configurator.py
@@ -201,6 +201,15 @@ def _load_internal_config_dict(self, model, model_class, dataset):
sample_init_file = os.path.join(current_path, '../properties/dataset/sample.yaml')
dataset_init_file = os.path.join(current_path, '../properties/dataset/' + dataset + '.yaml')
+ context_aware_init = os.path.join(current_path, '../properties/quick_start_config/context-aware.yaml')
+ context_aware_on_ml_100k_init = os.path.join(current_path, '../properties/quick_start_config/context-aware_ml-100k.yaml')
+ DIN_init = os.path.join(current_path, '../properties/quick_start_config/sequential_DIN.yaml')
+ DIN_on_ml_100k_init = os.path.join(current_path, '../properties/quick_start_config/sequential_DIN_on_ml-100k.yaml')
+ sequential_init = os.path.join(current_path, '../properties/quick_start_config/sequential.yaml')
+ special_sequential_on_ml_100k_init = os.path.join(current_path, '../properties/quick_start_config/special_sequential_on_ml-100k.yaml')
+ sequential_embedding_model_init = os.path.join(current_path, '../properties/quick_start_config/sequential_embedding_model.yaml')
+ knowledge_base_init = os.path.join(current_path, '../properties/quick_start_config/knowledge_base.yaml')
+
self.internal_config_dict = dict()
for file in [overall_init_file, model_init_file, sample_init_file, dataset_init_file]:
if os.path.isfile(file):
@@ -215,51 +224,48 @@ def _load_internal_config_dict(self, model, model_class, dataset):
if self.internal_config_dict['MODEL_TYPE'] == ModelType.GENERAL:
pass
elif self.internal_config_dict['MODEL_TYPE'] == ModelType.CONTEXT:
- self.internal_config_dict.update({
- 'eval_setting': 'RO_RS',
- 'group_by_user': False,
- 'training_neg_sample_num': 0,
- 'metrics': ['AUC', 'LogLoss'],
- 'valid_metric': 'AUC',
- })
+ with open(context_aware_init, 'r', encoding='utf-8') as f:
+ config_dict = yaml.load(f.read(), Loader=self.yaml_loader)
+ if config_dict is not None:
+ self.internal_config_dict.update(config_dict)
if dataset == 'ml-100k':
- self.internal_config_dict.update({
- 'threshold': {'rating': 4},
- 'load_col': {'inter': ['user_id', 'item_id', 'rating', 'timestamp'],
- 'user': ['user_id', 'age', 'gender', 'occupation'],
- 'item': ['item_id', 'release_year', 'class']},
- })
-
+ with open(context_aware_on_ml_100k_init, 'r', encoding='utf-8') as f:
+ config_dict = yaml.load(f.read(), Loader=self.yaml_loader)
+ if config_dict is not None:
+ self.internal_config_dict.update(config_dict)
+
elif self.internal_config_dict['MODEL_TYPE'] == ModelType.SEQUENTIAL:
if model == 'DIN':
- self.internal_config_dict.update({
- 'eval_setting': 'TO_LS, uni100',
- 'metrics': ['AUC', 'LogLoss'],
- 'valid_metric': 'AUC',
- })
+ with open(DIN_init, 'r', encoding='utf-8') as f:
+ config_dict = yaml.load(f.read(), Loader=self.yaml_loader)
+ if config_dict is not None:
+ self.internal_config_dict.update(config_dict)
if dataset == 'ml-100k':
- self.internal_config_dict.update({
- 'load_col': {'inter': ['user_id', 'item_id', 'rating', 'timestamp'],
- 'user': ['user_id', 'age', 'gender', 'occupation'],
- 'item': ['item_id', 'release_year']},
- })
-
+ with open(DIN_on_ml_100k_init, 'r', encoding='utf-8') as f:
+ config_dict = yaml.load(f.read(), Loader=self.yaml_loader)
+ if config_dict is not None:
+ self.internal_config_dict.update(config_dict)
+ elif model in ['GRU4RecKG','KSR']:
+ with open(sequential_embedding_model_init, 'r', encoding='utf-8') as f:
+ config_dict = yaml.load(f.read(), Loader=self.yaml_loader)
+ if config_dict is not None:
+ self.internal_config_dict.update(config_dict)
else:
- self.internal_config_dict.update({
- 'eval_setting': 'TO_LS,full',
- })
+ with open(sequential_init, 'r', encoding='utf-8') as f:
+ config_dict = yaml.load(f.read(), Loader=self.yaml_loader)
+ if config_dict is not None:
+ self.internal_config_dict.update(config_dict)
if dataset == 'ml-100k' and model in ['GRU4RecF', 'SASRecF', 'FDSA', 'S3Rec']:
- self.internal_config_dict.update({
- 'load_col': {'inter': ['user_id', 'item_id', 'rating', 'timestamp'],
- 'item': ['item_id', 'release_year', 'class']},
- })
-
+ with open(special_sequential_on_ml_100k_init, 'r', encoding='utf-8') as f:
+ config_dict = yaml.load(f.read(), Loader=self.yaml_loader)
+ if config_dict is not None:
+ self.internal_config_dict.update(config_dict)
+
elif self.internal_config_dict['MODEL_TYPE'] == ModelType.KNOWLEDGE:
- self.internal_config_dict.update({
- 'load_col': {'inter': ['user_id', 'item_id', 'rating', 'timestamp'],
- 'kg': ['head_id', 'relation_id', 'tail_id'],
- 'link': ['item_id', 'entity_id']}
- })
+ with open(knowledge_base_init, 'r', encoding='utf-8') as f:
+ config_dict = yaml.load(f.read(), Loader=self.yaml_loader)
+ if config_dict is not None:
+ self.internal_config_dict.update(config_dict)
def _get_final_config_dict(self):
final_config_dict = dict()
diff --git a/recbole/data/dataloader/__init__.py b/recbole/data/dataloader/__init__.py
index 18e6f0674..c224d85a6 100644
--- a/recbole/data/dataloader/__init__.py
+++ b/recbole/data/dataloader/__init__.py
@@ -4,3 +4,4 @@
from recbole.data.dataloader.context_dataloader import *
from recbole.data.dataloader.sequential_dataloader import *
from recbole.data.dataloader.knowledge_dataloader import *
+from recbole.data.dataloader.xgboost_dataloader import *
\ No newline at end of file
diff --git a/recbole/data/dataloader/abstract_dataloader.py b/recbole/data/dataloader/abstract_dataloader.py
index 69c41a9f6..e5464606d 100644
--- a/recbole/data/dataloader/abstract_dataloader.py
+++ b/recbole/data/dataloader/abstract_dataloader.py
@@ -60,6 +60,8 @@ def __init__(self, config, dataset,
self.history_item_matrix = self.dataset.history_item_matrix
self.history_user_matrix = self.dataset.history_user_matrix
self.inter_matrix = self.dataset.inter_matrix
+ self.get_user_feature = self.dataset.get_user_feature
+ self.get_item_feature = self.dataset.get_item_feature
for dataset_attr in self.dataset._dataloader_apis:
try:
@@ -80,7 +82,7 @@ def setup(self):
pass
def data_preprocess(self):
- """This function is used to do some data preprocess, such as pre-neg-sampling and pre-data-augmentation.
+ """This function is used to do some data preprocess, such as pre-data-augmentation.
By default, it will do nothing.
"""
pass
@@ -137,23 +139,3 @@ def upgrade_batch_size(self, batch_size):
"""
if self.batch_size < batch_size:
self.set_batch_size(batch_size)
-
- def get_user_feature(self):
- """It is similar to :meth:`~recbole.data.dataset.dataset.Dataset.get_user_feature`, but it will return an
- :class:`~recbole.data.interaction.Interaction` of user feature instead of a :class:`pandas.DataFrame`.
-
- Returns:
- Interaction: The interaction of user feature.
- """
- user_df = self.dataset.get_user_feature()
- return self._dataframe_to_interaction(user_df)
-
- def get_item_feature(self):
- """It is similar to :meth:`~recbole.data.dataset.dataset.Dataset.get_item_feature`, but it will return an
- :class:`~recbole.data.interaction.Interaction` of item feature instead of a :class:`pandas.DataFrame`.
-
- Returns:
- Interaction: The interaction of item feature.
- """
- item_df = self.dataset.get_item_feature()
- return self._dataframe_to_interaction(item_df)
diff --git a/recbole/data/dataloader/general_dataloader.py b/recbole/data/dataloader/general_dataloader.py
index 9bd3b733b..b80feca5f 100644
--- a/recbole/data/dataloader/general_dataloader.py
+++ b/recbole/data/dataloader/general_dataloader.py
@@ -20,6 +20,7 @@
from recbole.data.dataloader.abstract_dataloader import AbstractDataLoader
from recbole.data.dataloader.neg_sample_mixin import NegSampleMixin, NegSampleByMixin
from recbole.utils import DataLoaderType, InputType
+from recbole.data.interaction import Interaction, cat_interactions
class GeneralDataLoader(AbstractDataLoader):
@@ -50,7 +51,7 @@ def _shuffle(self):
def _next_batch_data(self):
cur_data = self.dataset[self.pr: self.pr + self.step]
self.pr += self.step
- return self._dataframe_to_interaction(cur_data)
+ return cur_data
class GeneralNegSampleDataLoader(NegSampleByMixin, AbstractDataLoader):
@@ -72,6 +73,8 @@ class GeneralNegSampleDataLoader(NegSampleByMixin, AbstractDataLoader):
"""
def __init__(self, config, dataset, sampler, neg_sample_args,
batch_size=1, dl_format=InputType.POINTWISE, shuffle=False):
+ self.uid_field = dataset.uid_field
+ self.iid_field = dataset.iid_field
self.uid_list, self.uid2index, self.uid2items_num = None, None, None
super().__init__(config, dataset, sampler, neg_sample_args,
@@ -79,22 +82,23 @@ def __init__(self, config, dataset, sampler, neg_sample_args,
def setup(self):
if self.user_inter_in_one_batch:
- self.uid_list, self.uid2index, self.uid2items_num = self.dataset.uid2index
- self._batch_size_adaptation()
-
- def data_preprocess(self):
- if self.user_inter_in_one_batch:
- new_inter_num = 0
- new_inter_feat = []
+ uid_field = self.dataset.uid_field
+ user_num = self.dataset.user_num
+ self.dataset.sort(by=uid_field, ascending=True)
+ self.uid_list = []
+ start, end = dict(), dict()
+ for i, uid in enumerate(self.dataset.inter_feat[uid_field].numpy()):
+ if uid not in start:
+ self.uid_list.append(uid)
+ start[uid] = i
+ end[uid] = i
+ self.uid2index = np.array([None] * user_num)
+ self.uid2items_num = np.zeros(user_num, dtype=np.int64)
for uid in self.uid_list:
- index = self.uid2index[uid]
- new_inter_feat.append(self._neg_sampling(self.dataset.inter_feat[index]))
- new_num = len(new_inter_feat[-1])
- self.uid2index[uid] = slice(new_inter_num, new_inter_num + new_num)
- self.uid2items_num[uid] = new_num
- self.dataset.inter_feat = pd.concat(new_inter_feat, ignore_index=True)
- else:
- self.dataset.inter_feat = self._neg_sampling(self.dataset.inter_feat)
+ self.uid2index[uid] = slice(start[uid], end[uid] + 1)
+ self.uid2items_num[uid] = end[uid] - start[uid] + 1
+ self.uid_list = np.array(self.uid_list)
+ self._batch_size_adaptation()
def _batch_size_adaptation(self):
if self.user_inter_in_one_batch:
@@ -111,7 +115,7 @@ def _batch_size_adaptation(self):
else:
batch_num = max(self.batch_size // self.times, 1)
new_batch_size = batch_num * self.times
- self.step = batch_num if self.real_time else new_batch_size
+ self.step = batch_num
self.upgrade_batch_size(new_batch_size)
@property
@@ -129,53 +133,44 @@ def _shuffle(self):
def _next_batch_data(self):
if self.user_inter_in_one_batch:
- sampling_func = self._neg_sampling if self.real_time else (lambda x: x)
- cur_data = []
- for uid in self.uid_list[self.pr: self.pr + self.step]:
+ uid_list = self.uid_list[self.pr: self.pr + self.step]
+ data_list = []
+ for uid in uid_list:
index = self.uid2index[uid]
- cur_data.append(sampling_func(self.dataset[index]))
- cur_data = pd.concat(cur_data, ignore_index=True)
- pos_len_list = self.uid2items_num[self.uid_list[self.pr: self.pr + self.step]]
+ data_list.append(self._neg_sampling(self.dataset[index]))
+ cur_data = cat_interactions(data_list)
+ pos_len_list = self.uid2items_num[uid_list]
user_len_list = pos_len_list * self.times
+ cur_data.set_additional_info(list(pos_len_list), list(user_len_list))
self.pr += self.step
- return self._dataframe_to_interaction(cur_data, list(pos_len_list), list(user_len_list))
+ return cur_data
else:
- cur_data = self.dataset[self.pr: self.pr + self.step]
+ cur_data = self._neg_sampling(self.dataset[self.pr: self.pr + self.step])
self.pr += self.step
- if self.real_time:
- cur_data = self._neg_sampling(cur_data)
- return self._dataframe_to_interaction(cur_data)
+ return cur_data
def _neg_sampling(self, inter_feat):
- uid_field = self.config['USER_ID_FIELD']
- iid_field = self.config['ITEM_ID_FIELD']
- uids = inter_feat[uid_field].to_list()
+ uids = inter_feat[self.uid_field]
neg_iids = self.sampler.sample_by_user_ids(uids, self.neg_sample_by)
- return self.sampling_func(uid_field, iid_field, neg_iids, inter_feat)
-
- def _neg_sample_by_pair_wise_sampling(self, uid_field, iid_field, neg_iids, inter_feat):
- inter_feat = pd.concat([inter_feat] * self.times, ignore_index=True)
- inter_feat.insert(len(inter_feat.columns), self.neg_item_id, neg_iids)
-
- if self.dataset.item_feat is not None:
- neg_prefix = self.config['NEG_PREFIX']
- neg_item_feat = self.dataset.item_feat.add_prefix(neg_prefix)
- inter_feat = pd.merge(inter_feat, neg_item_feat,
- on=self.neg_item_id, how='left', suffixes=('_inter', '_item'))
-
+ return self.sampling_func(inter_feat, neg_iids)
+
+ def _neg_sample_by_pair_wise_sampling(self, inter_feat, neg_iids):
+ inter_feat = inter_feat.repeat(self.times)
+ neg_item_feat = Interaction({self.iid_field: neg_iids})
+ neg_item_feat = self.dataset.join(neg_item_feat)
+ neg_item_feat.add_prefix(self.neg_prefix)
+ inter_feat.update(neg_item_feat)
return inter_feat
- def _neg_sample_by_point_wise_sampling(self, uid_field, iid_field, neg_iids, inter_feat):
+ def _neg_sample_by_point_wise_sampling(self, inter_feat, neg_iids):
pos_inter_num = len(inter_feat)
-
- new_df = pd.concat([inter_feat] * self.times, ignore_index=True)
- new_df[iid_field].values[pos_inter_num:] = neg_iids
-
- labels = np.zeros(pos_inter_num * self.times, dtype=np.int64)
- labels[: pos_inter_num] = 1
- new_df[self.label_field] = labels
-
- return new_df
+ new_data = inter_feat.repeat(self.times)
+ new_data[self.iid_field][pos_inter_num:] = neg_iids
+ new_data = self.dataset.join(new_data)
+ labels = torch.zeros(pos_inter_num * self.times)
+ labels[: pos_inter_num] = 1.0
+ new_data.update(Interaction({self.label_field: labels}))
+ return new_data
def get_pos_len_list(self):
"""
@@ -227,7 +222,7 @@ def __init__(self, config, dataset, sampler, neg_sample_args,
last_uid = None
positive_item = None
uid2used_item = sampler.used_ids
- for uid, iid in dataset.inter_feat[[uid_field, iid_field]].values:
+ for uid, iid in zip(dataset.inter_feat[uid_field].numpy(), dataset.inter_feat[iid_field].numpy()):
if uid != last_uid:
if last_uid is not None:
self._set_user_property(last_uid, uid2used_item[last_uid], positive_item)
@@ -236,7 +231,8 @@ def __init__(self, config, dataset, sampler, neg_sample_args,
positive_item = set()
positive_item.add(iid)
self._set_user_property(last_uid, uid2used_item[last_uid], positive_item)
- self.user_df = dataset.join(pd.DataFrame(self.uid_list, columns=[uid_field]))
+ self.uid_list = torch.tensor(self.uid_list)
+ self.user_df = dataset.join(Interaction({uid_field: self.uid_list}))
super().__init__(config, dataset, sampler, neg_sample_args,
batch_size=batch_size, dl_format=dl_format, shuffle=shuffle)
@@ -248,10 +244,7 @@ def _set_user_property(self, uid, used_item, positive_item):
swap_idx = torch.tensor(sorted(set(range(positive_item_num)) ^ positive_item))
self.uid2swap_idx[uid] = swap_idx
self.uid2rev_swap_idx[uid] = swap_idx.flip(0)
- self.uid2history_item[uid] = torch.tensor(list(history_item))
-
- def data_preprocess(self):
- pass
+ self.uid2history_item[uid] = torch.tensor(list(history_item), dtype=torch.int64)
def _batch_size_adaptation(self):
batch_num = max(self.batch_size // self.dataset.item_num, 1)
@@ -267,13 +260,17 @@ def _shuffle(self):
self.logger.warnning('GeneralFullDataLoader can\'t shuffle')
def _next_batch_data(self):
- cur_data = self._neg_sampling(self.user_df[self.pr: self.pr + self.step])
+ index = slice(self.pr, self.pr + self.step)
+ user_df = self.user_df[index]
+ pos_len_list = self.uid2items_num[self.uid_list[index]]
+ user_len_list = np.full(len(user_df), self.item_num)
+ user_df.set_additional_info(pos_len_list, user_len_list)
+ cur_data = self._neg_sampling(user_df)
self.pr += self.step
return cur_data
def _neg_sampling(self, user_df):
- uid_list = user_df[self.dataset.uid_field].values
- user_interaction = self._dataframe_to_interaction(user_df)
+ uid_list = list(user_df[self.dataset.uid_field])
history_item = self.uid2history_item[uid_list]
history_row = torch.cat([torch.full_like(hist_iid, i) for i, hist_iid in enumerate(history_item)])
@@ -284,7 +281,7 @@ def _neg_sampling(self, user_df):
swap_row = torch.cat([torch.full_like(swap, i) for i, swap in enumerate(swap_idx)])
swap_col_after = torch.cat(list(swap_idx))
swap_col_before = torch.cat(list(rev_swap_idx))
- return user_interaction, (history_row, history_col), swap_row, swap_col_after, swap_col_before
+ return user_df, (history_row, history_col), swap_row, swap_col_after, swap_col_before
def get_pos_len_list(self):
"""
diff --git a/recbole/data/dataloader/knowledge_dataloader.py b/recbole/data/dataloader/knowledge_dataloader.py
index efbd2c10c..0767eb900 100644
--- a/recbole/data/dataloader/knowledge_dataloader.py
+++ b/recbole/data/dataloader/knowledge_dataloader.py
@@ -14,6 +14,7 @@
from recbole.data.dataloader import AbstractDataLoader, GeneralNegSampleDataLoader
from recbole.utils import InputType, KGDataLoaderState
+from recbole.data.interaction import Interaction
class KGDataLoader(AbstractDataLoader):
@@ -63,24 +64,17 @@ def pr_end(self):
return len(self.dataset.kg_feat)
def _shuffle(self):
- self.dataset.kg_feat = self.dataset.kg_feat.sample(frac=1).reset_index(drop=True)
+ self.dataset.kg_feat.shuffle()
def _next_batch_data(self):
- cur_data = self.dataset.kg_feat[self.pr: self.pr + self.step]
+ cur_data = self._neg_sampling(self.dataset.kg_feat[self.pr: self.pr + self.step])
self.pr += self.step
- if self.real_time:
- cur_data = self._neg_sampling(cur_data)
- return self._dataframe_to_interaction(cur_data)
-
- def data_preprocess(self):
- """Do neg-sampling before training/evaluation.
- """
- self.dataset.kg_feat = self._neg_sampling(self.dataset.kg_feat)
+ return cur_data
def _neg_sampling(self, kg_feat):
- hids = kg_feat[self.hid_field].to_list()
+ hids = kg_feat[self.hid_field]
neg_tids = self.sampler.sample_by_entity_ids(hids, self.neg_sample_num)
- kg_feat.insert(len(kg_feat.columns), self.neg_tid_field, neg_tids)
+ kg_feat.update(Interaction({self.neg_tid_field: neg_tids}))
return kg_feat
@@ -129,63 +123,51 @@ def __init__(self, config, dataset, sampler, kg_sampler, neg_sample_args,
# using kg_sampler and dl_format is pairwise
self.kg_dataloader = KGDataLoader(config, dataset, kg_sampler,
- batch_size=batch_size, dl_format=InputType.PAIRWISE, shuffle=shuffle)
+ batch_size=batch_size, dl_format=InputType.PAIRWISE, shuffle=True)
- self.main_dataloader = self.general_dataloader
+ self.state = None
super().__init__(config, dataset,
batch_size=batch_size, dl_format=dl_format, shuffle=shuffle)
- @property
- def pr(self):
- """Pointer of :class:`KnowledgeBasedDataLoader`. It would be affect by self.state.
- """
- return self.main_dataloader.pr
-
- @pr.setter
- def pr(self, value):
- self.main_dataloader.pr = value
-
def __iter__(self):
- if not hasattr(self, 'state') or not hasattr(self, 'main_dataloader'):
- raise ValueError('The dataloader\'s state and main_dataloader must be set '
- 'when using the kg based dataloader')
- return super().__iter__()
+ if self.state is None:
+ raise ValueError('The dataloader\'s state must be set when using the kg based dataloader, '
+ 'you should call set_mode() before __iter__()')
+ if self.state == KGDataLoaderState.KG:
+ return self.kg_dataloader.__iter__()
+ elif self.state == KGDataLoaderState.RS:
+ return self.general_dataloader.__iter__()
+ elif self.state == KGDataLoaderState.RSKG:
+ self.kg_dataloader.__iter__()
+ self.general_dataloader.__iter__()
+ return self
def _shuffle(self):
- if self.state == KGDataLoaderState.RSKG:
- self.general_dataloader._shuffle()
- self.kg_dataloader._shuffle()
- else:
- self.main_dataloader._shuffle()
+ pass
def __next__(self):
- if self.pr >= self.pr_end:
- if self.state == KGDataLoaderState.RSKG:
- self.general_dataloader.pr = 0
- self.kg_dataloader.pr = 0
- else:
- self.pr = 0
+ if self.general_dataloader.pr >= self.general_dataloader.pr_end:
+ self.general_dataloader.pr = 0
+ self.kg_dataloader.pr = 0
raise StopIteration()
return self._next_batch_data()
def __len__(self):
- return len(self.main_dataloader)
+ return len(self.general_dataloader)
@property
def pr_end(self):
- return self.main_dataloader.pr_end
+ return self.general_dataloader.pr_end
def _next_batch_data(self):
- if self.state == KGDataLoaderState.KG:
- return self.kg_dataloader._next_batch_data()
- elif self.state == KGDataLoaderState.RS:
- return self.general_dataloader._next_batch_data()
- elif self.state == KGDataLoaderState.RSKG:
- kg_data = self.kg_dataloader._next_batch_data()
- rec_data = self.general_dataloader._next_batch_data()
- rec_data.update(kg_data)
- return rec_data
+ try:
+ kg_data = self.kg_dataloader.__next__()
+ except StopIteration:
+ kg_data = self.kg_dataloader.__next__()
+ rec_data = self.general_dataloader.__next__()
+ rec_data.update(kg_data)
+ return rec_data
def set_mode(self, state):
"""Set the mode of :class:`KnowledgeBasedDataLoader`, it can be set to three states:
@@ -202,11 +184,3 @@ def set_mode(self, state):
if state not in set(KGDataLoaderState):
raise NotImplementedError('kg data loader has no state named [{}]'.format(self.state))
self.state = state
- if self.state == KGDataLoaderState.RS:
- self.main_dataloader = self.general_dataloader
- elif self.state == KGDataLoaderState.KG:
- self.main_dataloader = self.kg_dataloader
- else: # RSKG
- kgpr = self.kg_dataloader.pr_end
- rspr = self.general_dataloader.pr_end
- self.main_dataloader = self.general_dataloader if rspr < kgpr else self.kg_dataloader
diff --git a/recbole/data/dataloader/neg_sample_mixin.py b/recbole/data/dataloader/neg_sample_mixin.py
index 4ea52a324..d95ff0465 100644
--- a/recbole/data/dataloader/neg_sample_mixin.py
+++ b/recbole/data/dataloader/neg_sample_mixin.py
@@ -49,11 +49,6 @@ def setup(self):
"""
self._batch_size_adaptation()
- def data_preprocess(self):
- """Do neg-sampling before training/evaluation.
- """
- raise NotImplementedError('Method [data_preprocess] should be implemented.')
-
def _batch_size_adaptation(self):
"""Adjust the batch size to ensure that each positive and negative interaction can be in a batch.
"""
@@ -117,13 +112,13 @@ def __init__(self, config, dataset, sampler, neg_sample_args,
self.times = self.neg_sample_by
self.sampling_func = self._neg_sample_by_pair_wise_sampling
- neg_prefix = config['NEG_PREFIX']
+ self.neg_prefix = config['NEG_PREFIX']
iid_field = config['ITEM_ID_FIELD']
- self.neg_item_id = neg_prefix + iid_field
+ self.neg_item_id = self.neg_prefix + iid_field
columns = [iid_field] if dataset.item_feat is None else dataset.item_feat.columns
for item_feat_col in columns:
- neg_item_feat_col = neg_prefix + item_feat_col
+ neg_item_feat_col = self.neg_prefix + item_feat_col
dataset.copy_field_property(neg_item_feat_col, item_feat_col)
else:
raise ValueError('`neg sampling by` with dl_format [{}] not been implemented'.format(dl_format))
diff --git a/recbole/data/dataloader/sequential_dataloader.py b/recbole/data/dataloader/sequential_dataloader.py
index 8f2e435ee..aaca251f2 100644
--- a/recbole/data/dataloader/sequential_dataloader.py
+++ b/recbole/data/dataloader/sequential_dataloader.py
@@ -15,6 +15,7 @@
import numpy as np
import torch
+from recbole.data.interaction import Interaction, cat_interactions
from recbole.data.dataloader.abstract_dataloader import AbstractDataLoader
from recbole.data.dataloader.neg_sample_mixin import NegSampleByMixin, NegSampleMixin
from recbole.utils import DataLoaderType, FeatureSource, FeatureType, InputType
@@ -50,22 +51,25 @@ def __init__(self, config, dataset,
self.max_item_list_len = config['MAX_ITEM_LIST_LENGTH']
list_suffix = config['LIST_SUFFIX']
- self.item_list_field = self.iid_field + list_suffix
- self.time_list_field = self.time_field + list_suffix
- self.position_field = config['POSITION_FIELD']
- self.target_iid_field = self.iid_field
- self.target_time_field = self.time_field
- self.item_list_length_field = config['ITEM_LIST_LENGTH_FIELD']
+ for field in dataset.inter_feat:
+ if field != self.uid_field:
+ list_field = field + list_suffix
+ setattr(self, f'{field}_list_field', list_field)
+ ftype = dataset.field2type[field]
+
+ if ftype in [FeatureType.TOKEN, FeatureType.TOKEN_SEQ]:
+ list_ftype = FeatureType.TOKEN_SEQ
+ else:
+ list_ftype = FeatureType.FLOAT_SEQ
+
+ if ftype in [FeatureType.TOKEN_SEQ, FeatureType.FLOAT_SEQ]:
+ list_len = (self.max_item_list_len, dataset.field2seqlen[field])
+ else:
+ list_len = self.max_item_list_len
- dataset.set_field_property(self.item_list_field, FeatureType.TOKEN_SEQ, FeatureSource.INTERACTION,
- self.max_item_list_len)
- dataset.set_field_property(self.time_list_field, FeatureType.FLOAT_SEQ, FeatureSource.INTERACTION,
- self.max_item_list_len)
- if self.position_field:
- dataset.set_field_property(self.position_field, FeatureType.TOKEN_SEQ, FeatureSource.INTERACTION,
- self.max_item_list_len)
- dataset.set_field_property(self.target_iid_field, FeatureType.TOKEN, FeatureSource.INTERACTION, 1)
- dataset.set_field_property(self.target_time_field, FeatureType.FLOAT, FeatureSource.INTERACTION, 1)
+ dataset.set_field_property(list_field, list_ftype, FeatureSource.INTERACTION, list_len)
+
+ self.item_list_length_field = config['ITEM_LIST_LENGTH_FIELD']
dataset.set_field_property(self.item_list_length_field, FeatureType.TOKEN, FeatureSource.INTERACTION, 1)
self.uid_list, self.item_list_index, self.target_index, self.item_list_length = \
@@ -78,45 +82,40 @@ def __init__(self, config, dataset,
def data_preprocess(self):
"""Do data augmentation before training/evaluation.
"""
- self.pre_processed_data = self.augmentation(self.uid_list, self.item_list_field,
- self.target_index, self.item_list_length)
+ self.pre_processed_data = self.augmentation(self.item_list_index, self.target_index, self.item_list_length)
@property
def pr_end(self):
return len(self.uid_list)
def _shuffle(self):
- new_index = np.random.permutation(len(self.item_list_index))
if self.real_time:
+ new_index = torch.randperm(self.pr_end)
self.uid_list = self.uid_list[new_index]
self.item_list_index = self.item_list_index[new_index]
self.target_index = self.target_index[new_index]
self.item_list_length = self.item_list_length[new_index]
else:
- new_data = {}
- for key, value in self.pre_processed_data.items():
- new_data[key] = value[new_index]
- self.pre_processed_data = new_data
+ self.pre_processed_data.shuffle()
def _next_batch_data(self):
- cur_index = slice(self.pr, self.pr + self.step)
+ cur_data = self._get_processed_data(slice(self.pr, self.pr + self.step))
+ self.pr += self.step
+ return cur_data
+
+ def _get_processed_data(self, index):
if self.real_time:
- cur_data = self.augmentation(self.uid_list[cur_index],
- self.item_list_index[cur_index],
- self.target_index[cur_index],
- self.item_list_length[cur_index])
+ cur_data = self.augmentation(self.item_list_index[index],
+ self.target_index[index],
+ self.item_list_length[index])
else:
- cur_data = {}
- for key, value in self.pre_processed_data.items():
- cur_data[key] = value[cur_index]
- self.pr += self.step
- return self._dict_to_interaction(cur_data)
+ cur_data = self.pre_processed_data[index]
+ return cur_data
- def augmentation(self, uid_list, item_list_index, target_index, item_list_length):
+ def augmentation(self, item_list_index, target_index, item_list_length):
"""Data augmentation.
Args:
- uid_list (np.ndarray): user id list.
item_list_index (np.ndarray): the index of history items list in interaction.
target_index (np.ndarray): the index of items to be predicted in interaction.
item_list_length (np.ndarray): history list length.
@@ -125,26 +124,26 @@ def augmentation(self, uid_list, item_list_index, target_index, item_list_length
dict: the augmented data.
"""
new_length = len(item_list_index)
+ new_data = self.dataset.inter_feat[target_index]
new_dict = {
- self.uid_field: uid_list,
- self.item_list_field: np.zeros((new_length, self.max_item_list_len), dtype=np.int64),
- self.time_list_field: np.zeros((new_length, self.max_item_list_len), dtype=np.int64),
- self.target_iid_field: self.dataset.inter_feat[self.iid_field][target_index].values,
- self.target_time_field: self.dataset.inter_feat[self.time_field][target_index].values,
- self.item_list_length_field: item_list_length,
+ self.item_list_length_field: torch.tensor(item_list_length),
}
- for field in self.dataset.inter_feat:
- if field != self.iid_field and field != self.time_field:
- new_dict[field] = self.dataset.inter_feat[field][target_index].values
- if self.position_field:
- new_dict[self.position_field] = np.tile(np.arange(self.max_item_list_len), (new_length, 1))
- iid_value = self.dataset.inter_feat[self.iid_field].values
- time_value = self.dataset.inter_feat[self.time_field].values
- for i, (index, length) in enumerate(zip(item_list_index, item_list_length)):
- new_dict[self.item_list_field][i][:length] = iid_value[index]
- new_dict[self.time_list_field][i][:length] = time_value[index]
- return new_dict
+ for field in self.dataset.inter_feat:
+ if field != self.uid_field:
+ list_field = getattr(self, f'{field}_list_field')
+ list_len = self.dataset.field2seqlen[list_field]
+ shape = (new_length, list_len) if isinstance(list_len, int) else (new_length, ) + list_len
+ list_ftype = self.dataset.field2type[list_field]
+ dtype = torch.int64 if list_ftype in [FeatureType.TOKEN, FeatureType.TOKEN_SEQ] else torch.float64
+ new_dict[list_field] = torch.zeros(shape, dtype=dtype)
+
+ value = self.dataset.inter_feat[field]
+ for i, (index, length) in enumerate(zip(item_list_index, item_list_length)):
+ new_dict[list_field][i][:length] = value[index]
+
+ new_data.update(Interaction(new_dict))
+ return new_data
class SequentialNegSampleDataLoader(NegSampleByMixin, SequentialDataLoader):
@@ -169,40 +168,23 @@ def __init__(self, config, dataset, sampler, neg_sample_args,
super().__init__(config, dataset, sampler, neg_sample_args,
batch_size=batch_size, dl_format=dl_format, shuffle=shuffle)
- def data_preprocess(self):
- """Do data augmentation and neg-sampling before training/evaluation.
- """
- self.pre_processed_data = self.augmentation(self.uid_list, self.item_list_field,
- self.target_index, self.item_list_length)
- self.pre_processed_data = self._neg_sampling(self.pre_processed_data)
-
def _batch_size_adaptation(self):
batch_num = max(self.batch_size // self.times, 1)
new_batch_size = batch_num * self.times
- self.step = batch_num if self.real_time else new_batch_size
+ self.step = batch_num
self.upgrade_batch_size(new_batch_size)
def _next_batch_data(self):
- cur_index = slice(self.pr, self.pr + self.step)
- if self.real_time:
- cur_data = self.augmentation(self.uid_list[cur_index],
- self.item_list_index[cur_index],
- self.target_index[cur_index],
- self.item_list_length[cur_index])
- cur_data = self._neg_sampling(cur_data)
- else:
- cur_data = {}
- for key, value in self.pre_processed_data.items():
- cur_data[key] = value[cur_index]
+ cur_data = self._get_processed_data(slice(self.pr, self.pr + self.step))
+ cur_data = self._neg_sampling(cur_data)
self.pr += self.step
if self.user_inter_in_one_batch:
cur_data_len = len(cur_data[self.uid_field])
pos_len_list = np.ones(cur_data_len // self.times, dtype=np.int64)
user_len_list = pos_len_list * self.times
- return self._dict_to_interaction(cur_data, list(pos_len_list), list(user_len_list))
- else:
- return self._dict_to_interaction(cur_data)
+ cur_data.set_additional_info(list(pos_len_list), list(user_len_list))
+ return cur_data
def _neg_sampling(self, data):
if self.user_inter_in_one_batch:
@@ -211,31 +193,26 @@ def _neg_sampling(self, data):
for i in range(data_len):
uids = data[self.uid_field][i: i + 1]
neg_iids = self.sampler.sample_by_user_ids(uids, self.neg_sample_by)
- cur_data = {field: data[field][i: i + 1] for field in data}
+ cur_data = data[i: i + 1]
data_list.append(self.sampling_func(cur_data, neg_iids))
- return {field: np.concatenate([d[field] for d in data_list])
- for field in data}
+ return cat_interactions(data_list)
else:
uids = data[self.uid_field]
neg_iids = self.sampler.sample_by_user_ids(uids, self.neg_sample_by)
return self.sampling_func(data, neg_iids)
def _neg_sample_by_pair_wise_sampling(self, data, neg_iids):
- new_data = {key: np.concatenate([value] * self.times) for key, value in data.items()}
- new_data[self.neg_item_id] = neg_iids
+ new_data = data.repeat(self.times)
+ new_data.update(Interaction({self.neg_item_id: neg_iids}))
return new_data
def _neg_sample_by_point_wise_sampling(self, data, neg_iids):
- new_data = {}
- for key, value in data.items():
- if key == self.target_iid_field:
- new_data[key] = np.concatenate([value, neg_iids])
- else:
- new_data[key] = np.concatenate([value] * self.times)
- pos_len = len(data[self.target_iid_field])
- total_len = len(new_data[self.target_iid_field])
- new_data[self.label_field] = np.zeros(total_len, dtype=np.int)
- new_data[self.label_field][:pos_len] = 1
+ pos_inter_num = len(data)
+ new_data = data.repeat(self.times)
+ new_data[self.iid_field][pos_inter_num:] = neg_iids
+ labels = torch.zeros(pos_inter_num * self.times)
+ labels[: pos_inter_num] = 1.0
+ new_data.update(Interaction({self.label_field: labels}))
return new_data
def get_pos_len_list(self):
@@ -275,9 +252,6 @@ def __init__(self, config, dataset, sampler, neg_sample_args,
super().__init__(config, dataset, sampler, neg_sample_args,
batch_size=batch_size, dl_format=dl_format, shuffle=shuffle)
- def data_preprocess(self):
- pass
-
def _batch_size_adaptation(self):
pass
@@ -290,9 +264,12 @@ def _shuffle(self):
def _next_batch_data(self):
interaction = super()._next_batch_data()
inter_num = len(interaction)
+ pos_len_list = np.ones(inter_num, dtype=np.int64)
+ user_len_list = np.full(inter_num, self.item_num)
+ interaction.set_additional_info(pos_len_list, user_len_list)
scores_row = torch.arange(inter_num).repeat(2)
padding_idx = torch.zeros(inter_num, dtype=torch.int64)
- positive_idx = interaction[self.target_iid_field]
+ positive_idx = interaction[self.iid_field]
scores_col_after = torch.cat((padding_idx, positive_idx))
scores_col_before = torch.cat((positive_idx, padding_idx))
return interaction, None, scores_row, scores_col_after, scores_col_before
@@ -309,4 +286,4 @@ def get_user_len_list(self):
Returns:
np.ndarray: Number of all item for each user in a training/evaluating epoch.
"""
- return np.full(len(self.uid_list), self.item_num)
+ return np.full(self.pr_end, self.item_num)
diff --git a/recbole/data/dataloader/user_dataloader.py b/recbole/data/dataloader/user_dataloader.py
index 73d92aa51..9603ce7cb 100644
--- a/recbole/data/dataloader/user_dataloader.py
+++ b/recbole/data/dataloader/user_dataloader.py
@@ -53,9 +53,9 @@ def pr_end(self):
return len(self.dataset.user_feat)
def _shuffle(self):
- self.dataset.user_feat = self.dataset.user_feat.sample(frac=1).reset_index(drop=True)
+ self.dataset.user_feat.shuffle()
def _next_batch_data(self):
- cur_data = self.dataset.user_feat[[self.uid_field]][self.pr: self.pr + self.step]
+ cur_data = self.dataset.user_feat[self.pr: self.pr + self.step]
self.pr += self.step
- return self._dataframe_to_interaction(cur_data)
+ return cur_data
diff --git a/recbole/data/dataloader/xgboost_dataloader.py b/recbole/data/dataloader/xgboost_dataloader.py
new file mode 100644
index 000000000..dbe897934
--- /dev/null
+++ b/recbole/data/dataloader/xgboost_dataloader.py
@@ -0,0 +1,38 @@
+# @Time : 2020/11/19
+# @Author : Chen Yang
+# @Email : 254170321@qq.com
+
+# UPDATE:
+# @Time : 2020/11/19
+# @Author : Chen Yang
+# @Email : 254170321@qq.com
+
+"""
+recbole.data.dataloader.xgboost_dataloader
+################################################
+"""
+
+from recbole.data.dataloader.general_dataloader import GeneralDataLoader, GeneralNegSampleDataLoader, GeneralFullDataLoader
+
+
+class XgboostDataLoader(GeneralDataLoader):
+ """:class:`XgboostDataLoader` is inherit from :class:`~recbole.data.dataloader.general_dataloader.GeneralDataLoader`,
+ and didn't add/change anything at all.
+ """
+ pass
+
+
+class XgboostNegSampleDataLoader(GeneralNegSampleDataLoader):
+ """:class:`XgboostNegSampleDataLoader` is inherit from
+ :class:`~recbole.data.dataloader.general_dataloader.GeneralNegSampleDataLoader`,
+ and didn't add/change anything at all.
+ """
+ pass
+
+
+class XgboostFullDataLoader(GeneralFullDataLoader):
+ """:class:`XgboostFullDataLoader` is inherit from
+ :class:`~recbole.data.dataloader.general_dataloader.GeneralFullDataLoader`,
+ and didn't add/change anything at all.
+ """
+ pass
diff --git a/recbole/data/dataset/dataset.py b/recbole/data/dataset/dataset.py
index dc14a1de7..78eb2fc63 100644
--- a/recbole/data/dataset/dataset.py
+++ b/recbole/data/dataset/dataset.py
@@ -23,7 +23,6 @@
import torch
import torch.nn.utils.rnn as rnn_utils
from scipy.sparse import coo_matrix
-from sklearn.impute import SimpleImputer
from recbole.utils import FeatureSource, FeatureType
from recbole.data.interaction import Interaction
@@ -78,16 +77,16 @@ class Dataset(object):
time_field (str or None): The same as ``config['TIME_FIELD']``.
- inter_feat (:class:`pandas.DataFrame`): Internal data structure stores the interaction features.
+ inter_feat (:class:`Interaction`): Internal data structure stores the interaction features.
It's loaded from file ``.inter``.
- user_feat (:class:`pandas.DataFrame` or None): Internal data structure stores the user features.
+ user_feat (:class:`Interaction` or None): Internal data structure stores the user features.
It's loaded from file ``.user`` if existed.
- item_feat (:class:`pandas.DataFrame` or None): Internal data structure stores the item features.
+ item_feat (:class:`Interaction` or None): Internal data structure stores the item features.
It's loaded from file ``.item`` if existed.
- feat_list (list): A list contains all the features (:class:`pandas.DataFrame`), including additional features.
+ feat_name_list (list): A list contains all the features' name (:class:`str`), including additional features.
"""
def __init__(self, config, saved_dataset=None):
self.config = config
@@ -111,12 +110,12 @@ def _from_scratch(self):
self._get_field_from_config()
self._load_data(self.dataset_name, self.dataset_path)
self._data_processing()
+ self._change_feat_format()
def _get_preset(self):
"""Initialization useful inside attributes.
"""
self.dataset_path = self.config['data_path']
- self._fill_nan_flag = self.config['fill_nan']
self.field2type = {}
self.field2source = {}
@@ -134,6 +133,10 @@ def _get_field_from_config(self):
self.label_field = self.config['LABEL_FIELD']
self.time_field = self.config['TIME_FIELD']
+ if (self.uid_field is None) ^ (self.iid_field is None):
+ raise ValueError('USER_ID_FIELD and ITEM_ID_FIELD need to be set at the same time '
+ 'or not set at the same time.')
+
self.logger.debug('uid_field: {}'.format(self.uid_field))
self.logger.debug('iid_field: {}'.format(self.iid_field))
@@ -147,7 +150,7 @@ def _data_processing(self):
- Normalization
- Preloading weights initialization
"""
- self.feat_list = self._build_feat_list()
+ self.feat_name_list = self._build_feat_name_list()
if self.benchmark_filename_list is None:
self._data_filtering()
@@ -175,23 +178,24 @@ def _data_filtering(self):
self._filter_by_inter_num()
self._reset_index()
- def _build_feat_list(self):
+ def _build_feat_name_list(self):
"""Feat list building.
- Any feat loaded by Dataset can be found in ``feat_list``
+ Any feat loaded by Dataset can be found in ``feat_name_list``
Returns:
- builded feature list.
+ built feature name list.
Note:
Subclasses can inherit this method to add new feat.
"""
- feat_list = [feat for feat in [self.inter_feat, self.user_feat, self.item_feat] if feat is not None]
+ feat_name_list = [feat_name for feat_name in ['inter_feat', 'user_feat', 'item_feat']
+ if getattr(self, feat_name, None) is not None]
if self.config['additional_feat_suffix'] is not None:
for suf in self.config['additional_feat_suffix']:
- if hasattr(self, '{}_feat'.format(suf)):
- feat_list.append(getattr(self, '{}_feat'.format(suf)))
- return feat_list
+ if getattr(self, '{}_feat'.format(suf), None) is not None:
+ feat_name_list.append('{}_feat'.format(suf))
+ return feat_name_list
def _restore_saved_dataset(self, saved_dataset):
"""Restore saved dataset from ``saved_dataset``.
@@ -310,7 +314,7 @@ def _load_additional_feat(self, token, dataset_path):
For those additional features, e.g. pretrained entity embedding, user can set them
as ``config['additional_feat_suffix']``, then they will be loaded and stored in
- :attr:`feat_list`. See :doc:`../user_guide/data/data_args` for details.
+ :attr:`feat_name_list`. See :doc:`../user_guide/data/data_args` for details.
Args:
token (str): dataset name.
@@ -431,24 +435,16 @@ def _load_feat(self, filepath, source):
def _user_item_feat_preparation(self):
"""Sort :attr:`user_feat` and :attr:`item_feat` by ``user_id`` or ``item_id``.
- Missing values will be filled.
+ Missing values will be filled later.
"""
- flag = False
if self.user_feat is not None:
new_user_df = pd.DataFrame({self.uid_field: np.arange(self.user_num)})
self.user_feat = pd.merge(new_user_df, self.user_feat, on=self.uid_field, how='left')
- flag = True
self.logger.debug('ordering user features by user id.')
if self.item_feat is not None:
new_item_df = pd.DataFrame({self.iid_field: np.arange(self.item_num)})
self.item_feat = pd.merge(new_item_df, self.item_feat, on=self.iid_field, how='left')
- flag = True
self.logger.debug('ordering item features by user id.')
- if flag:
- # CANNOT be removed
- # user/item feat has been updated, thus feat_list should be updated too.
- self.feat_list = self._build_feat_list()
- self._fill_nan_flag = True
def _preload_weight_matrix(self):
"""Transfer preload weight features into :class:`numpy.ndarray` with shape ``[id_token_length]``
@@ -457,11 +453,8 @@ def _preload_weight_matrix(self):
preload_fields = self.config['preload_weight']
if preload_fields is None:
return
- drop_flag = self.config['drop_preload_weight']
- if drop_flag is None:
- drop_flag = True
- self.logger.debug('preload weight matrix for {}, drop=[{}]'.format(preload_fields, drop_flag))
+ self.logger.debug('preload weight matrix for {}'.format(preload_fields))
for preload_id_field in preload_fields:
preload_value_field = preload_fields[preload_id_field]
@@ -476,7 +469,8 @@ def _preload_weight_matrix(self):
'while prelaod value field [{}] is from source [{}], which should be the same'.format(
preload_id_field, pid_source, preload_value_field, pv_source
))
- for feat in self.feat_list:
+ for feat_name in self.feat_name_list:
+ feat = getattr(self, feat_name)
if preload_id_field in feat:
id_ftype = self.field2type[preload_id_field]
if id_ftype != FeatureType.TOKEN:
@@ -508,9 +502,6 @@ def _preload_weight_matrix(self):
value_ftype))
continue
self._preloaded_weight[preload_id_field] = matrix
- if drop_flag:
- self._del_col(preload_id_field)
- self._del_col(preload_value_field)
def _fill_nan(self):
"""Missing value imputation.
@@ -521,23 +512,18 @@ def _fill_nan(self):
For fields with type :obj:`~recbole.utils.enum_type.FeatureType.FLOAT`, missing value will be filled by
the average of original data.
- For sequence features, missing value will be filled by ``[0]``.
+ For sequence features, missing value will be filled by ``[0]``.
"""
self.logger.debug('Filling nan')
- if not self._fill_nan_flag:
- return
-
- most_freq = SimpleImputer(missing_values=np.nan, strategy='most_frequent', copy=False)
- aveg = SimpleImputer(missing_values=np.nan, strategy='mean', copy=False)
-
- for feat in self.feat_list:
+ for feat_name in self.feat_name_list:
+ feat = getattr(self, feat_name)
for field in feat:
ftype = self.field2type[field]
if ftype == FeatureType.TOKEN:
- feat[field] = most_freq.fit_transform(feat[field].values.reshape(-1, 1))
+ feat[field].fillna(value=0, inplace=True)
elif ftype == FeatureType.FLOAT:
- feat[field] = aveg.fit_transform(feat[field].values.reshape(-1, 1))
+ feat[field].fillna(value=feat[field].mean(), inplace=True)
elif ftype.value.endswith('seq'):
feat[field] = feat[field].apply(lambda x: [0]
if (not isinstance(x, np.ndarray) and (not isinstance(x, list)))
@@ -571,7 +557,8 @@ def _normalize(self):
self.logger.debug('Normalized fields: {}'.format(fields))
- for feat in self.feat_list:
+ for feat_name in self.feat_name_list:
+ feat = getattr(self, feat_name)
for field in feat:
if field not in fields:
continue
@@ -643,16 +630,34 @@ def _filter_by_inter_num(self):
Lower bound is also called k-core filtering, which means this method will filter loops
until all the users and items has at least k interactions.
"""
+ if self.uid_field is None or self.iid_field is None:
+ return
+
+ max_user_inter_num = self.config['max_user_inter_num']
+ min_user_inter_num = self.config['min_user_inter_num']
+ max_item_inter_num = self.config['max_item_inter_num']
+ min_item_inter_num = self.config['min_item_inter_num']
+
+ if max_user_inter_num is None and min_user_inter_num is None:
+ user_inter_num = Counter()
+ else:
+ user_inter_num = Counter(self.inter_feat[self.uid_field].values)
+
+ if max_item_inter_num is None and min_item_inter_num is None:
+ item_inter_num = Counter()
+ else:
+ item_inter_num = Counter(self.inter_feat[self.iid_field].values)
+
while True:
ban_users = self._get_illegal_ids_by_inter_num(field=self.uid_field, feat=self.user_feat,
- max_num=self.config['max_user_inter_num'],
- min_num=self.config['min_user_inter_num'])
+ inter_num=user_inter_num,
+ max_num=max_user_inter_num, min_num=min_user_inter_num)
ban_items = self._get_illegal_ids_by_inter_num(field=self.iid_field, feat=self.item_feat,
- max_num=self.config['max_item_inter_num'],
- min_num=self.config['min_item_inter_num'])
+ inter_num=item_inter_num,
+ max_num=max_item_inter_num, min_num=min_item_inter_num)
if len(ban_users) == 0 and len(ban_items) == 0:
- return
+ break
if self.user_feat is not None:
dropped_user = self.user_feat[self.uid_field].isin(ban_users)
@@ -663,19 +668,25 @@ def _filter_by_inter_num(self):
self.item_feat.drop(self.item_feat.index[dropped_item], inplace=True)
dropped_inter = pd.Series(False, index=self.inter_feat.index)
- if self.uid_field:
- dropped_inter |= self.inter_feat[self.uid_field].isin(ban_users)
- if self.iid_field:
- dropped_inter |= self.inter_feat[self.iid_field].isin(ban_items)
- self.logger.debug('[{}] dropped interactions'.format(len(dropped_inter)))
- self.inter_feat.drop(self.inter_feat.index[dropped_inter], inplace=True)
-
- def _get_illegal_ids_by_inter_num(self, field, feat, max_num=None, min_num=None):
+ user_inter = self.inter_feat[self.uid_field]
+ item_inter = self.inter_feat[self.iid_field]
+ dropped_inter |= user_inter.isin(ban_users)
+ dropped_inter |= item_inter.isin(ban_items)
+
+ user_inter_num -= Counter(user_inter[dropped_inter].values)
+ item_inter_num -= Counter(item_inter[dropped_inter].values)
+
+ dropped_index = self.inter_feat.index[dropped_inter]
+ self.logger.debug('[{}] dropped interactions'.format(len(dropped_index)))
+ self.inter_feat.drop(dropped_index, inplace=True)
+
+ def _get_illegal_ids_by_inter_num(self, field, feat, inter_num, max_num=None, min_num=None):
"""Given inter feat, return illegal ids, whose inter num out of [min_num, max_num]
Args:
field (str): field name of user_id or item_id.
feat (pandas.DataFrame): interaction feature.
+ inter_num (Counter): interaction number counter.
max_num (int, optional): max number of interaction. Defaults to ``None``.
min_num (int, optional): min number of interaction. Defaults to ``None``.
@@ -686,16 +697,9 @@ def _get_illegal_ids_by_inter_num(self, field, feat, max_num=None, min_num=None)
field, max_num, min_num
))
- if field is None:
- return set()
- if max_num is None and min_num is None:
- return set()
-
max_num = max_num or np.inf
min_num = min_num or -1
- ids = self.inter_feat[field].values
- inter_num = Counter(ids)
ids = {id_ for id_ in inter_num if inter_num[id_] < min_num or inter_num[id_] > max_num}
if feat is not None:
@@ -716,14 +720,12 @@ def _filter_by_field_value(self):
if not filter_field:
return
- if self.config['drop_filter_field']:
- for field in set(filter_field):
- self._del_col(field)
def _reset_index(self):
- """Reset index for all feats in :attr:`feat_list`.
+ """Reset index for all feats in :attr:`feat_name_list`.
"""
- for feat in self.feat_list:
+ for feat_name in self.feat_name_list:
+ feat = getattr(self, feat_name)
if feat.empty:
raise ValueError('Some feat is empty, please check the filtering settings.')
feat.reset_index(drop=True, inplace=True)
@@ -748,23 +750,26 @@ def _drop_by_value(self, val, cmp):
raise ValueError('field [{}] not defined in dataset'.format(field))
if self.field2type[field] not in {FeatureType.FLOAT, FeatureType.FLOAT_SEQ}:
raise ValueError('field [{}] is not float-like field in dataset, which can\'t be filter'.format(field))
- for feat in self.feat_list:
+ for feat_name in self.feat_name_list:
+ feat = getattr(self, feat_name)
if field in feat:
feat.drop(feat.index[cmp(feat[field].values, val[field])], inplace=True)
filter_field.append(field)
return filter_field
- def _del_col(self, field):
+ def _del_col(self, feat, field):
"""Delete columns
Args:
- field (str): field name to be droped.
+ feat (pandas.DataFrame or Interaction): the feat contains field.
+ field (str): field name to be dropped.
"""
self.logger.debug('delete column [{}]'.format(field))
- for feat in self.feat_list:
- if field in feat:
- feat.drop(columns=field, inplace=True)
- for dct in [self.field2id_token, self.field2seqlen, self.field2source, self.field2type]:
+ if isinstance(feat, Interaction):
+ feat.drop(column=field)
+ else:
+ feat.drop(columns=field, inplace=True)
+ for dct in [self.field2id_token, self.field2token_id, self.field2seqlen, self.field2source, self.field2type]:
if field in dct:
del dct[field]
@@ -794,7 +799,7 @@ def _set_label_by_threshold(self):
self.inter_feat[self.label_field] = (self.inter_feat[field] >= value).astype(int)
else:
raise ValueError('field [{}] not in inter_feat'.format(field))
- self._del_col(field)
+ self._del_col(self.inter_feat, field)
def _get_fields_in_same_space(self):
"""Parsing ``config['fields_in_same_space']``. See :doc:`../user_guide/data/data_args` for detail arg setting.
@@ -916,6 +921,13 @@ def _remap(self, remap_list):
split_point = np.cumsum(feat[field].agg(len))[:-1]
feat[field] = np.split(new_ids, split_point)
+ def _change_feat_format(self):
+ """Change feat format from :class:`pandas.DataFrame` to :class:`Interaction`.
+ """
+ for feat_name in self.feat_name_list:
+ feat = getattr(self, feat_name)
+ setattr(self, feat_name, self._dataframe_to_interaction(feat))
+
@dlapi.set()
def num(self, field):
"""Given ``field``, for token-like fields, return the number of different tokens after remapping,
@@ -1096,7 +1108,7 @@ def avg_actions_of_users(self):
Returns:
numpy.float64: Average number of users' interaction records.
"""
- return np.mean(self.inter_feat.groupby(self.uid_field).size())
+ return np.mean(list(Counter(self.inter_feat[self.uid_field]).values()))
@property
def avg_actions_of_items(self):
@@ -1105,7 +1117,7 @@ def avg_actions_of_items(self):
Returns:
numpy.float64: Average number of items' interaction records.
"""
- return np.mean(self.inter_feat.groupby(self.iid_field).size())
+ return np.mean(list(Counter(self.inter_feat[self.iid_field]).values()))
@property
def sparsity(self):
@@ -1116,36 +1128,6 @@ def sparsity(self):
"""
return 1 - self.inter_num / self.user_num / self.item_num
- @property
- def uid2index(self):
- """Sort ``self.inter_feat``,
- and get the mapping of user_id and index of its interaction records.
-
- Returns:
- tuple:
- - :class:`numpy.ndarray` of int,
- user id list in interaction records.
- - :class:`numpy.ndarray` of :class:`slice`,
- interaction records between slice are all belong to the same uid, index represent user id.
- - :class:`numpy.ndarray` of int,
- representing number of interaction records of each user, index represent user id.
- """
- self._check_field('uid_field')
- self.sort(by=self.uid_field, ascending=True)
- uid_list = []
- start, end = dict(), dict()
- for i, uid in enumerate(self.inter_feat[self.uid_field].values):
- if uid not in start:
- uid_list.append(uid)
- start[uid] = i
- end[uid] = i
- uid2index = np.array([None] * self.user_num)
- uid2items_num = np.zeros(self.user_num, dtype=np.int64)
- for uid in uid_list:
- uid2index[uid] = slice(start[uid], end[uid] + 1)
- uid2items_num[uid] = end[uid] - start[uid] + 1
- return np.array(uid_list), uid2index, uid2items_num
-
def _check_field(self, *field_names):
"""Given a name of attribute, check if it's exist.
@@ -1160,15 +1142,15 @@ def join(self, df):
"""Given interaction feature, join user/item feature into it.
Args:
- df (pandas.DataFrame): Interaction feature to be joint.
+ df (Interaction): Interaction feature to be joint.
Returns:
- pandas.DataFrame: Interaction feature after joining operation.
+ Interaction: Interaction feature after joining operation.
"""
if self.user_feat is not None and self.uid_field in df:
- df = pd.merge(df, self.user_feat, on=self.uid_field, how='left', suffixes=('_inter', '_user'))
+ df.update(self.user_feat[df[self.uid_field]])
if self.item_feat is not None and self.iid_field in df:
- df = pd.merge(df, self.item_feat, on=self.iid_field, how='left', suffixes=('_inter', '_item'))
+ df.update(self.item_feat[df[self.iid_field]])
return df
def __getitem__(self, index, join=True):
@@ -1200,7 +1182,7 @@ def copy(self, new_inter_feat):
whose interaction feature is updated with ``new_inter_feat``, and all the other attributes the same.
Args:
- new_inter_feat (pandas.DataFrame): The new interaction feature need to be updated.
+ new_inter_feat (Interaction): The new interaction feature need to be updated.
Returns:
:class:`~Dataset`: the new :class:`~Dataset` object, whose interaction feature has been updated.
@@ -1209,6 +1191,31 @@ def copy(self, new_inter_feat):
nxt.inter_feat = new_inter_feat
return nxt
+ def _drop_unused_col(self):
+ """Drop columns which are loaded for data preparation but not used in model.
+ """
+ unused_col = self.config['unused_col']
+ if unused_col is None:
+ return
+
+ for feat_name, unused_fields in unused_col.items():
+ feat = getattr(self, feat_name + '_feat')
+ for field in unused_fields:
+ if field not in feat:
+ self.logger.warning('field [{}] is not in [{}_feat], which can not be set in `unused_col`'.format(
+ field, feat_name))
+ continue
+ self._del_col(feat, field)
+
+ def _grouped_index(self, group_by_list):
+ index = {}
+ for i, key in enumerate(group_by_list):
+ if key not in index:
+ index[key] = [i]
+ else:
+ index[key].append(i)
+ return index.values()
+
def _calcu_split_ids(self, tot, ratios):
"""Given split ratios, and total number, calculate the number of each part after splitting.
@@ -1249,7 +1256,7 @@ def split_by_ratio(self, ratios, group_by=None):
split_ids = self._calcu_split_ids(tot=tot_cnt, ratios=ratios)
next_index = [range(start, end) for start, end in zip([0] + split_ids, split_ids + [tot_cnt])]
else:
- grouped_inter_feat_index = self.inter_feat.groupby(by=group_by).groups.values()
+ grouped_inter_feat_index = self._grouped_index(self.inter_feat[group_by].numpy())
next_index = [[] for i in range(len(ratios))]
for grouped_index in grouped_inter_feat_index:
tot_cnt = len(grouped_index)
@@ -1257,7 +1264,8 @@ def split_by_ratio(self, ratios, group_by=None):
for index, start, end in zip(next_index, [0] + split_ids, split_ids + [tot_cnt]):
index.extend(grouped_index[start: end])
- next_df = [self.inter_feat.loc[index].reset_index(drop=True) for index in next_index]
+ self._drop_unused_col()
+ next_df = [self.inter_feat[index] for index in next_index]
next_ds = [self.copy(_) for _ in next_df]
return next_ds
@@ -1265,7 +1273,7 @@ def _split_index_by_leave_one_out(self, grouped_index, leave_one_num):
"""Split indexes by strategy leave one out.
Args:
- grouped_index (pandas.DataFrameGroupBy): Index to be splitted.
+ grouped_index (list of list of int): Index to be splitted.
leave_one_num (int): Number of parts whose length is expected to be ``1``.
Returns:
@@ -1298,26 +1306,28 @@ def leave_one_out(self, group_by, leave_one_num=1):
if group_by is None:
raise ValueError('leave one out strategy require a group field')
- grouped_inter_feat_index = self.inter_feat.groupby(by=group_by).groups.values()
+ grouped_inter_feat_index = self._grouped_index(self.inter_feat[group_by].numpy())
next_index = self._split_index_by_leave_one_out(grouped_inter_feat_index, leave_one_num)
- next_df = [self.inter_feat.loc[index].reset_index(drop=True) for index in next_index]
+
+ self._drop_unused_col()
+ next_df = [self.inter_feat[index] for index in next_index]
next_ds = [self.copy(_) for _ in next_df]
return next_ds
def shuffle(self):
"""Shuffle the interaction records inplace.
"""
- self.inter_feat = self.inter_feat.sample(frac=1).reset_index(drop=True)
+ self.inter_feat.shuffle()
def sort(self, by, ascending=True):
"""Sort the interaction records inplace.
Args:
- by (str): Field that as the key in the sorting process.
- ascending (bool, optional): Results are ascending if ``True``, otherwise descending.
+ by (str or list of str): Field that as the key in the sorting process.
+ ascending (bool or list of bool, optional): Results are ascending if ``True``, otherwise descending.
Defaults to ``True``
"""
- self.inter_feat.sort_values(by=by, ascending=ascending, inplace=True, ignore_index=True)
+ self.inter_feat.sort(by=by, ascending=ascending)
def build(self, eval_setting):
"""Processing dataset according to evaluation setting, including Group, Order and Split.
@@ -1379,22 +1389,22 @@ def save(self, filepath):
def get_user_feature(self):
"""
Returns:
- pandas.DataFrame: user features
+ Interaction: user features
"""
if self.user_feat is None:
self._check_field('uid_field')
- return pd.DataFrame({self.uid_field: np.arange(self.user_num)})
+ return Interaction({self.uid_field: torch.arange(self.user_num)})
else:
return self.user_feat
def get_item_feature(self):
"""
Returns:
- pandas.DataFrame: item features
+ Interaction: item features
"""
if self.item_feat is None:
self._check_field('iid_field')
- return pd.DataFrame({self.iid_field: np.arange(self.item_num)})
+ return Interaction({self.iid_field: torch.arange(self.item_num)})
else:
return self.item_feat
@@ -1409,7 +1419,7 @@ def _create_sparse_matrix(self, df_feat, source_field, target_field, form='coo',
else ``matrix[src, tgt] = df_feat[value_field][src, tgt]``.
Args:
- df_feat (pandas.DataFrame): Feature where src and tgt exist.
+ df_feat (Interaction): Feature where src and tgt exist.
source_field (str): Source field
target_field (str): Target field
form (str, optional): Sparse matrix format. Defaults to ``coo``.
@@ -1419,14 +1429,14 @@ def _create_sparse_matrix(self, df_feat, source_field, target_field, form='coo',
Returns:
scipy.sparse: Sparse matrix in form ``coo`` or ``csr``.
"""
- src = df_feat[source_field].values
- tgt = df_feat[target_field].values
+ src = df_feat[source_field]
+ tgt = df_feat[target_field]
if value_field is None:
data = np.ones(len(df_feat))
else:
- if value_field not in df_feat.columns:
+ if value_field not in df_feat:
raise ValueError('value_field [{}] should be one of `df_feat`\'s features.'.format(value_field))
- data = df_feat[value_field].values
+ data = df_feat[value_field]
mat = coo_matrix((data, (src, tgt)), shape=(self.num(source_field), self.num(target_field)))
if form == 'coo':
@@ -1436,7 +1446,7 @@ def _create_sparse_matrix(self, df_feat, source_field, target_field, form='coo',
else:
raise NotImplementedError('sparse matrix format [{}] has not been implemented.'.format(form))
- def _create_graph(self, df_feat, source_field, target_field, form='dgl', value_field=None):
+ def _create_graph(self, tensor_feat, source_field, target_field, form='dgl', value_field=None):
"""Get graph that describe relations between two fields.
Source and target should be token-like fields.
@@ -1447,7 +1457,7 @@ def _create_graph(self, df_feat, source_field, target_field, form='dgl', value_f
Currently, we support graph in `DGL`_ and `PyG`_.
Args:
- df_feat (pandas.DataFrame): Feature where src and tgt exist.
+ tensor_feat (Interaction): Feature where src and tgt exist.
source_field (str): Source field
target_field (str): Target field
form (str, optional): Library of graph data structure. Defaults to ``dgl``.
@@ -1463,7 +1473,6 @@ def _create_graph(self, df_feat, source_field, target_field, form='dgl', value_f
.. _PyG:
https://github.com/rusty1s/pytorch_geometric
"""
- tensor_feat = self._dataframe_to_interaction(df_feat)
src = tensor_feat[source_field]
tgt = tensor_feat[target_field]
@@ -1529,13 +1538,13 @@ def _history_matrix(self, row, value_field=None):
"""
self._check_field('uid_field', 'iid_field')
- user_ids, item_ids = self.inter_feat[self.uid_field].values, self.inter_feat[self.iid_field].values
+ user_ids, item_ids = self.inter_feat[self.uid_field].numpy(), self.inter_feat[self.iid_field].numpy()
if value_field is None:
values = np.ones(len(self.inter_feat))
else:
- if value_field not in self.inter_feat.columns:
+ if value_field not in self.inter_feat:
raise ValueError('value_field [{}] should be one of `inter_feat`\'s features.'.format(value_field))
- values = self.inter_feat[value_field].values
+ values = self.inter_feat[value_field].numpy()
if row == 'user':
row_num, max_col_num = self.user_num, self.item_num
@@ -1628,8 +1637,7 @@ def get_preload_weight(self, field):
raise ValueError('field [{}] not in preload_weight'.format(field))
return self._preloaded_weight[field]
- @dlapi.set()
- def _dataframe_to_interaction(self, data, *args):
+ def _dataframe_to_interaction(self, data):
"""Convert :class:`pandas.DataFrame` to :class:`~recbole.data.interaction.Interaction`.
Args:
@@ -1638,37 +1646,18 @@ def _dataframe_to_interaction(self, data, *args):
Returns:
:class:`~recbole.data.interaction.Interaction`: Converted data.
"""
- data = data.to_dict(orient='list')
- return self._dict_to_interaction(data, *args)
-
- @dlapi.set()
- def _dict_to_interaction(self, data, *args):
- """Convert :class:`dict` to :class:`~recbole.data.interaction.Interaction`.
-
- Args:
- data (dict): data to be converted.
-
- Returns:
- :class:`~recbole.data.interaction.Interaction`: Converted data.
- """
+ new_data = {}
for k in data:
+ value = data[k].values
ftype = self.field2type[k]
if ftype == FeatureType.TOKEN:
- data[k] = torch.LongTensor(data[k])
+ new_data[k] = torch.LongTensor(value)
elif ftype == FeatureType.FLOAT:
- data[k] = torch.FloatTensor(data[k])
+ new_data[k] = torch.FloatTensor(value)
elif ftype == FeatureType.TOKEN_SEQ:
- if isinstance(data[k], np.ndarray):
- data[k] = torch.LongTensor(data[k][:, :self.field2seqlen[k]])
- else:
- seq_data = [torch.LongTensor(d[:self.field2seqlen[k]]) for d in data[k]]
- data[k] = rnn_utils.pad_sequence(seq_data, batch_first=True)
+ seq_data = [torch.LongTensor(d[:self.field2seqlen[k]]) for d in value]
+ new_data[k] = rnn_utils.pad_sequence(seq_data, batch_first=True)
elif ftype == FeatureType.FLOAT_SEQ:
- if isinstance(data[k], np.ndarray):
- data[k] = torch.FloatTensor(data[k][:, :self.field2seqlen[k]])
- else:
- seq_data = [torch.FloatTensor(d[:self.field2seqlen[k]]) for d in data[k]]
- data[k] = rnn_utils.pad_sequence(seq_data, batch_first=True)
- else:
- raise ValueError('Illegal ftype [{}]'.format(ftype))
- return Interaction(data, *args)
+ seq_data = [torch.FloatTensor(d[:self.field2seqlen[k]]) for d in value]
+ new_data[k] = rnn_utils.pad_sequence(seq_data, batch_first=True)
+ return Interaction(new_data)
diff --git a/recbole/data/dataset/kg_dataset.py b/recbole/data/dataset/kg_dataset.py
index 06855c310..ad01d7f38 100644
--- a/recbole/data/dataset/kg_dataset.py
+++ b/recbole/data/dataset/kg_dataset.py
@@ -123,11 +123,11 @@ def __str__(self):
'The number of items that have been linked to KG: {}'.format(len(self.item2entity))]
return '\n'.join(info)
- def _build_feat_list(self):
- feat_list = super()._build_feat_list()
+ def _build_feat_name_list(self):
+ feat_name_list = super()._build_feat_name_list()
if self.kg_feat is not None:
- feat_list.append(self.kg_feat)
- return feat_list
+ feat_name_list.append('kg_feat')
+ return feat_name_list
def _restore_saved_dataset(self, saved_dataset):
raise NotImplementedError()
@@ -382,7 +382,7 @@ def head_entities(self):
Returns:
numpy.ndarray: List of head entities of kg triplets.
"""
- return self.kg_feat[self.head_entity_field].values
+ return self.kg_feat[self.head_entity_field].numpy()
@property
@dlapi.set()
@@ -391,7 +391,7 @@ def tail_entities(self):
Returns:
numpy.ndarray: List of tail entities of kg triplets.
"""
- return self.kg_feat[self.tail_entity_field].values
+ return self.kg_feat[self.tail_entity_field].numpy()
@property
@dlapi.set()
@@ -400,7 +400,7 @@ def relations(self):
Returns:
numpy.ndarray: List of relations of kg triplets.
"""
- return self.kg_feat[self.relation_field].values
+ return self.kg_feat[self.relation_field].numpy()
@property
@dlapi.set()
@@ -447,11 +447,11 @@ def kg_graph(self, form='coo', value_field=None):
def _create_ckg_sparse_matrix(self, form='coo', show_relation=False):
user_num = self.user_num
- hids = self.kg_feat[self.head_entity_field].values + user_num
- tids = self.kg_feat[self.tail_entity_field].values + user_num
+ hids = self.head_entities + user_num
+ tids = self.tail_entities + user_num
- uids = self.inter_feat[self.uid_field].values
- iids = self.inter_feat[self.iid_field].values + user_num
+ uids = self.inter_feat[self.uid_field].numpy()
+ iids = self.inter_feat[self.iid_field].numpy() + user_num
ui_rel_num = len(uids)
ui_rel_id = self.relation_num - 1
@@ -463,7 +463,7 @@ def _create_ckg_sparse_matrix(self, form='coo', show_relation=False):
if not show_relation:
data = np.ones(len(src))
else:
- kg_rel = self.kg_feat[self.relation_field].values
+ kg_rel = self.kg_feat[self.relation_field].numpy()
ui_rel = np.full(2 * ui_rel_num, ui_rel_id, dtype=kg_rel.dtype)
data = np.concatenate([ui_rel, kg_rel])
node_num = self.entity_num + self.user_num
@@ -478,8 +478,8 @@ def _create_ckg_sparse_matrix(self, form='coo', show_relation=False):
def _create_ckg_graph(self, form='dgl', show_relation=False):
user_num = self.user_num
- kg_tensor = self._dataframe_to_interaction(self.kg_feat)
- inter_tensor = self._dataframe_to_interaction(self.inter_feat)
+ kg_tensor = self.kg_feat
+ inter_tensor = self.inter_feat
head_entity = kg_tensor[self.head_entity_field] + user_num
tail_entity = kg_tensor[self.tail_entity_field] + user_num
diff --git a/recbole/data/dataset/sequential_dataset.py b/recbole/data/dataset/sequential_dataset.py
index 4f2af4139..0d78b5644 100644
--- a/recbole/data/dataset/sequential_dataset.py
+++ b/recbole/data/dataset/sequential_dataset.py
@@ -75,7 +75,7 @@ def prepare_data_augmentation(self):
last_uid = None
uid_list, item_list_index, target_index, item_list_length = [], [], [], []
seq_start = 0
- for i, uid in enumerate(self.inter_feat[self.uid_field].values):
+ for i, uid in enumerate(self.inter_feat[self.uid_field].numpy()):
if last_uid != uid:
last_uid = uid
seq_start = i
@@ -99,8 +99,10 @@ def leave_one_out(self, group_by, leave_one_num=1):
raise ValueError('leave one out strategy require a group field')
self.prepare_data_augmentation()
- grouped_index = pd.DataFrame(self.uid_list).groupby(by=0).groups.values()
+ grouped_index = self._grouped_index(self.uid_list)
next_index = self._split_index_by_leave_one_out(grouped_index, leave_one_num)
+
+ self._drop_unused_col()
next_ds = []
for index in next_index:
ds = copy.copy(self)
diff --git a/recbole/data/dataset/social_dataset.py b/recbole/data/dataset/social_dataset.py
index 9cf4a93c9..dcf76f536 100644
--- a/recbole/data/dataset/social_dataset.py
+++ b/recbole/data/dataset/social_dataset.py
@@ -56,11 +56,11 @@ def _load_data(self, token, dataset_path):
super()._load_data(token, dataset_path)
self.net_feat = self._load_net(self.dataset_name, self.dataset_path)
- def _build_feat_list(self):
- feat_list = super()._build_feat_list()
+ def _build_feat_name_list(self):
+ feat_name_list = super()._build_feat_name_list()
if self.net_feat is not None:
- feat_list.append(self.net_feat)
- return feat_list
+ feat_name_list.append('net_feat')
+ return feat_name_list
def _load_net(self, dataset_name, dataset_path):
net_file_path = os.path.join(dataset_path, '{}.{}'.format(dataset_name, 'net'))
diff --git a/recbole/data/interaction.py b/recbole/data/interaction.py
index 1787bd30b..8adfada82 100644
--- a/recbole/data/interaction.py
+++ b/recbole/data/interaction.py
@@ -13,6 +13,7 @@
"""
import numpy as np
+import torch
class Interaction(object):
@@ -81,13 +82,20 @@ class Interaction(object):
def __init__(self, interaction, pos_len_list=None, user_len_list=None):
self.interaction = interaction
+ self.pos_len_list = self.user_len_list = None
+ self.set_additional_info(pos_len_list, user_len_list)
+ for k in self.interaction:
+ if not isinstance(self.interaction[k], torch.Tensor):
+ raise ValueError('interaction [{}] should only contains torch.Tensor'.format(interaction))
+ self.length = -1
+ for k in self.interaction:
+ self.length = max(self.length, self.interaction[k].shape[0])
+
+ def set_additional_info(self, pos_len_list=None, user_len_list=None):
self.pos_len_list = pos_len_list
self.user_len_list = user_len_list
if (self.pos_len_list is None) ^ (self.user_len_list is None):
raise ValueError('pos_len_list and user_len_list should be both None or valued.')
- for k in self.interaction:
- self.length = self.interaction[k].shape[0]
- break
def __iter__(self):
return self.interaction.__iter__()
@@ -101,13 +109,17 @@ def __getitem__(self, index):
ret[k] = self.interaction[k][index]
return Interaction(ret)
+ def __contains__(self, item):
+ return item in self.interaction
+
def __len__(self):
return self.length
def __str__(self):
info = ['The batch_size of interaction: {}'.format(self.length)]
for k in self.interaction:
- temp_str = " {}, {}, {}".format(k, self.interaction[k].shape, self.interaction[k].device.type)
+ inter = self.interaction[k]
+ temp_str = " {}, {}, {}, {}".format(k, inter.shape, inter.device.type, inter.dtype)
info.append(temp_str)
info.append('\n')
return '\n'.join(info)
@@ -115,6 +127,14 @@ def __str__(self):
def __repr__(self):
return self.__str__()
+ @property
+ def columns(self):
+ """
+ Returns:
+ list of str: The columns of interaction.
+ """
+ return list(self.interaction.keys())
+
def to(self, device, selected_field=None):
"""Transfer Tensors in this Interaction object to the specified device.
@@ -214,8 +234,114 @@ def repeat_interleave(self, repeats, dim=0):
def update(self, new_inter):
"""Similar to ``dict.update()``
+
+ Args:
+ new_inter (Interaction): current interaction will be updated by new_inter.
"""
for k in new_inter.interaction:
self.interaction[k] = new_inter.interaction[k]
- self.pos_len_list = new_inter.pos_len_list
- self.user_len_list = new_inter.user_len_list
+ if new_inter.pos_len_list is not None:
+ self.pos_len_list = new_inter.pos_len_list
+ if new_inter.user_len_list is not None:
+ self.user_len_list = new_inter.user_len_list
+
+ def drop(self, column):
+ """Drop column in interaction.
+
+ Args:
+ column (str): the column to be dropped.
+ """
+ if column not in self.interaction:
+ raise ValueError('column [{}] is not in [{}]'.format(column, self))
+ del self.interaction[column]
+
+ def _reindex(self, index):
+ """Reset the index of interaction inplace.
+
+ Args:
+ index: the new index of current interaction.
+ """
+ for k in self.interaction:
+ self.interaction[k] = self.interaction[k][index]
+ if self.pos_len_list is not None:
+ self.pos_len_list = self.pos_len_list[index]
+ if self.user_len_list is not None:
+ self.user_len_list = self.user_len_list[index]
+
+ def shuffle(self):
+ """Shuffle current interaction inplace.
+ """
+ index = torch.randperm(self.length)
+ self._reindex(index)
+
+ def sort(self, by, ascending=True):
+ """Sort the current interaction inplace.
+
+ Args:
+ by (str or list of str): Field that as the key in the sorting process.
+ ascending (bool or list of bool, optional): Results are ascending if ``True``, otherwise descending.
+ Defaults to ``True``
+ """
+ if isinstance(by, str):
+ if by not in self.interaction:
+ raise ValueError('[{}] is not exist in interaction [{}]'.format(by, self))
+ by = [by]
+ elif isinstance(by, (list, tuple)):
+ for b in by:
+ if b not in self.interaction:
+ raise ValueError('[{}] is not exist in interaction [{}]'.format(b, self))
+ else:
+ raise TypeError('wrong type of by [{}]'.format(by))
+
+ if isinstance(ascending, bool):
+ ascending = [ascending]
+ elif isinstance(ascending, (list, tuple)):
+ for a in ascending:
+ if not isinstance(a, bool):
+ raise TypeError('wrong type of ascending [{}]'.format(ascending))
+ else:
+ raise TypeError('wrong type of ascending [{}]'.format(ascending))
+
+ if len(by) != len(ascending):
+ if len(ascending) == 1:
+ ascending = ascending * len(by)
+ else:
+ raise ValueError('by [{}] and ascending [{}] should have same length'.format(by, ascending))
+
+ for b, a in zip(by[::-1], ascending[::-1]):
+ index = np.argsort(self.interaction[b], kind='stable')
+ if not a:
+ index = index[::-1]
+ self._reindex(index)
+
+ def add_prefix(self, prefix):
+ """Add prefix to current interaction's columns.
+
+ Args:
+ prefix (str): The prefix to be added.
+ """
+ self.interaction = {prefix + key: value for key, value in self.interaction.items()}
+
+
+def cat_interactions(interactions):
+ """Concatenate list of interactions to single interaction.
+
+ Args:
+ interactions (list of :class:`Interaction`): List of interactions to be concatenated.
+
+ Returns:
+ :class:`Interaction`: Concatenated interaction.
+ """
+ if not isinstance(interactions, (list, tuple)):
+ raise TypeError('interactions [{}] should be list or tuple'.format(interactions))
+ if len(interactions) == 0:
+ raise ValueError('interactions [{}] should have some interactions'.format(interactions))
+
+ columns_set = set(interactions[0].columns)
+ for inter in interactions:
+ if columns_set != set(inter.columns):
+ raise ValueError('interactions [{}] should have some interactions'.format(interactions))
+
+ new_inter = {col: torch.cat([inter[col] for inter in interactions])
+ for col in columns_set}
+ return Interaction(new_inter)
diff --git a/recbole/data/utils.py b/recbole/data/utils.py
index accf76dfc..f6e649daf 100644
--- a/recbole/data/utils.py
+++ b/recbole/data/utils.py
@@ -264,6 +264,13 @@ def get_data_loader(name, config, eval_setting):
return SequentialNegSampleDataLoader
elif neg_sample_strategy == 'full':
return SequentialFullDataLoader
+ elif model_type == ModelType.XGBOOST:
+ if neg_sample_strategy == 'none':
+ return XgboostDataLoader
+ elif neg_sample_strategy == 'by':
+ return XgboostNegSampleDataLoader
+ elif neg_sample_strategy == 'full':
+ return XgboostFullDataLoader
elif model_type == ModelType.KNOWLEDGE:
if neg_sample_strategy == 'by':
if name == 'train':
diff --git a/recbole/model/exlib_recommender/xgboost.py b/recbole/model/exlib_recommender/xgboost.py
new file mode 100644
index 000000000..345953e38
--- /dev/null
+++ b/recbole/model/exlib_recommender/xgboost.py
@@ -0,0 +1,26 @@
+# -*- coding: utf-8 -*-
+# @Time : 2020/11/19
+# @Author : Chen Yang
+# @Email : 254170321@qq.com
+
+r"""
+recbole.model.exlib_recommender.xgboost
+#############################
+"""
+
+import xgboost as xgb
+from recbole.utils import ModelType, InputType, FeatureSource, FeatureType
+
+
+class xgboost(xgb.Booster):
+ r"""xgboost is inherited from xgb.Booster
+
+ """
+ type = ModelType.CONTEXT
+ input_type = InputType.POINTWISE
+
+ def __init__(self, config, dataset):
+ super().__init__(params=None, cache=(), model_file=None)
+
+ def to(self, device):
+ return self
diff --git a/recbole/model/knowledge_aware_recommender/kgnnls.py b/recbole/model/knowledge_aware_recommender/kgnnls.py
index 7907cbe7a..08fef392b 100644
--- a/recbole/model/knowledge_aware_recommender/kgnnls.py
+++ b/recbole/model/knowledge_aware_recommender/kgnnls.py
@@ -62,9 +62,9 @@ def __init__(self, config, dataset):
self.adj_entity, self.adj_relation = adj_entity.to(
self.device), adj_relation.to(self.device)
- inter_feat = dataset.dataset.inter_feat.values
- pos_users = torch.from_numpy(inter_feat[:, 0])
- pos_items = torch.from_numpy(inter_feat[:, 1])
+ inter_feat = dataset.dataset.inter_feat
+ pos_users = inter_feat[dataset.dataset.uid_field]
+ pos_items = inter_feat[dataset.dataset.iid_field]
pos_label = torch.ones(pos_items.shape)
pos_interaction_table, self.offset = self.get_interaction_table(
pos_users, pos_items, pos_label)
diff --git a/recbole/model/sequential_recommender/hgn.py b/recbole/model/sequential_recommender/hgn.py
new file mode 100644
index 000000000..a97c65988
--- /dev/null
+++ b/recbole/model/sequential_recommender/hgn.py
@@ -0,0 +1,198 @@
+# -*- coding: utf-8 -*-
+# @Time : 2020/11/21 16:36
+# @Author : Shao Weiqi
+# @Reviewer : Lin Kun
+# @Email : shaoweiqi@ruc.edu.cn
+
+r"""
+HGN
+################################################
+
+Reference:
+ Chen Ma et al. "Hierarchical Gating Networks for Sequential Recommendation."in SIGKDD 2019
+
+
+"""
+
+import torch
+import torch.nn as nn
+from torch.nn.init import xavier_uniform_, constant_, normal_
+from recbole.model.abstract_recommender import SequentialRecommender
+from recbole.model.loss import BPRLoss
+
+
+class HGN(SequentialRecommender):
+ r"""
+ HGN sets feature gating and instance gating to get the important feature and item for predicting the next item
+
+ """
+
+ def __init__(self, config, dataset):
+ super(HGN, self).__init__(config, dataset)
+
+ # load the dataset information
+ self.n_user = dataset.num(self.USER_ID)
+ self.device = config["device"]
+
+ # load the parameter information
+ self.embedding_size = config["embedding_size"]
+ self.reg_weight = config["reg_weight"]
+ self.pool_type = config["pooling_type"]
+
+ if self.pool_type not in ["max", "average"]:
+ raise NotImplementedError("Make sure 'loss_type' in ['max', 'average']!")
+
+ # define the layers and loss function
+ self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
+ self.user_embedding = nn.Embedding(self.n_user, self.embedding_size)
+
+ # define the module feature gating need
+ self.w1 = nn.Linear(self.embedding_size, self.embedding_size)
+ self.w2 = nn.Linear(self.embedding_size, self.embedding_size)
+ self.b = nn.Parameter(torch.zeros(self.embedding_size), requires_grad=True).to(self.device)
+
+ # define the module instance gating need
+ self.w3 = nn.Linear(self.embedding_size, 1, bias=False)
+ self.w4 = nn.Linear(self.embedding_size, self.max_seq_length, bias=False)
+
+ # define item_embedding for prediction
+ self.item_embedding_for_prediction = nn.Embedding(self.n_items, self.embedding_size)
+
+ self.sigmoid = nn.Sigmoid()
+
+ self.loss_type = config['loss_type']
+ if self.loss_type == 'BPR':
+ self.loss_fct = BPRLoss()
+ elif self.loss_type == 'CE':
+ self.loss_fct = nn.CrossEntropyLoss()
+ else:
+ raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")
+
+ # init the parameters of the model
+ self.apply(self._init_weights)
+
+ def reg_loss(self, user_embedding, item_embedding, seq_item_embedding):
+
+ reg_1, reg_2 = self.reg_weight
+ loss_1 = reg_1 * torch.norm(self.w1.weight, p=2) + reg_1 * torch.norm(self.w2.weight, p=2) + reg_1 * torch.norm(
+ self.w3.weight, p=2) + reg_1 * torch.norm(self.w4.weight, p=2)
+ loss_2 = reg_2 * torch.norm(user_embedding, p=2) + reg_2 * torch.norm(item_embedding, p=2) + reg_2 * torch.norm(
+ seq_item_embedding, p=2)
+
+ return loss_1 + loss_2
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Embedding):
+ normal_(module.weight.data, 0., 1 / self.embedding_size)
+ elif isinstance(module, nn.Linear):
+ xavier_uniform_(module.weight.data)
+ if module.bias is not None:
+ constant_(module.bias.data, 0)
+
+ def feature_gating(self, seq_item_embedding, user_embedding):
+ """
+
+ choose the features that will be sent to the next stage(more important feature, more focus)
+ """
+
+ batch_size, seq_len, embedding_size = seq_item_embedding.size()
+ seq_item_embedding_value = seq_item_embedding
+
+ seq_item_embedding = self.w1(seq_item_embedding)
+ # batch_size * seq_len * embedding_size
+ user_embedding = self.w2(user_embedding)
+ # batch_size * embedding_size
+ user_embedding = user_embedding.unsqueeze(1).repeat(1, seq_len, 1)
+ # batch_size * seq_len * embedding_size
+
+ user_item = self.sigmoid(seq_item_embedding + user_embedding + self.b)
+ # batch_size * seq_len * embedding_size
+
+ user_item = torch.mul(seq_item_embedding_value, user_item)
+ # batch_size * seq_len * embedding_size
+
+ return user_item
+
+ def instance_gating(self, user_item, user_embedding):
+ """
+
+ choose the last click items that will influence the prediction( more important more chance to get attention)
+ """
+
+ user_embedding_value = user_item
+
+ user_item = self.w3(user_item)
+ # batch_size * seq_len * 1
+
+ user_embedding = self.w4(user_embedding).unsqueeze(2)
+ # batch_size * seq_len * 1
+
+ instance_score = self.sigmoid(user_item + user_embedding).squeeze(-1)
+ # batch_size * seq_len * 1
+ output = torch.mul(instance_score.unsqueeze(2), user_embedding_value)
+ # batch_size * seq_len * embedding_size
+
+ if self.pool_type == "average":
+ output = torch.div(output.sum(dim=1), instance_score.sum(dim=1).unsqueeze(1))
+ # batch_size * embedding_size
+ else:
+ # for max_pooling
+ index = torch.max(instance_score, dim=1)[1]
+ # batch_size * 1
+ output = self.gather_indexes(output, index)
+ # batch_size * seq_len * embedding_size ==>> batch_size * embedding_size
+
+ return output
+
+ def forward(self, seq_item, user):
+
+ seq_item_embedding = self.item_embedding(seq_item)
+ user_embedding = self.user_embedding(user)
+ feature_gating = self.feature_gating(seq_item_embedding, user_embedding)
+ instance_gating = self.instance_gating(feature_gating, user_embedding)
+ # batch_size * embedding_size
+ item_item = torch.sum(seq_item_embedding, dim=1)
+ # batch_size * embedding_size
+
+ return user_embedding + instance_gating + item_item
+
+ def calculate_loss(self, interaction):
+
+ seq_item = interaction[self.ITEM_SEQ]
+ seq_item_embedding = self.item_embedding(seq_item)
+ user = interaction[self.USER_ID]
+ user_embedding = self.user_embedding(user)
+ seq_output = self.forward(seq_item, user)
+ pos_items = interaction[self.POS_ITEM_ID]
+ pos_items_emb = self.item_embedding_for_prediction(pos_items)
+ if self.loss_type == 'BPR':
+ neg_items = interaction[self.NEG_ITEM_ID]
+ neg_items_emb = self.item_embedding(neg_items)
+ pos_score = torch.sum(seq_output * pos_items_emb, dim=-1)
+ neg_score = torch.sum(seq_output * neg_items_emb, dim=-1)
+ loss = self.loss_fct(pos_score, neg_score)
+ return loss + self.reg_loss(user_embedding, pos_items_emb, seq_item_embedding)
+ else: # self.loss_type = 'CE'
+ test_item_emb = self.item_embedding_for_prediction.weight
+ logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
+ loss = self.loss_fct(logits, pos_items)
+ return loss + self.reg_loss(user_embedding, pos_items_emb, seq_item_embedding)
+
+ def predict(self, interaction):
+
+ item_seq = interaction[self.ITEM_SEQ]
+ test_item = interaction[self.ITEM_ID]
+ user = interaction[self.USER_ID]
+ seq_output = self.forward(item_seq, user)
+ test_item_emb = self.item_embedding_for_prediction(test_item)
+ scores = torch.mul(seq_output, test_item_emb).sum(dim=1)
+ return scores
+
+ def full_sort_predict(self, interaction):
+
+ item_seq = interaction[self.ITEM_SEQ]
+ user = interaction[self.USER_ID]
+ seq_output = self.forward(item_seq, user)
+ test_items_emb = self.item_embedding_for_prediction.weight
+ scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1))
+ return scores
diff --git a/recbole/model/sequential_recommender/hrm.py b/recbole/model/sequential_recommender/hrm.py
new file mode 100644
index 000000000..0554a4ace
--- /dev/null
+++ b/recbole/model/sequential_recommender/hrm.py
@@ -0,0 +1,172 @@
+# -*- coding: utf-8 -*-
+# @Time : 2020/11/22 12:08
+# @Author : Shao Weiqi
+# @Reviewer : Lin Kun
+# @Email : shaoweiqi@ruc.edu.cn
+
+r"""
+HRM
+################################################
+
+Reference:
+ Pengfei Wang et al. "Learning Hierarchical Representation Model for Next Basket Recommendation." in SIGIR 2015.
+
+Reference code:
+ https://github.com/wubinzzu/NeuRec
+
+"""
+
+import torch
+import torch.nn as nn
+from torch.nn.init import xavier_normal_, constant_
+
+from recbole.model.abstract_recommender import SequentialRecommender
+from recbole.model.loss import BPRLoss
+
+
+class HRM(SequentialRecommender):
+ r"""
+ HRM can well capture both sequential behavior and users’ general taste by involving transaction and
+ user representations in prediction.
+
+ HRM user max- & average- pooling as a good helper.
+ """
+
+ def __init__(self, config, dataset):
+ super(HRM, self).__init__(config, dataset)
+
+ # load the dataset information
+ self.n_user = dataset.num(self.USER_ID)
+ self.device = config["device"]
+
+ # load the parameters information
+ self.embedding_size = config["embedding_size"]
+ self.pooling_type_layer_1 = config["pooling_type_layer_1"]
+ self.pooling_type_layer_2 = config["pooling_type_layer_2"]
+ self.high_order = config["high_order"]
+ assert self.high_order <= self.max_seq_length, "high_order can't longer than the max_seq_length"
+ self.reg_weight = config["reg_weight"]
+ self.dropout_prob = config["dropout_prob"]
+
+ # define the layers and loss type
+ self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
+ self.user_embedding = nn.Embedding(self.n_user, self.embedding_size)
+ self.dropout = nn.Dropout(self.dropout_prob)
+
+ self.loss_type = config['loss_type']
+ if self.loss_type == 'BPR':
+ self.loss_fct = BPRLoss()
+ elif self.loss_type == 'CE':
+ self.loss_fct = nn.CrossEntropyLoss()
+ else:
+ raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")
+
+ # init the parameters of the model
+ self.apply(self._init_weights)
+
+ def inverse_seq_item(self, seq_item, seq_item_len):
+ """
+ inverse the seq_item, like this
+ [1,2,3,0,0,0,0] -- after inverse -->> [0,0,0,0,1,2,3]
+ """
+ seq_item = seq_item.cpu().numpy()
+ seq_item_len = seq_item_len.cpu().numpy()
+ new_seq_item = []
+ for items, length in zip(seq_item, seq_item_len):
+ item = list(items[:length])
+ zeros = list(items[length:])
+ seqs = zeros + item
+ new_seq_item.append(seqs)
+ seq_item = torch.tensor(new_seq_item, dtype=torch.long, device=self.device)
+
+ return seq_item
+
+ def _init_weights(self, module):
+
+ if isinstance(module, nn.Embedding):
+ xavier_normal_(module.weight.data)
+
+ def forward(self, seq_item, user, seq_item_len):
+
+ # seq_item=self.inverse_seq_item(seq_item)
+ seq_item = self.inverse_seq_item(seq_item, seq_item_len)
+
+ seq_item_embedding = self.item_embedding(seq_item)
+ # batch_size * seq_len * embedding_size
+
+ high_order_item_embedding = seq_item_embedding[:, -self.high_order:, :]
+ # batch_size * high_order * embedding_size
+
+ user_embedding = self.dropout(self.user_embedding(user))
+ # batch_size * embedding_size
+
+ # layer 1
+ if self.pooling_type_layer_1 == "max":
+ high_order_item_embedding = torch.max(high_order_item_embedding, dim=1).values
+ # batch_size * embedding_size
+ else:
+ for idx, len in enumerate(seq_item_len):
+ if len > self.high_order:
+ seq_item_len[idx] = self.high_order
+ high_order_item_embedding = torch.sum(seq_item_embedding, dim=1)
+ high_order_item_embedding = torch.div(high_order_item_embedding, seq_item_len.unsqueeze(1).float())
+ # batch_size * embedding_size
+ hybrid_user_embedding = self.dropout(
+ torch.cat([user_embedding.unsqueeze(dim=1), high_order_item_embedding.unsqueeze(dim=1)], dim=1))
+ # batch_size * 2_mul_embedding_size
+
+ # layer 2
+ if self.pooling_type_layer_2 == "max":
+ hybrid_user_embedding = torch.max(hybrid_user_embedding, dim=1).values
+ # batch_size * embedding_size
+ else:
+ hybrid_user_embedding = torch.mean(hybrid_user_embedding, dim=1)
+ # batch_size * embedding_size
+
+ return hybrid_user_embedding
+
+ def calculate_loss(self, interaction):
+
+ seq_item = interaction[self.ITEM_SEQ]
+ seq_item_len = interaction[self.ITEM_SEQ_LEN]
+ user = interaction[self.USER_ID]
+ seq_output = self.forward(seq_item, user, seq_item_len)
+ pos_items = interaction[self.POS_ITEM_ID]
+ pos_items_emb = self.item_embedding(pos_items)
+ if self.loss_type == 'BPR':
+ neg_items = interaction[self.NEG_ITEM_ID]
+ neg_items_emb = self.item_embedding(neg_items)
+ pos_score = torch.sum(seq_output * pos_items_emb, dim=-1)
+ neg_score = torch.sum(seq_output * neg_items_emb, dim=-1)
+ loss = self.loss_fct(pos_score, neg_score)
+ return loss
+ else: # self.loss_type = 'CE'
+ test_item_emb = self.item_embedding.weight.t()
+ logits = torch.matmul(seq_output, test_item_emb)
+ loss = self.loss_fct(logits, pos_items)
+
+ return loss
+
+ def predict(self, interaction):
+
+ item_seq = interaction[self.ITEM_SEQ]
+ seq_item_len = interaction[self.ITEM_SEQ_LEN]
+ test_item = interaction[self.ITEM_ID]
+ user = interaction[self.USER_ID]
+ seq_output = self.forward(item_seq, user, seq_item_len)
+ seq_output = seq_output.repeat(1, self.embedding_size)
+ test_item_emb = self.item_embedding(test_item)
+ scores = torch.mul(seq_output, test_item_emb).sum(dim=1)
+
+ return scores
+
+ def full_sort_predict(self, interaction):
+
+ item_seq = interaction[self.ITEM_SEQ]
+ seq_item_len = interaction[self.ITEM_SEQ_LEN]
+ user = interaction[self.USER_ID]
+ seq_output = self.forward(item_seq, user, seq_item_len)
+ test_items_emb = self.item_embedding.weight
+ scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1))
+
+ return scores
diff --git a/recbole/model/sequential_recommender/npe.py b/recbole/model/sequential_recommender/npe.py
new file mode 100644
index 000000000..f24374e44
--- /dev/null
+++ b/recbole/model/sequential_recommender/npe.py
@@ -0,0 +1,114 @@
+# -*- coding: utf-8 -*-
+# @Time : 2020/11/22 14:56
+# @Author : Shao Weiqi
+# @Reviewer : Lin Kun
+# @Email : shaoweiqi@ruc.edu.cn
+
+r"""
+NPE
+################################################
+
+Reference:
+ ThaiBinh Nguyen, et al. "NPE: Neural Personalized Embedding for Collaborative Filtering" in ijcai2018
+
+Reference code:
+ https://github.com/wubinzzu/NeuRec
+
+"""
+
+import torch
+import torch.nn as nn
+from torch.nn.init import xavier_normal_
+from recbole.model.abstract_recommender import SequentialRecommender
+from recbole.model.loss import BPRLoss
+
+
+class NPE(SequentialRecommender):
+ r"""
+ models a user’s click to an item in two terms: the personal preference of the user for the item,
+ and the relationships between this item and other items clicked by the user
+
+ """
+
+ def __init__(self, config, dataset):
+ super(NPE, self).__init__(config, dataset)
+
+ # load the dataset information
+ self.n_user = dataset.num(self.USER_ID)
+ self.device = config["device"]
+
+ # load the parameters information
+ self.embedding_size = config["embedding_size"]
+ self.dropout_prob = config["dropout_prob"]
+
+ # define layers and loss type
+ self.user_embedding = nn.Embedding(self.n_user, self.embedding_size)
+ self.item_embedding = nn.Embedding(self.n_items, self.embedding_size)
+ self.embedding_seq_item = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
+ self.relu = nn.ReLU()
+ self.dropout = nn.Dropout(self.dropout_prob)
+
+ self.loss_type = config['loss_type']
+ if self.loss_type == 'BPR':
+ self.loss_fct = BPRLoss()
+ elif self.loss_type == 'CE':
+ self.loss_fct = nn.CrossEntropyLoss()
+ else:
+ raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")
+
+ # init the parameters of the module
+ self.apply(self._init_weights)
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Embedding):
+ xavier_normal_(module.weight.data)
+
+ def forward(self, seq_item, user):
+
+ user_embedding = self.dropout(self.relu(self.user_embedding(user)))
+ # batch_size * embedding_size
+ seq_item_embedding = self.item_embedding(seq_item).sum(dim=1)
+ seq_item_embedding = self.dropout(self.relu(seq_item_embedding))
+ # batch_size * embedding_size
+
+ return user_embedding + seq_item_embedding
+
+ def calculate_loss(self, interaction):
+
+ seq_item = interaction[self.ITEM_SEQ]
+ user = interaction[self.USER_ID]
+ seq_output = self.forward(seq_item, user)
+ pos_items = interaction[self.POS_ITEM_ID]
+ pos_items_embs = self.item_embedding(pos_items)
+ if self.loss_type == 'BPR':
+ neg_items = interaction[self.NEG_ITEM_ID]
+ neg_items_emb = self.relu(self.item_embedding(neg_items))
+ pos_items_emb = self.relu(pos_items_embs)
+ pos_score = torch.sum(seq_output * pos_items_emb, dim=-1)
+ neg_score = torch.sum(seq_output * neg_items_emb, dim=-1)
+ loss = self.loss_fct(pos_score, neg_score)
+ return loss
+ else: # self.loss_type = 'CE'
+ test_item_emb = self.relu(self.item_embedding.weight)
+ logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
+ loss = self.loss_fct(logits, pos_items)
+ return loss
+
+ def predict(self, interaction):
+
+ item_seq = interaction[self.ITEM_SEQ]
+ test_item = interaction[self.ITEM_ID]
+ user = interaction[self.USER_ID]
+ seq_output = self.forward(item_seq, user)
+ test_item_emb = self.relu(self.item_embedding(test_item))
+ scores = torch.mul(seq_output, test_item_emb).sum(dim=1)
+ return scores
+
+ def full_sort_predict(self, interaction):
+
+ item_seq = interaction[self.ITEM_SEQ]
+ user = interaction[self.USER_ID]
+ seq_output = self.forward(item_seq, user)
+ test_items_emb = self.relu(self.item_embedding.weight)
+ scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1))
+ return scores
diff --git a/recbole/model/sequential_recommender/shan.py b/recbole/model/sequential_recommender/shan.py
new file mode 100644
index 000000000..d970a624d
--- /dev/null
+++ b/recbole/model/sequential_recommender/shan.py
@@ -0,0 +1,212 @@
+# -*- coding: utf-8 -*-
+# @Time : 2020/11/20 22:33
+# @Author : Shao Weiqi
+# @Reviewer : Lin Kun
+# @Email : shaoweiqi@ruc.edu.cn
+
+r"""
+SHAN
+################################################
+
+Reference:
+ Ying, H et al. "Sequential Recommender System based on Hierarchical Attention Network."in IJCAI 2018
+
+
+"""
+import torch
+import torch.nn as nn
+import numpy as np
+from recbole.model.abstract_recommender import SequentialRecommender
+from recbole.model.loss import BPRLoss
+from torch.nn.init import normal_, uniform_
+
+
+class SHAN(SequentialRecommender):
+ r"""
+ SHAN exploit the Hierarchical Attention Network to get the long-short term preference
+ first get the long term purpose and then fuse the long-term with recent items to get long-short term purpose
+
+ """
+
+ def __init__(self, config, dataset):
+
+ super(SHAN, self).__init__(config, dataset)
+
+ # load the dataset information
+ self.n_users = dataset.num(self.USER_ID)
+ self.device = config['device']
+
+ # load the parameter information
+ self.embedding_size = config["embedding_size"]
+ self.short_item_length = config["short_item_length"] # the length of the short session items
+ assert self.short_item_length <= self.max_seq_length, "short_item_length can't longer than the max_seq_length"
+ self.reg_weight = config["reg_weight"]
+
+ # define layers and loss
+ self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
+ self.user_embedding = nn.Embedding(self.n_users, self.embedding_size)
+
+ self.long_w = nn.Linear(self.embedding_size, self.embedding_size)
+ self.long_b = nn.Parameter(
+ uniform_(tensor=torch.zeros(self.embedding_size), a=-np.sqrt(3 / self.embedding_size),
+ b=np.sqrt(3 / self.embedding_size)), requires_grad=True).to(self.device)
+ self.long_short_w = nn.Linear(self.embedding_size, self.embedding_size)
+ self.long_short_b = nn.Parameter(
+ uniform_(tensor=torch.zeros(self.embedding_size), a=-np.sqrt(3 / self.embedding_size),
+ b=np.sqrt(3 / self.embedding_size)), requires_grad=True).to(self.device)
+
+ self.relu = nn.ReLU()
+
+ self.loss_type = config['loss_type']
+ if self.loss_type == 'BPR':
+ self.loss_fct = BPRLoss()
+ elif self.loss_type == 'CE':
+ self.loss_fct = nn.CrossEntropyLoss()
+ else:
+ raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")
+
+ # init the parameter of the model
+ self.apply(self.init_weights)
+
+ def reg_loss(self, user_embedding, item_embedding):
+
+ reg_1, reg_2 = self.reg_weight
+ loss_1 = reg_1 * torch.norm(self.long_w.weight, p=2) + reg_1 * torch.norm(self.long_short_w.weight, p=2)
+ loss_2 = reg_2 * torch.norm(user_embedding, p=2) + reg_2 * torch.norm(item_embedding, p=2)
+
+ return loss_1 + loss_2
+
+ def inverse_seq_item(self, seq_item, seq_item_len):
+ """
+ inverse the seq_item, like this
+ [1,2,3,0,0,0,0] -- after inverse -->> [0,0,0,0,1,2,3]
+ """
+ seq_item = seq_item.cpu().numpy()
+ seq_item_len = seq_item_len.cpu().numpy()
+ new_seq_item = []
+ for items, length in zip(seq_item, seq_item_len):
+ item = list(items[:length])
+ zeros = list(items[length:])
+ seqs = zeros + item
+ new_seq_item.append(seqs)
+ seq_item = torch.tensor(new_seq_item, dtype=torch.long, device=self.device)
+
+ return seq_item
+
+ def init_weights(self, module):
+ if isinstance(module, nn.Embedding):
+ normal_(module.weight.data, 0., 0.01)
+ elif isinstance(module, nn.Linear):
+ uniform_(module.weight.data, -np.sqrt(3 / self.embedding_size), np.sqrt(3 / self.embedding_size))
+ elif isinstance(module, nn.Parameter):
+ uniform_(module.data, -np.sqrt(3 / self.embedding_size), np.sqrt(3 / self.embedding_size))
+ print(module.data)
+
+ def forward(self, seq_item, user, seq_item_len):
+
+ seq_item = self.inverse_seq_item(seq_item, seq_item_len)
+
+ seq_item_embedding = self.item_embedding(seq_item)
+ user_embedding = self.user_embedding(user)
+
+ # get the mask
+ mask = seq_item.data.eq(0)
+ long_term_attention_based_pooling_layer = self.long_term_attention_based_pooling_layer(seq_item_embedding,
+ user_embedding, mask)
+ # batch_size * 1 * embedding_size
+
+ short_item_embedding = seq_item_embedding[:, -self.short_item_length:, :]
+ mask_long_short = mask[:, -self.short_item_length:]
+ batch_size = mask_long_short.size(0)
+ x = torch.zeros(size=(batch_size, 1)).eq(1).to(self.device)
+ mask_long_short = torch.cat([x, mask_long_short], dim=1)
+ # batch_size * short_item_length * embedding_size
+ long_short_item_embedding = torch.cat([long_term_attention_based_pooling_layer, short_item_embedding], dim=1)
+ # batch_size * 1_plus_short_item_length * embedding_size
+
+ long_short_item_embedding = self.long_and_short_term_attention_based_pooling_layer(long_short_item_embedding,
+ user_embedding,
+ mask_long_short)
+ # batch_size * embedding_size
+
+ return long_short_item_embedding
+
+ def calculate_loss(self, interaction):
+
+ seq_item = interaction[self.ITEM_SEQ]
+ seq_item_len = interaction[self.ITEM_SEQ_LEN]
+ user = interaction[self.USER_ID]
+ user_embedding = self.user_embedding(user)
+ seq_output = self.forward(seq_item, user, seq_item_len)
+ pos_items = interaction[self.POS_ITEM_ID]
+ pos_items_emb = self.item_embedding(pos_items)
+ if self.loss_type == 'BPR':
+ neg_items = interaction[self.NEG_ITEM_ID]
+ neg_items_emb = self.item_embedding(neg_items)
+ pos_score = torch.sum(seq_output * pos_items_emb, dim=-1)
+ neg_score = torch.sum(seq_output * neg_items_emb, dim=-1)
+ loss = self.loss_fct(pos_score, neg_score)
+ return loss + self.reg_loss(user_embedding, pos_items_emb)
+ else: # self.loss_type = 'CE'
+ test_item_emb = self.item_embedding.weight
+ logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
+ loss = self.loss_fct(logits, pos_items)
+ return loss + self.reg_loss(user_embedding, pos_items_emb)
+
+ def predict(self, interaction):
+
+ item_seq = interaction[self.ITEM_SEQ]
+ test_item = interaction[self.ITEM_ID]
+ seq_item_len = interaction[self.ITEM_SEQ_LEN]
+ user = interaction[self.USER_ID]
+ seq_output = self.forward(item_seq, user, seq_item_len)
+ test_item_emb = self.item_embedding(test_item)
+ scores = torch.mul(seq_output, test_item_emb).sum(dim=1)
+ return scores
+
+ def full_sort_predict(self, interaction):
+
+ item_seq = interaction[self.ITEM_SEQ]
+ seq_item_len = interaction[self.ITEM_SEQ_LEN]
+ user = interaction[self.USER_ID]
+ seq_output = self.forward(item_seq, user, seq_item_len)
+ test_items_emb = self.item_embedding.weight
+ scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1))
+ return scores
+
+ def long_and_short_term_attention_based_pooling_layer(self, long_short_item_embedding, user_embedding, mask=None):
+ """
+
+ fusing the long term purpose with the short-term preference
+ """
+ long_short_item_embedding_value = long_short_item_embedding
+
+ long_short_item_embedding = self.relu(self.long_short_w(long_short_item_embedding) + self.long_short_b)
+ long_short_item_embedding = torch.matmul(long_short_item_embedding, user_embedding.unsqueeze(2)).squeeze(-1)
+ # batch_size * seq_len
+ if mask is not None:
+ long_short_item_embedding.masked_fill_(mask, -1e9)
+ long_short_item_embedding = nn.Softmax(dim=-1)(long_short_item_embedding)
+ long_short_item_embedding = torch.mul(long_short_item_embedding_value,
+ long_short_item_embedding.unsqueeze(2)).sum(dim=1)
+
+ return long_short_item_embedding
+
+ def long_term_attention_based_pooling_layer(self, seq_item_embedding, user_embedding, mask=None):
+ """
+
+ get the long term purpose of user
+ """
+ seq_item_embedding_value = seq_item_embedding
+
+ seq_item_embedding = self.relu(self.long_w(seq_item_embedding) + self.long_b)
+ user_item_embedding = torch.matmul(seq_item_embedding, user_embedding.unsqueeze(2)).squeeze(-1)
+ # batch_size * seq_len
+ if mask is not None:
+ user_item_embedding.masked_fill_(mask, -1e9)
+ user_item_embedding = nn.Softmax(dim=1)(user_item_embedding)
+ user_item_embedding = torch.mul(seq_item_embedding_value, user_item_embedding.unsqueeze(2)).sum(dim=1,
+ keepdim=True)
+ # batch_size * 1 * embedding_size
+
+ return user_item_embedding
diff --git a/recbole/properties/dataset/ml-100k.yaml b/recbole/properties/dataset/ml-100k.yaml
index c513530cd..29449c0a5 100644
--- a/recbole/properties/dataset/ml-100k.yaml
+++ b/recbole/properties/dataset/ml-100k.yaml
@@ -37,12 +37,9 @@ lowest_val: ~
highest_val: ~
equal_val: ~
not_equal_val: ~
-drop_filter_field : False
# Preprocessing
fields_in_same_space: ~
-fill_nan: True
preload_weight: ~
-drop_preload_weight: True
normalize_field: ~
normalize_all: True
diff --git a/recbole/properties/dataset/sample.yaml b/recbole/properties/dataset/sample.yaml
index 8dcb9cc23..031a0049d 100644
--- a/recbole/properties/dataset/sample.yaml
+++ b/recbole/properties/dataset/sample.yaml
@@ -33,13 +33,10 @@ lowest_val: ~
highest_val: ~
equal_val: ~
not_equal_val: ~
-drop_filter_field : True
# Preprocessing
fields_in_same_space: ~
-fill_nan: True
preload_weight: ~
-drop_preload_weight: True
normalize_field: ~
normalize_all: True
diff --git a/recbole/properties/model/DMF.yaml b/recbole/properties/model/DMF.yaml
index 56040f182..d2ca2cfa4 100644
--- a/recbole/properties/model/DMF.yaml
+++ b/recbole/properties/model/DMF.yaml
@@ -1,6 +1,6 @@
# WARNING:
-# 1.if you set inter_matrix_type='rating', you must set drop_filter_field=False in your data config files.
-# 2.The dimensions of the last layer of users and items must be the same
+# 1. if you set inter_matrix_type='rating', you must set `unused_col: ~` in your data config files.
+# 2. The dimensions of the last layer of users and items must be the same
inter_matrix_type: '01'
user_embedding_size: 64
diff --git a/recbole/properties/model/HGN.yaml b/recbole/properties/model/HGN.yaml
new file mode 100644
index 000000000..69b95a440
--- /dev/null
+++ b/recbole/properties/model/HGN.yaml
@@ -0,0 +1,4 @@
+embedding_size: 64
+loss_type: 'BPR'
+pooling_type: "average"
+reg_weight: [0.00,0.00]
\ No newline at end of file
diff --git a/recbole/properties/model/HRM.yaml b/recbole/properties/model/HRM.yaml
new file mode 100644
index 000000000..531e93c3e
--- /dev/null
+++ b/recbole/properties/model/HRM.yaml
@@ -0,0 +1,6 @@
+embedding_size: 64
+high_order: 2
+loss_type: "CE"
+dropout_prob: 0.2
+pooling_type_layer_1: "max"
+pooling_type_layer_2: "max"
\ No newline at end of file
diff --git a/recbole/properties/model/NPE.yaml b/recbole/properties/model/NPE.yaml
new file mode 100644
index 000000000..cb93282a7
--- /dev/null
+++ b/recbole/properties/model/NPE.yaml
@@ -0,0 +1,3 @@
+embedding_size: 64
+loss_type: "CE"
+dropout_prob: 0.3
\ No newline at end of file
diff --git a/recbole/properties/model/SHAN.yaml b/recbole/properties/model/SHAN.yaml
new file mode 100644
index 000000000..9c83bf5c8
--- /dev/null
+++ b/recbole/properties/model/SHAN.yaml
@@ -0,0 +1,4 @@
+embedding_size: 64
+short_item_length: 2
+loss_type: "CE"
+reg_weight: [0.01,0.0001]
\ No newline at end of file
diff --git a/recbole/properties/model/xgboost.yaml b/recbole/properties/model/xgboost.yaml
new file mode 100644
index 000000000..3ad1e041f
--- /dev/null
+++ b/recbole/properties/model/xgboost.yaml
@@ -0,0 +1,46 @@
+# Type of training method
+train_or_cv: train
+
+# DMatrix
+
+xgb_weight: ~
+xgb_base_margin: ~
+xgb_missing: ~
+xgb_silent: ~
+xgb_feature_names: ~
+xgb_feature_types: ~
+xgb_nthread: ~
+
+# train or cv
+xgb_model: ~
+xgb_params:
+ booster: gbtree
+ objective: binary:logistic
+ gamma: 0.1
+ max_depth: 10
+ lambda: 3
+ subsample: 0.5
+ colsample_bytree: 0.7
+ min_child_weight: 3
+ eta: 0.1
+ seed: 100
+ nthread: 4
+xgb_num_boost_round: 10
+# xgb_evals: ~
+xgb_obj: ~
+xgb_feval: ~
+xgb_maximize: ~
+xgb_early_stopping_rounds: ~
+# xgb_evals_result: ~
+xgb_verbose_eval: False
+
+# cv
+xgb_cv_nfold: 3
+xgb_cv_stratified: False
+xgb_cv_folds: ~
+xgb_cv_fpreproc: ~
+xgb_cv_show_stdv: True
+xgb_cv_seed: 0
+xgb_cv_shuffle: True
+
+
diff --git a/recbole/properties/quick_start_config/context-aware.yaml b/recbole/properties/quick_start_config/context-aware.yaml
new file mode 100644
index 000000000..cdb71f098
--- /dev/null
+++ b/recbole/properties/quick_start_config/context-aware.yaml
@@ -0,0 +1,5 @@
+eval_setting: RO_RS
+group_by_user: False
+training_neg_sample_num: 0
+metrics: ['AUC', 'LogLoss']
+valid_metric: AUC
\ No newline at end of file
diff --git a/recbole/properties/quick_start_config/context-aware_ml-100k.yaml b/recbole/properties/quick_start_config/context-aware_ml-100k.yaml
new file mode 100644
index 000000000..150dd4d18
--- /dev/null
+++ b/recbole/properties/quick_start_config/context-aware_ml-100k.yaml
@@ -0,0 +1,5 @@
+threshold: {'rating': 4}
+load_col:
+ inter: ['user_id', 'item_id', 'rating', 'timestamp']
+ user: ['user_id', 'age', 'gender', 'occupation']
+ item: ['item_id', 'release_year', 'class']
\ No newline at end of file
diff --git a/recbole/properties/quick_start_config/knowledge_base.yaml b/recbole/properties/quick_start_config/knowledge_base.yaml
new file mode 100644
index 000000000..379341326
--- /dev/null
+++ b/recbole/properties/quick_start_config/knowledge_base.yaml
@@ -0,0 +1,4 @@
+load_col:
+ inter: ['user_id', 'item_id', 'rating', 'timestamp']
+ kg: ['head_id', 'relation_id', 'tail_id']
+ link: ['item_id', 'entity_id']
\ No newline at end of file
diff --git a/recbole/properties/quick_start_config/sequential.yaml b/recbole/properties/quick_start_config/sequential.yaml
new file mode 100644
index 000000000..87c0fa053
--- /dev/null
+++ b/recbole/properties/quick_start_config/sequential.yaml
@@ -0,0 +1 @@
+eval_setting: TO_LS,full
\ No newline at end of file
diff --git a/recbole/properties/quick_start_config/sequential_DIN.yaml b/recbole/properties/quick_start_config/sequential_DIN.yaml
new file mode 100644
index 000000000..58b8db955
--- /dev/null
+++ b/recbole/properties/quick_start_config/sequential_DIN.yaml
@@ -0,0 +1,3 @@
+eval_setting: TO_LS, uni100
+metrics: ['AUC', 'LogLoss']
+valid_metric: AUC
\ No newline at end of file
diff --git a/recbole/properties/quick_start_config/sequential_DIN_on_ml-100k.yaml b/recbole/properties/quick_start_config/sequential_DIN_on_ml-100k.yaml
new file mode 100644
index 000000000..702a7a862
--- /dev/null
+++ b/recbole/properties/quick_start_config/sequential_DIN_on_ml-100k.yaml
@@ -0,0 +1,4 @@
+load_col:
+ inter: ['user_id', 'item_id', 'rating', 'timestamp']
+ user: ['user_id', 'age', 'gender', 'occupation']
+ item: ['item_id', 'release_year']
\ No newline at end of file
diff --git a/recbole/properties/quick_start_config/sequential_embedding_model.yaml b/recbole/properties/quick_start_config/sequential_embedding_model.yaml
new file mode 100644
index 000000000..59b920994
--- /dev/null
+++ b/recbole/properties/quick_start_config/sequential_embedding_model.yaml
@@ -0,0 +1,4 @@
+load_col:
+ inter: ['user_id', 'item_id', 'rating', 'timestamp']
+ ent: ['ent_id', 'ent_emb']
+additional_feat_suffix: ent
\ No newline at end of file
diff --git a/recbole/properties/quick_start_config/special_sequential_on_ml-100k.yaml b/recbole/properties/quick_start_config/special_sequential_on_ml-100k.yaml
new file mode 100644
index 000000000..1fe509fe6
--- /dev/null
+++ b/recbole/properties/quick_start_config/special_sequential_on_ml-100k.yaml
@@ -0,0 +1,3 @@
+load_col:
+ inter: ['user_id', 'item_id', 'rating', 'timestamp']
+ item: ['item_id', 'release_year', 'class']
\ No newline at end of file
diff --git a/recbole/sampler/sampler.py b/recbole/sampler/sampler.py
index 2f4879e53..54f1fc37b 100644
--- a/recbole/sampler/sampler.py
+++ b/recbole/sampler/sampler.py
@@ -13,9 +13,9 @@
########################
"""
-import random
import copy
import numpy as np
+import torch
class AbstractSampler(object):
@@ -33,7 +33,10 @@ class AbstractSampler(object):
used_ids (numpy.ndarray): The result of :meth:`get_used_ids`.
"""
def __init__(self, distribution):
- self.distribution = None
+ self.distribution = ''
+ self.random_list = []
+ self.random_pr = 0
+ self.random_list_length = 0
self.set_distribution(distribution)
self.used_ids = self.get_used_ids()
@@ -47,7 +50,7 @@ def set_distribution(self, distribution):
return
self.distribution = distribution
self.random_list = self.get_random_list()
- random.shuffle(self.random_list)
+ np.random.shuffle(self.random_list)
self.random_pr = 0
self.random_list_length = len(self.random_list)
@@ -74,7 +77,21 @@ def random(self):
self.random_pr += 1
return value_id
- def sample_by_key_ids(self, key_ids, num, used_ids):
+ def random_num(self, num):
+ value_id = []
+ self.random_pr %= self.random_list_length
+ while True:
+ if self.random_pr + num <= self.random_list_length:
+ value_id.append(self.random_list[self.random_pr: self.random_pr + num])
+ self.random_pr += num
+ break
+ else:
+ value_id.append(self.random_list[self.random_pr:])
+ num -= self.random_list_length - self.random_pr
+ self.random_pr = 0
+ return np.concatenate(value_id)
+
+ def sample_by_key_ids(self, key_ids, num):
"""Sampling by key_ids.
Args:
@@ -83,22 +100,49 @@ def sample_by_key_ids(self, key_ids, num, used_ids):
used_ids (np.ndarray): Used ids. index is key_id, and element is a set of value_ids.
Returns:
- np.ndarray: Sampled value_ids.
+ torch.tensor: Sampled value_ids.
value_ids[0], value_ids[len(key_ids)], value_ids[len(key_ids) * 2], ..., value_id[len(key_ids) * (num - 1)]
is sampled for key_ids[0];
value_ids[1], value_ids[len(key_ids) + 1], value_ids[len(key_ids) * 2 + 1], ...,
value_id[len(key_ids) * (num - 1) + 1] is sampled for key_ids[1]; ...; and so on.
"""
+ key_ids = np.array(key_ids)
key_num = len(key_ids)
total_num = key_num * num
- value_ids = np.zeros(total_num, dtype=np.int64)
- used_id_list = np.repeat(used_ids, num)
- for i, used_ids in enumerate(used_id_list):
- cur = self.random()
- while cur in used_ids:
- cur = self.random()
- value_ids[i] = cur
- return value_ids
+ if (key_ids == key_ids[0]).all():
+ key_id = key_ids[0]
+ used = np.array(list(self.used_ids[key_id]))
+ value_ids = self.random_num(total_num)
+ check_list = np.arange(total_num)[np.isin(value_ids, used)]
+ while len(check_list) > 0:
+ value_ids[check_list] = value = self.random_num(len(check_list))
+ perm = value.argsort(kind='quicksort')
+ aux = value[perm]
+ mask = np.empty(aux.shape, dtype=np.bool_)
+ mask[:1] = True
+ mask[1:] = aux[1:] != aux[:-1]
+ value = aux[mask]
+ rev_idx = np.empty(mask.shape, dtype=np.intp)
+ rev_idx[perm] = np.cumsum(mask) - 1
+ ar = np.concatenate((value, used))
+ order = ar.argsort(kind='mergesort')
+ sar = ar[order]
+ bool_ar = (sar[1:] == sar[:-1])
+ flag = np.concatenate((bool_ar, [False]))
+ ret = np.empty(ar.shape, dtype=bool)
+ ret[order] = flag
+ mask = ret[rev_idx]
+ check_list = check_list[mask]
+ else:
+ value_ids = np.zeros(total_num, dtype=np.int64)
+ check_list = np.arange(total_num)
+ key_ids = np.tile(key_ids, num)
+ while len(check_list) > 0:
+ value_ids[check_list] = self.random_num(len(check_list))
+ check_list = np.array([i for i, used, v in
+ zip(check_list, self.used_ids[key_ids[check_list]], value_ids[check_list])
+ if v in used])
+ return torch.tensor(value_ids)
class Sampler(AbstractSampler):
@@ -140,11 +184,11 @@ def get_random_list(self):
np.ndarray or list: Random list of item_id.
"""
if self.distribution == 'uniform':
- return list(range(1, self.n_items))
+ return np.arange(1, self.n_items)
elif self.distribution == 'popularity':
random_item_list = []
for dataset in self.datasets:
- random_item_list.extend(dataset.inter_feat[self.iid_field].values)
+ random_item_list.extend(dataset.inter_feat[self.iid_field].numpy())
return random_item_list
else:
raise NotImplementedError('Distribution [{}] has not been implemented'.format(self.distribution))
@@ -159,7 +203,7 @@ def get_used_ids(self):
last = [set() for i in range(self.n_users)]
for phase, dataset in zip(self.phases, self.datasets):
cur = np.array([set(s) for s in last])
- for uid, iid in dataset.inter_feat[[self.uid_field, self.iid_field]].values:
+ for uid, iid in zip(dataset.inter_feat[self.uid_field].numpy(), dataset.inter_feat[self.iid_field].numpy()):
cur[uid].add(iid)
last = used_item_id[phase] = cur
return used_item_id
@@ -189,14 +233,14 @@ def sample_by_user_ids(self, user_ids, num):
num (int): Number of sampled item_ids for each user_id.
Returns:
- np.ndarray: Sampled item_ids.
+ torch.tensor: Sampled item_ids.
item_ids[0], item_ids[len(user_ids)], item_ids[len(user_ids) * 2], ..., item_id[len(user_ids) * (num - 1)]
is sampled for user_ids[0];
item_ids[1], item_ids[len(user_ids) + 1], item_ids[len(user_ids) * 2 + 1], ...,
item_id[len(user_ids) * (num - 1) + 1] is sampled for user_ids[1]; ...; and so on.
"""
try:
- return self.sample_by_key_ids(user_ids, num, self.used_ids[user_ids])
+ return self.sample_by_key_ids(user_ids, num)
except IndexError:
for user_id in user_ids:
if user_id < 0 or user_id >= self.n_users:
@@ -229,7 +273,7 @@ def get_random_list(self):
np.ndarray or list: Random list of entity_id.
"""
if self.distribution == 'uniform':
- return list(range(1, self.entity_num))
+ return np.arange(1, self.entity_num)
elif self.distribution == 'popularity':
return list(self.hid_list) + list(self.tid_list)
else:
@@ -254,14 +298,14 @@ def sample_by_entity_ids(self, head_entity_ids, num=1):
num (int, optional): Number of sampled entity_ids for each head_entity_id. Defaults to ``1``.
Returns:
- np.ndarray: Sampled entity_ids.
+ torch.tensor: Sampled entity_ids.
entity_ids[0], entity_ids[len(head_entity_ids)], entity_ids[len(head_entity_ids) * 2], ...,
entity_id[len(head_entity_ids) * (num - 1)] is sampled for head_entity_ids[0];
entity_ids[1], entity_ids[len(head_entity_ids) + 1], entity_ids[len(head_entity_ids) * 2 + 1], ...,
entity_id[len(head_entity_ids) * (num - 1) + 1] is sampled for head_entity_ids[1]; ...; and so on.
"""
try:
- return self.sample_by_key_ids(head_entity_ids, num, self.used_ids[head_entity_ids])
+ return self.sample_by_key_ids(head_entity_ids, num)
except IndexError:
for head_entity_id in head_entity_ids:
if head_entity_id not in self.head_entities:
@@ -287,8 +331,8 @@ def __init__(self, phases, dataset, distribution='uniform'):
self.dataset = dataset
self.iid_field = dataset.iid_field
- self.user_num = dataset.user_num
- self.item_num = dataset.item_num
+ self.n_users = dataset.user_num
+ self.n_items = dataset.item_num
super().__init__(distribution=distribution)
@@ -298,9 +342,9 @@ def get_random_list(self):
np.ndarray or list: Random list of item_id.
"""
if self.distribution == 'uniform':
- return list(range(1, self.item_num))
+ return np.arange(1, self.n_items)
elif self.distribution == 'popularity':
- return self.dataset.inter_feat[self.iid_field].values
+ return self.dataset.inter_feat[self.iid_field].numpy()
else:
raise NotImplementedError('Distribution [{}] has not been implemented'.format(self.distribution))
@@ -310,7 +354,7 @@ def get_used_ids(self):
np.ndarray: Used item_ids is the same as positive item_ids.
Index is user_id, and element is a set of item_ids.
"""
- return np.array([set() for i in range(self.user_num)])
+ return np.array([set() for i in range(self.n_users)])
def sample_by_user_ids(self, user_ids, num):
"""Sampling by user_ids.
@@ -320,14 +364,14 @@ def sample_by_user_ids(self, user_ids, num):
num (int): Number of sampled item_ids for each user_id.
Returns:
- np.ndarray: Sampled item_ids.
+ torch.tensor: Sampled item_ids.
item_ids[0], item_ids[len(user_ids)], item_ids[len(user_ids) * 2], ..., item_id[len(user_ids) * (num - 1)]
is sampled for user_ids[0];
item_ids[1], item_ids[len(user_ids) + 1], item_ids[len(user_ids) * 2 + 1], ...,
item_id[len(user_ids) * (num - 1) + 1] is sampled for user_ids[1]; ...; and so on.
"""
try:
- return self.sample_by_key_ids(user_ids, num, self.used_ids[user_ids])
+ return self.sample_by_key_ids(user_ids, num)
except IndexError:
for user_id in user_ids:
if user_id < 0 or user_id >= self.n_users:
diff --git a/recbole/trainer/trainer.py b/recbole/trainer/trainer.py
index 1dc7955d6..8907e13c0 100644
--- a/recbole/trainer/trainer.py
+++ b/recbole/trainer/trainer.py
@@ -3,9 +3,9 @@
# @Email : slmu@ruc.edu.cn
# UPDATE:
-# @Time : 2020/8/7, 2020/9/26, 2020/9/26, 2020/10/01, 2020/9/16, 2020/10/8, 2020/10/15
-# @Author : Zihan Lin, Yupeng Hou, Yushuo Chen, Shanlei Mu, Xingyu Pan, Hui Wang, Xinyan Fan
-# @Email : linzihan.super@foxmail.com, houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, slmu@ruc.edu.cn, panxy@ruc.edu.cn, hui.wang@ruc.edu.cn, xinyan.fan@ruc.edu.cn
+# @Time : 2020/8/7, 2020/9/26, 2020/9/26, 2020/10/01, 2020/9/16, 2020/10/8, 2020/10/15, 2020/11/20
+# @Author : Zihan Lin, Yupeng Hou, Yushuo Chen, Shanlei Mu, Xingyu Pan, Hui Wang, Xinyan Fan, Chen Yang
+# @Email : linzihan.super@foxmail.com, houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, slmu@ruc.edu.cn, panxy@ruc.edu.cn, hui.wang@ruc.edu.cn, xinyan.fan@ruc.edu.cn, 254170321@qq.com
r"""
recbole.trainer.trainer
@@ -19,6 +19,7 @@
from torch.nn.utils.clip_grad import clip_grad_norm_
import numpy as np
import matplotlib.pyplot as plt
+import xgboost as xgb
from time import time
from logging import getLogger
@@ -218,7 +219,7 @@ def _check_nan(self, loss):
def _generate_train_loss_output(self, epoch_idx, s_time, e_time, losses):
train_loss_output = 'epoch %d training [time: %.2fs, ' % (epoch_idx, e_time - s_time)
if isinstance(losses, tuple):
- train_loss_output = ', '.join('train_loss%d: %.4f' % (idx + 1, loss) for idx, loss in enumerate(losses))
+ train_loss_output += ', '.join('train_loss%d: %.4f' % (idx + 1, loss) for idx, loss in enumerate(losses))
else:
train_loss_output += 'train loss: %.4f' % losses
return train_loss_output + ']'
@@ -545,3 +546,162 @@ class TraditionalTrainer(Trainer):
def __init__(self, config, model):
super(TraditionalTrainer, self).__init__(config, model)
self.epochs = 1 # Set the epoch to 1 when running memory based model
+
+
+class xgboostTrainer(AbstractTrainer):
+ """xgboostTrainer is designed for XGBOOST.
+
+ """
+ def __init__(self, config, model):
+ super(xgboostTrainer, self).__init__(config, model)
+
+ self.logger = getLogger()
+ self.label_field = config['LABEL_FIELD']
+
+ self.train_or_cv = config['train_or_cv']
+ self.xgb_model = config['xgb_model']
+
+ # DMatrix params
+ self.weight = config['xgb_weight']
+ self.base_margin = config['xgb_base_margin']
+ self.missing = config['xgb_missing']
+ self.silent = config['xgb_silent']
+ self.feature_names = config['xgb_feature_names']
+ self.feature_types = config['xgb_feature_types']
+ self.nthread = config['xgb_nthread']
+
+ # train params
+ self.params = config['xgb_params']
+ self.num_boost_round = config['xgb_num_boost_round']
+ self.evals = ()
+ self.obj = config['xgb_obj']
+ self.feval = config['xgb_feval']
+ self.maximize = config['xgb_maximize']
+ self.early_stopping_rounds = config['xgb_early_stopping_rounds']
+ self.evals_result = {}
+ self.verbose_eval = config['xgb_verbose_eval']
+ self.callbacks = None
+
+ # cv params
+ if self.train_or_cv == 'cv':
+ self.nfold = config['xgb_cv_nfold']
+ self.stratified = config['xgb_cv_stratified']
+ self.folds = config['xgb_cv_folds']
+ self.fpreproc = config['xgb_cv_freproc']
+ self.show_stdv = config['xgb_cv_show_stdv']
+ self.seed = config['xgb_cv_seed']
+ self.shuffle = config['xgb_cv_shuffle']
+
+ # evaluator
+ self.eval_type = config['eval_type']
+ self.epochs = config['epochs']
+ self.eval_step = min(config['eval_step'], self.epochs)
+ self.valid_metric = config['valid_metric'].lower()
+
+ if self.eval_type == EvaluatorType.INDIVIDUAL:
+ self.evaluator = LossEvaluator(config)
+ else:
+ self.evaluator = TopKEvaluator(config)
+
+ # model saved
+ self.checkpoint_dir = config['checkpoint_dir']
+ ensure_dir(self.checkpoint_dir)
+ saved_model_file = '{}-{}.pth'.format(self.config['model'], get_local_time())
+ self.saved_model_file = os.path.join(self.checkpoint_dir, saved_model_file)
+
+ def _interaction_to_DMatrix(self, interaction):
+ r"""Convert data format from interaction to DMatrix
+
+ Args:
+ interaction (Interaction): Data in the form of 'Interaction'.
+ Returns:
+ DMatrix: Data in the form of 'DMatrix'.
+ """
+ interaction_np = interaction.numpy()
+ cur_data = np.array([])
+ for key, value in interaction_np.items():
+ value = np.resize(value,(value.shape[0],1))
+ if key != self.label_field:
+ if cur_data.shape[0] == 0:
+ cur_data = value
+ else:
+ cur_data = np.hstack((cur_data, value))
+
+ return xgb.DMatrix(data = cur_data,
+ label = interaction_np[self.label_field],
+ weight = self.weight,
+ base_margin = self.base_margin,
+ missing = self.missing,
+ silent = self.silent,
+ feature_names = self.feature_names,
+ feature_types = self.feature_types,
+ nthread = self.nthread)
+
+ def _train_epoch(self, train_data, valid_data):
+ r"""
+
+ Args:
+ train_data (XgboostDataLoader): XgboostDataLoader, which is the same with GeneralDataLoader.
+ valid_data (XgboostDataLoader): XgboostDataLoader, which is the same with GeneralDataLoader.
+ """
+ for _, train_interaction in enumerate(train_data):
+ self.dtrain = self._interaction_to_DMatrix(train_interaction)
+ self.evals = [(self.dtrain,'train')]
+ self.model = xgb.train(self.params, self.dtrain, 1,
+ self.evals, self.obj, self.feval, self.maximize,
+ self.early_stopping_rounds, self.evals_result,
+ self.verbose_eval, self.xgb_model, self.callbacks)
+
+ self.model.save_model(self.saved_model_file)
+ self.xgb_model = self.saved_model_file
+
+ def _valid_epoch(self, valid_data):
+ r"""
+
+ Args:
+ valid_data (XgboostDataLoader): XgboostDataLoader, which is the same with GeneralDataLoader.
+ """
+ valid_result = self.evaluate(valid_data)
+ valid_score = calculate_valid_score(valid_result, self.valid_metric)
+ return valid_result, valid_score
+
+ def fit(self, train_data, valid_data=None, verbose=True, saved=True):
+ self.best_valid_score = 0.
+ self.best_valid_result = 0.
+ if self.train_or_cv == 'train':
+ for epoch_idx in range(self.epochs):
+ train_loss = self._train_epoch(train_data, valid_data)
+
+ if (epoch_idx + 1) % self.eval_step == 0:
+ # evaluate
+ valid_start_time = time()
+ valid_result, valid_score = self._valid_epoch(valid_data)
+ valid_end_time = time()
+ valid_score_output = "epoch %d evaluating [time: %.2fs, valid_score: %f]" % \
+ (epoch_idx, valid_end_time - valid_start_time, valid_score)
+ valid_result_output = 'valid result: \n' + dict2str(valid_result)
+ if verbose:
+ self.logger.info(valid_score_output)
+ self.logger.info(valid_result_output)
+
+ self.best_valid_score = valid_score
+ self.best_valid_result = valid_result
+
+ return self.best_valid_score, self.best_valid_result
+
+ def evaluate(self, eval_data, load_best_model=True, model_file=None):
+ self.eval_pred = torch.Tensor()
+ self.eval_true = torch.Tensor()
+
+ for _, batched_data in enumerate(eval_data):
+ batched_data_DMatrix = self._interaction_to_DMatrix(batched_data)
+ batch_pred = torch.Tensor(self.model.predict(batched_data_DMatrix))
+ if self.params['objective'] == 'binary:logistic':
+ batch_pred = (batch_pred >= 0.5) * 1
+ self.eval_pred = torch.cat((self.eval_pred, batch_pred))
+ self.eval_true = torch.cat((self.eval_true, batched_data[self.label_field]))
+
+ matrix_list = [torch.stack((self.eval_pred, self.eval_true), 1)]
+
+ result = self.evaluator.evaluate(matrix_list, eval_data)
+ return result
diff --git a/recbole/utils/argument_list.py b/recbole/utils/argument_list.py
index a1af91f4a..05f6ad18a 100644
--- a/recbole/utils/argument_list.py
+++ b/recbole/utils/argument_list.py
@@ -29,9 +29,9 @@
'NEG_PREFIX',
'ITEM_LIST_LENGTH_FIELD', 'LIST_SUFFIX', 'MAX_ITEM_LIST_LENGTH', 'POSITION_FIELD',
'HEAD_ENTITY_ID_FIELD', 'TAIL_ENTITY_ID_FIELD', 'RELATION_ID_FIELD', 'ENTITY_ID_FIELD',
- 'load_col', 'unload_col', 'additional_feat_suffix',
+ 'load_col', 'unload_col', 'unused_col', 'additional_feat_suffix',
'max_user_inter_num', 'min_user_inter_num', 'max_item_inter_num', 'min_item_inter_num',
- 'lowest_val', 'highest_val', 'equal_val', 'not_equal_val', 'drop_filter_field',
- 'fields_in_same_space', 'fill_nan',
- 'preload_weight', 'drop_preload_weight',
+ 'lowest_val', 'highest_val', 'equal_val', 'not_equal_val',
+ 'fields_in_same_space',
+ 'preload_weight',
'normalize_field', 'normalize_all']
diff --git a/recbole/utils/enum_type.py b/recbole/utils/enum_type.py
index 62a11ee7c..84e15b812 100644
--- a/recbole/utils/enum_type.py
+++ b/recbole/utils/enum_type.py
@@ -26,6 +26,7 @@ class ModelType(Enum):
KNOWLEDGE = 4
SOCIAL = 5
TRADITIONAL = 6
+ XGBOOST = 7
class DataLoaderType(Enum):
diff --git a/recbole/utils/utils.py b/recbole/utils/utils.py
index 3a44d1fba..d47ea0fae 100644
--- a/recbole/utils/utils.py
+++ b/recbole/utils/utils.py
@@ -53,7 +53,8 @@ def get_model(model_name):
'general_recommender',
'context_aware_recommender',
'sequential_recommender',
- 'knowledge_aware_recommender'
+ 'knowledge_aware_recommender',
+ 'exlib_recommender'
]
model_file_name = model_name.lower()
diff --git a/requirements.txt b/requirements.txt
index 7f4bff1bf..2b254e9ec 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,4 +6,5 @@ hyperopt>=0.2.4
pandas>=1.0.5
tqdm>=4.48.2
scikit_learn>=0.23.2
-pyyaml>=5.1.0
\ No newline at end of file
+pyyaml>=5.1.0
+xgboost>=1.2.1
\ No newline at end of file
diff --git a/setup.py b/setup.py
index a8c0a358a..46c74cab9 100644
--- a/setup.py
+++ b/setup.py
@@ -36,7 +36,7 @@
setup(
name='recbole',
version=
- '0.1.1', # please remember to edit recbole/__init__.py in response, once updating the version
+ '0.1.2', # please remember to edit recbole/__init__.py in response, once updating the version
description='A unified, comprehensive and efficient recommendation library',
long_description=long_description,
long_description_content_type="text/markdown",
diff --git a/tests/model/test_model.yaml b/tests/model/test_model.yaml
index 46ffeb24a..77d78f37e 100644
--- a/tests/model/test_model.yaml
+++ b/tests/model/test_model.yaml
@@ -47,12 +47,9 @@ lowest_val: ~
highest_val: ~
equal_val: ~
not_equal_val: ~
-drop_filter_field : False
# Preprocessing
fields_in_same_space: ~
-fill_nan: True
preload_weight: ~
-drop_preload_weight: True
normalize_field: ~
normalize_all: True
diff --git a/tests/model/test_model_auto.py b/tests/model/test_model_auto.py
index cd69e1dd6..0be3e9b31 100644
--- a/tests/model/test_model_auto.py
+++ b/tests/model/test_model_auto.py
@@ -317,6 +317,34 @@ def test_sasrecf(self):
objective_function(config_dict=config_dict,
config_file_list=config_file_list, saved=False)
+ def test_hrm(self):
+ config_dict = {
+ 'model': 'HRM',
+ }
+ objective_function(config_dict=config_dict,
+ config_file_list=config_file_list, saved=False)
+
+ def test_npe(self):
+ config_dict = {
+ 'model': 'NPE',
+ }
+ objective_function(config_dict=config_dict,
+ config_file_list=config_file_list, saved=False)
+
+ def test_shan(self):
+ config_dict = {
+ 'model': 'SHAN',
+ }
+ objective_function(config_dict=config_dict,
+ config_file_list=config_file_list, saved=False)
+
+ def test_hgn(self):
+ config_dict = {
+ 'model': 'HGN',
+ }
+ objective_function(config_dict=config_dict,
+ config_file_list=config_file_list, saved=False)
+
# def test_fdsa(self):
# config_dict = {
# 'model': 'FDSA',
@@ -574,6 +602,38 @@ def test_stamp(self):
objective_function(config_dict=config_dict,
config_file_list=config_file_list, saved=False)
+ def test_hrm(self):
+ config_dict = {
+ 'model': 'HRM',
+ 'loss_type': 'BPR',
+ }
+ objective_function(config_dict=config_dict,
+ config_file_list=config_file_list, saved=False)
+
+ def test_npe(self):
+ config_dict = {
+ 'model': 'NPE',
+ 'loss_type': 'BPR',
+ }
+ objective_function(config_dict=config_dict,
+ config_file_list=config_file_list, saved=False)
+
+ def test_shan(self):
+ config_dict = {
+ 'model': 'SHAN',
+ 'loss_type': 'BPR',
+ }
+ objective_function(config_dict=config_dict,
+ config_file_list=config_file_list, saved=False)
+
+ def test_hgn(self):
+ config_dict = {
+ 'model': 'HGN',
+ 'loss_type': 'CE',
+ }
+ objective_function(config_dict=config_dict,
+ config_file_list=config_file_list, saved=False)
+
def test_caser(self):
config_dict = {
'model': 'Caser',