-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix curvy bugs #840
fix curvy bugs #840
Conversation
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | +2.74 +/- 2.90 | +3.50e-04 +/- 3.70e-04 | 1.31e-02 +/- 2.6e-04 | 1.28e-02 +/- 2.7e-04 |
test_build_transform_fft_midres | -0.33 +/- 3.43 | -3.24e-04 +/- 3.33e-03 | 9.67e-02 +/- 1.9e-03 | 9.71e-02 +/- 2.7e-03 |
test_build_transform_fft_highres | +1.00 +/- 2.01 | +4.78e-03 +/- 9.61e-03 | 4.83e-01 +/- 5.7e-03 | 4.78e-01 +/- 7.8e-03 |
test_equilibrium_init_lowres | +0.91 +/- 2.11 | +7.59e-03 +/- 1.76e-02 | 8.45e-01 +/- 1.3e-02 | 8.37e-01 +/- 1.2e-02 |
test_equilibrium_init_medres | -1.25 +/- 1.86 | -1.91e-02 +/- 2.84e-02 | 1.51e+00 +/- 1.9e-02 | 1.53e+00 +/- 2.1e-02 |
test_equilibrium_init_highres | -0.74 +/- 1.61 | -3.36e-02 +/- 7.29e-02 | 4.48e+00 +/- 6.3e-02 | 4.52e+00 +/- 3.7e-02 |
test_objective_compile_dshape_current | -1.41 +/- 12.56 | -6.86e-02 +/- 6.13e-01 | 4.81e+00 +/- 4.7e-01 | 4.88e+00 +/- 3.9e-01 |
test_objective_compile_atf | -8.35 +/- 11.09 | -1.04e+00 +/- 1.38e+00 | 1.14e+01 +/- 9.4e-01 | 1.24e+01 +/- 1.0e+00 |
test_objective_compute_dshape_current | +1.10 +/- 2.52 | +2.41e-05 +/- 5.50e-05 | 2.20e-03 +/- 3.7e-05 | 2.18e-03 +/- 4.0e-05 |
test_objective_compute_atf | -0.63 +/- 2.75 | -5.09e-05 +/- 2.21e-04 | 7.98e-03 +/- 2.0e-04 | 8.03e-03 +/- 9.9e-05 |
test_objective_jac_dshape_current | -8.03 +/- 9.66 | -5.06e-03 +/- 6.10e-03 | 5.80e-02 +/- 4.4e-03 | 6.31e-02 +/- 4.2e-03 |
test_objective_jac_atf | -8.65 +/- 8.35 | -3.53e-01 +/- 3.40e-01 | 3.73e+00 +/- 2.6e-01 | 4.08e+00 +/- 2.2e-01 |
test_perturb_1 | +1.44 +/- 21.31 | +1.33e-01 +/- 1.98e+00 | 9.41e+00 +/- 1.5e+00 | 9.28e+00 +/- 1.3e+00 |
test_perturb_2 | +0.09 +/- 4.96 | +1.44e-02 +/- 7.71e-01 | 1.56e+01 +/- 7.1e-01 | 1.56e+01 +/- 2.9e-01 | |
patch coverage is 100% and I added a new test, but project coverage decreased because I deleted code |
desc/compute/geom_utils.py
Outdated
@@ -45,7 +46,7 @@ def rotation_matrix(axis, angle=None): | |||
R1 = jnp.cos(angle) * jnp.eye(3) | |||
R2 = jnp.sin(angle) * jnp.cross(axis, jnp.identity(axis.shape[0]) * -1) | |||
R3 = (1 - jnp.cos(angle)) * jnp.outer(axis, axis) | |||
return R1 + R2 + R3 | |||
return R1 + jnp.nan_to_num(R2 + R3) # if axis=[0, 0, 0], R2 & R3 are NaN |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FourierPlanarCurve
is a 2D curve in the X-Y plane, and get rotated into another plane in 3D space. If the curve stays in the X-Y plane (as is the default parameters), then this function gets called with axis=[0, 0, 0]
and angle=0
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check that reverse mode ad works for this? It sometimes has weirdness with overwriting nans
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This actually doesn't work with forward or reverse mode AD. I spent a bunch of time thinking about this and couldn't come up with a good solution, but I welcome help if you have an idea.
I still think this PR is worth merging even if we don't fix that issue because:
- The old code was returning incorrect values
- The old code also did not properly handle the case
axis=[0, 0, 0]
- In practice this will only arise for vertical field coils, and I don't think that is a high-priority use case
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## master #840 +/- ##
==========================================
- Coverage 94.91% 94.91% -0.01%
==========================================
Files 80 80
Lines 19621 19653 +32
==========================================
+ Hits 18624 18654 +30
- Misses 997 999 +2
|
desc/compute/geom_utils.py
Outdated
""" | ||
# FIXME: does not work with AD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could probably add a cond here that checks if axis is all zero (or otherwise small) and just returns the identity matrix in that case? I think that might solve the AD issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I got that to work, but I had to get rid of the optional angle
argument and make it always use the norm of axis
as the angle. The issue is that the derivative of jax.linalg.norm(x)
is NaN when x = 0
, in both forward and reverse mode. So we can't call norm(axis)
outside of the cond
branch where it is guaranteed to be nonzero.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is what you want:
Lines 492 to 493 in c837cc5
def safenorm(x, ord=None, axis=None, fill=0, threshold=0): | |
"""Like jnp.linalg.norm, but without nan gradient at x=0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we can use safediv
and keep the angle argument, I prefer that as it is much clearer than having the angle be the norm of the axis
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I have this working now with the angle argument. I also created a new util function safenormalize
to normalize vectors to unit length
@f0uriest @dpanici This is ready for review. We still need to make a decision about how to better handle the |
desc/geometry/core.py
Outdated
R = rotation_matrix(axis, angle) | ||
self.rotmat = R @ self.rotmat | ||
def rotate(self, axis): | ||
"""Rotate the curve about axis in X, Y, Z coordinates.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should mention that the angle is determined by the norm of axis
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am slightly confused, so now if I want to rotate about the vector (1,1,0) 90 degrees (pi/4), I should pass in for the axis (1,1,0) * np.sqrt(pi/4 / 2)?
desc/geometry/core.py
Outdated
self.rotmat = jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]).astype(float) | ||
self.name = name | ||
self._shift = np.array([0, 0, 0], dtype=float) | ||
self._rotmat = np.eye(3, dtype=float).flatten() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason for making these all regular np arrays rather than jnp?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reverted back to jnp but I don't think it matters and we are inconsistent about np vs jnp in other classes. We should probably try to consistently use jnp everywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sometimes np is needed though to ensure certain things are treated as static with JAX so blanket jnp might not be what we want, but yea a lot of classes build methods it does not matter, but for the init and some quantities from build I think it can matter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use safediv
and allow the angle argument again for the rotation matrix, it is cleaner to have the arguments be both axis, angle
Add automatic constraint to fix |
) | ||
|
||
if self._normalize: | ||
self._normalization = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really we should add normalization options for Curve
to the compute_scaling_factors
function, but for these objectives the normalization will never really matter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#863 can be easily done in a separate PR, this one is getting large
Code coverage is low because of the new self-consistency curve constraints that were added but are not able to be tested yet. |
This PR fixes two issues related to the
Curve
class that are necessary for optimizing modular coils:Curve.shift
andCurve.rotmat
need to be "optimizable_param"s because they are params that get used in the compute functions._rotation_matrix_from_normal
function was giving incorrect results and is now replaced with calls torotation_matrix
.Things to note when reviewing:
master_compute_data.pkl
was changed because all of the old data for theFourierPlanarCurve
class was incorrect.eq.axis.shift
witheq.axis._shift
, etc. so they can be loaded without I/O warnings.