diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py index c472d74c..cccef1af 100644 --- a/PopPUNK/__main__.py +++ b/PopPUNK/__main__.py @@ -7,7 +7,7 @@ import sys # additional import numpy as np -import networkx as nx +import graph_tool.all as gt import subprocess # required from v2.1.1 onwards (no mash support) @@ -106,7 +106,6 @@ def get_options(): default=False, action='store_true') - # input options iGroup = parser.add_argument_group('Input files') iGroup.add_argument('--ref-db',type = str, help='Location of built reference database') @@ -249,7 +248,7 @@ def main(): if args.min_k >= args.max_k: sys.stderr.write("Minimum kmer size " + str(args.min_k) + " must be smaller than maximum kmer size\n") sys.exit(1) - elif args.k_step < 2: + elif args.k_step < 1: sys.stderr.write("Kmer size step must be at least one\n") sys.exit(1) elif no_sketchlib and (args.min_k < 9 or args.max_k > 31): @@ -300,6 +299,12 @@ def main(): if args.ref_db is not None and args.ref_db.endswith('/'): args.ref_db = args.ref_db[:-1] + # Check on parallelisation of graph-tools + if gt.openmp_enabled(): + gt.openmp_set_num_threads(args.threads) + sys.stderr.write('\nGraph-tools OpenMP parallelisation enabled:') + sys.stderr.write(' with ' + str(gt.openmp_get_num_threads()) + ' threads\n') + # run according to mode sys.stderr.write("PopPUNK (POPulation Partitioning Using Nucleotide Kmers)\n") sys.stderr.write("\t(with backend: " + dbFuncs['backend'] + " v" + dbFuncs['backend_version'] + "\n") @@ -358,7 +363,7 @@ def main(): #* *# #******************************# # refine model also needs to run all model steps - if args.fit_model or args.use_model or args.refine_model or args.threshold or args.easy_run: + if args.fit_model or args.use_model or args.refine_model or args.threshold or args.easy_run or args.lineage_clustering: # Set up saved data from first step, if easy_run mode if args.easy_run: distances = dists_out @@ -370,9 +375,10 @@ def main(): sys.stderr.write("Mode: Using previous model with a reference database\n\n") elif args.threshold: sys.stderr.write("Mode: Applying a core distance threshold\n\n") - else: + elif args.refine_model: sys.stderr.write("Mode: Refining model fit using network properties\n\n") - + elif args.lineage_clustering: + sys.stderr.write("Mode: Identifying lineages from neighbouring isolates\n\n") if args.distances is not None and args.ref_db is not None: distances = args.distances ref_db = args.ref_db @@ -419,6 +425,8 @@ def main(): model = BGMMFit(args.output) assignments = model.fit(distMat, args.K) model.plot(distMat, assignments) + # save model + model.save() # Run model refinement if args.refine_model or args.threshold or args.easy_run: @@ -436,43 +444,85 @@ def main(): if args.use_model: assignments = model.assign(distMat) model.plot(distMat, assignments) - - fit_type = 'combined' - model.save() - + #******************************# #* *# #* network construction *# #* *# #******************************# - genomeNetwork = constructNetwork(refList, queryList, assignments, model.within_label) + + if not args.lineage_clustering: + genomeNetwork = constructNetwork(refList, queryList, assignments, model.within_label) + # Ensure all in dists are in final network + networkMissing = set(range(len(refList))).difference(list(genomeNetwork.vertices())) + if len(networkMissing) > 0: + missing_isolates = [refList[m] for m in networkMissing] + sys.stderr.write("WARNING: Samples " + ", ".join(missing_isolates) + " are missing from the final network\n") + + fit_type = None + isolateClustering = {fit_type: printClusters(genomeNetwork, + refList, + args.output + "/" + os.path.basename(args.output), + externalClusterCSV = args.external_clustering)} + + # Write core and accessory based clusters, if they worked + if model.indiv_fitted: + indivNetworks = {} + for dist_type, slope in zip(['core', 'accessory'], [0, 1]): + indivAssignments = model.assign(distMat, slope) + indivNetworks[dist_type] = constructNetwork(refList, queryList, indivAssignments, model.within_label) + isolateClustering[dist_type] = printClusters(indivNetworks[dist_type], + refList, + args.output + "/" + os.path.basename(args.output) + "_" + dist_type, + externalClusterCSV = args.external_clustering) + indivNetworks[dist_type].save(args.output + "/" + os.path.basename(args.output) + + "_" + dist_type + '_graph.gt', fmt = 'gt') + + if args.core_only: + fit_type = 'core' + genomeNetwork = indivNetworks['core'] + elif args.accessory_only: + fit_type = 'accessory' + genomeNetwork = indivNetworks['accessory'] - # Ensure all in dists are in final network - networkMissing = set(refList).difference(list(genomeNetwork.nodes())) - if len(networkMissing) > 0: - sys.stderr.write("WARNING: Samples " + ",".join(networkMissing) + " are missing from the final network\n") - isolateClustering = {fit_type: printClusters(genomeNetwork, - args.output + "/" + os.path.basename(args.output), - externalClusterCSV = args.external_clustering)} + #******************************# + #* *# + #* lineages analysis *# + #* *# + #******************************# + + if args.lineage_clustering: + + # load distances + if args.distances is not None: + distances = args.distances + else: + sys.stderr.write("Need to provide an input set of distances with --distances\n\n") + sys.exit(1) - # Write core and accessory based clusters, if they worked - if model.indiv_fitted: + refList, queryList, self, distMat = readPickle(distances) + + # make directory for new output files + if not os.path.isdir(args.output): + try: + os.makedirs(args.output) + except OSError: + sys.stderr.write("Cannot create output directory\n") + sys.exit(1) + + # run lineage clustering + if self: + isolateClustering = cluster_into_lineages(distMat, rank_list, args.output, isolate_list = refList, use_accessory = args.use_accessory, existing_scheme = args.existing_scheme, num_processes = args.threads) + else: + isolateClustering = cluster_into_lineages(distMat, rank_list, args.output, isolate_list = refList, qlist = queryList, use_accessory = args.use_accessory, existing_scheme = args.existing_scheme, num_processes = args.threads) + + # load networks indivNetworks = {} - for dist_type, slope in zip(['core', 'accessory'], [0, 1]): - indivAssignments = model.assign(distMat, slope) - indivNetworks[dist_type] = constructNetwork(refList, queryList, indivAssignments, model.within_label) - isolateClustering[dist_type] = printClusters(indivNetworks[dist_type], - args.output + "/" + os.path.basename(args.output) + "_" + dist_type, - externalClusterCSV = args.external_clustering) - nx.write_gpickle(indivNetworks[dist_type], args.output + "/" + os.path.basename(args.output) + - "_" + dist_type + '_graph.gpickle') - if args.core_only: - fit_type = 'core' - genomeNetwork = indivNetworks['core'] - elif args.accessory_only: - fit_type = 'accessory' - genomeNetwork = indivNetworks['accessory'] + for rank in rank_list: + indivNetworks[rank] = gt.load_graph(args.output + "/" + os.path.basename(args.output) + '_rank_' + str(rank) + '_lineages.gt') + if rank == min(rank_list): + genomeNetwork = indivNetworks[rank] #******************************# #* *# @@ -500,16 +550,21 @@ def main(): overwrite = args.overwrite, microreact = args.microreact) if args.cytoscape: sys.stderr.write("Writing cytoscape output\n") - outputsForCytoscape(genomeNetwork, isolateClustering, args.output, args.info_csv) - if model.indiv_fitted: - sys.stderr.write("Writing individual cytoscape networks\n") - for dist_type in ['core', 'accessory']: - outputsForCytoscape(indivNetworks[dist_type], isolateClustering, args.output, - args.info_csv, suffix = dist_type, writeCsv = False) + if args.lineage_clustering: + for rank in rank_list: + outputsForCytoscape(indivNetworks[rank], isolateClustering, args.output, + args.info_csv, suffix = 'rank_' + str(rank), writeCsv = False) + else: + outputsForCytoscape(genomeNetwork, isolateClustering, args.output, args.info_csv) + if model.indiv_fitted: + sys.stderr.write("Writing individual cytoscape networks\n") + for dist_type in ['core', 'accessory']: + outputsForCytoscape(indivNetworks[dist_type], isolateClustering, args.output, + args.info_csv, suffix = dist_type, writeCsv = False) except: # Go ahead with final steps even if visualisations fail # (e.g. rapidnj not found) - sys.stderr.write("Error creating files for visualisation:" + str(sys.exc_info()[0])) + sys.stderr.write("Error creating files for visualisation: " + str(sys.exc_info()[0])) #******************************# #* *# @@ -518,10 +573,11 @@ def main(): #******************************# # extract limited references from clique by default if not args.full_db: - newReferencesNames, newReferencesFile = extractReferences(genomeNetwork, refList, args.output) - nodes_to_remove = set(refList).difference(newReferencesNames) - genomeNetwork.remove_nodes_from(nodes_to_remove) - prune_distance_matrix(refList, nodes_to_remove, distMat, + newReferencesIndices, newReferencesNames, newReferencesFile, genomeNetwork = extractReferences(genomeNetwork, refList, args.output) + nodes_to_remove = set(range(len(refList))).difference(newReferencesIndices) +# + names_to_remove = [refList[n] for n in nodes_to_remove] + prune_distance_matrix(refList, names_to_remove, distMat, args.output + "/" + os.path.basename(args.output) + ".dists") # With mash, the sketches are actually removed from the database @@ -534,38 +590,7 @@ def main(): args.estimated_length, True, args.threads, True) # overwrite old db os.remove(dummyRefFile) - nx.write_gpickle(genomeNetwork, args.output + "/" + os.path.basename(args.output) + '_graph.gpickle') - - #******************************# - #* *# - #* within-strain analysis *# - #* *# - #******************************# - if args.lineage_clustering: - sys.stderr.write("Mode: Identifying lineages within a clade\n\n") - - # load distances - if args.distances is not None: - distances = args.distances - else: - sys.stderr.write("Need to provide an input set of distances with --distances\n\n") - sys.exit(1) - - refList, queryList, self, distMat = readPickle(distances) - - # make directory for new output files - if not os.path.isdir(args.output): - try: - os.makedirs(args.output) - except OSError: - sys.stderr.write("Cannot create output directory\n") - sys.exit(1) - - # run lineage clustering - if self: - isolateClustering = cluster_into_lineages(distMat, rank_list, args.output, isolate_list = refList, use_accessory = args.use_accessory, existing_scheme = args.existing_scheme, num_processes = args.threads) - else: - isolateClustering = cluster_into_lineages(distMat, rank_list, args.output, isolate_list = refList, qlist = queryList, use_accessory = args.use_accessory, existing_scheme = args.existing_scheme, num_processes = args.threads) + genomeNetwork.save(args.output + "/" + os.path.basename(args.output) + '_graph.gt', fmt = 'gt') #*******************************# #* *# @@ -632,9 +657,8 @@ def main(): postpruning_combined_seq, newDistMat = prune_distance_matrix(rlist, isolates_to_remove, complete_distMat, dists_out) - rlist = viz_subset combined_seq, core_distMat, acc_distMat = \ - update_distance_matrices(viz_subset, newDistMat, + update_distance_matrices(viz_subset, newDistMat, threads = args.threads) # reorder subset to ensure list orders match @@ -666,14 +690,11 @@ def main(): prev_clustering = args.previous_clustering else: prev_clustering = os.path.dirname(args.distances + ".pkl") - - # Read in network and cluster assignment - genomeNetwork, cluster_file = fetchNetwork(prev_clustering, model, rlist, args.core_only, args.accessory_only) + + # load clustering + cluster_file = args.ref_db + '/' + args.ref_db + '_clusters.csv' isolateClustering = readIsolateTypeFromCsv(cluster_file, mode = 'clusters', return_dict = True) - - # prune the network and dictionary of assignments - genomeNetwork.remove_nodes_from(set(genomeNetwork.nodes).difference(viz_subset)) - + cluster_file = args.ref_db + '/' + os.path.basename(args.ref_db) + '_clusters.csv' # generate selected visualisations if args.microreact: sys.stderr.write("Writing microreact output\n") @@ -688,11 +709,22 @@ def main(): outputsForGrapetree(viz_subset, core_distMat, isolateClustering, args.output, args.info_csv, args.rapidnj, overwrite = args.overwrite, microreact = args.microreact) if args.cytoscape: - if args.viz_lineages or args.external_clustering: - sys.stderr.write("Can only generate a network output for fitted models\n") + sys.stderr.write("Writing cytoscape output\n") + if args.viz_lineages: + for rank in isolateClustering.keys(): + numeric_rank = rank.split('_')[1] + if numeric_rank.isdigit(): + genomeNetwork = gt.load_graph(args.ref_db + '/' + os.path.basename(args.ref_db) + '_rank_' + str(numeric_rank) + '_lineages.gt') + outputsForCytoscape(genomeNetwork, isolateClustering, args.output, + args.info_csv, suffix = 'rank_' + str(rank), viz_subset = viz_subset) else: - sys.stderr.write("Writing cytoscape output\n") - outputsForCytoscape(genomeNetwork, isolateClustering, args.output, args.info_csv) + genomeNetwork, cluster_file = fetchNetwork(prev_clustering, model, rlist, args.core_only, args.accessory_only) + outputsForCytoscape(genomeNetwork, isolateClustering, args.output, args.info_csv, viz_subset = viz_subset) + if model.indiv_fitted: + sys.stderr.write("Writing individual cytoscape networks\n") + for dist_type in ['core', 'accessory']: + outputsForCytoscape(indivNetworks[dist_type], isolateClustering, args.output, + args.info_csv, suffix = dist_type, viz_subset = viz_subset) else: # Cannot read input files @@ -813,7 +845,8 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances print_full_clustering = False if update_db: print_full_clustering = True - isolateClustering = {'combined': printClusters(genomeNetwork, output + "/" + os.path.basename(output), + isolateClustering = {'combined': printClusters(genomeNetwork, refList + ordered_queryList, + output + "/" + os.path.basename(output), old_cluster_file, external_clustering, print_full_clustering)} # Update DB as requested @@ -828,11 +861,11 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances # Update the network + ref list # only update network if assigning to strains if full_db is False and assign_lineage is False: - mashOrder = refList + ordered_queryList - newRepresentativesNames, newRepresentativesFile = extractReferences(genomeNetwork, mashOrder, output, refList) - genomeNetwork.remove_nodes_from(set(genomeNetwork.nodes).difference(newRepresentativesNames)) + dbOrder = refList + ordered_queryList + newRepresentativesIndices, newRepresentativesNames, newRepresentativesFile, genomeNetwork = extractReferences(genomeNetwork, dbOrder, output, refList) + isolates_to_remove = set(dbOrder).difference(newRepresentativesNames) newQueries = [x for x in ordered_queryList if x in frozenset(newRepresentativesNames)] # intersection that maintains order - nx.write_gpickle(genomeNetwork, output + "/" + os.path.basename(output) + '_graph.gpickle') + genomeNetwork.save(output + "/" + os.path.basename(output) + '_graph.gt', fmt = 'gt') else: newQueries = ordered_queryList @@ -840,7 +873,7 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances if newQueries != queryList and use_mash: tmpRefFile = writeTmpFile(newQueries) constructDatabase(tmpRefFile, kmers, sketch_sizes, output, - args.estimated_length, True, threads, True) # overwrite old db + estimated_length, True, threads, True) # overwrite old db os.remove(tmpRefFile) # With mash, this is the reduced DB constructed, # with sketchlib, all sketches @@ -877,7 +910,7 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances if full_db is False and assign_lineage is False: # could also have newRepresentativesNames in this diff (should be the same) - but want # to ensure consistency with the network in case of bad input/bugs - nodes_to_remove = set(combined_seq).difference(genomeNetwork.nodes) + nodes_to_remove = set(combined_seq).difference(newRepresentativesNames) # This function also writes out the new distance matrix postpruning_combined_seq, newDistMat = prune_distance_matrix(combined_seq, nodes_to_remove, complete_distMat, dists_out) @@ -905,9 +938,6 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances outputsForGrapetree(combined_seq, core_distMat, isolateClustering, output, info_csv, rapidnj, queryList = ordered_queryList, overwrite = overwrite, microreact = microreact) if cytoscape: - if assign_lineage: - sys.stderr.write("Cannot generate a cytoscape network from a lineage assignment") - else: sys.stderr.write("Writing cytoscape output\n") outputsForCytoscape(genomeNetwork, isolateClustering, output, info_csv, ordered_queryList) diff --git a/PopPUNK/lineage_clustering.py b/PopPUNK/lineage_clustering.py index 2223e480..7afe57c6 100644 --- a/PopPUNK/lineage_clustering.py +++ b/PopPUNK/lineage_clustering.py @@ -12,7 +12,7 @@ from collections import defaultdict import pickle import collections -import networkx as nx +import graph_tool.all as gt from multiprocessing import Pool, RawArray, shared_memory, managers try: from multiprocessing import Pool, shared_memory @@ -91,17 +91,16 @@ def get_nearest_neighbours(rank, isolates = None, ranks = None): frozen set of nearest neighbours. """ # data structure - nn = {} + nn = set() # load shared ranks ranks_shm = shared_memory.SharedMemory(name = ranks.name) ranks = np.ndarray(ranks.shape, dtype = ranks.dtype, buffer = ranks_shm.buf) # apply along axis for i in isolates: - nn[i] = defaultdict(frozenset) isolate_ranks = ranks[i,:] closest_ranked = np.ravel(np.where(isolate_ranks <= rank)) - neighbours = frozenset(closest_ranked.tolist()) - nn[i] = neighbours + for j in closest_ranked.tolist(): + nn.add((i,j)) # return dict return nn @@ -213,12 +212,17 @@ def cluster_into_lineages(distMat, rank_list = None, output = None, """ # data structures - lineage_clustering = defaultdict(dict) + lineage_assignation = defaultdict(dict) overall_lineage_seeds = defaultdict(dict) overall_lineages = defaultdict(dict) + max_existing_cluster = {rank:1 for rank in rank_list} + + # load existing scheme if supplied if existing_scheme is not None: with open(existing_scheme, 'rb') as pickle_file: - lineage_clustering, overall_lineage_seeds, rank_list = pickle.load(pickle_file) + lineage_assignation, overall_lineage_seeds, rank_list = pickle.load(pickle_file) + for rank in rank_list: + max_existing_cluster[rank] = max(lineage_assignation[rank].values()) + 1 # generate square distance matrix seqLabels, coreMat, accMat = \ @@ -263,25 +267,75 @@ def cluster_into_lineages(distMat, rank_list = None, output = None, distance_ranks_shared_array = np.ndarray(distance_ranks.shape, dtype = distance_ranks.dtype, buffer = distance_ranks_raw.buf) distance_ranks_shared_array[:] = distance_ranks[:] distance_ranks_shared_array = NumpyShared(name = distance_ranks_raw.name, shape = distance_ranks.shape, dtype = distance_ranks.dtype) - + + # build a graph framework for network outputs + # create graph structure with internal vertex property map + # storing lineage assignation cannot load boost.python within spawned + # processes so have to run network analysis separately + G = gt.Graph(directed = False) + G.add_vertex(len(isolate_list)) + # add sequence labels for visualisation + vid = G.new_vertex_property('string', + vals = isolate_list) + G.vp.id = vid + # parallelise neighbour identification for each rank with Pool(processes = num_processes) as pool: - results = pool.map(partial(run_clustering_for_rank, - distances_input = distances_shared_array, - distance_ranks_input = distance_ranks_shared_array, - isolates = isolate_list_shared, - previous_seeds = overall_lineage_seeds), + results = pool.map(partial(get_nearest_neighbours, + ranks = distance_ranks_shared_array, + isolates = isolate_list_shared), rank_list) - # extract results from multiprocessing pool + # extract results from multiprocessing pool and save output network + nn = defaultdict(dict) + for n,result in enumerate(results): + # get results per rank rank = rank_list[n] - lineage_clustering[rank], overall_lineage_seeds[rank] = result + # get neigbours + edges_to_add = result + # store results in network + G.add_edge_list(edges_to_add) + # calculate connectivity of each vertex + vertex_out_degrees = G.get_out_degrees(G.get_vertices()) + # identify components and rank by frequency + components, component_frequencies = gt.label_components(G) + component_frequency_ranks = (len(component_frequencies) - rankdata(component_frequencies, method = 'ordinal').astype(int)).tolist() + # construct a name translation table + # begin with previously defined clusters + component_name = [None] * len(component_frequencies) + for seed in overall_lineage_seeds[rank]: + isolate_index = isolate_list.index(seed) + component_number = components[isolate_index] + if component_name[component_number] is None or component_name[component_number] > overall_lineage_seeds[rank][seed]: + component_name[component_number] = overall_lineage_seeds[rank][seed] + # name remaining components in rank order + for component_rank in range(len(component_frequency_ranks)): +# + component_number = component_frequency_ranks.index(component_rank) + if component_name[component_number] is None: + component_name[component_number] = max_existing_cluster[rank] + # find seed isolate + component_max_degree = np.amax(vertex_out_degrees[np.where(components.a == component_number)]) + seed_isolate_index = int(np.where((components.a == component_number) & (vertex_out_degrees == component_max_degree))[0][0]) + seed_isolate = isolate_list[seed_isolate_index] + overall_lineage_seeds[rank][seed_isolate] = max_existing_cluster[rank] + # increment + max_existing_cluster[rank] = max_existing_cluster[rank] + 1 + # store assignments + for isolate_index,isolate_name in enumerate(isolate_list): + original_component = components.a[isolate_index] + renamed_component = component_name[original_component] + lineage_assignation[rank][isolate_name] = renamed_component + # save network + G.save(file_name = output + "/" + os.path.basename(output) + '_rank_' + str(rank) + '_lineages.gt', fmt = 'gt') + # clear edges - nodes in graph can be reused but edges differ between ranks + G.clear_edges() # store output with open(output + "/" + output + '_lineages.pkl', 'wb') as pickle_file: - pickle.dump([lineage_clustering, overall_lineage_seeds, rank_list], pickle_file) - + pickle.dump([lineage_assignation, overall_lineage_seeds, rank_list], pickle_file) + # process multirank lineages overall_lineages = {} overall_lineages = {'Rank_' + str(rank):{} for rank in rank_list} @@ -289,11 +343,11 @@ def cluster_into_lineages(distMat, rank_list = None, output = None, for index,isolate in enumerate(isolate_list): overall_lineage = None for rank in rank_list: - overall_lineages['Rank_' + str(rank)][isolate] = lineage_clustering[rank][index] + overall_lineages['Rank_' + str(rank)][isolate] = lineage_assignation[rank][isolate] if overall_lineage is None: - overall_lineage = str(lineage_clustering[rank][index]) + overall_lineage = str(lineage_assignation[rank][isolate]) else: - overall_lineage = overall_lineage + '-' + str(lineage_clustering[rank][index]) + overall_lineage = overall_lineage + '-' + str(lineage_assignation[rank][isolate]) overall_lineages['overall'][isolate] = overall_lineage # print output as CSV @@ -326,13 +380,13 @@ def run_clustering_for_rank(rank, distances_input = None, distance_ranks_input = Whether to extend a previously generated analysis or not. Returns: - lineage_clustering (dict) + lineage_assignation (dict) Assignment of each isolate to a cluster. lineage_seed (dict) Seed isolate used to initiate each cluster. - neighbours (nested dict) - Neighbour relationships between isolates for R. - """ + connections (set of tuples) + Edges to add to network describing lineages. + """ # load shared memory objects distances_shm = shared_memory.SharedMemory(name = distances_input.name) @@ -347,11 +401,6 @@ def run_clustering_for_rank(rank, distances_input = None, distance_ranks_input = if previous_seeds is not None: seeds = previous_seeds[rank] - # create graph structure - G = nx.Graph() - G.add_nodes_from(isolate_indices) - G.nodes.data('lineage', default = 0) - # identify nearest neighbours nn = get_nearest_neighbours(rank, ranks = distance_ranks_input, @@ -359,19 +408,20 @@ def run_clustering_for_rank(rank, distances_input = None, distance_ranks_input = # iteratively identify lineages lineage_index = 1 - while nx.number_of_isolates(G) > 0: + connections = set() + lineage_assignation = {isolate:None for isolate in isolate_list} + + while None in lineage_assignation.values(): if lineage_index in seeds.keys(): seed_isolate = seeds[lineage_index] else: - seed_isolate = pick_seed_isolate(G, distances = distances_input) + seed_isolate = pick_seed_isolate(lineage_assignation, distances = distances_input) # skip over previously-defined seeds if amalgamated into different lineage now - if nx.is_isolate(G, seed_isolate): + if lineage_assignation[seed_isolate] is None: seeds[lineage_index] = seed_isolate - G = get_lineage(G, nn, seed_isolate, lineage_index) + lineage_assignation, added_connections = get_lineage(lineage_assignation, nn, seed_isolate, lineage_index) + connections.update(added_connections) lineage_index = lineage_index + 1 - - # identify components and name lineages - lineage_clustering = {node:nodedata for (node, nodedata) in G.nodes(data='lineage')} - + # return clustering - return lineage_clustering, seeds + return lineage_assignation, seeds, nn, connections diff --git a/PopPUNK/mash.py b/PopPUNK/mash.py index f8e9bc3d..c153195a 100644 --- a/PopPUNK/mash.py +++ b/PopPUNK/mash.py @@ -18,7 +18,6 @@ from glob import glob from random import sample import numpy as np -import networkx as nx from scipy import optimize try: from multiprocessing import Pool, shared_memory @@ -542,10 +541,10 @@ def queryDatabase(rNames, qNames, dbPrefix, queryPrefix, klist, self = True, num # Check mash output is consistent with expected order # This is ok in all tests, but best to check and exit in case something changes between mash versions expected_names = iterDistRows(refList, qNames, self) - prev_ref = "" skip = 0 skipped = 0 + for line in mashOut: # Skip the first row with self and symmetric elements if skipped < skip: @@ -602,17 +601,20 @@ def queryDatabase(rNames, qNames, dbPrefix, queryPrefix, klist, self = True, num # run pairwise analyses across kmer lengths, mutating distMat # Create range of rows that each thread will work with + # if there is only one pair, apply_along_axis will not work + if threads > number_pairs: + threads = number_pairs rows_per_thread = int(number_pairs / threads) big_threads = number_pairs % threads start = 0 mat_chunks = [] + for thread in range(threads): end = start + rows_per_thread if thread < big_threads: end += 1 mat_chunks.append((start, end)) start = end - # create empty distMat that can be shared with multiple processes distMat = np.zeros((number_pairs, 2), dtype=raw.dtype) with SharedMemoryManager() as smm: @@ -624,7 +626,6 @@ def queryDatabase(rNames, qNames, dbPrefix, queryPrefix, klist, self = True, num shm_distMat = smm.SharedMemory(size = distMat.nbytes) distMat_shared = NumpyShared(name = shm_distMat.name, shape = (number_pairs, 2), dtype = raw.dtype) - # Run regressions with Pool(processes = threads) as pool: pool.map(partial(fitKmerBlock, @@ -668,7 +669,10 @@ def fitKmerBlock(idxRanges, distMat, raw, klist, jacobian): # analyse (start, end) = idxRanges - distMat[start:end, :] = np.apply_along_axis(fitKmerCurve, 1, raw[start:end, :], klist, jacobian) + if raw.shape[0] == 1: + distMat[start:end, :] = fitKmerCurve(raw[0,:], klist, jacobian) + else: + distMat[start:end, :] = np.apply_along_axis(fitKmerCurve, 1, raw[start:end, :], klist, jacobian) def fitKmerCurve(pairwise, klist, jacobian): @@ -707,4 +711,3 @@ def fitKmerCurve(pairwise, klist, jacobian): # Return core, accessory return(np.flipud(transformed_params)) - diff --git a/PopPUNK/network.py b/PopPUNK/network.py index 3d4fadfd..41fe4150 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -12,24 +12,28 @@ import operator import shutil import subprocess -import networkx as nx +import graph_tool.all as gt import numpy as np import pandas as pd +from scipy.stats import rankdata from tempfile import mkstemp, mkdtemp from collections import defaultdict, Counter from .sketchlib import calculateQueryQueryDistances from .utils import iterDistRows +from .utils import listDistInts from .utils import readIsolateTypeFromCsv from .utils import readRfile +from .utils import setupDBFuncs +from .utils import isolateNameToLabel def fetchNetwork(network_dir, model, refList, core_only = False, accessory_only = False): """Load the network based on input options - Returns the network as a networkx, and sets the slope parameter of - the passed model object. + Returns the network as a graph-tool format graph, and sets + the slope parameter of the passed model object. Args: network_dir (str) @@ -48,7 +52,7 @@ def fetchNetwork(network_dir, model, refList, [default = False] Returns: - genomeNetwork (nx.Graph) + genomeNetwork (graph) The loaded network cluster_file (str) The CSV of cluster assignments corresponding to this network @@ -56,39 +60,39 @@ def fetchNetwork(network_dir, model, refList, # If a refined fit, may use just core or accessory distances if core_only and model.type == 'refine': model.slope = 0 - network_file = network_dir + "/" + os.path.basename(network_dir) + '_core_graph.gpickle' + network_file = network_dir + "/" + os.path.basename(network_dir) + '_core_graph.gt' cluster_file = network_dir + "/" + os.path.basename(network_dir) + '_core_clusters.csv' elif accessory_only and model.type == 'refine': model.slope = 1 - network_file = network_dir + "/" + os.path.basename(network_dir) + '_accessory_graph.gpickle' + network_file = network_dir + "/" + os.path.basename(network_dir) + '_accessory_graph.gt' cluster_file = network_dir + "/" + os.path.basename(network_dir) + '_accessory_clusters.csv' else: - network_file = network_dir + "/" + os.path.basename(network_dir) + '_graph.gpickle' + network_file = network_dir + "/" + os.path.basename(network_dir) + '_graph.gt' cluster_file = network_dir + "/" + os.path.basename(network_dir) + '_clusters.csv' if core_only or accessory_only: sys.stderr.write("Can only do --core-only or --accessory-only fits from " "a refined fit. Using the combined distances.\n") - genomeNetwork = nx.read_gpickle(network_file) - sys.stderr.write("Network loaded: " + str(genomeNetwork.number_of_nodes()) + " samples\n") + genomeNetwork = gt.load_graph(network_file) + sys.stderr.write("Network loaded: " + str(len(list(genomeNetwork.vertices()))) + " samples\n") # Ensure all in dists are in final network - networkMissing = set(refList).difference(list(genomeNetwork.nodes())) + networkMissing = set(range(len(refList))).difference(list(genomeNetwork.vertices())) if len(networkMissing) > 0: sys.stderr.write("WARNING: Samples " + ",".join(networkMissing) + " are missing from the final network\n") return (genomeNetwork, cluster_file) -def extractReferences(G, mashOrder, outPrefix, existingRefs = None): +def extractReferences(G, dbOrder, outPrefix, existingRefs = None): """Extract references for each cluster based on cliques Writes chosen references to file by calling :func:`~writeReferences` Args: - G (networkx.Graph) + G (graph) A network used to define clusters from :func:`~constructNetwork` - mashOrder (list) + dbOrder (list) The order of files in the sketches, so returned references are in the same order outPrefix (str) Prefix for output file (.refs will be appended) @@ -103,60 +107,84 @@ def extractReferences(G, mashOrder, outPrefix, existingRefs = None): """ if existingRefs == None: references = set() + reference_indices = [] else: references = set(existingRefs) - + index_lookup = {v:k for k,v in enumerate(dbOrder)} + reference_indices = [index_lookup[r] for r in references] + # extract cliques from network - cliques = list(nx.find_cliques(G)) + cliques_in_overall_graph = [c.tolist() for c in gt.max_cliques(G)] # order list by size of clique - cliques.sort(key = len, reverse=True) + cliques_in_overall_graph.sort(key = len, reverse = True) # iterate through cliques - for clique in cliques: + for clique in cliques_in_overall_graph: alreadyRepresented = 0 for node in clique: - if node in references: + if node in reference_indices: alreadyRepresented = 1 break if alreadyRepresented == 0: - references.add(clique[0]) + reference_indices.append(clique[0]) # Find any clusters which are represented by multiple references - clusters = printClusters(G, printCSV=False) - ref_clusters = set() - multi_ref_clusters = set() - for reference in references: - if clusters[reference] in ref_clusters: - multi_ref_clusters.add(clusters[reference]) + # First get cluster assignments + clusters_in_overall_graph = printClusters(G, dbOrder, printCSV=False) + # Construct a dict containing one empty set for each cluster + reference_clusters_in_overall_graph = [set() for c in set(clusters_in_overall_graph.items())] + # Iterate through references + for reference_index in reference_indices: + # Add references to the originally empty set for the appropriate cluster + # Allows enumeration of the number of references per cluster + reference_clusters_in_overall_graph[clusters_in_overall_graph[dbOrder[reference_index]]].add(reference_index) + + # Use a vertex filter to extract the subgraph of refences + # as a graphview + reference_vertex = G.new_vertex_property('bool') + for n,vertex in enumerate(G.vertices()): + if n in reference_indices: + reference_vertex[vertex] = True else: - ref_clusters.add(clusters[reference]) - - # Check if these multi reference components have been split - if len(multi_ref_clusters) > 0: - # Initial reference graph - ref_G = G.copy() - ref_G.remove_nodes_from(set(ref_G.nodes).difference(references)) - - for multi_ref_cluster in multi_ref_clusters: - # Get a list of nodes that need to be in the same component - check = [] - for reference in references: - if clusters[reference] == multi_ref_cluster: - check.append(reference) - - # Pairwise check that nodes are in same component + reference_vertex[vertex] = False + G_ref = gt.GraphView(G, vfilt = reference_vertex) + G_ref = gt.Graph(G_ref, prune = True) # https://stackoverflow.com/questions/30839929/graph-tool-graphview-object + # Calculate component membership for reference graph + clusters_in_reference_graph = printClusters(G, dbOrder, printCSV=False) + # Record to which components references below in the reference graph + reference_clusters_in_reference_graph = {} + for reference_index in reference_indices: + reference_clusters_in_reference_graph[dbOrder[reference_index]] = clusters_in_reference_graph[dbOrder[reference_index]] + + # Check if multi-reference components have been split as a validation test + # First iterate through clusters + network_update_required = False + for cluster in reference_clusters_in_overall_graph: + # Identify multi-reference clusters by this length + if len(cluster) > 1: + check = list(cluster) + # check if these are still in the same component in the reference graph for i in range(len(check)): - component = nx.node_connected_component(ref_G, check[i]) + component_i = reference_clusters_in_reference_graph[dbOrder[check[i]]] for j in range(i, len(check)): # Add intermediate nodes - if check[j] not in component: - new_path = nx.shortest_path(G, check[i], check[j]) - for node in new_path: - references.add(node) - + component_j = reference_clusters_in_reference_graph[dbOrder[check[j]]] + if component_i != component_j: + network_update_required = True + vertex_list, edge_list = gt.shortest_path(G, check[i], check[j]) + # update reference list + for vertex in vertex_list: + reference_vertex[vertex] = True + reference_indices.add(int(vertex)) + + # update reference graph if vertices have been added + if network_update_required: + G_ref = gt.GraphView(G, vfilt = reference_vertex) + G_ref = gt.Graph(G_ref, prune = True) # https://stackoverflow.com/questions/30839929/graph-tool-graphview-object + # Order found references as in mash sketch files - references = [x for x in mashOrder if x in references] - refFileName = writeReferences(references, outPrefix) - return references, refFileName + reference_names = [dbOrder[int(x)] for x in sorted(reference_indices)] + refFileName = writeReferences(reference_names, outPrefix) + return reference_indices, reference_names, refFileName, G_ref def writeReferences(refList, outPrefix): """Writes chosen references to file @@ -223,25 +251,35 @@ def constructNetwork(rlist, qlist, assignments, within_label, summarise = True): (default = True) Returns: - G (networkx.Graph) + G (graph) The resulting network """ + # data structures connections = [] - for assignment, (ref, query) in zip(assignments, iterDistRows(rlist, qlist, self=True)): + self_comparison = True + vertex_labels = rlist + + # check if self comparison + if rlist != qlist: + self_comparison = False + vertex_labels.append(qlist) + + # identify edges + for assignment, (ref, query) in zip(assignments, listDistInts(rlist, qlist, self = self_comparison)): if assignment == within_label: connections.append((ref, query)) - density_proportion = len(connections) / (0.5 * (len(rlist) * (len(rlist) + 1))) - if density_proportion > 0.4 or len(connections) > 500000: - sys.stderr.write("Warning: trying to create very dense network\n") - # build the graph - G = nx.Graph() - G.add_nodes_from(rlist) - for connection in connections: - G.add_edge(*connection) + G = gt.Graph(directed = False) + G.add_vertex(len(vertex_labels)) + G.add_edge_list(connections) + + # add isolate ID to network + vid = G.new_vertex_property('string', + vals = vertex_labels) + G.vp.id = vid - # give some summaries + # print some summaries if summarise: (components, density, transitivity, score) = networkSummary(G) sys.stderr.write("Network summary:\n" + "\n".join(["\tComponents\t" + str(components), @@ -256,7 +294,7 @@ def networkSummary(G): """Provides summary values about the network Args: - G (networkx.Graph) + G (graph) The network of strains from :func:`~constructNetwork` Returns: @@ -269,9 +307,10 @@ def networkSummary(G): score (float) A score of network fit, given by :math:`\mathrm{transitivity} * (1-\mathrm{density})` """ - components = nx.number_connected_components(G) - density = nx.density(G) - transitivity = nx.transitivity(G) + component_assignments, component_frequencies = gt.label_components(G) + components = len(component_frequencies) + density = len(list(G.edges()))/(0.5 * len(list(G.vertices())) * (len(list(G.vertices())) - 1)) + transitivity = gt.global_clustering(G)[0] score = transitivity * (1-density) return(components, density, transitivity, score) @@ -289,7 +328,7 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length, List of reference names qfile (str) File containing queries - G (networkx.Graph) + G (graph) Network to add to (mutated) kmers (list) List of k-mer sizes @@ -321,6 +360,11 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length, distMat (numpy.array) Query-query distances """ + # initalise functions + readDBParams = dbFuncs['readDBParams'] + constructDatabase = dbFuncs['constructDatabase'] + queryDatabase = dbFuncs['queryDatabase'] + readDBParams = dbFuncs['readDBParams'] # initialise links data structure new_edges = [] @@ -332,19 +376,23 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length, # Set up query names qList, qSeqs = readRfile(qfile, oneSeq = use_mash) - queryFiles = dict(zip(qList, qSeqs)) if use_mash == True: + # mash must use sequence file names for both testing for + # assignment and for generating a new database rNames = None - qNames = qSeqs + qNames = isolateNameToLabel(qSeqs) else: rNames = qList qNames = rNames + queryFiles = dict(zip(qNames, qSeqs)) # store links for each query in a list of edge tuples - for assignment, (ref, query) in zip(assignments, iterDistRows(rlist, qList, self=False)): + ref_count = len(rlist) + for assignment, (ref, query) in zip(assignments, listDistInts(rlist, qNames, self = False)): if assignment == model.within_label: - new_edges.append((ref, query)) - assigned.add(query) + # query index needs to be adjusted for existing vertices in network + new_edges.append((ref, query + ref_count)) + assigned.add(qNames[query]) # Calculate all query-query distances too, if updating database if queryQuery: @@ -359,15 +407,15 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length, threads) queryAssignation = model.assign(distMat) - for assignment, (ref, query) in zip(queryAssignation, iterDistRows(qlist1, qlist1, self=True)): + for assignment, (ref, query) in zip(queryAssignation, listDistInts(qNames, qNames, self = True)): if assignment == model.within_label: - new_edges.append((ref, query)) + new_edges.append((ref + ref_count, query + ref_count)) # Otherwise only calculate query-query distances for new clusters else: # identify potentially new lineages in list: unassigned is a list of queries with no hits - unassigned = set(qNames).difference(assigned) - + unassigned = set(qSeqs).difference(assigned) + query_indices = {k:v+ref_count for v,k in enumerate(qSeqs)} # process unassigned query sequences, if there are any if len(unassigned) > 1: sys.stderr.write("Found novel query clusters. Calculating distances between them:\n") @@ -387,8 +435,8 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length, # use database construction methods to find links between unassigned queries sketchSize = readDBParams(queryDB, kmers, None)[1] - constructDatabase(tmpFile, kmers, sketchSize, tmpDirName, estimated_length, True, threads, False) + qlist1, qlist2, distMat = queryDatabase(rNames = list(unassigned), qNames = list(unassigned), dbPrefix = tmpDirName, @@ -397,31 +445,37 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length, self = True, number_plot_fits = 0, threads = threads) + queryAssignation = model.assign(distMat) # identify any links between queries and store in the same links dict # links dict now contains lists of links both to original database and new queries - for assignment, (query1, query2) in zip(queryAssignation, iterDistRows(qlist1, qlist2, self=True)): + # have to use names and link to query list in order to match to node indices + for assignment, (query1, query2) in zip(queryAssignation, iterDistRows(qlist1, qlist2, self = True)): if assignment == model.within_label: - new_edges.append((query1, query2)) + new_edges.append((query_indices[query1], query_indices[query2])) # remove directory shutil.rmtree(tmpDirName) # finish by updating the network - G.add_nodes_from(qNames) - G.add_edges_from(new_edges) + G.add_vertex(len(qNames)) + G.add_edge_list(new_edges) + + # including the vertex ID property map + for i,q in enumerate(qSeqs): + G.vp.id[i + len(rlist)] = q return qlist1, distMat -def printClusters(G, outPrefix = "_clusters.csv", oldClusterFile = None, +def printClusters(G, rlist, outPrefix = "_clusters.csv", oldClusterFile = None, externalClusterCSV = None, printRef = True, printCSV = True, clustering_type = 'combined'): """Get cluster assignments Also writes assignments to a CSV file Args: - G (networkx.Graph) + G (graph) Network used to define clusters (from :func:`~constructNetwork` or :func:`~addQueryToNetwork`) outPrefix (str) @@ -453,7 +507,15 @@ def printClusters(G, outPrefix = "_clusters.csv", oldClusterFile = None, if oldClusterFile == None and printRef == False: raise RuntimeError("Trying to print query clusters with no query sequences") - newClusters = sorted(nx.connected_components(G), key=len, reverse=True) + # get a sorted list of component assignments + component_assignments, component_frequencies = gt.label_components(G) + component_frequency_ranks = len(component_frequencies) - rankdata(component_frequencies, method = 'ordinal').astype(int) + newClusters = [set() for rank in range(len(component_frequency_ranks))] + for isolate_index, isolate_name in enumerate(rlist): + component = component_assignments.a[isolate_index] + component_rank = component_frequency_ranks[component] + newClusters[component_rank].add(isolate_name) + oldNames = set() if oldClusterFile != None: @@ -569,7 +631,6 @@ def printExternalClusters(newClusters, extClusterFile, outPrefix, d = defaultdict(list) # Read in external clusters -# extClusters = readExternalClusters(extClusterFile) readIsolateTypeFromCsv(clustCSV, mode = 'external', return_dict = False) # Go through each cluster (as defined by poppunk) and find the external diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py index 13441f0c..279a89fe 100644 --- a/PopPUNK/plot.py +++ b/PopPUNK/plot.py @@ -24,7 +24,8 @@ except ImportError: from sklearn.neighbors.kde import KernelDensity import dendropy -import networkx as nx + +from .utils import isolateNameToLabel def plot_scatter(X, scale, out_prefix, title, kde = True): """Draws a 2D scatter plot (png) of the core and accessory distances @@ -357,11 +358,11 @@ def get_grid(minimum, maximum, resolution): return(xx, yy, xy) -def outputsForCytoscape(G, clustering, outPrefix, epiCsv, queryList = None, suffix = None, writeCsv = True): +def outputsForCytoscape(G, clustering, outPrefix, epiCsv, queryList = None, suffix = None, writeCsv = True, viz_subset = None): """Write outputs for cytoscape. A graphml of the network, and CSV with metadata Args: - G (networkx.Graph) + G (graph) The network to write from :func:`~PopPUNK.network.constructNetwork` clustering (dict) Dictionary of cluster assignments (keys are nodeNames). @@ -380,19 +381,31 @@ def outputsForCytoscape(G, clustering, outPrefix, epiCsv, queryList = None, suff Whether to print CSV file to accompany network """ + # get list of isolate names + isolate_names = list(G.vp.id) + + # mask network if subsetting + if viz_subset is not None: + viz_vertex = G.new_vertex_property('bool') + for n,vertex in enumerate(G.vertices()): + if isolate_names[n] in viz_subset: + viz_vertex[vertex] = True + else: + viz_vertex[vertex] = False + G.set_vertex_filter(viz_vertex) + # write graph file if suffix is None: graph_file_name = os.path.basename(outPrefix) + "_cytoscape.graphml" else: graph_file_name = os.path.basename(outPrefix) + "_" + suffix + "_cytoscape.graphml" - nx.write_graphml(G, outPrefix + "/" + graph_file_name) + G.save(outPrefix + "/" + graph_file_name, fmt = 'graphml') # Write CSV of metadata if writeCsv: - refNames = G.nodes(data=False) - seqLabels = [r.split('/')[-1].split('.')[0] for r in refNames] + seqLabels = isolateNameToLabel(isolate_names) writeClusterCsv(outPrefix + "/" + outPrefix + "_cytoscape.csv", - refNames, + isolate_names, seqLabels, clustering, 'cytoscape', @@ -468,15 +481,12 @@ def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering, output_format = d = defaultdict(list) if epiCsv is not None: epiData = pd.read_csv(epiCsv, index_col = 0, quotechar='"') - epiData.index = [i.split('/')[-1].split('.')[0] for i in epiData.index] + epiData.index = isolateNameToLabel(epiData.index) for e in epiData.columns.values: colnames.append(str(e)) columns_to_be_omitted = [] - # process clustering data - nodeLabels = [r.split('/')[-1].split('.')[0] for r in nodeNames] - # get example clustering name for validation example_cluster_title = list(clustering.keys())[0] @@ -517,7 +527,7 @@ def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering, output_format = else: d['Status'].append("Reference") elif output_format == 'cytoscape': - d['id'].append(name) + d['id'].append(label) for cluster_type in clustering: col_name = cluster_type + suffix d[col_name].append(clustering[cluster_type][name]) @@ -657,7 +667,7 @@ def outputsForMicroreact(combined_list, coreMat, accMat, clustering, perplexity, from .tsne import generate_tsne # generate sequence labels - seqLabels = [r.split('/')[-1].split('.')[0] for r in combined_list] + seqLabels = isolateNameToLabel(combined_list) # check CSV before calculating other outputs writeClusterCsv(outPrefix + "/" + os.path.basename(outPrefix) + "_microreact_clusters.csv", @@ -751,7 +761,7 @@ def outputsForPhandango(combined_list, coreMat, clustering, outPrefix, epiCsv, r Avoid regenerating tree if already built for microreact (default = False) """ # generate sequence labels - seqLabels = [r.split('/')[-1].split('.')[0] for r in combined_list] + seqLabels = isolateNameToLabel(combined_list) # print clustering file writeClusterCsv(outPrefix + "/" + os.path.basename(outPrefix) + "_phandango_clusters.csv", @@ -802,7 +812,7 @@ def outputsForGrapetree(combined_list, coreMat, clustering, outPrefix, epiCsv, r Avoid regenerating tree if already built for microreact (default = False). """ # generate sequence labels - seqLabels = [r.split('/')[-1].split('.')[0] for r in combined_list] + seqLabels = isolateNameToLabel(combined_list) # print clustering file writeClusterCsv(outPrefix + "/" + os.path.basename(outPrefix) + "_grapetree_clusters.csv", diff --git a/PopPUNK/sketchlib.py b/PopPUNK/sketchlib.py index f2bc04c4..a2165e1f 100644 --- a/PopPUNK/sketchlib.py +++ b/PopPUNK/sketchlib.py @@ -18,7 +18,6 @@ from glob import glob from random import sample import numpy as np -import networkx as nx from scipy import optimize # Try to import sketchlib diff --git a/PopPUNK/utils.py b/PopPUNK/utils.py index faecdd56..997bd638 100644 --- a/PopPUNK/utils.py +++ b/PopPUNK/utils.py @@ -161,6 +161,40 @@ def iterDistRows(refSeqs, querySeqs, self=True): for ref in refSeqs: yield(ref, query) +def listDistInts(refSeqs, querySeqs, self=True): + """Gets the ref and query ID for each row of the distance matrix + + Returns an iterable with ref and query ID pairs by row. + + Args: + refSeqs (list) + List of reference sequence names. + querySeqs (list) + List of query sequence names. + self (bool) + Whether a self-comparison, used when constructing a database. + Requires refSeqs == querySeqs + Default is True + Returns: + ref, query (str, str) + Iterable of tuples with ref and query names for each distMat row. + """ + num_ref = len(refSeqs) + num_query = len(querySeqs) + if self: + if refSeqs != querySeqs: + raise RuntimeError('refSeqs must equal querySeqs for db building (self = true)') + for i in range(num_ref): + for j in range(i + 1, num_ref): + yield(j, i) + else: + comparisons = [(0,0)] * (len(refSeqs) * len(querySeqs)) + for i in range(num_query): + for j in range(num_ref): + yield(j, i) + + return comparisons + def writeTmpFile(fileList): """Writes a list to a temporary file. Used for turning variable into mash input. @@ -175,7 +209,7 @@ def writeTmpFile(fileList): tmpName = mkstemp(suffix=".tmp", dir=".")[1] with open(tmpName, 'w') as tmpFile: for fileName in fileList: - tmpFile.write(fileName + "\n") + tmpFile.write(fileName + '\t' + fileName + "\n") return tmpName @@ -447,3 +481,19 @@ def readRfile(rFile, oneSeq=False): sys.exit(1) return (names, sequences) + +def isolateNameToLabel(names): + """Function to process isolate names to labels + appropriate for visualisation. + + Args: + names (list) + List of isolate names. + Returns: + labels (list) + List of isolate labels. + """ + # useful to have as a function in case we + # want to remove certain characters + labels = [name.split('/')[-1].split('.')[0] for name in names] + return labels diff --git a/docs/conf.py b/docs/conf.py index b986e1b9..f19a9d93 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -38,7 +38,7 @@ # Causes a problem with rtd: https://github.com/pypa/setuptools/issues/1694 autodoc_mock_imports = ["hdbscan", "numpy", - "networkx", + "graph-tool", "pandas", "scipy", "sklearn", diff --git a/docs/installation.rst b/docs/installation.rst index 74ff2091..645ccd92 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -59,7 +59,7 @@ We tested PopPUNK with the following packages: * ``DendroPy`` (4.3.0) * ``hdbscan`` (0.8.13) * ``matplotlib`` (2.1.2) -* ``networkx`` (2.1) +* ``graph-tool`` (2.31) * ``numpy`` (1.14.1) * ``pandas`` (0.22.0) * ``scikit-learn`` (0.19.1) @@ -69,4 +69,3 @@ We tested PopPUNK with the following packages: Optionally, you can use `rapidnj `__ if producing output with ``--microreact`` and ``--rapidnj`` options. We used v2.3.2. - diff --git a/docs/troubleshooting.rst b/docs/troubleshooting.rst index 6418ce05..b392f21b 100644 --- a/docs/troubleshooting.rst +++ b/docs/troubleshooting.rst @@ -10,27 +10,6 @@ installing or running the software please raise an issue on github. Error/warning messages ---------------------- -Errors in graph.py -^^^^^^^^^^^^^^^^^^ -If you get an ``AttributeError``:: - - AttributeError: 'Graph' object has no attribute 'node' - -Then your ``networkx`` package is out of date. Its version needs to be at >=v2.0. - -Trying to create a very large network -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -When using ``--refine-model`` you may see the message:: - - Warning: trying to create very large network - -One or more times. This is triggered if :math:`5 \times 10^5` edges or greater than 40% -of the maximum possible number of edges have been added into the network. This suggests that -the boundary is too large including too many links as within sample. This isn't necessarily a -problem as it can occur at the edge of the optimisation range, so will not be the final optimised -result. However, if you have a large number of samples it may make this step run very slowly -and/or use a lot of memory. If that is the case, decrease ``--pos-shift``. - Row name mismatch ^^^^^^^^^^^^^^^^^ PopPUNK may throw:: @@ -236,7 +215,7 @@ Finding which isolates contribute to these distances reveals a clear culprit:: 1 14412_4_10 28 14412_4_15 -In this case it is sufficent to increase the number of mixture components to four, +In this case it is sufficient to increase the number of mixture components to four, which no longer includes these inflated distances. This gives a score of 0.9401 and 28 components: .. image:: images/contam_DPGMM_better_fit.png @@ -301,4 +280,3 @@ resources. Here are some tips based on these experiences: Another option for scaling is to run ``--create-db`` with a smaller initial set (not using the ``--full-db`` command), then use ``--assign-query`` to add to this. - diff --git a/environment.yml b/environment.yml index 73eb72a5..8ec247c9 100644 --- a/environment.yml +++ b/environment.yml @@ -20,3 +20,4 @@ dependencies: - rapidnj - h5py - pp-sketchlib + - graph-tool diff --git a/requirements.txt b/requirements.txt index 4a73dcbe..fbdb5eec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ DendroPy>=4.3.0 h5py>=2.10.0 hdbscan>=0.8.13 matplotlib>=2.1.2 -networkx>=2.1 +graph-tool>=2.31 numpy>=1.14.1 pandas>=0.22.0 scikit-learn>=0.19.1 diff --git a/scripts/poppunk_extract_components.py b/scripts/poppunk_extract_components.py index e8c3be38..842b5675 100755 --- a/scripts/poppunk_extract_components.py +++ b/scripts/poppunk_extract_components.py @@ -3,7 +3,8 @@ # Copyright 2018 John Lees and Nick Croucher import sys -import networkx as nx +import graph_tool.all as gt +from scipy.stats import rankdata import argparse # command line parsing @@ -14,8 +15,8 @@ def get_options(): prog='extract_components') # input options - parser.add_argument('graph', help='Input graph pickle (.gpickle)') - parser.add_argument('output', help='Prefix for output files') + parser.add_argument('--graph', help='Input graph pickle (.gt)') + parser.add_argument('--output', help='Prefix for output files') return parser.parse_args() @@ -25,13 +26,20 @@ def get_options(): # Check input ok args = get_options() - # open stored distances - G = nx.read_gpickle(args.graph) - sys.stderr.write("Writing " + str(nx.number_connected_components(G)) + " components " + # open stored graph + G = gt.load_graph(args.graph) + + # extract individual components + component_assignments, component_frequencies = gt.label_components(G) + component_frequency_ranks = len(component_frequencies) - rankdata(component_frequencies, method = 'ordinal').astype(int) + sys.stderr.write("Writing " + str(len(component_frequencies)) + " components " "in reverse order of size\n") - components = sorted(nx.connected_components(G), key=len, reverse=True) - for component_idx, component in enumerate(components): - nx.write_graphml(G.subgraph(component), args.output + ".component_" + str(component_idx + 1) + ".graphml") - + # extract as GraphView objects and print + for component_index in range(len(component_frequency_ranks)): + component_gv = gt.GraphView(G, vfilt = component_assignments.a == component_index) + component_G = gt.Graph(component_gv, prune = True) + component_fn = args.output + ".component_" + str(component_frequency_ranks[component_index]) + ".graphml" + component_G.save(component_fn, fmt = 'graphml') + sys.exit(0) diff --git a/setup.py b/setup.py index 1c268b1a..0a38235f 100644 --- a/setup.py +++ b/setup.py @@ -69,7 +69,7 @@ def find_version(*file_paths): 'scikit-learn', 'DendroPy', 'pandas', - 'networkx>=2.0', + 'graph-tool', 'matplotlib', 'hdbscan'], test_suite="test", diff --git a/test/run_test.py b/test/run_test.py index b2507a86..13f889b5 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -14,35 +14,38 @@ #easy run sys.stderr.write("Running database creation + DBSCAN model fit + fit refinement (--easy-run)\n") -subprocess.run("python ../poppunk-runner.py --easy-run --r-files references.txt --min-k 13 --k-step 3 --output example_db --full-db", shell=True, check=True) +subprocess.run("python ../poppunk-runner.py --easy-run --r-files references.txt --min-k 13 --k-step 3 --output example_db --full-db --overwrite", shell=True, check=True) #fit GMM sys.stderr.write("Running GMM model fit (--fit-model)\n") -subprocess.run("python ../poppunk-runner.py --fit-model --distances example_db/example_db.dists --ref-db example_db --output example_db --full-db --K 4 --microreact --cytoscape", shell=True, check=True) +subprocess.run("python ../poppunk-runner.py --fit-model --distances example_db/example_db.dists --ref-db example_db --output example_db --full-db --K 4 --microreact --cytoscape --overwrite", shell=True, check=True) #refine model with GMM sys.stderr.write("Running model refinement (--refine-model)\n") -subprocess.run("python ../poppunk-runner.py --refine-model --distances example_db/example_db.dists --ref-db example_db --output example_refine --neg-shift 0.8", shell=True, check=True) +subprocess.run("python ../poppunk-runner.py --refine-model --distances example_db/example_db.dists --ref-db example_db --output example_refine --neg-shift 0.8 --overwrite", shell=True, check=True) #assign query sys.stderr.write("Running query assignment (--assign-query)\n") -subprocess.run("python ../poppunk-runner.py --assign-query --q-files queries.txt --distances example_db/example_db.dists --ref-db example_db --output example_query --update-db", shell=True, check=True) +subprocess.run("python ../poppunk-runner.py --assign-query --q-files queries.txt --distances example_db/example_db.dists --ref-db example_db --output example_query --update-db --overwrite", shell=True, check=True) #use model sys.stderr.write("Running with an existing model (--use-model)\n") -subprocess.run("python ../poppunk-runner.py --use-model --ref-db example_db --model-dir example_db --distances example_db/example_db.dists --output example_use", shell=True, check=True) +subprocess.run("python ../poppunk-runner.py --use-model --ref-db example_db --model-dir example_db --distances example_db/example_db.dists --output example_use --overwrite", shell=True, check=True) #generate viz sys.stderr.write("Running microreact visualisations (--generate-viz)\n") subprocess.run("python ../poppunk-runner.py --generate-viz --distances example_db/example_db.dists --ref-db example_db --output example_viz --microreact --subset subset.txt", shell=True, check=True) +# general tests +sys.stderr.write("Running general tests\n\n") + # lineage clustering sys.stderr.write("Running lineage clustering test (--lineage-clustering)\n") -subprocess.run("python ../poppunk-runner.py --lineage-clustering --distances example_db/example_db.dists --output example_lineages --ranks 1,2,3,5", shell=True, check=True) +subprocess.run("python ../poppunk-runner.py --lineage-clustering --distances example_db/example_db.dists --output example_lineages --ranks 1,2,3,5 --ref-db example_db --overwrite", shell=True, check=True) # assign query to lineages sys.stderr.write("Running query assignment (--assign-lineages)\n") -subprocess.run("python ../poppunk-runner.py --assign-lineages --q-files queries.txt --distances example_db/example_db.dists --ref-db example_db --existing-scheme example_lineages/example_lineages_lineages.pkl --output example_lineage_query --update-db", shell=True, check=True) +subprocess.run("python ../poppunk-runner.py --assign-lineages --q-files queries.txt --distances example_db/example_db.dists --ref-db example_db --existing-scheme example_lineages/example_lineages_lineages.pkl --output example_lineage_query --update-db --overwrite", shell=True, check=True) # tests of other command line programs (TODO)