Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow saving weights of a very deep model into a HDF5 file. #7508

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 70 additions & 7 deletions keras/engine/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

try:
import h5py
HDF5_OBJECT_HEADER_LIMIT = 64512
except ImportError:
h5py = None

Expand Down Expand Up @@ -2825,10 +2826,72 @@ def _collect_input_shape(input_tensors):
return shapes


def _save_attributes_to_hdf5_group(group, name, data):
"""Saves attributes (data) of the specified name into the HDF5 group.
This method deals with an inherent problem of HDF5 file which is not
able to store data larger than HDF5_OBJECT_HEADER_LIMIT bytes.

# Arguments
group: A pointer to a HDF5 group.
name: A name of the attributes to save.
data: Attributes data to store.
"""
# Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT`
# because in that case even chunking the array would not make the saving
# possible.
bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT]

# Expecting this to never be true.
if len(bad_attributes) > 0:
raise RuntimeError("the following attributes cannot be saved to HDF5 file "
"because they are larger than %d bytes: '%s'"
% (HDF5_OBJECT_HEADER_LIMIT,
"', '".join([x for x in bad_attributes])))

data_npy = np.asarray(data)

n_chunks = 1
chunked_data = np.array_split(data_npy, n_chunks)

# This will never loop forever thanks to the test above.
while any(map(lambda x: x.nbytes > HDF5_OBJECT_HEADER_LIMIT, chunked_data)):
n_chunks += 1
chunked_data = np.array_split(data_npy, n_chunks)

if n_chunks > 1:
for chunk_id, chunk_data in enumerate(chunked_data):
group.attrs['%s%d' % (name, chunk_id)] = chunk_data
else:
group.attrs[name] = data


def _load_attributes_from_hdf5_group(group, name):
"""Loads attributes of the specified name from the HDF5 group. This method
deals with an inherent problem of HDF5 file which is not able to store
data larger than HDF5_OBJECT_HEADER_LIMIT bytes.

# Arguments
group: A pointer to a HDF5 group.
name: A name of the attributes to load.

# Returns
data: Attributes data.
"""
if name in group.attrs:
data = [n.decode('utf8') for n in group.attrs[name]]
else:
data = []
chunk_id = 0
while ('%s%d' % (name, chunk_id)) in group.attrs:
data.extend([n.decode('utf8') for n in group.attrs['%s%d' % (name, chunk_id)]])
chunk_id += 1
return data


def save_weights_to_hdf5_group(f, layers):
from .. import __version__ as keras_version

f.attrs['layer_names'] = [layer.name.encode('utf8') for layer in layers]
_save_attributes_to_hdf5_group(f, 'layer_names', [layer.name.encode('utf8') for layer in layers])
f.attrs['backend'] = K.backend().encode('utf8')
f.attrs['keras_version'] = str(keras_version).encode('utf8')

Expand All @@ -2843,7 +2906,7 @@ def save_weights_to_hdf5_group(f, layers):
else:
name = 'param_' + str(i)
weight_names.append(name.encode('utf8'))
g.attrs['weight_names'] = weight_names
_save_attributes_to_hdf5_group(g, 'weight_names', weight_names)
for name, val in zip(weight_names, weight_values):
param_dset = g.create_dataset(name, val.shape,
dtype=val.dtype)
Expand Down Expand Up @@ -3042,11 +3105,11 @@ def load_weights_from_hdf5_group(f, layers):
if weights:
filtered_layers.append(layer)

layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
layer_names = _load_attributes_from_hdf5_group(f, 'layer_names')
filtered_layer_names = []
for name in layer_names:
g = f[name]
weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
weight_names = _load_attributes_from_hdf5_group(g, 'weight_names')
if weight_names:
filtered_layer_names.append(name)
layer_names = filtered_layer_names
Expand All @@ -3061,7 +3124,7 @@ def load_weights_from_hdf5_group(f, layers):
weight_value_tuples = []
for k, name in enumerate(layer_names):
g = f[name]
weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
weight_names = _load_attributes_from_hdf5_group(g, 'weight_names')
weight_values = [g[weight_name] for weight_name in weight_names]
layer = filtered_layers[k]
symbolic_weights = layer.weights
Expand Down Expand Up @@ -3109,7 +3172,7 @@ def load_weights_from_hdf5_group_by_name(f, layers):
original_backend = None

# New file format.
layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
layer_names = _load_attributes_from_hdf5_group(f, 'layer_names')

# Reverse index of layer name to list of layers with name.
index = {}
Expand All @@ -3122,7 +3185,7 @@ def load_weights_from_hdf5_group_by_name(f, layers):
weight_value_tuples = []
for k, name in enumerate(layer_names):
g = f[name]
weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
weight_names = _load_attributes_from_hdf5_group(g, 'weight_names')
weight_values = [g[weight_name] for weight_name in weight_names]

for layer in index.get(name, []):
Expand Down
89 changes: 89 additions & 0 deletions tests/test_model_saving.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import os
import h5py
import tempfile
import numpy as np
from numpy.testing import assert_allclose
Expand Down Expand Up @@ -337,5 +338,93 @@ def test_saving_custom_activation_function():
assert_allclose(out, out2, atol=1e-05)


@keras_test
def test_saving_model_with_long_layer_names():
# This layer name will make the `layers_name` HDF5 attribute blow
# out of proportion. Note that it fits into the internal HDF5
# attribute memory limit on its own but because h5py converts
# the list of layer names into numpy array, which uses the same
# amout of memory for every item, it increases the memory
# requirements substantially.
x = Input(shape=(2,), name='input_' + ('x' * (2**15)))
f = x
for i in range(4):
f = Dense(2, name='dense_%d' % (i,))(f)

model = Model(inputs=[x], outputs=[f])

model.compile(loss='mse', optimizer='adam', metrics=['acc'])

x = np.random.random((1, 2))
y = np.random.random((1, 2))
model.train_on_batch(x, y)

out = model.predict(x)

_, fname = tempfile.mkstemp('.h5')
save_model(model, fname)

model = load_model(fname)

# Check that the HDF5 files contains chunked array
# of layer names.
with h5py.File(fname, 'r') as h5file:
n_layer_names_arrays = len([attr for attr in h5file['model_weights'].attrs
if attr.startswith('layer_names')])

os.remove(fname)

# The chunking of layer names array should have happend.
assert n_layer_names_arrays > 0

out2 = model.predict(x)
assert_allclose(out, out2, atol=1e-05)


@keras_test
def test_saving_model_with_long_weights_names():
x = Input(shape=(2,), name='nested_model_input')
f = x
for i in range(4):
f = Dense(2, name='nested_model_dense_%d' % (i,))(f)
# This layer name will make the `weights_name`
# HDF5 attribute blow out of proportion.
f = Dense(2, name='nested_model_output' + ('x' * (2**15)))(f)
nested_model = Model(inputs=[x], outputs=[f], name='nested_model')

x = Input(shape=(2,), name='outer_model_input')
f = nested_model(x)
f = Dense(2, name='outer_model_output')(f)

model = Model(inputs=[x], outputs=[f])

model.compile(loss='mse', optimizer='adam', metrics=['acc'])

x = np.random.random((1, 2))
y = np.random.random((1, 2))
model.train_on_batch(x, y)

out = model.predict(x)

_, fname = tempfile.mkstemp('.h5')
save_model(model, fname)

model = load_model(fname)

# Check that the HDF5 files contains chunked array
# of weight names.
with h5py.File(fname, 'r') as h5file:
n_weight_names_arrays = len([attr for attr in h5file['model_weights']['nested_model'].attrs
if attr.startswith('weight_names')])

os.remove(fname)

# The chunking of layer names array should have happend.
assert n_weight_names_arrays > 0

out2 = model.predict(x)
assert_allclose(out, out2, atol=1e-05)


if __name__ == '__main__':
pytest.main([__file__])