Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SurfaceRZFourier cache isn't invalidated by setting rc, zs, rs, zc arrays #465

Draft
wants to merge 29 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2eb196d
Fix: raise ValueError for invalid ProfileSpec inputs, correct typos i…
missing-user Nov 12, 2024
e41d792
Merge branch 'hiddenSymmetries:master' into master
missing-user Nov 13, 2024
b47f9c1
Fix: as_spec setter shape check typos
missing-user Nov 13, 2024
d8c63ce
iota has size nvol+1 not mvol+1
missing-user Nov 17, 2024
7e2afad
np.NINF to -np.inf for np2.0 compatibility
missing-user Nov 18, 2024
8b57e42
Revert "np.NINF to -np.inf for np2.0 compatibility"
missing-user Nov 18, 2024
f9eb4ae
iota error message updated
missing-user Nov 18, 2024
4cf1f6b
fix typo
missing-user Nov 18, 2024
5022aac
Merge branch 'hiddenSymmetries:master' into master
missing-user Nov 18, 2024
dca6c45
boundary object is kepy in sync in freeboundary mode
missing-user Nov 30, 2024
a99aa05
normal field refactor so vnc and vns and dofs are always in sync
missing-user Nov 30, 2024
403e841
initial guess and recompute trigger for freeboundary
missing-user Dec 2, 2024
37a9356
implemented some of smiets comments
missing-user Dec 3, 2024
9099e0a
boundary access forces freeboundary SPEC to run correctly
missing-user Dec 3, 2024
eb5177e
Merge branch 'hiddenSymmetries:master' into master
missing-user Dec 3, 2024
a572b8b
Fix https://github.com/hiddenSymmetries/simsopt/issues/389
missing-user Dec 4, 2024
201df39
unused import
missing-user Dec 4, 2024
1ac401e
Merge branch 'fix-mpi-logging'
missing-user Dec 4, 2024
30b1f1b
respect prob.bounds in all solver wrappers
missing-user Dec 4, 2024
32afc9d
Merge branch 'warn-unused-bounds'
missing-user Dec 4, 2024
6634053
extract conversion of results to simsopt surface to a separate function
missing-user Dec 5, 2024
4f4e214
array setters for surface
missing-user Dec 18, 2024
29bc558
docstrings for zs, rc, ... setters
missing-user Dec 18, 2024
557a7ed
removed print
missing-user Dec 18, 2024
e11205b
array setter indexing error
missing-user Dec 20, 2024
d473a37
removal fails sometimes
missing-user Dec 22, 2024
bf9d8ca
Typo in docstring
missing-user Jan 8, 2025
5f5c6b5
Merge remote-tracking branch 'upstream/master'
landreman Jan 27, 2025
9099e0b
Merge branch 'master' into surfaceRZfourier-setter
landreman Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/1_Simple/logger_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3

import logging
from simsopt.util.log import initialize_logging
from simsopt.util import initialize_logging

"""
Example file for transparently logging both MPI and serial jobs
Expand All @@ -24,7 +24,7 @@

if comm is not None:
initialize_logging(mpi=True, filename='mpi.log')
for i in range(2):
for i in range(5):
logging.warning("Hello (times %i) from mpi job" % (i+1))
print("End of 1_Simple/logger_example.py")
print("==================================")
1 change: 1 addition & 0 deletions examples/run_parallel_examples
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ echo MPI_OPTIONS=$MPI_OPTIONS
mpiexec $MPI_OPTIONS -n 2 ./1_Simple/tracing_fieldlines_NCSX.py
mpiexec $MPI_OPTIONS -n 2 ./1_Simple/tracing_fieldlines_QA.py
mpiexec $MPI_OPTIONS -n 2 ./1_Simple/tracing_particle.py
mpiexec $MPI_OPTIONS -n 2 ./1_Simple/logger_example.py
141 changes: 83 additions & 58 deletions src/simsopt/field/normal_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,6 @@ def __init__(self, nfp=1, stellsym=True, mpol=1, ntor=0,
self.stellsym = stellsym
self.mpol = mpol
self.ntor = ntor

if vns is None:
vns = np.zeros((self.mpol + 1, 2 * self.ntor + 1))

if not self.stellsym and vnc is None:
vnc = np.zeros((self.mpol + 1, 2 * self.ntor + 1))

if surface is None:
surface = SurfaceRZFourier(nfp=nfp, stellsym=stellsym, mpol=mpol, ntor=ntor)
Expand All @@ -71,10 +65,33 @@ def __init__(self, nfp=1, stellsym=True, mpol=1, ntor=0,
else:
self.ndof = 2 * (self.ntor + self.mpol * (2 * self.ntor + 1)) + 1

self._vns = vns
self._vnc = vnc

dofs = self.get_dofs()
if vns is None:
vns = np.zeros((self.mpol + 1, 2 * self.ntor + 1))
if vnc is None:
vnc = np.zeros((self.mpol + 1, 2 * self.ntor + 1))

dofs = np.zeros((self.ndof,))

# Populate dofs array
vns_shape = vns.shape
input_mpol = int(vns_shape[0]-1)
input_ntor = (vns_shape[1]-1)//2

if not self.stellsym:
assert vns.shape == vnc.shape
for mm in range(0, self.mpol+1):
for nn in range(-self.ntor, self.ntor+1):
if mm == 0 and nn < 0: continue
if mm > input_mpol: continue
if nn > input_ntor: continue

if not (mm == 0 and nn == 0):
ii = self.get_index_in_dofs(mm, nn, even=False)
dofs[ii] = vns[mm, input_ntor+nn]

if not self.stellsym:
ii = self.get_index_in_dofs(mm, nn, even=True)
dofs[ii] = vnc[mm, input_ntor+nn]

Optimizable.__init__(
self,
Expand All @@ -83,15 +100,48 @@ def __init__(self, nfp=1, stellsym=True, mpol=1, ntor=0,

@property
def vns(self):
return self._vns
vns_local = np.zeros((self.mpol + 1, 2 * self.ntor + 1))

input_mpol = int(vns_local.shape[0]-1)
input_ntor = (vns_local.shape[1]-1)//2
for mm in range(0, self.mpol+1):
for nn in range(-self.ntor, self.ntor+1):
if mm == 0 and nn < 0: continue
if mm > input_mpol: continue
if nn > input_ntor: continue

if not (mm == 0 and nn == 0):
ii = self.get_index_in_dofs(mm, nn, even=False)
vns_local[mm, input_ntor+nn] = self.local_full_x[ii]

# Don';'t allow changes to vns. Use set_vns() instead
vns_local.flags.writeable = False
return vns_local

@vns.setter
def vns(self, value):
raise AttributeError('Change Vns using set_vns() or set_vns_asarray()')

@property
def vnc(self):
return self._vnc
if self.stellsym:
raise AttributeError('Vnc is not available for stellarator symmetric fields')
vnc_local = np.zeros((self.mpol + 1, 2 * self.ntor + 1))

input_mpol = int(vnc_local.shape[0]-1)
input_ntor = (vnc_local.shape[1]-1)//2
for mm in range(0, self.mpol+1):
for nn in range(-self.ntor, self.ntor+1):
if mm == 0 and nn < 0: continue
if mm > input_mpol: continue
if nn > input_ntor: continue

ii = self.get_index_in_dofs(mm, nn, even=True)
vnc_local[mm, input_ntor+nn] = self.local_full_x[ii]

# Don';'t allow changes to vnc. Use set_vnc() instead
vnc_local.flags.writeable = False
return vnc_local

@vnc.setter
def vnc(self, value):
Expand Down Expand Up @@ -163,32 +213,6 @@ def from_spec_object(cls, spec):
normal_field = cls(**input_dict)

return normal_field

def get_dofs(self):
"""
get DOFs from vns and vnc
"""
# Pack in a single array
dofs = np.zeros((self.ndof,))

# Populate dofs array
vns_shape = self.vns.shape
input_mpol = int(vns_shape[0]-1)
input_ntor = int((vns_shape[1]-1)/2)
for mm in range(0, self.mpol+1):
for nn in range(-self.ntor, self.ntor+1):
if mm == 0 and nn < 0: continue
if mm > input_mpol: continue
if nn > input_ntor: continue

if not (mm == 0 and nn == 0):
ii = self.get_index_in_dofs(mm, nn, even=False)
dofs[ii] = self.vns[mm, input_ntor+nn]

if not self.stellsym:
ii = self.get_index_in_dofs(mm, nn, even=True)
dofs[ii] = self.vnc[mm, input_ntor+nn]
return dofs

def get_index_in_array(self, m, n, mpol=None, ntor=None):
"""
Expand Down Expand Up @@ -264,10 +288,9 @@ def get_vns(self, m, n):

def set_vns(self, m, n, value):
self.check_mn(m, n)
i,j = self.get_index_in_array(m, n)
self._vns[i,j] = value
dofs = self.get_dofs()
self.local_full_x = dofs
ii = self.get_index_in_dofs(m, n)
self.local_full_x[ii] = value
self.recompute_bell()

def get_vnc(self, m, n):
self.check_mn(m, n)
Expand All @@ -279,13 +302,13 @@ def get_vnc(self, m, n):

def set_vnc(self, m, n, value):
self.check_mn(m, n)
i,j = self.get_index_in_array(m, n)
if self.stellsym:
raise ValueError('Stellarator symmetric has no vnc')
else:
self._vnc[i,j] = value
dofs = self.get_dofs()
self.local_full_x = dofs
ii = self.get_index_in_dofs(m, n, even=True)
self.local_full_x[ii] = value
self.recompute_bell()


def check_mn(self, m, n):
if m < 0 or m > self.mpol:
Expand Down Expand Up @@ -409,7 +432,8 @@ def get_vns_asarray(self, mpol=None, ntor=None):
elif ntor > self.ntor:
raise ValueError('ntor out of bound')

vns = self.vns
vns = self.vns.copy()
vns.flags.writeable = True

return vns[0:mpol, self.ntor-ntor:self.ntor+ntor+1]

Expand All @@ -427,9 +451,8 @@ def get_vnc_asarray(self, mpol=None, ntor=None):
elif ntor > self.ntor:
raise ValueError('ntor out of bound')

vnc = self.vnc
if vnc is None:
vnc = np.zeros((mpol, 2*ntor+1))
vnc = self.vnc.copy()
vnc.flags.writeable = True

return vnc[0:mpol, self.ntor-ntor:self.ntor+ntor+1]

Expand Down Expand Up @@ -465,9 +488,10 @@ def set_vns_asarray(self, vns, mpol=None, ntor=None):
elif ntor > self.ntor:
raise ValueError('ntor out of bound')

self._vns = vns[0:mpol, self.ntor-ntor:self.ntor+ntor+1]
dofs = self.get_dofs()
self.local_full_x = dofs
for i in range(mpol):
for j in range(-ntor, ntor+1):
if i == 0 and j <= 0: continue
self.set_vns(i, j, vns[i, self.ntor+j])

def set_vnc_asarray(self, vnc, mpol=None, ntor=None):
"""
Expand All @@ -482,10 +506,11 @@ def set_vnc_asarray(self, vnc, mpol=None, ntor=None):
ntor = self.ntor
elif ntor > self.ntor:
raise ValueError('ntor out of bound')

self._vnc = vnc[0:mpol, self.ntor-ntor:self.ntor+ntor+1]
dofs = self.get_dofs()
self.local_full_x = dofs

for i in range(mpol):
for j in range(-ntor, ntor+1):
if i == 0 and j < 0: continue
self.set_vnc(i, j, vnc[i, self.ntor+j])

def set_vns_vnc_asarray(self, vns, vnc, mpol=None, ntor=None):
"""
Expand All @@ -503,7 +528,7 @@ def set_vns_vnc_asarray(self, vns, vnc, mpol=None, ntor=None):

self.set_vns_asarray(vns, mpol, ntor)
self.set_vnc_asarray(vnc, mpol, ntor)

def get_real_space_field(self):
"""
Fourier transform the field and get the real-space values of the normal component of the externally
Expand Down
88 changes: 88 additions & 0 deletions src/simsopt/geo/surfacerzfourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,90 @@ def _validate_mn(self, m, n):
raise IndexError('n must be <= ntor')
if n < -self.ntor:
raise IndexError('n must be >= -ntor')

@property
def rc_array(self):
"""
rc (np.ndarray): Array of cosine coefficients for the radial coordiante.
"""
return self.rc

@property
def zs_array(self):
"""
zs (np.ndarray): Array of sine coefficients for the Z coordiante.
"""
return self.zs

@property
def zc_array(self):
"""
zc (np.ndarray): Array of cosine coefficients for the Z coordiante.
"""
if self.stellsym:
raise ValueError(
'zc does not exist for this stellarator-symmetric surface.')
return self.zc

@property
def rs_array(self):
"""
rs (np.ndarray): Array of sine coefficients for the R coordiante.
"""
if self.stellsym:
raise ValueError(
'rs does not exist for this stellarator-symmetric surface.')
return self.rs

@rc_array.setter
def rc_array(self, rc):
"""
Setter to overwrite the entire rc array, triggering the recompute_bell.
To overwrite individual coefficients, use the set_rc method instead.
"""
if rc.shape != (self.mpol + 1, self.ntor * 2 + 1):
raise ValueError('rc must have shape (mpol+1, 2*ntor+1)')
self.rc = rc
self.local_full_x = self.get_dofs()

@zs_array.setter
def zs_array(self, zs):
"""
Setter to overwrite the entire zs array, triggering the recompute_bell.
To overwrite individual coefficients, use the set_zs method instead.
"""
if zs.shape != (self.mpol + 1, self.ntor * 2 + 1):
raise ValueError('zs must have shape (mpol+1, 2*ntor+1)')
self.zs = zs
self.local_full_x = self.get_dofs()

@rs_array.setter
def rs_array(self, rs):
"""
Setter to overwrite the entire rs array, triggering the recompute_bell.
To overwrite individual coefficients, use the set_rs method instead.
"""
if rs.shape != (self.mpol + 1, self.ntor * 2 + 1):
raise ValueError('rs must have shape (mpol+1, 2*ntor+1)')
if self.stellsym:
raise ValueError(
'rs does not exist for this stellarator-symmetric surface.')
self.rs = rs
self.local_full_x = self.get_dofs()

@zc_array.setter
def zc_array(self, zc):
"""
Setter to overwrite the entire zc array, triggering the recompute_bell.
To overwrite individual coefficients, use the set_zc method instead.
"""
if zc.shape != (self.mpol + 1, self.ntor * 2 + 1):
raise ValueError('zc must have shape (mpol+1, 2*ntor+1)')
if self.stellsym:
raise ValueError(
'zc does not exist for this stellarator-symmetric surface.')
self.zc = zc
self.local_full_x = self.get_dofs()

def get_rc(self, m, n):
"""
Expand Down Expand Up @@ -586,6 +670,7 @@ def get_zs(self, m, n):
def set_rc(self, m, n, val):
"""
Set a particular `rc` Parameter.
Modifying the `rc` array directly is discouraged, since it doesn't trigger the recompute_bell().
"""
self._validate_mn(m, n)
self.rc[m, n + self.ntor] = val
Expand All @@ -594,6 +679,7 @@ def set_rc(self, m, n, val):
def set_rs(self, m, n, val):
"""
Set a particular `rs` Parameter.
Modifying the `rs` array directly is discouraged, since it doesn't trigger the recompute_bell().
"""
if self.stellsym:
return ValueError(
Expand All @@ -605,6 +691,7 @@ def set_rs(self, m, n, val):
def set_zc(self, m, n, val):
"""
Set a particular `zc` Parameter.
Modifying the `zc` array directly is discouraged, since it doesn't trigger the recompute_bell().
"""
if self.stellsym:
return ValueError(
Expand All @@ -616,6 +703,7 @@ def set_zc(self, m, n, val):
def set_zs(self, m, n, val):
"""
Set a particular `zs` Parameter.
Modifying the `zs` array directly is discouraged, since it doesn't trigger the recompute_bell().
"""
self._validate_mn(m, n)
self.zs[m, n + self.ntor] = val
Expand Down
3 changes: 2 additions & 1 deletion src/simsopt/mhd/profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def f(self, lvol: int):
if (lvol < 0).any():
raise ValueError('lvol should be larger or equal than zero')
if (lvol >= self.local_full_x.size).any():
raise ValueError('lvol should be smaller than Mvol')
raise ValueError('lvol out of bounds for the size of this profile. \
Attempted to access index {} of {}'.format(lvol, self.local_full_x.size))

# Return value
return self.local_full_x[lvol]
Expand Down
Loading
Loading