diff --git a/src/sourmash/commands.py b/src/sourmash/commands.py index 41439a2acf..ff309a208c 100644 --- a/src/sourmash/commands.py +++ b/src/sourmash/commands.py @@ -631,6 +631,10 @@ def gather(args): query.minhash.scaled, int(args.scaled)) query.minhash = query.minhash.downsample(scaled=args.scaled) + # flatten if needed @CTB do we need this here? + if query.minhash.track_abundance: + query.minhash = query.minhash.flatten() + # empty? if not len(query.minhash): error('no query hashes!? exiting.') diff --git a/src/sourmash/index.py b/src/sourmash/index.py index 917a87b196..df4157daeb 100644 --- a/src/sourmash/index.py +++ b/src/sourmash/index.py @@ -280,6 +280,9 @@ def location(self): def signatures(self): return iter(self._signatures) + def __bool__(self): + return bool(self._signatures) + def __len__(self): return len(self._signatures) @@ -329,6 +332,9 @@ def __init__(self, zf, selection_dict=None, self.selection_dict = selection_dict self.traverse_yield_all = traverse_yield_all + def __bool__(self): + return bool(self.zf) + def __len__(self): return len(list(self.signatures())) @@ -464,7 +470,8 @@ def gather(self, query, *args, **kwargs): result = IndexSearchResult(cont, match, location) # calculate intersection of this "best match" with query, for removal. - match_mh = match.minhash.downsample(scaled=scaled) + # @CTB note flatten + match_mh = match.minhash.downsample(scaled=scaled).flatten() intersect_mh = query_mh.intersection(match_mh) # Prepare counter for finding the next match by decrementing diff --git a/src/sourmash/lca/lca_db.py b/src/sourmash/lca/lca_db.py index 4af77b5a5b..7d57def47c 100644 --- a/src/sourmash/lca/lca_db.py +++ b/src/sourmash/lca/lca_db.py @@ -212,6 +212,10 @@ def load(cls, db_name): xopen = gzip.open with xopen(db_name, 'rt') as fp: + if fp.read(1) != '{': + raise ValueError(f"'{db_name}' is not an LCA database file.") + fp.seek(0) + load_d = {} try: load_d = json.load(fp) diff --git a/src/sourmash/sourmash_args.py b/src/sourmash/sourmash_args.py index 5ad52587a4..3769a9cdd9 100644 --- a/src/sourmash/sourmash_args.py +++ b/src/sourmash/sourmash_args.py @@ -306,7 +306,7 @@ def _load_database(filename, traverse_yield_all, *, cache_size=None): debug_literal(f"_load_databases: FAIL on fn {desc}.") debug_literal(traceback.format_exc()) - if db: + if db is not None: loaded = True break