Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Devel #25

Merged
merged 2 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions EXPtools/utils/indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
This module provides functions to work with index-based calculations for spherical harmonics coef series.
"""
import numpy as np, math

def total_terms(l_max):
"""
Calculate the total number of terms up to a given maximum angular number.

Parameters:
l_max (int): The maximum angular quantum number.

Returns:
float: The total number of terms.
"""
return l_max * (l_max + 1) / 2 + l_max + 1 # Sigma(lmax+1) + 1 (the extra one for the 0th term)

def I(l, m):
"""
Calculate the index of a spherical harmonic element given the angular numbers l and m .

Parameters:
l (int): The angular number.
m (int): The magnetic quantum number, ranging from 0 to l.

Returns:
int: The index corresponding to the specified angular numbers.
"""
import math
assert isinstance(l, int) and isinstance(m, int), "l and m must be integers"
assert l >= 0, "l must be greater than 0"
assert abs(m) <= l, "m must be less than or equal to l"
return int(l * (l + 1) / 2) + abs(m)

def inverse_I(I):
"""
Calculate the angular numbers l and m given the index of a spherical harmonic element.

Parameters:
I (int): The index of the spherical harmonic element.

Returns:
tuple: A tuple containing the angular numbers (l, m).
"""
import math
assert isinstance(I, int) and I >=0, "I must be an interger greater than or equal to 0"
l = math.floor((-1 + math.sqrt(1 + 8 * I)) / 2) # Calculate l using the inverse of the formula
m = I - int(l * (l + 1) / 2) # Calculate m using the given formula
return l, m

def set_specific_lm_non_zero(data, lm_pairs_to_set):
"""
Sets specific (l, m) pairs in the input data to non-zero values.

Parameters:
data (np.ndarray): Input data, an array of complex numbers.
lm_pairs_to_set (list): List of tuples representing (l, m) pairs to set to non-zero.

Returns:
np.ndarray: An array with selected (l, m) pairs set to non-zero values.

Raises:
ValueError: If any of the provided (l, m) pairs are out of bounds for the input data.
"""
assert isinstance(lm_pairs_to_set, list), "lm_pairs_to_set must be a list"
for pair in lm_pairs_to_set:
assert isinstance(pair, tuple), "Each element in lm_pairs_to_set must be a tuple"
assert len(pair) == 2, "Each tuple in lm_pairs_to_set must contain two elements"
assert all(isinstance(x, int) for x in pair), "Each element in each tuple must be an integer"

# Get a zeros array of the same shape as the input data
arr_filt = np.zeros(data.shape, dtype=complex)

# Check if the provided (l, m) pairs are within the valid range
for l, m in lm_pairs_to_set:
if l < 0 or l >= data.shape[0] or m < 0 or m > l:
raise ValueError(f"Invalid (l, m) pair: ({l}, {m}). Out of bounds for the input data shape.")
arr_filt[I(l, m), :, :] = data[I(l, m), :, :]

return arr_filt
26 changes: 17 additions & 9 deletions EXPtools/visuals/visualize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os, sys, pickle, pyEXP
import os, sys, pickle, pyEXP
import numpy as np
import matplotlib.pyplot as plt

def make_basis_plot(basis, rvals, basis_props, **kwargs):
def make_basis_plot(basis, lmax=6, nmax=20,
savefile=None, nsnap='mean', y=0.92, dpi=200,
lrmin=0.5, lrmax=2.7, rnum=100):
Expand All @@ -24,16 +25,27 @@ def make_basis_plot(basis, lmax=6, nmax=20,
tuple: A tuple containing fig and ax.

"""
# Set up grid for plotting potential

lrmin = basis_props['rmin']
lrmax = basis_props['rmax']
rnum = basis_props['nbins']
lmax = basis_props['lmax']
nmax = basis_props['nmax']

halo_grid = basis.getBasis(lrmin, lrmax, rnum)
r = np.linspace(lrmin, lrmax, rnum)
r = np.power(10.0, r)

# Create subplots and plot potential for each l and n

ncols = (lmax-1)//5 + 1

# Create subplots and plot potential for each l and n
fig, ax = plt.subplots(lmax, 1, figsize=(10, 3*lmax), dpi=dpi,
sharex='col', sharey='row')
plt.subplots_adjust(wspace=0, hspace=0)
ax = ax.flatten()

ax = ax.flatten()

for l in range(lmax):
ax[l].set_title(f"$\ell = {l}$", y=0.8, fontsize=16)
for n in range(nmax):
Expand Down Expand Up @@ -174,11 +186,7 @@ def make_grid(gridtype, gridspecs, rgrid, representation='cartesian'):

else:
print('gridtype {} not implemented'.format(gridtype))







def return_fields_in_grid(basis, coefficients, times=[0],
projection='3D', proj_plane=0,
Expand Down
Loading