Skip to content

Commit

Permalink
update package
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jun 6, 2024
1 parent 2235915 commit cfca8b3
Show file tree
Hide file tree
Showing 23 changed files with 17 additions and 22 deletions.
4 changes: 4 additions & 0 deletions brainstate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@
from . import transform
from . import typing
from . import util
from . import surrogate
from . import functional
from . import init
from ._module import *
from ._module import __all__ as _module_all
from ._state import *
from ._state import __all__ as _state_all

__all__ = (
['environ', 'share', 'nn', 'optim', 'random',
'surrogate', 'functional', 'init',
'mixin', 'math', 'transform', 'util', 'typing'] +
_module_all + _state_all
)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 0 additions & 2 deletions brainstate/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
# ==============================================================================

from . import init, functional, surrogate
from ._base import *
from ._base import __all__ as base_all
from ._connections import *
Expand All @@ -40,7 +39,6 @@
from ._synouts import __all__ as synouts_all

__all__ = (
['init', 'functional', 'surrogate'] +
base_all +
connections_all +
dynamics_all +
Expand Down
3 changes: 1 addition & 2 deletions brainstate/nn/_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
import jax
import jax.numpy as jnp

from . import functional
from . import init
from ._base import DnnLayer
from .. import init, functional
from .._state import ParamState
from ..mixin import Mode
from ..typing import ArrayLike
Expand Down
3 changes: 1 addition & 2 deletions brainstate/nn/_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@
import jax
import jax.numpy as jnp

from . import init, surrogate
from ._base import ExplicitInOutSize
from ._misc import exp_euler_step
from .. import environ
from .. import environ, init, surrogate
from .._module import Dynamics
from .._state import ShortTermState
from ..mixin import DelayedInit, Mode, AlignPost
Expand Down
3 changes: 1 addition & 2 deletions brainstate/nn/_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
import jax.numpy as jnp
import jax.typing

from . import functional as F
from ._base import ElementWiseBlock
from .. import math, environ, random
from .. import math, environ, random, functional as F
from .._module import Module
from .._state import ParamState
from ..mixin import Mode
Expand Down
3 changes: 1 addition & 2 deletions brainstate/nn/_normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@
import jax
import jax.numpy as jnp

from . import init
from ._base import DnnLayer
from .. import environ
from .. import environ, init
from .._state import LongTermState, ParamState
from ..mixin import Mode
from ..typing import DTypeLike, ArrayLike, Size, Axes
Expand Down
3 changes: 1 addition & 2 deletions brainstate/nn/_others.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@

import jax.numpy as jnp

from . import init
from ._base import DnnLayer
from .. import random, math, environ, typing
from .. import random, math, environ, typing, init
from ..mixin import Mode

__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion brainstate/nn/_projection/_align_post.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@

from typing import Optional, Union

from brainstate._utils import set_module_as
from brainstate._module import (register_delay_of_target,
Projection,
Module,
Dynamics,
ReceiveInputProj,
ExtendedUpdateWithBA)
from brainstate._utils import set_module_as
from brainstate.mixin import (Mode, AllOfTypes, DelayedInitializer, BindCondData, AlignPost, UpdateReturn)
from ._utils import is_instance

Expand Down
2 changes: 1 addition & 1 deletion brainstate/nn/_projection/_align_pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

from typing import Optional, Union

from brainstate._utils import set_module_as
from brainstate._module import (Module, DelayAccess, Projection,
ExtendedUpdateWithBA, ReceiveInputProj,
register_delay_of_target)
from brainstate._utils import set_module_as
from brainstate.mixin import (DelayedInitializer, BindCondData, UpdateReturn, Mode, AllOfTypes)
from ._utils import is_instance

Expand Down
2 changes: 1 addition & 1 deletion brainstate/nn/_projection/_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

from typing import Optional, Union

from brainstate._utils import set_module_as
from brainstate._module import (Module, Dynamics, Projection, ReceiveInputProj,
UpdateReturn, register_delay_of_target)
from brainstate._utils import set_module_as
from brainstate.mixin import (Mode, BindCondData)
from ._utils import is_instance

Expand Down
2 changes: 1 addition & 1 deletion brainstate/nn/_projection/_vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

from typing import Optional

from brainstate._utils import set_module_as
from brainstate._module import (Module, Projection, Dynamics, ReceiveInputProj)
from brainstate._utils import set_module_as
from brainstate.mixin import (BindCondData, Mode)
from ._utils import is_instance

Expand Down
5 changes: 2 additions & 3 deletions brainstate/nn/_rate_rnns.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,16 @@

import jax.numpy as jnp

from . import init, functional
from ._base import ExplicitInOutSize
from ._connections import Linear
from .. import random
from .. import random, init, functional
from .._module import Module
from .._state import ShortTermState, ParamState
from ..mixin import DelayedInit, Mode
from ..typing import ArrayLike

__all__ = [
'ValinaRNNCell', 'GRUCell', 'MGUCell', 'LSTMCell', 'URLSTMCell',
'RNNCell', 'ValinaRNNCell', 'GRUCell', 'MGUCell', 'LSTMCell', 'URLSTMCell',
]


Expand Down
5 changes: 2 additions & 3 deletions brainstate/nn/_readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@
import jax
import jax.numpy as jnp

from . import init, surrogate
from ._base import DnnLayer
from ._misc import exp_euler_step
from ._dynamics import Neuron
from .. import environ
from ._misc import exp_euler_step
from .. import environ, init, surrogate
from .._state import ShortTermState, ParamState
from ..mixin import Mode
from ..typing import Size, ArrayLike, DTypeLike
Expand Down
File renamed without changes.

0 comments on commit cfca8b3

Please sign in to comment.