diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py index c9895f17..40909770 100644 --- a/PopPUNK/__main__.py +++ b/PopPUNK/__main__.py @@ -17,6 +17,8 @@ from .sketchlib import no_sketchlib, checkSketchlibLibrary +from .lineage_clustering import cluster_into_lineages + from .network import fetchNetwork from .network import constructNetwork from .network import extractReferences @@ -31,15 +33,17 @@ from .prune_db import prune_distance_matrix +from .sketchlib import calculateQueryQueryDistances + from .utils import setupDBFuncs from .utils import storePickle from .utils import readPickle from .utils import writeTmpFile from .utils import qcDistMat -from .utils import readClusters from .utils import translate_distMat from .utils import update_distance_matrices from .utils import readRfile +from .utils import readIsolateTypeFromCsv # Minimum sketchlib version SKETCHLIB_MAJOR = 1 @@ -87,6 +91,14 @@ def get_options(): help='Create model at this core distance threshold', default=None, type=float) + mode.add_argument('--lineage-clustering', + help='Identify lineages within a strain', + default=False, + action='store_true') + mode.add_argument('--assign-lineages', + help='Assign isolates to an existing lineages scheme', + default=False, + action='store_true') mode.add_argument('--use-model', help='Apply a fitted model to a reference database to restore database files', default=False, @@ -101,6 +113,8 @@ def get_options(): iGroup.add_argument('--distances', help='Prefix of input pickle of pre-calculated distances') iGroup.add_argument('--external-clustering', help='File with cluster definitions or other labels ' 'generated with any other method.', default=None) + iGroup.add_argument('--viz-lineages', help='CSV with lineage definitions to use for visualisation' + 'rather than strain definitions.', default=None) # output options oGroup = parser.add_argument_group('Output options') @@ -167,6 +181,13 @@ def get_options(): queryingGroup.add_argument('--accessory-only', help='Use an accessory-distance only model for assigning queries ' '[default = False]', default=False, action='store_true') + # lineage clustering within strains + lineagesGroup = parser.add_argument_group('Lineage analysis options') + lineagesGroup.add_argument('--ranks',help='Comma separated list of ranks used in lineage clustering [default = 1,2,3]', type = str, default = "1,2,3") + lineagesGroup.add_argument('--use-accessory',help='Use accessory distances for lineage definitions [default = use core distances]', action = 'store_true', default = False) + lineagesGroup.add_argument('--existing-scheme',help='Name of pickle file storing existing lineage definitions ' + ', required with "--assign-lineages"', type = str, default = None) + # plot output faGroup = parser.add_argument_group('Further analysis options') faGroup.add_argument('--subset', help='File with list of sequences to include in visualisation (with --generate-viz only)', default=None) @@ -257,6 +278,17 @@ def main(): if not args.use_mash: sketch_sizes = int(round(max(sketch_sizes.values())/64)) + # check if working with lineages + rank_list = [] + if args.lineage_clustering or args.assign_lineages: + rank_list = sorted([int(x) for x in args.ranks.split(',')]) + if min(rank_list) == 0 or max(rank_list) > 100: + sys.stderr.write('Ranks should be small non-zero integers for sensible results\n') + exit(1) + if args.assign_lineages and args.existing_scheme is None: + sys.stderr.write('Must provide an existing scheme (--existing-scheme) if assigning to lineages\n') + exit(1) + # check on file paths and whether files will be overwritten # confusing to overwrite command line parameter #if not args.full_db and not (args.create_db or args.easy_run or args.assign_query): @@ -276,7 +308,7 @@ def main(): sys.exit(1) else: sys.stderr.write('\t sketchlib: ' + checkSketchlibLibrary() + ')\n') - + #******************************# #* *# #* Create database *# @@ -501,19 +533,50 @@ def main(): nx.write_gpickle(genomeNetwork, args.output + "/" + os.path.basename(args.output) + '_graph.gpickle') + #******************************# + #* *# + #* within-strain analysis *# + #* *# + #******************************# + if args.lineage_clustering: + sys.stderr.write("Mode: Identifying lineages within a clade\n\n") + + # load distances + if args.distances is not None: + distances = args.distances + else: + sys.stderr.write("Need to provide an input set of distances with --distances\n\n") + sys.exit(1) + + refList, queryList, self, distMat = readPickle(distances) + + # make directory for new output files + if not os.path.isdir(args.output): + try: + os.makedirs(args.output) + except OSError: + sys.stderr.write("Cannot create output directory\n") + sys.exit(1) + + # run lineage clustering + if self: + isolateClustering = cluster_into_lineages(distMat, rank_list, args.output, isolate_list = refList, use_accessory = args.use_accessory, existing_scheme = args.existing_scheme, num_processes = args.threads) + else: + isolateClustering = cluster_into_lineages(distMat, rank_list, args.output, isolate_list = refList, qlist = queryList, use_accessory = args.use_accessory, existing_scheme = args.existing_scheme, num_processes = args.threads) + #*******************************# #* *# #* query assignment (function *# #* below) *# #* *# #*******************************# - elif args.assign_query: + elif args.assign_query or args.assign_lineages: assign_query(dbFuncs, args.ref_db, args.q_files, args.output, args.update_db, args.full_db, args.distances, args.microreact, args.cytoscape, kmers, sketch_sizes, args.ignore_length, args.estimated_length, args.threads, args.use_mash, args.mash, args.overwrite, args.plot_fit, args.no_stream, args.max_a_dist, args.model_dir, args.previous_clustering, args.external_clustering, args.core_only, args.accessory_only, args.phandango, args.grapetree, args.info_csv, - args.rapidnj, args.perplexity) + args.rapidnj, args.perplexity, args.assign_lineages, args.existing_scheme, rank_list, args.use_accessory) #******************************# #* *# @@ -529,6 +592,8 @@ def main(): sys.exit(1) if args.distances is not None and args.ref_db is not None: + + # Initial processing # Load original distances with open(args.distances + ".pkl", 'rb') as pickle_file: rlist, qlist, self = pickle.load(pickle_file) @@ -542,28 +607,8 @@ def main(): except OSError: sys.stderr.write("Cannot create output directory\n") sys.exit(1) - - # identify existing analysis files - model_prefix = args.ref_db - if args.model_dir is not None: - model_prefix = args.model_dir - model = loadClusterFit(model_prefix + "/" + os.path.basename(model_prefix) + '_fit.pkl', - model_prefix + "/" + os.path.basename(model_prefix) + '_fit.npz') - - # Set directories of previous fit - if args.previous_clustering is not None: - prev_clustering = args.previous_clustering - else: - prev_clustering = os.path.dirname(args.distances + ".pkl") - - # Read in network and cluster assignment - genomeNetwork, cluster_file = fetchNetwork(prev_clustering, model, rlist, args.core_only, args.accessory_only) - - # Use external clustering if specified - if args.external_clustering: - cluster_file = args.external_clustering - isolateClustering = {'combined': readClusters(cluster_file, return_dict=True)} - + + # Define set/subset to be visualised # extract subset of distances if requested viz_subset = rlist if args.subset is not None: @@ -574,8 +619,12 @@ def main(): # Use the same code as no full_db in assign_query to take a subset dists_out = args.output + "/" + os.path.basename(args.output) + ".dists" - nodes_to_remove = set(genomeNetwork.nodes).difference(viz_subset) - postpruning_combined_seq, newDistMat = prune_distance_matrix(rlist, nodes_to_remove, + #nodes_to_remove = set(genomeNetwork.nodes).difference(viz_subset) + isolates_to_remove = set(combined_seq).difference(viz_subset) + postpruning_combined_seq = viz_subset + newDistMat = complete_distMat + if len(isolates_to_remove) > 0: + postpruning_combined_seq, newDistMat = prune_distance_matrix(rlist, isolates_to_remove, complete_distMat, dists_out) rlist = viz_subset @@ -587,13 +636,37 @@ def main(): except: sys.stderr.write("Isolates in subset not found in existing database\n") assert postpruning_combined_seq == viz_subset - - # prune the network and dictionary of assignments - genomeNetwork.remove_nodes_from(set(genomeNetwork.nodes).difference(viz_subset)) - for clustering_type in isolateClustering: - isolateClustering[clustering_type] = {viz_key: isolateClustering[clustering_type][viz_key] - for viz_key in viz_subset} - + + # Either use strain definitions, lineage assignments or external clustering + isolateClustering = {} + # Use external clustering if specified + if args.external_clustering: + cluster_file = args.external_clustering + isolateClustering = readIsolateTypeFromCsv(cluster_file, mode = 'external', return_dict = True) + if args.viz_lineages: + cluster_file = args.viz_lineages + isolateClustering = readIsolateTypeFromCsv(cluster_file, mode = 'lineages', return_dict = True) + else: + # identify existing analysis files + model_prefix = args.ref_db + if args.model_dir is not None: + model_prefix = args.model_dir + model = loadClusterFit(model_prefix + "/" + os.path.basename(model_prefix) + '_fit.pkl', + model_prefix + "/" + os.path.basename(model_prefix) + '_fit.npz') + + # Set directories of previous fit + if args.previous_clustering is not None: + prev_clustering = args.previous_clustering + else: + prev_clustering = os.path.dirname(args.distances + ".pkl") + + # Read in network and cluster assignment + genomeNetwork, cluster_file = fetchNetwork(prev_clustering, model, rlist, args.core_only, args.accessory_only) + isolateClustering = readIsolateTypeFromCsv(cluster_file, mode = 'clusters', return_dict = True) + + # prune the network and dictionary of assignments + genomeNetwork.remove_nodes_from(set(genomeNetwork.nodes).difference(viz_subset)) + # generate selected visualisations if args.microreact: sys.stderr.write("Writing microreact output\n") @@ -608,8 +681,11 @@ def main(): outputsForGrapetree(viz_subset, core_distMat, isolateClustering, args.output, args.info_csv, args.rapidnj, overwrite = args.overwrite, microreact = args.microreact) if args.cytoscape: - sys.stderr.write("Writing cytoscape output\n") - outputsForCytoscape(genomeNetwork, isolateClustering, args.output, args.info_csv) + if args.viz_lineages or args.external_clustering: + sys.stderr.write("Can only generate a network output for fitted models\n") + else: + sys.stderr.write("Writing cytoscape output\n") + outputsForCytoscape(genomeNetwork, isolateClustering, args.output, args.info_csv) else: # Cannot read input files @@ -628,7 +704,7 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances kmers, sketch_sizes, ignore_length, estimated_length, threads, use_mash, mash, overwrite, plot_fit, no_stream, max_a_dist, model_dir, previous_clustering, external_clustering, core_only, accessory_only, phandango, grapetree, - info_csv, rapidnj, perplexity): + info_csv, rapidnj, perplexity, assign_lineage, existing_scheme, rank_list, use_accessory): """Code for assign query mode. Written as a separate function so it can be called by pathogen.watch API """ @@ -687,57 +763,71 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances threads = threads) qcPass = qcDistMat(distMat, refList, queryList, max_a_dist) - # Assign these distances as within or between - model_prefix = ref_db - if model_dir is not None: - model_prefix = model_dir - model = loadClusterFit(model_prefix + "/" + os.path.basename(model_prefix) + '_fit.pkl', - model_prefix + "/" + os.path.basename(model_prefix) + '_fit.npz') - queryAssignments = model.assign(distMat) - - # set model prefix - model_prefix = ref_db - if model_dir is not None: - model_prefix = model_dir - - # Set directories of previous fit - if previous_clustering is not None: - prev_clustering = previous_clustering + # Calculate query-query distances + ordered_queryList = [] + + # Assign to strains or lineages, as requested + if assign_lineage: + + # Assign lineages by calculating query-query information + ordered_queryList, query_distMat = calculateQueryQueryDistances(dbFuncs, refList, q_files, + kmers, estimated_length, output, use_mash, threads) + else: - prev_clustering = model_prefix - - # Load the network based on supplied options - genomeNetwork, old_cluster_file = fetchNetwork(prev_clustering, model, refList, - core_only, accessory_only) - - # Assign clustering by adding to network - ordered_queryList, query_distMat = addQueryToNetwork(dbFuncs, refList, q_files, - genomeNetwork, kmers, estimated_length, queryAssignments, model, output, update_db, - use_mash, threads) - - # if running simple query - print_full_clustering = False - if update_db: - print_full_clustering = True - isolateClustering = {'combined': printClusters(genomeNetwork, output + "/" + os.path.basename(output), - old_cluster_file, external_clustering, print_full_clustering)} - - # update_db like no full_db - if update_db: + # Assign these distances as within or between strain + model_prefix = ref_db + if model_dir is not None: + model_prefix = model_dir + model = loadClusterFit(model_prefix + "/" + os.path.basename(model_prefix) + '_fit.pkl', + model_prefix + "/" + os.path.basename(model_prefix) + '_fit.npz') + queryAssignments = model.assign(distMat) + + # set model prefix + model_prefix = ref_db + if model_dir is not None: + model_prefix = model_dir + + # Set directories of previous fit + if previous_clustering is not None: + prev_clustering = previous_clustering + else: + prev_clustering = model_prefix + + # Load the network based on supplied options + genomeNetwork, old_cluster_file = fetchNetwork(prev_clustering, model, refList, + core_only, accessory_only) + + # Assign clustering by adding to network + ordered_queryList, query_distMat = addQueryToNetwork(dbFuncs, refList, q_files, + genomeNetwork, kmers, estimated_length, queryAssignments, model, output, update_db, + use_mash, threads) + + # if running simple query + print_full_clustering = False + if update_db: + print_full_clustering = True + isolateClustering = {'combined': printClusters(genomeNetwork, output + "/" + os.path.basename(output), + old_cluster_file, external_clustering, print_full_clustering)} + + # Update DB as requested + if update_db or assign_lineage: + + # Check new sequences pass QC before adding them if not qcPass: sys.stderr.write("Queries contained outlier distances, not updating database\n") else: sys.stderr.write("Updating reference database to " + output + "\n") # Update the network + ref list - if full_db is False: + # only update network if assigning to strains + if full_db is False and assign_lineage is False: mashOrder = refList + ordered_queryList newRepresentativesNames, newRepresentativesFile = extractReferences(genomeNetwork, mashOrder, output, refList) genomeNetwork.remove_nodes_from(set(genomeNetwork.nodes).difference(newRepresentativesNames)) newQueries = [x for x in ordered_queryList if x in frozenset(newRepresentativesNames)] # intersection that maintains order + nx.write_gpickle(genomeNetwork, output + "/" + os.path.basename(output) + '_graph.gpickle') else: newQueries = ordered_queryList - nx.write_gpickle(genomeNetwork, output + "/" + os.path.basename(output) + '_graph.gpickle') # Update the sketch database if newQueries != queryList and use_mash: @@ -747,7 +837,7 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances os.remove(tmpRefFile) # With mash, this is the reduced DB constructed, # with sketchlib, all sketches - joinDBs(ref_db, output, output) + joinDBs(ref_db, output, output) # Update distance matrices with all calculated distances if distances == None: @@ -758,10 +848,16 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances combined_seq, core_distMat, acc_distMat = update_distance_matrices(refList, ref_distMat, ordered_queryList, distMat, query_distMat) complete_distMat = translate_distMat(combined_seq, core_distMat, acc_distMat) + + if assign_lineage: + expected_lineage_name = ref_db + '/' + ref_db + '_lineages.pkl' + if existing_scheme is not None: + expected_lineage_name = existing_scheme + isolateClustering = cluster_into_lineages(complete_distMat, rank_list, output, combined_seq, ordered_queryList, expected_lineage_name, use_accessory, threads) # Prune distances to references only, if not full db dists_out = output + "/" + os.path.basename(output) + ".dists" - if full_db is False: + if full_db is False and assign_lineage is False: # could also have newRepresentativesNames in this diff (should be the same) - but want # to ensure consistency with the network in case of bad input/bugs nodes_to_remove = set(combined_seq).difference(genomeNetwork.nodes) @@ -792,8 +888,11 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances outputsForGrapetree(combined_seq, core_distMat, isolateClustering, output, info_csv, rapidnj, queryList = ordered_queryList, overwrite = overwrite, microreact = microreact) if cytoscape: - sys.stderr.write("Writing cytoscape output\n") - outputsForCytoscape(genomeNetwork, isolateClustering, output, info_csv, ordered_queryList) + if assign_lineage: + sys.stderr.write("Cannot generate a cytoscape network from a lineage assignment") + else: + sys.stderr.write("Writing cytoscape output\n") + outputsForCytoscape(genomeNetwork, isolateClustering, output, info_csv, ordered_queryList) else: sys.stderr.write("Need to provide both a reference database with --ref-db and " diff --git a/PopPUNK/lineage_clustering.py b/PopPUNK/lineage_clustering.py new file mode 100644 index 00000000..d553801a --- /dev/null +++ b/PopPUNK/lineage_clustering.py @@ -0,0 +1,372 @@ +#!/usr/bin/env python +# vim: set fileencoding= : +# Copyright 2018-2020 John Lees and Nick Croucher + +# universal +import os +import sys +import re +# additional +import numpy as np +from scipy.stats import rankdata +from collections import defaultdict +import pickle +import collections +import networkx as nx +from multiprocessing import Pool, RawArray, shared_memory, managers +try: + from multiprocessing import Pool, shared_memory + from multiprocessing.managers import SharedMemoryManager + NumpyShared = collections.namedtuple('NumpyShared', ('name', 'shape', 'dtype')) +except ImportError as e: + sys.stderr.write("This version of PopPUNK requires python v3.8 or higher\n") + sys.exit(0) +from functools import partial + +# import poppunk package +from .plot import writeClusterCsv + +from .utils import iterDistRows +from .utils import update_distance_matrices + +def get_chunk_ranges(N, nb): + """ Calculates boundaries for dividing distances array + into chunks for parallelisation. + + Args: + N (int) + Number of rows in array + nb (int) + Number of blocks into which to divide array. + + Returns: + range_sizes (list of tuples) + Limits of blocks for dividing array. + """ + step = N / nb + range_sizes = [(round(step*i), round(step*(i+1))) for i in range(nb)] + # extend to end of distMat + range_sizes[len(range_sizes) - 1] = (range_sizes[len(range_sizes) - 1][0],N) + # return ranges + return range_sizes + +def rank_distance_matrix(bounds, distances = None): + """ Ranks distances between isolates for each index (row) + isolate. + + Args: + bounds (2-tuple) + Range of rows to process in this thread. + distances (ndarray in shared memory) + Shared memory object storing pairwise distances. + + Returns: + ranks (numpy ndarray) + Ranks of distances for each row. + """ + # load distance matrix from shared memory + distances_shm = shared_memory.SharedMemory(name = distances.name) + distances = np.ndarray(distances.shape, dtype = distances.dtype, buffer = distances_shm.buf) + # rank relevant slide of distance matrix + ranks = np.apply_along_axis(rankdata, 1, distances[slice(*bounds),:], method = 'ordinal') + return ranks + +def get_nearest_neighbours(rank, isolates = None, ranks = None): + """ Identifies sets of nearest neighbours for each isolate. + + Args: + rank (int) + Rank used in analysis. + isolates (int list) + List of isolate indices. + ranks (ndarray in shared memory) + Shared memory object pointing to ndarray of + ranked pairwise distances. + + Returns: + nn (default dict of frozensets) + Dict indexed by isolates, values are a + frozen set of nearest neighbours. + """ + # data structure + nn = {} + # load shared ranks + ranks_shm = shared_memory.SharedMemory(name = ranks.name) + ranks = np.ndarray(ranks.shape, dtype = ranks.dtype, buffer = ranks_shm.buf) + # apply along axis + for i in isolates: + nn[i] = defaultdict(frozenset) + isolate_ranks = ranks[i,:] + closest_ranked = np.ravel(np.where(isolate_ranks <= rank)) + neighbours = frozenset(closest_ranked.tolist()) + nn[i] = neighbours + # return dict + return nn + + +def pick_seed_isolate(G, distances = None): + """ Identifies seed isolate from the closest pair of + unclustered isolates. + + Args: + G (network) + Network with one node per isolate. + distances (ndarray in shared memory) + Pairwise distances between isolates. + + Returns: + seed_isolate (int) + Index of isolate selected as seed. + """ + # load distances from shared memory + distances_shm = shared_memory.SharedMemory(name = distances.name) + distances = np.ndarray(distances.shape, dtype = distances.dtype, buffer = distances_shm.buf) + # identify unclustered isolates + unclustered_isolates = list(nx.isolates(G)) + # select minimum distance between unclustered isolates + minimum_distance_between_unclustered_isolates = np.amin(distances[unclustered_isolates,unclustered_isolates],axis = 0) + # select occurrences of this distance + minimum_distance_coordinates = np.where(distances == minimum_distance_between_unclustered_isolates) + # identify case where both isolates are unclustered + for i in range(len(minimum_distance_coordinates[0])): + if minimum_distance_coordinates[0][i] in unclustered_isolates and minimum_distance_coordinates[1][i] in unclustered_isolates: + seed_isolate = minimum_distance_coordinates[0][i] + break + # return unclustered isolate with minimum distance to another isolate + return seed_isolate + +def get_lineage(G, neighbours, seed_isolate, lineage_index): + """ Identifies isolates corresponding to a particular + lineage given a cluster seed. + + Args: + G (network) + Network with one node per isolate. + neighbours (dict of frozen sets) + Pre-calculated neighbour relationships. + seed_isolate (int) + Index of isolate selected as seed. + lineage_index (int) + Label of current lineage. + + Returns: + G (network) + Network modified with new edges. + """ + # initiate lineage as the seed isolate and immediate unclustered neighbours + in_lineage = {seed_isolate} + G.nodes[seed_isolate]['lineage'] = lineage_index + for seed_neighbour in neighbours[seed_isolate]: + if nx.is_isolate(G, seed_neighbour): + G.add_edge(seed_isolate, seed_neighbour) + G.nodes[seed_neighbour]['lineage'] = lineage_index + in_lineage.add(seed_neighbour) + # iterate through other isolates until converged on a stable clustering + alterations = len(neighbours.keys()) + while alterations > 0: + alterations = 0 + for isolate in neighbours.keys(): + if nx.is_isolate(G, isolate): + intersection_size = in_lineage.intersection(neighbours[isolate]) + if intersection_size is not None and len(intersection_size) > 0: + for i in intersection_size: + G.add_edge(isolate, i) + G.nodes[isolate]['lineage'] = lineage_index + in_lineage.add(isolate) + alterations = alterations + 1 + # return final clustering + return G + +def cluster_into_lineages(distMat, rank_list = None, output = None, isolate_list = None, qlist = None, existing_scheme = None, use_accessory = False, num_processes = 1): + """ Clusters isolates into lineages based on their + relative distances. + + Args: + distMat (np.array) + n x 2 array of core and accessory distances for n samples. + This should not be subsampled. + rank_list (list of int) + Integers specifying the maximum rank of neighbours used + for clustering. + output (string) + Prefix used for printing output files. + isolate_list (list) + List of reference sequences. + qlist (list) + List of query sequences being added to an existing clustering. + Should be included within isolate_list. + existing_scheme (str) + Path to pickle file containing lineage scheme to which isolates + should be added. + use_accessory (bool) + Option to use accessory distances rather than core distances. + num_processes (int) + Number of CPUs to use for calculations. + + Returns: + overall_lineages (nested dict) + Dict for each rank listing the lineage of each isolate. + """ + + # data structures + lineage_clustering = defaultdict(dict) + overall_lineage_seeds = defaultdict(dict) + overall_lineages = defaultdict(dict) + if existing_scheme is not None: + with open(existing_scheme, 'rb') as pickle_file: + lineage_clustering, overall_lineage_seeds, rank_list = pickle.load(pickle_file) + + # generate square distance matrix + seqLabels, coreMat, accMat = update_distance_matrices(isolate_list, distMat) + if use_accessory: + distances = accMat + else: + distances = coreMat + try: + assert seqLabels == isolate_list + except: + sys.stderr.write('Isolates in wrong order?') + exit(1) + + # list indices and set self-self to Inf + isolate_indices = [n for n,i in enumerate(isolate_list)] + for i in isolate_indices: + distances[i,i] = np.Inf + + # get ranks of distances per row + chunk_boundaries = get_chunk_ranges(distances.shape[0], num_processes) + with SharedMemoryManager() as smm: + + # share isolate list + isolate_list_shared = smm.ShareableList(isolate_indices) + + # create shared memory object for distances + distances_raw = smm.SharedMemory(size = distances.nbytes) + distances_shared_array = np.ndarray(distances.shape, dtype = distances.dtype, buffer = distances_raw.buf) + distances_shared_array[:] = distances[:] + distances_shared_array = NumpyShared(name = distances_raw.name, shape = distances.shape, dtype = distances.dtype) + + # parallelise ranking of distances across CPUs + with Pool(processes = num_processes) as pool: + ranked_array = pool.map(partial(rank_distance_matrix, + distances = distances_shared_array), + chunk_boundaries) + + # concatenate ranks into shared memory + distance_ranks = np.concatenate(ranked_array) + distance_ranks_raw = smm.SharedMemory(size = distance_ranks.nbytes) + distance_ranks_shared_array = np.ndarray(distance_ranks.shape, dtype = distance_ranks.dtype, buffer = distance_ranks_raw.buf) + distance_ranks_shared_array[:] = distance_ranks[:] + distance_ranks_shared_array = NumpyShared(name = distance_ranks_raw.name, shape = distance_ranks.shape, dtype = distance_ranks.dtype) + + # parallelise neighbour identification for each rank + with Pool(processes = num_processes) as pool: + results = pool.map(partial(run_clustering_for_rank, + distances_input = distances_shared_array, + distance_ranks_input = distance_ranks_shared_array, + isolates = isolate_list_shared, + previous_seeds = overall_lineage_seeds), + rank_list) + + # extract results from multiprocessing pool + for n,result in enumerate(results): + rank = rank_list[n] + lineage_clustering[rank], overall_lineage_seeds[rank] = result + + # store output + with open(output + "/" + output + '_lineages.pkl', 'wb') as pickle_file: + pickle.dump([lineage_clustering, overall_lineage_seeds, rank_list], pickle_file) + + # process multirank lineages + overall_lineages = {} + overall_lineages = {'Rank_' + str(rank):{} for rank in rank_list} + overall_lineages['overall'] = {} + for index,isolate in enumerate(isolate_list): + overall_lineage = None + for rank in rank_list: + overall_lineages['Rank_' + str(rank)][isolate] = lineage_clustering[rank][index] + if overall_lineage is None: + overall_lineage = str(lineage_clustering[rank][index]) + else: + overall_lineage = overall_lineage + '-' + str(lineage_clustering[rank][index]) + overall_lineages['overall'][isolate] = overall_lineage + + # print output as CSV + writeClusterCsv(output + "/" + output + '_lineages.csv', + isolate_list, + isolate_list, + overall_lineages, + output_format = 'phandango', + epiCsv = None, + queryNames = qlist, + suffix = '_Lineage') + + # return lineages + return overall_lineages + +def run_clustering_for_rank(rank, distances_input = None, distance_ranks_input = None, isolates = None, previous_seeds = None): + """ Clusters isolates into lineages based on their + relative distances using a single R to enable + parallelisation. + + Args: + rank (int) + Integer specifying the maximum rank of neighbour used + for clustering. Should be changed to int list for hierarchical + clustering. + qlist (list) + List of query sequences being added to an existing clustering. + Should be included within rlist. + use_existing (bool) + Whether to extend a previously generated analysis or not. + + Returns: + lineage_clustering (dict) + Assignment of each isolate to a cluster. + lineage_seed (dict) + Seed isolate used to initiate each cluster. + neighbours (nested dict) + Neighbour relationships between isolates for R. + """ + + # load shared memory objects + distances_shm = shared_memory.SharedMemory(name = distances_input.name) + distances = np.ndarray(distances_input.shape, dtype = distances_input.dtype, buffer = distances_shm.buf) + distance_ranks_shm = shared_memory.SharedMemory(name = distance_ranks_input.name) + distance_ranks = np.ndarray(distance_ranks_input.shape, dtype = distance_ranks_input.dtype, buffer = distance_ranks_shm.buf) + isolate_list = isolates + isolate_indices = range(0,len(isolate_list)) + + # load previous scheme + seeds = {} + if previous_seeds is not None: + seeds = previous_seeds[rank] + + # create graph structure + G = nx.Graph() + G.add_nodes_from(isolate_indices) + G.nodes.data('lineage', default = 0) + + # identify nearest neighbours + nn = get_nearest_neighbours(rank, + ranks = distance_ranks_input, + isolates = isolate_list) + + # iteratively identify lineages + lineage_index = 1 + while nx.number_of_isolates(G) > 0: + if lineage_index in seeds.keys(): + seed_isolate = seeds[lineage_index] + else: + seed_isolate = pick_seed_isolate(G, distances = distances_input) + # skip over previously-defined seeds if amalgamated into different lineage now + if nx.is_isolate(G, seed_isolate): + seeds[lineage_index] = seed_isolate + G = get_lineage(G, nn, seed_isolate, lineage_index) + lineage_index = lineage_index + 1 + + # identify components and name lineages + lineage_clustering = {node:nodedata for (node, nodedata) in G.nodes(data='lineage')} + + # return clustering + return lineage_clustering, seeds diff --git a/PopPUNK/network.py b/PopPUNK/network.py index 40b7c494..4930bef4 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -18,9 +18,10 @@ from tempfile import mkstemp, mkdtemp from collections import defaultdict, Counter +from .sketchlib import calculateQueryQueryDistances + from .utils import iterDistRows -from .utils import readClusters -from .utils import readExternalClusters +from .utils import readIsolateTypeFromCsv from .utils import readRfile def fetchNetwork(network_dir, model, refList, @@ -320,13 +321,11 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length, distMat (numpy.array) Query-query distances """ - constructDatabase = dbFuncs['constructDatabase'] - queryDatabase = dbFuncs['queryDatabase'] - readDBParams = dbFuncs['readDBParams'] # initialise links data structure new_edges = [] assigned = set() + # These are returned qlist1 = None distMat = None @@ -338,7 +337,7 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length, rNames = None qNames = qSeqs else: - rNames = qList + rNames = qList qNames = rNames # store links for each query in a list of edge tuples @@ -350,17 +349,17 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length, # Calculate all query-query distances too, if updating database if queryQuery: sys.stderr.write("Calculating all query-query distances\n") + qlist1, distMat = calculateQueryQueryDistances(dbFuncs, + rNames, + qfile, + kmers, + estimated_length, + queryDB, + use_mash, + threads) - qlist1, qlist2, distMat = queryDatabase(rNames = rNames, - qNames = qNames, - dbPrefix = queryDB, - queryPrefix = queryDB, - klist = kmers, - self = True, - number_plot_fits = 0, - threads=threads) queryAssignation = model.assign(distMat) - for assignment, (ref, query) in zip(queryAssignation, iterDistRows(qlist1, qlist2, self=True)): + for assignment, (ref, query) in zip(queryAssignation, iterDistRows(qlist1, qlist1, self=True)): if assignment == model.within_label: new_edges.append((ref, query)) @@ -399,7 +398,7 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length, number_plot_fits = 0, threads = threads) queryAssignation = model.assign(distMat) - + # identify any links between queries and store in the same links dict # links dict now contains lists of links both to original database and new queries for assignment, (query1, query2) in zip(queryAssignation, iterDistRows(qlist1, qlist2, self=True)): @@ -416,7 +415,7 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length, return qlist1, distMat def printClusters(G, outPrefix = "_clusters.csv", oldClusterFile = None, - externalClusterCSV = None, printRef = True, printCSV = True): + externalClusterCSV = None, printRef = True, printCSV = True, clustering_type = 'combined'): """Get cluster assignments Also writes assignments to a CSV file @@ -427,26 +426,24 @@ def printClusters(G, outPrefix = "_clusters.csv", oldClusterFile = None, :func:`~addQueryToNetwork`) outPrefix (str) Prefix for output CSV - Default = "_clusters.csv" oldClusterFile (str) CSV with previous cluster assignments. Pass to ensure consistency in cluster assignment name. - Default = None externalClusterCSV (str) CSV with cluster assignments from any source. Will print a file relating these to new cluster assignments - Default = None printRef (bool) If false, print only query sequences in the output - Default = True printCSV (bool) Print results to file - Default = True + clustering_type (str) + Type of clustering network, used for comparison with old clusters + Default = 'combined' Returns: clustering (dict) @@ -460,8 +457,10 @@ def printClusters(G, outPrefix = "_clusters.csv", oldClusterFile = None, oldNames = set() if oldClusterFile != None: - oldClusters = readClusters(oldClusterFile) - new_id = len(oldClusters) + 1 # 1-indexed + oldAllClusters = readIsolateTypeFromCsv(oldClusterFile, mode = 'external', return_dict = False) + oldClusters = oldAllClusters[list(oldAllClusters.keys())[0]] + print('oldCluster is ' + str(oldClusters)) + new_id = len(oldClusters.keys()) + 1 # 1-indexed while new_id in oldClusters: new_id += 1 # in case clusters have been merged @@ -571,7 +570,8 @@ def printExternalClusters(newClusters, extClusterFile, outPrefix, d = defaultdict(list) # Read in external clusters - extClusters = readExternalClusters(extClusterFile) +# extClusters = readExternalClusters(extClusterFile) + readIsolateTypeFromCsv(clustCSV, mode = 'external', return_dict = False) # Go through each cluster (as defined by poppunk) and find the external # clusters that had previously been assigned to any sample in the cluster diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py index 7ac72d5a..13441f0c 100644 --- a/PopPUNK/plot.py +++ b/PopPUNK/plot.py @@ -399,7 +399,7 @@ def outputsForCytoscape(G, clustering, outPrefix, epiCsv, queryList = None, suff epiCsv, queryList) -def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering, output_format = 'microreact', epiCsv = None, queryNames = None): +def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering, output_format = 'microreact', epiCsv = None, queryNames = None, suffix = '_Cluster'): """Print CSV file of clustering and optionally epi data Writes CSV output of clusters which can be used as input to microreact and cytoscape. @@ -433,7 +433,7 @@ def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering, output_format = if output_format == 'microreact': colnames = ['id'] for cluster_type in clustering: - col_name = cluster_type + '_Cluster__autocolour' + col_name = cluster_type + suffix + '__autocolour' colnames.append(col_name) if queryNames is not None: colnames.append('Status') @@ -441,7 +441,7 @@ def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering, output_format = elif output_format == 'phandango': colnames = ['id'] for cluster_type in clustering: - col_name = cluster_type + '_Cluster' + col_name = cluster_type + suffix colnames.append(col_name) if queryNames is not None: colnames.append('Status') @@ -449,14 +449,14 @@ def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering, output_format = elif output_format == 'grapetree': colnames = ['ID'] for cluster_type in clustering: - col_name = cluster_type + '_Cluster' + col_name = cluster_type + suffix colnames.append(col_name) if queryNames is not None: colnames.append('Status') elif output_format == 'cytoscape': colnames = ['id'] for cluster_type in clustering: - col_name = cluster_type + '_Cluster' + col_name = cluster_type + suffix colnames.append(col_name) if queryNames is not None: colnames.append('Status') @@ -477,12 +477,15 @@ def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering, output_format = # process clustering data nodeLabels = [r.split('/')[-1].split('.')[0] for r in nodeNames] + # get example clustering name for validation + example_cluster_title = list(clustering.keys())[0] + for name, label in zip(nodeNames, nodeLabels): - if name in clustering['combined']: + if name in clustering[example_cluster_title]: if output_format == 'microreact': d['id'].append(label) for cluster_type in clustering: - col_name = cluster_type + "_Cluster__autocolour" + col_name = cluster_type + suffix + "__autocolour" d[col_name].append(clustering[cluster_type][name]) if queryNames is not None: if name in queryNames: @@ -494,7 +497,7 @@ def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering, output_format = elif output_format == 'phandango': d['id'].append(label) for cluster_type in clustering: - col_name = cluster_type + "_Cluster" + col_name = cluster_type + suffix d[col_name].append(clustering[cluster_type][name]) if queryNames is not None: if name in queryNames: @@ -506,7 +509,7 @@ def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering, output_format = elif output_format == 'grapetree': d['ID'].append(label) for cluster_type in clustering: - col_name = cluster_type + "_Cluster" + col_name = cluster_type + suffix d[col_name].append(clustering[cluster_type][name]) if queryNames is not None: if name in queryNames: @@ -516,7 +519,7 @@ def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering, output_format = elif output_format == 'cytoscape': d['id'].append(name) for cluster_type in clustering: - col_name = cluster_type + "_Cluster" + col_name = cluster_type + suffix d[col_name].append(clustering[cluster_type][name]) if queryNames is not None: if name in queryNames: @@ -527,7 +530,8 @@ def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering, output_format = # avoid adding if len(columns_to_be_omitted) == 0: columns_to_be_omitted = ['id', 'Id', 'ID', 'combined_Cluster__autocolour', - 'core_Cluster__autocolour', 'accessory_Cluster__autocolour'] + 'core_Cluster__autocolour', 'accessory_Cluster__autocolour', + 'overall_Lineage'] for c in d: if c not in columns_to_be_omitted: columns_to_be_omitted.append(c) @@ -543,7 +547,7 @@ def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering, output_format = else: sys.stderr.write("Cannot find " + name + " in clustering\n") sys.exit(1) - + # print CSV sys.stderr.write("Parsed data, now writing to CSV\n") try: diff --git a/PopPUNK/sketchlib.py b/PopPUNK/sketchlib.py index 88c623d1..f6a1e120 100644 --- a/PopPUNK/sketchlib.py +++ b/PopPUNK/sketchlib.py @@ -445,3 +445,59 @@ def queryDatabase(rNames, qNames, dbPrefix, queryPrefix, klist, self = True, num False, threads, use_gpu, deviceid) return(rNames, qNames, distMat) + +def calculateQueryQueryDistances(dbFuncs, rlist, qfile, kmers, estimated_length, + queryDB, use_mash = False, threads = 1): + """Calculates distances between queries. + + Args: + dbFuncs (list) + List of backend functions from :func:`~PopPUNK.utils.setupDBFuncs` + rlist (list) + List of reference names + qfile (str) + File containing queries + kmers (list) + List of k-mer sizes + estimated_length (int) + Estimated length of genome, if not calculated from data + queryDB (str) + Query database location + use_mash (bool) + Use the mash backend + threads (int) + Number of threads to use if new db created + (default = 1) + + Returns: + qlist1 (list) + Ordered list of queries + distMat (numpy.array) + Query-query distances + """ + + constructDatabase = dbFuncs['constructDatabase'] + queryDatabase = dbFuncs['queryDatabase'] + readDBParams = dbFuncs['readDBParams'] + + # Set up query names + qList, qSeqs = readRfile(qfile, oneSeq = use_mash) + queryFiles = dict(zip(qList, qSeqs)) + if use_mash == True: + rNames = None + qNames = qSeqs + else: + rNames = qList + qNames = rNames + + # Calculate all query-query distances too, if updating database + qlist1, qlist2, distMat = queryDatabase(rNames = rNames, + qNames = qNames, + dbPrefix = queryDB, + queryPrefix = queryDB, + klist = kmers, + self = True, + number_plot_fits = 0, + threads=threads) + + return qlist1, distMat diff --git a/PopPUNK/utils.py b/PopPUNK/utils.py index 64cf4525..9618c945 100644 --- a/PopPUNK/utils.py +++ b/PopPUNK/utils.py @@ -142,9 +142,7 @@ def iterDistRows(refSeqs, querySeqs, self=True): List of query sequence names. self (bool) Whether a self-comparison, used when constructing a database. - Requires refSeqs == querySeqs - Default is True Returns: ref, query (str, str) @@ -161,7 +159,6 @@ def iterDistRows(refSeqs, querySeqs, self=True): for ref in refSeqs: yield(ref, query) - def writeTmpFile(fileList): """Writes a list to a temporary file. Used for turning variable into mash input. @@ -213,12 +210,12 @@ def qcDistMat(distMat, refList, queryList, a_max): return passed -def readClusters(clustCSV, return_dict=False): - """Read a previous reference clustering from CSV +def readIsolateTypeFromCsv(clustCSV, mode = 'clusters', return_dict = False): + """Read isolate types from CSV file. Args: clustCSV (str) - File name of CSV with previous cluster assignments + File name of CSV with isolate assignments return_type (str) If True, return a dict with sample->cluster instead of sets @@ -229,46 +226,44 @@ def readClusters(clustCSV, return_dict=False): sets containing samples in the cluster). Or if return_dict is set keys are sample names, values are cluster assignments. """ + # data structures if return_dict: + clusters = defaultdict(dict) + else: clusters = {} + + # read CSV + clustersCsv = pd.read_csv(clustCSV, index_col = 0, quotechar='"') + + # select relevant columns according to mode + if mode == 'clusters': + type_columns = [n for n,col in enumerate(clustersCsv.columns) if ('Cluster' in col)] + elif mode == 'lineages': + type_columns = [n for n,col in enumerate(clustersCsv.columns) if ('Rank_' in col or 'overall' in col)] + elif mode == 'external': + if len(clustersCsv.columns) == 1: + type_columns = [0] + elif len(clustersCsv.columns) > 1: + type_columns = range((len(clustersCsv.columns)-1)) else: - clusters = defaultdict(set) - - with open(clustCSV, 'r') as csv_file: - header = csv_file.readline() - for line in csv_file: - (sample, clust_id) = line.rstrip().split(",")[:2] + sys.stderr.write('Unknown CSV reading mode: ' + mode + '\n') + exit(1) + + # read file + for row in clustersCsv.itertuples(): + for cls_idx in type_columns: + cluster_name = clustersCsv.columns[cls_idx] + cluster_name = cluster_name.replace('__autocolour','') if return_dict: - clusters[sample] = clust_id + clusters[cluster_name][row.Index] = str(row[cls_idx + 1]) else: - clusters[clust_id].add(sample) + if cluster_name not in clusters.keys(): + clusters[cluster_name] = defaultdict(set) + clusters[cluster_name][str(row[cls_idx + 1])].add(row.Index) + # return data structure return clusters - -def readExternalClusters(clustCSV): - """Read a cluster definition from CSV (does not have to be PopPUNK - generated clusters). Rows samples, columns clusters. - - Args: - clustCSV (str) - File name of CSV with previous cluster assingments - - Returns: - extClusters (dict) - Dictionary of dictionaries of cluster assignments - (first key cluster assignment name, second key sample, value cluster assignment) - """ - extClusters = defaultdict(lambda: defaultdict(str)) - - extClustersFile = pd.read_csv(clustCSV, index_col = 0, quotechar='"') - for row in extClustersFile.itertuples(): - for cls_idx, cluster in enumerate(extClustersFile.columns): - extClusters[str(cluster)][row.Index] = str(row[cls_idx + 1]) - - return(extClusters) - - def translate_distMat(combined_list, core_distMat, acc_distMat): """Convert distances from a square form (2 NxN matrices) to a long form (1 matrix with n_comparisons rows and 2 columns). @@ -469,8 +464,10 @@ def assembly_qc(assemblyList, klist, ignoreLengthOutliers, estimated_length): k_min = min(klist) max_prob = 1/(pow(4, k_min)/float(genome_length) + 1) - if 1/(pow(4, k_min)/float(genome_length) + 1) > 0.05: - sys.stderr.write("Minimum k-mer length " + str(k_min) + " is too small for genome length " + str(genome_length) +"; please increase to avoid nonsense results\n") + if max_prob > 0.05: + sys.stderr.write("Minimum k-mer length " + str(k_min) + " is too small for genome length " + str(genome_length) +"; results will be adjusted for random match probabilities\n") + if k_min < 6: + sys.stderr.write("Minimum k-mer length is too low; please increase to at least 6\n") exit(1) return (int(genome_length), max_prob) diff --git a/test/run_test.py b/test/run_test.py index fd2bed8b..8f736720 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -72,6 +72,22 @@ sys.stderr.write("Running microreact visualisations (--generate-viz)\n") subprocess.run("python ../poppunk-runner.py --generate-viz --distances example_db_mash/example_db_mash.dists --ref-db example_db_mash --output example_viz --microreact --subset subset.txt", shell=True, check=True) +# lineage clustering +sys.stderr.write("Running lineage clustering test (--lineage-clustering)\n") +subprocess.run("python ../poppunk-runner.py --lineage-clustering --distances example_db/example_db.dists --output example_lineages --ranks 1,2,3,5", shell=True, check=True) + +# assign query to lineages +sys.stderr.write("Running query assignment (--assign-lineages)\n") +subprocess.run("python ../poppunk-runner.py --assign-lineages --q-files queries.txt --distances example_db/example_db.dists --ref-db example_db --existing-scheme example_lineages/example_lineages_lineages.pkl --output example_lineage_query --update-db", shell=True, check=True) + +# lineage clustering with mash +sys.stderr.write("Running lineage clustering test (--lineage-clustering)\n") +subprocess.run("python ../poppunk-runner.py --lineage-clustering --distances example_db_mash/example_db_mash.dists --output example_lineages_mash --ranks 1,2,3,5 --use-mash", shell=True, check=True) + +# assign query to lineages with mash +sys.stderr.write("Running query assignment (--assign-lineages)\n") +subprocess.run("python ../poppunk-runner.py --assign-lineages --q-files queries.txt --distances example_db_mash/example_db_mash.dists --ref-db example_db_mash --existing-scheme example_lineages_mash/example_lineages_mash_lineages.pkl --output example_lineage_mash_query --update-db --use-mash", shell=True, check=True) + # tests of other command line programs (TODO)