Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support nested layer customization and enable better 'in_channels' exception raise #1015

Merged
merged 12 commits into from
Jul 8, 2019
Merged
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ To release a new version, please update the changelog as followed:
## [Unreleased]

### Added
- Support nested layer customization (#PR 1015)

### Changed

Expand All @@ -83,15 +84,15 @@ To release a new version, please update the changelog as followed:

### Fixed
- Fix `tf.models.Model._construct_graph` for list of outputs, e.g. STN case (PR #1010)

- Enable better `in_channels` exception raise. (pR #1015)
### Removed

### Security

### Contributors

- @zsdonghao
- @ChrisWu1997: #1010
- @ChrisWu1997: #1010 #1015

## [2.1.0]

Expand Down
19 changes: 15 additions & 4 deletions examples/database/task_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# load dataset from database
X_train, y_train, X_val, y_val, X_test, y_test = db.find_top_dataset('mnist')


# define the network
def mlp():
ni = tl.layers.Input([None, 784], name='input')
Expand All @@ -24,15 +25,18 @@ def mlp():
M = tl.models.Model(inputs=ni, outputs=net)
return M


network = mlp()

# cost and accuracy
cost = tl.cost.cross_entropy


def acc(y, y_):
correct_prediction = tf.equal(tf.argmax(y, 1), tf.convert_to_tensor(y_, tf.int64))
return tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


# define the optimizer
train_op = tf.optimizers.Adam(learning_rate=0.0001)

Expand All @@ -43,8 +47,17 @@ def acc(y, y_):
# )

tl.utils.fit(
network, train_op=tf.optimizers.Adam(learning_rate=0.0001), cost=tl.cost.cross_entropy, X_train=X_train,
y_train=y_train, acc=acc, batch_size=256, n_epoch=20, X_val=X_val, y_val=y_val, eval_train=False,
network,
train_op=tf.optimizers.Adam(learning_rate=0.0001),
cost=tl.cost.cross_entropy,
X_train=X_train,
y_train=y_train,
acc=acc,
batch_size=256,
n_epoch=20,
X_val=X_val,
y_val=y_val,
eval_train=False,
)

# evaluation and save result that match the result_key
Expand All @@ -55,5 +68,3 @@ def acc(y, y_):
db.save_model(network, model_name='mlp', name=str(n_units1) + '-' + str(n_units2), test_accuracy=test_accuracy)
# in other script, you can load the model as follow
# net = db.find_model(sess=sess, model_name=str(n_units1)+'-'+str(n_units2)

tf.python.keras.layers.BatchNormalization
48 changes: 39 additions & 9 deletions tensorlayer/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,11 @@ def __init__(self, name=None, act=None, *args, **kwargs):

# Layer weight state
self._all_weights = None
self._trainable_weights = None
self._nontrainable_weights = None
self._trainable_weights = []
self._nontrainable_weights = []

# nested layers
self._layers = None

# Layer training state
self.is_train = True
Expand Down Expand Up @@ -179,27 +182,40 @@ def all_weights(self):
if self._all_weights is not None and len(self._all_weights) > 0:
pass
else:
self._all_weights = list()
if self._trainable_weights is not None:
self._all_weights.extend(self._trainable_weights)
if self._nontrainable_weights is not None:
self._all_weights.extend(self._nontrainable_weights)
self._all_weights = self.trainable_weights + self.nontrainable_weights
return self._all_weights

@property
def trainable_weights(self):
return self._trainable_weights
nested = self._collect_sublayers_attr('trainable_weights')
return self._trainable_weights + nested

@property
def nontrainable_weights(self):
return self._nontrainable_weights
nested = self._collect_sublayers_attr('nontrainable_weights')
return self._nontrainable_weights + nested

@property
def weights(self):
raise Exception(
"no property .weights exists, do you mean .all_weights, .trainable_weights, or .nontrainable_weights ?"
)

def _collect_sublayers_attr(self, attr):
if attr not in ['trainable_weights', 'nontrainable_weights']:
raise ValueError(
"Only support to collect some certain attributes of nested layers,"
"e.g. 'trainable_weights', 'nontrainable_weights', but got {}".format(attr)
)
if self._layers is None:
return []
nested = []
for layer in self._layers:
value = getattr(layer, attr)
if value is not None:
nested.extend(value)
return nested

def __call__(self, inputs, *args, **kwargs):
"""
(1) Build the Layer if necessary.
Expand Down Expand Up @@ -326,6 +342,20 @@ def __setitem__(self, key, item):
def __delitem__(self, key):
raise TypeError("The Layer API does not allow to use the method: `__delitem__`")

def __setattr__(self, key, value):
if isinstance(value, Layer):
value._nodes_fixed = True
if self._layers is None:
self._layers = []
self._layers.append(value)
super().__setattr__(key, value)

def __delattr__(self, name):
value = getattr(self, name, None)
if isinstance(value, Layer):
self._layers.remove(value)
super().__delattr__(name)

@protected_method
def get_args(self):
init_args = {"layer_type": "normal"}
Expand Down
9 changes: 9 additions & 0 deletions tensorlayer/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,15 @@ def _fix_nodes_for_layers(self):
layer._fix_nodes_for_layers()
self._nodes_fixed = True

def __setattr__(self, key, value):
if isinstance(value, Layer):
if value._built is False:
raise AttributeError(
"The registered layer `{}` should be built in advance. "
"Do you forget to pass the keyword argument 'in_channels'? ".format(value.name)
)
super().__setattr__(key, value)

def __repr__(self):
# tmpstr = self.__class__.__name__ + '(\n'
tmpstr = self.name + '(\n'
Expand Down
123 changes: 123 additions & 0 deletions tests/layers/test_layers_core_nested.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-\
import os
import unittest

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf
import tensorlayer as tl
import numpy as np

from tests.utils import CustomTestCase


class Layer_nested(CustomTestCase):

@classmethod
def setUpClass(cls):
print("##### begin testing nested layer #####")

@classmethod
def tearDownClass(cls):
pass
# tf.reset_default_graph()

def test_nested_layer_with_inchannels(cls):

class MyLayer(tl.layers.Layer):

def __init__(self, name=None):
super(MyLayer, self).__init__(name=name)
self.input_layer = tl.layers.Dense(in_channels=50, n_units=20)
self.build(None)
self._built = True

def build(self, inputs_shape=None):
self.W = self._get_weights('weights', shape=(20, 10))

def forward(self, inputs):
inputs = self.input_layer(inputs)
output = tf.matmul(inputs, self.W)
return output

class model(tl.models.Model):

def __init__(self, name=None):
super(model, self).__init__(name=name)
self.layer = MyLayer()

def forward(self, inputs):
return self.layer(inputs)

input = tf.random.normal(shape=(100, 50))
model_dynamic = model()
model_dynamic.train()
cls.assertEqual(model_dynamic(input).shape, (100, 10))
cls.assertEqual(len(model_dynamic.all_weights), 3)
cls.assertEqual(len(model_dynamic.trainable_weights), 3)
model_dynamic.layer.input_layer.b.assign_add(tf.ones((20, )))
cls.assertEqual(np.sum(model_dynamic.all_weights[-1].numpy() - tf.ones(20, ).numpy()), 0)

ni = tl.layers.Input(shape=(100, 50))
nn = MyLayer(name='mylayer1')(ni)
model_static = tl.models.Model(inputs=ni, outputs=nn)
model_static.eval()
cls.assertEqual(model_static(input).shape, (100, 10))
cls.assertEqual(len(model_static.all_weights), 3)
cls.assertEqual(len(model_static.trainable_weights), 3)
model_static.get_layer('mylayer1').input_layer.b.assign_add(tf.ones((20, )))
cls.assertEqual(np.sum(model_static.all_weights[-1].numpy() - tf.ones(20, ).numpy()), 0)

def test_nested_layer_without_inchannels(cls):

class MyLayer(tl.layers.Layer):

def __init__(self, name=None):
super(MyLayer, self).__init__(name=name)
self.input_layer = tl.layers.Dense(n_units=20) # no need for in_channels here
self.build(None)
self._built = True

def build(self, inputs_shape=None):
self.W = self._get_weights('weights', shape=(20, 10))

def forward(self, inputs):
inputs = self.input_layer(inputs)
output = tf.matmul(inputs, self.W)
return output

class model(tl.models.Model):

def __init__(self, name=None):
super(model, self).__init__(name=name)
self.layer = MyLayer()

def forward(self, inputs):
return self.layer(inputs)

input = tf.random.normal(shape=(100, 50))
model_dynamic = model()
model_dynamic.train()
cls.assertEqual(model_dynamic(input).shape, (100, 10))
cls.assertEqual(len(model_dynamic.all_weights), 3)
cls.assertEqual(len(model_dynamic.trainable_weights), 3)
model_dynamic.layer.input_layer.b.assign_add(tf.ones((20, )))
cls.assertEqual(np.sum(model_dynamic.all_weights[-1].numpy() - tf.ones(20, ).numpy()), 0)

ni = tl.layers.Input(shape=(100, 50))
nn = MyLayer(name='mylayer2')(ni)
model_static = tl.models.Model(inputs=ni, outputs=nn)
model_static.eval()
cls.assertEqual(model_static(input).shape, (100, 10))
cls.assertEqual(len(model_static.all_weights), 3)
cls.assertEqual(len(model_static.trainable_weights), 3)
model_static.get_layer('mylayer2').input_layer.b.assign_add(tf.ones((20, )))
cls.assertEqual(np.sum(model_static.all_weights[-1].numpy() - tf.ones(20, ).numpy()), 0)


if __name__ == '__main__':

tl.logging.set_verbosity(tl.logging.DEBUG)

unittest.main()
19 changes: 19 additions & 0 deletions tests/models/test_model_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,25 @@ def test_model_weights_copy(self):
new_len = len(model_weights)
self.assertEqual(new_len - 1, ori_len)

def test_inchannels_exception(self):
print('-' * 20, 'test_inchannels_exception', '-' * 20)

class my_model(Model):

def __init__(self):
super(my_model, self).__init__()
self.dense = Dense(64)
self.vgg = tl.models.vgg16()

def forward(self, x):
return x

try:
M = my_model()
except Exception as e:
self.assertIsInstance(e, AttributeError)
print(e)


if __name__ == '__main__':

Expand Down