diff --git a/src/linalg/svd.rs b/src/linalg/svd.rs index fc5ed6b4d..4ce46e91a 100644 --- a/src/linalg/svd.rs +++ b/src/linalg/svd.rs @@ -119,21 +119,24 @@ where eps: T::RealField, max_niter: usize, ) -> Option { + // Ensure that the matrix is not empty before proceeding assert!( !matrix.is_empty(), "Cannot compute the SVD of an empty matrix." ); + let (nrows, ncols) = matrix.shape_generic(); let min_nrows_ncols = nrows.min(ncols); + // Special cases for 2x2 and 3x3 matrices, handled by predefined methods if Self::use_special_always_ordered_svd2() { - // SAFETY: the reference transmutes are OK since we checked that the types match exactly. + // SAFETY: Reference transmutes are OK because the types match exactly let matrix: &Matrix2 = unsafe { std::mem::transmute(&matrix) }; let result = super::svd2::svd_ordered2(matrix, compute_u, compute_v); let typed_result: &Self = unsafe { std::mem::transmute(&result) }; return Some(typed_result.clone()); } else if Self::use_special_always_ordered_svd3() { - // SAFETY: the reference transmutes are OK since we checked that the types match exactly. + // SAFETY: Reference transmutes are OK because the types match exactly let matrix: &Matrix3 = unsafe { std::mem::transmute(&matrix) }; let result = super::svd3::svd_ordered3(matrix, compute_u, compute_v, eps, max_niter); let typed_result: &Self = unsafe { std::mem::transmute(&result) }; @@ -142,8 +145,10 @@ where let dim = min_nrows_ncols.value(); + // Get the maximum absolute value of the matrix for normalization let m_amax = matrix.camax(); + // If the max value is not zero, unscale the matrix by dividing by m_amax if !m_amax.is_zero() { matrix.unscale_mut(m_amax.clone()); } @@ -158,6 +163,16 @@ where let mut diagonal = bi_matrix.diagonal(); let mut off_diagonal = bi_matrix.off_diagonal(); + // **Check for NaN values in the diagonal elements** + // We check whether any singular value in the diagonal is NaN by comparing each value to itself. + // This works because NaN is the only value in Rust that is not equal to itself. + if diagonal.iter().any(|s| s != s) { + // If any singular value is NaN, return None early + // Explanation: + // NaN != NaN is always true, so if `s != s`, we know `s` is NaN. + return None; // Return early if NaN found + } + let mut niter = 0; let (mut start, mut end) = Self::delimit_subproblem( &mut diagonal, @@ -169,10 +184,11 @@ where eps.clone(), ); + // Iterative SVD computation with Givens rotations while end != start { let subdim = end - start + 1; - // Solve the subproblem. + // Solve subproblem for larger subdimensions (> 2) #[allow(clippy::comparison_chain)] if subdim > 2 { let m = end - 1; @@ -184,19 +200,23 @@ where let dn = diagonal[n].clone(); let fm = off_diagonal[m].clone(); + // Perform calculations to determine the shift value for Givens rotation let tmm = dm.clone() * dm.clone() + off_diagonal[m - 1].clone() * off_diagonal[m - 1].clone(); let tmn = dm * fm.clone(); let tnn = dn.clone() * dn + fm.clone() * fm; + // Compute Wilkinson's shift let shift = symmetric_eigen::wilkinson_shift(tmm, tnn, tmn); + // Create vector for subsequent Givens rotations vec = Vector2::new( diagonal[start].clone() * diagonal[start].clone() - shift, diagonal[start].clone() * off_diagonal[start].clone(), ); } + // Perform Givens rotations to reduce the bidiagonal matrix for k in start..n { let m12 = if k == n - 1 { T::RealField::zero() @@ -224,7 +244,6 @@ where } let v = Vector2::new(subm[(0, 0)].clone(), subm[(1, 0)].clone()); - // TODO: does the case `v.y == 0` ever happen? let (rot2, norm2) = GivensRotation::cancel_y(&v) .unwrap_or((GivensRotation::identity(), subm[(0, 0)].clone())); @@ -264,7 +283,7 @@ where } } } else if subdim == 2 { - // Solve the remaining 2x2 subproblem. + // Solve the 2x2 subproblem if subdim == 2 let (u2, s, v2) = compute_2x2_uptrig_svd( diagonal[start].clone(), off_diagonal[start].clone(), @@ -321,9 +340,10 @@ where } } + // Unscale the singular values after SVD computation diagonal *= m_amax; - // Ensure all singular value are non-negative. + // Ensure all singular values are non-negative for i in 0..dim { let sval = diagonal[i].clone();