diff --git a/proseco/guide.py b/proseco/guide.py index 42884d96..b39f1645 100644 --- a/proseco/guide.py +++ b/proseco/guide.py @@ -385,6 +385,16 @@ def select_catalog(self, stage_cands): """ self.log(f"Selecting catalog from {len(stage_cands)} stage-selected stars") + forced = np.in1d(stage_cands["id"], self.include_ids) + n_forced = np.count_nonzero(forced) + if self.n_guide == n_forced: + self.log("All guide stars are force-included") + return stage_cands[forced] + + # This subtracts the number of unique force-included stars as they are not + # part of the combination selection process. + choose_m = min(len(stage_cands), self.n_guide) - n_forced + def index_combinations(n, m): seen = set() for n_tmp in range(m, n + 1): @@ -393,21 +403,18 @@ def index_combinations(n, m): seen.add(comb) yield comb - # Set a dictionary to save the first combination of that satisfies N tests. + # Set a dictionary to save the first combination that satisfies N tests. # This will be used if no combination satisfies all the tests. select_results = {} - # I should come back to this and see if the "min" is needed or if the combinations - # code runs fine even if we have fewer-than-expected stage_cands to start - choose_m = min(len(stage_cands), self.n_guide) - n_tries = 0 + cands_not_forced = stage_cands[~forced] for n_tries, comb in enumerate( - index_combinations(len(stage_cands), choose_m), start=1 + index_combinations(len(cands_not_forced), choose_m), start=1 ): - cands = stage_cands[ - list(comb) - ] # (note that [(1,2)] is not the same as list((1,2)) + all_ids = list(self.include_ids) + list(cands_not_forced["id"][list(comb)]) + cands = stage_cands[np.in1d(stage_cands["id"], all_ids)] + # (note that [(1,2)] is not the same as list((1,2)) n_pass, n_tests = run_select_checks( cands ) # This function knows how many tests get run diff --git a/proseco/tests/test_guide.py b/proseco/tests/test_guide.py index 0c916be9..e3317727 100644 --- a/proseco/tests/test_guide.py +++ b/proseco/tests/test_guide.py @@ -22,6 +22,7 @@ get_ax_range, get_guide_catalog, get_pixmag_for_offset, + run_select_checks, ) from ..report_guide import make_report from .test_common import DARK40, OBS_INFO, STD_INFO, mod_std_info @@ -609,6 +610,45 @@ def test_guides_include_bad(): assert "cannot include star id=20" in str(err) +def test_guides_include_close(): + """ + Test force include stars where they would not be selected due to + clustering. + """ + stars = StarsTable.empty() + + stars.add_fake_constellation( + mag=[7.0, 7.0, 7.0, 7.0, 7.0], id=[25, 26, 27, 28, 29], size=2000, n_stars=5 + ) + + stars.add_fake_star(mag=11.0, yang=100, zang=100, id=21) + stars.add_fake_star(mag=11.0, yang=-100, zang=-100, id=22) + stars.add_fake_star(mag=11.0, yang=100, zang=-100, id=23) + stars.add_fake_star(mag=11.0, yang=-100, zang=100, id=24) + + cat1 = get_guide_catalog(**mod_std_info(n_guide=5), stars=stars) + + # Run the cluster checks and confirm all 3 pass + cat1_pass, _ = run_select_checks(cat1) + assert cat1_pass == 3 + + # Confirm that only bright stars are used + assert np.count_nonzero(cat1["mag"] == 7.0) == 5 + + # Force include the faint 4 stars that are also close together + include_ids = [21, 22, 23, 24] + cat2 = get_guide_catalog( + **mod_std_info(n_guide=5), stars=stars, include_ids_guide=include_ids + ) + + # Run the cluster checks and confirm all 3 fail + cat2_pass, _ = run_select_checks(cat2) + assert cat2_pass == 0 + assert np.all(np.in1d(include_ids, cat2["id"])) + # And confirm that only one of the bright stars is used + assert np.count_nonzero(cat2["mag"] == 7.0) == 1 + + @pytest.mark.parametrize("dither", dither_cases) def test_edge_star(dither): """