diff --git a/MANIFEST.in b/MANIFEST.in index b510d8bd..ad9e8edf 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,2 @@ recursive-include scripts *.py -recursive-include PopPUNK/data *.json *.gz *.txt \ No newline at end of file +recursive-include PopPUNK/data *.gz \ No newline at end of file diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py index 2ce2ee04..402b2460 100644 --- a/PopPUNK/__main__.py +++ b/PopPUNK/__main__.py @@ -94,6 +94,10 @@ def get_options(): 'separate database [default = False]', default=False, action='store_true') qcGroup.add_argument('--max-a-dist', help='Maximum accessory distance to permit [default = 0.5]', default = 0.5, type = float) + qcGroup.add_argument('--max-pi-dist', help='Maximum core distance to permit [default = 0.5]', + default = 0.5, type = float) + qcGroup.add_argument('--type-isolate', help='Isolate from which distances will be calculated for pruning [default = None]', + default = None, type = str) qcGroup.add_argument('--length-sigma', help='Number of standard deviations of length distribution beyond ' 'which sequences will be excluded [default = 5]', default = 5, type = int) qcGroup.add_argument('--length-range', help='Allowed length range, outside of which sequences will be excluded ' @@ -121,8 +125,6 @@ def get_options(): type=float, default = None) refinementGroup.add_argument('--manual-start', help='A file containing information for a start point. ' 'See documentation for help.', default=None) - refinementGroup.add_argument('--no-local', help='Do not perform the local optimization step (speed up on very large datasets)', - default=False, action='store_true') refinementGroup.add_argument('--model-dir', help='Directory containing model to use for assigning queries ' 'to clusters [default = reference database directory]', type = str) refinementGroup.add_argument('--score-idx', @@ -150,7 +152,12 @@ def get_options(): other.add_argument('--threads', default=1, type=int, help='Number of threads to use [default = 1]') other.add_argument('--gpu-sketch', default=False, action='store_true', help='Use a GPU when calculating sketches (read data only) [default = False]') other.add_argument('--gpu-dist', default=False, action='store_true', help='Use a GPU when calculating distances [default = False]') + other.add_argument('--gpu-graph', default=False, action='store_true', help='Use a GPU when calculating networks [default = False]') other.add_argument('--deviceid', default=0, type=int, help='CUDA device ID, if using GPU [default = 0]') + other.add_argument('--no-plot', help='Switch off model plotting, which can be slow for large datasets', + default=False, action='store_true') + other.add_argument('--no-local', help='Do not perform the local optimization step in model refinement (speed up on very large datasets)', + default=False, action='store_true') other.add_argument('--version', action='version', version='%(prog)s '+__version__) @@ -200,6 +207,9 @@ def main(): from .network import constructNetwork from .network import extractReferences from .network import printClusters + from .network import get_vertex_list + from .network import save_network + from .network import checkNetworkVertexCount from .plot import writeClusterCsv from .plot import plot_scatter @@ -233,7 +243,10 @@ def main(): 'length_sigma': args.length_sigma, 'length_range': args.length_range, 'prop_n': args.prop_n, - 'upper_n': args.upper_n + 'upper_n': args.upper_n, + 'max_pi_dist': args.max_pi_dist, + 'max_a_dist': args.max_a_dist, + 'type_isolate': args.type_isolate } # Dict of DB access functions @@ -288,38 +301,42 @@ def main(): sys.stderr.write("--create-db requires --r-files and --output") sys.exit(1) - # generate sketches and QC sequences + # generate sketches and QC sequences to identify sequences not matching specified criteria createDatabaseDir(args.output, kmers) - seq_names = constructDatabase( - args.r_files, - kmers, - sketch_sizes, - args.output, - args.threads, - args.overwrite, - codon_phased = args.codon_phased, - calc_random = True) - - rNames = seq_names - qNames = seq_names - refList, queryList, distMat = queryDatabase(rNames = rNames, - qNames = qNames, - dbPrefix = args.output, - queryPrefix = args.output, - klist = kmers, - self = True, - number_plot_fits = args.plot_fit, - threads = args.threads) - qcDistMat(distMat, refList, queryList, args.max_a_dist) - - # Save results - dists_out = args.output + "/" + os.path.basename(args.output) + ".dists" - storePickle(refList, queryList, True, distMat, dists_out) + seq_names_passing = \ + constructDatabase( + args.r_files, + kmers, + sketch_sizes, + args.output, + args.threads, + args.overwrite, + codon_phased = args.codon_phased, + calc_random = True) + + # calculate distances between sequences + distMat = queryDatabase(rNames = seq_names_passing, + qNames = seq_names_passing, + dbPrefix = args.output, + queryPrefix = args.output, + klist = kmers, + self = True, + number_plot_fits = args.plot_fit, + threads = args.threads) + + # QC pairwise distances to identify long distances indicative of anomalous sequences in the collection + seq_names_passing, distMat = qcDistMat(distMat, + seq_names_passing, + seq_names_passing, + args.output, + args.output, + qc_dict) # Plot results - plot_scatter(distMat, - args.output + "/" + os.path.basename(args.output) + "_distanceDistribution", - args.output + " distances") + if not args.no_plot: + plot_scatter(distMat, + args.output + "/" + os.path.basename(args.output) + "_distanceDistribution", + args.output + " distances") #******************************# #* *# @@ -340,7 +357,7 @@ def main(): sys.stderr.write("Need to provide --ref-db where .h5 and .dists from " "--create-db mode were output") if args.distances is None: - distances = os.path.basename(args.ref_db) + "/" + args.ref_db + ".dists" + distances = args.ref_db + "/" + os.path.basename(args.ref_db) + ".dists" else: distances = args.distances if args.output is None: @@ -365,8 +382,9 @@ def main(): # Load the distances refList, queryList, self, distMat = readPickle(distances, enforce_self=True) - if qcDistMat(distMat, refList, queryList, args.max_a_dist) == False \ - and args.qc_filter == "stop": + seq_names = set(set(refList) | set(queryList)) + seq_names_passing, distMat = qcDistMat(distMat, refList, queryList, args.ref_db, output, qc_dict) + if len(set(seq_names_passing).difference(seq_names)) > 0 and args.qc_filter == "stop": sys.stderr.write("Distances failed quality control (change QC options to run anyway)\n") sys.exit(1) @@ -382,13 +400,11 @@ def main(): model = DBSCANFit(output) model.set_threads(args.threads) assignments = model.fit(distMat, args.D, args.min_cluster_prop) - model.plot() # Run Gaussian model elif args.fit_model == "bgmm": model = BGMMFit(output) model.set_threads(args.threads) assignments = model.fit(distMat, args.K) - model.plot(distMat, assignments) elif args.fit_model == "refine": new_model = RefineFit(output) new_model.set_threads(args.threads) @@ -398,15 +414,14 @@ def main(): args.indiv_refine, args.unconstrained, args.score_idx, - args.no_local) - new_model.plot(distMat) + args.no_local, + args.gpu_graph) model = new_model elif args.fit_model == "threshold": new_model = RefineFit(output) new_model.set_threads(args.threads) assignments = new_model.apply_threshold(distMat, args.threshold) - new_model.plot(distMat) model = new_model elif args.fit_model == "lineage": # run lineage clustering. Sparsity & low rank should keep memory @@ -414,7 +429,6 @@ def main(): model = LineageFit(output, rank_list) model.set_threads(args.threads) model.fit(distMat, args.use_accessory) - model.plot(distMat) assignments = {} for rank in rank_list: @@ -423,6 +437,10 @@ def main(): # save model model.save() + + # plot model + if not args.no_plot: + model.plot(distMat, assignments) # use model else: @@ -443,7 +461,8 @@ def main(): queryList, assignments, model.within_label, - weights=weights) + weights = weights, + use_gpu = args.gpu_graph) else: # Lineage fit requires some iteration indivNetworks = {} @@ -459,13 +478,15 @@ def main(): refList, assignments[rank], 0, - edge_list=True, - weights=weights + edge_list = True, + weights = weights, + use_gpu = args.gpu_graph ) lineage_clusters[rank] = \ printClusters(indivNetworks[rank], refList, - printCSV = False) + printCSV = False, + use_gpu = args.gpu_graph) # print output of each rank as CSV overall_lineage = createOverallLineage(rank_list, lineage_clusters) @@ -480,16 +501,14 @@ def main(): genomeNetwork = indivNetworks[min(rank_list)] # Ensure all in dists are in final network - networkMissing = set(map(str,set(range(len(refList))).difference(list(genomeNetwork.vertices())))) - if len(networkMissing) > 0: - missing_isolates = [refList[m] for m in networkMissing] - sys.stderr.write("WARNING: Samples " + ", ".join(missing_isolates) + " are missing from the final network\n") + checkNetworkVertexCount(refList, genomeNetwork, use_gpu = args.gpu_graph) fit_type = model.type isolateClustering = {fit_type: printClusters(genomeNetwork, refList, output + "/" + os.path.basename(output), - externalClusterCSV = args.external_clustering)} + externalClusterCSV = args.external_clustering, + use_gpu = args.gpu_graph)} # Write core and accessory based clusters, if they worked if model.indiv_fitted: @@ -517,9 +536,7 @@ def main(): fit_type = 'accessory' genomeNetwork = indivNetworks['accessory'] - genomeNetwork.save(output + "/" + \ - os.path.basename(output) + '_graph.gt', - fmt = 'gt') + save_network(genomeNetwork, prefix = output, suffix = "_graph", use_gpu = args.gpu_graph) #******************************# #* *# @@ -530,7 +547,12 @@ def main(): # (this no longer loses information and should generally be kept on) if model.type != "lineage": newReferencesIndices, newReferencesNames, newReferencesFile, genomeNetwork = \ - extractReferences(genomeNetwork, refList, output, threads = args.threads) + extractReferences(genomeNetwork, + refList, + output, + type_isolate = qc_dict['type_isolate'], + threads = args.threads, + use_gpu = args.gpu_graph) nodes_to_remove = set(range(len(refList))).difference(newReferencesIndices) names_to_remove = [refList[n] for n in nodes_to_remove] @@ -539,9 +561,8 @@ def main(): prune_distance_matrix(refList, names_to_remove, distMat, output + "/" + os.path.basename(output) + ".refs.dists") # Save reference network - genomeNetwork.save(output + "/" + \ - os.path.basename(output) + '.refs_graph.gt', - fmt = 'gt') + save_network(genomeNetwork, prefix = output, suffix = ".refs_graph", + use_gpu = args.gpu_graph) removeFromDB(args.ref_db, output, names_to_remove) os.rename(output + "/" + os.path.basename(output) + ".tmp.h5", output + "/" + os.path.basename(output) + ".refs.h5") diff --git a/PopPUNK/assign.py b/PopPUNK/assign.py index bc5002dc..dcd1f276 100644 --- a/PopPUNK/assign.py +++ b/PopPUNK/assign.py @@ -9,12 +9,15 @@ import numpy as np import subprocess from collections import defaultdict +import scipy.optimize +from scipy.sparse import coo_matrix, bmat, find # required from v2.1.1 onwards (no mash support) import pp_sketchlib # import poppunk package from .__init__ import __version__ +from .models import rankFile #*******************************# #* *# @@ -25,6 +28,7 @@ def assign_query(dbFuncs, ref_db, q_files, output, + qc_dict, update_db, write_references, distances, @@ -33,12 +37,18 @@ def assign_query(dbFuncs, plot_fit, graph_weights, max_a_dist, + max_pi_dist, + type_isolate, model_dir, strand_preserved, previous_clustering, external_clustering, core_only, accessory_only, + gpu_sketch, + gpu_dist, + gpu_graph, + deviceid, web, json_sketch, save_partial_query_graph): @@ -55,12 +65,12 @@ def assign_query(dbFuncs, from .network import extractReferences from .network import addQueryToNetwork from .network import printClusters + from .network import save_network from .plot import writeClusterCsv from .prune_db import prune_distance_matrix - from .sketchlib import calculateQueryQueryDistances from .sketchlib import addRandom from .utils import storePickle @@ -116,7 +126,13 @@ def assign_query(dbFuncs, for reference in refFile: rNames.append(reference.rstrip()) else: - rNames = getSeqsInDb(ref_db + "/" + os.path.basename(ref_db) + ".h5") + if os.path.isfile(distances + ".pkl"): + rNames = readPickle(distances, enforce_self = True, distances=False)[0] + elif update_db: + sys.stderr.write("Reference distances missing, cannot use --update-db\n") + sys.exit(1) + else: + rNames = getSeqsInDb(ref_db + "/" + os.path.basename(ref_db) + ".h5") # construct database if (web and json_sketch): qNames = sketch_to_hdf5(json_sketch, output) @@ -130,39 +146,47 @@ def assign_query(dbFuncs, threads, overwrite, codon_phased = codon_phased, - calc_random = False) + calc_random = False, + use_gpu = gpu_sketch, + deviceid = deviceid) # run query - refList, queryList, qrDistMat = queryDatabase(rNames = rNames, - qNames = qNames, - dbPrefix = ref_db, - queryPrefix = output, - klist = kmers, - self = False, - number_plot_fits = plot_fit, - threads = threads) + qrDistMat = queryDatabase(rNames = rNames, + qNames = qNames, + dbPrefix = ref_db, + queryPrefix = output, + klist = kmers, + self = False, + number_plot_fits = plot_fit, + threads = threads, + use_gpu = gpu_dist) # QC distance matrix - qcPass = qcDistMat(qrDistMat, refList, queryList, max_a_dist) + if qc_dict['run_qc']: + seq_names_passing = qcDistMat(qrDistMat, rNames, qNames, ref_db, output, qc_dict)[0] + else: + seq_names_passing = rNames + qNames # Load the network based on supplied options genomeNetwork, old_cluster_file = \ fetchNetwork(prev_clustering, model, - refList, + rNames, ref_graph = use_ref_graph, core_only = core_only, - accessory_only = accessory_only) + accessory_only = accessory_only, + use_gpu = gpu_graph) if model.type == 'lineage': # Assign lineages by calculating query-query information addRandom(output, qNames, kmers, strand_preserved, overwrite, threads) - qlist1, qlist2, qqDistMat = queryDatabase(rNames = qNames, - qNames = qNames, - dbPrefix = output, - queryPrefix = output, - klist = kmers, - self = True, - number_plot_fits = 0, - threads = threads) + qqDistMat = queryDatabase(rNames = qNames, + qNames = qNames, + dbPrefix = output, + queryPrefix = output, + klist = kmers, + self = True, + number_plot_fits = 0, + threads = threads, + use_gpu = gpu_dist) model.extend(qqDistMat, qrDistMat) genomeNetwork = {} @@ -179,22 +203,24 @@ def assign_query(dbFuncs, assignment, 0, edge_list = True, - weights=weights) + weights=weights, + use_gpu = gpu_graph) isolateClustering[rank] = \ printClusters(genomeNetwork[rank], - refList + queryList, - printCSV = False) + rNames + qNames, + printCSV = False, + use_gpu = gpu_graph) overall_lineage = createOverallLineage(model.ranks, isolateClustering) writeClusterCsv( output + "/" + os.path.basename(output) + '_lineages.csv', - refList + queryList, - refList + queryList, + rNames + qNames, + rNames + qNames, overall_lineage, output_format = 'phandango', epiCsv = None, - queryNames = queryList, + queryNames = qNames, suffix = '_Lineage') else: @@ -206,25 +232,27 @@ def assign_query(dbFuncs, weights = qrDistMat else: weights = None - qqDistMat = \ - addQueryToNetwork(dbFuncs, refList, queryList, + + genomeNetwork, qqDistMat = \ + addQueryToNetwork(dbFuncs, rNames, qNames, genomeNetwork, kmers, queryAssignments, model, output, update_db, strand_preserved, - weights = weights, threads = threads) + weights = weights, threads = threads, use_gpu = gpu_graph) isolateClustering = \ - {'combined': printClusters(genomeNetwork, refList + queryList, + {'combined': printClusters(genomeNetwork, rNames + qNames, output + "/" + os.path.basename(output), old_cluster_file, external_clustering, - write_references or update_db)} + write_references or update_db, + use_gpu = gpu_graph)} # Update DB as requested dists_out = output + "/" + os.path.basename(output) + ".dists" if update_db: # Check new sequences pass QC before adding them - if not qcPass: + if len(set(seq_names_passing).difference(rNames + qNames)) > 0: sys.stderr.write("Queries contained outlier distances, " "not updating database\n") else: @@ -234,24 +262,26 @@ def assign_query(dbFuncs, joinDBs(ref_db, output, output, {"threads": threads, "strand_preserved": strand_preserved}) if model.type == 'lineage': - genomeNetwork[min(model.ranks)].save(output + "/" + os.path.basename(output) + '_graph.gt', fmt = 'gt') - else: - genomeNetwork.save(output + "/" + os.path.basename(output) + '_graph.gt', fmt = 'gt') - - # Update distance matrices with all calculated distances - if distances == None: - distanceFiles = ref_db + "/" + os.path.basename(ref_db) + ".dists" + save_network(genomeNetwork[min(model.ranks)], prefix = output, suffix = '_graph', use_gpu = gpu_graph) + # Save sparse distance matrices and updated model + model.outPrefix = os.path.basename(output) + model.save() else: - distanceFiles = distances + save_network(genomeNetwork, prefix = output, suffix = '_graph', use_gpu = gpu_graph) - refList, refList_copy, self, rrDistMat = readPickle(distanceFiles, - enforce_self = True) + # Load the previous distances + refList_loaded, refList_copy, self, rrDistMat = \ + readPickle(distances, + enforce_self = True) + # This should now always be true, otherwise both qrDistMat and sparse matrix + # may need reordering + assert(refList_loaded == rNames) combined_seq, core_distMat, acc_distMat = \ - update_distance_matrices(refList, rrDistMat, - queryList, qrDistMat, - qqDistMat, threads = threads) - assert combined_seq == refList + queryList + update_distance_matrices(rNames, rrDistMat, + qNames, qrDistMat, + qqDistMat, threads = threads) + assert combined_seq == rNames + qNames # Get full distance matrix and save complete_distMat = \ @@ -259,14 +289,24 @@ def assign_query(dbFuncs, pp_sketchlib.squareToLong(acc_distMat, threads).reshape(-1, 1))) storePickle(combined_seq, combined_seq, True, complete_distMat, dists_out) + # Copy model if needed + if output != model.outPrefix: + model.copy(output) + # Clique pruning if model.type != 'lineage': - dbOrder = refList + queryList + dbOrder = rNames + qNames newRepresentativesIndices, newRepresentativesNames, \ newRepresentativesFile, genomeNetwork = \ - extractReferences(genomeNetwork, dbOrder, output, refList, threads = threads) + extractReferences(genomeNetwork, + dbOrder, + output, + existingRefs = rNames, + type_isolate = qc_dict['type_isolate'], + threads = threads, + use_gpu = gpu_graph) # intersection that maintains order - newQueries = [x for x in queryList if x in frozenset(newRepresentativesNames)] + newQueries = [x for x in qNames if x in frozenset(newRepresentativesNames)] # could also have newRepresentativesNames in this diff (should be the same) - but want # to ensure consistency with the network in case of bad input/bugs @@ -278,20 +318,20 @@ def assign_query(dbFuncs, postpruning_combined_seq, newDistMat = \ prune_distance_matrix(combined_seq, names_to_remove, complete_distMat, output + "/" + os.path.basename(output) + ".refs.dists") - genomeNetwork.save(output + "/" + os.path.basename(output) + '.refs_graph.gt', fmt = 'gt') + save_network(genomeNetwork, prefix = output, suffix = 'refs_graph', use_gpu = gpu_graph) removeFromDB(output, output, names_to_remove) os.rename(output + "/" + os.path.basename(output) + ".tmp.h5", - output + "/" + os.path.basename(output) + ".refs.h5") + output + "/" + os.path.basename(output) + ".refs.h5") # ensure sketch and distMat order match - assert postpruning_combined_seq == refList + newQueries + assert postpruning_combined_seq == rNames + newQueries else: - storePickle(refList, queryList, False, qrDistMat, dists_out) + storePickle(rNames, qNames, False, qrDistMat, dists_out) if save_partial_query_graph: if model.type == 'lineage': - genomeNetwork[min(model.ranks)].save(output + "/" + os.path.basename(output) + '_graph.gt', fmt = 'gt') + save_network(genomeNetwork[min(model.ranks)], prefix = output, suffix = '_graph', use_gpu = gpu_graph) else: - genomeNetwork.save(output + "/" + os.path.basename(output) + '_graph.gt', fmt = 'gt') + save_network(genomeNetwork, prefix = output, suffix = '_graph', use_gpu = gpu_graph) return(isolateClustering) @@ -337,9 +377,28 @@ def get_options(): 'k-mers [default = use canonical k-mers]') # qc options - qcGroup = parser.add_argument_group('Quality control options') + qcGroup = parser.add_argument_group('Quality control options for distances') + qcGroup.add_argument('--qc-filter', help='Behaviour following sequence QC step: "stop" [default], "prune"' + ' (analyse data passing QC), or "continue" (analyse all data)', + default='stop', type = str, choices=['stop', 'prune', 'continue']) + qcGroup.add_argument('--retain-failures', help='Retain sketches of genomes that do not pass QC filters in ' + 'separate database [default = False]', default=False, action='store_true') qcGroup.add_argument('--max-a-dist', help='Maximum accessory distance to permit [default = 0.5]', default = 0.5, type = float) + qcGroup.add_argument('--max-pi-dist', help='Maximum core distance to permit [default = 0.5]', + default = 0.5, type = float) + qcGroup.add_argument('--type-isolate', help='Isolate from which distances can be calculated for pruning [default = None]', + default = None, type = str) + qcGroup.add_argument('--length-sigma', help='Number of standard deviations of length distribution beyond ' + 'which sequences will be excluded [default = 5]', default = None, type = int) + qcGroup.add_argument('--length-range', help='Allowed length range, outside of which sequences will be excluded ' + '[two values needed - lower and upper bounds]', default=[None,None], + type = int, nargs = 2) + qcGroup.add_argument('--prop-n', help='Threshold ambiguous base proportion above which sequences will be excluded' + ' [default = 0.1]', default = None, + type = float) + qcGroup.add_argument('--upper-n', help='Threshold ambiguous base count above which sequences will be excluded', + default=None, type = int) # sequence querying queryingGroup = parser.add_argument_group('Database querying options') @@ -360,6 +419,7 @@ def get_options(): other.add_argument('--threads', default=1, type=int, help='Number of threads to use [default = 1]') other.add_argument('--gpu-sketch', default=False, action='store_true', help='Use a GPU when calculating sketches (read data only) [default = False]') other.add_argument('--gpu-dist', default=False, action='store_true', help='Use a GPU when calculating distances [default = False]') + other.add_argument('--gpu-graph', default=False, action='store_true', help='Use a GPU when constructing networks [default = False]') other.add_argument('--deviceid', default=0, type=int, help='CUDA device ID, if using GPU [default = 0]') other.add_argument('--version', action='version', version='%(prog)s '+__version__) @@ -402,7 +462,37 @@ def main(): from .utils import setupDBFuncs # Dict of QC options for passing to database construction and querying functions - qc_dict = {'run_qc': False } + if args.length_sigma is None and None in args.length_range and args.prop_n is None \ + and args.upper_n is None and args.max_a_dist is None and args.max_pi_dist is None: + qc_dict = {'run_qc': False, 'type_isolate': None } + else: + # define defaults if one QC parameter given + # length_sigma + if args.length_sigma is not None: + length_sigma = args.length_sigma + elif None in args.length_range: + length_sigma = 5 # default used in __main__ + else: + length_sigma = None + # prop_n + if args.prop_n is not None: + prop_n = args.prop_n + elif args.upper_n is None: + prop_n = 0.1 # default used in __main__ + else: + prop_n = None + qc_dict = { + 'run_qc': True, + 'qc_filter': args.qc_filter, + 'retain_failures': args.retain_failures, + 'length_sigma': length_sigma, + 'length_range': args.length_range, + 'prop_n': prop_n, + 'upper_n': args.upper_n, + 'max_pi_dist': args.max_pi_dist, + 'max_a_dist': args.max_a_dist, + 'type_isolate': args.type_isolate + } # Dict of DB access functions for assign_query (which is out of scope) dbFuncs = setupDBFuncs(args, args.min_kmer_count, qc_dict) @@ -416,7 +506,7 @@ def main(): setGtThreads(args.threads) if args.distances is None: - distances = os.path.basename(args.db) + "/" + args.db + ".dists" + distances = args.db + "/" + os.path.basename(args.db) + ".dists" else: distances = args.distances @@ -430,6 +520,7 @@ def main(): args.db, args.query, args.output, + qc_dict, args.update_db, args.write_references, distances, @@ -438,12 +529,18 @@ def main(): args.plot_fit, args.graph_weights, args.max_a_dist, + args.max_pi_dist, + args.type_isolate, args.model_dir, args.strand_preserved, args.previous_clustering, args.external_clustering, args.core_only, args.accessory_only, + args.gpu_sketch, + args.gpu_dist, + args.gpu_graph, + args.deviceid, web=False, json_sketch=None, save_partial_query_graph=False) diff --git a/PopPUNK/models.py b/PopPUNK/models.py index 0e4d77c5..8a68917f 100644 --- a/PopPUNK/models.py +++ b/PopPUNK/models.py @@ -33,6 +33,14 @@ sys.stderr.write("This version of PopPUNK requires python v3.8 or higher\n") sys.exit(0) +# GPU support +try: + import cugraph + import cudf + gpu_lib = True +except ImportError as e: + gpu_lib = False + import pp_sketchlib import poppunk_refine @@ -252,6 +260,12 @@ def no_scale(self): is done in the scaled space). ''' self.scale = np.array([1, 1], dtype = self.default_dtype) + + def copy(self, prefix): + """Copy the model to a new directory + """ + self.outPrefix = prefix + self.save() class BGMMFit(ClusterFit): @@ -677,7 +691,7 @@ def __init__(self, outPrefix): self.unconstrained = False def fit(self, X, sample_names, model, max_move, min_move, startFile = None, indiv_refine = False, - unconstrained = False, score_idx = 0, no_local = False): + unconstrained = False, score_idx = 0, no_local = False, use_gpu = False): '''Extends :func:`~ClusterFit.fit` Fits the distances by optimising network score, by calling @@ -700,11 +714,9 @@ def fit(self, X, sample_names, model, max_move, min_move, startFile = None, indi startFile (str) A file defining an initial fit, rather than one from ``--fit-model``. See documentation for format. - (default = None). indiv_refine (bool) Run refinement for core and accessory distances separately - (default = False). unconstrained (bool) If True, search in 2D and change the slope of the boundary @@ -714,6 +726,9 @@ def fit(self, X, sample_names, model, max_move, min_move, startFile = None, indi no_local (bool) Turn off the local optimisation step. Quicker, but may be less well refined. + use_gpu (bool) + Whether to use cugraph for graph analyses + Returns: y (numpy.array) Cluster assignments of samples in X @@ -724,6 +739,11 @@ def fit(self, X, sample_names, model, max_move, min_move, startFile = None, indi self.min_move = min_move self.unconstrained = unconstrained + # load CUDA libraries + if use_gpu and not gpu_lib: + sys.stderr.write('Unable to load GPU libraries; exiting\n') + sys.exit(1) + # Get starting point model.no_scale() if startFile: @@ -761,7 +781,7 @@ def fit(self, X, sample_names, model, max_move, min_move, startFile = None, indi refineFit(X/self.scale, sample_names, self.start_s, self.mean0, self.mean1, self.max_move, self.min_move, slope = 2, score_idx = score_idx, unconstrained = unconstrained, - no_local = no_local, num_processes = self.threads) + no_local = no_local, num_processes = self.threads, use_gpu = use_gpu) self.fitted = True # Try and do a 1D refinement for both core and accessory @@ -779,7 +799,8 @@ def fit(self, X, sample_names, model, max_move, min_move, startFile = None, indi start_point, acc_core, self.accessory_boundary, self.min_move, self.max_move = \ refineFit(X/self.scale, sample_names, self.start_s,self.mean0, self.mean1, self.max_move, self.min_move, - slope = 1, score_idx = score_idx, no_local = no_local, num_processes = self.threads) + slope = 1, score_idx = score_idx, no_local = no_local, num_processes = self.threads, + use_gpu = use_gpu) self.indiv_fitted = True except RuntimeError as e: print(e) @@ -1028,7 +1049,7 @@ def load(self, fit_npz, fit_obj): self.nn_dists = fit_npz self.fitted = True - def plot(self, X): + def plot(self, X, y = None): '''Extends :func:`~ClusterFit.plot` Write a summary of the fit, and plot the results using @@ -1037,6 +1058,9 @@ def plot(self, X): Args: X (numpy.array) Core and accessory distances + y (any) + Unused variable for compatibility with other + plotting functions ''' ClusterFit.plot(self, X) for rank in self.ranks: diff --git a/PopPUNK/network.py b/PopPUNK/network.py index 4a094728..34354c1f 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -19,9 +19,18 @@ from collections import defaultdict, Counter from functools import partial from multiprocessing import Pool +import pickle import graph_tool.all as gt import dendropy +# GPU support +try: + import cugraph + import cudf + gpu_lib = True +except ImportError as e: + gpu_lib = False + from .__main__ import accepted_weights_types from .sketchlib import addRandom @@ -36,7 +45,7 @@ from .unwords import gen_unword def fetchNetwork(network_dir, model, refList, ref_graph = False, - core_only = False, accessory_only = False): + core_only = False, accessory_only = False, use_gpu = False): """Load the network based on input options Returns the network as a graph-tool format graph, and sets @@ -54,12 +63,12 @@ def fetchNetwork(network_dir, model, refList, ref_graph = False, [default = False] core_only (bool) Return the network created using only core distances - [default = False] accessory_only (bool) Return the network created using only accessory distances - [default = False] + use_gpu (bool) + Use cugraph library to load graph Returns: genomeNetwork (graph) @@ -69,33 +78,93 @@ def fetchNetwork(network_dir, model, refList, ref_graph = False, """ # If a refined fit, may use just core or accessory distances dir_prefix = network_dir + "/" + os.path.basename(network_dir) + + # load CUDA libraries + if use_gpu and not gpu_lib: + sys.stderr.write('Unable to load GPU libraries; exiting\n') + sys.exit(1) + + if use_gpu: + graph_suffix = '.csv.gz' + else: + graph_suffix = '.gt' + if core_only and model.type == 'refine': model.slope = 0 - network_file = dir_prefix + '_core_graph.gt' + network_file = dir_prefix + '_core_graph' + graph_suffix cluster_file = dir_prefix + '_core_clusters.csv' elif accessory_only and model.type == 'refine': model.slope = 1 - network_file = dir_prefix + '_accessory_graph.gt' + network_file = dir_prefix + '_accessory_graph' + graph_suffix cluster_file = dir_prefix + '_accessory_clusters.csv' else: - if ref_graph and os.path.isfile(dir_prefix + '.refs_graph.gt'): - network_file = dir_prefix + '.refs_graph.gt' + if ref_graph and os.path.isfile(dir_prefix + '.refs_graph' + graph_suffix): + network_file = dir_prefix + '.refs_graph' + graph_suffix else: - network_file = dir_prefix + '_graph.gt' + network_file = dir_prefix + '_graph' + graph_suffix cluster_file = dir_prefix + '_clusters.csv' if core_only or accessory_only: sys.stderr.write("Can only do --core-only or --accessory-only fits from " "a refined fit. Using the combined distances.\n") - genomeNetwork = gt.load_graph(network_file) - sys.stderr.write("Network loaded: " + str(len(list(genomeNetwork.vertices()))) + " samples\n") + # Load network file + genomeNetwork = load_network_file(network_file, use_gpu = use_gpu) # Ensure all in dists are in final network - networkMissing = set(map(str,set(range(len(refList))).difference(list(genomeNetwork.vertices())))) - if len(networkMissing) > 0: - sys.stderr.write("WARNING: Samples " + ",".join(networkMissing) + " are missing from the final network\n") + checkNetworkVertexCount(refList, genomeNetwork, use_gpu) + + return genomeNetwork, cluster_file + +def load_network_file(fn, use_gpu = False): + """Load the network based on input options + + Returns the network as a graph-tool format graph, and sets + the slope parameter of the passed model object. + + Args: + fn (str) + Network file name + use_gpu (bool) + Use cugraph library to load graph + + Returns: + genomeNetwork (graph) + The loaded network + """ + # Load the network from the specified file + if use_gpu: + G_df = cudf.read_csv(fn, compression = 'gzip') + genomeNetwork = cugraph.Graph() + if 'weights' in G_df.columns: + G_df.columns = ['source','destination','weights'] + genomeNetwork.from_cudf_edgelist(G_df, edge_attr='weights', renumber=False) + else: + G_df.columns = ['source','destination'] + genomeNetwork.from_cudf_edgelist(G_df,renumber=False) + sys.stderr.write("Network loaded: " + str(genomeNetwork.number_of_vertices()) + " samples\n") + else: + genomeNetwork = gt.load_graph(fn) + sys.stderr.write("Network loaded: " + str(len(list(genomeNetwork.vertices()))) + " samples\n") + + return genomeNetwork - return (genomeNetwork, cluster_file) +def checkNetworkVertexCount(seq_list, G, use_gpu): + """Checks the number of network vertices matches the number + of sequence names. + + Args: + seq_list (list) + The list of sequence names + G (graph) + The network of sequences + use_gpu (bool) + Whether to use cugraph for graph analyses + """ + vertex_list = set(get_vertex_list(G, use_gpu = use_gpu)) + networkMissing = set(set(range(len(seq_list))).difference(vertex_list)) + if len(networkMissing) > 0: + sys.stderr.write("ERROR: " + str(len(networkMissing)) + " samples are missing from the final network\n") + sys.exit(1) def getCliqueRefs(G, reference_indices = set()): """Recursively prune a network of its cliques. Returns one vertex from @@ -140,7 +209,8 @@ def cliquePrune(component, graph, reference_indices, components_list): ref_list = getCliqueRefs(subgraph, refs) return(list(ref_list)) -def extractReferences(G, dbOrder, outPrefix, existingRefs = None, threads = 1): +def extractReferences(G, dbOrder, outPrefix, type_isolate = None, + existingRefs = None, threads = 1, use_gpu = False): """Extract references for each cluster based on cliques Writes chosen references to file by calling :func:`~writeReferences` @@ -152,8 +222,12 @@ def extractReferences(G, dbOrder, outPrefix, existingRefs = None, threads = 1): The order of files in the sketches, so returned references are in the same order outPrefix (str) Prefix for output file (.refs will be appended) + type_isolate (str) + Isolate to be included in set of references existingRefs (list) References that should be used for each clique + use_gpu (bool) + Use cugraph for graph analysis (default = False) Returns: refFileName (str) @@ -169,80 +243,168 @@ def extractReferences(G, dbOrder, outPrefix, existingRefs = None, threads = 1): index_lookup = {v:k for k,v in enumerate(dbOrder)} reference_indices = set([index_lookup[r] for r in references]) - # Each component is independent, so can be multithreaded - components = gt.label_components(G)[0].a - - # Turn gt threading off and on again either side of the parallel loop - if gt.openmp_enabled(): - gt.openmp_set_num_threads(1) - - # Cliques are pruned, taking one reference from each, until none remain - with Pool(processes=threads) as pool: - ref_lists = pool.map(partial(cliquePrune, - graph=G, - reference_indices=reference_indices, - components_list=components), - set(components)) - # Returns nested lists, which need to be flattened - reference_indices = set([entry for sublist in ref_lists for entry in sublist]) - - if gt.openmp_enabled(): - gt.openmp_set_num_threads(threads) - - # Use a vertex filter to extract the subgraph of refences - # as a graphview - reference_vertex = G.new_vertex_property('bool') - for n, vertex in enumerate(G.vertices()): - if n in reference_indices: - reference_vertex[vertex] = True + # Add type isolate, if necessary + type_isolate_index = None + if type_isolate is not None: + if type_isolate in dbOrder: + type_isolate_index = dbOrder.index(type_isolate) else: - reference_vertex[vertex] = False - G_ref = gt.GraphView(G, vfilt = reference_vertex) - G_ref = gt.Graph(G_ref, prune = True) # https://stackoverflow.com/questions/30839929/graph-tool-graphview-object - - # Find any clusters which are represented by >1 references - # This creates a dictionary: cluster_id: set(ref_idx in cluster) - clusters_in_full_graph = printClusters(G, dbOrder, printCSV=False) - reference_clusters_in_full_graph = defaultdict(set) - for reference_index in reference_indices: - reference_clusters_in_full_graph[clusters_in_full_graph[dbOrder[reference_index]]].add(reference_index) - - # Calculate the component membership within the reference graph - ref_order = [name for idx, name in enumerate(dbOrder) if idx in frozenset(reference_indices)] - clusters_in_reference_graph = printClusters(G_ref, ref_order, printCSV=False) - # Record the components/clusters the references are in the reference graph - # dict: name: ref_cluster - reference_clusters_in_reference_graph = {} - for reference_name in ref_order: - reference_clusters_in_reference_graph[reference_name] = clusters_in_reference_graph[reference_name] - - # Check if multi-reference components have been split as a validation test - # First iterate through clusters - network_update_required = False - for cluster_id, ref_idxs in reference_clusters_in_full_graph.items(): - # Identify multi-reference clusters by this length - if len(ref_idxs) > 1: - check = list(ref_idxs) - # check if these are still in the same component in the reference graph - for i in range(len(check)): - component_i = reference_clusters_in_reference_graph[dbOrder[check[i]]] - for j in range(i + 1, len(check)): - # Add intermediate nodes - component_j = reference_clusters_in_reference_graph[dbOrder[check[j]]] - if component_i != component_j: - network_update_required = True - vertex_list, edge_list = gt.shortest_path(G, check[i], check[j]) - # update reference list - for vertex in vertex_list: - reference_vertex[vertex] = True - reference_indices.add(int(vertex)) - - # update reference graph if vertices have been added - if network_update_required: + sys.stderr.write('Type isolate ' + type_isolate + ' not found\n') + sys.exit(1) + + if use_gpu: + if not gpu_lib: + sys.stderr.write('Unable to load GPU libraries; exiting\n') + sys.exit(1) + + # For large network, use more approximate method for extracting references + reference = {} + # Record the original components to which sequences belonged + component_assignments = cugraph.components.connectivity.connected_components(G) + # Leiden method has resolution parameter - higher values give greater precision + partition_assignments, score = cugraph.leiden(G, resolution = 0.1) + # group by partition, which becomes the first column, so retrieve second column + reference_index_df = partition_assignments.groupby('partition').nth(0) + reference_indices = reference_index_df['vertex'].to_arrow().to_pylist() + + # Add type isolate if necessary - before edges are added + if type_isolate_index is not None and type_isolate_index not in reference_indices: + reference_indices.append(type_isolate_index) + + # Order found references as in sketchlib database + reference_names = [dbOrder[int(x)] for x in sorted(reference_indices)] + refFileName = writeReferences(reference_names, outPrefix) + + # Extract reference edges + G_df = G.view_edge_list() + if 'src' in G_df.columns: + G_df.rename(columns={'src': 'source','dst': 'destination'}, inplace=True) + G_ref_df = G_df[G_df['source'].isin(reference_indices) & G_df['destination'].isin(reference_indices)] + # Add self-loop if needed + max_in_vertex_labels = max(reference_indices) + G_ref = add_self_loop(G_ref_df,max_in_vertex_labels, renumber = False) + + # Check references in same component in overall graph are connected in the reference graph + # First get components of original reference graph + reference_component_assignments = cugraph.components.connectivity.connected_components(G_ref) + reference_component_assignments.rename(columns={'labels': 'ref_labels'}, inplace=True) + # Merge with component assignments from overall graph + combined_vertex_assignments = reference_component_assignments.merge(component_assignments, + on = 'vertex', + how = 'left') + combined_vertex_assignments = combined_vertex_assignments[combined_vertex_assignments['vertex'].isin(reference_indices)] + # Find the number of components in the reference graph associated with each component in the overall graph - + # should be one if there is a one-to-one mapping of components - else links need to be added + max_ref_comp_count = combined_vertex_assignments.groupby(['labels'], sort = False)['ref_labels'].nunique().max() + if max_ref_comp_count > 1: + # Iterate through components + for component, component_df in combined_vertex_assignments.groupby(['labels'], sort = False): + # Find components in the overall graph matching multiple components in the reference graph + if component_df.groupby(['labels'], sort = False)['ref_labels'].nunique().iloc[0] > 1: + # Make a graph of the component from the overall graph + vertices_in_component = component_assignments[component_assignments['labels']==component]['vertex'] + references_in_component = vertices_in_component[vertices_in_component.isin(reference_indices)].values + G_component_df = G_df[G_df['source'].isin(vertices_in_component) & G_df['destination'].isin(vertices_in_component)] + G_component = cugraph.Graph() + G_component.from_cudf_edgelist(G_component_df) + # Find single shortest path from a reference to all other nodes in the component + traversal = cugraph.traversal.sssp(G_component,source = references_in_component[0]) + reference_index_set = set(reference_indices) + # Add predecessors to reference sequences on the SSSPs + predecessor_list = traversal[traversal['vertex'].isin(reference_indices)]['predecessor'].values + predecessors = set(predecessor_list[predecessor_list >= 0].flatten().tolist()) + # Add predecessors to reference set and check whether this results in complete paths + # where complete paths are indicated by references' predecessors being within the set of + # references + while len(predecessors) > 0 and len(predecessors - reference_index_set) > 0: + reference_index_set = reference_index_set.union(predecessors) + predecessor_list = traversal[traversal['vertex'].isin(reference_indices)]['predecessor'].values + predecessors = set(predecessor_list[predecessor_list >= 0].flatten().tolist()) + # Add expanded reference set to the overall list + reference_indices = list(reference_index_set) + # Create new reference graph + G_ref_df = G_df[G_df['source'].isin(reference_indices) & G_df['destination'].isin(reference_indices)] + G_ref = add_self_loop(G_ref_df, max_in_vertex_labels, renumber = False) + + else: + # Each component is independent, so can be multithreaded + components = gt.label_components(G)[0].a + + # Turn gt threading off and on again either side of the parallel loop + if gt.openmp_enabled(): + gt.openmp_set_num_threads(1) + + # Cliques are pruned, taking one reference from each, until none remain + with Pool(processes=threads) as pool: + ref_lists = pool.map(partial(cliquePrune, + graph=G, + reference_indices=reference_indices, + components_list=components), + set(components)) + # Returns nested lists, which need to be flattened + reference_indices = set([entry for sublist in ref_lists for entry in sublist]) + + # Add type isolate if necessary - before edges are added + if type_isolate_index is not None and type_isolate_index not in reference_indices: + reference_indices.add(type_isolate_index) + + if gt.openmp_enabled(): + gt.openmp_set_num_threads(threads) + + # Use a vertex filter to extract the subgraph of refences + # as a graphview + reference_vertex = G.new_vertex_property('bool') + for n, vertex in enumerate(G.vertices()): + if n in reference_indices: + reference_vertex[vertex] = True + else: + reference_vertex[vertex] = False G_ref = gt.GraphView(G, vfilt = reference_vertex) G_ref = gt.Graph(G_ref, prune = True) # https://stackoverflow.com/questions/30839929/graph-tool-graphview-object - # Order found references as in mash sketch files + # Find any clusters which are represented by >1 references + # This creates a dictionary: cluster_id: set(ref_idx in cluster) + clusters_in_full_graph = printClusters(G, dbOrder, printCSV=False) + reference_clusters_in_full_graph = defaultdict(set) + for reference_index in reference_indices: + reference_clusters_in_full_graph[clusters_in_full_graph[dbOrder[reference_index]]].add(reference_index) + + # Calculate the component membership within the reference graph + ref_order = [name for idx, name in enumerate(dbOrder) if idx in frozenset(reference_indices)] + clusters_in_reference_graph = printClusters(G_ref, ref_order, printCSV=False) + # Record the components/clusters the references are in the reference graph + # dict: name: ref_cluster + reference_clusters_in_reference_graph = {} + for reference_name in ref_order: + reference_clusters_in_reference_graph[reference_name] = clusters_in_reference_graph[reference_name] + + # Check if multi-reference components have been split as a validation test + # First iterate through clusters + network_update_required = False + for cluster_id, ref_idxs in reference_clusters_in_full_graph.items(): + # Identify multi-reference clusters by this length + if len(ref_idxs) > 1: + check = list(ref_idxs) + # check if these are still in the same component in the reference graph + for i in range(len(check)): + component_i = reference_clusters_in_reference_graph[dbOrder[check[i]]] + for j in range(i + 1, len(check)): + # Add intermediate nodes + component_j = reference_clusters_in_reference_graph[dbOrder[check[j]]] + if component_i != component_j: + network_update_required = True + vertex_list, edge_list = gt.shortest_path(G, check[i], check[j]) + # update reference list + for vertex in vertex_list: + reference_vertex[vertex] = True + reference_indices.add(int(vertex)) + + # update reference graph if vertices have been added + if network_update_required: + G_ref = gt.GraphView(G, vfilt = reference_vertex) + G_ref = gt.Graph(G_ref, prune = True) # https://stackoverflow.com/questions/30839929/graph-tool-graphview-object + + # Order found references as in sketch files reference_names = [dbOrder[int(x)] for x in sorted(reference_indices)] refFileName = writeReferences(reference_names, outPrefix) return reference_indices, reference_names, refFileName, G_ref @@ -268,9 +430,83 @@ def writeReferences(refList, outPrefix): return refFileName +def network_to_edges(prev_G_fn, rlist, previous_pkl = None, weights = False, + use_gpu = False): + """Load previous network, extract the edges to match the + vertex order specified in rlist, and also return weights if specified. + + Args: + prev_G_fn (str) + Path of file containing existing network. + rlist (list) + List of reference sequence labels in new network + previous_pkl (str) + Path of pkl file containing names of sequences in + previous network + weights (bool) + Whether to return edge weights + (default = False) + use_gpu (bool) + Whether to use cugraph for graph analyses + + Returns: + source_ids (list) + Source nodes for each edge + target_ids (list) + Target nodes for each edge + edge_weights (list) + Weights for each new edge + """ + # get list for translating node IDs to rlist + prev_G = load_network_file(prev_G_fn, use_gpu = use_gpu) + + # load list of names in previous network + if previous_pkl is not None: + with open(previous_pkl, 'rb') as pickle_file: + old_rlist, old_qlist, self = pickle.load(pickle_file) + if self: + old_ids = old_rlist + else: + old_ids = old_rlist + old_qlist + else: + sys.stderr.write('Missing .pkl file containing names of sequences in ' + 'previous network\n') + sys.exit(1) + + # Get edges as lists of source,destination,weight using original IDs + if use_gpu: + G_df = prev_G.view_edge_list() + if weights: + G_df.columns = ['source','destination','weight'] + edge_weights = G_df['weight'].to_arrow().to_pylist() + else: + G_df.columns = ['source','destination'] + old_source_ids = G_df['source'].to_arrow().to_pylist() + old_target_ids = G_df['destination'].to_arrow().to_pylist() + else: + # get the source and target nodes + old_source_ids = gt.edge_endpoint_property(prev_G, prev_G.vertex_index, "source") + old_target_ids = gt.edge_endpoint_property(prev_G, prev_G.vertex_index, "target") + # get the weights + if weights: + edge_weights = list(prev_G.ep['weight']) + + # Update IDs to new versions + old_id_indices = [rlist.index(x) for x in old_ids] + # translate to indices + source_ids = [old_id_indices[x] for x in old_source_ids] + target_ids = [old_id_indices[x] for x in old_target_ids] + + # return values + if weights: + return source_ids, target_ids, edge_weights + else: + return source_ids, target_ids + def constructNetwork(rlist, qlist, assignments, within_label, summarise = True, edge_list = False, weights = None, - weights_type = 'euclidean', sparse_input = None): + weights_type = 'euclidean', sparse_input = None, + previous_network = None, previous_pkl = None, use_gpu = False): """Construct an unweighted, undirected network without self-loops. Nodes are samples and edges where samples are within the same cluster @@ -299,6 +535,13 @@ def constructNetwork(rlist, qlist, assignments, within_label, accessory or euclidean distance sparse_input (numpy.array) Sparse distance matrix from lineage fit + previous_network (str) + Name of file containing a previous network to be integrated into this new + network + previous_pkl (str) + Name of file containing the names of the sequences in the previous_network + use_gpu (bool) + Whether to use GPUs for network construction Returns: G (graph) @@ -350,25 +593,77 @@ def constructNetwork(rlist, qlist, assignments, within_label, edge_tuple = (ref, query) connections.append(edge_tuple) - # build the graph - G = gt.Graph(directed = False) - G.add_vertex(len(vertex_labels)) + # read previous graph + if previous_network is not None: + if previous_pkl is not None: + if weights is not None or sparse_input is not None: + extra_sources, extra_targets, extra_weights = network_to_edges(previous_network, + rlist, + previous_pkl = previous_pkl, + weights = True, + use_gpu = use_gpu) + for (ref, query, weight) in zip(extra_sources, extra_targets, extra_weights): + edge_tuple = (ref, query, weight) + connections.append(edge_tuple) + else: + extra_sources, extra_targets = network_to_edges(prev_G, + rlist, + previous_pkl = previous_pkl, + weights = False, + use_gpu = use_gpu) + for (ref, query) in zip(extra_sources, extra_targets): + edge_tuple = (ref, query) + connections.append(edge_tuple) + else: + sys.stderr.write('A distance pkl corresponding to ' + previous_pkl + ' is required for loading\n') + sys.exit(1) + + # load GPU libraries if necessary + if use_gpu: + + if not gpu_lib: + sys.stderr.write('Unable to load GPU libraries; exiting\n') + sys.exit(1) + + # Set memory management for large networks + cudf.set_allocator("managed") + + # create DataFrame using edge tuples + if weights is not None or sparse_input is not None: + G_df = cudf.DataFrame(connections, columns =['source', 'destination', 'weights']) + else: + G_df = cudf.DataFrame(connections, columns =['source', 'destination']) + + # ensure the highest-integer node is included in the edge list + # by adding a self-loop if necessary; see https://github.com/rapidsai/cugraph/issues/1206 + max_in_df = np.amax([G_df['source'].max(),G_df['destination'].max()]) + max_in_vertex_labels = len(vertex_labels)-1 + use_weights = False + if weights is not None: + use_weights = True + G = add_self_loop(G_df, max_in_vertex_labels, weights = use_weights, renumber = False) - if weights is not None or sparse_input is not None: - eweight = G.new_ep("float") - G.add_edge_list(connections, eprops = [eweight]) - G.edge_properties["weight"] = eweight else: - G.add_edge_list(connections) - # add isolate ID to network - vid = G.new_vertex_property('string', - vals = vertex_labels) - G.vp.id = vid + # build the graph + G = gt.Graph(directed = False) + G.add_vertex(len(vertex_labels)) + + if weights is not None or sparse_input is not None: + eweight = G.new_ep("float") + G.add_edge_list(connections, eprops = [eweight]) + G.edge_properties["weight"] = eweight + else: + G.add_edge_list(connections) + + # add isolate ID to network + vid = G.new_vertex_property('string', + vals = vertex_labels) + G.vp.id = vid # print some summaries if summarise: - (metrics, scores) = networkSummary(G) + (metrics, scores) = networkSummary(G, use_gpu = use_gpu) sys.stderr.write("Network summary:\n" + "\n".join(["\tComponents\t\t\t\t" + str(metrics[0]), "\tDensity\t\t\t\t\t" + "{:.4f}".format(metrics[1]), "\tTransitivity\t\t\t\t" + "{:.4f}".format(metrics[2]), @@ -381,7 +676,7 @@ def constructNetwork(rlist, qlist, assignments, within_label, return G -def networkSummary(G, calc_betweenness=True): +def networkSummary(G, calc_betweenness=True, use_gpu = False): """Provides summary values about the network Args: @@ -389,6 +684,8 @@ def networkSummary(G, calc_betweenness=True): The network of strains from :func:`~constructNetwork` calc_betweenness (bool) Whether to calculate betweenness stats + use_gpu (bool) + Whether to use cugraph for graph analysis Returns: metrics (list) @@ -397,27 +694,62 @@ def networkSummary(G, calc_betweenness=True): scores (list) List of scores """ - component_assignments, component_frequencies = gt.label_components(G) - components = len(component_frequencies) - density = len(list(G.edges()))/(0.5 * len(list(G.vertices())) * (len(list(G.vertices())) - 1)) - transitivity = gt.global_clustering(G)[0] + if use_gpu: + + if not gpu_lib: + sys.stderr.write('Unable to load GPU libraries; exiting\n') + sys.exit(1) + + component_assignments = cugraph.components.connectivity.connected_components(G) + component_nums = component_assignments['labels'].unique().astype(int) + components = len(component_nums) + density = G.number_of_edges()/(0.5 * G.number_of_vertices() * G.number_of_vertices() - 1) + triangle_count = cugraph.community.triangle_count.triangles(G) + degree_df = G.in_degree() + triad_count = sum([d * (d - 1) for d in degree_df['degree'].to_pandas()]) + transitivity = 2 * triangle_count/triad_count + else: + component_assignments, component_frequencies = gt.label_components(G) + components = len(component_frequencies) + density = len(list(G.edges()))/(0.5 * len(list(G.vertices())) * (len(list(G.vertices())) - 1)) + transitivity = gt.global_clustering(G)[0] mean_bt = 0 weighted_mean_bt = 0 if calc_betweenness: betweenness = [] sizes = [] - for component, size in enumerate(component_frequencies): - if size > 3: - vfilt = component_assignments.a == component - subgraph = gt.GraphView(G, vfilt=vfilt) - betweenness.append(max(gt.betweenness(subgraph, norm = True)[0].a)) - sizes.append(size) + + if use_gpu: + component_frequencies = component_assignments['labels'].value_counts(sort = True, ascending = False) + for component in component_nums.to_pandas(): + size = component_frequencies[component_frequencies.index == component].iloc[0].astype(int) + if size > 3: + component_vertices = component_assignments['vertex'][component_assignments['labels']==component] + subgraph = cugraph.subgraph(G, component_vertices) + max_betweeness_k = 1000 + if len(component_vertices) >= max_betweeness_k: + component_betweenness = cugraph.betweenness_centrality(subgraph, k = max_betweeness_k) + else: + component_betweenness = cugraph.betweenness_centrality(subgraph) + betweenness.append(component_betweenness['betweenness_centrality'].max()) + sizes.append(size) + else: + for component, size in enumerate(component_frequencies): + if size > 3: + vfilt = component_assignments.a == component + subgraph = gt.GraphView(G, vfilt=vfilt) + betweenness.append(max(gt.betweenness(subgraph, norm = True)[0].a)) + sizes.append(size) if len(betweenness) > 1: mean_bt = np.mean(betweenness) weighted_mean_bt = np.average(betweenness, weights=sizes) + elif len(betweenness) == 1: + mean_bt = betweenness[0] + weighted_mean_bt = betweenness[0] + # Calculate scores metrics = [components, density, transitivity, mean_bt, weighted_mean_bt] base_score = transitivity * (1 - density) scores = [base_score, base_score * (1 - metrics[3]), base_score * (1 - metrics[4])] @@ -425,7 +757,8 @@ def networkSummary(G, calc_betweenness=True): def addQueryToNetwork(dbFuncs, rList, qList, G, kmers, assignments, model, queryDB, queryQuery = False, - strand_preserved = False, weights = None, threads = 1): + strand_preserved = False, weights = None, threads = 1, + use_gpu = False): """Finds edges between queries and items in the reference database, and modifies the network to include them. @@ -458,6 +791,8 @@ def addQueryToNetwork(dbFuncs, rList, qList, G, kmers, be annotated as an edge attribute threads (int) Number of threads to use if new db created + use_gpu (bool) + Whether to use cugraph for analysis (default = 1) Returns: @@ -494,14 +829,14 @@ def addQueryToNetwork(dbFuncs, rList, qList, G, kmers, else: sys.stderr.write("Calculating all query-query distances\n") addRandom(queryDB, qList, kmers, strand_preserved, threads = threads) - qlist1, qlist2, qqDistMat = queryDatabase(rNames = qList, - qNames = qList, - dbPrefix = queryDB, - queryPrefix = queryDB, - klist = kmers, - self = True, - number_plot_fits = 0, - threads = threads) + qqDistMat = queryDatabase(rNames = qList, + qNames = qList, + dbPrefix = queryDB, + queryPrefix = queryDB, + klist = kmers, + self = True, + number_plot_fits = 0, + threads = threads) queryAssignation = model.assign(qqDistMat) for row_idx, (assignment, (ref, query)) in enumerate(zip(queryAssignation, listDistInts(qList, qList, self = True))): @@ -524,21 +859,21 @@ def addQueryToNetwork(dbFuncs, rList, qList, G, kmers, # use database construction methods to find links between unassigned queries addRandom(queryDB, qList, kmers, strand_preserved, threads = threads) - qlist1, qlist2, qqDistMat = queryDatabase(rNames = list(unassigned), - qNames = list(unassigned), - dbPrefix = queryDB, - queryPrefix = queryDB, - klist = kmers, - self = True, - number_plot_fits = 0, - threads = threads) + qqDistMat = queryDatabase(rNames = list(unassigned), + qNames = list(unassigned), + dbPrefix = queryDB, + queryPrefix = queryDB, + klist = kmers, + self = True, + number_plot_fits = 0, + threads = threads) queryAssignation = model.assign(qqDistMat) # identify any links between queries and store in the same links dict # links dict now contains lists of links both to original database and new queries # have to use names and link to query list in order to match to node indices - for row_idx, (assignment, (query1, query2)) in enumerate(zip(queryAssignation, iterDistRows(qlist1, qlist2, self = True))): + for row_idx, (assignment, (query1, query2)) in enumerate(zip(queryAssignation, iterDistRows(qList, qList, self = True))): if assignment == model.within_label: if weights is not None: dist = np.linalg.norm(qqDistMat[row_idx, :]) @@ -548,24 +883,89 @@ def addQueryToNetwork(dbFuncs, rList, qList, G, kmers, new_edges.append(edge_tuple) # finish by updating the network - G.add_vertex(len(qList)) + if use_gpu: + + if not gpu_lib: + sys.stderr.write('Unable to load GPU libraries; exiting\n') + sys.exit(1) + + # construct updated graph + G_current_df = G.view_edge_list() + if weights is not None: + G_current_df.columns = ['source','destination','weights'] + G_extra_df = cudf.DataFrame(new_edges, columns =['source','destination','weights']) + G_df = cudf.concat([G_current_df,G_extra_df], ignore_index = True) + else: + G_current_df.columns = ['source','destination'] + G_extra_df = cudf.DataFrame(new_edges, columns =['source','destination']) + G_df = cudf.concat([G_current_df,G_extra_df], ignore_index = True) + + # use self-loop to ensure all nodes are present + max_in_vertex_labels = ref_count + len(qList) - 1 + include_weights = False + if weights is not None: + include_weights = True + G = add_self_loop(G_df, max_in_vertex_labels, weights = include_weights) - if weights is not None: - eweight = G.new_ep("float") - G.add_edge_list(new_edges, eprops = [eweight]) - G.edge_properties["weight"] = eweight else: - G.add_edge_list(new_edges) + G.add_vertex(len(qList)) - # including the vertex ID property map - for i, q in enumerate(qList): - G.vp.id[i + len(rList)] = q + if weights is not None: + eweight = G.new_ep("float") + G.add_edge_list(new_edges, eprops = [eweight]) + G.edge_properties["weight"] = eweight + else: + G.add_edge_list(new_edges) + + # including the vertex ID property map + for i, q in enumerate(qList): + G.vp.id[i + len(rList)] = q + + return G, qqDistMat + +def add_self_loop(G_df, seq_num, weights = False, renumber = True): + """Adds self-loop to cugraph graph to ensure all nodes are included in + the graph, even if singletons. + + Args: + G_df (cudf) + cudf data frame containing edge list + seq_num (int) + The expected number of nodes in the graph + renumber (bool) + Whether to renumber the vertices when added to the graph + + Returns: + G_new (graph) + Dictionary of cluster assignments (keys are sequence names) + """ + # use self-loop to ensure all nodes are present + min_in_df = np.amin([G_df['source'].min(), G_df['destination'].min()]) + if min_in_df.item() > 0: + G_self_loop = cudf.DataFrame() + G_self_loop['source'] = [0] + G_self_loop['destination'] = [0] + if weights: + G_self_loop['weights'] = 0.0 + G_df = cudf.concat([G_df,G_self_loop], ignore_index = True) + max_in_df = np.amax([G_df['source'].max(),G_df['destination'].max()]) + if max_in_df.item() != seq_num: + G_self_loop = cudf.DataFrame() + G_self_loop['source'] = [seq_num] + G_self_loop['destination'] = [seq_num] + if weights: + G_self_loop['weights'] = 0.0 + G_df = cudf.concat([G_df,G_self_loop], ignore_index = True) + # Construct graph + G_new = cugraph.Graph() + G_new.from_cudf_edgelist(G_df, renumber = renumber) + return G_new - return qqDistMat def printClusters(G, rlist, outPrefix=None, oldClusterFile=None, externalClusterCSV=None, printRef=True, printCSV=True, - clustering_type='combined', write_unwords=True): + clustering_type='combined', write_unwords=True, + use_gpu = False): """Get cluster assignments Also writes assignments to a CSV file @@ -597,6 +997,8 @@ def printClusters(G, rlist, outPrefix=None, oldClusterFile=None, write_unwords (bool) Write clusters with a pronouncable name rather than numerical index Default = True + use_gpu (bool) + Whether to use cugraph for network analysis Returns: clustering (dict) Dictionary of cluster assignments (keys are sequence names) @@ -608,13 +1010,28 @@ def printClusters(G, rlist, outPrefix=None, oldClusterFile=None, write_unwords = False # get a sorted list of component assignments - component_assignments, component_frequencies = gt.label_components(G) - component_frequency_ranks = len(component_frequencies) - rankdata(component_frequencies, method = 'ordinal').astype(int) - newClusters = [set() for rank in range(len(component_frequency_ranks))] - for isolate_index, isolate_name in enumerate(rlist): - component = component_assignments.a[isolate_index] - component_rank = component_frequency_ranks[component] - newClusters[component_rank].add(isolate_name) + if use_gpu: + if not gpu_lib: + sys.stderr.write('Unable to load GPU libraries; exiting\n') + sys.exit(1) + + component_assignments = cugraph.components.connectivity.connected_components(G) + component_frequencies = component_assignments['labels'].value_counts(sort = True, ascending = False) + newClusters = [set() for rank in range(component_frequencies.size)] + for isolate_index, isolate_name in enumerate(rlist): # assume sorted at the moment + component = component_assignments['labels'].iloc[isolate_index].item() + component_rank_bool = component_frequencies.index == component + component_rank = np.argmax(component_rank_bool.to_array()) + newClusters[component_rank].add(isolate_name) + else: + component_assignments, component_frequencies = gt.label_components(G) + component_frequency_ranks = len(component_frequencies) - rankdata(component_frequencies, method = 'ordinal').astype(int) + # use components to determine new clusters + newClusters = [set() for rank in range(len(component_frequency_ranks))] + for isolate_index, isolate_name in enumerate(rlist): + component = component_assignments.a[isolate_index] + component_rank = component_frequency_ranks[component] + newClusters[component_rank].add(isolate_name) oldNames = set() @@ -805,6 +1222,7 @@ def generate_minimum_spanning_tree(G, from_cugraph = False): if "weight" in G.edge_properties: mst_edge_prop_map = gt.min_spanning_tree(G, weights = G.ep["weight"]) mst_network = gt.GraphView(G, efilt = mst_edge_prop_map) + mst_network = gt.Graph(mst_network, prune = True) else: sys.stderr.write("generate_minimum_spanning_tree requires a weighted graph\n") raise RuntimeError("MST passed unweighted graph") @@ -854,3 +1272,48 @@ def generate_minimum_spanning_tree(G, from_cugraph = False): sys.stderr.write("Completed calculation of minimum-spanning tree\n") return mst_network + +def get_vertex_list(G, use_gpu = False): + """Generate a list of node indices + + Args: + G (network) + Graph tool network + use_gpu (bool) + Whether graph is a cugraph or not + [default = False] + + Returns: + vlist (list) + List of integers corresponding to nodes + """ + + if use_gpu: + vlist = range(G.number_of_vertices()) + else: + vlist = list(G.vertices()) + + return vlist + +def save_network(G, prefix = None, suffix = None, use_gpu = False): + """Save a network to disc + + Args: + G (network) + Graph tool network + prefix (str) + Prefix for output file + use_gpu (bool) + Whether graph is a cugraph or not + [default = False] + + """ + file_name = prefix + "/" + os.path.basename(prefix) + if suffix is not None: + file_name = file_name + suffix + if use_gpu: + G.to_pandas_edgelist().to_csv(file_name + '.csv.gz', + compression='gzip', index = False) + else: + G.save(file_name + '.gt', + fmt = 'gt') diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py index 744e1054..868e2b79 100644 --- a/PopPUNK/plot.py +++ b/PopPUNK/plot.py @@ -409,7 +409,7 @@ def distHistogram(dists, rank, outPrefix): "_rank_" + str(rank) + "_histogram.png") plt.close() -def drawMST(mst, outPrefix, isolate_clustering, overwrite): +def drawMST(mst, outPrefix, isolate_clustering, clustering_name, overwrite): """Plot a layout of the minimum spanning tree Args: @@ -419,6 +419,8 @@ def drawMST(mst, outPrefix, isolate_clustering, overwrite): Output prefix for save files isolate_clustering (dict) Dictionary of ID: cluster, used for colouring vertices + clustering_name (str) + Name of clustering scheme to be used for colouring overwrite (bool) Overwrite existing output files """ @@ -441,12 +443,12 @@ def drawMST(mst, outPrefix, isolate_clustering, overwrite): output=graph1_file_name, output_size=(3000, 3000)) if overwrite or not os.path.isfile(graph2_file_name): cluster_fill = {} - for cluster in set(isolate_clustering['Cluster'].values()): + for cluster in set(isolate_clustering[clustering_name].values()): cluster_fill[cluster] = list(np.random.rand(3)) + [0.9] plot_color = mst.new_vertex_property('vector') mst.vertex_properties['plot_color'] = plot_color for v in mst.vertices(): - plot_color[v] = cluster_fill[isolate_clustering['Cluster'][mst.vp.id[v]]] + plot_color[v] = cluster_fill[isolate_clustering[clustering_name][mst.vp.id[v]]] gt.graph_draw(mst, pos=pos, vertex_fill_color=mst.vertex_properties['plot_color'], output=graph2_file_name, output_size=(3000, 3000)) diff --git a/PopPUNK/refine.py b/PopPUNK/refine.py index ee8e6151..fc3b752b 100644 --- a/PopPUNK/refine.py +++ b/PopPUNK/refine.py @@ -24,6 +24,14 @@ import poppunk_refine import graph_tool.all as gt +# GPU support +try: + import cugraph + import cudf + gpu_lib = True +except ImportError as e: + gpu_lib = False + from .network import constructNetwork from .network import networkSummary @@ -32,7 +40,7 @@ def refineFit(distMat, sample_names, start_s, mean0, mean1, max_move, min_move, slope = 2, score_idx = 0, - unconstrained = False, no_local = False, num_processes = 1): + unconstrained = False, no_local = False, num_processes = 1, use_gpu = False): """Try to refine a fit by maximising a network score based on transitivity and density. Iteratively move the decision boundary to do this, using starting point from existing model. @@ -65,8 +73,10 @@ def refineFit(distMat, sample_names, start_s, mean0, mean1, Quicker, but may be less well refined. num_processes (int) Number of threads to use in the global optimisation step. - (default = 1) + use_gpu (bool) + Whether to use cugraph for graph analyses + Returns: start_point (tuple) (x, y) co-ordinates of starting point @@ -102,29 +112,41 @@ def refineFit(distMat, sample_names, start_s, mean0, mean1, x_max = np.linspace(x_max_start, x_max_end, global_grid_resolution, dtype=np.float32) y_max = np.linspace(y_max_start, y_max_end, global_grid_resolution, dtype=np.float32) - if gt.openmp_enabled(): - gt.openmp_set_num_threads(1) - - with SharedMemoryManager() as smm: - shm_distMat = smm.SharedMemory(size = distMat.nbytes) - distances_shared_array = np.ndarray(distMat.shape, dtype = distMat.dtype, buffer = shm_distMat.buf) - distances_shared_array[:] = distMat[:] - distances_shared = NumpyShared(name = shm_distMat.name, shape = distMat.shape, dtype = distMat.dtype) - - with Pool(processes = num_processes) as pool: - global_s = pool.map(partial(newNetwork2D, - sample_names = sample_names, - distMat = distances_shared, - x_range = x_max, - y_range = y_max, - score_idx = score_idx), - range(global_grid_resolution)) - - if gt.openmp_enabled(): - gt.openmp_set_num_threads(num_processes) - - global_s = list(chain.from_iterable(global_s)) - min_idx = np.argmin(np.array(global_s)) + if use_gpu: + global_s = map(partial(newNetwork2D, + sample_names = sample_names, + distMat = distMat, + x_range = x_max, + y_range = y_max, + score_idx = score_idx, + use_gpu = True), + range(global_grid_resolution)) + else: + if gt.openmp_enabled(): + gt.openmp_set_num_threads(1) + + with SharedMemoryManager() as smm: + shm_distMat = smm.SharedMemory(size = distMat.nbytes) + distances_shared_array = np.ndarray(distMat.shape, dtype = distMat.dtype, buffer = shm_distMat.buf) + distances_shared_array[:] = distMat[:] + distances_shared = NumpyShared(name = shm_distMat.name, shape = distMat.shape, dtype = distMat.dtype) + + with Pool(processes = num_processes) as pool: + global_s = pool.map(partial(newNetwork2D, + sample_names = sample_names, + distMat = distances_shared, + x_range = x_max, + y_range = y_max, + score_idx = score_idx, + use_gpu = False), + range(global_grid_resolution)) + + if gt.openmp_enabled(): + gt.openmp_set_num_threads(num_processes) + + global_s = np.array(list(chain.from_iterable(global_s))) + global_s[np.isnan(global_s)] = 1 + min_idx = np.argmin(global_s) optimal_x = x_max[min_idx % global_grid_resolution] optimal_y = y_max[min_idx // global_grid_resolution] @@ -148,7 +170,7 @@ def refineFit(distMat, sample_names, start_s, mean0, mean1, poppunk_refine.thresholdIterate1D(distMat, s_range, slope, start_point[0], start_point[1], mean1[0], mean1[1], num_processes) - global_s = np.array(growNetwork(sample_names, i_vec, j_vec, idx_vec, s_range, score_idx)) + global_s = np.array(growNetwork(sample_names, i_vec, j_vec, idx_vec, s_range, score_idx, use_gpu = use_gpu)) global_s[np.isnan(global_s)] = 1 min_idx = np.argmin(np.array(global_s)) if min_idx > 0 and min_idx < len(s_range) - 1: @@ -163,7 +185,9 @@ def refineFit(distMat, sample_names, start_s, mean0, mean1, local_s = scipy.optimize.minimize_scalar(newNetwork, bounds=bounds, method='Bounded', options={'disp': True}, - args = (sample_names, distMat, start_point, mean1, gradient, slope, score_idx)) + args = (sample_names, distMat, start_point, mean1, gradient, + slope, score_idx, num_processes, use_gpu), + ) optimised_s = local_s.x # Convert to x_max, y_max if needed @@ -181,7 +205,7 @@ def refineFit(distMat, sample_names, start_s, mean0, mean1, return start_point, optimal_x, optimal_y, min_move, max_move -def growNetwork(sample_names, i_vec, j_vec, idx_vec, s_range, score_idx, thread_idx = 0): +def growNetwork(sample_names, i_vec, j_vec, idx_vec, s_range, score_idx, thread_idx = 0, use_gpu = False): """Construct a network, then add edges to it iteratively. Input is from ``pp_sketchlib.iterateBoundary1D`` or``pp_sketchlib.iterateBoundary2D`` @@ -202,11 +226,19 @@ def growNetwork(sample_names, i_vec, j_vec, idx_vec, s_range, score_idx, thread_ [default = 0] thread_idx (int) Optional thread idx (if multithreaded) to offset progress bar by + use_gpu (bool) + Whether to use cugraph for graph analyses + Returns: scores (list) -1 * network score for each of x_range. Where network score is from :func:`~PopPUNK.network.networkSummary` """ + # load CUDA libraries + if use_gpu and not gpu_lib: + sys.stderr.write('Unable to load GPU libraries; exiting\n') + sys.exit(1) + scores = [] edge_list = [] prev_idx = 0 @@ -220,12 +252,22 @@ def growNetwork(sample_names, i_vec, j_vec, idx_vec, s_range, score_idx, thread_ # At first offset, make a new network, otherwise just add the new edges if prev_idx == 0: G = constructNetwork(sample_names, sample_names, edge_list, -1, - summarise=False, edge_list=True) + summarise=False, edge_list=True, use_gpu = use_gpu) else: - G.add_edge_list(edge_list) + if use_gpu: + G_current_df = G.view_edge_list() + G_current_df.columns = ['source','destination'] + G_extra_df = cudf.DataFrame(edge_list, columns =['source','destination']) + G_df = cudf.concat([G_current_df,G_extra_df], ignore_index = True) + G = cugraph.Graph() + G.from_cudf_edgelist(G_df) + else: + # Adding edges to network not currently possible with GPU - https://github.com/rapidsai/cugraph/issues/805 + # We add to the cuDF, and then reconstruct the network instead + G.add_edge_list(edge_list) # Add score into vector for any offsets passed (should usually just be one) for s in range(prev_idx, idx): - scores.append(-networkSummary(G, score_idx > 0)[1][score_idx]) + scores.append(-networkSummary(G, score_idx > 0, use_gpu = use_gpu)[1][score_idx]) pbar.update(1) prev_idx = idx edge_list = [] @@ -234,18 +276,23 @@ def growNetwork(sample_names, i_vec, j_vec, idx_vec, s_range, score_idx, thread_ # Add score for final offset(s) at end of loop if prev_idx == 0: G = constructNetwork(sample_names, sample_names, edge_list, -1, - summarise=False, edge_list=True) + summarise=False, edge_list=True, use_gpu = use_gpu) else: - G.add_edge_list(edge_list) + if use_gpu: + G = constructNetwork(sample_names, sample_names, edge_list, -1, + summarise=False, edge_list=True, use_gpu = use_gpu) + else: + # Not currently possible with GPU - https://github.com/rapidsai/cugraph/issues/805 + G.add_edge_list(edge_list) for s in range(prev_idx, len(s_range)): - scores.append(-networkSummary(G, score_idx > 0)[1][score_idx]) + scores.append(-networkSummary(G, score_idx > 0, use_gpu = use_gpu)[1][score_idx]) pbar.update(1) return(scores) def newNetwork(s, sample_names, distMat, start_point, mean1, gradient, - slope=2, score_idx=0, cpus=1): + slope=2, score_idx=0, cpus=1, use_gpu = False): """Wrapper function for :func:`~PopPUNK.network.constructNetwork` which is called by optimisation functions moving a triangular decision boundary. @@ -274,6 +321,9 @@ def newNetwork(s, sample_names, distMat, start_point, mean1, gradient, [default = 0] cpus (int) Number of CPUs to use for calculating assignment + use_gpu (bool) + Whether to use cugraph for graph analysis + Returns: score (float) -1 * network score. Where network score is from :func:`~PopPUNK.network.networkSummary` @@ -295,13 +345,14 @@ def newNetwork(s, sample_names, distMat, start_point, mean1, gradient, # Make network boundary_assignments = poppunk_refine.assignThreshold(distMat, slope, x_max, y_max, cpus) - G = constructNetwork(sample_names, sample_names, boundary_assignments, -1, summarise = False) + G = constructNetwork(sample_names, sample_names, boundary_assignments, -1, summarise = False, + use_gpu = use_gpu) # Return score - score = networkSummary(G, score_idx > 0)[1][score_idx] + score = networkSummary(G, score_idx > 0, use_gpu = use_gpu)[1][score_idx] return(-score) -def newNetwork2D(y_idx, sample_names, distMat, x_range, y_range, score_idx=0): +def newNetwork2D(y_idx, sample_names, distMat, x_range, y_range, score_idx=0, use_gpu = False): """Wrapper function for thresholdIterate2D and :func:`growNetwork`. For a given y_max, constructs networks across x_range and returns a list @@ -321,6 +372,9 @@ def newNetwork2D(y_idx, sample_names, distMat, x_range, y_range, score_idx=0): score_idx (int) Index of score from :func:`~PopPUNK.network.networkSummary` to use [default = 0] + use_gpu (bool) + Whether to use cugraph for graph analysis + Returns: scores (list) -1 * network score for each of x_range. @@ -335,7 +389,7 @@ def newNetwork2D(y_idx, sample_names, distMat, x_range, y_range, score_idx=0): y_max = y_range[y_idx] i_vec, j_vec, idx_vec = \ poppunk_refine.thresholdIterate2D(distMat, x_range, y_max) - scores = growNetwork(sample_names, i_vec, j_vec, idx_vec, x_range, score_idx, y_idx) + scores = growNetwork(sample_names, i_vec, j_vec, idx_vec, x_range, score_idx, y_idx, use_gpu = use_gpu) return(scores) def readManualStart(startFile): @@ -415,4 +469,3 @@ def likelihoodBoundary(s, model, start, end, within, between): X = transformLine(s, start, end).reshape(1, -1) responsibilities = model.assign(X, progress = False, values = True) return(responsibilities[0, within] - responsibilities[0, between]) - diff --git a/PopPUNK/sketchlib.py b/PopPUNK/sketchlib.py index 37ead86d..b5099e59 100644 --- a/PopPUNK/sketchlib.py +++ b/PopPUNK/sketchlib.py @@ -388,6 +388,10 @@ def constructDatabase(assemblyList, klist, sketch_size, oPrefix, deviceid (int) GPU device id (default = 0) + Returns: + names (list) + List of names included in the database (some may be pruned due + to QC) """ # read file names names, sequences = readRfile(assemblyList) @@ -417,6 +421,7 @@ def constructDatabase(assemblyList, klist, sketch_size, oPrefix, # QC sequences if qc_dict['run_qc']: filtered_names = sketchlibAssemblyQC(oPrefix, + names, klist, qc_dict, strand_preserved, @@ -517,10 +522,6 @@ def queryDatabase(rNames, qNames, dbPrefix, queryPrefix, klist, self = True, num (default = 0) Returns: - refList (list) - Names of reference sequences - queryList (list) - Names of query sequences distMat (numpy.array) Core distances (column 0) and accessory distances (column 1) between refList and queryList @@ -568,54 +569,58 @@ def queryDatabase(rNames, qNames, dbPrefix, queryPrefix, klist, self = True, num distMat = pp_sketchlib.queryDatabase(ref_db, query_db, rNames, qNames, klist, True, False, threads, use_gpu, deviceid) - return(rNames, qNames, distMat) + return distMat + -def calculateQueryQueryDistances(dbFuncs, qlist, kmers, - queryDB, threads = 1): - """Calculates distances between queries. +def pickTypeIsolate(prefix, names): + """Selects a type isolate as that with a minimal proportion + of missing data. Args: - dbFuncs (list) - List of backend functions from :func:`~PopPUNK.utils.setupDBFuncs` - rlist (list) - List of reference names - qlist (list) - List of query names - kmers (list) - List of k-mer sizes - queryDB (str) - Query database location - threads (int) - Number of threads to use if new db created - (default = 1) + prefix (str) + Prefix of output files + names (list) + Names of samples to QC Returns: - qlist1 (list) - Ordered list of queries - distMat (numpy.array) - Query-query distances + type_isolate (str) + Name of isolate selected as reference """ + # open databases + db_name = prefix + '/' + os.path.basename(prefix) + '.h5' + hdf_in = h5py.File(db_name, 'r+') - queryDatabase = dbFuncs['queryDatabase'] + min_prop_n = 1.0 + type_isolate = None - qlist1, qlist2, distMat = queryDatabase(rNames = qlist, - qNames = qlist, - dbPrefix = queryDB, - queryPrefix = queryDB, - klist = kmers, - self = True, - number_plot_fits = 0, - threads = threads) + try: + # process data structures + read_grp = hdf_in['sketches'] + # iterate through sketches + for dataset in read_grp: + if hdf_in['sketches'][dataset].attrs['missing_bases']/hdf_in['sketches'][dataset].attrs['length'] < min_prop_n: + min_prop_n = hdf_in['sketches'][dataset].attrs['missing_bases']/hdf_in['sketches'][dataset].attrs['length'] + type_isolate = dataset + if min_prop_n == 0.0: + break + # if failure still close files to avoid corruption + except: + hdf_in.close() + sys.stderr.write('Problem processing h5 databases during QC - aborting\n') + print("Unexpected error:", sys.exc_info()[0], file = sys.stderr) + raise - return qlist1, distMat + return type_isolate -def sketchlibAssemblyQC(prefix, klist, qc_dict, strand_preserved, threads): +def sketchlibAssemblyQC(prefix, names, klist, qc_dict, strand_preserved, threads): """Calculates random match probability based on means of genomes in assemblyList, and looks for length outliers. Args: prefix (str) Prefix of output files + names (list) + Names of samples to QC klist (list) List of k-mer sizes to sketch qc_dict (dict) @@ -647,10 +652,11 @@ def sketchlibAssemblyQC(prefix, klist, qc_dict, strand_preserved, threads): # iterate through sketches for dataset in read_grp: - # test thresholds - remove = False - seq_length[dataset] = hdf_in['sketches'][dataset].attrs['length'] - seq_ambiguous[dataset] = hdf_in['sketches'][dataset].attrs['missing_bases'] + if dataset in names: + # test thresholds + remove = False + seq_length[dataset] = hdf_in['sketches'][dataset].attrs['length'] + seq_ambiguous[dataset] = hdf_in['sketches'][dataset].attrs['missing_bases'] # calculate thresholds # get mean length @@ -734,6 +740,15 @@ def sketchlibAssemblyQC(prefix, klist, qc_dict, strand_preserved, threads): del hdf_in['random'] hdf_in.close() + # This gives back retained in the same order as names + retained = [x for x in names if x in frozenset(retained)] + + # stop if type sequence does not pass QC or is absent + if qc_dict['type_isolate'] is not None and qc_dict['type_isolate'] not in retained: + sys.stderr.write('Type isolate ' + qc_dict['type_isolate'] + ' not found in isolates after QC; check ' + 'name of type isolate and QC options\n') + sys.exit(1) + return retained def fitKmerCurve(pairwise, klist, jacobian): diff --git a/PopPUNK/sparse_mst.py b/PopPUNK/sparse_mst.py index 50a9e4b1..68f8e321 100755 --- a/PopPUNK/sparse_mst.py +++ b/PopPUNK/sparse_mst.py @@ -8,13 +8,22 @@ import pickle import re +import numpy as np import pandas as pd from scipy import sparse +# GPU support +try: + import cugraph + import cudf + gpu_lib = True +except ImportError as e: + gpu_lib = False + # import poppunk package from .__init__ import __version__ -from .network import constructNetwork, generate_minimum_spanning_tree +from .network import constructNetwork, generate_minimum_spanning_tree, network_to_edges from .plot import drawMST from .trees import mst_to_phylogeny, write_tree from .utils import setGtThreads, readIsolateTypeFromCsv @@ -29,9 +38,12 @@ def get_options(): # input options iGroup = parser.add_argument_group('Input files') - iGroup.add_argument('--distances', required=True, help='Prefix of input pickle of pre-calculated distances (required)') iGroup.add_argument('--rank-fit', required=True, help='Location of rank fit, a sparse matrix (*_rank*_fit.npz)') iGroup.add_argument('--previous-clustering', help='CSV file with cluster definitions') + iGroup.add_argument('--previous-mst', help='Graph tool file from which previous MST can be loaded', + default=None) + iGroup.add_argument('--distance-pkl', help='Input pickle from distances, which contains sample names') + iGroup.add_argument('--display-cluster', default=None, help='Column of clustering CSV to use for plotting') # output options oGroup = parser.add_argument_group('Output options') @@ -57,13 +69,25 @@ def main(): args = get_options() import graph_tool.all as gt - try: - import cugraph - import cudf - except ImportError as e: - if args.gpu_graph: - sys.stderr.write("cugraph and cudf unavailable\n") - raise ImportError(e) + # load CUDA libraries + if args.gpu_graph and not gpu_lib: + sys.stderr.write('Unable to load GPU libraries; exiting\n') + sys.exit(1) + + # Read in sample names + if (args.distance_pkl is not None) ^ (args.previous_clustering is not None): + sys.stderr.write("To label strains, both --distance-pkl and --previous-clustering" + " must be provided\n") + sys.exit(1) + elif os.path.exists(args.distance_pkl): + with open(args.distance_pkl, 'rb') as pickle_file: + rlist, qlist, self = pickle.load(pickle_file) + if not self: + sys.stderr.write("This script must be run on a full all-v-all model\n") + sys.exit(1) + else: + sys.stderr.write("Cannot find file " + args.distance_pkl + "\n") + sys.exit(1) # Check output path ok if not os.path.isdir(args.output): @@ -74,23 +98,32 @@ def main(): sys.exit(1) setGtThreads(args.threads) - # Read in sample names - with open(args.distances + ".pkl", 'rb') as pickle_file: - rlist, qlist, self = pickle.load(pickle_file) - if not self: - sys.stderr.write("This script must be run on a full all-v-all model\n") - sys.exit(1) - # Create network with sparse dists sys.stderr.write("Loading distances into graph\n") sparse_mat = sparse.load_npz(args.rank_fit) if args.gpu_graph: - G_df = cudf.DataFrame({'source': sparse_mat.row, - 'destination': sparse_mat.col, - 'weights': sparse_mat.data}) + # Load previous MST if specified + if args.previous_mst is not None: + print("Previous: " + str(args.previous_mst)) + extra_sources, extra_targets, extra_weights = network_to_edges(args.previous_mst, + rlist, + previous_pkl = args.distance_pkl, + weights = True, + use_gpu = use_gpu) + sources = np.append(sparse_mat.row, np.asarray(extra_sources)) + targets = np.append(sparse_mat.col, np.asarray(extra_targets)) + weights = np.append(sparse_mat.data, np.asarray(extra_weights)) + else: + sources = sparse_mat.row + targets = sparse_mat.col + weights = sparse_mat.data + G_df = cudf.DataFrame({'source': sources, + 'destination': targets, + 'weights': weights}) G_cu = cugraph.Graph() G_cu.from_cudf_edgelist(G_df, edge_attr='weights', renumber=False) + # Generate minimum spanning tree sys.stderr.write("Calculating MST (GPU part)\n") G_mst = cugraph.minimum_spanning_tree(G_cu, weight='weights') edge_df = G_mst.view_edge_list() @@ -102,11 +135,19 @@ def main(): weights=edge_df['weights'].values_host, summarise=False) else: - G = constructNetwork(rlist, rlist, None, 0, - sparse_input=sparse_mat, summarise=False) + # Load previous MST if specified + if args.previous_mst is not None: + G = constructNetwork(rlist, rlist, None, 0, + sparse_input=sparse_mat, summarise=False, + previous_network = args.previous_mst) + else: + G = constructNetwork(rlist, rlist, None, 0, + sparse_input=sparse_mat, summarise=False) sys.stderr.write("Calculating MST (CPU)\n") mst = generate_minimum_spanning_tree(G, args.gpu_graph) + + # Save output sys.stderr.write("Generating output\n") mst.save(args.output + "/" + os.path.basename(args.output) + ".graphml", fmt="graphml") mst_as_tree = mst_to_phylogeny(mst, rlist) @@ -116,7 +157,7 @@ def main(): if not args.no_plot: if args.previous_clustering != None: mode = "clusters" - if re.match(r"_lineages\.csv$", args.previous_clustering): + if args.previous_clustering.endswith('_lineages.csv'): mode = "lineages" isolateClustering = readIsolateTypeFromCsv(args.previous_clustering, mode = mode, @@ -127,7 +168,20 @@ def main(): for v in mst.vertices: isolateClustering['Cluster'][mst.vp.id[v]] = '0' - drawMST(mst, args.output, isolateClustering, True) + # Check selecting clustering type is in CSV + clustering_name = 'Cluster' + if args.display_cluster != None and args.previous_clustering != None: + if args.display_cluster not in isolateClustering.keys(): + sys.stderr.write('Unable to find clustering column ' + args.display_cluster + ' in file ' + + args.previous_clustering + '\n') + sys.exit() + else: + clustering_name = args.display_cluster + else: + clustering_name = list(isolateClustering.keys())[0] + + # Draw MST + drawMST(mst, args.output, isolateClustering, clustering_name, True) sys.exit(0) diff --git a/PopPUNK/utils.py b/PopPUNK/utils.py index 9842fffc..eb880604 100644 --- a/PopPUNK/utils.py +++ b/PopPUNK/utils.py @@ -123,7 +123,7 @@ def storePickle(rlist, qlist, self, X, pklName): np.save(pklName + ".npy", X) -def readPickle(pklName, enforce_self = False): +def readPickle(pklName, enforce_self=False, distances=True): """Loads core and accessory distances saved by :func:`~storePickle` Called during ``--fit-model`` @@ -134,6 +134,10 @@ def readPickle(pklName, enforce_self = False): enforce_self (bool) Error if self == False + [default = True] + distances (bool) + Read the distance matrix + [default = True] Returns: @@ -151,7 +155,10 @@ def readPickle(pklName, enforce_self = False): if enforce_self and not self: sys.stderr.write("Old distances " + pklName + ".npy not complete\n") sys.stderr.exit(1) - X = np.load(pklName + ".npy") + if distances: + X = np.load(pklName + ".npy") + else: + X = None return rlist, qlist, self, X @@ -219,9 +226,8 @@ def listDistInts(refSeqs, querySeqs, self=True): return comparisons -def qcDistMat(distMat, refList, queryList, a_max): - """Checks distance matrix for outliers. At the moment - just a threshold for accessory distance +def qcDistMat(distMat, refList, queryList, ref_db, prefix, qc_dict): + """Checks distance matrix for outliers. Args: distMat (np.array) @@ -230,25 +236,88 @@ def qcDistMat(distMat, refList, queryList, a_max): Reference labels queryList (list) Query labels (or refList if self) - a_max (float) - Maximum accessory distance to allow + ref_db (str) + Prefix of reference database + prefix (str) + Prefix of output files + qc_dict (dict) + Dict of QC options Returns: - passed (bool) - False if any samples failed + seq_names_passing (list) + List of isolates passing QC distance filters + distMat ([n,2] numpy ndarray) + Filtered long form distance matrix """ - passed = True + + # avoid circular import + from .prune_db import prune_distance_matrix + from .sketchlib import removeFromDB + from .sketchlib import pickTypeIsolate + + # Create overall list of sequences + if refList == refList: + seq_names_passing = refList + else: + seq_names_passing = refList + queryList + + # Sequences to remove + to_prune = [] + + # Create output directory if it does not exist already + if not os.path.isdir(prefix): + try: + os.makedirs(prefix) + except OSError: + sys.stderr.write("Cannot create output directory " + prefix + "\n") + sys.exit(1) + + # Pick type isolate if not supplied + if qc_dict['type_isolate'] is None: + qc_dict['type_isolate'] = pickTypeIsolate(ref_db, seq_names_passing) + sys.stderr.write('Selected type isolate for distance QC is ' + qc_dict['type_isolate'] + '\n') # First check with numpy, which is quicker than iterating over everything - if np.any(distMat[:,1] > a_max): - passed = False - names = iterDistRows(refList, queryList, refList == queryList) - for i, (ref, query) in enumerate(names): - if distMat[i,1] > a_max: - sys.stderr.write("WARNING: Accessory outlier at a=" + str(distMat[i,1]) + - " 1:" + ref + " 2:" + query + "\n") + long_distance_rows = np.where([(distMat[:, 0] > qc_dict['max_pi_dist']) | (distMat[:, 1] > qc_dict['max_a_dist'])])[1].tolist() + if len(long_distance_rows) > 0: + names = list(iterDistRows(refList, queryList, refList == queryList)) + # Prune sequences based on reference sequence + for i in long_distance_rows: + if names[i][0] == qc_dict['type_isolate']: + to_prune.append(names[i][1]) + elif names[i][1] == qc_dict['type_isolate']: + to_prune.append(names[i][0]) + + # prune based on distance from reference if provided + if qc_dict['qc_filter'] == 'stop' and len(to_prune) > 0: + sys.stderr.write('Outlier distances exceed QC thresholds; prune sequences or raise thresholds\n') + sys.stderr.write('Problem distances involved sequences ' + ';'.join(to_prune) + '\n') + sys.exit(1) + elif qc_dict['qc_filter'] == 'prune' and len(to_prune) > 0: + if qc_dict['type_isolate'] is None: + sys.stderr.write('Distances exceeded QC thresholds but no reference isolate supplied\n') + sys.stderr.write('Problem distances involved sequences ' + ';'.join(to_prune) + '\n') + sys.exit(1) + else: + # Remove sketches + db_name = ref_db + '/' + os.path.basename(ref_db) + '.h5' + filtered_db_name = prefix + '/' + 'filtered.' + os.path.basename(prefix) + '.h5' + removeFromDB(db_name, + filtered_db_name, + to_prune, + full_names = True) + os.rename(filtered_db_name, db_name) + # Remove from distance matrix + seq_names_passing, distMat = prune_distance_matrix(seq_names_passing, + to_prune, + distMat, + prefix + "/" + os.path.basename(prefix) + ".dists") + # Remove from reflist + sys.stderr.write('Pruned from the database after failing distance QC: ' + ';'.join(to_prune) + '\n') + else: + storePickle(seq_names_passing, seq_names_passing, True, distMat, prefix + "/" + os.path.basename(prefix) + ".dists") - return passed + return seq_names_passing, distMat def readIsolateTypeFromCsv(clustCSV, mode = 'clusters', return_dict = False): @@ -296,7 +365,7 @@ def readIsolateTypeFromCsv(clustCSV, mode = 'clusters', return_dict = False): cluster_name = clustersCsv.columns[cls_idx] cluster_name = cluster_name.replace('__autocolour','') if return_dict: - clusters[cluster_name][row.Index] = str(row[cls_idx + 1]) + clusters[cluster_name][str(row.Index)] = str(row[cls_idx + 1]) else: if cluster_name not in clusters.keys(): clusters[cluster_name] = defaultdict(set) @@ -409,6 +478,9 @@ def readRfile(rFile, oneSeq=False): "Must contain sample name and file, tab separated\n") sys.exit(1) + if "/" in rFields[0]: + sys.stderr.write("Sample names may not contain slashes\n") + sys.exit(1) names.append(rFields[0]) sample_files = [] for sequence in rFields[1:]: @@ -430,6 +502,14 @@ def readRfile(rFile, oneSeq=False): sys.stderr.write("Non-unique names are " + ",".join(dupes) + "\n") sys.exit(1) + # Names are sorted on return + # We have had issues (though they should be fixed) with unordered input + # not matching the database. This should help simplify things + list_iterable = zip(names, sequences) + sorted_names = sorted(list_iterable) + tuples = zip(*sorted_names) + names, sequences = [list(r_tuple) for r_tuple in tuples] + return (names, sequences) def isolateNameToLabel(names): @@ -509,4 +589,4 @@ def decisionBoundary(intercept, gradient): """ x = intercept[0] + intercept[1] * gradient y = intercept[1] + intercept[0] / gradient - return(x, y) \ No newline at end of file + return(x, y) diff --git a/PopPUNK/visualise.py b/PopPUNK/visualise.py index 8114770f..3369177f 100644 --- a/PopPUNK/visualise.py +++ b/PopPUNK/visualise.py @@ -53,18 +53,21 @@ def get_options(): 'to clusters [default = reference database directory]', type = str) iGroup.add_argument('--previous-clustering', - help='Directory containing previous cluster definitions ' + help='File containing previous cluster definitions ' 'and network [default = use that in the directory ' 'containing the model]', type = str) iGroup.add_argument('--previous-query-clustering', - help='Directory containing previous cluster definitions ' + help='File containing previous cluster definitions ' 'from poppunk_assign [default = use that in the directory ' - 'containing the model]', + 'of the query database]', type = str) - iGroup.add_argument('--use-network', - help='Specify a directory containing a .gt file to use for any graph visualisations', + iGroup.add_argument('--network-file', + help='Specify a file to use for any graph visualisations', type = str) + iGroup.add_argument('--display-cluster', + help='Column of clustering CSV to use for plotting', + default=None) # output options oGroup = parser.add_argument_group('Output options') @@ -92,7 +95,7 @@ def get_options(): faGroup.add_argument('--phandango', help='Generate phylogeny and TSV for Phandango visualisation', default=False, action='store_true') faGroup.add_argument('--grapetree', help='Generate phylogeny and CSV for grapetree visualisation', default=False, action='store_true') faGroup.add_argument('--tree', help='Type of tree to calculate [default = nj]', type=str, default='nj', - choices=['nj', 'mst', 'both']) + choices=['nj', 'mst', 'both', 'none']) faGroup.add_argument('--mst-distances', help='Distances used to calculate a minimum spanning tree [default = core]', type=str, default='core', choices=accepted_weights_types) faGroup.add_argument('--rapidnj', help='Path to rapidNJ binary to build NJ tree for Microreact', default='rapidnj') @@ -106,6 +109,7 @@ def get_options(): other = parser.add_argument_group('Other options') other.add_argument('--threads', default=1, type=int, help='Number of threads to use [default = 1]') other.add_argument('--gpu-dist', default=False, action='store_true', help='Use a GPU when calculating distances [default = False]') + other.add_argument('--gpu-graph', default=False, action='store_true', help='Use a GPU when calculating graphs [default = False]') other.add_argument('--deviceid', default=0, type=int, help='CUDA device ID, if using GPU [default = 0]') other.add_argument('--strand-preserved', default=False, action='store_true', help='If distances being calculated, treat strand as known when calculating random ' @@ -146,20 +150,24 @@ def generate_visualisations(query_db, model_dir, previous_clustering, previous_query_clustering, - use_network, + network_file, + gpu_graph, info_csv, rapidnj, tree, mst_distances, overwrite, core_only, - accessory_only): + accessory_only, + display_cluster, + web): from .models import loadClusterFit from .network import constructNetwork from .network import fetchNetwork from .network import generate_minimum_spanning_tree + from .network import load_network_file from .plot import drawMST from .plot import outputsForMicroreact @@ -299,43 +307,37 @@ def generate_visualisations(query_db, sys.exit(1) # Load previous clusters - mode = "clusters" - suffix = "_clusters.csv" - if model.type == "lineage": - mode = "lineages" - suffix = "_lineages.csv" - if model.indiv_fitted: - sys.stderr.write("Note: Individual (core/accessory) fits found, but " - "visualisation only supports combined boundary fit\n") - - # Set directories of previous fit if previous_clustering is not None: prev_clustering = previous_clustering + mode = "clusters" + suffix = "_clusters.csv" + if prev_clustering.endswith('_lineages.csv'): + mode = "lineages" + suffix = "_lineages.csv" else: - prev_clustering = os.path.dirname(model_file) - cluster_file = prev_clustering + '/' + os.path.basename(prev_clustering) + suffix - isolateClustering = readIsolateTypeFromCsv(cluster_file, + # Identify type of clustering based on model + mode = "clusters" + suffix = "_clusters.csv" + if model.type == "lineage": + mode = "lineages" + suffix = "_lineages.csv" + if model.indiv_fitted: + sys.stderr.write("Note: Individual (core/accessory) fits found, but " + "visualisation only supports combined boundary fit\n") + prev_clustering = os.path.basename(model_file) + '/' + os.path.basename(model_file) + suffix + isolateClustering = readIsolateTypeFromCsv(prev_clustering, mode = mode, return_dict = True) - # Set graph location - if use_network is not None: - graph_dir = use_network - if graph_dir != prev_clustering: - sys.stderr.write("WARNING: Loading graph from a different directory to clusters\n") - sys.stderr.write("WARNING: Ensure that they are consistent\n") - else: - graph_dir = prev_clustering - # Join clusters with query clusters if required if not self: if previous_query_clustering is not None: - prev_query_clustering = previous_query_clustering + '/' + os.path.basename(previous_query_clustering) + prev_query_clustering = previous_query_clustering else: - prev_query_clustering = query_db_loc + prev_query_clustering = os.path.basename(query_db) + '/' + os.path.basename(query_db) + suffix queryIsolateClustering = readIsolateTypeFromCsv( - prev_query_clustering + suffix, + prev_query_clustering, mode = mode, return_dict = True) isolateClustering = joinClusterDicts(isolateClustering, queryIsolateClustering) @@ -351,6 +353,18 @@ def generate_visualisations(query_db, if not overwrite: existing_tree = load_tree(output, "MST", distances=mst_distances) if existing_tree is None: + # Check selecting clustering type is in CSV + clustering_name = 'Cluster' + if display_cluster != None: + if display_cluster not in isolateClustering.keys(): + clustering_name = list(isolateClustering.keys())[0] + sys.stderr.write('Unable to find clustering column ' + display_cluster + ' in file ' + + prev_clustering + '; instead using ' + clustering_name + '\n') + else: + clustering_name = display_cluster + else: + clustering_name = list(isolateClustering.keys())[0] + # Get distance matrix complete_distMat = \ np.hstack((pp_sketchlib.squareToLong(core_distMat, threads).reshape(-1, 1), pp_sketchlib.squareToLong(acc_distMat, threads).reshape(-1, 1))) @@ -364,7 +378,7 @@ def generate_visualisations(query_db, weights_type=mst_distances, summarise=False) mst_graph = generate_minimum_spanning_tree(G) - drawMST(mst_graph, output, isolateClustering, overwrite) + drawMST(mst_graph, output, isolateClustering, clustering_name, overwrite) mst_tree = mst_to_phylogeny(mst_graph, isolateNameToLabel(combined_seq)) else: mst_tree = existing_tree @@ -423,7 +437,7 @@ def generate_visualisations(query_db, if cytoscape: sys.stderr.write("Writing cytoscape output\n") - genomeNetwork, cluster_file = fetchNetwork(graph_dir, model, rlist, False, core_only, accessory_only) + genomeNetwork = load_network_file(network_file, use_gpu = gpu_graph) outputsForCytoscape(genomeNetwork, mst_graph, isolateClustering, output, info_csv, viz_subset = viz_subset) if model.type == 'lineage': sys.stderr.write("Note: Only support for output of cytoscape graph at lowest rank\n") @@ -453,14 +467,17 @@ def main(): args.model_dir, args.previous_clustering, args.previous_query_clustering, - args.use_network, + args.network_file, + args.gpu_graph, args.info_csv, args.rapidnj, args.tree, args.mst_distances, args.overwrite, args.core_only, - args.accessory_only) + args.accessory_only, + args.display_cluster, + web = False) if __name__ == '__main__': main() diff --git a/PopPUNK/web.py b/PopPUNK/web.py index c1f6060f..6b013f61 100644 --- a/PopPUNK/web.py +++ b/PopPUNK/web.py @@ -66,6 +66,7 @@ def sketchAssign(): args.assign.ref_db, args.assign.q_files, outdir, + qc_dict, args.assign.update_db, args.assign.write_references, args.assign.distances, @@ -74,12 +75,18 @@ def sketchAssign(): args.assign.plot_fit, args.assign.graph_weights, args.assign.max_a_dist, + args.assign.max_pi_dist, + args.assign.type_isolate, args.assign.model_dir, args.assign.strand_preserved, args.assign.previous_clustering, args.assign.external_clustering, args.assign.core_only, args.assign.accessory_only, + args.assign.gpu_sketch, + args.assign.gpu_dist, + args.assign.gpu_graph, + args.assign.deviceid, args.assign.web, sketch_dict["sketch"], args.assign.save_partial_query_graph) @@ -107,16 +114,18 @@ def sketchAssign(): args.visualise.strand_preserved, outdir + "/include.txt", species_db, - species_db, + species_db + "/" + os.path.basename(species_db) + "_clusters.csv", args.visualise.previous_query_clustering, - outdir, + outdir + "/" + os.path.basename(outdir) + "_graph.gt", args.visualise.gpu_graph, args.visualise.info_csv, args.visualise.rapidnj, args.visualise.tree, args.visualise.mst_distances, args.visualise.overwrite, args.visualise.core_only, - args.visualise.accessory_only) + args.visualise.accessory_only, + args.visualise.display_cluster, + web=True) networkJson = graphml_to_json(outdir) if len(to_include) >= 3: with open(os.path.join(outdir, os.path.basename(outdir) + "_core_NJ.nwk"), "r") as p: @@ -323,4 +332,4 @@ def main(): scheduler.init_app(app) scheduler.start() atexit.register(lambda: scheduler.shutdown()) - app.run(debug=False,use_reloader=False) \ No newline at end of file + app.run(debug=False,use_reloader=False) diff --git a/docs/visualisation.rst b/docs/visualisation.rst index 666daa7a..43d655e8 100644 --- a/docs/visualisation.rst +++ b/docs/visualisation.rst @@ -44,7 +44,7 @@ Visualisation after query assignment:: Visualisation when sketches and models are in different folders:: - poppunk_visualise --ref-db example_db --previous-clustering example_lineages \ + poppunk_visualise --ref-db example_db --previous-clustering example_lineages/example_lineages_lineages.csv \ --model-dir example_lineages --output example_viz --microreact Visualisation with a lineage model, which has been queried (query-query distances must be provided):: diff --git a/environment.yml b/environment.yml index 0c525fe0..76b01588 100644 --- a/environment.yml +++ b/environment.yml @@ -17,7 +17,7 @@ dependencies: - hdbscan - rapidnj - h5py - - pp-sketchlib >=1.6.2 + - pp-sketchlib >=1.7.0 - graph-tool >=2.35 - requests - flask diff --git a/scripts/poppunk_batch_mst.py b/scripts/poppunk_batch_mst.py new file mode 100755 index 00000000..50fa152e --- /dev/null +++ b/scripts/poppunk_batch_mst.py @@ -0,0 +1,430 @@ +#!/usr/bin/env python +# vim: set fileencoding= : +# Copyright 2018-2021 John Lees and Nick Croucher + +# universal +import os +import sys +import argparse +import subprocess +import shutil +import glob +import tempfile +from collections import defaultdict +import pandas as pd + +rfile_names = "rlist.txt" + +# command line parsing +def get_options(): + + parser = argparse.ArgumentParser(description='Batch MST mode (create db + lineage model fit + assign + sparse_mst)', + prog='poppunk_batch_mst') + + # input options + ioGroup = parser.add_argument_group('Input and output file options') + ioGroup.add_argument('--r-files', help='Sample names and locations (as for poppunk --r-files)', + required=True) + ioGroup.add_argument('--batch-file', help='Single column list of batches to process samples in --r-files in') + ioGroup.add_argument('--n-batches', help='Number of batches for process if --batch-file is not specified', + type=int, + default=10) + ioGroup.add_argument('--info-csv', help='CSV containing information about sequences', default=None) + ioGroup.add_argument('--output', help='Prefix for output files', + required=True) + ioGroup.add_argument('--previous-clustering', help='CSV file with previous clusters in MST drawing', + default=None) + ioGroup.add_argument('--iterative-mst', help='Re-calculate the MST for each batch', + default=False, + action='store_true') + ioGroup.add_argument('--keep-intermediates', help='Retain the outputs of each batch', + default=False, + action='store_true') + ioGroup.add_argument('--use-batch-names', help='Name the stored outputs of each batch', + default=False, + action='store_true') + # analysis options + aGroup = parser.add_argument_group('Analysis options') + aGroup.add_argument('--rank', help='Comma separated ranks used to fit lineage model (list of ints)', + type = str, + default = "10") + aGroup.add_argument('--threads', help='Number of threads for parallelisation (int)', + type = int, + default = 1) + aGroup.add_argument('--gpu-dist', help='Use GPU for distance calculations', + default=False, + action='store_true') + aGroup.add_argument('--gpu-graph', help='Use GPU for network analysis', + default=False, + action='store_true') + aGroup.add_argument('--deviceid', help='GPU device ID (int)', + type = int, + default = 0) + aGroup.add_argument('--db-args', help="Other arguments to pass to poppunk. e.g. " + "'--min-k 13 --max-k 29'", + default = "") + aGroup.add_argument('--model-args', help="Other arguments to pass to lineage model fit", + default = "") + aGroup.add_argument('--assign-args', help="Other arguments to pass to poppunk_assign", + default = "") + + # QC options + qcGroup = parser.add_argument_group('Quality control options for distances') + qcGroup.add_argument('--qc-filter', help='Behaviour following sequence QC step: "stop" [default], "prune"' + ' (analyse data passing QC), or "continue" (analyse all data)', + default='stop', type = str, choices=['stop', 'prune', 'continue']) + qcGroup.add_argument('--retain-failures', help='Retain sketches of genomes that do not pass QC filters in ' + 'separate database [default = False]', default=False, action='store_true') + qcGroup.add_argument('--max-a-dist', help='Maximum accessory distance to permit [default = 0.5]', + default = 0.5, type = float) + qcGroup.add_argument('--length-sigma', help='Number of standard deviations of length distribution beyond ' + 'which sequences will be excluded [default = 5]', default = None, type = int) + qcGroup.add_argument('--length-range', help='Allowed length range, outside of which sequences will be excluded ' + '[two values needed - lower and upper bounds]', default=[None,None], + type = int, nargs = 2) + qcGroup.add_argument('--prop-n', help='Threshold ambiguous base proportion above which sequences will be excluded' + ' [default = 0.1]', default = None, + type = float) + qcGroup.add_argument('--upper-n', help='Threshold ambiguous base count above which sequences will be excluded', + default=None, type = int) + + # Executable options + eGroup = parser.add_argument_group('Executable locations') + eGroup.add_argument('--poppunk-exe', help="Location of poppunk executable. Use " + "'python poppunk-runner.py' to run from source tree", + default="poppunk") + eGroup.add_argument('--assign-exe', help="Location of poppunk_assign executable. Use " + "'python poppunk_assign-runner.py' to run from source tree", + default="poppunk_assign") + eGroup.add_argument('--mst-exe', help="Location of poppunk executable. Use " + "'python poppunk_mst-runner.py' to run from source tree", + default="poppunk_mst") + + return parser.parse_args() + +def writeBatch(rlines, batches, batch_selected, use_names = False): + tmpdir = "" + if use_names: + tmpdir = "./pp_mst_" + str(batch_selected) + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir) + os.mkdir(tmpdir) + else: + tmpdir = tempfile.mkdtemp(prefix="pp_mst", dir="./") + with open(tmpdir + "/" + rfile_names, 'w') as outfile: + for rline, batch in zip(rlines, batches): + if batch == batch_selected: + outfile.write(rline) + + return tmpdir + +def runCmd(cmd_string): + sys.stderr.write("Running command:\n") + sys.stderr.write(cmd_string + '\n') + subprocess.run(cmd_string, shell=True, check=True) + +def readLineages(clustCSV): + clusters = defaultdict(dict) + # read CSV + clustersCsv = pd.read_csv(clustCSV, index_col = 0, quotechar='"') + # select relevant columns + type_columns = [n for n,col in enumerate(clustersCsv.columns) if ('Rank_' in col or 'overall' in col)] + # read file + for row in clustersCsv.itertuples(): + for cls_idx in type_columns: + cluster_name = clustersCsv.columns[cls_idx] + cluster_name = cluster_name.replace('__autocolour','') + clusters[cluster_name][row.Index] = str(row[cls_idx + 1]) + # return data structure + return clusters + +def isolateNameToLabel(names): + labels = [name.split('/')[-1].split('.')[0] for name in names] + return labels + +def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering, + epiCsv = None, suffix = '_Lineage'): + # set order of column names + colnames = ['ID'] + for cluster_type in clustering: + col_name = cluster_type + suffix + colnames.append(col_name) + # process epidemiological data + d = defaultdict(list) + # process epidemiological data without duplicating names + # used by PopPUNK + columns_to_be_omitted = ['id', 'Id', 'ID', 'combined_Cluster__autocolour', + 'core_Cluster__autocolour', 'accessory_Cluster__autocolour', + 'overall_Lineage'] + if epiCsv is not None: + epiData = pd.read_csv(epiCsv, index_col = False, quotechar='"') + epiData.index = isolateNameToLabel(epiData.iloc[:,0]) + for e in epiData.columns.values: + if e not in columns_to_be_omitted: + colnames.append(str(e)) + # get example clustering name for validation + example_cluster_title = list(clustering.keys())[0] + for name, label in zip(nodeNames, isolateNameToLabel(nodeLabels)): + if name in clustering[example_cluster_title]: + d['ID'].append(label) + for cluster_type in clustering: + col_name = cluster_type + suffix + d[col_name].append(clustering[cluster_type][name]) + if epiCsv is not None: + if label in epiData.index: + for col, value in zip(epiData.columns.values, epiData.loc[label].values): + if col not in columns_to_be_omitted: + d[col].append(str(value)) + else: + for col in epiData.columns.values: + if col not in columns_to_be_omitted: + d[col].append('nan') + else: + sys.stderr.write("Cannot find " + name + " in clustering\n") + sys.exit(1) + # print CSV + sys.stderr.write("Parsed data, now writing to CSV\n") + try: + pd.DataFrame(data=d).to_csv(outfile, columns = colnames, index = False) + except subprocess.CalledProcessError as e: + sys.stderr.write("Problem with epidemiological data CSV; returned code: " + str(e.returncode) + "\n") + # check CSV + prev_col_items = -1 + prev_col_name = "unknown" + for col in d: + this_col_items = len(d[col]) + if prev_col_items > -1 and prev_col_items != this_col_items: + sys.stderr.write("Discrepant length between " + prev_col_name + \ + " (length of " + prev_col_items + ") and " + \ + col + "(length of " + this_col_items + ")\n") + sys.exit(1) + +# main code +if __name__ == "__main__": + + ########### + # Prepare # + ########### + + # Check input ok + args = get_options() + if args.previous_clustering is not None and \ + not os.path.isfile(args.previous_clustering): + sys.stderr.write("Provided --previous-clustering file cannot be found\n") + sys.exit(1) + + # Extract ranks + ranks = [int(rank) for rank in args.rank.split(',')] + max_rank = max(ranks) + + # Check input file + rlines = [] + nodeNames = [] + nodeLabels = [] + with open(args.r_files,'r') as r_file: + for r_line in r_file: + rlines.append(r_line) + node_info = r_line.rstrip().split() + nodeNames.append(node_info[0]) + nodeLabels.append(node_info[1]) + + # Check batching + batches = [] + if args.batch_file: + # Read specified batches + with open(args.batch_file,'r') as batch_file: + batches = [batch_line.rstrip() for batch_line in batch_file.readlines()] + else: + # Generate arbitrary batches + x = 0 + n = 1 + while x < len(rlines): + if n > args.n_batches: + n = 1 + batches.append(n) + n = n + 1 + x = x + 1 + # Validate batches + batch_names = sorted(set(batches)) + if len(batch_names) < 2: + sys.stderr.write("You must supply multiple batches\n") + sys.exit(1) + first_batch = batch_names.pop(0) + + # try/except block to clean up tmp files + wd = writeBatch(rlines, batches, first_batch, args.use_batch_names) + tmp_dirs = [wd] + try: + + ############### + # First batch # + ############### + + # First batch is create DB + lineage + create_db_cmd = args.poppunk_exe + " --create-db --r-files " + \ + wd + "/" + rfile_names + \ + " --output " + wd + " " + \ + args.db_args + " --threads " + \ + str(args.threads) + " " + \ + args.db_args + # QC options + if None not in args.length_range: + create_db_cmd += " --length-range " + str(args.length_range[0]) + " " + str(args.length_range[1]) + elif args.length_sigma is not None: + create_db_cmd += " --length-sigma " + str(args.length_sigma) + if args.upper_n is not None: + create_db_cmd += " --upper-n " + str(args.upper_n) + elif args.prop_n is not None: + create_db_cmd += " --prop-n " + str(args.prop_n) + create_db_cmd += " --qc-filter " + args.qc_filter + # GPU options + if args.gpu_dist: + create_db_cmd += " --gpu-dist --deviceid " + str(args.deviceid) + runCmd(create_db_cmd) + + # Fit lineage model + fit_model_cmd = args.poppunk_exe + " --fit-model lineage --ref-db " + \ + wd + " --rank " + \ + args.rank + " --threads " + \ + str(args.threads) + " " + \ + args.model_args + runCmd(fit_model_cmd) + + # Calculate MST if operating iteratively + if args.iterative_mst: + + mst_command = args.mst_exe + " --distance-pkl " + wd + \ + "/" + os.path.basename(wd) + ".dists.pkl --rank-fit " + \ + wd + "/" + os.path.basename(wd) + "_rank" + \ + str(max_rank) + "_fit.npz " + \ + " --output " + wd + \ + " --threads " + str(args.threads) + \ + " --previous-clustering " + wd + \ + "/" + os.path.basename(wd) + "_lineages.csv" + # GPU options + if args.gpu_graph: + mst_command = mst_command + " --gpu-graph" + runCmd(mst_command) + + ########### + # Iterate # + ########### + + for batch_idx, batch in enumerate(batch_names): + prev_wd = tmp_dirs[-1] + batch_wd = writeBatch(rlines, batches, batch, args.use_batch_names) + tmp_dirs.append(batch_wd) + + assign_cmd = args.assign_exe + " --db " + prev_wd + \ + " --query " + batch_wd + "/" + rfile_names + \ + " --model-dir " + prev_wd + " --output " + batch_wd + \ + " --threads " + str(args.threads) + " --update-db " + \ + args.assign_args + # QC options + if None not in args.length_range: + assign_cmd += " --length-range " + str(args.length_range[0]) + " " + str(args.length_range[1]) + elif args.length_sigma is not None: + assign_cmd += " --length-sigma " + str(args.length_sigma) + else: + assign_cmd += " --length-sigma 5" # default from __main__ + if args.upper_n is not None: + create_db_cmd += " --upper-n " + str(args.upper_n) + elif args.prop_n is not None: + assign_cmd += " --prop-n " + str(args.prop_n) + else: + assign_cmd += " --prop-n 0.1" # default from __main__ + assign_cmd += " --qc-filter " + args.qc_filter + # GPU options + if args.gpu_dist: + assign_cmd = assign_cmd + " --gpu-dist --deviceid " + str(args.deviceid) + runCmd(assign_cmd) + + # Calculate MST if operating iteratively + if args.iterative_mst: + + mst_command = args.mst_exe + " --distance-pkl " + batch_wd + \ + "/" + os.path.basename(batch_wd) + ".dists.pkl --rank-fit " + \ + batch_wd + "/" + os.path.basename(batch_wd) + "_rank" + \ + str(max_rank) + "_fit.npz " + \ + " --output " + batch_wd + \ + " --threads " + str(args.threads) + \ + " --previous-mst " + \ + prev_wd + "/" + os.path.basename(prev_wd) + ".graphml" + \ + " --previous-clustering " + batch_wd + \ + "/" + os.path.basename(batch_wd) + "_lineages.csv" + if args.gpu_graph: + mst_command = mst_command + " --gpu-graph" + runCmd(mst_command) + + # Remove the previous batch + if batch_idx > 0 and args.keep_intermediates == False: + shutil.rmtree(tmp_dirs[batch_idx - 1]) + + ########## + # Finish # + ########## + + # Calculate MST + output_dir = tmp_dirs[-1] + if args.iterative_mst: + # Create directory + if os.path.exists(args.output): + if os.path.isdir(args.output): + shutil.rmtree(args.output) + else: + os.remove(args.output) + os.mkdir(args.output) + # Copy over final MST + shutil.copy(os.path.join(output_dir,os.path.basename(output_dir) + ".graphml"), + os.path.join(args.output,os.path.basename(args.output) + ".graphml")) + shutil.copy(os.path.join(output_dir,os.path.basename(output_dir) + "_MST.nwk"), + os.path.join(args.output,os.path.basename(args.output) + "_MST.nwk")) + else: + # Calculate MST + mst_command = args.mst_exe + " --distance-pkl " + output_dir + \ + "/" + os.path.basename(output_dir) + ".dists.pkl --rank-fit " + \ + output_dir + "/" + os.path.basename(output_dir) + "_rank" + \ + str(max_rank) + "_fit.npz " + \ + " --output " + args.output + \ + " --threads " + str(args.threads) + if args.previous_clustering is not None: + mst_command = mst_command + " --previous-clustering " + args.previous_clustering + else: + mst_command = mst_command + " --previous-clustering " + \ + os.path.join(output_dir,os.path.basename(output_dir) + "_lineages.csv") + if args.gpu_graph: + mst_command = mst_command + " --gpu-graph" + runCmd(mst_command) + + # Retrieve isolate names and lineages from previous round + os.rename(os.path.join(output_dir,os.path.basename(output_dir) + ".dists.pkl"), + os.path.join(args.output,os.path.basename(args.output) + ".dists.pkl")) + os.rename(os.path.join(output_dir,os.path.basename(output_dir) + "_lineages.csv"), + os.path.join(args.output,os.path.basename(args.output) + "_lineages.csv")) + for rank in ranks: + os.rename(os.path.join(output_dir, os.path.basename(output_dir) + "_rank" + str(rank) + "_fit.npz"), + os.path.join(args.output, os.path.basename(args.output) + "_rank" + str(rank) + "_fit.npz")) + + # Merge with epidemiological data if requested + if args.info_csv is not None: + lineage_clustering = readLineages(os.path.join(args.output, + os.path.basename(args.output) + "_lineages.csv")) + writeClusterCsv(os.path.join(args.output, + os.path.basename(args.output) + "_info.csv"), + nodeNames, + nodeLabels, + lineage_clustering, + epiCsv = args.info_csv) + + except: + if args.keep_intermediates == False: + for tmpdir in tmp_dirs: + try: + shutil.rmtree(tmpdir) + except: + sys.stderr.write("Unable to remove " + tmpdir + "\n") + print("Unexpected error:", sys.exc_info()[0]) + raise + + if args.keep_intermediates == False: + shutil.rmtree(output_dir) diff --git a/scripts/poppunk_easy_run.py b/scripts/poppunk_easy_run.py index ccefca10..28d53e87 100755 --- a/scripts/poppunk_easy_run.py +++ b/scripts/poppunk_easy_run.py @@ -13,7 +13,7 @@ def get_options(): prog='easy_run') # input options - parser.add_argument('--r-files', help='List of sequence names and files (as for --r-files') + parser.add_argument('--r-files', help='List of sequence names and files (as for --r-files)') parser.add_argument('--output', help='Prefix for output files') parser.add_argument('--analysis-args', help="Other arguments to pass to poppunk. e.g. " diff --git a/scripts/poppunk_extract_distances.py b/scripts/poppunk_extract_distances.py index 6552fd03..eb4805f1 100755 --- a/scripts/poppunk_extract_distances.py +++ b/scripts/poppunk_extract_distances.py @@ -7,6 +7,7 @@ import numpy as np import argparse import dendropy +from scipy import sparse # command line parsing def get_options(): @@ -14,15 +15,25 @@ def get_options(): parser = argparse.ArgumentParser(description='Extract tab-separated file of distances from pkl and npy files', prog='extract_distances') # input options - parser.add_argument('--distances', required=True, help='Prefix of input pickle and numpy file of pre-calculated distances (required)') - parser.add_argument('--tree', required=False, help='Newick file containing phylogeny of isolates', default = None) - parser.add_argument('--output', required=True, help='Name of output file') + parser.add_argument('--distances', help='Prefix of input pickle (and optionally,' + ' numpy file) of pre-calculated distances (required)', + required=True) + parser.add_argument('--sparse', help='Sparse distance matrix file name', + default = None, + required = False) + parser.add_argument('--tree', help='Newick file containing phylogeny of isolates', + required = False, + default = None) + parser.add_argument('--output', help='Name of output file', + required = True) return parser.parse_args() -def iterDistRows(refSeqs, querySeqs, self=True): +def listDistInts(refSeqs, querySeqs, self=True): """Gets the ref and query ID for each row of the distance matrix + Returns an iterable with ref and query ID pairs by row. + Args: refSeqs (list) List of reference sequence names. @@ -36,15 +47,19 @@ def iterDistRows(refSeqs, querySeqs, self=True): ref, query (str, str) Iterable of tuples with ref and query names for each distMat row. """ + num_ref = len(refSeqs) + num_query = len(querySeqs) if self: - assert refSeqs == querySeqs - for i, ref in enumerate(refSeqs): - for j in range(i + 1, len(refSeqs)): - yield(refSeqs[j], ref) + if refSeqs != querySeqs: + raise RuntimeError('refSeqs must equal querySeqs for db building (self = true)') + for i in range(num_ref): + for j in range(i + 1, num_ref): + yield(j, i) else: - for query in querySeqs: - for ref in refSeqs: - yield(ref, query) + comparisons = [(0,0)] * (len(refSeqs) * len(querySeqs)) + for i in range(num_query): + for j in range(num_ref): + yield(j, i) def isolateNameToLabel(names): """Function to process isolate names to labels @@ -71,7 +86,6 @@ def isolateNameToLabel(names): # open stored distances with open(args.distances + ".pkl", 'rb') as pickle_file: rlist, qlist, self = pickle.load(pickle_file) - X = np.load(args.distances + ".npy") # get names order r_names = isolateNameToLabel(rlist) @@ -91,14 +105,32 @@ def isolateNameToLabel(names): taxon_name = t.label.replace(' ','_') tip_index[r_names.index(taxon_name)] = t + # Load sparse matrix + if args.sparse is not None: + sparse_mat = sparse.load_npz(args.sparse) + else: + X = np.load(args.distances + ".npy") + # open output file with open(args.output, 'w') as oFile: - oFile.write("\t".join(['Query', 'Reference', 'Core', 'Accessory'])) + # Write header of output file + if args.sparse is not None: + oFile.write("\t".join(['Query', 'Reference', 'Core'])) + else: + oFile.write("\t".join(['Query', 'Reference', 'Core', 'Accessory'])) if args.tree is not None: oFile.write("\t" + 'Patristic') oFile.write("\n") - for i, (r_index, q_index) in enumerate(iterDistRows(r_names, q_names, r_names == q_names)): - oFile.write("\t".join([q_names[q_index], r_names[r_index], str(X[i,0]), str(X[i,1])])) - if args.tree is not None: - oFile.write("\t" + str(pdc(tip_index[r_index], tip_index[q_index]))) - oFile.write("\n") + # Write distances + if args.sparse is not None: + for (r_index, q_index, dist) in zip(sparse_mat.col, sparse_mat.row, sparse_mat.data): + oFile.write("\t".join([q_names[q_index], r_names[r_index], str(dist)])) + if args.tree is not None: + oFile.write("\t" + str(pdc(tip_index[r_index], tip_index[q_index]))) + oFile.write("\n") + else: + for i, (r_index, q_index) in enumerate(listDistInts(r_names, q_names, r_names == q_names)): + oFile.write("\t".join([q_names[q_index], r_names[r_index], str(X[i,0]), str(X[i,1])])) + if args.tree is not None: + oFile.write("\t" + str(pdc(tip_index[r_index], tip_index[q_index]))) + oFile.write("\n") diff --git a/setup.py b/setup.py index 0465044b..fb3cd788 100644 --- a/setup.py +++ b/setup.py @@ -119,6 +119,7 @@ def build_extension(self, ext): scripts=['scripts/poppunk_calculate_rand_indices.py', 'scripts/poppunk_extract_components.py', 'scripts/poppunk_calculate_silhouette.py', + 'scripts/poppunk_batch_mst.py', 'scripts/poppunk_extract_distances.py', 'scripts/poppunk_add_weights.py', 'scripts/poppunk_db_info.py', diff --git a/test/clean_test.py b/test/clean_test.py index 3ecc96a1..b1923144 100755 --- a/test/clean_test.py +++ b/test/clean_test.py @@ -40,7 +40,12 @@ def deleteDir(dirname): "example_tsne", "example_prune", "example_refs", - "example_api" + "example_api", + "batch1", + "batch2", + "batch3", + "batch12", + "batch123" ] for outDir in outputDirs: deleteDir(outDir) diff --git a/test/rfile1.txt b/test/rfile1.txt new file mode 100644 index 00000000..4f388da2 --- /dev/null +++ b/test/rfile1.txt @@ -0,0 +1,3 @@ +7 12673_8#24.contigs_velvet.fa +1 12673_8#34.contigs_velvet.fa +2 12673_8#43.contigs_velvet.fa diff --git a/test/rfile12.txt b/test/rfile12.txt new file mode 100644 index 00000000..e4f63584 --- /dev/null +++ b/test/rfile12.txt @@ -0,0 +1,6 @@ +7 12673_8#24.contigs_velvet.fa +1 12673_8#34.contigs_velvet.fa +2 12673_8#43.contigs_velvet.fa +6 12754_4#79.contigs_velvet.fa +4 12754_4#85.contigs_velvet.fa +5 12754_4#89.contigs_velvet.fa diff --git a/test/rfile123.txt b/test/rfile123.txt new file mode 100644 index 00000000..af5a0ead --- /dev/null +++ b/test/rfile123.txt @@ -0,0 +1,9 @@ +7 12673_8#24.contigs_velvet.fa +1 12673_8#34.contigs_velvet.fa +2 12673_8#43.contigs_velvet.fa +6 12754_4#79.contigs_velvet.fa +4 12754_4#85.contigs_velvet.fa +5 12754_4#89.contigs_velvet.fa +8 12754_5#73.contigs_velvet.fa +3 12754_5#78.contigs_velvet.fa +9 12754_5#71.contigs_velvet.fa diff --git a/test/rfile2.txt b/test/rfile2.txt new file mode 100644 index 00000000..5f6e9a24 --- /dev/null +++ b/test/rfile2.txt @@ -0,0 +1,3 @@ +6 12754_4#79.contigs_velvet.fa +4 12754_4#85.contigs_velvet.fa +5 12754_4#89.contigs_velvet.fa diff --git a/test/rfile3.txt b/test/rfile3.txt new file mode 100644 index 00000000..23104358 --- /dev/null +++ b/test/rfile3.txt @@ -0,0 +1,3 @@ +8 12754_5#73.contigs_velvet.fa +3 12754_5#78.contigs_velvet.fa +9 12754_5#71.contigs_velvet.fa diff --git a/test/run_test.py b/test/run_test.py index e0389ef6..8ab75776 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -25,6 +25,10 @@ sys.stderr.write("Running database QC test (--create-db)\n") subprocess.run(python_cmd + " ../poppunk-runner.py --create-db --r-files references.txt --min-k 13 --k-step 3 --output example_qc --qc-filter continue --length-range 2000000 3000000 --overwrite", shell=True, check=True) +# test updating order is correct +sys.stderr.write("Running distance matrix order check (--update-db)\n") +subprocess.run(python_cmd + " test-update.py", shell=True, check=True) + #fit GMM sys.stderr.write("Running GMM model fit (--fit-model gmm)\n") subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model bgmm --ref-db example_db --K 4 --overwrite", shell=True, check=True) @@ -35,11 +39,11 @@ #refine model with GMM sys.stderr.write("Running model refinement (--fit-model refine)\n") -subprocess.run("python ../poppunk-runner.py --fit-model refine --ref-db example_db --output example_refine --neg-shift 0.8 --overwrite", shell=True, check=True) -subprocess.run("python ../poppunk-runner.py --fit-model refine --ref-db example_db --output example_refine --neg-shift 0.8 --overwrite --indiv-refine both", shell=True, check=True) -subprocess.run("python ../poppunk-runner.py --fit-model refine --ref-db example_db --output example_refine --neg-shift 0.8 --overwrite --score-idx 1", shell=True, check=True) -subprocess.run("python ../poppunk-runner.py --fit-model refine --ref-db example_db --output example_refine --neg-shift 0.8 --overwrite --score-idx 2", shell=True, check=True) -subprocess.run("python ../poppunk-runner.py --fit-model threshold --threshold 0.003 --ref-db example_db --output example_threshold", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model refine --ref-db example_db --output example_refine --neg-shift 0.8 --overwrite", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model refine --ref-db example_db --output example_refine --neg-shift 0.8 --overwrite --indiv-refine both", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model refine --ref-db example_db --output example_refine --neg-shift 0.8 --overwrite --score-idx 1", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model refine --ref-db example_db --output example_refine --neg-shift 0.8 --overwrite --score-idx 2", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model threshold --threshold 0.003 --ref-db example_db --output example_threshold", shell=True, check=True) # lineage clustering sys.stderr.write("Running lineage clustering test (--fit-model lineage)\n") @@ -51,7 +55,7 @@ # tests of other command line programs sys.stderr.write("Testing C++ extension\n") -subprocess.run("python test-refine.py", shell=True, check=True) +subprocess.run(python_cmd + " test-refine.py", shell=True, check=True) #assign query sys.stderr.write("Running query assignment\n") @@ -63,18 +67,18 @@ # viz sys.stderr.write("Running visualisations (poppunk_visualise)\n") subprocess.run(python_cmd + " ../poppunk_visualise-runner.py --ref-db example_db --output example_viz --microreact", shell=True, check=True) -subprocess.run(python_cmd + " ../poppunk_visualise-runner.py --ref-db example_db --output example_viz --cytoscape", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk_visualise-runner.py --ref-db example_db --output example_viz --cytoscape --network-file example_db/example_db_graph.gt", shell=True, check=True) subprocess.run(python_cmd + " ../poppunk_visualise-runner.py --ref-db example_db --output example_viz --phandango", shell=True, check=True) subprocess.run(python_cmd + " ../poppunk_visualise-runner.py --ref-db example_db --output example_viz --grapetree", shell=True, check=True) subprocess.run(python_cmd + " ../poppunk_visualise-runner.py --ref-db example_db --output example_viz_subset --microreact --include-files subset.txt", shell=True, check=True) subprocess.run(python_cmd + " ../poppunk_visualise-runner.py --ref-db example_db --query-db example_query --output example_viz_query --microreact", shell=True, check=True) -subprocess.run(python_cmd + " ../poppunk_visualise-runner.py --ref-db example_db --previous-clustering example_lineages --model-dir example_lineages --output example_lineage_viz --microreact", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk_visualise-runner.py --ref-db example_db --previous-clustering example_lineages/example_lineages_lineages.csv --model-dir example_lineages --output example_lineage_viz --microreact", shell=True, check=True) subprocess.run(python_cmd + " ../poppunk_visualise-runner.py --distances example_query/example_query.dists --ref-db example_db --model-dir example_lineages --query-db example_lineage_query --output example_viz_query_lineages --microreact", shell=True, check=True) # MST sys.stderr.write("Running MST\n") -subprocess.run("python ../poppunk_visualise-runner.py --ref-db example_db --output example_mst --microreact --tree mst", shell=True, check=True) -subprocess.run("python ../poppunk_mst-runner.py --distances example_db/example_db.dists --rank-fit example_lineages/example_lineages_rank5_fit.npz --previous-clustering example_dbscan/example_dbscan_clusters.csv --output example_sparse_mst --no-plot", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk_visualise-runner.py --ref-db example_db --output example_mst --microreact --tree mst", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk_mst-runner.py --distance-pkl example_db/example_db.dists.pkl --rank-fit example_lineages/example_lineages_rank5_fit.npz --previous-clustering example_dbscan/example_dbscan_clusters.csv --output example_sparse_mst --no-plot", shell=True, check=True) # t-sne sys.stderr.write("Running tsne viz\n") @@ -95,7 +99,7 @@ # web API sys.stderr.write("Running API tests\n") -subprocess.run(python_cmd + " test_web.py", shell=True, check=True) +subprocess.run(python_cmd + " test-web.py", shell=True, check=True) sys.stderr.write("Tests completed\n") diff --git a/test/test-update.py b/test/test-update.py new file mode 100755 index 00000000..22aa630e --- /dev/null +++ b/test/test-update.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python +# Copyright 2018-2021 John Lees and Nick Croucher + +"""Tests for PopPUNK --update-db order""" + +import subprocess +import os, sys +import sys +import shutil +import pickle + +import numpy as np +from scipy import stats +import h5py +import scipy.sparse + +import pp_sketchlib + +if os.environ.get("POPPUNK_PYTHON"): + python_cmd = os.environ.get("POPPUNK_PYTHON") +else: + python_cmd = "python" + +def run_regression(x, y, threshold = 0.99): + res = stats.linregress(x, y) + print("R^2: " + str(res.rvalue**2)) + if res.rvalue**2 < threshold: + sys.stderr.write("Distance matrix order failed!\n") + sys.exit(1) + +def compare_sparse_matrices(d1,d2,r1,r2): + d1_pairs = get_seq_tuples(d1.row,d1.col,r1) + d2_pairs = get_seq_tuples(d2.row,d2.col,r2) + d1_dists = [] + d2_dists = [] + + for (pair1,dist1) in zip(d1_pairs,d1.data): + for (pair2,dist2) in zip(d2_pairs,d2.data): + if pair1 == pair2: + d1_dists.append(dist1) + d2_dists.append(dist2) + break + + run_regression(np.asarray(d1_dists),np.asarray(d2_dists)) + +def get_seq_tuples(rows,cols,names): + tuple_list = [] + for (i,j) in zip(rows,cols): + sorted_pair = tuple(sorted((names[i],names[j]))) + tuple_list.append(sorted_pair) + return tuple_list + +def old_get_seq_tuples(rows,cols): + max_seqs = np.maximum(rows,cols) + min_seqs = np.minimum(rows,cols) + concat_seqs = np.vstack((max_seqs,min_seqs)) + seq_pairs = concat_seqs.T + seq_tuples = [tuple(row) for row in seq_pairs] + return seq_tuples + +# Check distances after one query + +# Check that order is the same after doing 1 + 2 with --update-db, as doing all of 1 + 2 together +subprocess.run(python_cmd + " ../poppunk-runner.py --create-db --r-files rfile12.txt --output batch12 --overwrite", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model lineage --ref-db batch12 --ranks 1,2", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk-runner.py --create-db --r-files rfile1.txt --output batch1 --overwrite", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model lineage --ref-db batch1 --ranks 1,2", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk_assign-runner.py --db batch1 --query rfile2.txt --output batch2 --update-db --overwrite", shell=True, check=True) + +# Load updated distances +X2 = np.load("batch2/batch2.dists.npy") +with open("batch2/batch2.dists.pkl", 'rb') as pickle_file: + rlist2, qlist, self = pickle.load(pickle_file) + +# Get same distances from the full database +ref_db = "batch12/batch12" +ref_h5 = h5py.File(ref_db + ".h5", 'r') +db_kmers = sorted(ref_h5['sketches/' + rlist2[0]].attrs['kmers']) +ref_h5.close() +X1 = pp_sketchlib.queryDatabase(ref_db, ref_db, rlist2, rlist2, db_kmers, + True, False, 1, False, 0) + +# Check distances match +run_regression(X1[:, 0], X2[:, 0]) +run_regression(X1[:, 1], X2[:, 1]) + +# Check sparse distances after one query +with open("batch12/batch12.dists.pkl", 'rb') as pickle_file: + rlist1, qlist1, self = pickle.load(pickle_file) +S1 = scipy.sparse.load_npz("batch12/batch12_rank2_fit.npz") +S2 = scipy.sparse.load_npz("batch2/batch2_rank2_fit.npz") +compare_sparse_matrices(S1,S2,rlist1,rlist2) + +# Check distances after second query + +# Check that order is the same after doing 1 + 2 + 3 with --update-db, as doing all of 1 + 2 + 3 together +subprocess.run(python_cmd + " ../poppunk-runner.py --create-db --r-files rfile123.txt --output batch123 --overwrite", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model lineage --ref-db batch123 --ranks 1,2", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk_assign-runner.py --db batch2 --query rfile3.txt --output batch3 --update-db --overwrite", shell=True, check=True) + +# Load updated distances +X2 = np.load("batch3/batch3.dists.npy") +with open("batch3/batch3.dists.pkl", 'rb') as pickle_file: + rlist4, qlist, self = pickle.load(pickle_file) + +# Get same distances from the full database +ref_db = "batch123/batch123" +ref_h5 = h5py.File(ref_db + ".h5", 'r') +db_kmers = sorted(ref_h5['sketches/' + rlist4[0]].attrs['kmers']) +ref_h5.close() +X1 = pp_sketchlib.queryDatabase(ref_db, ref_db, rlist4, rlist4, db_kmers, + True, False, 1, False, 0) + +# Check distances match +run_regression(X1[:, 0], X2[:, 0]) +run_regression(X1[:, 1], X2[:, 1]) + +# Check sparse distances after second query +with open("batch123/batch123.dists.pkl", 'rb') as pickle_file: + rlist3, qlist, self = pickle.load(pickle_file) +S3 = scipy.sparse.load_npz("batch123/batch123_rank2_fit.npz") +S4 = scipy.sparse.load_npz("batch3/batch3_rank2_fit.npz") + +compare_sparse_matrices(S3,S4,rlist3,rlist4) diff --git a/test/test-web.py b/test/test-web.py new file mode 100644 index 00000000..a69505c1 --- /dev/null +++ b/test/test-web.py @@ -0,0 +1,131 @@ +import os +import sys +import subprocess +from shutil import copyfile + +# testing without install +#sys.path.insert(0, '..') +from PopPUNK.assign import assign_query +from PopPUNK.web import default_options, summarise_clusters, get_colours, api, graphml_to_json +from PopPUNK.utils import setupDBFuncs +from PopPUNK.visualise import generate_visualisations + +def main(): + # Copy and move args and sketch files into example dirs + copyfile("web_args.txt", "example_db/args.txt") + copyfile("example_viz/example_viz_core_NJ.nwk", "example_viz/example_viz.nwk") + + # Test the output of the PopPUNk-web upload route for incorrect data types + sys.stderr.write('\nTesting assign for PopPUNK-web\n') + with open("json_sketch.txt", "r") as s: + sketch = s.read() + species = "Listeria monocytogenes" + species_db = "example_db" + outdir = "example_api" + if not os.path.exists(outdir): + os.mkdir(outdir) + args = default_options(species_db) + qc_dict = {'run_qc': False } + dbFuncs = setupDBFuncs(args.assign, args.assign.min_kmer_count, qc_dict) + ClusterResult = assign_query(dbFuncs, + args.assign.ref_db, + args.assign.q_files, + outdir, + qc_dict, + args.assign.update_db, + args.assign.write_references, + args.assign.distances, + args.assign.threads, + args.assign.overwrite, + args.assign.plot_fit, + args.assign.graph_weights, + args.assign.max_a_dist, + args.assign.max_pi_dist, + args.assign.type_isolate, + args.assign.model_dir, + args.assign.strand_preserved, + args.assign.previous_clustering, + args.assign.external_clustering, + args.assign.core_only, + args.assign.accessory_only, + args.assign.gpu_sketch, + args.assign.gpu_dist, + args.assign.gpu_graph, + args.assign.deviceid, + args.assign.web, + sketch, + args.assign.save_partial_query_graph) + query, query_prevalence, clusters, prevalences, alias_dict, to_include = \ + summarise_clusters(outdir, species, species_db) + colours = get_colours(query, clusters) + url = api(query, "example_viz") + sys.stderr.write('PopPUNK-web assign test successful\n') + + # Test generate_visualisations() for PopPUNK-web + sys.stderr.write('\nTesting visualisations for PopPUNK-web\n') + if len(to_include) < 3: + args.visualise.microreact = False + generate_visualisations(outdir, + species_db, + None, + args.visualise.threads, + outdir, + args.visualise.gpu_dist, + args.visualise.deviceid, + args.visualise.external_clustering, + args.visualise.microreact, + args.visualise.phandango, + args.visualise.grapetree, + args.visualise.cytoscape, + args.visualise.perplexity, + args.visualise.strand_preserved, + outdir + "/include.txt", + species_db, + species_db + "/" + os.path.basename(species_db) + "_clusters.csv", + args.visualise.previous_query_clustering, + outdir + "/" + os.path.basename(outdir) + "_graph.gt", + args.visualise.gpu_graph, + args.visualise.info_csv, + args.visualise.rapidnj, + args.visualise.tree, + args.visualise.mst_distances, + args.visualise.overwrite, + args.visualise.core_only, + args.visualise.accessory_only, + args.visualise.display_cluster, + web=True) + networkJson = graphml_to_json(outdir) + if len(to_include) >= 3: + with open(os.path.join(outdir, os.path.basename(outdir) + "_core_NJ.nwk"), "r") as p: + phylogeny = p.read() + else: + phylogeny = "A tree cannot be built with fewer than 3 samples." + + # ensure web api outputs are of the correct type + if not isinstance(species, str): + raise TypeError('"Species" datatype is incorrect, should be string.\n') + if not (isinstance(query_prevalence, float) or isinstance(query_prevalence, int)): + raise TypeError('"query_prevalence" datatype is incorrect, should be float/integer.\n') + if not isinstance(query, str): + raise TypeError('"query" datatype is incorrect, should be string.\n') + if not isinstance(clusters, list) and not isinstance(clusters[0], str): + raise TypeError('"clusters" datatype is incorrect, should be list of strings.\n') + if not isinstance(prevalences, list) and not (isinstance(prevalences[0], float) or isinstance(prevalences[0], int)): + raise TypeError('"prevalences" datatype is incorrect, should be list of floats/integers.\n') + if not isinstance(colours, list) and not isinstance(colours[0], str): + raise TypeError('"colours" datatype is incorrect, should be list of strings.\n') + if not isinstance(url, str): + raise TypeError('"url" datatype is incorrect, should be string.\n') + if not isinstance(alias_dict, dict): + raise TypeError('"alias_dict" datatype is incorrect, should be dictionary.\n') + if not isinstance(outdir, str): + raise TypeError('"outdir" datatype is incorrect, should be string.\n') + if not isinstance(networkJson, dict): + raise TypeError('"networkJson" datatype is incorrect, should be dict.\n') + if not isinstance(phylogeny, str): + raise TypeError('"phylogeny" datatype is incorrect, should be str.\n') + + sys.stderr.write('\nAPI tests complete\n') + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/test_web.py b/test/test_web.py deleted file mode 100644 index dd0f7390..00000000 --- a/test/test_web.py +++ /dev/null @@ -1,117 +0,0 @@ -import os -import sys -import subprocess -from shutil import copyfile - -# testing without install -#sys.path.insert(0, '..') -from PopPUNK.assign import assign_query -from PopPUNK.web import default_options, summarise_clusters, get_colours, api, graphml_to_json -from PopPUNK.utils import setupDBFuncs -from PopPUNK.visualise import generate_visualisations - -# Copy and move args and sketch files into example dirs -copyfile("web_args.txt", "example_db/args.txt") -copyfile("example_viz/example_viz_core_NJ.nwk", "example_viz/example_viz.nwk") - -# Test the output of the PopPUNk-web upload route for incorrect data types -sys.stderr.write('\nTesting assign for PopPUNK-web\n') -with open("json_sketch.txt", "r") as s: - sketch = s.read() -species = "Listeria monocytogenes" -species_db = "example_db" -outdir = "example_api" -if not os.path.exists(outdir): - os.mkdir(outdir) -args = default_options(species_db) -qc_dict = {'run_qc': False } -dbFuncs = setupDBFuncs(args.assign, args.assign.min_kmer_count, qc_dict) -ClusterResult = assign_query(dbFuncs, - args.assign.ref_db, - args.assign.q_files, - outdir, - args.assign.update_db, - args.assign.write_references, - args.assign.distances, - args.assign.threads, - args.assign.overwrite, - args.assign.plot_fit, - args.assign.graph_weights, - args.assign.max_a_dist, - args.assign.model_dir, - args.assign.strand_preserved, - args.assign.previous_clustering, - args.assign.external_clustering, - args.assign.core_only, - args.assign.accessory_only, - args.assign.web, - sketch, - args.assign.save_partial_query_graph) -query, query_prevalence, clusters, prevalences, alias_dict, to_include = \ - summarise_clusters(outdir, species, species_db) -colours = get_colours(query, clusters) -url = api(query, "example_viz") -sys.stderr.write('PopPUNK-web assign test successful\n') - -# Test generate_visualisations() for PopPUNK-web -sys.stderr.write('\nTesting visualisations for PopPUNK-web\n') -if len(to_include) < 3: - args.visualise.microreact = False -generate_visualisations(outdir, - species_db, - None, - args.visualise.threads, - outdir, - args.visualise.gpu_dist, - args.visualise.deviceid, - args.visualise.external_clustering, - args.visualise.microreact, - args.visualise.phandango, - args.visualise.grapetree, - args.visualise.cytoscape, - args.visualise.perplexity, - args.visualise.strand_preserved, - outdir + "/include.txt", - species_db, - species_db, - args.visualise.previous_query_clustering, - outdir, - args.visualise.info_csv, - args.visualise.rapidnj, - args.visualise.tree, - args.visualise.mst_distances, - args.visualise.overwrite, - args.visualise.core_only, - args.visualise.accessory_only) -networkJson = graphml_to_json(outdir) -if len(to_include) >= 3: - with open(os.path.join(outdir, os.path.basename(outdir) + "_core_NJ.nwk"), "r") as p: - phylogeny = p.read() -else: - phylogeny = "A tree cannot be built with fewer than 3 samples." - -# ensure web api outputs are of the correct type -if not isinstance(species, str): - raise TypeError('"Species" datatype is incorrect, should be string.\n') -if not (isinstance(query_prevalence, float) or isinstance(query_prevalence, int)): - raise TypeError('"query_prevalence" datatype is incorrect, should be float/integer.\n') -if not isinstance(query, str): - raise TypeError('"query" datatype is incorrect, should be string.\n') -if not isinstance(clusters, list) and not isinstance(clusters[0], str): - raise TypeError('"clusters" datatype is incorrect, should be list of strings.\n') -if not isinstance(prevalences, list) and not (isinstance(prevalences[0], float) or isinstance(prevalences[0], int)): - raise TypeError('"prevalences" datatype is incorrect, should be list of floats/integers.\n') -if not isinstance(colours, list) and not isinstance(colours[0], str): - raise TypeError('"colours" datatype is incorrect, should be list of strings.\n') -if not isinstance(url, str): - raise TypeError('"url" datatype is incorrect, should be string.\n') -if not isinstance(alias_dict, dict): - raise TypeError('"alias_dict" datatype is incorrect, should be dictionary.\n') -if not isinstance(outdir, str): - raise TypeError('"outdir" datatype is incorrect, should be string.\n') -if not isinstance(networkJson, dict): - raise TypeError('"networkJson" datatype is incorrect, should be dict.\n') -if not isinstance(phylogeny, str): - raise TypeError('"phylogeny" datatype is incorrect, should be str.\n') - -sys.stderr.write('\nAPI tests complete\n') diff --git a/test/web_args.txt b/test/web_args.txt index 2eb7f0fb..dc397c11 100644 --- a/test/web_args.txt +++ b/test/web_args.txt @@ -8,6 +8,8 @@ "plot_fit":0, "graph_weights":true, "max_a_dist":0.5, + "max_pi_dist":0.5, + "type_isolate":null, "strand_preserved":false, "external_clustering":null, "core_only":false, @@ -19,6 +21,7 @@ "gpu_sketch":false, "deviceid":0, "gpu_dist":false, + "gpu_graph":false, "min_kmer_count":0, "min_k":14, "max_k":29, @@ -36,6 +39,7 @@ "visualise":{ "threads":1, "gpu_dist":false, + "gpu_graph":false, "deviceid":0, "external_clustering":null, "microreact":true, @@ -51,6 +55,7 @@ "mst_distances":"core", "overwrite":true, "core_only":false, - "accessory_only":false + "accessory_only":false, + "display_cluster":null } }