Skip to content

Commit

Permalink
cpu error case
Browse files Browse the repository at this point in the history
  • Loading branch information
ljleb committed Dec 18, 2023
1 parent ac54d26 commit 4db8968
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions sd_meh/merge_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs):
a_centroid = a_neurons.mean(0)
b_centroid = b_neurons.mean(0)
new_centroid = weighted_sum(a_centroid, b_centroid, alpha)
if len(a.shape) == 1 or len(a.shape) == 2 and a.shape[0] == 1:
return new_centroid.reshape_as(a)
if 1 in a_neurons.shape:
return (a_neurons + new_centroid).reshape_as(a)

a_neurons -= a_centroid
b_neurons -= b_centroid
Expand All @@ -245,13 +245,11 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs):
if alpha_is_float:
# cancel reflection. without this, eigenvalues often have a complex component
# and then we can't obtain a valid dtype for the merge
u[:, -1] *= torch.nan_to_num(1 / (torch.det(u) * torch.det(v_t)))
u[:, -1] /= (torch.det(u) * torch.det(v_t))

transform = rotation = u @ v_t
print(f"shape: {a.shape} -> {a_neurons.shape} -> {transform.shape}")
det = torch.det(transform)
if torch.abs(det.abs() - 1) > 1e-6:
print("determinant error:", det)
if not torch.isfinite(rotation).all():
raise ValueError(f"determinant error: {torch.det(rotation)}")

if alpha_is_float:
transform = fractional_matrix_power(transform, alpha)
Expand All @@ -277,6 +275,4 @@ def fractional_matrix_power(matrix: Tensor, power: float):
eigenvalues, eigenvectors = torch.linalg.eig(matrix)
eigenvalues.pow_(power)
result = eigenvectors @ torch.diag(eigenvalues) @ torch.linalg.inv(eigenvectors)
if ((error := result.imag) > 1e-4).any():
print("image error:", error)
return result.real.to(dtype=matrix.dtype)

0 comments on commit 4db8968

Please sign in to comment.