diff --git a/pymc/math.py b/pymc/math.py index 7c9ceaa9ecb..7fe8d1e5e52 100644 --- a/pymc/math.py +++ b/pymc/math.py @@ -22,8 +22,6 @@ import pytensor.sparse import pytensor.tensor as pt import pytensor.tensor.slinalg -import scipy as sp -import scipy.sparse from pytensor.graph.basic import Apply from pytensor.graph.op import Op @@ -93,9 +91,8 @@ from pytensor.tensor.linalg import solve_triangular from pytensor.tensor.nlinalg import matrix_inverse from pytensor.tensor.special import log_softmax, softmax -from scipy.linalg import block_diag as scipy_block_diag -from pymc.pytensorf import floatX, ix_, largest_common_dtype +from pymc.pytensorf import floatX __all__ = [ "abs", @@ -513,55 +510,9 @@ def batched_diag(C): raise ValueError("Input should be 2 or 3 dimensional") -class BlockDiagonalMatrix(Op): - __props__ = ("sparse", "format") - - def __init__(self, sparse=False, format="csr"): - if format not in ("csr", "csc"): - raise ValueError(f"format must be one of: 'csr', 'csc', got {format}") - self.sparse = sparse - self.format = format - - def make_node(self, *matrices): - if not matrices: - raise ValueError("no matrices to allocate") - matrices = list(map(pt.as_tensor, matrices)) - if any(mat.type.ndim != 2 for mat in matrices): - raise TypeError("all data arguments must be matrices") - if self.sparse: - out_type = pytensor.sparse.matrix(self.format, dtype=largest_common_dtype(matrices)) - else: - out_type = pytensor.tensor.matrix(dtype=largest_common_dtype(matrices)) - return Apply(self, matrices, [out_type]) - - def perform(self, node, inputs, output_storage, params=None): - dtype = largest_common_dtype(inputs) - if self.sparse: - output_storage[0][0] = sp.sparse.block_diag(inputs, self.format, dtype) - else: - output_storage[0][0] = scipy_block_diag(*inputs).astype(dtype) - - def grad(self, inputs, gout): - shapes = pt.stack([i.shape for i in inputs]) - index_end = shapes.cumsum(0) - index_begin = index_end - shapes - slices = [ - ix_( - pt.arange(index_begin[i, 0], index_end[i, 0]), - pt.arange(index_begin[i, 1], index_end[i, 1]), - ) - for i in range(len(inputs)) - ] - return [gout[0][slc] for slc in slices] - - def infer_shape(self, fgraph, nodes, shapes): - first, second = zip(*shapes) - return [(pt.add(*first), pt.add(*second))] - - def block_diagonal(matrices, sparse=False, format="csr"): - r"""See scipy.sparse.block_diag or - scipy.linalg.block_diag for reference + r"""See pt.slinalg.block_diag or + pytensor.sparse.basic.block_diag for reference Parameters ---------- @@ -575,6 +526,12 @@ def block_diagonal(matrices, sparse=False, format="csr"): ------- matrix """ + warnings.warn( + "pymc.math.block_diagonal is deprecated in favor of `pytensor.tensor.linalg.block_diag` and `pytensor.sparse.block_diag` functions. This function will be removed in a future release", + ) if len(matrices) == 1: # graph optimization return matrices[0] - return BlockDiagonalMatrix(sparse=sparse, format=format)(*matrices) + if sparse: + return pytensor.sparse.basic.block_diag(*matrices, format=format) + else: + return pt.slinalg.block_diag(*matrices) diff --git a/requirements-dev.txt b/requirements-dev.txt index b21437de01f..8aff7d60c9e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -30,4 +30,4 @@ sphinx>=1.5 sphinxext-rediraffe types-cachetools typing-extensions>=3.7.4 -watermark \ No newline at end of file +watermark