From b1fc982a0d3cee78b14899d83d7b6fd203e7a6db Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Sat, 12 Jun 2021 10:01:56 -0700 Subject: [PATCH] add picklists to selectors --- src/sourmash/index.py | 5 ++++- src/sourmash/lca/lca_db.py | 27 ++++++++++++++++++++++++--- src/sourmash/sbt.py | 26 +++++++++++++++++++++++--- src/sourmash/sig/__main__.py | 14 ++++++-------- src/sourmash/sourmash_args.py | 5 +++-- 5 files changed, 60 insertions(+), 17 deletions(-) diff --git a/src/sourmash/index.py b/src/sourmash/index.py index 2aec59283c..b344a3cabc 100644 --- a/src/sourmash/index.py +++ b/src/sourmash/index.py @@ -291,7 +291,7 @@ def select(self, ksize=None, moltype=None, scaled=None, num=None, def select_signature(ss, ksize=None, moltype=None, scaled=0, num=0, - containment=False): + containment=False, picklist=None): "Check that the given signature matches the specificed requirements." # ksize match? if ksize and ksize != ss.minhash.ksize: @@ -318,6 +318,9 @@ def select_signature(ss, ksize=None, moltype=None, scaled=0, num=0, if ss.minhash.scaled or num != ss.minhash.num: return False + if picklist is not None and ss not in picklist: + return False + return True diff --git a/src/sourmash/lca/lca_db.py b/src/sourmash/lca/lca_db.py index a3d90ffd5d..69b776aacc 100644 --- a/src/sourmash/lca/lca_db.py +++ b/src/sourmash/lca/lca_db.py @@ -71,6 +71,7 @@ def __init__(self, ksize, scaled, moltype='DNA'): self.lineage_to_lid = {} self.lid_to_lineage = {} self.hashval_to_idx = defaultdict(set) + self.picklists = [] @property def location(self): @@ -176,7 +177,7 @@ def signatures(self): yield v def select(self, ksize=None, moltype=None, num=0, scaled=0, - containment=False): + containment=False, picklist=None): """Make sure this database matches the requested requirements. As with SBTs, queries with higher scaled values than the database @@ -197,6 +198,9 @@ def select(self, ksize=None, moltype=None, num=0, scaled=0, if moltype is not None and moltype != self.moltype: raise ValueError(f"moltype on this database is {self.moltype}; this is different from requested moltype of {moltype}") + if picklist is not None: + self.picklists.append(picklist) + return self @classmethod @@ -416,7 +420,16 @@ def _signatures(self): for idx, mh in mhd.items(): ident = self.idx_to_ident[idx] name = self.ident_to_name[ident] - sigd[idx] = SourmashSignature(mh, name=name) + ss = SourmashSignature(mh, name=name) + + keep = True + for picklist in self.picklists: + if ss not in picklist: + keep = False + break + + if keep: + sigd[idx] = SourmashSignature(mh, name=name) debug('=> {} signatures!', len(sigd)) return sigd @@ -478,7 +491,15 @@ def find(self, search_fn, query, **kwargs): # signal that it is done, or something. if search_fn.passes(score): if search_fn.collect(score, subj): - yield IndexSearchResult(score, subj, self.location) + + # filter on picklists + keep = True + for picklist in self.picklists: + if subj not in picklist: + keep = False + + if keep: + yield IndexSearchResult(score, subj, self.location) @cached_property def lid_to_idx(self): diff --git a/src/sourmash/sbt.py b/src/sourmash/sbt.py index c31a689621..3cbc762e2d 100644 --- a/src/sourmash/sbt.py +++ b/src/sourmash/sbt.py @@ -148,6 +148,7 @@ def __init__(self, factory, *, d=2, storage=None, cache_size=None): cache_size = sys.maxsize self._nodescache = _NodesCache(maxsize=cache_size) self._location = None + self.picklists = [] @property def location(self): @@ -155,10 +156,17 @@ def location(self): def signatures(self): for k in self.leaves(): - yield k.data + ss = k.data + keep = True + for picklist in self.picklists: + if ss not in picklist: + keep = False + + if keep: + yield k.data def select(self, ksize=None, moltype=None, num=0, scaled=0, - containment=False): + containment=False, picklist=None): """Make sure this database matches the requested requirements. Will always raise ValueError if a requirement cannot be met. @@ -210,6 +218,9 @@ def select(self, ksize=None, moltype=None, num=0, scaled=0, if scaled > db_mh.scaled and not containment: raise ValueError(f"search scaled value {scaled} is less than database scaled value of {db_mh.scaled}") + if picklist is not None: + self.picklists.append(picklist) + return self def new_node_pos(self, node): @@ -450,7 +461,16 @@ def node_search(node, *args, **kwargs): # & execute! for n in self._find_nodes(node_search, **kwargs): - yield IndexSearchResult(results[n.data], n.data, self.location) + ss = n.data + + # filter on picklists + keep = True + for picklist in self.picklists: + if ss not in picklist: + keep = False + + if keep: + yield IndexSearchResult(results[ss], ss, self.location) def _rebuild_node(self, pos=0): """Recursively rebuilds an internal node (if it is not present). diff --git a/src/sourmash/sig/__main__.py b/src/sourmash/sig/__main__.py index d75dc66685..718c799744 100644 --- a/src/sourmash/sig/__main__.py +++ b/src/sourmash/sig/__main__.py @@ -556,16 +556,11 @@ def extract(args): notify(f"WARNING: {n_empty_val} empty values in column '{picklist.column_name}' in CSV file") if dup_vals: notify(f"WARNING: {len(dup_vals)} values in column '{picklist.column_name}' were not distinct") - picklist_filter_fn = picklist.filter - else: - def picklist_filter_fn(it): - for ss in it: - yield ss # further filtering on md5 or name? if args.md5 is not None or args.name is not None: def filter_fn(it): - for ss in picklist_filter_fn(it): + for ss in it: # match? keep = False if args.name and args.name in str(ss): @@ -576,8 +571,10 @@ def filter_fn(it): if keep: yield ss else: - # whatever comes out of the picklist is fine - filter_fn = picklist_filter_fn + # whatever comes out of the database is fine + def filter_fn(it): + for ss in it: + yield ss # ok! filtering defined, let's go forward progress = sourmash_args.SignatureLoadingProgress() @@ -589,6 +586,7 @@ def filter_fn(it): siglist = sourmash_args.load_file_as_signatures(filename, ksize=args.ksize, select_moltype=moltype, + picklist=picklist, progress=progress) for ss in filter_fn(siglist): save_sigs.add(ss) diff --git a/src/sourmash/sourmash_args.py b/src/sourmash/sourmash_args.py index 4085fa5bea..40c7d35444 100644 --- a/src/sourmash/sourmash_args.py +++ b/src/sourmash/sourmash_args.py @@ -351,6 +351,7 @@ def load_file_as_index(filename, *, yield_all_files=False): def load_file_as_signatures(filename, *, select_moltype=None, ksize=None, + picklist=None, yield_all_files=False, progress=None): """Load 'filename' as a collection of signatures. Return an iterable. @@ -367,13 +368,13 @@ def load_file_as_signatures(filename, *, select_moltype=None, ksize=None, underneath this directory into a list of signatures. If yield_all_files=True, will attempt to load all files. - Applies selector function if select_moltype and/or ksize are given. + Applies selector function if select_moltype, ksize or picklist are given. """ if progress: progress.notify(filename) db = _load_database(filename, yield_all_files) - db = db.select(moltype=select_moltype, ksize=ksize) + db = db.select(moltype=select_moltype, ksize=ksize, picklist=picklist) loader = db.signatures() if progress is not None: