Skip to content

Commit

Permalink
Merge pull request #712 from mgullik/work_on_mask
Browse files Browse the repository at this point in the history
Fix apply_mask
  • Loading branch information
matteobachetti authored Apr 13, 2023
2 parents 43a86cc + 68fcb42 commit e789ddf
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 9 deletions.
2 changes: 2 additions & 0 deletions docs/changes/712.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
The method apply_gtis of the class Lightcurve is applied to all the attributes of the class Lightcurve.
This works for both inplace=True and inplace=False
21 changes: 19 additions & 2 deletions stingray/lightcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -1897,18 +1897,35 @@ def apply_mask(self, mask, inplace=False):
self._mask = self._n = None
if inplace:
new_ev = self
# If they don't exist, they get set
self.counts, self.counts_err
# eliminate possible conflicts
self._countrate = self._countrate_err = None
# Set time, counts and errors
self._time = self._time[mask]
self._counts = self._counts[mask]
if self._counts_err is not None:
self._counts_err = self._counts_err[mask]
else:
new_ev = Lightcurve(
time=self.time[mask], counts=self.counts[mask], skip_checks=True, gti=self.gti
)
if self._counts_err is not None:
new_ev.counts_err = self.counts_err[mask]
for attr in self.meta_attrs():
try:
setattr(new_ev, attr, copy.deepcopy(getattr(self, attr)))
except AttributeError:
continue

for attr in array_attrs:
if hasattr(self, "_" + attr) or attr in ["time", "counts"]:
if hasattr(self, "_" + attr) or attr in [
"time",
"counts",
"counts_err",
"_time",
"_counts",
"_counts_err",
]:
continue
if hasattr(self, attr) and getattr(self, attr) is not None:
setattr(new_ev, attr, copy.deepcopy(np.asarray(getattr(self, attr))[mask]))
Expand Down
36 changes: 29 additions & 7 deletions stingray/tests/test_lightcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,17 +1392,39 @@ def test_change_mjdref(self):
lc_new = self.lc.change_mjdref(57000)
assert lc_new.mjdref == 57000

def testapply_gtis(self):
@pytest.mark.parametrize("inplace", [True, False])
def test_apply_gtis(self, inplace):
time = np.arange(150)
count = np.zeros_like(time) + 3
lc = Lightcurve(time, count, gti=[[-0.5, 150.5]])
lc.gti = [[-0.5, 2.5], [12.5, 14.5]]
lc.apply_gtis()
assert lc.n == 5
assert np.allclose(lc.time, np.array([0, 1, 2, 13, 14]))
lc.gti = [[-0.5, 10.5]]
lc.apply_gtis()
assert np.allclose(lc.time, np.array([0, 1, 2]))
lc_new = lc.apply_gtis(inplace=inplace)
if inplace == True:
assert lc_new is lc
assert lc_new.n == 5
for attr in lc_new.array_attrs():
assert len(getattr(lc_new, attr)) == 5
assert np.allclose(lc_new.time, np.array([0, 1, 2, 13, 14]))

lc_new.gti = [[-0.5, 10.5]]
lc_new2 = lc_new.apply_gtis(inplace=inplace)
assert np.allclose(lc_new2.time, np.array([0, 1, 2]))

@pytest.mark.parametrize("inplace", [True, False])
def test_apply_gtis_lc_rate(self, inplace):
dt = 1
time = np.arange(1, 10, dt)
countrate = np.zeros_like(time) + 5
# create the lightcurve from countrare
lc_rate = Lightcurve(time, counts=countrate, input_counts=False, gti=[[-0.5, 10.5]])
lc_rate.gti = [[-0.5, 2.5]]
lc_rate_new = lc_rate.apply_gtis(inplace=inplace)
if inplace == True:
assert lc_rate_new is lc_rate
assert lc_rate_new.n == 2
for attr in lc_rate_new.array_attrs():
assert len(getattr(lc_rate_new, attr)) == 2
assert np.allclose(lc_rate_new.time, np.array([1, 2]))

def test_eq_operator(self):
time = [1, 2, 3]
Expand Down

0 comments on commit e789ddf

Please sign in to comment.