Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix curvy bugs #840

Merged
merged 23 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
8e59800
fix bug in PlanarCurve rotation matrix
daniel-dudt Jan 26, 2024
6a37dc4
replace if statement with nan_to_num call
daniel-dudt Jan 26, 2024
3887d6d
update tests
daniel-dudt Jan 27, 2024
3cfad71
make shift & rotmat optimizeable parameters of Curve
daniel-dudt Jan 29, 2024
e21202f
noting that fix does not work with AD
daniel-dudt Jan 29, 2024
7c8a040
replace unnecessary jnp with np in Curve class
daniel-dudt Jan 29, 2024
d25b289
update examples with private axis curve attributes
daniel-dudt Jan 29, 2024
983aeda
rotation_matrix work with AD at axis=[0, 0, 0]
daniel-dudt Jan 30, 2024
03a0002
ensure shift and rotmat are float arrays
daniel-dudt Jan 30, 2024
a88c45a
resolve merge conflicts with master
daniel-dudt Feb 1, 2024
93ce772
repairing tests
daniel-dudt Feb 1, 2024
dd432e4
repair plot_coils test
daniel-dudt Feb 1, 2024
e236ff8
Merge branch 'master' into dd/curve
ddudt Feb 2, 2024
0b08b9d
Merge branch 'master' into dd/curve
dpanici Feb 4, 2024
1a5c1f2
safenormalize util function to make rotation_matrix work with angle i…
daniel-dudt Feb 6, 2024
8edb49d
revert changes to use angle input again
daniel-dudt Feb 6, 2024
eb393ff
merge with master
daniel-dudt Feb 6, 2024
d7d0e07
increase test tolerance
daniel-dudt Feb 6, 2024
66b92a9
forgot to revert one change in a test
daniel-dudt Feb 6, 2024
7a6cd48
add FixCurve constraints to maybe_add_self_consistency
daniel-dudt Feb 7, 2024
8f1d3fd
change eq reference to thing
daniel-dudt Feb 7, 2024
ac0fc14
better checking for class types based on attrs
daniel-dudt Feb 7, 2024
18ac783
revert change to self-consistent constraint order
daniel-dudt Feb 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 63 additions & 36 deletions desc/compute/_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,7 @@
from desc.backend import jnp

from .data_index import register_compute_fun
from .geom_utils import (
_rotation_matrix_from_normal,
rpz2xyz,
rpz2xyz_vec,
xyz2rpz,
xyz2rpz_vec,
)
from .geom_utils import rotation_matrix, rpz2xyz, rpz2xyz_vec, xyz2rpz, xyz2rpz_vec
from .utils import cross, dot


Expand Down Expand Up @@ -184,16 +178,19 @@
basis="basis",
)
def _x_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
# create planar curve at z==0
# create planar curve at Z==0
r = transforms["r"].transform(params["r_n"], dz=0)
Z = jnp.zeros_like(r)
X = r * jnp.cos(data["s"])
Y = r * jnp.sin(data["s"])
coords = jnp.array([X, Y, Z]).T
# rotate into place
R = _rotation_matrix_from_normal(params["normal"])
coords = jnp.matmul(coords, R.T) + params["center"]
coords = jnp.matmul(coords, params["rotmat"].T) + params["shift"]
Zaxis = jnp.array([0, 0, 1]) # 2D curve in X-Y plane has normal = +Z axis
axis = cross(Zaxis, params["normal"])
angle = jnp.arccos(dot(Zaxis, params["normal"]) / jnp.linalg.norm(params["normal"]))
A = rotation_matrix(axis=axis, angle=angle)
coords = jnp.matmul(coords, A.T) + params["center"]
coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T) + params["shift"]

Check warning on line 193 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L188-L193

Added lines #L188 - L193 were not covered by tests
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = xyz2rpz(coords)
data["x"] = coords
Expand Down Expand Up @@ -224,16 +221,21 @@
dY = dr * jnp.sin(data["s"]) + r * jnp.cos(data["s"])
dZ = jnp.zeros_like(dX)
coords = jnp.array([dX, dY, dZ]).T
A = _rotation_matrix_from_normal(params["normal"])
Zaxis = jnp.array([0, 0, 1]) # 2D curve in X-Y plane has normal = +Z axis
axis = cross(Zaxis, params["normal"])
angle = jnp.arccos(dot(Zaxis, params["normal"]) / jnp.linalg.norm(params["normal"]))
A = rotation_matrix(axis=axis, angle=angle)

Check warning on line 227 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L224-L227

Added lines #L224 - L227 were not covered by tests
coords = jnp.matmul(coords, A.T)
coords = jnp.matmul(coords, params["rotmat"].T)
coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T)

Check warning on line 229 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L229

Added line #L229 was not covered by tests
if kwargs.get("basis", "rpz").lower() == "rpz":
X = r * jnp.cos(data["s"])
Y = r * jnp.sin(data["s"])
Z = jnp.zeros_like(X)
xyzcoords = jnp.array([X, Y, Z]).T
xyzcoords = jnp.matmul(xyzcoords, A.T) + params["center"]
xyzcoords = jnp.matmul(xyzcoords, params["rotmat"].T) + params["shift"]
xyzcoords = (

Check warning on line 236 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L236

Added line #L236 was not covered by tests
jnp.matmul(xyzcoords, params["rotmat"].reshape((3, 3)).T) + params["shift"]
)
x, y, z = xyzcoords.T
coords = xyz2rpz_vec(coords, x=x, y=y)
data["x_s"] = coords
Expand Down Expand Up @@ -269,16 +271,21 @@
)
d2Z = jnp.zeros_like(d2X)
coords = jnp.array([d2X, d2Y, d2Z]).T
A = _rotation_matrix_from_normal(params["normal"])
Zaxis = jnp.array([0, 0, 1]) # 2D curve in X-Y plane has normal = +Z axis
axis = cross(Zaxis, params["normal"])
angle = jnp.arccos(dot(Zaxis, params["normal"]) / jnp.linalg.norm(params["normal"]))
A = rotation_matrix(axis=axis, angle=angle)

Check warning on line 277 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L274-L277

Added lines #L274 - L277 were not covered by tests
coords = jnp.matmul(coords, A.T)
coords = jnp.matmul(coords, params["rotmat"].T)
coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T)

Check warning on line 279 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L279

Added line #L279 was not covered by tests
if kwargs.get("basis", "rpz").lower() == "rpz":
X = r * jnp.cos(data["s"])
Y = r * jnp.sin(data["s"])
Z = jnp.zeros_like(X)
xyzcoords = jnp.array([X, Y, Z]).T
xyzcoords = jnp.matmul(xyzcoords, A.T) + params["center"]
xyzcoords = jnp.matmul(xyzcoords, params["rotmat"].T) + params["shift"]
xyzcoords = (

Check warning on line 286 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L286

Added line #L286 was not covered by tests
jnp.matmul(xyzcoords, params["rotmat"].reshape((3, 3)).T) + params["shift"]
)
x, y, z = xyzcoords.T
coords = xyz2rpz_vec(coords, x=x, y=y)
data["x_ss"] = coords
Expand Down Expand Up @@ -321,16 +328,21 @@
)
d3Z = jnp.zeros_like(d3X)
coords = jnp.array([d3X, d3Y, d3Z]).T
A = _rotation_matrix_from_normal(params["normal"])
Zaxis = jnp.array([0, 0, 1]) # 2D curve in X-Y plane has normal = +Z axis
axis = cross(Zaxis, params["normal"])
angle = jnp.arccos(dot(Zaxis, params["normal"]) / jnp.linalg.norm(params["normal"]))
A = rotation_matrix(axis=axis, angle=angle)

Check warning on line 334 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L331-L334

Added lines #L331 - L334 were not covered by tests
coords = jnp.matmul(coords, A.T)
coords = jnp.matmul(coords, params["rotmat"].T)
coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T)

Check warning on line 336 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L336

Added line #L336 was not covered by tests
if kwargs.get("basis", "rpz").lower() == "rpz":
X = r * jnp.cos(data["s"])
Y = r * jnp.sin(data["s"])
Z = jnp.zeros_like(X)
xyzcoords = jnp.array([X, Y, Z]).T
xyzcoords = jnp.matmul(xyzcoords, A.T) + params["center"]
xyzcoords = jnp.matmul(xyzcoords, params["rotmat"].T) + params["shift"]
xyzcoords = (

Check warning on line 343 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L343

Added line #L343 was not covered by tests
jnp.matmul(xyzcoords, params["rotmat"].reshape((3, 3)).T) + params["shift"]
)
x, y, z = xyzcoords.T
coords = xyz2rpz_vec(coords, x=x, y=y)
data["x_sss"] = coords
Expand Down Expand Up @@ -363,7 +375,9 @@
coords = jnp.stack([R, phi, Z], axis=1)
# convert to xyz for displacement and rotation
coords = rpz2xyz(coords)
coords = coords @ params["rotmat"].T + params["shift"][jnp.newaxis, :]
coords = (

Check warning on line 378 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L378

Added line #L378 was not covered by tests
coords @ params["rotmat"].reshape((3, 3)).T + params["shift"][jnp.newaxis, :]
)
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = xyz2rpz(coords)
data["x"] = coords
Expand Down Expand Up @@ -397,7 +411,7 @@
coords = jnp.stack([dR, dphi, dZ], axis=1)
# convert to xyz for displacement and rotation
coords = rpz2xyz_vec(coords, phi=transforms["grid"].nodes[:, 2])
coords = coords @ params["rotmat"].T
coords = coords @ params["rotmat"].reshape((3, 3)).T

Check warning on line 414 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L414

Added line #L414 was not covered by tests
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = xyz2rpz_vec(coords, phi=transforms["grid"].nodes[:, 2])
data["x_s"] = coords
Expand Down Expand Up @@ -435,7 +449,7 @@
coords = jnp.stack([R, phi, Z], axis=1)
# convert to xyz for displacement and rotation
coords = rpz2xyz_vec(coords, phi=transforms["grid"].nodes[:, 2])
coords = coords @ params["rotmat"].T
coords = coords @ params["rotmat"].reshape((3, 3)).T

Check warning on line 452 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L452

Added line #L452 was not covered by tests
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = xyz2rpz_vec(coords, phi=transforms["grid"].nodes[:, 2])
data["x_ss"] = coords
Expand Down Expand Up @@ -473,7 +487,7 @@
coords = jnp.stack([R, phi, Z], axis=1)
# convert to xyz for displacement and rotation
coords = rpz2xyz_vec(coords, phi=transforms["grid"].nodes[:, 2])
coords = coords @ params["rotmat"].T
coords = coords @ params["rotmat"].reshape((3, 3)).T

Check warning on line 490 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L490

Added line #L490 was not covered by tests
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = xyz2rpz_vec(coords, phi=transforms["grid"].nodes[:, 2])
data["x_sss"] = coords
Expand Down Expand Up @@ -504,7 +518,9 @@
Y = transforms["Y"].transform(params["Y_n"], dz=0)
Z = transforms["Z"].transform(params["Z_n"], dz=0)
coords = jnp.stack([X, Y, Z], axis=1)
coords = coords @ params["rotmat"].T + params["shift"][jnp.newaxis, :]
coords = (

Check warning on line 521 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L521

Added line #L521 was not covered by tests
coords @ params["rotmat"].reshape((3, 3)).T + params["shift"][jnp.newaxis, :]
)
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = xyz2rpz(coords)
data["x"] = coords
Expand Down Expand Up @@ -535,7 +551,7 @@
dY = transforms["Y"].transform(params["Y_n"], dz=1)
dZ = transforms["Z"].transform(params["Z_n"], dz=1)
coords = jnp.stack([dX, dY, dZ], axis=1)
coords = coords @ params["rotmat"].T
coords = coords @ params["rotmat"].reshape((3, 3)).T

Check warning on line 554 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L554

Added line #L554 was not covered by tests
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = xyz2rpz_vec(
coords,
Expand Down Expand Up @@ -570,7 +586,7 @@
d2Y = transforms["Y"].transform(params["Y_n"], dz=2)
d2Z = transforms["Z"].transform(params["Z_n"], dz=2)
coords = jnp.stack([d2X, d2Y, d2Z], axis=1)
coords = coords @ params["rotmat"].T
coords = coords @ params["rotmat"].reshape((3, 3)).T

Check warning on line 589 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L589

Added line #L589 was not covered by tests
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = xyz2rpz_vec(
coords,
Expand Down Expand Up @@ -605,7 +621,7 @@
d3Y = transforms["Y"].transform(params["Y_n"], dz=3)
d3Z = transforms["Z"].transform(params["Z_n"], dz=3)
coords = jnp.stack([d3X, d3Y, d3Z], axis=1)
coords = coords @ params["rotmat"].T
coords = coords @ params["rotmat"].reshape((3, 3)).T

Check warning on line 624 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L624

Added line #L624 was not covered by tests
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = xyz2rpz_vec(
coords,
Expand Down Expand Up @@ -662,7 +678,9 @@
)

coords = jnp.stack([Xq, Yq, Zq], axis=1)
coords = coords @ params["rotmat"].T + params["shift"][jnp.newaxis, :]
coords = (

Check warning on line 681 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L681

Added line #L681 was not covered by tests
coords @ params["rotmat"].reshape((3, 3)).T + params["shift"][jnp.newaxis, :]
)
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = xyz2rpz(coords)
data["x"] = coords
Expand Down Expand Up @@ -715,7 +733,7 @@
)

coords_s = jnp.stack([dXq, dYq, dZq], axis=1)
coords_s = coords_s @ params["rotmat"].T
coords_s = coords_s @ params["rotmat"].reshape((3, 3)).T

Check warning on line 736 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L736

Added line #L736 was not covered by tests

if kwargs.get("basis", "rpz").lower() == "rpz":
# calculate the xy coordinates to rotate to rpz
Expand Down Expand Up @@ -745,7 +763,10 @@
)

coords = jnp.stack([Xq, Yq, Zq], axis=1)
coords = coords @ params["rotmat"].T + params["shift"][jnp.newaxis, :]
coords = (

Check warning on line 766 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L766

Added line #L766 was not covered by tests
coords @ params["rotmat"].reshape((3, 3)).T
+ params["shift"][jnp.newaxis, :]
)

coords_s = xyz2rpz_vec(coords_s, x=coords[:, 0], y=coords[:, 1])
data["x_s"] = coords_s
Expand Down Expand Up @@ -798,7 +819,7 @@
)

coords_ss = jnp.stack([d2Xq, d2Yq, d2Zq], axis=1)
coords_ss = coords_ss @ params["rotmat"].T
coords_ss = coords_ss @ params["rotmat"].reshape((3, 3)).T

Check warning on line 822 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L822

Added line #L822 was not covered by tests

if kwargs.get("basis", "rpz").lower() == "rpz":
# calculate the xy coordinates to rotate to rpz
Expand Down Expand Up @@ -827,7 +848,10 @@
period=2 * jnp.pi,
)
coords = jnp.stack([Xq, Yq, Zq], axis=1)
coords = coords @ params["rotmat"].T + params["shift"][jnp.newaxis, :]
coords = (

Check warning on line 851 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L851

Added line #L851 was not covered by tests
coords @ params["rotmat"].reshape((3, 3)).T
+ params["shift"][jnp.newaxis, :]
)

coords_ss = xyz2rpz_vec(coords_ss, x=coords[:, 0], y=coords[:, 1])
data["x_ss"] = coords_ss
Expand Down Expand Up @@ -880,7 +904,7 @@
)

coords_sss = jnp.stack([d3Xq, d3Yq, d3Zq], axis=1)
coords_sss = coords_sss @ params["rotmat"].T
coords_sss = coords_sss @ params["rotmat"].reshape((3, 3)).T

Check warning on line 907 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L907

Added line #L907 was not covered by tests

if kwargs.get("basis", "rpz").lower() == "rpz":
# calculate the xy coordinates to rotate to rpz
Expand Down Expand Up @@ -909,7 +933,10 @@
period=2 * jnp.pi,
)
coords = jnp.stack([Xq, Yq, Zq], axis=1)
coords = coords @ params["rotmat"].T + params["shift"][jnp.newaxis, :]
coords = (

Check warning on line 936 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L936

Added line #L936 was not covered by tests
coords @ params["rotmat"].reshape((3, 3)).T
+ params["shift"][jnp.newaxis, :]
)

coords_sss = xyz2rpz_vec(coords_sss, x=coords[:, 0], y=coords[:, 1])
data["x_sss"] = coords_sss
Expand Down Expand Up @@ -976,7 +1003,7 @@
def _frenet_binormal(params, transforms, profiles, data, **kwargs):
data["frenet_binormal"] = cross(
data["frenet_tangent"], data["frenet_normal"]
) * jnp.linalg.det(params["rotmat"])
) * jnp.linalg.det(params["rotmat"].reshape((3, 3)))
return data


Expand Down
21 changes: 5 additions & 16 deletions desc/compute/geom_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@
-------
rot : ndarray, shape(3,3)
Matrix to rotate points in cartesian (X,Y,Z) coordinates

"""
# FIXME: does not work with AD
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could probably add a cond here that checks if axis is all zero (or otherwise small) and just returns the identity matrix in that case? I think that might solve the AD issue

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I got that to work, but I had to get rid of the optional angle argument and make it always use the norm of axis as the angle. The issue is that the derivative of jax.linalg.norm(x) is NaN when x = 0, in both forward and reverse mode. So we can't call norm(axis) outside of the cond branch where it is guaranteed to be nonzero.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is what you want:

DESC/desc/compute/utils.py

Lines 492 to 493 in c837cc5

def safenorm(x, ord=None, axis=None, fill=0, threshold=0):
"""Like jnp.linalg.norm, but without nan gradient at x=0.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can use safediv and keep the angle argument, I prefer that as it is much clearer than having the angle be the norm of the axis

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I have this working now with the angle argument. I also created a new util function safenormalize to normalize vectors to unit length

axis = jnp.asarray(axis)
norm = jnp.linalg.norm(axis)

Check warning on line 44 in desc/compute/geom_utils.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/geom_utils.py#L44

Added line #L44 was not covered by tests
if angle is None:
angle = jnp.linalg.norm(axis)
axis = axis / jnp.linalg.norm(axis)
angle = norm
axis = jnp.nan_to_num(axis / norm)

Check warning on line 47 in desc/compute/geom_utils.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/geom_utils.py#L46-L47

Added lines #L46 - L47 were not covered by tests
R1 = jnp.cos(angle) * jnp.eye(3)
R2 = jnp.sin(angle) * jnp.cross(axis, jnp.identity(axis.shape[0]) * -1)
R3 = (1 - jnp.cos(angle)) * jnp.outer(axis, axis)
Expand Down Expand Up @@ -152,17 +155,3 @@
return cart

return inner(vec, phi)


def _rotation_matrix_from_normal(normal):
nx, ny, nz = normal
nxny = jnp.sqrt(nx**2 + ny**2)
R = jnp.array(
[
[ny / nxny, -nx / nxny, 0],
[nx * nx / nxny, ny * nz / nxny, -nxny],
[nx, ny, nz],
]
).T
R = jnp.where(nxny == 0, jnp.eye(3), R)
return R
Binary file modified desc/examples/ARIES-CS_output.h5
Binary file not shown.
Binary file modified desc/examples/ATF_output.h5
Binary file not shown.
Binary file modified desc/examples/DSHAPE_CURRENT_output.h5
Binary file not shown.
Binary file modified desc/examples/DSHAPE_output.h5
Binary file not shown.
Binary file modified desc/examples/ESTELL_output.h5
Binary file not shown.
Binary file modified desc/examples/HELIOTRON_output.h5
Binary file not shown.
Binary file modified desc/examples/HSX_output.h5
Binary file not shown.
Binary file modified desc/examples/NCSX_output.h5
Binary file not shown.
Binary file modified desc/examples/SOLOVEV_output.h5
Binary file not shown.
Binary file modified desc/examples/W7-X_output.h5
Binary file not shown.
Binary file modified desc/examples/WISTELL-A_output.h5
Binary file not shown.
Binary file modified desc/examples/precise_QA_output.h5
Binary file not shown.
Binary file modified desc/examples/precise_QH_output.h5
Binary file not shown.
39 changes: 32 additions & 7 deletions desc/geometry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import numpy as np

from desc.backend import jnp
from desc.compute import compute as compute_fun
from desc.compute import data_index
from desc.compute.geom_utils import reflection_matrix, rotation_matrix
Expand All @@ -17,18 +16,44 @@
)
from desc.grid import LinearGrid, QuadratureGrid, _Grid
from desc.io import IOAble
from desc.optimizable import Optimizable
from desc.optimizable import Optimizable, optimizable_parameter


class Curve(IOAble, Optimizable, ABC):
"""Abstract base class for 1D curves in 3D space."""

_io_attrs_ = ["_name", "shift", "rotmat"]
_io_attrs_ = ["_name", "_shift", "_rotmat"]

def __init__(self, name=""):
self.shift = jnp.array([0, 0, 0]).astype(float)
self.rotmat = jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]).astype(float)
self.name = name
self._shift = np.array([0, 0, 0])
self._rotmat = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]).flatten()
self._name = name

@optimizable_parameter
f0uriest marked this conversation as resolved.
Show resolved Hide resolved
@property
def shift(self):
"""Displacement of curve in X, Y, Z."""
return self._shift

Check warning on line 36 in desc/geometry/core.py

View check run for this annotation

Codecov / codecov/patch

desc/geometry/core.py#L36

Added line #L36 was not covered by tests

@shift.setter
def shift(self, new):
if len(new) == 3:
self._shift = np.asarray(new)

Check warning on line 41 in desc/geometry/core.py

View check run for this annotation

Codecov / codecov/patch

desc/geometry/core.py#L40-L41

Added lines #L40 - L41 were not covered by tests
else:
raise ValueError("shift should be a 3 element vector, got {}".format(new))

Check warning on line 43 in desc/geometry/core.py

View check run for this annotation

Codecov / codecov/patch

desc/geometry/core.py#L43

Added line #L43 was not covered by tests

@optimizable_parameter
@property
def rotmat(self):
"""Rotation matrix of curve in X, Y, Z."""
return self._rotmat

Check warning on line 49 in desc/geometry/core.py

View check run for this annotation

Codecov / codecov/patch

desc/geometry/core.py#L49

Added line #L49 was not covered by tests

@rotmat.setter
def rotmat(self, new):
if len(new) == 9:
self._rotmat = np.asarray(new)

Check warning on line 54 in desc/geometry/core.py

View check run for this annotation

Codecov / codecov/patch

desc/geometry/core.py#L53-L54

Added lines #L53 - L54 were not covered by tests
else:
self._rotmat = np.asarray(new.flatten())

Check warning on line 56 in desc/geometry/core.py

View check run for this annotation

Codecov / codecov/patch

desc/geometry/core.py#L56

Added line #L56 was not covered by tests

@property
def name(self):
Expand Down Expand Up @@ -144,7 +169,7 @@

def translate(self, displacement=[0, 0, 0]):
"""Translate the curve by a rigid displacement in x, y, z."""
self.shift += jnp.asarray(displacement)
self.shift += np.asarray(displacement)

Check warning on line 172 in desc/geometry/core.py

View check run for this annotation

Codecov / codecov/patch

desc/geometry/core.py#L172

Added line #L172 was not covered by tests

def rotate(self, axis=[0, 0, 1], angle=0):
f0uriest marked this conversation as resolved.
Show resolved Hide resolved
"""Rotate the curve by a fixed angle about axis in xyz coordinates."""
Expand Down
Binary file modified tests/inputs/DSHAPE_output_saved_without_current.h5
Binary file not shown.
Binary file modified tests/inputs/LandremanPaul2022_QA_reactorScale_lowRes.h5
Binary file not shown.
Binary file modified tests/inputs/LandremanPaul2022_QH_reactorScale_lowRes.h5
Binary file not shown.
Binary file modified tests/inputs/circular_model_tokamak_output.h5
Binary file not shown.
Binary file modified tests/inputs/iotest_HELIOTRON.h5
Binary file not shown.
Binary file modified tests/inputs/master_compute_data.pkl
Binary file not shown.
8 changes: 4 additions & 4 deletions tests/test_coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,8 @@ def test_save_and_load_makegrid_coils_rotated(tmpdir_factory):
Z2 = coords2[:, 2]

np.testing.assert_allclose(c1.current, c2.current, err_msg=f"Coil {i}")
np.testing.assert_allclose(X1, X2, err_msg=f"Coil {i}")
np.testing.assert_allclose(Y1, Y2, err_msg=f"Coil {i}")
np.testing.assert_allclose(X1, X2, err_msg=f"Coil {i}", atol=1e-16)
np.testing.assert_allclose(Y1, Y2, err_msg=f"Coil {i}", atol=1e-16)
np.testing.assert_allclose(Z1, Z2, atol=2e-7, err_msg=f"Coil {i}")

# check Bnormal on torus and ensure is near zero
Expand Down Expand Up @@ -564,8 +564,8 @@ def test_save_and_load_makegrid_coils_rotated_int_grid(tmpdir_factory):
Z2 = coords2[:, 2]

np.testing.assert_allclose(c1.current, c2.current, err_msg=f"Coil {i}")
np.testing.assert_allclose(X1, X2, err_msg=f"Coil {i}")
np.testing.assert_allclose(Y1, Y2, err_msg=f"Coil {i}")
np.testing.assert_allclose(X1, X2, err_msg=f"Coil {i}", atol=1e-16)
np.testing.assert_allclose(Y1, Y2, err_msg=f"Coil {i}", atol=1e-16)
np.testing.assert_allclose(Z1, Z2, atol=2e-7, err_msg=f"Coil {i}")

# check Bnormal on torus and ensure is near zero
Expand Down
Loading
Loading