Skip to content

Commit

Permalink
MAINT: fix tensorflow bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
Holger Kohr committed Nov 13, 2017
1 parent 1e95b3b commit cecad04
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@


# Add empty axes for batch and channel
x = x[None, ..., None]
z = z[None, ..., None]
x_reshaped = x[None, ..., None]
z_reshaped = z[None, ..., None]

# Lazily apply operator in tensorflow
y = odl_op_layer(x)
y = odl_op_layer(x_reshaped)

# Evaluate using tensorflow
print(y.eval())
Expand All @@ -44,7 +44,7 @@
# We need to scale by cell size to get correct value since the derivative
# in tensorflow uses unweighted spaces.
scale = ray_transform.range.cell_volume / ray_transform.domain.cell_volume
print(tf.gradients(y, [x], z)[0].eval() * scale)
print(tf.gradients(y, [x_reshaped], z_reshaped)[0].eval() * scale)

# Compare result with pure ODL
print(ray_transform.derivative(x.eval()).adjoint(z.eval()))
8 changes: 4 additions & 4 deletions odl/contrib/tensorflow/examples/tensorflow_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@
y = tf.constant(np.asarray(noisy_data))

# Add empty axes for batch and channel
x = x[None, ..., None]
y = y[None, ..., None]
x_reshaped = x[None, ..., None]
y_reshaped = y[None, ..., None]

# Define loss function
loss = (tf.reduce_sum((ray_transform_layer(x) - y) ** 2) +
50 * tf.reduce_sum(tf.abs(grad_layer(x))))
loss = (tf.reduce_sum((ray_transform_layer(x_reshaped) - y_reshaped) ** 2) +
50 * tf.reduce_sum(tf.abs(grad_layer(x_reshaped))))

# Train using the ADAM optimizer
optimizer = tf.train.AdamOptimizer(1e-1).minimize(loss)
Expand Down
31 changes: 22 additions & 9 deletions odl/contrib/tensorflow/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from __future__ import print_function, division, absolute_import
import numpy as np
import odl
import uuid
import tensorflow as tf
from tensorflow.python.framework import ops
Expand Down Expand Up @@ -41,12 +42,12 @@ def as_tensorflow_layer(odl_op, name='ODLOperator',
-------
tensorflow_layer : callable
Callable that, when called with a `tensorflow.Tensor` of shape
``(n, *odl_op.domain.shape, 1)`` where ``n`` is the batch size,
``(n,) + odl_op.domain.shape + (1,)`` where ``n`` is the batch size,
returns another `tensorflow.Tensor` which is a lazy evaluation of
``odl_op``.
If ``odl_op`` is an `Operator`, the shape of the returned tensor is
``(n, *odl_op.range.shape, 1)``.
``(n,) + odl_op.range.shape + (1,)``.
If ``odl_op`` is an `Functional`, the shape of the returned tensor is
``(n,)``.
Expand Down Expand Up @@ -158,14 +159,14 @@ def tensorflow_layer_grad_impl(x, dy, name):
if odl_op.is_functional:
in_shape = (n_x,)
else:
in_shape = (n_x,) + odl_op.range.shape + (1,)
out_shape = (n_x,) + odl_op.domain.shape + (1,)
in_shape = (n_x,) + space_shape(odl_op.range) + (1,)
out_shape = (n_x,) + space_shape(odl_op.domain) + (1,)

assert x_shape[1:] == odl_op.domain.shape + (1,)
assert x_shape[1:] == space_shape(odl_op.domain) + (1,)
if odl_op.is_functional:
assert dy_shape[1:] == ()
else:
assert dy_shape[1:] == odl_op.range.shape + (1,)
assert dy_shape[1:] == space_shape(odl_op.range) + (1,)

def _impl(x, dy):
"""Implementation of the adjoint of the derivative.
Expand Down Expand Up @@ -300,13 +301,13 @@ def tensorflow_layer(x, name=None):
n_x = x_shape[0]
fixed_size = False

in_shape = (n_x,) + odl_op.domain.shape + (1,)
in_shape = (n_x,) + space_shape(odl_op.domain) + (1,)
if odl_op.is_functional:
out_shape = (n_x,)
else:
out_shape = (n_x,) + odl_op.range.shape + (1,)
out_shape = (n_x,) + space_shape(odl_op.range) + (1,)

assert x_shape[1:] == odl_op.domain.shape + (1,)
assert x_shape[1:] == space_shape(odl_op.domain) + (1,)

out_dtype = getattr(odl_op.range, 'dtype',
odl_op.domain.dtype)
Expand Down Expand Up @@ -377,6 +378,18 @@ def tensorflow_layer_grad(op, grad):
return tensorflow_layer


def space_shape(space):
"""Return ``space.shape``, including power space base shape.
If ``space`` is a power space, return ``(len(space),) + space[0].shape``,
otherwise return ``space.shape``.
"""
if isinstance(space, odl.ProductSpace) and space.is_power_space:
return (len(space),) + space[0].shape
else:
return space.shape


if __name__ == '__main__':
from odl.util.testutils import run_doctests
run_doctests()
2 changes: 1 addition & 1 deletion odl/space/pspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def __len__(self):

@property
def shape(self):
"""Total spaces per axis spaces, computed recursively.
"""Total spaces per axis, computed recursively.
The recursion ends at the fist level that does not comprise
a *power* space, i.e., which is not made of equal spaces.
Expand Down

0 comments on commit cecad04

Please sign in to comment.