From 4db8968b32b5404767ba95d01b0c3a77077cacf2 Mon Sep 17 00:00:00 2001 From: ljleb Date: Sun, 17 Dec 2023 23:15:06 -0500 Subject: [PATCH] cpu error case --- sd_meh/merge_methods.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 6961f55..6d0bdec 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -232,8 +232,8 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): a_centroid = a_neurons.mean(0) b_centroid = b_neurons.mean(0) new_centroid = weighted_sum(a_centroid, b_centroid, alpha) - if len(a.shape) == 1 or len(a.shape) == 2 and a.shape[0] == 1: - return new_centroid.reshape_as(a) + if 1 in a_neurons.shape: + return (a_neurons + new_centroid).reshape_as(a) a_neurons -= a_centroid b_neurons -= b_centroid @@ -245,13 +245,11 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): if alpha_is_float: # cancel reflection. without this, eigenvalues often have a complex component # and then we can't obtain a valid dtype for the merge - u[:, -1] *= torch.nan_to_num(1 / (torch.det(u) * torch.det(v_t))) + u[:, -1] /= (torch.det(u) * torch.det(v_t)) transform = rotation = u @ v_t - print(f"shape: {a.shape} -> {a_neurons.shape} -> {transform.shape}") - det = torch.det(transform) - if torch.abs(det.abs() - 1) > 1e-6: - print("determinant error:", det) + if not torch.isfinite(rotation).all(): + raise ValueError(f"determinant error: {torch.det(rotation)}") if alpha_is_float: transform = fractional_matrix_power(transform, alpha) @@ -277,6 +275,4 @@ def fractional_matrix_power(matrix: Tensor, power: float): eigenvalues, eigenvectors = torch.linalg.eig(matrix) eigenvalues.pow_(power) result = eigenvectors @ torch.diag(eigenvalues) @ torch.linalg.inv(eigenvectors) - if ((error := result.imag) > 1e-4).any(): - print("image error:", error) return result.real.to(dtype=matrix.dtype)