forked from Justin0111/V-Net.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmake_graph.py
25 lines (21 loc) · 827 Bytes
/
make_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# From https://gist.github.com/apaszke/01aae7a0494c55af6242f06fad1f8b70
from graphviz import Digraph
from torch.autograd import Variable
def save(fname, creator):
dot = Digraph(comment='LRP',
node_attr={'style': 'filled', 'shape': 'box'})
#, 'fillcolor': 'lightblue'})
seen = set()
def add_nodes(var):
if var not in seen:
if isinstance(var, Variable):
dot.node(str(id(var)), str(var.size()), fillcolor='lightblue')
else:
dot.node(str(id(var)), type(var).__name__)
seen.add(var)
if hasattr(var, 'previous_functions'):
for u in var.previous_functions:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
add_nodes(creator)
dot.save(fname)