Skip to content

Commit

Permalink
feat(jax): support neural networks (#4156)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

- **New Features**
- Introduced JAX support, enhancing functionality and compatibility with
JAX library.
	- Added new `JAXBackend` class for backend integration with JAX.
	- New functions for converting between NumPy and JAX arrays.

- **Bug Fixes**
- Improved compatibility of neural network layers with array API
standards.

- **Tests**
- Added tests for JAX functionality and consistency checks against
reference outputs.
- Enhanced testing framework for activation functions and type
embeddings.

- **Chores**
	- Updated dependency requirements to include JAX library.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Sep 23, 2024
1 parent f5cfeab commit 0b72dae
Show file tree
Hide file tree
Showing 17 changed files with 393 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
- run: |
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
source/install/uv_with_retry.sh pip install --system -v -e .[gpu,test,lmp,cu12,torch] mpi4py
source/install/uv_with_retry.sh pip install --system -v -e .[gpu,test,lmp,cu12,torch,jax] mpi4py
env:
DP_VARIANT: cuda
DP_ENABLE_NATIVE_OPTIMIZATION: 1
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
source/install/uv_with_retry.sh pip install --system mpich
source/install/uv_with_retry.sh pip install --system "torch==2.3.0+cpu.cxx11.abi" -i https://download.pytorch.org/whl/
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
source/install/uv_with_retry.sh pip install --system --only-binary=horovod -e .[cpu,test] horovod[tensorflow-cpu] mpi4py
source/install/uv_with_retry.sh pip install --system --only-binary=horovod -e .[cpu,test,jax] horovod[tensorflow-cpu] mpi4py
env:
# Please note that uv has some issues with finding
# existing TensorFlow package. Currently, it uses
Expand Down
110 changes: 110 additions & 0 deletions deepmd/backend/jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from importlib.util import (
find_spec,
)
from typing import (
TYPE_CHECKING,
Callable,
ClassVar,
List,
Type,
)

from deepmd.backend.backend import (
Backend,
)

if TYPE_CHECKING:
from argparse import (
Namespace,
)

from deepmd.infer.deep_eval import (
DeepEvalBackend,
)
from deepmd.utils.neighbor_stat import (
NeighborStat,
)


@Backend.register("jax")
class JAXBackend(Backend):
"""JAX backend."""

name = "JAX"
"""The formal name of the backend."""
features: ClassVar[Backend.Feature] = (
Backend.Feature(0)
# Backend.Feature.ENTRY_POINT
# | Backend.Feature.DEEP_EVAL
# | Backend.Feature.NEIGHBOR_STAT
# | Backend.Feature.IO
)
"""The features of the backend."""
suffixes: ClassVar[List[str]] = []
"""The suffixes of the backend."""

def is_available(self) -> bool:
"""Check if the backend is available.
Returns
-------
bool
Whether the backend is available.
"""
return find_spec("jax") is not None

@property
def entry_point_hook(self) -> Callable[["Namespace"], None]:
"""The entry point hook of the backend.
Returns
-------
Callable[[Namespace], None]
The entry point hook of the backend.
"""
raise NotImplementedError

@property
def deep_eval(self) -> Type["DeepEvalBackend"]:
"""The Deep Eval backend of the backend.
Returns
-------
type[DeepEvalBackend]
The Deep Eval backend of the backend.
"""
raise NotImplementedError

@property
def neighbor_stat(self) -> Type["NeighborStat"]:
"""The neighbor statistics of the backend.
Returns
-------
type[NeighborStat]
The neighbor statistics of the backend.
"""
raise NotImplementedError

@property
def serialize_hook(self) -> Callable[[str], dict]:
"""The serialize hook to convert the model file to a dictionary.
Returns
-------
Callable[[str], dict]
The serialize hook of the backend.
"""
raise NotImplementedError

@property
def deserialize_hook(self) -> Callable[[str, dict], None]:
"""The deserialize hook to convert the dictionary to a model file.
Returns
-------
Callable[[str, dict], None]
The deserialize hook of the backend.
"""
raise NotImplementedError
22 changes: 22 additions & 0 deletions deepmd/dpmodel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
ABC,
abstractmethod,
)
from typing import (
Any,
Optional,
)

import ml_dtypes
import numpy as np
Expand Down Expand Up @@ -59,6 +63,24 @@ def __call__(self, *args, **kwargs):
return self.call(*args, **kwargs)


def to_numpy_array(x: Any) -> Optional[np.ndarray]:
"""Convert an array to a NumPy array.
Parameters
----------
x : Any
The array to be converted.
Returns
-------
Optional[np.ndarray]
The NumPy array.
"""
if x is None:
return None
return np.asarray(x)


__all__ = [
"GLOBAL_NP_FLOAT_PRECISION",
"GLOBAL_ENER_FLOAT_PRECISION",
Expand Down
50 changes: 38 additions & 12 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,20 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel import (
DEFAULT_PRECISION,
PRECISION_DICT,
NativeOP,
)
from deepmd.dpmodel.array_api import (
support_array_api,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils.seed import (
child_seed,
)
Expand Down Expand Up @@ -105,9 +112,9 @@ def serialize(self) -> dict:
The serialized layer.
"""
data = {
"w": self.w,
"b": self.b,
"idt": self.idt,
"w": to_numpy_array(self.w),
"b": to_numpy_array(self.b),
"idt": to_numpy_array(self.idt),
}
return {
"@class": "Layer",
Expand Down Expand Up @@ -215,6 +222,7 @@ def dim_in(self) -> int:
def dim_out(self) -> int:
return self.w.shape[1]

@support_array_api(version="2022.12")
def call(self, x: np.ndarray) -> np.ndarray:
"""Forward pass.
Expand All @@ -230,59 +238,77 @@ def call(self, x: np.ndarray) -> np.ndarray:
"""
if self.w is None or self.activation_function is None:
raise ValueError("w, b, and activation_function must be set")
xp = array_api_compat.array_namespace(x)
fn = get_activation_fn(self.activation_function)
y = (
np.matmul(x, self.w) + self.b
xp.matmul(x, self.w) + self.b
if self.b is not None
else np.matmul(x, self.w)
else xp.matmul(x, self.w)
)
y = fn(y)
if self.idt is not None:
y *= self.idt
if self.resnet and self.w.shape[1] == self.w.shape[0]:
y += x
elif self.resnet and self.w.shape[1] == 2 * self.w.shape[0]:
y += np.concatenate([x, x], axis=-1)
y += xp.concatenate([x, x], axis=-1)
return y


@support_array_api(version="2022.12")
def get_activation_fn(activation_function: str) -> Callable[[np.ndarray], np.ndarray]:
activation_function = activation_function.lower()
if activation_function == "tanh":
return np.tanh

def fn(x):
xp = array_api_compat.array_namespace(x)
return xp.tanh(x)

return fn
elif activation_function == "relu":

def fn(x):
xp = array_api_compat.array_namespace(x)
# https://stackoverflow.com/a/47936476/9567349
return x * (x > 0)
return x * xp.astype(x > 0, x.dtype)

return fn
elif activation_function in ("gelu", "gelu_tf"):

def fn(x):
xp = array_api_compat.array_namespace(x)
# generated by GitHub Copilot
return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))
return (
0.5
* x
* (1 + xp.tanh(xp.sqrt(xp.asarray(2 / xp.pi)) * (x + 0.044715 * x**3)))
)

return fn
elif activation_function == "relu6":

def fn(x):
xp = array_api_compat.array_namespace(x)
# generated by GitHub Copilot
return np.minimum(np.maximum(x, 0), 6)
return xp.where(
x < 0, xp.full_like(x, 0), xp.where(x > 6, xp.full_like(x, 6), x)
)

return fn
elif activation_function == "softplus":

def fn(x):
xp = array_api_compat.array_namespace(x)
# generated by GitHub Copilot
return np.log(1 + np.exp(x))
return xp.log(1 + xp.exp(x))

return fn
elif activation_function == "sigmoid":

def fn(x):
xp = array_api_compat.array_namespace(x)
# generated by GitHub Copilot
return 1 / (1 + np.exp(-x))
return 1 / (1 + xp.exp(-x))

return fn
elif activation_function.lower() in ("none", "linear"):
Expand Down
14 changes: 10 additions & 4 deletions deepmd/dpmodel/utils/type_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel.array_api import (
support_array_api,
)
from deepmd.dpmodel.common import (
PRECISION_DICT,
NativeOP,
Expand Down Expand Up @@ -92,16 +96,18 @@ def __init__(
bias=self.use_tebd_bias,
)

@support_array_api(version="2022.12")
def call(self) -> np.ndarray:
"""Compute the type embedding network."""
sample_array = self.embedding_net[0]["w"]
xp = array_api_compat.array_namespace(sample_array)
if not self.use_econf_tebd:
embed = self.embedding_net(
np.eye(self.ntypes, dtype=PRECISION_DICT[self.precision])
)
embed = self.embedding_net(xp.eye(self.ntypes, dtype=sample_array.dtype))
else:
embed = self.embedding_net(self.econf_tebd)
if self.padding:
embed = np.pad(embed, ((0, 1), (0, 0)), mode="constant")
embed_pad = xp.zeros((1, embed.shape[-1]), dtype=embed.dtype)
embed = xp.concatenate([embed, embed_pad], axis=0)
return embed

@classmethod
Expand Down
2 changes: 2 additions & 0 deletions deepmd/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""JAX backend."""
37 changes: 37 additions & 0 deletions deepmd/jax/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Union,
overload,
)

import numpy as np

from deepmd.jax.env import (
jnp,
)


@overload
def to_jax_array(array: np.ndarray) -> jnp.ndarray: ...


@overload
def to_jax_array(array: None) -> None: ...


def to_jax_array(array: Union[np.ndarray]) -> Union[jnp.ndarray]:
"""Convert a numpy array to a JAX array.
Parameters
----------
array : np.ndarray
The numpy array to convert.
Returns
-------
jnp.ndarray
The JAX tensor.
"""
if array is None:
return None
return jnp.array(array)
14 changes: 14 additions & 0 deletions deepmd/jax/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

__all__ = [
"jax",
"jnp",
]
1 change: 1 addition & 0 deletions deepmd/jax/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
Loading

0 comments on commit 0b72dae

Please sign in to comment.