diff --git a/PopPUNK/__init__.py b/PopPUNK/__init__.py index ee6b5324..80b46880 100644 --- a/PopPUNK/__init__.py +++ b/PopPUNK/__init__.py @@ -3,7 +3,7 @@ '''PopPUNK (POPulation Partitioning Using Nucleotide Kmers)''' -__version__ = '2.4.5' +__version__ = '2.4.6' # Minimum sketchlib version SKETCHLIB_MAJOR = 1 diff --git a/PopPUNK/assign.py b/PopPUNK/assign.py index 2823d8c4..8537d251 100644 --- a/PopPUNK/assign.py +++ b/PopPUNK/assign.py @@ -49,10 +49,89 @@ def assign_query(dbFuncs, gpu_dist, gpu_graph, deviceid, - web, - json_sketch, save_partial_query_graph): - """Code for assign query mode. Written as a separate function so it can be called + """Code for assign query mode for CLI""" + + createDatabaseDir = dbFuncs['createDatabaseDir'] + constructDatabase = dbFuncs['constructDatabase'] + readDBParams = dbFuncs['readDBParams'] + + if ref_db == output: + sys.stderr.write("--output and --ref-db must be different to " + "prevent overwrite.\n") + sys.exit(1) + + # Find distances to reference db + kmers, sketch_sizes, codon_phased = readDBParams(ref_db) + + # construct database + createDatabaseDir(output, kmers) + qNames = constructDatabase(q_files, + kmers, + sketch_sizes, + output, + threads, + overwrite, + codon_phased = codon_phased, + calc_random = False, + use_gpu = gpu_sketch, + deviceid = deviceid) + + isolateClustering = assign_query_hdf5(dbFuncs, + ref_db, + qNames, + output, + qc_dict, + update_db, + write_references, + distances, + threads, + overwrite, + plot_fit, + graph_weights, + max_a_dist, + max_pi_dist, + type_isolate, + model_dir, + strand_preserved, + previous_clustering, + external_clustering, + core, + accessory, + gpu_sketch, + gpu_dist, + gpu_graph, + deviceid, + save_partial_query_graph) + return(isolateClustering) + +def assign_query_hdf5(dbFuncs, + ref_db, + qNames, + output, + qc_dict, + update_db, + write_references, + distances, + threads, + overwrite, + plot_fit, + graph_weights, + max_a_dist, + max_pi_dist, + type_isolate, + model_dir, + strand_preserved, + previous_clustering, + external_clustering, + core, + accessory, + gpu_sketch, + gpu_dist, + gpu_graph, + deviceid, + save_partial_query_graph): + """Code for assign query mode taking hdf5 as input. Written as a separate function so it can be called by web APIs""" # Modules imported here as graph tool is very slow to load (it pulls in all of GTK?) @@ -80,10 +159,6 @@ def assign_query(dbFuncs, from .utils import update_distance_matrices from .utils import createOverallLineage - from .web import sketch_to_hdf5 - - createDatabaseDir = dbFuncs['createDatabaseDir'] - constructDatabase = dbFuncs['constructDatabase'] joinDBs = dbFuncs['joinDBs'] queryDatabase = dbFuncs['queryDatabase'] readDBParams = dbFuncs['readDBParams'] @@ -115,7 +190,7 @@ def assign_query(dbFuncs, prev_clustering = model_prefix # Find distances to reference db - kmers, sketch_sizes, codon_phased = readDBParams(ref_db) + kmers = readDBParams(ref_db)[0] # Iterate through different types of model fit with a refined model when specified # Core and accessory assignments use the same model and same overall set of distances @@ -150,22 +225,7 @@ def assign_query(dbFuncs, sys.exit(1) else: rNames = getSeqsInDb(os.path.join(ref_db, os.path.basename(ref_db) + ".h5")) - # construct database - use a single database directory for all query outputs - if (web and json_sketch is not None): - qNames = sketch_to_hdf5(json_sketch, output) - elif (fit_type == 'default'): - # construct database - createDatabaseDir(output, kmers) - qNames = constructDatabase(q_files, - kmers, - sketch_sizes, - output, - threads, - overwrite, - codon_phased = codon_phased, - calc_random = False, - use_gpu = gpu_sketch, - deviceid = deviceid) + if (fit_type == 'default' or (fit_type != 'default' and use_ref_graph)): # run query qrDistMat = queryDatabase(rNames = rNames, @@ -610,8 +670,6 @@ def main(): args.gpu_dist, args.gpu_graph, args.deviceid, - web=False, - json_sketch=None, save_partial_query_graph=False) sys.stderr.write("\nDone\n") diff --git a/PopPUNK/web.py b/PopPUNK/web.py index 37d70229..cd24eb07 100644 --- a/PopPUNK/web.py +++ b/PopPUNK/web.py @@ -170,47 +170,49 @@ def get_colours(query, clusters): colours.append('blue') return colours -def sketch_to_hdf5(sketch, output): - """Convert JSON sketch to query hdf5 database""" - kmers = [] - dists = [] - - sketch_dict = json.loads(sketch) - qNames = ["query"] +def sketch_to_hdf5(sketches_dict, output): + """Convert dict of JSON sketches to query hdf5 database""" + qNames = [] queryDB = h5py.File(os.path.join(output, os.path.basename(output) + '.h5'), 'w') sketches = queryDB.create_group("sketches") - sketch_props = sketches.create_group(qNames[0]) - - for key, value in sketch_dict.items(): - try: - kmers.append(int(key)) - dists.append(np.array(value, dtype='uint64')) - except (TypeError, ValueError): - if key == "version": - sketches.attrs['sketch_version'] = value - elif key == "codon_phased": - sketches.attrs['codon_phased'] = value - elif key == "densified": - sketches.attrs['densified'] = value - elif key == "bases": - sketch_props.attrs['base_freq'] = value - elif key == "bbits": - sketch_props.attrs['bbits'] = value - elif key == "length": - sketch_props.attrs['length'] = value - elif key == "missing_bases": - sketch_props.attrs['missing_bases'] = value - elif key == "sketchsize64": - sketch_props.attrs['sketchsize64'] = value - elif key == "species": - pass - else: - sys.stderr.write(key + " not recognised") - - sketch_props.attrs['kmers'] = kmers - for k_index in range(len(kmers)): - k_spec = sketch_props.create_dataset(str(kmers[k_index]), data=dists[k_index], dtype='uint64') - k_spec.attrs['kmer-size'] = kmers[k_index] + + for top_key, top_value in sketches_dict.items(): + qNames.append(top_key) + kmers = [] + dists = [] + sketch_dict = json.loads(top_value) + sketch_props = sketches.create_group(top_key) + + for key, value in sketch_dict.items(): + try: + kmers.append(int(key)) + dists.append(np.array(value, dtype='uint64')) + except (TypeError, ValueError): + if key == "version": + sketches.attrs['sketch_version'] = value + elif key == "codon_phased": + sketches.attrs['codon_phased'] = value + elif key == "densified": + pass + elif key == "bases": + sketch_props.attrs['base_freq'] = value + elif key == "bbits": + sketch_props.attrs['bbits'] = value + elif key == "length": + sketch_props.attrs['length'] = value + elif key == "missing_bases": + sketch_props.attrs['missing_bases'] = value + elif key == "sketchsize64": + sketch_props.attrs['sketchsize64'] = value + elif key == "species": + pass + else: + sys.stderr.write(key + " not recognised") + + sketch_props.attrs['kmers'] = kmers + for k_index in range(len(kmers)): + k_spec = sketch_props.create_dataset(str(kmers[k_index]), data=dists[k_index], dtype='uint64') + k_spec.attrs['kmer-size'] = kmers[k_index] queryDB.close() return qNames @@ -290,13 +292,14 @@ def get_aliases(aliasDF, clusterLabels, species): alias_dict = {"GPSC":str(GPS_name)} return alias_dict -def summarise_clusters(output, species, species_db): +def summarise_clusters(output, species, species_db, qNames): """Retreieve assigned query and all cluster prevalences. Write list of all isolates in cluster for tree subsetting""" totalDF = pd.read_csv(os.path.join(output, os.path.basename(output) + "_clusters.csv")) - queryDF = totalDF.loc[totalDF['Taxon'] == "query"] + queryDF = totalDF[totalDF['Taxon'].isin(qNames)] queryDF = queryDF.reset_index(drop=True) - query = str(queryDF["Cluster"][0]) + queries_names = list(queryDF["Taxon"]) + queries_clusters = list(queryDF["Cluster"]) num_samples = len(totalDF["Taxon"]) totalDF["Cluster"] = totalDF["Cluster"].astype(str) cluster_list = list(totalDF["Cluster"]) @@ -307,19 +310,21 @@ def summarise_clusters(output, species, species_db): uniquetotalDF = totalDF.drop_duplicates(subset=['Cluster']) clusters = list(uniquetotalDF['Cluster']) prevalences = list(uniquetotalDF["Prevalence"]) - query_prevalence = prevalences[clusters.index(query)] - # write list of all isolates in cluster - clusterDF = totalDF.loc[totalDF['Cluster'] == query] - to_include = list(clusterDF['Taxon']) - with open(os.path.join(output, "include.txt"), "w") as i: - i.write("\n".join(to_include)) + queries_prevalence = [] + for query in queries_clusters: + queries_prevalence.append(prevalences[clusters.index(str(query))]) + # write list of all isolates in cluster + clusterDF = totalDF.loc[totalDF['Cluster'] == str(query)] + to_include = list(clusterDF['Taxon']) + with open(os.path.join(output, "include" + str(query) + ".txt"), "w") as i: + i.write("\n".join(to_include)) # get aliases if os.path.isfile(os.path.join(species_db, "aliases.csv")): aliasDF = pd.read_csv(os.path.join(species_db, "aliases.csv")) alias_dict = get_aliases(aliasDF, list(clusterDF['Taxon']), species) else: alias_dict = {"Aliases": "NA"} - return query, query_prevalence, clusters, prevalences, alias_dict, to_include + return queries_names, queries_clusters, queries_prevalence, clusters, prevalences, alias_dict, to_include @scheduler.task('interval', id='clean_tmp', hours=1, misfire_grace_time=900) def clean_tmp():