Skip to content

Commit

Permalink
feat(jax): SavedModel C++ interface (including DPA-2 supports) (#4307)
Browse files Browse the repository at this point in the history
Including nlist and no nlist interface.

The limitation: A SavedModel created on a device cannot be run on
another. For example, a CUDA model cannot be run on the CPU.

The model is generated using #4336.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Added support for the JAX backend, including specific model and
checkpoint file formats.
- Introduced a new shell script for model conversion to enhance
usability.
- Updated installation documentation to clarify JAX support and
requirements.
- New section in documentation detailing limitations of the JAX backend
with LAMMPS.

- **Bug Fixes**
- Enhanced error handling for model initialization and backend
compatibility.

- **Documentation**
- Updated backend documentation to include JAX details and limitations.
- Improved clarity in installation instructions for both TensorFlow and
JAX.

- **Tests**
- Added comprehensive unit tests for JAX integration with the Deep
Potential class.
  - Expanded test coverage for LAMMPS integration with DeepMD.

- **Chores**
- Updated CMake configurations and workflow files for improved testing
and dependency management.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: Your Name <[email protected]>
  • Loading branch information
3 people authored Nov 13, 2024
1 parent 85e5e20 commit 698b08d
Show file tree
Hide file tree
Showing 24 changed files with 12,703 additions and 19 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/test_cc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ jobs:
mpi: mpich
- uses: lukka/get-cmake@latest
- run: python -m pip install uv
- run: source/install/uv_with_retry.sh pip install --system tensorflow
- name: Install Python dependencies
run: |
source/install/uv_with_retry.sh pip install --system tensorflow-cpu
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp,jax] mpi4py
- name: Convert models
run: source/tests/infer/convert-models.sh
- name: Download libtorch
run: |
wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.1.2%2Bcpu.zip -O libtorch.zip
Expand All @@ -47,12 +53,6 @@ jobs:
CMAKE_GENERATOR: Ninja
CXXFLAGS: ${{ matrix.check_memleak && '-fsanitize=leak' || '' }}
# test lammps
- run: |
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp] mpi4py
env:
DP_BUILD_TESTING: 1
if: ${{ !matrix.check_memleak }}
- run: pytest --cov=deepmd source/lmp/tests
env:
OMP_NUM_THREADS: 1
Expand Down
7 changes: 5 additions & 2 deletions .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
runs-on: nvidia
# https://github.com/deepmodeling/deepmd-kit/pull/2884#issuecomment-1744216845
container:
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
image: nvidia/cuda:12.6.2-cudnn-devel-ubuntu22.04
options: --gpus all
if: github.repository_owner == 'deepmodeling' && (github.event_name == 'pull_request' && github.event.label && github.event.label.name == 'Test CUDA' || github.event_name == 'workflow_dispatch' || github.event_name == 'merge_group')
steps:
Expand Down Expand Up @@ -63,12 +63,15 @@ jobs:
CUDA_VISIBLE_DEVICES: 0
# See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
XLA_PYTHON_CLIENT_PREALLOCATE: false
- name: Convert models
run: source/tests/infer/convert-models.sh
- name: Download libtorch
run: |
wget https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.5.0%2Bcu124.zip -O libtorch.zip
unzip libtorch.zip
- run: |
export CMAKE_PREFIX_PATH=$GITHUB_WORKSPACE/libtorch
export LD_LIBRARY_PATH=$CUDA_PATH/lib64:/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH
source/install/test_cc_local.sh
env:
OMP_NUM_THREADS: 1
Expand All @@ -79,7 +82,7 @@ jobs:
DP_VARIANT: cuda
DP_USE_MPICH2: 1
- run: |
export LD_LIBRARY_PATH=$GITHUB_WORKSPACE/dp_test/lib:$GITHUB_WORKSPACE/libtorch/lib:$CUDA_PATH/lib64:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$CUDA_PATH/lib64:/usr/lib/x86_64-linux-gnu/:$GITHUB_WORKSPACE/dp_test/lib:$GITHUB_WORKSPACE/libtorch/lib:$LD_LIBRARY_PATH
export PATH=$GITHUB_WORKSPACE/dp_test/bin:$PATH
python -m pytest -s source/lmp/tests || (cat log.lammps && exit 1)
python -m pytest source/ipi/tests
Expand Down
4 changes: 3 additions & 1 deletion doc/backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ While `.pth` and `.pt` are the same in the PyTorch package, they have different
[JAX](https://jax.readthedocs.io/) 0.4.33 (which requires Python 3.10 or above) or above is required.
Both `.xlo` and `.jax` are customized format extensions defined in DeePMD-kit, since JAX has no convention for file extensions.
`.savedmodel` is the TensorFlow [SavedModel format](https://www.tensorflow.org/guide/saved_model) generated by [JAX2TF](https://www.tensorflow.org/guide/jax2tf), which needs the installation of TensorFlow.
Currently, this backend is developed actively, and has no support for training and the C++ interface.
Only the `.savedmodel` format supports C++ inference, which needs the TensorFlow C++ interface.
The model is device-specific, so that the model generated on the GPU device cannot be run on the CPUs.
Currently, this backend is developed actively, and has no support for training.

### DP {{ dpmodel_icon }}

Expand Down
10 changes: 6 additions & 4 deletions doc/install/install-from-source.md
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,9 @@ If one does not need to use DeePMD-kit with LAMMPS or i-PI, then the python inte

::::{tab-set}

:::{tab-item} TensorFlow {{ tensorflow_icon }}
:::{tab-item} TensorFlow {{ tensorflow_icon }} / JAX {{ jax_icon }}

The C++ interfaces of both TensorFlow and JAX backends are based on the TensorFlow C++ library.

Since TensorFlow 2.12, TensorFlow C++ library (`libtensorflow_cc`) is packaged inside the Python library. Thus, you can skip building TensorFlow C++ library manually. If that does not work for you, you can still build it manually.

Expand Down Expand Up @@ -338,7 +340,7 @@ We recommend using [conda packages](https://docs.deepmodeling.org/faq/conda.html

::::{tab-set}

:::{tab-item} TensorFlow {{ tensorflow_icon }}
:::{tab-item} TensorFlow {{ tensorflow_icon }} / JAX {{ jax_icon }}

I assume you have activated the TensorFlow Python environment and want to install DeePMD-kit into path `$deepmd_root`, then execute CMake

Expand Down Expand Up @@ -375,7 +377,7 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value

**Type**: `BOOL` (`ON`/`OFF`), Default: `OFF`

{{ tensorflow_icon }} Whether building the TensorFlow backend.
{{ tensorflow_icon }} {{ jax_icon }} Whether building the TensorFlow backend and the JAX backend.

:::

Expand All @@ -391,7 +393,7 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value

**Type**: `PATH`

{{ tensorflow_icon }} The Path to TensorFlow's C++ interface.
{{ tensorflow_icon }} {{ jax_icon }} The Path to TensorFlow's C++ interface.

:::

Expand Down
10 changes: 10 additions & 0 deletions doc/model/dpa2.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ If one runs LAMMPS with MPI, the customized OP library for the C++ interface sho
If one runs LAMMPS with MPI and CUDA devices, it is recommended to compile the customized OP library for the C++ interface with a [CUDA-Aware MPI](https://developer.nvidia.com/mpi-solutions-gpus) library and CUDA,
otherwise the communication between GPU cards falls back to the slower CPU implementation.

## Limiations of the JAX backend with LAMMPS {{ jax_icon }}

When using the JAX backend, 2 or more MPI ranks are not supported. One must set `map` to `yes` using the [`atom_modify`](https://docs.lammps.org/atom_modify.html) command.

```lammps
atom_modify map yes
```

See the example `examples/water/lmp/jax_dpa2.lammps`.

## Data format

DPA-2 supports both the [standard data format](../data/system.md) and the [mixed type data format](../data/system.md#mixed-type).
Binary file added examples/water/dpa2/frozen_model.pth
Binary file not shown.
31 changes: 31 additions & 0 deletions examples/water/lmp/jax_dpa2.lammps
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@


# bulk water

units metal
boundary p p p
atom_style atomic
# Below line is required when using DPA-2 with the JAX backend
atom_modify map yes

neighbor 2.0 bin
neigh_modify every 10 delay 0 check no

read_data water.lmp
mass 1 16
mass 2 2

# See https://deepmd.rtfd.io/lammps/ for usage
pair_style deepmd frozen_model.savedmodel
# If atom names (O H in this example) are not set in the pair_coeff command, the type_map defined by the training parameter will be used by default.
pair_coeff * * O H

velocity all create 330.0 23456789

fix 1 all nvt temp 330.0 330.0 0.5
timestep 0.0005
thermo_style custom step pe ke etotal temp press vol
thermo 100
dump 1 all custom 100 water.dump id type x y z

run 1000
16 changes: 13 additions & 3 deletions source/api_c/include/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ extern "C" {
/** C API version. Bumped whenever the API is changed.
* @since API version 22
*/
#define DP_C_API_VERSION 24
#define DP_C_API_VERSION 25

/**
* @brief Neighbor list.
Expand All @@ -31,7 +31,7 @@ extern DP_Nlist* DP_NewNlist(int inum_,
int* ilist_,
int* numneigh_,
int** firstneigh_);
/*
/**
* @brief Create a new neighbor list with communication capabilities.
* @details This function extends DP_NewNlist by adding support for parallel
* communication, allowing the neighbor list to be used in distributed
Expand Down Expand Up @@ -68,7 +68,7 @@ extern DP_Nlist* DP_NewNlist_comm(int inum_,
int* recvproc,
void* world);

/*
/**
* @brief Set mask for a neighbor list.
*
* @param nl Neighbor list.
Expand All @@ -78,6 +78,16 @@ extern DP_Nlist* DP_NewNlist_comm(int inum_,
**/
extern void DP_NlistSetMask(DP_Nlist* nl, int mask);

/**
* @brief Set mapping for a neighbor list.
*
* @param nl Neighbor list.
* @param mapping mapping from all atoms to real atoms, in size nall.
* @since API version 25
*
**/
extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping);

/**
* @brief Delete a neighbor list.
*
Expand Down
5 changes: 5 additions & 0 deletions source/api_c/include/deepmd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,11 @@ struct InputNlist {
* @brief Set mask for this neighbor list.
*/
void set_mask(int mask) { DP_NlistSetMask(nl, mask); };
/**
* @brief Set mapping for this neighbor list.
* @param mapping mapping from all atoms to real atoms, in size nall.
*/
void set_mapping(int *mapping) { DP_NlistSetMapping(nl, mapping); };
};

/**
Expand Down
3 changes: 3 additions & 0 deletions source/api_c/src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ DP_Nlist* DP_NewNlist_comm(int inum_,
return new_nl;
}
void DP_NlistSetMask(DP_Nlist* nl, int mask) { nl->nl.set_mask(mask); }
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
nl->nl.set_mapping(mapping);
}
void DP_DeleteNlist(DP_Nlist* nl) { delete nl; }

// DP Base Model
Expand Down
Loading

0 comments on commit 698b08d

Please sign in to comment.