Skip to content

Commit

Permalink
add picklists to selectors
Browse files Browse the repository at this point in the history
  • Loading branch information
ctb committed Jun 12, 2021
1 parent 74f31f5 commit b1fc982
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 17 deletions.
5 changes: 4 additions & 1 deletion src/sourmash/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def select(self, ksize=None, moltype=None, scaled=None, num=None,


def select_signature(ss, ksize=None, moltype=None, scaled=0, num=0,
containment=False):
containment=False, picklist=None):
"Check that the given signature matches the specificed requirements."
# ksize match?
if ksize and ksize != ss.minhash.ksize:
Expand All @@ -318,6 +318,9 @@ def select_signature(ss, ksize=None, moltype=None, scaled=0, num=0,
if ss.minhash.scaled or num != ss.minhash.num:
return False

if picklist is not None and ss not in picklist:
return False

return True


Expand Down
27 changes: 24 additions & 3 deletions src/sourmash/lca/lca_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(self, ksize, scaled, moltype='DNA'):
self.lineage_to_lid = {}
self.lid_to_lineage = {}
self.hashval_to_idx = defaultdict(set)
self.picklists = []

@property
def location(self):
Expand Down Expand Up @@ -176,7 +177,7 @@ def signatures(self):
yield v

def select(self, ksize=None, moltype=None, num=0, scaled=0,
containment=False):
containment=False, picklist=None):
"""Make sure this database matches the requested requirements.
As with SBTs, queries with higher scaled values than the database
Expand All @@ -197,6 +198,9 @@ def select(self, ksize=None, moltype=None, num=0, scaled=0,
if moltype is not None and moltype != self.moltype:
raise ValueError(f"moltype on this database is {self.moltype}; this is different from requested moltype of {moltype}")

if picklist is not None:
self.picklists.append(picklist)

return self

@classmethod
Expand Down Expand Up @@ -416,7 +420,16 @@ def _signatures(self):
for idx, mh in mhd.items():
ident = self.idx_to_ident[idx]
name = self.ident_to_name[ident]
sigd[idx] = SourmashSignature(mh, name=name)
ss = SourmashSignature(mh, name=name)

keep = True
for picklist in self.picklists:
if ss not in picklist:
keep = False
break

if keep:
sigd[idx] = SourmashSignature(mh, name=name)

debug('=> {} signatures!', len(sigd))
return sigd
Expand Down Expand Up @@ -478,7 +491,15 @@ def find(self, search_fn, query, **kwargs):
# signal that it is done, or something.
if search_fn.passes(score):
if search_fn.collect(score, subj):
yield IndexSearchResult(score, subj, self.location)

# filter on picklists
keep = True
for picklist in self.picklists:
if subj not in picklist:
keep = False

if keep:
yield IndexSearchResult(score, subj, self.location)

@cached_property
def lid_to_idx(self):
Expand Down
26 changes: 23 additions & 3 deletions src/sourmash/sbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,25 @@ def __init__(self, factory, *, d=2, storage=None, cache_size=None):
cache_size = sys.maxsize
self._nodescache = _NodesCache(maxsize=cache_size)
self._location = None
self.picklists = []

@property
def location(self):
return self._location

def signatures(self):
for k in self.leaves():
yield k.data
ss = k.data
keep = True
for picklist in self.picklists:
if ss not in picklist:
keep = False

if keep:
yield k.data

def select(self, ksize=None, moltype=None, num=0, scaled=0,
containment=False):
containment=False, picklist=None):
"""Make sure this database matches the requested requirements.
Will always raise ValueError if a requirement cannot be met.
Expand Down Expand Up @@ -210,6 +218,9 @@ def select(self, ksize=None, moltype=None, num=0, scaled=0,
if scaled > db_mh.scaled and not containment:
raise ValueError(f"search scaled value {scaled} is less than database scaled value of {db_mh.scaled}")

if picklist is not None:
self.picklists.append(picklist)

return self

def new_node_pos(self, node):
Expand Down Expand Up @@ -450,7 +461,16 @@ def node_search(node, *args, **kwargs):

# & execute!
for n in self._find_nodes(node_search, **kwargs):
yield IndexSearchResult(results[n.data], n.data, self.location)
ss = n.data

# filter on picklists
keep = True
for picklist in self.picklists:
if ss not in picklist:
keep = False

if keep:
yield IndexSearchResult(results[ss], ss, self.location)

def _rebuild_node(self, pos=0):
"""Recursively rebuilds an internal node (if it is not present).
Expand Down
14 changes: 6 additions & 8 deletions src/sourmash/sig/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,16 +556,11 @@ def extract(args):
notify(f"WARNING: {n_empty_val} empty values in column '{picklist.column_name}' in CSV file")
if dup_vals:
notify(f"WARNING: {len(dup_vals)} values in column '{picklist.column_name}' were not distinct")
picklist_filter_fn = picklist.filter
else:
def picklist_filter_fn(it):
for ss in it:
yield ss

# further filtering on md5 or name?
if args.md5 is not None or args.name is not None:
def filter_fn(it):
for ss in picklist_filter_fn(it):
for ss in it:
# match?
keep = False
if args.name and args.name in str(ss):
Expand All @@ -576,8 +571,10 @@ def filter_fn(it):
if keep:
yield ss
else:
# whatever comes out of the picklist is fine
filter_fn = picklist_filter_fn
# whatever comes out of the database is fine
def filter_fn(it):
for ss in it:
yield ss

# ok! filtering defined, let's go forward
progress = sourmash_args.SignatureLoadingProgress()
Expand All @@ -589,6 +586,7 @@ def filter_fn(it):
siglist = sourmash_args.load_file_as_signatures(filename,
ksize=args.ksize,
select_moltype=moltype,
picklist=picklist,
progress=progress)
for ss in filter_fn(siglist):
save_sigs.add(ss)
Expand Down
5 changes: 3 additions & 2 deletions src/sourmash/sourmash_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def load_file_as_index(filename, *, yield_all_files=False):


def load_file_as_signatures(filename, *, select_moltype=None, ksize=None,
picklist=None,
yield_all_files=False,
progress=None):
"""Load 'filename' as a collection of signatures. Return an iterable.
Expand All @@ -367,13 +368,13 @@ def load_file_as_signatures(filename, *, select_moltype=None, ksize=None,
underneath this directory into a list of signatures. If
yield_all_files=True, will attempt to load all files.
Applies selector function if select_moltype and/or ksize are given.
Applies selector function if select_moltype, ksize or picklist are given.
"""
if progress:
progress.notify(filename)

db = _load_database(filename, yield_all_files)
db = db.select(moltype=select_moltype, ksize=ksize)
db = db.select(moltype=select_moltype, ksize=ksize, picklist=picklist)
loader = db.signatures()

if progress is not None:
Expand Down

0 comments on commit b1fc982

Please sign in to comment.