Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Fix axis error in normalization layer when loading model from tf backend saved h5 #258

Merged
merged 4 commits into from
Apr 7, 2020
Merged

Conversation

leondgarse
Copy link

@leondgarse leondgarse commented Apr 2, 2020

Fix axis error in normalization layer when loading model from tf backend saved h5

Summary

When loading model saved by tensorflow backend keras h5 file, met an error:

/opt/anaconda3/lib/python3.7/site-packages/keras/layers/normalization.py in build(self, input_shape)
     98
     99     def build(self, input_shape):
--> 100         dim = input_shape[self.axis]
    101         print(input_shape, self.axis, dim)
    102         if dim is None
TypeError: tuple indices must be integers or slices, not list

I write a little demo to reproduce it:

  • Save a tensorflow backend keras h5 model file:
from tensorflow import keras
mm = keras.models.Sequential([
    keras.layers.Dense(10, kernel_initializer="zeros"),
    keras.layers.BatchNormalization()])
mm.build(input_shape=(1, 10))
mm.summary()
mm.layers[-1].axis  # This is `1` if backend is MXNet, and data_format is 'channels_first'
# ListWrapper([1])
mm.save('mm.h5')
  • Load from MXNet backend keras:
# $ KERAS_BACKEND='mxnet' ipython
import keras
# Using MXNet backend
mm = keras.models.load_model('mm.h5', compile=False)
# /opt/anaconda3/lib/python3.7/site-packages/keras/layers/normalization.py in build(self, input_shape)
# TypeError: tuple indices must be integers or slices, not list

Related Issues

None

PR Overview

It seems in tensorflow backend keras, axis in BatchNormalization is a list, so I add an isinstance test to the self.axis init. Then the load_model function passed.

  • [y] This PR requires new unit tests [y/n] (make sure tests are included)
  • [n] This PR requires to update the documentation [y/n] (make sure the docs are up-to-date)
  • [y] This PR is backwards compatible [y/n]
  • [n] This PR changes the current API [y/n]

@roywei
Copy link

roywei commented Apr 2, 2020

Hi @leondgarse , thank you so much for your contribution.

The build error you see in CI may not due to your code change, I will double check and get back to you. We are having some problem in the CI system.

In the meantime, could you add a unit test here? https://github.com/awslabs/keras-apache-mxnet/tree/master/tests/keras/backend
You can use your reproducible code as a test and add a new file, something like mxnet_tf_model_test.py

Thanks!

@leondgarse
Copy link
Author

leondgarse commented Apr 2, 2020

Hi @roywei, I added a unit test file tests/keras/backend/mxnet_tf_model_test.py, and my local test is ok:

import mxnet_tf_model_test
aa = mxnet_tf_model_test.TestMXNetTfModel()
aa.test_batchnorm_layer_reload()
# Using MXNet backend
# axis =  [1]
# (1, 10) 1 10
# axis =  -1
# (1, 10) 1 10
# axis =  1
# (1, 10) 1 10

It tests loading a tf backend saved model, and then loading a mxnet backend saved model, to make sure everything alright.

Copy link

@roywei roywei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, could you rebase and trigger CI? it should pass now. Thanks for your contribution!

@leondgarse
Copy link
Author

Thanks for your work. Is my triggering the right method? What's the failure this time?

@roywei
Copy link

roywei commented Apr 6, 2020

@leondgarse could you try push an empty commit and trigger CI? git commit --allow-empty -m "trigger ci"

It seems a multi-threaded test failed in our test environment(docker), and it seems random. So re-trigger should work. Our nightly tests have been passing for a few days.


=================================== FAILURES ===================================
--
699 | ________________________________ test_warnings _________________________________
700 | [gw1] linux -- Python 3.7.6 /root/.pyenv/versions/3.7.6/bin/python3.7
701 |  
702 | @pytest.mark.skipif(sys.version_info < (3,),
703 | reason='Cannot catch warnings in python 2')
704 | def test_warnings():
705 | a = Input(shape=(3,), name='input_a')
706 | b = Input(shape=(3,), name='input_b')
707 |  
708 | a_2 = Dense(4, name='dense_1')(a)
709 | dp = Dropout(0.5, name='dropout')
710 | b_2 = dp(b)
711 |  
712 | model = Model([a, b], [a_2, b_2])
713 |  
714 | optimizer = 'rmsprop'
715 | loss = 'mse'
716 | loss_weights = [1., 0.5]
717 | model.compile(optimizer, loss, metrics=[], loss_weights=loss_weights,
718 | sample_weight_mode=None)
719 |  
720 | @threadsafe_generator
721 | def gen_data(batch_sz):
722 | while True:
723 | yield ([np.random.random((batch_sz, 3)),
724 | np.random.random((batch_sz, 3))],
725 | [np.random.random((batch_sz, 4)),
726 | np.random.random((batch_sz, 3))])
727 |  
728 | with pytest.warns(Warning) as w:
729 | out = model.fit_generator(gen_data(4),
730 | steps_per_epoch=10,
731 | use_multiprocessing=True,
732 | >                                     workers=2)
733 |  
734 | tests/keras/engine/test_training.py:607:
735 | _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
736 | keras/legacy/interfaces.py:91: in wrapper
737 | return func(*args, **kwargs)
738 | keras/engine/training.py:1433: in fit_generator
739 | initial_epoch=initial_epoch)
740 | keras/engine/training_generator.py:181: in fit_generator
741 | generator_output = next(output_generator)
742 | keras/utils/data_utils.py:695: in get
743 | inputs = self.queue.get(block=True).get()
744 | /root/.pyenv/versions/3.7.6/lib/python3.7/multiprocessing/pool.py:651: in get
745 | self.wait(timeout)
746 | /root/.pyenv/versions/3.7.6/lib/python3.7/multiprocessing/pool.py:648: in wait
747 | self._event.wait(timeout)
748 | /root/.pyenv/versions/3.7.6/lib/python3.7/threading.py:552: in wait
749 | signaled = self._cond.wait(timeout)
750 | _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
751 |  
752 | self = <Condition(<unlocked _thread.lock object at 0x7fba4402c990>, 0)>
753 | timeout = None
754 |  
755 | def wait(self, timeout=None):
756 | """Wait until notified or until a timeout occurs.
757 |  
758 | If the calling thread has not acquired the lock when this method is
759 | called, a RuntimeError is raised.
760 |  
761 | This method releases the underlying lock, and then blocks until it is
762 | awakened by a notify() or notify_all() call for the same condition
763 | variable in another thread, or until the optional timeout occurs. Once
764 | awakened or timed out, it re-acquires the lock and returns.
765 |  
766 | When the timeout argument is present and not None, it should be a
767 | floating point number specifying a timeout for the operation in seconds
768 | (or fractions thereof).
769 |  
770 | When the underlying lock is an RLock, it is not released using its
771 | release() method, since this may not actually unlock the lock when it
772 | was acquired multiple times recursively. Instead, an internal interface
773 | of the RLock class is used, which really unlocks it even when it has
774 | been recursively acquired several times. Another internal interface is
775 | then used to restore the recursion level when the lock is reacquired.
776 |  
777 | """
778 | if not self._is_owned():
779 | raise RuntimeError("cannot wait on un-acquired lock")
780 | waiter = _allocate_lock()
781 | waiter.acquire()
782 | self._waiters.append(waiter)
783 | saved_state = self._release_save()
784 | gotit = False
785 | try:    # restore state no matter what (e.g., KeyboardInterrupt)
786 | if timeout is None:
787 | >               waiter.acquire()
788 | E               Failed: Timeout >1200.0s
789 |  
790 | /root/.pyenv/versions/3.7.6/lib/python3.7/threading.py:296: Failed
791 | ----------------------------- Captured stdout call -----------------------------
792 | Epoch 1/1
793 | ----------------------------- Captured stderr call -----------------------------
794 |  
795 | +++++++++++++++++++++++++++++++++++ Timeout ++++++++++++++++++++++++++++++++++++
796 |  
797 | ~~~~~~~~~~~~~~~~~~~~ Stack of Thread-272 (140435986855680) ~~~~~~~~~~~~~~~~~~~~~
798 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 890, in _bootstrap
799 | self._bootstrap_inner()
800 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 926, in _bootstrap_inner
801 | self.run()
802 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 870, in run
803 | self._target(*self._args, **self._kwargs)
804 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/multiprocessing/pool.py", line 470, in _handle_results
805 | task = get()
806 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/multiprocessing/connection.py", line 250, in recv
807 | buf = self._recv_bytes()
808 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/multiprocessing/connection.py", line 407, in _recv_bytes
809 | buf = self._recv(4)
810 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/multiprocessing/connection.py", line 379, in _recv
811 | chunk = read(handle, remaining)
812 |  
813 | ~~~~~~~~~~~~~~~~~~~~ Stack of Thread-271 (140436114863872) ~~~~~~~~~~~~~~~~~~~~~
814 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 890, in _bootstrap
815 | self._bootstrap_inner()
816 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 926, in _bootstrap_inner
817 | self.run()
818 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 870, in run
819 | self._target(*self._args, **self._kwargs)
820 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/multiprocessing/pool.py", line 422, in _handle_tasks
821 | for taskseq, set_length in iter(taskqueue.get, None):
822 |  
823 | ~~~~~~~~~~~~~~~~~~~~ Stack of Thread-270 (140438099777280) ~~~~~~~~~~~~~~~~~~~~~
824 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 890, in _bootstrap
825 | self._bootstrap_inner()
826 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 926, in _bootstrap_inner
827 | self.run()
828 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 870, in run
829 | self._target(*self._args, **self._kwargs)
830 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/multiprocessing/pool.py", line 413, in _handle_workers
831 | time.sleep(0.1)
832 |  
833 | ~~~~~~~~~~~~~~~~~~~~ Stack of Thread-269 (140435410818816) ~~~~~~~~~~~~~~~~~~~~~
834 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 890, in _bootstrap
835 | self._bootstrap_inner()
836 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 926, in _bootstrap_inner
837 | self.run()
838 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 870, in run
839 | self._target(*self._args, **self._kwargs)
840 | File "/codebuild/output/src242250753/src/github.com/awslabs/keras-apache-mxnet/keras/utils/data_utils.py", line 681, in _run
841 | executor.apply_async(next_sample, (self.uid,)), block=True)
842 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/queue.py", line 139, in put
843 | self.not_full.wait()
844 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 296, in wait
845 | waiter.acquire()
846 |  
847 | ~~~~~~~~~~~~~~~~~~~~~ Stack of <unknown> (140440431216384) ~~~~~~~~~~~~~~~~~~~~~
848 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/site-packages/execnet/gateway_base.py", line 285, in _perform_spawn
849 | reply.run()
850 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/site-packages/execnet/gateway_base.py", line 220, in run
851 | self._result = func(*args, **kwargs)
852 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/site-packages/execnet/gateway_base.py", line 967, in _thread_receiver
853 | msg = Message.from_io(io)
854 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/site-packages/execnet/gateway_base.py", line 432, in from_io
855 | header = io.read(9)  # type 1, channel 4, payload 4
856 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/site-packages/execnet/gateway_base.py", line 400, in read
857 | data = self._read(numbytes - len(buf))
858 |  
859 | +++++++++++++++++++++++++++++++++++ Timeout ++++++++++++++++++++++++++++++++++++


@roywei
Copy link

roywei commented Apr 7, 2020

@leondgarse It seems you PR is constantly failing with the same multi-threaded test above. I'm not sure why it's failing. Nightly test all passed. You PR is not affect that test.

Could you try reset to commit 6e230a9 and do a git pull --rebase instead of merge? Ideally your PR should not contain my changes.

I will also double check why the test constantly fails on your case.

Sorry for the inconvenience caused.

@leondgarse
Copy link
Author

Ya! This is much more like it, here is my commands:

git reset --hard 6e230a9
git pull upstream master --rebase
git push --force

@roywei
Copy link

roywei commented Apr 7, 2020

@leondgarse Awsome! merging now. Thanks for your contribution!

@roywei roywei merged commit 5e5e74c into awslabs:master Apr 7, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants