Skip to content

Commit

Permalink
add vector-vector based rotation matrix to correct reverse mode bug
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici committed Dec 11, 2024
1 parent 4f05b37 commit e8726db
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 40 deletions.
26 changes: 12 additions & 14 deletions desc/compute/_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@

from desc.backend import jnp, sign

from ..utils import cross, dot, safenormalize
from ..utils import cross, dot
from .data_index import register_compute_fun
from .geom_utils import rotation_matrix, rpz2xyz, rpz2xyz_vec, xyz2rpz, xyz2rpz_vec
from .geom_utils import (
rotation_matrix_vector_vector,
rpz2xyz,
rpz2xyz_vec,
xyz2rpz,
xyz2rpz_vec,
)


@register_compute_fun(
Expand Down Expand Up @@ -208,9 +214,7 @@ def _x_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
coords = jnp.array([X, Y, Z]).T
# rotate into place
Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis
axis = cross(Zaxis, normal)
angle = jnp.arccos(dot(Zaxis, safenormalize(normal)))
A = rotation_matrix(axis=axis, angle=angle)
A = rotation_matrix_vector_vector(Zaxis, normal)

Check warning on line 217 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L217

Added line #L217 was not covered by tests
coords = jnp.matmul(coords, A.T) + center
coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T) + params["shift"]
# convert back to rpz
Expand Down Expand Up @@ -248,9 +252,7 @@ def _x_s_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
coords = jnp.array([dX, dY, dZ]).T
# rotate into place
Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis
axis = cross(Zaxis, normal)
angle = jnp.arccos(dot(Zaxis, safenormalize(normal)))
A = rotation_matrix(axis=axis, angle=angle)
A = rotation_matrix_vector_vector(Zaxis, normal)

Check warning on line 255 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L255

Added line #L255 was not covered by tests
coords = jnp.matmul(coords, A.T)
coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T)
# convert back to rpz
Expand Down Expand Up @@ -293,9 +295,7 @@ def _x_ss_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
coords = jnp.array([d2X, d2Y, d2Z]).T
# rotate into place
Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis
axis = cross(Zaxis, normal)
angle = jnp.arccos(dot(Zaxis, safenormalize(normal)))
A = rotation_matrix(axis=axis, angle=angle)
A = rotation_matrix_vector_vector(Zaxis, normal)

Check warning on line 298 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L298

Added line #L298 was not covered by tests
coords = jnp.matmul(coords, A.T)
coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T)
# convert back to rpz
Expand Down Expand Up @@ -345,9 +345,7 @@ def _x_sss_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
coords = jnp.array([d3X, d3Y, d3Z]).T
# rotate into place
Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis
axis = cross(Zaxis, normal)
angle = jnp.arccos(dot(Zaxis, safenormalize(normal)))
A = rotation_matrix(axis=axis, angle=angle)
A = rotation_matrix_vector_vector(Zaxis, normal)

Check warning on line 348 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L348

Added line #L348 was not covered by tests
coords = jnp.matmul(coords, A.T)
coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T)
# convert back to rpz
Expand Down
45 changes: 44 additions & 1 deletion desc/compute/geom_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from desc.backend import jnp

from ..utils import safenorm, safenormalize
from ..utils import cross, dot, safediv, safenorm, safenormalize


def reflection_matrix(normal):
Expand All @@ -28,6 +28,10 @@ def reflection_matrix(normal):
def rotation_matrix(axis, angle=None):
"""Matrix to rotate points about axis by given angle.
NOTE: not correct if a and b are antiparallel, will
simply return identity when in reality negative identity
is correct.
Parameters
----------
axis : array-like, shape(3,)
Expand All @@ -53,6 +57,45 @@ def rotation_matrix(axis, angle=None):
return jnp.where(norm < eps, jnp.eye(3), R1 + R2 + R3) # if axis=0, no rotation


def _skew_matrix(a):
return jnp.array([[0, -a[2], a[1]], [a[2], 0, -a[0]], [-a[1], a[0], 0]])

Check warning on line 61 in desc/compute/geom_utils.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/geom_utils.py#L61

Added line #L61 was not covered by tests


def rotation_matrix_vector_vector(a, b):
"""Matrix to rotate vector a onto b.
NOTE: not correct if a and b are antiparallel, will
simply return identity when in reality negative identity
is correct.
Parameters
----------
a,b : array-like, shape(3,)
Vectors, in cartesian (X,Y,Z) coordinates
Matrix will correspond to rotating a onto b
Returns
-------
rotmat : ndarray, shape(3,3)
Matrix to rotate points in cartesian (X,Y,Z) coordinates.
"""
a = jnp.asarray(a)
b = jnp.asarray(b)
a = safenormalize(a)
b = safenormalize(b)
axis = cross(a, b)
norm = safenorm(axis)
axis = safenormalize(axis)
eps = 1e2 * jnp.finfo(axis.dtype).eps
skew = _skew_matrix(axis)
R1 = jnp.eye(3)
R2 = skew
R3 = (skew @ skew) * safediv(1, 1 + dot(a, b))
return jnp.where(norm < eps, jnp.eye(3), R1 + R2 + R3) # if axis=0, no rotation

Check warning on line 96 in desc/compute/geom_utils.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/geom_utils.py#L84-L96

Added lines #L84 - L96 were not covered by tests


def xyz2rpz(pts):
"""Transform points from cartesian (X,Y,Z) to polar (R,phi,Z) form.
Expand Down
6 changes: 3 additions & 3 deletions tests/test_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,15 +615,15 @@ def test_basis(self):

xs_xyz = cxyz.compute("x_s")["x_s"]
xs_rpz = crpz.compute("x_s")["x_s"]
np.testing.assert_allclose(xs_xyz, xs_rpz, atol=2e-15)
np.testing.assert_allclose(xs_xyz, xs_rpz, atol=3e-15)

xss_xyz = cxyz.compute("x_ss")["x_ss"]
xss_rpz = crpz.compute("x_ss")["x_ss"]
np.testing.assert_allclose(xss_xyz, xss_rpz, atol=2e-15)
np.testing.assert_allclose(xss_xyz, xss_rpz, atol=3e-15)

xsss_xyz = cxyz.compute("x_sss")["x_sss"]
xsss_rpz = crpz.compute("x_sss")["x_sss"]
np.testing.assert_allclose(xsss_xyz, xsss_rpz, atol=2e-15)
np.testing.assert_allclose(xsss_xyz, xsss_rpz, atol=3e-15)

@pytest.mark.unit
def test_misc(self):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from desc.compute.geom_utils import (
rotation_matrix,
rotation_matrix_vector_vector,
rpz2xyz,
rpz2xyz_vec,
xyz2rpz,
Expand All @@ -19,6 +20,9 @@ def test_rotation_matrix():
At = np.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]])
np.testing.assert_allclose(A, At, atol=1e-10)

A = rotation_matrix_vector_vector([1, 0, 0], [0, 1, 0])
np.testing.assert_allclose(A, At, atol=1e-10)


@pytest.mark.unit
def test_xyz2rpz():
Expand Down
23 changes: 1 addition & 22 deletions tests/test_objective_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3024,27 +3024,6 @@ def test_objective_no_nangrad_boundary_error(self):
g = obj.grad(obj.x(eq, ext_field))
assert not np.any(np.isnan(g)), "boundary error"

@pytest.mark.unit
def test_objective_no_nangrad_boundary_error_with_coil(self):
"""BoundaryError."""
ext_field = FourierPlanarCoil(center=[0, 0, 10], normal=[0, 0, 1])

pres = PowerSeriesProfile([1.25e-1, 0, -1.25e-1])
iota = PowerSeriesProfile([-4.9e-1, 0, 3.0e-1])
surf = FourierRZToroidalSurface(
R_lmn=[4.0, 1.0],
modes_R=[[0, 0], [1, 0]],
Z_lmn=[-1.0],
modes_Z=[[-1, 0]],
NFP=1,
)

eq = Equilibrium(M=6, N=0, Psi=1.0, surface=surf, pressure=pres, iota=iota)
obj = ObjectiveFunction(BoundaryError(eq, ext_field), use_jit=False)
obj.build()
g = obj.grad(obj.x(eq, ext_field))
assert not np.any(np.isnan(g)), "boundary error"

@pytest.mark.unit
def test_objective_no_nangrad_vacuum_boundary_error(self):
"""VacuumBoundaryError."""
Expand Down Expand Up @@ -3097,7 +3076,7 @@ def test_objective_no_nangrad_quadratic_flux(self):
@pytest.mark.unit
def test_objective_no_nangrad_quadratic_flux_minimizing(self):
"""SurfaceQuadraticFlux."""
ext_field = ToroidalMagneticField(1.0, 1.0)
ext_field = FourierPlanarCoil(normal=[0, 0, 1])

surf = FourierRZToroidalSurface(
R_lmn=[4.0, 1.0],
Expand Down

0 comments on commit e8726db

Please sign in to comment.