diff --git a/.github/workflows/test_cc.yml b/.github/workflows/test_cc.yml index 768590980f..b7fd9c6bb4 100644 --- a/.github/workflows/test_cc.yml +++ b/.github/workflows/test_cc.yml @@ -27,7 +27,13 @@ jobs: mpi: mpich - uses: lukka/get-cmake@latest - run: python -m pip install uv - - run: source/install/uv_with_retry.sh pip install --system tensorflow + - name: Install Python dependencies + run: | + source/install/uv_with_retry.sh pip install --system tensorflow-cpu + export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)') + source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp,jax] mpi4py + - name: Convert models + run: source/tests/infer/convert-models.sh - name: Download libtorch run: | wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.1.2%2Bcpu.zip -O libtorch.zip @@ -47,12 +53,6 @@ jobs: CMAKE_GENERATOR: Ninja CXXFLAGS: ${{ matrix.check_memleak && '-fsanitize=leak' || '' }} # test lammps - - run: | - export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)') - source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp] mpi4py - env: - DP_BUILD_TESTING: 1 - if: ${{ !matrix.check_memleak }} - run: pytest --cov=deepmd source/lmp/tests env: OMP_NUM_THREADS: 1 diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index 996a1bcff0..4dbdc5acb9 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -19,7 +19,7 @@ jobs: runs-on: nvidia # https://github.com/deepmodeling/deepmd-kit/pull/2884#issuecomment-1744216845 container: - image: nvidia/cuda:12.3.1-devel-ubuntu22.04 + image: nvidia/cuda:12.6.2-cudnn-devel-ubuntu22.04 options: --gpus all if: github.repository_owner == 'deepmodeling' && (github.event_name == 'pull_request' && github.event.label && github.event.label.name == 'Test CUDA' || github.event_name == 'workflow_dispatch' || github.event_name == 'merge_group') steps: @@ -63,12 +63,15 @@ jobs: CUDA_VISIBLE_DEVICES: 0 # See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html XLA_PYTHON_CLIENT_PREALLOCATE: false + - name: Convert models + run: source/tests/infer/convert-models.sh - name: Download libtorch run: | wget https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.5.0%2Bcu124.zip -O libtorch.zip unzip libtorch.zip - run: | export CMAKE_PREFIX_PATH=$GITHUB_WORKSPACE/libtorch + export LD_LIBRARY_PATH=$CUDA_PATH/lib64:/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH source/install/test_cc_local.sh env: OMP_NUM_THREADS: 1 @@ -79,7 +82,7 @@ jobs: DP_VARIANT: cuda DP_USE_MPICH2: 1 - run: | - export LD_LIBRARY_PATH=$GITHUB_WORKSPACE/dp_test/lib:$GITHUB_WORKSPACE/libtorch/lib:$CUDA_PATH/lib64:$LD_LIBRARY_PATH + export LD_LIBRARY_PATH=$CUDA_PATH/lib64:/usr/lib/x86_64-linux-gnu/:$GITHUB_WORKSPACE/dp_test/lib:$GITHUB_WORKSPACE/libtorch/lib:$LD_LIBRARY_PATH export PATH=$GITHUB_WORKSPACE/dp_test/bin:$PATH python -m pytest -s source/lmp/tests || (cat log.lammps && exit 1) python -m pytest source/ipi/tests diff --git a/doc/backend.md b/doc/backend.md index dd20193d58..5943165dd5 100644 --- a/doc/backend.md +++ b/doc/backend.md @@ -31,7 +31,9 @@ While `.pth` and `.pt` are the same in the PyTorch package, they have different [JAX](https://jax.readthedocs.io/) 0.4.33 (which requires Python 3.10 or above) or above is required. Both `.xlo` and `.jax` are customized format extensions defined in DeePMD-kit, since JAX has no convention for file extensions. `.savedmodel` is the TensorFlow [SavedModel format](https://www.tensorflow.org/guide/saved_model) generated by [JAX2TF](https://www.tensorflow.org/guide/jax2tf), which needs the installation of TensorFlow. -Currently, this backend is developed actively, and has no support for training and the C++ interface. +Only the `.savedmodel` format supports C++ inference, which needs the TensorFlow C++ interface. +The model is device-specific, so that the model generated on the GPU device cannot be run on the CPUs. +Currently, this backend is developed actively, and has no support for training. ### DP {{ dpmodel_icon }} diff --git a/doc/install/install-from-source.md b/doc/install/install-from-source.md index 4a0a104b7e..0bf6fa5ee3 100644 --- a/doc/install/install-from-source.md +++ b/doc/install/install-from-source.md @@ -297,7 +297,9 @@ If one does not need to use DeePMD-kit with LAMMPS or i-PI, then the python inte ::::{tab-set} -:::{tab-item} TensorFlow {{ tensorflow_icon }} +:::{tab-item} TensorFlow {{ tensorflow_icon }} / JAX {{ jax_icon }} + +The C++ interfaces of both TensorFlow and JAX backends are based on the TensorFlow C++ library. Since TensorFlow 2.12, TensorFlow C++ library (`libtensorflow_cc`) is packaged inside the Python library. Thus, you can skip building TensorFlow C++ library manually. If that does not work for you, you can still build it manually. @@ -338,7 +340,7 @@ We recommend using [conda packages](https://docs.deepmodeling.org/faq/conda.html ::::{tab-set} -:::{tab-item} TensorFlow {{ tensorflow_icon }} +:::{tab-item} TensorFlow {{ tensorflow_icon }} / JAX {{ jax_icon }} I assume you have activated the TensorFlow Python environment and want to install DeePMD-kit into path `$deepmd_root`, then execute CMake @@ -375,7 +377,7 @@ One may add the following CMake variables to `cmake` using the [`-D ==nl.set_mask(mask); } +void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) { + nl->nl.set_mapping(mapping); +} void DP_DeleteNlist(DP_Nlist* nl) { delete nl; } // DP Base Model diff --git a/source/api_cc/include/DeepPotJAX.h b/source/api_cc/include/DeepPotJAX.h new file mode 100644 index 0000000000..76533fcc35 --- /dev/null +++ b/source/api_cc/include/DeepPotJAX.h @@ -0,0 +1,289 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +#pragma once + +#include +#include + +#include "DeepPot.h" +#include "common.h" +#include "neighbor_list.h" + +namespace deepmd { +/** + * @brief TensorFlow implementation for Deep Potential. + **/ +class DeepPotJAX : public DeepPotBackend { + public: + /** + * @brief DP constructor without initialization. + **/ + DeepPotJAX(); + virtual ~DeepPotJAX(); + /** + * @brief DP constructor with initialization. + * @param[in] model The name of the frozen model file. + * @param[in] gpu_rank The GPU rank. Default is 0. If < 0, use CPU. + * @param[in] file_content The content of the model file. If it is not empty, + *DP will read from the string instead of the file. + **/ + DeepPotJAX(const std::string& model, + const int& gpu_rank = 0, + const std::string& file_content = ""); + /** + * @brief Initialize the DP. + * @param[in] model The name of the frozen model file. + * @param[in] gpu_rank The GPU rank. Default is 0. If < 0, use CPU. + * @param[in] file_content The content of the model file. If it is not empty, + *DP will read from the string instead of the file. + **/ + void init(const std::string& model, + const int& gpu_rank = 0, + const std::string& file_content = ""); + /** + * @brief Get the cutoff radius. + * @return The cutoff radius. + **/ + double cutoff() const { + assert(inited); + return rcut; + }; + /** + * @brief Get the number of types. + * @return The number of types. + **/ + int numb_types() const { + assert(inited); + return ntypes; + }; + /** + * @brief Get the number of types with spin. + * @return The number of types with spin. + **/ + int numb_types_spin() const { + assert(inited); + return 0; + }; + /** + * @brief Get the dimension of the frame parameter. + * @return The dimension of the frame parameter. + **/ + int dim_fparam() const { + assert(inited); + return dfparam; + }; + /** + * @brief Get the dimension of the atomic parameter. + * @return The dimension of the atomic parameter. + **/ + int dim_aparam() const { + assert(inited); + return daparam; + }; + /** + * @brief Get the type map (element name of the atom types) of this model. + * @param[out] type_map The type map of this model. + **/ + void get_type_map(std::string& type_map); + + /** + * @brief Get whether the atom dimension of aparam is nall instead of fparam. + * @param[out] aparam_nall whether the atom dimension of aparam is nall + *instead of fparam. + **/ + bool is_aparam_nall() const { + assert(inited); + return false; + }; + + // forward to template class + void computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + void computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + void computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& inlist, + const int& ago, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + void computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& inlist, + const int& ago, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + void computew_mixed_type(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const int& nframes, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + void computew_mixed_type(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const int& nframes, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + + private: + bool inited; + // device + std::string device; + // the cutoff radius + double rcut; + // the number of types + int ntypes; + // the dimension of the frame parameter + int dfparam; + // the dimension of the atomic parameter + int daparam; + // type map + std::string type_map; + // sel + std::vector sel; + // number of neighbors + int nnei; + // padding to nall + int padding_to_nall = 0; + // padding for nloc + int padding_for_nloc = 0; + /** TF C API objects. + * @{ + */ + TF_Graph* graph; + TF_Status* status; + TF_Session* session; + TF_SessionOptions* sessionopts; + TFE_ContextOptions* ctx_opts; + TFE_Context* ctx; + std::vector func_vector; + /** + * @} + */ + // neighbor list data + NeighborListData nlist_data; + /** + * @brief Evaluate the energy, force, virial, atomic energy, and atomic virial + *by using this DP. + * @param[out] ener The system energy. + * @param[out] force The force on each atom. + * @param[out] virial The virial. + * @param[out] atom_energy The atomic energy. + * @param[out] atom_virial The atomic virial. + * @param[in] coord The coordinates of atoms. The array should be of size + *nframes x natoms x 3. + * @param[in] atype The atom types. The list should contain natoms ints. + * @param[in] box The cell of the region. The array should be of size nframes + *x 9. + * @param[in] fparam The frame parameter. The array can be of size : + * nframes x dim_fparam. + * dim_fparam. Then all frames are assumed to be provided with the same + *fparam. + * @param[in] aparam The atomic parameter The array can be of size : + * nframes x natoms x dim_aparam. + * natoms x dim_aparam. Then all frames are assumed to be provided with the + *same aparam. + * @param[in] atomic Whether to compute the atomic energy and virial. + **/ + template + void compute(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + + /** + * @brief Evaluate the energy, force, virial, atomic energy, and atomic virial + *by using this DP. + * @param[out] ener The system energy. + * @param[out] force The force on each atom. + * @param[out] virial The virial. + * @param[out] atom_energy The atomic energy. + * @param[out] atom_virial The atomic virial. + * @param[in] coord The coordinates of atoms. The array should be of size + *nframes x natoms x 3. + * @param[in] atype The atom types. The list should contain natoms ints. + * @param[in] box The cell of the region. The array should be of size nframes + *x 9. + * @param[in] nghost The number of ghost atoms. + * @param[in] lmp_list The input neighbour list. + * @param[in] ago Update the internal neighbour list if ago is 0. + * @param[in] fparam The frame parameter. The array can be of size : + * nframes x dim_fparam. + * dim_fparam. Then all frames are assumed to be provided with the same + *fparam. + * @param[in] aparam The atomic parameter The array can be of size : + * nframes x natoms x dim_aparam. + * natoms x dim_aparam. Then all frames are assumed to be provided with the + *same aparam. + * @param[in] atomic Whether to compute atomic energy and virial. + **/ + template + void compute(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& lmp_list, + const int& ago, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); +}; +} // namespace deepmd diff --git a/source/api_cc/include/common.h b/source/api_cc/include/common.h index 9b1adcbd62..def3df933b 100644 --- a/source/api_cc/include/common.h +++ b/source/api_cc/include/common.h @@ -13,7 +13,7 @@ namespace deepmd { typedef double ENERGYTYPE; -enum DPBackend { TensorFlow, PyTorch, Paddle, Unknown }; +enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown }; struct NeighborListData { /// Array stores the core region atom's index diff --git a/source/api_cc/src/DeepPot.cc b/source/api_cc/src/DeepPot.cc index b47c8a9ba1..6f8724f78e 100644 --- a/source/api_cc/src/DeepPot.cc +++ b/source/api_cc/src/DeepPot.cc @@ -7,6 +7,7 @@ #include "AtomMap.h" #include "common.h" #ifdef BUILD_TENSORFLOW +#include "DeepPotJAX.h" #include "DeepPotTF.h" #endif #ifdef BUILD_PYTORCH @@ -41,6 +42,9 @@ void DeepPot::init(const std::string& model, backend = deepmd::DPBackend::PyTorch; } else if (model.length() >= 3 && model.substr(model.length() - 3) == ".pb") { backend = deepmd::DPBackend::TensorFlow; + } else if (model.length() >= 11 && + model.substr(model.length() - 11) == ".savedmodel") { + backend = deepmd::DPBackend::JAX; } else { throw deepmd::deepmd_exception("Unsupported model file format"); } @@ -58,6 +62,14 @@ void DeepPot::init(const std::string& model, #endif } else if (deepmd::DPBackend::Paddle == backend) { throw deepmd::deepmd_exception("PaddlePaddle backend is not supported yet"); + } else if (deepmd::DPBackend::JAX == backend) { +#ifdef BUILD_TENSORFLOW + dp = std::make_shared(model, gpu_rank, file_content); +#else + throw deepmd::deepmd_exception( + "TensorFlow backend is not built, which is used to load JAX2TF " + "SavedModels"); +#endif } else { throw deepmd::deepmd_exception("Unknown file type"); } diff --git a/source/api_cc/src/DeepPotJAX.cc b/source/api_cc/src/DeepPotJAX.cc new file mode 100644 index 0000000000..be1a5542b4 --- /dev/null +++ b/source/api_cc/src/DeepPotJAX.cc @@ -0,0 +1,783 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +#ifdef BUILD_TENSORFLOW + +#include "DeepPotJAX.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "device.h" +#include "errors.h" + +#define PADDING_FACTOR 1.05 + +inline void check_status(TF_Status* status) { + if (TF_GetCode(status) != TF_OK) { + throw deepmd::deepmd_exception("TensorFlow C API Error: " + + std::string(TF_Message(status))); + } +} + +inline void find_function(TF_Function*& found_func, + const std::vector& funcs, + const std::string func_name) { + for (size_t i = 0; i < funcs.size(); i++) { + TF_Function* func = funcs[i]; + const char* name = TF_FunctionName(func); + std::string name_(name); + // remove trailing integer e.g. _123 + std::string::size_type pos = name_.find_last_not_of("0123456789_"); + if (pos != std::string::npos) { + name_ = name_.substr(0, pos + 1); + } + if (name_ == "__inference_" + func_name) { + found_func = func; + return; + } + } + found_func = NULL; +} + +inline TF_DataType get_data_tensor_type(const std::vector& data) { + return TF_DOUBLE; +} + +inline TF_DataType get_data_tensor_type(const std::vector& data) { + return TF_FLOAT; +} + +inline TF_DataType get_data_tensor_type(const std::vector& data) { + return TF_INT32; +} + +inline TF_DataType get_data_tensor_type(const std::vector& data) { + return TF_INT64; +} + +inline TFE_Op* get_func_op(TFE_Context* ctx, + const std::string func_name, + const std::vector& funcs, + const std::string device, + TF_Status* status) { + TF_Function* func = NULL; + find_function(func, funcs, func_name); + if (func == NULL) { + throw std::runtime_error("Function " + func_name + " not found"); + } + const char* real_func_name = TF_FunctionName(func); + // execute the function + TFE_Op* op = TFE_NewOp(ctx, real_func_name, status); + check_status(status); + TFE_OpSetDevice(op, device.c_str(), status); + check_status(status); + return op; +} + +template +inline T get_scalar(TFE_Context* ctx, + const std::string func_name, + const std::vector& funcs, + const std::string device, + TF_Status* status) { + TFE_Op* op = get_func_op(ctx, func_name, funcs, device, status); + check_status(status); + TFE_TensorHandle* retvals[1]; + int nretvals = 1; + TFE_Execute(op, retvals, &nretvals, status); + check_status(status); + TFE_TensorHandle* retval = retvals[0]; + TF_Tensor* tensor = TFE_TensorHandleResolve(retval, status); + check_status(status); + T* data = (T*)TF_TensorData(tensor); + // copy data + T result = *data; + // deallocate + TFE_DeleteOp(op); + TF_DeleteTensor(tensor); + TFE_DeleteTensorHandle(retval); + return result; +} + +template +inline std::vector get_vector(TFE_Context* ctx, + const std::string func_name, + const std::vector& funcs, + const std::string device, + TF_Status* status) { + TFE_Op* op = get_func_op(ctx, func_name, funcs, device, status); + check_status(status); + TFE_TensorHandle* retvals[1]; + int nretvals = 1; + TFE_Execute(op, retvals, &nretvals, status); + check_status(status); + TFE_TensorHandle* retval = retvals[0]; + // copy data + std::vector result; + tensor_to_vector(result, retval, status); + // deallocate + TFE_DeleteTensorHandle(retval); + TFE_DeleteOp(op); + return result; +} + +inline std::vector get_vector_string( + TFE_Context* ctx, + const std::string func_name, + const std::vector& funcs, + const std::string device, + TF_Status* status) { + TFE_Op* op = get_func_op(ctx, func_name, funcs, device, status); + check_status(status); + TFE_TensorHandle* retvals[1]; + int nretvals = 1; + TFE_Execute(op, retvals, &nretvals, status); + check_status(status); + TFE_TensorHandle* retval = retvals[0]; + TF_Tensor* tensor = TFE_TensorHandleResolve(retval, status); + check_status(status); + // calculate the number of bytes in each string + const void* data = TF_TensorData(tensor); + int64_t bytes_each_string = + TF_TensorByteSize(tensor) / TF_TensorElementCount(tensor); + // copy data + std::vector result; + for (int ii = 0; ii < TF_TensorElementCount(tensor); ++ii) { + const TF_TString* datastr = + static_cast(static_cast( + static_cast(data) + ii * bytes_each_string)); + const char* dst = TF_TString_GetDataPointer(datastr); + size_t dst_len = TF_TString_GetSize(datastr); + result.push_back(std::string(dst, dst_len)); + } + + // deallocate + TFE_DeleteOp(op); + TF_DeleteTensor(tensor); + TFE_DeleteTensorHandle(retval); + return result; +} + +template +inline TF_Tensor* create_tensor(const std::vector& data, + const std::vector& shape) { + TF_Tensor* tensor = + TF_AllocateTensor(get_data_tensor_type(data), shape.data(), shape.size(), + data.size() * sizeof(T)); + memcpy(TF_TensorData(tensor), data.data(), TF_TensorByteSize(tensor)); + return tensor; +} + +template +inline TFE_TensorHandle* add_input(TFE_Op* op, + const std::vector& data, + const std::vector& data_shape, + TF_Tensor*& data_tensor, + TF_Status* status) { + data_tensor = create_tensor(data, data_shape); + TFE_TensorHandle* handle = TFE_NewTensorHandle(data_tensor, status); + check_status(status); + + TFE_OpAddInput(op, handle, status); + check_status(status); + return handle; +} + +template +inline void tensor_to_vector(std::vector& result, + TFE_TensorHandle* retval, + TF_Status* status) { + TF_Tensor* tensor = TFE_TensorHandleResolve(retval, status); + check_status(status); + T* data = (T*)TF_TensorData(tensor); + // copy data + result.resize(TF_TensorElementCount(tensor)); + for (int i = 0; i < TF_TensorElementCount(tensor); i++) { + result[i] = data[i]; + } + // Delete the tensor to free memory + TF_DeleteTensor(tensor); +} + +deepmd::DeepPotJAX::DeepPotJAX() : inited(false) {} +deepmd::DeepPotJAX::DeepPotJAX(const std::string& model, + const int& gpu_rank, + const std::string& file_content) + : inited(false) { + init(model, gpu_rank, file_content); +} +void deepmd::DeepPotJAX::init(const std::string& model, + const int& gpu_rank, + const std::string& file_content) { + if (inited) { + std::cerr << "WARNING: deepmd-kit should not be initialized twice, do " + "nothing at the second call of initializer" + << std::endl; + return; + } + + const char* saved_model_dir = model.c_str(); + graph = TF_NewGraph(); + status = TF_NewStatus(); + + sessionopts = TF_NewSessionOptions(); + int num_intra_nthreads, num_inter_nthreads; + get_env_nthreads(num_intra_nthreads, num_inter_nthreads); + // https://github.com/Neargye/hello_tf_c_api/blob/51516101cf59408a6bb456f7e5f3c6628e327b3a/src/tf_utils.cpp#L400-L401 + // https://github.com/Neargye/hello_tf_c_api/blob/51516101cf59408a6bb456f7e5f3c6628e327b3a/src/tf_utils.cpp#L364-L379 + // The following is an equivalent of setting this in Python: + // config = tf.ConfigProto( allow_soft_placement = True ) + // config.gpu_options.allow_growth = True + // config.gpu_options.per_process_gpu_memory_fraction = percentage + // Create a byte-array for the serialized ProtoConfig, set the mandatory bytes + // (first three and last four) + std::array config = { + {0x10, static_cast(num_intra_nthreads), 0x28, + static_cast(num_inter_nthreads), 0x32, 0xb, 0x9, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x20, 0x1, 0x38, 0x1}}; + + // Convert the desired percentage into a byte-array. + double gpu_memory_fraction = 0.9; + auto bytes = reinterpret_cast(&gpu_memory_fraction); + + // Put it to the config byte-array, from 7 to 14: + for (std::size_t i = 0; i < sizeof(gpu_memory_fraction); ++i) { + config[i + 7] = bytes[i]; + } + + TF_SetConfig(sessionopts, config.data(), config.size(), status); + check_status(status); + + TF_Buffer* runopts = NULL; + + const char* tags = "serve"; + int ntags = 1; + + session = TF_LoadSessionFromSavedModel(sessionopts, runopts, saved_model_dir, + &tags, ntags, graph, NULL, status); + check_status(status); + + int nfuncs = TF_GraphNumFunctions(graph); + // allocate memory for the TF_Function* array + func_vector.resize(nfuncs); + TF_Function** funcs = func_vector.data(); + TF_GraphGetFunctions(graph, funcs, nfuncs, status); + check_status(status); + + ctx_opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetConfig(ctx_opts, config.data(), config.size(), status); + check_status(status); + ctx = TFE_NewContext(ctx_opts, status); + check_status(status); +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + int gpu_num; + DPGetDeviceCount(gpu_num); // check current device environment + if (gpu_num > 0 && gpu_rank >= 0) { + DPErrcheck(DPSetDevice(gpu_rank % gpu_num)); + device = "/gpu:" + std::to_string(gpu_rank % gpu_num); + } else { + device = "/cpu:0"; + } +#else + device = "/cpu:0"; +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + + // add all functions, otherwise the function will not be found + // even for tf.cond + for (size_t i = 0; i < func_vector.size(); i++) { + TF_Function* func = func_vector[i]; + TFE_ContextAddFunction(ctx, func, status); + check_status(status); + } + + rcut = get_scalar(ctx, "get_rcut", func_vector, device, status); + dfparam = + get_scalar(ctx, "get_dim_fparam", func_vector, device, status); + daparam = + get_scalar(ctx, "get_dim_aparam", func_vector, device, status); + std::vector type_map_ = + get_vector_string(ctx, "get_type_map", func_vector, device, status); + // deepmd-kit stores type_map as a concatenated string, split by ' ' + type_map = type_map_[0]; + for (size_t i = 1; i < type_map_.size(); i++) { + type_map += " " + type_map_[i]; + } + ntypes = type_map_.size(); + sel = get_vector(ctx, "get_sel", func_vector, device, status); + nnei = std::accumulate(sel.begin(), sel.end(), decltype(sel)::value_type(0)); + inited = true; +} + +deepmd::DeepPotJAX::~DeepPotJAX() { + if (inited) { + TF_DeleteSession(session, status); + TF_DeleteGraph(graph); + TF_DeleteSessionOptions(sessionopts); + TF_DeleteStatus(status); + TFE_DeleteContext(ctx); + TFE_DeleteContextOptions(ctx_opts); + for (size_t i = 0; i < func_vector.size(); i++) { + TF_DeleteFunction(func_vector[i]); + } + } +} + +template +void deepmd::DeepPotJAX::compute(std::vector& ener, + std::vector& force_, + std::vector& virial, + std::vector& atom_energy_, + std::vector& atom_virial_, + const std::vector& dcoord, + const std::vector& datype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam_, + const bool atomic) { + std::vector coord, force, aparam, atom_energy, atom_virial; + std::vector ener_double, force_double, virial_double, + atom_energy_double, atom_virial_double; + std::vector atype, fwd_map, bkw_map; + int nghost_real, nall_real, nloc_real; + int nall = datype.size(); + // nlist passed to the model + int nframes = nall > 0 ? (dcoord.size() / 3 / nall) : 1; + int nghost = 0; + + select_real_atoms_coord(coord, atype, aparam, nghost_real, fwd_map, bkw_map, + nall_real, nloc_real, dcoord, datype, aparam_, nghost, + ntypes, nframes, daparam, nall, false); + + if (nloc_real == 0) { + // no real atoms, fill 0 for all outputs + // this can prevent a Xla error + ener.resize(nframes, 0.0); + force_.resize(static_cast(nframes) * nall * 3, 0.0); + virial.resize(static_cast(nframes) * 9, 0.0); + atom_energy_.resize(static_cast(nframes) * nall, 0.0); + atom_virial_.resize(static_cast(nframes) * nall * 9, 0.0); + return; + } + + // cast coord, fparam, and aparam to double - I think it's useless to have a + // float model interface + std::vector coord_double(coord.begin(), coord.end()); + std::vector box_double(box.begin(), box.end()); + std::vector fparam_double(fparam.begin(), fparam.end()); + std::vector aparam_double(aparam.begin(), aparam.end()); + + TFE_Op* op; + if (atomic) { + op = get_func_op(ctx, "call_with_atomic_virial", func_vector, device, + status); + } else { + op = get_func_op(ctx, "call_without_atomic_virial", func_vector, device, + status); + } + std::vector input_list(5); + std::vector data_tensor(5); + // coord + std::vector coord_shape = {nframes, nloc_real, 3}; + input_list[0] = + add_input(op, coord_double, coord_shape, data_tensor[0], status); + // atype + std::vector atype_shape = {nframes, nloc_real}; + input_list[1] = add_input(op, atype, atype_shape, data_tensor[1], status); + // box + int box_size = box_double.size() > 0 ? 3 : 0; + std::vector box_shape = {nframes, box_size, box_size}; + input_list[2] = add_input(op, box_double, box_shape, data_tensor[2], status); + // fparam + std::vector fparam_shape = {nframes, dfparam}; + input_list[3] = + add_input(op, fparam_double, fparam_shape, data_tensor[3], status); + // aparam + std::vector aparam_shape = {nframes, nloc_real, daparam}; + input_list[4] = + add_input(op, aparam_double, aparam_shape, data_tensor[4], status); + // execute the function + int nretvals = 6; + TFE_TensorHandle* retvals[nretvals]; + + TFE_Execute(op, retvals, &nretvals, status); + check_status(status); + + // copy data + // for atom virial, the order is: + // Identity_15 energy -1, -1, 1 + // Identity_16 energy_derv_c -1, -1, 1, 9 (may pop) + // Identity_17 energy_derv_c_redu -1, 1, 9 + // Identity_18 energy_derv_r -1, -1, 1, 3 + // Identity_19 energy_redu -1, 1 + // Identity_20 mask (int32) -1, -1 + // + // for no atom virial, the order is: + // Identity_15 energy -1, -1, 1 + // Identity_16 energy_derv_c -1, 1, 9 + // Identity_17 energy_derv_r -1, -1, 1, 3 + // Identity_18 energy_redu -1, 1 + // Identity_19 mask (int32) -1, -1 + // + // it seems the order is the alphabet order? + // not sure whether it is safe to assume the order + if (atomic) { + tensor_to_vector(ener_double, retvals[4], status); + tensor_to_vector(force_double, retvals[3], status); + tensor_to_vector(virial_double, retvals[2], status); + tensor_to_vector(atom_energy_double, retvals[0], status); + tensor_to_vector(atom_virial_double, retvals[1], status); + } else { + tensor_to_vector(ener_double, retvals[3], status); + tensor_to_vector(force_double, retvals[2], status); + tensor_to_vector(virial_double, retvals[1], status); + tensor_to_vector(atom_energy_double, retvals[0], status); + } + + // cast back to VALUETYPE + ener = std::vector(ener_double.begin(), ener_double.end()); + force = std::vector(force_double.begin(), force_double.end()); + virial = std::vector(virial_double.begin(), virial_double.end()); + atom_energy = std::vector(atom_energy_double.begin(), + atom_energy_double.end()); + atom_virial = std::vector(atom_virial_double.begin(), + atom_virial_double.end()); + force.resize(static_cast(nframes) * nall_real * 3); + atom_virial.resize(static_cast(nframes) * nall_real * 9); + + // nall atom_energy is required in the C++ API; + // we always forget it! + atom_energy.resize(static_cast(nframes) * nall_real, 0.0); + + force_.resize(static_cast(nframes) * fwd_map.size() * 3); + atom_energy_.resize(static_cast(nframes) * fwd_map.size()); + atom_virial_.resize(static_cast(nframes) * fwd_map.size() * 9); + select_map(force_, force, bkw_map, 3, nframes, fwd_map.size(), + nall_real); + select_map(atom_energy_, atom_energy, bkw_map, 1, nframes, + fwd_map.size(), nall_real); + select_map(atom_virial_, atom_virial, bkw_map, 9, nframes, + fwd_map.size(), nall_real); + + // cleanup input_list, etc + for (size_t i = 0; i < 5; i++) { + TFE_DeleteTensorHandle(input_list[i]); + TF_DeleteTensor(data_tensor[i]); + } + for (size_t i = 0; i < nretvals; i++) { + TFE_DeleteTensorHandle(retvals[i]); + } + TFE_DeleteOp(op); +} + +template +void deepmd::DeepPotJAX::compute(std::vector& ener, + std::vector& force_, + std::vector& virial, + std::vector& atom_energy_, + std::vector& atom_virial_, + const std::vector& dcoord, + const std::vector& datype, + const std::vector& box, + const int nghost, + const InputNlist& lmp_list, + const int& ago, + const std::vector& fparam, + const std::vector& aparam_, + const bool atomic) { + std::vector coord, force, aparam, atom_energy, atom_virial; + std::vector ener_double, force_double, virial_double, + atom_energy_double, atom_virial_double; + std::vector atype, fwd_map, bkw_map; + int nghost_real, nall_real, nloc_real; + int nall = datype.size(); + // nlist passed to the model + int nframes = 1; + + select_real_atoms_coord(coord, atype, aparam, nghost_real, fwd_map, bkw_map, + nall_real, nloc_real, dcoord, datype, aparam_, nghost, + ntypes, nframes, daparam, nall, false); + + if (nloc_real == 0) { + // no real atoms, fill 0 for all outputs + // this can prevent a Xla error + ener.resize(nframes, 0.0); + force_.resize(static_cast(nframes) * nall * 3, 0.0); + virial.resize(static_cast(nframes) * 9, 0.0); + atom_energy_.resize(static_cast(nframes) * nall, 0.0); + atom_virial_.resize(static_cast(nframes) * nall * 9, 0.0); + return; + } + + // cast coord, fparam, and aparam to double - I think it's useless to have a + // float model interface + std::vector coord_double(coord.begin(), coord.end()); + std::vector fparam_double(fparam.begin(), fparam.end()); + std::vector aparam_double(aparam.begin(), aparam.end()); + + if (padding_for_nloc != nloc_real) { + padding_to_nall = nall_real * PADDING_FACTOR; + padding_for_nloc = nloc_real; + } + while (padding_to_nall < nall_real) { + padding_to_nall *= PADDING_FACTOR; + } + // do padding + coord_double.resize(nframes * padding_to_nall * 3, 0.0); + atype.resize(nframes * padding_to_nall, -1); + + TFE_Op* op; + if (atomic) { + op = get_func_op(ctx, "call_lower_with_atomic_virial", func_vector, device, + status); + } else { + op = get_func_op(ctx, "call_lower_without_atomic_virial", func_vector, + device, status); + } + std::vector input_list(6); + std::vector data_tensor(6); + // coord + std::vector coord_shape = {nframes, padding_to_nall, 3}; + input_list[0] = + add_input(op, coord_double, coord_shape, data_tensor[0], status); + // atype + std::vector atype_shape = {nframes, padding_to_nall}; + input_list[1] = add_input(op, atype, atype_shape, data_tensor[1], status); + // nlist + if (ago == 0) { + nlist_data.copy_from_nlist(lmp_list); + nlist_data.shuffle_exclude_empty(fwd_map); + } + size_t max_size = 0; + for (const auto& row : nlist_data.jlist) { + max_size = std::max(max_size, row.size()); + } + std::vector nlist_shape = {nframes, nloc_real, + static_cast(max_size)}; + std::vector nlist(static_cast(nframes) * nloc_real * + max_size); + // pass nlist_data.jlist to nlist + for (int ii = 0; ii < nloc_real; ii++) { + for (int jj = 0; jj < max_size; jj++) { + if (jj < nlist_data.jlist[ii].size()) { + nlist[ii * max_size + jj] = nlist_data.jlist[ii][jj]; + } else { + nlist[ii * max_size + jj] = -1; + } + } + } + input_list[2] = add_input(op, nlist, nlist_shape, data_tensor[2], status); + // mapping; for now, set it to -1, assume it is not used + std::vector mapping_shape = {nframes, padding_to_nall}; + std::vector mapping(nframes * padding_to_nall, -1); + // pass mapping if it is given in the neighbor list + if (lmp_list.mapping) { + // assume nframes is 1 + for (size_t ii = 0; ii < nall_real; ii++) { + mapping[ii] = lmp_list.mapping[fwd_map[ii]]; + } + } + input_list[3] = add_input(op, mapping, mapping_shape, data_tensor[3], status); + // fparam + std::vector fparam_shape = {nframes, dfparam}; + input_list[4] = + add_input(op, fparam_double, fparam_shape, data_tensor[4], status); + // aparam + std::vector aparam_shape = {nframes, nloc_real, daparam}; + input_list[5] = + add_input(op, aparam_double, aparam_shape, data_tensor[5], status); + // execute the function + int nretvals = 6; + TFE_TensorHandle* retvals[nretvals]; + + TFE_Execute(op, retvals, &nretvals, status); + check_status(status); + + // copy data + // the order is: + // energy + // energy_derv_c + // energy_derv_c_redu + // energy_derv_r + // energy_redu + // mask + // it seems the order is the alphabet order? + // not sure whether it is safe to assume the order + tensor_to_vector(ener_double, retvals[4], status); + tensor_to_vector(force_double, retvals[3], status); + tensor_to_vector(virial_double, retvals[2], status); + tensor_to_vector(atom_energy_double, retvals[0], status); + tensor_to_vector(atom_virial_double, retvals[1], status); + + // cast back to VALUETYPE + ener = std::vector(ener_double.begin(), ener_double.end()); + force = std::vector(force_double.begin(), force_double.end()); + virial = std::vector(virial_double.begin(), virial_double.end()); + atom_energy = std::vector(atom_energy_double.begin(), + atom_energy_double.end()); + atom_virial = std::vector(atom_virial_double.begin(), + atom_virial_double.end()); + force.resize(static_cast(nframes) * nall_real * 3); + atom_virial.resize(static_cast(nframes) * nall_real * 9); + + // nall atom_energy is required in the C++ API; + // we always forget it! + atom_energy.resize(static_cast(nframes) * nall_real, 0.0); + + force_.resize(static_cast(nframes) * fwd_map.size() * 3); + atom_energy_.resize(static_cast(nframes) * fwd_map.size()); + atom_virial_.resize(static_cast(nframes) * fwd_map.size() * 9); + select_map(force_, force, bkw_map, 3, nframes, fwd_map.size(), + nall_real); + select_map(atom_energy_, atom_energy, bkw_map, 1, nframes, + fwd_map.size(), nall_real); + select_map(atom_virial_, atom_virial, bkw_map, 9, nframes, + fwd_map.size(), nall_real); + + // cleanup input_list, etc + for (size_t i = 0; i < 6; i++) { + TFE_DeleteTensorHandle(input_list[i]); + TF_DeleteTensor(data_tensor[i]); + } + for (size_t i = 0; i < nretvals; i++) { + TFE_DeleteTensorHandle(retvals[i]); + } + TFE_DeleteOp(op); +} + +template void deepmd::DeepPotJAX::compute( + std::vector& dener, + std::vector& dforce_, + std::vector& dvirial, + std::vector& datom_energy_, + std::vector& datom_virial_, + const std::vector& dcoord_, + const std::vector& datype_, + const std::vector& dbox, + const int nghost, + const InputNlist& lmp_list, + const int& ago, + const std::vector& fparam, + const std::vector& aparam_, + const bool atomic); + +template void deepmd::DeepPotJAX::compute( + std::vector& dener, + std::vector& dforce_, + std::vector& dvirial, + std::vector& datom_energy_, + std::vector& datom_virial_, + const std::vector& dcoord_, + const std::vector& datype_, + const std::vector& dbox, + const int nghost, + const InputNlist& lmp_list, + const int& ago, + const std::vector& fparam, + const std::vector& aparam_, + const bool atomic); + +void deepmd::DeepPotJAX::get_type_map(std::string& type_map_) { + type_map_ = type_map; +} + +// forward to template method +void deepmd::DeepPotJAX::computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic) { + compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box, + fparam, aparam, atomic); +} +void deepmd::DeepPotJAX::computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic) { + compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box, + fparam, aparam, atomic); +} +void deepmd::DeepPotJAX::computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& inlist, + const int& ago, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic) { + compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box, + nghost, inlist, ago, fparam, aparam, atomic); +} +void deepmd::DeepPotJAX::computew(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& inlist, + const int& ago, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic) { + compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box, + nghost, inlist, ago, fparam, aparam, atomic); +} +void deepmd::DeepPotJAX::computew_mixed_type(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const int& nframes, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic) { + throw deepmd::deepmd_exception("not implemented"); +} +void deepmd::DeepPotJAX::computew_mixed_type(std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const int& nframes, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic) { + throw deepmd::deepmd_exception("not implemented"); +} +#endif diff --git a/source/api_cc/tests/test_deeppot_jax.cc b/source/api_cc/tests/test_deeppot_jax.cc new file mode 100644 index 0000000000..439a271015 --- /dev/null +++ b/source/api_cc/tests/test_deeppot_jax.cc @@ -0,0 +1,554 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "DeepPot.h" +#include "neighbor_list.h" +#include "test_utils.h" + +template +class TestInferDeepPotAJAX : public ::testing::Test { + protected: + // import numpy as np + // from deepmd.infer import DeepPot + // coord = np.array([ + // 12.83, 2.56, 2.18, 12.09, 2.87, 2.74, + // 00.25, 3.32, 1.68, 3.36, 3.00, 1.81, + // 3.51, 2.51, 2.60, 4.27, 3.22, 1.56 + // ]).reshape(1, -1) + // atype = np.array([0, 1, 1, 0, 1, 1]) + // box = np.array([13., 0., 0., 0., 13., 0., 0., 0., 13.]).reshape(1, -1) + // dp = DeepPot("deeppot_sea.savedmodel") + // e, f, v, ae, av = dp.eval(coord, box, atype, atomic=True) + // np.set_printoptions(precision=16) + // print(f"{e.ravel()=} {v.ravel()=} {f.ravel()=} {ae.ravel()=} + // {av.ravel()=}") + std::vector coord = {12.83, 2.56, 2.18, 12.09, 2.87, 2.74, + 00.25, 3.32, 1.68, 3.36, 3.00, 1.81, + 3.51, 2.51, 2.60, 4.27, 3.22, 1.56}; + std::vector atype = {0, 1, 1, 0, 1, 1}; + std::vector box = {13., 0., 0., 0., 13., 0., 0., 0., 13.}; + // the data in this file is just copied from PT + std::vector expected_e = { + + -93.016873944029, -185.923296645958, -185.927096544970, + -93.019371018039, -185.926179995548, -185.924351901852}; + std::vector expected_f = { + + 0.006277522211, -0.001117962774, 0.000618580445, 0.009928999655, + 0.003026035654, -0.006941982227, 0.000667853212, -0.002449963843, + 0.006506463508, -0.007284129115, 0.000530662205, -0.000028806821, + 0.000068097781, 0.006121331983, -0.009019754602, -0.009658343745, + -0.006110103225, 0.008865499697}; + std::vector expected_v = { + -0.000155238009, 0.000116605516, -0.007869862476, 0.000465578340, + 0.008182547185, -0.002398713212, -0.008112887338, -0.002423738425, + 0.007210716605, -0.019203504012, 0.001724938709, 0.009909211091, + 0.001153857542, -0.001600015103, -0.000560024090, 0.010727836276, + -0.001034836404, -0.007973454377, -0.021517399106, -0.004064359664, + 0.004866398692, -0.003360038617, -0.007241406162, 0.005920941051, + 0.004899151657, 0.006290788591, -0.006478820311, 0.001921504710, + 0.001313470921, -0.000304091236, 0.001684345981, 0.004124109256, + -0.006396084465, -0.000701095618, -0.006356507032, 0.009818550859, + -0.015230664587, -0.000110244376, 0.000690319396, 0.000045953023, + -0.005726548770, 0.008769818495, -0.000572380210, 0.008860603423, + -0.013819348050, -0.021227082558, -0.004977781343, 0.006646239696, + -0.005987066507, -0.002767831232, 0.003746502525, 0.007697590397, + 0.003746130152, -0.005172634748}; + int natoms; + double expected_tot_e; + std::vector expected_tot_v; + + deepmd::DeepPot dp; + + void SetUp() override { + std::string file_name = "../../tests/infer/deeppot_sea.savedmodel"; + + dp.init(file_name); + + natoms = expected_e.size(); + EXPECT_EQ(natoms * 3, expected_f.size()); + EXPECT_EQ(natoms * 9, expected_v.size()); + expected_tot_e = 0.; + expected_tot_v.resize(9); + std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.); + for (int ii = 0; ii < natoms; ++ii) { + expected_tot_e += expected_e[ii]; + } + for (int ii = 0; ii < natoms; ++ii) { + for (int dd = 0; dd < 9; ++dd) { + expected_tot_v[dd] += expected_v[ii * 9 + dd]; + } + } + } + + void TearDown() override {} +}; + +TYPED_TEST_SUITE(TestInferDeepPotAJAX, ValueTypes); + +TYPED_TEST(TestInferDeepPotAJAX, cpu_build_nlist) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + double ener; + std::vector force, virial; + dp.compute(ener, force, virial, coord, atype, box); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepPotAJAX, cpu_build_nlist_numfv) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + class MyModel : public EnergyModelTest { + deepmd::DeepPot& mydp; + const std::vector atype; + + public: + MyModel(deepmd::DeepPot& dp_, const std::vector& atype_) + : mydp(dp_), atype(atype_) {}; + virtual void compute(double& ener, + std::vector& force, + std::vector& virial, + const std::vector& coord, + const std::vector& box) { + mydp.compute(ener, force, virial, coord, atype, box); + } + }; + MyModel model(dp, atype); + model.test_f(coord, box); + model.test_v(coord, box); + std::vector box_(box); + box_[1] -= 0.4; + model.test_f(coord, box_); + model.test_v(coord, box_); + box_[2] += 0.5; + model.test_f(coord, box_); + model.test_v(coord, box_); + box_[4] += 0.2; + model.test_f(coord, box_); + model.test_v(coord, box_); + box_[3] -= 0.3; + model.test_f(coord, box_); + model.test_v(coord, box_); + box_[6] -= 0.7; + model.test_f(coord, box_); + model.test_v(coord, box_); + box_[7] += 0.6; + model.test_f(coord, box_); + model.test_v(coord, box_); +} + +TYPED_TEST(TestInferDeepPotAJAX, cpu_build_nlist_atomic) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + double ener; + std::vector force, virial, atom_ener, atom_vir; + dp.compute(ener, force, virial, atom_ener, atom_vir, coord, atype, box); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + EXPECT_EQ(atom_ener.size(), natoms); + EXPECT_EQ(atom_vir.size(), natoms * 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } + for (int ii = 0; ii < natoms; ++ii) { + EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON); + } + for (int ii = 0; ii < natoms * 9; ++ii) { + EXPECT_LT(fabs(atom_vir[ii] - expected_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepPotAJAX, cpu_lmp_nlist) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + float rc = dp.cutoff(); + int nloc = coord.size() / 3; + std::vector coord_cpy; + std::vector atype_cpy, mapping; + std::vector > nlist_data; + _build_nlist(nlist_data, coord_cpy, atype_cpy, mapping, coord, + atype, box, rc); + int nall = coord_cpy.size() / 3; + std::vector ilist(nloc), numneigh(nloc); + std::vector firstneigh(nloc); + deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_data); + + double ener; + std::vector force_, virial; + dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, nall - nloc, + inlist, 0); + std::vector force; + _fold_back(force, force_, mapping, nloc, nall, 3); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } + + ener = 0.; + std::fill(force_.begin(), force_.end(), 0.0); + std::fill(virial.begin(), virial.end(), 0.0); + dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, nall - nloc, + inlist, 1); + _fold_back(force, force_, mapping, nloc, nall, 3); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepPotAJAX, cpu_lmp_nlist_atomic) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + float rc = dp.cutoff(); + int nloc = coord.size() / 3; + std::vector coord_cpy; + std::vector atype_cpy, mapping; + std::vector > nlist_data; + _build_nlist(nlist_data, coord_cpy, atype_cpy, mapping, coord, + atype, box, rc); + int nall = coord_cpy.size() / 3; + std::vector ilist(nloc), numneigh(nloc); + std::vector firstneigh(nloc); + deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_data); + double ener; + std::vector force_, atom_ener_, atom_vir_, virial; + std::vector force, atom_ener, atom_vir; + dp.compute(ener, force_, virial, atom_ener_, atom_vir_, coord_cpy, atype_cpy, + box, nall - nloc, inlist, 0); + _fold_back(force, force_, mapping, nloc, nall, 3); + _fold_back(atom_ener, atom_ener_, mapping, nloc, nall, 1); + _fold_back(atom_vir, atom_vir_, mapping, nloc, nall, 9); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + EXPECT_EQ(atom_ener.size(), natoms); + EXPECT_EQ(atom_vir.size(), natoms * 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } + for (int ii = 0; ii < natoms; ++ii) { + EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON); + } + for (int ii = 0; ii < natoms * 9; ++ii) { + EXPECT_LT(fabs(atom_vir[ii] - expected_v[ii]), EPSILON); + } + + ener = 0.; + std::fill(force_.begin(), force_.end(), 0.0); + std::fill(virial.begin(), virial.end(), 0.0); + std::fill(atom_ener_.begin(), atom_ener_.end(), 0.0); + std::fill(atom_vir_.begin(), atom_vir_.end(), 0.0); + dp.compute(ener, force_, virial, atom_ener_, atom_vir_, coord_cpy, atype_cpy, + box, nall - nloc, inlist, 1); + _fold_back(force, force_, mapping, nloc, nall, 3); + _fold_back(atom_ener, atom_ener_, mapping, nloc, nall, 1); + _fold_back(atom_vir, atom_vir_, mapping, nloc, nall, 9); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + EXPECT_EQ(atom_ener.size(), natoms); + EXPECT_EQ(atom_vir.size(), natoms * 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } + for (int ii = 0; ii < natoms; ++ii) { + EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON); + } + for (int ii = 0; ii < natoms * 9; ++ii) { + EXPECT_LT(fabs(atom_vir[ii] - expected_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepPotAJAX, cpu_lmp_nlist_2rc) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + float rc = dp.cutoff(); + int nloc = coord.size() / 3; + std::vector coord_cpy; + std::vector atype_cpy, mapping; + std::vector > nlist_data; + _build_nlist(nlist_data, coord_cpy, atype_cpy, mapping, coord, + atype, box, rc * 2); + int nall = coord_cpy.size() / 3; + std::vector ilist(nloc), numneigh(nloc); + std::vector firstneigh(nloc); + deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_data); + + double ener; + std::vector force_(nall * 3, 0.0), virial(9, 0.0); + dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, nall - nloc, + inlist, 0); + std::vector force; + _fold_back(force, force_, mapping, nloc, nall, 3); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } + + ener = 0.; + std::fill(force_.begin(), force_.end(), 0.0); + std::fill(virial.begin(), virial.end(), 0.0); + dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, nall - nloc, + inlist, 1); + _fold_back(force, force_, mapping, nloc, nall, 3); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepPotAJAX, cpu_lmp_nlist_type_sel) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + float rc = dp.cutoff(); + + // add vir atoms + int nvir = 2; + std::vector coord_vir(nvir * 3); + std::vector atype_vir(nvir, 2); + for (int ii = 0; ii < nvir; ++ii) { + coord_vir[ii] = coord[ii]; + } + coord.insert(coord.begin(), coord_vir.begin(), coord_vir.end()); + atype.insert(atype.begin(), atype_vir.begin(), atype_vir.end()); + natoms += nvir; + std::vector expected_f_vir(nvir * 3, 0.0); + expected_f.insert(expected_f.begin(), expected_f_vir.begin(), + expected_f_vir.end()); + + // build nlist + int nloc = coord.size() / 3; + std::vector coord_cpy; + std::vector atype_cpy, mapping; + std::vector > nlist_data; + _build_nlist(nlist_data, coord_cpy, atype_cpy, mapping, coord, + atype, box, rc); + int nall = coord_cpy.size() / 3; + std::vector ilist(nloc), numneigh(nloc); + std::vector firstneigh(nloc); + deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_data); + + // dp compute + double ener; + std::vector force_(nall * 3, 0.0), virial(9, 0.0); + dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, nall - nloc, + inlist, 0); + // fold back + std::vector force; + _fold_back(force, force_, mapping, nloc, nall, 3); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepPotAJAX, cpu_lmp_nlist_type_sel_atomic) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_v = this->expected_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + float rc = dp.cutoff(); + + // add vir atoms + int nvir = 2; + std::vector coord_vir(nvir * 3); + std::vector atype_vir(nvir, 2); + for (int ii = 0; ii < nvir; ++ii) { + coord_vir[ii] = coord[ii]; + } + coord.insert(coord.begin(), coord_vir.begin(), coord_vir.end()); + atype.insert(atype.begin(), atype_vir.begin(), atype_vir.end()); + natoms += nvir; + std::vector expected_f_vir(nvir * 3, 0.0); + expected_f.insert(expected_f.begin(), expected_f_vir.begin(), + expected_f_vir.end()); + + // build nlist + int nloc = coord.size() / 3; + std::vector coord_cpy; + std::vector atype_cpy, mapping; + std::vector > nlist_data; + _build_nlist(nlist_data, coord_cpy, atype_cpy, mapping, coord, + atype, box, rc); + int nall = coord_cpy.size() / 3; + std::vector ilist(nloc), numneigh(nloc); + std::vector firstneigh(nloc); + deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_data); + + // dp compute + double ener; + std::vector force_(nall * 3, 0.0), virial(9, 0.0), atomic_energy, + atomic_virial; + dp.compute(ener, force_, virial, atomic_energy, atomic_virial, coord_cpy, + atype_cpy, box, nall - nloc, inlist, 0); + // fold back + std::vector force; + _fold_back(force, force_, mapping, nloc, nall, 3); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepPotAJAX, print_summary) { + deepmd::DeepPot& dp = this->dp; + dp.print_summary(""); +} + +TYPED_TEST(TestInferDeepPotAJAX, get_type_map) { + deepmd::DeepPot& dp = this->dp; + std::string type_map; + dp.get_type_map(type_map); + EXPECT_EQ(type_map, "O H"); +} diff --git a/source/cmake/googletest.cmake.in b/source/cmake/googletest.cmake.in index 5d167cf774..85c3745c00 100644 --- a/source/cmake/googletest.cmake.in +++ b/source/cmake/googletest.cmake.in @@ -11,7 +11,7 @@ endif() include(ExternalProject) ExternalProject_Add(googletest GIT_REPOSITORY ${GTEST_REPO_ADDRESS} - GIT_TAG release-1.12.1 + GIT_TAG v1.14.0 GIT_SHALLOW TRUE SOURCE_DIR "@CMAKE_CURRENT_BINARY_DIR@/googletest-src" BINARY_DIR "@CMAKE_CURRENT_BINARY_DIR@/googletest-build" diff --git a/source/lib/include/neighbor_list.h b/source/lib/include/neighbor_list.h index bb4b8cf13c..5b39ea7454 100644 --- a/source/lib/include/neighbor_list.h +++ b/source/lib/include/neighbor_list.h @@ -44,6 +44,8 @@ struct InputNlist { void* world; /// mask to the neighbor index int mask = 0xFFFFFFFF; + /// mapping from all atoms to real atoms, in the size of nall + int* mapping = nullptr; InputNlist() : inum(0), ilist(NULL), @@ -99,6 +101,10 @@ struct InputNlist { * @brief Set mask for this neighbor list. */ void set_mask(int mask_) { mask = mask_; }; + /** + * @brief Set mapping for this neighbor list. + */ + void set_mapping(int* mapping_) { mapping = mapping_; }; }; /** diff --git a/source/lmp/fix_dplr.cpp b/source/lmp/fix_dplr.cpp index 8e54410d0a..ac161730db 100644 --- a/source/lmp/fix_dplr.cpp +++ b/source/lmp/fix_dplr.cpp @@ -467,6 +467,14 @@ void FixDPLR::pre_force(int vflag) { int nghost = atom->nghost; int nall = nlocal + nghost; + // mapping (for DPA-2 JAX) + std::vector mapping_vec(nall, -1); + if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) { + for (size_t ii = 0; ii < nall; ++ii) { + mapping_vec[ii] = atom->map(atom->tag[ii]); + } + } + // if (eflag_atom) { // error->all(FLERR,"atomic energy calculation is not supported by this // fix\n"); @@ -499,6 +507,9 @@ void FixDPLR::pre_force(int vflag) { deepmd_compat::InputNlist lmp_list(list->inum, list->ilist, list->numneigh, list->firstneigh); lmp_list.set_mask(NEIGHMASK); + if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) { + lmp_list.set_mapping(mapping_vec.data()); + } // declear output vector tensor; // compute diff --git a/source/lmp/pair_deepmd.cpp b/source/lmp/pair_deepmd.cpp index 6d12fda20a..8127979cd1 100644 --- a/source/lmp/pair_deepmd.cpp +++ b/source/lmp/pair_deepmd.cpp @@ -155,6 +155,14 @@ void PairDeepMD::compute(int eflag, int vflag) { } } + // mapping (for DPA-2 JAX) + std::vector mapping_vec(nall, -1); + if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) { + for (size_t ii = 0; ii < nall; ++ii) { + mapping_vec[ii] = atom->map(atom->tag[ii]); + } + } + if (do_compute_aparam) { make_aparam_from_compute(daparam); } else if (aparam.size() > 0) { @@ -198,6 +206,9 @@ void PairDeepMD::compute(int eflag, int vflag) { commdata_->firstrecv, commdata_->sendlist, commdata_->sendproc, commdata_->recvproc, &world); lmp_list.set_mask(NEIGHMASK); + if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) { + lmp_list.set_mapping(mapping_vec.data()); + } deepmd_compat::InputNlist extend_lmp_list; if (single_model || multi_models_no_mod_devi) { // cvflag_atom is the right flag for the cvatom matrix diff --git a/source/lmp/tests/test_lammps_dpa_jax.py b/source/lmp/tests/test_lammps_dpa_jax.py new file mode 100644 index 0000000000..10428b2374 --- /dev/null +++ b/source/lmp/tests/test_lammps_dpa_jax.py @@ -0,0 +1,726 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +import os +import shutil +import subprocess as sp +import sys +import tempfile +from pathlib import ( + Path, +) + +import constants +import numpy as np +import pytest +from lammps import ( + PyLammps, +) +from write_lmp_data import ( + write_lmp_data, +) + +pbtxt_file2 = ( + Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot-1.pbtxt" +) +pb_file = ( + Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot_dpa.savedmodel" +) +pb_file2 = Path(__file__).parent / "graph2.pb" +system_file = Path(__file__).parent.parent.parent / "tests" +data_file = Path(__file__).parent / "data.lmp" +data_file_si = Path(__file__).parent / "data.si" +data_type_map_file = Path(__file__).parent / "data_type_map.lmp" +md_file = Path(__file__).parent / "md.out" + +# this is as the same as python and c++ tests, test_deeppot_a.py +expected_ae = np.array( + [ + -94.24098099691867, + -187.8049502787117, + -187.80486052083617, + -94.24059525229518, + -187.80366985846246, + -187.8042377490619, + ] +) +expected_e = np.sum(expected_ae) +expected_f = np.array( + [ + -0.0020150115442053, + -0.0133389255924977, + -0.0014347177433057, + -0.0140757358179293, + 0.0031373814221557, + 0.0098594354314677, + 0.004755683505073, + 0.0099471082374397, + -0.0080868184532793, + -0.0086166721574536, + 0.0037803939137322, + -0.0075733131286482, + 0.0037437603038209, + -0.008452527996008, + 0.0134837461840424, + 0.0162079757106944, + 0.0049265700151781, + -0.0062483322902769, + ] +).reshape(6, 3) + +expected_f2 = np.array( + [ + [-0.6454949, 1.72457783, 0.18897958], + [1.68936514, -0.36995299, -1.36044464], + [-1.09902692, -1.35487928, 1.17416702], + [1.68426111, -0.50835585, 0.98340415], + [0.05771758, 1.12515818, -1.77561531], + [-1.686822, -0.61654789, 0.78950921], + ] +) + +expected_v = -np.array( + [ + 0.0133534319524089, + 0.0013445914938337, + -0.0029370551651952, + 0.0002611806151294, + 0.004662662211533, + -0.0002717443796319, + -0.0027779798869954, + -0.0003277976466339, + 0.0018284972283065, + 0.0085710118978246, + 0.0003865036653608, + -0.0057964032875089, + -0.0014358330222619, + 0.0002912625128908, + 0.001212630641674, + -0.0050582608957046, + -0.0001087907763249, + 0.0040068757134429, + 0.0116736349373084, + 0.0007055477968445, + -0.0019544933708784, + 0.0032997459258512, + 0.0037887116116712, + -0.0043140890650835, + -0.0034418738401156, + -0.0029420616852742, + 0.0038219676716965, + 0.0147134944025738, + 0.0005214313829998, + -0.0006524136175906, + 0.0003656980996363, + 0.0010046161607714, + -0.0017279359476254, + 0.000111127036911, + -0.0017063190420654, + 0.0030174567965904, + 0.0104435705455108, + -0.0008704394438241, + 0.0012354202650812, + 0.0009397615830053, + 0.0029105236407293, + -0.0044188897903449, + -0.0011461513500477, + -0.0045759080125852, + 0.0070310883421107, + 0.0089818851995049, + 0.0038819466696704, + -0.005443705549253, + 0.0025390283635246, + 0.0012121502955869, + -0.0016998728971157, + -0.0032355117893925, + -0.0015590242752438, + 0.0021980725909838, + ] +).reshape(6, 9) +expected_v2 = -np.array( + [ + [ + -0.70008436, + -0.06399891, + 0.63678391, + -0.07642171, + -0.70580035, + 0.20506145, + 0.64098364, + 0.20305781, + -0.57906794, + ], + [ + -0.6372635, + 0.14315552, + 0.51952246, + 0.04604049, + -0.06003681, + -0.02688702, + 0.54489318, + -0.10951559, + -0.43730539, + ], + [ + -0.25090748, + -0.37466262, + 0.34085833, + -0.26690852, + -0.37676917, + 0.29080825, + 0.31600481, + 0.37558276, + -0.33251064, + ], + [ + -0.80195614, + -0.10273138, + 0.06935364, + -0.10429256, + -0.29693811, + 0.45643496, + 0.07247872, + 0.45604679, + -0.71048816, + ], + [ + -0.03840668, + -0.07680205, + 0.10940472, + -0.02374189, + -0.27610266, + 0.4336071, + 0.02465248, + 0.4290638, + -0.67496763, + ], + [ + -0.61475065, + -0.21163135, + 0.26652929, + -0.26134659, + -0.11560267, + 0.15415902, + 0.34343952, + 0.1589482, + -0.21370642, + ], + ] +).reshape(6, 9) + +box = np.array([0, 13, 0, 13, 0, 13, 0, 0, 0]) +coord = np.array( + [ + [12.83, 2.56, 2.18], + [12.09, 2.87, 2.74], + [0.25, 3.32, 1.68], + [3.36, 3.00, 1.81], + [3.51, 2.51, 2.60], + [4.27, 3.22, 1.56], + ] +) +type_OH = np.array([1, 2, 2, 1, 2, 2]) +type_HO = np.array([2, 1, 1, 2, 1, 1]) + + +sp.check_output( + f"{sys.executable} -m deepmd convert-from pbtxt -i {pbtxt_file2.resolve()} -o {pb_file2.resolve()}".split() +) + + +def setup_module(): + write_lmp_data(box, coord, type_OH, data_file) + write_lmp_data(box, coord, type_HO, data_type_map_file) + write_lmp_data( + box * constants.dist_metal2si, + coord * constants.dist_metal2si, + type_OH, + data_file_si, + ) + + +def teardown_module(): + os.remove(data_file) + os.remove(data_type_map_file) + + +def _lammps(data_file, units="metal") -> PyLammps: + lammps = PyLammps() + lammps.units(units) + lammps.boundary("p p p") + lammps.atom_style("atomic") + # Requires for DPA-2 + lammps.atom_modify("map yes") + if units == "metal" or units == "real": + lammps.neighbor("2.0 bin") + elif units == "si": + lammps.neighbor("2.0e-10 bin") + else: + raise ValueError("units should be metal, real, or si") + lammps.neigh_modify("every 10 delay 0 check no") + lammps.read_data(data_file.resolve()) + if units == "metal" or units == "real": + lammps.mass("1 16") + lammps.mass("2 2") + elif units == "si": + lammps.mass("1 %.10e" % (16 * constants.mass_metal2si)) + lammps.mass("2 %.10e" % (2 * constants.mass_metal2si)) + else: + raise ValueError("units should be metal, real, or si") + if units == "metal": + lammps.timestep(0.0005) + elif units == "real": + lammps.timestep(0.5) + elif units == "si": + lammps.timestep(5e-16) + else: + raise ValueError("units should be metal, real, or si") + lammps.fix("1 all nve") + return lammps + + +@pytest.fixture +def lammps(): + lmp = _lammps(data_file=data_file) + yield lmp + lmp.close() + + +@pytest.fixture +def lammps_type_map(): + lmp = _lammps(data_file=data_type_map_file) + yield lmp + lmp.close() + + +@pytest.fixture +def lammps_real(): + lmp = _lammps(data_file=data_file, units="real") + yield lmp + lmp.close() + + +@pytest.fixture +def lammps_si(): + lmp = _lammps(data_file=data_file_si, units="si") + yield lmp + lmp.close() + + +def test_pair_deepmd(lammps): + lammps.pair_style(f"deepmd {pb_file.resolve()}") + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + lammps.run(1) + + +def test_pair_deepmd_virial(lammps): + lammps.pair_style(f"deepmd {pb_file.resolve()}") + lammps.pair_coeff("* *") + lammps.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + idx_map = lammps.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps.variables[f"virial{ii}"].value + ) / constants.nktv2p == pytest.approx(expected_v[idx_map, ii]) + + +def test_pair_deepmd_model_devi(lammps): + lammps.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic" + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_model_devi_virial(lammps): + lammps.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic" + ) + lammps.pair_coeff("* *") + lammps.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + idx_map = lammps.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps.variables[f"virial{ii}"].value + ) / constants.nktv2p == pytest.approx(expected_v[idx_map, ii]) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_model_devi_atomic_relative(lammps): + relative = 1.0 + lammps.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative {relative}" + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + norm = np.linalg.norm(np.mean([expected_f, expected_f2], axis=0), axis=1) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + expected_md_f /= norm + relative + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_model_devi_atomic_relative_v(lammps): + relative = 1.0 + lammps.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative_v {relative}" + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + norm = ( + np.abs( + np.mean([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) + ) + / 6 + ) + expected_md_v /= norm + relative + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_type_map(lammps_type_map): + lammps_type_map.pair_style(f"deepmd {pb_file.resolve()}") + lammps_type_map.pair_coeff("* * H O") + lammps_type_map.run(0) + assert lammps_type_map.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps_type_map.atoms[ii].force == pytest.approx( + expected_f[lammps_type_map.atoms[ii].id - 1] + ) + lammps_type_map.run(1) + + +def test_pair_deepmd_real(lammps_real): + lammps_real.pair_style(f"deepmd {pb_file.resolve()}") + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + lammps_real.run(1) + + +def test_pair_deepmd_virial_real(lammps_real): + lammps_real.pair_style(f"deepmd {pb_file.resolve()}") + lammps_real.pair_coeff("* *") + lammps_real.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps_real.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps_real.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + idx_map = lammps_real.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps_real.variables[f"virial{ii}"].value + ) / constants.nktv2p_real == pytest.approx( + expected_v[idx_map, ii] * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_real(lammps_real): + lammps_real.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic" + ) + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_virial_real(lammps_real): + lammps_real.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic" + ) + lammps_real.pair_coeff("* *") + lammps_real.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps_real.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps_real.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + idx_map = lammps_real.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps_real.variables[f"virial{ii}"].value + ) / constants.nktv2p_real == pytest.approx( + expected_v[idx_map, ii] * constants.ener_metal2real + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_atomic_relative_real(lammps_real): + relative = 1.0 + lammps_real.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative {relative * constants.force_metal2real}" + ) + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + norm = np.linalg.norm(np.mean([expected_f, expected_f2], axis=0), axis=1) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + expected_md_f /= norm + relative + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_atomic_relative_v_real(lammps_real): + relative = 1.0 + lammps_real.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative_v {relative * constants.ener_metal2real}" + ) + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + norm = ( + np.abs( + np.mean([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) + ) + / 6 + ) + expected_md_v /= norm + relative + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_si(lammps_si): + lammps_si.pair_style(f"deepmd {pb_file.resolve()}") + lammps_si.pair_coeff("* *") + lammps_si.run(0) + assert lammps_si.eval("pe") == pytest.approx(expected_e * constants.ener_metal2si) + for ii in range(6): + assert lammps_si.atoms[ii].force == pytest.approx( + expected_f[lammps_si.atoms[ii].id - 1] * constants.force_metal2si + ) + lammps_si.run(1) + + +@pytest.mark.skipif( + shutil.which("mpirun") is None, reason="MPI is not installed on this system" +) +@pytest.mark.skipif( + importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed" +) +@pytest.mark.parametrize( + ("balance_args",), + [(["--balance"],), ([],)], +) +@pytest.mark.skip("MPI is not supported") +def test_pair_deepmd_mpi(balance_args: list): + with tempfile.NamedTemporaryFile() as f: + sp.check_call( + [ + "mpirun", + "-n", + "2", + sys.executable, + Path(__file__).parent / "run_mpi_pair_deepmd.py", + data_file, + pb_file, + pb_file2, + md_file, + f.name, + *balance_args, + ] + ) + arr = np.loadtxt(f.name, ndmin=1) + pe = arr[0] + + relative = 1.0 + assert pe == pytest.approx(expected_e) + # load model devi + md = np.loadtxt(md_file.resolve()) + norm = np.linalg.norm(np.mean([expected_f, expected_f2], axis=0), axis=1) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + expected_md_f /= norm + relative + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) diff --git a/source/lmp/tests/test_lammps_jax.py b/source/lmp/tests/test_lammps_jax.py new file mode 100644 index 0000000000..6d67cd3203 --- /dev/null +++ b/source/lmp/tests/test_lammps_jax.py @@ -0,0 +1,723 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +import os +import shutil +import subprocess as sp +import sys +import tempfile +from pathlib import ( + Path, +) + +import constants +import numpy as np +import pytest +from lammps import ( + PyLammps, +) +from write_lmp_data import ( + write_lmp_data, +) + +pbtxt_file2 = ( + Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot-1.pbtxt" +) +pb_file = ( + Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot_sea.savedmodel" +) +pb_file2 = Path(__file__).parent / "graph2.pb" +system_file = Path(__file__).parent.parent.parent / "tests" +data_file = Path(__file__).parent / "data.lmp" +data_file_si = Path(__file__).parent / "data.si" +data_type_map_file = Path(__file__).parent / "data_type_map.lmp" +md_file = Path(__file__).parent / "md.out" + +# this is as the same as python and c++ tests, test_deeppot_a.py +expected_ae = np.array( + [ + -93.016873944029, + -185.923296645958, + -185.927096544970, + -93.019371018039, + -185.926179995548, + -185.924351901852, + ] +) +expected_e = np.sum(expected_ae) +expected_f = np.array( + [ + 0.006277522211, + -0.001117962774, + 0.000618580445, + 0.009928999655, + 0.003026035654, + -0.006941982227, + 0.000667853212, + -0.002449963843, + 0.006506463508, + -0.007284129115, + 0.000530662205, + -0.000028806821, + 0.000068097781, + 0.006121331983, + -0.009019754602, + -0.009658343745, + -0.006110103225, + 0.008865499697, + ] +).reshape(6, 3) + +expected_f2 = np.array( + [ + [-0.6454949, 1.72457783, 0.18897958], + [1.68936514, -0.36995299, -1.36044464], + [-1.09902692, -1.35487928, 1.17416702], + [1.68426111, -0.50835585, 0.98340415], + [0.05771758, 1.12515818, -1.77561531], + [-1.686822, -0.61654789, 0.78950921], + ] +) + +expected_v = -np.array( + [ + -0.000155238009, + 0.000116605516, + -0.007869862476, + 0.000465578340, + 0.008182547185, + -0.002398713212, + -0.008112887338, + -0.002423738425, + 0.007210716605, + -0.019203504012, + 0.001724938709, + 0.009909211091, + 0.001153857542, + -0.001600015103, + -0.000560024090, + 0.010727836276, + -0.001034836404, + -0.007973454377, + -0.021517399106, + -0.004064359664, + 0.004866398692, + -0.003360038617, + -0.007241406162, + 0.005920941051, + 0.004899151657, + 0.006290788591, + -0.006478820311, + 0.001921504710, + 0.001313470921, + -0.000304091236, + 0.001684345981, + 0.004124109256, + -0.006396084465, + -0.000701095618, + -0.006356507032, + 0.009818550859, + -0.015230664587, + -0.000110244376, + 0.000690319396, + 0.000045953023, + -0.005726548770, + 0.008769818495, + -0.000572380210, + 0.008860603423, + -0.013819348050, + -0.021227082558, + -0.004977781343, + 0.006646239696, + -0.005987066507, + -0.002767831232, + 0.003746502525, + 0.007697590397, + 0.003746130152, + -0.005172634748, + ] +).reshape(6, 9) +expected_v2 = -np.array( + [ + [ + -0.70008436, + -0.06399891, + 0.63678391, + -0.07642171, + -0.70580035, + 0.20506145, + 0.64098364, + 0.20305781, + -0.57906794, + ], + [ + -0.6372635, + 0.14315552, + 0.51952246, + 0.04604049, + -0.06003681, + -0.02688702, + 0.54489318, + -0.10951559, + -0.43730539, + ], + [ + -0.25090748, + -0.37466262, + 0.34085833, + -0.26690852, + -0.37676917, + 0.29080825, + 0.31600481, + 0.37558276, + -0.33251064, + ], + [ + -0.80195614, + -0.10273138, + 0.06935364, + -0.10429256, + -0.29693811, + 0.45643496, + 0.07247872, + 0.45604679, + -0.71048816, + ], + [ + -0.03840668, + -0.07680205, + 0.10940472, + -0.02374189, + -0.27610266, + 0.4336071, + 0.02465248, + 0.4290638, + -0.67496763, + ], + [ + -0.61475065, + -0.21163135, + 0.26652929, + -0.26134659, + -0.11560267, + 0.15415902, + 0.34343952, + 0.1589482, + -0.21370642, + ], + ] +).reshape(6, 9) + +box = np.array([0, 13, 0, 13, 0, 13, 0, 0, 0]) +coord = np.array( + [ + [12.83, 2.56, 2.18], + [12.09, 2.87, 2.74], + [0.25, 3.32, 1.68], + [3.36, 3.00, 1.81], + [3.51, 2.51, 2.60], + [4.27, 3.22, 1.56], + ] +) +type_OH = np.array([1, 2, 2, 1, 2, 2]) +type_HO = np.array([2, 1, 1, 2, 1, 1]) + + +sp.check_output( + f"{sys.executable} -m deepmd convert-from pbtxt -i {pbtxt_file2.resolve()} -o {pb_file2.resolve()}".split() +) + + +def setup_module(): + write_lmp_data(box, coord, type_OH, data_file) + write_lmp_data(box, coord, type_HO, data_type_map_file) + write_lmp_data( + box * constants.dist_metal2si, + coord * constants.dist_metal2si, + type_OH, + data_file_si, + ) + + +def teardown_module(): + os.remove(data_file) + os.remove(data_type_map_file) + + +def _lammps(data_file, units="metal") -> PyLammps: + lammps = PyLammps() + lammps.units(units) + lammps.boundary("p p p") + lammps.atom_style("atomic") + if units == "metal" or units == "real": + lammps.neighbor("2.0 bin") + elif units == "si": + lammps.neighbor("2.0e-10 bin") + else: + raise ValueError("units should be metal, real, or si") + lammps.neigh_modify("every 10 delay 0 check no") + lammps.read_data(data_file.resolve()) + if units == "metal" or units == "real": + lammps.mass("1 16") + lammps.mass("2 2") + elif units == "si": + lammps.mass("1 %.10e" % (16 * constants.mass_metal2si)) + lammps.mass("2 %.10e" % (2 * constants.mass_metal2si)) + else: + raise ValueError("units should be metal, real, or si") + if units == "metal": + lammps.timestep(0.0005) + elif units == "real": + lammps.timestep(0.5) + elif units == "si": + lammps.timestep(5e-16) + else: + raise ValueError("units should be metal, real, or si") + lammps.fix("1 all nve") + return lammps + + +@pytest.fixture +def lammps(): + lmp = _lammps(data_file=data_file) + yield lmp + lmp.close() + + +@pytest.fixture +def lammps_type_map(): + lmp = _lammps(data_file=data_type_map_file) + yield lmp + lmp.close() + + +@pytest.fixture +def lammps_real(): + lmp = _lammps(data_file=data_file, units="real") + yield lmp + lmp.close() + + +@pytest.fixture +def lammps_si(): + lmp = _lammps(data_file=data_file_si, units="si") + yield lmp + lmp.close() + + +def test_pair_deepmd(lammps): + lammps.pair_style(f"deepmd {pb_file.resolve()}") + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + lammps.run(1) + + +def test_pair_deepmd_virial(lammps): + lammps.pair_style(f"deepmd {pb_file.resolve()}") + lammps.pair_coeff("* *") + lammps.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + idx_map = lammps.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps.variables[f"virial{ii}"].value + ) / constants.nktv2p == pytest.approx(expected_v[idx_map, ii]) + + +def test_pair_deepmd_model_devi(lammps): + lammps.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic" + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_model_devi_virial(lammps): + lammps.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic" + ) + lammps.pair_coeff("* *") + lammps.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + idx_map = lammps.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps.variables[f"virial{ii}"].value + ) / constants.nktv2p == pytest.approx(expected_v[idx_map, ii]) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_model_devi_atomic_relative(lammps): + relative = 1.0 + lammps.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative {relative}" + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + norm = np.linalg.norm(np.mean([expected_f, expected_f2], axis=0), axis=1) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + expected_md_f /= norm + relative + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_model_devi_atomic_relative_v(lammps): + relative = 1.0 + lammps.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative_v {relative}" + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + norm = ( + np.abs( + np.mean([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) + ) + / 6 + ) + expected_md_v /= norm + relative + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_type_map(lammps_type_map): + lammps_type_map.pair_style(f"deepmd {pb_file.resolve()}") + lammps_type_map.pair_coeff("* * H O") + lammps_type_map.run(0) + assert lammps_type_map.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps_type_map.atoms[ii].force == pytest.approx( + expected_f[lammps_type_map.atoms[ii].id - 1] + ) + lammps_type_map.run(1) + + +def test_pair_deepmd_real(lammps_real): + lammps_real.pair_style(f"deepmd {pb_file.resolve()}") + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + lammps_real.run(1) + + +def test_pair_deepmd_virial_real(lammps_real): + lammps_real.pair_style(f"deepmd {pb_file.resolve()}") + lammps_real.pair_coeff("* *") + lammps_real.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps_real.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps_real.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + idx_map = lammps_real.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps_real.variables[f"virial{ii}"].value + ) / constants.nktv2p_real == pytest.approx( + expected_v[idx_map, ii] * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_real(lammps_real): + lammps_real.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic" + ) + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_virial_real(lammps_real): + lammps_real.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic" + ) + lammps_real.pair_coeff("* *") + lammps_real.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps_real.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps_real.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + idx_map = lammps_real.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps_real.variables[f"virial{ii}"].value + ) / constants.nktv2p_real == pytest.approx( + expected_v[idx_map, ii] * constants.ener_metal2real + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_atomic_relative_real(lammps_real): + relative = 1.0 + lammps_real.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative {relative * constants.force_metal2real}" + ) + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + norm = np.linalg.norm(np.mean([expected_f, expected_f2], axis=0), axis=1) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + expected_md_f /= norm + relative + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_atomic_relative_v_real(lammps_real): + relative = 1.0 + lammps_real.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative_v {relative * constants.ener_metal2real}" + ) + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + norm = ( + np.abs( + np.mean([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) + ) + / 6 + ) + expected_md_v /= norm + relative + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_si(lammps_si): + lammps_si.pair_style(f"deepmd {pb_file.resolve()}") + lammps_si.pair_coeff("* *") + lammps_si.run(0) + assert lammps_si.eval("pe") == pytest.approx(expected_e * constants.ener_metal2si) + for ii in range(6): + assert lammps_si.atoms[ii].force == pytest.approx( + expected_f[lammps_si.atoms[ii].id - 1] * constants.force_metal2si + ) + lammps_si.run(1) + + +@pytest.mark.skipif( + shutil.which("mpirun") is None, reason="MPI is not installed on this system" +) +@pytest.mark.skipif( + importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed" +) +@pytest.mark.parametrize( + ("balance_args",), + [(["--balance"],), ([],)], +) +def test_pair_deepmd_mpi(balance_args: list): + with tempfile.NamedTemporaryFile() as f: + sp.check_call( + [ + "mpirun", + "-n", + "2", + sys.executable, + Path(__file__).parent / "run_mpi_pair_deepmd.py", + data_file, + pb_file, + pb_file2, + md_file, + f.name, + *balance_args, + ] + ) + arr = np.loadtxt(f.name, ndmin=1) + pe = arr[0] + + relative = 1.0 + assert pe == pytest.approx(expected_e) + # load model devi + md = np.loadtxt(md_file.resolve()) + norm = np.linalg.norm(np.mean([expected_f, expected_f2], axis=0), axis=1) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + expected_md_f /= norm + relative + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) diff --git a/source/tests/infer/convert-models.sh b/source/tests/infer/convert-models.sh new file mode 100755 index 0000000000..d74023b9fd --- /dev/null +++ b/source/tests/infer/convert-models.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +set -ev + +SCRIPT_PATH=$(dirname $(realpath -s $0)) + +dp convert-backend ${SCRIPT_PATH}/deeppot_sea.yaml ${SCRIPT_PATH}/deeppot_sea.savedmodel +dp convert-backend ${SCRIPT_PATH}/deeppot_dpa.yaml ${SCRIPT_PATH}/deeppot_dpa.savedmodel diff --git a/source/tests/infer/deeppot_dpa.yaml b/source/tests/infer/deeppot_dpa.yaml new file mode 100644 index 0000000000..29cf7c7b5d --- /dev/null +++ b/source/tests/infer/deeppot_dpa.yaml @@ -0,0 +1,5520 @@ +"@variables": {} +backend: PyTorch +model: + "@class": Model + "@variables": + out_bias: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - -93.57372029622395 + - - -187.1474405924479 + out_std: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 1.0 + - - 1.0 + "@version": 2 + atom_exclude_types: [] + descriptor: + "@class": Descriptor + "@version": 3 + add_tebd_to_repinit_out: false + concat_output_tebd: true + env_protection: 0.0 + exclude_types: [] + g1_shape_tranform: + "@class": Layer + "@variables": + b: null + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.20483269303791724 + - -0.21992468328757978 + - -0.26684291550680395 + - 0.37176791042370194 + - 0.012112465815266748 + - - -0.10228811869581654 + - 0.08533578921915844 + - -0.3127677683059358 + - 0.032891262984076214 + - 0.23853603507174417 + - - -0.11012778988715048 + - 0.002332609590948461 + - 0.1086418697477636 + - -0.1658026622411995 + - -0.21664326015096075 + - - -0.3527700628717402 + - -0.0075194560495788045 + - 0.18179527579497576 + - -0.003956996745727664 + - 0.1718867759576585 + - - 0.0690111857182912 + - 0.12422934001100412 + - 0.2162151055573501 + - -0.24322535261890577 + - -0.029043758026887957 + - - 0.34780578564256326 + - 0.20364573090748753 + - -0.05444165068567216 + - 0.030228628305679458 + - -0.36336997419904504 + - - -0.20379881010311193 + - -0.2128475427919148 + - -0.037252887701215716 + - -0.23563042001690995 + - 0.038193025137633965 + - - 0.11171587196232823 + - 0.1670429227878146 + - -0.1463146137690288 + - 0.33725869994865676 + - -0.251623438391177 + - - -0.3906697623605178 + - 0.3863150686237999 + - -0.37690452827133364 + - -0.34893479030358127 + - -0.07335302931907087 + - - 0.16693800781399992 + - 0.005394633503732046 + - 0.05236532931922369 + - -0.2320975481004007 + - -0.25842256129817975 + - - -0.07725894463746617 + - 0.150649458663615 + - 0.050087480712582094 + - 0.3745134756446062 + - -0.02431033793295473 + - - -0.27003908777563884 + - 0.2901709025927058 + - -0.3546670523393916 + - -0.36436155271432985 + - 0.03651470716837575 + - - -0.32108833643814677 + - 0.11802772716870968 + - 0.14881127437278965 + - 0.26406953165348784 + - 0.26049877020851886 + - - 0.17508942727037166 + - 0.07516552799833877 + - 0.17850291315018707 + - 0.3601769599806957 + - 0.003466820196038751 + - - 0.3667723047384918 + - -0.11287631309355094 + - 0.2272234189515299 + - 0.05287271135296442 + - 0.13257295366264804 + - - -0.22845857796273042 + - 0.03096872749873777 + - -0.23055796402292011 + - 0.13873363458620155 + - -0.18488270405859109 + - - -0.18400661973285898 + - -0.016123380005281133 + - -0.15654111351971653 + - 0.32911107863020456 + - -0.19672320640130256 + - - 0.05193497288510673 + - -0.3032556512293548 + - 0.36893267229696947 + - 0.003620948210812366 + - -0.0390227870596316 + - - 0.11798065615207687 + - -0.21463894028272787 + - 0.14470957159240572 + - 0.10473807077972676 + - -0.3034827834248236 + - - 0.18143999845728037 + - -0.29022845895287624 + - 0.24239889170806395 + - 0.08092498547972143 + - -0.27091058258618883 + - - -0.025962254895722407 + - -0.32320042757474404 + - 0.00944493440361434 + - -0.37659211401566733 + - -0.12090975734412258 + - - 0.13914846278313975 + - -0.36190265610728367 + - -0.345121739571488 + - 0.3666897475740704 + - -0.23663407984613768 + - - -0.28117724230097424 + - -0.3482935640251905 + - 0.011078574073914218 + - 0.16004223335456821 + - 0.3614686961516761 + - - 0.15265019896041013 + - 0.03919674668030359 + - 0.07933011131248942 + - 0.4118041022473589 + - -0.19917149715506582 + - - -0.29705350142063586 + - -0.2573913724532826 + - 0.3413330988333964 + - -0.2629345752977957 + - -0.33925716711712006 + - - -0.3090126659916563 + - 0.10052600358008285 + - 0.3295391136582312 + - -0.20642069215799752 + - 0.1505575795303182 + - - -0.3432080532310526 + - 0.3164411332145225 + - 0.21984254570387593 + - 0.14569302992407368 + - -0.2789802513964337 + - - 0.06858600492516646 + - 0.003213025851039919 + - -0.34745085008036425 + - -0.3025951406540648 + - 0.3360807249658782 + - - 0.32079300335979 + - 0.10559011069365329 + - 0.12094144536768299 + - 0.33662630333028337 + - -0.35442232207439917 + - - -0.30186065289187725 + - -0.346981607427445 + - 0.3899552474706737 + - 0.16646982977851396 + - -0.06503610879430799 + - - 0.023620358718866294 + - -0.13711543258356879 + - -0.22371664723989837 + - 0.18122464518071227 + - -0.2238394582304553 + - - -0.23312745666197218 + - 0.10470224590876251 + - 0.050192516659196086 + - -0.3607304972415714 + - 0.39592397024699677 + - - 0.3541674922988369 + - 0.034944883058862765 + - 0.004196837614533963 + - 0.36513421403017227 + - -0.2939792596841648 + "@version": 1 + activation_function: none + bias: false + precision: float64 + resnet: false + use_timestep: false + ntypes: 2 + precision: default + repformer_args: + activation_function: tanh + attn1_hidden: 5 + attn1_nhead: 4 + attn2_has_gate: true + attn2_hidden: 5 + attn2_nhead: 4 + axis_neuron: 4 + direct_dist: false + g1_dim: 5 + g1_out_conv: true + g1_out_mlp: true + g2_dim: 5 + ln_eps: 1.0e-05 + nlayers: 3 + nsel: 40 + rcut: 4.0 + rcut_smth: 3.5 + set_davg_zero: true + trainable_ln: true + update_g1_has_attn: false + update_g1_has_conv: true + update_g1_has_drrd: true + update_g1_has_grrg: true + update_g2_has_attn: false + update_g2_has_g1g1: false + update_h2: false + update_residual: 0.01 + update_residual_init: norm + update_style: res_residual + use_sqrt_nnei: true + repformers_variable: + "@variables": + davg: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + dstd: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + env_mat: + rcut: 4.0 + rcut_smth: 3.5 + g2_embd: + "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -1.0120525255965693 + - -1.3410942661505185 + - 0.6444248980785466 + - 1.3379360916834435 + - -0.9400579171669438 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.019464291658135142 + - -0.018782174023876802 + - 0.053736857134061365 + - -0.21761095463941962 + - 0.22908429306304245 + "@version": 1 + activation_function: none + bias: true + precision: float64 + resnet: false + use_timestep: false + repformer_layers: + - "@class": RepformerLayer + "@variables": + g1_residual: + - "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.03603003169654137 + - 0.012497448175246137 + - -0.01095320901085369 + - -0.03415370741201022 + - -0.033128449146010305 + - "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.006919593857755369 + - 0.010748609458630657 + - 0.004842926409136031 + - -0.011656955546082724 + - 0.004788783773278474 + - "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.0004263964322652511 + - -0.0028300152727785068 + - 0.01952962243649864 + - 0.0027873043447218014 + - 0.001309288766687558 + g2_residual: + - "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.01160023747691338 + - -0.0004448807549913108 + - 0.0008057624078188488 + - -0.024948515472675986 + - -0.019416324515653484 + h2_residual: [] + "@version": 2 + activation_function: tanh + attn1_hidden: 5 + attn1_nhead: 4 + attn2_has_gate: true + attn2_hidden: 5 + attn2_nhead: 4 + axis_neuron: 4 + g1_dim: 5 + g1_out_conv: true + g1_out_mlp: true + g1_self_mlp: + "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -1.30811432030891 + - 1.7317839475528145 + - 1.399643928837452 + - -2.217598009660054 + - -0.3280510654318928 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.19058890417971588 + - 0.3463019129386775 + - -0.3238515676656896 + - 0.4620257390575868 + - -0.11979316985218211 + - - -0.5091348418247256 + - 0.5937549073105088 + - -0.4101966440417052 + - 0.07750470968878109 + - -0.3424037916252215 + - - -0.6019537986876181 + - 0.18600176889845013 + - 0.1980666907358186 + - -0.09605024233548902 + - -0.5160961441700344 + - - -0.22056069372777706 + - -0.7447000890755333 + - -0.06841639167039011 + - -0.3211447460158319 + - 0.01879819061642492 + - - 0.17548998966047413 + - -0.007952594370924146 + - 0.19396423635351934 + - -0.274239700940922 + - 0.2571037042854489 + "@version": 1 + activation_function: none + bias: true + precision: float64 + resnet: false + use_timestep: false + g2_dim: 5 + linear1: + "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.6292962632873957 + - 0.3041398278338624 + - -0.26570178182208953 + - 1.0069393034707894 + - -0.017831983893162602 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.08396232707317755 + - -0.2590789034376074 + - 0.06449599213906514 + - -0.026485184870448455 + - -0.11502287050655546 + - - -0.3394134083847728 + - -0.04182949795373899 + - -0.18355759680044953 + - -0.07272151368022486 + - 0.202216848914141 + - - -0.09852715985600499 + - -0.10543705177859736 + - 0.08879304650716603 + - 0.08474028791239975 + - 0.22150741767418541 + - - -0.03346579823549347 + - 0.08142402535224946 + - 0.20444458395431703 + - 0.2061630292035231 + - -0.34623750552383953 + - - 0.045881739265562126 + - 0.03235926439065396 + - -0.06514538897263505 + - -0.09157218195401022 + - -0.16078786248733104 + - - 0.10337932580707862 + - -0.009900109711506542 + - -0.09848266744555643 + - -0.06743723410229494 + - -0.034295685655753826 + - - -0.024749851911743403 + - -0.10168528807416925 + - -0.08321216530050654 + - 0.08263273817835241 + - 0.06961247702103983 + - - -0.026904632217682176 + - 0.024565437698427974 + - -0.015916989311626998 + - -0.2529377102697093 + - -0.33605661981070634 + - - -0.027971199641699127 + - -0.09923898814189575 + - -0.24178299388114752 + - 0.029602144279883807 + - 0.09859016934223175 + - - 0.12904137491494896 + - 0.01463153375175068 + - -0.030196990745127137 + - -0.03427710675704333 + - 0.0769092571140682 + - - 0.06078298146276607 + - 0.018624014276540105 + - 0.007625746651479272 + - -0.05212163740127107 + - -0.22961856283351262 + - - -0.24276750771398306 + - 0.015459667548397407 + - -0.024917383364799983 + - -0.07328719787935754 + - 0.10317938994270194 + - - 0.06207164078936751 + - -0.10450229276912223 + - 0.06706176715814717 + - 0.055976275244021684 + - -0.018233802569777677 + - - -0.06651532308872152 + - -0.08647319137042646 + - -0.09169294749680779 + - 0.129348538280104 + - 0.006274092835851255 + - - -0.062332536772224355 + - -0.12119349201175855 + - -0.06183693675207785 + - 0.0008820048050103422 + - 0.12406819642946329 + - - 0.036026233381260604 + - 0.08058280500947529 + - 0.02802481232492143 + - 0.14557999799990246 + - 0.25206957815030306 + - - -0.19644397693165241 + - 0.11483898331021172 + - -0.07929205314453085 + - 0.13685466931987425 + - 0.029385541250605204 + - - -0.20239024913689344 + - 0.14260455767505056 + - -0.12672049424541829 + - 0.19878940946568238 + - 0.21052724203164538 + - - -0.14411612791600975 + - -0.04190247752914743 + - 0.07299916367427191 + - 0.1440709108775927 + - -0.1613102600483591 + - - 0.3114860420321912 + - 0.07687254774336891 + - 0.03723959101138982 + - 0.012729756801082936 + - -0.24720255506387367 + - - 0.1623672745194676 + - -0.17803926391541774 + - 0.07601730846902242 + - -0.07896998871476738 + - -0.1424258418559638 + - - 0.03553129535089544 + - -0.14295036742419862 + - 0.07638378413519416 + - -0.07390601295144453 + - 0.12940724982987217 + - - -0.02842164520441565 + - 0.13939917807001465 + - 0.47701969821105683 + - -0.031117611945881554 + - -0.03500981062313414 + - - 0.17217654358849513 + - -0.20172150565508037 + - 0.14182112976591155 + - -0.06382811614171423 + - -0.007306675305514548 + - - -0.15447228144755248 + - -0.02312456900409192 + - -0.09544935736484167 + - 0.07250584700668984 + - 0.1574085006156906 + - - -0.12603968435322543 + - -0.025573781632627562 + - 0.1508227463536938 + - -0.30345724164088406 + - 0.1854020066002888 + - - -0.1635859501573259 + - -0.08480199517991185 + - 0.08817054366324961 + - 0.16171683703905798 + - -0.12004204073976865 + - - -0.08111522326105529 + - -0.0024435761674501118 + - -0.052941182960640855 + - 0.23910390490710043 + - -0.21315524715166578 + - - -0.05926997962527499 + - -0.008935034265033262 + - 0.14683102932558992 + - -0.01918527903899294 + - -0.03406480065791599 + - - -0.1337717203489477 + - -0.021818018335351835 + - 0.024971591316419486 + - 0.06704069613758198 + - 0.040648989162459674 + - - -0.004444956659464356 + - -0.0003425146958406023 + - 0.17960753789619216 + - -0.27214634408723903 + - 0.07711820341688266 + - - 0.133652261731334 + - -0.009797974080371686 + - -0.4259816212787344 + - 0.23200562464230318 + - -0.02493040660085401 + - - -0.046836875861095834 + - 0.001645473952655283 + - -0.05473403991636848 + - -0.05378549705402274 + - -0.3889996124421639 + - - 0.02092041296068538 + - 0.21891176724882613 + - 0.25305998826597703 + - -0.00737941664247247 + - -0.07319199260494484 + - - 0.22207531087979715 + - 0.1167463561271954 + - 0.05854914028062706 + - 0.02828761486751832 + - 0.014280499410575949 + - - 0.21548357500898635 + - -0.14325757178810586 + - -0.37108308687673025 + - -0.30841574977024155 + - -0.15765248368668322 + - - -0.19065434721674077 + - 0.2894449163901936 + - 0.10368226466064218 + - 0.002272586212844238 + - -0.06685991839970858 + - - 0.15850988743880434 + - -0.19930536837479582 + - 0.06352955279220247 + - 0.18233644655152334 + - 0.177009225453199 + - - 0.08440437483086662 + - 0.053514367173557176 + - -0.08416559269376686 + - 0.04962224184852935 + - -0.19450578606057445 + - - -0.02627920444598498 + - -0.11978290209672943 + - 0.06513795302813882 + - 0.028386691085269353 + - -0.16676268182611712 + "@version": 1 + activation_function: none + bias: true + precision: float64 + resnet: false + use_timestep: false + linear2: + "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.2592964101062212 + - 0.6050098232448291 + - -0.005515314503147 + - 0.11603766186289032 + - 0.7299519869702973 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.22980393702142735 + - 0.08816839372355742 + - -0.7825210216367764 + - 0.47400651965129326 + - -0.09715501151613053 + - - 0.051226778173106544 + - -0.0586569497395102 + - 0.07008245083492698 + - -0.27909388065497365 + - -0.403409427380388 + - - 0.4155289437793008 + - -0.23023614159453773 + - -0.06319940162259689 + - 0.49105437825377674 + - -0.672285862471505 + - - 0.2628024247581461 + - 0.30092184220502505 + - -0.04589377746162835 + - -0.15376402934013342 + - -0.3459379045607231 + - - 0.055019077191780244 + - -0.15428864792280397 + - 0.03686133046649239 + - 0.22797050584554945 + - -0.2895756986826399 + "@version": 1 + activation_function: none + bias: true + precision: float64 + resnet: false + use_timestep: false + ln_eps: 1.0e-05 + ntypes: 2 + precision: default + proj_g1g2: + "@class": Layer + "@variables": + b: null + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.5532734073335777 + - -0.2894171115724634 + - -0.34154590578539124 + - -0.5741289385215242 + - 0.22542651626880558 + - - -0.1258999860718108 + - -0.23796762711809444 + - 0.056052343409511166 + - 0.20182107205251548 + - 0.07978830712924143 + - - -0.10108756888794897 + - -0.31868861303852497 + - 0.2266451364666591 + - 0.013827304866388349 + - 0.050907970632244975 + - - -0.09037692956629029 + - -0.27064175886126934 + - -0.25149093441780035 + - -0.11031261523979898 + - 0.20910864059883363 + - - 0.1704620928115622 + - 0.45496457562319714 + - -0.39714014820404625 + - 0.052072985229575805 + - 0.20305217115768973 + "@version": 1 + activation_function: none + bias: false + precision: float64 + resnet: false + use_timestep: false + rcut: 4.0 + rcut_smth: 3.5 + sel: &id001 + - 40 + smooth: true + trainable_ln: true + update_chnnl_2: true + update_g1_has_attn: false + update_g1_has_conv: true + update_g1_has_drrd: true + update_g1_has_grrg: true + update_g2_has_attn: false + update_g2_has_g1g1: false + update_h2: false + update_style: res_residual + use_sqrt_nnei: true + - "@class": RepformerLayer + "@variables": + g1_residual: + - "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.030953153877771103 + - 0.004550247326878251 + - 0.013303927865191467 + - 0.01018780786448025 + - 0.0155549810199352 + - "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.016276946544039366 + - -0.0007364959557409706 + - -0.004972067050977963 + - 0.005248593373274716 + - -0.0056900226522463215 + - "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.008800796363452382 + - -0.005593693472478371 + - 0.032199083632300035 + - 8.400962254808775e-05 + - -0.012730621609703547 + g2_residual: + - "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.00982260814551794 + - 0.0042572728334056365 + - 0.009538800228021472 + - 0.00047524228131468776 + - -0.015464154773430883 + h2_residual: [] + "@version": 2 + activation_function: tanh + attn1_hidden: 5 + attn1_nhead: 4 + attn2_has_gate: true + attn2_hidden: 5 + attn2_nhead: 4 + axis_neuron: 4 + g1_dim: 5 + g1_out_conv: true + g1_out_mlp: true + g1_self_mlp: + "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 1.0730561183538059 + - 2.776096418581049 + - -1.6878494103454023 + - 0.6319621912698188 + - 0.7296819544535204 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.30406067533186854 + - -0.23319974862010298 + - 0.4268841113647301 + - -0.3171107228509156 + - 0.23705159038283963 + - - 0.14924431344154163 + - -0.3277346372033825 + - 0.03732222579147385 + - -0.06231168041312384 + - -0.27967515711810453 + - - -0.17033596314716415 + - -0.002809366135203804 + - 0.5350728562707233 + - 0.04409125365519289 + - -0.23373164532325605 + - - 0.16251375866692905 + - -0.2669074996515733 + - 0.4046092505537821 + - -0.03440243154913439 + - 0.14308632480367547 + - - -0.3212722334397611 + - -0.14397893865966555 + - 0.2031848262908836 + - 0.08760210337263934 + - 0.690524747593835 + "@version": 1 + activation_function: none + bias: true + precision: float64 + resnet: false + use_timestep: false + g2_dim: 5 + linear1: + "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.3904504721256209 + - -0.7387582311089921 + - -0.6516004102844352 + - -0.4170397034404462 + - -0.8279414759908088 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.07233303518597486 + - -0.12559877932713687 + - 0.05665805267436587 + - 0.08680631673659558 + - 0.0025329042894111187 + - - 0.2870588620333185 + - -0.07786427816699869 + - -0.07104575536059457 + - 0.048253184585550064 + - -0.1640024445711842 + - - 0.18797971218330461 + - 0.06771349169052267 + - -0.07907380461980497 + - 0.05525437486420736 + - 0.02009520745986023 + - - 0.281038963580294 + - 0.2270448019984451 + - 0.04662719213217196 + - 0.07123094368297082 + - -0.10806283754971098 + - - -0.07898984260721922 + - -0.10739362148061957 + - 0.0033415947030256544 + - 0.032288529511892566 + - -0.03173790906780475 + - - -0.011155332801251748 + - 0.1316727549039841 + - 0.0971263766783617 + - 0.17867926273983348 + - 0.03672718518126696 + - - 0.09157897721667 + - 0.09926237313114537 + - -0.21258585423818205 + - 0.1269364386377713 + - 0.30222502480313734 + - - -0.4265044494307036 + - 0.014698325150853153 + - 0.222526397099498 + - -0.0515765918179406 + - 0.04461419978150989 + - - -0.17174281294203816 + - 0.2751118183268816 + - -0.09571052405552033 + - 0.10674231950444127 + - -0.2665717704888117 + - - -0.08713033504338041 + - 0.05286235812039554 + - -0.13465378907856126 + - 0.26029622715251594 + - 0.20429631782598345 + - - -0.1931696011063644 + - 0.03179373592007634 + - -0.01705227065987145 + - -0.15627913293510198 + - 0.07755903941265974 + - - -0.11049696053897917 + - -0.057292985864107124 + - -0.010211172262537599 + - -0.16280061653720834 + - 0.07661291164398892 + - - -0.35140753207139847 + - -0.06810004447528932 + - 0.10839694449933315 + - 0.03542613606995244 + - -0.06775694369641179 + - - -0.04147634682326751 + - -0.056638697310459854 + - -0.05330569239674951 + - 0.13901111988568746 + - -0.06824938335129761 + - - -0.11203618974941376 + - 0.3604814644601604 + - -0.10640848463730161 + - 0.02699744755022092 + - 0.010353729905313842 + - - 0.12662021724846018 + - -0.15433041202985273 + - -0.28807521824574284 + - 0.20393023164034899 + - -0.09407420816857529 + - - -0.11791962062708798 + - 0.10090518368799273 + - 0.3166277996749642 + - 0.09456371288740745 + - 0.13776853376528053 + - - 0.21434164527122126 + - -0.0473284024955143 + - -0.09804549668998058 + - -0.03296726988358063 + - -0.02246160526373819 + - - -0.04965182688872488 + - -0.1861644408143712 + - 0.005313676453730348 + - -0.010360120573932947 + - -0.3078070692016113 + - - 0.10048335615799668 + - -0.08731584533464844 + - -0.19140719426127487 + - 0.09445837662615653 + - 0.27757444837207024 + - - -0.07841883765491009 + - 0.0471220563723184 + - 0.021673650695265518 + - 0.13963872028707364 + - -0.006825295820998376 + - - -0.19271438135057883 + - -0.007110790191465846 + - -0.08217641423459685 + - -0.025890309182111926 + - 0.06403918521905971 + - - -0.2968684802524053 + - -0.07467269668438187 + - -0.019663765996688794 + - -0.20157188265864767 + - -0.015590896465403993 + - - 0.15120523240819497 + - 0.040544407430732454 + - -0.14427351582541523 + - -0.08588535919441226 + - -0.09050202476744904 + - - 0.055015238092817935 + - -0.14935031270240165 + - -0.11538708778119777 + - 0.2118218514797107 + - 0.06777301739072661 + - - -0.018499557084470993 + - -0.1961687250411221 + - 0.3651860346902693 + - 0.27833350069061785 + - 0.025619643555402506 + - - 0.05902038518224982 + - 0.2800929655786798 + - -0.1055623585063209 + - -0.15935345894232575 + - 0.2517592270664209 + - - -0.1313006730876836 + - 0.21938815584964638 + - -0.18293917289674955 + - 0.005219129308061493 + - -0.10128044828072286 + - - 0.2452850640707537 + - 0.037943632583206474 + - 0.08442200151768535 + - 0.0068233590078222535 + - 0.25510063580027 + - - -0.0017976530933192004 + - -0.07441759598153459 + - 0.0814058326168008 + - 0.10158131544467495 + - 0.16953645649921872 + - - 0.12128479191043083 + - -0.042781762199358454 + - 0.007426922447889091 + - 0.04627483821729731 + - 0.07022983337840087 + - - -0.0511176676858341 + - -0.06232114476630841 + - -0.09207601218193863 + - 0.03228648493806836 + - 0.3338877697814744 + - - 0.055592097950266525 + - 0.2427801930543228 + - 0.009889478166058258 + - -0.027476734426876356 + - -0.027202887337284466 + - - 0.17481965178682457 + - 0.1064657969944493 + - 0.09565306493013129 + - 0.29187520131762024 + - 0.04937333337412164 + - - 0.0711114075126626 + - -0.03924682971036787 + - -0.05911820824729128 + - -0.016236488108268037 + - -0.24994434115230696 + - - -0.017473370139498092 + - 0.005054368701559544 + - -0.0018005121808373533 + - 0.2494665904609818 + - -0.13141342964673194 + - - -0.48951968995103823 + - -0.07991897204533437 + - -0.07265925729464501 + - -0.35037006167764917 + - -0.035530910846305705 + - - -0.024529386317393667 + - 0.125441077090495 + - 0.05072755218603061 + - -0.03358126521780828 + - -0.20649865854135296 + - - 0.27783626218193974 + - 0.3489366772451567 + - -0.09634345998938784 + - -0.061231461806113344 + - -0.12164226668578611 + - - -0.11350896510363769 + - 0.20888045493504787 + - -0.061339620252498527 + - -0.0494159943463361 + - 0.07315534698550205 + "@version": 1 + activation_function: none + bias: true + precision: float64 + resnet: false + use_timestep: false + linear2: + "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.39419698530848374 + - -0.6417653194947952 + - 0.1468172295786452 + - 1.184577336104801 + - 0.5081978769638908 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.0066873932717765395 + - -0.12251150304647253 + - 0.7086314961675694 + - 0.08910236706793025 + - 0.1156306065848897 + - - -0.20486974188140508 + - 0.4061737224632258 + - -0.03632714518429385 + - 0.046287906724936156 + - -0.3900507826710705 + - - -0.2351514218004438 + - -0.5837581895052254 + - -0.24421716883694933 + - 0.14058729936441064 + - -0.39625249214829994 + - - -0.3368066757750851 + - -0.11444928445140763 + - -0.216424188896182 + - -0.03014362208217597 + - -0.21739981725330465 + - - -0.4151991607396354 + - 0.14239346929129224 + - 0.6065187044001205 + - 0.17679720954491984 + - 0.18137573756648276 + "@version": 1 + activation_function: none + bias: true + precision: float64 + resnet: false + use_timestep: false + ln_eps: 1.0e-05 + ntypes: 2 + precision: default + proj_g1g2: + "@class": Layer + "@variables": + b: null + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.013767481419125458 + - 0.7128447681363075 + - -0.3388321535306186 + - 0.12895013965521143 + - 0.1759839796572137 + - - -0.17995579630766076 + - -0.10490891548428505 + - 0.17227112758807675 + - 0.5043682738830814 + - -0.5177896277554952 + - - -0.028086154477676623 + - 0.6331413729833129 + - 0.041191722613779125 + - -0.18618113150914578 + - 0.20480913933465447 + - - 0.1161772622437584 + - -0.43735266009160034 + - -0.1966245142374793 + - 0.7403002581659407 + - 0.7077246669326973 + - - -0.06815761371165154 + - 0.13160740208693508 + - 0.5921753126883629 + - 0.3803766737952691 + - -0.046876897835655046 + "@version": 1 + activation_function: none + bias: false + precision: float64 + resnet: false + use_timestep: false + rcut: 4.0 + rcut_smth: 3.5 + sel: *id001 + smooth: true + trainable_ln: true + update_chnnl_2: true + update_g1_has_attn: false + update_g1_has_conv: true + update_g1_has_drrd: true + update_g1_has_grrg: true + update_g2_has_attn: false + update_g2_has_g1g1: false + update_h2: false + update_style: res_residual + use_sqrt_nnei: true + - "@class": RepformerLayer + "@variables": + g1_residual: + - "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.012170950513902995 + - 0.0056982009234836295 + - -0.014772204332654533 + - 0.008484024106957531 + - 0.00849914730737344 + - "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.03246587217116422 + - -0.016893025175096388 + - 0.012435698330610843 + - 0.020187947808578702 + - 0.006687558900631235 + - "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.011748661520302154 + - -0.002501192068144276 + - -0.014619574984816258 + - -0.00025857395887020297 + - -0.011081880220023974 + g2_residual: [] + h2_residual: [] + "@version": 2 + activation_function: tanh + attn1_hidden: 5 + attn1_nhead: 4 + attn2_has_gate: true + attn2_hidden: 5 + attn2_nhead: 4 + axis_neuron: 4 + g1_dim: 5 + g1_out_conv: true + g1_out_mlp: true + g1_self_mlp: + "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.22622119705468932 + - 1.679972567436481 + - 0.8617849583979595 + - 0.7627209179213263 + - 0.08736709822833234 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.12167279667097665 + - -0.05876474635765746 + - 0.01001600975306451 + - 0.5491115954893747 + - 0.1738236476960245 + - - -0.3362357644991349 + - -0.2174002731903439 + - -0.12518015854971282 + - -0.5928577746423984 + - 0.07857201204414647 + - - -0.6509218388064963 + - 0.3415018469079166 + - -0.08849907468808393 + - -0.0072378826748915985 + - 0.14619134458027205 + - - -0.14740735248742085 + - -0.0063016032593476295 + - -0.09839776527112142 + - 0.19985375347739473 + - 0.19764726057708193 + - - -0.16575148152801386 + - -0.3304732289437867 + - 0.45087541559586225 + - -0.11539370596670748 + - -0.5659729701689729 + "@version": 1 + activation_function: none + bias: true + precision: float64 + resnet: false + use_timestep: false + g2_dim: 5 + linear1: + "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.04476047832627152 + - -1.8078872112598667 + - 2.047538554462726 + - -0.4313795822933001 + - -0.7411876884018038 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.02120826400167494 + - -0.20276242117630586 + - 0.24387467563095325 + - 0.10303126756068341 + - 0.051025069833096175 + - - -0.03374151054389157 + - 0.032688321848361145 + - -0.13788376296513702 + - -0.10889072552319669 + - 0.22432429428309092 + - - -0.17057797559937082 + - -0.01330308815114757 + - 0.17325681446662627 + - -0.045057383108055585 + - -0.09778619431562305 + - - -0.38340838579797915 + - 0.09212560805687599 + - -0.14003455421868935 + - -0.22433560375321876 + - -0.2834092517263271 + - - -0.11250033669131652 + - 0.2936406637464358 + - -0.1634908218886667 + - 0.09011330992430655 + - -0.251553266648588 + - - 0.2002812725896845 + - 0.18806047286167354 + - -0.1445594203730198 + - 0.2596828111976336 + - -0.003130421956802045 + - - 0.21828687409081016 + - -0.13414842719178494 + - -0.24666529269789395 + - -0.006883518402413685 + - 0.05769283363709148 + - - -0.3054723048189706 + - -0.06387554521612063 + - -0.18766119271870504 + - -0.12509613691242047 + - 0.055240333151730776 + - - 0.015469398501610828 + - -0.0823150717826605 + - -0.16675512192642816 + - 0.12935809271217594 + - 0.07713906361783558 + - - -0.045899178406135374 + - 0.1396668327265014 + - 0.26340289691123836 + - 0.0048105822723732644 + - -0.005366986901183898 + - - 0.009290570011027783 + - -0.060519424426768775 + - 0.13006197150464363 + - -0.01785373737250758 + - -0.014898842675596258 + - - 0.1269696841481518 + - -0.03191557171377839 + - 0.1526428291590831 + - 0.06812463690309147 + - 0.1181645022470457 + - - -0.008354378466241394 + - 0.05983437407138071 + - 0.06201711768881509 + - -0.2712250988812433 + - -0.12835557643255655 + - - 0.3624815275568562 + - 0.02126087985755094 + - 0.11631761909889272 + - -0.055310241744757925 + - 0.16972568359433227 + - - -0.07463063674973691 + - 0.22762638743923655 + - 0.24432912187935354 + - -0.079083681709922 + - -0.014821139250702014 + - - 0.033932738843908794 + - -0.11705856512428435 + - 0.16089535272200908 + - 0.3188475709224389 + - -0.19925058656337122 + - - 0.05149532136704679 + - 0.34912853325652876 + - 0.03914985231841249 + - -0.14289800364536395 + - -0.024926754252559545 + - - 0.005878903575216287 + - 0.006351181553721704 + - 0.2151655940717292 + - -0.10973066883698986 + - 0.1965176044595831 + - - -0.04881865143285429 + - 0.05682148666175087 + - -0.004566442841927496 + - -0.10219483242326999 + - 0.05265465863908208 + - - -0.02890441167758548 + - 0.11807479061676636 + - -0.026782896957314253 + - -0.10741381535691857 + - -0.08460612842711608 + - - -0.29912522080035436 + - -0.041053273463352566 + - -0.08341514675036599 + - -0.15255641816499257 + - 0.032738166656675345 + - - -0.13127827514841212 + - 0.17337979922895738 + - -0.07999054412107975 + - 0.029187688342863987 + - -0.027991659012987048 + - - 0.0953753641153965 + - -0.1154037586787982 + - -0.1606200159638128 + - 0.03963080020458736 + - -0.28877505888277183 + - - 0.04834257331300868 + - 0.07109182789858672 + - -0.04018473714707446 + - 0.3392501402646008 + - 0.051706351939985955 + - - 0.2262052651345229 + - 0.06567796566839683 + - 0.02969873382446752 + - 0.307137329425125 + - 0.17947875928066645 + - - 0.08614098800132988 + - -0.29531982917953764 + - -0.16051475876915816 + - 0.19569343061448652 + - 0.03787520855977715 + - - -0.21429349518787014 + - 0.05394593886889696 + - -0.04993584771039846 + - 0.19564022713367052 + - 0.11362262179190846 + - - -0.004245970327980209 + - -0.1940574165136546 + - 0.007951406114803104 + - -0.08859796346311412 + - -0.032649124565685736 + - - -0.21731129328505974 + - -0.06621352576033905 + - 0.09129544921890727 + - -0.06541669039814363 + - -0.10724204962303797 + - - 0.057357048566617864 + - -0.008345673540796282 + - 0.29292258751655026 + - -0.12308726040552805 + - 0.04054304095472656 + - - 0.10394124106268385 + - 0.17444441607800482 + - 0.12456110243374384 + - 0.050099919442543044 + - -0.11282643776218643 + - - 0.17835449092593977 + - 0.09637084989472562 + - -0.1934254167912494 + - -0.2545118504213924 + - 0.06827181091761918 + - - 0.19626184614240083 + - -0.04699868343310416 + - -0.10007243420160061 + - -0.09383391698441179 + - -0.0990655246656314 + - - 0.012335198304450592 + - -0.2180497687723705 + - 0.1541401217483529 + - -0.11638836449441982 + - 0.2683999798967227 + - - 0.10896246206993876 + - -0.3225899091557129 + - 0.08447514888719702 + - -0.16532213737983084 + - -0.14340574861753197 + - - 0.27134405821317176 + - 0.37994741042339447 + - -0.061229587223300345 + - -0.10718219560731332 + - -0.0342547953758568 + - - 0.01995848716225226 + - -0.2369496772630444 + - -0.12037037553687044 + - -0.11437070301591287 + - 0.244046581192314 + - - 0.16709557426845523 + - 0.029705706114580222 + - 0.02950591956405181 + - 0.042079856333530945 + - 0.1289641022252176 + - - 0.1269877310037645 + - -0.2914364205391121 + - -0.2493182842173884 + - -0.3155095607393483 + - 0.057889936497034826 + - - -0.026087582504925794 + - 0.09059376415886668 + - 0.05302656666128242 + - 0.19074227440516286 + - -0.16835197038194713 + "@version": 1 + activation_function: none + bias: true + precision: float64 + resnet: false + use_timestep: false + ln_eps: 1.0e-05 + ntypes: 2 + precision: default + proj_g1g2: + "@class": Layer + "@variables": + b: null + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.44534011123758094 + - 0.3517185018820327 + - 0.09140581073576556 + - 0.10509259688629408 + - -0.693942488230306 + - - -0.24023695288558378 + - -0.10202481638123778 + - 0.19072504173684576 + - 0.2987115067295286 + - 0.3039842054402882 + - - -0.14614782464590242 + - -0.2746073100636409 + - -0.6649874649402261 + - 0.30314217114609954 + - 0.8548523617893014 + - - -0.0037272725651430706 + - 0.42826436495159964 + - 0.14658274447536565 + - 0.3890848512382279 + - -0.7377284335342408 + - - 0.3773681827887734 + - -0.19069778577208663 + - -0.40778172197678264 + - 0.21385929050684915 + - -0.23807817854870678 + "@version": 1 + activation_function: none + bias: false + precision: float64 + resnet: false + use_timestep: false + rcut: 4.0 + rcut_smth: 3.5 + sel: *id001 + smooth: true + trainable_ln: true + update_chnnl_2: false + update_g1_has_attn: false + update_g1_has_conv: true + update_g1_has_drrd: true + update_g1_has_grrg: true + update_g2_has_attn: false + update_g2_has_g1g1: false + update_h2: false + update_style: res_residual + use_sqrt_nnei: true + repinit_args: + activation_function: tanh + axis_neuron: 5 + neuron: + - 5 + - 5 + - 5 + nsel: 108 + rcut: 6.0 + rcut_smth: 0.5 + resnet_dt: false + set_davg_zero: true + tebd_dim: 8 + tebd_input_mode: concat + three_body_neuron: + - 2 + - 4 + - 8 + three_body_rcut: 4.0 + three_body_rcut_smth: 3.5 + three_body_sel: 40 + type_one_side: false + use_three_body: true + repinit_three_body_variable: + "@variables": + davg: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + dstd: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - 0.2484323774795112 + - 0.18856311281779198 + - 0.18856311281779198 + - 0.18856311281779198 + - - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + - - 0.22487048413495345 + - 0.17407539794806373 + - 0.17407539794806373 + - 0.17407539794806373 + embeddings: + "@class": NetworkCollection + "@version": 1 + ndim: 0 + network_type: embedding_network + networks: + - "@class": EmbeddingNetwork + "@version": 2 + activation_function: tanh + bias: true + in_dim: 17 + layers: + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.43941908550486863 + - 0.9378261950772157 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.026265096160513004 + - -0.4639835283413547 + - - 0.03635647454640597 + - 0.07473758444527034 + - - -0.06992753819659699 + - 0.21986075742162492 + - - 0.07061419602422543 + - -0.0524592050658983 + - - -0.10868945785879074 + - -0.6393923325990258 + - - -0.4771175911983638 + - 0.2185816237968037 + - - -0.21775778011277241 + - -0.03846588547990851 + - - 0.2688575804342918 + - -0.022262286686160183 + - - 0.11139791007840963 + - -0.11383753559255314 + - - -0.5714738566517861 + - -0.11339892392396721 + - - 0.234274561882303 + - -0.2815645637897343 + - - -0.12797356315352978 + - -0.042039107070487636 + - - -0.0663246827412385 + - 0.1748397461193538 + - - -0.036346593063428266 + - -0.5376841980420106 + - - 0.002151021953151517 + - 0.15460640930804267 + - - 0.006481066742778678 + - -0.10520778210184074 + - - 0.24902329498264406 + - 0.1643944050570271 + "@version": 1 + activation_function: tanh + bias: true + precision: float64 + resnet: true + use_timestep: false + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.46486827708713296 + - 0.11781816864962305 + - 2.043777904940986 + - -0.06518765760721654 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.8388031982509411 + - 0.19742570339479432 + - -0.4757894936761817 + - 0.25494057061469866 + - - -0.26501627450969395 + - 0.29921478857516215 + - -0.4503188694359583 + - -0.555699105983989 + "@version": 1 + activation_function: tanh + bias: true + precision: float64 + resnet: true + use_timestep: false + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.36211325433682123 + - -1.474782667517953 + - -1.577999656211323 + - -0.665208648579387 + - -0.2968754579079675 + - 0.33237673982661636 + - -0.23852601602175583 + - -0.48970534633738644 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.15347549712808095 + - 0.06557762027792953 + - 0.0674498985020097 + - -0.28931880581807395 + - 0.08433956624127203 + - 0.40007512162053827 + - 0.047171932510215106 + - -0.43192983392825585 + - - -0.29002127476738354 + - 0.3010519137520722 + - 0.6236997835707944 + - -0.08736593092255791 + - -0.13973966203097968 + - 0.17215384978373047 + - 0.08768909794026612 + - 0.005232646216097761 + - - -0.14487959711385273 + - 0.13440625093514774 + - 0.11219192922053893 + - -0.3172037430673178 + - -0.028935720530630663 + - -0.1404831879240233 + - 0.03421585634032436 + - 0.031902064691486466 + - - -0.4670784225758848 + - 0.23080792888721513 + - -0.034071548427154236 + - 0.3635026497032477 + - 0.01725141617134657 + - 0.08667184666806436 + - -0.027586440920311633 + - 0.3547335062572569 + "@version": 1 + activation_function: tanh + bias: true + precision: float64 + resnet: true + use_timestep: false + neuron: + - 2 + - 4 + - 8 + precision: float64 + resnet_dt: false + ntypes: 2 + env_mat: + rcut: 4.0 + rcut_smth: 3.5 + repinit_variable: + "@variables": + davg: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + dstd: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - 0.15519872219179812 + - 0.09699061260807051 + - 0.09699061260807051 + - 0.09699061260807051 + - - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + - - 0.13706276479932664 + - 0.08672582449204416 + - 0.08672582449204416 + - 0.08672582449204416 + embeddings: + "@class": NetworkCollection + "@version": 1 + ndim: 0 + network_type: embedding_network + networks: + - "@class": EmbeddingNetwork + "@version": 2 + activation_function: tanh + bias: true + in_dim: 17 + layers: + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.9269543229421101 + - -0.23703601479020053 + - -0.4823021177542036 + - -1.4228502883640606 + - -2.011213316161816 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.19283446648725555 + - -0.36591818629805983 + - -0.15211699281081714 + - -0.039633317878387885 + - 0.08412479500393036 + - - 0.26348734235446575 + - -0.11555307251357866 + - 0.006660027683729128 + - -0.019346367290438243 + - 0.021346630671683325 + - - 0.13448295451127248 + - -0.018369442651599413 + - 0.21822347435795875 + - -0.5862539690338648 + - -0.03882603721619449 + - - 0.4562437634385346 + - 0.0922020249035163 + - -0.08537728490580257 + - -0.021504575261091438 + - 0.7537494236523268 + - - 0.058577520908596456 + - -0.297231972380086 + - 0.31401326943447117 + - 0.46604390397410916 + - -0.1079660066450081 + - - -0.016397124963663072 + - 0.193877750145573 + - 0.27323597326251514 + - 0.19024717476711756 + - 0.4902622942180982 + - - -0.07398866271485056 + - -0.32387509250925417 + - -0.20144274644415622 + - 0.3184551957437682 + - 0.05231949884880957 + - - -0.06009353811716006 + - 0.1737334138251616 + - -0.030962167721662194 + - -0.3036329366266574 + - -0.03161563299627668 + - - -0.22386979617490793 + - -0.11241202512640262 + - 0.0956477518147091 + - 0.149473360709335 + - -0.5834842034552281 + - - -0.29621075081021037 + - -0.2830710916386785 + - -0.40204993504967584 + - -0.43071084216223127 + - 0.14104372569244336 + - - 0.01845290217537068 + - -0.07375824446050101 + - 0.12504281011315418 + - -0.15207732006538546 + - 0.2551513482964309 + - - -0.024460540016230218 + - -0.2997661998703328 + - 0.19544625222543976 + - 0.04807700173152966 + - -0.22437119604868247 + - - 0.13184243108870328 + - -0.00655765860570148 + - 0.38000312984475265 + - -0.27888427289533213 + - 0.0523989672895371 + - - -0.36462365192301033 + - 0.2424520289960548 + - -0.14147066767231933 + - -0.2628867980693902 + - 0.4298456873841488 + - - -0.17116788917189402 + - -0.18512494024858211 + - -0.18136447808496087 + - -0.3087350626017898 + - -0.45704153153985133 + - - -0.21642261388083456 + - 0.1735543598039101 + - 0.12961593877456143 + - 0.30453456449045946 + - 0.013523974010561922 + - - -0.2135497141182742 + - 0.1500783241424191 + - 0.11827264405808366 + - 0.02356320702723967 + - -0.12559343259716768 + "@version": 1 + activation_function: tanh + bias: true + precision: float64 + resnet: true + use_timestep: false + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.10135465784739055 + - 0.5563348164690164 + - -0.35665415789438026 + - -2.529333391603986 + - -0.6897370433713996 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.02728733154369319 + - 0.34130728631608 + - -0.03568730885475739 + - 0.2806831273127007 + - -0.2534529086482955 + - - 0.0938752210762737 + - 0.20716028680714232 + - -0.18709754258589167 + - -0.003746032146125652 + - -0.07950919542774962 + - - 0.38983379896498505 + - -0.3109696383589871 + - 0.16903169446000998 + - -0.43353022307834055 + - -0.07037761125406733 + - - -0.1771462337798467 + - 0.3699660287991744 + - 0.49624472738258896 + - 0.13924527117477936 + - 0.6182673667623552 + - - -0.3420404033654241 + - -0.10075864284818886 + - -0.0020930780896084565 + - -0.2535436493462217 + - -0.5878699357609878 + "@version": 1 + activation_function: tanh + bias: true + precision: float64 + resnet: true + use_timestep: false + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.9852698210289226 + - -0.7516427000555517 + - 1.954718712962926 + - -1.3330436009980762 + - 1.1124496499921253 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.610428145819889 + - 0.16457695424561933 + - -0.10628976953976568 + - -0.10270959121450651 + - 0.710010727298032 + - - -0.21316268003180913 + - 0.1601285991489345 + - 0.017821166720948873 + - -0.1658766563788847 + - -0.5157841154491742 + - - 0.1301734160138573 + - 0.44333852482803526 + - -0.4706128872051425 + - -0.026771684706215094 + - -0.5432790014276224 + - - 0.13432831349954696 + - 0.05915400213144824 + - 0.30068669323555086 + - -0.2597730228141981 + - -0.04895968969299514 + - - -0.22048431980779148 + - -0.4194490160378933 + - 0.007756177005246415 + - -0.3298891389130536 + - -0.5385584464634693 + "@version": 1 + activation_function: tanh + bias: true + precision: float64 + resnet: true + use_timestep: false + neuron: + - 5 + - 5 + - 5 + precision: float64 + resnet_dt: false + ntypes: 2 + env_mat: + rcut: 6.0 + rcut_smth: 0.5 + smooth: true + trainable: true + type: dpa2 + type_embedding: + "@class": TypeEmbedNet + "@version": 2 + activation_function: Linear + embedding: + "@class": EmbeddingNetwork + "@version": 2 + activation_function: Linear + bias: false + in_dim: 2 + layers: + - "@class": Layer + "@variables": + b: null + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.5016826093036629 + - -0.5279194903676838 + - -0.0549678404477373 + - -0.038069148958863784 + - -0.36174151294282214 + - -0.2514151046175414 + - 0.08245655199851126 + - 0.026302755666446104 + - - -0.07831498854680775 + - -0.008641242501326772 + - 0.22177790678578946 + - -0.14363349563193728 + - -0.3164380301609334 + - -0.14466111798293158 + - 0.06012655136452968 + - 0.07616883527453565 + "@version": 1 + activation_function: Linear + bias: false + precision: float64 + resnet: true + use_timestep: false + neuron: + - 8 + precision: float64 + resnet_dt: false + neuron: + - 8 + ntypes: 2 + padding: true + precision: default + resnet_dt: false + trainable: true + type_map: &id002 + - O + - H + use_econf_tebd: false + use_tebd_bias: false + type_map: *id002 + use_econf_tebd: false + use_tebd_bias: false + fitting: + "@class": Fitting + "@variables": + aparam_avg: null + aparam_inv_std: null + bias_atom_e: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.0 + - - 0.0 + fparam_avg: null + fparam_inv_std: null + "@version": 2 + activation_function: tanh + atom_ener: [] + dim_descrpt: 13 + dim_out: 1 + exclude_types: [] + layer_name: null + mixed_types: true + nets: + "@class": NetworkCollection + "@version": 1 + ndim: 0 + network_type: fitting_network + networks: + - "@class": FittingNetwork + "@version": 1 + activation_function: tanh + bias_out: true + in_dim: 13 + layers: + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.3349859616903922 + - -0.45748055191068804 + - 1.0314434101351464 + - -0.6562475549893949 + - 0.9498122103417428 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.28734085240489987 + - -0.22799376929792833 + - 0.09371468724415302 + - -0.060343990641131336 + - -0.1605587616678341 + - - -0.500828607464722 + - -0.39769032445849883 + - -0.13097956796562052 + - -0.23212793717704652 + - -0.038828884528037115 + - - 0.23321046838546103 + - 0.017315597110918363 + - 0.13631624095760458 + - 0.23691624195155855 + - 0.005574145933647628 + - - -0.20133887020224747 + - -0.328829379978232 + - 0.34759416395637244 + - 0.1944497447741702 + - 0.3193324891906819 + - - 0.019880427431400238 + - 0.006034434080859869 + - 0.28792741597790694 + - -0.05048129942486406 + - 0.2568884031435667 + - - -0.07159867806442435 + - 0.2954506374545977 + - 0.29676436126009775 + - 0.11528861832629143 + - 0.08716069714021854 + - - 0.02839550934091151 + - -0.5658931776162197 + - -0.12736808072489597 + - -0.2978927418496175 + - 0.08156233385616206 + - - -0.21714384551138263 + - -0.17752510849514816 + - -0.1873420473309751 + - -0.24301014238016175 + - 0.15351478595992818 + - - 0.3901008433965192 + - -0.23477779245126928 + - 0.07766405087449278 + - -0.28986099335396964 + - -0.055080083740405905 + - - -0.31078867820981393 + - -0.22574724498621582 + - 0.4107563625227253 + - -0.2557694947708352 + - 0.36134010096258296 + - - 0.2572164450054561 + - -0.20523865915689532 + - -0.10738066546619592 + - -0.306753245712776 + - -0.3876447597823812 + - - 0.04476080392863874 + - 0.166299366041299 + - -0.13581197624774488 + - -0.3027853280458097 + - -0.2496523280475521 + - - -0.010164582247014432 + - 0.5254917836303269 + - -0.11728939959987111 + - 0.031401166283109405 + - 0.202952049786492 + "@version": 1 + activation_function: tanh + bias: true + precision: float64 + resnet: true + use_timestep: false + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.7612288939582822 + - 0.22011088454870795 + - -0.6930292890564576 + - 0.24515979935539478 + - -0.9119932094656207 + idt: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.12083338973169744 + - 0.12359911898435172 + - 0.11847294748624924 + - 0.1204604519504766 + - 0.07610319945201176 + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.12870200352732564 + - -0.10389868985844915 + - -0.37501879915699127 + - -0.18541690943539932 + - -0.04702771493728981 + - - -0.28377397781017044 + - 0.04631072135573396 + - 0.3373706928254973 + - 0.053062537153292455 + - -0.2800937779745926 + - - -0.15595871887877247 + - 0.5560575834005833 + - -0.027504448094158956 + - 0.1068959101261423 + - -0.04916392093285025 + - - 0.7159719256460106 + - 0.4405997085081062 + - -0.7393022855683045 + - 0.32721085056544846 + - 0.4379543738536064 + - - -0.16927349569863864 + - -0.06948480276565093 + - -0.059155283940860406 + - 0.22445565063495596 + - 0.24290367853391648 + "@version": 1 + activation_function: tanh + bias: true + precision: float64 + resnet: true + use_timestep: true + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.268423528320119 + - -0.13946633307874465 + - -0.7993169599497402 + - 1.8957665864366233 + - 0.830006098693181 + idt: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.08137930994046798 + - 0.09111025159099423 + - 0.12083475394061953 + - 0.12136745805960235 + - 0.12132465520083764 + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.1198813779466343 + - 0.07320296304118168 + - 0.23796678982452435 + - 0.0667895312516395 + - 0.10427067882701904 + - - 0.08373108216057079 + - -0.1780521559772933 + - 0.05778911769623431 + - -0.3942187090319752 + - 0.19806462779250142 + - - 0.2171199847078245 + - -0.11237105080550389 + - 0.46886249866228147 + - 0.21442756872773663 + - 0.014044661093248303 + - - -0.49713589716071865 + - 0.27064695337075706 + - -0.2772472406473757 + - -0.42266647999572204 + - -0.19290646520604507 + - - 0.13360533303184338 + - 0.46104850502457395 + - 0.02001739740431832 + - -0.08158918138208685 + - 0.3628296701734589 + "@version": 1 + activation_function: tanh + bias: true + precision: float64 + resnet: true + use_timestep: true + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.9728956257191301 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.23602425791819237 + - - 0.1825842903969548 + - - -0.1141469354572199 + - - 0.11064695516945687 + - - 0.6363405008198829 + "@version": 1 + activation_function: none + bias: true + precision: float64 + resnet: false + use_timestep: false + neuron: + - 5 + - 5 + - 5 + out_dim: 1 + precision: default + resnet_dt: true + ntypes: 2 + neuron: + - 5 + - 5 + - 5 + ntypes: 2 + numb_aparam: 0 + numb_fparam: 0 + precision: default + rcond: null + resnet_dt: true + spin: null + tot_ener_zero: false + trainable: + - true + - true + - true + - true + type: ener + type_map: + - O + - H + use_aparam_as_mask: false + var_name: energy + pair_exclude_types: [] + preset_out_bias: null + rcond: null + type: standard + type_map: + - O + - H +model_def_script: + atom_exclude_types: [] + data_bias_nsample: 10 + data_stat_nbatch: 10 + data_stat_protect: 0.01 + descriptor: + add_tebd_to_repinit_out: false + concat_output_tebd: true + env_protection: 0.0 + exclude_types: [] + precision: default + repformer: + activation_function: tanh + attn1_hidden: 5 + attn1_nhead: 4 + attn2_has_gate: true + attn2_hidden: 5 + attn2_nhead: 4 + axis_neuron: 4 + direct_dist: false + g1_dim: 5 + g1_out_conv: true + g1_out_mlp: true + g2_dim: 5 + ln_eps: null + nlayers: 3 + nsel: 40 + rcut: 4.0 + rcut_smth: 3.5 + set_davg_zero: true + trainable_ln: true + update_g1_has_attn: false + update_g1_has_conv: true + update_g1_has_drrd: true + update_g1_has_grrg: true + update_g2_has_attn: false + update_g2_has_g1g1: false + update_h2: false + update_residual: 0.01 + update_residual_init: norm + update_style: res_residual + use_sqrt_nnei: true + repinit: + activation_function: tanh + axis_neuron: 5 + neuron: + - 5 + - 5 + - 5 + nsel: 108 + rcut: 6.0 + rcut_smth: 0.5 + resnet_dt: false + set_davg_zero: true + tebd_dim: 8 + tebd_input_mode: concat + three_body_neuron: + - 2 + - 4 + - 8 + three_body_rcut: 4.0 + three_body_rcut_smth: 3.5 + three_body_sel: 40 + type_one_side: false + use_three_body: true + smooth: true + trainable: true + type: dpa2 + use_econf_tebd: false + use_tebd_bias: false + fitting_net: + activation_function: tanh + atom_ener: [] + neuron: + - 5 + - 5 + - 5 + numb_aparam: 0 + numb_fparam: 0 + precision: default + rcond: null + resnet_dt: true + seed: 1 + trainable: true + type: ener + use_aparam_as_mask: false + pair_exclude_types: [] + preset_out_bias: null + srtab_add_bias: true + type: standard + type_map: + - O + - H +pt_version: 2.5.0+cu124 +software: deepmd-kit +time: "2024-11-12 20:55:56.133493+00:00" +version: 3.0.0b5.dev48+gd46d5f0b7.d20241025 diff --git a/source/tests/infer/deeppot_sea.yaml b/source/tests/infer/deeppot_sea.yaml new file mode 100644 index 0000000000..f51e0cdc0d --- /dev/null +++ b/source/tests/infer/deeppot_sea.yaml @@ -0,0 +1,3975 @@ +"@variables": {} +backend: PyTorch +model: + "@class": Model + "@variables": + out_bias: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 0.0 + - - 0.0 + out_std: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 1.0 + - - 1.0 + "@version": 2 + atom_exclude_types: [] + descriptor: + "@class": Descriptor + "@variables": + davg: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - 0.05033023990415701 + - 0.0 + - 0.0 + - 0.0 + - - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + - - 0.04810435854572292 + - 0.0 + - 0.0 + - 0.0 + dstd: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - 0.13983584883542127 + - 0.08580324224912261 + - 0.08580324224912261 + - 0.08580324224912261 + - - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + - - 0.12387506031307452 + - 0.0767222430740185 + - 0.0767222430740185 + - 0.0767222430740185 + "@version": 2 + activation_function: tanh + axis_neuron: 2 + embeddings: + "@class": NetworkCollection + "@version": 1 + ndim: 1 + network_type: embedding_network + networks: + - "@class": EmbeddingNetwork + "@version": 2 + activation_function: tanh + bias: true + in_dim: 1 + layers: + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.9135780953141011 + - -0.21988704992411853 + - -0.4665052129042189 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.4480607783689787 + - -0.8184588710236892 + - -0.3212229262177685 + "@version": 1 + activation_function: tanh + bias: true + precision: float64 + resnet: true + use_timestep: false + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.08271688634770508 + - 0.5710172961098444 + - -0.3386457716744601 + - -2.5417164583577074 + - -0.7193287459119063 + - 0.7652256530675483 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.019875525969363945 + - 0.35820662383392005 + - -0.03324304377423504 + - 0.27688221097732024 + - -0.24989229837693086 + - 0.0783270796917456 + - - 0.20945742957428046 + - -0.211872271466121 + - 0.0008358507886423699 + - -0.05453596666424369 + - 0.39210040951537417 + - -0.3419570224964847 + - - 0.1653069701250396 + - -0.44312835436502285 + - -0.05373650575797951 + - -0.19747867430009755 + - 0.38126230712956843 + - 0.5052052142698005 + "@version": 1 + activation_function: tanh + bias: true + precision: float64 + resnet: true + use_timestep: false + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.9648211602390828 + - -0.7329251669015443 + - 1.9627035817615672 + - -1.3461633369429176 + - 1.0936693885979498 + - 1.3607024848954599 + - 0.8334459935068284 + - -1.8810072976327308 + - -1.1989114806542243 + - 1.4281403830794184 + - -1.3843058263856687 + - -0.052672020574027804 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.47249502253413767 + - 0.12647502337259248 + - -0.06699924946784985 + - -0.09423315846872603 + - 0.5333188515820263 + - -0.1727275917320159 + - 0.09597175344674748 + - -0.005932492557303448 + - -0.11177655892814739 + - -0.3658612150653396 + - 0.0900624080089449 + - 0.3087569242518805 + - - -0.3531998974784297 + - -0.016843749414359895 + - -0.3978187976699206 + - 0.09613768242432623 + - 0.038704464021698956 + - 0.2121530086788622 + - -0.17793330160589513 + - -0.01613256387026088 + - -0.18276498936872523 + - -0.3375118318769777 + - -0.015101439735200641 + - -0.2428858281631755 + - - -0.3856297271047689 + - 0.34124658239474304 + - -0.019121537646152507 + - -0.2957945097641662 + - 0.3338945897757846 + - -0.11131338121737373 + - 0.4990860582585273 + - 0.18017764178170176 + - 0.12372129206846949 + - -0.0053360180956128446 + - -0.1427637886528732 + - 0.17922837261556487 + - - 0.01826665666023047 + - -0.31470232462679454 + - -0.10194269104342851 + - 0.140105295020757 + - 0.20718071908235003 + - -0.18636425764042444 + - 0.22150662118375927 + - -0.44467695847359695 + - 0.19376190244147914 + - 0.3744529008877187 + - 0.187347523462222 + - 0.24725889785262806 + - - -0.05569133835982222 + - 0.4056690684281682 + - -0.4050124133142187 + - 0.06911914042311589 + - -0.02507250287261658 + - 0.12447300353713507 + - -0.08866881586874285 + - 0.09787093395803924 + - -0.5732585744778257 + - 0.3127186256273779 + - -0.04216391443940198 + - 0.14144785545496388 + - - 0.06883848922135298 + - -0.1953127184316002 + - 0.016320845694316112 + - -0.17358161953878057 + - -0.27455717485484615 + - 0.237686322099333 + - 0.003476977953736387 + - -0.5070084729698298 + - -0.09432938135450629 + - -0.01422933742930722 + - 0.06222538657948485 + - 0.2710869470606886 + "@version": 1 + activation_function: tanh + bias: true + precision: float64 + resnet: true + use_timestep: false + neuron: + - 3 + - 6 + - 12 + precision: float64 + resnet_dt: false + - "@class": EmbeddingNetwork + "@version": 2 + activation_function: tanh + bias: true + in_dim: 1 + layers: + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.42401411440600284 + - 0.9335508956852949 + - 1.2815073047858874 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.02969569009440913 + - -1.0118009942211346 + - 0.06896408835119755 + "@version": 1 + activation_function: tanh + bias: true + precision: float64 + resnet: true + use_timestep: false + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.45914204107700757 + - 0.11575034881349684 + - 2.042472615214458 + - -0.04460101257497492 + - -0.09496399374159296 + - -0.02778538768440204 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.7035503395434383 + - 0.14786570627009207 + - -0.38542169764931816 + - 0.20612207919736217 + - -0.2061862442776192 + - 0.24707340883641274 + - - -0.35356238326930606 + - -0.4611521817765432 + - 0.058136587416921855 + - 0.4183052227368231 + - -0.16602218293244755 + - -0.0786167186136645 + - - 0.19848155121056013 + - 0.3739182045960408 + - -0.25545938052117345 + - -0.18953224983729802 + - -0.41858486773276415 + - -0.3770657866346729 + "@version": 1 + activation_function: tanh + bias: true + precision: float64 + resnet: true + use_timestep: false + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.3691793459040411 + - -1.4655414502333097 + - -1.6050232638930426 + - -0.6681646046581292 + - -0.3140923743196441 + - 0.3343664152287671 + - -0.24004830128184454 + - -0.4732990868654677 + - -0.32975769383273484 + - 0.6314644563710302 + - 2.1358167277049422 + - 0.6117069749191495 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.14563116173349283 + - 0.04829045319671905 + - 0.07994197920232828 + - -0.23860016370960932 + - 0.04963336794016981 + - 0.3068430990490502 + - 0.01936056893300573 + - -0.3550040917950255 + - -0.21780471274370697 + - 0.264292708909017 + - 0.5059605396244347 + - -0.07740470349959563 + - - -0.09592344803557039 + - 0.14557821903298793 + - 0.07125651278685914 + - 0.008116574101943882 + - -0.11727715174998957 + - 0.11448630070265027 + - 0.09347138153544679 + - -0.24422598023495334 + - -0.04354907253714727 + - -0.13208668521533803 + - 0.008239117976368397 + - 0.034859296288528165 + - - -0.3631037056379459 + - 0.1978243087998284 + - -0.049890357214265546 + - 0.2963972390728509 + - 0.013778872768427486 + - 0.08639903620572627 + - -0.008822281966412574 + - 0.2933178206949785 + - -0.08230554347102534 + - -0.05800263712621382 + - 0.08577681436425053 + - -0.166578601938574 + - - 0.37022110375622513 + - 0.28402854017517876 + - 0.3418359729210835 + - -0.20230643070146162 + - -0.2111416939858313 + - -0.15959751348169024 + - 0.04722110963098123 + - 0.15879411283730685 + - 0.1556805204442205 + - 0.19973471060846398 + - -0.1302345558079619 + - 0.35036032507149595 + - - 0.22368587041266697 + - -0.11532135231064344 + - -0.03134221053922798 + - -0.10236704897227157 + - -0.19255680351960777 + - 0.0925144003206949 + - -0.1027075690403189 + - 0.31624210244063766 + - -0.043468953265288414 + - 0.4041294754912617 + - -0.2828764908196377 + - -0.31702892676419375 + - - -0.2574112780622649 + - -0.19937051087385846 + - 0.12304523584653093 + - -0.45599474001808815 + - -0.13757553285321772 + - 0.15271930575698398 + - 0.1160045218588379 + - -0.007113074597327396 + - -0.1376347835486681 + - -0.5184472929334725 + - -0.2921898733184288 + - -0.6940203486091138 + "@version": 1 + activation_function: tanh + bias: true + precision: float64 + resnet: true + use_timestep: false + neuron: + - 3 + - 6 + - 12 + precision: float64 + resnet_dt: false + ntypes: 2 + env_mat: + rcut: 6.0 + rcut_smth: 0.5 + env_protection: 0.0 + exclude_types: [] + neuron: + - 3 + - 6 + - 12 + precision: float64 + rcut: 6.0 + rcut_smth: 0.5 + resnet_dt: false + sel: + - 46 + - 92 + set_davg_zero: false + spin: null + trainable: true + type: se_e2_a + type_map: + - O + - H + type_one_side: true + fitting: + "@class": Fitting + "@variables": + aparam_avg: null + aparam_inv_std: null + bias_atom_e: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -93.57372029622398 + - - -187.14744059244796 + fparam_avg: null + fparam_inv_std: null + "@version": 2 + activation_function: tanh + atom_ener: null + dim_descrpt: 24 + dim_out: 1 + exclude_types: [] + layer_name: null + mixed_types: false + nets: + "@class": NetworkCollection + "@version": 1 + ndim: 1 + network_type: fitting_network + networks: + - "@class": FittingNetwork + "@version": 1 + activation_function: tanh + bias_out: true + in_dim: 24 + layers: + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.9981537200365072 + - -1.315191620198832 + - 0.6369789163922976 + - 1.337158860758462 + - -0.9472125858788953 + - -0.6323336917857305 + - -0.00672112112193493 + - 0.550266797507348 + - -0.7859665327566713 + - 0.9559627789152517 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.0042052253448055035 + - 0.0012790546447848106 + - 0.024459922240004015 + - -0.08938464488474171 + - 0.09579565174216789 + - 0.06359047895799255 + - -0.23858555445678956 + - 0.24574716961976314 + - -0.30148053467596875 + - 0.018318522349898603 + - - 0.13428259692815644 + - 0.02493265251926957 + - 0.2121507005915962 + - 0.26295247288893503 + - 0.1431369868225353 + - 0.24016486486115748 + - 0.01699816451436643 + - 0.09133735203040197 + - -0.29745245083276833 + - 0.056392950181727595 + - - -0.2802586118952235 + - -0.1316082125599456 + - 0.2632040292764991 + - 0.11419401800523367 + - 0.2607000737789437 + - 0.11963076327749202 + - -0.20127652325886658 + - 0.3165101919991174 + - -0.038561905875380734 + - -0.0042471663599128115 + - - 0.1635954491570036 + - 0.03590994868003968 + - -0.21038601686901315 + - -0.14028713956327443 + - 0.07567915811764823 + - -0.03368584094543363 + - -0.21812728000462345 + - 0.23355325490724416 + - -0.056981076241320164 + - 0.08885470630211721 + - - -0.0039599483516003075 + - -0.09200095301889753 + - -0.12235481385358529 + - -0.17658616519785011 + - -0.05859356694366821 + - 0.06997022727882568 + - -0.18718983767065284 + - 0.2546703433218559 + - -0.013893634233231148 + - 0.14028529992874594 + - - -0.09860275843967325 + - -0.08940099797057129 + - 0.08290195726553863 + - 0.2651291904246647 + - -0.0810781079200689 + - -0.26944309498149577 + - -0.06398119285959505 + - -0.0824025975377168 + - 0.1517972454111715 + - 0.15972842705634308 + - - -0.0689818055157703 + - -0.2472544578106405 + - 0.20391933074516952 + - -0.3384475880777855 + - 0.024170530156081366 + - 0.12082177662274075 + - -0.03015793325701928 + - 0.26097457806906477 + - 0.13234544720128233 + - 0.10969962380216029 + - - -0.1410236369792598 + - 0.2565220565743628 + - -0.09813900782449475 + - -0.16475308162821475 + - -0.4382269634184545 + - -0.09618619130856115 + - -0.06151773086475732 + - -0.3679955124060774 + - 0.09489492151931028 + - 0.1699636112301434 + - - -0.06466507122698881 + - -0.03399661756182682 + - -0.08465867612319165 + - -0.02820143776252954 + - -0.04364607418496975 + - 0.10649423660304615 + - 0.1117641236592904 + - 0.242076918802004 + - 0.23336414341736533 + - -0.10973931654443178 + - - -0.11803473383837482 + - 0.26023291747567073 + - -0.036576563288649645 + - 0.031801452475671496 + - 0.027499520410698055 + - 0.21832021718409408 + - 0.002563705598485124 + - -0.023939863087569735 + - -0.15508788279340086 + - -0.009958330239924398 + - - 0.2029037816897343 + - 0.1236798701169301 + - -0.09018519532479251 + - 0.10235830736281337 + - 0.22210742529901917 + - -0.11510008731097023 + - -0.2907045465491179 + - 0.19173281505129286 + - -0.1826620176135373 + - -0.3858276132763665 + - - 0.25873945816211946 + - 0.20526926855225042 + - -0.1946447069870335 + - 0.0920127246001931 + - 0.18416151105208917 + - 0.13335997520924525 + - 0.11244062601373198 + - 0.006020269240571531 + - 0.22885110139445944 + - -0.24802252929895452 + - - 0.18188964884594755 + - 0.09757316104582281 + - 0.21451544724583474 + - -0.03795113704629376 + - 0.23089932965728366 + - -0.1653256730851483 + - 0.05472418048709738 + - -0.03607142063053913 + - -0.12208110627218638 + - -0.05489922891341493 + - - 0.003967605563087107 + - 0.09393663928648069 + - -0.07443910754909523 + - 0.07398359030793977 + - -0.09616706889032565 + - 0.04408314592707753 + - -0.027688682986298615 + - -0.3893989079960798 + - -0.019583475709650577 + - 0.15263602009279958 + - - 0.23242566280078358 + - -0.20751154943785288 + - -0.1057531593975858 + - 0.0097666599925194 + - 0.01723815124684891 + - -0.14321076447143907 + - -0.05668897839878534 + - 0.4076712003422085 + - -0.12968800173163939 + - -0.11833504008442321 + - - -0.0382870843112316 + - 0.1923434843810583 + - 0.18933356350887878 + - 0.1935721875785095 + - 0.1040048511477642 + - 0.2018770661645953 + - -0.007807146975480021 + - -0.0007246459780983533 + - -0.15092810563793607 + - -0.1266003232191877 + - - 0.14101405179181745 + - 0.07251429705927144 + - -0.009029755426231204 + - -0.030652737781312292 + - 0.028970459200312354 + - -0.06279658944669542 + - -0.16561866479726714 + - 0.10135625033764951 + - 0.20076140517783675 + - 0.18808803566465912 + - - -0.062259847002018434 + - 0.16140409497962688 + - -0.047263255762838075 + - 0.02161784552788213 + - -0.34189275457722157 + - 0.23624471783983889 + - -0.07604339720459162 + - 0.11972193795451776 + - 0.008718648262089578 + - -0.14597435103721404 + - - -0.14629679080173855 + - -0.010059437004151458 + - 0.12313644848963382 + - 0.07363202275050783 + - 0.3158951551063782 + - 0.30872396959636816 + - 0.11068402170384259 + - 0.13798118511411697 + - -0.17580561411126133 + - 0.2309366660573632 + - - 0.24607121569674878 + - 0.04633671294658672 + - -0.07794643734438188 + - 0.2637145614710995 + - 0.1673610762049212 + - -0.050889657349376165 + - 0.08082949413787617 + - 0.24349734338054974 + - -0.22618034164526177 + - -0.2699847668276845 + - - -0.0930135234346539 + - 0.11048617654242995 + - -0.10055237297539021 + - 0.2114601268899422 + - -0.08746970446094664 + - -0.16504052170679984 + - 0.11382627278043715 + - 0.3029957109835582 + - -0.011668717315770452 + - -0.191775672823171 + - - -0.23450812115563055 + - -0.009067565673813316 + - 0.07750402173739522 + - 0.07176706287311344 + - 0.11641190944846258 + - -0.1357712681870809 + - -0.18494204455539703 + - 0.11317724042565168 + - 0.20294076058173613 + - 0.2652756944904472 + - - -0.05606948494123854 + - 0.07125608415284536 + - 0.21786208482337216 + - 0.06403079423538505 + - -0.023569537627091536 + - -0.08670783596141193 + - -0.06814400771010572 + - 0.16076713468366743 + - 0.05133570956677264 + - 0.04802388700136072 + - - 0.07703357030277777 + - 0.049584559880658324 + - 0.31461230774136195 + - -0.13778140662267765 + - 0.2840242396477052 + - 0.03509103879961169 + - 0.3406341956660782 + - -0.09133539705008382 + - -0.07324504845041965 + - 0.09846652096261795 + "@version": 1 + activation_function: tanh + bias: true + precision: float64 + resnet: true + use_timestep: false + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.6163330784017647 + - -1.770660783796408 + - 0.5161258405142825 + - -0.1612283658165593 + - -0.7523557826606485 + - -2.224214510392427 + - -0.33342190547093487 + - -1.126722365560626 + - -0.4720938937079637 + - 1.3963705617158102 + idt: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.11119561540931056 + - 0.08970452515484859 + - 0.08915193432675515 + - 0.09048898647907772 + - 0.11061185348485504 + - 0.08865526034423891 + - 0.10991823920411607 + - 0.08882526076676668 + - 0.11002293261542401 + - 0.10972840025346457 + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.3296701598540796 + - 0.2296310333592966 + - 0.12856996740720156 + - -0.2899652443366801 + - -0.30948078796262024 + - -0.03584087926405295 + - -0.17713330810709904 + - -0.028899062709869576 + - 0.1701000953073888 + - -0.1455204576271593 + - - -0.058465655159713324 + - 0.025638712945035648 + - -0.11357307412441271 + - -0.022422611999286395 + - -0.0638233851458305 + - -0.17943732395298106 + - -0.35728668924171725 + - 0.2633856900982372 + - 0.31930240317937025 + - -0.13942155663674075 + - - 0.25925030472898564 + - -0.08600915200889982 + - 0.013237271542377893 + - 0.08270645093988137 + - 0.07938652623724432 + - -0.07712393546328064 + - 0.0353061287875199 + - -0.09040033388673825 + - -0.10510751746845258 + - -0.239521654551994 + - - 0.1950191665752893 + - 0.05733138609046635 + - 0.06036323899411715 + - 0.08603163367369916 + - -0.06351003379199475 + - -0.11695534727101047 + - -0.29962219683967695 + - -0.16592572533204622 + - 0.04512989503642938 + - -0.5426376699838075 + - - -0.16145358850836278 + - 0.1472987868317316 + - -0.1685368576975565 + - 0.1486279679321272 + - 0.4495215217544417 + - 0.11479724794417837 + - -0.16865830590287098 + - -0.06396524972992296 + - 0.11824420098274822 + - -0.3484497284909885 + - - -0.13690561062125733 + - -0.13492860372177048 + - 0.35633010151792976 + - 0.12420350039852161 + - -0.13175470708945922 + - 0.05482431376078837 + - 0.029804708140519526 + - -0.06731920990649255 + - 0.18110006421406621 + - -0.03754748826767218 + - - 0.12325162433762646 + - -0.06526178999491541 + - 0.02939936255381139 + - 0.14462063711735892 + - -0.1315751283801704 + - 0.021863822341382386 + - 0.11981563675784059 + - 0.03806003514789598 + - -0.004355114224454414 + - -0.11544599826628787 + - - 0.2604909324088583 + - 0.21536103820459426 + - -0.121929486161253 + - -0.2582532715867142 + - -0.1627398924375832 + - -0.18123721122539346 + - -0.1803681902724512 + - -0.247559536889046 + - -0.2506988345912087 + - -0.012069856489839355 + - - 0.03898964086288465 + - 0.03094500722331738 + - -0.38373964405311967 + - -0.0705277215507272 + - 0.1177653852310158 + - -0.1105579483799866 + - -0.19736589071870414 + - 0.11345784613887841 + - 0.08554759880845628 + - 0.25028576336592095 + - - -0.38439612321931316 + - 0.4083379401980623 + - 0.02577391248413277 + - 0.3357676686485732 + - 0.15595467768056223 + - 0.24350752960431812 + - -0.331118357474186 + - 0.026624072250718868 + - 0.0469741762786187 + - 0.1417201508528421 + "@version": 1 + activation_function: tanh + bias: true + precision: float64 + resnet: true + use_timestep: true + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.2443577621260153 + - 0.624131229202312 + - -0.0058782517543746185 + - 0.1006152125408876 + - 0.7079714568863247 + - 0.09603584514823546 + - -0.6583908819630836 + - 1.6357423422817616 + - -0.0618699131255024 + - 0.4326396377925964 + idt: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.10967226479267887 + - 0.11114996160896161 + - 0.09000111259693142 + - 0.10890636464120326 + - 0.0883775382571271 + - 0.09102278408813357 + - 0.11021326614143533 + - 0.1098526664530503 + - 0.08886435538847728 + - 0.11066158832294501 + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.17618485260708863 + - 0.045662247265226524 + - -0.5499256594329442 + - 0.34921229907826634 + - -0.04998263541098386 + - 0.02885288947399967 + - -0.036935148275496446 + - 0.03168683104426859 + - -0.18333941288288597 + - -0.2877694783301828 + - - 0.30086376971246415 + - -0.16704806261539829 + - -0.02718284690086246 + - 0.3540665247754029 + - -0.47316920072321395 + - 0.17479351669434584 + - 0.2222871117126889 + - -0.04083715524524929 + - -0.09965086611311093 + - -0.25783386657611146 + - - 0.026827655453065244 + - -0.09119899103973973 + - 0.019197115436614908 + - 0.1507771055095364 + - -0.21686370675957115 + - -0.08131250420897763 + - -0.04988204840585801 + - 0.0874730513357209 + - 0.23496968515864666 + - -0.26002029020484513 + - - -0.021338944609839042 + - 0.09905589528721442 + - 0.41041040462733797 + - -0.0533099757907697 + - -0.021340758072178192 + - 0.3166872698691298 + - -0.15138364509791633 + - 0.11433952007507375 + - 0.02953263811461213 + - -0.013537413535656817 + - - 0.3009126159955464 + - 0.1954788190243878 + - 0.21111462027375286 + - -0.19605908878211573 + - 0.04831200653374466 + - 0.3456478755867381 + - 0.1844910166298473 + - 0.4334051763004547 + - -0.11278043986647618 + - 0.0029415936287763853 + - - 0.4131691255165403 + - 0.008090227469773597 + - 0.02601770026739277 + - 0.057518289846918556 + - 0.045924119256435546 + - -0.1639339084346573 + - -0.016992038540731294 + - 0.39067897412791813 + - -0.15408600681685158 + - -0.3682538203037312 + - - -0.12391815059857987 + - 0.20270400475705594 + - 0.05799313895689829 + - -0.13403630322954838 + - -0.02240264202173418 + - 0.04562362082331838 + - 0.5614711146060805 + - -0.11039776332473819 + - 0.2691093146500228 + - -0.06570941118786719 + - - -0.4507256649511681 + - 0.012534253590173329 + - 0.02695666344894358 + - -0.24904640409426215 + - -0.031153981418388346 + - -0.5520342779112677 + - -0.033871400211516405 + - 0.15977107583011851 + - 0.039250575027518594 + - 0.10224610487637376 + - - 0.19048716887591094 + - 0.3354937447407981 + - 0.14225222638765445 + - 0.1779119193457957 + - 0.1563110490437229 + - -0.10540873992351772 + - -0.19084132513916224 + - 0.13029029243756504 + - 0.22645260563443315 + - 0.02354655421863742 + - - 0.4944264160553193 + - -0.21930927535519323 + - -0.2530454534409796 + - 0.011598014978998455 + - 0.0841671191792054 + - 0.10389936624580148 + - 0.10109843311103231 + - 0.30195061372635124 + - -0.17459900475580292 + - 0.15902741038759083 + "@version": 1 + activation_function: tanh + bias: true + precision: float64 + resnet: true + use_timestep: true + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.3888875707253524 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.011842172174146248 + - - -0.10595673764230339 + - - 0.6574265076935244 + - - 0.07000937566890184 + - - 0.1337010097188838 + - - -0.18983960016519466 + - - 0.39772644612916264 + - - -0.052834839876726186 + - - 0.050494269297535775 + - - -0.36970159517523593 + "@version": 1 + activation_function: none + bias: true + precision: float64 + resnet: false + use_timestep: false + neuron: + - 10 + - 10 + - 10 + out_dim: 1 + precision: float64 + resnet_dt: true + - "@class": FittingNetwork + "@version": 1 + activation_function: tanh + bias_out: true + in_dim: 24 + layers: + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.9000215275312623 + - -0.7227983960332327 + - 1.321515394024242 + - -1.0068830901074755 + - 0.8067888775589227 + - 0.5259057745703067 + - -1.0576301719828265 + - 0.07108081373858936 + - -0.18919061094060166 + - -0.8438625943799956 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.2639494686727459 + - 0.010256088113037042 + - 0.10252432422254726 + - 0.1358424628723544 + - -0.06740107220848378 + - 0.2483353972875281 + - 0.1283761651669193 + - -0.15185866988889724 + - -0.010024301844827593 + - 0.03270748203887378 + - - -0.24135443878774207 + - 0.06680607543652448 + - -0.20526105984848583 + - -0.08170985489575135 + - -0.014980349046510608 + - -0.12251724207605148 + - -0.33215839642502065 + - -0.0893955914470478 + - 0.09685445903892786 + - -0.017194770138568405 + - - -0.15723278542924113 + - 0.3044613549688298 + - 0.019364461806932475 + - 0.06021115227209433 + - 0.14496091559550034 + - 0.07508084172243198 + - -0.0877602205085306 + - 0.07182335283162082 + - 0.035870017932231726 + - 0.38929526764814876 + - - 0.16203882562124738 + - -0.006541279986094555 + - 0.06663819764767388 + - 0.034230902521745404 + - 0.23042498769919634 + - -0.099609321017009 + - -0.5115848193328142 + - -0.10943344489952037 + - 0.09052949992901997 + - -0.09867770244150034 + - - -0.16792948505177765 + - -0.10141216787820605 + - 0.04509922259178674 + - 0.01834823651662897 + - -0.0947941363957327 + - 0.011585898708337426 + - -0.22242196389762037 + - -0.02906434835708795 + - 0.035931171395247034 + - -0.1551391234188593 + - - -0.11776430983031586 + - 0.16368543523909046 + - -0.03382559903040933 + - -0.10158099352181799 + - 0.04226584034972191 + - -0.2800772224818269 + - -0.1107114178030891 + - -0.20171969920755975 + - 0.018747187046783933 + - -0.1956648548513648 + - - -0.009540451682665336 + - -0.04322394377916856 + - 0.05711019844711882 + - -0.17375929192144648 + - 0.08809681311993994 + - -0.05280096463310341 + - -0.16431224282159823 + - -0.166866126039713 + - -0.17435131036511972 + - 0.22597287533447752 + - - 0.0507339892429798 + - 0.09579538728179318 + - 0.038039862708474176 + - 0.0028466082820301723 + - -0.05651059197513672 + - -0.10098375458059926 + - -0.1571779231828787 + - 0.02384123879165128 + - -0.049876844658735084 + - 0.04523498526212328 + - - -0.4447121005124451 + - 0.25891242331505243 + - -0.08492103152745828 + - -0.026516847851845854 + - -0.27500865889520304 + - -0.1352245901963451 + - -0.2816293054397084 + - -0.26771014483176436 + - 0.0394212322338621 + - -0.10990601907118915 + - - -0.27635148046036195 + - -0.0751644443246153 + - 0.04675683081585847 + - -0.2402155597668226 + - -0.1430883558338466 + - 0.123899488905589 + - 0.025577533316584806 + - -0.09807864897434167 + - 0.09903175267746918 + - 0.0684929646793837 + - - -0.12679113232734612 + - 0.016568913234529754 + - -0.04587053800138196 + - 0.011498746771685862 + - -0.17994957845492482 + - 0.09845997049967746 + - -0.17032015097430434 + - 0.09096849426272227 + - -0.18366375790865738 + - -0.01270384812610915 + - - -0.2690129610756614 + - -0.058612835837374276 + - -0.10810671748420293 + - -0.12975485588268149 + - -0.2505978093628489 + - -0.101708409209403 + - 0.18612170716022863 + - 0.29354112140454053 + - 0.1640974741755058 + - -0.21399285226775538 + - - 0.014432650198421848 + - 0.12990120625924478 + - 0.12467157085766334 + - 0.12443424527019636 + - -0.03672542117823035 + - -0.11479058801452288 + - -0.007680294256678681 + - -0.2929250704773807 + - -0.08020608845647081 + - 0.11582168938812817 + - - 0.11728584361060512 + - -0.10601859600331497 + - 0.002794232973850923 + - 0.1627873351259995 + - -0.025293428980551407 + - -0.11724616652140807 + - -0.22680671870617725 + - 0.1099539565411639 + - 0.1776642924177351 + - 0.10371783226753746 + - - 0.3056032976023567 + - -0.3665719924864891 + - 0.08653537156898466 + - 0.0609487406692957 + - 0.13764575247447544 + - 0.16073262532054974 + - 0.11250966014845575 + - 0.02753441289238271 + - -0.015872675462566162 + - 0.10351056173414053 + - - -0.05017533401353259 + - -0.07384519749788616 + - -0.17614423543644447 + - -0.09663927227023622 + - -0.1790259621151453 + - 0.03168735029070646 + - -0.26112613330746653 + - 0.3288292945779274 + - -0.05557832956378767 + - 0.053100216246127895 + - - 0.25194208657331535 + - 0.046802913549916024 + - 0.012827957057808835 + - 0.060730872638830716 + - -0.16661704697368657 + - 0.1784818136494304 + - 0.2819638309602722 + - -0.17565794216032193 + - 0.28692647269535654 + - -0.12885201804513213 + - - -0.14921299976152355 + - -0.15752553954165557 + - 0.04876986997607481 + - -0.1863098849339127 + - 0.09128999377895769 + - 0.008009921667809717 + - -0.3017259463864289 + - -0.3171946274933575 + - 0.10186123663885215 + - -0.4413740947953646 + - - -0.08999848872321205 + - -0.2771120878253297 + - 0.20539242757822895 + - 0.005196131487372434 + - 0.09996012014582731 + - 0.2657724574170448 + - 0.16975588443885592 + - 0.06600840258420763 + - 0.21760144308034488 + - -0.10174231479171901 + - - -0.10630677074119527 + - -0.2025794786472348 + - 0.06086165293497518 + - 0.1521617829360692 + - -0.3400013083716394 + - 0.19892602735923454 + - 0.06371760535565664 + - -0.042456612062881635 + - 0.0608862002079617 + - -0.13855964354657122 + - - 0.07582742023083668 + - -0.12280411037083581 + - -0.2312295729747888 + - 0.09282181630764516 + - -0.11774711035032617 + - -0.31037439415574963 + - -0.19072723466736397 + - 0.13810854925479657 + - 0.06740177744292578 + - -0.23896620577909786 + - - -0.162341629422499 + - 0.04435484266535598 + - 0.12949549440468158 + - -0.16297035472848678 + - 0.08726167365622212 + - 0.04366956611879564 + - -0.027206957390748736 + - 0.1333275950891988 + - -0.2613085150183285 + - -0.20947871852476566 + - - -0.09117864158954347 + - 0.11184338889765945 + - -0.10492660212956549 + - 0.037476509645114925 + - 0.22490240548059126 + - -0.15238707485438738 + - -0.05298925855685219 + - -0.1864744562511453 + - 0.16282101762464754 + - -0.15821562334837078 + - - 0.11064503009817617 + - -0.08129274403273178 + - -0.11562116185307988 + - 0.036873955142837256 + - -0.4232929777513909 + - 0.3379396690789762 + - -0.22790642930236557 + - -0.12393699374270109 + - -0.13795658045985393 + - 0.28580886098129543 + "@version": 1 + activation_function: tanh + bias: true + precision: float64 + resnet: true + use_timestep: false + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.38707932767395387 + - -0.342039075676344 + - -0.7605327261659611 + - 0.11960682112234215 + - 0.22877035392822276 + - 1.2538633129456696 + - 0.3451934629791991 + - 0.057223403845354286 + - 0.41432284297592403 + - -1.057538843832911 + idt: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.11057731344173005 + - 0.09166489068928595 + - 0.08834237103327737 + - 0.08985530004808084 + - 0.09004106819400194 + - 0.110088004003701 + - 0.1102947353036592 + - 0.11122172956610853 + - 0.11222041413258106 + - 0.10895646062165928 + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.22562359760266557 + - 0.6093700510258442 + - -0.3862283030393641 + - 0.13190938086077073 + - 0.16904684054901958 + - 0.34209971229462666 + - 0.21129876725330834 + - -0.2755838784617781 + - -0.00958156785007398 + - -0.0792576938319562 + - - -0.07590864199833239 + - 0.1330950724591729 + - 0.16239841094514196 + - 0.06216113664814958 + - -0.2724028166870573 + - -0.0950820514913243 + - 0.29024492128515733 + - -0.046965745288487834 + - -0.27571817414854705 + - 0.22166174173429865 + - - 0.2035656947004025 + - 0.0031920482731557197 + - -0.42269138021776165 + - 0.2457720240311169 + - -0.3007697984627373 + - -0.2586916704989178 + - 0.011124140548716552 + - 0.28748139157650465 + - -0.07939623462531636 + - -0.142930432028218 + - - 0.18313342574526303 + - -0.10683878239024348 + - 0.169257801643557 + - 0.5262453606874672 + - -0.04879583171885936 + - 0.25763233021925236 + - 0.020272263444042087 + - -0.03904631900691532 + - -0.18716181672868185 + - 0.07230450466148552 + - - -0.30592931443701327 + - -0.1252728151762653 + - 0.11473858127268824 + - -0.06372339464341502 + - -0.14453755741135488 + - 0.2826035720710313 + - -0.19879292892809405 + - 0.006055059682360477 + - -0.2538459959901242 + - 0.10745867501569967 + - - 0.22970647900288899 + - 0.138364153785151 + - -0.07735393462141985 + - 0.07530026753816028 + - 0.07112656939827086 + - -0.2214458770953512 + - -0.4924904487955132 + - -0.1622388614063035 + - -0.197804383645105 + - -0.40774478244528006 + - - -0.17575776197771453 + - 0.12146708364526686 + - -0.034463932487421085 + - 0.1788850492610013 + - -0.05731103348416332 + - 0.06889448330307754 + - 0.24516690736509378 + - 0.3127497159747833 + - -0.030280799327598185 + - 0.09556479648505135 + - - -0.014413020252857104 + - -0.28698490503832147 + - 0.09483994866387635 + - 0.05222910607348386 + - 0.1994789368157623 + - -0.2818741859625111 + - -0.010396515902957145 + - 0.04183557337111259 + - -0.008287492548398551 + - -0.2925052749940528 + - - 0.017256393993720674 + - -0.09336793519057174 + - -0.14184201486390932 + - -0.20290610789388205 + - 0.03808761725237456 + - -0.062149851543097115 + - -0.00864623950442309 + - -0.35021969228968475 + - -0.2052184205698035 + - -0.4436374571930673 + - - 0.05734301760745409 + - -0.03714558201311452 + - -0.4173373754967207 + - -0.08679355824638953 + - -0.07306918814547493 + - 0.19544632145934263 + - 0.497186002866004 + - 0.06458406592447367 + - -0.36548805693881903 + - -0.13124714330887066 + "@version": 1 + activation_function: tanh + bias: true + precision: float64 + resnet: true + use_timestep: true + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.22997830608154382 + - 0.4435768941487313 + - 0.14303044935320117 + - -1.16577110008048 + - 0.4816720053756091 + - -2.9765647280871423 + - 0.5889251249161973 + - -0.38050218573495515 + - 0.18979171907573328 + - 0.45877476495214425 + idt: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.11070143628243294 + - 0.10922602584484076 + - 0.11195094567841188 + - 0.09004675088245527 + - 0.08980017610007925 + - 0.08912480626627803 + - 0.0896492384050733 + - 0.08849818353918305 + - 0.10983521381344624 + - 0.09094846995634104 + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.10108210517328102 + - -0.05232561496473743 + - 0.36668855727973637 + - 0.09257418721057574 + - 0.17191215843944357 + - -0.11483389015076328 + - 0.0656162851585946 + - -0.09176178089980337 + - -0.15664437870164957 + - -0.16346971077648034 + - - 0.02558528376556598 + - 0.5832659444951624 + - -0.37826154374912035 + - -0.2956624104776143 + - 0.08919028399896758 + - 0.15925510299907994 + - 0.28617892614917384 + - -0.1767224311399701 + - -0.3197317172438727 + - 0.3121168642622297 + - - 0.07216872045443962 + - -0.13598929105799168 + - -0.03536008289115259 + - 0.3093828991431095 + - 0.31602137772259464 + - -0.18710495215339262 + - 0.010401609944592771 + - 0.03476012661438127 + - -0.012983716604445061 + - 0.04684740161761361 + - - 0.04247512111775211 + - -0.09371981872684937 + - -0.3663641703790199 + - 0.46223401434724043 + - 0.20247958630383156 + - -0.08317997337403724 + - 0.3634265230232031 + - 0.3458274937618409 + - 0.025543925139465067 + - -0.03106949633698399 + - - 0.040260607731424286 + - 0.009309373996654179 + - 0.10520190131367661 + - 0.11129333178305297 + - -0.12162768486020593 + - -0.12192274836206039 + - 0.11950586018766925 + - 0.2291912405064738 + - -0.10968589986322486 + - -0.29068612133443905 + - - 0.1438823277917827 + - 0.18216007495495068 + - -0.009571238259658254 + - 0.16707448786759976 + - -0.04281652925723783 + - 0.32047875873206777 + - -0.08763911799482689 + - -0.1245252648441373 + - 0.46887113009782105 + - -0.19844629125139612 + - - -0.24011562238772022 + - -0.10204798730286967 + - -0.21125021451360748 + - 0.08315306052287015 + - -0.1793580497394702 + - -0.08407247375098266 + - -0.06370754052005728 + - -0.20060326958664665 + - -0.17948953871230516 + - 0.09669298322007468 + - - -0.13343204571646267 + - 0.18008449100797377 + - -0.46816412675017605 + - -0.05507106060832335 + - -0.16700345090023166 + - 0.15508384962457078 + - 0.017734250652154813 + - -0.1550230036549483 + - -0.02628806339568613 + - 0.049337498067655874 + - - 0.02679767080621389 + - 0.12092724481351536 + - 0.23803055650836435 + - 0.2918643383802174 + - -0.0050387329311884395 + - 0.10526966012765887 + - 0.034679495369277794 + - -0.06543260933885439 + - 0.005063852100596086 + - 0.09933328759193807 + - - -0.2777375125368255 + - -0.23070868272286654 + - 0.6108565671340866 + - 0.06073857349664605 + - -0.024130496263620627 + - -0.016698846402277097 + - 0.020472805597523167 + - -0.16180878068639318 + - -0.16084599014917336 + - -0.3314925208838775 + "@version": 1 + activation_function: tanh + bias: true + precision: float64 + resnet: true + use_timestep: true + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 1.1981756305945983 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.38583843326969747 + - - -0.35536303933445884 + - - -0.3982240653639878 + - - -0.06690466423340803 + - - 0.27802346228262165 + - - -0.2061475291897428 + - - 0.28581111357063477 + - - -0.43468045802907607 + - - -0.3247707927028263 + - - 0.09371781830568159 + "@version": 1 + activation_function: none + bias: true + precision: float64 + resnet: false + use_timestep: false + neuron: + - 10 + - 10 + - 10 + out_dim: 1 + precision: float64 + resnet_dt: true + ntypes: 2 + neuron: + - 10 + - 10 + - 10 + ntypes: 2 + numb_aparam: 0 + numb_fparam: 0 + precision: float64 + rcond: null + resnet_dt: true + spin: null + tot_ener_zero: false + trainable: + - true + - true + - true + - true + type: ener + type_map: + - O + - H + use_aparam_as_mask: false + var_name: energy + pair_exclude_types: [] + preset_out_bias: null + rcond: null + type: standard + type_map: + - O + - H +model_def_script: + _comment4: " that's all" + descriptor: + _comment2: " that's all" + axis_neuron: 2 + neuron: + - 3 + - 6 + - 12 + ntypes: 2 + precision: float64 + rcut: 6.0 + rcut_smth: 0.5 + resnet_dt: false + seed: 1 + sel: + - 46 + - 92 + type: se_e2_a + fitting_net: + _comment3: " that's all" + dim_descrpt: 24 + embedding_width: 24 + mixed_types: false + neuron: + - 10 + - 10 + - 10 + ntypes: 2 + precision: float64 + resnet_dt: true + seed: 1 + type: ener + resuming: true + type_map: + - O + - H +pt_version: 2.5.0+cu124 +software: deepmd-kit +time: "2024-11-12 20:54:08.623904+00:00" +version: 3.0.0b5.dev48+gd46d5f0b7.d20241025