Skip to content

Commit

Permalink
Add tfp.math.pinv which calculates the Moore-Penrose pseudo-inverse.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 193056877
  • Loading branch information
Joshua V. Dillon authored and Copybara-Service committed Apr 16, 2018
1 parent e8091c2 commit 748af84
Show file tree
Hide file tree
Showing 4 changed files with 310 additions and 1 deletion.
24 changes: 24 additions & 0 deletions tensorflow_probability/python/math/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":custom_gradient",
":linalg",
],
)

Expand All @@ -56,3 +57,26 @@ py_test(
"//tensorflow_probability",
],
)

py_library(
name = "linalg",
srcs = [
"linalg.py",
],
srcs_version = "PY2AND3",
deps = [
# numpy dep,
# tensorflow dep,
],
)

py_test(
name = "linalg_test",
size = "small",
srcs = ["linalg_test.py"],
deps = [
# numpy dep,
# tensorflow dep,
"//tensorflow_probability",
],
)
6 changes: 5 additions & 1 deletion tensorflow_probability/python/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
from __future__ import division
from __future__ import print_function

from tensorflow_probability.python.math.linalg import pinv

from tensorflow.python.util.all_util import remove_undocumented

_allowed_symbols = []
_allowed_symbols = [
'pinv',
]

remove_undocumented(__name__, _allowed_symbols)
164 changes: 164 additions & 0 deletions tensorflow_probability/python/math/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright 2018 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Functions for common linear algebra operations.
Note: Many of these functions will eventually be migrated to core Tensorflow.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Dependency imports
import numpy as np

import tensorflow as tf


__all__ = [
'pinv',
]


def pinv(a, rcond=None, validate_args=False, name=None):
"""Compute the Moore-Penrose pseudo-inverse of a matrix.
Calculate the [generalized inverse of a matrix](
https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse) using its
singular-value decomposition (SVD) and including all large singular values.
The pseudo-inverse of a matrix `A`, is defined as: "the matrix that 'solves'
[the least-squares problem] `A @ x = b`," i.e., if `x_hat` is a solution, then
`A_pinv` is the matrix such that `x_hat = A_pinv @ b`. It can be shown that if
`U @ Sigma @ V.T = A` is the singular value decomposition of `A`, then
`A_pinv = V @ inv(Sigma) U^T`. [(Strang, 1980)][1]
This function is analogous to [`numpy.linalg.pinv`](
https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.pinv.html).
It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
default `rcond` is `1e-15`. Here the default is
`10. * max(num_rows, num_cols) * np.finfo(dtype).eps`.
Args:
a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be
pseudo-inverted.
rcond: `Tensor` of small singular value cutoffs. Singular values smaller
(in modulus) than `rcond` * largest_singular_value (again, in modulus) are
set to zero. Must broadcast against `tf.shape(a)[:-2]`.
Default value: `10. * max(num_rows, num_cols) * np.finfo(a.dtype).eps`.
validate_args: When `True`, additional assertions might be embedded in the
graph.
Default value: `False` (i.e., no graph assertions are added).
name: Python `str` prefixed to ops created by this function.
Default value: "pinv".
Returns:
a_pinv: The pseudo-inverse of input `a`. Has same shape as `a` except
rightmost two dimensions are transposed.
Raises:
TypeError: if input `a` does not have `float`-like `dtype`.
ValueError: if input `a` has fewer than 2 dimensions.
#### Examples
```python
import tensorflow as tf
import tensorflow_probability as tfp
a = tf.constant([[1., 0.4, 0.5],
[0.4, 0.2, 0.25],
[0.5, 0.25, 0.35]])
tf.matmul(tfp.math.pinv(a), a)
# ==> array([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]], dtype=float32)
a = tf.constant([[1., 0.4, 0.5, 1.],
[0.4, 0.2, 0.25, 2.],
[0.5, 0.25, 0.35, 3.]])
tf.matmul(tfp.math.pinv(a), a)
# ==> array([[ 0.76, 0.37, 0.21, -0.02],
[ 0.37, 0.43, -0.33, 0.02],
[ 0.21, -0.33, 0.81, 0.01],
[-0.02, 0.02, 0.01, 1. ]], dtype=float32)
```
#### References
[1]: G. Strang. "Linear Algebra and Its Applications, 2nd Ed." Academic Press,
Inc., 1980, pp. 139-142.
"""
with tf.name_scope(name, 'pinv', [a, rcond]):
a = tf.convert_to_tensor(a, name='a')

if not a.dtype.is_floating:
raise TypeError('Input `a` must have `float`-like `dtype` '
'(saw {}).'.format(a.dtype.name))
if a.shape.ndims is not None and a.shape.ndims < 2:
raise ValueError('Input `a` must have at least 2 dimensions '
'(saw: {}).'.format(a.shape.ndims))
elif validate_args:
assert_rank_at_least_2 = tf.assert_rank_at_least(
a, rank=2,
message='Input `a` must have at least 2 dimensions.')
with tf.control_dependencies([assert_rank_at_least_2]):
a = tf.identity(a)

dtype = a.dtype.as_numpy_dtype

if rcond is None:
def get_dim_size(dim):
if a.shape.ndims is not None and a.shape[dim].value is not None:
return a.shape[dim].value
return tf.shape(a)[dim]
num_rows = get_dim_size(-2)
num_cols = get_dim_size(-1)
if isinstance(num_rows, int) and isinstance(num_cols, int):
max_rows_cols = float(max(num_rows, num_cols))
else:
max_rows_cols = tf.cast(tf.maximum(num_rows, num_cols), dtype)
rcond = 10. * max_rows_cols * np.finfo(dtype).eps

rcond = tf.convert_to_tensor(rcond, dtype=dtype, name='rcond')

# Calculate pseudo inverse via SVD.
# Note: if a is symmetric then u == v. (We might observe additional
# performance by explicitly setting `v = u` in such cases.)
[
singular_values, # Sigma
left_singular_vectors, # U
right_singular_vectors, # V
] = tf.linalg.svd(a, full_matrices=False, compute_uv=True)

# Saturate small singular values to inf. This has the effect of make
# `1. / s = 0.` while not resulting in `NaN` gradients.
max_singular_value = tf.reduce_max(singular_values, axis=-1, keepdims=True)
cutoff = rcond[..., tf.newaxis] * max_singular_value
inf = tf.fill(tf.shape(singular_values), np.array(np.inf, dtype))
singular_values = tf.where(singular_values > cutoff, singular_values, inf)

# Although `a == tf.matmul(u, s * v, transpose_b=True)` we swap
# `u` and `v` here so that `tf.matmul(pinv(A), A) = tf.eye()`, i.e.,
# a matrix inverse has "transposed" semantics.
a_pinv = tf.matmul(
right_singular_vectors / singular_values[..., tf.newaxis, :],
left_singular_vectors,
adjoint_b=True)

if a.shape.ndims is not None:
a_pinv.set_shape(a.shape[:-2].concatenate([a.shape[-1], a.shape[-2]]))

return a_pinv
117 changes: 117 additions & 0 deletions tensorflow_probability/python/math/linalg_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2018 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tests for linear algebra."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Dependency imports
import numpy as np

import tensorflow as tf

from tensorflow_probability.python.math import pinv as pinv
from tensorflow.python.framework import test_util


class _PinvTest(object):

def expected_pinv(self, a, rcond):
"""Calls `np.linalg.pinv` but corrects its broken batch semantics."""
if a.ndim < 3:
return np.linalg.pinv(a, rcond)
if rcond is None:
rcond = 10. * max(a.shape[-2], a.shape[-1]) * np.finfo(a.dtype).eps
s = np.concatenate([a.shape[:-2], [a.shape[-1], a.shape[-2]]])
a_pinv = np.zeros(s, dtype=a.dtype)
for i in np.ndindex(a.shape[:(a.ndim - 2)]):
a_pinv[i] = np.linalg.pinv(
a[i],
rcond=rcond if isinstance(rcond, float) else rcond[i])
return a_pinv

@test_util.run_in_graph_and_eager_modes()
def test_symmetric(self):
a_ = self.dtype([[1., .4, .5],
[.4, .2, .25],
[.5, .25, .35]])
a_ = np.stack([a_ + 1., a_], axis=0) # Batch of matrices.
a = tf.placeholder_with_default(
input=a_,
shape=a_.shape if self.use_static_shape else None)
if self.use_default_rcond:
rcond = None
else:
rcond = self.dtype([0., 0.01]) # Smallest 1 component is forced to zero.
expected_a_pinv_ = self.expected_pinv(a_, rcond)
a_pinv = pinv(a, rcond, validate_args=True)
a_pinv_ = self.evaluate(a_pinv)
self.assertAllClose(expected_a_pinv_, a_pinv_,
atol=1e-5, rtol=1e-5)
if not self.use_static_shape:
return
self.assertAllEqual(expected_a_pinv_.shape, a_pinv.shape)

@test_util.run_in_graph_and_eager_modes()
def test_nonsquare(self):
a_ = self.dtype([[1., .4, .5, 1.],
[.4, .2, .25, 2.],
[.5, .25, .35, 3.]])
a_ = np.stack([a_ + 0.5, a_], axis=0) # Batch of matrices.
a = tf.placeholder_with_default(
input=a_,
shape=a_.shape if self.use_static_shape else None)
if self.use_default_rcond:
rcond = None
else:
# Smallest 2 components are forced to zero.
rcond = self.dtype([0., 0.25])
expected_a_pinv_ = self.expected_pinv(a_, rcond)
a_pinv = pinv(a, rcond, validate_args=True)
a_pinv_ = self.evaluate(a_pinv)
self.assertAllClose(expected_a_pinv_, a_pinv_,
atol=1e-5, rtol=1e-4)
if not self.use_static_shape:
return
self.assertAllEqual(expected_a_pinv_.shape, a_pinv.shape)


class PinvTestDynamic32DefaultRcond(tf.test.TestCase, _PinvTest):
dtype = np.float32
use_static_shape = False
use_default_rcond = True


class PinvTestStatic64DefaultRcond(tf.test.TestCase, _PinvTest):
dtype = np.float64
use_static_shape = True
use_default_rcond = True


class PinvTestDynamic32CustomtRcond(tf.test.TestCase, _PinvTest):
dtype = np.float32
use_static_shape = False
use_default_rcond = False


class PinvTestStatic64CustomRcond(tf.test.TestCase, _PinvTest):
dtype = np.float64
use_static_shape = True
use_default_rcond = False


if __name__ == '__main__':
tf.test.main()

0 comments on commit 748af84

Please sign in to comment.