Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Label order #152

Merged
merged 6 commits into from
Feb 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions PopPUNK/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,19 +292,19 @@ def main():

rNames = seq_names
qNames = seq_names
refList, queryList, distMat = queryDatabase(rNames = rNames,
qNames = qNames,
dbPrefix = args.output,
queryPrefix = args.output,
klist = kmers,
self = True,
number_plot_fits = args.plot_fit,
threads = args.threads)
qcDistMat(distMat, refList, queryList, args.max_a_dist)
distMat = queryDatabase(rNames = rNames,
qNames = qNames,
dbPrefix = args.output,
queryPrefix = args.output,
klist = kmers,
self = True,
number_plot_fits = args.plot_fit,
threads = args.threads)
qcDistMat(distMat, rNames, qNames, args.max_a_dist)

# Save results
dists_out = args.output + "/" + os.path.basename(args.output) + ".dists"
storePickle(refList, queryList, True, distMat, dists_out)
storePickle(rNames, qNames, True, distMat, dists_out)

# Plot results
plot_scatter(distMat,
Expand Down
68 changes: 36 additions & 32 deletions PopPUNK/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def assign_query(dbFuncs,

from .prune_db import prune_distance_matrix

from .sketchlib import calculateQueryQueryDistances
from .sketchlib import addRandom

from .utils import storePickle
Expand Down Expand Up @@ -131,37 +130,37 @@ def assign_query(dbFuncs,
codon_phased = codon_phased,
calc_random = False)
# run query
refList, queryList, qrDistMat = queryDatabase(rNames = rNames,
qNames = qNames,
dbPrefix = ref_db,
queryPrefix = output,
klist = kmers,
self = False,
number_plot_fits = plot_fit,
threads = threads)
qrDistMat = queryDatabase(rNames = rNames,
qNames = qNames,
dbPrefix = ref_db,
queryPrefix = output,
klist = kmers,
self = False,
number_plot_fits = plot_fit,
threads = threads)
# QC distance matrix
qcPass = qcDistMat(qrDistMat, refList, queryList, max_a_dist)
qcPass = qcDistMat(qrDistMat, rNames, qNames, max_a_dist)

# Load the network based on supplied options
genomeNetwork, old_cluster_file = \
fetchNetwork(prev_clustering,
model,
refList,
rNames,
ref_graph = use_ref_graph,
core_only = core_only,
accessory_only = accessory_only)

if model.type == 'lineage':
# Assign lineages by calculating query-query information
addRandom(output, qNames, kmers, strand_preserved, overwrite, threads)
qlist1, qlist2, qqDistMat = queryDatabase(rNames = qNames,
qNames = qNames,
dbPrefix = output,
queryPrefix = output,
klist = kmers,
self = True,
number_plot_fits = 0,
threads = threads)
qqDistMat = queryDatabase(rNames = qNames,
qNames = qNames,
dbPrefix = output,
queryPrefix = output,
klist = kmers,
self = True,
number_plot_fits = 0,
threads = threads)
model.extend(qqDistMat, qrDistMat)

genomeNetwork = {}
Expand All @@ -182,18 +181,18 @@ def assign_query(dbFuncs,

isolateClustering[rank] = \
printClusters(genomeNetwork[rank],
refList + queryList,
rNames + qNames,
printCSV = False)

overall_lineage = createOverallLineage(model.ranks, isolateClustering)
writeClusterCsv(
output + "/" + os.path.basename(output) + '_lineages.csv',
refList + queryList,
refList + queryList,
rNames + qNames,
rNames + qNames,
overall_lineage,
output_format = 'phandango',
epiCsv = None,
queryNames = queryList,
queryNames = qNames,
suffix = '_Lineage')

else:
Expand All @@ -206,14 +205,14 @@ def assign_query(dbFuncs,
else:
weights = None
qqDistMat = \
addQueryToNetwork(dbFuncs, refList, queryList,
addQueryToNetwork(dbFuncs, rNames, qNames,
genomeNetwork, kmers,
queryAssignments, model, output, update_db,
strand_preserved,
weights = weights, threads = threads)

isolateClustering = \
{'combined': printClusters(genomeNetwork, refList + queryList,
{'combined': printClusters(genomeNetwork, rNames + qNames,
output + "/" + os.path.basename(output),
old_cluster_file,
external_clustering,
Expand Down Expand Up @@ -248,24 +247,29 @@ def assign_query(dbFuncs,

combined_seq, core_distMat, acc_distMat = \
update_distance_matrices(refList, rrDistMat,
queryList, qrDistMat,
qqDistMat, threads = threads)
assert combined_seq == refList + queryList
qNames, qrDistMat,
qqDistMat, threads = threads)
assert combined_seq == refList + qNames

# Get full distance matrix and save
complete_distMat = \
np.hstack((pp_sketchlib.squareToLong(core_distMat, threads).reshape(-1, 1),
pp_sketchlib.squareToLong(acc_distMat, threads).reshape(-1, 1)))
storePickle(combined_seq, combined_seq, True, complete_distMat, dists_out)

# Copy model if needed
if output != model.outPrefix:
model.outPrefix = output
model.save()

# Clique pruning
if model.type != 'lineage':
dbOrder = refList + queryList
dbOrder = refList + qNames
newRepresentativesIndices, newRepresentativesNames, \
newRepresentativesFile, genomeNetwork = \
extractReferences(genomeNetwork, dbOrder, output, refList, threads = threads)
# intersection that maintains order
newQueries = [x for x in queryList if x in frozenset(newRepresentativesNames)]
newQueries = [x for x in qNames if x in frozenset(newRepresentativesNames)]

# could also have newRepresentativesNames in this diff (should be the same) - but want
# to ensure consistency with the network in case of bad input/bugs
Expand All @@ -280,12 +284,12 @@ def assign_query(dbFuncs,
genomeNetwork.save(output + "/" + os.path.basename(output) + '.refs_graph.gt', fmt = 'gt')
removeFromDB(output, output, names_to_remove)
os.rename(output + "/" + os.path.basename(output) + ".tmp.h5",
output + "/" + os.path.basename(output) + ".refs.h5")
output + "/" + os.path.basename(output) + ".refs.h5")

# ensure sketch and distMat order match
assert postpruning_combined_seq == refList + newQueries
else:
storePickle(refList, queryList, False, qrDistMat, dists_out)
storePickle(rNames, qNames, False, qrDistMat, dists_out)
if save_partial_query_graph:
if model.type == 'lineage':
genomeNetwork[min(model.ranks)].save(output + "/" + os.path.basename(output) + '_graph.gt', fmt = 'gt')
Expand Down
32 changes: 16 additions & 16 deletions PopPUNK/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,14 +489,14 @@ def addQueryToNetwork(dbFuncs, rList, qList, G, kmers,
if queryQuery:
sys.stderr.write("Calculating all query-query distances\n")
addRandom(queryDB, qList, kmers, strand_preserved, threads = threads)
qlist1, qlist2, qqDistMat = queryDatabase(rNames = qList,
qNames = qList,
dbPrefix = queryDB,
queryPrefix = queryDB,
klist = kmers,
self = True,
number_plot_fits = 0,
threads = threads)
qqDistMat = queryDatabase(rNames = qList,
qNames = qList,
dbPrefix = queryDB,
queryPrefix = queryDB,
klist = kmers,
self = True,
number_plot_fits = 0,
threads = threads)

queryAssignation = model.assign(qqDistMat)
for row_idx, (assignment, (ref, query)) in enumerate(zip(queryAssignation, listDistInts(qList, qList, self = True))):
Expand All @@ -519,14 +519,14 @@ def addQueryToNetwork(dbFuncs, rList, qList, G, kmers,

# use database construction methods to find links between unassigned queries
addRandom(queryDB, qList, kmers, strand_preserved, threads = threads)
qlist1, qlist2, qqDistMat = queryDatabase(rNames = list(unassigned),
qNames = list(unassigned),
dbPrefix = queryDB,
queryPrefix = queryDB,
klist = kmers,
self = True,
number_plot_fits = 0,
threads = threads)
qqDistMat = queryDatabase(rNames = list(unassigned),
qNames = list(unassigned),
dbPrefix = queryDB,
queryPrefix = queryDB,
klist = kmers,
self = True,
number_plot_fits = 0,
threads = threads)

queryAssignation = model.assign(qqDistMat)

Expand Down
46 changes: 2 additions & 44 deletions PopPUNK/sketchlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def joinDBs(db1, db2, output, update_random = None):
# Need to close before adding random
hdf_join.close()
if len(sequence_names) > 2:
sys.stderr.write("Updating random match chances")
sys.stderr.write("Updating random match chances\n")
pp_sketchlib.addRandom(join_prefix + ".tmp",
sequence_names,
kmer_size,
Expand Down Expand Up @@ -517,10 +517,6 @@ def queryDatabase(rNames, qNames, dbPrefix, queryPrefix, klist, self = True, num
(default = 0)

Returns:
refList (list)
Names of reference sequences
queryList (list)
Names of query sequences
distMat (numpy.array)
Core distances (column 0) and accessory distances (column 1) between
refList and queryList
Expand Down Expand Up @@ -568,46 +564,8 @@ def queryDatabase(rNames, qNames, dbPrefix, queryPrefix, klist, self = True, num
distMat = pp_sketchlib.queryDatabase(ref_db, query_db, rNames, qNames, klist,
True, False, threads, use_gpu, deviceid)

return(rNames, qNames, distMat)
return distMat

def calculateQueryQueryDistances(dbFuncs, qlist, kmers,
queryDB, threads = 1):
"""Calculates distances between queries.

Args:
dbFuncs (list)
List of backend functions from :func:`~PopPUNK.utils.setupDBFuncs`
rlist (list)
List of reference names
qlist (list)
List of query names
kmers (list)
List of k-mer sizes
queryDB (str)
Query database location
threads (int)
Number of threads to use if new db created
(default = 1)

Returns:
qlist1 (list)
Ordered list of queries
distMat (numpy.array)
Query-query distances
"""

queryDatabase = dbFuncs['queryDatabase']

qlist1, qlist2, distMat = queryDatabase(rNames = qlist,
qNames = qlist,
dbPrefix = queryDB,
queryPrefix = queryDB,
klist = kmers,
self = True,
number_plot_fits = 0,
threads = threads)

return qlist1, distMat

def sketchlibAssemblyQC(prefix, klist, qc_dict, strand_preserved, threads):
"""Calculates random match probability based on means of genomes
Expand Down
3 changes: 3 additions & 0 deletions PopPUNK/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,9 @@ def readRfile(rFile, oneSeq=False):
"Must contain sample name and file, tab separated\n")
sys.exit(1)

if "/" in rFields[0]:
sys.stderr.write("Sample names may not contain slashes\n")
sys.exit(1)
names.append(rFields[0])
sample_files = []
for sequence in rFields[1:]:
Expand Down
76 changes: 40 additions & 36 deletions PopPUNK/visualise.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,43 +342,47 @@ def generate_visualisations(query_db,
# Generate MST
mst_tree = None
mst_graph = None
if tree == 'mst' or tree == 'both':
existing_tree = None
if not overwrite:
existing_tree = load_tree(output, "MST", distances=mst_distances)
if existing_tree is None:
complete_distMat = \
np.hstack((pp_sketchlib.squareToLong(core_distMat, threads).reshape(-1, 1),
pp_sketchlib.squareToLong(acc_distMat, threads).reshape(-1, 1)))
# Dense network may be slow
sys.stderr.write("Generating MST from dense distances (may be slow)\n")
G = constructNetwork(combined_seq,
combined_seq,
np.zeros(complete_distMat.shape[0]),
0,
weights=complete_distMat,
weights_type=mst_distances,
summarise=False)
mst_graph = generate_minimum_spanning_tree(G)
drawMST(mst_graph, output, isolateClustering, overwrite)
mst_tree = mst_to_phylogeny(mst_graph, isolateNameToLabel(combined_seq))
else:
mst_tree = existing_tree

# Generate NJ tree
nj_tree = None
if tree == 'nj' or tree == 'both':
existing_tree = None
if not overwrite:
existing_tree = load_tree(output, "NJ")
if existing_tree is None:
nj_tree = generate_nj_tree(core_distMat,
combined_seq,
output,
rapidnj,
threads = threads)
else:
nj_tree = existing_tree
if len(combined_seq) >= 3:
# MST tree
if tree == 'mst' or tree == 'both':
existing_tree = None
if not overwrite:
existing_tree = load_tree(output, "MST", distances=mst_distances)
if existing_tree is None:
complete_distMat = \
np.hstack((pp_sketchlib.squareToLong(core_distMat, threads).reshape(-1, 1),
pp_sketchlib.squareToLong(acc_distMat, threads).reshape(-1, 1)))
# Dense network may be slow
sys.stderr.write("Generating MST from dense distances (may be slow)\n")
G = constructNetwork(combined_seq,
combined_seq,
np.zeros(complete_distMat.shape[0]),
0,
weights=complete_distMat,
weights_type=mst_distances,
summarise=False)
mst_graph = generate_minimum_spanning_tree(G)
drawMST(mst_graph, output, isolateClustering, overwrite)
mst_tree = mst_to_phylogeny(mst_graph, isolateNameToLabel(combined_seq))
else:
mst_tree = existing_tree

# Generate NJ tree
if tree == 'nj' or tree == 'both':
existing_tree = None
if not overwrite:
existing_tree = load_tree(output, "NJ")
if existing_tree is None:
nj_tree = generate_nj_tree(core_distMat,
combined_seq,
output,
rapidnj,
threads = threads)
else:
nj_tree = existing_tree
else:
sys.stderr.write("Fewer than three sequences, not drawing trees\n")

# Now have all the objects needed to generate selected visualisations
if microreact:
Expand Down
Loading