Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

回测时加载模型checkpoint和初始化actor critic维度不同 #13

Open
wuxiawei opened this issue May 8, 2024 · 4 comments
Open

Comments

@wuxiawei
Copy link

wuxiawei commented May 8, 2024

在执行:
results = DRLAgent.DRL_prediction_load_from_file(model_name='maesac',environment=test_trade_gym, cwd=model_path)
的时候报错:
RuntimeError: Error(s) in loading state_dict for SACPolicy:
size mismatch for actor.mu.weight: copying a param with shape torch.Size([88, 256]) from checkpoint, the shape in current model is torch.Size([31064, 256]).
size mismatch for actor.mu.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([31064]).
size mismatch for actor.log_std.weight: copying a param with shape torch.Size([88, 256]) from checkpoint, the shape in current model is torch.Size([31064, 256]).
size mismatch for actor.log_std.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([31064]).
size mismatch for critic.qf0.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
size mismatch for critic.qf1.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
size mismatch for critic_target.qf0.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
size mismatch for critic_target.qf1.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
请问这是什么原因呢?

@wuxiawei wuxiawei changed the title 回测时加载模型和checkpoint维度不同 回测时加载模型checkpoint和初始化actor critic维度不同 May 9, 2024
@xbkaishui
Copy link

请问这个问题解决了吗?

@wuxiawei
Copy link
Author

还没有,你现在解决了吗?

@yo-yoo
Copy link

yo-yoo commented Jun 23, 2024

请问这个问题解决了吗?

遇到了同样的问题,请问解决了吗?是因为单卡运行导致的吗?

@mzliu2017
Copy link

检查了好久,终于找到出错是在zip存取和读取数据有错,存储时data_to_json()之前的data格式正确,为例
,后面读取时json_to_data()会产生warning,导致读取错误。
image
修改numpy版本可以解决此bug,我之前使用的是numpy=1.24.0,改为1.23.5就可以使用了(python=3.8.19),但是更底层的原因就没有深究了,希望对你有帮助~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants