-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
关于数据集的问题 #78
Comments
同问,请问你这个问题解决了嘛? |
原始数据是没经过数据预处理的, 你可以写个dataset的yaml文件, 例如下面这个yelp.yaml文件
在执行的时候加上参数
参数的具体含义可以参考文档: https://recbole.io/docs/user_guide/config/data_settings.html |
非常感谢你的回答,是这样的,我下载了recbole库网盘里的数据集,也设置了yaml的数据过滤操作,但是得到的过滤后是数据信息和论文里的不一样,比如amazonbooks数据集,同样是设置成load_col: |
yelp和amazonbook应该都不是18版本的, 你可以直接输入数据集应该会自动下载, 如果自己下载的话我当时用的google drive |
非常感谢您的回答,不好意思我的问题有点多,我还想向您请教一个问题,就是我想用recbole库中训练好的模型重新对testdata做新的评价指标的计算,例如之前我跑了ndcg,我现在想在之前模型训练的参数上在测试一下recall或者其他指标,然后我用from recbole.quick_start import load_data_and_model config, model, dataset, train_data, valid_data, test_data = load_data_and_model(model_file = "D:\RecBole-1.2.0\saved\LightGCN-Jan-16-2024_20-55-38.pth")config['metrics'] = ["Recall","NDCG","ItemCoverage","Novelty"] |
第一个路径报错可能是因为在evaluate函数中,需要再次显式传入模型地址以加载,函数的签名如下所示: def evaluate(
self, eval_data, load_best_model=True, model_file=None, show_progress=False
):
r"""Evaluate the model based on the eval data.
Args:
eval_data (DataLoader): the eval data
load_best_model (bool, optional): whether load the best model in the training process, default: True.
It should be set True, if users want to test the model after training.
model_file (str, optional): the saved model file, default: None. If users want to test the previously
trained model file, they can set this parameter.
show_progress (bool): Show the progress of evaluate epoch. Defaults to ``False``.
Returns:
collections.OrderedDict: eval result, key is the eval metric and value in the corresponding metric value.
""" 第二个问题是你的config中路径是否设置对了呢?可以通过打印 from recbole.trainer import Trainer
from recbole_gnn.model.general_recommender import LightGCN
from recbole_gnn.utils import create_dataset, data_preparation
from recbole.utils import init_logger, init_seed
from logging import getLogger
import torch
def load_data_and_model(model_file, mymodel):
def compatibility_settings():
import numpy as np
np.bool = np.bool_
np.int = np.int_
np.float = np.float_
np.complex = np.complex_
np.object = np.object_
np.str = np.str_
np.long = np.int_
np.unicode = np.unicode_
compatibility_settings()
checkpoint = torch.load(model_file)
config = checkpoint['config']
init_seed(config['seed'], config['reproducibility'])
init_logger(config)
logger = getLogger()
logger.info(config)
dataset = create_dataset(config)
logger.info(dataset)
train_data, valid_data, test_data = data_preparation(config, dataset)
init_seed(config['seed'], config['reproducibility'])
model = mymodel(config, train_data.dataset).to(config['device'])
model.load_state_dict(checkpoint['state_dict'])
model.load_other_parameter(checkpoint.get('other_parameter'))
return config, model, dataset, train_data, valid_data, test_data
model_path = 'your_path'
config, model, dataset, train_data, valid_data, test_data = load_data_and_model(
model_file=model_path, mymodel=LightGCN
)
print(config['data_path'])
config['metrics'] = ["Recall", "NDCG"]
config['topk'] = [10, 20, 50]
print(config)
trainer = Trainer(config, model)
test_result = trainer.evaluate(test_data, model_file=model_path)
print(test_result) |
您好,我在百度网盘中下载的的Yelp2018、Amazon-Book、Gowalla数据集都要比论文中所报告的统计数值大很多,请问怎么获得与论文中相同版本的数据集呢,十分感谢~
The text was updated successfully, but these errors were encountered: