Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix curvy bugs #840

Merged
merged 23 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
8e59800
fix bug in PlanarCurve rotation matrix
daniel-dudt Jan 26, 2024
6a37dc4
replace if statement with nan_to_num call
daniel-dudt Jan 26, 2024
3887d6d
update tests
daniel-dudt Jan 27, 2024
3cfad71
make shift & rotmat optimizeable parameters of Curve
daniel-dudt Jan 29, 2024
e21202f
noting that fix does not work with AD
daniel-dudt Jan 29, 2024
7c8a040
replace unnecessary jnp with np in Curve class
daniel-dudt Jan 29, 2024
d25b289
update examples with private axis curve attributes
daniel-dudt Jan 29, 2024
983aeda
rotation_matrix work with AD at axis=[0, 0, 0]
daniel-dudt Jan 30, 2024
03a0002
ensure shift and rotmat are float arrays
daniel-dudt Jan 30, 2024
a88c45a
resolve merge conflicts with master
daniel-dudt Feb 1, 2024
93ce772
repairing tests
daniel-dudt Feb 1, 2024
dd432e4
repair plot_coils test
daniel-dudt Feb 1, 2024
e236ff8
Merge branch 'master' into dd/curve
ddudt Feb 2, 2024
0b08b9d
Merge branch 'master' into dd/curve
dpanici Feb 4, 2024
1a5c1f2
safenormalize util function to make rotation_matrix work with angle i…
daniel-dudt Feb 6, 2024
8edb49d
revert changes to use angle input again
daniel-dudt Feb 6, 2024
eb393ff
merge with master
daniel-dudt Feb 6, 2024
d7d0e07
increase test tolerance
daniel-dudt Feb 6, 2024
66b92a9
forgot to revert one change in a test
daniel-dudt Feb 6, 2024
7a6cd48
add FixCurve constraints to maybe_add_self_consistency
daniel-dudt Feb 7, 2024
8f1d3fd
change eq reference to thing
daniel-dudt Feb 7, 2024
ac0fc14
better checking for class types based on attrs
daniel-dudt Feb 7, 2024
18ac783
revert change to self-consistent constraint order
daniel-dudt Feb 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions desc/coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,21 +883,22 @@ def linspaced_angular(
axis : array-like, shape(3,)
axis to rotate about
angle : float
total rotational extend of coil set.
total rotational extent of coil set
n : int
number of copies of original coil
endpoint : bool
whether to include a coil at final angle

"""
assert isinstance(coil, _Coil) and not isinstance(coil, CoilSet)
if current is None:
current = coil.current
currents = jnp.broadcast_to(current, (n,))
phi = jnp.linspace(0, angle, n, endpoint=endpoint)
coils = []
phis = jnp.linspace(0, angle, n, endpoint=endpoint)
for i in range(n):
coili = coil.copy()
coili.rotate(axis, angle=phis[i])
coili.rotate(axis=axis, angle=phi[i])
coili.current = currents[i]
coils.append(coili)
return cls(*coils)
Expand All @@ -920,14 +921,15 @@ def linspaced_linear(
number of copies of original coil
endpoint : bool
whether to include a coil at final point

"""
assert isinstance(coil, _Coil) and not isinstance(coil, CoilSet)
if current is None:
current = coil.current
currents = jnp.broadcast_to(current, (n,))
displacement = jnp.asarray(displacement)
coils = []
a = jnp.linspace(0, 1, n, endpoint=endpoint)
coils = []
for i in range(n):
coili = coil.copy()
coili.translate(a[i] * displacement)
Expand All @@ -953,6 +955,7 @@ def from_symmetry(cls, coils, NFP, sym=False):
number of field periods
sym : bool
whether coils should be stellarator symmetric

"""
if not isinstance(coils, CoilSet):
coils = CoilSet(coils)
Expand Down
104 changes: 67 additions & 37 deletions desc/compute/_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,8 @@
from desc.backend import jnp

from .data_index import register_compute_fun
from .geom_utils import (
_rotation_matrix_from_normal,
rpz2xyz,
rpz2xyz_vec,
xyz2rpz,
xyz2rpz_vec,
)
from .utils import cross, dot
from .geom_utils import rotation_matrix, rpz2xyz, rpz2xyz_vec, xyz2rpz, xyz2rpz_vec
from .utils import cross, dot, safenormalize


@register_compute_fun(
Expand Down Expand Up @@ -184,16 +178,19 @@ def _Z_Curve(params, transforms, profiles, data, **kwargs):
basis="basis",
)
def _x_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
# create planar curve at z==0
# create planar curve at Z==0
r = transforms["r"].transform(params["r_n"], dz=0)
Z = jnp.zeros_like(r)
X = r * jnp.cos(data["s"])
Y = r * jnp.sin(data["s"])
coords = jnp.array([X, Y, Z]).T
# rotate into place
R = _rotation_matrix_from_normal(params["normal"])
coords = jnp.matmul(coords, R.T) + params["center"]
coords = jnp.matmul(coords, params["rotmat"].T) + params["shift"]
Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis
axis = cross(Zaxis, params["normal"])
angle = jnp.arccos(dot(Zaxis, safenormalize(params["normal"])))
A = rotation_matrix(axis=axis, angle=angle)
coords = jnp.matmul(coords, A.T) + params["center"]
coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T) + params["shift"]
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = xyz2rpz(coords)
data["x"] = coords
Expand Down Expand Up @@ -224,16 +221,22 @@ def _x_s_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
dY = dr * jnp.sin(data["s"]) + r * jnp.cos(data["s"])
dZ = jnp.zeros_like(dX)
coords = jnp.array([dX, dY, dZ]).T
A = _rotation_matrix_from_normal(params["normal"])
# rotate into place
Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis
axis = cross(Zaxis, params["normal"])
angle = jnp.arccos(dot(Zaxis, safenormalize(params["normal"])))
A = rotation_matrix(axis=axis, angle=angle)
coords = jnp.matmul(coords, A.T)
coords = jnp.matmul(coords, params["rotmat"].T)
coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T)
if kwargs.get("basis", "rpz").lower() == "rpz":
X = r * jnp.cos(data["s"])
Y = r * jnp.sin(data["s"])
Z = jnp.zeros_like(X)
xyzcoords = jnp.array([X, Y, Z]).T
xyzcoords = jnp.matmul(xyzcoords, A.T) + params["center"]
xyzcoords = jnp.matmul(xyzcoords, params["rotmat"].T) + params["shift"]
xyzcoords = (
jnp.matmul(xyzcoords, params["rotmat"].reshape((3, 3)).T) + params["shift"]
)
x, y, z = xyzcoords.T
coords = xyz2rpz_vec(coords, x=x, y=y)
data["x_s"] = coords
Expand Down Expand Up @@ -269,16 +272,22 @@ def _x_ss_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
)
d2Z = jnp.zeros_like(d2X)
coords = jnp.array([d2X, d2Y, d2Z]).T
A = _rotation_matrix_from_normal(params["normal"])
# rotate into place
Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis
axis = cross(Zaxis, params["normal"])
angle = jnp.arccos(dot(Zaxis, safenormalize(params["normal"])))
A = rotation_matrix(axis=axis, angle=angle)
coords = jnp.matmul(coords, A.T)
coords = jnp.matmul(coords, params["rotmat"].T)
coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T)
if kwargs.get("basis", "rpz").lower() == "rpz":
X = r * jnp.cos(data["s"])
Y = r * jnp.sin(data["s"])
Z = jnp.zeros_like(X)
xyzcoords = jnp.array([X, Y, Z]).T
xyzcoords = jnp.matmul(xyzcoords, A.T) + params["center"]
xyzcoords = jnp.matmul(xyzcoords, params["rotmat"].T) + params["shift"]
xyzcoords = (
jnp.matmul(xyzcoords, params["rotmat"].reshape((3, 3)).T) + params["shift"]
)
x, y, z = xyzcoords.T
coords = xyz2rpz_vec(coords, x=x, y=y)
data["x_ss"] = coords
Expand Down Expand Up @@ -321,16 +330,22 @@ def _x_sss_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
)
d3Z = jnp.zeros_like(d3X)
coords = jnp.array([d3X, d3Y, d3Z]).T
A = _rotation_matrix_from_normal(params["normal"])
# rotate into place
Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis
axis = cross(Zaxis, params["normal"])
angle = jnp.arccos(dot(Zaxis, safenormalize(params["normal"])))
A = rotation_matrix(axis=axis, angle=angle)
coords = jnp.matmul(coords, A.T)
coords = jnp.matmul(coords, params["rotmat"].T)
coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T)
if kwargs.get("basis", "rpz").lower() == "rpz":
X = r * jnp.cos(data["s"])
Y = r * jnp.sin(data["s"])
Z = jnp.zeros_like(X)
xyzcoords = jnp.array([X, Y, Z]).T
xyzcoords = jnp.matmul(xyzcoords, A.T) + params["center"]
xyzcoords = jnp.matmul(xyzcoords, params["rotmat"].T) + params["shift"]
xyzcoords = (
jnp.matmul(xyzcoords, params["rotmat"].reshape((3, 3)).T) + params["shift"]
)
x, y, z = xyzcoords.T
coords = xyz2rpz_vec(coords, x=x, y=y)
data["x_sss"] = coords
Expand Down Expand Up @@ -363,7 +378,9 @@ def _x_FourierRZCurve(params, transforms, profiles, data, **kwargs):
coords = jnp.stack([R, phi, Z], axis=1)
# convert to xyz for displacement and rotation
coords = rpz2xyz(coords)
coords = coords @ params["rotmat"].T + params["shift"][jnp.newaxis, :]
coords = (
coords @ params["rotmat"].reshape((3, 3)).T + params["shift"][jnp.newaxis, :]
)
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = xyz2rpz(coords)
data["x"] = coords
Expand Down Expand Up @@ -397,7 +414,7 @@ def _x_s_FourierRZCurve(params, transforms, profiles, data, **kwargs):
coords = jnp.stack([dR, dphi, dZ], axis=1)
# convert to xyz for displacement and rotation
coords = rpz2xyz_vec(coords, phi=transforms["grid"].nodes[:, 2])
coords = coords @ params["rotmat"].T
coords = coords @ params["rotmat"].reshape((3, 3)).T
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = xyz2rpz_vec(coords, phi=transforms["grid"].nodes[:, 2])
data["x_s"] = coords
Expand Down Expand Up @@ -435,7 +452,7 @@ def _x_ss_FourierRZCurve(params, transforms, profiles, data, **kwargs):
coords = jnp.stack([R, phi, Z], axis=1)
# convert to xyz for displacement and rotation
coords = rpz2xyz_vec(coords, phi=transforms["grid"].nodes[:, 2])
coords = coords @ params["rotmat"].T
coords = coords @ params["rotmat"].reshape((3, 3)).T
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = xyz2rpz_vec(coords, phi=transforms["grid"].nodes[:, 2])
data["x_ss"] = coords
Expand Down Expand Up @@ -473,7 +490,7 @@ def _x_sss_FourierRZCurve(params, transforms, profiles, data, **kwargs):
coords = jnp.stack([R, phi, Z], axis=1)
# convert to xyz for displacement and rotation
coords = rpz2xyz_vec(coords, phi=transforms["grid"].nodes[:, 2])
coords = coords @ params["rotmat"].T
coords = coords @ params["rotmat"].reshape((3, 3)).T
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = xyz2rpz_vec(coords, phi=transforms["grid"].nodes[:, 2])
data["x_sss"] = coords
Expand Down Expand Up @@ -504,7 +521,9 @@ def _x_FourierXYZCurve(params, transforms, profiles, data, **kwargs):
Y = transforms["Y"].transform(params["Y_n"], dz=0)
Z = transforms["Z"].transform(params["Z_n"], dz=0)
coords = jnp.stack([X, Y, Z], axis=1)
coords = coords @ params["rotmat"].T + params["shift"][jnp.newaxis, :]
coords = (
coords @ params["rotmat"].reshape((3, 3)).T + params["shift"][jnp.newaxis, :]
)
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = xyz2rpz(coords)
data["x"] = coords
Expand Down Expand Up @@ -535,7 +554,7 @@ def _x_s_FourierXYZCurve(params, transforms, profiles, data, **kwargs):
dY = transforms["Y"].transform(params["Y_n"], dz=1)
dZ = transforms["Z"].transform(params["Z_n"], dz=1)
coords = jnp.stack([dX, dY, dZ], axis=1)
coords = coords @ params["rotmat"].T
coords = coords @ params["rotmat"].reshape((3, 3)).T
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = xyz2rpz_vec(
coords,
Expand Down Expand Up @@ -570,7 +589,7 @@ def _x_ss_FourierXYZCurve(params, transforms, profiles, data, **kwargs):
d2Y = transforms["Y"].transform(params["Y_n"], dz=2)
d2Z = transforms["Z"].transform(params["Z_n"], dz=2)
coords = jnp.stack([d2X, d2Y, d2Z], axis=1)
coords = coords @ params["rotmat"].T
coords = coords @ params["rotmat"].reshape((3, 3)).T
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = xyz2rpz_vec(
coords,
Expand Down Expand Up @@ -605,7 +624,7 @@ def _x_sss_FourierXYZCurve(params, transforms, profiles, data, **kwargs):
d3Y = transforms["Y"].transform(params["Y_n"], dz=3)
d3Z = transforms["Z"].transform(params["Z_n"], dz=3)
coords = jnp.stack([d3X, d3Y, d3Z], axis=1)
coords = coords @ params["rotmat"].T
coords = coords @ params["rotmat"].reshape((3, 3)).T
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = xyz2rpz_vec(
coords,
Expand Down Expand Up @@ -662,7 +681,9 @@ def _x_SplineXYZCurve(params, transforms, profiles, data, **kwargs):
)

coords = jnp.stack([Xq, Yq, Zq], axis=1)
coords = coords @ params["rotmat"].T + params["shift"][jnp.newaxis, :]
coords = (
coords @ params["rotmat"].reshape((3, 3)).T + params["shift"][jnp.newaxis, :]
)
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = xyz2rpz(coords)
data["x"] = coords
Expand Down Expand Up @@ -715,7 +736,7 @@ def _x_s_SplineXYZCurve(params, transforms, profiles, data, **kwargs):
)

coords_s = jnp.stack([dXq, dYq, dZq], axis=1)
coords_s = coords_s @ params["rotmat"].T
coords_s = coords_s @ params["rotmat"].reshape((3, 3)).T

if kwargs.get("basis", "rpz").lower() == "rpz":
# calculate the xy coordinates to rotate to rpz
Expand Down Expand Up @@ -745,7 +766,10 @@ def _x_s_SplineXYZCurve(params, transforms, profiles, data, **kwargs):
)

coords = jnp.stack([Xq, Yq, Zq], axis=1)
coords = coords @ params["rotmat"].T + params["shift"][jnp.newaxis, :]
coords = (
coords @ params["rotmat"].reshape((3, 3)).T
+ params["shift"][jnp.newaxis, :]
)

coords_s = xyz2rpz_vec(coords_s, x=coords[:, 0], y=coords[:, 1])
data["x_s"] = coords_s
Expand Down Expand Up @@ -798,7 +822,7 @@ def _x_ss_SplineXYZCurve(params, transforms, profiles, data, **kwargs):
)

coords_ss = jnp.stack([d2Xq, d2Yq, d2Zq], axis=1)
coords_ss = coords_ss @ params["rotmat"].T
coords_ss = coords_ss @ params["rotmat"].reshape((3, 3)).T

if kwargs.get("basis", "rpz").lower() == "rpz":
# calculate the xy coordinates to rotate to rpz
Expand Down Expand Up @@ -827,7 +851,10 @@ def _x_ss_SplineXYZCurve(params, transforms, profiles, data, **kwargs):
period=2 * jnp.pi,
)
coords = jnp.stack([Xq, Yq, Zq], axis=1)
coords = coords @ params["rotmat"].T + params["shift"][jnp.newaxis, :]
coords = (
coords @ params["rotmat"].reshape((3, 3)).T
+ params["shift"][jnp.newaxis, :]
)

coords_ss = xyz2rpz_vec(coords_ss, x=coords[:, 0], y=coords[:, 1])
data["x_ss"] = coords_ss
Expand Down Expand Up @@ -880,7 +907,7 @@ def _x_sss_SplineXYZCurve(params, transforms, profiles, data, **kwargs):
)

coords_sss = jnp.stack([d3Xq, d3Yq, d3Zq], axis=1)
coords_sss = coords_sss @ params["rotmat"].T
coords_sss = coords_sss @ params["rotmat"].reshape((3, 3)).T

if kwargs.get("basis", "rpz").lower() == "rpz":
# calculate the xy coordinates to rotate to rpz
Expand Down Expand Up @@ -909,7 +936,10 @@ def _x_sss_SplineXYZCurve(params, transforms, profiles, data, **kwargs):
period=2 * jnp.pi,
)
coords = jnp.stack([Xq, Yq, Zq], axis=1)
coords = coords @ params["rotmat"].T + params["shift"][jnp.newaxis, :]
coords = (
coords @ params["rotmat"].reshape((3, 3)).T
+ params["shift"][jnp.newaxis, :]
)

coords_sss = xyz2rpz_vec(coords_sss, x=coords[:, 0], y=coords[:, 1])
data["x_sss"] = coords_sss
Expand Down Expand Up @@ -976,7 +1006,7 @@ def _frenet_normal(params, transforms, profiles, data, **kwargs):
def _frenet_binormal(params, transforms, profiles, data, **kwargs):
data["frenet_binormal"] = cross(
data["frenet_tangent"], data["frenet_normal"]
) * jnp.linalg.det(params["rotmat"])
) * jnp.linalg.det(params["rotmat"].reshape((3, 3)))
return data


Expand Down
29 changes: 10 additions & 19 deletions desc/compute/geom_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from desc.backend import jnp

from .utils import safenorm, safenormalize


def reflection_matrix(normal):
"""Matrix to reflect points across plane through origin with specified normal.
Expand Down Expand Up @@ -35,17 +37,20 @@ def rotation_matrix(axis, angle=None):

Returns
-------
rot : ndarray, shape(3,3)
Matrix to rotate points in cartesian (X,Y,Z) coordinates
rotmat : ndarray, shape(3,3)
Matrix to rotate points in cartesian (X,Y,Z) coordinates.

"""
axis = jnp.asarray(axis)
norm = safenorm(axis)
axis = safenormalize(axis)
if angle is None:
angle = jnp.linalg.norm(axis)
axis = axis / jnp.linalg.norm(axis)
angle = norm
eps = 1e2 * jnp.finfo(axis.dtype).eps
R1 = jnp.cos(angle) * jnp.eye(3)
R2 = jnp.sin(angle) * jnp.cross(axis, jnp.identity(axis.shape[0]) * -1)
R3 = (1 - jnp.cos(angle)) * jnp.outer(axis, axis)
return R1 + R2 + R3
return jnp.where(norm < eps, jnp.eye(3), R1 + R2 + R3) # if axis=0, no rotation


def xyz2rpz(pts):
Expand Down Expand Up @@ -152,17 +157,3 @@ def inner(vec, phi):
return cart

return inner(vec, phi)


def _rotation_matrix_from_normal(normal):
nx, ny, nz = normal
nxny = jnp.sqrt(nx**2 + ny**2)
R = jnp.array(
[
[ny / nxny, -nx / nxny, 0],
[nx * nx / nxny, ny * nz / nxny, -nxny],
[nx, ny, nz],
]
).T
R = jnp.where(nxny == 0, jnp.eye(3), R)
return R
Loading
Loading