Skip to content

Commit

Permalink
Merge branch 'main' into features/1563-Add_DMD
Browse files Browse the repository at this point in the history
  • Loading branch information
mrfh92 authored Dec 3, 2024
2 parents 6a65fa4 + e6499cf commit fc084fc
Show file tree
Hide file tree
Showing 10 changed files with 663 additions and 98 deletions.
17 changes: 17 additions & 0 deletions benchmarks/cb/decomposition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# flake8: noqa
import heat as ht
from mpi4py import MPI
from perun import monitor
from heat.decomposition import IncrementalPCA


@monitor()
def incremental_pca_split0(list_of_X, n_components):
ipca = IncrementalPCA(n_components=n_components)
for X in list_of_X:
ipca.partial_fit(X)


def run_decomposition_benchmarks():
list_of_X = [ht.random.rand(50000, 500, split=0) for _ in range(10)]
incremental_pca_split0(list_of_X, 50)
2 changes: 2 additions & 0 deletions benchmarks/cb/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
from cluster import run_cluster_benchmarks
from manipulations import run_manipulation_benchmarks
from preprocessing import run_preprocessing_benchmarks
from decomposition import run_decomposition_benchmarks

run_linalg_benchmarks()
run_cluster_benchmarks()
run_manipulation_benchmarks()
run_preprocessing_benchmarks()
run_decomposition_benchmarks()
105 changes: 48 additions & 57 deletions heat/core/linalg/qr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
QR decomposition of (distributed) 2-D ``DNDarray``s.
QR decomposition of ``DNDarray``s.
"""

import collections
Expand All @@ -24,16 +24,19 @@ def qr(
Factor the matrix ``A`` as *QR*, where ``Q`` is orthonormal and ``R`` is upper-triangular.
If ``mode = "reduced``, function returns ``QR(Q=Q, R=R)``, if ``mode = "r"`` function returns ``QR(Q=None, R=R)``
This function also works for batches of matrices; in this case, the last two dimensions of the input array are considered as the matrix dimensions.
The output arrays have the same leading batch dimensions as the input array.
Parameters
----------
A : DNDarray of shape (M, N)
Array which will be decomposed. So far only 2D arrays with datatype float32 or float64 are supported
For split=0, the matrix must be tall skinny, i.e. the local chunks of data must have at least as many rows as columns.
A : DNDarray of shape (M, N), of shape (...,M,N) in the batched case
Array which will be decomposed. So far only arrays with datatype float32 or float64 are supported
For split=0 (-2, in the batched case), the matrix must be tall skinny, i.e. the local chunks of data must have at least as many rows as columns.
mode : str, optional
default "reduced" returns Q and R with dimensions (M, min(M,N)) and (min(M,N), N), respectively.
default "reduced" returns Q and R with dimensions (M, min(M,N)) and (min(M,N), N). Potential batch dimensions are not modified.
"r" returns only R, with dimensions (min(M,N), N).
procs_to_merge : int, optional
This parameter is only relevant for split=0 and determines the number of processes to be merged at one step during the so-called TS-QR algorithm.
This parameter is only relevant for split=0 (-2, in the batched case) and determines the number of processes to be merged at one step during the so-called TS-QR algorithm.
The default is 2. Higher choices might be faster, but will probably result in higher memory consumption. 0 corresponds to merging all processes at once.
We only recommend to modify this parameter if you are familiar with the TS-QR algorithm (see the references below).
Expand All @@ -49,7 +52,7 @@ def qr(
Unlike ``numpy.linalg.qr()``, `ht.linalg.qr` only supports ``mode="reduced"`` or ``mode="r"`` for the moment, since "complete" may result in heavy memory usage.
Heats QR function is built on top of PyTorchs QR function, ``torch.linalg.qr()``, using LAPACK (CPU) and MAGMA (CUDA) on
the backend. For split=0, tall-skinny QR (TS-QR) is implemented, while for split=1 a block-wise version of stabilized Gram-Schmidt orthogonalization is used.
the backend. For split=0 (-2, in the batched case), tall-skinny QR (TS-QR) is implemented, while for split=1 (-1, in the batched case) a block-wise version of stabilized Gram-Schmidt orthogonalization is used.
References
-----------
Expand Down Expand Up @@ -87,65 +90,53 @@ def qr(
if procs_to_merge == 0:
procs_to_merge = A.comm.size

if A.ndim != 2:
raise ValueError(
f"Array 'A' must be 2 dimensional, buts has {A.ndim} dimensions. \n Please open an issue on GitHub if you require QR for batches of matrices similar to PyTorch."
)
if A.dtype not in [float32, float64]:
raise TypeError(f"Array 'A' must have a datatype of float32 or float64, but has {A.dtype}")

QR = collections.namedtuple("QR", "Q, R")

if not A.is_distributed():
if not A.is_distributed() or A.split < A.ndim - 2:
# handle the case of a single process or split=None: just PyTorch QR
Q, R = torch.linalg.qr(A.larray, mode=mode)
R = DNDarray(
R,
gshape=R.shape,
dtype=A.dtype,
split=A.split,
device=A.device,
comm=A.comm,
balanced=True,
)
R = factories.array(R, is_split=A.split)
if mode == "reduced":
Q = DNDarray(
Q,
gshape=Q.shape,
dtype=A.dtype,
split=A.split,
device=A.device,
comm=A.comm,
balanced=True,
)
Q = factories.array(Q, is_split=A.split)
else:
Q = None
return QR(Q, R)

if A.split == 1:
if A.split == A.ndim - 1:
# handle the case that A is split along the columns
# here, we apply a block-wise version of (stabilized) Gram-Schmidt orthogonalization
# instead of orthogonalizing each column of A individually, we orthogonalize blocks of columns (i.e. the local arrays) at once

lshapes = A.lshape_map[:, 1]
lshapes = A.lshape_map[:, -1]
lshapes_cum = torch.cumsum(lshapes, 0)
nprocs = A.comm.size

if A.shape[0] >= A.shape[1]:
if A.shape[-2] >= A.shape[-1]:
last_row_reached = nprocs
k = A.shape[1]
k = A.shape[-1]
else:
last_row_reached = min(torch.argwhere(lshapes_cum >= A.shape[0]))[0]
k = A.shape[0]
last_row_reached = min(torch.argwhere(lshapes_cum >= A.shape[-2]))[0]
k = A.shape[-2]

if mode == "reduced":
Q = factories.zeros(A.shape, dtype=A.dtype, split=1, device=A.device, comm=A.comm)
Q = factories.zeros(
A.shape, dtype=A.dtype, split=A.ndim - 1, device=A.device, comm=A.comm
)

R = factories.zeros((k, A.shape[1]), dtype=A.dtype, split=1, device=A.device, comm=A.comm)
R = factories.zeros(
(*A.shape[:-2], k, A.shape[-1]),
dtype=A.dtype,
split=A.ndim - 1,
device=A.device,
comm=A.comm,
)
R_shapes = torch.hstack(
[
torch.zeros(1, dtype=torch.int32, device=A.device.torch_device),
torch.cumsum(R.lshape_map[:, 1], 0),
torch.cumsum(R.lshape_map[:, -1], 0),
]
)

Expand All @@ -154,10 +145,11 @@ def qr(
for i in range(last_row_reached + 1):
# this loop goes through all the column-blocks (i.e. local arrays) of the matrix
# this corresponds to the loop over all columns in classical Gram-Schmidt

if i < nprocs - 1:
k_loc_i = min(A.shape[0], A.lshape_map[i, 1])
k_loc_i = min(A.shape[-2], A.lshape_map[i, -1])
Q_buf = torch.zeros(
(A.shape[0], k_loc_i), dtype=A.larray.dtype, device=A.device.torch_device
(*A.shape[:-1], k_loc_i), dtype=A.larray.dtype, device=A.device.torch_device
)

if A.comm.rank == i:
Expand All @@ -167,32 +159,31 @@ def qr(
Q_buf = Q_curr
if mode == "reduced":
Q.larray = Q_curr
r_size = R.larray[R_shapes[i] : R_shapes[i + 1], :].shape[0]
R.larray[R_shapes[i] : R_shapes[i + 1], :] = R_loc[:r_size, :]
r_size = R.larray[..., R_shapes[i] : R_shapes[i + 1], :].shape[-2]
R.larray[..., R_shapes[i] : R_shapes[i + 1], :] = R_loc[..., :r_size, :]

if i < nprocs - 1:
# broadcast the orthogonalized block of columns to all other processes
req = A.comm.Ibcast(Q_buf, root=i)
req.Wait()
A.comm.Bcast(Q_buf, root=i)

if A.comm.rank > i:
# subtract the contribution of the current block of columns from the remaining columns
R_loc = Q_buf.T @ A_columns
R_loc = torch.transpose(Q_buf, -2, -1) @ A_columns
A_columns -= Q_buf @ R_loc
r_size = R.larray[R_shapes[i] : R_shapes[i + 1], :].shape[0]
R.larray[R_shapes[i] : R_shapes[i + 1], :] = R_loc[:r_size, :]
r_size = R.larray[..., R_shapes[i] : R_shapes[i + 1], :].shape[-2]
R.larray[..., R_shapes[i] : R_shapes[i + 1], :] = R_loc[..., :r_size, :]

if mode == "reduced":
Q = Q[:, :k].balance()
Q = Q[..., :, :k].balance()
else:
Q = None

return QR(Q, R)

if A.split == 0:
if A.split == A.ndim - 2:
# implementation of TS-QR for split = 0
# check that data distribution is reasonable for TS-QR (i.e. tall-skinny matrix with also tall-skinny local chunks of data)
if A.lshape_map[:, 0].max().item() < A.shape[1]:
if A.lshape_map[:, -2].max().item() < A.shape[-1]:
raise ValueError(
"A is split along the rows and the local chunks of data are rectangular with more rows than columns. \n Applying TS-QR in this situation is not reasonable w.r.t. runtime and memory consumption. \n We recomment to split A along the columns instead. \n In case this is not an option for you, please open an issue on GitHub."
)
Expand All @@ -209,10 +200,10 @@ def qr(
while len(current_procs) > 1:
if A.comm.rank in current_procs and local_comm.size > 1:
# create array to collect the R_loc's from all processes of the process group of at most n_procs_to_merge processes
shapes_R_loc = local_comm.gather(R_loc.shape[0], root=0)
shapes_R_loc = local_comm.gather(R_loc.shape[-2], root=0)
if local_comm.rank == 0:
gathered_R_loc = torch.zeros(
(sum(shapes_R_loc), R_loc.shape[1]),
(*R_loc.shape[:-2], sum(shapes_R_loc), R_loc.shape[-1]),
device=R_loc.device,
dtype=R_loc.dtype,
)
Expand All @@ -225,7 +216,7 @@ def qr(
counts = None
displs = None
# gather the R_loc's from all processes of the process group of at most n_procs_to_merge processes
local_comm.Gatherv(R_loc, (gathered_R_loc, counts, displs), root=0, axis=0)
local_comm.Gatherv(R_loc, (gathered_R_loc, counts, displs), root=0, axis=-2)
# perform QR decomposition on the concatenated, gathered R_loc's to obtain new R_loc
if local_comm.rank == 0:
previous_shape = R_loc.shape
Expand All @@ -242,7 +233,7 @@ def qr(
dtype=R_loc.dtype,
)
# scatter the Q_buf to all processes of the process group
local_comm.Scatterv((Q_buf, counts, displs), scattered_Q_buf, root=0, axis=0)
local_comm.Scatterv((Q_buf, counts, displs), scattered_Q_buf, root=0, axis=-2)
del gathered_R_loc, Q_buf

# for each process in the current processes, broadcast the scattered_Q_buf of this process
Expand Down Expand Up @@ -282,7 +273,7 @@ def qr(
leave_comm = A.comm.Split(A.comm.rank // procs_to_merge**level, A.comm.rank)
level += 1
# broadcast the final R_loc to all processes
R_gshape = (A.shape[1], A.shape[1])
R_gshape = (*A.shape[:-2], A.shape[-1], A.shape[-1])
if A.comm.rank != 0:
R_loc = torch.empty(R_gshape, dtype=R_loc.dtype, device=R_loc.device)
A.comm.Bcast(R_loc, root=0)
Expand All @@ -302,7 +293,7 @@ def qr(
Q_loc,
gshape=A.shape,
dtype=A.dtype,
split=0,
split=A.split,
device=A.device,
comm=A.comm,
balanced=True,
Expand Down
Loading

0 comments on commit fc084fc

Please sign in to comment.