Skip to content

Commit

Permalink
Deprecate block_diag from math module in favor of PyTensor (#7132)
Browse files Browse the repository at this point in the history
  • Loading branch information
AryanNanda17 authored Feb 8, 2024
1 parent 8745974 commit 627a8dd
Showing 1 changed file with 10 additions and 53 deletions.
63 changes: 10 additions & 53 deletions pymc/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import pytensor.sparse
import pytensor.tensor as pt
import pytensor.tensor.slinalg
import scipy as sp
import scipy.sparse

from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
Expand Down Expand Up @@ -93,9 +91,8 @@
from pytensor.tensor.linalg import solve_triangular
from pytensor.tensor.nlinalg import matrix_inverse
from pytensor.tensor.special import log_softmax, softmax
from scipy.linalg import block_diag as scipy_block_diag

from pymc.pytensorf import floatX, ix_, largest_common_dtype
from pymc.pytensorf import floatX

__all__ = [
"abs",
Expand Down Expand Up @@ -513,55 +510,9 @@ def batched_diag(C):
raise ValueError("Input should be 2 or 3 dimensional")


class BlockDiagonalMatrix(Op):
__props__ = ("sparse", "format")

def __init__(self, sparse=False, format="csr"):
if format not in ("csr", "csc"):
raise ValueError(f"format must be one of: 'csr', 'csc', got {format}")
self.sparse = sparse
self.format = format

def make_node(self, *matrices):
if not matrices:
raise ValueError("no matrices to allocate")
matrices = list(map(pt.as_tensor, matrices))
if any(mat.type.ndim != 2 for mat in matrices):
raise TypeError("all data arguments must be matrices")
if self.sparse:
out_type = pytensor.sparse.matrix(self.format, dtype=largest_common_dtype(matrices))
else:
out_type = pytensor.tensor.matrix(dtype=largest_common_dtype(matrices))
return Apply(self, matrices, [out_type])

def perform(self, node, inputs, output_storage, params=None):
dtype = largest_common_dtype(inputs)
if self.sparse:
output_storage[0][0] = sp.sparse.block_diag(inputs, self.format, dtype)
else:
output_storage[0][0] = scipy_block_diag(*inputs).astype(dtype)

def grad(self, inputs, gout):
shapes = pt.stack([i.shape for i in inputs])
index_end = shapes.cumsum(0)
index_begin = index_end - shapes
slices = [
ix_(
pt.arange(index_begin[i, 0], index_end[i, 0]),
pt.arange(index_begin[i, 1], index_end[i, 1]),
)
for i in range(len(inputs))
]
return [gout[0][slc] for slc in slices]

def infer_shape(self, fgraph, nodes, shapes):
first, second = zip(*shapes)
return [(pt.add(*first), pt.add(*second))]


def block_diagonal(matrices, sparse=False, format="csr"):
r"""See scipy.sparse.block_diag or
scipy.linalg.block_diag for reference
r"""See pt.slinalg.block_diag or
pytensor.sparse.basic.block_diag for reference
Parameters
----------
Expand All @@ -575,6 +526,12 @@ def block_diagonal(matrices, sparse=False, format="csr"):
-------
matrix
"""
warnings.warn(
"pymc.math.block_diagonal is deprecated in favor of `pytensor.tensor.linalg.block_diag` and `pytensor.sparse.block_diag` functions. This function will be removed in a future release",
)
if len(matrices) == 1: # graph optimization
return matrices[0]
return BlockDiagonalMatrix(sparse=sparse, format=format)(*matrices)
if sparse:
return pytensor.sparse.basic.block_diag(*matrices, format=format)
else:
return pt.slinalg.block_diag(*matrices)

0 comments on commit 627a8dd

Please sign in to comment.