Skip to content

Commit

Permalink
Raise error in prune_low_magnitude when unsupported to_prune object i…
Browse files Browse the repository at this point in the history
…s passed in.

PiperOrigin-RevId: 284044063
  • Loading branch information
alanchiao authored and tensorflower-gardener committed Dec 5, 2019
1 parent cc05be8 commit 089fadb
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 9 deletions.
23 changes: 18 additions & 5 deletions tensorflow_model_optimization/python/core/sparsity/keras/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ def prune_low_magnitude(to_prune,
**kwargs):
"""Modify a keras layer or model to be pruned during training.
This function wraps a keras model or layer with pruning functionality which
This function wraps a tf.keras model or layer with pruning functionality which
sparsifies the layer's weights during training. For example, using this with
50% sparsity will ensure that 50% of the layer's weights are zero.
The function accepts either a single keras layer
(subclass of `keras.layers.Layer`), list of keras layers or a keras model
(instance of `keras.models.Model`) and handles them appropriately.
(subclass of `tf.keras.layers.Layer`), list of keras layers or a Sequential
or Functional keras model and handles them appropriately.
If it encounters a layer it does not know how to handle, it will throw an
error. While pruning an entire model, even a single unknown layer would lead
Expand Down Expand Up @@ -144,15 +144,28 @@ def _add_pruning_wrapper(layer):
'block_size': block_size,
'block_pooling_type': block_pooling_type
}
is_sequential_or_functional = isinstance(
to_prune, keras.Model) and (isinstance(to_prune, keras.Sequential) or
to_prune._is_graph_network)

# A subclassed model is also a subclass of keras.layers.Layer.
is_keras_layer = isinstance(
to_prune, keras.layers.Layer) and not isinstance(to_prune, keras.Model)

if isinstance(to_prune, list):
return _prune_list(to_prune, **params)
elif isinstance(to_prune, keras.Model):
elif is_sequential_or_functional:
return keras.models.clone_model(
to_prune, input_tensors=None, clone_function=_add_pruning_wrapper)
elif isinstance(to_prune, keras.layers.Layer):
elif is_keras_layer:
params.update(kwargs)
return pruning_wrapper.PruneLowMagnitude(to_prune, **params)
else:
raise ValueError(
'`prune_low_magnitude` can only prune an object of the following '
'types: tf.keras.models.Sequential, tf.keras functional model, '
'tf.keras.layers.Layer, list of tf.keras.layers.Layer. You passed '
'an object of type: {input}.'.format(input=to_prune.__class__.__name__))


def strip_pruning(model):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper


class TestModel(keras.Model):
class TestSubclassedModel(keras.Model):
"""A model subclass."""

def __init__(self):
"""A test subclass model with one dense layer."""
super(TestModel, self).__init__(name='test_model')
super(TestSubclassedModel, self).__init__(name='test_model')
self.layer1 = keras.layers.Dense(10, activation='relu')

def call(self, inputs):
Expand All @@ -55,6 +55,13 @@ class CustomNonPrunableLayer(layers.Dense):

class PruneTest(test.TestCase, parameterized.TestCase):

INVALID_TO_PRUNE_PARAM_ERROR = ('`prune_low_magnitude` can only prune an '
'object of the following types: '
'tf.keras.models.Sequential, tf.keras '
'functional model, tf.keras.layers.Layer, '
'list of tf.keras.layers.Layer. You passed an'
' object of type: {input}.')

def setUp(self):
super(PruneTest, self).setUp()

Expand Down Expand Up @@ -319,9 +326,21 @@ def testPruneFunctionalModelPreservesBuiltState(self):
self.assertEqual(loaded_model.built, True)

def testPruneSubclassModel(self):
model = TestModel()
with self.assertRaises(ValueError):
model = TestSubclassedModel()
with self.assertRaises(ValueError) as e:
_ = prune.prune_low_magnitude(model, **self.params)
self.assertEqual(
str(e.exception),
self.INVALID_TO_PRUNE_PARAM_ERROR.format(input='TestSubclassedModel'))

def testPruneMiscObject(self):

model = object()
with self.assertRaises(ValueError) as e:
_ = prune.prune_low_magnitude(model, **self.params)
self.assertEqual(
str(e.exception),
self.INVALID_TO_PRUNE_PARAM_ERROR.format(input='object'))

def testStripPruningSequentialModel(self):
model = keras.Sequential([
Expand Down

0 comments on commit 089fadb

Please sign in to comment.