Skip to content

Commit

Permalink
improve node graph visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed May 5, 2022
1 parent 9b339b9 commit 4e6f7ac
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
1 change: 0 additions & 1 deletion brainpy/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
from .autograd import *
from .controls import *
from .jit import *
# from .parallels import *

# settings
from . import setting
Expand Down
37 changes: 24 additions & 13 deletions brainpy/nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,17 +1335,24 @@ def forward(self,

def plot_node_graph(self,
fig_size: tuple = (10, 10),
node_size: int = 2000,
node_size: int = 1000,
arrow_size: int = 20,
layout='shell_layout'):
layout='shell_layout',
show=True,
legends=None,
ax=None):
"""Plot the node graph based on NetworkX package
Parameters
----------
fig_size: tuple, default to (10, 10)
The size of the figure
node_size: int, default to 2000
The size of the node
.. deprecated:: 2.1.9
Please use ``ax`` variable.
node_size: int
The size of the node. default to 1000
arrow_size:int, default to 20
The size of the arrow
layout: str
Expand Down Expand Up @@ -1412,11 +1419,15 @@ def plot_node_graph(self,
raise UnsupportedError(f'Only support layouts: {SUPPORTED_LAYOUTS}')
layout = getattr(nx, layout)(G)

plt.figure(figsize=fig_size)
if ax is None:
from brainpy.visualization.figures import get_figure
fig, gs = get_figure(1, 1, fig_size[1], fig_size[0])
ax = fig.add_subplot(gs[0, 0])
nx.draw_networkx_nodes(G, pos=layout,
nodelist=nodes_trainable,
node_color=trainable_color,
node_size=node_size)
node_size=node_size,
ax=ax)
nx.draw_networkx_nodes(G, pos=layout,
nodelist=nodes_untrainable,
node_color=untrainable_color,
Expand Down Expand Up @@ -1449,12 +1460,10 @@ def plot_node_graph(self,
proxie = []
labels = []
if len(nodes_trainable):
proxie.append(Line2D([], [], color='white', marker='o',
markerfacecolor=trainable_color))
proxie.append(Line2D([], [], color='white', marker='o', markerfacecolor=trainable_color))
labels.append('Trainable')
if len(nodes_untrainable):
proxie.append(Line2D([], [], color='white', marker='o',
markerfacecolor=untrainable_color))
proxie.append(Line2D([], [], color='white', marker='o', markerfacecolor=untrainable_color))
labels.append('Nontrainable')
if len(ff_edges):
proxie.append(Line2D([], [], color=ff_color, linewidth=2))
Expand All @@ -1466,9 +1475,11 @@ def plot_node_graph(self,
proxie.append(Line2D([], [], color=rec_color, linewidth=2))
labels.append('Recurrent')

plt.legend(proxie, labels, scatterpoints=1, markerscale=2, loc='best')
plt.tight_layout()
plt.show()
legends = dict() if legends is None else legends
ax.legend(proxie, labels, scatterpoints=1, markerscale=2, loc='best', **legends)
if show:
plt.tight_layout()
plt.show()


class FrozenNetwork(Network):
Expand Down

0 comments on commit 4e6f7ac

Please sign in to comment.