-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix nan
in reverse mode gradient caused by rotation_matrix
#1457
base: master
Are you sure you want to change the base?
Conversation
nan
in BoundaryError
when deriv_mode="rev"
nan
in reverse mode gradient caused by rotation_matrix
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | +1.54 +/- 4.53 | +8.45e-03 +/- 2.48e-02 | 5.57e-01 +/- 2.0e-02 | 5.48e-01 +/- 1.5e-02 |
test_equilibrium_init_medres | +2.41 +/- 2.31 | +1.04e-01 +/- 9.99e-02 | 4.43e+00 +/- 7.1e-02 | 4.33e+00 +/- 7.0e-02 |
test_equilibrium_init_highres | +2.89 +/- 2.06 | +1.62e-01 +/- 1.15e-01 | 5.75e+00 +/- 8.6e-02 | 5.59e+00 +/- 7.7e-02 |
test_objective_compile_dshape_current | +0.26 +/- 1.39 | +1.04e-02 +/- 5.60e-02 | 4.05e+00 +/- 4.3e-02 | 4.04e+00 +/- 3.6e-02 |
test_objective_compute_dshape_current | +3.08 +/- 8.79 | +1.64e-04 +/- 4.69e-04 | 5.49e-03 +/- 2.6e-04 | 5.33e-03 +/- 3.9e-04 |
test_objective_jac_dshape_current | +1.69 +/- 7.98 | +7.32e-04 +/- 3.45e-03 | 4.39e-02 +/- 1.7e-03 | 4.32e-02 +/- 3.0e-03 |
test_perturb_2 | +0.35 +/- 3.32 | +7.25e-02 +/- 6.84e-01 | 2.07e+01 +/- 6.6e-01 | 2.06e+01 +/- 1.9e-01 |
test_proximal_freeb_jac | -0.25 +/- 1.85 | -1.87e-02 +/- 1.39e-01 | 7.47e+00 +/- 7.1e-02 | 7.49e+00 +/- 1.2e-01 |
test_solve_fixed_iter | -0.22 +/- 3.92 | -6.49e-02 +/- 1.17e+00 | 2.97e+01 +/- 8.8e-01 | 2.98e+01 +/- 7.7e-01 |
test_LinearConstraintProjection_build | +0.06 +/- 1.18 | +1.55e-02 +/- 2.85e-01 | 2.42e+01 +/- 2.1e-01 | 2.42e+01 +/- 1.9e-01 |
test_build_transform_fft_midres | -0.10 +/- 5.10 | -6.07e-04 +/- 3.22e-02 | 6.30e-01 +/- 2.2e-02 | 6.31e-01 +/- 2.4e-02 |
test_build_transform_fft_highres | -0.16 +/- 2.79 | -1.57e-03 +/- 2.78e-02 | 9.93e-01 +/- 1.4e-02 | 9.95e-01 +/- 2.4e-02 |
test_equilibrium_init_lowres | -1.02 +/- 2.06 | -4.17e-02 +/- 8.41e-02 | 4.03e+00 +/- 6.3e-02 | 4.07e+00 +/- 5.6e-02 |
test_objective_compile_atf | -0.96 +/- 3.18 | -7.90e-02 +/- 2.63e-01 | 8.18e+00 +/- 2.4e-01 | 8.26e+00 +/- 1.1e-01 |
test_objective_compute_atf | +4.22 +/- 4.47 | +6.86e-04 +/- 7.26e-04 | 1.69e-02 +/- 4.0e-04 | 1.63e-02 +/- 6.1e-04 |
test_objective_jac_atf | +0.28 +/- 3.58 | +5.49e-03 +/- 7.03e-02 | 1.97e+00 +/- 5.2e-02 | 1.96e+00 +/- 4.7e-02 |
test_perturb_1 | -1.55 +/- 1.99 | -2.37e-01 +/- 3.05e-01 | 1.51e+01 +/- 1.8e-01 | 1.53e+01 +/- 2.5e-01 |
test_proximal_jac_atf | +0.13 +/- 1.07 | +1.09e-02 +/- 8.84e-02 | 8.30e+00 +/- 6.1e-02 | 8.29e+00 +/- 6.4e-02 |
test_proximal_freeb_compute | +0.37 +/- 1.69 | +7.33e-04 +/- 3.40e-03 | 2.01e-01 +/- 1.7e-03 | 2.01e-01 +/- 3.0e-03 |
test_solve_fixed_iter_compiled | +0.76 +/- 2.12 | +1.31e-01 +/- 3.64e-01 | 1.73e+01 +/- 2.4e-01 | 1.71e+01 +/- 2.7e-01 | |
@@ -28,6 +28,10 @@ def reflection_matrix(normal): | |||
def rotation_matrix(axis, angle=None): | |||
"""Matrix to rotate points about axis by given angle. | |||
|
|||
NOTE: not correct if a and b are antiparallel, will |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dont think this comment applies here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am certain it does, I checked it and the rotation matrix is wrong for this case with this function too
def rotation_matrix_vector_vector(a, b): | ||
"""Matrix to rotate vector a onto b. | ||
|
||
NOTE: not correct if a and b are antiparallel, will |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could add something like jnp.where(a+b==0, -R, R)
@@ -3076,7 +3076,7 @@ def test_objective_no_nangrad_quadratic_flux(self): | |||
@pytest.mark.unit | |||
def test_objective_no_nangrad_quadratic_flux_minimizing(self): | |||
"""SurfaceQuadraticFlux.""" | |||
ext_field = ToroidalMagneticField(1.0, 1.0) | |||
ext_field = FourierPlanarCoil(normal=[0, 0, 1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this intersect the plasma?
That existing test passes because the issue isnt with the rotation matrix itself, but in the |
R1 = jnp.eye(3) | ||
R2 = skew | ||
R3 = (skew @ skew) * safediv(1, 1 + dot(a, b)) | ||
return jnp.where(norm < eps, jnp.eye(3), R1 + R2 + R3) # if axis=0, no rotation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure this is correct, I seem to get weird answers in a few cases I've tried.
The determinant of a rotation matrix should always be 1, but for example
np.linalg.det(rotation_matrix_vector_vector(np.array([1,0,0]), np.array([1,1,0])))
gives 1.17
and if they are almost antiparallel:
np.linalg.det(rotation_matrix_vector_vector(np.array([1,0,0]), np.array([-1,1e-4,0])))
gives like 4e16
@@ -345,9 +345,7 @@ def _x_sss_FourierPlanarCurve(params, transforms, profiles, data, **kwargs): | |||
coords = jnp.array([d3X, d3Y, d3Z]).T | |||
# rotate into place | |||
Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis | |||
axis = cross(Zaxis, normal) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another way to rotate a
into b
would be to rotate by pi
around the axis (a+b)/2
which might be easier.
a = Zaxis
b = safenormalize(normal)
axis = safenormalize((a+b)/2)
A = rotation_matrix(axis, angle=np.pi)
That avoids the arccos and I think should be fine since its all safenormalized?
Would still need to special case if they are exactly antiparallel but seems to work ok if they are ~nearly antiparallel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like any method for computing the rotation matrix will have edge cases that require special logic to handle. Therefore I vote to keep the existing method, and just add the logic to make the reverse mode work when angle = arccos(1) = 0
|
Resolves #1456 by adding different method of computing
rotation_matrix
(from here)rotation_matrix
still?FourierPlanarCurve
normal
is parallel or anti-parallel toZaxis
#1456 ?