-
Notifications
You must be signed in to change notification settings - Fork 10
/
clustermap_triangle.py
82 lines (73 loc) · 2.2 KB
/
clustermap_triangle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
from collections import defaultdict
import seaborn
import matplotlib.pyplot as plt
import sys
sys.setrecursionlimit(100000)
from scipy.cluster import hierarchy
import scipy
file = sys.argv[1]
if 'mash' in file:
print("ANI matrix obtained from Mash detected.")
if 'fastani' in file:
print("ANI matrix obtained from FastANI detected.")
counter = 0
items = 0
labels = []
condensed = []
matrix = []
all_labels = set()
delim = '\t'
#delim = ','
for line in open(file, 'r'):
if counter == 0:
#print(line)
spl = line.split(delim)
if len(spl) > 2:
items = len(spl)
else:
items = int(line.split(delim)[-1])
#print(items)
matrix = [[] for x in range(items)]
counter += 1
continue
if delim in line:
spl = line.split(delim);
else:
spl = line.split();
#print(spl[0].split('/')[-1])
labels.append(spl[0].split('/')[-1])
endpoints = range(1,counter)
for i in endpoints:
if 'mash' in file:
matrix[i-1].append(100 - 100 * float(spl[i]))
elif 'fastani' in file:
matrix[i-1].append(float(spl[i]))
else:
if float(spl[i]) <= 1:
matrix[i-1].append(float(spl[i]) * 100)
else:
matrix[i-1].append(float(spl[i]))
counter += 1
for vec in matrix:
for score in vec:
condensed.append(100 - score)
cmap = seaborn.cm.rocket_r
#Z = hierarchy.linkage(condensed, 'single')
#Z = hierarchy.linkage(condensed, 'complete')
Z = hierarchy.linkage(condensed, 'average')
square_mat = scipy.spatial.distance.squareform(condensed)
if len(sys.argv) > 2:
vmax = float(sys.argv[2])
cg = seaborn.clustermap(square_mat, row_linkage = Z, col_linkage = Z, vmax = vmax, cmap = cmap)
else:
cg = seaborn.clustermap(square_mat, row_linkage = Z, col_linkage = Z, cmap = cmap)
#print(cg.dendrogram_row.reordered_ind)
re = [labels[x] for x in cg.dendrogram_row.reordered_ind]
if len(labels) < 50:
xticks = [x for x in range(len(labels))]
cg.ax_heatmap.set_xticks(xticks)
cg.ax_heatmap.set_xticklabels(re, rotation=90)
#cg.ax_heatmap.set_yticklabels(re, rotation=0)
plt.tight_layout()
plt.show()