Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make factorize_linear_constraints faster #1374

Merged
merged 27 commits into from
Dec 3, 2024
Merged
Changes from 3 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
56f5585
make degenerate constraint check faster
YigitElma Nov 18, 2024
8486bfc
use jnp instead of np
YigitElma Nov 18, 2024
1dcbd93
Merge branch 'master' into yge/degenerate
dpanici Nov 18, 2024
487c580
back to numpy version
YigitElma Nov 18, 2024
7ac7bb3
Merge branch 'yge/degenerate' of github.com:PlasmaControl/DESC into y…
YigitElma Nov 18, 2024
0bb978e
add benchmark, update unmarked test check
YigitElma Nov 18, 2024
27260e4
try qr for null-space
YigitElma Nov 19, 2024
6f618b9
remove redundant array conversion
YigitElma Nov 19, 2024
3102b27
use solve_triangular instead
YigitElma Nov 19, 2024
70bdc81
add check for rank 0 case
YigitElma Nov 19, 2024
0c0fddb
refactor solve_fixed_iter test to compiled and first to prevent huge …
YigitElma Nov 19, 2024
474c628
add setup to benchmarks to prevent high standard deviation
YigitElma Nov 19, 2024
14f3018
fix setup return type
YigitElma Nov 19, 2024
586de4d
try fix
YigitElma Nov 19, 2024
c59b58a
update docs
YigitElma Nov 19, 2024
aa82a2c
Merge branch 'master' into yge/degenerate
dpanici Nov 19, 2024
5a707df
Merge branch 'master' into yge/degenerate
dpanici Nov 20, 2024
3190e41
update tolerance
YigitElma Nov 21, 2024
859f067
use jax svd in svd_inv_null
YigitElma Nov 21, 2024
a02adde
remove qr version, jax svd version is fast enough
YigitElma Nov 21, 2024
ca3d967
fix typo
YigitElma Nov 21, 2024
b116a8a
clean up jax svd version
YigitElma Nov 21, 2024
b149090
Merge branch 'master' into yge/degenerate
dpanici Nov 22, 2024
72dc319
Merge branch 'master' into yge/degenerate
YigitElma Nov 22, 2024
27945d6
Merge branch 'master' into yge/degenerate
YigitElma Dec 3, 2024
c8a4077
Merge branch 'master' into yge/degenerate
YigitElma Dec 3, 2024
b23b5e0
Merge branch 'master' into yge/degenerate
YigitElma Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 12 additions & 25 deletions desc/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,38 +97,25 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa
# which are duplicate rows of A that also have duplicate entries of b,
# if the entries of b aren't the same then the constraints are actually
# incompatible and so we will leave those to be caught later.
A_augmented = np.hstack([A, np.reshape(b, (A.shape[0], 1))])
row_idx_to_delete = np.array([], dtype=int)
for row_idx in range(A_augmented.shape[0]):
# find all rows equal to this row
rows_equal_to_this_row = np.where(
np.all(A_augmented[row_idx, :] == A_augmented, axis=1)
)[0]
# find the rows equal to this row that are not this row
rows_equal_to_this_row_but_not_this_row = rows_equal_to_this_row[
rows_equal_to_this_row != row_idx
]
# if there are rows equal to this row that aren't this row, AND this particular
# row has not already been detected as a duplicate of an earlier one and slated
# for deletion, add the duplicate row indices to the array of
# rows to be deleted
if (
rows_equal_to_this_row_but_not_this_row.size
and row_idx not in row_idx_to_delete
):
row_idx_to_delete = np.append(row_idx_to_delete, rows_equal_to_this_row[1:])
# delete the affected rows, and also the corresponding rows of b
A_augmented = np.delete(A_augmented, row_idx_to_delete, axis=0)
A_augmented = jnp.hstack([A, jnp.reshape(b, (A.shape[0], 1))])

# Find unique rows of A_augmented
unique_rows, unique_indices = jnp.unique(A_augmented, axis=0, return_index=True)

# Sort the indices to preserve the order of appearance
unique_indices = jnp.sort(unique_indices)

# while loop has problems updating JAX arrays, convert them to numpy arrays
A_augmented = np.array(A_augmented)
YigitElma marked this conversation as resolved.
Show resolved Hide resolved
# Extract the unique rows
A_augmented = A_augmented[unique_indices]
A = A_augmented[:, :-1]
b = np.atleast_1d(A_augmented[:, -1].squeeze())

# will store the global index of the unfixed rows, idx
indices_row = np.arange(A.shape[0])
indices_idx = np.arange(A.shape[1])

# while loop has problems updating JAX arrays, convert them to numpy arrays
A = np.array(A)
b = np.array(b)
while len(np.where(np.count_nonzero(A, axis=1) == 1)[0]):
# fixed just means there is a single element in A, so A_ij*x_j = b_i
fixed_rows = np.where(np.count_nonzero(A, axis=1) == 1)[0]
Expand Down
Loading