Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

linalg: least squares #801

Merged
merged 27 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions doc/specs/stdlib_linalg.md
Original file line number Diff line number Diff line change
Expand Up @@ -599,3 +599,47 @@ Specifically, upper Hessenberg matrices satisfy `a_ij = 0` when `j < i-1`, and l
```fortran
{!example/linalg/example_is_hessenberg.f90!}
```

## `lstsq` - Computes the least squares solution to a linear matrix equation.

### Status

Experimental

### Description

This function computes the least-squares solution to a linear matrix equation \( A \cdot x = b \).

Result vector `x` returns the approximate solution that minimizes the 2-norm \( || A \cdot x - b ||_2 \), i.e., it contains the least-squares solution to the problem. Matrix `A` may be full-rank, over-determined, or under-determined. The solver is based on LAPACK's `*GELSD` backends.

### Syntax

`x = ` [[stdlib_linalg(module):lstsq(interface)]] `(a, b, [, cond, overwrite_a, rank, err])`

### Arguments

`a`: Shall be a rank-2 square array containing the coefficient matrix. It is an `intent(inout)` argument.

`b`: Shall be a rank-1 array containing the right-hand-side vector. It is an `intent(in)` argument.
perazz marked this conversation as resolved.
Show resolved Hide resolved

`cond` (optional): Singular value cut-off threshold for rank evaluation: `s_i >= cond*maxval(s), i=1:rank`. Shall be a scalar, `intent(in)` argument.
perazz marked this conversation as resolved.
Show resolved Hide resolved

`overwrite_a` (optional): Shall be an input logical flag. if `.true.`, input matrix a will be used as temporary storage and overwritten. This avoids internal data allocation. This is an `intent(in)` argument.
perazz marked this conversation as resolved.
Show resolved Hide resolved

`rank` (optional): Shall be an `integer` scalar value, that contains the rank of input matrix `A`. This is an `intent(out)` argument.

`err` (optional): Shall be a `type(linalg_state_type)` value. This is an `intent(out)` argument.

### Return value

Returns an array value that represents the solution to the least squares system.

Raises `LINALG_ERROR` if the underlying SVD process did not converge.
perazz marked this conversation as resolved.
Show resolved Hide resolved
Raises `LINALG_VALUE_ERROR` if the matrix and rhs vectors have invalid/incompatible sizes.
perazz marked this conversation as resolved.
Show resolved Hide resolved
Exceptions trigger an `error stop`.

### Example

```fortran
{!example/linalg/example_lstsq.f90!}
```
1 change: 1 addition & 0 deletions example/linalg/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ ADD_EXAMPLE(state1)
ADD_EXAMPLE(state2)
ADD_EXAMPLE(blas_gemv)
ADD_EXAMPLE(lapack_getrf)
ADD_EXAMPLE(lstsq)
27 changes: 27 additions & 0 deletions example/linalg/example_lstsq.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
program example_lstsq
use stdlib_linalg_constants, only: dp
use stdlib_linalg, only: lstsq
use stdlib_linalg_state, only: linalg_state_type
perazz marked this conversation as resolved.
Show resolved Hide resolved
implicit none
type(linalg_state_type) :: err

integer, allocatable :: x(:),y(:)
real(dp), allocatable :: A(:,:),b(:),coef(:)

! Data set
x = [1, 2, 2]
y = [5, 13, 25]

! Fit three points using a parabola, least squares method
! A = [1 x x**2]
A = reshape([[1,1,1],x,x**2],[3,3])
b = y

! Get coefficients of y = coef(1) + x*coef(2) + x^2*coef(3)
coef = lstsq(A,b)

print *, 'parabola: ',coef
! parabola: -0.42857142857141695 1.1428571428571503 4.2857142857142811


end program example_lstsq
35 changes: 32 additions & 3 deletions include/common.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,20 @@
#:set REAL_KINDS = REAL_KINDS + ["qp"]
#:endif

#! BLAS/LAPACK initials for each real kind
#:set REAL_INIT = ["s", "d"]
#:if WITH_XDP
#:set REAL_INIT = REAL_INIT + ["x"]
#:endif
#:if WITH_QP
#:set REAL_INIT = REAL_INIT + ["q"]
#:endif

#! Real types to be considered during templating
#:set REAL_TYPES = ["real({})".format(k) for k in REAL_KINDS]

#! Collected (kind, type) tuples for real types
#:set REAL_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES))
#:set REAL_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_INIT))

#! Complex kinds to be considered during templating
#:set CMPLX_KINDS = ["sp", "dp"]
Expand All @@ -42,11 +51,20 @@
#:set CMPLX_KINDS = CMPLX_KINDS + ["qp"]
#:endif

#! BLAS/LAPACK initials for each complex kind
#:set CMPLX_INIT = ["c", "z"]
#:if WITH_XDP
#:set CMPLX_INIT = CMPLX_INIT + ["y"]
#:endif
#:if WITH_QP
#:set CMPLX_INIT = CMPLX_INIT + ["w"]
#:endif

#! Complex types to be considered during templating
#:set CMPLX_TYPES = ["complex({})".format(k) for k in CMPLX_KINDS]

#! Collected (kind, type) tuples for complex types
#:set CMPLX_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES))
#! Collected (kind, type, initial) tuples for complex types
#:set CMPLX_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_INIT))

#! Integer kinds to be considered during templating
#:set INT_KINDS = ["int8", "int16", "int32", "int64"]
Expand Down Expand Up @@ -109,6 +127,17 @@
#{if rank > 0}#(${":" + ",:" * (rank - 1)}$)#{endif}#
#:enddef

#! Generates an empty array rank suffix.
#!
#! Args:
#! rank (int): Rank of the variable
#!
#! Returns:
#! Empty array rank suffix string (e.g. (0,0) if rank = 2)
#!
#:def emptyranksuffix(rank)
#{if rank > 0}#(${"0" + ",0" * (rank - 1)}$)#{endif}#
#:enddef

#! Joins stripped lines with given character string
#!
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ set(fppFiles
stdlib_kinds.fypp
stdlib_linalg.fypp
stdlib_linalg_diag.fypp
stdlib_linalg_least_squares.fypp
stdlib_linalg_outer_product.fypp
stdlib_linalg_kronecker.fypp
stdlib_linalg_cross_product.fypp
Expand Down
55 changes: 52 additions & 3 deletions src/stdlib_linalg.fypp
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
#:include "common.fypp"
#:set RCI_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES + INT_KINDS_TYPES
#:set RC_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES
#:set RCI_KINDS_TYPES = RC_KINDS_TYPES + INT_KINDS_TYPES
#:set RHS_SUFFIX = ["one","many"]
#:set RHS_SYMBOL = [ranksuffix(r) for r in [1,2]]
#:set RHS_EMPTY = [emptyranksuffix(r) for r in [1,2]]
#:set ALL_RHS = list(zip(RHS_SYMBOL,RHS_SUFFIX,RHS_EMPTY))
module stdlib_linalg
!!Provides a support for various linear algebra procedures
!! ([Specification](../page/specs/stdlib_linalg.html))
use stdlib_kinds, only: sp, dp, xdp, qp, &
int8, int16, int32, int64
use stdlib_kinds, only: xdp, int8, int16, int32, int64
use stdlib_linalg_constants, only: sp, dp, qp, lk, ilp
use stdlib_error, only: error_stop
use stdlib_optval, only: optval
use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling
perazz marked this conversation as resolved.
Show resolved Hide resolved
implicit none
private

public :: diag
public :: eye
public :: lstsq
public :: trace
public :: outer_product
public :: kronecker_product
Expand Down Expand Up @@ -214,6 +221,48 @@ module stdlib_linalg
#:endfor
end interface is_hessenberg

! Least squares solution to system Ax=b, i.e. such that the 2-norm abs(b-Ax) is minimized.
interface lstsq
!! version: experimental
!!
!! Computes the squares solution to system \( A \cdot x = b \).
!! ([Specification](../page/specs/stdlib_linalg.html#det-computes-the-determinant-of-a-square-matrix))
!!
!!### Summary
!! Interface for computing least squares, i.e. the 2-norm \( || (b-A \cdot x ||_2 \) minimizing solution.
!!
!!### Description
!!
!! This interface provides methods for computing the least squares of a linear matrix system.
!! Supported data types include `real` and `complex`.
!!
!!@note The solution is based on LAPACK's singular value decomposition `*GELSD` methods.
!!@note BLAS/LAPACK backends do not currently support extended precision (``xdp``).
!!
#:for nd,ndsuf,nde in ALL_RHS
#:for rk,rt,ri in RC_KINDS_TYPES
#:if rk!="xdp"
module function stdlib_linalg_${ri}$_lstsq_${ndsuf}$(a,b,cond,overwrite_a,rank,err) result(x)
!> Input matrix a[n,n]
${rt}$, intent(inout), target :: a(:,:)
!> Right hand side vector or array, b[n] or b[n,nrhs]
${rt}$, intent(in) :: b${nd}$
!> [optional] cutoff for rank evaluation: singular values s(i)<=cond*maxval(s) are considered 0.
real(${rk}$), optional, intent(in) :: cond
!> [optional] Can A,b data be overwritten and destroyed?
logical(lk), optional, intent(in) :: overwrite_a
!> [optional] Return rank of A
integer(ilp), optional, intent(out) :: rank
!> [optional] state return flag. On error if not requested, the code will stop
type(linalg_state_type), optional, intent(out) :: err
!> Result array/matrix x[n] or x[n,nrhs]
${rt}$, allocatable, target :: x${nd}$
end function stdlib_linalg_${ri}$_lstsq_${ndsuf}$
#:endif
#:endfor
#:endfor
end interface lstsq

contains


Expand Down
Loading
Loading