Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph tool #83

Merged
merged 62 commits into from
Jul 16, 2020
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
7260e50
First addition of graph-tools code
nickjcroucher May 7, 2020
09ebf6d
Include listDistInts routine in utils.py
nickjcroucher May 7, 2020
9eb7d84
Functioning model refinement with graph-tools
nickjcroucher May 7, 2020
09aacaf
Update extraction of references from network
nickjcroucher May 8, 2020
e2476af
More efficient extraction of references
nickjcroucher May 8, 2020
027b36d
Remove redundant imports and fix output
nickjcroucher May 8, 2020
5623a67
Switch from multiprocessing to open MP parallelisation using graph-tools
nickjcroucher May 8, 2020
88714dd
Fix network loading message
nickjcroucher May 9, 2020
145eee6
Graph loading function updated
nickjcroucher May 9, 2020
e75541f
Update visualisation code
nickjcroucher May 9, 2020
0109e0b
Refactor lineage_clustering code to use graph-tool
nickjcroucher May 12, 2020
188ed40
Update docstrings
nickjcroucher May 12, 2020
a75cb4f
Enable visualisation of lineage networks using Cytoscape
nickjcroucher May 12, 2020
2fc779b
Add extra network printing features
nickjcroucher May 12, 2020
fc3a69a
Change to network-based definitions of lineages
nickjcroucher May 12, 2020
dc042b9
Enable visualisation of networks post-processing
nickjcroucher May 13, 2020
4a18c1e
Enable querying and pruning of networks
nickjcroucher May 13, 2020
a263367
Fix output of names and labels
nickjcroucher May 14, 2020
9eecdaa
Remove debugging message
nickjcroucher May 14, 2020
09e6564
Add new dependency of lineage clustering on ref-db to tests
nickjcroucher May 14, 2020
09db4dc
Add graph-tool to dependencies
nickjcroucher May 14, 2020
9520b91
Overwrite for local running of tests
nickjcroucher May 14, 2020
2235736
Change references to query sequences in network extension
nickjcroucher May 14, 2020
1df2387
Use hash for query sequence name retrieval
nickjcroucher May 14, 2020
a38e58d
Use list for query sequence retrieval
nickjcroucher May 14, 2020
92956f0
Correct maths of listDistInts
nickjcroucher May 14, 2020
14a465f
Merge branch 'sketchlib140' into graph-tool
nickjcroucher Jul 3, 2020
a6037c3
Remove legacy mash test
nickjcroucher Jul 3, 2020
97dffbd
Merge branch 'sketchlib140' into graph-tool
johnlees Jul 3, 2020
a661bd3
Merge remote-tracking branch 'origin/master' into graph-tool
johnlees Jul 3, 2020
957434c
Adjusting test file
nickjcroucher Jul 4, 2020
1a36336
Fix test file
nickjcroucher Jul 4, 2020
22f20d9
Change minimum k step
nickjcroucher Jul 13, 2020
e00707a
Restore generate-viz mode test
nickjcroucher Jul 15, 2020
f6b86f6
Specified graph-tool package as a dependency in documentation
nickjcroucher Jul 15, 2020
04df0df
Removed outdated parts from troubleshooting document
nickjcroucher Jul 15, 2020
651586f
Update docstrings for graph-tool
nickjcroucher Jul 15, 2020
1d5766b
Update PopPUNK/__main__.py
nickjcroucher Jul 15, 2020
ad04c72
Update PopPUNK/__main__.py
nickjcroucher Jul 15, 2020
82e245a
Update PopPUNK/__main__.py
nickjcroucher Jul 15, 2020
46cc362
Remove debug file printing
nickjcroucher Jul 15, 2020
c6167e8
Update PopPUNK/mash.py
nickjcroucher Jul 15, 2020
3fb0006
Whitespace removed
nickjcroucher Jul 15, 2020
8e29c1e
Update PopPUNK/lineage_clustering.py
nickjcroucher Jul 15, 2020
c040518
Update PopPUNK/lineage_clustering.py
nickjcroucher Jul 15, 2020
0f8d1f9
Change default lineage cluster
nickjcroucher Jul 15, 2020
3c495c0
Merge branch 'graph-tool' of https://github.com/johnlees/PopPUNK into…
nickjcroucher Jul 15, 2020
be09381
Update PopPUNK/mash.py
nickjcroucher Jul 15, 2020
169fe66
Tidying up network construction
nickjcroucher Jul 15, 2020
cdacc00
Merge branch 'graph-tool' of https://github.com/johnlees/PopPUNK into…
nickjcroucher Jul 15, 2020
0ed9bcc
Assign local variable more clearly
nickjcroucher Jul 15, 2020
ed2820f
Improve error message for isolates missing from network
nickjcroucher Jul 15, 2020
abf0bc3
Tidying of excess code
nickjcroucher Jul 15, 2020
06a8648
Update PopPUNK/__main__.py
nickjcroucher Jul 15, 2020
39dda3d
Update PopPUNK/__main__.py
nickjcroucher Jul 15, 2020
a5a58a8
Replace mashOrder with dbOrder
nickjcroucher Jul 15, 2020
0a299b7
Reinstate model.save()
nickjcroucher Jul 16, 2020
4ff2a49
Update network component extraction code
nickjcroucher Jul 16, 2020
0ac1fd4
Expand comment to explain network resuse
nickjcroucher Jul 16, 2020
8e8fefa
Expanded explanation in comments
nickjcroucher Jul 16, 2020
ecbdbe9
Convert to listDistInts to generator function
nickjcroucher Jul 16, 2020
17a3f0e
Remove redundant function
nickjcroucher Jul 16, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions PopPUNK/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,8 @@ def main():
model = BGMMFit(args.output)
assignments = model.fit(distMat, args.K)
model.plot(distMat, assignments)
# save model
model.save()

# Run model refinement
if args.refine_model or args.threshold or args.easy_run:
Expand All @@ -442,10 +444,6 @@ def main():
if args.use_model:
assignments = model.assign(distMat)
model.plot(distMat, assignments)

if not args.lineage_clustering: # change this once lineage clustering is refined as a model
fit_type = 'combined'
model.save()

#******************************#
#* *#
Expand All @@ -458,10 +456,12 @@ def main():
# Ensure all in dists are in final network
networkMissing = set(range(len(refList))).difference(list(genomeNetwork.vertices()))
if len(networkMissing) > 0:
sys.stderr.write("WARNING: Samples " + ",".join(networkMissing) + " are missing from the final network\n")
missing_isolates = [refList[m] for m in networkMissing]
sys.stderr.write("WARNING: Samples " + ", ".join(missing_isolates) + " are missing from the final network\n")

fit_type = None
isolateClustering = {fit_type: printClusters(genomeNetwork,
refList, # assume no rlist+qlist?
refList,
args.output + "/" + os.path.basename(args.output),
externalClusterCSV = args.external_clustering)}

Expand Down Expand Up @@ -520,7 +520,7 @@ def main():
# load networks
indivNetworks = {}
for rank in rank_list:
indivNetworks[rank] = gt.load_graph(args.output + "/" + args.output + '_rank_' + str(rank) + '_lineages.gt')
indivNetworks[rank] = gt.load_graph(args.output + "/" + os.path.basename(args.output) + '_rank_' + str(rank) + '_lineages.gt')
if rank == min(rank_list):
genomeNetwork = indivNetworks[rank]

Expand Down Expand Up @@ -575,7 +575,7 @@ def main():
if not args.full_db:
newReferencesIndices, newReferencesNames, newReferencesFile, genomeNetwork = extractReferences(genomeNetwork, refList, args.output)
nodes_to_remove = set(range(len(refList))).difference(newReferencesIndices)
# genomeNetwork.remove_vertex(list(nodes_to_remove))
#
names_to_remove = [refList[n] for n in nodes_to_remove]
prune_distance_matrix(refList, names_to_remove, distMat,
args.output + "/" + os.path.basename(args.output) + ".dists")
Expand Down Expand Up @@ -657,7 +657,9 @@ def main():
postpruning_combined_seq, newDistMat = prune_distance_matrix(rlist, isolates_to_remove,
complete_distMat, dists_out)

combined_seq, core_distMat, acc_distMat = update_distance_matrices(viz_subset, newDistMat)
combined_seq, core_distMat, acc_distMat = \
update_distance_matrices(viz_subset, newDistMat,
threads = args.threads)

# reorder subset to ensure list orders match
try:
Expand Down Expand Up @@ -692,7 +694,7 @@ def main():
# load clustering
cluster_file = args.ref_db + '/' + args.ref_db + '_clusters.csv'
nickjcroucher marked this conversation as resolved.
Show resolved Hide resolved
isolateClustering = readIsolateTypeFromCsv(cluster_file, mode = 'clusters', return_dict = True)

cluster_file = args.ref_db + '/' + os.path.basename(args.ref_db) + '_clusters.csv'
# generate selected visualisations
if args.microreact:
sys.stderr.write("Writing microreact output\n")
Expand All @@ -712,7 +714,7 @@ def main():
for rank in isolateClustering.keys():
numeric_rank = rank.split('_')[1]
if numeric_rank.isdigit():
genomeNetwork = gt.load_graph(args.ref_db + '/' + args.ref_db + '_rank_' + str(numeric_rank) + '_lineages.gt')
genomeNetwork = gt.load_graph(args.ref_db + '/' + os.path.basename(args.ref_db) + '_rank_' + str(numeric_rank) + '_lineages.gt')
outputsForCytoscape(genomeNetwork, isolateClustering, args.output,
args.info_csv, suffix = 'rank_' + str(rank), viz_subset = viz_subset)
else:
Expand Down Expand Up @@ -859,9 +861,9 @@ def assign_query(dbFuncs, ref_db, q_files, output, update_db, full_db, distances
# Update the network + ref list
# only update network if assigning to strains
if full_db is False and assign_lineage is False:
mashOrder = refList + ordered_queryList
newRepresentativesIndices, newRepresentativesNames, newRepresentativesFile, genomeNetwork = extractReferences(genomeNetwork, mashOrder, output, refList)
isolates_to_remove = set(mashOrder).difference(newRepresentativesNames)
dbOrder = refList + ordered_queryList
newRepresentativesIndices, newRepresentativesNames, newRepresentativesFile, genomeNetwork = extractReferences(genomeNetwork, dbOrder, output, refList)
isolates_to_remove = set(dbOrder).difference(newRepresentativesNames)
newQueries = [x for x in ordered_queryList if x in frozenset(newRepresentativesNames)] # intersection that maintains order
genomeNetwork.save(output + "/" + os.path.basename(output) + '_graph.gt', fmt = 'gt')
else:
Expand Down
12 changes: 6 additions & 6 deletions PopPUNK/lineage_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def cluster_into_lineages(distMat, rank_list = None, output = None,
component_name[component_number] = overall_lineage_seeds[rank][seed]
# name remaining components in rank order
for component_rank in range(len(component_frequency_ranks)):
# component_number = component_frequency_ranks[np.where(component_frequency_ranks == component_rank)]
#
component_number = component_frequency_ranks.index(component_rank)
if component_name[component_number] is None:
component_name[component_number] = max_existing_cluster[rank]
Expand All @@ -328,8 +328,8 @@ def cluster_into_lineages(distMat, rank_list = None, output = None,
renamed_component = component_name[original_component]
lineage_assignation[rank][isolate_name] = renamed_component
# save network
G.save(file_name = output + "/" + output + '_rank_' + str(rank) + '_lineages.gt', fmt = 'gt')
# clear edges
G.save(file_name = output + "/" + os.path.basename(output) + '_rank_' + str(rank) + '_lineages.gt', fmt = 'gt')
# clear edges - nodes in graph can be reused but edges differ between ranks
G.clear_edges()
nickjcroucher marked this conversation as resolved.
Show resolved Hide resolved

# store output
Expand Down Expand Up @@ -409,15 +409,15 @@ def run_clustering_for_rank(rank, distances_input = None, distance_ranks_input =
# iteratively identify lineages
lineage_index = 1
connections = set()
lineage_assignation = {isolate:0 for isolate in isolate_list}
lineage_assignation = {isolate:None for isolate in isolate_list}

while 0 in lineage_assignation.values():
while None in lineage_assignation.values():
if lineage_index in seeds.keys():
seed_isolate = seeds[lineage_index]
else:
seed_isolate = pick_seed_isolate(lineage_assignation, distances = distances_input)
# skip over previously-defined seeds if amalgamated into different lineage now
if lineage_assignation[seed_isolate] == 0:
if lineage_assignation[seed_isolate] is None:
seeds[lineage_index] = seed_isolate
lineage_assignation, added_connections = get_lineage(lineage_assignation, nn, seed_isolate, lineage_index)
connections.update(added_connections)
Expand Down
3 changes: 0 additions & 3 deletions PopPUNK/mash.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,6 @@ def queryDatabase(rNames, qNames, dbPrefix, queryPrefix, klist, self = True, num
end += 1
mat_chunks.append((start, end))
start = end

# create empty distMat that can be shared with multiple processes
distMat = np.zeros((number_pairs, 2), dtype=raw.dtype)
with SharedMemoryManager() as smm:
Expand All @@ -627,7 +626,6 @@ def queryDatabase(rNames, qNames, dbPrefix, queryPrefix, klist, self = True, num

shm_distMat = smm.SharedMemory(size = distMat.nbytes)
distMat_shared = NumpyShared(name = shm_distMat.name, shape = (number_pairs, 2), dtype = raw.dtype)

# Run regressions
with Pool(processes = threads) as pool:
pool.map(partial(fitKmerBlock,
Expand Down Expand Up @@ -713,4 +711,3 @@ def fitKmerCurve(pairwise, klist, jacobian):

# Return core, accessory
return(np.flipud(transformed_params))

50 changes: 22 additions & 28 deletions PopPUNK/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def fetchNetwork(network_dir, model, refList,
core_only = False, accessory_only = False):
"""Load the network based on input options

Returns the network as a networkx, and sets the slope parameter of
the passed model object.
Returns the network as a graph-tool format graph, and sets
the slope parameter of the passed model object.

Args:
network_dir (str)
Expand All @@ -52,7 +52,7 @@ def fetchNetwork(network_dir, model, refList,
[default = False]

Returns:
genomeNetwork (nx.Graph)
genomeNetwork (graph)
The loaded network
cluster_file (str)
The CSV of cluster assignments corresponding to this network
Expand Down Expand Up @@ -84,15 +84,15 @@ def fetchNetwork(network_dir, model, refList,
return (genomeNetwork, cluster_file)


def extractReferences(G, mashOrder, outPrefix, existingRefs = None):
def extractReferences(G, dbOrder, outPrefix, existingRefs = None):
"""Extract references for each cluster based on cliques

Writes chosen references to file by calling :func:`~writeReferences`

Args:
G (networkx.Graph)
G (graph)
A network used to define clusters from :func:`~constructNetwork`
mashOrder (list)
dbOrder (list)
The order of files in the sketches, so returned references are in the same order
outPrefix (str)
Prefix for output file (.refs will be appended)
Expand All @@ -110,7 +110,7 @@ def extractReferences(G, mashOrder, outPrefix, existingRefs = None):
reference_indices = []
else:
references = set(existingRefs)
index_lookup = {v:k for k,v in enumerate(mashOrder)}
index_lookup = {v:k for k,v in enumerate(dbOrder)}
reference_indices = [index_lookup[r] for r in references]

# extract cliques from network
Expand All @@ -129,13 +129,14 @@ def extractReferences(G, mashOrder, outPrefix, existingRefs = None):

# Find any clusters which are represented by multiple references
# First get cluster assignments
clusters_in_overall_graph = printClusters(G, mashOrder, printCSV=False)
# Construct a dict of sets for each cluster
clusters_in_overall_graph = printClusters(G, dbOrder, printCSV=False)
# Construct a dict containing one empty set for each cluster
reference_clusters_in_overall_graph = [set() for c in set(clusters_in_overall_graph.items())]
nickjcroucher marked this conversation as resolved.
Show resolved Hide resolved
# Iterate through references
for reference_index in reference_indices:
# Add references to the appropriate cluster
reference_clusters_in_overall_graph[clusters_in_overall_graph[mashOrder[reference_index]]].add(reference_index)
# Add references to the originally empty set for the appropriate cluster
# Allows enumeration of the number of references per cluster
reference_clusters_in_overall_graph[clusters_in_overall_graph[dbOrder[reference_index]]].add(reference_index)

# Use a vertex filter to extract the subgraph of refences
# as a graphview
Expand All @@ -145,19 +146,14 @@ def extractReferences(G, mashOrder, outPrefix, existingRefs = None):
reference_vertex[vertex] = True
else:
reference_vertex[vertex] = False
# G.set_vertex_filter(reference_vertex)
G_ref = gt.GraphView(G, vfilt = reference_vertex)
G_ref = gt.Graph(G_ref, prune = True) # https://stackoverflow.com/questions/30839929/graph-tool-graphview-object
# Calculate component membership for reference graph
# reference_graph_components, reference_graph_component_frequencies = gt.label_components(G_ref)
clusters_in_reference_graph = printClusters(G, mashOrder, printCSV=False)
clusters_in_reference_graph = printClusters(G, dbOrder, printCSV=False)
# Record to which components references below in the reference graph
reference_clusters_in_reference_graph = {}
for reference_index in reference_indices:
reference_clusters_in_reference_graph[mashOrder[reference_index]] = clusters_in_reference_graph[mashOrder[reference_index]]

# Unset mask on network for shortest path calculations
# G.set_vertex_filter(None)
reference_clusters_in_reference_graph[dbOrder[reference_index]] = clusters_in_reference_graph[dbOrder[reference_index]]

# Check if multi-reference components have been split as a validation test
# First iterate through clusters
Expand All @@ -168,10 +164,10 @@ def extractReferences(G, mashOrder, outPrefix, existingRefs = None):
check = list(cluster)
# check if these are still in the same component in the reference graph
for i in range(len(check)):
component_i = reference_clusters_in_reference_graph[mashOrder[check[i]]]
component_i = reference_clusters_in_reference_graph[dbOrder[check[i]]]
for j in range(i, len(check)):
# Add intermediate nodes
component_j = reference_clusters_in_reference_graph[mashOrder[check[j]]]
component_j = reference_clusters_in_reference_graph[dbOrder[check[j]]]
if component_i != component_j:
network_update_required = True
vertex_list, edge_list = gt.shortest_path(G, check[i], check[j])
Expand All @@ -186,7 +182,7 @@ def extractReferences(G, mashOrder, outPrefix, existingRefs = None):
G_ref = gt.Graph(G_ref, prune = True) # https://stackoverflow.com/questions/30839929/graph-tool-graphview-object

# Order found references as in mash sketch files
reference_names = [mashOrder[int(x)] for x in sorted(reference_indices)]
reference_names = [dbOrder[int(x)] for x in sorted(reference_indices)]
refFileName = writeReferences(reference_names, outPrefix)
return reference_indices, reference_names, refFileName, G_ref

Expand Down Expand Up @@ -255,7 +251,7 @@ def constructNetwork(rlist, qlist, assignments, within_label, summarise = True):
(default = True)

Returns:
G (networkx.Graph)
G (graph)
The resulting network
"""
# data structures
Expand Down Expand Up @@ -298,7 +294,7 @@ def networkSummary(G):
"""Provides summary values about the network

Args:
G (networkx.Graph)
G (graph)
The network of strains from :func:`~constructNetwork`

Returns:
Expand Down Expand Up @@ -332,7 +328,7 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,
List of reference names
qfile (str)
File containing queries
G (networkx.Graph)
G (graph)
Network to add to (mutated)
kmers (list)
List of k-mer sizes
Expand Down Expand Up @@ -463,13 +459,12 @@ def addQueryToNetwork(dbFuncs, rlist, qfile, G, kmers, estimated_length,
shutil.rmtree(tmpDirName)

# finish by updating the network
G.save('before.graphml',fmt='graphml')
G.add_vertex(len(qNames))
G.add_edge_list(new_edges)

# including the vertex ID property map
for i,q in enumerate(qSeqs):
G.vp.id[i + len(rlist)] = q
G.save('after.graphml',fmt='graphml')

return qlist1, distMat

Expand All @@ -480,7 +475,7 @@ def printClusters(G, rlist, outPrefix = "_clusters.csv", oldClusterFile = None,
Also writes assignments to a CSV file

Args:
G (networkx.Graph)
G (graph)
Network used to define clusters (from :func:`~constructNetwork` or
:func:`~addQueryToNetwork`)
outPrefix (str)
Expand Down Expand Up @@ -636,7 +631,6 @@ def printExternalClusters(newClusters, extClusterFile, outPrefix,
d = defaultdict(list)

# Read in external clusters
# extClusters = readExternalClusters(extClusterFile)
readIsolateTypeFromCsv(clustCSV, mode = 'external', return_dict = False)

# Go through each cluster (as defined by poppunk) and find the external
Expand Down
2 changes: 1 addition & 1 deletion PopPUNK/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def outputsForCytoscape(G, clustering, outPrefix, epiCsv, queryList = None, suff
"""Write outputs for cytoscape. A graphml of the network, and CSV with metadata

Args:
G (networkx.Graph)
G (graph)
The network to write from :func:`~PopPUNK.network.constructNetwork`
clustering (dict)
Dictionary of cluster assignments (keys are nodeNames).
Expand Down
36 changes: 35 additions & 1 deletion PopPUNK/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def iterDistRows(refSeqs, querySeqs, self=True):
for ref in refSeqs:
yield(ref, query)

def listDistInts(refSeqs, querySeqs, self=True):
def old_listDistInts(refSeqs, querySeqs, self=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just remove this function

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, just hadn't updated a local change, updated in 17a3f0e.

"""Gets the ref and query ID for each row of the distance matrix

Returns an iterable with ref and query ID pairs by row.
Expand Down Expand Up @@ -199,6 +199,40 @@ def listDistInts(refSeqs, querySeqs, self=True):

return comparisons

def listDistInts(refSeqs, querySeqs, self=True):
"""Gets the ref and query ID for each row of the distance matrix

Returns an iterable with ref and query ID pairs by row.

Args:
refSeqs (list)
List of reference sequence names.
querySeqs (list)
List of query sequence names.
self (bool)
Whether a self-comparison, used when constructing a database.
Requires refSeqs == querySeqs
Default is True
Returns:
ref, query (str, str)
Iterable of tuples with ref and query names for each distMat row.
"""
num_ref = len(refSeqs)
num_query = len(querySeqs)
if self:
if refSeqs != querySeqs:
raise RuntimeError('refSeqs must equal querySeqs for db building (self = true)')
for i in range(num_ref):
for j in range(i + 1, num_ref):
yield(j, i)
else:
comparisons = [(0,0)] * (len(refSeqs) * len(querySeqs))
for i in range(num_query):
for j in range(num_ref):
yield(j, i)

return comparisons

def writeTmpFile(fileList):
"""Writes a list to a temporary file. Used for turning variable into mash
input.
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
# Causes a problem with rtd: https://github.com/pypa/setuptools/issues/1694
autodoc_mock_imports = ["hdbscan",
"numpy",
"networkx",
"graph-tool",
"pandas",
"scipy",
"sklearn",
Expand Down
Loading