Skip to content

Commit

Permalink
Merge pull request #115 from johnlees/assign_hotfix
Browse files Browse the repository at this point in the history
Quick fix for query assign code
  • Loading branch information
johnlees authored Oct 20, 2020
2 parents e95a720 + 933cab0 commit 90e558c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 140 deletions.
39 changes: 20 additions & 19 deletions PopPUNK/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@

from .prune_db import prune_distance_matrix

from .sketchlib import calculateQueryQueryDistances

from .utils import setupDBFuncs
from .utils import storePickle
from .utils import readPickle
Expand Down Expand Up @@ -874,15 +872,18 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
# QC distance matrix
qcPass = qcDistMat(distMat, refList, queryList, max_a_dist)

# Calculate query-query distances
ordered_queryList = []

# Assign to strains or lineages, as requested
if assign_lineage:

# Assign lineages by calculating query-query information
ordered_queryList, query_distMat = calculateQueryQueryDistances(dbFuncs, refList, qNames,
kmers, output, use_mash, threads)
qlist1, qlist2, query_distMat = queryDatabase(rNames = qNames,
qNames = qNames,
dbPrefix = output,
queryPrefix = output,
klist = kmers,
self = True,
number_plot_fits = 0,
threads = threads)

else:
# Assign these distances as within or between strain
Expand All @@ -909,15 +910,15 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
core_only, accessory_only)

# Assign clustering by adding to network
ordered_queryList, query_distMat = addQueryToNetwork(dbFuncs, refList, queryList, q_files,
query_distMat = addQueryToNetwork(dbFuncs, refList, queryList, q_files,
genomeNetwork, kmers, queryAssignments, model, output, update_db,
use_mash, threads)

# if running simple query
print_full_clustering = False
if update_db:
print_full_clustering = True
isolateClustering = {'combined': printClusters(genomeNetwork, refList + ordered_queryList,
isolateClustering = {'combined': printClusters(genomeNetwork, refList + queryList,
output + "/" + os.path.basename(output),
old_cluster_file, external_clustering, print_full_clustering)}

Expand All @@ -933,13 +934,13 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
# Update the network + ref list
# only update network if assigning to strains
if full_db is False and assign_lineage is False:
dbOrder = refList + ordered_queryList
dbOrder = refList + queryList
newRepresentativesIndices, newRepresentativesNames, newRepresentativesFile, genomeNetwork = extractReferences(genomeNetwork, dbOrder, output, refList)
isolates_to_remove = set(dbOrder).difference(newRepresentativesNames)
newQueries = [x for x in ordered_queryList if x in frozenset(newRepresentativesNames)] # intersection that maintains order
newQueries = [x for x in queryList if x in frozenset(newRepresentativesNames)] # intersection that maintains order
genomeNetwork.save(output + "/" + os.path.basename(output) + '_graph.gt', fmt = 'gt')
else:
newQueries = ordered_queryList
newQueries = queryList

# Update the sketch database
if newQueries != queryList and use_mash:
Expand All @@ -959,7 +960,7 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
refList, refList_copy, self, ref_distMat = readPickle(distanceFiles)
combined_seq, core_distMat, acc_distMat = \
update_distance_matrices(refList, ref_distMat,
ordered_queryList, distMat,
queryList, distMat,
query_distMat, threads = threads)
complete_distMat = \
np.hstack((pp_sketchlib.squareToLong(core_distMat, threads).reshape(-1, 1),
Expand All @@ -972,7 +973,7 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
isolateClustering = cluster_into_lineages(complete_distMat,
rank_list, output,
combined_seq,
ordered_queryList,
queryList,
expected_lineage_name,
use_accessory,
threads)
Expand Down Expand Up @@ -1000,18 +1001,18 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
if microreact:
sys.stderr.write("Writing microreact output\n")
outputsForMicroreact(combined_seq, core_distMat, acc_distMat, isolateClustering, perplexity,
output, info_csv, rapidnj, ordered_queryList, overwrite)
output, info_csv, rapidnj, queryList, overwrite)
if phandango:
sys.stderr.write("Writing phandango output\n")
outputsForPhandango(combined_seq, core_distMat, isolateClustering, output, info_csv, rapidnj,
queryList = ordered_queryList, overwrite = overwrite, microreact = microreact)
queryList = queryList, overwrite = overwrite, microreact = microreact)
if grapetree:
sys.stderr.write("Writing grapetree output\n")
outputsForGrapetree(combined_seq, core_distMat, isolateClustering, output, info_csv, rapidnj,
queryList = ordered_queryList, overwrite = overwrite, microreact = microreact)
queryList = queryList, overwrite = overwrite, microreact = microreact)
if cytoscape:
sys.stderr.write("Writing cytoscape output\n")
outputsForCytoscape(genomeNetwork, isolateClustering, output, info_csv, ordered_queryList)
sys.stderr.write("Writing cytoscape output\n")
outputsForCytoscape(genomeNetwork, isolateClustering, output, info_csv, queryList)

else:
sys.stderr.write("Need to provide both a reference database with --ref-db and "
Expand Down
95 changes: 26 additions & 69 deletions PopPUNK/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from tempfile import mkstemp, mkdtemp
from collections import defaultdict, Counter

from .sketchlib import calculateQueryQueryDistances

from .utils import iterDistRows
from .utils import listDistInts
from .utils import readIsolateTypeFromCsv
Expand Down Expand Up @@ -353,8 +351,6 @@ def addQueryToNetwork(dbFuncs, rlist, qList, qFile, G, kmers,
(default = 1)
Returns:
qlist1 (list)
Ordered list of queries
distMat (numpy.array)
Query-query distances
"""
Expand All @@ -369,90 +365,53 @@ def addQueryToNetwork(dbFuncs, rlist, qList, qFile, G, kmers,
assigned = set()

# These are returned
qlist1 = None
distMat = None

# Set up query names
if use_mash == True:
# mash must use sequence file names for both testing for
# assignment and for generating a new database
rNames = None
qNames = qList
else:
rNames = qList
qNames = rNames

# identify query sequence files
qSeqs = []
queryFiles = {}
with open(qFile, 'r') as qfile:
for line in qfile.readlines():
info = line.rstrip().split()
if info[0] in qNames:
qSeqs.append(info[1])
queryFiles[info[0]] = info[1]
qqDistMat = None

# store links for each query in a list of edge tuples
ref_count = len(rlist)
for assignment, (ref, query) in zip(assignments, listDistInts(rlist, qNames, self = False)):
for assignment, (ref, query) in zip(assignments, listDistInts(rlist, qList, self = False)):
if assignment == model.within_label:
# query index needs to be adjusted for existing vertices in network
new_edges.append((ref, query + ref_count))
assigned.add(qNames[query])
assigned.add(qList[query])

# Calculate all query-query distances too, if updating database
if queryQuery:
sys.stderr.write("Calculating all query-query distances\n")
qlist1, distMat = calculateQueryQueryDistances(dbFuncs,
rNames,
qNames,
kmers,
queryDB,
use_mash,
threads)

queryAssignation = model.assign(distMat)
for assignment, (ref, query) in zip(queryAssignation, listDistInts(qNames, qNames, self = True)):
qlist1, qlist2, qqDistMat = queryDatabase(rNames = qList,
qNames = qList,
dbPrefix = queryDB,
queryPrefix = queryDB,
klist = kmers,
self = True,
number_plot_fits = 0,
threads = threads)

queryAssignation = model.assign(qqDistMat)
for assignment, (ref, query) in zip(queryAssignation, listDistInts(qList, qList, self = True)):
if assignment == model.within_label:
new_edges.append((ref + ref_count, query + ref_count))

# Otherwise only calculate query-query distances for new clusters
else:

# identify potentially new lineages in list: unassigned is a list of queries with no hits
unassigned = set(qSeqs).difference(assigned)
query_indices = {k:v+ref_count for v,k in enumerate(qSeqs)}
unassigned = set(qList).difference(assigned)
query_indices = {k:v+ref_count for v,k in enumerate(qList)}
# process unassigned query sequences, if there are any
if len(unassigned) > 1:
sys.stderr.write("Found novel query clusters. Calculating distances between them:\n")

# write unassigned queries to file as if a list of references
tmpDirName = mkdtemp(prefix=os.path.basename(queryDB), suffix="_tmp", dir="./")
tmpHandle, tmpFile = mkstemp(prefix=os.path.basename(queryDB), suffix="_tmp", dir=tmpDirName)
with open(tmpFile, 'w') as tFile:
for query in unassigned:
if isinstance(queryFiles[query], list):
seqFiles = "\t".join(queryFiles[query])
elif isinstance(queryFiles[query], str):
seqFiles = queryFiles[query]
else:
raise RuntimeError("Error with formatting of q-file")
tFile.write(query + '\t' + seqFiles + '\n')
sys.stderr.write("Found novel query clusters. Calculating distances between them.\n")

# use database construction methods to find links between unassigned queries
sketchSize = readDBParams(queryDB, kmers, None)[1]
constructDatabase(tmpFile, kmers, sketchSize, tmpDirName, True, threads, False)

qlist1, qlist2, distMat = queryDatabase(rNames = list(unassigned),
qlist1, qlist2, qqDistMat = queryDatabase(rNames = list(unassigned),
qNames = list(unassigned),
dbPrefix = tmpDirName,
queryPrefix = tmpDirName,
dbPrefix = queryDB,
queryPrefix = queryDB,
klist = kmers,
self = True,
number_plot_fits = 0,
threads = threads)

queryAssignation = model.assign(distMat)
queryAssignation = model.assign(qqDistMat)

# identify any links between queries and store in the same links dict
# links dict now contains lists of links both to original database and new queries
Expand All @@ -461,18 +420,15 @@ def addQueryToNetwork(dbFuncs, rlist, qList, qFile, G, kmers,
if assignment == model.within_label:
new_edges.append((query_indices[query1], query_indices[query2]))

# remove directory
shutil.rmtree(tmpDirName)

# finish by updating the network
G.add_vertex(len(qNames))
G.add_vertex(len(qList))
G.add_edge_list(new_edges)

# including the vertex ID property map
for i,q in enumerate(qSeqs):
for i,q in enumerate(qList):
G.vp.id[i + len(rlist)] = q

return qlist1, distMat
return qqDistMat

def printClusters(G, rlist, outPrefix = "_clusters.csv", oldClusterFile = None,
externalClusterCSV = None, printRef = True, printCSV = True, clustering_type = 'combined'):
Expand Down Expand Up @@ -637,7 +593,8 @@ def printExternalClusters(newClusters, extClusterFile, outPrefix,
d = defaultdict(list)

# Read in external clusters
readIsolateTypeFromCsv(clustCSV, mode = 'external', return_dict = False)
extClusters = \
readIsolateTypeFromCsv(extClusterFile, mode = 'external', return_dict = False)

# Go through each cluster (as defined by poppunk) and find the external
# clusters that had previously been assigned to any sample in the cluster
Expand Down
52 changes: 0 additions & 52 deletions PopPUNK/sketchlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,55 +468,3 @@ def queryDatabase(rNames, qNames, dbPrefix, queryPrefix, klist, self = True, num
True, False, threads, use_gpu, deviceid)

return(rNames, qNames, distMat)

def calculateQueryQueryDistances(dbFuncs, rlist, qlist, kmers,
queryDB, use_mash = False, threads = 1):
"""Calculates distances between queries.
Args:
dbFuncs (list)
List of backend functions from :func:`~PopPUNK.utils.setupDBFuncs`
rlist (list)
List of reference names
qlist (list)
List of query names
kmers (list)
List of k-mer sizes
queryDB (str)
Query database location
use_mash (bool)
Use the mash backend
threads (int)
Number of threads to use if new db created
(default = 1)
Returns:
qlist1 (list)
Ordered list of queries
distMat (numpy.array)
Query-query distances
"""

constructDatabase = dbFuncs['constructDatabase']
queryDatabase = dbFuncs['queryDatabase']
readDBParams = dbFuncs['readDBParams']

# Set up query names
if use_mash == True:
rNames = None
qNames = qlist
else:
rNames = qlist
qNames = rNames

# Calculate all query-query distances too, if updating database
qlist1, qlist2, distMat = queryDatabase(rNames = rNames,
qNames = qNames,
dbPrefix = queryDB,
queryPrefix = queryDB,
klist = kmers,
self = True,
number_plot_fits = 0,
threads = threads)

return qlist1, distMat

0 comments on commit 90e558c

Please sign in to comment.