diff --git a/PopPUNK/__init__.py b/PopPUNK/__init__.py index 165032b3..22718977 100644 --- a/PopPUNK/__init__.py +++ b/PopPUNK/__init__.py @@ -3,4 +3,4 @@ '''PopPUNK (POPulation Partitioning Using Nucleotide Kmers)''' -__version__ = '2.1.1' +__version__ = '2.2.0' diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py index 681e1f4c..14d25a00 100644 --- a/PopPUNK/__main__.py +++ b/PopPUNK/__main__.py @@ -128,10 +128,12 @@ def get_options(): # comparison metrics kmerGroup = parser.add_argument_group('Kmer comparison options') - kmerGroup.add_argument('--min-k', default = 13, type=int, help='Minimum kmer length [default = 9]') + kmerGroup.add_argument('--min-k', default = 13, type=int, help='Minimum kmer length [default = 13]') kmerGroup.add_argument('--max-k', default = 29, type=int, help='Maximum kmer length [default = 29]') kmerGroup.add_argument('--k-step', default = 4, type=int, help='K-mer step size [default = 4]') kmerGroup.add_argument('--sketch-size', default=10000, type=int, help='Kmer sketch size [default = 10000]') + kmerGroup.add_argument('--codon-phased', default=False, action='store_true', + help='Used codon phased seeds X--X--X [default = False]') kmerGroup.add_argument('--min-kmer-count', default=0, type=int, help='Minimum k-mer count when using reads as input [default = 0]') kmerGroup.add_argument('--exact-count', default=False, action='store_true', help='Use the exact k-mer counter with reads ' @@ -218,7 +220,8 @@ def get_options(): other.add_argument('--use-mash', default=False, action='store_true', help='Use the old mash sketch backend [default = False]') other.add_argument('--mash', default='mash', help='Location of mash executable') other.add_argument('--threads', default=1, type=int, help='Number of threads to use [default = 1]') - other.add_argument('--use-gpu', default=False, action='store_true', help='Use a GPU when calculating distances [default = False]') + other.add_argument('--gpu-sketch', default=False, action='store_true', help='Use a GPU when calculating sketches (read data only) [default = False]') + other.add_argument('--gpu-dist', default=False, action='store_true', help='Use a GPU when calculating distances [default = False]') other.add_argument('--deviceid', default=0, type=int, help='CUDA device ID, if using GPU [default = 0]') other.add_argument('--no-stream', help='Use temporary files for mash dist interfacing. Reduce memory use/increase disk use for large datasets', default=False, action='store_true') @@ -265,20 +268,14 @@ def main(): elif no_sketchlib and (args.min_k < 9 or args.max_k > 31): sys.stderr.write("When using Mash, Kmer size must be between 9 and 31\n") sys.exit(1) - elif args.min_k < 5 or args.max_k > 51: - sys.stderr.write("Very short or very long kmers are not recommended\n") + elif args.min_k < 3: + sys.stderr.write("Min k-mer length must be 3 or higher\n") sys.exit(1) kmers = np.arange(args.min_k, args.max_k + 1, args.k_step) - # Dict of DB access functions for assign_query (which is out of scope) - dbFuncs = setupDBFuncs(args, kmers, args.min_kmer_count) - createDatabaseDir = dbFuncs['createDatabaseDir'] - constructDatabase = dbFuncs['constructDatabase'] - queryDatabase = dbFuncs['queryDatabase'] - readDBParams = dbFuncs['readDBParams'] - # Dict of QC options for passing to database construction and querying functions qc_dict = { + 'run_qc': args.create_db or args.easy_run, 'qc_filter': args.qc_filter, 'retain_failures': args.retain_failures, 'length_sigma': args.length_sigma, @@ -287,6 +284,13 @@ def main(): 'upper_n': args.upper_n } + # Dict of DB access functions for assign_query (which is out of scope) + dbFuncs = setupDBFuncs(args, kmers, args.min_kmer_count, qc_dict) + createDatabaseDir = dbFuncs['createDatabaseDir'] + constructDatabase = dbFuncs['constructDatabase'] + queryDatabase = dbFuncs['queryDatabase'] + readDBParams = dbFuncs['readDBParams'] + # define sketch sizes, store in hash in case one day # different kmers get different hash sizes sketch_sizes = {} @@ -357,14 +361,14 @@ def main(): if args.r_files is not None: # generate sketches and QC sequences createDatabaseDir(args.output, kmers) - seq_names = constructDatabase(args.r_files, kmers, sketch_sizes, - args.output, - args.threads, - args.overwrite, - strand_preserved = args.strand_preserved, - min_count = args.min_kmer_count, - use_exact = args.exact_count, - qc_dict = qc_dict) + seq_names = constructDatabase( + args.r_files, + kmers, + sketch_sizes, + args.output, + args.threads, + args.overwrite, + calc_random = True) # Calculate and QC distances if args.use_mash == True: @@ -480,13 +484,13 @@ def main(): if args.use_model: assignments = model.assign(distMat) model.plot(distMat, assignments) - + #******************************# #* *# #* network construction *# #* *# #******************************# - + if not args.lineage_clustering: genomeNetwork = constructNetwork(refList, queryList, assignments, model.within_label) # Ensure all in dists are in final network @@ -527,7 +531,7 @@ def main(): #* lineages analysis *# #* *# #******************************# - + if args.lineage_clustering: # load distances @@ -538,7 +542,7 @@ def main(): sys.exit(1) refList, queryList, self, distMat = readPickle(distances) - + # make directory for new output files if not os.path.isdir(args.output): try: @@ -546,13 +550,13 @@ def main(): except OSError: sys.stderr.write("Cannot create output directory\n") sys.exit(1) - + # run lineage clustering if self: isolateClustering = cluster_into_lineages(distMat, rank_list, args.output, isolate_list = refList, use_accessory = args.use_accessory, existing_scheme = args.existing_scheme, num_processes = args.threads) else: isolateClustering = cluster_into_lineages(distMat, rank_list, args.output, isolate_list = refList, qlist = queryList, use_accessory = args.use_accessory, existing_scheme = args.existing_scheme, num_processes = args.threads) - + # load networks indivNetworks = {} for rank in rank_list: @@ -609,8 +613,9 @@ def main(): #******************************# # extract limited references from clique by default if not args.full_db: - newReferencesIndices, newReferencesNames, newReferencesFile, genomeNetwork = extractReferences(genomeNetwork, refList, args.output) - nodes_to_remove = set(range(len(refList))).difference(newReferencesIndices) + newReferencesIndices, newReferencesNames, newReferencesFile, genomeNetwork = \ + extractReferences(genomeNetwork, refList, args.output) + nodes_to_remove = set(range(len(refList))).difference(newReferencesIndices) names_to_remove = [refList[n] for n in nodes_to_remove] prune_distance_matrix(refList, names_to_remove, distMat, args.output + "/" + os.path.basename(args.output) + ".dists") @@ -621,8 +626,14 @@ def main(): dummyRefFile = writeDummyReferences(newReferencesNames, args.output) # Read and overwrite previous database kmers, sketch_sizes = readDBParams(ref_db, kmers, sketch_sizes) - constructDatabase(dummyRefFile, kmers, sketch_sizes, args.output, - True, args.threads, True) # overwrite old db + constructDatabase(dummyRefFile, + kmers, + sketch_sizes, + args.output, + True, + args.threads, + True, # overwrite old db + calc_random = True) os.remove(dummyRefFile) genomeNetwork.save(args.output + "/" + os.path.basename(args.output) + '_graph.gt', fmt = 'gt') @@ -639,9 +650,7 @@ def main(): args.threads, args.use_mash, args.mash, args.overwrite, args.plot_fit, args.no_stream, args.max_a_dist, args.model_dir, args.previous_clustering, args.external_clustering, args.core_only, args.accessory_only, args.phandango, args.grapetree, args.info_csv, - args.rapidnj, args.perplexity, args.assign_lineages, args.existing_scheme, rank_list, args.use_accessory, - strand_preserved = args.strand_preserved, min_count = args.min_kmer_count, - use_exact = args.exact_count, qc_dict = qc_dict) + args.rapidnj, args.perplexity, args.assign_lineages, args.existing_scheme, rank_list, args.use_accessory) #******************************# #* *# @@ -695,7 +704,7 @@ def main(): complete_distMat, dists_out) combined_seq, core_distMat, acc_distMat = \ - update_distance_matrices(viz_subset, newDistMat, + update_distance_matrices(viz_subset, newDistMat, threads = args.threads) # reorder subset to ensure list orders match @@ -732,7 +741,7 @@ def main(): prev_clustering = args.previous_clustering else: prev_clustering = os.path.dirname(args.distances + ".pkl") - + # load clustering if model.indiv_fitted: cluster_file = args.ref_db + '/' + os.path.basename(args.ref_db) + '_clusters.csv' @@ -745,7 +754,7 @@ def main(): else: cluster_file = args.ref_db + '/' + os.path.basename(args.ref_db) + '_clusters.csv' isolateClustering = readIsolateTypeFromCsv(cluster_file, mode = 'clusters', return_dict = True) - + # generate selected visualisations if args.microreact: sys.stderr.write("Writing microreact output\n") @@ -796,9 +805,7 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances kmers, sketch_sizes, threads, use_mash, mash, overwrite, plot_fit, no_stream, max_a_dist, model_dir, previous_clustering, external_clustering, core_only, accessory_only, phandango, grapetree, - info_csv, rapidnj, perplexity, assign_lineage, existing_scheme, rank_list, use_accessory, - # added extra arguments for constructing sketchlib libraries - strand_preserved, min_count, use_exact, qc_dict): + info_csv, rapidnj, perplexity, assign_lineage, existing_scheme, rank_list, use_accessory): """Code for assign query mode. Written as a separate function so it can be called by pathogen.watch API """ @@ -845,13 +852,14 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances rNames.append(reference.rstrip()) else: rNames = getSeqsInDb(ref_db + "/" + os.path.basename(ref_db) + ".h5") - # construct database and QC - qNames = constructDatabase(q_files, kmers, sketch_sizes, output, - threads, overwrite, - strand_preserved = strand_preserved, - min_count = min_count, - use_exact = use_exact, - qc_dict = qc_dict) + # construct database + qNames = constructDatabase(q_files, + kmers, + sketch_sizes, + output, + threads, + overwrite, + calc_random = False) # run query refList, queryList, distMat = queryDatabase(rNames = rNames, @@ -862,7 +870,7 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances self = False, number_plot_fits = plot_fit, threads = threads) - + # QC distance matrix qcPass = qcDistMat(distMat, refList, queryList, max_a_dist) diff --git a/PopPUNK/network.py b/PopPUNK/network.py index aa207724..e3254aea 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -112,7 +112,7 @@ def extractReferences(G, dbOrder, outPrefix, existingRefs = None): references = set(existingRefs) index_lookup = {v:k for k,v in enumerate(dbOrder)} reference_indices = [index_lookup[r] for r in references] - + # extract cliques from network cliques_in_overall_graph = [c.tolist() for c in gt.max_cliques(G)] # order list by size of clique @@ -175,12 +175,12 @@ def extractReferences(G, dbOrder, outPrefix, existingRefs = None): for vertex in vertex_list: reference_vertex[vertex] = True reference_indices.add(int(vertex)) - + # update reference graph if vertices have been added if network_update_required: G_ref = gt.GraphView(G, vfilt = reference_vertex) G_ref = gt.Graph(G_ref, prune = True) # https://stackoverflow.com/questions/30839929/graph-tool-graphview-object - + # Order found references as in mash sketch files reference_names = [dbOrder[int(x)] for x in sorted(reference_indices)] refFileName = writeReferences(reference_names, outPrefix) @@ -258,12 +258,12 @@ def constructNetwork(rlist, qlist, assignments, within_label, summarise = True): connections = [] self_comparison = True vertex_labels = rlist - + # check if self comparison if rlist != qlist: self_comparison = False vertex_labels.append(qlist) - + # identify edges for assignment, (ref, query) in zip(assignments, listDistInts(rlist, qlist, self = self_comparison)): if assignment == within_label: @@ -418,7 +418,7 @@ def addQueryToNetwork(dbFuncs, rlist, qList, qFile, G, kmers, # Otherwise only calculate query-query distances for new clusters else: - + # identify potentially new lineages in list: unassigned is a list of queries with no hits unassigned = set(qSeqs).difference(assigned) query_indices = {k:v+ref_count for v,k in enumerate(qSeqs)} @@ -467,7 +467,7 @@ def addQueryToNetwork(dbFuncs, rlist, qList, qFile, G, kmers, # finish by updating the network G.add_vertex(len(qNames)) G.add_edge_list(new_edges) - + # including the vertex ID property map for i,q in enumerate(qSeqs): G.vp.id[i + len(rlist)] = q @@ -521,7 +521,7 @@ def printClusters(G, rlist, outPrefix = "_clusters.csv", oldClusterFile = None, component = component_assignments.a[isolate_index] component_rank = component_frequency_ranks[component] newClusters[component_rank].add(isolate_name) - + oldNames = set() if oldClusterFile != None: diff --git a/PopPUNK/reference_pick.py b/PopPUNK/reference_pick.py index fcc72680..6ce33ac5 100755 --- a/PopPUNK/reference_pick.py +++ b/PopPUNK/reference_pick.py @@ -7,12 +7,11 @@ import sys # additional from shutil import copyfile -import networkx as nx +import graph_tool.all as gt # import poppunk package from .__init__ import __version__ -from .sketchlib import no_sketchlib from .sketchlib import removeFromDB from .mash import checkMashVersion @@ -48,13 +47,9 @@ def get_options(): # output options oGroup = parser.add_argument_group('Output options') oGroup.add_argument('--output', required=True, help='Prefix for output files (required)') - oGroup.add_argument('--no-resketch', default=False, action='store_true', help='Do not resketch the references (--use-mash only)' - '[default = False]') # processing other = parser.add_argument_group('Other options') - other.add_argument('--use-mash', default=False, action='store_true', help='Use the old mash sketch backend [default = False]') - other.add_argument('--mash', default='mash', help='Location of mash executable') other.add_argument('--threads', default=1, type=int, help='Number of threads to use [default = 1]') other.add_argument('--version', action='version', @@ -66,17 +61,6 @@ def main(): # Check input args ok args = get_options() - resketch = not args.no_resketch - if no_sketchlib: - args.use_mash = True - - if args.use_mash: - checkMashVersion(args.mash) - else: - resketch = True - if resketch and (args.ref_db is None or not os.path.isdir(args.ref_db)): - sys.stderr.write("Must provide original --ref-db if using --resketch\n") - sys.exit(1) # Check output path ok if not os.path.isdir(args.output): @@ -92,38 +76,22 @@ def main(): raise RuntimeError("Distance DB should be self-self distances") # Read in full network - genomeNetwork = nx.read_gpickle(args.network) - sys.stderr.write("Network loaded: " + str(genomeNetwork.number_of_nodes()) + " samples\n") + genomeNetwork = gt.load_graph(network_file) + sys.stderr.write("Network loaded: " + str(len(list(genomeNetwork.vertices()))) + " samples\n") # This is the same set of function calls for --fit-model when no --full-db in __main__.py # Find refs and prune network - newReferencesNames, newReferencesFile = extractReferences(genomeNetwork, refList, args.output) - nodes_to_remove = set(refList).difference(newReferencesNames) - genomeNetwork.remove_nodes_from(nodes_to_remove) - nx.write_gpickle(genomeNetwork, args.output + "/" + os.path.basename(args.output) + '_graph.gpickle') + reference_indices, reference_names, refFileName, G_ref = \ + extractReferences(genomeNetwork, refList, args.output) + G_ref.save(args.output + "/" + os.path.basename(args.output) + '_graph.gt', fmt = 'gt') # Prune distances prune_distance_matrix(refList, nodes_to_remove, distMat, args.output + "/" + os.path.basename(args.output) + ".dists") - # Resketch + # 'Resketch' if len(nodes_to_remove) > 0: - if resketch: - if args.use_mash: - sys.stderr.write("Resketching " + str(len(newReferencesNames)) + " sequences\n") - - # Find db properties - kmers = getKmersFromReferenceDatabase(args.ref_db) - sketch_sizes = getSketchSize(args.ref_db, kmers, args.mash) - - # Resketch all - createDatabaseDir(args.output, kmers) - dummyRefFile = writeDummyReferences(newReferencesNames, args.output) - constructDatabase(dummyRefFile, kmers, sketch_sizes, args.output, args.estimated_length, True, args.threads, args.mash, True) - os.remove(dummyRefFile) - else: - removeFromDB(args.ref_db, args.output, set(refList) - set(newReferencesNames)) - + removeFromDB(args.ref_db, args.output, set(refList) - set(newReferencesNames)) else: sys.stderr.write("No sequences to remove\n") diff --git a/PopPUNK/sketchlib.py b/PopPUNK/sketchlib.py index a6dc7037..fff65840 100644 --- a/PopPUNK/sketchlib.py +++ b/PopPUNK/sketchlib.py @@ -287,7 +287,9 @@ def removeFromDB(db_name, out_name, removeSeqs): def constructDatabase(assemblyList, klist, sketch_size, oPrefix, threads, overwrite, strand_preserved, min_count, - use_exact, qc_dict): + use_exact, qc_dict, calc_random = True, + codon_phased = False, + use_gpu = False, deviceid = 0): """Sketch the input assemblies at the requested k-mer lengths A multithread wrapper around :func:`~runSketch`. Threads are used to either run multiple sketch @@ -321,6 +323,17 @@ def constructDatabase(assemblyList, klist, sketch_size, oPrefix, (default = False) qc_dict (dict) Dict containg QC settings + calc_random (bool) + Add random match chances to DB (turn off for queries) + codon_phased (bool) + Use codon phased seeds + (default = False) + use_gpu (bool) + Use GPU for read sketching + (default = False) + deviceid (int) + GPU device id + (default = 0) """ # read file names names, sequences = readRfile(assemblyList) @@ -333,11 +346,38 @@ def constructDatabase(assemblyList, klist, sketch_size, oPrefix, os.remove(dbfilename) # generate sketches - pp_sketchlib.constructDatabase(dbname, names, sequences, klist, sketch_size, - not strand_preserved, min_count, use_exact, threads) - + pp_sketchlib.constructDatabase(dbname, + names, + sequences, + klist, + sketch_size, + codon_phased, + False, + not strand_preserved, + min_count, + use_exact, + threads, + use_gpu, + deviceid) + # QC sequences - filtered_names = sketchlib_assembly_qc(oPrefix, klist, qc_dict, strand_preserved, threads) + if qc_dict['run_qc']: + filtered_names = sketchlib_assembly_qc(oPrefix, + klist, + qc_dict, + strand_preserved, + threads) + else: + filtered_names = names + + # Add random matches if required + # (typically on for reference, off for query) + if (calc_random): + pp_sketchlib.addRandom(dbname, + filtered_names, + klist, + not strand_preserved, + threads) # return filtered file names return filtered_names diff --git a/PopPUNK/utils.py b/PopPUNK/utils.py index 3e11d455..532db601 100644 --- a/PopPUNK/utils.py +++ b/PopPUNK/utils.py @@ -21,7 +21,7 @@ # Use partials to set up slightly different function calls between # both possible backends -def setupDBFuncs(args, kmers, min_count): +def setupDBFuncs(args, kmers, min_count, qc_dict): """Wraps common database access functions from sketchlib and mash, to try and make their API more similar @@ -32,30 +32,16 @@ def setupDBFuncs(args, kmers, min_count): List of k-mer sizes min_count (int) Minimum k-mer count for reads + qc_dict (dict) + Table of parameters for QC function Returns: dbFuncs (dict) Functions with consistent arguments to use as the database API """ if args.use_mash: - from .mash import checkMashVersion - from .mash import createDatabaseDir - from .mash import getKmersFromReferenceDatabase - from .mash import joinDBs as joinDBsMash - from .mash import constructDatabase as constructDatabaseMash - from .mash import queryDatabase as queryDBMash - from .mash import readMashDBParams - from .mash import getSeqsInDb - - # check mash is installed - backend = "mash" - version = checkMashVersion(args.mash) - - constructDatabase = partial(constructDatabaseMash, mash_exec = args.mash) - readDBParams = partial(readMashDBParams, mash_exec = args.mash) - queryDatabase = partial(queryDBMash, no_stream = args.no_stream, mash_exec = args.mash) - joinDBs = partial(joinDBsMash, klist = getKmersFromReferenceDatabase(args.output), mash_exec = args.mash) - + sys.stderr.write("mash no longer supported. " + "Please downgrade to <=v2.0.2 to use\n") else: from .sketchlib import checkSketchlibVersion @@ -69,9 +55,17 @@ def setupDBFuncs(args, kmers, min_count): backend = "sketchlib" version = checkSketchlibVersion() - constructDatabase = partial(constructDatabaseSketchlib, strand_preserved = args.strand_preserved, - min_count = args.min_kmer_count, use_exact = args.exact_count) - queryDatabase = partial(queryDatabaseSketchlib, use_gpu = args.use_gpu, deviceid = args.deviceid) + constructDatabase = partial(constructDatabaseSketchlib, + codon_phased = args.codon_phased, + strand_preserved = args.strand_preserved, + min_count = args.min_kmer_count, + use_exact = args.exact_count, + qc_dict = qc_dict, + use_gpu = args.gpu_sketch, + deviceid = args.deviceid) + queryDatabase = partial(queryDatabaseSketchlib, + use_gpu = args.gpu_dist, + deviceid = args.deviceid) # Dict of DB access functions for assign_query (which is out of scope) dbFuncs = {'createDatabaseDir': createDatabaseDir, @@ -402,7 +396,7 @@ def readRfile(rFile, oneSeq=False): def isolateNameToLabel(names): """Function to process isolate names to labels appropriate for visualisation. - + Args: names (list) List of isolate names. @@ -448,7 +442,7 @@ def sketchlib_assembly_qc(prefix, klist, qc_dict, strand_preserved, threads): if qc_dict['retain_failures']: failed_db_name = prefix + '/' + 'failed.' + os.path.basename(prefix) + '.h5' hdf_fail = h5py.File(failed_db_name, 'w') - + # try/except structure to prevent h5 corruption try: # process data structures @@ -459,10 +453,8 @@ def sketchlib_assembly_qc(prefix, klist, qc_dict, strand_preserved, threads): fail_grp = hdf_fail.create_group('sketches') seq_length = {} seq_ambiguous = {} - seq_excluded = {} - removed = [] retained = [] - + # iterate through sketches for dataset in read_grp: # test thresholds @@ -477,21 +469,21 @@ def sketchlib_assembly_qc(prefix, klist, qc_dict, strand_preserved, threads): # get mean length genome_lengths = np.fromiter(seq_length.values(), dtype = int) mean_genome_length = np.mean(genome_lengths) - + # calculate length threshold unless user-supplied if qc_dict['length_range'][0] is None: - lower_length = mean_genome_length - qc_dict['length_sigma']*np.std(genome_lengths) - upper_length = mean_genome_length + qc_dict['length_sigma']*np.std(genome_lengths) + lower_length = mean_genome_length - \ + qc_dict['length_sigma'] * np.std(genome_lengths) + upper_length = mean_genome_length + \ + qc_dict['length_sigma'] * np.std(genome_lengths) else: lower_length, upper_length = qc_dict['length_range'] # open file to report QC failures with open(prefix + '/' + os.path.basename(prefix) + '_qcreport.txt', 'a+') as qc_file: - # iterate through and filter failed_sample = False for dataset in seq_length.keys(): - # determine if sequence passes filters remove = False if seq_length[dataset] < lower_length: @@ -513,13 +505,10 @@ def sketchlib_assembly_qc(prefix, klist, qc_dict, strand_preserved, threads): if qc_dict['retain_failures']: fail_grp.copy(read_grp[dataset], dataset) else: + retained.append(dataset) if qc_dict['qc_filter'] == 'prune': out_grp.copy(read_grp[dataset], dataset) - retained.append(dataset) - - # get kmers from original database - db_kmers = hdf_in['sketches'][retained[0]].attrs['kmers'] - + # close files hdf_in.close() if qc_dict['qc_filter'] == 'prune': @@ -530,7 +519,7 @@ def sketchlib_assembly_qc(prefix, klist, qc_dict, strand_preserved, threads): # replace original database with pruned version if qc_dict['qc_filter'] == 'prune': os.rename(filtered_db_name, db_name) - + # if failure still close files to avoid corruption except: hdf_in.close() @@ -539,25 +528,26 @@ def sketchlib_assembly_qc(prefix, klist, qc_dict, strand_preserved, threads): if qc_dict['retain_failures']: hdf_fail.close() sys.stderr.write('Problem processing h5 databases during QC - aborting\n') - sys.exit(1) - - + + print("Unexpected error:", sys.exc_info()[0], file = sys.stderr) + raise + # stop if at least one sample fails QC and option is not continue/prune if failed_sample and qc_dict['qc_filter'] == 'stop': - sys.stderr.write('Sequences failed QC filters - details in ' + prefix + '/' + os.path.basename(prefix) + '_qcreport.txt\n') + sys.stderr.write('Sequences failed QC filters - details in ' + \ + prefix + '/' + os.path.basename(prefix) + \ + '_qcreport.txt\n') sys.exit(1) - + # calculate random matches if any sequences pass QC filters if len(retained) == 0: sys.stderr.write('No sequences passed QC filters - please adjust your settings\n') sys.exit(1) - use_rc = not strand_preserved - db_name_prefix = prefix + '/' + os.path.basename(prefix) + # remove random matches if already present hdf_in = h5py.File(db_name, 'r+') if 'random' in hdf_in: del hdf_in['random'] hdf_in.close() - pp_sketchlib.addRandom(db_name_prefix, retained, db_kmers.tolist(), use_rc, threads) - + return retained diff --git a/README.md b/README.md index 4948c904..6fc42e19 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,20 @@ Lees JA, Harris SR, Tonkin-Hill G, Gladstone RA, Lo SW, Weiser JN, Corander J, B Fast and flexible bacterial genomic epidemiology with PopPUNK. *Genome Research* **29**:304-316 (2019). doi:[10.1101/gr.241455.118](https://doi.org/10.1101/gr.241455.118) +## News + +### 2020-09-30 +We have discovered a bug affecting the interaction of pp-sketchlib and PopPUNK. +If you have used `PopPUNK >=v2.0.0` with `pp-sketchlib =v2.2` and `pp-sketchlib >=v1.5.1`. If this is not +possible, you can either: +- Run `scripts/poppunk_pickle_fix.py` on your `.dists.pkl` file and re-run + model fits. +- Create the database with `poppunk_sketch` directly, rather than ` + PopPUNK --create-db` + ## Installation This is for the command line version. For more details see [installation](https://poppunk.readthedocs.io/en/latest/installation.html) in the documentation. diff --git a/docs/conf.py b/docs/conf.py index 643eb2e3..6c8dc82c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -69,9 +69,9 @@ # built documents. # # The short X.Y version. -version = '2.0.2' +version = '2.2.0' # The full version, including alpha/beta/rc tags. -release = '2.0.2' +release = '2.2.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docs/setup.py b/docs/setup.py index 18717e06..fb8eb5be 100644 --- a/docs/setup.py +++ b/docs/setup.py @@ -49,7 +49,7 @@ def find_version(*file_paths): 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python :: 3.8', ], - python_requires='>=3.7.0', + python_requires='>=3.8.0', keywords='bacteria genomics population-genetics k-mer', packages=['PopPUNK'], entry_points={ diff --git a/docs/troubleshooting.rst b/docs/troubleshooting.rst index b392f21b..b22afa7e 100644 --- a/docs/troubleshooting.rst +++ b/docs/troubleshooting.rst @@ -7,6 +7,21 @@ installing or running the software please raise an issue on github. .. contents:: :local: +Known bugs +---------- + +When I look at my clusters on a tree, they make no sense +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +This is a bug caused by alphabetic sorting of labels in ``PopPUNK >=v2.0.0`` +with ``pp-sketchlib =v2.2`` and ``pp-sketchlib >=v1.5.1`` (preferred). +- Run `scripts/poppunk_pickle_fix.py` on your `.dists.pkl` file and re-run + model fits. +- Create the database with ``poppunk_sketch --sketch`` and + ``poppunk_sketch --query``directly, rather than `PopPUNK --create-db`. + + Error/warning messages ---------------------- diff --git a/environment.yml b/environment.yml index 8ec247c9..ad0d5f6b 100644 --- a/environment.yml +++ b/environment.yml @@ -14,10 +14,9 @@ dependencies: - scikit-learn - dendropy - matplotlib - - networkx - mash - hdbscan - rapidnj - h5py - - pp-sketchlib + - pp-sketchlib >=1.5.1 - graph-tool diff --git a/scripts/poppunk_calculate_rand_indices.py b/scripts/poppunk_calculate_rand_indices.py index 86ece8f8..b1031c26 100755 --- a/scripts/poppunk_calculate_rand_indices.py +++ b/scripts/poppunk_calculate_rand_indices.py @@ -22,7 +22,7 @@ def rand_index_score(labels_true, labels_pred): # check clusterings labels_true, labels_pred = check_clusterings(labels_true, labels_pred) - + # initial statistics calculations n_samples = labels_true.shape[0] n_samples_comb = comb(n_samples,2) @@ -30,7 +30,7 @@ def rand_index_score(labels_true, labels_pred): n_clusters = np.unique(labels_pred).shape[0] class_freq = np.bincount(labels_true) cluster_freq = np.bincount(labels_pred) - + # Special limit cases: no clustering since the data is not split; # or trivial clustering where each document is assigned a unique cluster. # These are perfect matches hence return 1.0. @@ -44,7 +44,7 @@ def rand_index_score(labels_true, labels_pred): sum_comb_c = sum((n_c**2) for n_c in cluster_freq) sum_comb_k = sum((n_k**2) for n_k in class_freq) sum_comb = sum((n_ij**2) for n_ij in contingency.data) - + return (1 + (sum_comb - 0.5 * sum_comb_k - 0.5 * sum_comb_c)/n_samples_comb) # command line parsing diff --git a/scripts/poppunk_extract_components.py b/scripts/poppunk_extract_components.py index 842b5675..51f93faf 100755 --- a/scripts/poppunk_extract_components.py +++ b/scripts/poppunk_extract_components.py @@ -28,7 +28,7 @@ def get_options(): # open stored graph G = gt.load_graph(args.graph) - + # extract individual components component_assignments, component_frequencies = gt.label_components(G) component_frequency_ranks = len(component_frequencies) - rankdata(component_frequencies, method = 'ordinal').astype(int) @@ -41,5 +41,5 @@ def get_options(): component_G = gt.Graph(component_gv, prune = True) component_fn = args.output + ".component_" + str(component_frequency_ranks[component_index]) + ".graphml" component_G.save(component_fn, fmt = 'graphml') - + sys.exit(0) diff --git a/scripts/poppunk_pickle_fix.py b/scripts/poppunk_pickle_fix.py new file mode 100755 index 00000000..bfde518f --- /dev/null +++ b/scripts/poppunk_pickle_fix.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python +# vim: set fileencoding= : +# Copyright 2018-2020 John Lees and Nick Croucher + +import sys +import argparse +import pickle + +# command line parsing +def get_options(): + + parser = argparse.ArgumentParser(description='Fix pickle files with incorrect label order', + prog='pickle_fix') + + # input options + parser.add_argument('pickle', help='Input pickle (.dists.pkl)') + parser.add_argument('output', help='Prefix for output files') + + return parser.parse_args() + +# main code +if __name__ == "__main__": + + # Check input ok + args = get_options() + + with open(args.pickle, 'rb') as pickled_names: + rNames, qNames, self = pickle.load(pickled_names) + + rNames = sorted(rNames) + qNames = sorted(qNames) + + with open(args.output + ".dists.pkl", 'wb') as pickle_fixed: + pickle.dump([rNames, qNames, self], pickle_fixed) + + sys.exit(0) diff --git a/setup.py b/setup.py index c375969d..6c697db4 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ def find_version(*file_paths): 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python :: 3.8', ], - python_requires='>=3.7.0', + python_requires='>=3.8.0', keywords='bacteria genomics population-genetics k-mer', packages=['PopPUNK'], entry_points={ @@ -63,7 +63,8 @@ def find_version(*file_paths): scripts=['scripts/poppunk_calculate_rand_indices.py', 'scripts/poppunk_extract_components.py', 'scripts/poppunk_calculate_silhouette.py', - 'scripts/poppunk_extract_distances.py'], + 'scripts/poppunk_extract_distances.py', + 'scripts/poppunk_pickle_fix.py'], install_requires=['numpy', 'scipy', 'scikit-learn', diff --git a/test/clean_test.py b/test/clean_test.py index 303b2410..cd6e9b01 100755 --- a/test/clean_test.py +++ b/test/clean_test.py @@ -25,8 +25,8 @@ def deleteDir(dirname): "example_use", "example_viz", "example_lineages", - "example_lineage_query" - "" + "example_lineage_query", + "example_qc" ] for outDir in outputDirs: deleteDir(outDir)