From 7260e5025e3422cca7206fb2b348e5d0f451f01b Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Thu, 7 May 2020 17:59:27 +0100
Subject: [PATCH 01/56] First addition of graph-tools code

---
 PopPUNK/network.py | 35 +++++++++++++++++++++++++++--------
 1 file changed, 27 insertions(+), 8 deletions(-)

diff --git a/PopPUNK/network.py b/PopPUNK/network.py
index 4930bef4..d154b16c 100644
--- a/PopPUNK/network.py
+++ b/PopPUNK/network.py
@@ -13,6 +13,7 @@
 import shutil
 import subprocess
 import networkx as nx
+import graph_tool.all as gt
 import numpy as np
 import pandas as pd
 from tempfile import mkstemp, mkdtemp
@@ -21,6 +22,7 @@
 from .sketchlib import calculateQueryQueryDistances
 
 from .utils import iterDistRows
+from .utils import listDistInts
 from .utils import readIsolateTypeFromCsv
 from .utils import readRfile
 
@@ -226,20 +228,32 @@ def constructNetwork(rlist, qlist, assignments, within_label, summarise = True):
         G (networkx.Graph)
             The resulting network
     """
+    # data structures
     connections = []
-    for assignment, (ref, query) in zip(assignments, iterDistRows(rlist, qlist, self=True)):
+    self_comparison = True
+    num_vertices = len(rlist)
+    
+    # check if self comparison
+    if rlist != qlist:
+        self_comparison = False
+        num_vertices = num_vertices + len(qlist)
+    
+    # identify edges
+    for assignment, (ref, query) in zip(assignments, listDistInts(rlist, qlist, self = self_comparison)):
         if assignment == within_label:
             connections.append((ref, query))
 
+    # issue warning
     density_proportion = len(connections) / (0.5 * (len(rlist) * (len(rlist) + 1)))
     if density_proportion > 0.4 or len(connections) > 500000:
         sys.stderr.write("Warning: trying to create very dense network\n")
 
     # build the graph
-    G = nx.Graph()
-    G.add_nodes_from(rlist)
-    for connection in connections:
-        G.add_edge(*connection)
+    G = gt.Graph(directed = False)
+    G.add_vertex(num_vertices)
+    G.add_edge_list(connections)
+#    for connection in connections:
+#        G.add_edge(*connection)
 
     # give some summaries
     if summarise:
@@ -269,9 +283,14 @@ def networkSummary(G):
         score (float)
             A score of network fit, given by :math:`\mathrm{transitivity} * (1-\mathrm{density})`
     """
-    components = nx.number_connected_components(G)
-    density = nx.density(G)
-    transitivity = nx.transitivity(G)
+    comp, hist = gt.label_components(G)
+    components = len(set(comp.a))
+#    density = nx.density(G)
+    density = len(list(G.edges()))/(len(list(G.vertices())) * (len(list(G.vertices())) - 1))
+#    transitivity = nx.transitivity(G)
+    transitivity = gt.global_clustering(G)[0]
+    print('Transitivity: '+ str(gt.global_clustering(G)))
+    quit()
     score = transitivity * (1-density)
 
     return(components, density, transitivity, score)

From 09ebf6d2851a1dc23059124e67af6da92b58a1b6 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Thu, 7 May 2020 18:01:32 +0100
Subject: [PATCH 02/56] Include listDistInts routine in utils.py

---
 PopPUNK/utils.py | 36 ++++++++++++++++++++++++++++++++++++
 1 file changed, 36 insertions(+)

diff --git a/PopPUNK/utils.py b/PopPUNK/utils.py
index 9618c945..4d396eb8 100644
--- a/PopPUNK/utils.py
+++ b/PopPUNK/utils.py
@@ -159,6 +159,42 @@ def iterDistRows(refSeqs, querySeqs, self=True):
             for ref in refSeqs:
                 yield(ref, query)
 
+def listDistInts(refSeqs, querySeqs, self=True):
+    """Gets the ref and query ID for each row of the distance matrix
+
+    Returns an iterable with ref and query ID pairs by row.
+
+    Args:
+        refSeqs (list)
+            List of reference sequence names.
+        querySeqs (list)
+            List of query sequence names.
+        self (bool)
+            Whether a self-comparison, used when constructing a database.
+            Requires refSeqs == querySeqs
+            Default is True
+    Returns:
+        ref, query (str, str)
+            Iterable of tuples with ref and query names for each distMat row.
+    """
+    n = 0
+    num_ref = len(refSeqs)
+    num_query = len(querySeqs)
+    if self:
+        comparisons = [(0,0)] * (num_ref * (num_ref-1))
+        if refSeqs != querySeqs:
+            raise RuntimeError('refSeqs must equal querySeqs for db building (self = true)')
+        for i in range(num_ref):
+            for j in range(i + 1, num_ref):
+                comparisons[n] = (j, i)
+    else:
+        comparisons = [(0,0)] * (len(refSeqs) * len(querySeqs))
+        for i in range(num_query):
+            for j in range(num_ref):
+                comparisons[n] = (j, i)
+                
+    return comparisons
+
 def writeTmpFile(fileList):
     """Writes a list to a temporary file. Used for turning variable into mash
     input.

From 9eb7d8413cfea0947d3cf0484a04b7609e6262ca Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Thu, 7 May 2020 22:00:35 +0100
Subject: [PATCH 03/56] Functioning model refinement with graph-tools

---
 PopPUNK/__main__.py |  8 ++++++--
 PopPUNK/network.py  | 22 ++++++++++++----------
 PopPUNK/utils.py    |  2 ++
 3 files changed, 20 insertions(+), 12 deletions(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index 40909770..0c5a0adf 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -8,6 +8,7 @@
 # additional
 import numpy as np
 import networkx as nx
+import graph_tool.all as gt
 import subprocess
 
 # import poppunk package
@@ -446,11 +447,12 @@ def main():
         genomeNetwork = constructNetwork(refList, queryList, assignments, model.within_label)
 
         # Ensure all in dists are in final network
-        networkMissing = set(refList).difference(list(genomeNetwork.nodes()))
+        networkMissing = set(range(len(refList))).difference(list(genomeNetwork.vertices()))
         if len(networkMissing) > 0:
             sys.stderr.write("WARNING: Samples " + ",".join(networkMissing) + " are missing from the final network\n")
 
         isolateClustering = {fit_type: printClusters(genomeNetwork,
+                                                    refList, # assume no rlist+qlist?
                                                      args.output + "/" + os.path.basename(args.output),
                                                      externalClusterCSV = args.external_clustering)}
 
@@ -461,6 +463,7 @@ def main():
                 indivAssignments = model.assign(distMat, slope)
                 indivNetworks[dist_type] = constructNetwork(refList, queryList, indivAssignments, model.within_label)
                 isolateClustering[dist_type] = printClusters(indivNetworks[dist_type],
+                                                refList,
                                                  args.output + "/" + os.path.basename(args.output) + "_" + dist_type,
                                                  externalClusterCSV = args.external_clustering)
                 nx.write_gpickle(indivNetworks[dist_type], args.output + "/" + os.path.basename(args.output) +
@@ -806,7 +809,8 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
             print_full_clustering = False
             if update_db:
                 print_full_clustering = True
-            isolateClustering = {'combined': printClusters(genomeNetwork, output + "/" + os.path.basename(output),
+            isolateClustering = {'combined': printClusters(genomeNetwork, refList + ordered_queryList,
+                                                            output + "/" + os.path.basename(output),
                                                             old_cluster_file, external_clustering, print_full_clustering)}
 
         # Update DB as requested
diff --git a/PopPUNK/network.py b/PopPUNK/network.py
index d154b16c..07ffff7b 100644
--- a/PopPUNK/network.py
+++ b/PopPUNK/network.py
@@ -16,6 +16,7 @@
 import graph_tool.all as gt
 import numpy as np
 import pandas as pd
+from scipy.stats import rankdata
 from tempfile import mkstemp, mkdtemp
 from collections import defaultdict, Counter
 
@@ -283,14 +284,10 @@ def networkSummary(G):
         score (float)
             A score of network fit, given by :math:`\mathrm{transitivity} * (1-\mathrm{density})`
     """
-    comp, hist = gt.label_components(G)
-    components = len(set(comp.a))
-#    density = nx.density(G)
-    density = len(list(G.edges()))/(len(list(G.vertices())) * (len(list(G.vertices())) - 1))
-#    transitivity = nx.transitivity(G)
+    component_assignments, component_frequencies = gt.label_components(G)
+    components = len(component_frequencies)
+    density = len(list(G.edges()))/(0.5 * len(list(G.vertices())) * (len(list(G.vertices())) - 1))
     transitivity = gt.global_clustering(G)[0]
-    print('Transitivity: '+ str(gt.global_clustering(G)))
-    quit()
     score = transitivity * (1-density)
 
     return(components, density, transitivity, score)
@@ -433,7 +430,7 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,
 
     return qlist1, distMat
 
-def printClusters(G, outPrefix = "_clusters.csv", oldClusterFile = None,
+def printClusters(G, rlist, outPrefix = "_clusters.csv", oldClusterFile = None,
                   externalClusterCSV = None, printRef = True, printCSV = True, clustering_type = 'combined'):
     """Get cluster assignments
 
@@ -472,13 +469,18 @@ def printClusters(G, outPrefix = "_clusters.csv", oldClusterFile = None,
     if oldClusterFile == None and printRef == False:
         raise RuntimeError("Trying to print query clusters with no query sequences")
 
-    newClusters = sorted(nx.connected_components(G), key=len, reverse=True)
+    # get a sorted list of component assignments
+    newClusters = {}
+    component_assignments, component_frequencies = gt.label_components(G)
+    component_frequency_ranks = len(component_frequencies) - rankdata(component_frequencies, method = 'ordinal').astype(int)
+    for n,v in enumerate(rlist):
+        newClusters[rlist[n]] = component_frequency_ranks[component_assignments.a]
+    
     oldNames = set()
 
     if oldClusterFile != None:
         oldAllClusters = readIsolateTypeFromCsv(oldClusterFile, mode = 'external', return_dict = False)
         oldClusters = oldAllClusters[list(oldAllClusters.keys())[0]]
-        print('oldCluster is ' + str(oldClusters))
         new_id = len(oldClusters.keys()) + 1 # 1-indexed
         while new_id in oldClusters:
             new_id += 1 # in case clusters have been merged
diff --git a/PopPUNK/utils.py b/PopPUNK/utils.py
index 4d396eb8..ef3ac112 100644
--- a/PopPUNK/utils.py
+++ b/PopPUNK/utils.py
@@ -187,11 +187,13 @@ def listDistInts(refSeqs, querySeqs, self=True):
         for i in range(num_ref):
             for j in range(i + 1, num_ref):
                 comparisons[n] = (j, i)
+                n = n + 1
     else:
         comparisons = [(0,0)] * (len(refSeqs) * len(querySeqs))
         for i in range(num_query):
             for j in range(num_ref):
                 comparisons[n] = (j, i)
+                n = n + 1
                 
     return comparisons
 

From 09aacafb551f2e1a84e114b6cab4a141792e9c81 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Fri, 8 May 2020 14:20:00 +0100
Subject: [PATCH 04/56] Update extraction of references from network

---
 PopPUNK/__main__.py |  9 +++--
 PopPUNK/network.py  | 95 ++++++++++++++++++++++++++-------------------
 2 files changed, 60 insertions(+), 44 deletions(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index 0c5a0adf..344dcb92 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -466,8 +466,9 @@ def main():
                                                 refList,
                                                  args.output + "/" + os.path.basename(args.output) + "_" + dist_type,
                                                  externalClusterCSV = args.external_clustering)
-                nx.write_gpickle(indivNetworks[dist_type], args.output + "/" + os.path.basename(args.output) +
-                                                           "_" + dist_type + '_graph.gpickle')
+                indivNetworks[dist_type].save(args.output + "/" + os.path.basename(args.output) +
+                "_" + dist_type + '_graph.gt', fmt = 'gt')
+
             if args.core_only:
                 fit_type = 'core'
                 genomeNetwork = indivNetworks['core']
@@ -534,7 +535,7 @@ def main():
                                 args.estimated_length, True, args.threads, True) # overwrite old db
                 os.remove(dummyRefFile)
 
-        nx.write_gpickle(genomeNetwork, args.output + "/" + os.path.basename(args.output) + '_graph.gpickle')
+        genomeNetwork.save(args.output + "/" + os.path.basename(args.output) + '_graph.gt', fmt = 'gt')
 
     #******************************#
     #*                            *#
@@ -829,7 +830,7 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
                 newRepresentativesNames, newRepresentativesFile = extractReferences(genomeNetwork, mashOrder, output, refList)
                 genomeNetwork.remove_nodes_from(set(genomeNetwork.nodes).difference(newRepresentativesNames))
                 newQueries = [x for x in ordered_queryList if x in frozenset(newRepresentativesNames)] # intersection that maintains order
-                nx.write_gpickle(genomeNetwork, output + "/" + os.path.basename(output) + '_graph.gpickle')
+                genomeNetwork.save(output + "/" + os.path.basename(output) + '_graph.gt', fmt = 'gt')
             else:
                 newQueries = ordered_queryList
 
diff --git a/PopPUNK/network.py b/PopPUNK/network.py
index 07ffff7b..b6391cb7 100644
--- a/PopPUNK/network.py
+++ b/PopPUNK/network.py
@@ -59,24 +59,24 @@ def fetchNetwork(network_dir, model, refList,
     # If a refined fit, may use just core or accessory distances
     if core_only and model.type == 'refine':
         model.slope = 0
-        network_file = network_dir + "/" + os.path.basename(network_dir) + '_core_graph.gpickle'
+        network_file = network_dir + "/" + os.path.basename(network_dir) + '_core_graph.gt'
         cluster_file = network_dir + "/" + os.path.basename(network_dir) + '_core_clusters.csv'
     elif accessory_only and model.type == 'refine':
         model.slope = 1
-        network_file = network_dir + "/" + os.path.basename(network_dir) + '_accessory_graph.gpickle'
+        network_file = network_dir + "/" + os.path.basename(network_dir) + '_accessory_graph.gt'
         cluster_file = network_dir + "/" + os.path.basename(network_dir) + '_accessory_clusters.csv'
     else:
-        network_file = network_dir + "/" + os.path.basename(network_dir) + '_graph.gpickle'
+        network_file = network_dir + "/" + os.path.basename(network_dir) + '_graph.gt'
         cluster_file = network_dir + "/" + os.path.basename(network_dir) + '_clusters.csv'
         if core_only or accessory_only:
             sys.stderr.write("Can only do --core-only or --accessory-only fits from "
                              "a refined fit. Using the combined distances.\n")
 
-    genomeNetwork = nx.read_gpickle(network_file)
+    genomeNetwork = gt.load_graph(network_file)
     sys.stderr.write("Network loaded: " + str(genomeNetwork.number_of_nodes()) + " samples\n")
 
     # Ensure all in dists are in final network
-    networkMissing = set(refList).difference(list(genomeNetwork.nodes()))
+    networkMissing = set(range(len(refList))).difference(list(genomeNetwork.vertices()))
     if len(networkMissing) > 0:
         sys.stderr.write("WARNING: Samples " + ",".join(networkMissing) + " are missing from the final network\n")
 
@@ -110,9 +110,9 @@ def extractReferences(G, mashOrder, outPrefix, existingRefs = None):
         references = set(existingRefs)
 
     # extract cliques from network
-    cliques = list(nx.find_cliques(G))
+    cliques = [c.tolist() for c in gt.max_cliques(G)]
     # order list by size of clique
-    cliques.sort(key = len, reverse=True)
+    cliques.sort(key = len, reverse = True)
     # iterate through cliques
     for clique in cliques:
         alreadyRepresented = 0
@@ -124,40 +124,53 @@ def extractReferences(G, mashOrder, outPrefix, existingRefs = None):
             references.add(clique[0])
 
     # Find any clusters which are represented by multiple references
-    clusters = printClusters(G, printCSV=False)
-    ref_clusters = set()
-    multi_ref_clusters = set()
-    for reference in references:
-        if clusters[reference] in ref_clusters:
-            multi_ref_clusters.add(clusters[reference])
+    # First get cluster assignments
+    clusters = printClusters(G, mashOrder, printCSV=False)
+    # Construct a dict of sets for each cluster
+    ref_clusters = [set() for c in set(clusters.items())]
+    # Iterate through references
+    for reference_index in references:
+        # Add references to the appropriate cluster
+        ref_clusters[clusters[mashOrder[reference_index]]].add(reference_index)
+
+    # Make a mask of the existing network to only retain references
+    # Use a vertex filter
+    reference_vertex = G.new_vertex_property('bool')
+    for n,vertex in enumerate(G.vertices()):
+        if n in references:
+            reference_vertex[vertex] = True
         else:
-            ref_clusters.add(clusters[reference])
-
-    # Check if these multi reference components have been split
-    if len(multi_ref_clusters) > 0:
-        # Initial reference graph
-        ref_G = G.copy()
-        ref_G.remove_nodes_from(set(ref_G.nodes).difference(references))
-
-        for multi_ref_cluster in multi_ref_clusters:
-            # Get a list of nodes that need to be in the same component
-            check = []
-            for reference in references:
-                if clusters[reference] == multi_ref_cluster:
-                    check.append(reference)
-
-            # Pairwise check that nodes are in same component
+            reference_vertex[vertex] = False
+    G.set_vertex_filter(reference_vertex)
+    # Calculate component membership for reference graph
+    reference_components, reference_component_frequencies = gt.label_components(G)
+    # Record to which components references below in the reference graph
+    reference_cluster = {}
+    for reference_index in references:
+        reference_cluster[reference_index] = reference_components.a[reference_index]
+
+    # Unset mask on network for shortest path calculations
+    G.set_vertex_filter(None)
+
+    # Check if multi-reference components have been split as a validation test
+    # First iterate through clusters
+    for cluster in ref_clusters:
+        # Identify multi-reference clusters by this length
+        if len(cluster) > 1:
+            check = list(cluster)
+            # check if these are still in the same component in the reference graph
             for i in range(len(check)):
-                component = nx.node_connected_component(ref_G, check[i])
+                component_i = reference_cluster[check[i]]
                 for j in range(i, len(check)):
                     # Add intermediate nodes
-                    if check[j] not in component:
-                        new_path = nx.shortest_path(G, check[i], check[j])
-                        for node in new_path:
-                            references.add(node)
-
+                    component_j = reference_cluster[check[j]]
+                    if component_i != component_j:
+                        vertex_list, edge_list = gt.shortest_path(G, check[i], check[j])
+                        for vertex in vertex_list:
+                            references.add(vertex)
+    
     # Order found references as in mash sketch files
-    references = [x for x in mashOrder if x in references]
+    references = [mashOrder[x] for x in sorted(references)]
     refFileName = writeReferences(references, outPrefix)
     return references, refFileName
 
@@ -470,12 +483,14 @@ def printClusters(G, rlist, outPrefix = "_clusters.csv", oldClusterFile = None,
         raise RuntimeError("Trying to print query clusters with no query sequences")
 
     # get a sorted list of component assignments
-    newClusters = {}
     component_assignments, component_frequencies = gt.label_components(G)
     component_frequency_ranks = len(component_frequencies) - rankdata(component_frequencies, method = 'ordinal').astype(int)
-    for n,v in enumerate(rlist):
-        newClusters[rlist[n]] = component_frequency_ranks[component_assignments.a]
-    
+    newClusters = [set() for rank in range(len(component_frequency_ranks))]
+    for isolate_index, isolate_name in enumerate(rlist):
+        component = component_assignments.a[isolate_index]
+        component_rank = component_frequency_ranks[component]
+        newClusters[component_rank].add(isolate_name)
+        
     oldNames = set()
 
     if oldClusterFile != None:

From e2476af7e022e777ab1985f8b84adcd1c4639888 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Fri, 8 May 2020 14:32:58 +0100
Subject: [PATCH 05/56] More efficient extraction of references

---
 PopPUNK/__main__.py | 9 +++++----
 PopPUNK/network.py  | 6 +++---
 2 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index 344dcb92..0c1699b4 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -519,10 +519,11 @@ def main():
         #******************************# 
         # extract limited references from clique by default
         if not args.full_db:
-            newReferencesNames, newReferencesFile = extractReferences(genomeNetwork, refList, args.output)
-            nodes_to_remove = set(refList).difference(newReferencesNames)
-            genomeNetwork.remove_nodes_from(nodes_to_remove)
-            prune_distance_matrix(refList, nodes_to_remove, distMat,
+            newReferencesIndices, newReferencesNames, newReferencesFile = extractReferences(genomeNetwork, refList, args.output)
+            nodes_to_remove = set(range(len(refList))).difference(newReferencesIndices)
+            genomeNetwork.remove_vertex(list(nodes_to_remove))
+            names_to_remove = [refList[n] for n in nodes_to_remove]
+            prune_distance_matrix(refList, names_to_remove, distMat,
                                   args.output + "/" + os.path.basename(args.output) + ".dists")
             
             # With mash, the sketches are actually removed from the database
diff --git a/PopPUNK/network.py b/PopPUNK/network.py
index b6391cb7..bd8e51bf 100644
--- a/PopPUNK/network.py
+++ b/PopPUNK/network.py
@@ -170,9 +170,9 @@ def extractReferences(G, mashOrder, outPrefix, existingRefs = None):
                             references.add(vertex)
     
     # Order found references as in mash sketch files
-    references = [mashOrder[x] for x in sorted(references)]
-    refFileName = writeReferences(references, outPrefix)
-    return references, refFileName
+    reference_names = [mashOrder[x] for x in sorted(references)]
+    refFileName = writeReferences(reference_names, outPrefix)
+    return references, reference_names, refFileName
 
 def writeReferences(refList, outPrefix):
     """Writes chosen references to file

From 027b36d9ec4a46d3975efc1a7094beb0f1f611d1 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Fri, 8 May 2020 14:43:40 +0100
Subject: [PATCH 06/56] Remove redundant imports and fix output

---
 PopPUNK/__main__.py  | 1 -
 PopPUNK/mash.py      | 1 -
 PopPUNK/network.py   | 1 -
 PopPUNK/plot.py      | 3 +--
 PopPUNK/sketchlib.py | 1 -
 5 files changed, 1 insertion(+), 6 deletions(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index 0c1699b4..e78eef4d 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -7,7 +7,6 @@
 import sys
 # additional
 import numpy as np
-import networkx as nx
 import graph_tool.all as gt
 import subprocess
 
diff --git a/PopPUNK/mash.py b/PopPUNK/mash.py
index f8e9bc3d..9b01864a 100644
--- a/PopPUNK/mash.py
+++ b/PopPUNK/mash.py
@@ -18,7 +18,6 @@
 from glob import glob
 from random import sample
 import numpy as np
-import networkx as nx
 from scipy import optimize
 try:
     from multiprocessing import Pool, shared_memory
diff --git a/PopPUNK/network.py b/PopPUNK/network.py
index bd8e51bf..32113e4c 100644
--- a/PopPUNK/network.py
+++ b/PopPUNK/network.py
@@ -12,7 +12,6 @@
 import operator
 import shutil
 import subprocess
-import networkx as nx
 import graph_tool.all as gt
 import numpy as np
 import pandas as pd
diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py
index 13441f0c..a9b6dc7c 100644
--- a/PopPUNK/plot.py
+++ b/PopPUNK/plot.py
@@ -24,7 +24,6 @@
 except ImportError:
     from sklearn.neighbors.kde import KernelDensity
 import dendropy
-import networkx as nx
 
 def plot_scatter(X, scale, out_prefix, title, kde = True):
     """Draws a 2D scatter plot (png) of the core and accessory distances
@@ -385,7 +384,7 @@ def outputsForCytoscape(G, clustering, outPrefix, epiCsv, queryList = None, suff
         graph_file_name = os.path.basename(outPrefix) + "_cytoscape.graphml"
     else:
         graph_file_name = os.path.basename(outPrefix) + "_" + suffix + "_cytoscape.graphml"
-    nx.write_graphml(G, outPrefix + "/" + graph_file_name)
+    G.save(outPrefix + "/" + graph_file_name, fmt = 'graphml')
 
     # Write CSV of metadata
     if writeCsv:
diff --git a/PopPUNK/sketchlib.py b/PopPUNK/sketchlib.py
index f6a1e120..52c9f954 100644
--- a/PopPUNK/sketchlib.py
+++ b/PopPUNK/sketchlib.py
@@ -18,7 +18,6 @@
 from glob import glob
 from random import sample
 import numpy as np
-import networkx as nx
 from scipy import optimize
 
 # Try to import sketchlib

From 5623a6710236a0535ec08eed0b1841641e1b3dc0 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Fri, 8 May 2020 17:31:21 +0100
Subject: [PATCH 07/56] Switch from multiprocessing to open MP parallelisation
 using graph-tools

---
 PopPUNK/__main__.py |  6 ++++++
 PopPUNK/network.py  |  2 +-
 PopPUNK/refine.py   | 27 ++++++++++-----------------
 3 files changed, 17 insertions(+), 18 deletions(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index e78eef4d..9521eaa1 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -298,6 +298,12 @@ def main():
     if args.ref_db is not None and args.ref_db.endswith('/'):
         args.ref_db = args.ref_db[:-1]
 
+    # Check on parallelisation of graph-tools
+    if gt.openmp_enabled():
+        gt.openmp_set_num_threads(args.threads)
+        sys.stderr.write('\nGraph-tools OpenMP parallelisation enabled:')
+        sys.stderr.write(' with ' + str(gt.openmp_get_num_threads()) + ' threads\n')
+
     # run according to mode
     sys.stderr.write("PopPUNK (POPulation Partitioning Using Nucleotide Kmers)\n")
     sys.stderr.write("\t(with backend: " + dbFuncs['backend'] + " v" + dbFuncs['backend_version'] + "\n")
diff --git a/PopPUNK/network.py b/PopPUNK/network.py
index 32113e4c..47066a1d 100644
--- a/PopPUNK/network.py
+++ b/PopPUNK/network.py
@@ -169,7 +169,7 @@ def extractReferences(G, mashOrder, outPrefix, existingRefs = None):
                             references.add(vertex)
     
     # Order found references as in mash sketch files
-    reference_names = [mashOrder[x] for x in sorted(references)]
+    reference_names = [mashOrder[int(x)] for x in sorted(references)]
     refFileName = writeReferences(reference_names, outPrefix)
     return references, reference_names, refFileName
 
diff --git a/PopPUNK/refine.py b/PopPUNK/refine.py
index 056151d4..97ec156f 100644
--- a/PopPUNK/refine.py
+++ b/PopPUNK/refine.py
@@ -81,23 +81,16 @@ def refineFit(distMat, sample_names, start_s, mean0, mean1,
     sys.stderr.write("Trying to optimise score globally\n")
     global_grid_resolution = 40 # Seems to work
     s_range = np.linspace(-min_move, max_move, num = global_grid_resolution)
-    
-    # Move distMat into shared memory
-    with SharedMemoryManager() as smm:
-        shm_distMat = smm.SharedMemory(size = distMat.nbytes)
-        distances_shared_array = np.ndarray(distMat.shape, dtype = distMat.dtype, buffer = shm_distMat.buf)
-        distances_shared_array[:] = distMat[:]
-        distances_shared = NumpyShared(name = shm_distMat.name, shape = distMat.shape, dtype = distMat.dtype)
-        
-        with Pool(processes = num_processes) as pool:
-            global_s = pool.map(partial(newNetwork,
-                                        sample_names = sample_names,
-                                        distMat = distances_shared,
-                                        start_point = start_point,
-                                        mean1 = mean1,
-                                        gradient = gradient,
-                                        slope = slope),
-                                s_range)
+
+    # Global optimisation of boundary position
+    global_s = [newNetwork(s,
+                            sample_names = sample_names,
+                            distMat = distMat,
+                            start_point = start_point,
+                            mean1 = mean1,
+                            gradient = gradient,
+                            slope = slope)
+                            for s in s_range]
 
     # Local optimisation around global optimum
     min_idx = np.argmin(np.array(global_s))

From 88714dd65cccf0b22555f3b3e771c76a51e61d0e Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Sat, 9 May 2020 11:24:21 +0100
Subject: [PATCH 08/56] Fix network loading message

---
 PopPUNK/network.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/PopPUNK/network.py b/PopPUNK/network.py
index 47066a1d..f1131c80 100644
--- a/PopPUNK/network.py
+++ b/PopPUNK/network.py
@@ -72,7 +72,7 @@ def fetchNetwork(network_dir, model, refList,
                              "a refined fit. Using the combined distances.\n")
 
     genomeNetwork = gt.load_graph(network_file)
-    sys.stderr.write("Network loaded: " + str(genomeNetwork.number_of_nodes()) + " samples\n")
+    sys.stderr.write("Network loaded: " + str(len(list(genomeNetwork.vertices()))) + " samples\n")
 
     # Ensure all in dists are in final network
     networkMissing = set(range(len(refList))).difference(list(genomeNetwork.vertices()))

From 145eee69fd25544aabcb66de1289391a9e0c6b80 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Sat, 9 May 2020 12:42:10 +0100
Subject: [PATCH 09/56] Graph loading function updated

---
 PopPUNK/__main__.py | 21 +++++++++++++--------
 1 file changed, 13 insertions(+), 8 deletions(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index 9521eaa1..6fac4613 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -637,7 +637,6 @@ def main():
                 postpruning_combined_seq, newDistMat = prune_distance_matrix(rlist, isolates_to_remove,
                                                                       complete_distMat, dists_out)
 
-            rlist = viz_subset
             combined_seq, core_distMat, acc_distMat = update_distance_matrices(viz_subset, newDistMat)
 
             # reorder subset to ensure list orders match
@@ -669,13 +668,6 @@ def main():
                     prev_clustering = args.previous_clustering
                 else:
                     prev_clustering = os.path.dirname(args.distances + ".pkl")
-
-                # Read in network and cluster assignment
-                genomeNetwork, cluster_file = fetchNetwork(prev_clustering, model, rlist, args.core_only, args.accessory_only)
-                isolateClustering = readIsolateTypeFromCsv(cluster_file, mode = 'clusters', return_dict = True)
-
-                # prune the network and dictionary of assignments
-                genomeNetwork.remove_nodes_from(set(genomeNetwork.nodes).difference(viz_subset))
                 
             # generate selected visualisations
             if args.microreact:
@@ -695,6 +687,19 @@ def main():
                     sys.stderr.write("Can only generate a network output for fitted models\n")
                 else:
                     sys.stderr.write("Writing cytoscape output\n")
+                    # Read in network and cluster assignment
+                    genomeNetwork, cluster_file = fetchNetwork(prev_clustering, model, rlist, args.core_only, args.accessory_only)
+                    isolateClustering = readIsolateTypeFromCsv(cluster_file, mode = 'clusters', return_dict = True)
+
+                    # mask the network and dictionary of assignments
+                    viz_vertex = G.new_vertex_property('bool')
+                    for n,vertex in enumerate(G.vertices()):
+                        if rlist[n] in viz_subset:
+                            viz_vertex[vertex] = True
+                        else:
+                            viz_vertex[vertex] = False
+                    G.set_vertex_filter(viz_vertex)
+                    # write output
                     outputsForCytoscape(genomeNetwork, isolateClustering, args.output, args.info_csv)
 
         else:

From e75541fa20f2c38fd36d3f7c6cc7832c38f4390a Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Sat, 9 May 2020 13:14:48 +0100
Subject: [PATCH 10/56] Update visualisation code

---
 PopPUNK/__main__.py | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index 6fac4613..23585d2c 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -448,9 +448,9 @@ def main():
         #*                            *#
         #* network construction       *#
         #*                            *#
-        #******************************#  
+        #******************************#
+        
         genomeNetwork = constructNetwork(refList, queryList, assignments, model.within_label)
-
         # Ensure all in dists are in final network
         networkMissing = set(range(len(refList))).difference(list(genomeNetwork.vertices()))
         if len(networkMissing) > 0:
@@ -656,7 +656,7 @@ def main():
                 cluster_file = args.viz_lineages
                 isolateClustering = readIsolateTypeFromCsv(cluster_file, mode = 'lineages', return_dict = True)
             else:
-                # identify existing analysis files
+#                # identify existing analysis files
                 model_prefix = args.ref_db
                 if args.model_dir is not None:
                     model_prefix = args.model_dir
@@ -668,7 +668,11 @@ def main():
                     prev_clustering = args.previous_clustering
                 else:
                     prev_clustering = os.path.dirname(args.distances + ".pkl")
-                
+                    
+                # load clustering
+                cluster_file = args.ref_db + '/' + args.ref_db + '_clusters.csv'
+                isolateClustering = readIsolateTypeFromCsv(cluster_file, mode = 'clusters', return_dict = True)
+
             # generate selected visualisations
             if args.microreact:
                 sys.stderr.write("Writing microreact output\n")

From 0109e0b470df9809d60a09463d19496f103cd451 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Tue, 12 May 2020 10:09:06 +0100
Subject: [PATCH 11/56] Refactor lineage_clustering code to use graph-tool

---
 PopPUNK/lineage_clustering.py | 102 ++++++++++++++++++++--------------
 1 file changed, 60 insertions(+), 42 deletions(-)

diff --git a/PopPUNK/lineage_clustering.py b/PopPUNK/lineage_clustering.py
index d553801a..b90ff0c9 100644
--- a/PopPUNK/lineage_clustering.py
+++ b/PopPUNK/lineage_clustering.py
@@ -12,7 +12,7 @@
 from collections import defaultdict
 import pickle
 import collections
-import networkx as nx
+import graph_tool.all as gt
 from multiprocessing import Pool, RawArray, shared_memory, managers
 try:
     from multiprocessing import Pool, shared_memory
@@ -104,7 +104,7 @@ def get_nearest_neighbours(rank, isolates = None, ranks = None):
     return nn
     
 
-def pick_seed_isolate(G, distances = None):
+def pick_seed_isolate(lineage_assignation, distances = None):
     """ Identifies seed isolate from the closest pair of
     unclustered isolates.
     
@@ -122,7 +122,7 @@ def pick_seed_isolate(G, distances = None):
     distances_shm = shared_memory.SharedMemory(name = distances.name)
     distances = np.ndarray(distances.shape, dtype = distances.dtype, buffer = distances_shm.buf)
     # identify unclustered isolates
-    unclustered_isolates = list(nx.isolates(G))
+    unclustered_isolates = [isolate for isolate,lineage in lineage_assignation.items() if lineage == 0]
     # select minimum distance between unclustered isolates
     minimum_distance_between_unclustered_isolates = np.amin(distances[unclustered_isolates,unclustered_isolates],axis = 0)
     # select occurrences of this distance
@@ -135,7 +135,7 @@ def pick_seed_isolate(G, distances = None):
     # return unclustered isolate with minimum distance to another isolate
     return seed_isolate
 
-def get_lineage(G, neighbours, seed_isolate, lineage_index):
+def get_lineage(lineage_assignation, neighbours, seed_isolate, lineage_index):
     """ Identifies isolates corresponding to a particular
     lineage given a cluster seed.
 
@@ -155,27 +155,28 @@ def get_lineage(G, neighbours, seed_isolate, lineage_index):
     """
     # initiate lineage as the seed isolate and immediate unclustered neighbours
     in_lineage = {seed_isolate}
-    G.nodes[seed_isolate]['lineage'] = lineage_index
+    lineage_assignation[seed_isolate] = lineage_index
+    edges_to_add = set()
     for seed_neighbour in neighbours[seed_isolate]:
-        if nx.is_isolate(G, seed_neighbour):
-            G.add_edge(seed_isolate, seed_neighbour)
-            G.nodes[seed_neighbour]['lineage'] = lineage_index
+        if lineage_assignation[seed_neighbour] == 0:
+            edges_to_add.add((seed_isolate, seed_neighbour))
             in_lineage.add(seed_neighbour)
+            lineage_assignation[seed_neighbour] = lineage_index
     # iterate through other isolates until converged on a stable clustering
     alterations = len(neighbours.keys())
     while alterations > 0:
         alterations = 0
         for isolate in neighbours.keys():
-            if nx.is_isolate(G, isolate):
+            if lineage_assignation[isolate] == 0:
                 intersection_size = in_lineage.intersection(neighbours[isolate])
                 if intersection_size is not None and len(intersection_size) > 0:
                     for i in intersection_size:
-                        G.add_edge(isolate, i)
-                        G.nodes[isolate]['lineage'] = lineage_index
+                        edges_to_add.add((isolate, i))
                     in_lineage.add(isolate)
+                    lineage_assignation[isolate] = lineage_index
                     alterations = alterations + 1
     # return final clustering
-    return G
+    return lineage_assignation, edges_to_add
 
 def cluster_into_lineages(distMat, rank_list = None, output = None, isolate_list = None, qlist = None, existing_scheme = None, use_accessory = False, num_processes = 1):
     """ Clusters isolates into lineages based on their
@@ -209,12 +210,12 @@ def cluster_into_lineages(distMat, rank_list = None, output = None, isolate_list
     """
     
     # data structures
-    lineage_clustering = defaultdict(dict)
+    lineage_assignation = defaultdict(dict)
     overall_lineage_seeds = defaultdict(dict)
     overall_lineages = defaultdict(dict)
     if existing_scheme is not None:
         with open(existing_scheme, 'rb') as pickle_file:
-            lineage_clustering, overall_lineage_seeds, rank_list = pickle.load(pickle_file)
+            lineage_assignation, overall_lineage_seeds, rank_list = pickle.load(pickle_file)
 
     # generate square distance matrix
     seqLabels, coreMat, accMat = update_distance_matrices(isolate_list, distMat)
@@ -258,24 +259,46 @@ def cluster_into_lineages(distMat, rank_list = None, output = None, isolate_list
         distance_ranks_shared_array = np.ndarray(distance_ranks.shape, dtype = distance_ranks.dtype, buffer = distance_ranks_raw.buf)
         distance_ranks_shared_array[:] = distance_ranks[:]
         distance_ranks_shared_array = NumpyShared(name = distance_ranks_raw.name, shape = distance_ranks.shape, dtype = distance_ranks.dtype)
-                
+        
+        # alter parallelisation of graph-tools to account for multiprocessing
+        num_gt_processes = 1
+        if num_processes > len(rank_list):
+            num_gt_processes = max(1,int(num_processes/len(rank_list)))
+            num_mp_processes = rank_list
+        else:
+            num_mp_processes = num_processes
+        
         # parallelise neighbour identification for each rank
-        with Pool(processes = num_processes) as pool:
+        with Pool(processes = num_mp_processes) as pool:
             results = pool.map(partial(run_clustering_for_rank,
                                 distances_input = distances_shared_array,
                                 distance_ranks_input = distance_ranks_shared_array,
                                 isolates = isolate_list_shared,
                                 previous_seeds = overall_lineage_seeds),
                                 rank_list)
-        
-        # extract results from multiprocessing pool
+
+        # extract results from multiprocessing pool and save output network
         for n,result in enumerate(results):
+            
+            # get results per rank
             rank = rank_list[n]
-            lineage_clustering[rank], overall_lineage_seeds[rank] = result
+            lineage_assignation[rank], overall_lineage_seeds[rank], connections = result
+
+            # create graph structure with internal vertex property map
+            # storing lineage assignation cannot load boost.python within spawned
+            # processes so have to run network analysis separately
+            G = gt.Graph()
+            G.add_vertex(len(isolate_list))
+
+            # store results in network
+            G.add_edge_list(connections)
+
+            # save network
+            G.save(file_name = output + "/" + output + '_rank_' + str(rank) + '_lineages.gt', fmt = 'gt')
 
     # store output
     with open(output + "/" + output + '_lineages.pkl', 'wb') as pickle_file:
-        pickle.dump([lineage_clustering, overall_lineage_seeds, rank_list], pickle_file)
+        pickle.dump([lineage_assignation, overall_lineage_seeds, rank_list], pickle_file)
     
     # process multirank lineages
     overall_lineages = {}
@@ -284,11 +307,11 @@ def cluster_into_lineages(distMat, rank_list = None, output = None, isolate_list
     for index,isolate in enumerate(isolate_list):
         overall_lineage = None
         for rank in rank_list:
-            overall_lineages['Rank_' + str(rank)][isolate] = lineage_clustering[rank][index]
+            overall_lineages['Rank_' + str(rank)][isolate] = lineage_assignation[rank][index]
             if overall_lineage is None:
-                overall_lineage = str(lineage_clustering[rank][index])
+                overall_lineage = str(lineage_assignation[rank][index])
             else:
-                overall_lineage = overall_lineage + '-' + str(lineage_clustering[rank][index])
+                overall_lineage = overall_lineage + '-' + str(lineage_assignation[rank][index])
         overall_lineages['overall'][isolate] = overall_lineage
     
     # print output as CSV
@@ -321,14 +344,13 @@ def run_clustering_for_rank(rank, distances_input = None, distance_ranks_input =
             Whether to extend a previously generated analysis or not.
             
     Returns:
-        lineage_clustering (dict)
+        lineage_assignation (dict)
             Assignment of each isolate to a cluster.
         lineage_seed (dict)
             Seed isolate used to initiate each cluster.
         neighbours (nested dict)
             Neighbour relationships between isolates for R.
-    """
-    
+    """    
     # load shared memory objects
     distances_shm = shared_memory.SharedMemory(name = distances_input.name)
     distances = np.ndarray(distances_input.shape, dtype = distances_input.dtype, buffer = distances_shm.buf)
@@ -336,17 +358,12 @@ def run_clustering_for_rank(rank, distances_input = None, distance_ranks_input =
     distance_ranks = np.ndarray(distance_ranks_input.shape, dtype = distance_ranks_input.dtype, buffer = distance_ranks_shm.buf)
     isolate_list = isolates
     isolate_indices = range(0,len(isolate_list))
-
+    
     # load previous scheme
     seeds = {}
     if previous_seeds is not None:
         seeds = previous_seeds[rank]
-
-    # create graph structure
-    G = nx.Graph()
-    G.add_nodes_from(isolate_indices)
-    G.nodes.data('lineage', default = 0)
-    
+   
     # identify nearest neighbours
     nn = get_nearest_neighbours(rank,
                             ranks = distance_ranks_input,
@@ -354,19 +371,20 @@ def run_clustering_for_rank(rank, distances_input = None, distance_ranks_input =
     
     # iteratively identify lineages
     lineage_index = 1
-    while nx.number_of_isolates(G) > 0:
+    connections = set()
+    lineage_assignation = {isolate:0 for isolate in isolate_list}
+    
+    while 0 in lineage_assignation.values():
         if lineage_index in seeds.keys():
             seed_isolate = seeds[lineage_index]
         else:
-            seed_isolate = pick_seed_isolate(G, distances = distances_input)
+            seed_isolate = pick_seed_isolate(lineage_assignation, distances = distances_input)
         # skip over previously-defined seeds if amalgamated into different lineage now
-        if nx.is_isolate(G, seed_isolate):
+        if lineage_assignation[seed_isolate] == 0:
             seeds[lineage_index] = seed_isolate
-            G = get_lineage(G, nn, seed_isolate, lineage_index)
+            lineage_assignation, added_connections = get_lineage(lineage_assignation, nn, seed_isolate, lineage_index)
+            connections.update(added_connections)
         lineage_index = lineage_index + 1
-
-    # identify components and name lineages
-    lineage_clustering = {node:nodedata for (node, nodedata) in G.nodes(data='lineage')}
-
+    
     # return clustering
-    return lineage_clustering, seeds
+    return lineage_assignation, seeds, connections

From 188ed403361a021418303b3045166c390fc85fee Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Tue, 12 May 2020 10:12:16 +0100
Subject: [PATCH 12/56] Update docstrings

---
 PopPUNK/lineage_clustering.py | 22 ++++++++++++++--------
 1 file changed, 14 insertions(+), 8 deletions(-)

diff --git a/PopPUNK/lineage_clustering.py b/PopPUNK/lineage_clustering.py
index b90ff0c9..4540b5c2 100644
--- a/PopPUNK/lineage_clustering.py
+++ b/PopPUNK/lineage_clustering.py
@@ -109,8 +109,9 @@ def pick_seed_isolate(lineage_assignation, distances = None):
     unclustered isolates.
     
     Args:
-        G (network)
-            Network with one node per isolate.
+        lineage_assignation (dict)
+            Dict of lineage assignments (int) of each
+            isolate index (int).
         distances (ndarray in shared memory)
             Pairwise distances between isolates.
             
@@ -140,8 +141,9 @@ def get_lineage(lineage_assignation, neighbours, seed_isolate, lineage_index):
     lineage given a cluster seed.
 
     Args:
-        G (network)
-            Network with one node per isolate.
+        lineage_assignation (dict)
+            Dict of lineage assignments (int) of each
+            isolate index (int).
         neighbours (dict of frozen sets)
            Pre-calculated neighbour relationships.
         seed_isolate (int)
@@ -150,8 +152,12 @@ def get_lineage(lineage_assignation, neighbours, seed_isolate, lineage_index):
            Label of current lineage.
         
     Returns:
-        G (network)
-            Network modified with new edges.
+        lineage_assignation (dict)
+            Dict of lineage assignments (int) of each
+            isolate index (int).
+        edges_to_add (set of tuples)
+            Edges to add to network describing lineages
+            of this rank.
     """
     # initiate lineage as the seed isolate and immediate unclustered neighbours
     in_lineage = {seed_isolate}
@@ -348,8 +354,8 @@ def run_clustering_for_rank(rank, distances_input = None, distance_ranks_input =
             Assignment of each isolate to a cluster.
         lineage_seed (dict)
             Seed isolate used to initiate each cluster.
-        neighbours (nested dict)
-            Neighbour relationships between isolates for R.
+        connections (set of tuples)
+            Edges to add to network describing lineages.
     """    
     # load shared memory objects
     distances_shm = shared_memory.SharedMemory(name = distances_input.name)

From a75cb4f5f5d1f63d96de9b702ff7c3f296d47d02 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Tue, 12 May 2020 13:09:32 +0100
Subject: [PATCH 13/56] Enable visualisation of lineage networks using
 Cytoscape

---
 PopPUNK/__main__.py | 166 ++++++++++++++++++++++++--------------------
 PopPUNK/plot.py     |   2 +-
 2 files changed, 91 insertions(+), 77 deletions(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index 23585d2c..59893775 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -104,7 +104,6 @@ def get_options():
             default=False,
             action='store_true')
 
-
     # input options
     iGroup = parser.add_argument_group('Input files')
     iGroup.add_argument('--ref-db',type = str, help='Location of built reference database')
@@ -248,7 +247,7 @@ def main():
         sys.stderr.write("Minimum kmer size " + str(args.min_k) + " must be smaller than maximum kmer size\n")
         sys.exit(1)
     elif args.k_step < 2:
-        sys.stderr.write("Kmer size step must be at least one\n")
+        sys.stderr.write("Kmer size step must be at least two\n")
         sys.exit(1)
     elif no_sketchlib and (args.min_k < 9 or args.max_k > 31):
         sys.stderr.write("When using Mash, Kmer size must be between 9 and 31\n")
@@ -362,7 +361,7 @@ def main():
     #*                            *#
     #******************************#    
     # refine model also needs to run all model steps
-    if args.fit_model or args.use_model or args.refine_model or args.threshold or args.easy_run:
+    if args.fit_model or args.use_model or args.refine_model or args.threshold or args.easy_run or args.lineage_clustering:
         # Set up saved data from first step, if easy_run mode
         if args.easy_run:
             distances = dists_out
@@ -374,9 +373,10 @@ def main():
                 sys.stderr.write("Mode: Using previous model with a reference database\n\n")
             elif args.threshold:
                 sys.stderr.write("Mode: Applying a core distance threshold\n\n")
-            else:
+            elif args.refine_model:
                 sys.stderr.write("Mode: Refining model fit using network properties\n\n")
-
+            elif args.lineage_clustering:
+                sys.stderr.write("Mode: Identifying lineages from neighbouring isolates\n\n")
             if args.distances is not None and args.ref_db is not None:
                 distances = args.distances
                 ref_db = args.ref_db
@@ -441,8 +441,9 @@ def main():
             assignments = model.assign(distMat)
             model.plot(distMat, assignments)
 
-        fit_type = 'combined'
-        model.save()
+        if not args.lineage_clustering: # change this once lineage clustering is refined as a model
+            fit_type = 'combined'
+            model.save()
         
         #******************************#
         #*                            *#
@@ -450,36 +451,75 @@ def main():
         #*                            *#
         #******************************#
         
-        genomeNetwork = constructNetwork(refList, queryList, assignments, model.within_label)
-        # Ensure all in dists are in final network
-        networkMissing = set(range(len(refList))).difference(list(genomeNetwork.vertices()))
-        if len(networkMissing) > 0:
-            sys.stderr.write("WARNING: Samples " + ",".join(networkMissing) + " are missing from the final network\n")
-
-        isolateClustering = {fit_type: printClusters(genomeNetwork,
-                                                    refList, # assume no rlist+qlist?
-                                                     args.output + "/" + os.path.basename(args.output),
-                                                     externalClusterCSV = args.external_clustering)}
-
-        # Write core and accessory based clusters, if they worked
-        if model.indiv_fitted:
+        if not args.lineage_clustering:
+            genomeNetwork = constructNetwork(refList, queryList, assignments, model.within_label)
+            # Ensure all in dists are in final network
+            networkMissing = set(range(len(refList))).difference(list(genomeNetwork.vertices()))
+            if len(networkMissing) > 0:
+                sys.stderr.write("WARNING: Samples " + ",".join(networkMissing) + " are missing from the final network\n")
+
+            isolateClustering = {fit_type: printClusters(genomeNetwork,
+                                                        refList, # assume no rlist+qlist?
+                                                         args.output + "/" + os.path.basename(args.output),
+                                                         externalClusterCSV = args.external_clustering)}
+
+            # Write core and accessory based clusters, if they worked
+            if model.indiv_fitted:
+                indivNetworks = {}
+                for dist_type, slope in zip(['core', 'accessory'], [0, 1]):
+                    indivAssignments = model.assign(distMat, slope)
+                    indivNetworks[dist_type] = constructNetwork(refList, queryList, indivAssignments, model.within_label)
+                    isolateClustering[dist_type] = printClusters(indivNetworks[dist_type],
+                                                    refList,
+                                                     args.output + "/" + os.path.basename(args.output) + "_" + dist_type,
+                                                     externalClusterCSV = args.external_clustering)
+                    indivNetworks[dist_type].save(args.output + "/" + os.path.basename(args.output) +
+                    "_" + dist_type + '_graph.gt', fmt = 'gt')
+
+                if args.core_only:
+                    fit_type = 'core'
+                    genomeNetwork = indivNetworks['core']
+                elif args.accessory_only:
+                    fit_type = 'accessory'
+                    genomeNetwork = indivNetworks['accessory']
+
+        #******************************#
+        #*                            *#
+        #* lineages analysis          *#
+        #*                            *#
+        #******************************#
+        
+        if args.lineage_clustering:
+
+            # load distances
+            if args.distances is not None:
+                distances = args.distances
+            else:
+                sys.stderr.write("Need to provide an input set of distances with --distances\n\n")
+                sys.exit(1)
+
+            refList, queryList, self, distMat = readPickle(distances)
+            
+            # make directory for new output files
+            if not os.path.isdir(args.output):
+                try:
+                    os.makedirs(args.output)
+                except OSError:
+                    sys.stderr.write("Cannot create output directory\n")
+                    sys.exit(1)
+            
+            # run lineage clustering
+            if self:
+                isolateClustering = cluster_into_lineages(distMat, rank_list, args.output, isolate_list = refList, use_accessory = args.use_accessory, existing_scheme = args.existing_scheme, num_processes = args.threads)
+            else:
+                isolateClustering = cluster_into_lineages(distMat, rank_list, args.output, isolate_list = refList, qlist = queryList, use_accessory = args.use_accessory,  existing_scheme = args.existing_scheme, num_processes = args.threads)
+                
+            # load networks
             indivNetworks = {}
-            for dist_type, slope in zip(['core', 'accessory'], [0, 1]):
-                indivAssignments = model.assign(distMat, slope)
-                indivNetworks[dist_type] = constructNetwork(refList, queryList, indivAssignments, model.within_label)
-                isolateClustering[dist_type] = printClusters(indivNetworks[dist_type],
-                                                refList,
-                                                 args.output + "/" + os.path.basename(args.output) + "_" + dist_type,
-                                                 externalClusterCSV = args.external_clustering)
-                indivNetworks[dist_type].save(args.output + "/" + os.path.basename(args.output) +
-                "_" + dist_type + '_graph.gt', fmt = 'gt')
-
-            if args.core_only:
-                fit_type = 'core'
-                genomeNetwork = indivNetworks['core']
-            elif args.accessory_only:
-                fit_type = 'accessory'
-                genomeNetwork = indivNetworks['accessory']
+            for rank in rank_list:
+                indivNetworks[rank] = gt.load_graph(args.output + "/" + args.output + '_rank_' + str(rank) + '_lineages.gt')
+                if rank == min(rank_list):
+                    genomeNetwork = indivNetworks[rank]
 
         #******************************#
         #*                            *#
@@ -506,16 +546,21 @@ def main():
                                         overwrite = args.overwrite, microreact = args.microreact)
                 if args.cytoscape:
                     sys.stderr.write("Writing cytoscape output\n")
-                    outputsForCytoscape(genomeNetwork, isolateClustering, args.output, args.info_csv)
-                    if model.indiv_fitted:
-                        sys.stderr.write("Writing individual cytoscape networks\n")
-                        for dist_type in ['core', 'accessory']:
-                            outputsForCytoscape(indivNetworks[dist_type], isolateClustering, args.output,
-                                        args.info_csv, suffix = dist_type, writeCsv = False)
+                    if args.lineage_clustering:
+                        for rank in rank_list:
+                            outputsForCytoscape(indivNetworks[rank], isolateClustering, args.output,
+                                        args.info_csv, suffix = 'rank_' + str(rank), writeCsv = False)
+                    else:
+                        outputsForCytoscape(genomeNetwork, isolateClustering, args.output, args.info_csv)
+                        if model.indiv_fitted:
+                            sys.stderr.write("Writing individual cytoscape networks\n")
+                            for dist_type in ['core', 'accessory']:
+                                outputsForCytoscape(indivNetworks[dist_type], isolateClustering, args.output,
+                                            args.info_csv, suffix = dist_type, writeCsv = False)
         except:
             # Go ahead with final steps even if visualisations fail
             # (e.g. rapidnj not found)
-            sys.stderr.write("Error creating files for visualisation:", sys.exc_info()[0])
+            sys.stderr.write("Error creating files for visualisation: " + str(sys.exc_info()[0]))
 
         #******************************#
         #*                            *#
@@ -543,37 +588,6 @@ def main():
 
         genomeNetwork.save(args.output + "/" + os.path.basename(args.output) + '_graph.gt', fmt = 'gt')
 
-    #******************************#
-    #*                            *#
-    #* within-strain analysis     *#
-    #*                            *#
-    #******************************#
-    if args.lineage_clustering:
-        sys.stderr.write("Mode: Identifying lineages within a clade\n\n")
-
-        # load distances
-        if args.distances is not None:
-            distances = args.distances
-        else:
-            sys.stderr.write("Need to provide an input set of distances with --distances\n\n")
-            sys.exit(1)
-
-        refList, queryList, self, distMat = readPickle(distances)
-        
-        # make directory for new output files
-        if not os.path.isdir(args.output):
-            try:
-                os.makedirs(args.output)
-            except OSError:
-                sys.stderr.write("Cannot create output directory\n")
-                sys.exit(1)
-        
-        # run lineage clustering
-        if self:
-            isolateClustering = cluster_into_lineages(distMat, rank_list, args.output, isolate_list = refList, use_accessory = args.use_accessory, existing_scheme = args.existing_scheme, num_processes = args.threads)
-        else:
-            isolateClustering = cluster_into_lineages(distMat, rank_list, args.output, isolate_list = refList, qlist = queryList, use_accessory = args.use_accessory,  existing_scheme = args.existing_scheme, num_processes = args.threads)
-
     #*******************************#
     #*                             *#
     #* query assignment (function  *#
@@ -601,7 +615,7 @@ def main():
             sys.stderr.write("Must specify at least one type of visualisation to output\n")
             sys.exit(1)
 
-        if args.distances is not None and args.ref_db is not None:
+        if args.distances is not None and (args.ref_db is not None or args.lineage_clustering):
             
             # Initial processing
             # Load original distances
@@ -656,7 +670,7 @@ def main():
                 cluster_file = args.viz_lineages
                 isolateClustering = readIsolateTypeFromCsv(cluster_file, mode = 'lineages', return_dict = True)
             else:
-#                # identify existing analysis files
+                # identify existing analysis files
                 model_prefix = args.ref_db
                 if args.model_dir is not None:
                     model_prefix = args.model_dir
diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py
index a9b6dc7c..16041b74 100644
--- a/PopPUNK/plot.py
+++ b/PopPUNK/plot.py
@@ -388,7 +388,7 @@ def outputsForCytoscape(G, clustering, outPrefix, epiCsv, queryList = None, suff
 
     # Write CSV of metadata
     if writeCsv:
-        refNames = G.nodes(data=False)
+        refNames = G.vertices
         seqLabels = [r.split('/')[-1].split('.')[0] for r in refNames]
         writeClusterCsv(outPrefix + "/" + outPrefix + "_cytoscape.csv",
                         refNames,

From 2fc779ba8e8a94bbba535ea918039b960c4c4a94 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Tue, 12 May 2020 14:42:35 +0100
Subject: [PATCH 14/56] Add extra network printing features

---
 PopPUNK/lineage_clustering.py | 59 ++++++++++++++++++++++++-----------
 1 file changed, 40 insertions(+), 19 deletions(-)

diff --git a/PopPUNK/lineage_clustering.py b/PopPUNK/lineage_clustering.py
index 4540b5c2..4410e1ab 100644
--- a/PopPUNK/lineage_clustering.py
+++ b/PopPUNK/lineage_clustering.py
@@ -266,16 +266,19 @@ def cluster_into_lineages(distMat, rank_list = None, output = None, isolate_list
         distance_ranks_shared_array[:] = distance_ranks[:]
         distance_ranks_shared_array = NumpyShared(name = distance_ranks_raw.name, shape = distance_ranks.shape, dtype = distance_ranks.dtype)
         
-        # alter parallelisation of graph-tools to account for multiprocessing
-        num_gt_processes = 1
-        if num_processes > len(rank_list):
-            num_gt_processes = max(1,int(num_processes/len(rank_list)))
-            num_mp_processes = rank_list
-        else:
-            num_mp_processes = num_processes
+        # build a graph framework for network outputs
+        # create graph structure with internal vertex property map
+        # storing lineage assignation cannot load boost.python within spawned
+        # processes so have to run network analysis separately
+        G = gt.Graph()
+        G.add_vertex(len(isolate_list))
+        # add sequence labels for visualisation
+        vid = G.new_vertex_property('string',
+                                    vals = [i.split('/')[-1].split('.')[0] for i in isolate_list])
+        G.vp.id = vid
         
         # parallelise neighbour identification for each rank
-        with Pool(processes = num_mp_processes) as pool:
+        with Pool(processes = num_processes) as pool:
             results = pool.map(partial(run_clustering_for_rank,
                                 distances_input = distances_shared_array,
                                 distance_ranks_input = distance_ranks_shared_array,
@@ -284,23 +287,19 @@ def cluster_into_lineages(distMat, rank_list = None, output = None, isolate_list
                                 rank_list)
 
         # extract results from multiprocessing pool and save output network
+        nn = defaultdict(dict)
         for n,result in enumerate(results):
-            
             # get results per rank
             rank = rank_list[n]
-            lineage_assignation[rank], overall_lineage_seeds[rank], connections = result
-
-            # create graph structure with internal vertex property map
-            # storing lineage assignation cannot load boost.python within spawned
-            # processes so have to run network analysis separately
-            G = gt.Graph()
-            G.add_vertex(len(isolate_list))
-
+            lineage_assignation[rank], overall_lineage_seeds[rank], nn[rank], connections = result
+            # produce nearest neighbour network for alternative downstream analyses
+            make_nn_network(G,nn[rank],output,rank)
             # store results in network
             G.add_edge_list(connections)
-
             # save network
             G.save(file_name = output + "/" + output + '_rank_' + str(rank) + '_lineages.gt', fmt = 'gt')
+            # clear edges
+            G.clear_edges()
 
     # store output
     with open(output + "/" + output + '_lineages.pkl', 'wb') as pickle_file:
@@ -393,4 +392,26 @@ def run_clustering_for_rank(rank, distances_input = None, distance_ranks_input =
         lineage_index = lineage_index + 1
     
     # return clustering
-    return lineage_assignation, seeds, connections
+    return lineage_assignation, seeds, nn, connections
+
+def make_nn_network(G,nn,output,rank):
+    """Output all nearest neighbour relationships for
+    deeper analysis of lineage analyses.
+    Args:
+        G (graph-tools network)
+        
+        nn (nested dict)
+        
+        output (str)
+        
+        rank (int)
+    
+    Returns:
+        Void
+    """
+    edges_to_add = set()
+    for i in nn.keys():
+        for j in nn[i]:
+            edges_to_add.add((i,j))
+    G.add_edge_list(edges_to_add)
+    G.save(file_name = output + "/" + output + '_nearestNeighbours_rank_' + str(rank) + '_lineages.graphml', fmt = 'graphml')

From fc3a69a3da3eef18c5d9c7c9d8b0c3276c4c5ef4 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Tue, 12 May 2020 19:08:59 +0100
Subject: [PATCH 15/56] Change to network-based definitions of lineages

---
 PopPUNK/lineage_clustering.py | 169 ++++++++++------------------------
 1 file changed, 49 insertions(+), 120 deletions(-)

diff --git a/PopPUNK/lineage_clustering.py b/PopPUNK/lineage_clustering.py
index 4410e1ab..f3cd4a5b 100644
--- a/PopPUNK/lineage_clustering.py
+++ b/PopPUNK/lineage_clustering.py
@@ -89,100 +89,18 @@ def get_nearest_neighbours(rank, isolates = None, ranks = None):
             frozen set of nearest neighbours.
     """
     # data structure
-    nn = {}
+    nn = set()
     # load shared ranks
     ranks_shm = shared_memory.SharedMemory(name = ranks.name)
     ranks = np.ndarray(ranks.shape, dtype = ranks.dtype, buffer = ranks_shm.buf)
     # apply along axis
     for i in isolates:
-        nn[i] = defaultdict(frozenset)
         isolate_ranks = ranks[i,:]
         closest_ranked = np.ravel(np.where(isolate_ranks <= rank))
-        neighbours = frozenset(closest_ranked.tolist())
-        nn[i] = neighbours
+        for j in closest_ranked.tolist():
+            nn.add((i,j))
     # return dict
     return nn
-    
-
-def pick_seed_isolate(lineage_assignation, distances = None):
-    """ Identifies seed isolate from the closest pair of
-    unclustered isolates.
-    
-    Args:
-        lineage_assignation (dict)
-            Dict of lineage assignments (int) of each
-            isolate index (int).
-        distances (ndarray in shared memory)
-            Pairwise distances between isolates.
-            
-    Returns:
-        seed_isolate (int)
-            Index of isolate selected as seed.
-    """
-    # load distances from shared memory
-    distances_shm = shared_memory.SharedMemory(name = distances.name)
-    distances = np.ndarray(distances.shape, dtype = distances.dtype, buffer = distances_shm.buf)
-    # identify unclustered isolates
-    unclustered_isolates = [isolate for isolate,lineage in lineage_assignation.items() if lineage == 0]
-    # select minimum distance between unclustered isolates
-    minimum_distance_between_unclustered_isolates = np.amin(distances[unclustered_isolates,unclustered_isolates],axis = 0)
-    # select occurrences of this distance
-    minimum_distance_coordinates = np.where(distances == minimum_distance_between_unclustered_isolates)
-    # identify case where both isolates are unclustered
-    for i in range(len(minimum_distance_coordinates[0])):
-        if minimum_distance_coordinates[0][i] in unclustered_isolates and minimum_distance_coordinates[1][i] in unclustered_isolates:
-            seed_isolate = minimum_distance_coordinates[0][i]
-            break
-    # return unclustered isolate with minimum distance to another isolate
-    return seed_isolate
-
-def get_lineage(lineage_assignation, neighbours, seed_isolate, lineage_index):
-    """ Identifies isolates corresponding to a particular
-    lineage given a cluster seed.
-
-    Args:
-        lineage_assignation (dict)
-            Dict of lineage assignments (int) of each
-            isolate index (int).
-        neighbours (dict of frozen sets)
-           Pre-calculated neighbour relationships.
-        seed_isolate (int)
-           Index of isolate selected as seed.
-        lineage_index (int)
-           Label of current lineage.
-        
-    Returns:
-        lineage_assignation (dict)
-            Dict of lineage assignments (int) of each
-            isolate index (int).
-        edges_to_add (set of tuples)
-            Edges to add to network describing lineages
-            of this rank.
-    """
-    # initiate lineage as the seed isolate and immediate unclustered neighbours
-    in_lineage = {seed_isolate}
-    lineage_assignation[seed_isolate] = lineage_index
-    edges_to_add = set()
-    for seed_neighbour in neighbours[seed_isolate]:
-        if lineage_assignation[seed_neighbour] == 0:
-            edges_to_add.add((seed_isolate, seed_neighbour))
-            in_lineage.add(seed_neighbour)
-            lineage_assignation[seed_neighbour] = lineage_index
-    # iterate through other isolates until converged on a stable clustering
-    alterations = len(neighbours.keys())
-    while alterations > 0:
-        alterations = 0
-        for isolate in neighbours.keys():
-            if lineage_assignation[isolate] == 0:
-                intersection_size = in_lineage.intersection(neighbours[isolate])
-                if intersection_size is not None and len(intersection_size) > 0:
-                    for i in intersection_size:
-                        edges_to_add.add((isolate, i))
-                    in_lineage.add(isolate)
-                    lineage_assignation[isolate] = lineage_index
-                    alterations = alterations + 1
-    # return final clustering
-    return lineage_assignation, edges_to_add
 
 def cluster_into_lineages(distMat, rank_list = None, output = None, isolate_list = None, qlist = None, existing_scheme = None, use_accessory = False, num_processes = 1):
     """ Clusters isolates into lineages based on their
@@ -219,9 +137,14 @@ def cluster_into_lineages(distMat, rank_list = None, output = None, isolate_list
     lineage_assignation = defaultdict(dict)
     overall_lineage_seeds = defaultdict(dict)
     overall_lineages = defaultdict(dict)
+    max_existing_cluster = {rank:1 for rank in rank_list}
+    
+    # load existing scheme if supplied
     if existing_scheme is not None:
         with open(existing_scheme, 'rb') as pickle_file:
             lineage_assignation, overall_lineage_seeds, rank_list = pickle.load(pickle_file)
+        for rank in rank_list:
+            max_existing_cluster[rank] = max(lineage_assignation[rank].values()) + 1
 
     # generate square distance matrix
     seqLabels, coreMat, accMat = update_distance_matrices(isolate_list, distMat)
@@ -270,7 +193,7 @@ def cluster_into_lineages(distMat, rank_list = None, output = None, isolate_list
         # create graph structure with internal vertex property map
         # storing lineage assignation cannot load boost.python within spawned
         # processes so have to run network analysis separately
-        G = gt.Graph()
+        G = gt.Graph(directed = False)
         G.add_vertex(len(isolate_list))
         # add sequence labels for visualisation
         vid = G.new_vertex_property('string',
@@ -279,11 +202,9 @@ def cluster_into_lineages(distMat, rank_list = None, output = None, isolate_list
         
         # parallelise neighbour identification for each rank
         with Pool(processes = num_processes) as pool:
-            results = pool.map(partial(run_clustering_for_rank,
-                                distances_input = distances_shared_array,
-                                distance_ranks_input = distance_ranks_shared_array,
-                                isolates = isolate_list_shared,
-                                previous_seeds = overall_lineage_seeds),
+            results = pool.map(partial(get_nearest_neighbours,
+                                ranks = distance_ranks_shared_array,
+                                isolates = isolate_list_shared),
                                 rank_list)
 
         # extract results from multiprocessing pool and save output network
@@ -291,11 +212,41 @@ def cluster_into_lineages(distMat, rank_list = None, output = None, isolate_list
         for n,result in enumerate(results):
             # get results per rank
             rank = rank_list[n]
-            lineage_assignation[rank], overall_lineage_seeds[rank], nn[rank], connections = result
-            # produce nearest neighbour network for alternative downstream analyses
-            make_nn_network(G,nn[rank],output,rank)
+            # get neigbours
+            edges_to_add = result
             # store results in network
-            G.add_edge_list(connections)
+            G.add_edge_list(edges_to_add)
+            # calculate connectivity of each vertex
+            vertex_out_degrees = G.get_out_degrees(G.get_vertices())
+            # identify components and rank by frequency
+            components, component_frequencies = gt.label_components(G)
+            component_frequency_ranks = (len(component_frequencies) - rankdata(component_frequencies, method = 'ordinal').astype(int)).tolist()
+            # construct a name translation table
+            # begin with previously defined clusters
+            component_name = [None] * len(component_frequencies)
+            for seed in overall_lineage_seeds[rank]:
+                isolate_index = isolate_list.index(seed)
+                component_number = components[isolate_index]
+                if component_name[component_number] is None or component_name[component_number] > overall_lineage_seeds[rank][seed]:
+                    component_name[component_number] = overall_lineage_seeds[rank][seed]
+            # name remaining components in rank order
+            for component_rank in range(len(component_frequency_ranks)):
+#                component_number = component_frequency_ranks[np.where(component_frequency_ranks == component_rank)]
+                component_number = component_frequency_ranks.index(component_rank)
+                if component_name[component_number] is None:
+                    component_name[component_number] = max_existing_cluster[rank]
+                    # find seed isolate
+                    component_max_degree = np.amax(vertex_out_degrees[np.where(components.a == component_number)])
+                    seed_isolate_index = int(np.where((components.a == component_number) & (vertex_out_degrees == component_max_degree))[0][0])
+                    seed_isolate = isolate_list[seed_isolate_index]
+                    overall_lineage_seeds[rank][seed_isolate] = max_existing_cluster[rank]
+                    # increment
+                    max_existing_cluster[rank] = max_existing_cluster[rank] + 1
+            # store assignments
+            for isolate_index,isolate_name in enumerate(isolate_list):
+                original_component = components.a[isolate_index]
+                renamed_component = component_name[original_component]
+                lineage_assignation[rank][isolate_name] = renamed_component
             # save network
             G.save(file_name = output + "/" + output + '_rank_' + str(rank) + '_lineages.gt', fmt = 'gt')
             # clear edges
@@ -312,11 +263,11 @@ def cluster_into_lineages(distMat, rank_list = None, output = None, isolate_list
     for index,isolate in enumerate(isolate_list):
         overall_lineage = None
         for rank in rank_list:
-            overall_lineages['Rank_' + str(rank)][isolate] = lineage_assignation[rank][index]
+            overall_lineages['Rank_' + str(rank)][isolate] = lineage_assignation[rank][isolate]
             if overall_lineage is None:
-                overall_lineage = str(lineage_assignation[rank][index])
+                overall_lineage = str(lineage_assignation[rank][isolate])
             else:
-                overall_lineage = overall_lineage + '-' + str(lineage_assignation[rank][index])
+                overall_lineage = overall_lineage + '-' + str(lineage_assignation[rank][isolate])
         overall_lineages['overall'][isolate] = overall_lineage
     
     # print output as CSV
@@ -393,25 +344,3 @@ def run_clustering_for_rank(rank, distances_input = None, distance_ranks_input =
     
     # return clustering
     return lineage_assignation, seeds, nn, connections
-
-def make_nn_network(G,nn,output,rank):
-    """Output all nearest neighbour relationships for
-    deeper analysis of lineage analyses.
-    Args:
-        G (graph-tools network)
-        
-        nn (nested dict)
-        
-        output (str)
-        
-        rank (int)
-    
-    Returns:
-        Void
-    """
-    edges_to_add = set()
-    for i in nn.keys():
-        for j in nn[i]:
-            edges_to_add.add((i,j))
-    G.add_edge_list(edges_to_add)
-    G.save(file_name = output + "/" + output + '_nearestNeighbours_rank_' + str(rank) + '_lineages.graphml', fmt = 'graphml')

From dc042b981849c32f3ccae205bd2ba85288012497 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 13 May 2020 06:00:15 +0100
Subject: [PATCH 16/56] Enable visualisation of networks post-processing

---
 PopPUNK/__main__.py | 35 +++++++++++++++--------------------
 PopPUNK/plot.py     | 20 ++++++++++++++++----
 2 files changed, 31 insertions(+), 24 deletions(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index 59893775..db2958d6 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -615,7 +615,7 @@ def main():
             sys.stderr.write("Must specify at least one type of visualisation to output\n")
             sys.exit(1)
 
-        if args.distances is not None and (args.ref_db is not None or args.lineage_clustering):
+        if args.distances is not None and args.ref_db is not None:
             
             # Initial processing
             # Load original distances
@@ -701,24 +701,22 @@ def main():
                 outputsForGrapetree(viz_subset, core_distMat, isolateClustering, args.output, args.info_csv, args.rapidnj,
                                     overwrite = args.overwrite, microreact = args.microreact)
             if args.cytoscape:
-                if args.viz_lineages or args.external_clustering:
-                    sys.stderr.write("Can only generate a network output for fitted models\n")
+                sys.stderr.write("Writing cytoscape output\n")
+                if args.viz_lineages:
+                    for rank in isolateClustering.keys():
+                        numeric_rank = rank.split('_')[1]
+                        if numeric_rank.isdigit():
+                            genomeNetwork = gt.load_graph(args.ref_db + '/' + args.ref_db + '_rank_' + str(numeric_rank) + '_lineages.gt')
+                            outputsForCytoscape(genomeNetwork, isolateClustering, args.output,
+                                        args.info_csv, suffix = 'rank_' + str(rank), viz_subset = viz_subset)
                 else:
-                    sys.stderr.write("Writing cytoscape output\n")
-                    # Read in network and cluster assignment
                     genomeNetwork, cluster_file = fetchNetwork(prev_clustering, model, rlist, args.core_only, args.accessory_only)
-                    isolateClustering = readIsolateTypeFromCsv(cluster_file, mode = 'clusters', return_dict = True)
-
-                    # mask the network and dictionary of assignments
-                    viz_vertex = G.new_vertex_property('bool')
-                    for n,vertex in enumerate(G.vertices()):
-                        if rlist[n] in viz_subset:
-                            viz_vertex[vertex] = True
-                        else:
-                            viz_vertex[vertex] = False
-                    G.set_vertex_filter(viz_vertex)
-                    # write output
-                    outputsForCytoscape(genomeNetwork, isolateClustering, args.output, args.info_csv)
+                    outputsForCytoscape(genomeNetwork, isolateClustering, args.output, args.info_csv, viz_subset = viz_subset)
+                    if model.indiv_fitted:
+                        sys.stderr.write("Writing individual cytoscape networks\n")
+                        for dist_type in ['core', 'accessory']:
+                            outputsForCytoscape(indivNetworks[dist_type], isolateClustering, args.output,
+                                        args.info_csv, suffix = dist_type, viz_subset = viz_subset)
 
         else:
             # Cannot read input files
@@ -922,9 +920,6 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
             outputsForGrapetree(combined_seq, core_distMat, isolateClustering, output, info_csv, rapidnj,
                                 queryList = ordered_queryList, overwrite = overwrite, microreact = microreact)
         if cytoscape:
-            if assign_lineage:
-                sys.stderr.write("Cannot generate a cytoscape network from a lineage assignment")
-            else:
                 sys.stderr.write("Writing cytoscape output\n")
                 outputsForCytoscape(genomeNetwork, isolateClustering, output, info_csv, ordered_queryList)
 
diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py
index 16041b74..938e7612 100644
--- a/PopPUNK/plot.py
+++ b/PopPUNK/plot.py
@@ -356,7 +356,7 @@ def get_grid(minimum, maximum, resolution):
     return(xx, yy, xy)
 
 
-def outputsForCytoscape(G, clustering, outPrefix, epiCsv, queryList = None, suffix = None, writeCsv = True):
+def outputsForCytoscape(G, clustering, outPrefix, epiCsv, queryList = None, suffix = None, writeCsv = True, viz_subset = None):
     """Write outputs for cytoscape. A graphml of the network, and CSV with metadata
 
     Args:
@@ -379,6 +379,19 @@ def outputsForCytoscape(G, clustering, outPrefix, epiCsv, queryList = None, suff
             Whether to print CSV file to accompany network
 
     """
+    # get list of isolate names
+    isolate_names = list(G.vp.id)
+    
+    # mask network if subsetting
+    if viz_subset is not None:
+        viz_vertex = G.new_vertex_property('bool')
+        for n,vertex in enumerate(G.vertices()):
+            if isolate_names[n] in viz_subset:
+                viz_vertex[vertex] = True
+            else:
+                viz_vertex[vertex] = False
+        G.set_vertex_filter(viz_vertex)
+    
     # write graph file
     if suffix is None:
         graph_file_name = os.path.basename(outPrefix) + "_cytoscape.graphml"
@@ -388,10 +401,9 @@ def outputsForCytoscape(G, clustering, outPrefix, epiCsv, queryList = None, suff
 
     # Write CSV of metadata
     if writeCsv:
-        refNames = G.vertices
-        seqLabels = [r.split('/')[-1].split('.')[0] for r in refNames]
+        seqLabels = [i.split('/')[-1].split('.')[0] for i in isolate_names]
         writeClusterCsv(outPrefix + "/" + outPrefix + "_cytoscape.csv",
-                        refNames,
+                        isolate_names,
                         seqLabels,
                         clustering,
                         'cytoscape',

From 4a18c1e083513398b4a8714f39bf404709b7895a Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 13 May 2020 22:17:28 +0100
Subject: [PATCH 17/56] Enable querying and pruning of networks

---
 PopPUNK/__main__.py |  15 ++---
 PopPUNK/mash.py     |  15 +++--
 PopPUNK/network.py  | 132 +++++++++++++++++++++++++++-----------------
 PopPUNK/plot.py     |   2 +-
 PopPUNK/utils.py    |  19 ++++++-
 5 files changed, 119 insertions(+), 64 deletions(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index db2958d6..2521fb70 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -569,9 +569,9 @@ def main():
         #******************************# 
         # extract limited references from clique by default
         if not args.full_db:
-            newReferencesIndices, newReferencesNames, newReferencesFile = extractReferences(genomeNetwork, refList, args.output)
+            newReferencesIndices, newReferencesNames, newReferencesFile, genomeNetwork = extractReferences(genomeNetwork, refList, args.output)
             nodes_to_remove = set(range(len(refList))).difference(newReferencesIndices)
-            genomeNetwork.remove_vertex(list(nodes_to_remove))
+#            genomeNetwork.remove_vertex(list(nodes_to_remove))
             names_to_remove = [refList[n] for n in nodes_to_remove]
             prune_distance_matrix(refList, names_to_remove, distMat,
                                   args.output + "/" + os.path.basename(args.output) + ".dists")
@@ -796,7 +796,7 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
 
         # Calculate query-query distances
         ordered_queryList = []
-        
+
         # Assign to strains or lineages, as requested
         if assign_lineage:
 
@@ -854,8 +854,9 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
             # only update network if assigning to strains
             if full_db is False and assign_lineage is False:
                 mashOrder = refList + ordered_queryList
-                newRepresentativesNames, newRepresentativesFile = extractReferences(genomeNetwork, mashOrder, output, refList)
-                genomeNetwork.remove_nodes_from(set(genomeNetwork.nodes).difference(newRepresentativesNames))
+                newRepresentativesIndices, newRepresentativesNames, newRepresentativesFile, genomeNetwork = extractReferences(genomeNetwork, mashOrder, output, refList)
+#                genomeNetwork.remove_nodes_from(set(genomeNetwork.nodes).difference(newRepresentativesNames))
+                isolates_to_remove = set(mashOrder).difference(newRepresentativesNames)
                 newQueries = [x for x in ordered_queryList if x in frozenset(newRepresentativesNames)] # intersection that maintains order
                 genomeNetwork.save(output + "/" + os.path.basename(output) + '_graph.gt', fmt = 'gt')
             else:
@@ -865,7 +866,7 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
             if newQueries != queryList and use_mash:
                 tmpRefFile = writeTmpFile(newQueries)
                 constructDatabase(tmpRefFile, kmers, sketch_sizes, output,
-                                    args.estimated_length, True, threads, True) # overwrite old db
+                                    estimated_length, True, threads, True) # overwrite old db
                 os.remove(tmpRefFile)
             # With mash, this is the reduced DB constructed,
             # with sketchlib, all sketches
@@ -892,7 +893,7 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
             if full_db is False and assign_lineage is False:
                 # could also have newRepresentativesNames in this diff (should be the same) - but want
                 # to ensure consistency with the network in case of bad input/bugs
-                nodes_to_remove = set(combined_seq).difference(genomeNetwork.nodes)
+                nodes_to_remove = set(combined_seq).difference(newRepresentativesNames)
                 # This function also writes out the new distance matrix
                 postpruning_combined_seq, newDistMat = prune_distance_matrix(combined_seq, nodes_to_remove,
                                                                                 complete_distMat, dists_out)
diff --git a/PopPUNK/mash.py b/PopPUNK/mash.py
index 9b01864a..b88791ef 100644
--- a/PopPUNK/mash.py
+++ b/PopPUNK/mash.py
@@ -541,10 +541,10 @@ def queryDatabase(rNames, qNames, dbPrefix, queryPrefix, klist, self = True, num
             # Check mash output is consistent with expected order
             # This is ok in all tests, but best to check and exit in case something changes between mash versions
             expected_names = iterDistRows(refList, qNames, self)
-
             prev_ref = ""
             skip = 0
             skipped = 0
+            
             for line in mashOut:
                 # Skip the first row with self and symmetric elements
                 if skipped < skip:
@@ -601,17 +601,21 @@ def queryDatabase(rNames, qNames, dbPrefix, queryPrefix, klist, self = True, num
 
     # run pairwise analyses across kmer lengths, mutating distMat
     # Create range of rows that each thread will work with
+    # if there is only one pair, apply_along_axis will not work
+    if threads > number_pairs:
+        threads = number_pairs
     rows_per_thread = int(number_pairs / threads)
     big_threads = number_pairs % threads
     start = 0
     mat_chunks = []
+
     for thread in range(threads):
         end = start + rows_per_thread
         if thread < big_threads:
             end += 1
         mat_chunks.append((start, end))
         start = end
-
+    
     # create empty distMat that can be shared with multiple processes
     distMat = np.zeros((number_pairs, 2), dtype=raw.dtype)
     with SharedMemoryManager() as smm:
@@ -623,7 +627,7 @@ def queryDatabase(rNames, qNames, dbPrefix, queryPrefix, klist, self = True, num
 
         shm_distMat = smm.SharedMemory(size = distMat.nbytes)
         distMat_shared = NumpyShared(name = shm_distMat.name, shape = (number_pairs, 2), dtype = raw.dtype)
-
+        
         # Run regressions
         with Pool(processes = threads) as pool:
             pool.map(partial(fitKmerBlock,
@@ -667,7 +671,10 @@ def fitKmerBlock(idxRanges, distMat, raw, klist, jacobian):
     
     # analyse
     (start, end) = idxRanges
-    distMat[start:end, :] = np.apply_along_axis(fitKmerCurve, 1, raw[start:end, :], klist, jacobian)
+    if raw.shape[0] == 1:
+        distMat[start:end, :] = fitKmerCurve(raw[0,:], klist, jacobian)
+    else:
+        distMat[start:end, :] = np.apply_along_axis(fitKmerCurve, 1, raw[start:end, :], klist, jacobian)
 
 
 def fitKmerCurve(pairwise, klist, jacobian):
diff --git a/PopPUNK/network.py b/PopPUNK/network.py
index f1131c80..7eaf9cb0 100644
--- a/PopPUNK/network.py
+++ b/PopPUNK/network.py
@@ -25,6 +25,8 @@
 from .utils import listDistInts
 from .utils import readIsolateTypeFromCsv
 from .utils import readRfile
+from .utils import setupDBFuncs
+from .utils import isolateNameToLabel
 
 def fetchNetwork(network_dir, model, refList,
                   core_only = False, accessory_only = False):
@@ -105,73 +107,88 @@ def extractReferences(G, mashOrder, outPrefix, existingRefs = None):
     """
     if existingRefs == None:
         references = set()
+        reference_indices = []
     else:
         references = set(existingRefs)
-
+        index_lookup = {v:k for k,v in enumerate(mashOrder)}
+        reference_indices = [index_lookup[r] for r in references]
+    
     # extract cliques from network
-    cliques = [c.tolist() for c in gt.max_cliques(G)]
+    cliques_in_overall_graph = [c.tolist() for c in gt.max_cliques(G)]
     # order list by size of clique
-    cliques.sort(key = len, reverse = True)
+    cliques_in_overall_graph.sort(key = len, reverse = True)
     # iterate through cliques
-    for clique in cliques:
+    for clique in cliques_in_overall_graph:
         alreadyRepresented = 0
         for node in clique:
-            if node in references:
+            if node in reference_indices:
                 alreadyRepresented = 1
                 break
         if alreadyRepresented == 0:
-            references.add(clique[0])
+            reference_indices.append(clique[0])
 
     # Find any clusters which are represented by multiple references
     # First get cluster assignments
-    clusters = printClusters(G, mashOrder, printCSV=False)
+    clusters_in_overall_graph = printClusters(G, mashOrder, printCSV=False)
     # Construct a dict of sets for each cluster
-    ref_clusters = [set() for c in set(clusters.items())]
+    reference_clusters_in_overall_graph = [set() for c in set(clusters_in_overall_graph.items())]
     # Iterate through references
-    for reference_index in references:
+    for reference_index in reference_indices:
         # Add references to the appropriate cluster
-        ref_clusters[clusters[mashOrder[reference_index]]].add(reference_index)
+        reference_clusters_in_overall_graph[clusters_in_overall_graph[mashOrder[reference_index]]].add(reference_index)
 
-    # Make a mask of the existing network to only retain references
-    # Use a vertex filter
+    # Use a vertex filter to extract the subgraph of refences
+    # as a graphview
     reference_vertex = G.new_vertex_property('bool')
     for n,vertex in enumerate(G.vertices()):
-        if n in references:
+        if n in reference_indices:
             reference_vertex[vertex] = True
         else:
             reference_vertex[vertex] = False
-    G.set_vertex_filter(reference_vertex)
+#    G.set_vertex_filter(reference_vertex)
+    G_ref = gt.GraphView(G, vfilt = reference_vertex)
+    G_ref = gt.Graph(G_ref, prune = True) # https://stackoverflow.com/questions/30839929/graph-tool-graphview-object
     # Calculate component membership for reference graph
-    reference_components, reference_component_frequencies = gt.label_components(G)
+#    reference_graph_components, reference_graph_component_frequencies = gt.label_components(G_ref)
+    clusters_in_reference_graph = printClusters(G, mashOrder, printCSV=False)
     # Record to which components references below in the reference graph
-    reference_cluster = {}
-    for reference_index in references:
-        reference_cluster[reference_index] = reference_components.a[reference_index]
+    reference_clusters_in_reference_graph = {}
+    for reference_index in reference_indices:
+        reference_clusters_in_reference_graph[mashOrder[reference_index]] = clusters_in_reference_graph[mashOrder[reference_index]]
 
     # Unset mask on network for shortest path calculations
-    G.set_vertex_filter(None)
+#    G.set_vertex_filter(None)
 
     # Check if multi-reference components have been split as a validation test
     # First iterate through clusters
-    for cluster in ref_clusters:
+    network_update_required = False
+    for cluster in reference_clusters_in_overall_graph:
         # Identify multi-reference clusters by this length
         if len(cluster) > 1:
             check = list(cluster)
             # check if these are still in the same component in the reference graph
             for i in range(len(check)):
-                component_i = reference_cluster[check[i]]
+                component_i = reference_clusters_in_reference_graph[mashOrder[check[i]]]
                 for j in range(i, len(check)):
                     # Add intermediate nodes
-                    component_j = reference_cluster[check[j]]
+                    component_j = reference_clusters_in_reference_graph[mashOrder[check[j]]]
                     if component_i != component_j:
+                        network_update_required = True
                         vertex_list, edge_list = gt.shortest_path(G, check[i], check[j])
+                        # update reference list
                         for vertex in vertex_list:
-                            references.add(vertex)
+                            reference_vertex[vertex] = True
+                            reference_indices.add(int(vertex))
+    
+    # update reference graph if vertices have been added
+    if network_update_required:
+        G_ref = gt.GraphView(G, vfilt = reference_vertex)
+        G_ref = gt.Graph(G_ref, prune = True) # https://stackoverflow.com/questions/30839929/graph-tool-graphview-object
     
     # Order found references as in mash sketch files
-    reference_names = [mashOrder[int(x)] for x in sorted(references)]
+    reference_names = [mashOrder[int(x)] for x in sorted(reference_indices)]
     refFileName = writeReferences(reference_names, outPrefix)
-    return references, reference_names, refFileName
+    return reference_indices, reference_names, refFileName, G_ref
 
 def writeReferences(refList, outPrefix):
     """Writes chosen references to file
@@ -244,31 +261,29 @@ def constructNetwork(rlist, qlist, assignments, within_label, summarise = True):
     # data structures
     connections = []
     self_comparison = True
-    num_vertices = len(rlist)
+    vertex_labels = rlist
     
     # check if self comparison
     if rlist != qlist:
         self_comparison = False
-        num_vertices = num_vertices + len(qlist)
+        vertex_labels.append(qlist)
     
     # identify edges
     for assignment, (ref, query) in zip(assignments, listDistInts(rlist, qlist, self = self_comparison)):
         if assignment == within_label:
             connections.append((ref, query))
 
-    # issue warning
-    density_proportion = len(connections) / (0.5 * (len(rlist) * (len(rlist) + 1)))
-    if density_proportion > 0.4 or len(connections) > 500000:
-        sys.stderr.write("Warning: trying to create very dense network\n")
-
     # build the graph
     G = gt.Graph(directed = False)
-    G.add_vertex(num_vertices)
+    G.add_vertex(len(vertex_labels))
     G.add_edge_list(connections)
-#    for connection in connections:
-#        G.add_edge(*connection)
 
-    # give some summaries
+    # add isolate ID to network
+    vid = G.new_vertex_property('string',
+                                vals = isolateNameToLabel(vertex_labels))
+    G.vp.id = vid
+
+    # print some summaries
     if summarise:
         (components, density, transitivity, score) = networkSummary(G)
         sys.stderr.write("Network summary:\n" + "\n".join(["\tComponents\t" + str(components),
@@ -349,6 +364,11 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,
         distMat (numpy.array)
             Query-query distances
     """
+    # initalise functions
+    readDBParams = dbFuncs['readDBParams']
+    constructDatabase = dbFuncs['constructDatabase']
+    queryDatabase = dbFuncs['queryDatabase']
+    readDBParams = dbFuncs['readDBParams']
 
     # initialise links data structure
     new_edges = []
@@ -360,19 +380,24 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,
 
     # Set up query names
     qList, qSeqs = readRfile(qfile, oneSeq = use_mash)
-    queryFiles = dict(zip(qList, qSeqs))
     if use_mash == True:
         rNames = None
-        qNames = qSeqs
+        qNames = [i.split('/')[-1].split('.')[0] for i in qSeqs]
+        # mash must use sequence file names for both testing for
+        # assignment and for generating a new database
+        queryFiles = dict(zip(qSeqs, qSeqs))
     else:
         rNames = qList
         qNames = rNames
+        queryFiles = dict(zip(qNames, qSeqs))
 
     # store links for each query in a list of edge tuples
-    for assignment, (ref, query) in zip(assignments, iterDistRows(rlist, qList, self=False)):
+    ref_count = len(rlist)
+    for assignment, (ref, query) in zip(assignments, listDistInts(rlist, qNames, self = False)):
         if assignment == model.within_label:
-            new_edges.append((ref, query))
-            assigned.add(query)
+            # query index needs to be adjusted for existing vertices in network
+            new_edges.append((ref, query + ref_count))
+            assigned.add(qSeqs[query])
 
     # Calculate all query-query distances too, if updating database
     if queryQuery:
@@ -387,15 +412,15 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,
                                                         threads)
 
         queryAssignation = model.assign(distMat)
-        for assignment, (ref, query) in zip(queryAssignation, iterDistRows(qlist1, qlist1, self=True)):
+        for assignment, (ref, query) in zip(queryAssignation, listDistInts(qNames, qNames, self = True)):
             if assignment == model.within_label:
-                new_edges.append((ref, query))
+                new_edges.append((ref + ref_count, query + ref_count))
 
     # Otherwise only calculate query-query distances for new clusters
     else:
         # identify potentially new lineages in list: unassigned is a list of queries with no hits
-        unassigned = set(qNames).difference(assigned)
-
+        unassigned = set(qSeqs).difference(assigned)
+        query_indices = {k:v+ref_count for v,k in enumerate(qSeqs)}
         # process unassigned query sequences, if there are any
         if len(unassigned) > 1:
             sys.stderr.write("Found novel query clusters. Calculating distances between them:\n")
@@ -415,8 +440,8 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,
 
             # use database construction methods to find links between unassigned queries
             sketchSize = readDBParams(queryDB, kmers, None)[1]
-
             constructDatabase(tmpFile, kmers, sketchSize, tmpDirName, estimated_length, True, threads, False)
+
             qlist1, qlist2, distMat = queryDatabase(rNames = list(unassigned),
                                                     qNames = list(unassigned), 
                                                     dbPrefix = tmpDirName,
@@ -425,20 +450,27 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,
                                                     self = True,
                                                     number_plot_fits = 0,
                                                     threads = threads)
+
             queryAssignation = model.assign(distMat)
             
             # identify any links between queries and store in the same links dict
             # links dict now contains lists of links both to original database and new queries
-            for assignment, (query1, query2) in zip(queryAssignation, iterDistRows(qlist1, qlist2, self=True)):
+            # have to use names and link to query list in order to match to node indices
+            for assignment, (query1, query2) in zip(queryAssignation, iterDistRows(qlist1, qlist2, self = True)):
                 if assignment == model.within_label:
-                    new_edges.append((query1, query2))
+                    new_edges.append((query_indices[query1], query_indices[query2]))
 
             # remove directory
             shutil.rmtree(tmpDirName)
 
     # finish by updating the network
-    G.add_nodes_from(qNames)
-    G.add_edges_from(new_edges)
+    G.save('before.graphml',fmt='graphml')
+    G.add_vertex(len(qNames))
+    G.add_edge_list(new_edges)
+    # including the vertex ID property map
+    for i,q in enumerate(qNames):
+        G.vp.id[i + len(rlist)] = q
+    G.save('after.graphml',fmt='graphml')
 
     return qlist1, distMat
 
diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py
index 938e7612..aa5a98bf 100644
--- a/PopPUNK/plot.py
+++ b/PopPUNK/plot.py
@@ -528,7 +528,7 @@ def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering, output_format =
                     else:
                         d['Status'].append("Reference")
             elif output_format == 'cytoscape':
-                d['id'].append(name)
+                d['id'].append(label)
                 for cluster_type in clustering:
                     col_name = cluster_type + suffix
                     d[col_name].append(clustering[cluster_type][name])
diff --git a/PopPUNK/utils.py b/PopPUNK/utils.py
index ef3ac112..c9a72f71 100644
--- a/PopPUNK/utils.py
+++ b/PopPUNK/utils.py
@@ -211,7 +211,7 @@ def writeTmpFile(fileList):
     tmpName = mkstemp(suffix=".tmp", dir=".")[1]
     with open(tmpName, 'w') as tmpFile:
         for fileName in fileList:
-            tmpFile.write(fileName + "\n")
+            tmpFile.write(fileName + '\t' + fileName + "\n")
 
     return tmpName
 
@@ -396,7 +396,6 @@ def update_distance_matrices(refList, distMat, queryList = None, query_ref_distM
 
     # if query vs refdb (--assign-query), also include these comparisons
     if queryList is not None:
-
         # query v query - symmetric
         i = len(refList)
         j = len(refList)+1
@@ -555,3 +554,19 @@ def readRfile(rFile, oneSeq=False):
         sys.exit(1)
 
     return (names, sequences)
+
+def isolateNameToLabel(names):
+    """Function to process isolate names to labels
+    appropriate for visualisation.
+    
+    Args:
+        names (list)
+            List of isolate names.
+    Returns:
+        labels (list)
+            List of isolate labels.
+    """
+    # useful to have as a function in case we
+    # want to remove certain characters
+    labels = [name.split('/')[-1].split('.')[0] for name in names]
+    return labels

From a263367e707f0ca2206506fa2996c691d785370a Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Thu, 14 May 2020 06:40:49 +0100
Subject: [PATCH 18/56] Fix output of names and labels

---
 PopPUNK/__main__.py           |  1 -
 PopPUNK/lineage_clustering.py |  6 +++---
 PopPUNK/network.py            |  6 +++---
 PopPUNK/plot.py               | 15 +++++++--------
 4 files changed, 13 insertions(+), 15 deletions(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index 2521fb70..d9532485 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -855,7 +855,6 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
             if full_db is False and assign_lineage is False:
                 mashOrder = refList + ordered_queryList
                 newRepresentativesIndices, newRepresentativesNames, newRepresentativesFile, genomeNetwork = extractReferences(genomeNetwork, mashOrder, output, refList)
-#                genomeNetwork.remove_nodes_from(set(genomeNetwork.nodes).difference(newRepresentativesNames))
                 isolates_to_remove = set(mashOrder).difference(newRepresentativesNames)
                 newQueries = [x for x in ordered_queryList if x in frozenset(newRepresentativesNames)] # intersection that maintains order
                 genomeNetwork.save(output + "/" + os.path.basename(output) + '_graph.gt', fmt = 'gt')
diff --git a/PopPUNK/lineage_clustering.py b/PopPUNK/lineage_clustering.py
index f3cd4a5b..e470490f 100644
--- a/PopPUNK/lineage_clustering.py
+++ b/PopPUNK/lineage_clustering.py
@@ -197,7 +197,7 @@ def cluster_into_lineages(distMat, rank_list = None, output = None, isolate_list
         G.add_vertex(len(isolate_list))
         # add sequence labels for visualisation
         vid = G.new_vertex_property('string',
-                                    vals = [i.split('/')[-1].split('.')[0] for i in isolate_list])
+                                    vals = isolate_list)
         G.vp.id = vid
         
         # parallelise neighbour identification for each rank
@@ -269,7 +269,7 @@ def cluster_into_lineages(distMat, rank_list = None, output = None, isolate_list
             else:
                 overall_lineage = overall_lineage + '-' + str(lineage_assignation[rank][isolate])
         overall_lineages['overall'][isolate] = overall_lineage
-    
+    print('ISOLATES: ' + str(overall_lineages))
     # print output as CSV
     writeClusterCsv(output + "/" + output + '_lineages.csv',
                     isolate_list,
@@ -314,7 +314,7 @@ def run_clustering_for_rank(rank, distances_input = None, distance_ranks_input =
     distance_ranks = np.ndarray(distance_ranks_input.shape, dtype = distance_ranks_input.dtype, buffer = distance_ranks_shm.buf)
     isolate_list = isolates
     isolate_indices = range(0,len(isolate_list))
-    
+
     # load previous scheme
     seeds = {}
     if previous_seeds is not None:
diff --git a/PopPUNK/network.py b/PopPUNK/network.py
index 7eaf9cb0..f5d52417 100644
--- a/PopPUNK/network.py
+++ b/PopPUNK/network.py
@@ -280,7 +280,7 @@ def constructNetwork(rlist, qlist, assignments, within_label, summarise = True):
 
     # add isolate ID to network
     vid = G.new_vertex_property('string',
-                                vals = isolateNameToLabel(vertex_labels))
+                                vals = vertex_labels)
     G.vp.id = vid
 
     # print some summaries
@@ -382,7 +382,7 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,
     qList, qSeqs = readRfile(qfile, oneSeq = use_mash)
     if use_mash == True:
         rNames = None
-        qNames = [i.split('/')[-1].split('.')[0] for i in qSeqs]
+        qNames = isolateNameToLabel(qSeqs)
         # mash must use sequence file names for both testing for
         # assignment and for generating a new database
         queryFiles = dict(zip(qSeqs, qSeqs))
@@ -468,7 +468,7 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,
     G.add_vertex(len(qNames))
     G.add_edge_list(new_edges)
     # including the vertex ID property map
-    for i,q in enumerate(qNames):
+    for i,q in enumerate(qSeqs):
         G.vp.id[i + len(rlist)] = q
     G.save('after.graphml',fmt='graphml')
 
diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py
index aa5a98bf..888b86d0 100644
--- a/PopPUNK/plot.py
+++ b/PopPUNK/plot.py
@@ -25,6 +25,8 @@
     from sklearn.neighbors.kde import KernelDensity
 import dendropy
 
+from .utils import isolateNameToLabel
+
 def plot_scatter(X, scale, out_prefix, title, kde = True):
     """Draws a 2D scatter plot (png) of the core and accessory distances
 
@@ -401,7 +403,7 @@ def outputsForCytoscape(G, clustering, outPrefix, epiCsv, queryList = None, suff
 
     # Write CSV of metadata
     if writeCsv:
-        seqLabels = [i.split('/')[-1].split('.')[0] for i in isolate_names]
+        seqLabels = isolateNameToLabel(isolate_names)
         writeClusterCsv(outPrefix + "/" + outPrefix + "_cytoscape.csv",
                         isolate_names,
                         seqLabels,
@@ -479,15 +481,12 @@ def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering, output_format =
     d = defaultdict(list)
     if epiCsv is not None:
         epiData = pd.read_csv(epiCsv, index_col = 0, quotechar='"')
-        epiData.index = [i.split('/')[-1].split('.')[0] for i in epiData.index]
+        epiData.index = isolateNameToLabel(epiData.index)
         for e in epiData.columns.values:
             colnames.append(str(e))
 
     columns_to_be_omitted = []
 
-    # process clustering data
-    nodeLabels = [r.split('/')[-1].split('.')[0] for r in nodeNames]
-
     # get example clustering name for validation
     example_cluster_title = list(clustering.keys())[0]
 
@@ -668,7 +667,7 @@ def outputsForMicroreact(combined_list, coreMat, accMat, clustering, perplexity,
     from .tsne import generate_tsne
 
     # generate sequence labels
-    seqLabels = [r.split('/')[-1].split('.')[0] for r in combined_list]
+    seqLabels = isolateNameToLabel(combined_list)
 
     # check CSV before calculating other outputs
     writeClusterCsv(outPrefix + "/" + os.path.basename(outPrefix) + "_microreact_clusters.csv",
@@ -762,7 +761,7 @@ def outputsForPhandango(combined_list, coreMat, clustering, outPrefix, epiCsv, r
             Avoid regenerating tree if already built for microreact (default = False)
     """
     # generate sequence labels
-    seqLabels = [r.split('/')[-1].split('.')[0] for r in combined_list]
+    seqLabels = isolateNameToLabel(combined_list)
 
     # print clustering file
     writeClusterCsv(outPrefix + "/" + os.path.basename(outPrefix) + "_phandango_clusters.csv",
@@ -813,7 +812,7 @@ def outputsForGrapetree(combined_list, coreMat, clustering, outPrefix, epiCsv, r
             Avoid regenerating tree if already built for microreact (default = False).
     """
     # generate sequence labels
-    seqLabels = [r.split('/')[-1].split('.')[0] for r in combined_list]
+    seqLabels = isolateNameToLabel(combined_list)
 
     # print clustering file
     writeClusterCsv(outPrefix + "/" + os.path.basename(outPrefix) + "_grapetree_clusters.csv",

From 9eecdaa119c54ebf4bc93341cf0a449e2a8bb1a9 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Thu, 14 May 2020 07:08:23 +0100
Subject: [PATCH 19/56] Remove debugging message

---
 PopPUNK/lineage_clustering.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/PopPUNK/lineage_clustering.py b/PopPUNK/lineage_clustering.py
index e470490f..023ae724 100644
--- a/PopPUNK/lineage_clustering.py
+++ b/PopPUNK/lineage_clustering.py
@@ -269,7 +269,7 @@ def cluster_into_lineages(distMat, rank_list = None, output = None, isolate_list
             else:
                 overall_lineage = overall_lineage + '-' + str(lineage_assignation[rank][isolate])
         overall_lineages['overall'][isolate] = overall_lineage
-    print('ISOLATES: ' + str(overall_lineages))
+
     # print output as CSV
     writeClusterCsv(output + "/" + output + '_lineages.csv',
                     isolate_list,

From 09e656453dff5627b69026ae6e9ab213cbfa5310 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Thu, 14 May 2020 09:13:41 +0100
Subject: [PATCH 20/56] Add new dependency of lineage clustering on ref-db to
 tests

---
 test/run_test.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/test/run_test.py b/test/run_test.py
index 8f736720..f6bb6577 100755
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -74,7 +74,7 @@
 
 # lineage clustering
 sys.stderr.write("Running lineage clustering test (--lineage-clustering)\n")
-subprocess.run("python ../poppunk-runner.py --lineage-clustering --distances example_db/example_db.dists --output example_lineages --ranks 1,2,3,5", shell=True, check=True)
+subprocess.run("python ../poppunk-runner.py --lineage-clustering --distances example_db/example_db.dists --output example_lineages --ranks 1,2,3,5 --ref-db example_db", shell=True, check=True)
 
 # assign query to lineages
 sys.stderr.write("Running query assignment (--assign-lineages)\n")
@@ -82,7 +82,7 @@
 
 # lineage clustering with mash
 sys.stderr.write("Running lineage clustering test (--lineage-clustering)\n")
-subprocess.run("python ../poppunk-runner.py --lineage-clustering --distances example_db_mash/example_db_mash.dists --output example_lineages_mash --ranks 1,2,3,5 --use-mash", shell=True, check=True)
+subprocess.run("python ../poppunk-runner.py --lineage-clustering --distances example_db_mash/example_db_mash.dists --output example_lineages_mash --ranks 1,2,3,5  --ref-db example_db --use-mash", shell=True, check=True)
 
 # assign query to lineages with mash
 sys.stderr.write("Running query assignment (--assign-lineages)\n")

From 09db4dcb568b5d9fa4d70b450af90d106d4edcf3 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Thu, 14 May 2020 12:24:21 +0100
Subject: [PATCH 21/56] Add graph-tool to dependencies

---
 environment.yml | 1 +
 1 file changed, 1 insertion(+)

diff --git a/environment.yml b/environment.yml
index 794e97d9..845327bf 100644
--- a/environment.yml
+++ b/environment.yml
@@ -21,3 +21,4 @@ dependencies:
   - rapidnj
   - h5py
   - pp-sketchlib
+  - graph-tool

From 9520b91623ec81999aa222ca77c174daa1ef75a6 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Thu, 14 May 2020 13:00:23 +0100
Subject: [PATCH 22/56] Overwrite for local running of tests

---
 test/run_test.py | 30 +++++++++++++++---------------
 1 file changed, 15 insertions(+), 15 deletions(-)

diff --git a/test/run_test.py b/test/run_test.py
index f6bb6577..2d939661 100755
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -18,23 +18,23 @@
 
 #easy run
 sys.stderr.write("Running database creation + DBSCAN model fit + fit refinement (--easy-run)\n")
-subprocess.run("python ../poppunk-runner.py --easy-run --r-files references.txt --min-k 13 --k-step 3 --output example_db --full-db", shell=True, check=True)
+subprocess.run("python ../poppunk-runner.py --easy-run --r-files references.txt --min-k 13 --k-step 3 --output example_db --full-db --overwrite", shell=True, check=True)
 
 #fit GMM
 sys.stderr.write("Running GMM model fit (--fit-model)\n")
-subprocess.run("python ../poppunk-runner.py --fit-model --distances example_db/example_db.dists --ref-db example_db --output example_db --full-db --K 4 --microreact --cytoscape", shell=True, check=True)
+subprocess.run("python ../poppunk-runner.py --fit-model --distances example_db/example_db.dists --ref-db example_db --output example_db --full-db --K 4 --microreact --cytoscape --overwrite", shell=True, check=True)
 
 #refine model with GMM
 sys.stderr.write("Running model refinement (--refine-model)\n")
-subprocess.run("python ../poppunk-runner.py --refine-model --distances example_db/example_db.dists --ref-db example_db --output example_refine --neg-shift 0.8", shell=True, check=True)
+subprocess.run("python ../poppunk-runner.py --refine-model --distances example_db/example_db.dists --ref-db example_db --output example_refine --neg-shift 0.8 --overwrite", shell=True, check=True)
 
 #assign query
 sys.stderr.write("Running query assignment (--assign-query)\n")
-subprocess.run("python ../poppunk-runner.py --assign-query --q-files queries.txt --distances example_db/example_db.dists --ref-db example_db --output example_query --update-db", shell=True, check=True)
+subprocess.run("python ../poppunk-runner.py --assign-query --q-files queries.txt --distances example_db/example_db.dists --ref-db example_db --output example_query --update-db --overwrite", shell=True, check=True)
 
 #use model
 sys.stderr.write("Running with an existing model (--use-model)\n")
-subprocess.run("python ../poppunk-runner.py --use-model --ref-db example_db --model-dir example_db --distances example_db/example_db.dists --output example_use", shell=True, check=True)
+subprocess.run("python ../poppunk-runner.py --use-model --ref-db example_db --model-dir example_db --distances example_db/example_db.dists --output example_use --overwrite", shell=True, check=True)
 
 
 # tests with mash backend
@@ -46,23 +46,23 @@
 
 #easy run
 sys.stderr.write("Running database creation + DBSCAN model fit + fit refinement (--easy-run)\n")
-subprocess.run("python ../poppunk-runner.py --easy-run --r-files references.txt --min-k 13 --k-step 3 --output example_db_mash --full-db --no-stream --use-mash --mash " + mash_exec, shell=True, check=True)
+subprocess.run("python ../poppunk-runner.py --easy-run --r-files references.txt --min-k 13 --k-step 3 --output example_db_mash --full-db --no-stream --overwrite --use-mash --mash " + mash_exec, shell=True, check=True)
 
 #fit GMM
 sys.stderr.write("Running GMM model fit (--fit-model)\n")
-subprocess.run("python ../poppunk-runner.py --fit-model --distances example_db_mash/example_db_mash.dists --ref-db example_db_mash --output example_db_mash --full-db --K 4 --microreact --cytoscape --no-stream --use-mash --mash " + mash_exec, shell=True, check=True)
+subprocess.run("python ../poppunk-runner.py --fit-model --distances example_db_mash/example_db_mash.dists --ref-db example_db_mash --output example_db_mash --full-db --K 4 --microreact --cytoscape --no-stream --overwrite --use-mash --mash " + mash_exec, shell=True, check=True)
 
 #refine model with GMM
 sys.stderr.write("Running model refinement (--refine-model)\n")
-subprocess.run("python ../poppunk-runner.py --refine-model --distances example_db_mash/example_db_mash.dists --ref-db example_db_mash --output example_refine_mash --neg-shift 0.8 --use-mash --mash " + mash_exec, shell=True, check=True)
+subprocess.run("python ../poppunk-runner.py --refine-model --distances example_db_mash/example_db_mash.dists --ref-db example_db_mash --output example_refine_mash --neg-shift 0.8 --overwrite --use-mash --mash " + mash_exec, shell=True, check=True)
 
 #assign query
 sys.stderr.write("Running query assignment (--assign-query)\n")
-subprocess.run("python ../poppunk-runner.py --assign-query --q-files queries.txt --distances example_db_mash/example_db_mash.dists --ref-db example_db_mash --output example_query_mash --update-db --no-stream --use-mash --mash " + mash_exec, shell=True, check=True)
+subprocess.run("python ../poppunk-runner.py --assign-query --q-files queries.txt --distances example_db_mash/example_db_mash.dists --ref-db example_db_mash --output example_query_mash --update-db --overwrite --no-stream --use-mash --mash " + mash_exec, shell=True, check=True)
 
 #use model
 sys.stderr.write("Running with an existing model (--use-model)\n")
-subprocess.run("python ../poppunk-runner.py --use-model --ref-db example_db_mash --model-dir example_db_mash --distances example_db_mash/example_db_mash.dists --output example_use_mash --no-stream --use-mash --mash " + mash_exec, shell=True, check=True)
+subprocess.run("python ../poppunk-runner.py --use-model --ref-db example_db_mash --model-dir example_db_mash --distances example_db_mash/example_db_mash.dists --output example_use_mash --overwrite --no-stream --use-mash --mash " + mash_exec, shell=True, check=True)
 
 
 # general tests
@@ -70,23 +70,23 @@
 
 #generate viz
 sys.stderr.write("Running microreact visualisations (--generate-viz)\n")
-subprocess.run("python ../poppunk-runner.py --generate-viz --distances example_db_mash/example_db_mash.dists --ref-db example_db_mash --output example_viz --microreact --subset subset.txt", shell=True, check=True)
+subprocess.run("python ../poppunk-runner.py --generate-viz --distances example_db_mash/example_db_mash.dists --ref-db example_db_mash --output example_viz --microreact --subset subset.txt --overwrite", shell=True, check=True)
 
 # lineage clustering
 sys.stderr.write("Running lineage clustering test (--lineage-clustering)\n")
-subprocess.run("python ../poppunk-runner.py --lineage-clustering --distances example_db/example_db.dists --output example_lineages --ranks 1,2,3,5 --ref-db example_db", shell=True, check=True)
+subprocess.run("python ../poppunk-runner.py --lineage-clustering --distances example_db/example_db.dists --output example_lineages --ranks 1,2,3,5 --ref-db example_db --overwrite", shell=True, check=True)
 
 # assign query to lineages
 sys.stderr.write("Running query assignment (--assign-lineages)\n")
-subprocess.run("python ../poppunk-runner.py --assign-lineages --q-files queries.txt --distances example_db/example_db.dists --ref-db example_db --existing-scheme example_lineages/example_lineages_lineages.pkl --output example_lineage_query --update-db", shell=True, check=True)
+subprocess.run("python ../poppunk-runner.py --assign-lineages --q-files queries.txt --distances example_db/example_db.dists --ref-db example_db --existing-scheme example_lineages/example_lineages_lineages.pkl --output example_lineage_query --update-db --overwrite", shell=True, check=True)
 
 # lineage clustering with mash
 sys.stderr.write("Running lineage clustering test (--lineage-clustering)\n")
-subprocess.run("python ../poppunk-runner.py --lineage-clustering --distances example_db_mash/example_db_mash.dists --output example_lineages_mash --ranks 1,2,3,5  --ref-db example_db --use-mash", shell=True, check=True)
+subprocess.run("python ../poppunk-runner.py --lineage-clustering --distances example_db_mash/example_db_mash.dists --output example_lineages_mash --ranks 1,2,3,5  --ref-db example_db --use-mash --overwrite", shell=True, check=True)
 
 # assign query to lineages with mash
 sys.stderr.write("Running query assignment (--assign-lineages)\n")
-subprocess.run("python ../poppunk-runner.py --assign-lineages --q-files queries.txt --distances example_db_mash/example_db_mash.dists --ref-db example_db_mash --existing-scheme example_lineages_mash/example_lineages_mash_lineages.pkl --output example_lineage_mash_query --update-db --use-mash", shell=True, check=True)
+subprocess.run("python ../poppunk-runner.py --assign-lineages --q-files queries.txt --distances example_db_mash/example_db_mash.dists --ref-db example_db_mash --existing-scheme example_lineages_mash/example_lineages_mash_lineages.pkl --output example_lineage_mash_query --update-db --use-mash --overwrite", shell=True, check=True)
 
 
 # tests of other command line programs (TODO)

From 2235736b554fac7cbeabc2b939ec4e603fa56ea1 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Thu, 14 May 2020 13:31:29 +0100
Subject: [PATCH 23/56] Change references to query sequences in network
 extension

---
 PopPUNK/network.py | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/PopPUNK/network.py b/PopPUNK/network.py
index f5d52417..90d82f89 100644
--- a/PopPUNK/network.py
+++ b/PopPUNK/network.py
@@ -381,15 +381,14 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,
     # Set up query names
     qList, qSeqs = readRfile(qfile, oneSeq = use_mash)
     if use_mash == True:
-        rNames = None
-        qNames = isolateNameToLabel(qSeqs)
         # mash must use sequence file names for both testing for
         # assignment and for generating a new database
-        queryFiles = dict(zip(qSeqs, qSeqs))
+        rNames = None
+        qNames = isolateNameToLabel(qSeqs)
     else:
         rNames = qList
         qNames = rNames
-        queryFiles = dict(zip(qNames, qSeqs))
+    queryFiles = dict(zip(qNames, qSeqs))
 
     # store links for each query in a list of edge tuples
     ref_count = len(rlist)

From 1df2387dd19bec2ee2cee8e8b49ea3bb11f2d0e8 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Thu, 14 May 2020 13:34:08 +0100
Subject: [PATCH 24/56] Use hash for query sequence name retrieval

---
 PopPUNK/network.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/PopPUNK/network.py b/PopPUNK/network.py
index 90d82f89..fa7b3b60 100644
--- a/PopPUNK/network.py
+++ b/PopPUNK/network.py
@@ -396,7 +396,7 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,
         if assignment == model.within_label:
             # query index needs to be adjusted for existing vertices in network
             new_edges.append((ref, query + ref_count))
-            assigned.add(qSeqs[query])
+            assigned.add(queryFiles[qNames[query]])
 
     # Calculate all query-query distances too, if updating database
     if queryQuery:

From a38e58d207ae93cf1c82b8a1eb605518a2ae6520 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Thu, 14 May 2020 13:53:34 +0100
Subject: [PATCH 25/56] Use list for query sequence retrieval

---
 PopPUNK/network.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/PopPUNK/network.py b/PopPUNK/network.py
index fa7b3b60..f8b0de0a 100644
--- a/PopPUNK/network.py
+++ b/PopPUNK/network.py
@@ -396,7 +396,7 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,
         if assignment == model.within_label:
             # query index needs to be adjusted for existing vertices in network
             new_edges.append((ref, query + ref_count))
-            assigned.add(queryFiles[qNames[query]])
+            assigned.add(qNames[query])
 
     # Calculate all query-query distances too, if updating database
     if queryQuery:

From 92956f04c468aac1f0ecd4533d242a5e4f1dbb21 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Thu, 14 May 2020 16:23:06 +0100
Subject: [PATCH 26/56] Correct maths of listDistInts

---
 PopPUNK/utils.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/PopPUNK/utils.py b/PopPUNK/utils.py
index c9a72f71..379ebe1c 100644
--- a/PopPUNK/utils.py
+++ b/PopPUNK/utils.py
@@ -181,7 +181,7 @@ def listDistInts(refSeqs, querySeqs, self=True):
     num_ref = len(refSeqs)
     num_query = len(querySeqs)
     if self:
-        comparisons = [(0,0)] * (num_ref * (num_ref-1))
+        comparisons = [(0,0)] * int((num_ref * (num_ref-1)) * 0.5)
         if refSeqs != querySeqs:
             raise RuntimeError('refSeqs must equal querySeqs for db building (self = true)')
         for i in range(num_ref):

From a6037c36ae1f25e69e3e7843ba287d6af5a47dfd Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Fri, 3 Jul 2020 17:11:58 +0100
Subject: [PATCH 27/56] Remove legacy mash test

---
 test/run_test.py | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/test/run_test.py b/test/run_test.py
index 277506b4..96743d2b 100755
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -39,10 +39,6 @@
 # general tests
 sys.stderr.write("Running general tests\n\n")
 
-#generate viz
-sys.stderr.write("Running microreact visualisations (--generate-viz)\n")
-subprocess.run("python ../poppunk-runner.py --generate-viz --distances example_db_mash/example_db_mash.dists --ref-db example_db_mash --output example_viz --microreact --subset subset.txt --overwrite", shell=True, check=True)
-
 # lineage clustering
 sys.stderr.write("Running lineage clustering test (--lineage-clustering)\n")
 subprocess.run("python ../poppunk-runner.py --lineage-clustering --distances example_db/example_db.dists --output example_lineages --ranks 1,2,3,5 --ref-db example_db --overwrite", shell=True, check=True)

From 1a363361e5c853ad907b293d08486f96ce77999d Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Sat, 4 Jul 2020 06:51:52 +0100
Subject: [PATCH 28/56] Fix test file

---
 test/run_test.py | 6 ------
 1 file changed, 6 deletions(-)

diff --git a/test/run_test.py b/test/run_test.py
index a54c57a0..da5824a1 100755
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -32,14 +32,8 @@
 sys.stderr.write("Running with an existing model (--use-model)\n")
 subprocess.run("python ../poppunk-runner.py --use-model --ref-db example_db --model-dir example_db --distances example_db/example_db.dists --output example_use --overwrite", shell=True, check=True)
 
-<<<<<<< HEAD
 # general tests
 sys.stderr.write("Running general tests\n\n")
-=======
-#generate viz
-sys.stderr.write("Running microreact visualisations (--generate-viz)\n")
-subprocess.run("python ../poppunk-runner.py --generate-viz --distances example_db_mash/example_db_mash.dists --ref-db example_db_mash --output example_viz --microreact --subset subset.txt --overwrite", shell=True, check=True)
->>>>>>> a661bd3cf83811a443817f8065e1e9d7b5e41c22
 
 # lineage clustering
 sys.stderr.write("Running lineage clustering test (--lineage-clustering)\n")

From 22f20d9fae2414cf5129b7c4d21f54f8854ef724 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Mon, 13 Jul 2020 15:04:01 +0100
Subject: [PATCH 29/56] Change minimum k step

---
 PopPUNK/__main__.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index 99618b8b..595b5fc8 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -248,8 +248,8 @@ def main():
     if args.min_k >= args.max_k:
         sys.stderr.write("Minimum kmer size " + str(args.min_k) + " must be smaller than maximum kmer size\n")
         sys.exit(1)
-    elif args.k_step < 2:
-        sys.stderr.write("Kmer size step must be at least two\n")
+    elif args.k_step < 1:
+        sys.stderr.write("Kmer size step must be at least one\n")
         sys.exit(1)
     elif no_sketchlib and (args.min_k < 9 or args.max_k > 31):
         sys.stderr.write("When using Mash, Kmer size must be between 9 and 31\n")

From e00707acbf9ab6a5f3cc3c01e9f0255f41609523 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 15 Jul 2020 13:13:31 +0100
Subject: [PATCH 30/56] Restore generate-viz mode test

---
 test/run_test.py | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/test/run_test.py b/test/run_test.py
index da5824a1..13f889b5 100755
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -32,6 +32,10 @@
 sys.stderr.write("Running with an existing model (--use-model)\n")
 subprocess.run("python ../poppunk-runner.py --use-model --ref-db example_db --model-dir example_db --distances example_db/example_db.dists --output example_use --overwrite", shell=True, check=True)
 
+#generate viz
+sys.stderr.write("Running microreact visualisations (--generate-viz)\n")
+subprocess.run("python ../poppunk-runner.py --generate-viz --distances example_db/example_db.dists --ref-db example_db --output example_viz --microreact --subset subset.txt", shell=True, check=True)
+
 # general tests
 sys.stderr.write("Running general tests\n\n")
 

From f6b86f61b0d966d696e42c4651c54fa6d2de8b82 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 15 Jul 2020 13:23:17 +0100
Subject: [PATCH 31/56] Specified graph-tool package as a dependency in
 documentation

---
 docs/conf.py          | 2 +-
 docs/installation.rst | 3 +--
 requirements.txt      | 2 +-
 setup.py              | 2 +-
 4 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/docs/conf.py b/docs/conf.py
index b986e1b9..f19a9d93 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -38,7 +38,7 @@
 # Causes a problem with rtd: https://github.com/pypa/setuptools/issues/1694
 autodoc_mock_imports = ["hdbscan",
                         "numpy",
-                        "networkx",
+                        "graph-tool",
                         "pandas",
                         "scipy",
                         "sklearn",
diff --git a/docs/installation.rst b/docs/installation.rst
index 74ff2091..645ccd92 100644
--- a/docs/installation.rst
+++ b/docs/installation.rst
@@ -59,7 +59,7 @@ We tested PopPUNK with the following packages:
 * ``DendroPy`` (4.3.0)
 * ``hdbscan`` (0.8.13)
 * ``matplotlib`` (2.1.2)
-* ``networkx`` (2.1)
+* ``graph-tool`` (2.31)
 * ``numpy`` (1.14.1)
 * ``pandas`` (0.22.0)
 * ``scikit-learn`` (0.19.1)
@@ -69,4 +69,3 @@ We tested PopPUNK with the following packages:
 Optionally, you can use `rapidnj <http://birc.au.dk/software/rapidnj/>`__
 if producing output with ``--microreact`` and ``--rapidnj`` options. We used
 v2.3.2.
-
diff --git a/requirements.txt b/requirements.txt
index 4a73dcbe..fbdb5eec 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,7 +3,7 @@ DendroPy>=4.3.0
 h5py>=2.10.0
 hdbscan>=0.8.13
 matplotlib>=2.1.2
-networkx>=2.1
+graph-tool>=2.31
 numpy>=1.14.1
 pandas>=0.22.0
 scikit-learn>=0.19.1
diff --git a/setup.py b/setup.py
index 1c268b1a..0a38235f 100644
--- a/setup.py
+++ b/setup.py
@@ -69,7 +69,7 @@ def find_version(*file_paths):
                       'scikit-learn',
                       'DendroPy',
                       'pandas',
-                      'networkx>=2.0',
+                      'graph-tool',
                       'matplotlib',
                       'hdbscan'],
     test_suite="test",

From 04df0df0561572c17863809dd9563fd3b3e13330 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 15 Jul 2020 13:32:20 +0100
Subject: [PATCH 32/56] Removed outdated parts from troubleshooting document

---
 docs/troubleshooting.rst | 24 +-----------------------
 1 file changed, 1 insertion(+), 23 deletions(-)

diff --git a/docs/troubleshooting.rst b/docs/troubleshooting.rst
index 6418ce05..b392f21b 100644
--- a/docs/troubleshooting.rst
+++ b/docs/troubleshooting.rst
@@ -10,27 +10,6 @@ installing or running the software please raise an issue on github.
 Error/warning messages
 ----------------------
 
-Errors in graph.py
-^^^^^^^^^^^^^^^^^^
-If you get an ``AttributeError``::
-
-    AttributeError: 'Graph' object has no attribute 'node'
-
-Then your ``networkx`` package is out of date. Its version needs to be at >=v2.0.
-
-Trying to create a very large network
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-When using ``--refine-model`` you may see the message::
-
-    Warning: trying to create very large network
-
-One or more times. This is triggered if :math:`5 \times 10^5` edges or greater than 40%
-of the maximum possible number of edges have been added into the network. This suggests that
-the boundary is too large including too many links as within sample. This isn't necessarily a
-problem as it can occur at the edge of the optimisation range, so will not be the final optimised
-result. However, if you have a large number of samples it may make this step run very slowly
-and/or use a lot of memory. If that is the case, decrease ``--pos-shift``.
-
 Row name mismatch
 ^^^^^^^^^^^^^^^^^
 PopPUNK may throw::
@@ -236,7 +215,7 @@ Finding which isolates contribute to these distances reveals a clear culprit::
        1 14412_4_10
       28 14412_4_15
 
-In this case it is sufficent to increase the number of mixture components to four,
+In this case it is sufficient to increase the number of mixture components to four,
 which no longer includes these inflated distances. This gives a score of 0.9401 and 28 components:
 
 .. image:: images/contam_DPGMM_better_fit.png
@@ -301,4 +280,3 @@ resources. Here are some tips based on these experiences:
 
 Another option for scaling is to run ``--create-db`` with a smaller initial set (not
 using the ``--full-db`` command), then use ``--assign-query`` to add to this.
-

From 651586feaca82aaa8e0fc543026ab9b23f4e7141 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 15 Jul 2020 13:43:30 +0100
Subject: [PATCH 33/56] Update docstrings for graph-tool

---
 PopPUNK/network.py | 16 ++++++++--------
 PopPUNK/plot.py    |  2 +-
 2 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/PopPUNK/network.py b/PopPUNK/network.py
index 43a85b08..d2b5974f 100644
--- a/PopPUNK/network.py
+++ b/PopPUNK/network.py
@@ -32,8 +32,8 @@ def fetchNetwork(network_dir, model, refList,
                   core_only = False, accessory_only = False):
     """Load the network based on input options
 
-       Returns the network as a networkx, and sets the slope parameter of
-       the passed model object.
+       Returns the network as a graph-tool format graph, and sets
+       the slope parameter of the passed model object.
 
        Args:
             network_dir (str)
@@ -52,7 +52,7 @@ def fetchNetwork(network_dir, model, refList,
                 [default = False]
 
        Returns:
-            genomeNetwork (nx.Graph)
+            genomeNetwork (graph)
                 The loaded network
             cluster_file (str)
                 The CSV of cluster assignments corresponding to this network
@@ -90,7 +90,7 @@ def extractReferences(G, mashOrder, outPrefix, existingRefs = None):
        Writes chosen references to file by calling :func:`~writeReferences`
 
        Args:
-           G (networkx.Graph)
+           G (graph)
                A network used to define clusters from :func:`~constructNetwork`
            mashOrder (list)
                The order of files in the sketches, so returned references are in the same order
@@ -255,7 +255,7 @@ def constructNetwork(rlist, qlist, assignments, within_label, summarise = True):
             (default = True)
 
     Returns:
-        G (networkx.Graph)
+        G (graph)
             The resulting network
     """
     # data structures
@@ -298,7 +298,7 @@ def networkSummary(G):
     """Provides summary values about the network
 
     Args:
-        G (networkx.Graph)
+        G (graph)
             The network of strains from :func:`~constructNetwork`
 
     Returns:
@@ -332,7 +332,7 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,
             List of reference names
         qfile (str)
             File containing queries
-        G (networkx.Graph)
+        G (graph)
             Network to add to (mutated)
         kmers (list)
             List of k-mer sizes
@@ -480,7 +480,7 @@ def printClusters(G, rlist, outPrefix = "_clusters.csv", oldClusterFile = None,
     Also writes assignments to a CSV file
 
     Args:
-        G (networkx.Graph)
+        G (graph)
             Network used to define clusters (from :func:`~constructNetwork` or
             :func:`~addQueryToNetwork`)
         outPrefix (str)
diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py
index 888b86d0..279a89fe 100644
--- a/PopPUNK/plot.py
+++ b/PopPUNK/plot.py
@@ -362,7 +362,7 @@ def outputsForCytoscape(G, clustering, outPrefix, epiCsv, queryList = None, suff
     """Write outputs for cytoscape. A graphml of the network, and CSV with metadata
 
     Args:
-        G (networkx.Graph)
+        G (graph)
             The network to write from :func:`~PopPUNK.network.constructNetwork`
         clustering (dict)
             Dictionary of cluster assignments (keys are nodeNames).

From 1d5766b8231f2215b4060dadac43c5c427c77a78 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 15 Jul 2020 14:14:58 +0100
Subject: [PATCH 34/56] Update PopPUNK/__main__.py

Add file name correction.

Co-authored-by: John Lees <lees.john6@gmail.com>
---
 PopPUNK/__main__.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index 595b5fc8..24896cf2 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -712,7 +712,7 @@ def main():
                     for rank in isolateClustering.keys():
                         numeric_rank = rank.split('_')[1]
                         if numeric_rank.isdigit():
-                            genomeNetwork = gt.load_graph(args.ref_db + '/' + args.ref_db + '_rank_' + str(numeric_rank) + '_lineages.gt')
+                            genomeNetwork = gt.load_graph(args.ref_db + '/' + os.path.basename(args.ref_db) + '_rank_' + str(numeric_rank) + '_lineages.gt')
                             outputsForCytoscape(genomeNetwork, isolateClustering, args.output,
                                         args.info_csv, suffix = 'rank_' + str(rank), viz_subset = viz_subset)
                 else:

From ad04c72b79729eef36ea74099439c7cbd7c2d1a8 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 15 Jul 2020 14:15:46 +0100
Subject: [PATCH 35/56] Update PopPUNK/__main__.py

Remove commented line.

Co-authored-by: John Lees <lees.john6@gmail.com>
---
 PopPUNK/__main__.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index 24896cf2..b89900c9 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -575,7 +575,7 @@ def main():
         if not args.full_db:
             newReferencesIndices, newReferencesNames, newReferencesFile, genomeNetwork = extractReferences(genomeNetwork, refList, args.output)
             nodes_to_remove = set(range(len(refList))).difference(newReferencesIndices)
-#            genomeNetwork.remove_vertex(list(nodes_to_remove))
+#            
             names_to_remove = [refList[n] for n in nodes_to_remove]
             prune_distance_matrix(refList, names_to_remove, distMat,
                                   args.output + "/" + os.path.basename(args.output) + ".dists")

From 82e245a354d4141b4e1a36194f6d117b772ae5ac Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 15 Jul 2020 14:25:57 +0100
Subject: [PATCH 36/56] Update PopPUNK/__main__.py

Reverting merge error.

Co-authored-by: John Lees <lees.john6@gmail.com>
---
 PopPUNK/__main__.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index b89900c9..04d7dfb8 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -657,7 +657,9 @@ def main():
                 postpruning_combined_seq, newDistMat = prune_distance_matrix(rlist, isolates_to_remove,
                                                                       complete_distMat, dists_out)
 
-            combined_seq, core_distMat, acc_distMat = update_distance_matrices(viz_subset, newDistMat)
+            combined_seq, core_distMat, acc_distMat = \	
+                update_distance_matrices(viz_subset, newDistMat,	
+                                         threads = args.threads)
 
             # reorder subset to ensure list orders match
             try:

From 46cc3624538e9978bcbeae52978fd070f2642fcf Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 15 Jul 2020 14:27:53 +0100
Subject: [PATCH 37/56] Remove debug file printing

---
 PopPUNK/network.py | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/PopPUNK/network.py b/PopPUNK/network.py
index d2b5974f..103f879d 100644
--- a/PopPUNK/network.py
+++ b/PopPUNK/network.py
@@ -463,13 +463,12 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,
             shutil.rmtree(tmpDirName)
 
     # finish by updating the network
-    G.save('before.graphml',fmt='graphml')
     G.add_vertex(len(qNames))
     G.add_edge_list(new_edges)
+    
     # including the vertex ID property map
     for i,q in enumerate(qSeqs):
         G.vp.id[i + len(rlist)] = q
-    G.save('after.graphml',fmt='graphml')
 
     return qlist1, distMat
 

From c6167e8fb9519f534db95668f811567d07f2c9eb Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 15 Jul 2020 15:17:28 +0100
Subject: [PATCH 38/56] Update PopPUNK/mash.py

Remove whitespace.

Co-authored-by: John Lees <lees.john6@gmail.com>
---
 PopPUNK/mash.py | 2 --
 1 file changed, 2 deletions(-)

diff --git a/PopPUNK/mash.py b/PopPUNK/mash.py
index b88791ef..0aa551e5 100644
--- a/PopPUNK/mash.py
+++ b/PopPUNK/mash.py
@@ -627,7 +627,6 @@ def queryDatabase(rNames, qNames, dbPrefix, queryPrefix, klist, self = True, num
 
         shm_distMat = smm.SharedMemory(size = distMat.nbytes)
         distMat_shared = NumpyShared(name = shm_distMat.name, shape = (number_pairs, 2), dtype = raw.dtype)
-        
         # Run regressions
         with Pool(processes = threads) as pool:
             pool.map(partial(fitKmerBlock,
@@ -713,4 +712,3 @@ def fitKmerCurve(pairwise, klist, jacobian):
 
     # Return core, accessory
     return(np.flipud(transformed_params))
-

From 3fb00065788291e4b99733c1d0335cb949f00211 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 15 Jul 2020 15:42:40 +0100
Subject: [PATCH 39/56] Whitespace removed

---
 PopPUNK/__main__.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index 04d7dfb8..24bf5b0b 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -657,7 +657,7 @@ def main():
                 postpruning_combined_seq, newDistMat = prune_distance_matrix(rlist, isolates_to_remove,
                                                                       complete_distMat, dists_out)
 
-            combined_seq, core_distMat, acc_distMat = \	
+            combined_seq, core_distMat, acc_distMat = \
                 update_distance_matrices(viz_subset, newDistMat,	
                                          threads = args.threads)
 

From 8e29c1e80a9e27b20e3d031c768c7b83cfe8d9d8 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 15 Jul 2020 15:43:55 +0100
Subject: [PATCH 40/56] Update PopPUNK/lineage_clustering.py

Remove commented line.

Co-authored-by: John Lees <lees.john6@gmail.com>
---
 PopPUNK/lineage_clustering.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/PopPUNK/lineage_clustering.py b/PopPUNK/lineage_clustering.py
index 90ad2b2f..9e340995 100644
--- a/PopPUNK/lineage_clustering.py
+++ b/PopPUNK/lineage_clustering.py
@@ -311,7 +311,7 @@ def cluster_into_lineages(distMat, rank_list = None, output = None,
                     component_name[component_number] = overall_lineage_seeds[rank][seed]
             # name remaining components in rank order
             for component_rank in range(len(component_frequency_ranks)):
-#                component_number = component_frequency_ranks[np.where(component_frequency_ranks == component_rank)]
+#                
                 component_number = component_frequency_ranks.index(component_rank)
                 if component_name[component_number] is None:
                     component_name[component_number] = max_existing_cluster[rank]

From c040518b18cce9696c99bf4fd5bf798c6ffa8db8 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 15 Jul 2020 15:44:19 +0100
Subject: [PATCH 41/56] Update PopPUNK/lineage_clustering.py

Fix file naming.

Co-authored-by: John Lees <lees.john6@gmail.com>
---
 PopPUNK/lineage_clustering.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/PopPUNK/lineage_clustering.py b/PopPUNK/lineage_clustering.py
index 9e340995..d0830758 100644
--- a/PopPUNK/lineage_clustering.py
+++ b/PopPUNK/lineage_clustering.py
@@ -328,7 +328,7 @@ def cluster_into_lineages(distMat, rank_list = None, output = None,
                 renamed_component = component_name[original_component]
                 lineage_assignation[rank][isolate_name] = renamed_component
             # save network
-            G.save(file_name = output + "/" + output + '_rank_' + str(rank) + '_lineages.gt', fmt = 'gt')
+            G.save(file_name = output + "/" + os.path.basename(output) + '_rank_' + str(rank) + '_lineages.gt', fmt = 'gt')
             # clear edges
             G.clear_edges()
 

From 0f8d1f981c9cfa6607d03ceeb68ec7ace67d9781 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 15 Jul 2020 15:50:59 +0100
Subject: [PATCH 42/56] Change default lineage cluster

---
 PopPUNK/lineage_clustering.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/PopPUNK/lineage_clustering.py b/PopPUNK/lineage_clustering.py
index 90ad2b2f..8ce0816d 100644
--- a/PopPUNK/lineage_clustering.py
+++ b/PopPUNK/lineage_clustering.py
@@ -409,15 +409,15 @@ def run_clustering_for_rank(rank, distances_input = None, distance_ranks_input =
     # iteratively identify lineages
     lineage_index = 1
     connections = set()
-    lineage_assignation = {isolate:0 for isolate in isolate_list}
+    lineage_assignation = {isolate:None for isolate in isolate_list}
     
-    while 0 in lineage_assignation.values():
+    while None in lineage_assignation.values():
         if lineage_index in seeds.keys():
             seed_isolate = seeds[lineage_index]
         else:
             seed_isolate = pick_seed_isolate(lineage_assignation, distances = distances_input)
         # skip over previously-defined seeds if amalgamated into different lineage now
-        if lineage_assignation[seed_isolate] == 0:
+        if lineage_assignation[seed_isolate] is None:
             seeds[lineage_index] = seed_isolate
             lineage_assignation, added_connections = get_lineage(lineage_assignation, nn, seed_isolate, lineage_index)
             connections.update(added_connections)

From be09381063e6e83071ca0dcda98b837c3ecf7996 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 15 Jul 2020 15:53:07 +0100
Subject: [PATCH 43/56] Update PopPUNK/mash.py

Remove whitespace.

Co-authored-by: John Lees <lees.john6@gmail.com>
---
 PopPUNK/mash.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/PopPUNK/mash.py b/PopPUNK/mash.py
index 0aa551e5..c153195a 100644
--- a/PopPUNK/mash.py
+++ b/PopPUNK/mash.py
@@ -615,7 +615,6 @@ def queryDatabase(rNames, qNames, dbPrefix, queryPrefix, klist, self = True, num
             end += 1
         mat_chunks.append((start, end))
         start = end
-    
     # create empty distMat that can be shared with multiple processes
     distMat = np.zeros((number_pairs, 2), dtype=raw.dtype)
     with SharedMemoryManager() as smm:

From 169fe668964e134a42b550447b0e5820fe860aea Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 15 Jul 2020 15:59:42 +0100
Subject: [PATCH 44/56] Tidying up network construction

---
 PopPUNK/__main__.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index 24bf5b0b..4b7b2c64 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -461,7 +461,7 @@ def main():
                 sys.stderr.write("WARNING: Samples " + ",".join(networkMissing) + " are missing from the final network\n")
 
             isolateClustering = {fit_type: printClusters(genomeNetwork,
-                                                        refList, # assume no rlist+qlist?
+                                                         refList,
                                                          args.output + "/" + os.path.basename(args.output),
                                                          externalClusterCSV = args.external_clustering)}
 

From 0ed9bcc3962edcb485126aaf2ad9f66c33ce5a1f Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 15 Jul 2020 16:24:21 +0100
Subject: [PATCH 45/56] Assign local variable more clearly

---
 PopPUNK/__main__.py | 5 +----
 1 file changed, 1 insertion(+), 4 deletions(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index 4b7b2c64..d4d10497 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -442,10 +442,6 @@ def main():
         if args.use_model:
             assignments = model.assign(distMat)
             model.plot(distMat, assignments)
-
-        if not args.lineage_clustering: # change this once lineage clustering is refined as a model
-            fit_type = 'combined'
-            model.save()
         
         #******************************#
         #*                            *#
@@ -460,6 +456,7 @@ def main():
             if len(networkMissing) > 0:
                 sys.stderr.write("WARNING: Samples " + ",".join(networkMissing) + " are missing from the final network\n")
 
+            fit_type = None
             isolateClustering = {fit_type: printClusters(genomeNetwork,
                                                          refList,
                                                          args.output + "/" + os.path.basename(args.output),

From ed2820f5ea15e716f1f5d3da7b4dd10b0bde4344 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 15 Jul 2020 16:41:44 +0100
Subject: [PATCH 46/56] Improve error message for isolates missing from network

---
 PopPUNK/__main__.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index d4d10497..fddff76d 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -454,7 +454,8 @@ def main():
             # Ensure all in dists are in final network
             networkMissing = set(range(len(refList))).difference(list(genomeNetwork.vertices()))
             if len(networkMissing) > 0:
-                sys.stderr.write("WARNING: Samples " + ",".join(networkMissing) + " are missing from the final network\n")
+                missing_isolates = [refList[m] for m in networkMissing]
+                sys.stderr.write("WARNING: Samples " + ", ".join(missing_isolates) + " are missing from the final network\n")
 
             fit_type = None
             isolateClustering = {fit_type: printClusters(genomeNetwork,

From abf0bc3b0a910e11d4cf84c49105e6921c9a91d0 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 15 Jul 2020 16:56:10 +0100
Subject: [PATCH 47/56] Tidying of excess code

---
 PopPUNK/network.py | 6 ------
 1 file changed, 6 deletions(-)

diff --git a/PopPUNK/network.py b/PopPUNK/network.py
index 103f879d..b871d89a 100644
--- a/PopPUNK/network.py
+++ b/PopPUNK/network.py
@@ -145,20 +145,15 @@ def extractReferences(G, mashOrder, outPrefix, existingRefs = None):
             reference_vertex[vertex] = True
         else:
             reference_vertex[vertex] = False
-#    G.set_vertex_filter(reference_vertex)
     G_ref = gt.GraphView(G, vfilt = reference_vertex)
     G_ref = gt.Graph(G_ref, prune = True) # https://stackoverflow.com/questions/30839929/graph-tool-graphview-object
     # Calculate component membership for reference graph
-#    reference_graph_components, reference_graph_component_frequencies = gt.label_components(G_ref)
     clusters_in_reference_graph = printClusters(G, mashOrder, printCSV=False)
     # Record to which components references below in the reference graph
     reference_clusters_in_reference_graph = {}
     for reference_index in reference_indices:
         reference_clusters_in_reference_graph[mashOrder[reference_index]] = clusters_in_reference_graph[mashOrder[reference_index]]
 
-    # Unset mask on network for shortest path calculations
-#    G.set_vertex_filter(None)
-
     # Check if multi-reference components have been split as a validation test
     # First iterate through clusters
     network_update_required = False
@@ -635,7 +630,6 @@ def printExternalClusters(newClusters, extClusterFile, outPrefix,
     d = defaultdict(list)
 
     # Read in external clusters
-#    extClusters = readExternalClusters(extClusterFile)
     readIsolateTypeFromCsv(clustCSV, mode = 'external', return_dict = False)
 
     # Go through each cluster (as defined by poppunk) and find the external

From 06a8648c391794591363101e52968bc50172ccb5 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 15 Jul 2020 16:57:52 +0100
Subject: [PATCH 48/56] Update PopPUNK/__main__.py

FIx file naming.

Co-authored-by: John Lees <lees.john6@gmail.com>
---
 PopPUNK/__main__.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index fddff76d..f87318f5 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -692,7 +692,7 @@ def main():
                 # load clustering
                 cluster_file = args.ref_db + '/' + args.ref_db + '_clusters.csv'
                 isolateClustering = readIsolateTypeFromCsv(cluster_file, mode = 'clusters', return_dict = True)
-
+                cluster_file = args.ref_db + '/' + os.path.basename(args.ref_db) + '_clusters.csv'
             # generate selected visualisations
             if args.microreact:
                 sys.stderr.write("Writing microreact output\n")

From 39dda3d059b4a378e415194071b17851c0c9d7ef Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 15 Jul 2020 16:58:16 +0100
Subject: [PATCH 49/56] Update PopPUNK/__main__.py

Fix network file naming.

Co-authored-by: John Lees <lees.john6@gmail.com>
---
 PopPUNK/__main__.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index f87318f5..497d24fb 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -518,7 +518,7 @@ def main():
             # load networks
             indivNetworks = {}
             for rank in rank_list:
-                indivNetworks[rank] = gt.load_graph(args.output + "/" + args.output + '_rank_' + str(rank) + '_lineages.gt')
+                indivNetworks[rank] = gt.load_graph(args.output + "/" + os.path.basename(args.output) + '_rank_' + str(rank) + '_lineages.gt')
                 if rank == min(rank_list):
                     genomeNetwork = indivNetworks[rank]
 

From a5a58a8981597733636950f53ad213c85e006c96 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Wed, 15 Jul 2020 17:02:01 +0100
Subject: [PATCH 50/56] Replace mashOrder with dbOrder

---
 PopPUNK/__main__.py |  6 +++---
 PopPUNK/network.py  | 20 ++++++++++----------
 2 files changed, 13 insertions(+), 13 deletions(-)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index 497d24fb..2f15f2ab 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -859,9 +859,9 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
             # Update the network + ref list
             # only update network if assigning to strains
             if full_db is False and assign_lineage is False:
-                mashOrder = refList + ordered_queryList
-                newRepresentativesIndices, newRepresentativesNames, newRepresentativesFile, genomeNetwork = extractReferences(genomeNetwork, mashOrder, output, refList)
-                isolates_to_remove = set(mashOrder).difference(newRepresentativesNames)
+                dbOrder = refList + ordered_queryList
+                newRepresentativesIndices, newRepresentativesNames, newRepresentativesFile, genomeNetwork = extractReferences(genomeNetwork, dbOrder, output, refList)
+                isolates_to_remove = set(dbOrder).difference(newRepresentativesNames)
                 newQueries = [x for x in ordered_queryList if x in frozenset(newRepresentativesNames)] # intersection that maintains order
                 genomeNetwork.save(output + "/" + os.path.basename(output) + '_graph.gt', fmt = 'gt')
             else:
diff --git a/PopPUNK/network.py b/PopPUNK/network.py
index b871d89a..d63129b2 100644
--- a/PopPUNK/network.py
+++ b/PopPUNK/network.py
@@ -84,7 +84,7 @@ def fetchNetwork(network_dir, model, refList,
     return (genomeNetwork, cluster_file)
 
 
-def extractReferences(G, mashOrder, outPrefix, existingRefs = None):
+def extractReferences(G, dbOrder, outPrefix, existingRefs = None):
     """Extract references for each cluster based on cliques
 
        Writes chosen references to file by calling :func:`~writeReferences`
@@ -92,7 +92,7 @@ def extractReferences(G, mashOrder, outPrefix, existingRefs = None):
        Args:
            G (graph)
                A network used to define clusters from :func:`~constructNetwork`
-           mashOrder (list)
+           dbOrder (list)
                The order of files in the sketches, so returned references are in the same order
            outPrefix (str)
                Prefix for output file (.refs will be appended)
@@ -110,7 +110,7 @@ def extractReferences(G, mashOrder, outPrefix, existingRefs = None):
         reference_indices = []
     else:
         references = set(existingRefs)
-        index_lookup = {v:k for k,v in enumerate(mashOrder)}
+        index_lookup = {v:k for k,v in enumerate(dbOrder)}
         reference_indices = [index_lookup[r] for r in references]
     
     # extract cliques from network
@@ -129,13 +129,13 @@ def extractReferences(G, mashOrder, outPrefix, existingRefs = None):
 
     # Find any clusters which are represented by multiple references
     # First get cluster assignments
-    clusters_in_overall_graph = printClusters(G, mashOrder, printCSV=False)
+    clusters_in_overall_graph = printClusters(G, dbOrder, printCSV=False)
     # Construct a dict of sets for each cluster
     reference_clusters_in_overall_graph = [set() for c in set(clusters_in_overall_graph.items())]
     # Iterate through references
     for reference_index in reference_indices:
         # Add references to the appropriate cluster
-        reference_clusters_in_overall_graph[clusters_in_overall_graph[mashOrder[reference_index]]].add(reference_index)
+        reference_clusters_in_overall_graph[clusters_in_overall_graph[dbOrder[reference_index]]].add(reference_index)
 
     # Use a vertex filter to extract the subgraph of refences
     # as a graphview
@@ -148,11 +148,11 @@ def extractReferences(G, mashOrder, outPrefix, existingRefs = None):
     G_ref = gt.GraphView(G, vfilt = reference_vertex)
     G_ref = gt.Graph(G_ref, prune = True) # https://stackoverflow.com/questions/30839929/graph-tool-graphview-object
     # Calculate component membership for reference graph
-    clusters_in_reference_graph = printClusters(G, mashOrder, printCSV=False)
+    clusters_in_reference_graph = printClusters(G, dbOrder, printCSV=False)
     # Record to which components references below in the reference graph
     reference_clusters_in_reference_graph = {}
     for reference_index in reference_indices:
-        reference_clusters_in_reference_graph[mashOrder[reference_index]] = clusters_in_reference_graph[mashOrder[reference_index]]
+        reference_clusters_in_reference_graph[dbOrder[reference_index]] = clusters_in_reference_graph[dbOrder[reference_index]]
 
     # Check if multi-reference components have been split as a validation test
     # First iterate through clusters
@@ -163,10 +163,10 @@ def extractReferences(G, mashOrder, outPrefix, existingRefs = None):
             check = list(cluster)
             # check if these are still in the same component in the reference graph
             for i in range(len(check)):
-                component_i = reference_clusters_in_reference_graph[mashOrder[check[i]]]
+                component_i = reference_clusters_in_reference_graph[dbOrder[check[i]]]
                 for j in range(i, len(check)):
                     # Add intermediate nodes
-                    component_j = reference_clusters_in_reference_graph[mashOrder[check[j]]]
+                    component_j = reference_clusters_in_reference_graph[dbOrder[check[j]]]
                     if component_i != component_j:
                         network_update_required = True
                         vertex_list, edge_list = gt.shortest_path(G, check[i], check[j])
@@ -181,7 +181,7 @@ def extractReferences(G, mashOrder, outPrefix, existingRefs = None):
         G_ref = gt.Graph(G_ref, prune = True) # https://stackoverflow.com/questions/30839929/graph-tool-graphview-object
     
     # Order found references as in mash sketch files
-    reference_names = [mashOrder[int(x)] for x in sorted(reference_indices)]
+    reference_names = [dbOrder[int(x)] for x in sorted(reference_indices)]
     refFileName = writeReferences(reference_names, outPrefix)
     return reference_indices, reference_names, refFileName, G_ref
 

From 0a299b7f2cbaf703f2724bc20eb17f59067fc41a Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Thu, 16 Jul 2020 08:44:46 +0100
Subject: [PATCH 51/56] Reinstate model.save()

---
 PopPUNK/__main__.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py
index 2f15f2ab..cccef1af 100644
--- a/PopPUNK/__main__.py
+++ b/PopPUNK/__main__.py
@@ -425,6 +425,8 @@ def main():
                 model = BGMMFit(args.output)
                 assignments = model.fit(distMat, args.K)
                 model.plot(distMat, assignments)
+            # save model
+            model.save()
 
         # Run model refinement
         if args.refine_model or args.threshold or args.easy_run:

From 4ff2a49581e058b38cd845f7c653a1336752934d Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Thu, 16 Jul 2020 09:30:44 +0100
Subject: [PATCH 52/56] Update network component extraction code

---
 scripts/poppunk_extract_components.py | 28 +++++++++++++++++----------
 1 file changed, 18 insertions(+), 10 deletions(-)

diff --git a/scripts/poppunk_extract_components.py b/scripts/poppunk_extract_components.py
index e8c3be38..842b5675 100755
--- a/scripts/poppunk_extract_components.py
+++ b/scripts/poppunk_extract_components.py
@@ -3,7 +3,8 @@
 # Copyright 2018 John Lees and Nick Croucher
 
 import sys
-import networkx as nx
+import graph_tool.all as gt
+from scipy.stats import rankdata
 import argparse
 
 # command line parsing
@@ -14,8 +15,8 @@ def get_options():
                                      prog='extract_components')
 
     # input options
-    parser.add_argument('graph', help='Input graph pickle (.gpickle)')
-    parser.add_argument('output', help='Prefix for output files')
+    parser.add_argument('--graph', help='Input graph pickle (.gt)')
+    parser.add_argument('--output', help='Prefix for output files')
 
     return parser.parse_args()
 
@@ -25,13 +26,20 @@ def get_options():
     # Check input ok
     args = get_options()
 
-    # open stored distances
-    G = nx.read_gpickle(args.graph)
-    sys.stderr.write("Writing " + str(nx.number_connected_components(G)) + " components "
+    # open stored graph
+    G = gt.load_graph(args.graph)
+    
+    # extract individual components
+    component_assignments, component_frequencies = gt.label_components(G)
+    component_frequency_ranks = len(component_frequencies) - rankdata(component_frequencies, method = 'ordinal').astype(int)
+    sys.stderr.write("Writing " + str(len(component_frequencies)) + " components "
                      "in reverse order of size\n")
 
-    components = sorted(nx.connected_components(G), key=len, reverse=True)
-    for component_idx, component in enumerate(components):
-        nx.write_graphml(G.subgraph(component), args.output + ".component_" + str(component_idx + 1) + ".graphml")
-
+    # extract as GraphView objects and print
+    for component_index in range(len(component_frequency_ranks)):
+        component_gv = gt.GraphView(G, vfilt = component_assignments.a == component_index)
+        component_G = gt.Graph(component_gv, prune = True)
+        component_fn = args.output + ".component_" + str(component_frequency_ranks[component_index]) + ".graphml"
+        component_G.save(component_fn, fmt = 'graphml')
+    
     sys.exit(0)

From 0ac1fd46af554f11aff60af7d9ff256142b9b70e Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Thu, 16 Jul 2020 09:47:09 +0100
Subject: [PATCH 53/56] Expand comment to explain network resuse

---
 PopPUNK/lineage_clustering.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/PopPUNK/lineage_clustering.py b/PopPUNK/lineage_clustering.py
index 1360f49b..7afe57c6 100644
--- a/PopPUNK/lineage_clustering.py
+++ b/PopPUNK/lineage_clustering.py
@@ -329,7 +329,7 @@ def cluster_into_lineages(distMat, rank_list = None, output = None,
                 lineage_assignation[rank][isolate_name] = renamed_component
             # save network
             G.save(file_name = output + "/" + os.path.basename(output) + '_rank_' + str(rank) + '_lineages.gt', fmt = 'gt')
-            # clear edges
+            # clear edges - nodes in graph can be reused but edges differ between ranks
             G.clear_edges()
 
     # store output

From 8e8fefa358f6d47fe54a4d2c01f52e721ee0df33 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Thu, 16 Jul 2020 09:54:39 +0100
Subject: [PATCH 54/56] Expanded explanation in comments

---
 PopPUNK/network.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/PopPUNK/network.py b/PopPUNK/network.py
index d63129b2..41fe4150 100644
--- a/PopPUNK/network.py
+++ b/PopPUNK/network.py
@@ -130,11 +130,12 @@ def extractReferences(G, dbOrder, outPrefix, existingRefs = None):
     # Find any clusters which are represented by multiple references
     # First get cluster assignments
     clusters_in_overall_graph = printClusters(G, dbOrder, printCSV=False)
-    # Construct a dict of sets for each cluster
+    # Construct a dict containing one empty set for each cluster
     reference_clusters_in_overall_graph = [set() for c in set(clusters_in_overall_graph.items())]
     # Iterate through references
     for reference_index in reference_indices:
-        # Add references to the appropriate cluster
+        # Add references to the originally empty set for the appropriate cluster
+        # Allows enumeration of the number of references per cluster
         reference_clusters_in_overall_graph[clusters_in_overall_graph[dbOrder[reference_index]]].add(reference_index)
 
     # Use a vertex filter to extract the subgraph of refences

From ecbdbe9be8ed4f3ada2f3ec4ae2bd95437dd7f09 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Thu, 16 Jul 2020 10:06:51 +0100
Subject: [PATCH 55/56] Convert to listDistInts to generator function

---
 PopPUNK/utils.py | 36 +++++++++++++++++++++++++++++++++++-
 1 file changed, 35 insertions(+), 1 deletion(-)

diff --git a/PopPUNK/utils.py b/PopPUNK/utils.py
index 0cb0603b..a734e889 100644
--- a/PopPUNK/utils.py
+++ b/PopPUNK/utils.py
@@ -161,7 +161,7 @@ def iterDistRows(refSeqs, querySeqs, self=True):
             for ref in refSeqs:
                 yield(ref, query)
 
-def listDistInts(refSeqs, querySeqs, self=True):
+def old_listDistInts(refSeqs, querySeqs, self=True):
     """Gets the ref and query ID for each row of the distance matrix
 
     Returns an iterable with ref and query ID pairs by row.
@@ -199,6 +199,40 @@ def listDistInts(refSeqs, querySeqs, self=True):
                 
     return comparisons
 
+def listDistInts(refSeqs, querySeqs, self=True):
+    """Gets the ref and query ID for each row of the distance matrix
+
+    Returns an iterable with ref and query ID pairs by row.
+
+    Args:
+        refSeqs (list)
+            List of reference sequence names.
+        querySeqs (list)
+            List of query sequence names.
+        self (bool)
+            Whether a self-comparison, used when constructing a database.
+            Requires refSeqs == querySeqs
+            Default is True
+    Returns:
+        ref, query (str, str)
+            Iterable of tuples with ref and query names for each distMat row.
+    """
+    num_ref = len(refSeqs)
+    num_query = len(querySeqs)
+    if self:
+        if refSeqs != querySeqs:
+            raise RuntimeError('refSeqs must equal querySeqs for db building (self = true)')
+        for i in range(num_ref):
+            for j in range(i + 1, num_ref):
+                yield(j, i)
+    else:
+        comparisons = [(0,0)] * (len(refSeqs) * len(querySeqs))
+        for i in range(num_query):
+            for j in range(num_ref):
+                yield(j, i)
+
+    return comparisons
+
 def writeTmpFile(fileList):
     """Writes a list to a temporary file. Used for turning variable into mash
     input.

From 17a3f0ede8ca111a5e8ce9bab6278eec71c7ba53 Mon Sep 17 00:00:00 2001
From: nickjcroucher <n.croucher@imperial.ac.uk>
Date: Thu, 16 Jul 2020 12:06:52 +0100
Subject: [PATCH 56/56] Remove redundant function

---
 PopPUNK/utils.py | 38 --------------------------------------
 1 file changed, 38 deletions(-)

diff --git a/PopPUNK/utils.py b/PopPUNK/utils.py
index a734e889..997bd638 100644
--- a/PopPUNK/utils.py
+++ b/PopPUNK/utils.py
@@ -161,44 +161,6 @@ def iterDistRows(refSeqs, querySeqs, self=True):
             for ref in refSeqs:
                 yield(ref, query)
 
-def old_listDistInts(refSeqs, querySeqs, self=True):
-    """Gets the ref and query ID for each row of the distance matrix
-
-    Returns an iterable with ref and query ID pairs by row.
-
-    Args:
-        refSeqs (list)
-            List of reference sequence names.
-        querySeqs (list)
-            List of query sequence names.
-        self (bool)
-            Whether a self-comparison, used when constructing a database.
-            Requires refSeqs == querySeqs
-            Default is True
-    Returns:
-        ref, query (str, str)
-            Iterable of tuples with ref and query names for each distMat row.
-    """
-    n = 0
-    num_ref = len(refSeqs)
-    num_query = len(querySeqs)
-    if self:
-        comparisons = [(0,0)] * int((num_ref * (num_ref-1)) * 0.5)
-        if refSeqs != querySeqs:
-            raise RuntimeError('refSeqs must equal querySeqs for db building (self = true)')
-        for i in range(num_ref):
-            for j in range(i + 1, num_ref):
-                comparisons[n] = (j, i)
-                n = n + 1
-    else:
-        comparisons = [(0,0)] * (len(refSeqs) * len(querySeqs))
-        for i in range(num_query):
-            for j in range(num_ref):
-                comparisons[n] = (j, i)
-                n = n + 1
-                
-    return comparisons
-
 def listDistInts(refSeqs, querySeqs, self=True):
     """Gets the ref and query ID for each row of the distance matrix