Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Remove min_n_below from search code #1137

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/sourmash.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ void kmerminhash_slice_free(uint64_t *ptr, uintptr_t insize);

bool kmerminhash_track_abundance(const SourmashKmerMinHash *ptr);

double nodegraph_angular_similarity_upper_bound(const SourmashNodegraph *ptr,
const SourmashKmerMinHash *mh_ptr);

void nodegraph_buffer_free(uint8_t *ptr, uintptr_t insize);

bool nodegraph_count(SourmashNodegraph *ptr, uint64_t h);
Expand Down
11 changes: 11 additions & 0 deletions src/core/src/ffi/nodegraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,17 @@ pub unsafe extern "C" fn nodegraph_matches(
ng.matches(mh)
}

ffi_fn! {
unsafe fn nodegraph_angular_similarity_upper_bound(
ptr: *const SourmashNodegraph,
mh_ptr: *const SourmashKmerMinHash,
) -> Result<f64> {
let ng = SourmashNodegraph::as_rust(ptr);
let mh = SourmashKmerMinHash::as_rust(mh_ptr);
ng.angular_similarity_upper_bound(mh)
}
}

#[no_mangle]
pub unsafe extern "C" fn nodegraph_update(
ptr: *mut SourmashNodegraph,
Expand Down
39 changes: 39 additions & 0 deletions src/core/src/sketch/nodegraph.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::f64::consts::PI;
use std::fs::File;
use std::io;
use std::path::Path;
Expand Down Expand Up @@ -154,6 +155,44 @@ impl Nodegraph {
mh.iter_mins().filter(|x| self.get(**x) == 1).count()
}

/// upper-bound estimate of the angular similarity with an abundance MinHash
///
/// Note that this is not a *tight* bound: it might overestimate a lot.
/// The current goal is guaranteeing it doesn't *underestimate*
/// (and, if possible, it doesn't overestimate all the time).
pub fn angular_similarity_upper_bound(&self, other: &KmerMinHash) -> Result<f64, Error> {
if !other.track_abundance() {
// TODO: throw error, we need abundance for this
unimplemented!()
}
let other_abunds = other.to_vec_abunds();

let mut prod = 0;
let mut a_sq = 0_u64;
let b_sq: u64 = other_abunds.iter().map(|(_, b)| (b * b)).sum();

for (hash, abund) in other_abunds {
if self.get(hash) == 1 {
// TODO: which one overestimate less?
a_sq += abund * abund;
//a_sq += 1;

prod += abund * abund;
}
}

let norm_a = (a_sq as f64).sqrt();
let norm_b = (b_sq as f64).sqrt();

if norm_a == 0. || norm_b == 0. {
return Ok(0.0);
}
let prod = f64::min(prod as f64 / (norm_a * norm_b), 1.);
let distance = 2. * prod.acos() / PI;
// Adding some leeway with a small epsilon
Ok(1. - distance + 1e-3)
}

pub fn ntables(&self) -> usize {
self.bs.len()
}
Expand Down
8 changes: 8 additions & 0 deletions src/sourmash/nodegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ def matches(self, mh):

return self._methodcall(lib.nodegraph_matches, mh._objptr)

def angular_similarity_upper_bound(self, mh):
if not isinstance(mh, MinHash):
# FIXME: we could take sets here too (or anything that can be
# converted to a list of ints...)
raise ValueError("mh must be a MinHash")

return self._methodcall(lib.nodegraph_angular_similarity_upper_bound, mh._objptr)

def to_khmer_nodegraph(self):
import khmer
try:
Expand Down
34 changes: 0 additions & 34 deletions src/sourmash/sbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,8 +922,6 @@ def _load_v3(cls, info, leaf_loader, dirname, storage, *, print_version_warning=
error("WARNING: this is an old index version, please run `sourmash migrate` to update it.")
error("WARNING: proceeding with execution, but it will take longer to finish!")

tree._fill_min_n_below()

return tree

@classmethod
Expand Down Expand Up @@ -1042,32 +1040,6 @@ def _load_v6(cls, info, leaf_loader, dirname, storage, *, print_version_warning=

return tree

def _fill_min_n_below(self):
"""\
Propagate the smallest hash size below each node up the tree from
the leaves.
"""
def fill_min_n_below(node, *args, **kwargs):
original_min_n_below = node.metadata.get('min_n_below', sys.maxsize)
min_n_below = original_min_n_below

children = kwargs['children']
for child in children:
if child.node is not None:
if isinstance(child.node, Leaf):
min_n_below = min(len(child.node.data.minhash), min_n_below)
else:
child_n = child.node.metadata.get('min_n_below', sys.maxsize)
min_n_below = min(child_n, min_n_below)

if min_n_below == 0:
min_n_below = 1

node.metadata['min_n_below'] = min_n_below
return original_min_n_below != min_n_below

self._fill_up(fill_min_n_below)

def _fill_internal(self):

def fill_nodegraphs(node, *args, **kwargs):
Expand Down Expand Up @@ -1266,12 +1238,6 @@ def load(info, storage=None):

def update(self, parent):
parent.data.update(self.data)
if 'min_n_below' in self.metadata:
min_n_below = min(parent.metadata.get('min_n_below', sys.maxsize),
self.metadata.get('min_n_below'))
if min_n_below == 0:
min_n_below = 1
parent.metadata['min_n_below'] = min_n_below


class Leaf(object):
Expand Down
43 changes: 19 additions & 24 deletions src/sourmash/sbtmh.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,6 @@ def save(self, path):
def update(self, parent):
mh = self.data.minhash
parent.data.update(mh)
min_n_below = parent.metadata.get('min_n_below', sys.maxsize)
min_n_below = min(len(mh), min_n_below)

if min_n_below == 0:
min_n_below = 1

parent.metadata['min_n_below'] = min_n_below

@property
def data(self):
Expand All @@ -73,29 +66,33 @@ def data(self, new_data):

### Search functionality.

def _max_jaccard_underneath_internal_node(node, mh):
def _max_jaccard_underneath_internal_node(node, query):
"""\
calculate the maximum possibility similarity score below
this node, based on the number of matches in 'hashes' at this node,
divided by the smallest minhash size below this node.
divided by the size of the query.

This should yield be an upper bound on the Jaccard similarity
for any signature below this point.
"""
mh = query.minhash

if len(mh) == 0:
return 0.0

# count the maximum number of hash matches beneath this node
matches = node.data.matches(mh)
if mh.track_abundance:
# In this case we need to use the upper bound for angular similarity
max_score = node.data.angular_similarity_upper_bound(mh)
else:
# In this case we are working with similarity/containment:
# J(A, B) = |A intersection B| / |A union B|
# If we use only |A| as denominator, it is the containment
# Because |A| <= |A union B|, it is also an upper bound on the max jaccard

# get the size of the smallest collection of hashes below this point
min_n_below = node.metadata.get('min_n_below', -1)

if min_n_below == -1:
raise Exception('cannot do similarity search on this SBT; need to rebuild.')
# count the maximum number of hash matches beneath this node
matches = node.data.matches(mh)

# max of numerator divided by min of denominator => max Jaccard
max_score = float(matches) / min_n_below
max_score = float(matches) / len(mh)

return max_score

Expand All @@ -106,13 +103,12 @@ def search_minhashes(node, sig, threshold, results=None):
"""
assert results is None

sig_mh = sig.minhash
score = 0

if isinstance(node, SigLeaf):
score = node.data.minhash.similarity(sig_mh)
score = node.data.minhash.similarity(sig.minhash)
else: # Node minhash comparison
score = _max_jaccard_underneath_internal_node(node, sig_mh)
score = _max_jaccard_underneath_internal_node(node, sig)

if score >= threshold:
return 1
Expand All @@ -126,13 +122,12 @@ def __init__(self):

def search(self, node, sig, threshold, results=None):
assert results is None
sig_mh = sig.minhash
score = 0

if isinstance(node, SigLeaf):
score = node.data.minhash.similarity(sig_mh)
score = node.data.minhash.similarity(sig.minhash)
else: # internal object, not leaf.
score = _max_jaccard_underneath_internal_node(node, sig_mh)
score = _max_jaccard_underneath_internal_node(node, sig)

if score >= threshold:
# have we done better than this elsewhere? if yes, truncate.
Expand Down
1 change: 1 addition & 0 deletions tests/test-data/min_n_below/HSMA33OT.fastq.gz.sig

Large diffs are not rendered by default.

Binary file added tests/test-data/min_n_below/index.sbt.zip
Binary file not shown.
29 changes: 29 additions & 0 deletions tests/test_sbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,3 +940,32 @@ def test_sbt_node_cache():

assert tree._nodescache.currsize == 1
assert tree._nodescache.currsize == 1


@utils.in_thisdir
def test_sbt_min_n_below_removal_abundance(c):
sigfile1 = utils.get_test_data('min_n_below/HSMA33OT.fastq.gz.sig')
db = utils.get_test_data('min_n_below/index.sbt.zip')

c.run_sourmash('search', sigfile1, db, '--threshold', '0.085', '-k', '51')
assert '17 matches;' in c.last_result.out

c.run_sourmash('gather', sigfile1, db, '-k', '51')
assert 'found 8 matches total' in c.last_result.out
assert 'the recovered matches hit 28.0% of the query' in c.last_result.out


@utils.in_thisdir
def test_sbt_min_n_below_removal_noabundance(c):
sigfile1 = utils.get_test_data('min_n_below/HSMA33OT.fastq.gz.sig')
noabunds_sig = c.output("HSMA_flat.sig")
db = utils.get_test_data('min_n_below/index.sbt.zip')

c.run_sourmash("sig", "flatten", "-o", noabunds_sig, sigfile1)

c.run_sourmash('search', noabunds_sig, db, '--threshold', '0.085', '-k', '51')
assert '10 matches;' in c.last_result.out

c.run_sourmash('gather', noabunds_sig, db, '-k', '51')
assert 'found 8 matches total' in c.last_result.out
assert 'the recovered matches hit 15.1% of the query' in c.last_result.out
5 changes: 0 additions & 5 deletions tests/test_sourmash.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,6 @@ def test_do_sourmash_sbt_search_check_bug():
assert '1 matches:' in out

tree = load_sbt_index(os.path.join(location, 'zzz.sbt.zip'))
assert tree._nodes[0].metadata['min_n_below'] == 431


def test_do_sourmash_sbt_search_empty_sig():
Expand All @@ -1262,7 +1261,6 @@ def test_do_sourmash_sbt_search_empty_sig():
assert '1 matches:' in out

tree = load_sbt_index(os.path.join(location, 'zzz.sbt.zip'))
assert tree._nodes[0].metadata['min_n_below'] == 1


def test_do_sourmash_sbt_move_and_search_output():
Expand Down Expand Up @@ -4433,9 +4431,6 @@ def test_migrate():
sorted(identity)))

assert "this is an old index version" not in err
assert all('min_n_below' in node.metadata
for node in identity
if isinstance(node, Node))


def test_license_cc0():
Expand Down