diff --git a/desc/compute/_curve.py b/desc/compute/_curve.py index 1be68dfe5..0db4e07a1 100644 --- a/desc/compute/_curve.py +++ b/desc/compute/_curve.py @@ -2,9 +2,15 @@ from desc.backend import jnp, sign -from ..utils import cross, dot, safenormalize +from ..utils import cross, dot from .data_index import register_compute_fun -from .geom_utils import rotation_matrix, rpz2xyz, rpz2xyz_vec, xyz2rpz, xyz2rpz_vec +from .geom_utils import ( + rotation_matrix_vector_vector, + rpz2xyz, + rpz2xyz_vec, + xyz2rpz, + xyz2rpz_vec, +) @register_compute_fun( @@ -208,9 +214,7 @@ def _x_FourierPlanarCurve(params, transforms, profiles, data, **kwargs): coords = jnp.array([X, Y, Z]).T # rotate into place Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis - axis = cross(Zaxis, normal) - angle = jnp.arccos(dot(Zaxis, safenormalize(normal))) - A = rotation_matrix(axis=axis, angle=angle) + A = rotation_matrix_vector_vector(Zaxis, normal) coords = jnp.matmul(coords, A.T) + center coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T) + params["shift"] # convert back to rpz @@ -248,9 +252,7 @@ def _x_s_FourierPlanarCurve(params, transforms, profiles, data, **kwargs): coords = jnp.array([dX, dY, dZ]).T # rotate into place Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis - axis = cross(Zaxis, normal) - angle = jnp.arccos(dot(Zaxis, safenormalize(normal))) - A = rotation_matrix(axis=axis, angle=angle) + A = rotation_matrix_vector_vector(Zaxis, normal) coords = jnp.matmul(coords, A.T) coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T) # convert back to rpz @@ -293,9 +295,7 @@ def _x_ss_FourierPlanarCurve(params, transforms, profiles, data, **kwargs): coords = jnp.array([d2X, d2Y, d2Z]).T # rotate into place Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis - axis = cross(Zaxis, normal) - angle = jnp.arccos(dot(Zaxis, safenormalize(normal))) - A = rotation_matrix(axis=axis, angle=angle) + A = rotation_matrix_vector_vector(Zaxis, normal) coords = jnp.matmul(coords, A.T) coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T) # convert back to rpz @@ -345,9 +345,7 @@ def _x_sss_FourierPlanarCurve(params, transforms, profiles, data, **kwargs): coords = jnp.array([d3X, d3Y, d3Z]).T # rotate into place Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis - axis = cross(Zaxis, normal) - angle = jnp.arccos(dot(Zaxis, safenormalize(normal))) - A = rotation_matrix(axis=axis, angle=angle) + A = rotation_matrix_vector_vector(Zaxis, normal) coords = jnp.matmul(coords, A.T) coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T) # convert back to rpz diff --git a/desc/compute/geom_utils.py b/desc/compute/geom_utils.py index eeda658b6..a51b88e41 100644 --- a/desc/compute/geom_utils.py +++ b/desc/compute/geom_utils.py @@ -4,7 +4,7 @@ from desc.backend import jnp -from ..utils import safenorm, safenormalize +from ..utils import cross, dot, safediv, safenorm, safenormalize def reflection_matrix(normal): @@ -28,6 +28,10 @@ def reflection_matrix(normal): def rotation_matrix(axis, angle=None): """Matrix to rotate points about axis by given angle. + NOTE: not correct if a and b are antiparallel, will + simply return identity when in reality negative identity + is correct. + Parameters ---------- axis : array-like, shape(3,) @@ -53,6 +57,45 @@ def rotation_matrix(axis, angle=None): return jnp.where(norm < eps, jnp.eye(3), R1 + R2 + R3) # if axis=0, no rotation +def _skew_matrix(a): + return jnp.array([[0, -a[2], a[1]], [a[2], 0, -a[0]], [-a[1], a[0], 0]]) + + +def rotation_matrix_vector_vector(a, b): + """Matrix to rotate vector a onto b. + + NOTE: not correct if a and b are antiparallel, will + simply return identity when in reality negative identity + is correct. + + Parameters + ---------- + a,b : array-like, shape(3,) + Vectors, in cartesian (X,Y,Z) coordinates + Matrix will correspond to rotating a onto b + + + Returns + ------- + rotmat : ndarray, shape(3,3) + Matrix to rotate points in cartesian (X,Y,Z) coordinates. + + """ + a = jnp.asarray(a) + b = jnp.asarray(b) + a = safenormalize(a) + b = safenormalize(b) + axis = cross(a, b) + norm = safenorm(axis) + axis = safenormalize(axis) + eps = 1e2 * jnp.finfo(axis.dtype).eps + skew = _skew_matrix(axis) + R1 = jnp.eye(3) + R2 = skew + R3 = (skew @ skew) * safediv(1, 1 + dot(a, b)) + return jnp.where(norm < eps, jnp.eye(3), R1 + R2 + R3) # if axis=0, no rotation + + def xyz2rpz(pts): """Transform points from cartesian (X,Y,Z) to polar (R,phi,Z) form. diff --git a/tests/test_curves.py b/tests/test_curves.py index 0b1af1d63..f56f816a2 100644 --- a/tests/test_curves.py +++ b/tests/test_curves.py @@ -615,15 +615,15 @@ def test_basis(self): xs_xyz = cxyz.compute("x_s")["x_s"] xs_rpz = crpz.compute("x_s")["x_s"] - np.testing.assert_allclose(xs_xyz, xs_rpz, atol=2e-15) + np.testing.assert_allclose(xs_xyz, xs_rpz, atol=3e-15) xss_xyz = cxyz.compute("x_ss")["x_ss"] xss_rpz = crpz.compute("x_ss")["x_ss"] - np.testing.assert_allclose(xss_xyz, xss_rpz, atol=2e-15) + np.testing.assert_allclose(xss_xyz, xss_rpz, atol=3e-15) xsss_xyz = cxyz.compute("x_sss")["x_sss"] xsss_rpz = crpz.compute("x_sss")["x_sss"] - np.testing.assert_allclose(xsss_xyz, xsss_rpz, atol=2e-15) + np.testing.assert_allclose(xsss_xyz, xsss_rpz, atol=3e-15) @pytest.mark.unit def test_misc(self): diff --git a/tests/test_geometry.py b/tests/test_geometry.py index f95cd9b23..a5bbbd243 100644 --- a/tests/test_geometry.py +++ b/tests/test_geometry.py @@ -5,6 +5,7 @@ from desc.compute.geom_utils import ( rotation_matrix, + rotation_matrix_vector_vector, rpz2xyz, rpz2xyz_vec, xyz2rpz, @@ -19,6 +20,9 @@ def test_rotation_matrix(): At = np.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]]) np.testing.assert_allclose(A, At, atol=1e-10) + A = rotation_matrix_vector_vector([1, 0, 0], [0, 1, 0]) + np.testing.assert_allclose(A, At, atol=1e-10) + @pytest.mark.unit def test_xyz2rpz(): diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index 26d966150..bda8829a6 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -3076,7 +3076,7 @@ def test_objective_no_nangrad_quadratic_flux(self): @pytest.mark.unit def test_objective_no_nangrad_quadratic_flux_minimizing(self): """SurfaceQuadraticFlux.""" - ext_field = ToroidalMagneticField(1.0, 1.0) + ext_field = FourierPlanarCoil(normal=[0, 0, 1]) surf = FourierRZToroidalSurface( R_lmn=[4.0, 1.0],