-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix nan
in reverse mode gradient caused by rotation_matrix
#1457
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
|
||
from desc.backend import jnp | ||
|
||
from ..utils import safenorm, safenormalize | ||
from ..utils import cross, dot, safediv, safenorm, safenormalize | ||
|
||
|
||
def reflection_matrix(normal): | ||
|
@@ -28,6 +28,10 @@ | |
def rotation_matrix(axis, angle=None): | ||
"""Matrix to rotate points about axis by given angle. | ||
|
||
NOTE: not correct if a and b are antiparallel, will | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dont think this comment applies here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am certain it does, I checked it and the rotation matrix is wrong for this case with this function too |
||
simply return identity when in reality negative identity | ||
is correct. | ||
|
||
Parameters | ||
---------- | ||
axis : array-like, shape(3,) | ||
|
@@ -53,6 +57,45 @@ | |
return jnp.where(norm < eps, jnp.eye(3), R1 + R2 + R3) # if axis=0, no rotation | ||
|
||
|
||
def _skew_matrix(a): | ||
return jnp.array([[0, -a[2], a[1]], [a[2], 0, -a[0]], [-a[1], a[0], 0]]) | ||
|
||
|
||
def rotation_matrix_vector_vector(a, b): | ||
"""Matrix to rotate vector a onto b. | ||
|
||
NOTE: not correct if a and b are antiparallel, will | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could add something like |
||
simply return identity when in reality negative identity | ||
is correct. | ||
|
||
Parameters | ||
---------- | ||
a,b : array-like, shape(3,) | ||
Vectors, in cartesian (X,Y,Z) coordinates | ||
Matrix will correspond to rotating a onto b | ||
|
||
|
||
Returns | ||
------- | ||
rotmat : ndarray, shape(3,3) | ||
Matrix to rotate points in cartesian (X,Y,Z) coordinates. | ||
|
||
""" | ||
a = jnp.asarray(a) | ||
b = jnp.asarray(b) | ||
a = safenormalize(a) | ||
b = safenormalize(b) | ||
axis = cross(a, b) | ||
norm = safenorm(axis) | ||
axis = safenormalize(axis) | ||
eps = 1e2 * jnp.finfo(axis.dtype).eps | ||
skew = _skew_matrix(axis) | ||
R1 = jnp.eye(3) | ||
R2 = skew | ||
R3 = (skew @ skew) * safediv(1, 1 + dot(a, b)) | ||
return jnp.where(norm < eps, jnp.eye(3), R1 + R2 + R3) # if axis=0, no rotation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure this is correct, I seem to get weird answers in a few cases I've tried. np.linalg.det(rotation_matrix_vector_vector(np.array([1,0,0]), np.array([1,1,0]))) gives 1.17 and if they are almost antiparallel: np.linalg.det(rotation_matrix_vector_vector(np.array([1,0,0]), np.array([-1,1e-4,0]))) gives like 4e16 |
||
|
||
|
||
def xyz2rpz(pts): | ||
"""Transform points from cartesian (X,Y,Z) to polar (R,phi,Z) form. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3076,7 +3076,7 @@ def test_objective_no_nangrad_quadratic_flux(self): | |
@pytest.mark.unit | ||
def test_objective_no_nangrad_quadratic_flux_minimizing(self): | ||
"""SurfaceQuadraticFlux.""" | ||
ext_field = ToroidalMagneticField(1.0, 1.0) | ||
ext_field = FourierPlanarCoil(normal=[0, 0, 1]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this intersect the plasma? |
||
|
||
surf = FourierRZToroidalSurface( | ||
R_lmn=[4.0, 1.0], | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another way to rotate
a
intob
would be to rotate bypi
around the axis(a+b)/2
which might be easier.That avoids the arccos and I think should be fine since its all safenormalized?
Would still need to special case if they are exactly antiparallel but seems to work ok if they are ~nearly antiparallel