-
Notifications
You must be signed in to change notification settings - Fork 65
Fix axis error in normalization layer when loading model from tf backend saved h5 #258
Conversation
Hi @leondgarse , thank you so much for your contribution. The build error you see in CI may not due to your code change, I will double check and get back to you. We are having some problem in the CI system. In the meantime, could you add a unit test here? https://github.com/awslabs/keras-apache-mxnet/tree/master/tests/keras/backend Thanks! |
Hi @roywei, I added a unit test file import mxnet_tf_model_test
aa = mxnet_tf_model_test.TestMXNetTfModel()
aa.test_batchnorm_layer_reload()
# Using MXNet backend
# axis = [1]
# (1, 10) 1 10
# axis = -1
# (1, 10) 1 10
# axis = 1
# (1, 10) 1 10 It tests loading a tf backend saved model, and then loading a mxnet backend saved model, to make sure everything alright. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, could you rebase and trigger CI? it should pass now. Thanks for your contribution!
Thanks for your work. Is my triggering the right method? What's the failure this time? |
@leondgarse could you try push an empty commit and trigger CI? It seems a multi-threaded test failed in our test environment(docker), and it seems random. So re-trigger should work. Our nightly tests have been passing for a few days.
|
@leondgarse It seems you PR is constantly failing with the same multi-threaded test above. I'm not sure why it's failing. Nightly test all passed. You PR is not affect that test. Could you try reset to commit I will also double check why the test constantly fails on your case. Sorry for the inconvenience caused. |
Ya! This is much more like it, here is my commands: git reset --hard 6e230a9
git pull upstream master --rebase
git push --force |
@leondgarse Awsome! merging now. Thanks for your contribution! |
Fix axis error in normalization layer when loading model from tf backend saved h5
Summary
When loading model saved by tensorflow backend keras h5 file, met an error:
I write a little demo to reproduce it:
Related Issues
None
PR Overview
It seems in tensorflow backend keras,
axis
inBatchNormalization
is a list, so I add anisinstance
test to theself.axis
init. Then theload_model
function passed.