Skip to content

Commit

Permalink
Fix 1D convolution layers under Theano backend (#2938)
Browse files Browse the repository at this point in the history
This issue is due to an unexpected loss of dimensionality when
composing the backend tensor operations "reshape" and "squeeze"
when there are dimensions of length 1.

For example, using a Theano backend the following fails with a
complaint about dimension mismatch:

UpSampling1D(2)(MaxPooling1D(2)(Reshape((2,1))(Input(shape=(2,)))))

The issue arises due to the conflict of two behaviors specific
to the Theano backend:

-   Reshape uses Theano's reshape function. Theano's reshape
    automatically makes dimensions with length 1 "broadcastable"

-   MaxPooling1D's implementation class _Pooling1D has a call method
    which uses a dummy dimension which it has to remove. The manner
    in which this dummy method is removed it to call "squeeze(x, axis)"
    from the backend. The squeeze implementation tells Theano to make
    the dummy dimension broadcastable, and then calls Theano's "squeeze",
    which removes ALL the broadcastable dimensions; not just the dummy
    dimension, but also the length 1 dimension flagged as broadcastable
    by reshape. This causes the problem observed above. This behavior
    is distinct from the behavior of the TensorFlow backend, which
    removes only the requested dimension.

This PR addresses this issue in two ways:

First, it introduces a test which checks the composition of "reshape"
and "squeeze" to make sure we get the same result using both Theano
and TensorFlow backends.

Second, it changes the implementation of squeeze(x,axis) so that the
Theano backend should behave similarly to the TensorFlow backend. With
this change the introduced test passes and the above example works.
  • Loading branch information
shaunharker authored and fchollet committed Jun 9, 2016
1 parent 4e0c8cf commit ab4bf44
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
7 changes: 5 additions & 2 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,11 @@ def expand_dims(x, dim=-1):
def squeeze(x, axis):
'''Remove a 1-dimension from the tensor at index "axis".
'''
x = T.addbroadcast(x, axis)
return T.squeeze(x)
broadcastable = x.broadcastable[:axis] + x.broadcastable[axis+1:]
x = T.patternbroadcast(x, [i == axis for i in range(x.type.ndim)])
x = T.squeeze(x)
x = T.patternbroadcast(x, broadcastable)
return x


def temporal_padding(x, padding=1):
Expand Down
23 changes: 23 additions & 0 deletions tests/keras/backend/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,26 @@ def check_two_tensor_operation(function_name, x_input_shape,
assert zth.shape == ztf.shape
assert_allclose(zth, ztf, atol=1e-05)

def check_composed_tensor_operations(first_function_name, first_function_args,
second_function_name, second_function_args,
input_shape):
''' Creates a random tensor t0 with shape input_shape and compute
t1 = first_function_name(t0, **first_function_args)
t2 = second_function_name(t1, **second_function_args)
with both Theano and TensorFlow backends and ensures the answers match.
'''
val = np.random.random(input_shape) - 0.5
xth = KTH.variable(val)
xtf = KTF.variable(val)

yth = getattr(KTH, first_function_name)(xth, **first_function_args)
ytf = getattr(KTF, first_function_name)(xtf, **first_function_args)

zth = KTH.eval(getattr(KTH, second_function_name)(yth, **second_function_args))
ztf = KTF.eval(getattr(KTF, second_function_name)(ytf, **second_function_args))

assert zth.shape == ztf.shape
assert_allclose(zth, ztf, atol=1e-05)

class TestBackend(object):

Expand Down Expand Up @@ -70,6 +90,9 @@ def test_shape_operations(self):
check_single_tensor_operation('expand_dims', (4, 3), dim=-1)
check_single_tensor_operation('expand_dims', (4, 3, 2), dim=1)
check_single_tensor_operation('squeeze', (4, 3, 1), axis=2)
check_composed_tensor_operations('reshape', {'shape':(4,3,1,1)},
'squeeze', {'axis':2},
(4, 3, 1, 1))

def test_repeat_elements(self):
reps = 3
Expand Down

0 comments on commit ab4bf44

Please sign in to comment.