Skip to content

Commit

Permalink
Refactor to utils function
Browse files Browse the repository at this point in the history
  • Loading branch information
TheSkyentist committed Jun 29, 2024
1 parent df3e9ba commit b50ae4d
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 17 deletions.
5 changes: 1 addition & 4 deletions grizli/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2516,7 +2516,6 @@ def xfit_redshift(self, prior=None,
from numpy.polynomial import Polynomial
from scipy.stats import t as student_t
from scipy.special import huber
from scipy.signal import find_peaks

if isinstance(zr, int):
if zr == 0:
Expand Down Expand Up @@ -2657,9 +2656,7 @@ def xfit_redshift(self, prior=None,

if len(zgrid) > 1:
chi2_rev[chi2_rev < 0] = 0
peak_threshold = 0.4 # Threshold for peak finding
peak_height = peak_threshold*(chi2_rev.max()-chi2_rev.min())+chi2_rev.min() # Get absolute height
indexes,_ = find_peaks(chi2_rev,height=peak_height,distance=9)
indexes = utils.find_peaks(chi2_rev,threshold=0.4,min_dist=9)
num_peaks = len(indexes)
so = np.argsort(chi2_rev[indexes])
indexes = indexes[so[::-1]]
Expand Down
6 changes: 1 addition & 5 deletions grizli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5255,12 +5255,8 @@ def init_poly_coeffs(self, poly_order=1, fit_background=True):
# print(utils.NO_NEWLINE + '{0:.4f} {1:9.1f}'.format(zgrid[i], chi2[i]))
#
# # peaks
# from scipy.signal import find_peaks
# chi2nu = (chi2.min()-chi2)/self.DoF
# chi2nu_mask = (chi2nu+0.01)*(chi2nu > -0.004)
# peak_threshold = 0.003 # Threshold for peak finding
# peak_height = peak_threshold*(chi2nu_mask.max()-chi2nu_mask.min())+chi2nu_mask.min() # Get absolute height
# indexes,_ = find_peaks(chi2nu_mask, height=peak_height, distance=21)
# indexes = utils.find_peaks((chi2nu+0.01)*(chi2nu > -0.004), threshold=0.003, min_dist=21)
# num_peaks = len(indexes)
# # plt.plot(zgrid, (chi2-chi2.min())/ self.DoF)
# # plt.scatter(zgrid[indexes], (chi2-chi2.min())[indexes]/ self.DoF, color='r')
Expand Down
10 changes: 2 additions & 8 deletions grizli/multifit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2812,21 +2812,15 @@ def fit_redshift(self, prior=None, poly_order=1, fwhm=1200,
print('First iteration: z_best={0:.4f}\n'.format(zgrid[iz]))

# peaks
from scipy.signal import find_peaks
# chi2nu = (chi2.min()-chi2)/self.DoF
# chi2nu_mask = (chi2nu+delta_chi2_threshold)*(chi2nu > -delta_chi2_threshold)
# peak_threshold = 0.3 # Threshold for peak finding
# height = peak_threshold*(chi2nu_mask.max()-chi2nu_mask.min())+chi2nu_mask.min() # Get absolute height
# indexes,_ = find_peaks(chi2nu_mask, height=peak_height, distance=21)
# indexes = utils.find_peaks((chi2nu+delta_chi2_threshold)*(chi2nu > - delta_chi2_threshold), threshold=0.3, min_dist=21)

chi2_rev = (chi2_poly - chi2)/self.DoF
if chi2_poly < (chi2.min() + 9):
chi2_rev = (chi2.min() + 16 - chi2)/self.DoF

chi2_rev[chi2_rev < 0] = 0
peak_threshold = 0.4 # Threshold for peak finding
peak_height = peak_threshold*(chi2_rev.max()-chi2_rev.min())+chi2_rev.min() # Get absolute height
indexes,_ = find_peaks(chi2_rev, height=peak_height, distance=9)
indexes = utils.find_peaks(chi2_rev, threshold=0.4, min_dist=9)
num_peaks = len(indexes)

if False:
Expand Down
35 changes: 35 additions & 0 deletions grizli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12898,3 +12898,38 @@ def patch_photutils():
fp.writelines(lines)

print(f"Patch applied to {the_file}!")

def find_peaks(signal,threshold=0.5,min_dist=1):
"""
Find peaks in a signal using `scipy.signal.find_peaks`
Rescales input based on PeakUtils implementation
Parameters
----------
signal : array-like
The input signal
threshold : float
The relative threshold for peak detection
min_dist : int
The minimum distance between peaks
e.g. a distance of 1 are adjacent peaks
differs from the PeakUtils implementation where adjacent peaks are seperated by 0
Returns
-------
peaks : array-like
The indices of the peaks in the signal
"""

# Import required packages
import scipy.signal

# Calculate absolute height
smin = signal.min() # Only calculate this once
height = threshold*(signal.max()-smin)+smin

# Find peaks
peaks,_ = scipy.signal.find_peaks(signal,height=height,distance=min_dist)

return peaks

0 comments on commit b50ae4d

Please sign in to comment.