Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to sketchlib v1.5.1 #104

Merged
merged 8 commits into from
Sep 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion PopPUNK/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

'''PopPUNK (POPulation Partitioning Using Nucleotide Kmers)'''

__version__ = '2.1.1'
__version__ = '2.2.0'
100 changes: 54 additions & 46 deletions PopPUNK/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,12 @@ def get_options():

# comparison metrics
kmerGroup = parser.add_argument_group('Kmer comparison options')
kmerGroup.add_argument('--min-k', default = 13, type=int, help='Minimum kmer length [default = 9]')
kmerGroup.add_argument('--min-k', default = 13, type=int, help='Minimum kmer length [default = 13]')
kmerGroup.add_argument('--max-k', default = 29, type=int, help='Maximum kmer length [default = 29]')
kmerGroup.add_argument('--k-step', default = 4, type=int, help='K-mer step size [default = 4]')
kmerGroup.add_argument('--sketch-size', default=10000, type=int, help='Kmer sketch size [default = 10000]')
kmerGroup.add_argument('--codon-phased', default=False, action='store_true',
help='Used codon phased seeds X--X--X [default = False]')
kmerGroup.add_argument('--min-kmer-count', default=0, type=int, help='Minimum k-mer count when using reads as input [default = 0]')
kmerGroup.add_argument('--exact-count', default=False, action='store_true',
help='Use the exact k-mer counter with reads '
Expand Down Expand Up @@ -218,7 +220,8 @@ def get_options():
other.add_argument('--use-mash', default=False, action='store_true', help='Use the old mash sketch backend [default = False]')
other.add_argument('--mash', default='mash', help='Location of mash executable')
other.add_argument('--threads', default=1, type=int, help='Number of threads to use [default = 1]')
other.add_argument('--use-gpu', default=False, action='store_true', help='Use a GPU when calculating distances [default = False]')
other.add_argument('--gpu-sketch', default=False, action='store_true', help='Use a GPU when calculating sketches (read data only) [default = False]')
other.add_argument('--gpu-dist', default=False, action='store_true', help='Use a GPU when calculating distances [default = False]')
other.add_argument('--deviceid', default=0, type=int, help='CUDA device ID, if using GPU [default = 0]')
other.add_argument('--no-stream', help='Use temporary files for mash dist interfacing. Reduce memory use/increase disk use for large datasets', default=False, action='store_true')

Expand Down Expand Up @@ -265,20 +268,14 @@ def main():
elif no_sketchlib and (args.min_k < 9 or args.max_k > 31):
sys.stderr.write("When using Mash, Kmer size must be between 9 and 31\n")
sys.exit(1)
elif args.min_k < 5 or args.max_k > 51:
sys.stderr.write("Very short or very long kmers are not recommended\n")
elif args.min_k < 3:
sys.stderr.write("Min k-mer length must be 3 or higher\n")
sys.exit(1)
kmers = np.arange(args.min_k, args.max_k + 1, args.k_step)

# Dict of DB access functions for assign_query (which is out of scope)
dbFuncs = setupDBFuncs(args, kmers, args.min_kmer_count)
createDatabaseDir = dbFuncs['createDatabaseDir']
constructDatabase = dbFuncs['constructDatabase']
queryDatabase = dbFuncs['queryDatabase']
readDBParams = dbFuncs['readDBParams']

# Dict of QC options for passing to database construction and querying functions
qc_dict = {
'run_qc': args.create_db or args.easy_run,
'qc_filter': args.qc_filter,
'retain_failures': args.retain_failures,
'length_sigma': args.length_sigma,
Expand All @@ -287,6 +284,13 @@ def main():
'upper_n': args.upper_n
}

# Dict of DB access functions for assign_query (which is out of scope)
dbFuncs = setupDBFuncs(args, kmers, args.min_kmer_count, qc_dict)
createDatabaseDir = dbFuncs['createDatabaseDir']
constructDatabase = dbFuncs['constructDatabase']
queryDatabase = dbFuncs['queryDatabase']
readDBParams = dbFuncs['readDBParams']

# define sketch sizes, store in hash in case one day
# different kmers get different hash sizes
sketch_sizes = {}
Expand Down Expand Up @@ -357,14 +361,14 @@ def main():
if args.r_files is not None:
# generate sketches and QC sequences
createDatabaseDir(args.output, kmers)
seq_names = constructDatabase(args.r_files, kmers, sketch_sizes,
args.output,
args.threads,
args.overwrite,
strand_preserved = args.strand_preserved,
min_count = args.min_kmer_count,
use_exact = args.exact_count,
qc_dict = qc_dict)
seq_names = constructDatabase(
args.r_files,
kmers,
sketch_sizes,
args.output,
args.threads,
args.overwrite,
calc_random = True)

# Calculate and QC distances
if args.use_mash == True:
Expand Down Expand Up @@ -480,13 +484,13 @@ def main():
if args.use_model:
assignments = model.assign(distMat)
model.plot(distMat, assignments)

#******************************#
#* *#
#* network construction *#
#* *#
#******************************#

if not args.lineage_clustering:
genomeNetwork = constructNetwork(refList, queryList, assignments, model.within_label)
# Ensure all in dists are in final network
Expand Down Expand Up @@ -527,7 +531,7 @@ def main():
#* lineages analysis *#
#* *#
#******************************#

if args.lineage_clustering:

# load distances
Expand All @@ -538,21 +542,21 @@ def main():
sys.exit(1)

refList, queryList, self, distMat = readPickle(distances)

# make directory for new output files
if not os.path.isdir(args.output):
try:
os.makedirs(args.output)
except OSError:
sys.stderr.write("Cannot create output directory\n")
sys.exit(1)

# run lineage clustering
if self:
isolateClustering = cluster_into_lineages(distMat, rank_list, args.output, isolate_list = refList, use_accessory = args.use_accessory, existing_scheme = args.existing_scheme, num_processes = args.threads)
else:
isolateClustering = cluster_into_lineages(distMat, rank_list, args.output, isolate_list = refList, qlist = queryList, use_accessory = args.use_accessory, existing_scheme = args.existing_scheme, num_processes = args.threads)

# load networks
indivNetworks = {}
for rank in rank_list:
Expand Down Expand Up @@ -609,8 +613,9 @@ def main():
#******************************#
# extract limited references from clique by default
if not args.full_db:
newReferencesIndices, newReferencesNames, newReferencesFile, genomeNetwork = extractReferences(genomeNetwork, refList, args.output)
nodes_to_remove = set(range(len(refList))).difference(newReferencesIndices)
newReferencesIndices, newReferencesNames, newReferencesFile, genomeNetwork = \
extractReferences(genomeNetwork, refList, args.output)
nodes_to_remove = set(range(len(refList))).difference(newReferencesIndices)
names_to_remove = [refList[n] for n in nodes_to_remove]
prune_distance_matrix(refList, names_to_remove, distMat,
args.output + "/" + os.path.basename(args.output) + ".dists")
Expand All @@ -621,8 +626,14 @@ def main():
dummyRefFile = writeDummyReferences(newReferencesNames, args.output)
# Read and overwrite previous database
kmers, sketch_sizes = readDBParams(ref_db, kmers, sketch_sizes)
constructDatabase(dummyRefFile, kmers, sketch_sizes, args.output,
True, args.threads, True) # overwrite old db
constructDatabase(dummyRefFile,
kmers,
sketch_sizes,
args.output,
True,
args.threads,
True, # overwrite old db
calc_random = True)
os.remove(dummyRefFile)

genomeNetwork.save(args.output + "/" + os.path.basename(args.output) + '_graph.gt', fmt = 'gt')
Expand All @@ -639,9 +650,7 @@ def main():
args.threads, args.use_mash, args.mash, args.overwrite, args.plot_fit, args.no_stream,
args.max_a_dist, args.model_dir, args.previous_clustering, args.external_clustering,
args.core_only, args.accessory_only, args.phandango, args.grapetree, args.info_csv,
args.rapidnj, args.perplexity, args.assign_lineages, args.existing_scheme, rank_list, args.use_accessory,
strand_preserved = args.strand_preserved, min_count = args.min_kmer_count,
use_exact = args.exact_count, qc_dict = qc_dict)
args.rapidnj, args.perplexity, args.assign_lineages, args.existing_scheme, rank_list, args.use_accessory)

#******************************#
#* *#
Expand Down Expand Up @@ -695,7 +704,7 @@ def main():
complete_distMat, dists_out)

combined_seq, core_distMat, acc_distMat = \
update_distance_matrices(viz_subset, newDistMat,
update_distance_matrices(viz_subset, newDistMat,
threads = args.threads)

# reorder subset to ensure list orders match
Expand Down Expand Up @@ -732,7 +741,7 @@ def main():
prev_clustering = args.previous_clustering
else:
prev_clustering = os.path.dirname(args.distances + ".pkl")

# load clustering
if model.indiv_fitted:
cluster_file = args.ref_db + '/' + os.path.basename(args.ref_db) + '_clusters.csv'
Expand All @@ -745,7 +754,7 @@ def main():
else:
cluster_file = args.ref_db + '/' + os.path.basename(args.ref_db) + '_clusters.csv'
isolateClustering = readIsolateTypeFromCsv(cluster_file, mode = 'clusters', return_dict = True)

# generate selected visualisations
if args.microreact:
sys.stderr.write("Writing microreact output\n")
Expand Down Expand Up @@ -796,9 +805,7 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
kmers, sketch_sizes, threads, use_mash, mash, overwrite,
plot_fit, no_stream, max_a_dist, model_dir, previous_clustering,
external_clustering, core_only, accessory_only, phandango, grapetree,
info_csv, rapidnj, perplexity, assign_lineage, existing_scheme, rank_list, use_accessory,
# added extra arguments for constructing sketchlib libraries
strand_preserved, min_count, use_exact, qc_dict):
info_csv, rapidnj, perplexity, assign_lineage, existing_scheme, rank_list, use_accessory):
"""Code for assign query mode. Written as a separate function so it can be called
by pathogen.watch API
"""
Expand Down Expand Up @@ -845,13 +852,14 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
rNames.append(reference.rstrip())
else:
rNames = getSeqsInDb(ref_db + "/" + os.path.basename(ref_db) + ".h5")
# construct database and QC
qNames = constructDatabase(q_files, kmers, sketch_sizes, output,
threads, overwrite,
strand_preserved = strand_preserved,
min_count = min_count,
use_exact = use_exact,
qc_dict = qc_dict)
# construct database
qNames = constructDatabase(q_files,
kmers,
sketch_sizes,
output,
threads,
overwrite,
calc_random = False)

# run query
refList, queryList, distMat = queryDatabase(rNames = rNames,
Expand All @@ -862,7 +870,7 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
self = False,
number_plot_fits = plot_fit,
threads = threads)

# QC distance matrix
qcPass = qcDistMat(distMat, refList, queryList, max_a_dist)

Expand Down
16 changes: 8 additions & 8 deletions PopPUNK/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def extractReferences(G, dbOrder, outPrefix, existingRefs = None):
references = set(existingRefs)
index_lookup = {v:k for k,v in enumerate(dbOrder)}
reference_indices = [index_lookup[r] for r in references]

# extract cliques from network
cliques_in_overall_graph = [c.tolist() for c in gt.max_cliques(G)]
# order list by size of clique
Expand Down Expand Up @@ -175,12 +175,12 @@ def extractReferences(G, dbOrder, outPrefix, existingRefs = None):
for vertex in vertex_list:
reference_vertex[vertex] = True
reference_indices.add(int(vertex))

# update reference graph if vertices have been added
if network_update_required:
G_ref = gt.GraphView(G, vfilt = reference_vertex)
G_ref = gt.Graph(G_ref, prune = True) # https://stackoverflow.com/questions/30839929/graph-tool-graphview-object

# Order found references as in mash sketch files
reference_names = [dbOrder[int(x)] for x in sorted(reference_indices)]
refFileName = writeReferences(reference_names, outPrefix)
Expand Down Expand Up @@ -258,12 +258,12 @@ def constructNetwork(rlist, qlist, assignments, within_label, summarise = True):
connections = []
self_comparison = True
vertex_labels = rlist

# check if self comparison
if rlist != qlist:
self_comparison = False
vertex_labels.append(qlist)

# identify edges
for assignment, (ref, query) in zip(assignments, listDistInts(rlist, qlist, self = self_comparison)):
if assignment == within_label:
Expand Down Expand Up @@ -418,7 +418,7 @@ def addQueryToNetwork(dbFuncs, rlist, qList, qFile, G, kmers,

# Otherwise only calculate query-query distances for new clusters
else:

# identify potentially new lineages in list: unassigned is a list of queries with no hits
unassigned = set(qSeqs).difference(assigned)
query_indices = {k:v+ref_count for v,k in enumerate(qSeqs)}
Expand Down Expand Up @@ -467,7 +467,7 @@ def addQueryToNetwork(dbFuncs, rlist, qList, qFile, G, kmers,
# finish by updating the network
G.add_vertex(len(qNames))
G.add_edge_list(new_edges)

# including the vertex ID property map
for i,q in enumerate(qSeqs):
G.vp.id[i + len(rlist)] = q
Expand Down Expand Up @@ -521,7 +521,7 @@ def printClusters(G, rlist, outPrefix = "_clusters.csv", oldClusterFile = None,
component = component_assignments.a[isolate_index]
component_rank = component_frequency_ranks[component]
newClusters[component_rank].add(isolate_name)

oldNames = set()

if oldClusterFile != None:
Expand Down
Loading