Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updates #550

Merged
merged 9 commits into from
Nov 28, 2023
Merged
52 changes: 29 additions & 23 deletions brainpy/_src/dynsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import collections
import inspect
import warnings
import numbers
import warnings
from typing import Union, Dict, Callable, Sequence, Optional, Any

import numpy as np
Expand All @@ -13,7 +13,7 @@
from brainpy._src.deprecations import _update_deprecate_msg
from brainpy._src.initialize import parameter, variable_
from brainpy._src.mixin import SupportAutoDelay, Container, SupportInputProj, DelayRegister, _get_delay_tool
from brainpy.errors import NoImplementationError, UnsupportedError, APIChangedError
from brainpy.errors import NoImplementationError, UnsupportedError
from brainpy.types import ArrayType, Shape

__all__ = [
Expand All @@ -27,9 +27,9 @@
'Dynamic', 'Projection',
]


IonChaDyn = None
SLICE_VARS = 'slice_vars'
the_top_layer_reset_state = True


def not_implemented(fun):
Expand Down Expand Up @@ -138,16 +138,12 @@ def update(self, *args, **kwargs):
"""
raise NotImplementedError('Must implement "update" function by subclass self.')

def reset(self, *args, include_self: bool = False, **kwargs):
def reset(self, *args, **kwargs):
"""Reset function which reset the whole variables in the model (including its children models).

``reset()`` function is a collective behavior which resets all states in this model.

See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details.

Args::
include_self: bool. Reset states including the node self. Please turn on this if the node has
implemented its ".reset_state()" function.
"""
from brainpy._src.helpers import reset_state
reset_state(self, *args, **kwargs)
Expand All @@ -162,19 +158,6 @@ def reset_state(self, *args, **kwargs):
"""
pass

# raise APIChangedError(
# '''
# From version >= 2.4.6, the policy of ``.reset_state()`` has been changed.
#
# 1. If you are resetting all states in a network by calling "net.reset_state()", please use
# "bp.reset_state(net)" function. ".reset_state()" only defines the resetting of local states
# in a local node (excluded its children nodes).
#
# 2. If you does not customize "reset_state()" function for a local node, please implement it in your subclass.
#
# '''
# )

def clear_input(self, *args, **kwargs):
"""Clear the input at the current time step."""
pass
Expand Down Expand Up @@ -344,14 +327,37 @@ def _compatible_update(self, *args, **kwargs):
return ret
return update_fun(*args, **kwargs)

def _compatible_reset_state(self, *args, **kwargs):
global the_top_layer_reset_state
the_top_layer_reset_state = False
try:
self.reset(*args, **kwargs)
finally:
the_top_layer_reset_state = True
warnings.warn(
'''
From version >= 2.4.6, the policy of ``.reset_state()`` has been changed. See https://brainpy.tech/docs/tutorial_toolbox/state_saving_and_loading.html for details.

1. If you are resetting all states in a network by calling "net.reset_state(*args, **kwargs)", please use
"bp.reset_state(net, *args, **kwargs)" function, or "net.reset(*args, **kwargs)".
".reset_state()" only defines the resetting of local states in a local node (excluded its children nodes).

2. If you does not customize "reset_state()" function for a local node, please implement it in your subclass.

''',
DeprecationWarning
)

def _get_update_fun(self):
return object.__getattribute__(self, 'update')

def __getattribute__(self, item):
if item == 'update':
return self._compatible_update # update function compatible with previous ``update()`` function
else:
return super().__getattribute__(item)
if item == 'reset_state':
if the_top_layer_reset_state:
return self._compatible_reset_state # reset_state function compatible with previous ``reset_state()`` function
return super().__getattribute__(item)

def __repr__(self):
return f'{self.name}(mode={self.mode})'
Expand Down
48 changes: 41 additions & 7 deletions brainpy/_src/math/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings
from typing import Any, Callable, TypeVar, cast

import jax
from jax import config, numpy as jnp, devices
from jax.lib import xla_bridge

Expand Down Expand Up @@ -682,7 +683,11 @@ def set_host_device_count(n):
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags)


def clear_buffer_memory(platform=None):
def clear_buffer_memory(
platform: str = None,
array: bool = True,
compilation: bool = False
):
"""Clear all on-device buffers.

This function will be very useful when you call models in a Python loop,
Expand All @@ -697,18 +702,47 @@ def clear_buffer_memory(platform=None):
----------
platform: str
The device to clear its memory.
array: bool
Clear all buffer array.
compilation: bool
Clear compilation cache.

"""
for buf in xla_bridge.get_backend(platform=platform).live_buffers():
buf.delete()
if array:
for buf in xla_bridge.get_backend(platform=platform).live_buffers():
buf.delete()
if compilation:
jax.clear_caches()


def disable_gpu_memory_preallocation():
"""Disable pre-allocating the GPU memory."""
def disable_gpu_memory_preallocation(release_memory: bool = True):
"""Disable pre-allocating the GPU memory.

This disables the preallocation behavior. JAX will instead allocate GPU memory as needed,
potentially decreasing the overall memory usage. However, this behavior is more prone to
GPU memory fragmentation, meaning a JAX program that uses most of the available GPU memory
may OOM with preallocation disabled.

Args:
release_memory: bool. Whether we release memory during the computation.
"""
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
if release_memory:
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'


def enable_gpu_memory_preallocation():
"""Disable pre-allocating the GPU memory."""
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'true'
os.environ.pop('XLA_PYTHON_CLIENT_ALLOCATOR')
os.environ.pop('XLA_PYTHON_CLIENT_ALLOCATOR', None)


def gpu_memory_preallocation(percent: float):
"""GPU memory allocation.

If preallocation is enabled, this makes JAX preallocate ``percent`` of the total GPU memory,
instead of the default 75%. Lowering the amount preallocated can fix OOMs that occur when the JAX program starts.
"""
assert 0. <= percent < 1., f'GPU memory preallocation must be in [0., 1.]. But we got {percent}.'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = str(percent)

13 changes: 0 additions & 13 deletions brainpy/_src/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,19 +519,6 @@ def __subclasscheck__(self, subclass):
return all([issubclass(subclass, cls) for cls in self.__bases__])


class UnionType2(MixIn):
"""Union type for multiple types.

>>> import brainpy as bp
>>>
>>> isinstance(bp.dyn.Expon(1), JointType[bp.DynamicalSystem, bp.mixin.ParamDesc, bp.mixin.SupportAutoDelay])
"""

@classmethod
def __class_getitem__(cls, types: Union[type, Sequence[type]]) -> type:
return _MetaUnionType('UnionType', types, {})


if sys.version_info.minor > 8:
class _JointGenericAlias(_UnionGenericAlias, _root=True):
def __subclasscheck__(self, subclass):
Expand Down
7 changes: 7 additions & 0 deletions brainpy/_src/running/pathos_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- ``cpu_unordered_parallel``: Performs a parallel unordered map.
"""

import sys
from collections.abc import Sized
from typing import (Any, Callable, Generator, Iterable, List,
Union, Optional, Sequence, Dict)
Expand All @@ -20,6 +21,8 @@
try:
from pathos.helpers import cpu_count # noqa
from pathos.multiprocessing import ProcessPool # noqa
import multiprocess.context as ctx # noqa
ctx._force_start_method('spawn')
except ModuleNotFoundError:
cpu_count = None
ProcessPool = None
Expand Down Expand Up @@ -63,6 +66,10 @@ def _parallel(
A generator which will apply the function to each element of the given Iterables
in parallel in order with a progress bar.
"""
if sys.platform == 'win32' and sys.version_info.minor >= 11:
raise NotImplementedError('Multiprocessing is not available in Python >=3.11 on Windows. '
'Please use Linux or MacOS, or Windows with Python <= 3.10.')

if ProcessPool is None or cpu_count is None:
raise PackageMissingError(
'''
Expand Down
41 changes: 41 additions & 0 deletions brainpy/_src/running/tests/test_pathos_multiprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import sys

import jax
import pytest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm

if sys.platform == 'win32' and sys.version_info.minor >= 11:
pytest.skip('python 3.11 does not support.', allow_module_level=True)
else:
pytest.skip('Cannot pass tests.', allow_module_level=True)


class TestParallel(parameterized.TestCase):
@parameterized.product(
duration=[1e2, 1e3, 1e4, 1e5]
)
def test_cpu_unordered_parallel_v1(self, duration):
@jax.jit
def body(inp):
return bm.for_loop(lambda x: x + 1e-9, inp)

input_long = bm.random.randn(1, int(duration / bm.dt), 3) / 100

r = bp.running.cpu_ordered_parallel(body, {'inp': [input_long, input_long]}, num_process=2)
assert bm.allclose(r[0], r[1])

@parameterized.product(
duration=[1e2, 1e3, 1e4, 1e5]
)
def test_cpu_unordered_parallel_v2(self, duration):
@jax.jit
def body(inp):
return bm.for_loop(lambda x: x + 1e-9, inp)

input_long = bm.random.randn(1, int(duration / bm.dt), 3) / 100

r = bp.running.cpu_unordered_parallel(body, {'inp': [input_long, input_long]}, num_process=2)
assert bm.allclose(r[0], r[1])
1 change: 1 addition & 0 deletions brainpy/math/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
clear_buffer_memory as clear_buffer_memory,
enable_gpu_memory_preallocation as enable_gpu_memory_preallocation,
disable_gpu_memory_preallocation as disable_gpu_memory_preallocation,
gpu_memory_preallocation as gpu_memory_preallocation,
ditype as ditype,
dftype as dftype,
)
2 changes: 1 addition & 1 deletion docs/tutorial_advanced/operator_custom_with_numba.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"collapsed": true
},
"source": [
"# Operator Customization with Numba"
"# CPU Operator Customization with Numba"
]
},
{
Expand Down
11 changes: 10 additions & 1 deletion docs/tutorial_advanced/operator_custom_with_taichi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,18 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Operator Customization with Taichi"
"# CPU and GPU Operator Customization with Taichi"
]
},
{
"cell_type": "markdown",
"source": [
"This functionality is only available for ``brainpylib>=0.2.0``. "
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ numba
brainpylib
jax
jaxlib
matplotlib>=3.4
matplotlib
msgpack
tqdm
pathos

# test requirements
pytest
Expand Down
4 changes: 2 additions & 2 deletions requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ msgpack
numba
jax
jaxlib
matplotlib>=3.4
scipy>=1.1.0
matplotlib
scipy
numba

# document requirements
Expand Down
13 changes: 11 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
# installation packages
packages = find_packages(exclude=['lib*', 'docs', 'tests'])


# setup
setup(
name='brainpy',
Expand All @@ -51,13 +50,23 @@
author_email='[email protected]',
packages=packages,
python_requires='>=3.8',
install_requires=['numpy>=1.15', 'jax', 'tqdm', 'msgpack', 'numba'],
install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm', 'msgpack', 'numba'],
url='https://github.com/brainpy/BrainPy',
project_urls={
"Bug Tracker": "https://github.com/brainpy/BrainPy/issues",
"Documentation": "https://brainpy.readthedocs.io/",
"Source Code": "https://github.com/brainpy/BrainPy",
},
dependency_links=[
'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html',
],
extras_require={
'cpu': ['jaxlib>=0.4.13', 'brainpylib'],
'cuda': ['jax[cuda]', 'brainpylib-cu11x'],
'cuda11': ['jax[cuda11_local]', 'brainpylib-cu11x'],
'cuda12': ['jax[cuda12_local]', 'brainpylib-cu12x'],
'tpu': ['jax[tpu]'],
},
keywords=('computational neuroscience, '
'brain-inspired computation, '
'dynamical systems, '
Expand Down