Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix curvy bugs #840

Merged
merged 23 commits into from
Feb 8, 2024
Merged

fix curvy bugs #840

merged 23 commits into from
Feb 8, 2024

Conversation

ddudt
Copy link
Collaborator

@ddudt ddudt commented Jan 26, 2024

This PR fixes two issues related to the Curve class that are necessary for optimizing modular coils:

  1. Curve.shift and Curve.rotmat need to be "optimizable_param"s because they are params that get used in the compute functions.
  2. The _rotation_matrix_from_normal function was giving incorrect results and is now replaced with calls to rotation_matrix.

Things to note when reviewing:

  • master_compute_data.pkl was changed because all of the old data for the FourierPlanarCurve class was incorrect.
  • All of the examples had to be updated to replace eq.axis.shift with eq.axis._shift, etc. so they can be loaded without I/O warnings.

Copy link
Contributor

github-actions bot commented Jan 26, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     +2.74 +/- 2.90     | +3.50e-04 +/- 3.70e-04 |  1.31e-02 +/- 2.6e-04  |  1.28e-02 +/- 2.7e-04  |
 test_build_transform_fft_midres         |     -0.33 +/- 3.43     | -3.24e-04 +/- 3.33e-03 |  9.67e-02 +/- 1.9e-03  |  9.71e-02 +/- 2.7e-03  |
 test_build_transform_fft_highres        |     +1.00 +/- 2.01     | +4.78e-03 +/- 9.61e-03 |  4.83e-01 +/- 5.7e-03  |  4.78e-01 +/- 7.8e-03  |
 test_equilibrium_init_lowres            |     +0.91 +/- 2.11     | +7.59e-03 +/- 1.76e-02 |  8.45e-01 +/- 1.3e-02  |  8.37e-01 +/- 1.2e-02  |
 test_equilibrium_init_medres            |     -1.25 +/- 1.86     | -1.91e-02 +/- 2.84e-02 |  1.51e+00 +/- 1.9e-02  |  1.53e+00 +/- 2.1e-02  |
 test_equilibrium_init_highres           |     -0.74 +/- 1.61     | -3.36e-02 +/- 7.29e-02 |  4.48e+00 +/- 6.3e-02  |  4.52e+00 +/- 3.7e-02  |
 test_objective_compile_dshape_current   |     -1.41 +/- 12.56    | -6.86e-02 +/- 6.13e-01 |  4.81e+00 +/- 4.7e-01  |  4.88e+00 +/- 3.9e-01  |
 test_objective_compile_atf              |     -8.35 +/- 11.09    | -1.04e+00 +/- 1.38e+00 |  1.14e+01 +/- 9.4e-01  |  1.24e+01 +/- 1.0e+00  |
 test_objective_compute_dshape_current   |     +1.10 +/- 2.52     | +2.41e-05 +/- 5.50e-05 |  2.20e-03 +/- 3.7e-05  |  2.18e-03 +/- 4.0e-05  |
 test_objective_compute_atf              |     -0.63 +/- 2.75     | -5.09e-05 +/- 2.21e-04 |  7.98e-03 +/- 2.0e-04  |  8.03e-03 +/- 9.9e-05  |
 test_objective_jac_dshape_current       |     -8.03 +/- 9.66     | -5.06e-03 +/- 6.10e-03 |  5.80e-02 +/- 4.4e-03  |  6.31e-02 +/- 4.2e-03  |
 test_objective_jac_atf                  |     -8.65 +/- 8.35     | -3.53e-01 +/- 3.40e-01 |  3.73e+00 +/- 2.6e-01  |  4.08e+00 +/- 2.2e-01  |
 test_perturb_1                          |     +1.44 +/- 21.31    | +1.33e-01 +/- 1.98e+00 |  9.41e+00 +/- 1.5e+00  |  9.28e+00 +/- 1.3e+00  |
 test_perturb_2                          |     +0.09 +/- 4.96     | +1.44e-02 +/- 7.71e-01 |  1.56e+01 +/- 7.1e-01  |  1.56e+01 +/- 2.9e-01  |

@ddudt ddudt marked this pull request as ready for review January 27, 2024 03:22
@ddudt
Copy link
Collaborator Author

ddudt commented Jan 27, 2024

patch coverage is 100% and I added a new test, but project coverage decreased because I deleted code

@ddudt ddudt requested review from f0uriest and dpanici January 27, 2024 03:23
@@ -45,7 +46,7 @@ def rotation_matrix(axis, angle=None):
R1 = jnp.cos(angle) * jnp.eye(3)
R2 = jnp.sin(angle) * jnp.cross(axis, jnp.identity(axis.shape[0]) * -1)
R3 = (1 - jnp.cos(angle)) * jnp.outer(axis, axis)
return R1 + R2 + R3
return R1 + jnp.nan_to_num(R2 + R3) # if axis=[0, 0, 0], R2 & R3 are NaN
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FourierPlanarCurve is a 2D curve in the X-Y plane, and get rotated into another plane in 3D space. If the curve stays in the X-Y plane (as is the default parameters), then this function gets called with axis=[0, 0, 0] and angle=0.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check that reverse mode ad works for this? It sometimes has weirdness with overwriting nans

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually doesn't work with forward or reverse mode AD. I spent a bunch of time thinking about this and couldn't come up with a good solution, but I welcome help if you have an idea.

I still think this PR is worth merging even if we don't fix that issue because:

  1. The old code was returning incorrect values
  2. The old code also did not properly handle the case axis=[0, 0, 0]
  3. In practice this will only arise for vertical field coils, and I don't think that is a high-priority use case

@ddudt ddudt marked this pull request as draft January 29, 2024 22:06
Copy link

codecov bot commented Jan 29, 2024

Codecov Report

Attention: 2 lines in your changes are missing coverage. Please review.

Comparison is base (3707e08) 94.91% compared to head (66b92a9) 94.91%.

❗ Current head 66b92a9 differs from pull request most recent head ac0fc14. Consider uploading reports for the commit ac0fc14 to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #840      +/-   ##
==========================================
- Coverage   94.91%   94.91%   -0.01%     
==========================================
  Files          80       80              
  Lines       19621    19653      +32     
==========================================
+ Hits        18624    18654      +30     
- Misses        997      999       +2     
Files Coverage Δ
desc/coils.py 98.12% <100.00%> (ø)
desc/compute/_curve.py 99.42% <100.00%> (+0.02%) ⬆️
desc/compute/geom_utils.py 100.00% <100.00%> (ø)
desc/compute/utils.py 96.10% <100.00%> (+0.06%) ⬆️
desc/geometry/core.py 97.20% <92.85%> (-0.93%) ⬇️

... and 2 files with indirect coverage changes

@ddudt ddudt changed the title fix bug in PlanarCurve rotation matrix fix curvy bugs Jan 29, 2024
"""
# FIXME: does not work with AD
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could probably add a cond here that checks if axis is all zero (or otherwise small) and just returns the identity matrix in that case? I think that might solve the AD issue

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I got that to work, but I had to get rid of the optional angle argument and make it always use the norm of axis as the angle. The issue is that the derivative of jax.linalg.norm(x) is NaN when x = 0, in both forward and reverse mode. So we can't call norm(axis) outside of the cond branch where it is guaranteed to be nonzero.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is what you want:

DESC/desc/compute/utils.py

Lines 492 to 493 in c837cc5

def safenorm(x, ord=None, axis=None, fill=0, threshold=0):
"""Like jnp.linalg.norm, but without nan gradient at x=0.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can use safediv and keep the angle argument, I prefer that as it is much clearer than having the angle be the norm of the axis

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I have this working now with the angle argument. I also created a new util function safenormalize to normalize vectors to unit length

desc/geometry/core.py Show resolved Hide resolved
@ddudt ddudt marked this pull request as ready for review February 2, 2024 04:34
@ddudt
Copy link
Collaborator Author

ddudt commented Feb 2, 2024

@f0uriest @dpanici This is ready for review. We still need to make a decision about how to better handle the Curve.shift and Curve.rotmat attributes, but that can wait for a future PR. They are already treated as optimizable compute params on the master branch, and the bug fixes in this PR are the minimal changes necessary to get the existing API functional.

R = rotation_matrix(axis, angle)
self.rotmat = R @ self.rotmat
def rotate(self, axis):
"""Rotate the curve about axis in X, Y, Z coordinates."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should mention that the angle is determined by the norm of axis

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am slightly confused, so now if I want to rotate about the vector (1,1,0) 90 degrees (pi/4), I should pass in for the axis (1,1,0) * np.sqrt(pi/4 / 2)?

self.rotmat = jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]).astype(float)
self.name = name
self._shift = np.array([0, 0, 0], dtype=float)
self._rotmat = np.eye(3, dtype=float).flatten()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason for making these all regular np arrays rather than jnp?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reverted back to jnp but I don't think it matters and we are inconsistent about np vs jnp in other classes. We should probably try to consistently use jnp everywhere.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sometimes np is needed though to ensure certain things are treated as static with JAX so blanket jnp might not be what we want, but yea a lot of classes build methods it does not matter, but for the init and some quantities from build I think it can matter

desc/compute/geom_utils.py Outdated Show resolved Hide resolved
dpanici
dpanici previously requested changes Feb 4, 2024
Copy link
Collaborator

@dpanici dpanici left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use safediv and allow the angle argument again for the rotation matrix, it is cleaner to have the arguments be both axis, angle

@dpanici dpanici mentioned this pull request Feb 5, 2024
@ddudt ddudt requested review from dpanici and f0uriest February 6, 2024 18:03
desc/compute/utils.py Show resolved Hide resolved
desc/geometry/core.py Show resolved Hide resolved
@dpanici
Copy link
Collaborator

dpanici commented Feb 6, 2024

Add automatic constraint to fix rotmat and shift as another self consistency check

)

if self._normalize:
self._normalization = 1
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really we should add normalization options for Curve to the compute_scaling_factors function, but for these objectives the normalization will never really matter.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#863 can be easily done in a separate PR, this one is getting large

desc/objectives/getters.py Outdated Show resolved Hide resolved
@ddudt ddudt requested a review from f0uriest February 7, 2024 22:56
@ddudt ddudt dismissed dpanici’s stale review February 7, 2024 22:57

made requested changes

@ddudt
Copy link
Collaborator Author

ddudt commented Feb 7, 2024

Code coverage is low because of the new self-consistency curve constraints that were added but are not able to be tested yet.

@ddudt ddudt merged commit f52b1ae into master Feb 8, 2024
15 of 17 checks passed
@ddudt ddudt deleted the dd/curve branch February 8, 2024 16:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants