Skip to content

Commit

Permalink
Merge pull request #83 from karllark/fit_model_updates
Browse files Browse the repository at this point in the history
Updates to make fitting using a model more robust
  • Loading branch information
karllark authored Aug 25, 2021
2 parents b801194 + a7a3858 commit d88e8e8
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 116 deletions.
215 changes: 118 additions & 97 deletions measure_extinction/extdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,34 +414,22 @@ def calc_elx(self, redstar, compstar, rel_band="V"):
else:
self.calc_elx_spectra(redstar, compstar, cursrc, rel_band=rel_band)

def trans_elv_elvebv(self):
def calc_EBV(self):
"""
Transform E(lambda-V) to E(lambda -V)/E(B-V) by
normalizing by E(lambda-B).
Parameters
----------
Calculate E(B-V) from the observed extinction curve
Returns
-------
Updates self.(exts, uncs)
Updates self.columns["EBV"]
"""
if self.type_rel_band != "V":
warnings.warn("attempt to normalize a non-elv curve with ebv", UserWarning)
# determine the index for the B band
dwaves = np.absolute(self.waves["BAND"] - 0.438 * u.micron)
sindxs = np.argsort(dwaves)
bindx = sindxs[0]
if dwaves[bindx] > 0.02 * u.micron:
warnings.warn("no B band measurement in E(l-V)", UserWarning)
else:
# determine the index for the B band
dwaves = np.absolute(self.waves["BAND"] - 0.438 * u.micron)
sindxs = np.argsort(dwaves)
bindx = sindxs[0]
if dwaves[bindx] > 0.02 * u.micron:
warnings.warn("no B band measurement in E(l-V)", UserWarning)
else:
# normalize each portion of the extinction curve
ebv = self.exts["BAND"][bindx]
for curname in self.exts.keys():
self.exts[curname] /= ebv
self.uncs[curname] /= ebv
self.type = "elvebv"
self.columns["EBV"] = self.exts["BAND"][bindx]

def calc_AV(self, akav=0.112):
"""
Expand All @@ -468,7 +456,7 @@ def calc_AV(self, akav=0.112):
else:
dwaves = np.absolute(self.waves["BAND"] - 2.19 * u.micron)
kindx = dwaves.argmin()
if dwaves[kindx] > 0.02 * u.micron:
if dwaves[kindx] > 0.04 * u.micron:
warnings.warn(
"No K band measurement available in E(lambda-V)!", stacklevel=2
)
Expand All @@ -492,16 +480,46 @@ def calc_RV(self):
self.calc_AV()
av = _get_column_val(self.columns["AV"])

# obtain E(B-V)
dwaves = np.absolute(self.waves["BAND"] - 0.438 * u.micron)
bindx = dwaves.argmin()
if dwaves[bindx] > 0.02 * u.micron:
# obtain or calculate E(B-V)
if "EBV" not in self.columns.keys():
self.calc_EBV()
ebv = _get_column_val(self.columns["EBV"])

self.columns["RV"] = av / ebv

def trans_elv_elvebv(self, ebv=None):
"""
Transform E(lambda-V) to E(lambda -V)/E(B-V) by
normalizing by E(B-V)).
Parameters
----------
ebv : float [default = None]
value of E(B-V) to use - otherwise take it from the columns of the object
or calculate it from the E(lambda-V) curve
Returns
-------
Updates self.(exts, uncs)
"""
if self.type_rel_band != "V":
warnings.warn(
"attempt to normalize a non E(lambda-V) curve with E(B-V)", UserWarning
)
elif self.type != "elx":
warnings.warn(
"No B band measurement available in E(lambda-V)!", stacklevel=2
"attempt to normalize a non E(lambda-V) curve with E(B-V)", UserWarning
)
else:
ebv = self.exts["BAND"][bindx]
self.columns["RV"] = av / ebv
if ebv is None:
if "EBV" not in self.columns.keys():
self.calc_EBV()
ebv = _get_column_val(self.columns["EBV"])

for curname in self.exts.keys():
self.exts[curname] /= ebv
self.uncs[curname] /= ebv
self.type = "elvebv"

def trans_elv_alav(self, av=None, akav=0.112):
"""
Expand All @@ -525,6 +543,10 @@ def trans_elv_alav(self, av=None, akav=0.112):
warnings.warn(
"attempt to normalize a non-E(lambda-V) curve with A(V)", UserWarning
)
elif self.type != "elx":
warnings.warn(
"attempt to normalize a non E(lambda-V) curve with A(V)", UserWarning
)
else:
if av is None:
if "AV" not in self.columns.keys():
Expand All @@ -534,7 +556,7 @@ def trans_elv_alav(self, av=None, akav=0.112):
self.exts[curname] = (self.exts[curname] / av) + 1
self.uncs[curname] /= av
# update the extinction curve type
self.type = "alav"
self.type = "alax"

def get_fitdata(
self,
Expand Down Expand Up @@ -681,60 +703,32 @@ def save(
else:
print(ckey + " not supported for saving extcurves")
else: # save the column info if available in the extdata object
if "AV" in self.columns.keys():
hname.append("AV")
hcomment.append("V-band extinction A(V)")
if isinstance(self.columns["AV"], tuple):
hval.append(self.columns["AV"][0])
if len(self.columns["AV"]) == 2:
hname.append("AV_UNC")
hcomment.append("A(V) uncertainty")
hval.append(self.columns["AV"][1])
elif len(self.columns["AV"]) == 3:
hname.append("AV_MUNC")
hcomment.append("A(V) lower uncertainty")
hval.append(self.columns["AV"][1])
hname.append("AV_PUNC")
hcomment.append("A(V) upper uncertainty")
hval.append(self.columns["AV"][2])
else:
hval.append(self.columns["AV"])
if "RV" in self.columns.keys():
hname.append("RV")
hcomment.append("total-to-selective extintion R(V)")
if isinstance(self.columns["RV"], tuple):
hval.append(self.columns["RV"][0])
if len(self.columns["RV"]) == 2:
hname.append("RV_UNC")
hcomment.append("R(V) uncertainty")
hval.append(self.columns["RV"][1])
elif len(self.columns["RV"]) == 3:
hname.append("RV_MUNC")
hcomment.append("R(V) lower uncertainty")
hval.append(self.columns["RV"][1])
hname.append("RV_PUNC")
hcomment.append("R(V) upper uncertainty")
hval.append(self.columns["RV"][2])
else:
hval.append(self.columns["RV"])
if "EBV" in self.columns.keys():
hname.append("EBV")
hcomment.append("color excess E(B-V)")
if isinstance(self.columns["EBV"], tuple):
hval.append(self.columns["EBV"][0])
if len(self.columns["EBV"]) == 2:
hname.append("EBV_UNC")
hcomment.append("E(B-V) uncertainty")
hval.append(self.columns["EBV"][1])
elif len(self.columns["EBV"]) == 3:
hname.append("EBV_MUNC")
hcomment.append("E(B-V) lower uncertainty")
hval.append(self.columns["EBV"][1])
hname.append("EBV_PUNC")
hcomment.append("E(B-V) upper uncertainty")
hval.append(self.columns["EBV"][2])
else:
hval.append(self.columns["EBV"])
colkeys = ["AV", "RV", "EBV", "LOGHI"]
colinfo = [
"V-band extinction A(V)",
"total-to-selective extintion R(V)",
"color excess E(B-V)",
"log10 of the HI column density N(HI)",
]
for i, ckey in enumerate(colkeys):
if ckey in self.columns.keys():
hname.append(f"{ckey}")
hcomment.append(f"{colinfo[i]}")
if isinstance(self.columns[f"{ckey}"], tuple):
hval.append(self.columns[f"{ckey}"][0])
if len(self.columns[f"{ckey}"]) == 2:
hname.append(f"{ckey}_UNC")
hcomment.append(f"{ckey} uncertainty")
hval.append(self.columns[f"{ckey}"][1])
elif len(self.columns[f"{ckey}"]) == 3:
hname.append(f"{ckey}_MUNC")
hcomment.append(f"{ckey} lower uncertainty")
hval.append(self.columns[f"{ckey}"][1])
hname.append(f"{ckey}_PUNC")
hcomment.append(f"{ckey} upper uncertainty")
hval.append(self.columns[f"{ckey}"][2])
else:
hval.append(self.columns[f"{ckey}"])

# legacy save param keywords
if fm90_best_params is not None:
Expand Down Expand Up @@ -805,7 +799,9 @@ def save(
# write the portions of the extinction curve from each dataset
# individual extensions so that the full info is perserved
for curname in self.exts.keys():
col1 = fits.Column(name="WAVELENGTH", format="E", array=self.waves[curname])
col1 = fits.Column(
name="WAVELENGTH", format="E", array=self.waves[curname].to(u.micron)
)
col2 = fits.Column(name="EXT", format="E", array=self.exts[curname])
col3 = fits.Column(name="UNC", format="E", array=self.uncs[curname])
col4 = fits.Column(name="NPTS", format="E", array=self.npts[curname])
Expand All @@ -827,7 +823,11 @@ def save(

# write the fitted model if available
if self.model:
col1 = fits.Column(name="MOD_WAVE", format="E", array=self.model["waves"])
if isinstance(self.model["waves"], u.Quantity):
outvals = self.model["waves"].to(u.micron)
else:
outvals = self.model["waves"]
col1 = fits.Column(name="MOD_WAVE", format="E", array=outvals)
col2 = fits.Column(name="MOD_EXT", format="E", array=self.model["exts"])
col3 = fits.Column(
name="RESIDUAL", format="E", array=self.model["residuals"]
Expand Down Expand Up @@ -914,8 +914,9 @@ def read(self, ext_filename):
for curkey in column_keys:
if pheader.get(curkey):
if pheader.get("%s_UNC" % curkey):
tunc = float(pheader.get(curkey)), float(
pheader.get("%s_UNC" % curkey)
tunc = (
float(pheader.get(curkey)),
float(pheader.get("%s_UNC" % curkey)),
)
elif pheader.get("%s_PUNC" % curkey):
tunc = (
Expand Down Expand Up @@ -1090,6 +1091,7 @@ def plot(
legend_key=None,
legend_label=None,
fontsize=None,
model=False,
):
"""
Plot an extinction curve
Expand Down Expand Up @@ -1149,6 +1151,9 @@ def plot(
fontsize : int [default=None]
fontsize for plot
model : boolean
if set and the model exists, plot it
"""
if alax:
# transform the extinctions from E(lambda-V) to A(lambda)/A(V)
Expand All @@ -1159,7 +1164,8 @@ def plot(
if curtype in exclude:
continue
# replace extinction values by NaNs for wavelength regions that need to be excluded from the plot
self.exts[curtype][self.npts[curtype] == 0] = np.nan
if np.sum(self.npts[curtype] == 0) > 0:
self.exts[curtype][self.npts[curtype] == 0] = np.nan
x = self.waves[curtype].to(u.micron).value
y = self.exts[curtype]
yu = self.uncs[curtype]
Expand Down Expand Up @@ -1206,6 +1212,24 @@ def plot(
fontsize=fontsize,
)

# plot the model if desired
if model:
x = self.model["waves"]
if wavenum:
x = 1.0 / x
y = self.model["exts"]

y = y / normval + yoffset

pltax.plot(x, y, "-", color=color, alpha=alpha)

if wavenum:
xtitle = r"$1/\lambda$ $[\mu m^{-1}]$"
else:
xtitle = r"$\lambda$ $[\mu m]$"
pltax.set_xlabel(xtitle)
pltax.set_ylabel(self._get_ext_ytitle())

def fit_band_ext(self):
"""
Fit the observed NIR extinction curve with a powerlaw model, based on the band data between 1 and 40 micron
Expand Down Expand Up @@ -1296,13 +1320,10 @@ def fit_spex_ext(
bounds={"amplitude": amp_bounds, "alpha": index_bounds},
)
else:
func = (
PowerLaw1D(
fixed={"x_0": True},
bounds={"amplitude": amp_bounds, "alpha": index_bounds},
)
| AxAvToExv(bounds={"Av": AV_bounds})
)
func = PowerLaw1D(
fixed={"x_0": True},
bounds={"amplitude": amp_bounds, "alpha": index_bounds},
) | AxAvToExv(bounds={"Av": AV_bounds})

fit = LevMarLSQFitter()
fit_result = fit(func, waves, exts, weights=1 / exts_unc)
Expand Down
4 changes: 2 additions & 2 deletions measure_extinction/tests/test_plot_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_plot_extinction():
message + "with the default settings, has failed."
)

assert os.path.isfile(data_path + starpair.lower() + "_ext_alav.pdf"), (
assert os.path.isfile(data_path + starpair.lower() + "_ext_alax.pdf"), (
message + "in A(lambda)/A(V), has failed."
)

Expand All @@ -92,7 +92,7 @@ def test_plot_extinction():
message + "in one figure, with the default settings, has failed."
)

assert os.path.isfile(data_path + "all_ext_alav.pdf"), (
assert os.path.isfile(data_path + "all_ext_alax.pdf"), (
message + "in one figure, in A(lambda)/A(V), has failed."
)

Expand Down
Loading

0 comments on commit d88e8e8

Please sign in to comment.