From 54986ea0e5b2d3e923cc3b9d9ec262e696c01538 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Thu, 5 Dec 2024 16:56:02 +0800 Subject: [PATCH] Improve error handling in `dnn/linear` module --- brainpy/_src/dnn/linear.py | 16 ++++++++-------- docker/requirements.txt | 2 +- docs/quickstart/installation.rst | 24 ------------------------ 3 files changed, 9 insertions(+), 33 deletions(-) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 06fa9413f..8942375eb 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -12,7 +12,7 @@ from brainpy import math as bm from brainpy._src import connect, initialize as init from brainpy._src.context import share -from brainpy._src.dependency_check import import_taichi, import_braintaichi +from brainpy._src.dependency_check import import_taichi, import_braintaichi, raise_braintaichi_not_found from brainpy._src.dnn.base import Layer from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP from brainpy.check import is_initializer @@ -241,7 +241,7 @@ def update(self, x): return x -if ti is not None: +if ti is not None and bti is not None: # @numba.njit(nogil=True, fastmath=True, parallel=False) # def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w): @@ -321,7 +321,7 @@ def _dense_on_pre( def dense_on_pre(weight, spike, trace, w_min, w_max): if dense_on_pre_prim is None: - raise PackageMissingError.by_purpose('taichi', 'custom operators') + raise_braintaichi_not_found() if w_min is None: w_min = -np.inf @@ -341,7 +341,7 @@ def dense_on_pre(weight, spike, trace, w_min, w_max): def dense_on_post(weight, spike, trace, w_min, w_max): if dense_on_post_prim is None: - raise PackageMissingError.by_purpose('taichi', 'custom operators') + raise_braintaichi_not_found() if w_min is None: w_min = -np.inf @@ -728,7 +728,7 @@ def _batch_csrmv(self, x): transpose=self.transpose) -if ti is not None: +if ti is not None and bti is not None: @ti.kernel def _csr_on_pre_update( old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) @@ -852,7 +852,7 @@ def _csc_on_post_update( def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None): if csr_on_pre_update_prim is None: - raise PackageMissingError.by_purpose('taichi', 'customized operators') + raise_braintaichi_not_found() if w_min is None: w_min = -np.inf @@ -874,7 +874,7 @@ def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None): def coo_on_pre_update(w, pre_ids, post_ids, spike, trace, w_min=None, w_max=None): if coo_on_pre_update_prim is None: - raise PackageMissingError.by_purpose('taichi', 'customized operators') + raise_braintaichi_not_found() if w_min is None: w_min = -np.inf @@ -897,7 +897,7 @@ def coo_on_pre_update(w, pre_ids, post_ids, spike, trace, w_min=None, w_max=None def csc_on_post_update(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min=None, w_max=None): if csc_on_post_update_prim is None: - raise PackageMissingError.by_purpose('taichi', 'customized operators') + raise_braintaichi_not_found() if w_min is None: w_min = -np.inf diff --git a/docker/requirements.txt b/docker/requirements.txt index 460371906..72e9c84da 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -6,10 +6,10 @@ jax jaxlib scipy>=1.1.0 brainpy -brainpylib brainpy_datasets h5py pathos +braintaichi # test requirements pytest diff --git a/docs/quickstart/installation.rst b/docs/quickstart/installation.rst index 395bf627c..a94e22bf7 100644 --- a/docs/quickstart/installation.rst +++ b/docs/quickstart/installation.rst @@ -36,7 +36,6 @@ To install brainpy with minimum requirements (only depends on ``jax``), you can # or - pip install brainpy[cuda11_mini] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 11.0 pip install brainpy[cuda12_mini] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 12.0 # or @@ -64,7 +63,6 @@ To install a GPU-only version of BrainPy, you can run .. code-block:: bash pip install brainpy[cuda12] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 12.0 - pip install brainpy[cuda11] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 11.0 @@ -79,25 +77,3 @@ you can run the following in your cloud TPU VM: pip install brainpy[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # for google TPU - -``brainpylib`` --------------- - - -``brainpylib`` defines a set of useful operators for building and simulating spiking neural networks. - - -To install the ``brainpylib`` package on CPU devices, you can run - -.. code-block:: bash - - pip install brainpylib - - -To install the ``brainpylib`` package on CUDA (Linux only), you can run - - -.. code-block:: bash - - pip install brainpylib -