Skip to content

Commit

Permalink
[MRG] Adjust Index.find search protocol to support selective collec…
Browse files Browse the repository at this point in the history
…tion of matches (#1477)

* begin refactoring 'categorize'

* have the 'find' function for SBTs return signatures

* fix majority of tests

* comment & then fix test

* torture the tests into working

* split find and _find_nodes to take different kinds of functions

* redo 'find' on index

* refactor lca_db to use new find

* refactor SBT to use new find

* comment/cleanup

* refactor out common code

* fix up gather

* use 'passes' properly

* attempted cleanup

* minor fixes

* get a start on correct downsampling

* adjust tree downsampling for regular minhashes, too

* remove now-unused search functions in sbtmh

* refactor categorize to use new find

* cleanup and removal

* remove redundant code in lca_db

* remove redundant code in SBT

* add notes

* remove more unused code

* refactor most of the test_sbt tests

* fix one minor issue

* fix jaccard calculation in sbt

* check for compatibility of search fn and query signature

* switch tests over to jaccard similarity, not containment

* fix test

* remove test for unimplemented LCA_Database.find method

* document threshold change; update test

* refuse to run abund signatures

* flatten sigs internally for gather

* reinflate abundances for saving

* fix problem where sbt indices coudl be created with abund signatures

* more

* split flat and abund search

* make ignore_abundance work again for categorize

* turn off best-only, since it triggers on self-hits.

* add test: 'sourmash index' flattens sigs

* add note about something to test

* fix typo; still broken tho

* location is now a property

* move search code into search.py

* remove redundant scaled checking code

* best-only now works properly for two tests

* 'fix' tests by removing v1 and v2 SBT compatibility

* simplify (?) downsampling code

* require keyword args in MinHash.downsample(...)

* fix bug with downsample

* require keyword args in MinHash.downsample(...)

* fix test to use proper downsampling, reverse order to match scaled

* add test for revealed bug

* remove unnecessary comment

* flatten subject MinHash, too

* add testme comment

* clean up sbt find

* clean up lca find

* add IndexSearchResult namedtuple for search and gather results

* add more tests for Index classes

* add tests for subj & query num downsampling

* tests for Index.search_abund

* refactor a bit

* refactor make_jaccard_search_query; start tests

* even more tests

* test collect, best_only

* more search tests

* remove unnec space

* add minor comment

* deal with status == None on SystemExit

* upgrade and simplify categorize

* restore test

* merge

* fix abundance search in SBT for categorize

* code cleanup and refactoring; check for proper error messages

* add explicit test for incompatible num

* refactor MinHash.downsample

* deal with status == None on SystemExit

* fix test

* fix comment mispelling

* properly pass kwargs; fix search_sbt_index

* add simple tests for SBT load and search API

* allow arbitrary kwargs for LCA_DAtabase.find

* add testing of passthru-kwargs

* re-enable test

* add notes to update docstrings

* docstring updates

* fix test

* better tests for gather --save-unassigned

* remove unnecessary check-me comment

* clear out docstring

* SBT search doesn't work on v1 and v2 SBTs b/c no min_n_below

* fix my dumb mistake with gather

* have the JaccardSearch.collect function take the matching signature

* adjust search protocol to permit ignoring after finding

* comment me

* update comments, threshold setting

Co-authored-by: Luiz Irber <[email protected]>
  • Loading branch information
ctb and luizirber authored Apr 23, 2021
1 parent f02e250 commit e2dbe24
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 15 deletions.
4 changes: 2 additions & 2 deletions src/sourmash/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def prepare_query(query_mh, subj_mh):
if search_fn.passes(score):
# note: here we yield the original signature, not the
# downsampled minhash.
search_fn.collect(score)
yield subj, score
if search_fn.collect(score, subj):
yield subj, score

def search_abund(self, query, *, threshold=None, **kwargs):
"""Return set of matches with angular similarity above 'threshold'.
Expand Down
9 changes: 7 additions & 2 deletions src/sourmash/lca/lca_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,9 +462,14 @@ def find(self, search_fn, query, **kwargs):

score = search_fn.score_fn(query_size, shared_size, subj_size,
total_size)

# note to self: even with JaccardSearchBestOnly, this will
# still iterate over & score all signatures. We should come
# up with a protocol by which the JaccardSearch object can
# signal that it is done, or something.
if search_fn.passes(score):
search_fn.collect(score)
yield subj, score
if search_fn.collect(score, subj):
yield subj, score

@cached_property
def lid_to_idx(self):
Expand Down
9 changes: 6 additions & 3 deletions src/sourmash/sbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,12 @@ def node_search(node, *args, **kwargs):

if search_fn.passes(score):
if is_leaf: # terminal node? keep.
results[node.data] = score
search_fn.collect(score)
return True
if search_fn.collect(score, node.data):
results[node.data] = score
return True
else: # it's a good internal node, keep.
return True

return False

# & execute!
Expand Down
20 changes: 14 additions & 6 deletions src/sourmash/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def make_gather_query(query_mh, threshold_bp):
if threshold > 1.0:
return None

search_obj = JaccardSearchBestOnly(SearchType.CONTAINMENT, threshold=threshold)
search_obj = JaccardSearchBestOnly(SearchType.CONTAINMENT,
threshold=threshold)

return search_obj

Expand Down Expand Up @@ -111,14 +112,20 @@ def check_is_compatible(self, sig):
raise TypeError("this search cannot be done with an abund signature")

def passes(self, score):
"Return True if this score meets or exceeds the threshold."
"""Return True if this score meets or exceeds the threshold.
Note: this can be used whenever a score or estimate is available
(e.g. internal nodes on an SBT). `collect(...)`, below, decides
whether a particular signature should be collected, and/or can
update the threshold (used for BestOnly behavior).
"""
if score and score >= self.threshold:
return True
return False

def collect(self, score):
"Is this a potential match?"
pass
def collect(self, score, match_sig):
"Return True if this match should be collected."
return True

def score_jaccard(self, query_size, shared_size, subject_size, total_size):
"Calculate Jaccard similarity."
Expand All @@ -142,9 +149,10 @@ def score_max_containment(self, query_size, shared_size, subject_size,

class JaccardSearchBestOnly(JaccardSearch):
"A subclass of JaccardSearch that implements best-only."
def collect(self, score):
def collect(self, score, match):
"Raise the threshold to the best match found so far."
self.threshold = max(self.threshold, score)
return True


# generic SearchResult tuple.
Expand Down
126 changes: 126 additions & 0 deletions tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sourmash.sbt import SBT, GraphFactory, Leaf
from sourmash.sbtmh import SigLeaf
from sourmash import sourmash_args
from sourmash.search import JaccardSearch, SearchType

import sourmash_tst_utils as utils

Expand Down Expand Up @@ -1081,3 +1082,128 @@ def test_multi_index_load_from_pathlist_3_zipfile(c):

mi = MultiIndex.load_from_pathlist(file_list)
assert len(mi) == 7

##
## test a slightly outre version of JaccardSearch - this is a test of the
## JaccardSearch 'collect' protocol, in particular...
##

class JaccardSearchBestOnly_ButIgnore(JaccardSearch):
"A class that ignores certain results, but still does all the pruning."
def __init__(self, ignore_list):
super().__init__(SearchType.JACCARD, threshold=0.1)
self.ignore_list = ignore_list

# a collect function that _ignores_ things in the ignore_list
def collect(self, score, match):
print('in collect; current threshold:', self.threshold)
for q in self.ignore_list:
print('ZZZ', match, match.similarity(q))
if match.similarity(q) == 1.0:
print('yes, found.')
return False

# update threshold if not perfect match, which could help prune.
self.threshold = score
return True


def test_linear_index_gather_ignore():
sig2 = utils.get_test_data('2.fa.sig')
sig47 = utils.get_test_data('47.fa.sig')
sig63 = utils.get_test_data('63.fa.sig')

ss2 = sourmash.load_one_signature(sig2, ksize=31)
ss47 = sourmash.load_one_signature(sig47, ksize=31)
ss63 = sourmash.load_one_signature(sig63, ksize=31)

# construct an index...
lidx = LinearIndex([ss2, ss47, ss63])

# ...now search with something that should ignore sig47, the exact match.
search_fn = JaccardSearchBestOnly_ButIgnore([ss47])

results = list(lidx.find(search_fn, ss47))
results = [ ss for (ss, score) in results ]

def is_found(ss, xx):
for q in xx:
print(ss, ss.similarity(q))
if ss.similarity(q) == 1.0:
return True
return False

assert not is_found(ss47, results)
assert not is_found(ss2, results)
assert is_found(ss63, results)


def test_lca_index_gather_ignore():
from sourmash.lca import LCA_Database

sig2 = utils.get_test_data('2.fa.sig')
sig47 = utils.get_test_data('47.fa.sig')
sig63 = utils.get_test_data('63.fa.sig')

ss2 = sourmash.load_one_signature(sig2, ksize=31)
ss47 = sourmash.load_one_signature(sig47, ksize=31)
ss63 = sourmash.load_one_signature(sig63, ksize=31)

# construct an index...
db = LCA_Database(ksize=31, scaled=1000)
db.insert(ss2)
db.insert(ss47)
db.insert(ss63)

# ...now search with something that should ignore sig47, the exact match.
search_fn = JaccardSearchBestOnly_ButIgnore([ss47])

results = list(db.find(search_fn, ss47))
results = [ ss for (ss, score) in results ]

def is_found(ss, xx):
for q in xx:
print(ss, ss.similarity(q))
if ss.similarity(q) == 1.0:
return True
return False

assert not is_found(ss47, results)
assert not is_found(ss2, results)
assert is_found(ss63, results)


def test_sbt_index_gather_ignore():
sig2 = utils.get_test_data('2.fa.sig')
sig47 = utils.get_test_data('47.fa.sig')
sig63 = utils.get_test_data('63.fa.sig')

ss2 = sourmash.load_one_signature(sig2, ksize=31)
ss47 = sourmash.load_one_signature(sig47, ksize=31)
ss63 = sourmash.load_one_signature(sig63, ksize=31)

# construct an index...
factory = GraphFactory(5, 100, 3)
db = SBT(factory, d=2)

db.insert(ss2)
db.insert(ss47)
db.insert(ss63)

# ...now search with something that should ignore sig47, the exact match.
print(f'\n** trying to ignore {ss47}')
search_fn = JaccardSearchBestOnly_ButIgnore([ss47])

results = list(db.find(search_fn, ss47))
results = [ ss for (ss, score) in results ]

def is_found(ss, xx):
for q in xx:
print('is found?', ss, ss.similarity(q))
if ss.similarity(q) == 1.0:
return True
return False

assert not is_found(ss47, results)
assert not is_found(ss2, results)
assert is_found(ss63, results)
4 changes: 2 additions & 2 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,13 @@ def test_score_jaccard_max_containment_zero_query_size():

def test_collect():
search_obj = make_jaccard_search_query(threshold=0)
search_obj.collect(1.0)
search_obj.collect(1.0, None)
assert search_obj.threshold == 0


def test_collect_best_only():
search_obj = make_jaccard_search_query(threshold=0, best_only=True)
search_obj.collect(1.0)
search_obj.collect(1.0, None)
assert search_obj.threshold == 1.0


Expand Down

0 comments on commit e2dbe24

Please sign in to comment.