Skip to content

Commit

Permalink
Pick references in a way which maintains connected components
Browse files Browse the repository at this point in the history
Add extra references from paths between split CCs in the original graph. Ensures consistency of clusters when using a reference-only graph with assign query (and prevents warnings).

Closes #50
  • Loading branch information
johnlees committed Aug 23, 2019
1 parent 2f4cbd4 commit 4155aac
Showing 1 changed file with 64 additions and 23 deletions.
87 changes: 64 additions & 23 deletions PopPUNK/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ def extractReferences(G, mashOrder, outPrefix, existingRefs = None):
An updated list of the reference names
"""
if existingRefs == None:
references = []
references = set()
else:
references = existingRefs
references = set(existingRefs)

# extract cliques from network
cliques = list(nx.find_cliques(G))
Expand All @@ -120,10 +120,43 @@ def extractReferences(G, mashOrder, outPrefix, existingRefs = None):
alreadyRepresented = 1
break
if alreadyRepresented == 0:
references.append(clique[0])
references.add(clique[0])

# Find any clusters which are represented by multiple references
clusters = printClusters(G, printCSV=False)
ref_clusters = set()
multi_ref_clusters = set()
for reference in references:
if clusters[reference] in ref_clusters:
multi_ref_clusters.add(clusters[reference])
else:
ref_clusters.add(clusters[reference])

# Check if these multi reference components have been split
if len(multi_ref_clusters) > 0:
# Initial reference graph
ref_G = G.copy()
ref_G.remove_nodes_from(set(ref_G.nodes).difference(references))

for multi_ref_cluster in multi_ref_clusters:
# Get a list of nodes that need to be in the same component
check = []
for reference in references:
if clusters[reference] == multi_ref_cluster:
check.append(reference)

# Pairwise check that nodes are in same component
for i in range(len(check)):
component = nx.node_connected_component(ref_G, check[i])
for j in range(i, len(check)):
# Add intermediate nodes
if check[j] not in component:
new_path = nx.shortest_path(G, check[i], check[j])
for node in new_path:
references.add(node)

# Order found references as in mash sketch files
references = [x for x in mashOrder if x in frozenset(references)]
references = [x for x in mashOrder if x in references]
refFileName = writeReferences(references, outPrefix)
return references, refFileName

Expand Down Expand Up @@ -328,7 +361,8 @@ def addQueryToNetwork(rlist, qlist, qfile, G, kmers, assignments, model,

return qlist1, distMat

def printClusters(G, outPrefix, oldClusterFile = None, externalClusterCSV = None, printRef = True):
def printClusters(G, outPrefix = "_clusters.csv", oldClusterFile = None,
externalClusterCSV = None, printRef = True, printCSV = True):
"""Get cluster assignments
Also writes assignments to a CSV file
Expand All @@ -338,7 +372,9 @@ def printClusters(G, outPrefix, oldClusterFile = None, externalClusterCSV = None
Network used to define clusters (from :func:`~constructNetwork` or
:func:`~addQueryToNetwork`)
outPrefix (str)
Prefix for output CSV (_clusters.csv)
Prefix for output CSV
Default = "_clusters.csv"
oldClusterFile (str)
CSV with previous cluster assignments.
Pass to ensure consistency in cluster assignment name.
Expand All @@ -352,6 +388,10 @@ def printClusters(G, outPrefix, oldClusterFile = None, externalClusterCSV = None
printRef (bool)
If false, print only query sequences in the output
Default = True
printCSV (bool)
Print results to file
Default = True
Returns:
Expand Down Expand Up @@ -432,23 +472,24 @@ def printClusters(G, outPrefix, oldClusterFile = None, externalClusterCSV = None
clustering[cluster_member] = cls_id

# print clustering to file
outFileName = outPrefix + "_clusters.csv"
with open(outFileName, 'w') as cluster_file:
cluster_file.write("Taxon,Cluster\n")

# sort the clusters by frequency - define a list with a custom sort order
# first line gives tuples e.g. (1, 28), (2, 17) - cluster 1 has 28 members, cluster 2 has 17 members
# second line takes first element - the cluster IDs sorted by frequency
freq_order = sorted(dict(Counter(clustering.values())).items(), key=operator.itemgetter(1), reverse=True)
freq_order = [x[0] for x in freq_order]

# iterate through cluster dictionary sorting by value using above custom sort order
for cluster_member, cluster_name in sorted(clustering.items(), key=lambda i:freq_order.index(i[1])):
if printRef or cluster_member not in oldNames:
cluster_file.write(",".join((cluster_member, str(cluster_name))) + "\n")

if externalClusterCSV is not None:
printExternalClusters(newClusters, externalClusterCSV, outPrefix, oldNames, printRef)
if printCSV:
outFileName = outPrefix + "_clusters.csv"
with open(outFileName, 'w') as cluster_file:
cluster_file.write("Taxon,Cluster\n")

# sort the clusters by frequency - define a list with a custom sort order
# first line gives tuples e.g. (1, 28), (2, 17) - cluster 1 has 28 members, cluster 2 has 17 members
# second line takes first element - the cluster IDs sorted by frequency
freq_order = sorted(dict(Counter(clustering.values())).items(), key=operator.itemgetter(1), reverse=True)
freq_order = [x[0] for x in freq_order]

# iterate through cluster dictionary sorting by value using above custom sort order
for cluster_member, cluster_name in sorted(clustering.items(), key=lambda i:freq_order.index(i[1])):
if printRef or cluster_member not in oldNames:
cluster_file.write(",".join((cluster_member, str(cluster_name))) + "\n")

if externalClusterCSV is not None:
printExternalClusters(newClusters, externalClusterCSV, outPrefix, oldNames, printRef)

return(clustering)

Expand Down

0 comments on commit 4155aac

Please sign in to comment.