Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix nan in reverse mode gradient caused by rotation_matrix #1457

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions desc/compute/_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@

from desc.backend import jnp, sign

from ..utils import cross, dot, safenormalize
from ..utils import cross, dot
from .data_index import register_compute_fun
from .geom_utils import rotation_matrix, rpz2xyz, rpz2xyz_vec, xyz2rpz, xyz2rpz_vec
from .geom_utils import (
rotation_matrix_vector_vector,
rpz2xyz,
rpz2xyz_vec,
xyz2rpz,
xyz2rpz_vec,
)


@register_compute_fun(
Expand Down Expand Up @@ -208,9 +214,7 @@
coords = jnp.array([X, Y, Z]).T
# rotate into place
Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis
axis = cross(Zaxis, normal)
angle = jnp.arccos(dot(Zaxis, safenormalize(normal)))
A = rotation_matrix(axis=axis, angle=angle)
A = rotation_matrix_vector_vector(Zaxis, normal)

Check warning on line 217 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L217

Added line #L217 was not covered by tests
coords = jnp.matmul(coords, A.T) + center
coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T) + params["shift"]
# convert back to rpz
Expand Down Expand Up @@ -248,9 +252,7 @@
coords = jnp.array([dX, dY, dZ]).T
# rotate into place
Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis
axis = cross(Zaxis, normal)
angle = jnp.arccos(dot(Zaxis, safenormalize(normal)))
A = rotation_matrix(axis=axis, angle=angle)
A = rotation_matrix_vector_vector(Zaxis, normal)

Check warning on line 255 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L255

Added line #L255 was not covered by tests
coords = jnp.matmul(coords, A.T)
coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T)
# convert back to rpz
Expand Down Expand Up @@ -293,9 +295,7 @@
coords = jnp.array([d2X, d2Y, d2Z]).T
# rotate into place
Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis
axis = cross(Zaxis, normal)
angle = jnp.arccos(dot(Zaxis, safenormalize(normal)))
A = rotation_matrix(axis=axis, angle=angle)
A = rotation_matrix_vector_vector(Zaxis, normal)

Check warning on line 298 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L298

Added line #L298 was not covered by tests
coords = jnp.matmul(coords, A.T)
coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T)
# convert back to rpz
Expand Down Expand Up @@ -345,9 +345,7 @@
coords = jnp.array([d3X, d3Y, d3Z]).T
# rotate into place
Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis
axis = cross(Zaxis, normal)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another way to rotate a into b would be to rotate by pi around the axis (a+b)/2 which might be easier.

a = Zaxis
b = safenormalize(normal)
axis = safenormalize((a+b)/2)
A = rotation_matrix(axis, angle=np.pi)

That avoids the arccos and I think should be fine since its all safenormalized?

Would still need to special case if they are exactly antiparallel but seems to work ok if they are ~nearly antiparallel

angle = jnp.arccos(dot(Zaxis, safenormalize(normal)))
A = rotation_matrix(axis=axis, angle=angle)
A = rotation_matrix_vector_vector(Zaxis, normal)

Check warning on line 348 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L348

Added line #L348 was not covered by tests
coords = jnp.matmul(coords, A.T)
coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T)
# convert back to rpz
Expand Down
45 changes: 44 additions & 1 deletion desc/compute/geom_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from desc.backend import jnp

from ..utils import safenorm, safenormalize
from ..utils import cross, dot, safediv, safenorm, safenormalize


def reflection_matrix(normal):
Expand All @@ -28,6 +28,10 @@
def rotation_matrix(axis, angle=None):
"""Matrix to rotate points about axis by given angle.

NOTE: not correct if a and b are antiparallel, will
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont think this comment applies here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am certain it does, I checked it and the rotation matrix is wrong for this case with this function too

simply return identity when in reality negative identity
is correct.

Parameters
----------
axis : array-like, shape(3,)
Expand All @@ -53,6 +57,45 @@
return jnp.where(norm < eps, jnp.eye(3), R1 + R2 + R3) # if axis=0, no rotation


def _skew_matrix(a):
return jnp.array([[0, -a[2], a[1]], [a[2], 0, -a[0]], [-a[1], a[0], 0]])

Check warning on line 61 in desc/compute/geom_utils.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/geom_utils.py#L61

Added line #L61 was not covered by tests


def rotation_matrix_vector_vector(a, b):
"""Matrix to rotate vector a onto b.

NOTE: not correct if a and b are antiparallel, will
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could add something like jnp.where(a+b==0, -R, R)

simply return identity when in reality negative identity
is correct.

Parameters
----------
a,b : array-like, shape(3,)
Vectors, in cartesian (X,Y,Z) coordinates
Matrix will correspond to rotating a onto b


Returns
-------
rotmat : ndarray, shape(3,3)
Matrix to rotate points in cartesian (X,Y,Z) coordinates.

"""
a = jnp.asarray(a)
b = jnp.asarray(b)
a = safenormalize(a)
b = safenormalize(b)
axis = cross(a, b)
norm = safenorm(axis)
axis = safenormalize(axis)
eps = 1e2 * jnp.finfo(axis.dtype).eps
skew = _skew_matrix(axis)
R1 = jnp.eye(3)
R2 = skew
R3 = (skew @ skew) * safediv(1, 1 + dot(a, b))
return jnp.where(norm < eps, jnp.eye(3), R1 + R2 + R3) # if axis=0, no rotation

Check warning on line 96 in desc/compute/geom_utils.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/geom_utils.py#L84-L96

Added lines #L84 - L96 were not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this is correct, I seem to get weird answers in a few cases I've tried.
The determinant of a rotation matrix should always be 1, but for example

np.linalg.det(rotation_matrix_vector_vector(np.array([1,0,0]), np.array([1,1,0])))

gives 1.17

and if they are almost antiparallel:

np.linalg.det(rotation_matrix_vector_vector(np.array([1,0,0]), np.array([-1,1e-4,0])))

gives like 4e16



def xyz2rpz(pts):
"""Transform points from cartesian (X,Y,Z) to polar (R,phi,Z) form.

Expand Down
6 changes: 3 additions & 3 deletions tests/test_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,15 +615,15 @@ def test_basis(self):

xs_xyz = cxyz.compute("x_s")["x_s"]
xs_rpz = crpz.compute("x_s")["x_s"]
np.testing.assert_allclose(xs_xyz, xs_rpz, atol=2e-15)
np.testing.assert_allclose(xs_xyz, xs_rpz, atol=3e-15)

xss_xyz = cxyz.compute("x_ss")["x_ss"]
xss_rpz = crpz.compute("x_ss")["x_ss"]
np.testing.assert_allclose(xss_xyz, xss_rpz, atol=2e-15)
np.testing.assert_allclose(xss_xyz, xss_rpz, atol=3e-15)

xsss_xyz = cxyz.compute("x_sss")["x_sss"]
xsss_rpz = crpz.compute("x_sss")["x_sss"]
np.testing.assert_allclose(xsss_xyz, xsss_rpz, atol=2e-15)
np.testing.assert_allclose(xsss_xyz, xsss_rpz, atol=3e-15)

@pytest.mark.unit
def test_misc(self):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from desc.compute.geom_utils import (
rotation_matrix,
rotation_matrix_vector_vector,
rpz2xyz,
rpz2xyz_vec,
xyz2rpz,
Expand All @@ -19,6 +20,9 @@ def test_rotation_matrix():
At = np.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]])
np.testing.assert_allclose(A, At, atol=1e-10)

A = rotation_matrix_vector_vector([1, 0, 0], [0, 1, 0])
np.testing.assert_allclose(A, At, atol=1e-10)


@pytest.mark.unit
def test_xyz2rpz():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_objective_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3076,7 +3076,7 @@ def test_objective_no_nangrad_quadratic_flux(self):
@pytest.mark.unit
def test_objective_no_nangrad_quadratic_flux_minimizing(self):
"""SurfaceQuadraticFlux."""
ext_field = ToroidalMagneticField(1.0, 1.0)
ext_field = FourierPlanarCoil(normal=[0, 0, 1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this intersect the plasma?


surf = FourierRZToroidalSurface(
R_lmn=[4.0, 1.0],
Expand Down
Loading