diff --git a/sc2ts/inference.py b/sc2ts/inference.py index c22d358..387f7d8 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -568,12 +568,12 @@ def extend( f"mutations={base_ts.num_mutations};date={base_ts.metadata['sc2ts']['date']}" ) - metadata_matches = list(metadata_db.get(date)) + metadata_matches = {md["strain"]: md for md in metadata_db.get(date)} logger.info(f"Got {len(metadata_matches)} metadata matches") preprocessed_samples = preprocess( - strains=[md["strain"] for md in metadata_matches], + strains=list(metadata_matches.keys()), alignment_store_path=alignment_store.path, keep_sites=base_ts.sites_position.astype(int), progress_title=date, @@ -584,10 +584,11 @@ def extend( pango_lineage_key = "Viridian_pangolin" samples = [] - for s, md in zip(preprocessed_samples, metadata_matches): + for s in preprocessed_samples: if s.haplotype is None: logger.debug(f"No alignment stored for {s.strain}") continue + md = metadata_matches[s.strain] s.metadata = md s.pango = md.get(pango_lineage_key, "Unknown") s.date = date @@ -1270,7 +1271,6 @@ def mutation_summary(self): return "[" + ", ".join(str(mutation) for mutation in self.mutations) + "]" - def get_match_info(ts, sample_paths, sample_mutations): tables = ts.tables assert np.all(tables.sites.ancestral_state_offset == np.arange(ts.num_sites + 1))