Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix nan in reverse mode gradient caused by rotation_matrix #1457

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

dpanici
Copy link
Collaborator

@dpanici dpanici commented Dec 11, 2024

Resolves #1456 by adding different method of computing rotation_matrix (from here)

@dpanici dpanici changed the title Fix nan in BoundaryError when deriv_mode="rev" Fix nan in reverse mode gradient caused by rotation_matrix Dec 11, 2024
@dpanici dpanici requested review from ddudt and f0uriest December 11, 2024 21:35
@dpanici dpanici added the skip_changelog No need to update changelog on this PR label Dec 11, 2024
Copy link
Contributor

github-actions bot commented Dec 11, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     +1.54 +/- 4.53     | +8.45e-03 +/- 2.48e-02 |  5.57e-01 +/- 2.0e-02  |  5.48e-01 +/- 1.5e-02  |
 test_equilibrium_init_medres            |     +2.41 +/- 2.31     | +1.04e-01 +/- 9.99e-02 |  4.43e+00 +/- 7.1e-02  |  4.33e+00 +/- 7.0e-02  |
 test_equilibrium_init_highres           |     +2.89 +/- 2.06     | +1.62e-01 +/- 1.15e-01 |  5.75e+00 +/- 8.6e-02  |  5.59e+00 +/- 7.7e-02  |
 test_objective_compile_dshape_current   |     +0.26 +/- 1.39     | +1.04e-02 +/- 5.60e-02 |  4.05e+00 +/- 4.3e-02  |  4.04e+00 +/- 3.6e-02  |
 test_objective_compute_dshape_current   |     +3.08 +/- 8.79     | +1.64e-04 +/- 4.69e-04 |  5.49e-03 +/- 2.6e-04  |  5.33e-03 +/- 3.9e-04  |
 test_objective_jac_dshape_current       |     +1.69 +/- 7.98     | +7.32e-04 +/- 3.45e-03 |  4.39e-02 +/- 1.7e-03  |  4.32e-02 +/- 3.0e-03  |
 test_perturb_2                          |     +0.35 +/- 3.32     | +7.25e-02 +/- 6.84e-01 |  2.07e+01 +/- 6.6e-01  |  2.06e+01 +/- 1.9e-01  |
 test_proximal_freeb_jac                 |     -0.25 +/- 1.85     | -1.87e-02 +/- 1.39e-01 |  7.47e+00 +/- 7.1e-02  |  7.49e+00 +/- 1.2e-01  |
 test_solve_fixed_iter                   |     -0.22 +/- 3.92     | -6.49e-02 +/- 1.17e+00 |  2.97e+01 +/- 8.8e-01  |  2.98e+01 +/- 7.7e-01  |
 test_LinearConstraintProjection_build   |     +0.06 +/- 1.18     | +1.55e-02 +/- 2.85e-01 |  2.42e+01 +/- 2.1e-01  |  2.42e+01 +/- 1.9e-01  |
 test_build_transform_fft_midres         |     -0.10 +/- 5.10     | -6.07e-04 +/- 3.22e-02 |  6.30e-01 +/- 2.2e-02  |  6.31e-01 +/- 2.4e-02  |
 test_build_transform_fft_highres        |     -0.16 +/- 2.79     | -1.57e-03 +/- 2.78e-02 |  9.93e-01 +/- 1.4e-02  |  9.95e-01 +/- 2.4e-02  |
 test_equilibrium_init_lowres            |     -1.02 +/- 2.06     | -4.17e-02 +/- 8.41e-02 |  4.03e+00 +/- 6.3e-02  |  4.07e+00 +/- 5.6e-02  |
 test_objective_compile_atf              |     -0.96 +/- 3.18     | -7.90e-02 +/- 2.63e-01 |  8.18e+00 +/- 2.4e-01  |  8.26e+00 +/- 1.1e-01  |
 test_objective_compute_atf              |     +4.22 +/- 4.47     | +6.86e-04 +/- 7.26e-04 |  1.69e-02 +/- 4.0e-04  |  1.63e-02 +/- 6.1e-04  |
 test_objective_jac_atf                  |     +0.28 +/- 3.58     | +5.49e-03 +/- 7.03e-02 |  1.97e+00 +/- 5.2e-02  |  1.96e+00 +/- 4.7e-02  |
 test_perturb_1                          |     -1.55 +/- 1.99     | -2.37e-01 +/- 3.05e-01 |  1.51e+01 +/- 1.8e-01  |  1.53e+01 +/- 2.5e-01  |
 test_proximal_jac_atf                   |     +0.13 +/- 1.07     | +1.09e-02 +/- 8.84e-02 |  8.30e+00 +/- 6.1e-02  |  8.29e+00 +/- 6.4e-02  |
 test_proximal_freeb_compute             |     +0.37 +/- 1.69     | +7.33e-04 +/- 3.40e-03 |  2.01e-01 +/- 1.7e-03  |  2.01e-01 +/- 3.0e-03  |
 test_solve_fixed_iter_compiled          |     +0.76 +/- 2.12     | +1.31e-01 +/- 3.64e-01 |  1.73e+01 +/- 2.4e-01  |  1.71e+01 +/- 2.7e-01  |

@@ -28,6 +28,10 @@ def reflection_matrix(normal):
def rotation_matrix(axis, angle=None):
"""Matrix to rotate points about axis by given angle.

NOTE: not correct if a and b are antiparallel, will
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont think this comment applies here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am certain it does, I checked it and the rotation matrix is wrong for this case with this function too

def rotation_matrix_vector_vector(a, b):
"""Matrix to rotate vector a onto b.

NOTE: not correct if a and b are antiparallel, will
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could add something like jnp.where(a+b==0, -R, R)

@@ -3076,7 +3076,7 @@ def test_objective_no_nangrad_quadratic_flux(self):
@pytest.mark.unit
def test_objective_no_nangrad_quadratic_flux_minimizing(self):
"""SurfaceQuadraticFlux."""
ext_field = ToroidalMagneticField(1.0, 1.0)
ext_field = FourierPlanarCoil(normal=[0, 0, 1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this intersect the plasma?

@f0uriest
Copy link
Member

That existing test passes because the issue isnt with the rotation matrix itself, but in the arccos that is needed to compute the angle

R1 = jnp.eye(3)
R2 = skew
R3 = (skew @ skew) * safediv(1, 1 + dot(a, b))
return jnp.where(norm < eps, jnp.eye(3), R1 + R2 + R3) # if axis=0, no rotation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this is correct, I seem to get weird answers in a few cases I've tried.
The determinant of a rotation matrix should always be 1, but for example

np.linalg.det(rotation_matrix_vector_vector(np.array([1,0,0]), np.array([1,1,0])))

gives 1.17

and if they are almost antiparallel:

np.linalg.det(rotation_matrix_vector_vector(np.array([1,0,0]), np.array([-1,1e-4,0])))

gives like 4e16

@@ -345,9 +345,7 @@ def _x_sss_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
coords = jnp.array([d3X, d3Y, d3Z]).T
# rotate into place
Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis
axis = cross(Zaxis, normal)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another way to rotate a into b would be to rotate by pi around the axis (a+b)/2 which might be easier.

a = Zaxis
b = safenormalize(normal)
axis = safenormalize((a+b)/2)
A = rotation_matrix(axis, angle=np.pi)

That avoids the arccos and I think should be fine since its all safenormalized?

Would still need to special case if they are exactly antiparallel but seems to work ok if they are ~nearly antiparallel

Copy link
Collaborator

@ddudt ddudt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like any method for computing the rotation matrix will have edge cases that require special logic to handle. Therefore I vote to keep the existing method, and just add the logic to make the reverse mode work when angle = arccos(1) = 0

@dpanici
Copy link
Collaborator Author

dpanici commented Dec 16, 2024

  • safe arccos
  • check for sign of dot product of Zaxis and normal to assign correct sign
  • check for normal near-Z-axis works as well (when almost but not quite aligned with normal)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
skip_changelog No need to update changelog on this PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

nan reverse mode gradient when FourierPlanarCurve normal is parallel or anti-parallel to Zaxis
4 participants