From e2dbe24fa6afd79ea82a79e2e1bc2fb6112a9119 Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Fri, 23 Apr 2021 05:47:42 -0700 Subject: [PATCH] [MRG] Adjust `Index.find` search protocol to support selective collection of matches (#1477) * begin refactoring 'categorize' * have the 'find' function for SBTs return signatures * fix majority of tests * comment & then fix test * torture the tests into working * split find and _find_nodes to take different kinds of functions * redo 'find' on index * refactor lca_db to use new find * refactor SBT to use new find * comment/cleanup * refactor out common code * fix up gather * use 'passes' properly * attempted cleanup * minor fixes * get a start on correct downsampling * adjust tree downsampling for regular minhashes, too * remove now-unused search functions in sbtmh * refactor categorize to use new find * cleanup and removal * remove redundant code in lca_db * remove redundant code in SBT * add notes * remove more unused code * refactor most of the test_sbt tests * fix one minor issue * fix jaccard calculation in sbt * check for compatibility of search fn and query signature * switch tests over to jaccard similarity, not containment * fix test * remove test for unimplemented LCA_Database.find method * document threshold change; update test * refuse to run abund signatures * flatten sigs internally for gather * reinflate abundances for saving * fix problem where sbt indices coudl be created with abund signatures * more * split flat and abund search * make ignore_abundance work again for categorize * turn off best-only, since it triggers on self-hits. * add test: 'sourmash index' flattens sigs * add note about something to test * fix typo; still broken tho * location is now a property * move search code into search.py * remove redundant scaled checking code * best-only now works properly for two tests * 'fix' tests by removing v1 and v2 SBT compatibility * simplify (?) downsampling code * require keyword args in MinHash.downsample(...) * fix bug with downsample * require keyword args in MinHash.downsample(...) * fix test to use proper downsampling, reverse order to match scaled * add test for revealed bug * remove unnecessary comment * flatten subject MinHash, too * add testme comment * clean up sbt find * clean up lca find * add IndexSearchResult namedtuple for search and gather results * add more tests for Index classes * add tests for subj & query num downsampling * tests for Index.search_abund * refactor a bit * refactor make_jaccard_search_query; start tests * even more tests * test collect, best_only * more search tests * remove unnec space * add minor comment * deal with status == None on SystemExit * upgrade and simplify categorize * restore test * merge * fix abundance search in SBT for categorize * code cleanup and refactoring; check for proper error messages * add explicit test for incompatible num * refactor MinHash.downsample * deal with status == None on SystemExit * fix test * fix comment mispelling * properly pass kwargs; fix search_sbt_index * add simple tests for SBT load and search API * allow arbitrary kwargs for LCA_DAtabase.find * add testing of passthru-kwargs * re-enable test * add notes to update docstrings * docstring updates * fix test * better tests for gather --save-unassigned * remove unnecessary check-me comment * clear out docstring * SBT search doesn't work on v1 and v2 SBTs b/c no min_n_below * fix my dumb mistake with gather * have the JaccardSearch.collect function take the matching signature * adjust search protocol to permit ignoring after finding * comment me * update comments, threshold setting Co-authored-by: Luiz Irber --- src/sourmash/index.py | 4 +- src/sourmash/lca/lca_db.py | 9 ++- src/sourmash/sbt.py | 9 ++- src/sourmash/search.py | 20 ++++-- tests/test_index.py | 126 +++++++++++++++++++++++++++++++++++++ tests/test_search.py | 4 +- 6 files changed, 157 insertions(+), 15 deletions(-) diff --git a/src/sourmash/index.py b/src/sourmash/index.py index 7992d7fe30..3fb131cb3d 100644 --- a/src/sourmash/index.py +++ b/src/sourmash/index.py @@ -125,8 +125,8 @@ def prepare_query(query_mh, subj_mh): if search_fn.passes(score): # note: here we yield the original signature, not the # downsampled minhash. - search_fn.collect(score) - yield subj, score + if search_fn.collect(score, subj): + yield subj, score def search_abund(self, query, *, threshold=None, **kwargs): """Return set of matches with angular similarity above 'threshold'. diff --git a/src/sourmash/lca/lca_db.py b/src/sourmash/lca/lca_db.py index 0a5fd8a57b..4af77b5a5b 100644 --- a/src/sourmash/lca/lca_db.py +++ b/src/sourmash/lca/lca_db.py @@ -462,9 +462,14 @@ def find(self, search_fn, query, **kwargs): score = search_fn.score_fn(query_size, shared_size, subj_size, total_size) + + # note to self: even with JaccardSearchBestOnly, this will + # still iterate over & score all signatures. We should come + # up with a protocol by which the JaccardSearch object can + # signal that it is done, or something. if search_fn.passes(score): - search_fn.collect(score) - yield subj, score + if search_fn.collect(score, subj): + yield subj, score @cached_property def lid_to_idx(self): diff --git a/src/sourmash/sbt.py b/src/sourmash/sbt.py index fed0bb7a62..af9617235e 100644 --- a/src/sourmash/sbt.py +++ b/src/sourmash/sbt.py @@ -436,9 +436,12 @@ def node_search(node, *args, **kwargs): if search_fn.passes(score): if is_leaf: # terminal node? keep. - results[node.data] = score - search_fn.collect(score) - return True + if search_fn.collect(score, node.data): + results[node.data] = score + return True + else: # it's a good internal node, keep. + return True + return False # & execute! diff --git a/src/sourmash/search.py b/src/sourmash/search.py index 461f0e7d88..0106e8de95 100644 --- a/src/sourmash/search.py +++ b/src/sourmash/search.py @@ -69,7 +69,8 @@ def make_gather_query(query_mh, threshold_bp): if threshold > 1.0: return None - search_obj = JaccardSearchBestOnly(SearchType.CONTAINMENT, threshold=threshold) + search_obj = JaccardSearchBestOnly(SearchType.CONTAINMENT, + threshold=threshold) return search_obj @@ -111,14 +112,20 @@ def check_is_compatible(self, sig): raise TypeError("this search cannot be done with an abund signature") def passes(self, score): - "Return True if this score meets or exceeds the threshold." + """Return True if this score meets or exceeds the threshold. + + Note: this can be used whenever a score or estimate is available + (e.g. internal nodes on an SBT). `collect(...)`, below, decides + whether a particular signature should be collected, and/or can + update the threshold (used for BestOnly behavior). + """ if score and score >= self.threshold: return True return False - def collect(self, score): - "Is this a potential match?" - pass + def collect(self, score, match_sig): + "Return True if this match should be collected." + return True def score_jaccard(self, query_size, shared_size, subject_size, total_size): "Calculate Jaccard similarity." @@ -142,9 +149,10 @@ def score_max_containment(self, query_size, shared_size, subject_size, class JaccardSearchBestOnly(JaccardSearch): "A subclass of JaccardSearch that implements best-only." - def collect(self, score): + def collect(self, score, match): "Raise the threshold to the best match found so far." self.threshold = max(self.threshold, score) + return True # generic SearchResult tuple. diff --git a/tests/test_index.py b/tests/test_index.py index 01cadb6cec..2227010eaa 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -14,6 +14,7 @@ from sourmash.sbt import SBT, GraphFactory, Leaf from sourmash.sbtmh import SigLeaf from sourmash import sourmash_args +from sourmash.search import JaccardSearch, SearchType import sourmash_tst_utils as utils @@ -1081,3 +1082,128 @@ def test_multi_index_load_from_pathlist_3_zipfile(c): mi = MultiIndex.load_from_pathlist(file_list) assert len(mi) == 7 + +## +## test a slightly outre version of JaccardSearch - this is a test of the +## JaccardSearch 'collect' protocol, in particular... +## + +class JaccardSearchBestOnly_ButIgnore(JaccardSearch): + "A class that ignores certain results, but still does all the pruning." + def __init__(self, ignore_list): + super().__init__(SearchType.JACCARD, threshold=0.1) + self.ignore_list = ignore_list + + # a collect function that _ignores_ things in the ignore_list + def collect(self, score, match): + print('in collect; current threshold:', self.threshold) + for q in self.ignore_list: + print('ZZZ', match, match.similarity(q)) + if match.similarity(q) == 1.0: + print('yes, found.') + return False + + # update threshold if not perfect match, which could help prune. + self.threshold = score + return True + + +def test_linear_index_gather_ignore(): + sig2 = utils.get_test_data('2.fa.sig') + sig47 = utils.get_test_data('47.fa.sig') + sig63 = utils.get_test_data('63.fa.sig') + + ss2 = sourmash.load_one_signature(sig2, ksize=31) + ss47 = sourmash.load_one_signature(sig47, ksize=31) + ss63 = sourmash.load_one_signature(sig63, ksize=31) + + # construct an index... + lidx = LinearIndex([ss2, ss47, ss63]) + + # ...now search with something that should ignore sig47, the exact match. + search_fn = JaccardSearchBestOnly_ButIgnore([ss47]) + + results = list(lidx.find(search_fn, ss47)) + results = [ ss for (ss, score) in results ] + + def is_found(ss, xx): + for q in xx: + print(ss, ss.similarity(q)) + if ss.similarity(q) == 1.0: + return True + return False + + assert not is_found(ss47, results) + assert not is_found(ss2, results) + assert is_found(ss63, results) + + +def test_lca_index_gather_ignore(): + from sourmash.lca import LCA_Database + + sig2 = utils.get_test_data('2.fa.sig') + sig47 = utils.get_test_data('47.fa.sig') + sig63 = utils.get_test_data('63.fa.sig') + + ss2 = sourmash.load_one_signature(sig2, ksize=31) + ss47 = sourmash.load_one_signature(sig47, ksize=31) + ss63 = sourmash.load_one_signature(sig63, ksize=31) + + # construct an index... + db = LCA_Database(ksize=31, scaled=1000) + db.insert(ss2) + db.insert(ss47) + db.insert(ss63) + + # ...now search with something that should ignore sig47, the exact match. + search_fn = JaccardSearchBestOnly_ButIgnore([ss47]) + + results = list(db.find(search_fn, ss47)) + results = [ ss for (ss, score) in results ] + + def is_found(ss, xx): + for q in xx: + print(ss, ss.similarity(q)) + if ss.similarity(q) == 1.0: + return True + return False + + assert not is_found(ss47, results) + assert not is_found(ss2, results) + assert is_found(ss63, results) + + +def test_sbt_index_gather_ignore(): + sig2 = utils.get_test_data('2.fa.sig') + sig47 = utils.get_test_data('47.fa.sig') + sig63 = utils.get_test_data('63.fa.sig') + + ss2 = sourmash.load_one_signature(sig2, ksize=31) + ss47 = sourmash.load_one_signature(sig47, ksize=31) + ss63 = sourmash.load_one_signature(sig63, ksize=31) + + # construct an index... + factory = GraphFactory(5, 100, 3) + db = SBT(factory, d=2) + + db.insert(ss2) + db.insert(ss47) + db.insert(ss63) + + # ...now search with something that should ignore sig47, the exact match. + print(f'\n** trying to ignore {ss47}') + search_fn = JaccardSearchBestOnly_ButIgnore([ss47]) + + results = list(db.find(search_fn, ss47)) + results = [ ss for (ss, score) in results ] + + def is_found(ss, xx): + for q in xx: + print('is found?', ss, ss.similarity(q)) + if ss.similarity(q) == 1.0: + return True + return False + + assert not is_found(ss47, results) + assert not is_found(ss2, results) + assert is_found(ss63, results) diff --git a/tests/test_search.py b/tests/test_search.py index efe61ea809..d52582b0cc 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -118,13 +118,13 @@ def test_score_jaccard_max_containment_zero_query_size(): def test_collect(): search_obj = make_jaccard_search_query(threshold=0) - search_obj.collect(1.0) + search_obj.collect(1.0, None) assert search_obj.threshold == 0 def test_collect_best_only(): search_obj = make_jaccard_search_query(threshold=0, best_only=True) - search_obj.collect(1.0) + search_obj.collect(1.0, None) assert search_obj.threshold == 1.0