Skip to content

Commit

Permalink
add graph_invisible_attrs
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Dec 16, 2024
1 parent 8c90098 commit e7d9568
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 8 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
</a>
<a href="https://badge.fury.io/py/brainstate"><img alt="PyPI version" src="https://badge.fury.io/py/brainstate.svg"></a>
<a href="https://github.com/chaobrain/brainstate/actions/workflows/CI.yml"><img alt="Continuous Integration" src="https://github.com/chaobrain/brainstate/actions/workflows/CI.yml/badge.svg"></a>
<a href="https://pepy.tech/projects/brainstate"><img src="https://static.pepy.tech/badge/brainstate" alt="PyPI Downloads"></a>
</p>


Expand Down
10 changes: 9 additions & 1 deletion brainstate/graph/_graph_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ class Node(PrettyRepr, metaclass=GraphNodeMeta):
- Deepcopy the node.
"""

graph_invisible_attrs = ()

if TYPE_CHECKING:
_trace_state: StateJaxTracer

Expand Down Expand Up @@ -170,7 +173,12 @@ def _to_shape_dtype(value):
def _node_flatten(
node: Node
) -> Tuple[Tuple[Tuple[str, Any], ...], Tuple[Type]]:
nodes = sorted((key, value) for key, value in vars(node).items() if key != '_trace_state')
graph_invisible_attrs = getattr(node, 'graph_invisible_attrs', ())
graph_invisible_attrs = tuple(graph_invisible_attrs) + ('_trace_state',)
nodes = sorted(
(key, value) for key, value in vars(node).items()
if (key not in graph_invisible_attrs)
)
return nodes, (type(node),)


Expand Down
2 changes: 2 additions & 0 deletions brainstate/nn/_dynamics/_dynamics_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ class Dynamics(Module):

__module__ = 'brainstate.nn'

graph_invisible_attrs = ('_before_updates', '_after_updates', '_current_inputs', '_delta_inputs')

# before updates
_before_updates: Optional[Dict[Hashable, Callable]]

Expand Down
23 changes: 16 additions & 7 deletions brainstate/nn/_dynamics/_projection_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
from __future__ import annotations

from typing import Union, Callable
from typing import Union, Callable, Optional

from brainstate._state import State
from brainstate.mixin import AlignPost, ParamDescriber, BindCondData, JointTypes
Expand Down Expand Up @@ -60,24 +60,28 @@ def is_instance(x, cls) -> bool:
return isinstance(x, cls)


def get_post_repr(syn, out):
return f'{syn.identifier} // {out.identifier}'
def get_post_repr(label, syn, out):
if label is None:
return f'{syn.identifier} // {out.identifier}'
else:
return f'{label}{syn.identifier} // {out.identifier}'


def align_post_add_bef_update(
syn_desc: ParamDescriber[AlignPost],
out_desc: ParamDescriber[BindCondData],
post: Dynamics,
proj_name: str
proj_name: str,
label: str,
):
# synapse and output initialization
_post_repr = get_post_repr(syn_desc, out_desc)
_post_repr = get_post_repr(label, syn_desc, out_desc)
if not post._has_before_update(_post_repr):
syn_cls = syn_desc()
out_cls = out_desc()

# synapse and output initialization
post.add_current_input(proj_name, out_cls)
post.add_current_input(proj_name, out_cls, label=label)
post._add_before_update(_post_repr, _AlignPost(syn_cls, out_cls))
syn = post._get_before_update(_post_repr).syn
out = post._get_before_update(_post_repr).out
Expand Down Expand Up @@ -139,6 +143,7 @@ def __init__(
syn: Union[ParamDescriber[AlignPost], AlignPost],
out: Union[ParamDescriber[SynOut], SynOut],
post: Dynamics,
label: Optional[str] = None,
):
super().__init__(name=get_unique_name(self.__class__.__name__))

Expand Down Expand Up @@ -185,7 +190,11 @@ def __init__(

if merging:
# synapse and output initialization
syn, out = align_post_add_bef_update(syn_desc=syn, out_desc=out, post=post, proj_name=self.name)
syn, out = align_post_add_bef_update(syn_desc=syn,
out_desc=out,
post=post,
proj_name=self.name,
label=label)
else:
post.add_current_input(self.name, out)

Expand Down

0 comments on commit e7d9568

Please sign in to comment.