Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue #1475: SVD now tests for NaN values #1477

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions src/linalg/svd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,21 +119,24 @@ where
eps: T::RealField,
max_niter: usize,
) -> Option<Self> {
// Ensure that the matrix is not empty before proceeding
assert!(
!matrix.is_empty(),
"Cannot compute the SVD of an empty matrix."
);

let (nrows, ncols) = matrix.shape_generic();
let min_nrows_ncols = nrows.min(ncols);

// Special cases for 2x2 and 3x3 matrices, handled by predefined methods
if Self::use_special_always_ordered_svd2() {
// SAFETY: the reference transmutes are OK since we checked that the types match exactly.
// SAFETY: Reference transmutes are OK because the types match exactly
let matrix: &Matrix2<T::RealField> = unsafe { std::mem::transmute(&matrix) };
let result = super::svd2::svd_ordered2(matrix, compute_u, compute_v);
let typed_result: &Self = unsafe { std::mem::transmute(&result) };
return Some(typed_result.clone());
} else if Self::use_special_always_ordered_svd3() {
// SAFETY: the reference transmutes are OK since we checked that the types match exactly.
// SAFETY: Reference transmutes are OK because the types match exactly
let matrix: &Matrix3<T::RealField> = unsafe { std::mem::transmute(&matrix) };
let result = super::svd3::svd_ordered3(matrix, compute_u, compute_v, eps, max_niter);
let typed_result: &Self = unsafe { std::mem::transmute(&result) };
Expand All @@ -142,8 +145,10 @@ where

let dim = min_nrows_ncols.value();

// Get the maximum absolute value of the matrix for normalization
let m_amax = matrix.camax();

// If the max value is not zero, unscale the matrix by dividing by m_amax
if !m_amax.is_zero() {
matrix.unscale_mut(m_amax.clone());
}
Expand All @@ -158,6 +163,16 @@ where
let mut diagonal = bi_matrix.diagonal();
let mut off_diagonal = bi_matrix.off_diagonal();

// **Check for NaN values in the diagonal elements**
// We check whether any singular value in the diagonal is NaN by comparing each value to itself.
// This works because NaN is the only value in Rust that is not equal to itself.
if diagonal.iter().any(|s| s != s) {
// If any singular value is NaN, return None early
// Explanation:
// NaN != NaN is always true, so if `s != s`, we know `s` is NaN.
return None; // Return early if NaN found
}

let mut niter = 0;
let (mut start, mut end) = Self::delimit_subproblem(
&mut diagonal,
Expand All @@ -169,10 +184,11 @@ where
eps.clone(),
);

// Iterative SVD computation with Givens rotations
while end != start {
let subdim = end - start + 1;

// Solve the subproblem.
// Solve subproblem for larger subdimensions (> 2)
#[allow(clippy::comparison_chain)]
if subdim > 2 {
let m = end - 1;
Expand All @@ -184,19 +200,23 @@ where
let dn = diagonal[n].clone();
let fm = off_diagonal[m].clone();

// Perform calculations to determine the shift value for Givens rotation
let tmm = dm.clone() * dm.clone()
+ off_diagonal[m - 1].clone() * off_diagonal[m - 1].clone();
let tmn = dm * fm.clone();
let tnn = dn.clone() * dn + fm.clone() * fm;

// Compute Wilkinson's shift
let shift = symmetric_eigen::wilkinson_shift(tmm, tnn, tmn);

// Create vector for subsequent Givens rotations
vec = Vector2::new(
diagonal[start].clone() * diagonal[start].clone() - shift,
diagonal[start].clone() * off_diagonal[start].clone(),
);
}

// Perform Givens rotations to reduce the bidiagonal matrix
for k in start..n {
let m12 = if k == n - 1 {
T::RealField::zero()
Expand Down Expand Up @@ -224,7 +244,6 @@ where
}

let v = Vector2::new(subm[(0, 0)].clone(), subm[(1, 0)].clone());
// TODO: does the case `v.y == 0` ever happen?
let (rot2, norm2) = GivensRotation::cancel_y(&v)
.unwrap_or((GivensRotation::identity(), subm[(0, 0)].clone()));

Expand Down Expand Up @@ -264,7 +283,7 @@ where
}
}
} else if subdim == 2 {
// Solve the remaining 2x2 subproblem.
// Solve the 2x2 subproblem if subdim == 2
let (u2, s, v2) = compute_2x2_uptrig_svd(
diagonal[start].clone(),
off_diagonal[start].clone(),
Expand Down Expand Up @@ -321,9 +340,10 @@ where
}
}

// Unscale the singular values after SVD computation
diagonal *= m_amax;

// Ensure all singular value are non-negative.
// Ensure all singular values are non-negative
for i in 0..dim {
let sval = diagonal[i].clone();

Expand Down