diff --git a/.github/workflows/build-cpu-windows.yml b/.github/workflows/build-cpu-windows.yml index 559a8e04e..22266bf1a 100644 --- a/.github/workflows/build-cpu-windows.yml +++ b/.github/workflows/build-cpu-windows.yml @@ -21,7 +21,7 @@ jobs: run: | mkdir build cd build - cmake -G "Visual Studio 16 2019" -A x64 -DBUILD_SUPERBUILD=ON -DBUILD_SHARED_LIBS=OFF -DBUILD_TARGET=CPU -DBUILD_HPC=OFF -DBUILD_TESTS=ON .. + cmake -G "Visual Studio 16 2019" -A x64 -DBUILD_SUPERBUILD=ON -DBUILD_SHARED_LIBS=OFF -DBUILD_TARGET=CPU -DBUILD_HPC=OFF -DBUILD_TESTS=ON -DBUILD_DIST=OFF -DBUILD_RUNTIME=OFF .. shell: cmd - name: Build run: cmake --build build --config Release diff --git a/.github/workflows/build-cpu.yml b/.github/workflows/build-cpu.yml index da628efe3..7fdd66ad9 100644 --- a/.github/workflows/build-cpu.yml +++ b/.github/workflows/build-cpu.yml @@ -16,12 +16,14 @@ jobs: steps: - uses: actions/checkout@v2 - name: Install utilities - run: sudo apt-get install -y cmake wget graphviz + run: | + sudo apt-get install -y cmake wget graphviz + #sudo apt-get install -y build-essential checkinstall zlib1g-dev libcrypto++-dev libssl-dev - name: Build run: | mkdir build cd build - cmake .. -DBUILD_SUPERBUILD=ON -DBUILD_HPC=OFF -DBUILD_TESTS=ON + cmake .. -DBUILD_SUPERBUILD=ON -DBUILD_HPC=OFF -DBUILD_TESTS=ON -DBUILD_DIST=OFF -DBUILD_RUNTIME=OFF make - name: Test run: | diff --git a/.gitignore b/.gitignore index 9fabfc14a..a4928f0ae 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ src/serialization/onnx/onnx.pb.cc *.DS_Store .idea .vscode +*~ # Build /[Bb]uild* diff --git a/CMakeLists.txt b/CMakeLists.txt index b94564f01..4a71e7f7f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,15 +9,18 @@ cmake_minimum_required(VERSION 3.9.2) option(BUILD_SUPERBUILD "Compile using the superbuild system" OFF) option(BUILD_PROTOBUF "Compile using Protobuf" ON) option(BUILD_OPENMP "Compile using OpenMP" ON) -option(BUILD_HPC "Compile using aggressive flags" ON) +option(BUILD_HPC "Compile using aggressive flags for performance" ON) option(BUILD_TESTS "Compile tests (HPC needs to be disabled)" OFF) # Disable HPC to pass tests (there are numerical errors) option(BUILD_EXAMPLES "Compile examples" ON) -option(BUILD_SHARED_LIBS "Global flag to cause add_library to create shared libraries if on" ON) +option(BUILD_DIST "Compile for a distributed execution" OFF) +option(BUILD_RUNTIME "Compile runtime" OFF) option(BUILD_COVERAGE "Flag to compile for coverage information" OFF) option(BUILD_SANITIZERS "Flag to compile with sanitizers information" OFF) if(WIN32) - option(BUILD_SHARED_LIBS "" OFF) # Prefer lib over dll in windows + option(BUILD_SHARED_LIBS "Global flag to cause add_library to create shared libraries if on" OFF) +else() + option(BUILD_SHARED_LIBS "Global flag to cause add_library to create shared libraries if on" ON) endif() ########################################################################### @@ -144,6 +147,11 @@ if(BUILD_EXAMPLES) add_subdirectory(examples) endif(BUILD_EXAMPLES) +# Build runtime +if(BUILD_RUNTIME) + add_subdirectory(runtime) +endif(BUILD_RUNTIME) + ########################################################################### ########################## INSTALLATION ################################### diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index fd9ec6011..70a91abe4 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -20,11 +20,12 @@ endif() # For development (ignore) option(USE_SYSTEM_GTEST "Use system dependency" OFF) option(USE_SYSTEM_EIGEN "Use system dependency" OFF) -option(USE_SYSTEM_ZLIB "Use system dependency" OFF) option(USE_SYSTEM_PROTOBUF "Use system dependency" OFF) -if(WIN32) - option(USE_SYSTEM_PTHREADS "Use system dependency" OFF) -endif() +option(USE_SYSTEM_ZLIB "Use system dependency" OFF) +option(USE_SYSTEM_OPENSSL "Use system dependency" OFF) +#if(WIN32) +# option(USE_SYSTEM_PTHREADS "Use system dependency" OFF) +#endif() # Set variables set(EDDL_DEPENDENCIES) @@ -73,14 +74,14 @@ if(CMAKE_GENERATOR MATCHES "Visual Studio") list(APPEND config_types ${CMAKE_BUILD_TYPE}) endif() -############## -### GTEST ## -############## +############# +## GTEST ## +############# message(STATUS "Subproject: GTEST...") if(USE_SYSTEM_GTEST) find_package(GTest REQUIRED) else() - # Download and unpack googletest at configure time + # Download and unpack googletest at configure time foreach(config ${config_types}) CONF_PACKAGE(googletest ${config} "") BUILD_PACKAGE(googletest ${config}) @@ -120,31 +121,74 @@ else() CONF_PACKAGE(protobuf ${config} @ONLY) BUILD_PACKAGE(protobuf ${config}) endforeach() - + set(Protobuf_ROOT "${EP_BASE_DIR}/protobuf" PARENT_SCOPE) set(Protobuf_INCLUDE_DIRS "${EP_BASE_DIR}/protobuf/include" PARENT_SCOPE) set(Protobuf_PROTOC_EXECUTABLE "${EP_BASE_DIR}/protobuf/bin/protoc" PARENT_SCOPE) endif() -add_custom_target(protobuf_files +add_custom_target(protobuf_files protoc --cpp_out=../src/serialization/onnx ../src/serialization/onnx/onnx.proto ) message(STATUS "Subproject: Protobuf...DONE") -################## -## PTHREADS4W ## -################## -if(WIN32) - if(USE_SYSTEM_PTHREADS) - set(PTHREADS_INSTALL_PATH "$ENV{PTHREADS_ROOT}" CACHE PATH "Path to the installation of Pthreads under Windows (PTHREADS_ROOT env variable)" PARENT_SCOPE) + +# Only for distributed versions +if(BUILD_DIST) + + ############ + ## ZLIB ## + ############ + message(STATUS "Subproject: ZLIB...") + if(USE_SYSTEM_ZLIB) + find_package(ZLIB REQUIRED) else() - # Download and unpack pthreads4w at configure time - CONF_PACKAGE(pthreads4w ${CMAKE_BUILD_TYPE} @ONLY) - BUILD_PACKAGE(pthreads4w ${CMAKE_BUILD_TYPE}) - - set(PTHREADS_INSTALL_PATH "${EP_BASE_DIR}/PTHREADS-BUILT" PARENT_SCOPE) + # Download and unpack ZLIB at configure time + foreach(config ${config_types}) + CONF_PACKAGE(zlib ${config} "") + BUILD_PACKAGE(zlib ${config}) + endforeach() + # Set variables + SET(ZLIB_ROOT "${EP_BASE_DIR}/zlib" PARENT_SCOPE) + SET(ZLIB_INCLUDE_DIRS "${EP_BASE_DIR}/zlib/include" PARENT_SCOPE) endif() + message(STATUS "Subproject: ZLIB...DONE") + + + ############### + ## OPENSSL ## + ############### + message(STATUS "Subproject: OpenSSL...") + if(USE_SYSTEM_OPENSSL) + find_package(OpenSSL REQUIRED) + else() + # Download and unpack ZLIB at configure time + foreach(config ${config_types}) + CONF_PACKAGE(openssl ${config} "") + BUILD_PACKAGE(openssl ${config}) + endforeach() + # Set variables + SET(OPENSSL_ROOT_DIR "${EP_BASE_DIR}/openssl" PARENT_SCOPE) + SET(OPENSSL_INCLUDE_DIR "${EP_BASE_DIR}/openssl/include" PARENT_SCOPE) + endif() + message(STATUS "Subproject: OpenSSL...DONE") + endif() +################## +## PTHREADS4W ## +################## +#if(WIN32) +# if(USE_SYSTEM_PTHREADS) +# set(PTHREADS_INSTALL_PATH "$ENV{PTHREADS_ROOT}" CACHE PATH "Path to the installation of Pthreads under Windows (PTHREADS_ROOT env variable)" PARENT_SCOPE) +# else() +# # Download and unpack pthreads4w at configure time +# CONF_PACKAGE(pthreads4w ${CMAKE_BUILD_TYPE} @ONLY) +# BUILD_PACKAGE(pthreads4w ${CMAKE_BUILD_TYPE}) +# +# set(PTHREADS_INSTALL_PATH "${EP_BASE_DIR}/PTHREADS-BUILT" PARENT_SCOPE) +# endif() +#endif() + ############### ##### EDDL ## ############### diff --git a/cmake/EDDLConfig.cmake.in b/cmake/EDDLConfig.cmake.in index b978e62db..bdc27c2f8 100644 --- a/cmake/EDDLConfig.cmake.in +++ b/cmake/EDDLConfig.cmake.in @@ -3,14 +3,19 @@ include(CMakeFindDependencyMacro) # VARIABLES SET(USE_SUPERBUILD @BUILD_SUPERBUILD@) # Set in parent +SET(USE_DIST @BUILD_DIST@) # Set in parent + SET(USE_PROTOBUF @BUILD_PROTOBUF@) # Set in parent SET(USE_OPENMP @USE_OPENMP@) # Modified in a subdirectory +SET(USE_DIST @USE_DIST@) # Modified in a subdirectory SET(USE_CUDA @USE_CUDA@) # Modified in a subdirectory SET(USE_CUDNN @USE_CUDNN@) # Modified in a subdirectory SET(USE_FPGA @USE_FPGA@) # Modified in a subdirectory + SET(USE_SYSTEM_EIGEN @USE_SYSTEM_EIGEN@) # Modified in a subdirectory -#SET(USE_SYSTEM_ZLIB @USE_SYSTEM_ZLIB@) # Modified in a subdirectory SET(USE_SYSTEM_PROTOBUF @USE_SYSTEM_PROTOBUF@) # Modified in a subdirectory +SET(USE_SYSTEM_ZLIB ON) # Modified in a subdirectory (the superbuild installation does not include the cmakes to find the library) +SET(USE_SYSTEM_OPENSSL ON) # Modified in a subdirectory (the superbuild installation does not include the cmakes to find the library) SET(EP_BASE_DIR @EP_BASE_DIR@) # Modified in a subdirectory # Threads (linux) @@ -35,6 +40,11 @@ if(USE_FPGA) find_dependency(OpenCL REQUIRED) endif() +# Distributed +if(USE_DIST) + find_dependency(OpenSSL REQUIRED) +endif() + if(USE_SUPERBUILD) list(APPEND CMAKE_MODULE_PATH "${EP_BASE_DIR}") if(NOT USE_SYSTEM_EIGEN) @@ -56,6 +66,24 @@ if(USE_PROTOBUF) endif() endif() +if(USE_DIST) + # Zlib + if(USE_SUPERBUILD AND NOT USE_SYSTEM_ZLIB) + SET(ZLIB_ROOT @ZLIB_ROOT@) # Modified in a subdirectory + find_dependency(ZLIB CONFIG REQUIRED HINTS ${ZLIB_ROOT}) + else() + find_dependency(ZLIB REQUIRED) + endif() + + # OpenSSL + if(USE_SUPERBUILD AND NOT USE_SYSTEM_OPENSSL) + SET(OPENSSL_ROOT_DIR @OPENSSL_ROOT_DIR@) # Modified in a subdirectory + find_dependency(OpenSSL CONFIG REQUIRED HINTS ${OPENSSL_ROOT_DIR}) + else() + find_dependency(OpenSSL REQUIRED) + endif() +endif() + include("${CMAKE_CURRENT_LIST_DIR}/EDDLTargets.cmake") # Set default paths diff --git a/cmake/openssl.CMakeLists.txt.in b/cmake/openssl.CMakeLists.txt.in new file mode 100644 index 000000000..ab9d46529 --- /dev/null +++ b/cmake/openssl.CMakeLists.txt.in @@ -0,0 +1,18 @@ +cmake_minimum_required(VERSION 3.9.2) + +project(openssl-download NONE) + +# Set base dir +SET(EP_BASE_DIR @EP_BASE_DIR@) +SET(CMAKE_BUILD_TYPE @CMAKE_BUILD_TYPE@) + +include(ExternalProject) +ExternalProject_Add(openssl + PREFIX openssl + GIT_REPOSITORY "https://github.com/openssl/openssl.git" + GIT_TAG "OpenSSL_1_1_1i" + SOURCE_DIR "${EP_BASE_DIR}/openssl-src" + BINARY_DIR "${EP_BASE_DIR}/openssl-build" + INSTALL_DIR "${EP_BASE_DIR}/openssl" + CONFIGURE_COMMAND sh ${EP_BASE_DIR}/openssl-src/config --prefix=${EP_BASE_DIR}/openssl + ) diff --git a/docs/markdown/bundle/requirements.txt b/docs/markdown/bundle/requirements.txt index 27db7e64e..54466e0fb 100644 --- a/docs/markdown/bundle/requirements.txt +++ b/docs/markdown/bundle/requirements.txt @@ -13,10 +13,7 @@ ca-certificates 2020.12.5 ha878542_0 conda-forge cairo 1.16.0 h7979940_1007 conda-forge certifi 2020.12.5 py38h578d9bd_1 conda-forge chardet 4.0.0 pypi_0 pypi -cmake 3.19.4 h4547794_0 conda-forge -cudatoolkit 11.0.3 h15472ef_7 conda-forge -cudatoolkit-dev 10.1.243 h516909a_3 conda-forge -cudnn 8.0.5.39 ha5ca753_1 conda-forge +cmake 3.19.4 h3020d66_1 conda-forge docutils 0.16 pypi_0 pypi doxygen 1.9.1 hb166930_0 conda-forge eigen 3.3.7 hc9558a2_1001 conda-forge @@ -51,7 +48,7 @@ libedit 3.1.20191231 he28a2e2_2 conda-forge libev 4.33 h516909a_1 conda-forge libffi 3.3 h58526e2_2 conda-forge libgcc-ng 9.3.0 h2828fa1_18 conda-forge -libglib 2.66.4 h748fe8e_2 conda-forge +libglib 2.66.6 h1f3bc88_3 conda-forge libgomp 9.3.0 h2828fa1_18 conda-forge libiconv 1.16 h516909a_0 conda-forge libidn2 2.3.0 h516909a_0 conda-forge diff --git a/docs/sphinx/source/intro/build-options.rst b/docs/sphinx/source/intro/build-options.rst index 4039223cc..d52401e4e 100644 --- a/docs/sphinx/source/intro/build-options.rst +++ b/docs/sphinx/source/intro/build-options.rst @@ -25,31 +25,31 @@ environment by running the following commands **from the source directory**: .. code:: bash - conda env create -f environment-cpu.yml # -cpu, -gpu, -cudnn + conda env create -f environment.yml conda activate eddl You can also update your environment with: .. code:: bash - conda env update -f environment-cpu.yml # -cpu, -gpu, -cudnn + conda env update -f environment.yml If you decide to manually install these dependencies in your system (make sure they are at standard paths): .. code:: yaml - - cmake>=3.9.2 + - cmake>=3.17.2 - eigen==3.3.7 - protobuf==3.11.4 - - libprotobuf==3.11.4 - - cudnn==8.0.5.39 - - cudatoolkit-dev==10.1.243 - - gtest - - graphviz - - wget - - doxygen - - python - - pip + - libprotobuf==3.11.4 # We need to avoid problems with paths (idk why) + - zlib==1.2.11 + - openssl==1.1.1i + - gtest==1.10.0 + - graphviz==2.42.3 # Build & Run + - wget==1.20.1 + - doxygen==1.9.1 # Docs + - python==3.8.6 + - pip==21.0.1 - pip: - sphinx==3.2.1 - sphinx_rtd_theme==0.5.0 @@ -134,7 +134,7 @@ not found (or CUDA), the EDDL will automatically fallback to CPU. Additional flags ^^^^^^^^^^^^^^^^^ -These flags can enable/disable features of the EDDL so that you can optimized and +These flags can enable/disable features of the EDDL so that you can optimize and troubleshoot the compilation process (see: :doc:`troubleshoot`). @@ -181,8 +181,18 @@ troubleshoot the compilation process (see: :doc:`troubleshoot`). .. note:: - This flag is needed to known which CUDA Toolkit/cuDNN the user wants to use. By default cmake looks in the ``PATH``. + This flag is needed to known which CUDA Toolkit the user wants to use. By default cmake looks in the ``PATH``. +- **CUDNN ROOT DIR:** + +.. code:: bash + + --DCUDNN_ROOT_DIR=/path/to/cuda #/usr/local/cuda + +.. note:: + + This flag is needed to known where to look for the cuDNN libraries. By default cuda is expected to be installed in + along with the CUDA toolkit. - **CUDA host compiler:** @@ -296,8 +306,17 @@ troubleshoot the compilation process (see: :doc:`troubleshoot`). If you want to distribute the resulting shared library, you should use the flag ``-DBUILD_SUPERBUILD=ON`` so that we can make specific tunings to our dependencies. +- **Build distributed:** To let the EDDL work in a distributed mode, use the setting ``BUILD_DIST``: + +.. code:: bash + + -DBUILD_DIST=ON + +.. note:: + + Enabled by default. .. _Anaconda: https://docs.conda.io/en/latest/miniconda.html .. _Eigen3: http://eigen.tuxfamily.org/index.php?title=Main_Page -.. _Requirements: https://github.com/deephealthproject/eddl/blob/develop/docs/markdown/bundle/requirements.txt \ No newline at end of file +.. _Requirements: https://github.com/deephealthproject/eddl/blob/develop/docs/markdown/bundle/requirements.txt diff --git a/docs/sphinx/source/intro/installation.rst b/docs/sphinx/source/intro/installation.rst index 6099b2449..e1a2721fd 100644 --- a/docs/sphinx/source/intro/installation.rst +++ b/docs/sphinx/source/intro/installation.rst @@ -79,7 +79,7 @@ You can also install ``EDDL`` from source with cmake. * C++ compiler - * Anaconda or CMake *(if using* ``-D BUILD_SUPERBUILD=ON`` *)* + * Anaconda * [Optional] CUDA Toolkit 10 or later (to compile for GPU) @@ -99,7 +99,7 @@ You can also install ``EDDL`` from source with cmake. cd eddl/ # Install dependencies - conda env create -f environment-cpu.yml # -cpu, -gpu, -cudnn + conda env create -f environment.yml conda activate eddl # Build and install @@ -124,7 +124,7 @@ You can also install ``EDDL`` from source with cmake. cd eddl/ # Install dependencies - conda env create -f environment-cpu.yml # -cpu, -gpu, -cudnn + conda env create -f environment.yml conda activate eddl # Build and install diff --git a/docs/sphinx/source/intro/troubleshoot.rst b/docs/sphinx/source/intro/troubleshoot.rst index cc3dc3bf7..45405b0ca 100644 --- a/docs/sphinx/source/intro/troubleshoot.rst +++ b/docs/sphinx/source/intro/troubleshoot.rst @@ -16,11 +16,11 @@ Also, if it is using CPU and the library has been compiled for CPU, it could be support AVX instructions. -OpenMP -^^^^^^^^ +OpenMP (MacOS) +^^^^^^^^^^^^^^^ On MacOS, the Clang that comes with XCode doesn't support OpenMP. Therefore, we recommend you to install -the GCC compile so that you can take advatange of OpenMP. +the GCC compile so that you can take advantange of OpenMP. *(Note: By default, GCC is just a symlink to Clang. more_)* .. code:: bash @@ -39,7 +39,32 @@ As a last resort, you can always disable OpenMP and use the EDDL by making use o Undefined symbols for architecture x86_64 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -This error might be due to the CMake cache or a conflict between compilers. +If you cannot compile the EDDL using the distributed mode due to OpenSSL, you might try these things: + +First, make you you have OpenSSL installed: + +- Ubuntu/Debian: ``sudo apt-get install libcrypto++-dev libssl-dev`` +- MacOS: ``brew install openssl`` + +If this does not work, check if the following paths are correctly setup: + +.. code:: bash + + # NOTE: This is a copy-paste from "brew", but for linux should be quite similar. + + If you need to have openssl@1.1 first in your PATH run: + echo 'export PATH="/usr/local/opt/openssl@1.1/bin:$PATH"' >> ~/.zshrc + + For compilers to find openssl@1.1 you may need to set: + export LDFLAGS="-L/usr/local/opt/openssl@1.1/lib" + export CPPFLAGS="-I/usr/local/opt/openssl@1.1/include" + + For pkg-config to find openssl@1.1 you may need to set: + export PKG_CONFIG_PATH="/usr/local/opt/openssl@1.1/lib/pkgconfig" + + +(MacOS) Undefined symbols for architecture x86_64 +-------------------------------------------------- First, try deleting the ``build/`` folder and run ``cmake`` again. If this doesn't work, try forcing a specific compiler either with the flag ``-DCMAKE_CXX_COMPILER`` or by exporting these variables to your environment: @@ -285,4 +310,4 @@ If you want to run it using the conda environment, add: To get the path, activate the environment and type: echo $CONDA_PREFIX -.. _more: https://stackoverflow.com/questions/39979836/using-openmp-with-c11-on-mac-os \ No newline at end of file +.. _more: https://stackoverflow.com/questions/39979836/using-openmp-with-c11-on-mac-os diff --git a/environment-cudnn.yml b/environment-cudnn.yml deleted file mode 100644 index 8ed2f7acc..000000000 --- a/environment-cudnn.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: eddl -channels: - - conda-forge - - defaults - - anaconda -dependencies: - - cmake>=3.9.2 - - eigen==3.3.7 - - protobuf==3.11.4 - - libprotobuf==3.11.4 # We need to avoid problems with paths (idk why) - - cudnn==8.0.5.39 - - cudatoolkit-dev==10.1.243 - - gtest - - graphviz # Build & Run - - wget - - doxygen # Docs - - python - - pip - - pip: - - sphinx==3.2.1 - - sphinx_rtd_theme==0.5.0 - - sphinx-tabs==1.3.0 - - breathe==4.22.1 \ No newline at end of file diff --git a/environment-gpu.yml b/environment-gpu.yml deleted file mode 100644 index ff72fe95a..000000000 --- a/environment-gpu.yml +++ /dev/null @@ -1,22 +0,0 @@ -name: eddl -channels: - - conda-forge - - defaults - - anaconda -dependencies: - - cmake>=3.9.2 - - eigen==3.3.7 - - protobuf==3.11.4 - - libprotobuf==3.11.4 # We need to avoid problems with paths (idk why) - - cudatoolkit-dev==10.1.243 - - gtest - - graphviz # Build & Run - - wget - - doxygen # Docs - - python - - pip - - pip: - - sphinx==3.2.1 - - sphinx_rtd_theme==0.5.0 - - sphinx-tabs==1.3.0 - - breathe==4.22.1 \ No newline at end of file diff --git a/environment-cpu.yml b/environment.yml similarity index 61% rename from environment-cpu.yml rename to environment.yml index dca5b6ab7..12b9cedea 100644 --- a/environment-cpu.yml +++ b/environment.yml @@ -4,16 +4,18 @@ channels: - defaults - anaconda dependencies: - - cmake>=3.9.2 + - cmake>=3.17.2 - eigen==3.3.7 - protobuf==3.11.4 - libprotobuf==3.11.4 # We need to avoid problems with paths (idk why) - - gtest - - graphviz # Build & Run - - wget - - doxygen # Docs - - python - - pip + - zlib==1.2.11 + - openssl==1.1.1i + - gtest==1.10.0 + - graphviz==2.42.3 # Build & Run + - wget==1.20.1 + - doxygen==1.9.1 # Docs + - python==3.8.6 + - pip==21.0.1 - pip: - sphinx==3.2.1 - sphinx_rtd_theme==0.5.0 diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 2737043fd..f1801f42b 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -94,6 +94,16 @@ target_link_libraries(nlp_video_to_labels eddl) add_executable(nlp_text_generation "nn/4_NLP/5_nlp_text_generation.cpp") target_link_libraries(nlp_text_generation eddl) +# TEST INTERNALS **************************************************** + add_executable(test1 "test_internals/test1.cpp") + target_link_libraries(test1 eddl) + add_executable(test2 "test_internals/test2.cpp") + target_link_libraries(test2 eddl) + add_executable(test3 "test_internals/test3.cpp") + target_link_libraries(test3 eddl) + add_executable(test4 "test_internals/test4.cpp") + target_link_libraries(test4 eddl) + # EXAMPLES: Tensor **************************************************** add_executable(tensor_ops "tensor/eddl_ops.cpp") @@ -123,6 +133,9 @@ if(BUILD_PROTOBUF) add_executable(utils_serialization "onnx/4_utils_serialization.cpp") target_link_libraries(utils_serialization eddl) + add_executable(onnx_gradients "onnx/4_onnx_test_gradients.cpp") + target_link_libraries(onnx_gradients eddl) + # EXAMPLES: ONNX MNIST **************************************************** add_executable(onnx_mnist_mlp "onnx/nn/1_mnist/1_mnist_mlp.cpp") target_link_libraries(onnx_mnist_mlp eddl) @@ -197,7 +210,8 @@ if(BUILD_PROTOBUF) add_executable(onnx_drive_seg "onnx/nn/3_drive/1_drive_seg.cpp") target_link_libraries(onnx_drive_seg eddl) - # EXAMPLES: ONNX NLP **************************************************** + +# EXAMPLES: ONNX NLP **************************************************** add_executable(onnx_nlp_sentiment_rnn "onnx/nn/4_NLP/1_nlp_sentiment_rnn.cpp") target_link_libraries(onnx_nlp_sentiment_rnn eddl) diff --git a/examples/nn/1_mnist/14_mnist_losses.cpp b/examples/nn/1_mnist/14_mnist_losses.cpp index 1c60142e1..e6ff845dd 100644 --- a/examples/nn/1_mnist/14_mnist_losses.cpp +++ b/examples/nn/1_mnist/14_mnist_losses.cpp @@ -154,7 +154,7 @@ int main(int argc, char **argv) { fflush(stdout); optimize(dicep); - //optimize({mse,dicep}); + //optimize({dicei,mse}); update(net); diff --git a/examples/nn/3_drive/1_drive_seg.cpp b/examples/nn/3_drive/1_drive_seg.cpp index 483651959..782da5925 100644 --- a/examples/nn/3_drive/1_drive_seg.cpp +++ b/examples/nn/3_drive/1_drive_seg.cpp @@ -38,49 +38,45 @@ layer UNetWithPadding(layer x) int depth=32; - x = LeakyReLu(Conv(x, depth, { 3,3 }, { 1, 1 }, "same")); - x = LeakyReLu(Conv(x, depth, { 3,3 }, { 1, 1 }, "same")); + x = LeakyReLu(BatchNormalization(Conv(x, depth, { 3,3 }, { 1, 1 }, "same"))); + x = LeakyReLu(BatchNormalization(Conv(x, depth, { 3,3 }, { 1, 1 }, "same"))); x2 = MaxPool(x, { 2,2 }, { 2,2 }); - x2 = LeakyReLu(Conv(x2, 2*depth, { 3,3 }, { 1, 1 }, "same")); - x2 = LeakyReLu(Conv(x2, 2*depth, { 3,3 }, { 1, 1 }, "same")); + x2 = LeakyReLu(BatchNormalization(Conv(x2, 2*depth, { 3,3 }, { 1, 1 }, "same"))); + x2 = LeakyReLu(BatchNormalization(Conv(x2, 2*depth, { 3,3 }, { 1, 1 }, "same"))); x3 = MaxPool(x2, { 2,2 }, { 2,2 }); - x3 = LeakyReLu(Conv(x3, 4*depth, { 3,3 }, { 1, 1 }, "same")); - x3 = LeakyReLu(Conv(x3, 4*depth, { 3,3 }, { 1, 1 }, "same")); + x3 = LeakyReLu(BatchNormalization(Conv(x3, 4*depth, { 3,3 }, { 1, 1 }, "same"))); + x3 = LeakyReLu(BatchNormalization(Conv(x3, 4*depth, { 3,3 }, { 1, 1 }, "same"))); x4 = MaxPool(x3, { 2,2 }, { 2,2 }); - x4 = LeakyReLu(Conv(x4, 8*depth, { 3,3 }, { 1, 1 }, "same")); - x4 = LeakyReLu(Conv(x4, 8*depth, { 3,3 }, { 1, 1 }, "same")); + x4 = LeakyReLu(BatchNormalization(Conv(x4, 8*depth, { 3,3 }, { 1, 1 }, "same"))); + x4 = LeakyReLu(BatchNormalization(Conv(x4, 8*depth, { 3,3 }, { 1, 1 }, "same"))); x5 = MaxPool(x4, { 2,2 }, { 2,2 }); - x5 = LeakyReLu(Conv(x5, 8*depth, { 3,3 }, { 1, 1 }, "same")); - x5 = LeakyReLu(Conv(x5, 8*depth, { 3,3 }, { 1, 1 }, "same")); - x5 = Conv(UpSampling(x5, { 2,2 }), 8*depth, { 3,3 }, { 1, 1 }, "same"); - //x5 = Conv(UpSampling(x5, { 2,2 }), 8*depth, { 2,2 }, { 1, 1 }, "same"); + x5 = LeakyReLu(BatchNormalization(Conv(x5, 8*depth, { 3,3 }, { 1, 1 }, "same"))); + x5 = LeakyReLu(BatchNormalization(Conv(x5, 8*depth, { 3,3 }, { 1, 1 }, "same"))); + x5 = BatchNormalization(Conv(UpSampling(x5, { 2,2 }), 8*depth, { 3,3 }, { 1, 1 }, "same")); if (USE_CONCAT) x4 = Concat({x4,x5}); else x4 = Sum(x4,x5); - x4 = LeakyReLu(Conv(x4, 8*depth, { 3,3 }, { 1, 1 }, "same")); - x4 = LeakyReLu(Conv(x4, 8*depth, { 3,3 }, { 1, 1 }, "same")); - x4 = Conv(UpSampling(x4, { 2,2 }), 4*depth, { 3,3 }, { 1, 1 }, "same"); - //x4 = Conv(UpSampling(x4, { 2,2 }), 4*depth, { 2,2 }, { 1, 1 }, "same"); + x4 = LeakyReLu(BatchNormalization(Conv(x4, 8*depth, { 3,3 }, { 1, 1 }, "same"))); + x4 = LeakyReLu(BatchNormalization(Conv(x4, 8*depth, { 3,3 }, { 1, 1 }, "same"))); + x4 = BatchNormalization(Conv(UpSampling(x4, { 2,2 }), 4*depth, { 3,3 }, { 1, 1 }, "same")); if (USE_CONCAT) x3 = Concat({x3,x4}); else x3 = Sum(x3,x4); - x3 = LeakyReLu(Conv(x3, 4*depth, { 3,3 }, { 1, 1 }, "same")); - x3 = LeakyReLu(Conv(x3, 4*depth, { 3,3 }, { 1, 1 }, "same")); - //x3 = Conv(UpSampling(x3, { 2,2 }), 2*depth, { 2,2 }, { 1, 1 }, "same"); - x3 = Conv(UpSampling(x3, { 2,2 }), 2*depth, { 3,3 }, { 1, 1 }, "same"); + x3 = LeakyReLu(BatchNormalization(Conv(x3, 4*depth, { 3,3 }, { 1, 1 }, "same"))); + x3 = LeakyReLu(BatchNormalization(Conv(x3, 4*depth, { 3,3 }, { 1, 1 }, "same"))); + x3 = BatchNormalization(Conv(UpSampling(x3, { 2,2 }), 2*depth, { 3,3 }, { 1, 1 }, "same")); if (USE_CONCAT) x2 = Concat({x2,x3}); else x2 = Sum(x2,x3); - x2 = LeakyReLu(Conv(x2, 2*depth, { 3,3 }, { 1, 1 }, "same")); - x2 = LeakyReLu(Conv(x2, 2*depth, { 3,3 }, { 1, 1 }, "same")); - //x2 = Conv(UpSampling(x2, { 2,2 }), depth, { 2,2 }, { 1, 1 }, "same"); - x2 = Conv(UpSampling(x2, { 2,2 }), depth, { 3,3 }, { 1, 1 }, "same"); + x2 = LeakyReLu(BatchNormalization(Conv(x2, 2*depth, { 3,3 }, { 1, 1 }, "same"))); + x2 = LeakyReLu(BatchNormalization(Conv(x2, 2*depth, { 3,3 }, { 1, 1 }, "same"))); + x2 = BatchNormalization(Conv(UpSampling(x2, { 2,2 }), depth, { 3,3 }, { 1, 1 }, "same")); if (USE_CONCAT) x = Concat({x,x2}); else x = Sum(x,x2); - x = LeakyReLu(Conv(x, depth, { 3,3 }, { 1, 1 }, "same")); - x = LeakyReLu(Conv(x, depth, { 3,3 }, { 1, 1 }, "same")); - x = Conv(x, 1, { 1,1 }); + x = LeakyReLu(BatchNormalization(Conv(x, depth, { 3,3 }, { 1, 1 }, "same"))); + x = LeakyReLu(BatchNormalization(Conv(x, depth, { 3,3 }, { 1, 1 }, "same"))); + x = BatchNormalization(Conv(x, 1, { 1,1 })); return x; } @@ -93,15 +89,16 @@ int main(int argc, char **argv){ // Settings int epochs = 100000; - int batch_size = 2; + int batch_size = 4; ////////////////////////////////////////////////////////////// // Network for Data Augmentation + layer in1=Input({3,584,584}); layer in2=Input({1,584,584}); layer l=Concat({in1,in2}); // Cat image and mask - l= RandomCropScale(l, {0.9f, 1.0f}); // Random Crop and Scale to orig size + //l= RandomCropScale(l, {0.9f, 1.0f}); // Random Crop and Scale to orig size l= CenteredCrop(l,{512,512}); // Crop to work with sizes power 2 layer img=Select(l,{"0:3"}); // UnCat [0-2] image layer mask=Select(l,{"3"}); // UnCat [3] mask @@ -112,7 +109,7 @@ int main(int argc, char **argv){ // Build model for DA build(danet); - toGPU(danet,"low_mem"); // only in GPU 0 with low_mem setup + toGPU(danet,{1,1},10,"low_mem"); summary(danet); ////////////////////////////////////////////////////////////// @@ -121,28 +118,20 @@ int main(int argc, char **argv){ layer out=Sigmoid(UNetWithPadding(in)); model segnet=Model({in},{out}); build(segnet, - adam(0.00001), // Optimizer - {"mse"}, // Losses - {"mse"}, // Metrics - CS_GPU({1}, "low_mem") -// CS_CPU(-1) + adam(0.001), // Optimizer + {"mse"}, // Losses + {"mse"}, // Metrics + CS_GPU({1,1}, 10, "low_mem") ); - // Train on multi-gpu with sync weights every 100 batches: -// toGPU(segnet,{1},100,"low_mem"); // In two gpus, syncronize every 100 batches, low_mem setup summary(segnet); plot(segnet,"segnet.pdf"); ////////////////////////////////////////////////////////////// // Load and preprocess training data - cout<<"Reading train numpy\n"; Tensor* x_train = Tensor::load("drive_trX.bin"); - x_train->info(); x_train->div_(255.0f); - //permute - cout<<"Reading test numpy\n"; Tensor* y_train = Tensor::load("drive_trY.bin"); - y_train->info(); y_train->div_(255.0f); Tensor* xbatch = new Tensor({batch_size,3,584,584}); diff --git a/examples/nn/4_NLP/2_nlp_sentiment_gru.cpp b/examples/nn/4_NLP/2_nlp_sentiment_gru.cpp index 362cf4ec6..27f706b9c 100644 --- a/examples/nn/4_NLP/2_nlp_sentiment_gru.cpp +++ b/examples/nn/4_NLP/2_nlp_sentiment_gru.cpp @@ -61,8 +61,8 @@ int main(int argc, char **argv) { optimizer opt = adam(0.001); //opt->set_clip_val(0.01); - compserv cs = CS_CPU(); - //compserv cs = CS_GPU({1}); // one GPU + //compserv cs = CS_CPU(); + compserv cs = CS_GPU({1}); // one GPU //compserv cs = CS_GPU({1,1},100); // two GPU with weight sync every 100 batches // Build model diff --git a/examples/nn/4_NLP/2_nlp_sentiment_lstm.cpp b/examples/nn/4_NLP/2_nlp_sentiment_lstm.cpp index 139285f6b..70b2ae773 100644 --- a/examples/nn/4_NLP/2_nlp_sentiment_lstm.cpp +++ b/examples/nn/4_NLP/2_nlp_sentiment_lstm.cpp @@ -59,8 +59,8 @@ int main(int argc, char **argv) { optimizer opt=adam(0.001); //opt->set_clip_val(0.01); - compserv cs = CS_CPU(); - //compserv cs = CS_GPU({1}); // one GPU + //compserv cs = CS_CPU(); + compserv cs = CS_GPU({1}); // one GPU //compserv cs = CS_GPU({1,1},100); // two GPU with weight sync every 100 batches // Build model diff --git a/examples/nn/4_NLP/3_nlp_machine_translation.cpp b/examples/nn/4_NLP/3_nlp_machine_translation.cpp index 5f7e4f596..53088bf2e 100644 --- a/examples/nn/4_NLP/3_nlp_machine_translation.cpp +++ b/examples/nn/4_NLP/3_nlp_machine_translation.cpp @@ -43,7 +43,7 @@ int main(int argc, char **argv) { download_eutrans(); // Settings - int epochs = 10; + int epochs = 1; int batch_size = 32; int ilength=30; diff --git a/examples/nn/4_NLP/4_nlp_video_to_labels.cpp b/examples/nn/4_NLP/4_nlp_video_to_labels.cpp index eb5e6f7f8..3562dfb95 100644 --- a/examples/nn/4_NLP/4_nlp_video_to_labels.cpp +++ b/examples/nn/4_NLP/4_nlp_video_to_labels.cpp @@ -41,17 +41,16 @@ int main(int argc, char **argv) { l = ReLu(l); l = Dense(l, 2); layer out = ReLu(l); - model deepVO = Model({in},{out}); + model net = Model({in},{out}); - build(deepVO, + build(net, adam(), {"mse"}, {"mse"}, CS_GPU({1}) -// CS_CPU() ); - plot(deepVO,"model.pdf","TB"); - summary(deepVO); + plot(net,"model.pdf","LR"); + summary(net); // Input: 32 samples that are sequences of 10 3D RGB images of 256x256. Tensor* seqImages = Tensor::randu({32, 10, 3, 10, size, size}); @@ -59,8 +58,9 @@ int main(int argc, char **argv) { // Target: A sequence of 7 samples of 2 values per image Tensor* seqLabels = Tensor::randu({32, 7, 2}); + fit(net, {seqImages}, {seqLabels}, 4, 10); - fit(deepVO, {seqImages}, {seqLabels}, 4, 10); + delete net; return 0; diff --git a/examples/nn/4_NLP/5_nlp_text_generation.cpp b/examples/nn/4_NLP/5_nlp_text_generation.cpp index 419367290..3da03c078 100644 --- a/examples/nn/4_NLP/5_nlp_text_generation.cpp +++ b/examples/nn/4_NLP/5_nlp_text_generation.cpp @@ -35,9 +35,9 @@ layer ResBlock(layer l, int filters,int nconv,int half) { l=ReLu(BatchNormalization(Conv(l,filters,{3,3},{1,1}))); if (half) - return Sum(BatchNormalization(Conv(in,filters,{1,1},{2,2})),l); + return Add(BatchNormalization(Conv(in,filters,{1,1},{2,2})),l); else - return Sum(l,in); + return Add(l,in); } @@ -132,16 +132,16 @@ int main(int argc, char **argv) { // Load dataset Tensor *x_train=Tensor::load("flickr_trX.bin","bin"); - x_train->info(); //1000,256,256,3 + //x_train->info(); //1000,256,256,3 Tensor *xtrain=Tensor::permute(x_train,{0,3,1,2});//1000,3,256,256 Tensor *y_train=Tensor::load("flickr_trY.bin","bin"); - y_train->info(); + //y_train->info(); y_train=onehot(y_train,outvs); y_train->reshape_({y_train->shape[0],olength,outvs}); //batch x timesteps x input_dim - y_train->info(); + //y_train->info(); //load(net,"img2text.bin","bin"); diff --git a/examples/onnx/4_onnx_test_gradients.cpp b/examples/onnx/4_onnx_test_gradients.cpp new file mode 100644 index 000000000..239c25213 --- /dev/null +++ b/examples/onnx/4_onnx_test_gradients.cpp @@ -0,0 +1,164 @@ +/* +* EDDL Library - European Distributed Deep Learning Library. +* Version: 0.2 +* copyright (c) 2019, Universidad Politécnica de Valencia (UPV), PRHLT Research Centre +* Date: October 2019 +* Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) +* All rights reserved +*/ + +#include +#include +#include + + +#include "eddl/apis/eddl.h" + +#include "eddl/serialization/onnx/eddl_onnx.h" // Not allowed + +using namespace eddl; + +////////////////////////////////// +// mnist_mlp.cpp: +// A very basic MLP for mnist +// Using fit for training +////////////////////////////////// + + +int main(int argc, char **argv) { + + // Download mnist + download_mnist(); + + // Settings + int epochs = 1; + int batch_size = 100; + int num_classes = 10; + CompServ* export_CS = CS_GPU({1}); + CompServ* import_CS = CS_GPU({1}); + + // Define network + layer in = Input({784}); + layer l = in; // Aux var + + l=Reshape(l,{1,28, 28}); + l=ReLu(Conv(l,32,{3,3},{1,1})); + //l=BatchNormalization(l, true); + l=MaxPool(l,{2,2}); + l=ReLu(Conv(l,32,{3,3},{1,1})); + //l=BatchNormalization(l, true); + l=MaxPool(l,{2,2}); + + l=Reshape(l,{-1}); + + layer out = Activation(Dense(l, num_classes), "softmax"); + + cout << "Creating model" << endl; + model net = Model({in}, {out}); + cout << "Model created" << endl; + + // Build model + cout << "Building the model" << endl; + build(net, + rmsprop(0.01), // Optimizer + {"soft_cross_entropy"}, // Losses + {"categorical_accuracy"}, // Metrics + export_CS, // Computing service + true // Enable parameters initialization + ); + cout << "Model is correctly built" << endl; + + cout << "Enabling distributed training" << endl; + net->enable_distributed(); + cout << "Distributed training enabled" << endl; + + // Export the net before training + void* serialized_net; + cout << "Serializing net (without training) to pointer" << endl; + size_t model_size = serialize_net_to_onnx_pointer(net, serialized_net, false); + cout << "Net serialized to pointer" << endl; + + // View model + summary(net); + + // Load dataset + Tensor* x_train = Tensor::load("mnist_trX.bin"); + Tensor* y_train = Tensor::load("mnist_trY.bin"); + Tensor* x_test = Tensor::load("mnist_tsX.bin"); + Tensor* y_test = Tensor::load("mnist_tsY.bin"); + + // Preprocessing + x_train->div_(255.0f); + x_test->div_(255.0f); + + // Train model + cout << "Training the first model" << endl; + fit(net, {x_train}, {y_train}, batch_size, epochs); + + // Evaluate + cout << "Evaluating the first model" << endl; + evaluate(net, {x_test}, {y_test}); + + // Export gradients + void* serialized_gradients; + string path("mnist.onnx"); + cout << "Exporting gradients" << endl; + size_t gradients_size = serialize_net_to_onnx_pointer(net, serialized_gradients, true); + cout << "Gradients exported" << endl; + + // Export trained model + void * serialized_net_once_trained; + cout << "Exporting trained weights" << endl; + size_t snot_size = serialize_net_to_onnx_pointer(net, serialized_net_once_trained, false); + cout << "Trained weights exported" << endl; + + // Reset the counter of the layers index + LConv::reset_name_counter(); + LDense::reset_name_counter(); + + // Import net topology without trained weights + cout << "Importing original net topology (without training)" << endl; + Net* imported_net = import_net_from_onnx_pointer(serialized_net, model_size); + cout << "Untrained net imported" << endl; + + // Build model + cout << "Building the loaded topology" << endl; + build(imported_net, + rmsprop(0.01), // Optimizer + {"soft_cross_entropy"}, // Losses + {"categorical_accuracy"}, // Metrics + import_CS, // Computing service + false // Disable parameters initialization + ); + cout << "Model is correctly built" << endl; + + // Resize the net to the desired batch size + imported_net->resize(batch_size); + + // View loaded model + summary(imported_net); + + // Evaluate with untrained model + cout << "Evaluating test with the untrained weights" << endl; + evaluate(imported_net, {x_test}, {y_test}); + + // Apply grads + cout << "Applying grads from training" << endl; + apply_grads_from_onnx_pointer(imported_net, serialized_gradients, gradients_size); + cout << "Grads applied" << endl; + + // Evaluate net with accumulated gradients applied + cout << "Evaluating test after applying gradients" << endl; + evaluate(imported_net, {x_test}, {y_test}); + + // Set trained weights + cout << "Putting the trained weights" << endl; + set_weights_from_onnx_pointer(imported_net, serialized_net_once_trained, snot_size); + cout << "Trained weights set" << endl; + + // Evaluate with trained weights + cout << "Evaluating test after putting the trained weights" << endl; + evaluate(imported_net, {x_test}, {y_test}); + + return 0; +} diff --git a/examples/test_internals/test1.cpp b/examples/test_internals/test1.cpp new file mode 100644 index 000000000..3e6d4601c --- /dev/null +++ b/examples/test_internals/test1.cpp @@ -0,0 +1,151 @@ +/* +* EDDL Library - European Distributed Deep Learning Library. +* Version: 0.9 +* copyright (c) 2020, Universidad Politécnica de Valencia (UPV), PRHLT Research Centre +* Date: November 2020 +* Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) +* All rights reserved +*/ + +#include +#include +#include + +#include "eddl/apis/eddl.h" + + +using namespace eddl; + +// Checking deletes, memory leaks +// CNN models, CPU and GPU +// In a separate terminal try +// top/htop and nvidia-smi (GPU) +// to check memory + +layer BG(layer l) { + return GaussianNoise(BatchNormalization(l),0.3); + //return l; +} + +layer ResBlock(layer l, int filters,int nconv,int half) { + layer in=l; + + if (half) + l=ReLu(BG(Conv(l,filters,{3,3},{2,2}))); + else + l=ReLu(BG(Conv(l,filters,{3,3},{1,1}))); + + + for(int i=0;i +#include +#include + +#include "eddl/apis/eddl.h" + + +using namespace eddl; + +// Checking deletes, memory leaks +// RNN models, CPU and GPU +// In a separate terminal try +// top/htop and nvidia-smi (GPU) +// to check memory + +int main(int argc, char **argv){ + + int times_cpu=100; + int times_gpu=100; + + int ilength=30; + int olength=30; + int invs=687; + int outvs=514; + int embedding=64; + + //CPU + for(int i=0;i +#include +#include + +#include "eddl/apis/eddl.h" + + +using namespace eddl; + +// Checking deletes, memory leaks +// CNN decoder models, CPU, GPU +// In a separate terminal try +// top/htop and nvidia-smi (GPU) +// to check memory + +layer ResBlock(layer l, int filters,int nconv,int half) { + layer in=l; + + if (half) + l=ReLu(BatchNormalization(Conv(l,filters,{3,3},{2,2}))); + else + l=ReLu(BatchNormalization(Conv(l,filters,{3,3},{1,1}))); + + + for(int i=0;iset_clip_val(0.01); + + // Build model + build(net, + opt, // Optimizer + {"softmax_cross_entropy"}, // Losses + {"accuracy"}, // Metrics + CS_CPU() + ); + + + // Load dataset + Tensor *x_train=Tensor::zeros({10,3,256,256}); //batch x input_dim + Tensor *y_train=Tensor::zeros({10,olength,outvs}); //batch x timesteps x ouput_dim + + // to force unrolling + fit(net, {x_train}, {y_train}, 10, 1); + + + delete x_train; + delete y_train; + delete net; + + } + + //GPU + for(int i=0;iset_clip_val(0.01); + + // Build model + build(net, + opt, // Optimizer + {"softmax_cross_entropy"}, // Losses + {"accuracy"}, // Metrics + CS_GPU({1}) + ); + + + // Load dataset + Tensor *x_train=Tensor::zeros({10,3,256,256}); //batch x input_dim + Tensor *y_train=Tensor::zeros({10,olength,outvs}); //batch x timesteps x ouput_dim + + // to force unrolling + fit(net, {x_train}, {y_train}, 10, 1); + + + delete x_train; + delete y_train; + delete net; + + } +} diff --git a/examples/test_internals/test4.cpp b/examples/test_internals/test4.cpp new file mode 100644 index 000000000..3941fdf2e --- /dev/null +++ b/examples/test_internals/test4.cpp @@ -0,0 +1,120 @@ +/* +* EDDL Library - European Distributed Deep Learning Library. +* Version: 0.9 +* copyright (c) 2020, Universidad Politécnica de Valencia (UPV), PRHLT Research Centre +* Date: November 2020 +* Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) +* All rights reserved +*/ + +#include +#include +#include + +#include "eddl/apis/eddl.h" + + +using namespace eddl; + +// Checking deletes, memory leaks +// CNN 3D Synchronous rnn, CPU, GPU +// In a separate terminal try +// top/htop and nvidia-smi (GPU) +// to check memory + +int main(int argc, char **argv){ + + int times_cpu=10; + int times_gpu=10; + + //CPU + for(int i=0;i &ks, const vector &st, const string& p, int mem=0); PoolDescriptor3D(const vector &ks, const vector &st, const vector &p, int mem=0); diff --git a/include/eddl/distributed/distributed_environment.h b/include/eddl/distributed/distributed_environment.h new file mode 100644 index 000000000..afe879505 --- /dev/null +++ b/include/eddl/distributed/distributed_environment.h @@ -0,0 +1,149 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: x.y + * copyright (c) 2020, Universitat Politècnica de València (UPV), PRHLT Research Centre + * Date: August 2020 + * Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) + * All rights reserved + */ + +#ifndef __DISTRIBUTED_ENVIRONMENT_H__ +#define __DISTRIBUTED_ENVIRONMENT_H__ 1 + +#include + +#include +#include +#include + +namespace eddl { + +class DistributedEnvironment +{ +public: + DistributedEnvironment() + { + this->master_ip_addr = ""; + this->master_s_addr = 0; + this->tcp_port = base_tcp_port; + this->udp_data_port = base_udp_data_port; + this->udp_ack_port = base_udp_ack_port; + this->my_ip_addr = ""; + this->my_s_addr = 0; + this->set_multicast_group_addr(eddl_multicast_group_addr); + this->verbose_level = 0; + this->batch_size = 10; + + init_message_type_names(); + } + + std::string get_master_ip_addr() + { + if (this->master_s_addr == 0) + throw std::runtime_error(err_msg("master_ip_addr is not set yet!")); + + return this->master_ip_addr; + } + std::string get_my_ip_addr() + { + if (this->my_s_addr == 0) + throw std::runtime_error(err_msg("my_ip_addr not is not set yet!")); + + return this->my_ip_addr; + } + std::string get_multicast_group_addr() + { + if (this->multicast_s_addr == 0) + throw std::runtime_error(err_msg("multicast_group_addr not is not set yet!")); + + return this->multicast_group_addr; + } + void set_master_ip_addr(std::string s) + { + struct in_addr addr; + int rc = inet_aton(s.c_str(), &addr); + if (rc == 1) { + this->master_ip_addr = s; + this->master_s_addr = addr.s_addr; + } else { + throw std::runtime_error(err_msg("invalid ip addr provided: " + s)); + } + } + void set_my_ip_addr(std::string s) + { + struct in_addr addr; + int rc = inet_aton(s.c_str(), &addr); + if (rc == 1) { + this->my_ip_addr = s; + this->my_s_addr = addr.s_addr; + } else { + throw std::runtime_error(err_msg("invalid ip addr provided: " + s)); + } + } + void set_multicast_group_addr(std::string s) + { + struct in_addr addr; + int rc = inet_aton(s.c_str(), &addr); + if (rc == 1) { + this->multicast_group_addr = s; + this->multicast_s_addr = addr.s_addr; + } else { + throw std::runtime_error(err_msg("invalid ip addr provided: " + s)); + } + } + in_addr_t get_master_s_addr() + { + if (this->master_s_addr == 0) + throw std::runtime_error(err_msg("master_ip_addr is not set yet!")); + + return this->master_s_addr; + } + in_addr_t get_my_s_addr() + { + if (this->my_s_addr == 0) + throw std::runtime_error(err_msg("my_ip_addr is not set yet!")); + + return this->my_s_addr; + } + in_addr_t get_multicast_s_addr() + { + if (this->multicast_s_addr == 0) + throw std::runtime_error(err_msg("multicast_group_addr is not set yet!")); + + return this->multicast_s_addr; + } + + int get_verbose_level() { return this->verbose_level; } + void set_verbose_level(int v) { this->verbose_level = std::max(0,v); } + void increase_verbose_level() { this->verbose_level++; } + + int get_batch_size() { return this->batch_size; } + void set_batch_size(int bs) { this->batch_size = std::max(1,bs); } + + int get_tcp_port() { return this->tcp_port; } + void set_tcp_port(int port_number) { this->tcp_port = port_number; } + + int get_udp_data_port() { return this->udp_data_port; } + void set_udp_data_port(int port_number) { this->udp_data_port = port_number; } + + int get_udp_ack_port() { return this->udp_ack_port; } + void set_udp_ack_port(int port_number) { this->udp_ack_port = port_number; } + + +private: + std::string master_ip_addr; + in_addr_t master_s_addr; + int tcp_port; + int udp_data_port; + int udp_ack_port; + std::string my_ip_addr; + in_addr_t my_s_addr; + int verbose_level; + int batch_size; + std::string multicast_group_addr; + in_addr_t multicast_s_addr; + +}; // of class DistributedEnvironment +}; // of namespace eddl + +#endif // __DISTRIBUTED_ENVIRONMENT_H__ diff --git a/include/eddl/distributed/eddl_distributed.h b/include/eddl/distributed/eddl_distributed.h new file mode 100644 index 000000000..fffeb9289 --- /dev/null +++ b/include/eddl/distributed/eddl_distributed.h @@ -0,0 +1,104 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: x.y + * copyright (c) 2020, Universitat Politècnica de València (UPV), PRHLT Research Centre + * Date: July 2020 + * Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) + * All rights reserved + */ + +#ifndef __EDDL_DISTRIBUTED_H__ +#define __EDDL_DISTRIBUTED_H__ 1 + +#include +#include + +namespace eddl { + +enum eddl_thread_status {INACTIVE, RUNNING, STOPPED}; +enum eddl_worker_status { WORKER_WAITING, + WORKER_RUNNING, + WORKER_STOPPING, + WORKER_TO_SHUTDOWN}; +/* + - a worker starts in the WAITING state + - and can move + + from WAITING to RUNNING by means of a command from the master + + from WAITING to TO_SHUTDOWN by means of a command from the master + + from RUNNING to STOPPING by means of a command from the master + + from STOPPING to WAITING after completing and/or aborting pending tasks + + from STOPPING to TO_SHUTDOWN by means of a command from the master + + - when a worker ends its execution, a new worker process can be launched + automatically if the system is configured to do it, otherwise it must be + launched manually. +*/ + +enum eddl_message_types {DATA_WEIGHTS= 0x00501, + DATA_GRADIENTS= 0x00502, + DATA_SAMPLES= 0x00504, + PARAMETER= 0x00a01, + COMMAND= 0x00a02, + PKG_ACK= 0x00a04, + MSG_CHKSUM= 0x00a08, +// PKG_CHKSUM= 0x00a10, + MSG_ACK_WEIGHTS= 0x00a21, + MSG_ACK_GRADIENTS=0x00a22, + MSG_ACK_SAMPLES= 0x00a24}; + +enum eddl_command_types {START=0x041, STOP=0x042, SHUTDOWN=0x044}; + +enum eddl_worker_modes {FEDERATED_ML=0x011, // no data is accepted from the master + ONE_MASTER=0x022, // only obey to one master that must be specified + ANY_MASTER=0x044 // worker servers to any master if not busy + }; + +static constexpr int base_tcp_port = 3017; ///< port master node will accept connections from worker nodes: 3x17, where x={0..9} +static constexpr int base_udp_data_port = 3011; ///< port master node will send datagrams to worker nodes +static constexpr int base_udp_ack_port = 3013; ///< port master node will receive acknowledgements from worker nodes + +// see https://www.cisco.com/c/dam/en/us/support/docs/ip/ip-multicast/ipmlt_wp.pdf +static std::string eddl_multicast_group_addr("239.193.111.211"); // campus scope +//static std::string eddl_multicast_group_addr("225.1.1.1"); // testing example + +#define next_multiple(_x_,_y_) ((_y_)*(((_x_)/(_y_))+(((_x_)%(_y_))!=0))) +#define prev_multiple(_x_,_y_) ((_y_)*(((_x_)/(_y_)))) + +static constexpr size_t eddl_alignment = 8; ///< alignment in bytes to allocate memory +static constexpr int listen_max_pending = 50; ///< maximum number of connections pending to be accepted by the master node +static constexpr int eddl_checksum_len = 32; ///< SHA256 algorithm is used, whose output is 256 bits (32 bytes) length +static constexpr size_t eddl_msg_id_len = 19; ///< 19=8+3+8 hexadecimal digits, 8 of the IP address, 3 of the message type and 8 of the timestamp in milliseconds +static constexpr size_t _eddl_msg_id_len_ = next_multiple(eddl_msg_id_len,eddl_alignment); ///< next eight-multiple from 19 +static constexpr size_t eddl_default_mtu = 1500; //1536; ///< MTU -- block size for sending/receiving packets +static constexpr size_t eddl_packet_data_size = prev_multiple(eddl_default_mtu + - 4*sizeof(uint32_t) + - _eddl_msg_id_len_ + - 4*sizeof(size_t) + - eddl_checksum_len, 8); + // check this with eddl_packet class definition + +uint64_t get_system_milliseconds(); +std::vector str_split(std::string s, char sep); +std::string get_ip_address(uint32_t s_addr); +std::string pointer_to_string(void * ptr); + +size_t compute_aligned_size(size_t size); +void * eddl_malloc(size_t size); + +std::string compose_log_message(const char * filename, const int line_number, const char * function_name, const char *msg); +std::string compose_log_message(const char * filename, const int line_number, const char * function_name, std::string msg); +#define err_msg(s) compose_log_message(__FILE__,__LINE__,__func__,s) +void print_log_message(const char * filename, const int line_number, const char * function_name, const char *msg); +void print_log_message(const char * filename, const int line_number, const char * function_name, std::string msg); +#define print_log_msg(s) print_log_message(__FILE__,__LINE__,__func__,s) +void print_err_message(const char * filename, const int line_number, const char * function_name, const char *msg); +void print_err_message(const char * filename, const int line_number, const char * function_name, std::string msg); +#define print_err_msg(s) print_err_message(__FILE__,__LINE__,__func__,s) + +void init_message_type_names(); +std::string get_message_type_name(int value); +void show_all_message_type_names(); + +}; + +#endif // __EDDL_DISTRIBUTED_H__ diff --git a/include/eddl/distributed/eddl_message.h b/include/eddl/distributed/eddl_message.h new file mode 100644 index 000000000..0430510f3 --- /dev/null +++ b/include/eddl/distributed/eddl_message.h @@ -0,0 +1,121 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: x.y + * copyright (c) 2020, Universitat Politècnica de València (UPV), PRHLT Research Centre + * Date: July 2020 + * Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) + * All rights reserved + */ + +#ifndef __EDDL_MESSAGE_H__ +#define __EDDL_MESSAGE_H__ 1 + +#include + +#include +#include + +namespace eddl { + +class eddl_message +{ +public: + eddl_message(uint32_t type, + uint32_t source_addr, + uint32_t target_addr, + size_t message_data_size, + size_t packet_data_size, + void * data ); + eddl_message(eddl_packet * packet); + + ~eddl_message(); + + void set_data(size_t message_data_size, void * data); + void set_source_addr(uint32_t source_addr); + void set_target_addr(uint32_t target_addr); + void set_message_id(char * message_id = nullptr); + std::string & get_message_id() { return message_id; } + std::string get_acknowledged_message_id(); + uint32_t get_acknowledged_message_type(); + + inline uint32_t get_type() { return type; } + inline uint32_t get_source_addr() { return source_addr; } + inline uint32_t get_target_addr() { return target_addr; } + inline uint64_t get_timestamp() { return timestamp; } + inline size_t get_seq_len() { return seq_len; } + inline bool is_complete() { return 0 == pending_packets && checksum_has_been_set; } + inline bool was_checksum_already_set() { return checksum_has_been_set; } + inline uint32_t get_message_data_size() { return message_data_size; } + inline std::string get_checksum() { return std::string((char *)checksum, eddl_checksum_len); } + inline unsigned char * get_checksum_ptr() { return checksum; } + inline void * get_data() { return data; } + std::string get_ip_address(); + + uint32_t get_command(); + + void compute_checksum(); + bool is_checksum_valid(); + + void set_checksum(unsigned char * checksum); + void add_packet(eddl_packet * packet); + bool was_packet_already_added(size_t seq_no); + eddl_packet * get_packet(size_t packet_index); + eddl_packet * create_packet_for_checksum(); + + eddl_message * create_acknowledgement(); + + static eddl_message * start_command(uint32_t target_addr) + { + return new eddl_message(eddl_message_types::COMMAND, + 0, target_addr, + 0, eddl_command_types::START, + nullptr); + } + static eddl_message * stop_command(uint32_t target_addr) + { + return new eddl_message(eddl_message_types::COMMAND, + 0, target_addr, + 0, eddl_command_types::STOP, + nullptr); + } + static eddl_message * shutdown_command(uint32_t target_addr) + { + return new eddl_message(eddl_message_types::COMMAND, + 0, target_addr, + 0, eddl_command_types::SHUTDOWN, + nullptr); + } + static eddl_message * acknowledgement(eddl_packet * packet, uint32_t my_s_addr) + { + size_t data[2]; + data[0] = packet->get_seq_no(); + data[1] = packet->get_type(); + + return new eddl_message(eddl_message_types::PKG_ACK, + my_s_addr, + packet->get_source_addr(), // acknowledgement must be sent to the packet emisor + sizeof(data), + sizeof(data), + &data); + } + +private: + uint32_t type; + uint32_t source_addr; + uint32_t target_addr; + uint64_t timestamp; + size_t seq_len; + size_t message_data_size; + size_t packet_data_size; + std::string message_id; + unsigned char checksum[eddl_checksum_len]; + unsigned char * data; + + size_t pending_packets; + bool * received_packet; + bool checksum_has_been_set; +}; + +}; + +#endif // __EDDL_MESSAGE_H__ diff --git a/include/eddl/distributed/eddl_message_acks.h b/include/eddl/distributed/eddl_message_acks.h new file mode 100644 index 000000000..d0317932a --- /dev/null +++ b/include/eddl/distributed/eddl_message_acks.h @@ -0,0 +1,50 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: x.y + * copyright (c) 2020, Universitat Politècnica de València (UPV), PRHLT Research Centre + * Date: July 2020 + * Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) + * All rights reserved + */ + +#ifndef __EDDL_MESSAGE_ACKS_H__ +#define __EDDL_MESSAGE_ACKS_H__ 1 + +#include +#include + +#include +#include +#include + +namespace eddl { + +class eddl_message_acks +{ +public: + eddl_message_acks(std::vector & workers, + eddl_message * message); + ~eddl_message_acks(); + + void acknowledge(uint32_t source_addr, size_t seq_no); + void acknowledge_whole_message(uint32_t source_addr); + bool all_has_been_acknowledged(); + bool packet_already_acknowledged(size_t seq_no); + bool lasting_too_much_time(); + ssize_t get_pending_acknowledgements(); + +private: + std::map acks; + size_t living_workers; + size_t num_acks_per_worker; + size_t total_num_acks; + size_t ack_counter; + + size_t * packet_counters; + + uint64_t starting_timestamp; +}; + +}; + +#endif // __EDDL_MESSAGE_ACKS_H__ diff --git a/include/eddl/distributed/eddl_packet.h b/include/eddl/distributed/eddl_packet.h new file mode 100644 index 000000000..3e2696726 --- /dev/null +++ b/include/eddl/distributed/eddl_packet.h @@ -0,0 +1,85 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: x.y + * copyright (c) 2020, Universitat Politècnica de València (UPV), PRHLT Research Centre + * Date: July 2020 + * Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) + * All rights reserved + */ + +#ifndef __EDDL_PACKET_H__ +#define __EDDL_PACKET_H__ 1 + +#include + +#include +#include + +namespace eddl { + +class eddl_packet +{ +public: + eddl_packet(uint32_t type, + uint32_t source_addr, + uint32_t target_addr, + std::string & message_id, + size_t seq_no, + size_t seq_len, + size_t message_size, + uint32_t all_but_last_packet_size, + size_t data_size, + void * data); + eddl_packet(uint32_t type, + uint32_t source_addr, + uint32_t target_addr, + std::string & message_id, + size_t seq_no, + size_t seq_len, + uint32_t command); + ~eddl_packet(); + + inline uint32_t get_type() { return type; } + inline uint32_t get_source_addr() { return source_addr; } + inline uint32_t get_target_addr() { return target_addr; } + inline size_t get_seq_no() { return seq_no; } + inline size_t get_seq_len() { return seq_len; } + inline void * get_data() { return data; } + + uint32_t get_command(); + + inline unsigned char * get_checksum_ptr() { return checksum; } + inline std::string get_checksum() { return std::string((char *)checksum, eddl_checksum_len); } + std::string get_ip_address(); + + std::string get_message_id() { return std::string(message_id, eddl_msg_id_len); } + inline char * get_message_id_ptr() { return message_id; } + // returns the size in bytes of the whole message this packet belongs to + inline size_t get_message_size() { return message_size; } + // return the size in bytes of the data contained in this packet + inline size_t get_data_size() { return data_size; } + // return the size in bytes of the all the packets of the same message but the last one + inline size_t get_all_but_last_packet_size() { return all_but_last_packet_size; } + + void compute_checksum(); + bool is_checksum_valid(); + + eddl_packet_ack * create_acknowledgement(uint32_t worker_addr); + +private: + uint32_t type; + uint32_t source_addr; + uint32_t target_addr; + uint32_t all_but_last_packet_size; + char message_id[_eddl_msg_id_len_]; + size_t seq_no; + size_t seq_len; + size_t message_size; + size_t data_size; // must be less than or equal to eddl_packet_data_size + unsigned char checksum[eddl_checksum_len]; + unsigned char data[eddl_packet_data_size]; +}; + +}; + +#endif // __EDDL_PACKET_H__ diff --git a/include/eddl/distributed/eddl_packet_ack.h b/include/eddl/distributed/eddl_packet_ack.h new file mode 100644 index 000000000..58478b116 --- /dev/null +++ b/include/eddl/distributed/eddl_packet_ack.h @@ -0,0 +1,48 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: x.y + * copyright (c) 2020, Universitat Politècnica de València (UPV), PRHLT Research Centre + * Date: August 2020 + * Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) + * All rights reserved + */ + +#ifndef __EDDL_PACKET_ACK_H__ +#define __EDDL_PACKET_ACK_H__ 1 + +#include +#include + +#include + +namespace eddl { + +class eddl_packet_ack +{ +public: + eddl_packet_ack(uint32_t source_addr, + uint32_t seq_no, + char * message_id) : + source_addr(source_addr), + seq_no(seq_no) + { + memset(this->message_id, 0, sizeof(this->message_id)); + strncpy(this->message_id, message_id, eddl_msg_id_len); + } + + ~eddl_packet_ack() + { + } + + inline uint32_t get_source_addr() { return this->source_addr; } + inline size_t get_seq_no() { return this->seq_no; } + std::string get_message_id() { return std::string(this->message_id, eddl_msg_id_len); } + inline char * get_message_id_ptr() { return this->message_id; } + +private: + uint32_t source_addr; + uint32_t seq_no; + char message_id[_eddl_msg_id_len_]; +}; +}; +#endif // __EDDL_PACKET_ACK_H__ diff --git a/include/eddl/distributed/eddl_queue.h b/include/eddl/distributed/eddl_queue.h new file mode 100644 index 000000000..b746aba2a --- /dev/null +++ b/include/eddl/distributed/eddl_queue.h @@ -0,0 +1,123 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: x.y + * copyright (c) 2020, Universitat Politècnica de València (UPV), PRHLT Research Centre + * Date: July 2020 + * Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) + * All rights reserved + */ + +#ifndef __EDDL_QUEUE_H__ +#define __EDDL_QUEUE_H__ 1 + +#include +#include +#include + +#include + +namespace eddl { + +class eddl_queue +{ +public: + eddl_queue() {} + + ~eddl_queue() + { + clear(); + } + + void clear() + { + // Critical region starts + std::unique_lock lck(mutex_queue); + + while (! q.empty()) { + eddl_message *m = q.front(); + q.pop(); + delete m; + } + // Critical region ends + } + + void push(eddl_message * message) + { + // Critical region starts + std::unique_lock lck(mutex_queue); + q.push(message); + cond_var.notify_one(); + // Critical region ends + } + + void push_front(eddl_message * message) + { + // Critical region starts + std::unique_lock lck(mutex_queue); + q.push(message); + for (auto i = q.size(); i > 0; i--) { + q.push(q.front()); + q.pop(); + } + cond_var.notify_one(); + // Critical region ends + } + + eddl_message * front() + { + eddl_message * message = nullptr; + // Critical region starts + std::unique_lock lck(mutex_queue); + + if ( q.empty() ) cond_var.wait(lck); + + if ( ! q.empty() ) { + message = q.front(); + } + + return message; + // Critical region ends + } + + eddl_message * pop() + { + eddl_message * message = nullptr; + // Critical region starts + std::unique_lock lck(mutex_queue); + + if ( q.empty() ) cond_var.wait(lck); + + if ( ! q.empty() ) { + message = q.front(); + q.pop(); + } + + return message; + // Critical region ends + } + + size_t size() + { + // Critical region starts + std::unique_lock lck(mutex_queue); + return q.size(); + // Critical region ends + } + + bool empty() + { + // Critical region starts + std::unique_lock lck(mutex_queue); + return q.empty(); + // Critical region ends + } + +private: + std::queue q; // the actual queue + std::mutex mutex_queue; + std::condition_variable cond_var; +}; + +}; + +#endif // __EDDL_QUEUE_H__ diff --git a/include/eddl/distributed/eddl_worker_node.h b/include/eddl/distributed/eddl_worker_node.h new file mode 100644 index 000000000..6aa626614 --- /dev/null +++ b/include/eddl/distributed/eddl_worker_node.h @@ -0,0 +1,75 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: x.y + * copyright (c) 2020, Universitat Politècnica de València (UPV), PRHLT Research Centre + * Date: July 2020 + * Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) + * All rights reserved + */ + +#ifndef __EDDL_WORKER_NODE_H__ +#define __EDDL_WORKER_NODE_H__ 1 + +#include +#include +#include + +#include + +namespace eddl { + +class eddl_worker_node +{ +public: + eddl_worker_node(std::string description); + + inline int get_cpu_cores() { return cpu_cores; } + inline int get_cpu_mem() { return cpu_cores; } + inline int get_gpu_cards() { return gpu_cards; } + inline std::string get_gpu_mem() { return gpu_mem_mode; } + inline int get_fpga_cards() { return fpga_cards; } + inline int get_fpga_mem() { return fpga_mem; } + inline int get_batch_size() { return batch_size; } + + inline void set_batch_size(int b) { batch_size=b; } + + std::string get_ip_address(); + inline uint32_t get_s_addr() { return s_addr; } + + inline bool is_active() { return active; } + void activate() { active=true; } + void deactivate() { active=false; } + + +private: + std::string hostname_or_ip_address; + uint32_t s_addr; + int cpu_cores; + int cpu_mem; // in megas + int gpu_cards; + std::string gpu_mem_mode; // "low_mem", "mid_mem", "full_mem" + int fpga_cards; + int fpga_mem; // in megas + int batch_size; + bool active; + + std::string data_subset; // short description or identifier of the subset assigned to the worker node + + std::queue gradient_timestamps; +}; + +}; + +/* + * computing service example + + ip:192.168.13.11;cpu:2,8192;gpu:1,low_mem;fpga:0,0;batch_size:10; + + this line describes a worker node whose ip address is 192.168.13.11, + from which this task will use 2 cores assuming 8 GB is the total RAM of + the computer, one GPU in low_mem mode will be used, 0 FPGAs are available + with 0 MB of memory, and the batch_size used in the work node by the + train_batch() method will be of 10 samples. + */ + +#endif // __EDDL_WORKER_NODE_H__ diff --git a/include/eddl/distributed/multicast_receiver.h b/include/eddl/distributed/multicast_receiver.h new file mode 100644 index 000000000..dd46a47ae --- /dev/null +++ b/include/eddl/distributed/multicast_receiver.h @@ -0,0 +1,62 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: x.y + * copyright (c) 2020, Universitat Politècnica de València (UPV), PRHLT Research Centre + * Date: July 2020 + * Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) + * All rights reserved + */ + +#ifndef __MulticastReceiver_H__ +#define __MulticastReceiver_H__ 1 + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace eddl { + +class MulticastReceiver +{ +public: + MulticastReceiver(eddl_queue & input_queue, + eddl_queue & ack_queue, + eddl_queue & output_queue, + DistributedEnvironment & distributed_environment); + ~MulticastReceiver(); + + void stop(); + void receiver(); + void send_ack(eddl_packet_ack * ack); + +private: + eddl_queue & input_queue; + eddl_queue & ack_queue; + eddl_queue & output_queue; + DistributedEnvironment & distributed_environment; + + int socket_fd_in; // input socket file descriptor + int socket_fd_out; // output socket file descriptor + int port_number_in; // input port number to use + int port_number_out; // output port number to use + + uint32_t my_s_addr; + + bool receiver_active; + std::thread receiver_thread; + + std::map active_messages; + std::map recently_received_messages; +}; + +}; + +#endif // __MulticastReceiver_H__ diff --git a/include/eddl/distributed/multicast_sender.h b/include/eddl/distributed/multicast_sender.h new file mode 100644 index 000000000..b09ddbc8f --- /dev/null +++ b/include/eddl/distributed/multicast_sender.h @@ -0,0 +1,73 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: x.y + * copyright (c) 2020, Universitat Politècnica de València (UPV), PRHLT Research Centre + * Date: August 2020 + * Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) + * All rights reserved + */ + +#ifndef __MULTICAST_SENDER_H__ +#define __MULTICAST_SENDER_H__ 1 + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace eddl { + +class MulticastSender +{ +public: + MulticastSender(std::vector & workers, + eddl_queue & output_queue, + eddl_queue & ack_queue, + DistributedEnvironment & distributed_environment); + ~MulticastSender(); + + void stop(); + void sender(); + void ack_processor(); + bool send_message(eddl_message * message); + +private: + std::vector & workers; + eddl_queue & output_queue; + eddl_queue & ack_queue; + DistributedEnvironment & distributed_environment; +/* + std::string multicast_group_addr; + int port_number_out; // input port number to use + int port_number_in; // input port number to use +*/ + int socket_fd_in; // input socket file descriptor + int socket_fd_out; // output socket file descriptor + + struct sockaddr_in target_group_addr; + + bool sender_active; + std::thread sender_thread; + std::thread ack_processor_thread; + std::mutex ack_processor_mutex; + + std::map active_acknowledgements; +}; // class MulticastSender + +}; // namespace eddl + +#endif // __MULTICAST_SENDER_H__ diff --git a/include/eddl/distributed/tcp_receiver.h b/include/eddl/distributed/tcp_receiver.h new file mode 100644 index 000000000..b882c9cc8 --- /dev/null +++ b/include/eddl/distributed/tcp_receiver.h @@ -0,0 +1,86 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: x.y + * copyright (c) 2020, Universitat Politècnica de València (UPV), PRHLT Research Centre + * Date: July 2020 + * Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) + * All rights reserved + */ + +#ifndef __TCP_RECEIVER_H__ +#define __TCP_RECEIVER_H__ 1 + +#include +#include +#include +#include + +#include +#include +#include + +namespace eddl { + +class TCP_Receiver +{ +private: + class ActiveThread + { + public: + ActiveThread(int socket_fd, + eddl_queue & input_queue, + eddl_queue & weights_ack_queue, + eddl_queue & generic_ack_queue, + eddl_queue & output_queue, + TCP_Receiver * tcp_receiver); + ~ActiveThread(); + inline eddl_thread_status get_status() { return this->status; } + inline void stop() { this->status=STOPPED; } + inline void disable() { this->status=INACTIVE; } + void join() { this->thread->join(); } + eddl_message * receive_message(); + void thread_receiver(); + + private: + std::thread * thread; + eddl_thread_status status; + int socket_fd; + eddl_queue & input_queue; + eddl_queue & weights_ack_queue; + eddl_queue & generic_ack_queue; + eddl_queue & output_queue; + TCP_Receiver * tcp_receiver; + }; + +public: + TCP_Receiver( eddl_queue & input_queue, + eddl_queue & weights_ack_queue, + eddl_queue & generic_ack_queue, + eddl_queue & output_queue, + DistributedEnvironment & distributed_environment); + ~TCP_Receiver(); + + void stop(); + void drop_stopped(); + void joiner(); + void acceptor(); + +private: + eddl_queue & input_queue; + eddl_queue & weights_ack_queue; + eddl_queue & generic_ack_queue; + eddl_queue & output_queue; + DistributedEnvironment & distributed_environment; + + int socket_fd; // socket file descriptor + bool receiver_active; + std::thread joiner_thread; + std::thread acceptor_thread; + + std::queue active_threads; + std::mutex mutex_active_threads; +}; + +}; + +#endif // __TCP_RECEIVER_H__ diff --git a/include/eddl/distributed/tcp_sender.h b/include/eddl/distributed/tcp_sender.h new file mode 100644 index 000000000..e41c13960 --- /dev/null +++ b/include/eddl/distributed/tcp_sender.h @@ -0,0 +1,59 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: x.y + * copyright (c) 2020, Universitat Politècnica de València (UPV), PRHLT Research Centre + * Date: July 2020 + * Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) + * All rights reserved + */ + +#ifndef __TCP_SENDER_H__ +#define __TCP_SENDER_H__ 1 + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace eddl { + +class TCP_Sender +{ +public: + TCP_Sender(eddl_queue & output_queue, + eddl_queue & ack_queue, + DistributedEnvironment & distributed_environment); + ~TCP_Sender(); + + void stop(); + void sender(); + bool send_message(eddl_message * msg); + void manage_to_send_message(eddl_message * msg); + + void change_status_to(int new_status); + +private: + eddl_queue & output_queue; + eddl_queue & ack_queue; + DistributedEnvironment & distributed_environment; + bool sender_active; + std::thread sender_thread; + eddl_queue queue_of_pending_messages; + std::map sent_messages; + int sender_status; + uint64_t timestamp_last_status_change; + + static constexpr int NORMAL_OPERATION=0; + static constexpr int FAILED_TO_CONNECT=1; + static constexpr int FAILED_TO_WRITE=2; +}; + +}; + +#endif // __TCP_SENDER_H__ diff --git a/include/eddl/hardware/gpu/gpu_hw.h b/include/eddl/hardware/gpu/gpu_hw.h index 8e6cece1d..86290692c 100644 --- a/include/eddl/hardware/gpu/gpu_hw.h +++ b/include/eddl/hardware/gpu/gpu_hw.h @@ -241,16 +241,4 @@ int* get_block_dim(int N, int blockSize); void copy_cpu2gpu(float* cpu_addresses, float* gpu_addresses, int size, bool delete_cpu); void gpu_initialize_rd(ReduceDescriptor2 *rd, Tensor *A, Tensor *B, bool reverse=false); -// new batchnorm implementation -void gpu_batchnorm_forward(int gpu_device, int b, int z, int rc, - float *input, float *output, float *opa, - float *global_mean, float *global_variance, - float *affine_g, float *affine_b, - float *mean, float *variance, - bool trmode, float epsilon, float momentum); - -void gpu_batchnorm_backward(int gpu_device, int b, int z, int rc, - float *delta, float *opa, float *pdelta, float *gbn_g, float *gbn_b, - float *bn_g, float *variance, float *mean1, float *mean2); - #endif //EDDL_GPU_HW_H diff --git a/include/eddl/hardware/gpu/gpu_kernels.h b/include/eddl/hardware/gpu/gpu_kernels.h index f6eff5f7c..a8f5504af 100644 --- a/include/eddl/hardware/gpu/gpu_kernels.h +++ b/include/eddl/hardware/gpu/gpu_kernels.h @@ -195,14 +195,4 @@ __global__ void gpu_not_equal(float *A, float *B, float *C, long int size); __global__ void mask(float *A, float v, long int size); -// new batchnorm implementation -const int batch_norm_block_size = 256; -__global__ void gpu_batchnorm_forward_1(int b, int rc, int rcz, float *input, float *mean, float *variance); -__global__ void gpu_batchnorm_forward_2(int z, float inv_N, float *mean, float *variance, float momentum, float *global_mean, float *global_variance, float epsilon); -__global__ void gpu_batchnorm_forward_3(int b, int rc, int rcz, float *input, float *mean, float *variance, float *affine_g, float *affine_b, float *opa, float *output); - -__global__ void gpu_batchnorm_backward_1(int b, int rc, int rcz, float *delta, float *opa, float *bn_g, float *mean1, float *mean2); -__global__ void gpu_batchnorm_backward_2(int z, float inv_N, float *mean1, float *mean2, float *gbn_g, float *gbn_b, float *bn_g); -__global__ void gpu_batchnorm_backward_3(int b, int rc, int rcz, float *delta, float *opa, float *pdelta, float *mean1, float *mean2, float *variance); - #endif diff --git a/include/eddl/hardware/gpu/nn/gpu_tensor_nn.h b/include/eddl/hardware/gpu/nn/gpu_tensor_nn.h index 4e9fd9c39..a1db5286d 100644 --- a/include/eddl/hardware/gpu/nn/gpu_tensor_nn.h +++ b/include/eddl/hardware/gpu/nn/gpu_tensor_nn.h @@ -117,4 +117,16 @@ void gpu_permute_channels_last(Tensor *A,Tensor *B); void gpu_permute_batch_first(Tensor *A,Tensor *B); void gpu_permute_batch_last(Tensor *A,Tensor *B); +// new batchnorm implementation +void gpu_batchnorm_forward(int gpu_device, int b, int z, int rc, + float *input, float *output, float *opa, + float *global_mean, float *global_variance, + float *affine_g, float *affine_b, + float *mean, float *variance, + bool trmode, float epsilon, float momentum); + +void gpu_batchnorm_backward(int gpu_device, int b, int z, int rc, + float *delta, float *opa, float *pdelta, float *gbn_g, float *gbn_b, + float *bn_g, float *variance, float *mean1, float *mean2); + #endif //EDDL_GPU_TENSOR_NN_H diff --git a/include/eddl/hardware/gpu/nn/gpu_tensor_nn_kernels.h b/include/eddl/hardware/gpu/nn/gpu_tensor_nn_kernels.h index 498032702..56541141e 100644 --- a/include/eddl/hardware/gpu/nn/gpu_tensor_nn_kernels.h +++ b/include/eddl/hardware/gpu/nn/gpu_tensor_nn_kernels.h @@ -108,6 +108,14 @@ __global__ void bn_permute_channels_last(float *src, float *dest,int b,int z,int __global__ void bn_permute_batch_first(float *src, float *dest,int b,int z,int r,int c,long int size); __global__ void bn_permute_batch_last(float *src, float *dest,int b,int z,int r,int c,long int size); - +// new batchnorm implementation +const int batch_norm_block_size = 256; +__global__ void gpu_batchnorm_forward_1(int b, int rc, int rcz, float *input, float *mean, float *variance); +__global__ void gpu_batchnorm_forward_2(int z, float inv_N, float *mean, float *variance, float momentum, float *global_mean, float *global_variance, float epsilon); +__global__ void gpu_batchnorm_forward_3(int b, int rc, int rcz, float *input, float *mean, float *variance, float *affine_g, float *affine_b, float *opa, float *output); + +__global__ void gpu_batchnorm_backward_1(int b, int rc, int rcz, float *delta, float *opa, float *bn_g, float *mean1, float *mean2); +__global__ void gpu_batchnorm_backward_2(int z, float inv_N, float *mean1, float *mean2, float *gbn_g, float *gbn_b, float *bn_g); +__global__ void gpu_batchnorm_backward_3(int b, int rc, int rcz, float *delta, float *opa, float *pdelta, float *mean1, float *mean2, float *variance); #endif diff --git a/include/eddl/layers/conv/layer_conv.h b/include/eddl/layers/conv/layer_conv.h index fe58481ca..50f65e24b 100644 --- a/include/eddl/layers/conv/layer_conv.h +++ b/include/eddl/layers/conv/layer_conv.h @@ -30,7 +30,6 @@ using namespace std; class LConv : public LinLayer { public: static int total_layers; - bool distributed_training; ConvolDescriptor *cd; diff --git a/include/eddl/layers/core/layer_core.h b/include/eddl/layers/core/layer_core.h index 4c0f4aba3..b194a64f1 100644 --- a/include/eddl/layers/core/layer_core.h +++ b/include/eddl/layers/core/layer_core.h @@ -113,7 +113,6 @@ class LDense : public LinLayer { static int total_layers; int ndim; bool use_bias; // TODO: Implement - bool distributed_training; // Params Tensor *W; diff --git a/include/eddl/layers/layer.h b/include/eddl/layers/layer.h index 5dd482b43..8c9c44bcd 100644 --- a/include/eddl/layers/layer.h +++ b/include/eddl/layers/layer.h @@ -45,6 +45,7 @@ class Layer { bool iscloned; bool isnorm; bool isdecoder; + bool distributed_training; vector params; vector gradients; diff --git a/include/eddl/net/compserv.h b/include/eddl/net/compserv.h index 4e0fd80e8..93b9a4bd6 100644 --- a/include/eddl/net/compserv.h +++ b/include/eddl/net/compserv.h @@ -19,8 +19,8 @@ using namespace std; class CompServ { public: - string type; - + string type; // "local" or "distributed" + string hw; //CPU, GPU, FPGA int threads_arg; // The value passed to the constructor int local_threads; @@ -28,6 +28,7 @@ class CompServ { vector local_fpgas; int lsb; //local sync batches bool isshared; + @@ -41,7 +42,7 @@ class CompServ { CompServ(); CompServ * share(); - + CompServ * clone(); // for local CompServ(int threads, const vector g, const vector &f,int lsb=1, int mem=0); diff --git a/include/eddl/net/net.h b/include/eddl/net/net.h index bb2f5aa08..d12c05ff3 100644 --- a/include/eddl/net/net.h +++ b/include/eddl/net/net.h @@ -148,6 +148,8 @@ class Net { void reset_accumulated_gradients(); void apply_accumulated_gradients(); + void collect_acc_grads(); + void distribute_weights(); void sync_weights(); // API diff --git a/runtime/CMakeLists.txt b/runtime/CMakeLists.txt new file mode 100644 index 000000000..1290e5e17 --- /dev/null +++ b/runtime/CMakeLists.txt @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 3.9.2) + +project(eddl-examples) + + +# RUNTIME EXECUTABLES FOR AD HOC DISTRIBUTED VERSION: REQUIRES ONNX +if(BUILD_PROTOBUF) + add_executable(master "distributed/master.cpp") + target_link_libraries(master eddl) + + add_executable(worker "distributed/worker.cpp") + target_link_libraries(worker eddl) + + add_executable(misc_info "distributed/misc_info.cpp") + target_link_libraries(misc_info eddl) +endif() diff --git a/runtime/distributed/master.cpp b/runtime/distributed/master.cpp new file mode 100644 index 000000000..258775872 --- /dev/null +++ b/runtime/distributed/master.cpp @@ -0,0 +1,223 @@ +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +eddl::eddl_queue * global_input_queue; + +void handler_funtion(int parameter) +{ + switch (parameter) { + case SIGUSR1: + global_input_queue->push(eddl::eddl_message::start_command(0)); + break; + case SIGHUP: + global_input_queue->push(eddl::eddl_message::stop_command(0)); + break; + case SIGINT: // also generated from keyboard by CTRL+C in Unix systems + case SIGUSR2: + case SIGTERM: + global_input_queue->push(eddl::eddl_message::shutdown_command(0)); + break; + } + eddl::print_log_msg("signal caught " + std::to_string(parameter)); +} + +int main(int argc, char *argv[]) +{ + eddl::DistributedEnvironment distributed_environment; + distributed_environment.set_my_ip_addr("10.81.25.6"); // socrates.vpn + eddl::eddl_queue input_queue; + eddl::eddl_queue generic_output_queue; + eddl::eddl_queue generic_ack_queue; + eddl::eddl_queue weights_output_queue; + eddl::eddl_queue weights_ack_queue; + std::vector workers; + + for (int i=0; i < argc; i++) { + if (! strcmp(argv[i], "--my-ip-addr")) { + distributed_environment.set_my_ip_addr(argv[++i]); + } else if (! strcmp(argv[i], "--tcp-port")) { + distributed_environment.set_tcp_port(atoi(argv[++i])); + /* + } else if (! strncmp(argv[i], "--mode=", 7)) { + std::vector parts = eddl::str_split(argv[i],'='); + if (parts[1] == "federated_ml") + worker_mode = eddl::eddl_worker_modes::FEDERATED_ML; + else if (parts[1] == "one_master") + worker_mode = eddl::eddl_worker_modes::ONE_MASTER; + else if (parts[1] == "any_master") + worker_mode = eddl::eddl_worker_modes::ANY_MASTER; + else + throw std::runtime_error(eddl::err_msg("unrecognized worker mode")); + */ + } else if (! strcmp(argv[i], "--multicast-group-addr")) { + distributed_environment.set_multicast_group_addr(argv[++i]); + } else if (! strncmp(argv[i], "--verbose=", 10)) { + std::vector parts = eddl::str_split(argv[i],'='); + distributed_environment.set_verbose_level(std::stoi(parts[1])); + } else if (! strcmp(argv[i], "--verbose")) { + distributed_environment.increase_verbose_level(); + } + } + + workers.push_back(new eddl::eddl_worker_node("ip:10.81.25.1;cpu:4,8192;gpu:0,low_mem;fpga:0,0;batch_size:10")); + + /* + tcp_receiver pushes: + 1. weights acknowledgements from workers into the weights_ack_queue + 2. other acknowledgements from workers into the generic_ack_queue + 3. acknowledgements created by the master when a data message is complete into the generic_output_queue + 4. messages which are not acknowledgements into the input_queue + */ + eddl::TCP_Receiver tcp_receiver( input_queue, + weights_ack_queue, + generic_ack_queue, + generic_output_queue, + distributed_environment); + /* + tcp_sender pops: + 1. output messages from the generic_output_queue + 2. other acknowledgements sent by workers from the generic_ack_queue + + */ + eddl::TCP_Sender tcp_sender( generic_output_queue, + generic_ack_queue, + distributed_environment); + /* + multicast_sender pops: + 1. output weight data messages from the weights_output_queue + 2. weights acknowledgements sent by workers from the weights_ack_queue + */ + eddl::MulticastSender multicast_sender( workers, + weights_output_queue, + weights_ack_queue, + distributed_environment); + + //////////////////////////////////////////////////////////////// + unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); + std::mt19937_64 generator_1(seed); // mt19937 is a standard mersenne_twister_engine + std::mt19937_64 generator_2(seed + 1); // mt19937 is a standard mersenne_twister_engine + std::mt19937_64 generator_3(seed + 2); // mt19937 is a standard mersenne_twister_engine + std::bernoulli_distribution bernoulli(0.3); + std::uniform_int_distribution dist_sizes(100,50*1024*1024); + std::uniform_int_distribution dist_content(0,255); + //////////////////////////////////////////////////////////////// + + //////////////////////////////////////////////////////////////// + global_input_queue = & input_queue; + // setting sigaction structure: begin + static struct sigaction _action; + _action.sa_sigaction = nullptr; + _action.sa_handler = handler_funtion; + for (unsigned long int i=0; i < _SIGSET_NWORDS; i++) + _action.sa_mask.__val[i] = 0; + _action.sa_flags = 0; + // setting sigaction structure: end + sigaction(SIGHUP, &_action, nullptr); + sigaction(SIGINT, &_action, nullptr); + sigaction(SIGTERM, &_action, nullptr); + sigaction(SIGUSR1, &_action, nullptr); + sigaction(SIGUSR2, &_action, nullptr); + //////////////////////////////////////////////////////////////// + + //////////////////////////////////////////////////////////////// + eddl::print_log_msg("ready to receive messages"); + + bool master_active = true; + bool shutdown_master = false; + + for (auto w: workers) + generic_output_queue.push(eddl::eddl_message::start_command(w->get_s_addr())); + + while (! shutdown_master) { + eddl::eddl_message * msg = input_queue.pop(); + if (nullptr == msg) continue; + + std::cout << "received message: " + << std::hex << msg->get_message_id() << " " + << std::hex << msg->get_type() << " " + << std::dec << msg->get_timestamp() << " " + << std::dec << msg->get_message_data_size() + << " bytes" << std::endl; + + if (msg->get_type() == eddl::eddl_message_types::COMMAND) { + switch (msg->get_command()) { + case eddl::eddl_command_types::START: + master_active = true; + for (auto w: workers) + generic_output_queue.push(eddl::eddl_message::start_command(w->get_s_addr())); + break; + case eddl::eddl_command_types::STOP: + master_active = false; + for (auto w: workers) + generic_output_queue.push(eddl::eddl_message::stop_command(w->get_s_addr())); + break; + case eddl::eddl_command_types::SHUTDOWN: + master_active = false; + for (auto w: workers) + generic_output_queue.push(eddl::eddl_message::stop_command(w->get_s_addr())); + shutdown_master = true; + generic_output_queue.clear(); + weights_output_queue.clear(); + // first send a shutdown command to ensure tcp receivers change status to stop + for (auto w: workers) + generic_output_queue.push(eddl::eddl_message::shutdown_command(w->get_s_addr())); + // a first packet to myself to make my tcp receiver be aware of stopping + generic_output_queue.push(eddl::eddl_message::shutdown_command(distributed_environment.get_my_s_addr())); + // a packet to unlock multicast receiver threads + weights_output_queue.push(eddl::eddl_message::shutdown_command(0)); + // second send a shutdown command to unlock the acceptor thread of tcp receivers + for (auto w: workers) + generic_output_queue.push(eddl::eddl_message::shutdown_command(w->get_s_addr())); + // a second packet to myself to stop my tcp receiver + generic_output_queue.push(eddl::eddl_message::shutdown_command(distributed_environment.get_my_s_addr())); + // wait a little bit + std::this_thread::sleep_for(std::chrono::seconds(1)); + tcp_sender.stop(); + tcp_receiver.stop(); + multicast_sender.stop(); + break; + } + } + + delete msg; + + if (master_active && weights_output_queue.size() < 10 && bernoulli(generator_1)) { + size_t size_in_bytes = dist_sizes(generator_2); + unsigned char * data = new unsigned char [size_in_bytes]; + for (size_t i=0; i < size_in_bytes; i++) data[i] = (unsigned char)dist_content(generator_3); + + msg = new eddl::eddl_message(eddl::eddl_message_types::DATA_WEIGHTS, + 0, 0, + size_in_bytes, eddl::eddl_packet_data_size, data); + + delete [] data; + weights_output_queue.push(msg); + } + + std::cout << " |input_queue| = " << input_queue.size() + << " |generic_output_queue| = " << generic_output_queue.size() + << " |weights_output_queue| = " << weights_output_queue.size() + << " |generic_ack_queue| = " << generic_ack_queue.size() + << " |weights_ack_queue| = " << weights_ack_queue.size() + << std::endl; + } + + for (auto w : workers) delete w; + workers.clear(); + + eddl::print_log_msg("master main thread ready to finish when threads stop"); + + return 0; +} diff --git a/runtime/distributed/misc_info.cpp b/runtime/distributed/misc_info.cpp new file mode 100644 index 000000000..53e93d493 --- /dev/null +++ b/runtime/distributed/misc_info.cpp @@ -0,0 +1,51 @@ + +#include +#include +#include +#include + +#include +#include + +using namespace eddl; + +int main(int argc, char *argv[]) +{ + std::cout + << " sizeof(eddl_packet) = " << sizeof(eddl_packet) << std::endl + << " sizeof(eddl_message) = " << sizeof(eddl_message) << std::endl + << " eddl_default_mtu = " << eddl_default_mtu << std::endl + << " eddl_packet_data_size = " << eddl_packet_data_size << std::endl + << " eddl_checksum_len = " << eddl_checksum_len << std::endl + << " eddl_msg_id_len = " << eddl_msg_id_len << std::endl + << " _eddl_msg_id_len_ = " << _eddl_msg_id_len_ << std::endl + << " eddl_alignment = " << eddl_alignment << std::endl + << " listen_max_pending = " << listen_max_pending << std::endl + << " base_tcp_port = " << base_tcp_port << std::endl + << " base_udp_data_port = " << base_udp_data_port << std::endl + << " base_udp_ack_port = " << base_udp_ack_port << std::endl + << std::endl; + + char buffer[1024*1024]; + eddl_message * message = new eddl_message(DATA_GRADIENTS, 0, 0, 1024*1024, eddl_packet_data_size, buffer); + + std::cout + << " message->get_message_data_size() = " << message->get_message_data_size() << std::endl + << " message->get_seq_len() = " << message->get_seq_len() << std::endl + << std::endl; + + eddl_packet * packet = message->get_packet(0); + + std::cout + << " packet->get_message_size() = " << packet->get_message_size() << std::endl + << " packet->get_data_size() = " << packet->get_data_size() << std::endl + << std::endl; + + delete packet; + delete message; + + init_message_type_names(); + show_all_message_type_names(); + + return EXIT_SUCCESS; +} diff --git a/runtime/distributed/worker.cpp b/runtime/distributed/worker.cpp new file mode 100644 index 000000000..21b2cff4d --- /dev/null +++ b/runtime/distributed/worker.cpp @@ -0,0 +1,178 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +int main(int argc, char *argv[]) +{ + eddl::DistributedEnvironment distributed_environment; + eddl::eddl_worker_modes worker_mode = eddl::eddl_worker_modes::ANY_MASTER; + + // manual settings for testing + distributed_environment.set_my_ip_addr("10.81.25.1"); // platon.vpn + distributed_environment.set_master_ip_addr("10.81.25.6"); // socrates.vpn + + for (int i=0; i < argc; i++) { + if (! strcmp(argv[i], "--my-ip-addr")) { + distributed_environment.set_my_ip_addr(argv[++i]); + } else if (! strcmp(argv[i], "--server")) { + distributed_environment.set_master_ip_addr(argv[++i]); + } else if (! strcmp(argv[i], "--tcp-port")) { + distributed_environment.set_tcp_port(atoi(argv[++i])); + } else if (! strncmp(argv[i], "--mode=", 7)) { + std::vector parts = eddl::str_split(argv[i],'='); + if (parts[1] == "federated_ml") + worker_mode = eddl::eddl_worker_modes::FEDERATED_ML; + else if (parts[1] == "one_master") + worker_mode = eddl::eddl_worker_modes::ONE_MASTER; + else if (parts[1] == "any_master") + worker_mode = eddl::eddl_worker_modes::ANY_MASTER; + else + throw std::runtime_error(eddl::err_msg("unrecognized worker mode")); + + } else if (! strcmp(argv[i], "--multicast-group-addr")) { + distributed_environment.set_multicast_group_addr(argv[++i]); + } else if (! strncmp(argv[i], "--verbose=", 10)) { + std::vector parts = eddl::str_split(argv[i],'='); + distributed_environment.set_verbose_level(std::stoi(parts[1])); + } else if (! strcmp(argv[i], "--verbose")) { + distributed_environment.increase_verbose_level(); + } + } + + unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); + std::mt19937_64 generator_1(seed); // mt19937 is a standard mersenne_twister_engine + std::mt19937_64 generator_2(seed + 3); // mt19937 is a standard mersenne_twister_engine + std::uniform_int_distribution dist_sizes(100,50*1024*1024); + std::uniform_int_distribution dist_content(0,255); + + eddl::eddl_queue input_queue; + eddl::eddl_queue output_queue; + eddl::eddl_queue ack_queue; + + eddl::TCP_Receiver tcp_receiver( input_queue, + ack_queue, + ack_queue, + output_queue, + distributed_environment); + eddl::TCP_Sender tcp_sender( output_queue, + ack_queue, + distributed_environment); + eddl::MulticastReceiver multicast_receiver( input_queue, + ack_queue, + output_queue, + distributed_environment); + + eddl::eddl_message * message = nullptr; + eddl::eddl_worker_status worker_status = eddl::eddl_worker_status::WORKER_WAITING; + int seconds_to_wait_while_waiting = 1; + int iterations_waiting = 0; + + while (worker_status != eddl::eddl_worker_status::WORKER_TO_SHUTDOWN) { + + // independently of the status the input queue must be processed + if (! input_queue.empty()) { + message = input_queue.pop(); + std::cout << "received message: " + << std::hex << message->get_message_id() << " " + << std::hex << message->get_type() << " " + << std::dec << message->get_timestamp() << " " + << std::dec << message->get_message_data_size() + << " bytes" << std::endl; + + if (message->get_type() == eddl::eddl_message_types::COMMAND) { + switch (message->get_command()) { + case eddl::eddl_command_types::START: + if (worker_status == eddl::eddl_worker_status::WORKER_WAITING) { + worker_status = eddl::eddl_worker_status::WORKER_RUNNING; + seconds_to_wait_while_waiting = 1; + iterations_waiting = 0; + } + break; + case eddl::eddl_command_types::STOP: + if (worker_status == eddl::eddl_worker_status::WORKER_RUNNING) { + worker_status = eddl::eddl_worker_status::WORKER_STOPPING; + } + break; + case eddl::eddl_command_types::SHUTDOWN: + worker_status = eddl::eddl_worker_status::WORKER_TO_SHUTDOWN; + break; + } + } + delete message; + } + + switch (worker_status) { + case eddl::eddl_worker_status::WORKER_RUNNING: + if (output_queue.size() < 10) { + size_t size_in_bytes = dist_sizes(generator_1); + unsigned char * data = new unsigned char [size_in_bytes]; + for (size_t i=0; i < size_in_bytes; i++) + data[i] = (unsigned char)dist_content(generator_1); + + message = new eddl::eddl_message( + eddl::eddl_message_types::DATA_GRADIENTS, + 0, // source addr will be set by the sender thread + distributed_environment.get_master_s_addr(), + size_in_bytes, + eddl::eddl_packet_data_size, + data); + delete [] data; + output_queue.push(message); + } else { + /* + sleep 1 second to avoid having too many messages in the + output queue; this should be reviewed in the real + worker implementation + */ + std::this_thread::sleep_for(std::chrono::seconds(1)); + } + break; + + case eddl::eddl_worker_status::WORKER_STOPPING: + // does not perform new actions like sending new messages + if (output_queue.empty()) + worker_status = eddl::eddl_worker_status::WORKER_WAITING; + else + std::this_thread::sleep_for(std::chrono::seconds(1)); + break; + + case eddl::eddl_worker_status::WORKER_TO_SHUTDOWN: + output_queue.clear(); + tcp_sender.stop(); + tcp_receiver.stop(); + multicast_receiver.stop(); + break; + + case eddl::eddl_worker_status::WORKER_WAITING: + std::cout << "worker inactive waiting for " + << seconds_to_wait_while_waiting + << " second(s)." << std::endl; + std::this_thread::sleep_for(std::chrono::seconds(seconds_to_wait_while_waiting)); + if (++iterations_waiting >= 10) + seconds_to_wait_while_waiting = 10; + break; + } // of switch + + if (distributed_environment.get_verbose_level() >= 1) + std::cout << " |input_queue| = " << input_queue.size() + << " |output_queue| = " << output_queue.size() + << " |ack_queue| = " << ack_queue.size() + << std::endl; + } // of while worker_status + + eddl::print_log_msg("worker main thread ready to finish when threads stop"); + + return EXIT_SUCCESS; +} diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6825bce9c..e6229c21c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -8,7 +8,7 @@ set(BUILD_TARGET "CUDNN" CACHE STRING "Compile library for {CPU, GPU, CUDNN, FPG option(OVERWRITE_PROTO_FILES "Overwrite Protobuf files (requires a compatible Protobuf compiler)" ON) # Double checks (restricted args) -set_property(CACHE BUILD_TARGET PROPERTY STRINGS CPU GPU FPGA) +set_property(CACHE BUILD_TARGET PROPERTY STRINGS CPU GPU CUDNN FPGA) # Initializations (Local's scope) SET(USE_OPENMP OFF) @@ -35,9 +35,16 @@ set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build, options a ########################################################################### -############################### SANITY CHECKS ############################## +############################### SANITY CHECKS ############################# ########################################################################### +# Prefer static libraries +# There are problems since not all static libraries have been compile with -fPIC +#IF(WIN32) +# SET(CMAKE_FIND_LIBRARY_SUFFIXES .lib .a ${CMAKE_FIND_LIBRARY_SUFFIXES}) +#ELSE(WIN32) +# SET(CMAKE_FIND_LIBRARY_SUFFIXES .a ${CMAKE_FIND_LIBRARY_SUFFIXES}) +#ENDIF(WIN32) ########################################################################### ############################### GET FILES ################################# @@ -52,6 +59,12 @@ file(GLOB_RECURSE CPP_SOURCES "${PROJECT_SOURCE_DIR}/src/*" *.{cc, cpp}) list(FILTER CPP_HEADERS EXCLUDE REGEX ".*/src/serialization/onnx/onnx.pb*") list(FILTER CPP_SOURCES EXCLUDE REGEX ".*/src/serialization/onnx/onnx.pb*") +# Remove problematic files if they are not needed +if(NOT BUILD_DIST) + list(FILTER PUBLIC_HEADERS EXCLUDE REGEX ".*/src/distributed/*") + list(FILTER CPP_HEADERS EXCLUDE REGEX ".*/src/distributed/*") + list(FILTER CPP_SOURCES EXCLUDE REGEX ".*/src/distributed/*") +endif() SET(ALL_FILES ${PUBLIC_HEADERS} ${CPP_HEADERS} ${CPP_SOURCES}) @@ -75,6 +88,7 @@ FOREACH(item ${SPECIAL_FILES}) ENDFOREACH(item) + ########################################################################### ############################# SET LIBRARY ################################# ########################################################################### @@ -111,17 +125,22 @@ find_package(Eigen3 3.3 REQUIRED NO_MODULE) # EIGEN_DIR => ok target_include_directories(${PROJECT_NAME} PRIVATE ${EIGEN3_INCLUDE_DIRS}) target_link_libraries(${PROJECT_NAME} PUBLIC Eigen3::Eigen) # Header only library - # ONNX files if(BUILD_PROTOBUF) add_definitions(-DcPROTO) + # Link library - if(NOT Protobuf_ROOT) - find_package(Protobuf REQUIRED) # Problems with: Protobuf_ROOT + if(Protobuf_ROOT) + # Find libraries (need absolute paths) + find_library(Protobuf_LIBRARY NAMES protobuf libprotobuf HINTS ${Protobuf_ROOT} PATHS ${Protobuf_ROOT} PATH_SUFFIXES "lib" "lib64") + find_library(Protobuf_LIBRARY_DEBUG NAMES protobufd libprotobufd HINTS ${Protobuf_ROOT} PATHS ${Protobuf_ROOT} PATH_SUFFIXES "lib" "lib64") + find_library(Protobuf_LIBRARY_RELEASE NAMES protobuf libprotobuf HINTS ${Protobuf_ROOT} PATHS ${Protobuf_ROOT} PATH_SUFFIXES "lib" "lib64") + else() + find_package(Protobuf) # Problems with: Protobuf_ROOT # Check if Protobuf was really found if(NOT Protobuf_FOUND) - message(FATAL_ERROR "Protobuf was found by CMake but its libraries or includes are missing. + message(FATAL_ERROR "Protobuf was found by CMake. Use '-D BUILD_SUPERBUILD=ON', or try with a different Protobuf installation to fix this problem. Alternatively, you can disable it with '-D BUILD_PROTOBUF=OFF") endif() @@ -133,13 +152,12 @@ if(BUILD_PROTOBUF) # Add includes target_include_directories(${PROJECT_NAME} PUBLIC $) - # Find libraries (need absolute paths) - find_library(Protobuf_LIBRARY NAMES protobuf HINTS ${Protobuf_ROOT} PATHS ${Protobuf_ROOT} PATH_SUFFIXES "lib" "lib64") - find_library(Protobuf_LIBRARY_DEBUG NAMES protobufd HINTS ${Protobuf_ROOT} PATHS ${Protobuf_ROOT} PATH_SUFFIXES "lib" "lib64") - find_library(Protobuf_LIBRARY_RELEASE NAMES protobuf HINTS ${Protobuf_ROOT} PATHS ${Protobuf_ROOT} PATH_SUFFIXES "lib" "lib64") - # Add libraries - target_link_libraries(${PROJECT_NAME} PUBLIC ${Protobuf_LIBRARY}) + if(MSVC) + target_link_libraries(${PROJECT_NAME} PUBLIC optimized ${Protobuf_LIBRARY} debug ${Protobuf_LIBRARY_DEBUG}) + else() + target_link_libraries(${PROJECT_NAME} PUBLIC ${Protobuf_LIBRARY}) + endif() # Create "onnx.pb.cc" and "onnx.pb.h" files (from serialization/onnx) # Equivalent to: /usr/local/bin/protoc --cpp_out . onnx.proto @@ -191,6 +209,8 @@ SET(USE_OPENMP ${USE_OPENMP} PARENT_SCOPE) # Parent's scope # CUDA if(USE_CUDA) + cmake_minimum_required(VERSION 3.17.2) # Due to CUDAToolkit + # Check if cuda is available include(CheckLanguage) check_language(CUDA) @@ -230,15 +250,27 @@ if(USE_CUDA) # Add includes target_include_directories(${PROJECT_NAME} PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) -# # Find libraries (need absolute paths) -# find_library(CUBLAS_LIBRARY cublas HINTS ${CUDA_TOOLKIT_ROOT_DIR}) -# find_library(CUDART_LIBRARY cudart HINTS ${CUDA_TOOLKIT_ROOT_DIR}) -# find_library(CURAND_LIBRARY curand HINTS ${CUDA_TOOLKIT_ROOT_DIR}) -# target_link_libraries(${PROJECT_NAME} PRIVATE ${CUBLAS_LIBRARY} ${CUDART_LIBRARY} ${CURAND_LIBRARY}) + # Add libraries target_link_libraries(${PROJECT_NAME} PRIVATE CUDA::cublas CUDA::cudart CUDA::curand) if(USE_CUDNN) - find_library(CUDNN_LIBRARY cudnn HINTS ${CUDAToolkit_LIBRARY_DIR}) - target_link_libraries(${PROJECT_NAME} PRIVATE ${CUDNN_LIBRARY}) + if(CUDNN_ROOT_DIR) + SET(CUDNN_INCLUDE_DIRS ${CUDNN_ROOT_DIR}/include) + find_library(CUDNN_LIBRARY cudnn HINTS ${CUDNN_ROOT_DIR} PATHS ${CUDNN_ROOT_DIR} PATH_SUFFIXES "lib" "lib64") + else() + SET(CUDNN_INCLUDE_DIRS ${CUDAToolkit_INCLUDE_DIRS}) + find_library(CUDNN_LIBRARY cudnn HINTS ${CUDAToolkit_LIBRARY_DIR} PATHS ${CUDAToolkit_LIBRARY_DIR}) + endif() + + # Check if the library has been found + if(CUDNN_LIBRARY) + target_include_directories(${PROJECT_NAME} PRIVATE ${CUDNN_INCLUDE_DIRS}) + target_link_libraries(${PROJECT_NAME} PRIVATE ${CUDNN_LIBRARY}) + else() + message(WARNING "[WARNING] CUDNN was not found but requested during compilation. (Falling back to: '-D BUILD_TARGET=GPU') + Hint: Install CUDNN in the same path as the CUDA Toolkit, or specify the CUDNN path using this flag '-D CUDNN_ROOT_DIR=path'") + SET(BUILD_TARGET "GPU") # Local's scope + SET(USE_CUDNN OFF) # Local's scope (disable) + endif() endif() if(APPLE) @@ -293,6 +325,40 @@ endif() SET(USE_FPGA ${USE_FPGA} PARENT_SCOPE) # Parent's scope +if(BUILD_DIST) + # ZLIB + if(DEFINED ZLIB_ROOT AND DEFINED ZLIB_INCLUDE_DIRS) + find_library(ZLIB_LIBRARIES z HINTS ${ZLIB_ROOT} PATHS ${ZLIB_ROOT} PATH_SUFFIXES "lib" "lib64") + else() + find_package(ZLIB) + + # Check if ZLIB was really found + if(NOT ZLIB_FOUND) + message(FATAL_ERROR "ZLIB was not found by CMake. + Use '-D BUILD_SUPERBUILD=ON', or try with a different ZLIB installation to fix this problem.") + endif() + endif() + target_include_directories(${PROJECT_NAME} PUBLIC $) + target_link_libraries(${PROJECT_NAME} PUBLIC ${ZLIB_LIBRARIES}) + + # OPENSSL + if(DEFINED OPENSSL_ROOT_DIR AND DEFINED OPENSSL_INCLUDE_DIR) + find_library(OPENSSL_SSL_LIBRARY ssl HINTS ${OPENSSL_ROOT_DIR} PATHS ${OPENSSL_ROOT_DIR} PATH_SUFFIXES "lib" "lib64") + find_library(OPENSSL_CRYPTO_LIBRARY crypto HINTS ${OPENSSL_ROOT_DIR} PATHS ${OPENSSL_ROOT_DIR} PATH_SUFFIXES "lib" "lib64") + SET(OPENSSL_LIBRARIES ${OPENSSL_SSL_LIBRARY} ${OPENSSL_CRYPTO_LIBRARY}) + else() + find_package(OpenSSL) + + # Check if ZLIB was really found + if(NOT OPENSSL_FOUND) + message(FATAL_ERROR "OpenSSL was not found by CMake. + Use '-D BUILD_SUPERBUILD=ON', or try with a different OpenSSL installation to fix this problem.") + endif() + endif() + target_include_directories(${PROJECT_NAME} PUBLIC $) + target_link_libraries(${PROJECT_NAME} PUBLIC ${OPENSSL_LIBRARIES}) +endif() + ########################################################################### ################################## WINDOWS ################################ ########################################################################### @@ -326,6 +392,7 @@ if(MSVC) endif() endif() + ########################################################################## ############################ INSTALLATION ################################ ########################################################################## @@ -361,7 +428,9 @@ message(STATUS "Build shared libs: " ${BUILD_SHARED_LIBS} ) message(STATUS "Build coverage: " ${BUILD_COVERAGE} ) message(STATUS "Build sanitizers: " ${BUILD_SANITIZERS} ) message(STATUS "Build HPC: " ${BUILD_HPC} ) -message(STATUS "Use superbuild: " ${BUILD_SUPERBUILD} ) +message(STATUS "Build distributed: " ${BUILD_DIST} ) +message(STATUS "-------------------------------------------" ) +message(STATUS "Find library suffixes: " ${CMAKE_FIND_LIBRARY_SUFFIXES} ) message(STATUS "-------------------------------------------" ) message(STATUS "C++ compiler: ${CMAKE_CXX_COMPILER_ID} (${CMAKE_CXX_COMPILER}) | Version: ${CMAKE_CXX_COMPILER_VERSION}") message(STATUS "C++ flags: " ${CMAKE_CXX_FLAGS}) @@ -389,6 +458,8 @@ endif() message(STATUS "-------------------------------------------" ) message(STATUS "CUDNN enabled: " ${USE_CUDNN} ) if(USE_CUDNN) + message(STATUS "CuDNN root dir: ${CUDNN_ROOT_DIR}") + message(STATUS "CuDNN include dir: ${CUDNN_INCLUDE_DIRS}") message(STATUS "CuDNN libraries: ${CUDNN_LIBRARY}") endif() message(STATUS "-------------------------------------------" ) @@ -411,11 +482,18 @@ if(BUILD_PROTOBUF) # message(STATUS "Protobuf libraries (release): " ${Protobuf_LIBRARY_RELEASE} ) message(STATUS "Protobuf compiler: " ${Protobuf_PROTOC_EXECUTABLE} ) endif() -#if(WIN32) -#message(STATUS "-------------------------------------------" ) -#message(STATUS "Pthreads dir: " ${PTHREADS_INSTALL_PATH} ) -#message(STATUS "Pthreads include: " ${PTHREADS_INSTALL_PATH}/include ) -#message(STATUS "Pthreads libraries: " ${PTHREADS_INSTALL_PATH}/lib ) -#endif() +if(BUILD_DIST) + message(STATUS "-------------------------------------------" ) + message(STATUS "ZLIB root: " ${ZLIB_ROOT} ) + message(STATUS "ZLIB include: " ${ZLIB_INCLUDE_DIRS} ) + message(STATUS "ZLIB libraries: " ${ZLIB_LIBRARIES} ) + message(STATUS "-------------------------------------------" ) + message(STATUS "OpenSSL root: " ${OPENSSL_ROOT_DIR} ) + message(STATUS "OpenSSL include: " ${OPENSSL_INCLUDE_DIR} ) + message(STATUS "OpenSSL SSL library: " ${OPENSSL_SSL_LIBRARY} ) + message(STATUS "OpenSSL crypto library: " ${OPENSSL_CRYPTO_LIBRARY} ) + message(STATUS "OpenSSL version: " ${OPENSSL_VERSION} ) +endif() +message(STATUS "-------------------------------------------" ) message(STATUS "===========================================" ) message(STATUS "===========================================" ) diff --git a/src/descriptors/descriptor_conv3D.cpp b/src/descriptors/descriptor_conv3D.cpp index 3aea30841..d5d14cc50 100644 --- a/src/descriptors/descriptor_conv3D.cpp +++ b/src/descriptors/descriptor_conv3D.cpp @@ -63,6 +63,8 @@ ConvolDescriptor3D::~ConvolDescriptor3D(){ eddl_free(ptrI); } #ifdef cGPU +#ifndef cCUDNN + else if (O->isGPU()) { if (mem_level>1) { @@ -77,6 +79,7 @@ ConvolDescriptor3D::~ConvolDescriptor3D(){ } } #endif +#endif } @@ -144,6 +147,19 @@ void ConvolDescriptor3D::build(Tensor *A) { pad = {padd[0], padd[1], padr[0], padr[1], padc[0], padc[1]}; // (front, back), (top, bottom), (left, right) } +#ifdef cCUDNN + if(!A->isCPU()){ + if(pad[0] != pad[1] || pad[2] != pad[3] || pad[4] != pad[5]){ + std::cout<<"Warning: asymmetric padding not supported by cuDNN... fixing ... potential shapes mismatch later"<isGPU()) { +#ifndef cCUDNN if (mem_level>1) { // Lowering gpuIB=new Tensor(vector{d*r*c,kz*kd*kc*kr}, I->device); @@ -183,7 +200,7 @@ void ConvolDescriptor3D::build(Tensor *A) { if (mem_level==0) gpuOB=new Tensor(vector{z,A->shape[0]*d*r*c}, I->device); } - +#endif // Tensor with variable shared ptr, delete create ptr gpuI=new Tensor(vector{d*r*c,kd*kz*kc*kr}, I->device); gpu_delete_tensor(gpuI->gpu_device,gpuI->ptr); @@ -201,6 +218,48 @@ void ConvolDescriptor3D::build(Tensor *A) { gpugK=new Tensor(vector{z, kz*kd*kc*kr}, I->device); gpu_delete_tensor(gpuI->gpu_device,gpugK->ptr); } +#ifdef cCUDNN + //CUDNN + convolution_mode = CUDNN_CONVOLUTION; //CUDNN_CROSS_CORRELATION + data_type = CUDNN_DATA_FLOAT; + tensor_format = CUDNN_TENSOR_NCHW; // CUDNN_TENSOR_NHWC + + cudnnCreateConvolutionDescriptor(&convolution_descriptor); + int padding[3] = {pad[0],pad[2],pad[4]}; + int strides[3] ={sd,sr,sc}; + int dilats[3] = {1,1,1}; + cudnnSetConvolutionNdDescriptor(convolution_descriptor,3, + padding, + strides, + dilats, + convolution_mode, data_type); + + + cudnnCreateTensorDescriptor(&xDesc); + int dims[5] = {in, iz, id, ir, ic}; + int str[5] = {iz*id*ir*ic,id*ir*ic,ir*ic,ic,1}; + cudnnSetTensorNdDescriptor(xDesc, /*tensor_format,*/ data_type,5,dims,str); + + int ydims[5] = {in,z,d,r,c}; + int ystr[5] = {z*d*r*c, d*r*c, r*c, c, 1}; + cudnnCreateTensorDescriptor(&yDesc); + cudnnSetTensorNdDescriptor(yDesc,/* tensor_format,*/ data_type, 5, ydims, ystr); + + int bdims[5] = {1,z,1,1,1}; + int bstr[5] = {z, 1, 1, 1, 1}; + cudnnCreateTensorDescriptor(&bDesc); + cudnnSetTensorNdDescriptor(bDesc,/* tensor_format,*/ data_type, 5, bdims, bstr); + + int fdims[5] = {nk, kz, kd, kr, kc}; + // int fstr[5] = {kz*kd*kr*kc,kd*kr*kc,kr*kc,kc,1}; + cudnnCreateFilterDescriptor(&wDesc); + cudnnSetFilterNdDescriptor(wDesc, data_type, tensor_format, 5, fdims); + + cudnn_env_init = -1; + cudnn_conv_back_init = -1; + +#endif + #endif #ifdef cFPGA @@ -232,12 +291,29 @@ void ConvolDescriptor3D::resize(int b) } #ifdef cGPU else if (I->isGPU()) { +#ifndef cCUDNN if (mem_level<2) gpuIB->resize(b*d*r*c); if (mem_level==0) { delete gpuOB; gpuOB=new Tensor(vector{z,b*d*r*c}, I->device); } +#endif + +#ifdef cCUDNN + int dims[5] = {b, iz, id, ir, ic}; + int str[5] = {iz*id*ir*ic,id*ir*ic,ir*ic,ic,1}; + cudnnSetTensorNdDescriptor(xDesc, /*tensor_format,*/ data_type,5,dims,str); + + int ydims[5] = {b,z,d,r,c}; + int ystr[5] = {z*d*r*c, d*r*c, r*c, c, 1}; + cudnnSetTensorNdDescriptor(yDesc, /*tensor_format,*/ data_type, 5, ydims, ystr); + + //cudnnSetTensor4dDescriptor(yDesc, tensor_format, data_type, O->shape[0], O->shape[1],O->shape[2],O->shape[3]); + + + +#endif } #endif diff --git a/src/descriptors/descriptor_pool.cpp b/src/descriptors/descriptor_pool.cpp index 2be102e0c..5ce04382e 100644 --- a/src/descriptors/descriptor_pool.cpp +++ b/src/descriptors/descriptor_pool.cpp @@ -144,14 +144,12 @@ void PoolDescriptor::resize(int b) { this->O->resize(b); #ifdef cCUDNN - #ifdef cCUDNN + if(!this->O->isCPU()){ cudnnSetTensor4dDescriptor(xDesc, tensor_format, data_type, b,iz,ir,ic); - cudnnCreateTensorDescriptor(&yDesc); cudnnSetTensor4dDescriptor(yDesc, tensor_format, data_type, O->shape[0], O->shape[1],O->shape[2],O->shape[3]); - -#endif +} #endif // if (!mem_level) { D->resize(b); } } diff --git a/src/descriptors/descriptor_pool3D.cpp b/src/descriptors/descriptor_pool3D.cpp index 4d57cd8a9..238c38730 100644 --- a/src/descriptors/descriptor_pool3D.cpp +++ b/src/descriptors/descriptor_pool3D.cpp @@ -59,7 +59,8 @@ void PoolDescriptor3D::build(Tensor *A) { sd = stride[0]; sr = stride[1]; sc = stride[2]; - + + in = A->shape[0]; iz = A->shape[1]; id = A->shape[2]; ir = A->shape[3]; @@ -91,6 +92,16 @@ void PoolDescriptor3D::build(Tensor *A) { paddf = pad[0]; paddb = pad[1]; // depth: front-top padrt = pad[0]; padrb = pad[1]; // rows: top-bottom padcl = pad[2]; padcr = pad[3]; // cols: left-right +#ifdef cCUDNN + if(!A->isCPU()){ + if(pad[0] != pad[1] || pad[2] != pad[3] || pad[4] != pad[5]){ + std::cout<<"Warning: asymmetric padding not supported by cuDNN... fixing ... potential shapes mismatch later"<shape[0]) return; O->resize(b); +#ifdef cCUDNN + if(!I->isCPU()){ + int dims[5] = {b, iz, id, ir, ic}; + int str[5] = {iz*id*ir*ic,id*ir*ic,ir*ic,ic,1}; + cudnnSetTensorNdDescriptor(xDesc, /*tensor_format,*/ data_type,5,dims,str); + + int ydims[5] = {b,z,d,r,c}; + int ystr[5] = {z*d*r*c, d*r*c, r*c, c, 1}; + cudnnSetTensorNdDescriptor(yDesc, /*tensor_format,*/ data_type, 5, ydims, ystr); + + //cudnnSetTensor4dDescriptor(yDesc, tensor_format, data_type, O->shape[0], O->shape[1],O->shape[2],O->shape[3]); +} + + +#endif + // if (!mem_level) { D->resize(b); } } diff --git a/src/distributed/README.md b/src/distributed/README.md new file mode 100644 index 000000000..e39f5a6dc --- /dev/null +++ b/src/distributed/README.md @@ -0,0 +1,31 @@ +

+ EDDL +

+ +![build](https://github.com/deephealthproject/eddl/workflows/build/badge.svg) +[![Documentation Status](https://readthedocs.org/projects/ansicolortags/badge/?version=latest)](https://deephealthproject.github.io/eddl/) +![GitHub release (latest by date)](https://img.shields.io/github/v/release/deephealthproject/eddl) +![GitHub](https://img.shields.io/github/license/deephealthproject/eddl) + +This specific section presents and describes the communication system of EDDL developed *ad hoc* to run training procedures in a distributed way. + +## Notice + +**This is a very draft version, not ready to be used or tested** + + +## EDDL communication system + +### The following figures give a general overview of the communication system: + +- Diagram to illustrate report of gradients from worker nodes to the master node and update weights from master node to worker nodes in the EDDL (used in the proposal): ![Distributed Learning flow proposed for the EDDL](images/hybrid-graph-2.svg.png) +- Diagram of the communication system for the EDDL: ![Diagram of the communication system for the EDDL](images/EDDL-distributed-schema.png) +- Flowchart of the Master node: ![Flowchart of the Master node](images/Master-Node.png) +- Flowchart of Worker nodes: ![Flowchart of Worker nodes](images/Worker-Node.png) +- Timeline of network parameters interchange between master node and worker nodes: ![Timeline of network parameters interchange between master node and worker nodes](images/Timeline-in-master-and-worker-nodes.png) + + +## Progress and coverage + +**Not available yet for the communication system** + diff --git a/src/distributed/communications/eddl_distributed_utils.cpp b/src/distributed/communications/eddl_distributed_utils.cpp new file mode 100644 index 000000000..0b0f1e03a --- /dev/null +++ b/src/distributed/communications/eddl_distributed_utils.cpp @@ -0,0 +1,154 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: x.y + * copyright (c) 2020, Universitat Politècnica de València (UPV), PRHLT Research Centre + * Date: July 2020 + * Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) + * All rights reserved + */ + +#include +#include + +#include + +#include +#include + +namespace eddl { + +uint64_t get_system_milliseconds() +{ + std::chrono::system_clock::time_point just_now = std::chrono::system_clock::now(); + std::chrono::system_clock::duration since_epoch = just_now.time_since_epoch(); + uint64_t msec = std::chrono::duration_cast(since_epoch).count(); + + return msec; +} +size_t compute_aligned_size(size_t size) +{ + return size + ((eddl_alignment - (size % eddl_alignment)) % eddl_alignment); +} + +void * eddl_malloc(size_t size) +{ + // the memory allocated by this method must be released by using the free() system call + void *ptr = nullptr; + int rc = posix_memalign(&ptr, eddl_alignment, compute_aligned_size(size)); + if (rc != 0) + throw std::runtime_error(err_msg("error allocating memory.")); + + return ptr; +} + +std::vector str_split(std::string s, char sep) +{ + std::vector v; + + std::string part=""; + + size_t i = 0; + while (i < s.size()) { + + size_t j = s.find(sep, i); + + if (j==i) { + v.push_back(std::string("")); + i+=1; + } else if (j > s.size()) { // end of string reached + v.push_back(std::string(s.substr(i,s.size()-i))); + i=s.size(); + } else { + v.push_back(std::string(s.substr(i,j-i))); + i=j+1; + } + } + + return v; +} + +std::string get_ip_address(uint32_t s_addr) +{ + unsigned char *ptr = (unsigned char *) &s_addr; + unsigned int i; + + std::string s1, s2, s3, s4; + + i = ptr[0]; s1 = std::to_string(i); + i = ptr[1]; s2 = std::to_string(i); + i = ptr[2]; s3 = std::to_string(i); + i = ptr[3]; s4 = std::to_string(i); + + return s1 + "." + s2 + "." + s3 + "." + s4; +} +std::string pointer_to_string(void * ptr) +{ + std::stringstream buff; + buff << ptr; + return buff.str(); +} + +std::string compose_log_message(const char * filename, const int line_number, const char * function_name, const char * msg) +{ + std::stringstream buff; + buff << filename; + buff << ":"; + buff << std::to_string(line_number); + buff << ":"; + buff << function_name; + buff << ": "; + buff << msg; + return buff.str(); +} +std::string compose_log_message(const char * filename, const int line_number, const char * function_name, std::string msg) +{ + std::stringstream buff; + buff << filename << ":" << std::to_string(line_number) + ":" << function_name << ": " << msg; + return buff.str(); +} +void print_log_message(const char * filename, const int line_number, const char * function_name, const char * msg) +{ + std::cout << compose_log_message(filename, line_number, function_name, msg) << std::endl; +} +void print_log_message(const char * filename, const int line_number, const char * function_name, std::string msg) +{ + std::cout << compose_log_message(filename, line_number, function_name, msg) << std::endl; +} +void print_err_message(const char * filename, const int line_number, const char * function_name, const char * msg) +{ + std::cerr << compose_log_message(filename, line_number, function_name, msg) << std::endl; +} +void print_err_message(const char * filename, const int line_number, const char * function_name, std::string msg) +{ + std::cerr << compose_log_message(filename, line_number, function_name, msg) << std::endl; +} + + +static std::map __eddl_message_types_names; +void init_message_type_names() +{ + #define stringify(name) # name + __eddl_message_types_names[DATA_SAMPLES] = stringify(DATA_SAMPLES); + __eddl_message_types_names[DATA_WEIGHTS] = stringify(DATA_WEIGHTS); + __eddl_message_types_names[DATA_GRADIENTS] = stringify(DATA_GRADIENTS); + __eddl_message_types_names[MSG_ACK_SAMPLES] = stringify(MSG_ACK_SAMPLES); + __eddl_message_types_names[MSG_ACK_WEIGHTS] = stringify(MSG_ACK_WEIGHTS); + __eddl_message_types_names[MSG_ACK_GRADIENTS] = stringify(MSG_ACK_GRADIENTS); + __eddl_message_types_names[PARAMETER] = stringify(PARAMETER); + __eddl_message_types_names[COMMAND] = stringify(COMMAND); + __eddl_message_types_names[PKG_ACK] = stringify(PKG_ACK); + __eddl_message_types_names[MSG_CHKSUM] = stringify(MSG_CHKSUM); + #undef stringify +} +std::string get_message_type_name(int value) +{ + return __eddl_message_types_names[value]; +} +void show_all_message_type_names() +{ + for(auto iter: __eddl_message_types_names) + std::cout << std::hex << iter.first << " " << iter.second << std::endl; +} + + +}; diff --git a/src/distributed/communications/eddl_message.cpp b/src/distributed/communications/eddl_message.cpp new file mode 100644 index 000000000..ec9ee2a2c --- /dev/null +++ b/src/distributed/communications/eddl_message.cpp @@ -0,0 +1,318 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: x.y + * copyright (c) 2020, Universitat Politècnica de València (UPV), PRHLT Research Centre + * Date: July 2020 + * Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) + * All rights reserved + */ + +#include + +#include +#include +#include +#include + +namespace eddl { + +eddl_message::eddl_message(uint32_t type, + uint32_t source_addr, + uint32_t target_addr, + size_t message_data_size, + size_t packet_data_size, + void * data ) +: type(type), source_addr(source_addr), target_addr(target_addr) +{ + this->timestamp = get_system_milliseconds(); + this->set_message_id(); + this->message_data_size = 0; + this->packet_data_size = 0; + this->seq_len = 0; + this->data = nullptr; + + this->checksum_has_been_set = false; + memset(this->checksum, 0, eddl_checksum_len); + + this->received_packet = nullptr; + + if (0 == message_data_size && nullptr != data) + throw std::runtime_error(err_msg("data size equal to zero when data is not nullptr.")); + + if (type == eddl_message_types::COMMAND) { + this->data = (unsigned char *)eddl_malloc(sizeof(uint32_t)); + uint32_t *p = (uint32_t *)this->data; + // see fabric constructors in eddl_message.h + // the command id is passed via the parameter 'packet_data_size' + *p = (uint32_t)packet_data_size; + this->message_data_size = this->packet_data_size = sizeof(uint32_t); + this->seq_len = 1; + this->compute_checksum(); + this->checksum_has_been_set = true; + } else if (message_data_size > 0) { + this->message_data_size = message_data_size; + this->packet_data_size = packet_data_size; + set_data(message_data_size, data); + } +} +eddl_message::eddl_message(eddl_packet * packet) +{ + this->type = packet->get_type(); + this->source_addr = packet->get_source_addr(); + this->target_addr = packet->get_target_addr(); + this->timestamp = get_system_milliseconds(); + this->seq_len = 1; + this->message_data_size = packet->get_data_size(); + this->packet_data_size = packet->get_data_size(); + this->set_message_id(packet->get_message_id_ptr()); + memset(this->checksum, 0, eddl_checksum_len*sizeof(unsigned char)); + this->data = (unsigned char *)eddl_malloc(this->message_data_size); + memcpy(this->data, packet->get_data(), this->message_data_size); + this->pending_packets = 0; + this->received_packet = nullptr; + this->checksum_has_been_set = false; +} + +void eddl_message::set_data(size_t message_data_size, void * data) +{ + this->message_data_size = message_data_size; + + if (0 == packet_data_size) + throw std::runtime_error(err_msg("invalid packet_data_size.")); + + this->seq_len = message_data_size / packet_data_size + + ((message_data_size % packet_data_size) != 0); + + if (0 == this->seq_len) + throw std::runtime_error(err_msg("invalid seq_len.")); + + this->pending_packets = this->seq_len; + + if (nullptr != this->data) free(this->data); + this->data = (unsigned char *)eddl_malloc(message_data_size); + if (nullptr == this->data) + throw std::runtime_error(err_msg("error allocating memory.")); + if (nullptr != data) { + memcpy(this->data, data, message_data_size); + this->compute_checksum(); + this->checksum_has_been_set = true; + this->pending_packets = 0; + } else { + memset(this->data, 0, message_data_size); + } + + if (nullptr != this->received_packet) free(this->received_packet); + this->received_packet = (bool *)eddl_malloc(this->seq_len * sizeof(bool)); + memset(this->received_packet, 0, this->seq_len * sizeof(bool)); +} + +eddl_message::~eddl_message() +{ + if (nullptr != this->data) free(this->data); + if (nullptr != this->received_packet) free(this->received_packet); +} + +uint32_t eddl_message::get_command() +{ + uint32_t *p = (uint32_t *)this->data; + + return *p; +} + +void eddl_message::set_source_addr(uint32_t source_addr) +{ + this->source_addr = source_addr; +} +void eddl_message::set_target_addr(uint32_t target_addr) +{ + this->target_addr = target_addr; +} +void eddl_message::set_message_id(char * message_id) +{ + static char hex[20]="0123456789ABCDEF"; + + if (nullptr != message_id) { + if (strlen(message_id) < eddl_msg_id_len) + throw std::runtime_error(err_msg("invalid message id")); + + this->message_id = std::string(message_id, eddl_msg_id_len); + } else { + char s[32]; + int i=0; + + uint32_t s_addr = this->source_addr; + + for (int k=0; k < 8; k++) { + s[i++] = hex[s_addr & 0x00f]; + s_addr >>= 4; + } + + uint32_t type = this->type; + for (int k=0; k < 3; k++) { + s[i++] = hex[type & 0x00f]; + type >>= 4; + } + + uint64_t msec = this->timestamp; + for (int k=0; k < 8; k++) { + s[i++] = hex[msec & 0x00f]; + msec >>= 4; + } + s[i++] = '\0'; + + this->message_id = s; + } +} + +void eddl_message::compute_checksum() +{ + SHA256((unsigned char *)this->data, this->message_data_size, this->checksum); +} +bool eddl_message::is_checksum_valid() +{ + unsigned char checksum[eddl_checksum_len]; + SHA256((unsigned char *)this->data, this->message_data_size, checksum); + + for (int i=0; i < eddl_checksum_len; i++) + if (this->checksum[i] != checksum[i]) return false; + + return true; +} + +void eddl_message::set_checksum(unsigned char * checksum) +{ + memcpy(this->checksum, checksum, eddl_checksum_len); + this->checksum_has_been_set = true; +} + +void eddl_message::add_packet(eddl_packet * packet) +{ + if (0 == this->packet_data_size) + // this could fail if the first packet is the last one of the + // sequence whose size is smaller than the remaining packet + this->packet_data_size = packet->get_all_but_last_packet_size(); + + if (nullptr == this->data) { + set_data(packet->get_message_size(), nullptr); + } + + size_t seq_no = packet->get_seq_no(); + + if (seq_no >= this->seq_len) + throw std::runtime_error(err_msg("invalid packet seq_no.")); + + if (this->message_data_size != packet->get_message_size()) + throw std::runtime_error(err_msg("message_data_size discrepancy: " + + std::to_string(this->message_data_size) + + " vs " + + std::to_string(packet->get_message_size()))); + + if (this->packet_data_size != packet->get_data_size()) { + if (seq_no < this->seq_len-1) + throw std::runtime_error(err_msg("packet_data_size discrepancy: " + + std::to_string(this->packet_data_size) + + " vs " + + std::to_string(packet->get_data_size()))); + /* + else + print_err_msg("last packet of the message has a different data size."); + */ + } + + size_t i = seq_no * this->packet_data_size; + memcpy(&this->data[i], packet->get_data(), packet->get_data_size()); + + if (! this->received_packet[seq_no]) { + this->received_packet[seq_no] = true; + this->pending_packets--; + } +} +bool eddl_message::was_packet_already_added(size_t seq_no) +{ + if (seq_no >= this->seq_len) + throw std::runtime_error(err_msg("invalid packet seq_no.")); + + return this->received_packet != nullptr + && this->received_packet[seq_no]; +} + +eddl_packet * eddl_message::get_packet(size_t packet_index) +{ + if (nullptr == this->data) + throw std::runtime_error(err_msg("no data available.")); + + if (packet_index >= this->seq_len) + throw std::runtime_error(err_msg("invalid packet index.")); + + size_t pos = this->packet_data_size * packet_index; + + if (pos >= this->message_data_size) + throw std::runtime_error(err_msg("invalid index to access data.")); + + size_t data_size_of_this_packet = std::min(this->packet_data_size, + this->message_data_size - pos); + + return new eddl_packet(this->type, + this->source_addr, + this->target_addr, + this->message_id, + packet_index, + this->seq_len, + this->message_data_size, + this->packet_data_size, + data_size_of_this_packet, + & this->data[pos] ); +} + +eddl_packet * eddl_message::create_packet_for_checksum() +{ + return new eddl_packet(eddl_message_types::MSG_CHKSUM, //this->type, + this->source_addr, + this->target_addr, + this->message_id, + this->seq_len, // this is a special case where packet index is not used + this->seq_len, + this->message_data_size, + this->packet_data_size, // not sure if here it should be eddl_checksum_len + eddl_checksum_len, + this->checksum); +} + +eddl_message * eddl_message::create_acknowledgement() +{ + uint32_t type = 0; + switch (this->get_type()) { + case eddl_message_types::DATA_SAMPLES: + type = eddl_message_types::MSG_ACK_SAMPLES; + break; + case eddl_message_types::DATA_GRADIENTS: + type = eddl_message_types::MSG_ACK_GRADIENTS; + break; + case eddl_message_types::DATA_WEIGHTS: + type = eddl_message_types::MSG_ACK_WEIGHTS; + break; + default: + throw std::runtime_error(err_msg("unexpected message type to be acknowledged.")); + } + + eddl_message * ack = new eddl_message(type, + this->get_target_addr(), // source addr is the target addr of this message + this->get_source_addr(), // target addr is the source addr of this message + this->message_id.size(), + this->message_id.size(), + (void *)this->message_id.c_str() ); + return ack; +} +std::string eddl_message::get_acknowledged_message_id() +{ + return std::string((char *)this->data, eddl_msg_id_len); +} +uint32_t eddl_message::get_acknowledged_message_type() +{ + return ((this->data[ 8]-'0') << 8) + + ((this->data[ 9]-'0') << 4) + + (this->data[10]-'0'); +} + + +}; diff --git a/src/distributed/communications/eddl_message_acks.cpp b/src/distributed/communications/eddl_message_acks.cpp new file mode 100644 index 000000000..56e416f0f --- /dev/null +++ b/src/distributed/communications/eddl_message_acks.cpp @@ -0,0 +1,96 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: x.y + * copyright (c) 2020, Universitat Politècnica de València (UPV), PRHLT Research Centre + * Date: August 2020 + * Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) + * All rights reserved + */ + +#include + +#include + +namespace eddl { + +eddl_message_acks::eddl_message_acks(std::vector & workers, + eddl_message * message) +{ + // we need space for the checksum show seq_no is equal to seq_len + this->num_acks_per_worker = message->get_seq_len()+1; + this->living_workers = 0; + for (auto w: workers) { + if (w->is_active()) { + int * ptr = new int [this->num_acks_per_worker]; + memset(ptr, 0, this->num_acks_per_worker * sizeof(int)); + this->acks[w->get_s_addr()] = ptr; + this->living_workers++; + } + } + this->total_num_acks = this->num_acks_per_worker * this->living_workers; + this->ack_counter = 0; + + this->packet_counters = new size_t [this->num_acks_per_worker]; + memset(this->packet_counters, 0, this->num_acks_per_worker * sizeof(size_t)); + + this->starting_timestamp = get_system_milliseconds(); +} +eddl_message_acks::~eddl_message_acks() +{ + for (auto iter : this->acks) + delete [] iter.second; + + delete [] this->packet_counters; +} + +ssize_t eddl_message_acks::get_pending_acknowledgements() +{ + return this->total_num_acks - this->ack_counter; +} + +void eddl_message_acks::acknowledge(uint32_t source_addr, size_t seq_no) +{ + if (seq_no >= this->num_acks_per_worker) + throw std::runtime_error(err_msg("invalid seq_no")); + + if (this->acks.count(source_addr) == 0) + throw std::runtime_error(err_msg("invalid source_addr " + get_ip_address(source_addr))); + + if (this->acks[source_addr][seq_no] == 0) { + this->acks[source_addr][seq_no] = 1; + this->ack_counter++; + this->packet_counters[seq_no]++; + } +} +void eddl_message_acks::acknowledge_whole_message(uint32_t source_addr) +{ + if (this->acks.count(source_addr) == 0) + throw std::runtime_error(err_msg("invalid source_addr " + get_ip_address(source_addr))); + + for(size_t seq_no = 0; seq_no < this->num_acks_per_worker; seq_no++) { + if (this->acks[source_addr][seq_no] == 0) { + this->acks[source_addr][seq_no] = 1; + this->ack_counter++; + this->packet_counters[seq_no]++; + } + } +} +bool eddl_message_acks::all_has_been_acknowledged() +{ + return this->ack_counter == this->total_num_acks; +} +bool eddl_message_acks::packet_already_acknowledged(size_t seq_no) +{ + if (seq_no >= this->num_acks_per_worker) + throw std::runtime_error(err_msg("invalid seq_no")); + + return this->packet_counters[seq_no] == this->living_workers; +} + +bool eddl_message_acks::lasting_too_much_time() +{ + // returns true if more than 60 seconds to be acknowledged + return (get_system_milliseconds() - this->starting_timestamp) > 60*1000; +} + +}; diff --git a/src/distributed/communications/eddl_packet.cpp b/src/distributed/communications/eddl_packet.cpp new file mode 100644 index 000000000..0496bc596 --- /dev/null +++ b/src/distributed/communications/eddl_packet.cpp @@ -0,0 +1,117 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: x.y + * copyright (c) 2020, Universitat Politècnica de València (UPV), PRHLT Research Centre + * Date: July 2020 + * Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) + * All rights reserved + */ + +#include + +#include +#include +#include + +namespace eddl { + +eddl_packet::eddl_packet(uint32_t type, + uint32_t source_addr, + uint32_t target_addr, + std::string & message_id, + size_t seq_no, + size_t seq_len, + size_t message_size, + uint32_t all_but_last_packet_size, + size_t data_size, + void * data ) +{ + if (data_size == 0) + throw std::runtime_error(err_msg("packet data size cannot be zero.")); + + if (data_size > eddl_packet_data_size) + throw std::runtime_error(err_msg("packet data size cannot be larger than 'eddl_packet_data_size'.")); + + if (eddl_msg_id_len != message_id.size()) + throw std::runtime_error(err_msg("non-valid message id.")); + + this->type = type; + this->source_addr = source_addr; + this->target_addr = target_addr; + memset(this->message_id, 0, sizeof(this->message_id)); + strncpy(this->message_id, message_id.c_str(), eddl_msg_id_len); + this->seq_no = seq_no; + this->seq_len = seq_len; + this->message_size = message_size; + this->all_but_last_packet_size = all_but_last_packet_size; + this->data_size = data_size; + memset(this->data, 0, sizeof(this->data)); + memcpy(this->data, data, data_size); + this->compute_checksum(); +} +eddl_packet::eddl_packet(uint32_t type, + uint32_t source_addr, + uint32_t target_addr, + std::string & message_id, + size_t seq_no, + size_t seq_len, + uint32_t command ) +{ + if (eddl_msg_id_len != message_id.size()) + throw std::runtime_error(err_msg("non-valid message id.")); + + this->type = type; + this->source_addr = source_addr; + this->target_addr = target_addr; + memset(this->message_id, 0, sizeof(this->message_id)); + strncpy(this->message_id, message_id.c_str(), eddl_msg_id_len); + this->seq_no = seq_no; + this->seq_len = seq_len; + this->message_size = 0; + this->all_but_last_packet_size = 0; // TO-BE REVIEWED + this->data_size = sizeof(uint32_t); + memset(this->data, 0, sizeof(this->data)); + uint32_t *p = (uint32_t *)this->data; + *p = command; + this->compute_checksum(); +} + +eddl_packet::~eddl_packet() +{ +} + +uint32_t eddl_packet::get_command() +{ + uint32_t *p = (uint32_t *)this->data; + + return *p; +} + +void eddl_packet::compute_checksum() +{ + memset(this->checksum, 0, eddl_checksum_len*sizeof(unsigned char)); + unsigned char checksum[eddl_checksum_len]; + SHA256((unsigned char *)this, sizeof(eddl_packet), checksum); + memcpy(this->checksum, checksum, eddl_checksum_len*sizeof(unsigned char)); +} +bool eddl_packet::is_checksum_valid() +{ + unsigned char checksum_orig[eddl_checksum_len]; + unsigned char checksum_new[eddl_checksum_len]; + memcpy(checksum_orig, this->checksum, eddl_checksum_len*sizeof(unsigned char)); + memset(this->checksum, 0, eddl_checksum_len*sizeof(unsigned char)); + SHA256((unsigned char *)this, sizeof(eddl_packet), checksum_new); + memcpy(this->checksum, checksum_orig, eddl_checksum_len*sizeof(unsigned char)); + + for (int i=0; i < eddl_checksum_len; i++) + if (checksum_orig[i] != checksum_new[i]) return false; + + return true; +} + +eddl_packet_ack * eddl_packet::create_acknowledgement(uint32_t worker_addr) +{ + return new eddl_packet_ack(worker_addr, this->seq_no, this->message_id); +} + +}; diff --git a/src/distributed/communications/eddl_worker_node.cpp b/src/distributed/communications/eddl_worker_node.cpp new file mode 100644 index 000000000..b22033ff9 --- /dev/null +++ b/src/distributed/communications/eddl_worker_node.cpp @@ -0,0 +1,75 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: x.y + * copyright (c) 2020, Universitat Politècnica de València (UPV), PRHLT Research Centre + * Date: July 2020 + * Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) + * All rights reserved + */ + +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + + +namespace eddl { + +eddl_worker_node::eddl_worker_node(std::string description) +{ +// ip:192.168.13.11;cpu:2,8192;gpu:1,low_mem;fpga:0,0;batch_size:10; + std::vector columns=str_split(description, ';'); + + for (auto s : columns) { + std::vector key_values=str_split(s, ':'); + std::string key = key_values[0]; + std::vector values=str_split(key_values[1], ','); + + if (key == "ip") { + + this->hostname_or_ip_address = values[0]; + struct hostent *host = gethostbyname(this->hostname_or_ip_address.c_str()); + if (sizeof(this->s_addr) != host->h_length) + throw std::runtime_error(err_msg("address error conversion.")); + memcpy((char *)&this->s_addr, (char *)host->h_addr, host->h_length); + + } else if (key == "cpu") { + + this->cpu_cores = std::stoi(values[0]); + this->cpu_mem = std::stoi(values[1]); + + } else if (key == "gpu") { + + this->gpu_cards = std::stoi(values[0]); + this->gpu_mem_mode = values[1]; + + } else if (key == "fga") { + + this->fpga_cards = std::stoi(values[0]); + this->fpga_mem = std::stoi(values[1]); + + } else if (key == "batch_size") { + + this->batch_size = std::stoi(values[0]); + } + } + + this->data_subset=""; + this->active = true; +} + +std::string eddl_worker_node::get_ip_address() +{ + return eddl::get_ip_address(this->s_addr); +} + + +}; diff --git a/src/distributed/communications/multicast_receiver.cpp b/src/distributed/communications/multicast_receiver.cpp new file mode 100644 index 000000000..caf667960 --- /dev/null +++ b/src/distributed/communications/multicast_receiver.cpp @@ -0,0 +1,309 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: x.y + * copyright (c) 2020, Universitat Politècnica de València (UPV), PRHLT Research Centre + * Date: July 2020 + * Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) + * All rights reserved + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#if !defined(MSG_NOSIGNAL) +# if defined(__APPLE__) +# define MSG_NOSIGNAL 0 +# else +# error "MSG_NOSIGNAL is not defined this should be fixed!" +# endif +#endif + +namespace eddl { + + +MulticastReceiver::MulticastReceiver(eddl_queue & input_queue, + eddl_queue & ack_queue, + eddl_queue & output_queue, + DistributedEnvironment & distributed_environment) : + input_queue(input_queue), + ack_queue(ack_queue), + output_queue(output_queue), + distributed_environment(distributed_environment) +{ + socket_fd_in = socket(AF_INET, SOCK_DGRAM, 0); + if (socket_fd_in < 0) + throw std::runtime_error(err_msg("input socket cannot be created.")); + + int reuse=1; + if (setsockopt(socket_fd_in, SOL_SOCKET, SO_REUSEADDR, (char *)&reuse, sizeof(reuse)) < 0) + throw std::runtime_error(err_msg("socket cannot be set to allow multiple instances to receive copies of multicast datagrams.")); + +#if defined(__APPLE__) + { + int set = 1; + if (setsockopt(socket_fd_in, SOL_SOCKET, SO_NOSIGPIPE, (void *)&set, sizeof(int)) < 0) + throw std::runtime_error(err_msg("cannot unset SIGPIPE. " + std::to_string(errno) + ":" + strerror(errno))); + } +#endif + + this->port_number_in = distributed_environment.get_udp_data_port(); + + struct sockaddr_in host_addr; + memset(&host_addr, 0, sizeof(host_addr)); + host_addr.sin_family = AF_INET; + host_addr.sin_addr.s_addr = INADDR_ANY; // distributed_environment.get_my_s_addr(); + host_addr.sin_port = htons(this->port_number_in); + + if (bind(socket_fd_in, (struct sockaddr *) &host_addr, sizeof(host_addr)) < 0) + throw std::runtime_error(err_msg("binding socket failed.")); + + struct ip_mreq mreq; + mreq.imr_multiaddr.s_addr = distributed_environment.get_multicast_s_addr(); + mreq.imr_interface.s_addr = distributed_environment.get_my_s_addr(); + + if (setsockopt(socket_fd_in, IPPROTO_IP, IP_ADD_MEMBERSHIP, (char *) &mreq, sizeof(mreq)) < 0) + throw std::runtime_error(err_msg("adding membership to multicast group failed.")); + + //////////////////////////////////////////////////////////////////////////// + + socket_fd_out = socket(AF_INET, SOCK_DGRAM, 0); + if (socket_fd_out < 0) + throw std::runtime_error(err_msg("output socket cannot be created.")); + +#if defined(__APPLE__) + { + int set = 1; + if (setsockopt(socket_fd_out, SOL_SOCKET, SO_NOSIGPIPE, (void *)&set, sizeof(int)) < 0) + throw std::runtime_error(err_msg("cannot unset SIGPIPE. " + std::to_string(errno) + ":" + strerror(errno))); + } +#endif + + this->port_number_out = distributed_environment.get_udp_ack_port(); + + //////////////////////////////////////////////////////////////////////////// + + std::cout << "ready to receive messages from multicast group " + << get_ip_address(mreq.imr_multiaddr.s_addr) << ":" << this->port_number_in + << " via " << distributed_environment.get_my_ip_addr() + << " and sent acknowledgements to " + << distributed_environment.get_master_ip_addr() << ":" << this->port_number_out + << std::endl; + + //////////////////////////////////////////////////////////////////////////// + + receiver_active = true; + receiver_thread = std::thread( & MulticastReceiver::receiver, this); +} + +MulticastReceiver::~MulticastReceiver() +{ + receiver_active = false; + receiver_thread.join(); + recently_received_messages.clear(); + for (auto iter : active_messages) + delete iter.second; + active_messages.clear(); + close(socket_fd_out); +} + +void MulticastReceiver::stop() +{ + receiver_active = false; + // does this method to send a packet with closing command in order to + // unlock the receiver thread? +} + +void MulticastReceiver::send_ack(eddl_packet_ack * ack) +{ + int flags = MSG_NOSIGNAL; + struct sockaddr_in peer_addr; + memset(&peer_addr, 0, sizeof(peer_addr)); + peer_addr.sin_family = AF_INET; + peer_addr.sin_addr.s_addr = distributed_environment.get_master_s_addr(); + peer_addr.sin_port = htons(this->port_number_out); + + ssize_t l = sizeof(eddl_packet_ack); + ssize_t n = sendto(socket_fd_out, (void *)ack, l, flags, + (const struct sockaddr *)&peer_addr, sizeof(peer_addr)); + //print_log_msg("sent acknowledgement of packet no " + std::to_string(ack->get_seq_no())); + delete ack; + + if (n != l) + throw std::runtime_error(err_msg("sent " + std::to_string(n) + + " bytes instead of " + std::to_string(l) + + " " + std::to_string(errno) + ": " + + strerror(errno))); +} +void MulticastReceiver::receiver() +{ + void * data; + while (receiver_active) { + struct sockaddr_in peer_addr; + socklen_t peer_addr_size = sizeof(peer_addr); + int flags = MSG_NOSIGNAL; // MSG_WAITALL; + data = eddl_malloc(sizeof(eddl_packet)); + if (nullptr == data) + throw std::runtime_error(err_msg("error allocating memory.")); + + ssize_t l = sizeof(eddl_packet); + // blocking call + ssize_t n = recvfrom(socket_fd_in, data, l, flags, + (struct sockaddr *)&peer_addr, &peer_addr_size); + if (n < 0) { + print_err_msg("error receiving a packet: " + std::to_string(errno) + ": " + strerror(errno)); + free(data); + continue; // do not abort the process, just drop the packet + } + + if (n != l) { + print_err_msg("warning received an incomplete packet of " + + std::to_string(n) + " bytes instead of " + + std::to_string(l) + " bytes requrested"); + free(data); + continue; // do not abort the process, just drop the packet + } + + eddl_packet * packet = (eddl_packet *)data; + /** + print_log_msg("received packet " + std::to_string(packet->get_seq_no()) + + "/" + std::to_string(packet->get_seq_len()) + + " of message " + packet->get_message_id() + + " from " + get_ip_address(packet->get_source_addr())); + **/ + if (packet->is_checksum_valid()) { + if (packet->get_source_addr() != peer_addr.sin_addr.s_addr) + throw std::runtime_error(err_msg("received packet from " + + get_ip_address(peer_addr.sin_addr.s_addr) + + " claiming it was sent from " + + get_ip_address(packet->get_source_addr()))); + + eddl_message * message = nullptr; + std::string msg_id = ""; + + switch(packet->get_type()) { + case eddl_message_types::DATA_WEIGHTS: + case eddl_message_types::DATA_GRADIENTS: + case eddl_message_types::DATA_SAMPLES: + case eddl_message_types::MSG_CHKSUM: + /* get info from packet and add it in the corresponding + existing message or just create a new message, + but if the messsage was recently received then resent + packets must be ignored (dropped) + */ + + /* contrary to the above comment, the acknowledgement is sent + in any case in order to allow multicast_sender in the peer + to know the packet was received -- this is pending to be + analysed in more detail, but currently this seems that + alevaites the problem of pending messages to be sent in the + multicast sender of the peer. + */ + this->send_ack(packet->create_acknowledgement(distributed_environment.get_my_s_addr())); + + msg_id = packet->get_message_id(); + if (this->recently_received_messages.count(msg_id) > 0) { + uint64_t lapse = get_system_milliseconds() - this->recently_received_messages[msg_id]; + // if more than one hour it was received then remove it + if (lapse > 1*60*60*1000) { + this->recently_received_messages.erase(msg_id); + } + } else { + if (this->active_messages.count(packet->get_message_id()) == 0) { + message = new eddl_message(packet->get_type(), + packet->get_source_addr(), + packet->get_target_addr(), + packet->get_message_size(), + packet->get_data_size(), + (void *)nullptr); + message->set_message_id(packet->get_message_id_ptr()); + this->active_messages[message->get_message_id()] = message; + //print_log_msg(std::string("receiving message ") + message->get_message_id()); + //print_log_msg(".....................................................message created from packet with " + std::to_string(packet->get_message_size()) + " vs " + std::to_string(message->get_message_data_size())); + } else { + message = this->active_messages[packet->get_message_id()]; + } + if (packet->get_type() == eddl_message_types::MSG_CHKSUM) { + //print_log_msg(".....................................................message checksum received"); + if (!message->was_checksum_already_set()) { + message->set_checksum((unsigned char *)packet->get_data()); + } + // this->send_ack(packet->create_acknowledgement(distributed_environment.get_my_s_addr())); + } else { + // add the packet to the message --same packet can be received more than once + //print_log_msg(".....................................................message " + pointer_to_string(message)); + //print_log_msg(".....................................................message id " + message->get_message_id()); + //print_log_msg(".....................................................packet msg id " + packet->get_message_id()); + if (! message->was_packet_already_added(packet->get_seq_no())) { + message->add_packet(packet); + // acknowledge the received packet + } + // this->send_ack(packet->create_acknowledgement(distributed_environment.get_my_s_addr())); + //print_log_msg(".....................................................added packet to message and sent ack"); + } + // if message complete enqueue the message + if (message->is_complete()) { + this->active_messages.erase(message->get_message_id()); + if (message->is_checksum_valid()) { + switch (message->get_type()) { + case eddl_message_types::DATA_SAMPLES: + case eddl_message_types::DATA_GRADIENTS: + case eddl_message_types::DATA_WEIGHTS: + this->input_queue.push(message); + this->recently_received_messages[message->get_message_id()] = get_system_milliseconds(); + this->output_queue.push_front(message->create_acknowledgement()); + print_log_msg(std::string("received message ") + message->get_message_id()); + break; + default: + throw std::runtime_error(err_msg("unexpected message type.")); + } + } else { + delete message; + } + } + } + break; + case eddl_message_types::COMMAND: + if (packet->get_command() == eddl_command_types::SHUTDOWN) + receiver_active = false; + this->input_queue.push(new eddl_message(packet)); + break; + case eddl_message_types::PARAMETER: + this->input_queue.push(new eddl_message(packet)); + break; + case eddl_message_types::PKG_ACK: + case eddl_message_types::MSG_ACK_SAMPLES: + case eddl_message_types::MSG_ACK_WEIGHTS: + case eddl_message_types::MSG_ACK_GRADIENTS: + //this->ack_queue.push(new eddl_message(packet)); + //break; + default: + throw std::runtime_error(err_msg("unexpected message type.")); + } // switch + } else { + // otherwise do nothing, sender will resend non-acknowledged packets + // so the next print_err_msg() must be commented + print_err_msg("received packet " + std::to_string(packet->get_seq_no()) + + "/" + std::to_string(packet->get_seq_len()) + + " of message " + packet->get_message_id() + + " from " + get_ip_address(packet->get_source_addr())); + } + /* + instead of deleting the object of the class eddl_packet, we have + to free the memory block + delete packet -- DON'T DO THIS IN THIS CASE + */ + free(data); + } // while receiver_active + close(socket_fd_in); + print_log_msg("multicast receiver thread stopped normally"); +} +}; diff --git a/src/distributed/communications/multicast_sender.cpp b/src/distributed/communications/multicast_sender.cpp new file mode 100644 index 000000000..7c07fb660 --- /dev/null +++ b/src/distributed/communications/multicast_sender.cpp @@ -0,0 +1,412 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: x.y + * copyright (c) 2020, Universitat Politècnica de València (UPV), PRHLT Research Centre + * Date: August 2020 + * Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) + * All rights reserved + */ + +#include + +#include +#include +#include + +#if !defined(MSG_NOSIGNAL) +# if defined(__APPLE__) +# define MSG_NOSIGNAL 0 +# else +# error "MSG_NOSIGNAL is not defined this should be fixed!" +# endif +#endif + +namespace eddl { + +MulticastSender::MulticastSender(std::vector & workers, + eddl_queue & output_queue, + eddl_queue & ack_queue, + DistributedEnvironment & distributed_environment) : + workers(workers), + output_queue(output_queue), + ack_queue(ack_queue), + distributed_environment(distributed_environment) +{ + socket_fd_out = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); + if (socket_fd_out < 0) + throw std::runtime_error(err_msg("output socket cannot be created.")); + + u_char loop = 0; // 0 to disable multicast loop ; not necessary in theory + // change this if the master als acts as a worker + if (setsockopt(socket_fd_out, IPPROTO_IP, IP_MULTICAST_LOOP, &loop, sizeof(loop)) < 0) + throw std::runtime_error(err_msg("cannot deactivate multicast loop.")); + + u_char ttl=1; // set ttl to the number of routers multicast packets can go through + // change this to adapt to the needs of federated machine learning + if (setsockopt(socket_fd_out, IPPROTO_IP, IP_MULTICAST_TTL, &ttl, sizeof(ttl)) < 0) + throw std::runtime_error(err_msg("cannot set multicast TTL. " + std::to_string(errno) + ":" + strerror(errno))); + +#if defined(__APPLE__) + { + int set = 1; + if (setsockopt(socket_fd_out, SOL_SOCKET, SO_NOSIGPIPE, (void *)&set, sizeof(int)) < 0) + throw std::runtime_error(err_msg("cannot unset SIGPIPE. " + std::to_string(errno) + ":" + strerror(errno))); + } +#endif + + memset(&this->target_group_addr, 0, sizeof(this->target_group_addr)); + this->target_group_addr.sin_family = AF_INET; + this->target_group_addr.sin_addr.s_addr = distributed_environment.get_multicast_s_addr(); + this->target_group_addr.sin_port = htons(distributed_environment.get_udp_data_port()); + + struct in_addr mreq; + mreq.s_addr = distributed_environment.get_my_s_addr(); + /* alternative 1 sinc Linux 1.2 + struct ip_mreq mreq; + mreq.imr_multiaddr.s_addr = 0; //distributed_environment.get_multicast_s_addr(); + mreq.imr_interface.s_addr = distributed_environment.get_my_s_addr(); + */ + /* alternative 2 since Linux 3.5 + struct ip_mreqn mreq; + mreq.imr_multiaddr.s_addr = 0; //distributed_environment.get_multicast_s_addr(); + mreq.imr_address.s_addr = distributed_environment->get_my_s_addr(); + mreq.imr_ifindex = 0; + */ + if (setsockopt(socket_fd_out, IPPROTO_IP, IP_MULTICAST_IF, (char *)&mreq, sizeof(mreq)) < 0) + throw std::runtime_error(err_msg("cannot set my inferface addr for multicast." + + std::to_string(errno) + ":" + strerror(errno))); + + //////////////////////////////////////////////////////////////////////////// + socket_fd_in = socket(AF_INET, SOCK_DGRAM, 0); + if (socket_fd_in < 0) + throw std::runtime_error(err_msg("input socket cannot be created.")); + +#if defined(__APPLE__) + { + int set = 1; + if (setsockopt(socket_fd_in, SOL_SOCKET, SO_NOSIGPIPE, (void *)&set, sizeof(int)) < 0) + throw std::runtime_error(err_msg("cannot unset SIGPIPE. " + std::to_string(errno) + ":" + strerror(errno))); + } +#endif + + struct sockaddr_in host_addr; + memset(&host_addr, 0, sizeof(host_addr)); + host_addr.sin_family = AF_INET; + host_addr.sin_addr.s_addr = INADDR_ANY; // distributed_environment->get_my_s_addr(); + host_addr.sin_port = htons(distributed_environment.get_udp_ack_port()); + + if (bind(socket_fd_in, (struct sockaddr *) &host_addr, sizeof(host_addr)) < 0) + throw std::runtime_error(err_msg("binding socket failed.")); + + //////////////////////////////////////////////////////////////////////////// + + std::cout << "ready to sent messages to multicast group " + << get_ip_address(distributed_environment.get_multicast_s_addr()) + << ":" << distributed_environment.get_udp_data_port() + << " via " << get_ip_address(distributed_environment.get_my_s_addr()) + << " and receive acknowledgements from any worker via port " + << distributed_environment.get_udp_ack_port() + << std::endl; + + socklen_t optlen; + int sockt_buffer_size; + int rc; +/* + optlen = sizeof(sockt_buffer_size); + sockt_buffer_size = 30*1024*1024; + rc = setsockopt(socket_fd_out, SOL_SOCKET, SO_SNDBUF, &sockt_buffer_size, optlen); + std::cout << "rc = " << rc << std::endl; +*/ + optlen = sizeof(sockt_buffer_size); + rc = getsockopt(socket_fd_out, SOL_SOCKET, SO_SNDBUF, &sockt_buffer_size, &optlen); + std::cout << "send UDP buffer size is " << sockt_buffer_size + << " rc = " << rc + << std::endl; + optlen = sizeof(sockt_buffer_size); + rc = getsockopt(socket_fd_out, SOL_SOCKET, SO_RCVBUF, &sockt_buffer_size, &optlen); + std::cout << "recv UDP buffer size is " << sockt_buffer_size + << " rc = " << rc + << std::endl; + //////////////////////////////////////////////////////////////////////////// + + sender_active = true; + sender_thread = std::thread( & MulticastSender::sender, this); + ack_processor_thread = std::thread( & MulticastSender::ack_processor, this); +} + +MulticastSender::~MulticastSender() +{ + stop(); + + sender_active = false; + output_queue.clear(); + output_queue.push(nullptr); + + sender_thread.join(); + ack_processor_thread.join(); + + for (auto iter: this->active_acknowledgements) + delete iter.second; +} + +void MulticastSender::stop() +{ + sender_active = false; + //////////////////////////////////////////////////////////////////////////// + ////////// this stops the acknowledgement procesor thread ////////////////// + int temp_socket = socket(AF_INET, SOCK_DGRAM, 0); + char data[sizeof(eddl_packet_ack)]; + memset(data, 0, sizeof(data)); + struct sockaddr_in peer; + memset(&peer, 0, sizeof(peer)); + peer.sin_family = AF_INET; + peer.sin_port = htons(distributed_environment.get_udp_ack_port()); + peer.sin_addr.s_addr = distributed_environment.get_my_s_addr(); + ssize_t l = sizeof(data); +#if defined(__APPLE__) + { + int set = 1; + if (setsockopt(temp_socket, SOL_SOCKET, SO_NOSIGPIPE, (void *)&set, sizeof(int)) < 0) + throw std::runtime_error(err_msg("cannot unset SIGPIPE. " + std::to_string(errno) + ":" + strerror(errno))); + } +#endif + int flags = MSG_NOSIGNAL; + ssize_t n = sendto(temp_socket, data, l, flags, + (const struct sockaddr *)&peer, sizeof(peer)); + if (n != l) + print_err_msg("failed to sent a stopping acknowledgement to myself."); + close(temp_socket); + //////////////////////////////////////////////////////////////////////////// +} + +void MulticastSender::sender() +{ + eddl_message * message = nullptr; + while (sender_active) { + if (output_queue.empty()) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + continue; + } + // poping from the queue blocks until something his available + message = output_queue.pop(); + // this allows to stop this thread, according to the destructor of this class + if (nullptr == message) { + continue; + } else if (! send_message(message)) { + // an error ocurred while sending the message + std::string msg_id = message->get_message_id(); + // destroys the message + print_err_msg("message " + msg_id + " was not sent to all living workers."); + } + delete message; + } + close(socket_fd_out); + print_log_msg("multicast sender thread stopped normally."); +} +void MulticastSender::ack_processor() +{ + int flags = MSG_NOSIGNAL; + struct sockaddr_in peer_addr; + socklen_t peer_addr_size; + unsigned char data[next_multiple(sizeof(eddl_packet_ack),8)]; + + while (sender_active) { + ssize_t l = sizeof(eddl_packet_ack); + memset(data, 0, sizeof(data)); + memset(&peer_addr, 0, sizeof(peer_addr)); + peer_addr_size=0; + // blocking call + ssize_t n = recvfrom(socket_fd_in, data, l, flags, (struct sockaddr *)&peer_addr, &peer_addr_size); + if (n < 0) { + print_err_msg("error receiving an acknowledgement: " + + std::to_string(errno) + ": " + strerror(errno)); + continue; // do not abort the process, just drop the packet + } + if (n != l) { + print_err_msg("warning received an incomplete acknowledgement of " + + std::to_string(n) + " bytes instead of " + + std::to_string(l) + " bytes requested"); + continue; // do not abort the process, just drop the packet + } + + eddl_packet_ack * ack = (eddl_packet_ack *)data; + + //////////////////////////////////////////////////////////////////////// + size_t sum=0; + for (size_t i = 0; i < sizeof(eddl_packet_ack); i++) sum += data[i]; + if (0 == sum) break; // an empty acknowledgement means to stop + //////////////////////////////////////////////////////////////////////// + + std::string message_id = ack->get_message_id(); + { // critical region starts + std::unique_lock lck(ack_processor_mutex); + + if (this->active_acknowledgements.count(message_id) > 0) { + eddl_message_acks * _acks = this->active_acknowledgements[message_id]; + _acks->acknowledge(ack->get_source_addr(), ack->get_seq_no()); + if (distributed_environment.get_verbose_level() > 2) + print_log_msg("received acknowledgement " + + std::to_string(ack->get_seq_no()) + + " for message " + message_id); + } else { + print_log_msg("received an obsolete acknowledgement for message " + message_id); + } + } // critical region ends + } + close(socket_fd_in); + print_log_msg("multicast acknowledgment processor thread stopped normally."); +} + +bool MulticastSender::send_message(eddl_message * message) +{ + message->set_source_addr(distributed_environment.get_my_s_addr()); + message->set_target_addr(distributed_environment.get_multicast_s_addr()); + // compulsory to compute again the message id every time source addr is updated + message->set_message_id(); // with no parameter method set_message_id() computes the message_id + + eddl_message_acks * message_acks = nullptr; + + // prepare acknowledgements for the message to be sent here + { // critical region starts + std::unique_lock lck(ack_processor_mutex); + + message_acks = new eddl_message_acks(workers, message); + this->active_acknowledgements[message->get_message_id()] = message_acks; + } // critical region ends + + std::queue seq_no_queue; + eddl_packet * sent_packets[message->get_seq_len()+1]; // includes the checksum + + // populates the queue of packet indices including one additional for the checksum + for (size_t seq_no=0; seq_no <= message->get_seq_len(); ++seq_no) { + seq_no_queue.push(seq_no); + sent_packets[seq_no] = nullptr; + } + + bool return_status = true; + int flags = MSG_NOSIGNAL; + + try { + uint64_t t0 = get_system_milliseconds(); + size_t msec_to_wait_after_sendto=1; + size_t counter=0; + ssize_t sent_bytes=0; + while( sender_active && ! seq_no_queue.empty()) { + size_t pending_packets = seq_no_queue.size(); + for (size_t i=0; sender_active && i < pending_packets; i++) { + size_t seq_no = seq_no_queue.front(); + seq_no_queue.pop(); + + bool packet_to_be_sent = false; + { // critical region starts + std::unique_lock lck(ack_processor_mutex); + + packet_to_be_sent = ! message_acks->packet_already_acknowledged(seq_no); + } // critical region ends + + if (packet_to_be_sent) { + eddl_packet * packet = sent_packets[seq_no]; + if (nullptr == packet) { + if (seq_no < message->get_seq_len()) + packet = message->get_packet(seq_no); + else + packet = message->create_packet_for_checksum(); + sent_packets[seq_no] = packet; + } + ssize_t l = sizeof(eddl_packet); + ssize_t n = sendto(socket_fd_out, (void *)packet, l, + flags, + (const struct sockaddr *) &this->target_group_addr, + sizeof(this->target_group_addr)); + sent_bytes += n; + if (sent_bytes >= 4000) { + std::this_thread::sleep_for(std::chrono::milliseconds(msec_to_wait_after_sendto)); + sent_bytes -= 4000; + } + if (n != l) + throw std::runtime_error(err_msg("sent " + std::to_string(n) + + " bytes instead of " + std::to_string(l) + + " " + std::to_string(errno) + ": " + + strerror(errno))); + //print_log_msg("packet sent " + std::to_string(seq_no)); + seq_no_queue.push(seq_no); + } else if (nullptr != sent_packets[seq_no]) { + delete sent_packets[seq_no]; + sent_packets[seq_no] = nullptr; + } + // otherwise the packet at seq_no is considered successfully sent + // and is not pushed into the queue again + + if (message_acks->lasting_too_much_time()) { + return_status = false; + throw std::runtime_error(err_msg("time over sending message " + message->get_message_id())); + } + + ++counter; + } // inner for loop i < pending_packets + + if (! ack_queue.empty()) { + eddl_message * ack = ack_queue.pop(); + + std::string message_id = ack->get_acknowledged_message_id(); + if (ack->get_type() == eddl_message_types::MSG_ACK_WEIGHTS) { + // critical region starts + std::unique_lock lck(ack_processor_mutex); + + if (this->active_acknowledgements.count(message_id) > 0) { + eddl_message_acks * _acks = this->active_acknowledgements[message_id]; + _acks->acknowledge_whole_message(ack->get_source_addr()); + if (distributed_environment.get_verbose_level() > 1) + print_log_msg("received acknowledgement of whole message:" + message_id); + } else { + if (distributed_environment.get_verbose_level() > 1) + print_err_msg("received acknowledgement of a non-active whole message:" + message_id); + } + } // critical region ends + delete ack; + } + + print_log_msg("message being sent: " + message->get_message_id() + + " |seq_no_queue| = " + std::to_string(seq_no_queue.size()) + + " pending_acknowledgement_count = " + + std::to_string(message_acks->get_pending_acknowledgements()) + + " counter = " + std::to_string(counter) + + " waiting " + std::to_string((get_system_milliseconds()-t0)) + + " milliseconds from message started to be sent."); + /* + msec_to_wait_after_sendto = (msec_to_wait_after_sendto == 1) + ? 10 + : msec_to_wait_after_sendto+10; + */ + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } // outer while loop ! seq_no_queue.empty() + } + catch(std::exception & e) { + print_err_msg(std::string("an exception ocurred: ") + e.what()); + return_status = false; + } + + // cleaning data structures starts + while (! seq_no_queue.empty()) seq_no_queue.pop(); + + for (auto packet : sent_packets) + if (nullptr != packet) delete packet; + // cleaning data structures ends + + { // critical region starts + std::unique_lock lck(ack_processor_mutex); + + if (message_acks->all_has_been_acknowledged()) { + // do any pending action to do + } else { + // review the list of workers and deactivate those who failed systematically + return_status = false; + } + this->active_acknowledgements.erase(message->get_message_id()); + delete message_acks; + } // critical region ends + + return return_status; +} + +}; // namespace eddl diff --git a/src/distributed/communications/tcp_receiver.cpp b/src/distributed/communications/tcp_receiver.cpp new file mode 100644 index 000000000..da569704f --- /dev/null +++ b/src/distributed/communications/tcp_receiver.cpp @@ -0,0 +1,352 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: x.y + * copyright (c) 2020, Universitat Politècnica de València (UPV), PRHLT Research Centre + * Date: July 2020 + * Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) + * All rights reserved + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace eddl { + + +TCP_Receiver::TCP_Receiver( eddl_queue & input_queue, + eddl_queue & weights_ack_queue, + eddl_queue & generic_ack_queue, + eddl_queue & output_queue, + DistributedEnvironment & distributed_environment) : + input_queue(input_queue), + weights_ack_queue(weights_ack_queue), + generic_ack_queue(generic_ack_queue), + output_queue(output_queue), + distributed_environment(distributed_environment) +{ + socket_fd = socket(AF_INET, SOCK_STREAM, 0); + if (socket_fd < 0) + throw std::runtime_error(err_msg("socket cannot be created.")); + + struct sockaddr_in my_addr; + + /* Clear data structure */ + memset(&my_addr, 0, sizeof(struct sockaddr_in)); + my_addr.sin_family = AF_INET; + my_addr.sin_addr.s_addr = INADDR_ANY; + my_addr.sin_port = htons(distributed_environment.get_tcp_port()); + + if (bind(socket_fd, (struct sockaddr *) &my_addr, sizeof(struct sockaddr_in)) < 0) + throw std::runtime_error(err_msg("binding socket failed.")); + + if (listen(socket_fd, listen_max_pending) < 0) + throw std::runtime_error(err_msg("setting listening state failed.")); + + signal(SIGPIPE, SIG_IGN); + + receiver_active=true; + joiner_thread = std::thread( & TCP_Receiver::joiner, this); + acceptor_thread = std::thread( & TCP_Receiver::acceptor, this); +} + +TCP_Receiver::~TCP_Receiver() +{ + stop(); + receiver_active=false; + + // a signal must be sent to the acceptor thread in order to wake up it + // from the accept() system call, but the solution is to deatch the thread + // and leave the program to end, then the thread is killed. + + close(socket_fd); + + joiner_thread.join(); + //acceptor_thread.detach(); + acceptor_thread.join(); // hoping the master sends the stopping commands + + drop_stopped(); + + /* + an object of this class can be destroyed and this destructor complete + its execution while one (or more) objects of the class ActiveThread + remain(s) active while reading an incoming message. + if this occurs, then the master process monitored by valgrind will reports + some reachable memory blocks. + */ +} + +void TCP_Receiver::stop() +{ + receiver_active=false; +} +void TCP_Receiver::acceptor() +{ + while (receiver_active) { + + struct sockaddr_in peer_addr; + socklen_t peer_addr_size = sizeof(struct sockaddr_in); + + int connected_socket_fd = accept(socket_fd, (struct sockaddr *) &peer_addr, &peer_addr_size); + + if (connected_socket_fd < 0) { + if (receiver_active) + throw std::runtime_error(err_msg("accepting a connection failed.")); + } else { + /* + if (verbose_level >= 1) + print_log_msg("connection accepted from " + inet_ntoa(peer_addr.sin_addr)); + */ + + ActiveThread * at = new ActiveThread(connected_socket_fd, input_queue, weights_ack_queue, generic_ack_queue, output_queue, this); + + { // critical region for pushing new items in the queue of active threads + std::unique_lock lck(mutex_active_threads); + + active_threads.push(at); + } + } + } + print_log_msg("acceptor thread stopped normally."); +} + +void TCP_Receiver::drop_stopped() +{ /* + this method is wholly executed inside a critical region that + takes exclusive access to the queue of active_threads + */ + // Critical region starts + std::unique_lock lck(mutex_active_threads); + + for (unsigned int i=0; i < active_threads.size(); i++) { + + // pops an active thread from queue + ActiveThread * at = active_threads.front(); active_threads.pop(); + + if (at->get_status() == STOPPED) { + // if the active thread is stopped then it is joined and destroyed + at->join(); + at->disable(); + delete at; + } else { + // otherwise it is pushed again into the queue + active_threads.push(at); + } + } + // Critical region ends +} +void TCP_Receiver::joiner() +{ + while (receiver_active) { + std::this_thread::sleep_for(std::chrono::seconds(1)); + drop_stopped(); + } + + print_log_msg("joiner thread stopped normally."); +} + + + +///////////////////////////////////////////////////////////////////////////// +///////////////////// METHODS OF CLASS ActiveThread //////////////////////// +///////////////////////////////////////////////////////////////////////////// +TCP_Receiver::ActiveThread::ActiveThread(int socket_fd, + eddl_queue & input_queue, + eddl_queue & weights_ack_queue, + eddl_queue & generic_ack_queue, + eddl_queue & output_queue, + TCP_Receiver * tcp_receiver) +: socket_fd(socket_fd), + input_queue(input_queue), + weights_ack_queue(weights_ack_queue), + generic_ack_queue(generic_ack_queue), + output_queue(output_queue), + tcp_receiver(tcp_receiver) +{ + status = INACTIVE; + thread = new std::thread( & ActiveThread::thread_receiver, this); +} +TCP_Receiver::ActiveThread::~ActiveThread() +{ + /* + here it is not necessary to join the thread because + it is assumed here the thread was joined from the + drop_stopped() method of the class TCP_Receiver + */ + delete thread; +} +void TCP_Receiver::ActiveThread::thread_receiver() +{ + this->status = RUNNING; + std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); + + eddl_message * message = receive_message(); + close(socket_fd); + + std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); + int msec = std::chrono::duration_cast(end - begin).count(); + + if (nullptr != message) { + if (tcp_receiver->distributed_environment.get_verbose_level() >= 1) + print_log_msg("thread on socket " + std::to_string(socket_fd) + + " completed after receiving " + std::to_string(message->get_message_data_size()) + + " bytes in " + std::to_string(msec/1.0e6) + " seconds!" + + " message_type: " + get_message_type_name(message->get_type())); + + switch (message->get_type()) { + + case eddl_message_types::DATA_SAMPLES: + case eddl_message_types::DATA_GRADIENTS: + case eddl_message_types::DATA_WEIGHTS: + output_queue.push(message->create_acknowledgement()); + input_queue.push(message); + break; + + case eddl_message_types::COMMAND: + if (message->get_command() == eddl_command_types::SHUTDOWN) + this->tcp_receiver->receiver_active = false; + case eddl_message_types::PARAMETER: + input_queue.push(message); + break; + + case eddl_message_types::MSG_ACK_WEIGHTS: + weights_ack_queue.push(message); + break; + + case eddl_message_types::MSG_ACK_GRADIENTS: + case eddl_message_types::MSG_ACK_SAMPLES: + generic_ack_queue.push(message); + break; + + case eddl_message_types::PKG_ACK: + { + size_t * p = (size_t *)message->get_data(); + // in p[1] is the type of the acknowledged message + // see method acknowledgement(eddl_packet *) + // in file eddl_message.h + switch (p[1]) { + case eddl_message_types::DATA_WEIGHTS: + weights_ack_queue.push(message); + break; + default: + generic_ack_queue.push(message); + break; + } + } + break; + + default: + throw std::runtime_error(err_msg("non-expected message type")); + } + } else { + print_err_msg("thread on socket " + std::to_string(socket_fd) + + " received an erroneous message in " + + std::to_string(msec/1.0e6) + " seconds!"); + } + this->status = STOPPED; +} +eddl_message * TCP_Receiver::ActiveThread::receive_message() +{ + uint32_t type; + char msg_id[eddl_msg_id_len+1]; + size_t size_in_bytes; + size_t block_size = eddl_default_mtu; + uint32_t source_addr, target_addr; + unsigned char checksum[eddl_checksum_len]; + + ssize_t n, s; + size_t l; + + eddl_message * message = nullptr; + + try { + // receive message type + s = l = sizeof(type); + n = read(socket_fd, &type, l); + if (n != s) { print_err_msg("message type read failed."); return nullptr; } + + // receive message id + memset(msg_id, 0, eddl_msg_id_len+1); + s = l = eddl_msg_id_len; + n = read(socket_fd, msg_id, l); + if (n != s) { print_err_msg("message id read failed."); return nullptr; } + + // receive message sender s_addr + s = l = sizeof(source_addr); + n = read(socket_fd, &source_addr, l); + if (n != s) { print_err_msg("message sender s_addr read failed."); return nullptr; } + + // receive message target s_addr + s = l = sizeof(target_addr); + n = read(socket_fd, &target_addr, l); + if (n != s) { print_err_msg("message receiver s_addr read failed."); return nullptr; } + + // receive message checksum + memset(checksum, 0, eddl_checksum_len); // otherwise valgrind reports a warning + s = l = eddl_checksum_len; + n = read(socket_fd, checksum, l); + if (n != s) { print_err_msg("message checksum read failed."); return nullptr; } + + // receive message size_in_bytes + s = l = sizeof(size_in_bytes); + n = read(socket_fd, &size_in_bytes, l); + if (n != s) { print_err_msg("message size in bytes read failed."); return nullptr; } + + message = new eddl_message(type, + source_addr, + target_addr, + size_in_bytes, + eddl_packet_data_size, + nullptr); + + message->set_message_id(msg_id); + message->set_checksum(checksum); + + char * ptr = (char *)message->get_data(); + size_t bytes_received=0; + + /* + if (verbose_level >= 1) + print_log_msg("receiving a data message of " + std::to_string(size_in_bytes) + " bytes"); + */ + + while( bytes_received < size_in_bytes ) { + + l = size_in_bytes - bytes_received; + l = (l < block_size) ? l : block_size; + + n = read(socket_fd, ptr, l); + if (n < 0) + throw std::runtime_error(err_msg(std::string("read data block failed:") + + std::to_string(errno) + ":" + strerror(errno))); + + if (n == 0) break; + + bytes_received += n; + ptr += n; + } + } + catch(std::exception & e) { + print_err_msg("an exception ocurred: " + std::string(e.what())); + delete message; + return nullptr; + } + + if (message->is_checksum_valid()) { + return message; + } else { + delete message; + return nullptr; + } +} + +}; diff --git a/src/distributed/communications/tcp_sender.cpp b/src/distributed/communications/tcp_sender.cpp new file mode 100644 index 000000000..30b488594 --- /dev/null +++ b/src/distributed/communications/tcp_sender.cpp @@ -0,0 +1,298 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: x.y + * copyright (c) 2020, Universitat Politècnica de València (UPV), PRHLT Research Centre + * Date: July 2020 + * Author: PRHLT Research Centre, UPV, (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) + * All rights reserved + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace eddl { + + +TCP_Sender::TCP_Sender(eddl_queue & output_queue, + eddl_queue & ack_queue, + DistributedEnvironment & distributed_environment) : + output_queue(output_queue), + ack_queue(ack_queue), + distributed_environment(distributed_environment) +{ + change_status_to(NORMAL_OPERATION); + sender_active = true; + sender_thread = std::thread( & TCP_Sender::sender, this); +} + +void TCP_Sender::change_status_to(int new_status) +{ + this->sender_status = new_status; + this->timestamp_last_status_change = get_system_milliseconds(); +} + +TCP_Sender::~TCP_Sender() +{ + stop(); + sender_thread.join(); + + /* + see the code of eddl_queue, it wipes itself by deleting the pending messages, + so it is not necessary to delete the messages in any of the three queues: + output_queue + ack_queue + queue_of_pending_messages + */ + + for (auto &x : sent_messages) { + print_err_msg("deleting message with id " + + x.first + + " pending to be acknowledged."); + delete x.second; + } + sent_messages.clear(); +} +void TCP_Sender::stop() +{ + sender_active = false; + queue_of_pending_messages.clear(); + // the output queue must be cleared in the main thread of a worker or the master +} + +void TCP_Sender::sender() +{ + while (sender_active) { + eddl_message * message = nullptr; + + if (sender_status == NORMAL_OPERATION) { + // messages in the queue of pending messages have more priority + while (sender_active && nullptr == message && ! queue_of_pending_messages.empty()) { + // the pop() method blocks and waits until data is ready + message = queue_of_pending_messages.pop(); + if (get_system_milliseconds() - message->get_timestamp() > 50000 /* 50 seconds */) { + // too old messages are dropped + print_err_msg("dropping too old message " + message->get_message_id() + + " from the queue of pending messages!"); + delete message; + message = nullptr; + } + } + // only gets messages from the output queue if the queue of pending messages is empty + if (nullptr == message && ! output_queue.empty()) { + // the pop() method blocks and waits until data is ready + message = output_queue.pop(); + } + + if (nullptr != message) { + manage_to_send_message(message); + } else { + if (ack_queue.empty()) + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + } else { + if (get_system_milliseconds()-this->timestamp_last_status_change > 1000) { + change_status_to(NORMAL_OPERATION); + } + } + while (! ack_queue.empty()) { + eddl_message * ack = ack_queue.pop(); + std::string msg_id = ack->get_acknowledged_message_id(); + if (sent_messages.count(msg_id) > 0) { + delete sent_messages.at(msg_id); + sent_messages.erase(msg_id); + /* + if (verbose_level >= 2) + print_log_msg("sent_messages[" + msg_id + "] ERASED" + + " |sent_messages| = " + std::to_string(sent_messages.size())); + */ + } + delete ack; + } + } + print_log_msg("sender thread stopped normally."); +} + +void TCP_Sender::manage_to_send_message(eddl_message * message) +{ + if (send_message(message)) { + /* + a sent message is maintained in sent_messages map in order to wait + for the corresponding acknowledgement, when the acknowledgement + of a message is received, then the message is removed + */ + std::string msg_id; + switch (message->get_type()) { + case eddl_message_types::DATA_SAMPLES: + case eddl_message_types::DATA_GRADIENTS: + case eddl_message_types::DATA_WEIGHTS: + msg_id = message->get_message_id(); + /* + if (verbose_level >= 2) + print_err_msg("sent_messages[" + msg_id + "] INSERTED" + + " |sent_messages| = " + std::to_string(sent_messages.size()+1)); + */ + if (sent_messages.count(msg_id) > 0) { + throw std::runtime_error(err_msg("recently sent message already existed in the map.")); + } + sent_messages[msg_id] = message; + break; + + default : + delete message; + break; + } + } else if (sender_active) { + queue_of_pending_messages.push(message); + } else { + delete message; + } +} +bool TCP_Sender::send_message(eddl_message * message) +{ + int socket_fd = socket(AF_INET, SOCK_STREAM, 0); + if (socket_fd < 0) + throw std::runtime_error(err_msg("socket cannot be created.")); + + struct sockaddr_in peer_addr; + + memset(&peer_addr, 0, sizeof(struct sockaddr_in)); + peer_addr.sin_family = AF_INET; + peer_addr.sin_addr.s_addr = message->get_target_addr(); + peer_addr.sin_port = htons(distributed_environment.get_tcp_port()); + + /* + if (verbose_level >= 1) + print_log_msg("trying to connect to " + inet_ntoa(peer_addr.sin_addr)); + */ + + if (connect(socket_fd, (const sockaddr *)&peer_addr, sizeof(peer_addr)) < 0) { + close(socket_fd); + print_err_msg("failed to connect."); + change_status_to(FAILED_TO_CONNECT); + return false; + } + + message->set_source_addr(distributed_environment.get_my_s_addr()); + // compulsory to compute again the message id every time source addr is updated + message->set_message_id(); + + uint32_t type = message->get_type(); + char msg_id[eddl_msg_id_len+1]; strncpy(msg_id, message->get_message_id().c_str(), eddl_msg_id_len); + uint32_t source_addr = distributed_environment.get_my_s_addr(); + uint32_t target_addr = message->get_target_addr(); + unsigned char * checksum = message->get_checksum_ptr(); + size_t size_in_bytes = message->get_message_data_size(); + + ssize_t n, s; + size_t l; + + try { + // send the message type + s = l = sizeof(type); + n = write(socket_fd, &type, l); + if (n != s) { + close(socket_fd); + print_err_msg("failed to send message type."); + change_status_to(FAILED_TO_WRITE); + return false; + } + + // send the message id + s = l = eddl_msg_id_len; + n = write(socket_fd, &msg_id, l); + if (n != s) { + close(socket_fd); + print_err_msg("failed to send message id."); + change_status_to(FAILED_TO_WRITE); + return false; + } + + // send the message sender s_addr + s = l = sizeof(source_addr); + n = write(socket_fd, &source_addr, l); + if (n != s) { + close(socket_fd); + print_err_msg("failed to send sender s_addr."); + change_status_to(FAILED_TO_WRITE); + return false; + } + + // send the message receiver s_addr + s = l = sizeof(target_addr); + n = write(socket_fd, &target_addr, l); + if (n != s) { + close(socket_fd); + print_err_msg("failed to send receiver s_addr."); + change_status_to(FAILED_TO_WRITE); + return false; + } + + // send the message checksum + s = l = eddl_checksum_len; + n = write(socket_fd, checksum, l); + if (n != s) { + close(socket_fd); + print_err_msg("failed to message checksum."); + change_status_to(FAILED_TO_WRITE); + return false; + } + + // send the message size in bytes + s = l = sizeof(size_in_bytes); + n = write(socket_fd, &size_in_bytes, l); + if (n != s) { + close(socket_fd); + print_err_msg("failed to message size."); + change_status_to(FAILED_TO_WRITE); + return false; + } + + // send the message data + size_t block_size = eddl_default_mtu; + size_t pending = size_in_bytes; + char * ptr = (char *)message->get_data(); + + while( pending > 0 ) { + + s = l = (pending < block_size) ? pending : block_size; + + n = write(socket_fd, ptr, l); + if (n < 0) { + std::string str = "n = " + std::to_string(n) + + " errno = " + std::to_string(errno) + + " " + strerror(errno) + + " ptr = " + pointer_to_string(ptr) + + " bytes received = " + std::to_string(ptr - (char *)message->get_data()); + print_err_msg("write failed " + str); + close(socket_fd); + change_status_to(FAILED_TO_WRITE); + return false; + } + + if (n < s) std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + pending -= n; + ptr += n; + } + } + catch(std::exception & e) { + print_err_msg("exception ocurred: " + std::string(e.what())); + close(socket_fd); + change_status_to(FAILED_TO_WRITE); + return false; + } + close(socket_fd); + return true; +} + +}; diff --git a/src/distributed/images/EDDL-distributed-schema.png b/src/distributed/images/EDDL-distributed-schema.png new file mode 100644 index 000000000..f78496d8c Binary files /dev/null and b/src/distributed/images/EDDL-distributed-schema.png differ diff --git a/src/distributed/images/Master-Node.png b/src/distributed/images/Master-Node.png new file mode 100644 index 000000000..a5502a489 Binary files /dev/null and b/src/distributed/images/Master-Node.png differ diff --git a/src/distributed/images/Timeline-in-master-and-worker-nodes.png b/src/distributed/images/Timeline-in-master-and-worker-nodes.png new file mode 100644 index 000000000..ed1a4ff91 Binary files /dev/null and b/src/distributed/images/Timeline-in-master-and-worker-nodes.png differ diff --git a/src/distributed/images/Worker-Node.png b/src/distributed/images/Worker-Node.png new file mode 100644 index 000000000..790b3ac9a Binary files /dev/null and b/src/distributed/images/Worker-Node.png differ diff --git a/src/distributed/images/hybrid-graph-2.svg.png b/src/distributed/images/hybrid-graph-2.svg.png new file mode 100644 index 000000000..52fb653a5 Binary files /dev/null and b/src/distributed/images/hybrid-graph-2.svg.png differ diff --git a/src/distributed/images/logo-eddl.png b/src/distributed/images/logo-eddl.png new file mode 100644 index 000000000..b3aa5d895 Binary files /dev/null and b/src/distributed/images/logo-eddl.png differ diff --git a/src/hardware/cpu/cpu_math.cpp b/src/hardware/cpu/cpu_math.cpp index 4e2c7d900..90b62e937 100644 --- a/src/hardware/cpu/cpu_math.cpp +++ b/src/hardware/cpu/cpu_math.cpp @@ -699,124 +699,3 @@ float cpu_median(float *ptr, int size, int *map) { delete[] sorted_data; return median; } - -void cpu_batchnorm_forward(int b, int z, int rc, - float *input, float *output, float *opa, - float *global_mean, float *global_variance, - float *affine_g, float *affine_b, - float *mean, float *variance, - bool trmode, float epsilon, float momentum) -{ - const int block_size = 256; - int rcz = rc * z; - if (trmode) { - // compute mean and variance - for (int j = 0; j < z; j++) mean[j] = variance[j] = 0.0; - #pragma omp parallel for - for (int k = 0; k < rcz; k += block_size) - for (int i = 0; i < b; i++) { - int p = k + i * rcz; - for (int l = 0; l < block_size && k + l < rcz; l++, p++) { - int j = (k + l) / rc; - mean[j] += input[p]; - variance[j] += input[p] * input[p]; - } - } - float N = b * rc; - #pragma omp parallel for - for (int j = 0; j < z; j++) { - mean[j] = mean[j] / N; - variance[j] = variance[j] / N - mean[j] * mean[j]; - // update global statistics - if (momentum != 0.0) { - global_mean[j] = momentum * global_mean[j] + (1.0 - momentum) * mean[j]; - global_variance[j] = momentum * global_variance[j] + (1.0 - momentum) * variance[j]; - } - variance[j] = sqrt(variance[j] + epsilon); - } - } else { - // just update variance - mean = global_mean; - #pragma omp parallel for - for (int j = 0; j < z; j++) { - variance[j] = sqrt(global_variance[j] + epsilon); - } - } - // normalization - #pragma omp parallel for - for (int k = 0; k < rcz; k += block_size) - for (int i = 0; i < b; i++) { - int p = k + i * rcz; - for (int l = 0; l < block_size && k + l < rcz; l++, p++) { - int j = (k + l) / rc; - float o = (input[p] - mean[j]) / variance[j]; - // affine transformation - if (affine_g != NULL) { - opa[p] = o; - output[p] = o * affine_g[j] + affine_b[j]; - } else output[p] = o; - } - } -} - -void cpu_batchnorm_backward(int b, int z, int rc, float *delta, float *opa, float *pdelta, float *gbn_g, float *gbn_b, float *bn_g, float *variance, float *mean1, float *mean2) -{ - const int block_size = 256; - int rcz = rc * z; - float N = b * rc; - if (bn_g != NULL) { // affine - // compute mean - for (int j = 0; j < z; j++) mean1[j] = mean2[j] = 0.0; - #pragma omp parallel for - for (int k = 0; k < rcz; k += block_size) - for (int i = 0; i < b; i++) { - int p = k + i * rcz; - for (int l = 0; l < block_size && k + l < rcz; l++, p++) { - int j = (k + l) / rc; - mean1[j] += delta[p] * opa[p]; - mean2[j] += delta[p]; - delta[p] *= bn_g[j]; - } - } - #pragma omp parallel for - for (int j = 0; j < z; j++) { - mean1[j] /= N; - mean2[j] /= N; - gbn_g[j] += mean1[j]; - gbn_b[j] += mean2[j]; - mean1[j] *= bn_g[j]; - mean2[j] *= bn_g[j]; - } - } else { - // compute mean - for (int j = 0; j < z; j++) mean1[j] = mean2[j] = 0.0; - #pragma omp parallel for - for (int k = 0; k < rcz; k += block_size) - for (int i = 0; i < b; i++) { - int p = k + i * rcz; - for (int l = 0; l < block_size && k + l < rcz; l++, p++) { - int j = (k + l) / rc; - mean1[j] += delta[p] * opa[p]; // step 1 & 2 - mean2[j] += delta[p]; // step 4 - } - } - #pragma omp parallel for - for (int j = 0; j < z; j++) { - mean1[j] /= N; - mean2[j] /= N; - } - } - #pragma omp parallel for - for (int k = 0; k < rcz; k += block_size) - for (int i = 0; i < b; i++) { - int p = k + i * rcz; - for (int l = 0; l < block_size && k + l < rcz; l++, p++) { - int j = (k + l) / rc; - // opa[p] = opa[p] * mean1[j] + mean2[j]; // step 3 & 5 - // delta[p] -= opa[p]; // step 6 - // delta[p] /= variance[j]; // step 7 - // pdelta[p] += delta[p]; - pdelta[p] += (delta[p] - (opa[p] * mean1[j] + mean2[j])) / variance[j]; - } - } -} diff --git a/src/hardware/cpu/nn/cpu_bn.cpp b/src/hardware/cpu/nn/cpu_bn.cpp index 5c74afb23..ae2b8c6a1 100644 --- a/src/hardware/cpu/nn/cpu_bn.cpp +++ b/src/hardware/cpu/nn/cpu_bn.cpp @@ -113,3 +113,124 @@ void cpu_permute_batch_first(Tensor *A,Tensor *B) _profile(_CPU_PERMUTE_BATCH_FIRST, 1); } + +void cpu_batchnorm_forward(int b, int z, int rc, + float *input, float *output, float *opa, + float *global_mean, float *global_variance, + float *affine_g, float *affine_b, + float *mean, float *variance, + bool trmode, float epsilon, float momentum) +{ + const int block_size = 256; + int rcz = rc * z; + if (trmode) { + // compute mean and variance + for (int j = 0; j < z; j++) mean[j] = variance[j] = 0.0; + #pragma omp parallel for + for (int k = 0; k < rcz; k += block_size) + for (int i = 0; i < b; i++) { + int p = k + i * rcz; + for (int l = 0; l < block_size && k + l < rcz; l++, p++) { + int j = (k + l) / rc; + mean[j] += input[p]; + variance[j] += input[p] * input[p]; + } + } + float N = b * rc; + #pragma omp parallel for + for (int j = 0; j < z; j++) { + mean[j] = mean[j] / N; + variance[j] = variance[j] / N - mean[j] * mean[j]; + // update global statistics + if (momentum != 0.0) { + global_mean[j] = momentum * global_mean[j] + (1.0 - momentum) * mean[j]; + global_variance[j] = momentum * global_variance[j] + (1.0 - momentum) * variance[j]; + } + variance[j] = sqrt(variance[j] + epsilon); + } + } else { + // just update variance + mean = global_mean; + #pragma omp parallel for + for (int j = 0; j < z; j++) { + variance[j] = sqrt(global_variance[j] + epsilon); + } + } + // normalization + #pragma omp parallel for + for (int k = 0; k < rcz; k += block_size) + for (int i = 0; i < b; i++) { + int p = k + i * rcz; + for (int l = 0; l < block_size && k + l < rcz; l++, p++) { + int j = (k + l) / rc; + float o = (input[p] - mean[j]) / variance[j]; + // affine transformation + if (affine_g != NULL) { + opa[p] = o; + output[p] = o * affine_g[j] + affine_b[j]; + } else output[p] = o; + } + } +} + +void cpu_batchnorm_backward(int b, int z, int rc, float *delta, float *opa, float *pdelta, float *gbn_g, float *gbn_b, float *bn_g, float *variance, float *mean1, float *mean2) +{ + const int block_size = 256; + int rcz = rc * z; + float N = b * rc; + if (bn_g != NULL) { // affine + // compute mean + for (int j = 0; j < z; j++) mean1[j] = mean2[j] = 0.0; + #pragma omp parallel for + for (int k = 0; k < rcz; k += block_size) + for (int i = 0; i < b; i++) { + int p = k + i * rcz; + for (int l = 0; l < block_size && k + l < rcz; l++, p++) { + int j = (k + l) / rc; + mean1[j] += delta[p] * opa[p]; + mean2[j] += delta[p]; + delta[p] *= bn_g[j]; + } + } + #pragma omp parallel for + for (int j = 0; j < z; j++) { + mean1[j] /= N; + mean2[j] /= N; + gbn_g[j] += mean1[j]; + gbn_b[j] += mean2[j]; + mean1[j] *= bn_g[j]; + mean2[j] *= bn_g[j]; + } + } else { + // compute mean + for (int j = 0; j < z; j++) mean1[j] = mean2[j] = 0.0; + #pragma omp parallel for + for (int k = 0; k < rcz; k += block_size) + for (int i = 0; i < b; i++) { + int p = k + i * rcz; + for (int l = 0; l < block_size && k + l < rcz; l++, p++) { + int j = (k + l) / rc; + mean1[j] += delta[p] * opa[p]; // step 1 & 2 + mean2[j] += delta[p]; // step 4 + } + } + #pragma omp parallel for + for (int j = 0; j < z; j++) { + mean1[j] /= N; + mean2[j] /= N; + } + } + #pragma omp parallel for + for (int k = 0; k < rcz; k += block_size) + for (int i = 0; i < b; i++) { + int p = k + i * rcz; + for (int l = 0; l < block_size && k + l < rcz; l++, p++) { + int j = (k + l) / rc; + // opa[p] = opa[p] * mean1[j] + mean2[j]; // step 3 & 5 + // delta[p] -= opa[p]; // step 6 + // delta[p] /= variance[j]; // step 7 + // pdelta[p] += delta[p]; + pdelta[p] += (delta[p] - (opa[p] * mean1[j] + mean2[j])) / variance[j]; + } + } +} diff --git a/src/hardware/gpu/gpu_math.cu b/src/hardware/gpu/gpu_math.cu index aa9179649..86a816949 100644 --- a/src/hardware/gpu/gpu_math.cu +++ b/src/hardware/gpu/gpu_math.cu @@ -927,57 +927,3 @@ void gpu_initialize_rd(ReduceDescriptor2 *rd, Tensor *A, Tensor *B, bool reverse } } } - -void gpu_batchnorm_forward(int gpu_device, int b, int z, int rc, - float *input, float *output, float *opa, - float *global_mean, float *global_variance, - float *affine_g, float *affine_b, - float *mean, float *variance, - bool trmode, float epsilon, float momentum) -{ - cudaSetDevice(gpu_device); - int rcz = rc * z; - int num_blocks = rcz / batch_norm_block_size; - if (rcz % batch_norm_block_size) num_blocks++; - int num_blocks_z = z / batch_norm_block_size; - if (z % batch_norm_block_size) num_blocks_z++; - if (trmode) { - // compute mean and variance - // for (int j = 0; j < z; j++) mean[j] = variance[j] = 0.0; - check_cuda(cudaMemset(mean, 0, z * sizeof(float)), "gpu_batchnorm_forward"); - check_cuda(cudaMemset(variance, 0, z * sizeof(float)), "gpu_batchnorm_forward"); - // compute mean and variance - gpu_batchnorm_forward_1<<>>(b, rc, rcz, input, mean, variance); - gpu_batchnorm_forward_2<<>>(z, 1.0 / (b * rc), mean, variance, momentum, global_mean, global_variance, epsilon); - // normalization - gpu_batchnorm_forward_3<<>>(b, rc, rcz, input, mean, variance, affine_g, affine_b, opa, output); - } else { - gpu_batchnorm_forward_2<<>>(z, 1.0 / (b * rc), NULL, variance, momentum, NULL, global_variance, epsilon); - // normalization - gpu_batchnorm_forward_3<<>>(b, rc, rcz, input, mean, variance, affine_g, affine_b, opa, output); - } -} - -void gpu_batchnorm_backward(int gpu_device, int b, int z, int rc, float *delta, float *opa, float *pdelta, float *gbn_g, float *gbn_b, float *bn_g, float *variance, float *mean1, float *mean2) -{ - cudaSetDevice(gpu_device); - int rcz = rc * z; - int num_blocks = rcz / batch_norm_block_size; - if (rcz % batch_norm_block_size) num_blocks++; - int num_blocks_z = z / batch_norm_block_size; - if (z % batch_norm_block_size) num_blocks_z++; - float N = b * rc; - // for (int j = 0; j < z; j++) mean1[j] = mean2[j] = 0.0; - check_cuda(cudaMemset(mean1, 0, z * sizeof(float)), "gpu_batchnorm_backward"); - check_cuda(cudaMemset(mean2, 0, z * sizeof(float)), "gpu_batchnorm_backward"); - if (bn_g != NULL) { - // compute mean - gpu_batchnorm_backward_1<<>>(b, rc, rcz, delta, opa, bn_g, mean1, mean2); - gpu_batchnorm_backward_2<<>>(z, 1.0 / (b * rc), mean1, mean2, gbn_g, gbn_b, bn_g); - } else { - // compute mean - gpu_batchnorm_backward_1<<>>(b, rc, rcz, delta, opa, NULL, mean1, mean2); - gpu_batchnorm_backward_2<<>>(z, 1.0 / (b * rc), mean1, mean2, NULL, NULL, NULL); - } - gpu_batchnorm_backward_3<<>>(b, rc, rcz, delta, opa, pdelta, mean1, mean2, variance); -} diff --git a/src/hardware/gpu/gpu_math_kernels.cu b/src/hardware/gpu/gpu_math_kernels.cu index 1cca4010e..02ec65d74 100644 --- a/src/hardware/gpu/gpu_math_kernels.cu +++ b/src/hardware/gpu/gpu_math_kernels.cu @@ -387,118 +387,3 @@ __global__ void gpu_minimum(float* A, float* B, float* C, long int size){ C[thread_id_x] = min(A[thread_id_x], B[thread_id_x]); } } - -// new batchnorm implementation - -__global__ void gpu_batchnorm_forward_1(int b, int rc, int rcz, float *input, float *mean, float *variance) -{ - // for (int k = 0; k < rcz; k += batch_norm_block_size) - int k = blockIdx.x * batch_norm_block_size + threadIdx.x; - if (k < rcz) { - int j = k / rc; - float m = 0, v = 0; - for (int i = 0, p = k; i < b; i++, p += rcz) { - // for (int l = 0; l < batch_norm_block_size && k + l < rcz; l++, p++) { - float x = input[p]; - m += x; - v += x * x; - } - atomicAdd(mean + j, m); - atomicAdd(variance + j, v); - } -} - -__global__ void gpu_batchnorm_forward_2(int z, float inv_N, float *mean, float *variance, float momentum, float *global_mean, float *global_variance, float epsilon) -{ - // for (int j = 0; j < z; j++) { - int j = blockIdx.x * batch_norm_block_size + threadIdx.x; - if (j < z) { - if (mean != NULL) { - mean[j] *= inv_N; - variance[j] = variance[j] * inv_N - mean[j] * mean[j]; - // update global statistics - if (momentum != 0.0) { - global_mean[j] = momentum * global_mean[j] + (1.0 - momentum) * mean[j]; - global_variance[j] = momentum * global_variance[j] + (1.0 - momentum) * variance[j]; - } - variance[j] = 1.0f / sqrt(variance[j] + epsilon); - } else { - variance[j] = 1.0f / sqrt(global_variance[j] + epsilon); - } - } -} - -__global__ void gpu_batchnorm_forward_3(int b, int rc, int rcz, float *input, float *mean, float *variance, float *affine_g, float *affine_b, float *opa, float *output) -{ - // for (int k = 0; k < rcz; k += batch_norm_block_size) - int k = blockIdx.x * batch_norm_block_size + threadIdx.x; - if (k < rcz) { - int j = k / rc; - float m = mean[j]; - float v = variance[j]; - for (int i = 0, p = k; i < b; i++, p += rcz) { - // for (int l = 0; l < batch_norm_block_size && k + l < rcz; l++, p++) { - float o = (input[p] - m) * v; - // affine transformation - if (affine_g != NULL) { - opa[p] = o; - output[p] = o * affine_g[j] + affine_b[j]; - } else output[p] = o; - } - } -} - -__global__ void gpu_batchnorm_backward_1(int b, int rc, int rcz, float *delta, float *opa, float *bn_g, float *mean1, float *mean2) -{ - // for (int k = 0; k < rcz; k += batch_norm_block_size) - int k = blockIdx.x * batch_norm_block_size + threadIdx.x; - if (k < rcz) { - int j = k / rc; - float m1 = 0, m2 = 0; - for (int i = 0, p = k; i < b; i++, p += rcz) { - // for (int l = 0; l < batch_norm_block_size && k + l < rcz; l++, p++) { - m1 += delta[p] * opa[p]; // step 1 & 2 - m2 += delta[p]; // step 4 - if (bn_g != NULL) delta[p] *= bn_g[j]; // affine - } - atomicAdd(mean1 + j, m1); - atomicAdd(mean2 + j, m2); - } -} - -__global__ void gpu_batchnorm_backward_2(int z, float inv_N, float *mean1, float *mean2, float *gbn_g, float *gbn_b, float *bn_g) -{ - // for (int j = 0; j < z; j++) { - int j = blockIdx.x * batch_norm_block_size + threadIdx.x; - if (j < z) { - if (bn_g != NULL) { // affine - float m1 = mean1[j] * inv_N; - float m2 = mean2[j] * inv_N; - gbn_g[j] += m1; - gbn_b[j] += m2; - mean1[j] = m1 * bn_g[j]; - mean2[j] = m2 * bn_g[j]; - } else { - mean1[j] *= inv_N; - mean2[j] *= inv_N; - } - } -} - -__global__ void gpu_batchnorm_backward_3(int b, int rc, int rcz, float *delta, float *opa, float *pdelta, float *mean1, float *mean2, float *variance) -{ - // for (int k = 0; k < rcz; k += batch_norm_block_size) - int k = blockIdx.x * batch_norm_block_size + threadIdx.x; - if (k < rcz) { - int j = k / rc; - for (int i = 0, p = k; i < b; i++, p += rcz) { - // for (int l = 0; l < batch_norm_block_size && k + l < rcz; l++, p++) { - float o = opa[p] * mean1[j] + mean2[j]; // step 3 & 5 - // opa[p] = o; - float d = delta[p] - o; // step 6 - d = d / variance[j]; // step 7 - // delta[p] = d; - pdelta[p] += d; - } - } -} diff --git a/src/hardware/gpu/nn/gpu_bn.cu b/src/hardware/gpu/nn/gpu_bn.cu index b178760c5..f156ca5ae 100644 --- a/src/hardware/gpu/nn/gpu_bn.cu +++ b/src/hardware/gpu/nn/gpu_bn.cu @@ -62,23 +62,56 @@ void gpu_permute_batch_first(Tensor *A,Tensor *B) check_cuda(cudaDeviceSynchronize(),"bn_permute_batch_first"); } +void gpu_batchnorm_forward(int gpu_device, int b, int z, int rc, + float *input, float *output, float *opa, + float *global_mean, float *global_variance, + float *affine_g, float *affine_b, + float *mean, float *variance, + bool trmode, float epsilon, float momentum) +{ + cudaSetDevice(gpu_device); + int rcz = rc * z; + int num_blocks = rcz / batch_norm_block_size; + if (rcz % batch_norm_block_size) num_blocks++; + int num_blocks_z = z / batch_norm_block_size; + if (z % batch_norm_block_size) num_blocks_z++; + if (trmode) { + // compute mean and variance + // for (int j = 0; j < z; j++) mean[j] = variance[j] = 0.0; + check_cuda(cudaMemset(mean, 0, z * sizeof(float)), "gpu_batchnorm_forward"); + check_cuda(cudaMemset(variance, 0, z * sizeof(float)), "gpu_batchnorm_forward"); + // compute mean and variance + gpu_batchnorm_forward_1<<>>(b, rc, rcz, input, mean, variance); + gpu_batchnorm_forward_2<<>>(z, 1.0 / (b * rc), mean, variance, momentum, global_mean, global_variance, epsilon); + // normalization + gpu_batchnorm_forward_3<<>>(b, rc, rcz, input, mean, variance, affine_g, affine_b, opa, output); + } else { + gpu_batchnorm_forward_2<<>>(z, 1.0 / (b * rc), NULL, variance, momentum, NULL, global_variance, epsilon); + // normalization + gpu_batchnorm_forward_3<<>>(b, rc, rcz, input, mean, variance, affine_g, affine_b, opa, output); + } +} - - - - - - - - - - - - - - - - - - -///////// +void gpu_batchnorm_backward(int gpu_device, int b, int z, int rc, float *delta, float *opa, float *pdelta, float *gbn_g, float *gbn_b, float *bn_g, float *variance, float *mean1, float *mean2) +{ + cudaSetDevice(gpu_device); + int rcz = rc * z; + int num_blocks = rcz / batch_norm_block_size; + if (rcz % batch_norm_block_size) num_blocks++; + int num_blocks_z = z / batch_norm_block_size; + if (z % batch_norm_block_size) num_blocks_z++; + float N = b * rc; + // for (int j = 0; j < z; j++) mean1[j] = mean2[j] = 0.0; + check_cuda(cudaMemset(mean1, 0, z * sizeof(float)), "gpu_batchnorm_backward"); + check_cuda(cudaMemset(mean2, 0, z * sizeof(float)), "gpu_batchnorm_backward"); + if (bn_g != NULL) { + // compute mean + gpu_batchnorm_backward_1<<>>(b, rc, rcz, delta, opa, bn_g, mean1, mean2); + gpu_batchnorm_backward_2<<>>(z, 1.0 / (b * rc), mean1, mean2, gbn_g, gbn_b, bn_g); + } else { + // compute mean + gpu_batchnorm_backward_1<<>>(b, rc, rcz, delta, opa, NULL, mean1, mean2); + gpu_batchnorm_backward_2<<>>(z, 1.0 / (b * rc), mean1, mean2, NULL, NULL, NULL); + } + gpu_batchnorm_backward_3<<>>(b, rc, rcz, delta, opa, pdelta, mean1, mean2, variance); +} diff --git a/src/hardware/gpu/nn/gpu_bn_kernels.cu b/src/hardware/gpu/nn/gpu_bn_kernels.cu index 9b380704c..ba67e2884 100644 --- a/src/hardware/gpu/nn/gpu_bn_kernels.cu +++ b/src/hardware/gpu/nn/gpu_bn_kernels.cu @@ -86,3 +86,118 @@ __global__ void bn_permute_batch_first(float *src, float *dest,int b,int z,int r dest[thread_id_x]=src[pos]; } } + +// new batchnorm implementation + +__global__ void gpu_batchnorm_forward_1(int b, int rc, int rcz, float *input, float *mean, float *variance) +{ + // for (int k = 0; k < rcz; k += batch_norm_block_size) + int k = blockIdx.x * batch_norm_block_size + threadIdx.x; + if (k < rcz) { + int j = k / rc; + float m = 0, v = 0; + for (int i = 0, p = k; i < b; i++, p += rcz) { + // for (int l = 0; l < batch_norm_block_size && k + l < rcz; l++, p++) { + float x = input[p]; + m += x; + v += x * x; + } + atomicAdd(mean + j, m); + atomicAdd(variance + j, v); + } +} + +__global__ void gpu_batchnorm_forward_2(int z, float inv_N, float *mean, float *variance, float momentum, float *global_mean, float *global_variance, float epsilon) +{ + // for (int j = 0; j < z; j++) { + int j = blockIdx.x * batch_norm_block_size + threadIdx.x; + if (j < z) { + if (mean != NULL) { + mean[j] *= inv_N; + variance[j] = variance[j] * inv_N - mean[j] * mean[j]; + // update global statistics + if (momentum != 0.0) { + global_mean[j] = momentum * global_mean[j] + (1.0 - momentum) * mean[j]; + global_variance[j] = momentum * global_variance[j] + (1.0 - momentum) * variance[j]; + } + variance[j] = 1.0f / sqrt(variance[j] + epsilon); + } else { + variance[j] = 1.0f / sqrt(global_variance[j] + epsilon); + } + } +} + +__global__ void gpu_batchnorm_forward_3(int b, int rc, int rcz, float *input, float *mean, float *variance, float *affine_g, float *affine_b, float *opa, float *output) +{ + // for (int k = 0; k < rcz; k += batch_norm_block_size) + int k = blockIdx.x * batch_norm_block_size + threadIdx.x; + if (k < rcz) { + int j = k / rc; + float m = mean[j]; + float v = variance[j]; + for (int i = 0, p = k; i < b; i++, p += rcz) { + // for (int l = 0; l < batch_norm_block_size && k + l < rcz; l++, p++) { + float o = (input[p] - m) * v; + // affine transformation + if (affine_g != NULL) { + opa[p] = o; + output[p] = o * affine_g[j] + affine_b[j]; + } else output[p] = o; + } + } +} + +__global__ void gpu_batchnorm_backward_1(int b, int rc, int rcz, float *delta, float *opa, float *bn_g, float *mean1, float *mean2) +{ + // for (int k = 0; k < rcz; k += batch_norm_block_size) + int k = blockIdx.x * batch_norm_block_size + threadIdx.x; + if (k < rcz) { + int j = k / rc; + float m1 = 0, m2 = 0; + for (int i = 0, p = k; i < b; i++, p += rcz) { + // for (int l = 0; l < batch_norm_block_size && k + l < rcz; l++, p++) { + m1 += delta[p] * opa[p]; // step 1 & 2 + m2 += delta[p]; // step 4 + if (bn_g != NULL) delta[p] *= bn_g[j]; // affine + } + atomicAdd(mean1 + j, m1); + atomicAdd(mean2 + j, m2); + } +} + +__global__ void gpu_batchnorm_backward_2(int z, float inv_N, float *mean1, float *mean2, float *gbn_g, float *gbn_b, float *bn_g) +{ + // for (int j = 0; j < z; j++) { + int j = blockIdx.x * batch_norm_block_size + threadIdx.x; + if (j < z) { + if (bn_g != NULL) { // affine + float m1 = mean1[j] * inv_N; + float m2 = mean2[j] * inv_N; + gbn_g[j] += m1; + gbn_b[j] += m2; + mean1[j] = m1 * bn_g[j]; + mean2[j] = m2 * bn_g[j]; + } else { + mean1[j] *= inv_N; + mean2[j] *= inv_N; + } + } +} + +__global__ void gpu_batchnorm_backward_3(int b, int rc, int rcz, float *delta, float *opa, float *pdelta, float *mean1, float *mean2, float *variance) +{ + // for (int k = 0; k < rcz; k += batch_norm_block_size) + int k = blockIdx.x * batch_norm_block_size + threadIdx.x; + if (k < rcz) { + int j = k / rc; + for (int i = 0, p = k; i < b; i++, p += rcz) { + // for (int l = 0; l < batch_norm_block_size && k + l < rcz; l++, p++) { + float o = opa[p] * mean1[j] + mean2[j]; // step 3 & 5 + // opa[p] = o; + float d = delta[p] - o; // step 6 + d = d / variance[j]; // step 7 + // delta[p] = d; + pdelta[p] += d; + } + } +} diff --git a/src/hardware/gpu/nn/gpu_conv.cu b/src/hardware/gpu/nn/gpu_conv.cu index d45218c15..fa45aa313 100644 --- a/src/hardware/gpu/nn/gpu_conv.cu +++ b/src/hardware/gpu/nn/gpu_conv.cu @@ -350,13 +350,129 @@ void gpu_conv2D_back(ConvolDescriptor *D){ void gpu_conv3D(ConvolDescriptor3D *D){ + int device=D->I->gpu_device; + cudaSetDevice(device); +#ifdef cCUDNN + // FWD environment + float alpha = 1.0f; + float beta = 0.0f; + if (D->cudnn_env_init < 0){ + D->cudnn_env_init = 1; + + int requestedAlgoCount; + check_cudnn(cudnnGetConvolutionForwardAlgorithmMaxCount( hdnn[device], &requestedAlgoCount), + "cudnnGetConvolutionForwardAlgorithmMaxCount",__FILE__); + + int returnedAlgoCount; + cudnnConvolutionFwdAlgoPerf_t * perfResults = new cudnnConvolutionFwdAlgoPerf_t [requestedAlgoCount]; + check_cudnn(cudnnFindConvolutionForwardAlgorithm( hdnn[device], D->xDesc, D->wDesc, D->convolution_descriptor, D->yDesc, + requestedAlgoCount, &returnedAlgoCount, perfResults),"cudnnFindConvolutionForwardAlgorithm",__FILE__); + + int aux_alg = 0; + size_t size; + do{ + D->fwd_algorithm = perfResults[aux_alg].algo; + + check_cudnn(cudnnGetConvolutionForwardWorkspaceSize(hdnn[device],D->xDesc, D->wDesc, + D->convolution_descriptor, D->yDesc, + D->fwd_algorithm, &size), + "cudnnGetConvolutionForwardWorkspaceSize",__FILE__); + aux_alg++; + } + while(allocate_workspace(size,device)); + } + //BWD environment + if (D->cudnn_conv_back_init < 0){ + D->cudnn_conv_back_init = 1; + int requestedAlgoCount; + + check_cudnn(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( + hdnn[device], &requestedAlgoCount),"cudnnGetConvolutionBackwardFilterAlgorithmMaxCount",__FILE__); + int returnedAlgoCount; + cudnnConvolutionBwdFilterAlgoPerf_t * perfResults = new cudnnConvolutionBwdFilterAlgoPerf_t [requestedAlgoCount]; + + check_cudnn(cudnnFindConvolutionBackwardFilterAlgorithm(hdnn[device], D->xDesc, D->yDesc, + D->convolution_descriptor, D->wDesc, requestedAlgoCount, + &returnedAlgoCount, perfResults),"cudnnFindConvolutionBackwardFilterAlgorithm",__FILE__); + int aux_alg = 0; + size_t size; + do{ + D->bwd_filter_algorithm = perfResults[aux_alg].algo; + + check_cudnn(cudnnGetConvolutionBackwardFilterWorkspaceSize(hdnn[device],D->xDesc, D->yDesc, + D->convolution_descriptor, D->wDesc, + D->bwd_filter_algorithm, &size),"cudnnGetConvolutionBackwardFilterWorkspaceSize",__FILE__); + aux_alg++; + } + while(allocate_workspace(size,device)); + //////////// DATA!!!! + requestedAlgoCount = 0; + check_cudnn(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(hdnn[device], &requestedAlgoCount),"cudnnGetConvolutionBackwardDataAlgorithmMaxCount", __FILE__); + returnedAlgoCount=0; + cudnnConvolutionBwdDataAlgoPerf_t * perfResults_d = new cudnnConvolutionBwdDataAlgoPerf_t [requestedAlgoCount]; + + check_cudnn(cudnnFindConvolutionBackwardDataAlgorithm(hdnn[device], D->wDesc, D->yDesc, + D->convolution_descriptor, D->xDesc, requestedAlgoCount, + &returnedAlgoCount, perfResults_d),"(cudnnFindConvolutionBackwardDataAlgorithm",__FILE__); + aux_alg = 0; + size=0; + do{ + D->bwd_data_algorithm = perfResults_d[aux_alg].algo; + + check_cudnn(cudnnGetConvolutionBackwardDataWorkspaceSize(hdnn[device],D->wDesc, D->yDesc, + D->convolution_descriptor, D->xDesc, + D->bwd_data_algorithm, &size),"cudnnGetConvolutionBackwardDataWorkspaceSize",__FILE__); + aux_alg++; + } + while(allocate_workspace(size,device)); + + } + + check_cudnn(cudnnConvolutionForward( hdnn[device], &alpha, D->xDesc, D->I->ptr, + D->wDesc, D->K->ptr, + D->convolution_descriptor, D->fwd_algorithm, + shared_workspace[device], workspace_size[device], + &beta, D->yDesc, D->O->ptr),"cudnnConvolutionForward",__FILE__); + if (D->use_bias) { + check_cudnn(cudnnAddTensor(hdnn[device], &alpha, D->bDesc, D->bias->ptr, + &alpha, D->yDesc, D->O->ptr),"cudnnAddTensor",__FILE__); + } + + +#endif } void gpu_conv3D_grad(ConvolDescriptor3D *D){ + int device=D->I->gpu_device; + cudaSetDevice(device); +#ifndef cCUDNN + check_cudnn(cudnnConvolutionBackwardFilter(hdnn[device], &alpha, + D->xDesc, D->I->ptr, + D->yDesc, D->D->ptr, D->convolution_descriptor, + D->bwd_filter_algorithm, + shared_workspace[device], workspace_size[device], + &beta, D->wDesc, D->gK->ptr),"cudnnConvolutionBackwardFilter",__FILE__); + if (D->use_bias) { + check_cudnn(cudnnConvolutionBackwardBias(hdnn[device], &alpha, D->yDesc, D->D->ptr, + &beta, D->bDesc, D->gbias->ptr),"cudnnConvolutionBackwardBias",__FILE__); + } +#endif } void gpu_conv3D_back(ConvolDescriptor3D *D){ + int device=D->I->gpu_device; + cudaSetDevice(device); +#ifdef cCUDNN + float alpha = 1.0f; + float beta = 0.0f; + check_cudnn(cudnnConvolutionBackwardData(hdnn[device], &alpha, D->wDesc, D->K->ptr, + D->yDesc, D->D->ptr, + D->convolution_descriptor, D->bwd_data_algorithm, + shared_workspace[device], workspace_size[device], + &beta, D->xDesc, D->ID->ptr),"cudnnConvolutionBackwardData",__FILE__); +#endif + } diff --git a/src/hardware/gpu/nn/gpu_losses.cu b/src/hardware/gpu/nn/gpu_losses.cu index b01869818..8437995be 100644 --- a/src/hardware/gpu/nn/gpu_losses.cu +++ b/src/hardware/gpu/nn/gpu_losses.cu @@ -13,7 +13,7 @@ #include //#include -#include +/* #include #include #include #include @@ -23,7 +23,7 @@ #include #include #include -#include +#include */ #include "eddl/hardware/gpu/nn/gpu_tensor_nn.h" #include "eddl/hardware/gpu/nn/gpu_tensor_nn_kernels.h" @@ -60,6 +60,7 @@ float gpu_categorical_cross_entropy(Tensor* y_true, Tensor* y_pred){ float *sum_array; check_cuda(cudaMalloc((void**)&(sum_array), n_batches*sizeof(float)),"create temp array"); + check_cuda(cudaMemset(sum_array, 0, sizeof(float)), "memset"); check_cuda(cudaDeviceSynchronize(), "create"); // Calculate derivative of Softmax @@ -67,8 +68,9 @@ float gpu_categorical_cross_entropy(Tensor* y_true, Tensor* y_pred){ check_cuda(cudaDeviceSynchronize(),"gpu_categorical_cross_entropy"); // Reduce sum and compute mean - thrust::device_ptr dev_ptr = thrust::device_pointer_cast(sum_array); - float sum_ce = thrust::reduce(dev_ptr, dev_ptr + n_batches); + // thrust::device_ptr dev_ptr = thrust::device_pointer_cast(sum_array); + float sum_ce; // = thrust::reduce(dev_ptr, dev_ptr + n_batches); + check_cuda(cudaMemcpy(&sum_ce, sum_array, sizeof(float), cudaMemcpyDeviceToHost), "memcpy"); float mean_ce = -sum_ce;//(float)n_batches; // Mean // Delete tmp array @@ -102,6 +104,7 @@ float gpu_binary_cross_entropy(Tensor* y_true, Tensor* y_pred){ float *sum_array; check_cuda(cudaMalloc((void**)&(sum_array), n_batches*sizeof(float)),"create temp array"); + check_cuda(cudaMemset(sum_array, 0, sizeof(float)), "memset"); check_cuda(cudaDeviceSynchronize(), "create"); // Calculate derivative of Softmax @@ -109,8 +112,9 @@ float gpu_binary_cross_entropy(Tensor* y_true, Tensor* y_pred){ check_cuda(cudaDeviceSynchronize(),"gpu_binary_cross_entropy"); // Reduce sum and compute mean - thrust::device_ptr dev_ptr = thrust::device_pointer_cast(sum_array); - float sum_ce = thrust::reduce(dev_ptr, dev_ptr + n_batches); + // thrust::device_ptr dev_ptr = thrust::device_pointer_cast(sum_array); + float sum_ce; // = thrust::reduce(dev_ptr, dev_ptr + n_batches); + check_cuda(cudaMemcpy(&sum_ce, sum_array, sizeof(float), cudaMemcpyDeviceToHost), "memcpy"); float mean_ce = -sum_ce;//(float)n_batches; // Mean // Delete tmp array diff --git a/src/hardware/gpu/nn/gpu_losses_kernels.cu b/src/hardware/gpu/nn/gpu_losses_kernels.cu index c5526b27e..b7fb4108f 100644 --- a/src/hardware/gpu/nn/gpu_losses_kernels.cu +++ b/src/hardware/gpu/nn/gpu_losses_kernels.cu @@ -50,7 +50,8 @@ __global__ void gpu_categorical_cross_entropy(float* y_true, float* y_pred, floa } // Store partial sums (later will be reduced) - sum_array[thread_id_x] = bi_sum; + // sum_array[thread_id_x] = bi_sum; + atomicAdd(sum_array, bi_sum); } } @@ -70,7 +71,8 @@ __global__ void gpu_binary_cross_entropy(float* y_true, float* y_pred, float* su float eps =10e-8; // Store sums (later will be reduced) - sum_array[thread_id_x] = y_true[thread_id_x] * logf(y_pred[thread_id_x]+eps) + (1.0-y_true[thread_id_x]) * logf(1.0f-y_pred[thread_id_x]+eps); + // sum_array[thread_id_x] = y_true[thread_id_x] * logf(y_pred[thread_id_x]+eps) + (1.0-y_true[thread_id_x]) * logf(1.0f-y_pred[thread_id_x]+eps); + atomicAdd(sum_array, y_true[thread_id_x] * logf(y_pred[thread_id_x]+eps) + (1.0-y_true[thread_id_x]) * logf(1.0f-y_pred[thread_id_x]+eps)); } } diff --git a/src/hardware/gpu/nn/gpu_pool.cu b/src/hardware/gpu/nn/gpu_pool.cu index cbd1fceba..9241297a6 100644 --- a/src/hardware/gpu/nn/gpu_pool.cu +++ b/src/hardware/gpu/nn/gpu_pool.cu @@ -92,9 +92,33 @@ void gpu_mpool2D_back(PoolDescriptor *D){ void gpu_mpool3D(PoolDescriptor3D *D){ +int device=D->I->gpu_device; + cudaSetDevice(device); + +#ifdef cCUDNN + float alpha=1.0; + float beta=0.0; +// amy_get_descriptor(D->xDesc,"xDesc"); +// amy_get_descriptor(D->yDesc,"yDesc"); + check_cudnn(cudnnPoolingForward(hdnn[device], D->poolingDesc, + &alpha, D->xDesc, D->I->ptr, + &beta, D->yDesc, D->O->ptr),"cudnnPoolingForward",__FILE__); +#endif + } void gpu_mpool3D_back(PoolDescriptor3D *D){ +int device=D->I->gpu_device; + cudaSetDevice(device); + +#ifdef cCUDNN + float alpha=1.0; + float beta=0.0; + + check_cudnn(cudnnPoolingBackward(hdnn[device], D->poolingDesc, &alpha, D->yDesc, D->O->ptr, + D->yDesc, D->D->ptr, D->xDesc, D->I->ptr, + &beta, D->xDesc, D->ID->ptr),"cudnnPoolingBackward",__FILE__); +#endif } diff --git a/src/layers/layer.cpp b/src/layers/layer.cpp index 68f4a0f9a..05c1d1529 100644 --- a/src/layers/layer.cpp +++ b/src/layers/layer.cpp @@ -37,6 +37,7 @@ Layer::Layer(string name, int dev, int mem) { trainable=true; iscloned=false; isdecoder=false; + distributed_training=false; this->do_deletes = true; diff --git a/src/layers/normalization/layer_batchnorm.cpp b/src/layers/normalization/layer_batchnorm.cpp index db94ed0a1..59bdde87f 100644 --- a/src/layers/normalization/layer_batchnorm.cpp +++ b/src/layers/normalization/layer_batchnorm.cpp @@ -155,67 +155,65 @@ void LBatchNorm::forward() { // Input = Output = opa = {Batch,Channels,H,W} OR {Batch,Dim} // bn_mean = bn_var = mean = variance = bn_g = bn_b = {Channels} or {Dim} -#ifndef cCUDNN -#ifndef BATCHNORM_ORIG + if ((input->isCPU())||(input->isFPGA())) { + // new implementation for CPU / GPU - if (input->isCPU() || input->isGPU()) { + if (input->isCPU()) { tensorNN::BatchNormForward(input, output, opa, mean, variance, affine ? bn_g : NULL, affine ? bn_b : NULL, bn_mean, bn_var, mode == TRMODE, epsilon, momentum); - } else -#endif - { - - int M,N; - int b,z,r,c,d; - Tensor *in; - - if (input->ndim==2) { - N=b=input->shape[0]; - M=d=input->shape[1]; - in=input->clone(); + } else { + int M,N; + int b,z,r,c,d; + Tensor *in; + + if (input->ndim==2) { + N=b=input->shape[0]; + M=d=input->shape[1]; + in=input->clone(); + } + else { + b=input->shape[0]; + M=z=input->shape[1]; + r=input->shape[2]; + c=input->shape[3]; + N=b*r*c; + + in=new Tensor({b*r*c*z},input->device); + tensorNN::permute_channels_last(input,in); + in->reshape_({N,M}); + opa->reshape_({N,M}); + } + + BN_forward(in,bn_mean,bn_var,mean,variance,momentum,epsilon,mode==TRMODE); + + + Tensor::copy(in,opa); + if (affine) { + Tensor *var=new Tensor({N,M},input->device); + Tensor *ones=new Tensor({N,1},input->device); + ones->fill_(1.0); + + // apply affine transform in=gamma*in+beta + rmult(in,bn_g,ones,var); + rsum(in,bn_b,ones,var); + delete var; + delete ones; + } + + // copy in to ouput + if (input->ndim==4) { + tensorNN::permute_channels_first(in,output); + } + else Tensor::copy(in,output); + + + delete in; } - else { - b=input->shape[0]; - M=z=input->shape[1]; - r=input->shape[2]; - c=input->shape[3]; - N=b*r*c; - - in=new Tensor({b*r*c*z},input->device); - tensorNN::permute_channels_last(input,in); - in->reshape_({N,M}); - opa->reshape_({N,M}); - } - - BN_forward(in,bn_mean,bn_var,mean,variance,momentum,epsilon,mode==TRMODE); - - Tensor::copy(in,opa); - if (affine) { - Tensor *var=new Tensor({N,M},input->device); - Tensor *ones=new Tensor({N,1},input->device); - ones->fill_(1.0); - - // apply affine transform in=gamma*in+beta - rmult(in,bn_g,ones,var); - rsum(in,bn_b,ones,var); - delete var; - delete ones; - } - - // copy in to ouput - if (input->ndim==4) { - tensorNN::permute_channels_first(in,output); } - else Tensor::copy(in,output); - - - delete in; - - } - -#else + else { // GPU +#ifdef cCUDNN float alpha = 1.0; float beta = 0.0; @@ -224,90 +222,94 @@ void LBatchNorm::forward() { bnScaleBiasMeanVarDesc, bn_g->ptr, bn_b->ptr, exponentialAverageFactor, mean->ptr, variance->ptr, epsilon, bn_mean->ptr, bn_var->ptr); - if(nnn != CUDNN_STATUS_SUCCESS) std::cout<<"Error fwd BN "<< cudnnGetErrorString(nnn) <name <isCPU() || input->isGPU()) { - tensorNN::BatchNormBackward(delta, opa, parent[0]->delta, - affine ? gbn_g : NULL, affine ? gbn_b : NULL, affine ? bn_g : NULL, - bn_var, work1, work2); - } else -#endif - { - - int M,N; - int b,z,r,c,d; - - Tensor *dp; - - if (input->ndim==2) { - N=b=input->shape[0]; - M=d=input->shape[1]; - - dp=delta->clone(); + if ((input->isCPU())||(input->isFPGA())) { + // new implementation for CPU + if (input->isCPU()) { + tensorNN::BatchNormBackward(delta, opa, parent[0]->delta, + affine ? gbn_g : NULL, affine ? gbn_b : NULL, affine ? bn_g : NULL, + bn_var, work1, work2); + } else + { + + int M,N; + int b,z,r,c,d; + + Tensor *dp; + + if (input->ndim==2) { + N=b=input->shape[0]; + M=d=input->shape[1]; + + dp=delta->clone(); + } + else { + b=input->shape[0]; + M=z=input->shape[1]; + r=input->shape[2]; + c=input->shape[3]; + + N=b*r*c; + + // permute input and delta + dp=new Tensor({b,r,c,z},input->device); + + tensorNN::permute_channels_last(delta,dp); + + dp->reshape_({N,M}); + + } + + // Affine + if (affine) { + Tensor *A=new Tensor({N,M},delta->device); + Tensor *ones=new Tensor({N},delta->device); + ones->fill_(1.0); + Tensor *m=new Tensor({1,M},delta->device); + //1 gamma + Tensor::el_mult(dp,opa,A,0); + cmean(A,m,ones); + Tensor::add(1,gbn_g,1,m,gbn_g,1); + + //2 Beta + cmean(dp,m,ones); + Tensor::add(1,gbn_b,1,m,gbn_b,1); + + // delta=dE/dY + // Obtain dE/dY from delta: + rmult(dp,bn_g,ones,A); + delete A; + delete ones; + delete m; + } + + BN_backward(dp,bn_var,opa); + + // Inc parent delta + if (input->ndim==4) { + tensorNN::permute_channels_first(dp,delta); + Tensor::inc(delta, parent[0]->delta); + } + else Tensor::inc(dp, parent[0]->delta); + + delete dp; + + } } - else { - b=input->shape[0]; - M=z=input->shape[1]; - r=input->shape[2]; - c=input->shape[3]; - - N=b*r*c; - - // permute input and delta - dp=new Tensor({b,r,c,z},input->device); - - tensorNN::permute_channels_last(delta,dp); - - dp->reshape_({N,M}); - - } - - // Affine - if (affine) { - Tensor *A=new Tensor({N,M},delta->device); - Tensor *ones=new Tensor({N},delta->device); - ones->fill_(1.0); - Tensor *m=new Tensor({1,M},delta->device); - //1 gamma - Tensor::el_mult(dp,opa,A,0); - cmean(A,m,ones); - Tensor::add(1,gbn_g,1,m,gbn_g,1); - - //2 Beta - cmean(dp,m,ones); - Tensor::add(1,gbn_b,1,m,gbn_b,1); - - // delta=dE/dY - // Obtain dE/dY from delta: - rmult(dp,bn_g,ones,A); - delete A; - delete ones; - delete m; - } - - BN_backward(dp,bn_var,opa); - - // Inc parent delta - if (input->ndim==4) { - tensorNN::permute_channels_first(dp,delta); - Tensor::inc(delta, parent[0]->delta); - } - else Tensor::inc(dp, parent[0]->delta); - - delete dp; - - } - -#else + else { //GPU + #ifdef cCUDNN float alphaDataDiff = 1.0; float betaDataDiff = 0.0; float alphaParamDiff = 1.0; @@ -319,8 +321,13 @@ void LBatchNorm::backward(){ bnScaleBiasMeanVarDesc,bn_g->ptr, gbn_g->ptr, gbn_b->ptr, epsilon, bn_mean->ptr, bn_var->ptr); if(nnn != CUDNN_STATUS_SUCCESS) std::cout<<"Error bwd BN "<< cudnnGetErrorString(nnn) <delta, + affine ? gbn_g : NULL, affine ? gbn_b : NULL, affine ? bn_g : NULL, + bn_var, work1, work2); #endif + } } diff --git a/src/layers/pool/layer_avgpool.cpp b/src/layers/pool/layer_avgpool.cpp index dbb4c29d9..b42aebb9c 100644 --- a/src/layers/pool/layer_avgpool.cpp +++ b/src/layers/pool/layer_avgpool.cpp @@ -28,7 +28,10 @@ LAveragePool::LAveragePool(Layer *parent, const vector &pool_size, const ve LAveragePool::LAveragePool(Layer *parent, PoolDescriptor *D, const string& name, int dev, int mem) : LPool(parent, D, name, dev, mem) { if(name.empty()) this->name = "avgpool" + to_string(++total_layers); - + + // Params + D->indX = new Tensor(D->O->shape, dev); // Is this needed here? + D->indY = new Tensor(D->O->shape, dev); #ifdef cCUDNN D->mode = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; D->maxpoolingNanOpt = CUDNN_NOT_PROPAGATE_NAN; @@ -41,6 +44,12 @@ LAveragePool::LAveragePool(Layer *parent, PoolDescriptor *D, const string& name, void LAveragePool::resize(int batch){ LPool::resize(batch); + + delete pd->indX; + pd->indX = new Tensor(pd->O->shape, dev); + + delete pd->indY; + pd->indY = new Tensor(pd->O->shape, dev); } void LAveragePool::forward() { @@ -52,7 +61,7 @@ void LAveragePool::backward() { } Layer *LAveragePool::share(int c, int bs, vector p) { - auto *n = new LAveragePool(p[0], this->pd, "share_"+to_string(c)+this->name, this->dev, this->mem_level); + auto *n = new LAveragePool(p[0], new PoolDescriptor(pd->ksize, pd->stride, pd->pad, pd->mem_level), "share_"+to_string(c)+this->name, this->dev, this->mem_level); n->orig = this; return n; @@ -60,7 +69,7 @@ Layer *LAveragePool::share(int c, int bs, vector p) { Layer *LAveragePool::clone(int c, int bs, vector p, int todev) { - auto *n = new LMaxPool(p[0], new PoolDescriptor(pd->ksize, pd->stride, pd->pad, pd->mem_level), "share_"+to_string(c)+this->name, todev, this->mem_level); + auto *n = new LAveragePool(p[0], new PoolDescriptor(pd->ksize, pd->stride, pd->pad, pd->mem_level), "share_"+to_string(c)+this->name, todev, this->mem_level); n->orig = this; diff --git a/src/layers/pool/layer_avgpool1D.cpp b/src/layers/pool/layer_avgpool1D.cpp index 616b8cb40..b129d9d1f 100644 --- a/src/layers/pool/layer_avgpool1D.cpp +++ b/src/layers/pool/layer_avgpool1D.cpp @@ -51,7 +51,7 @@ void LAveragePool1D::backward() { } Layer *LAveragePool1D::share(int c, int bs, vector p) { - auto *n = new LAveragePool1D(p[0], this->pd, "share_"+to_string(c)+this->name, this->dev, this->mem_level); + auto *n = new LAveragePool1D(p[0], new PoolDescriptor(pd->ksize, pd->stride, pd->pad, pd->mem_level), "share_"+to_string(c)+this->name, this->dev, this->mem_level); n->orig = this; return n; diff --git a/src/layers/pool/layer_maxpool.cpp b/src/layers/pool/layer_maxpool.cpp index 998022b6d..72a263ddc 100644 --- a/src/layers/pool/layer_maxpool.cpp +++ b/src/layers/pool/layer_maxpool.cpp @@ -34,12 +34,13 @@ LMaxPool::LMaxPool(Layer *parent, PoolDescriptor *D, const string& name, int dev D->indY = new Tensor(D->O->shape, dev); #ifdef cCUDNN + if(!D->I->isCPU()){ D->mode = CUDNN_POOLING_MAX; D->maxpoolingNanOpt = CUDNN_NOT_PROPAGATE_NAN; cudnnStatus_t bbb = cudnnSetPooling2dDescriptor(D->poolingDesc, D->mode, D->maxpoolingNanOpt, D->windowHeight, D->windowWidth, D->verticalPadding, D->horizontalPadding, D->verticalStride, D->horizontalStride); if(bbb != CUDNN_STATUS_SUCCESS) std::cout<<"Error create pooling descriptor "<< cudnnGetErrorString(bbb) < p) { - auto *n = new LMaxPool(p[0], this->pd, "share_"+to_string(c)+this->name, this->dev, this->mem_level); + auto *n = new LMaxPool(p[0], new PoolDescriptor(pd->ksize, pd->stride, pd->pad, pd->mem_level), "share_"+to_string(c)+this->name, this->dev, this->mem_level); n->orig = this; return n; diff --git a/src/layers/pool/layer_maxpool1D.cpp b/src/layers/pool/layer_maxpool1D.cpp index 1634b01da..bb69d8606 100644 --- a/src/layers/pool/layer_maxpool1D.cpp +++ b/src/layers/pool/layer_maxpool1D.cpp @@ -32,6 +32,20 @@ LMaxPool1D::LMaxPool1D(Layer *parent, PoolDescriptor *D, const string& name, int // Params D->indX = new Tensor(D->O->shape, dev); // Is this needed here? D->indY = new Tensor(D->O->shape, dev); + +#ifdef cCUDNN + if(!D->I->isCPU()){ + D->mode = CUDNN_POOLING_MAX; + D->maxpoolingNanOpt = CUDNN_NOT_PROPAGATE_NAN; +// std::cout<<"wH: "<windowHeight<<" wW: " << D->windowWidth<<", vp: "<verticalPadding<<",hp: " << D->horizontalPadding<<", vs" <verticalStride<<", hS" <horizontalStride<poolingDesc, D->mode, D->maxpoolingNanOpt, D->windowHeight, D->windowWidth, + D->verticalPadding, D->horizontalPadding, D->verticalStride, D->horizontalStride); + if(bbb != CUDNN_STATUS_SUCCESS) std::cout<<"Error create pooling descriptor "<< cudnnGetErrorString(bbb) < p) { - auto *n = new LMaxPool1D(p[0], this->pd, "share_"+to_string(c)+this->name, this->dev, this->mem_level); + auto *n = new LMaxPool1D(p[0], new PoolDescriptor(pd->ksize, pd->stride, pd->pad, pd->mem_level), "share_"+to_string(c)+this->name, this->dev, this->mem_level); n->orig = this; return n; diff --git a/src/layers/pool/layer_maxpool3D.cpp b/src/layers/pool/layer_maxpool3D.cpp index 2e1c8be31..ca6b7c74d 100644 --- a/src/layers/pool/layer_maxpool3D.cpp +++ b/src/layers/pool/layer_maxpool3D.cpp @@ -31,6 +31,16 @@ LMaxPool3D::LMaxPool3D(Layer *parent, PoolDescriptor3D *D, const string& name, i // Params D->indX = new Tensor(D->O->shape, dev); // Is this needed here? D->indY = new Tensor(D->O->shape, dev); +#ifdef cCUDNN + if(!D->I->isCPU()){ + + D->mode = CUDNN_POOLING_MAX; + D->maxpoolingNanOpt = CUDNN_NOT_PROPAGATE_NAN; + cudnnStatus_t bbb = cudnnSetPoolingNdDescriptor(D->poolingDesc, D->mode, D->maxpoolingNanOpt, 3, D->cwindow, D->cpadding, + D->cstride); + if(bbb != CUDNN_STATUS_SUCCESS) std::cout<<"Error create pooling3D descriptor "<< cudnnGetErrorString(bbb) < p) { - auto *n = new LMaxPool3D(p[0], this->pd, "share_"+to_string(c)+this->name, this->dev, this->mem_level); + auto *n = new LMaxPool3D(p[0], new PoolDescriptor3D(pd->ksize, pd->stride, pd->pad, pd->mem_level), "share_"+to_string(c)+this->name, this->dev, this->mem_level); n->orig = this; return n; diff --git a/src/layers/pool/layer_pool1D.cpp b/src/layers/pool/layer_pool1D.cpp index c63a17069..b5658e969 100644 --- a/src/layers/pool/layer_pool1D.cpp +++ b/src/layers/pool/layer_pool1D.cpp @@ -47,20 +47,7 @@ LPool1D::LPool1D(Layer *parent, PoolDescriptor *D, string name, int dev, int mem } LPool1D::~LPool1D(){ - - // deleting pd->O here can provoque double delete/free problems - if (this->pd->O != nullptr) delete this->pd->O; - this->pd->O = nullptr; - - // deleting pd->D here can provoque double delete/free problems - if (this->pd->D != nullptr) delete this->pd->D; - this->pd->D = nullptr; - - delete this->pd; - this->pd = nullptr; - - if (this->input_reshaped != nullptr) delete this->input_reshaped; - this->input_reshaped = nullptr; + delete pd; } void LPool1D::mem_delta(){ diff --git a/src/net/compserv.cpp b/src/net/compserv.cpp index 79289e5f1..c78363a3d 100644 --- a/src/net/compserv.cpp +++ b/src/net/compserv.cpp @@ -33,6 +33,10 @@ CompServ::CompServ(int t, const vector g, const vector &f, int lsb, in for (auto _ : g) this->local_gpus.push_back(_); for (auto _ : f) this->local_fpgas.push_back(_); + if (local_fpgas.size()>0) hw="FPGA"; + else if (local_gpus.size()>0) hw="GPU"; + else hw="CPU"; + this->lsb = lsb; if (lsb < 0) { @@ -64,6 +68,18 @@ CompServ * CompServ::share() { return n; } +CompServ * CompServ::clone() { + CompServ *n = new CompServ(); + + n->type = this->type; + n->local_threads = this->local_threads; + for (auto _ : this->local_gpus) n->local_gpus.push_back(_); + for (auto _ : this->local_fpgas) n->local_fpgas.push_back(_); + n->lsb = this->lsb; + n->mem_level = this->mem_level; + + return n; +} // for Distributed diff --git a/src/net/net_api.cpp b/src/net/net_api.cpp index 562d8b99b..f4c76fd32 100644 --- a/src/net/net_api.cpp +++ b/src/net/net_api.cpp @@ -271,7 +271,6 @@ void Net::forward(vector in) msg("size missmatch in list of tensors","Net.forward(vtensor)"); if (batch_size!=in[0]->shape[0]) { - cout<shape[0]<shape[0]); } diff --git a/src/net/net_build.cpp b/src/net/net_build.cpp index d6744a1d2..7fb8e3172 100644 --- a/src/net/net_build.cpp +++ b/src/net/net_build.cpp @@ -188,7 +188,7 @@ void Net::make_graph(Optimizer *opt, vloss lo, vmetrics me, bool initialize) { decsize=lout.size()/lo.size(); for(int i=0;iclone()); } else losses = vloss(lo); @@ -200,7 +200,7 @@ void Net::make_graph(Optimizer *opt, vloss lo, vmetrics me, bool initialize) { if (isdecoder) { for(int i=0;imetrics.push_back(me[j]); + this->metrics.push_back(me[j]->clone()); } else { for(int j=0;jmetrics.push_back(me[j]); @@ -502,10 +502,11 @@ Layer * Net::getLayer(string lname) return nullptr; } - - void Net::enable_distributed(){ - for(Layer* l : layers){ - l->enable_distributed(); - } + for(Layer* l : layers) + l->enable_distributed(); + + for (int i = 0; i < snets.size(); i++) + for(Layer* l : snets[i]->layers) + l->enable_distributed(); } diff --git a/src/net/net_func.cpp b/src/net/net_func.cpp index 1ff9748ed..952e0a08c 100644 --- a/src/net/net_func.cpp +++ b/src/net/net_func.cpp @@ -68,7 +68,7 @@ void Net::do_delta() { for (int i = 0; i < lout.size(); i++) { lout[i]->mem_delta(); if (losses.size()>=(i+1)) { - losses[i]->delta(lout[i]->target, lout[i]->output, lout[i]->delta); + losses[i]->delta(lout[i]->target, lout[i]->output, lout[i]->delta); } } } @@ -91,6 +91,16 @@ void Net::do_applygrads() { optimizer->applygrads(batch_size); } +void Net::collect_acc_grads() { + for (int j = 0; j < layers.size(); j++) + for (int k = 0; k < layers[j]->acc_gradients.size(); k++) { + // Taking average + layers[j]->acc_gradients[k]->fill_(0.0); + for (int i = 0; i < snets.size(); i++) + Tensor::inc(snets[i]->layers[j]->acc_gradients[k], layers[j]->acc_gradients[k]); + layers[j]->acc_gradients[k]->div_(snets.size()); + } +} void Net::sync_weights() { for (int j = 0; j < layers.size(); j++) diff --git a/src/net/netloss.cpp b/src/net/netloss.cpp index f3aa5c1d2..c29420adf 100644 --- a/src/net/netloss.cpp +++ b/src/net/netloss.cpp @@ -32,7 +32,10 @@ NetLoss::NetLoss(const std::function)>& f, vector Net *sn=in[0]->net; - graph->build(sn->optimizer->clone(),{new LMin()},{new MSum()},sn->cs); + CompServ *cs=sn->cs->clone(); + cs->mem_level=0; //delta must stay to backward netinput layers + + graph->build(sn->optimizer->clone(),{new LMin()},{new MSum()},cs); cout<<"Loss graph:"<summary(); @@ -52,8 +55,12 @@ NetLoss::NetLoss(const std::function& f, Layer *in, string name) graph=new Net(ginput,{fout}); Net *sn=in->net; + + CompServ * cs=sn->cs->clone(); + cs->mem_level=0; //delta must stay to backward netinput layers + - graph->build(sn->optimizer->clone(),{new LMin()},{new MSum()},sn->cs); + graph->build(sn->optimizer->clone(),{new LMin()},{new MSum()},cs); cout<<"Loss graph:"<summary(); diff --git a/src/serialization/onnx/eddl_onnx_export.cpp b/src/serialization/onnx/eddl_onnx_export.cpp index 11598800b..7e86e92b1 100644 --- a/src/serialization/onnx/eddl_onnx_export.cpp +++ b/src/serialization/onnx/eddl_onnx_export.cpp @@ -171,6 +171,12 @@ void build_unsqueeze_node(string node_name, string input, string output, vector< // OPSET: 7, 1 void build_lstm_node(LLSTM *layer, onnx::GraphProto *graph); +// OPSET: 7, 3, 1 +void build_gru_node(LGRU *layer, onnx::GraphProto *graph); + +// OPSET: 7, 1 +void build_rnn_node(LRNN *layer, onnx::GraphProto *graph); + // OPSET: 13 void build_resize_node(LScale *layer, onnx::GraphProto *graph); @@ -244,6 +250,7 @@ void save_net_to_onnx_file(Net *net, string path) { // The serialization is automated by the protobuf library cerr << "Failed to write the model in onnx." << endl; } + ofs.close(); } size_t serialize_net_to_onnx_pointer(Net *net, void *&serialized_model, bool gradients) @@ -327,7 +334,7 @@ void set_graph(onnx::ModelProto *model, Net *net, bool gradients) /* * We get all the input layers from the layers vector of the model * instead of taking them from net->lin. Beacause for the case of - * a recurrent net with decoder the input layer that is connected + * a recurrent net with decoder the input layer that is connected * to the decoder is not added in the lin vector of the model. * With this way we ensure that we are taking all the input layers * of the model. @@ -604,6 +611,14 @@ void build_node_from_layer(Layer *layer, onnx::GraphProto *graph, bool gradients { build_lstm_node((LLSTM *)(MLayer *)layer, graph); } + else if (LGRU *t = dynamic_cast(layer)) + { + build_gru_node((LGRU *)(MLayer *)layer, graph); + } + else if (LRNN *t = dynamic_cast(layer)) + { + build_rnn_node((LRNN *)(MLayer *)layer, graph); + } else if (LCopyStates *t = dynamic_cast(layer)) { handle_copy_states((LCopyStates *)(MLayer *)layer, graph); @@ -1093,10 +1108,9 @@ void build_reshape_node(LReshape *layer, onnx::GraphProto *graph) target_shape_tensor->set_data_type(onnx::TensorProto::INT64); target_shape_tensor->add_dims(layer->ls.size()); // Set the target shape - for (int i : layer->ls) - { - target_shape_tensor->add_int64_data(i); - } + target_shape_tensor->add_int64_data(-1); // For batch_size + for (int i = 1; i < layer->ls.size(); ++i) + target_shape_tensor->add_int64_data(layer->ls[i]); // Add an empty node to the graph onnx::NodeProto *node = graph->add_node(); @@ -1486,14 +1500,14 @@ void build_mul_node(LMult *layer, onnx::GraphProto *graph) } /* - void build_pow_node(LPow *layer, onnx::GraphProto *graph) + void build_pow_node(LPow *layer, onnx::GraphProto *graph) { // Add an empty node to the graph onnx::NodeProto* node = graph->add_node(); node->set_op_type("Pow"); node->set_name(layer->name); // Set the inputs names of the node from the parents of the layer - for (Layer* parentl : layer->parent) + for (Layer* parentl : layer->parent) { node->add_input(parentl->name); } @@ -1907,6 +1921,13 @@ void build_lstm_node(LLSTM *layer, onnx::GraphProto *graph) input_forget_attr->set_type(onnx::AttributeProto::INT); input_forget_attr->set_i(0); // To not couple the input and forget gates + // Warn if the layer uses mask zeros. Not supported in ONNX + if (layer->mask_zeros) { + cout << "[ONNX::Export] Warning: The LSTM layer " << layer->name << " has mask_zeros=true. " + << "This attribute is not supported in ONNX, so the model exported will not have this attribute." + << endl; + } + // W input (weights for all the layers W[iofc]) onnx::TensorProto *w = graph->add_initializer(); w->set_name(layer->name + "_W"); @@ -1918,12 +1939,16 @@ void build_lstm_node(LLSTM *layer, onnx::GraphProto *graph) */ Tensor *Wix = layer->Wix->permute({1, 0}); w->mutable_float_data()->Add(Wix->ptr, Wix->ptr + Wix->size); // i weights + delete Wix; Tensor *Wox = layer->Wox->permute({1, 0}); w->mutable_float_data()->Add(Wox->ptr, Wox->ptr + Wox->size); // o weights + delete Wox; Tensor *Wfx = layer->Wfx->permute({1, 0}); w->mutable_float_data()->Add(Wfx->ptr, Wfx->ptr + Wfx->size); // f weights + delete Wfx; Tensor *Wcx = layer->Wcx->permute({1, 0}); w->mutable_float_data()->Add(Wcx->ptr, Wcx->ptr + Wcx->size); // c weights + delete Wcx; // R input (recurrent weights for all the layers W[iofc]) onnx::TensorProto *r = graph->add_initializer(); @@ -1936,12 +1961,16 @@ void build_lstm_node(LLSTM *layer, onnx::GraphProto *graph) */ Tensor *Wih = layer->Wih->permute({1, 0}); r->mutable_float_data()->Add(Wih->ptr, Wih->ptr + Wih->size); // i recurrent weights + delete Wih; Tensor *Woh = layer->Woh->permute({1, 0}); r->mutable_float_data()->Add(Woh->ptr, Woh->ptr + Woh->size); // o recurrent weights + delete Woh; Tensor *Wfh = layer->Wfh->permute({1, 0}); r->mutable_float_data()->Add(Wfh->ptr, Wfh->ptr + Wfh->size); // f recurrent weights + delete Wfh; Tensor *Wch = layer->Wch->permute({1, 0}); r->mutable_float_data()->Add(Wch->ptr, Wch->ptr + Wch->size); // c recurrent weights + delete Wch; // B input (biases for all the layers) onnx::TensorProto *b = graph->add_initializer(); @@ -2002,6 +2031,328 @@ void build_lstm_node(LLSTM *layer, onnx::GraphProto *graph) } } +void build_gru_node(LGRU *layer, onnx::GraphProto *graph) +{ + // Add an empty node to the graph + onnx::NodeProto *node = graph->add_node(); + node->set_op_type("GRU"); + node->set_name(layer->name); + // Set the input sequence of the GRU + node->add_input(layer->parent[0]->name); + node->add_input(layer->name + "_W"); + node->add_input(layer->name + "_R"); + node->add_input(layer->name + "_B"); + node->add_input(""); // Empty str to skip the sequence_lens input + // Check if we have to copy states for a decoder GRU + if (layer->parent.size() > 1 && layer->isdecoder) + { + string l_copyStates_name = layer->parent[1]->name; + node->add_input(l_copyStates_name + "_h"); + } + + // Attr activation alpha (for GRU activation functions) + // Not used in EDDL + //onnx::AttributeProto* activation_alpha_attr = node->add_attribute(); + //activation_alpha_attr->set_name( "activation_alpha" ); + //activation_alpha_attr->set_type( onnx::AttributeProto::FLOATS ); + + // Attr activation beta + // Not used in EDDL + //onnx::AttributeProto* activation_beta_attr = node->add_attribute(); + //activation_beta_attr->set_name( "activation_beta" ); // Not used in EDDL + //activation_beta_attr->set_type( onnx::AttributeProto::FLOATS ); + + // Attr activations + onnx::AttributeProto *activations_attr = node->add_attribute(); + activations_attr->set_name("activations"); + activations_attr->set_type(onnx::AttributeProto::STRINGS); + activations_attr->add_strings("Sigmoid"); // For gates z, r + activations_attr->add_strings("Tanh"); // For gate n + + // Attr clip (cell clip threshold, [-threshold, +threshold]) + // Not used in EDDL + //onnx::AttributeProto* hidden_size_attr = node->add_attribute(); + //hidden_size_attr->set_name( "clip" ); + //hidden_size_attr->set_type( onnx::AttributeProto::FLOAT ); + //hidden_size_attr->set_i( /*?*/ ); + + // Attr direction + onnx::AttributeProto *direction_attr = node->add_attribute(); + direction_attr->set_name("direction"); + direction_attr->set_type(onnx::AttributeProto::STRING); + direction_attr->set_s("forward"); // Current implementation of GRU + + // Attr hidden size + onnx::AttributeProto *hidden_size_attr = node->add_attribute(); + hidden_size_attr->set_name("hidden_size"); + hidden_size_attr->set_type(onnx::AttributeProto::INT); + hidden_size_attr->set_i(layer->units); + + // Attr linear transformation before reset + onnx::AttributeProto *linear_trans_attr = node->add_attribute(); + linear_trans_attr->set_name("linear_before_reset"); + linear_trans_attr->set_type(onnx::AttributeProto::INT); + // We apply the linear transformation before the r gate. + // See "linear_before_reset" attribute in https://github.com/onnx/onnx/blob/master/docs/Operators.md#GRU + linear_trans_attr->set_i(1); + + // Warn if the layer uses mask zeros. Not supported in ONNX + if (layer->mask_zeros) { + cout << "[ONNX::Export] Warning: The GRU layer " << layer->name << " has mask_zeros=true. " + << "This attribute is not supported in ONNX, so the model exported will not have this attribute." + << endl; + } + + // W input (weights for all the layers W[zrn]) + onnx::TensorProto *w = graph->add_initializer(); + w->set_name(layer->name + "_W"); + w->set_data_type(onnx::TensorProto::FLOAT); + vector w_dims{1, 3 * layer->units, layer->input->shape[1]}; // w_dims shape[0] = 1 beacuse is only forward + w->mutable_dims()->Add(w_dims.begin(), w_dims.end()); // Set the shape of the weights + /* + * The Weights are permuted before saving them (required by ONNX standad) + */ + Tensor *Wz_x = layer->Wz_x->permute({1, 0}); + w->mutable_float_data()->Add(Wz_x->ptr, Wz_x->ptr + Wz_x->size); // z weights + delete Wz_x; + Tensor *Wr_x = layer->Wr_x->permute({1, 0}); + w->mutable_float_data()->Add(Wr_x->ptr, Wr_x->ptr + Wr_x->size); // r weights + delete Wr_x; + Tensor *Wn_x = layer->Wn_x->permute({1, 0}); + w->mutable_float_data()->Add(Wn_x->ptr, Wn_x->ptr + Wn_x->size); // n weights + delete Wn_x; + + // R input (recurrent weights for all the layers W[zrh]) + onnx::TensorProto *r = graph->add_initializer(); + r->set_name(layer->name + "_R"); + r->set_data_type(onnx::TensorProto::FLOAT); + vector r_dims{1, 3 * layer->units, layer->units}; // r_dims shape[0] = 1 beacuse is only forward + r->mutable_dims()->Add(r_dims.begin(), r_dims.end()); // Set the shape of the weights + /* + * The Weights are permuted before saving them (required by ONNX standad) + */ + Tensor *Wz_hidden = layer->Uz_h->permute({1, 0}); + r->mutable_float_data()->Add(Wz_hidden->ptr, Wz_hidden->ptr + Wz_hidden->size); // z recurrent weights + delete Wz_hidden; + Tensor *Wr_hidden = layer->Ur_h->permute({1, 0}); + r->mutable_float_data()->Add(Wr_hidden->ptr, Wr_hidden->ptr + Wr_hidden->size); // r recurrent weights + delete Wr_hidden; + Tensor *Wn_hidden = layer->Un_h->permute({1, 0}); + r->mutable_float_data()->Add(Wn_hidden->ptr, Wn_hidden->ptr + Wn_hidden->size); // n recurrent weights + delete Wn_hidden; + + // B input (biases for all the layers) + onnx::TensorProto *b = graph->add_initializer(); + b->set_name(layer->name + "_B"); + b->set_data_type(onnx::TensorProto::FLOAT); + vector b_dims{1, 6 * layer->units}; // b_dims shape[0] = 1 for weights in one directions + b->mutable_dims()->Add(b_dims.begin(), b_dims.end()); // Set the shape of the weights + + b->mutable_float_data()->Add(layer->bias_z_t->ptr, layer->bias_z_t->ptr + layer->bias_z_t->size); // z bias + b->mutable_float_data()->Add(layer->bias_r_t->ptr, layer->bias_r_t->ptr + layer->bias_r_t->size); // r bias + b->mutable_float_data()->Add(layer->bias_n_t->ptr, layer->bias_n_t->ptr + layer->bias_n_t->size); // n bias + + // Set recurrent forward biases to 0 for gates z and r + for (int i = 0; i < 2 * layer->units; ++i) + b->add_float_data(0.0); + + // The recurrent bias for n is set. Because we need it for applying the linear transformation before the + // r gate. See "linear_before_reset" attribute in https://github.com/onnx/onnx/blob/master/docs/Operators.md#GRU + b->mutable_float_data()->Add(layer->bias_n_t_hidden->ptr, layer->bias_n_t_hidden->ptr + layer->bias_n_t_hidden->size); // n recurrent bias + + /* Set the outputs of the node to link with the other nodes + * - In ONNX the GRU operator can have up to 2 outputs: + * * Y -> [seq_len, num_directions, batch_size, hidden_size] + * * Y_h (optional) -> [num_directions, batch_size, hidden_size] + * - If the layer is encoder we select Y_h as output + * - If the layer is encoder but there are more stacked GRU, we select Y as output + * - If the layer is decoder we select Y as output + * + * Note: To select the output of the GRU that the next layer in the graph takes as input + * we have to set that output name to the layer name (layer->name) + */ + node->add_output(layer->name + "_Y"); + node->add_output(layer->name + "_Y_h"); + if (layer->isdecoder || layer->child[0]->isrecurrent /*To detect stacked GRU*/) + { + // Squeeze: [seq_length, num_directions, batch_size, hidden_size] -> [seq_length, batch_size, hidden_size] + // Note: The EDDL only supports one-directional GRU, so num_directions=1 + build_squeeze_node( + layer->name + "_outputSqueeze", // node name + layer->name + "_Y", // input name + layer->name, // Output name + {1}, // axes to squeeze + graph); + } + else + { // is encoder + // Squeeze: [num_directions, batch_size, hidden_size] -> [batch_size, hidden_size] + // Note: The EDDL only supports one-directional GRU, so num_directions=1 + build_squeeze_node( + layer->name + "_outputSqueeze", // node name + layer->name + "_Y_h", // input name + layer->name, // Output name + {0}, // axes to squeeze + graph); + } +} + +void build_rnn_node(LRNN *layer, onnx::GraphProto *graph) +{ + // Add an empty node to the graph + onnx::NodeProto *node = graph->add_node(); + node->set_op_type("RNN"); + node->set_name(layer->name); + // Set the input sequence of the RNN + node->add_input(layer->parent[0]->name); + node->add_input(layer->name + "_W"); + node->add_input(layer->name + "_R"); + if (layer->use_bias) node->add_input(layer->name + "_B"); + else node->add_input(""); + node->add_input(""); // Empty str to skip the sequence_lens input + // Check if we have to copy states for a decoder RNN + if (layer->parent.size() > 1 && layer->isdecoder) + { + string l_copyStates_name = layer->parent[1]->name; + node->add_input(l_copyStates_name + "_h"); + } + + // Attr activations + float alpha, beta; // optional auxiliary parameters + bool activation_with_params = false; + onnx::AttributeProto *activations_attr = node->add_attribute(); + activations_attr->set_name("activations"); + activations_attr->set_type(onnx::AttributeProto::STRINGS); + if (layer->activation == "relu") + activations_attr->add_strings("Relu"); + else if (layer->activation == "sigmoid") + activations_attr->add_strings("Sigmoid"); + else if (layer->activation == "hard_sigmoid") { + activations_attr->add_strings("HardSigmoid"); + alpha = 0.2; + beta = 0.5; + activation_with_params = true; + } else if (layer->activation == "tanh") + activations_attr->add_strings("Tanh"); + else if (layer->activation == "none") { + activations_attr->add_strings("Affine"); + // Achieve linear activation: alpha * x + beta -> where alpha = 1.0 and beta = 0.0 + alpha = 1.0; + beta = 0.0; + activation_with_params = true; + } else + msg("Activation not supported for RNN", "ONNX::ExportNet"); + + if (activation_with_params) { + // Auxiliary alpha attribute for the activation functions + onnx::AttributeProto* activation_alpha_attr = node->add_attribute(); + activation_alpha_attr->set_name("activation_alpha"); + activation_alpha_attr->set_type(onnx::AttributeProto::FLOATS); + activation_alpha_attr->add_floats(alpha); + + // Auxiliary beta attribute for the activation functions + onnx::AttributeProto* activation_beta_attr = node->add_attribute(); + activation_beta_attr->set_name("activation_beta"); + activation_beta_attr->set_type(onnx::AttributeProto::FLOATS); + activation_beta_attr->add_floats(beta); + } + + // Attr clip (cell clip threshold, [-threshold, +threshold]) + // Not used in EDDL + //onnx::AttributeProto* hidden_size_attr = node->add_attribute(); + //hidden_size_attr->set_name( "clip" ); + //hidden_size_attr->set_type( onnx::AttributeProto::FLOAT ); + //hidden_size_attr->set_i( /*?*/ ); + + // Attr direction + onnx::AttributeProto *direction_attr = node->add_attribute(); + direction_attr->set_name("direction"); + direction_attr->set_type(onnx::AttributeProto::STRING); + direction_attr->set_s("forward"); // Current implementation of RNN + + // Attr hidden size + onnx::AttributeProto *hidden_size_attr = node->add_attribute(); + hidden_size_attr->set_name("hidden_size"); + hidden_size_attr->set_type(onnx::AttributeProto::INT); + hidden_size_attr->set_i(layer->units); + + // Weights for input + onnx::TensorProto *w = graph->add_initializer(); + w->set_name(layer->name + "_W"); + w->set_data_type(onnx::TensorProto::FLOAT); + vector w_dims{1, layer->units, layer->input->shape[1]}; // w_dims shape[0] = 1 beacuse is only forward + w->mutable_dims()->Add(w_dims.begin(), w_dims.end()); // Set the shape of the weights + /* + * The Weights are permuted before saving them (required by ONNX standad) + */ + Tensor *Wx = layer->Wx->permute({1, 0}); + w->mutable_float_data()->Add(Wx->ptr, Wx->ptr + Wx->size); + delete Wx; + + // Recurrent weights + onnx::TensorProto *r = graph->add_initializer(); + r->set_name(layer->name + "_R"); + r->set_data_type(onnx::TensorProto::FLOAT); + vector r_dims{1, layer->units, layer->units}; // r_dims shape[0] = 1 beacuse is only forward + r->mutable_dims()->Add(r_dims.begin(), r_dims.end()); // Set the shape of the weights + /* + * The Weights are permuted before saving them (required by ONNX standad) + */ + Tensor *Wy = layer->Wy->permute({1, 0}); + r->mutable_float_data()->Add(Wy->ptr, Wy->ptr + Wy->size); + delete Wy; + + // Bias + if (layer->use_bias) { + onnx::TensorProto *b = graph->add_initializer(); + b->set_name(layer->name + "_B"); + b->set_data_type(onnx::TensorProto::FLOAT); + vector b_dims{1, 2 * layer->units}; // b_dims shape[0] = 1 for weights in one directions + b->mutable_dims()->Add(b_dims.begin(), b_dims.end()); // Set the shape of the weights + b->mutable_float_data()->Add(layer->bias->ptr, layer->bias->ptr + layer->bias->size); + // Set recurrent biases to 0 + for (int i = 0; i < layer->units; ++i) + b->add_float_data(0.0); + } + + /* Set the outputs of the node to link with the other nodes + * - In ONNX the LSTM operator can have up to 2 outputs: + * * Y -> [seq_len, num_directions, batch_size, hidden_size] + * * Y_h (optional) -> [num_directions, batch_size, hidden_size] + * - If the layer is encoder we select Y_h as output + * - If the layer is encoder but there are more stacked RNN, we select Y as output + * - If the layer is decoder we select Y as output + * + * Note: To select the output of the RNN that the next layer in the graph takes as input + * we have to set that output name to the layer name (layer->name) + */ + node->add_output(layer->name + "_Y"); + node->add_output(layer->name + "_Y_h"); + if (layer->isdecoder || layer->child[0]->isrecurrent /*To detect stacked RNN*/) + { + // Squeeze: [seq_length, num_directions, batch_size, hidden_size] -> [seq_length, batch_size, hidden_size] + // Note: The EDDL only supports one-directional RNN, so num_directions=1 + build_squeeze_node( + layer->name + "_outputSqueeze", // node name + layer->name + "_Y", // input name + layer->name, // Output name + {1}, // axes to squeeze + graph); + } + else + { // is encoder + // Squeeze: [num_directions, batch_size, hidden_size] -> [batch_size, hidden_size] + // Note: The EDDL only supports one-directional RNN, so num_directions=1 + build_squeeze_node( + layer->name + "_outputSqueeze", // node name + layer->name + "_Y_h", // input name + layer->name, // Output name + {0}, // axes to squeeze + graph); + } +} + + void build_resize_node(LScale *layer, onnx::GraphProto *graph) { // Add an empty node to the graph @@ -2165,27 +2516,27 @@ void handle_copy_states(LCopyStates *layer, onnx::GraphProto *graph) { string parent_name = layer->parent[0]->name; string child_name = layer->child[0]->name; - // Check the type of the parent layer to copy the states + + // Set the node to copy the hidden (h) state + string node_name = parent_name + "_to_" + child_name + "_CopyState_h"; + string input_name = parent_name; + string output_name = layer->name + "_h"; + /* + * Add an Unsqueeze layer to reshape the h state to the desired shape for LSTM. + * + * Note: The h state coming from the previous LSTM has been squeezed, so we + * have to unsqueeze it to get the desired shape for the decoder LSTM + */ + build_unsqueeze_node( + layer->name + "_h_unsqueeze", // node name + input_name, // input name + output_name, // Output name + {0}, // axes to squeeze + graph); + + // Set the node to copy the cell (c) state in case of LSTM if (LLSTM *l = dynamic_cast(layer->parent[0])) { - // Set the node to copy the hidden (h) state - string node_name = parent_name + "_to_" + child_name + "_CopyState_h"; - string input_name = parent_name; - string output_name = layer->name + "_h"; - /* - * Add an Unsqueeze layer to reshape the h state to the desired shape for LSTM. - * - * Note: The h state coming from the previous LSTM has been squeezed, so we - * have to unsqueeze it to get the desired shape for the decoder LSTM - */ - build_unsqueeze_node( - layer->name + "_h_unsqueeze", // node name - input_name, // input name - output_name, // Output name - {0}, // axes to squeeze - graph); - - // Set the node to copy the cell (c) state node_name = parent_name + "_to_" + child_name + "_CopyState_c"; input_name = parent_name + "_Y_c"; output_name = layer->name + "_c"; @@ -2225,9 +2576,9 @@ void prepare_recurrent_input(string input_name, string output_name, vector void prepare_recurrent_output(string input_name, string output_name, vector output_shape, onnx::GraphProto *graph) { /* - * This functions takes a graph of a recurrent net and adds a transpose operator - * to fix the output shape from (seq_len, batch_size, out_shape) to (batch_size, seq_len, out_shape) - */ + * This functions takes a graph of a recurrent net and adds a transpose operator + * to fix the output shape from (seq_len, batch_size, out_shape) to (batch_size, seq_len, out_shape) + */ // Add an empty node to the graph onnx::NodeProto *node = graph->add_node(); node->set_op_type("Transpose"); @@ -2253,7 +2604,9 @@ void prepare_recurrent_output(string input_name, string output_name, vector // End: Exporting Module //---------------------------------------------------------------------------------------- + #else + void save_net_to_onnx_file(Net *net, string path) { cerr << "Not compiled for ONNX. Missing Protobuf" << endl; diff --git a/src/serialization/onnx/eddl_onnx_import.cpp b/src/serialization/onnx/eddl_onnx_import.cpp index ef963b94e..63192cc10 100644 --- a/src/serialization/onnx/eddl_onnx_import.cpp +++ b/src/serialization/onnx/eddl_onnx_import.cpp @@ -75,6 +75,8 @@ enum ONNX_LAYERS ADD, // OPSET: 13, 7 MAT_MUL, // OPSET: 13, 9, 1 (Only for MatMul+Add Dense layer) LSTM, // OPSET: 7, 1 + GRU, // OPSET: 7, 3, 1 + RNN, // OPSET: 7, 1 IDENTITY, // We skip this layer when found GATHER, // OPSET: 13, 11, 1 CAST, // We skip this layer when found @@ -208,6 +210,8 @@ map create_enum_map() map_layers["MatMul"] = ONNX_LAYERS::MAT_MUL; map_layers["LSTM"] = ONNX_LAYERS::LSTM; + map_layers["GRU"] = ONNX_LAYERS::GRU; + map_layers["RNN"] = ONNX_LAYERS::RNN; map_layers["Identity"] = ONNX_LAYERS::IDENTITY; map_layers["Gather"] = ONNX_LAYERS::GATHER; map_layers["Cast"] = ONNX_LAYERS::CAST; @@ -502,6 +506,7 @@ Net *import_net_from_onnx_file(std::string path, int mem, int log_level) cerr << "Failed to parse model." << endl; //return; } + input.close(); } return build_net_onnx(model, mem, log_level); } @@ -560,7 +565,9 @@ Layer *get_model_input_layer(Layer *l) bool node_is_recurrent(onnx::NodeProto *node, map &map_layers) { ONNX_LAYERS layer_type = map_layers[node->op_type()]; - if (layer_type == ONNX_LAYERS::LSTM) + if (layer_type == ONNX_LAYERS::LSTM || + layer_type == ONNX_LAYERS::GRU || + layer_type == ONNX_LAYERS::RNN) return true; return false; @@ -879,12 +886,13 @@ Net *build_net_onnx(onnx::ModelProto model, int mem, int log_level) int filters; vector kernel_shape; vector strides; - vector pads; - string auto_pad_option = ""; - bool auto_pad = false; + vector pads = {}; + string auto_pad_option = "custom"; vector *bias; bool use_bias = node->input_size() > 2; bool conv1d = false; + int groups = 1; + vector dilation_rate = {1, 1}; for (int j = 0; j < node->attribute_size(); j++) { // Set the attributes @@ -892,14 +900,10 @@ Net *build_net_onnx(onnx::ModelProto model, int mem, int log_level) string attr_name = attribute.name(); if (!attr_name.compare("auto_pad")) { - auto_pad = true; if (!attribute.s().compare("NOTSET")) - { - auto_pad = false; - continue; - } + auto_pad_option = "custom"; else if (!attribute.s().compare("VALID")) - auto_pad_option = "none"; + auto_pad_option = "valid"; else if (!attribute.s().compare("SAME_UPPER")) auto_pad_option = "same"; } @@ -964,14 +968,15 @@ Net *build_net_onnx(onnx::ModelProto model, int mem, int log_level) filters = dims[0]; string name = node->name(); - ConvolDescriptor *cd; - - // TODO: REVIEW!!!! - int groups = 1; - vector dilation_rate = {1, 1}; - //auto_pad_option == "custom"; - //cd->pad = pads; - cd = new ConvolDescriptor(filters, kernel_shape, strides, auto_pad_option, pads, groups, dilation_rate, use_bias, mem); + ConvolDescriptor *cd = new ConvolDescriptor(filters, + kernel_shape, + strides, + auto_pad_option, + pads, + groups, + dilation_rate, + use_bias, + mem); if (conv1d) actual_layer = new LConv1D(parent, cd, name, dev, mem); @@ -2087,8 +2092,13 @@ Net *build_net_onnx(onnx::ModelProto model, int mem, int log_level) clip = attribute.f(); } else if (!attr_name.compare("direction")) - { // Not used yet in eddl but implemented. We default to forward + { direction = attribute.s(); + if (direction.compare("forward")) + { + msg("LSTM layer " + name + " is not forward direction. EDDL only supports one-directional LSTM", + "ONNX::ImportNet"); + } } else if (!attr_name.compare("hidden_size")) { @@ -2227,30 +2237,47 @@ Net *build_net_onnx(onnx::ModelProto model, int mem, int log_level) delete recurrence_weights_cell_tensor; delete recurrence_weights_cell_g; - string biases_name = node->input(3); //Get weights and dims - vector *biases = &(map_init_values[biases_name]); - vector biases_dims = map_init_dims[biases_name]; + /* + * Set bias values + */ vector bias_dims = {hidden_size}; - + // Vectors to store the imported weights vector *bias_input = new vector; vector *bias_output = new vector; vector *bias_forget = new vector; vector *bias_cell = new vector; - vector *bias_recurrence_input = new vector; vector *bias_recurrence_output = new vector; vector *bias_recurrence_forget = new vector; vector *bias_recurrence_cell = new vector; - bias_input->assign(biases->begin() + hidden_size * 0, biases->begin() + hidden_size * 1); - bias_output->assign(biases->begin() + hidden_size * 1, biases->begin() + hidden_size * 2); - bias_forget->assign(biases->begin() + hidden_size * 2, biases->begin() + hidden_size * 3); - bias_cell->assign(biases->begin() + hidden_size * 3, biases->begin() + hidden_size * 4); - - bias_recurrence_input->assign(biases->begin() + hidden_size * 4, biases->begin() + hidden_size * 5); - bias_recurrence_output->assign(biases->begin() + hidden_size * 5, biases->begin() + hidden_size * 6); - bias_recurrence_forget->assign(biases->begin() + hidden_size * 6, biases->begin() + hidden_size * 7); - bias_recurrence_cell->assign(biases->begin() + hidden_size * 7, biases->begin() + hidden_size * 8); + if (node->input_size() > 3) { + string biases_name = node->input(3); //Get weights and dims + vector *biases = &(map_init_values[biases_name]); + + bias_input->assign(biases->begin() + hidden_size * 0, biases->begin() + hidden_size * 1); + bias_output->assign(biases->begin() + hidden_size * 1, biases->begin() + hidden_size * 2); + bias_forget->assign(biases->begin() + hidden_size * 2, biases->begin() + hidden_size * 3); + bias_cell->assign(biases->begin() + hidden_size * 3, biases->begin() + hidden_size * 4); + bias_recurrence_input->assign(biases->begin() + hidden_size * 4, biases->begin() + hidden_size * 5); + bias_recurrence_output->assign(biases->begin() + hidden_size * 5, biases->begin() + hidden_size * 6); + bias_recurrence_forget->assign(biases->begin() + hidden_size * 6, biases->begin() + hidden_size * 7); + bias_recurrence_cell->assign(biases->begin() + hidden_size * 7, biases->begin() + hidden_size * 8); + } else { + // Set bias values to 0.0 + // Note: In EDDL we don't have use_bias option for LSTM so to achieve the same + // result we set the bias values to 0.0 + vector zero_bias(hidden_size, 0.0); + bias_input->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_output->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_forget->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_cell->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_recurrence_input->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_recurrence_output->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_recurrence_forget->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_recurrence_cell->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + } + Tensor *bias_input_tensor = new Tensor(bias_dims, NEW_FROM_VECTOR_PTR(bias_input), dev); Tensor::copy(bias_input_tensor, lstm->inbias); @@ -2298,6 +2325,449 @@ Net *build_net_onnx(onnx::ModelProto model, int mem, int log_level) } break; + case ONNX_LAYERS::GRU: + { + log_string("GRU layer detected", log_level, LOG_LEVEL::DEBUG); + vector activation_alpha; // Values for configuring some activations with extra parameters + vector activation_beta; // Values for configuring some activations with extra parameters + vector activations; // Activation functions in order for each gate + float clip = -1; // Value for clipping + string direction = ""; // Forward, backward or reverse (Forward by default) + int hidden_size = -1; // Number of neurons in the hidden layer + + for (int j = 0; j < node->attribute_size(); j++) + { // Set the attributes + onnx::AttributeProto attribute = node->attribute(j); + string attr_name = attribute.name(); + if (!attr_name.compare("activation_alpha")) + { // Not used yet in eddl but implemented + for (int h = 0; h < attribute.floats_size(); h++) + { + activation_alpha.push_back(attribute.floats(h)); + } + } + else if (!attr_name.compare("activation_beta")) + { // Not used yet in eddl but implemented + for (int h = 0; h < attribute.floats_size(); h++) + { + activation_beta.push_back(attribute.floats(h)); + } + } + else if (!attr_name.compare("activations")) + { // Not used yet in eddl but implemented. We default to Sigmoid, TanH + for (int h = 0; h < attribute.strings_size(); h++) + { + activations.push_back(attribute.strings(h)); + } + } + else if (!attr_name.compare("clip")) + { // Not used yet in eddl but implemented + clip = attribute.f(); + } + else if (!attr_name.compare("direction")) + { + direction = attribute.s(); + if (direction.compare("forward")) + { + msg("GRU layer " + name + " is not forward direction. EDDL only supports one-directional GRU", "ONNX::ImportNet"); + } + } + else if (!attr_name.compare("hidden_size")) + { + hidden_size = attribute.i(); + } + //else if (!attr_name.compare("linear_before_reset")) {} + } + + if (hidden_size < 0) + msg("GRU layer " + name + " doesn't have the number of neurons.", "ONNX::ImportNet"); + + string parent_name = node->input(0); // Get parent + Layer *parent = output_node_map[parent_name]; + vector parent_shape = parent->output->shape; + vector parents = {parent}; + + /* + * Check if the layer is Decoder by checking if there is not a recurrent layer after this one. To avoid + * conflicts with the stacked GRU layers that are encoders. + */ + bool is_decoder = node_is_decoder(node, input_node_map); + + if (is_decoder) + { + log_string("The layer " + name + " is decoder", log_level, LOG_LEVEL::DEBUG); + // We have to create the copy states layer for the decoder + Layer *parent_hstate = output_node_map[node->input(5)]; // 5: hidden state + Layer *cps = new LCopyStates({parent_hstate}, "", dev, mem); + parents.push_back(cps); // Add the layer to the parents for the GRU + } + + string weights_gates = node->input(1); // Get weights and dims + vector *weights_g = &(map_init_values[weights_gates]); + vector dims_g = map_init_dims[weights_gates]; + int input_size = dims_g[2]; + + // Load input weights with shape [hidden_size, input_size]. After load we transpose + // Note: EDDL input weights are of shape [input_size, hidden_size] + vector dims_input_gru = {dims_g[1] / 3, input_size}; + + vector *weights_z_g = new vector; + vector *weights_r_g = new vector; + vector *weights_n_g = new vector; + int w_size = input_size * hidden_size; + weights_z_g->assign(weights_g->begin() + w_size * 0, weights_g->begin() + w_size * 1); + weights_r_g->assign(weights_g->begin() + w_size * 1, weights_g->begin() + w_size * 2); + weights_n_g->assign(weights_g->begin() + w_size * 2, weights_g->begin() + w_size * 3); + + string recurrence_weights_gates = node->input(2); // Get weights and dims + vector *recurrence_weights_g = &(map_init_values[recurrence_weights_gates]); + vector recurrence_dims_g = map_init_dims[recurrence_weights_gates]; + + vector dims_recurrent_gru = {recurrence_dims_g[2], recurrence_dims_g[2]}; + + vector *recurrence_weights_z_g = new vector; + vector *recurrence_weights_r_g = new vector; + vector *recurrence_weights_n_g = new vector; + w_size = hidden_size * hidden_size; + recurrence_weights_z_g->assign(recurrence_weights_g->begin() + w_size * 0, recurrence_weights_g->begin() + w_size * 1); + recurrence_weights_r_g->assign(recurrence_weights_g->begin() + w_size * 1, recurrence_weights_g->begin() + w_size * 2); + recurrence_weights_n_g->assign(recurrence_weights_g->begin() + w_size * 2, recurrence_weights_g->begin() + w_size * 3); + + LGRU *gru = new LGRU(parents, hidden_size, 0, 0, name, dev, mem); + + if (is_decoder) + { + // Set attribute for unrolling + gru->isdecoder = true; + set_decoder(gru->parent[0]); + // We also have to remove the input layer that feeds the decoder from the input layers of the model + // First we search the corresponding input layer for the decoder + Layer *dec_linput = get_model_input_layer(gru); + if (dec_linput != nullptr) + inputs2remove.push_back(dec_linput->name); + else + msg("Input layer for decoder " + name + " not found", "ONNX::ImportNet"); + } + + /* + * The Weights are permuted before copying them to the GRU layer (mismatch between ONNX standad and EDDL implementation) + */ + Tensor *weights_z_tensor = new Tensor(dims_input_gru, NEW_FROM_VECTOR_PTR(weights_z_g), dev); + weights_z_tensor->permute_({1, 0}); + Tensor::copy(weights_z_tensor, gru->Wz_x); + delete weights_z_tensor; + delete weights_z_g; + + Tensor *weights_r_tensor = new Tensor(dims_input_gru, NEW_FROM_VECTOR_PTR(weights_r_g), dev); + weights_r_tensor->permute_({1, 0}); + Tensor::copy(weights_r_tensor, gru->Wr_x); + delete weights_r_tensor; + delete weights_r_g; + + Tensor *weights_n_tensor = new Tensor(dims_input_gru, NEW_FROM_VECTOR_PTR(weights_n_g), dev); + weights_n_tensor->permute_({1, 0}); + Tensor::copy(weights_n_tensor, gru->Wn_x); + delete weights_n_tensor; + delete weights_n_g; + + Tensor *recurrence_weights_z_tensor = new Tensor(dims_recurrent_gru, NEW_FROM_VECTOR_PTR(recurrence_weights_z_g), dev); + recurrence_weights_z_tensor->permute_({1, 0}); + Tensor::copy(recurrence_weights_z_tensor, gru->Uz_h); + delete recurrence_weights_z_tensor; + delete recurrence_weights_z_g; + + Tensor *recurrence_weights_r_tensor = new Tensor(dims_recurrent_gru, NEW_FROM_VECTOR_PTR(recurrence_weights_r_g), dev); + recurrence_weights_r_tensor->permute_({1, 0}); + Tensor::copy(recurrence_weights_r_tensor, gru->Ur_h); + delete recurrence_weights_r_tensor; + delete recurrence_weights_r_g; + + Tensor *recurrence_weights_n_tensor = new Tensor(dims_recurrent_gru, NEW_FROM_VECTOR_PTR(recurrence_weights_n_g), dev); + recurrence_weights_n_tensor->permute_({1, 0}); + Tensor::copy(recurrence_weights_n_tensor, gru->Un_h); + delete recurrence_weights_n_tensor; + delete recurrence_weights_n_g; + + /* + * Set bias values + */ + vector bias_dims = {hidden_size}; + // Vectors to store the imported weights + vector *bias_z = new vector; + vector *bias_r = new vector; + vector *bias_n = new vector; + vector *bias_recurrence_z = new vector; + vector *bias_recurrence_r = new vector; + vector *bias_recurrence_n = new vector; + + if (node->input_size() > 3) { // Check that we have bias + string biases_name = node->input(3); + vector *biases = &(map_init_values[biases_name]); + // Forward bias (zrh) + bias_z->assign(biases->begin() + hidden_size * 0, biases->begin() + hidden_size * 1); + bias_r->assign(biases->begin() + hidden_size * 1, biases->begin() + hidden_size * 2); + bias_n->assign(biases->begin() + hidden_size * 2, biases->begin() + hidden_size * 3); + // Recurrent bias (zrh) + bias_recurrence_z->assign(biases->begin() + hidden_size * 3, biases->begin() + hidden_size * 4); + bias_recurrence_r->assign(biases->begin() + hidden_size * 4, biases->begin() + hidden_size * 5); + bias_recurrence_n->assign(biases->begin() + hidden_size * 5, biases->begin() + hidden_size * 6); + } else { + // Set bias values to 0.0 + // Note: In EDDL we don't have use_bias option for GRU so to achieve the same + // result we set the bias values to 0.0 + vector zero_bias(hidden_size, 0.0); + bias_z->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_r->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_n->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_recurrence_z->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_recurrence_r->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_recurrence_n->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + } + + Tensor *bias_z_tensor = new Tensor(bias_dims, NEW_FROM_VECTOR_PTR(bias_z), dev); + Tensor::copy(bias_z_tensor, gru->bias_z_t); + delete bias_z_tensor; + delete bias_z; + + Tensor *bias_r_tensor = new Tensor(bias_dims, NEW_FROM_VECTOR_PTR(bias_r), dev); + Tensor::copy(bias_r_tensor, gru->bias_r_t); + delete bias_r_tensor; + delete bias_r; + + Tensor *bias_n_tensor = new Tensor(bias_dims, NEW_FROM_VECTOR_PTR(bias_n), dev); + Tensor::copy(bias_n_tensor, gru->bias_n_t); + delete bias_n_tensor; + delete bias_n; + + // Add the recurrent bias values for gates z and r + Tensor *bias_recurrence_z_tensor = new Tensor(bias_dims, NEW_FROM_VECTOR_PTR(bias_recurrence_z), dev); + Tensor::add(bias_recurrence_z_tensor, gru->bias_z_t, gru->bias_z_t); + delete bias_recurrence_z_tensor; + delete bias_recurrence_z; + + Tensor *bias_recurrence_r_tensor = new Tensor(bias_dims, NEW_FROM_VECTOR_PTR(bias_recurrence_r), dev); + Tensor::add(bias_recurrence_r_tensor, gru->bias_r_t, gru->bias_r_t); + delete bias_recurrence_r_tensor; + delete bias_recurrence_r; + + // The recurrent bias for h goes to its own tensor beacuse we need it for applying the linear transformation + // before the r gate. See "linear_before_reset" attribute in https://github.com/onnx/onnx/blob/master/docs/Operators.md#GRU + Tensor *bias_recurrence_n_tensor = new Tensor(bias_dims, NEW_FROM_VECTOR_PTR(bias_recurrence_n), dev); + Tensor::copy(bias_recurrence_n_tensor, gru->bias_n_t_hidden); + delete bias_recurrence_n_tensor; + delete bias_recurrence_n; + + actual_layer = gru; + log_string("GRU layer created", log_level, LOG_LEVEL::DEBUG); + } + break; + + case ONNX_LAYERS::RNN: + { + log_string("RNN layer detected", log_level, LOG_LEVEL::DEBUG); + vector activation_alpha; // Values for configuring some activations with extra parameters + vector activation_beta; // Values for configuring some activations with extra parameters + vector activations; // Activation functions in order for each gate + float clip = -1; // Value for clipping + string direction = ""; // Forward, backward or reverse (Forward by default) + int hidden_size = -1; // Number of neurons in the hidden layer + bool use_bias = node->input_size() > 3; + + for (int j = 0; j < node->attribute_size(); j++) + { // Set the attributes + onnx::AttributeProto attribute = node->attribute(j); + string attr_name = attribute.name(); + if (!attr_name.compare("activation_alpha")) + { // Not used yet in eddl but implemented + for (int h = 0; h < attribute.floats_size(); h++) + { + activation_alpha.push_back(attribute.floats(h)); + } + } + else if (!attr_name.compare("activation_beta")) + { // Not used yet in eddl but implemented + for (int h = 0; h < attribute.floats_size(); h++) + { + activation_beta.push_back(attribute.floats(h)); + } + } + else if (!attr_name.compare("activations")) + { + for (int h = 0; h < attribute.strings_size(); h++) + { + activations.push_back(attribute.strings(h)); + } + } + else if (!attr_name.compare("clip")) + { // Not used yet in eddl but implemented + clip = attribute.f(); + } + else if (!attr_name.compare("direction")) + { + direction = attribute.s(); + if (direction.compare("forward")) + { + msg("RNN layer " + name + " is not forward direction. EDDL only supports one-directional RNN", "ONNX::ImportNet"); + } + } + else if (!attr_name.compare("hidden_size")) + { + hidden_size = attribute.i(); + } + //else if (!attr_name.compare("linear_before_reset")) {} + } + + // Take forward activation function + string activation; + if (activations.size() > 0) { + string forward_activation = activations[0]; + if (forward_activation == "Relu") + activation = "relu"; + else if (forward_activation == "Sigmoid") + activation = "sigmoid"; + else if (forward_activation == "HardSigmoid") { + float epsilon = 1e-5; + float alpha = 0.2; + float beta = 0.5; + if (activation_alpha.size() > 0) alpha = activation_alpha[0]; + if (activation_beta.size() > 0) beta = activation_beta[0]; + bool is_not_valid = abs(alpha - 0.2) > epsilon; + is_not_valid |= abs(beta - 0.5) > epsilon; + // Check that is equivalent to our hard sigmoid implementation + if (is_not_valid) { + msg("The HardSigmoid activation function with alpha != 0.2 or beta != 0.5 is not supported for RNN.", + "ONNX::ImportNet"); + } else { + activation = "hard_sigmoid"; + } + } else if (forward_activation == "Tanh") + activation = "tanh"; + else if (forward_activation == "Affine") { + float alpha = 1.0; + float beta = 0.0; + if (activation_alpha.size() > 0) alpha = activation_alpha[0]; + if (activation_beta.size() > 0) beta = activation_beta[0]; + // Check that is equivalent to linear activation function + if (alpha != 1.0 || beta != 0.0) { + msg("The Affine activation function with alpha != 1.0 or beta != 0.0 is not supported for RNN.", + "ONNX::ImportNet"); + } else { + activation = "none"; + } + } else + msg("Activation function \"" + forward_activation + "\" is not supported for RNN.", + "ONNX::ImportNet"); + } else { + msg("RNN layer " + name + " doesn't provide an activation function.", + "ONNX::ImportNet"); + } + + if (hidden_size < 0) + msg("RNN layer " + name + " doesn't have the number of neurons.", "ONNX::ImportNet"); + + string parent_name = node->input(0); // Get parent + Layer *parent = output_node_map[parent_name]; + vector parent_shape = parent->output->shape; + vector parents = {parent}; + + /* + * Check if the layer is Decoder by checking if there is not a recurrent layer after this one. To avoid + * conflicts with the stacked RNN layers that are encoders. + */ + bool is_decoder = node_is_decoder(node, input_node_map); + + if (is_decoder) + { + log_string("The layer " + name + " is decoder", log_level, LOG_LEVEL::DEBUG); + // We have to create the copy states layer for the decoder + Layer *parent_hstate = output_node_map[node->input(5)]; // 5: hidden state + Layer *cps = new LCopyStates({parent_hstate}, "", dev, mem); + parents.push_back(cps); // Add the layer to the parents for the RNN + } + + string weights_gates = node->input(1); // Get weights and dims + vector *weights_g = &(map_init_values[weights_gates]); + vector dims_g = map_init_dims[weights_gates]; + int input_size = dims_g[2]; + + // Load input weights with shape [hidden_size, input_size]. After load we transpose + // Note: EDDL input weights are of shape [input_size, hidden_size] + vector dims_input_gru = {dims_g[1], input_size}; + + vector *weights_x = new vector; + int w_size = input_size * hidden_size; + weights_x->assign(weights_g->begin() , weights_g->begin() + w_size); + + string recurrence_weights_gates = node->input(2); // Get weights and dims + vector *recurrence_weights_g = &(map_init_values[recurrence_weights_gates]); + vector recurrence_dims_g = map_init_dims[recurrence_weights_gates]; + + vector dims_recurrent_gru = {recurrence_dims_g[2], recurrence_dims_g[2]}; + + vector *weights_h = new vector; + w_size = hidden_size * hidden_size; + weights_h->assign(recurrence_weights_g->begin(), recurrence_weights_g->begin() + w_size); + + LRNN *rnn = new LRNN(parents, hidden_size, activation, use_bias, false, name, dev, mem); + + if (is_decoder) + { + // Set attribute for unrolling + rnn->isdecoder = true; + set_decoder(rnn->parent[0]); + // We also have to remove the input layer that feeds the decoder from the input layers of the model + // First we search the corresponding input layer for the decoder + Layer *dec_linput = get_model_input_layer(rnn); + if (dec_linput != nullptr) + inputs2remove.push_back(dec_linput->name); + else + msg("Input layer for decoder " + name + " not found", "ONNX::ImportNet"); + } + + /* + * The Weights are permuted before copying them to the RNN layer (mismatch between ONNX standad and EDDL implementation) + */ + Tensor *weights_x_tensor = new Tensor(dims_input_gru, NEW_FROM_VECTOR_PTR(weights_x), dev); + weights_x_tensor->permute_({1, 0}); + Tensor::copy(weights_x_tensor, rnn->Wx); + delete weights_x_tensor; + delete weights_x; + + Tensor *weights_h_tensor = new Tensor(dims_recurrent_gru, NEW_FROM_VECTOR_PTR(weights_h), dev); + weights_h_tensor->permute_({1, 0}); + Tensor::copy(weights_h_tensor, rnn->Wy); + delete weights_h_tensor; + delete weights_h; + + if (use_bias) { + string biases_name = node->input(3); + vector *biases = &(map_init_values[biases_name]); + vector bias_dims = {hidden_size}; + + vector *bias_x = new vector; + vector *bias_h = new vector; + + bias_x->assign(biases->begin() + hidden_size * 0, biases->begin() + hidden_size * 1); + bias_h->assign(biases->begin() + hidden_size * 1, biases->begin() + hidden_size * 2); + + Tensor *bias_x_tensor = new Tensor(bias_dims, NEW_FROM_VECTOR_PTR(bias_x), dev); + Tensor::copy(bias_x_tensor, rnn->bias); + delete bias_x_tensor; + delete bias_x; + + // Add the recurrent bias values for gates z and r + Tensor *bias_h_tensor = new Tensor(bias_dims, NEW_FROM_VECTOR_PTR(bias_h), dev); + Tensor::add(bias_h_tensor, rnn->bias, rnn->bias); + delete bias_h_tensor; + delete bias_h; + } + + actual_layer = rnn; + log_string("RNN layer created", log_level, LOG_LEVEL::DEBUG); + } + break; + + case ONNX_LAYERS::IDENTITY: { log_string("Identity layer detected", log_level, LOG_LEVEL::DEBUG); diff --git a/src/serialization/onnx/onnx.pb.h b/src/serialization/onnx/onnx.pb.h new file mode 100644 index 000000000..5313a753d --- /dev/null +++ b/src/serialization/onnx/onnx.pb.h @@ -0,0 +1,8784 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: onnx.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_onnx_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_onnx_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3011000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3011004 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +#include +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_onnx_2eproto +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct TableStruct_onnx_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxillaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[17] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_onnx_2eproto; +namespace onnx { +class AttributeProto; +class AttributeProtoDefaultTypeInternal; +extern AttributeProtoDefaultTypeInternal _AttributeProto_default_instance_; +class GraphProto; +class GraphProtoDefaultTypeInternal; +extern GraphProtoDefaultTypeInternal _GraphProto_default_instance_; +class ModelProto; +class ModelProtoDefaultTypeInternal; +extern ModelProtoDefaultTypeInternal _ModelProto_default_instance_; +class NodeProto; +class NodeProtoDefaultTypeInternal; +extern NodeProtoDefaultTypeInternal _NodeProto_default_instance_; +class OperatorSetIdProto; +class OperatorSetIdProtoDefaultTypeInternal; +extern OperatorSetIdProtoDefaultTypeInternal _OperatorSetIdProto_default_instance_; +class SparseTensorProto; +class SparseTensorProtoDefaultTypeInternal; +extern SparseTensorProtoDefaultTypeInternal _SparseTensorProto_default_instance_; +class StringStringEntryProto; +class StringStringEntryProtoDefaultTypeInternal; +extern StringStringEntryProtoDefaultTypeInternal _StringStringEntryProto_default_instance_; +class TensorAnnotation; +class TensorAnnotationDefaultTypeInternal; +extern TensorAnnotationDefaultTypeInternal _TensorAnnotation_default_instance_; +class TensorProto; +class TensorProtoDefaultTypeInternal; +extern TensorProtoDefaultTypeInternal _TensorProto_default_instance_; +class TensorProto_Segment; +class TensorProto_SegmentDefaultTypeInternal; +extern TensorProto_SegmentDefaultTypeInternal _TensorProto_Segment_default_instance_; +class TensorShapeProto; +class TensorShapeProtoDefaultTypeInternal; +extern TensorShapeProtoDefaultTypeInternal _TensorShapeProto_default_instance_; +class TensorShapeProto_Dimension; +class TensorShapeProto_DimensionDefaultTypeInternal; +extern TensorShapeProto_DimensionDefaultTypeInternal _TensorShapeProto_Dimension_default_instance_; +class TypeProto; +class TypeProtoDefaultTypeInternal; +extern TypeProtoDefaultTypeInternal _TypeProto_default_instance_; +class TypeProto_Map; +class TypeProto_MapDefaultTypeInternal; +extern TypeProto_MapDefaultTypeInternal _TypeProto_Map_default_instance_; +class TypeProto_Sequence; +class TypeProto_SequenceDefaultTypeInternal; +extern TypeProto_SequenceDefaultTypeInternal _TypeProto_Sequence_default_instance_; +class TypeProto_Tensor; +class TypeProto_TensorDefaultTypeInternal; +extern TypeProto_TensorDefaultTypeInternal _TypeProto_Tensor_default_instance_; +class ValueInfoProto; +class ValueInfoProtoDefaultTypeInternal; +extern ValueInfoProtoDefaultTypeInternal _ValueInfoProto_default_instance_; +} // namespace onnx +PROTOBUF_NAMESPACE_OPEN +template<> ::onnx::AttributeProto* Arena::CreateMaybeMessage<::onnx::AttributeProto>(Arena*); +template<> ::onnx::GraphProto* Arena::CreateMaybeMessage<::onnx::GraphProto>(Arena*); +template<> ::onnx::ModelProto* Arena::CreateMaybeMessage<::onnx::ModelProto>(Arena*); +template<> ::onnx::NodeProto* Arena::CreateMaybeMessage<::onnx::NodeProto>(Arena*); +template<> ::onnx::OperatorSetIdProto* Arena::CreateMaybeMessage<::onnx::OperatorSetIdProto>(Arena*); +template<> ::onnx::SparseTensorProto* Arena::CreateMaybeMessage<::onnx::SparseTensorProto>(Arena*); +template<> ::onnx::StringStringEntryProto* Arena::CreateMaybeMessage<::onnx::StringStringEntryProto>(Arena*); +template<> ::onnx::TensorAnnotation* Arena::CreateMaybeMessage<::onnx::TensorAnnotation>(Arena*); +template<> ::onnx::TensorProto* Arena::CreateMaybeMessage<::onnx::TensorProto>(Arena*); +template<> ::onnx::TensorProto_Segment* Arena::CreateMaybeMessage<::onnx::TensorProto_Segment>(Arena*); +template<> ::onnx::TensorShapeProto* Arena::CreateMaybeMessage<::onnx::TensorShapeProto>(Arena*); +template<> ::onnx::TensorShapeProto_Dimension* Arena::CreateMaybeMessage<::onnx::TensorShapeProto_Dimension>(Arena*); +template<> ::onnx::TypeProto* Arena::CreateMaybeMessage<::onnx::TypeProto>(Arena*); +template<> ::onnx::TypeProto_Map* Arena::CreateMaybeMessage<::onnx::TypeProto_Map>(Arena*); +template<> ::onnx::TypeProto_Sequence* Arena::CreateMaybeMessage<::onnx::TypeProto_Sequence>(Arena*); +template<> ::onnx::TypeProto_Tensor* Arena::CreateMaybeMessage<::onnx::TypeProto_Tensor>(Arena*); +template<> ::onnx::ValueInfoProto* Arena::CreateMaybeMessage<::onnx::ValueInfoProto>(Arena*); +PROTOBUF_NAMESPACE_CLOSE +namespace onnx { + +enum AttributeProto_AttributeType : int { + AttributeProto_AttributeType_UNDEFINED = 0, + AttributeProto_AttributeType_FLOAT = 1, + AttributeProto_AttributeType_INT = 2, + AttributeProto_AttributeType_STRING = 3, + AttributeProto_AttributeType_TENSOR = 4, + AttributeProto_AttributeType_GRAPH = 5, + AttributeProto_AttributeType_SPARSE_TENSOR = 11, + AttributeProto_AttributeType_FLOATS = 6, + AttributeProto_AttributeType_INTS = 7, + AttributeProto_AttributeType_STRINGS = 8, + AttributeProto_AttributeType_TENSORS = 9, + AttributeProto_AttributeType_GRAPHS = 10, + AttributeProto_AttributeType_SPARSE_TENSORS = 12 +}; +bool AttributeProto_AttributeType_IsValid(int value); +constexpr AttributeProto_AttributeType AttributeProto_AttributeType_AttributeType_MIN = AttributeProto_AttributeType_UNDEFINED; +constexpr AttributeProto_AttributeType AttributeProto_AttributeType_AttributeType_MAX = AttributeProto_AttributeType_SPARSE_TENSORS; +constexpr int AttributeProto_AttributeType_AttributeType_ARRAYSIZE = AttributeProto_AttributeType_AttributeType_MAX + 1; + +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* AttributeProto_AttributeType_descriptor(); +template +inline const std::string& AttributeProto_AttributeType_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function AttributeProto_AttributeType_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + AttributeProto_AttributeType_descriptor(), enum_t_value); +} +inline bool AttributeProto_AttributeType_Parse( + const std::string& name, AttributeProto_AttributeType* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + AttributeProto_AttributeType_descriptor(), name, value); +} +enum TensorProto_DataType : int { + TensorProto_DataType_UNDEFINED = 0, + TensorProto_DataType_FLOAT = 1, + TensorProto_DataType_UINT8 = 2, + TensorProto_DataType_INT8 = 3, + TensorProto_DataType_UINT16 = 4, + TensorProto_DataType_INT16 = 5, + TensorProto_DataType_INT32 = 6, + TensorProto_DataType_INT64 = 7, + TensorProto_DataType_STRING = 8, + TensorProto_DataType_BOOL = 9, + TensorProto_DataType_FLOAT16 = 10, + TensorProto_DataType_DOUBLE = 11, + TensorProto_DataType_UINT32 = 12, + TensorProto_DataType_UINT64 = 13, + TensorProto_DataType_COMPLEX64 = 14, + TensorProto_DataType_COMPLEX128 = 15, + TensorProto_DataType_BFLOAT16 = 16 +}; +bool TensorProto_DataType_IsValid(int value); +constexpr TensorProto_DataType TensorProto_DataType_DataType_MIN = TensorProto_DataType_UNDEFINED; +constexpr TensorProto_DataType TensorProto_DataType_DataType_MAX = TensorProto_DataType_BFLOAT16; +constexpr int TensorProto_DataType_DataType_ARRAYSIZE = TensorProto_DataType_DataType_MAX + 1; + +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* TensorProto_DataType_descriptor(); +template +inline const std::string& TensorProto_DataType_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function TensorProto_DataType_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + TensorProto_DataType_descriptor(), enum_t_value); +} +inline bool TensorProto_DataType_Parse( + const std::string& name, TensorProto_DataType* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + TensorProto_DataType_descriptor(), name, value); +} +enum TensorProto_DataLocation : int { + TensorProto_DataLocation_DEFAULT = 0, + TensorProto_DataLocation_EXTERNAL = 1 +}; +bool TensorProto_DataLocation_IsValid(int value); +constexpr TensorProto_DataLocation TensorProto_DataLocation_DataLocation_MIN = TensorProto_DataLocation_DEFAULT; +constexpr TensorProto_DataLocation TensorProto_DataLocation_DataLocation_MAX = TensorProto_DataLocation_EXTERNAL; +constexpr int TensorProto_DataLocation_DataLocation_ARRAYSIZE = TensorProto_DataLocation_DataLocation_MAX + 1; + +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* TensorProto_DataLocation_descriptor(); +template +inline const std::string& TensorProto_DataLocation_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function TensorProto_DataLocation_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + TensorProto_DataLocation_descriptor(), enum_t_value); +} +inline bool TensorProto_DataLocation_Parse( + const std::string& name, TensorProto_DataLocation* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + TensorProto_DataLocation_descriptor(), name, value); +} +enum Version : int { + _START_VERSION = 0, + IR_VERSION_2017_10_10 = 1, + IR_VERSION_2017_10_30 = 2, + IR_VERSION_2017_11_3 = 3, + IR_VERSION_2019_1_22 = 4, + IR_VERSION_2019_3_18 = 5, + IR_VERSION = 6 +}; +bool Version_IsValid(int value); +constexpr Version Version_MIN = _START_VERSION; +constexpr Version Version_MAX = IR_VERSION; +constexpr int Version_ARRAYSIZE = Version_MAX + 1; + +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* Version_descriptor(); +template +inline const std::string& Version_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function Version_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + Version_descriptor(), enum_t_value); +} +inline bool Version_Parse( + const std::string& name, Version* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + Version_descriptor(), name, value); +} +// =================================================================== + +class AttributeProto : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.AttributeProto) */ { + public: + AttributeProto(); + virtual ~AttributeProto(); + + AttributeProto(const AttributeProto& from); + AttributeProto(AttributeProto&& from) noexcept + : AttributeProto() { + *this = ::std::move(from); + } + + inline AttributeProto& operator=(const AttributeProto& from) { + CopyFrom(from); + return *this; + } + inline AttributeProto& operator=(AttributeProto&& from) noexcept { + if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields(); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const AttributeProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const AttributeProto* internal_default_instance() { + return reinterpret_cast( + &_AttributeProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + friend void swap(AttributeProto& a, AttributeProto& b) { + a.Swap(&b); + } + inline void Swap(AttributeProto* other) { + if (other == this) return; + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline AttributeProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + AttributeProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const AttributeProto& from); + void MergeFrom(const AttributeProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(AttributeProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.AttributeProto"; + } + private: + inline ::PROTOBUF_NAMESPACE_ID::Arena* GetArenaNoVirtual() const { + return nullptr; + } + inline void* MaybeArenaPtr() const { + return nullptr; + } + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2eproto); + return ::descriptor_table_onnx_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef AttributeProto_AttributeType AttributeType; + static constexpr AttributeType UNDEFINED = + AttributeProto_AttributeType_UNDEFINED; + static constexpr AttributeType FLOAT = + AttributeProto_AttributeType_FLOAT; + static constexpr AttributeType INT = + AttributeProto_AttributeType_INT; + static constexpr AttributeType STRING = + AttributeProto_AttributeType_STRING; + static constexpr AttributeType TENSOR = + AttributeProto_AttributeType_TENSOR; + static constexpr AttributeType GRAPH = + AttributeProto_AttributeType_GRAPH; + static constexpr AttributeType SPARSE_TENSOR = + AttributeProto_AttributeType_SPARSE_TENSOR; + static constexpr AttributeType FLOATS = + AttributeProto_AttributeType_FLOATS; + static constexpr AttributeType INTS = + AttributeProto_AttributeType_INTS; + static constexpr AttributeType STRINGS = + AttributeProto_AttributeType_STRINGS; + static constexpr AttributeType TENSORS = + AttributeProto_AttributeType_TENSORS; + static constexpr AttributeType GRAPHS = + AttributeProto_AttributeType_GRAPHS; + static constexpr AttributeType SPARSE_TENSORS = + AttributeProto_AttributeType_SPARSE_TENSORS; + static inline bool AttributeType_IsValid(int value) { + return AttributeProto_AttributeType_IsValid(value); + } + static constexpr AttributeType AttributeType_MIN = + AttributeProto_AttributeType_AttributeType_MIN; + static constexpr AttributeType AttributeType_MAX = + AttributeProto_AttributeType_AttributeType_MAX; + static constexpr int AttributeType_ARRAYSIZE = + AttributeProto_AttributeType_AttributeType_ARRAYSIZE; + static inline const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* + AttributeType_descriptor() { + return AttributeProto_AttributeType_descriptor(); + } + template + static inline const std::string& AttributeType_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function AttributeType_Name."); + return AttributeProto_AttributeType_Name(enum_t_value); + } + static inline bool AttributeType_Parse(const std::string& name, + AttributeType* value) { + return AttributeProto_AttributeType_Parse(name, value); + } + + // accessors ------------------------------------------------------- + + enum : int { + kFloatsFieldNumber = 7, + kIntsFieldNumber = 8, + kStringsFieldNumber = 9, + kTensorsFieldNumber = 10, + kGraphsFieldNumber = 11, + kSparseTensorsFieldNumber = 23, + kNameFieldNumber = 1, + kSFieldNumber = 4, + kDocStringFieldNumber = 13, + kRefAttrNameFieldNumber = 21, + kTFieldNumber = 5, + kGFieldNumber = 6, + kSparseTensorFieldNumber = 22, + kIFieldNumber = 3, + kFFieldNumber = 2, + kTypeFieldNumber = 20, + }; + // repeated float floats = 7; + int floats_size() const; + private: + int _internal_floats_size() const; + public: + void clear_floats(); + private: + float _internal_floats(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + _internal_floats() const; + void _internal_add_floats(float value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + _internal_mutable_floats(); + public: + float floats(int index) const; + void set_floats(int index, float value); + void add_floats(float value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + floats() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + mutable_floats(); + + // repeated int64 ints = 8; + int ints_size() const; + private: + int _internal_ints_size() const; + public: + void clear_ints(); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_ints(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + _internal_ints() const; + void _internal_add_ints(::PROTOBUF_NAMESPACE_ID::int64 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + _internal_mutable_ints(); + public: + ::PROTOBUF_NAMESPACE_ID::int64 ints(int index) const; + void set_ints(int index, ::PROTOBUF_NAMESPACE_ID::int64 value); + void add_ints(::PROTOBUF_NAMESPACE_ID::int64 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + ints() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + mutable_ints(); + + // repeated bytes strings = 9; + int strings_size() const; + private: + int _internal_strings_size() const; + public: + void clear_strings(); + const std::string& strings(int index) const; + std::string* mutable_strings(int index); + void set_strings(int index, const std::string& value); + void set_strings(int index, std::string&& value); + void set_strings(int index, const char* value); + void set_strings(int index, const void* value, size_t size); + std::string* add_strings(); + void add_strings(const std::string& value); + void add_strings(std::string&& value); + void add_strings(const char* value); + void add_strings(const void* value, size_t size); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& strings() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* mutable_strings(); + private: + const std::string& _internal_strings(int index) const; + std::string* _internal_add_strings(); + public: + + // repeated .onnx.TensorProto tensors = 10; + int tensors_size() const; + private: + int _internal_tensors_size() const; + public: + void clear_tensors(); + ::onnx::TensorProto* mutable_tensors(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorProto >* + mutable_tensors(); + private: + const ::onnx::TensorProto& _internal_tensors(int index) const; + ::onnx::TensorProto* _internal_add_tensors(); + public: + const ::onnx::TensorProto& tensors(int index) const; + ::onnx::TensorProto* add_tensors(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorProto >& + tensors() const; + + // repeated .onnx.GraphProto graphs = 11; + int graphs_size() const; + private: + int _internal_graphs_size() const; + public: + void clear_graphs(); + ::onnx::GraphProto* mutable_graphs(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::GraphProto >* + mutable_graphs(); + private: + const ::onnx::GraphProto& _internal_graphs(int index) const; + ::onnx::GraphProto* _internal_add_graphs(); + public: + const ::onnx::GraphProto& graphs(int index) const; + ::onnx::GraphProto* add_graphs(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::GraphProto >& + graphs() const; + + // repeated .onnx.SparseTensorProto sparse_tensors = 23; + int sparse_tensors_size() const; + private: + int _internal_sparse_tensors_size() const; + public: + void clear_sparse_tensors(); + ::onnx::SparseTensorProto* mutable_sparse_tensors(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::SparseTensorProto >* + mutable_sparse_tensors(); + private: + const ::onnx::SparseTensorProto& _internal_sparse_tensors(int index) const; + ::onnx::SparseTensorProto* _internal_add_sparse_tensors(); + public: + const ::onnx::SparseTensorProto& sparse_tensors(int index) const; + ::onnx::SparseTensorProto* add_sparse_tensors(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::SparseTensorProto >& + sparse_tensors() const; + + // optional string name = 1; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional bytes s = 4; + bool has_s() const; + private: + bool _internal_has_s() const; + public: + void clear_s(); + const std::string& s() const; + void set_s(const std::string& value); + void set_s(std::string&& value); + void set_s(const char* value); + void set_s(const void* value, size_t size); + std::string* mutable_s(); + std::string* release_s(); + void set_allocated_s(std::string* s); + private: + const std::string& _internal_s() const; + void _internal_set_s(const std::string& value); + std::string* _internal_mutable_s(); + public: + + // optional string doc_string = 13; + bool has_doc_string() const; + private: + bool _internal_has_doc_string() const; + public: + void clear_doc_string(); + const std::string& doc_string() const; + void set_doc_string(const std::string& value); + void set_doc_string(std::string&& value); + void set_doc_string(const char* value); + void set_doc_string(const char* value, size_t size); + std::string* mutable_doc_string(); + std::string* release_doc_string(); + void set_allocated_doc_string(std::string* doc_string); + private: + const std::string& _internal_doc_string() const; + void _internal_set_doc_string(const std::string& value); + std::string* _internal_mutable_doc_string(); + public: + + // optional string ref_attr_name = 21; + bool has_ref_attr_name() const; + private: + bool _internal_has_ref_attr_name() const; + public: + void clear_ref_attr_name(); + const std::string& ref_attr_name() const; + void set_ref_attr_name(const std::string& value); + void set_ref_attr_name(std::string&& value); + void set_ref_attr_name(const char* value); + void set_ref_attr_name(const char* value, size_t size); + std::string* mutable_ref_attr_name(); + std::string* release_ref_attr_name(); + void set_allocated_ref_attr_name(std::string* ref_attr_name); + private: + const std::string& _internal_ref_attr_name() const; + void _internal_set_ref_attr_name(const std::string& value); + std::string* _internal_mutable_ref_attr_name(); + public: + + // optional .onnx.TensorProto t = 5; + bool has_t() const; + private: + bool _internal_has_t() const; + public: + void clear_t(); + const ::onnx::TensorProto& t() const; + ::onnx::TensorProto* release_t(); + ::onnx::TensorProto* mutable_t(); + void set_allocated_t(::onnx::TensorProto* t); + private: + const ::onnx::TensorProto& _internal_t() const; + ::onnx::TensorProto* _internal_mutable_t(); + public: + + // optional .onnx.GraphProto g = 6; + bool has_g() const; + private: + bool _internal_has_g() const; + public: + void clear_g(); + const ::onnx::GraphProto& g() const; + ::onnx::GraphProto* release_g(); + ::onnx::GraphProto* mutable_g(); + void set_allocated_g(::onnx::GraphProto* g); + private: + const ::onnx::GraphProto& _internal_g() const; + ::onnx::GraphProto* _internal_mutable_g(); + public: + + // optional .onnx.SparseTensorProto sparse_tensor = 22; + bool has_sparse_tensor() const; + private: + bool _internal_has_sparse_tensor() const; + public: + void clear_sparse_tensor(); + const ::onnx::SparseTensorProto& sparse_tensor() const; + ::onnx::SparseTensorProto* release_sparse_tensor(); + ::onnx::SparseTensorProto* mutable_sparse_tensor(); + void set_allocated_sparse_tensor(::onnx::SparseTensorProto* sparse_tensor); + private: + const ::onnx::SparseTensorProto& _internal_sparse_tensor() const; + ::onnx::SparseTensorProto* _internal_mutable_sparse_tensor(); + public: + + // optional int64 i = 3; + bool has_i() const; + private: + bool _internal_has_i() const; + public: + void clear_i(); + ::PROTOBUF_NAMESPACE_ID::int64 i() const; + void set_i(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_i() const; + void _internal_set_i(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // optional float f = 2; + bool has_f() const; + private: + bool _internal_has_f() const; + public: + void clear_f(); + float f() const; + void set_f(float value); + private: + float _internal_f() const; + void _internal_set_f(float value); + public: + + // optional .onnx.AttributeProto.AttributeType type = 20; + bool has_type() const; + private: + bool _internal_has_type() const; + public: + void clear_type(); + ::onnx::AttributeProto_AttributeType type() const; + void set_type(::onnx::AttributeProto_AttributeType value); + private: + ::onnx::AttributeProto_AttributeType _internal_type() const; + void _internal_set_type(::onnx::AttributeProto_AttributeType value); + public: + + // @@protoc_insertion_point(class_scope:onnx.AttributeProto) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float > floats_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 > ints_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField strings_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorProto > tensors_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::GraphProto > graphs_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::SparseTensorProto > sparse_tensors_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr s_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr doc_string_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr ref_attr_name_; + ::onnx::TensorProto* t_; + ::onnx::GraphProto* g_; + ::onnx::SparseTensorProto* sparse_tensor_; + ::PROTOBUF_NAMESPACE_ID::int64 i_; + float f_; + int type_; + friend struct ::TableStruct_onnx_2eproto; +}; +// ------------------------------------------------------------------- + +class ValueInfoProto : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.ValueInfoProto) */ { + public: + ValueInfoProto(); + virtual ~ValueInfoProto(); + + ValueInfoProto(const ValueInfoProto& from); + ValueInfoProto(ValueInfoProto&& from) noexcept + : ValueInfoProto() { + *this = ::std::move(from); + } + + inline ValueInfoProto& operator=(const ValueInfoProto& from) { + CopyFrom(from); + return *this; + } + inline ValueInfoProto& operator=(ValueInfoProto&& from) noexcept { + if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields(); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const ValueInfoProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const ValueInfoProto* internal_default_instance() { + return reinterpret_cast( + &_ValueInfoProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 1; + + friend void swap(ValueInfoProto& a, ValueInfoProto& b) { + a.Swap(&b); + } + inline void Swap(ValueInfoProto* other) { + if (other == this) return; + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline ValueInfoProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + ValueInfoProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const ValueInfoProto& from); + void MergeFrom(const ValueInfoProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(ValueInfoProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.ValueInfoProto"; + } + private: + inline ::PROTOBUF_NAMESPACE_ID::Arena* GetArenaNoVirtual() const { + return nullptr; + } + inline void* MaybeArenaPtr() const { + return nullptr; + } + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2eproto); + return ::descriptor_table_onnx_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kNameFieldNumber = 1, + kDocStringFieldNumber = 3, + kTypeFieldNumber = 2, + }; + // optional string name = 1; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional string doc_string = 3; + bool has_doc_string() const; + private: + bool _internal_has_doc_string() const; + public: + void clear_doc_string(); + const std::string& doc_string() const; + void set_doc_string(const std::string& value); + void set_doc_string(std::string&& value); + void set_doc_string(const char* value); + void set_doc_string(const char* value, size_t size); + std::string* mutable_doc_string(); + std::string* release_doc_string(); + void set_allocated_doc_string(std::string* doc_string); + private: + const std::string& _internal_doc_string() const; + void _internal_set_doc_string(const std::string& value); + std::string* _internal_mutable_doc_string(); + public: + + // optional .onnx.TypeProto type = 2; + bool has_type() const; + private: + bool _internal_has_type() const; + public: + void clear_type(); + const ::onnx::TypeProto& type() const; + ::onnx::TypeProto* release_type(); + ::onnx::TypeProto* mutable_type(); + void set_allocated_type(::onnx::TypeProto* type); + private: + const ::onnx::TypeProto& _internal_type() const; + ::onnx::TypeProto* _internal_mutable_type(); + public: + + // @@protoc_insertion_point(class_scope:onnx.ValueInfoProto) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr doc_string_; + ::onnx::TypeProto* type_; + friend struct ::TableStruct_onnx_2eproto; +}; +// ------------------------------------------------------------------- + +class NodeProto : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.NodeProto) */ { + public: + NodeProto(); + virtual ~NodeProto(); + + NodeProto(const NodeProto& from); + NodeProto(NodeProto&& from) noexcept + : NodeProto() { + *this = ::std::move(from); + } + + inline NodeProto& operator=(const NodeProto& from) { + CopyFrom(from); + return *this; + } + inline NodeProto& operator=(NodeProto&& from) noexcept { + if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields(); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const NodeProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const NodeProto* internal_default_instance() { + return reinterpret_cast( + &_NodeProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 2; + + friend void swap(NodeProto& a, NodeProto& b) { + a.Swap(&b); + } + inline void Swap(NodeProto* other) { + if (other == this) return; + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline NodeProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + NodeProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const NodeProto& from); + void MergeFrom(const NodeProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(NodeProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.NodeProto"; + } + private: + inline ::PROTOBUF_NAMESPACE_ID::Arena* GetArenaNoVirtual() const { + return nullptr; + } + inline void* MaybeArenaPtr() const { + return nullptr; + } + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2eproto); + return ::descriptor_table_onnx_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kInputFieldNumber = 1, + kOutputFieldNumber = 2, + kAttributeFieldNumber = 5, + kNameFieldNumber = 3, + kOpTypeFieldNumber = 4, + kDocStringFieldNumber = 6, + kDomainFieldNumber = 7, + }; + // repeated string input = 1; + int input_size() const; + private: + int _internal_input_size() const; + public: + void clear_input(); + const std::string& input(int index) const; + std::string* mutable_input(int index); + void set_input(int index, const std::string& value); + void set_input(int index, std::string&& value); + void set_input(int index, const char* value); + void set_input(int index, const char* value, size_t size); + std::string* add_input(); + void add_input(const std::string& value); + void add_input(std::string&& value); + void add_input(const char* value); + void add_input(const char* value, size_t size); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& input() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* mutable_input(); + private: + const std::string& _internal_input(int index) const; + std::string* _internal_add_input(); + public: + + // repeated string output = 2; + int output_size() const; + private: + int _internal_output_size() const; + public: + void clear_output(); + const std::string& output(int index) const; + std::string* mutable_output(int index); + void set_output(int index, const std::string& value); + void set_output(int index, std::string&& value); + void set_output(int index, const char* value); + void set_output(int index, const char* value, size_t size); + std::string* add_output(); + void add_output(const std::string& value); + void add_output(std::string&& value); + void add_output(const char* value); + void add_output(const char* value, size_t size); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& output() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* mutable_output(); + private: + const std::string& _internal_output(int index) const; + std::string* _internal_add_output(); + public: + + // repeated .onnx.AttributeProto attribute = 5; + int attribute_size() const; + private: + int _internal_attribute_size() const; + public: + void clear_attribute(); + ::onnx::AttributeProto* mutable_attribute(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::AttributeProto >* + mutable_attribute(); + private: + const ::onnx::AttributeProto& _internal_attribute(int index) const; + ::onnx::AttributeProto* _internal_add_attribute(); + public: + const ::onnx::AttributeProto& attribute(int index) const; + ::onnx::AttributeProto* add_attribute(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::AttributeProto >& + attribute() const; + + // optional string name = 3; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional string op_type = 4; + bool has_op_type() const; + private: + bool _internal_has_op_type() const; + public: + void clear_op_type(); + const std::string& op_type() const; + void set_op_type(const std::string& value); + void set_op_type(std::string&& value); + void set_op_type(const char* value); + void set_op_type(const char* value, size_t size); + std::string* mutable_op_type(); + std::string* release_op_type(); + void set_allocated_op_type(std::string* op_type); + private: + const std::string& _internal_op_type() const; + void _internal_set_op_type(const std::string& value); + std::string* _internal_mutable_op_type(); + public: + + // optional string doc_string = 6; + bool has_doc_string() const; + private: + bool _internal_has_doc_string() const; + public: + void clear_doc_string(); + const std::string& doc_string() const; + void set_doc_string(const std::string& value); + void set_doc_string(std::string&& value); + void set_doc_string(const char* value); + void set_doc_string(const char* value, size_t size); + std::string* mutable_doc_string(); + std::string* release_doc_string(); + void set_allocated_doc_string(std::string* doc_string); + private: + const std::string& _internal_doc_string() const; + void _internal_set_doc_string(const std::string& value); + std::string* _internal_mutable_doc_string(); + public: + + // optional string domain = 7; + bool has_domain() const; + private: + bool _internal_has_domain() const; + public: + void clear_domain(); + const std::string& domain() const; + void set_domain(const std::string& value); + void set_domain(std::string&& value); + void set_domain(const char* value); + void set_domain(const char* value, size_t size); + std::string* mutable_domain(); + std::string* release_domain(); + void set_allocated_domain(std::string* domain); + private: + const std::string& _internal_domain() const; + void _internal_set_domain(const std::string& value); + std::string* _internal_mutable_domain(); + public: + + // @@protoc_insertion_point(class_scope:onnx.NodeProto) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField input_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField output_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::AttributeProto > attribute_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr op_type_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr doc_string_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr domain_; + friend struct ::TableStruct_onnx_2eproto; +}; +// ------------------------------------------------------------------- + +class ModelProto : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.ModelProto) */ { + public: + ModelProto(); + virtual ~ModelProto(); + + ModelProto(const ModelProto& from); + ModelProto(ModelProto&& from) noexcept + : ModelProto() { + *this = ::std::move(from); + } + + inline ModelProto& operator=(const ModelProto& from) { + CopyFrom(from); + return *this; + } + inline ModelProto& operator=(ModelProto&& from) noexcept { + if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields(); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const ModelProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const ModelProto* internal_default_instance() { + return reinterpret_cast( + &_ModelProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 3; + + friend void swap(ModelProto& a, ModelProto& b) { + a.Swap(&b); + } + inline void Swap(ModelProto* other) { + if (other == this) return; + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline ModelProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + ModelProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const ModelProto& from); + void MergeFrom(const ModelProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(ModelProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.ModelProto"; + } + private: + inline ::PROTOBUF_NAMESPACE_ID::Arena* GetArenaNoVirtual() const { + return nullptr; + } + inline void* MaybeArenaPtr() const { + return nullptr; + } + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2eproto); + return ::descriptor_table_onnx_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kOpsetImportFieldNumber = 8, + kMetadataPropsFieldNumber = 14, + kProducerNameFieldNumber = 2, + kProducerVersionFieldNumber = 3, + kDomainFieldNumber = 4, + kDocStringFieldNumber = 6, + kGraphFieldNumber = 7, + kIrVersionFieldNumber = 1, + kModelVersionFieldNumber = 5, + }; + // repeated .onnx.OperatorSetIdProto opset_import = 8; + int opset_import_size() const; + private: + int _internal_opset_import_size() const; + public: + void clear_opset_import(); + ::onnx::OperatorSetIdProto* mutable_opset_import(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::OperatorSetIdProto >* + mutable_opset_import(); + private: + const ::onnx::OperatorSetIdProto& _internal_opset_import(int index) const; + ::onnx::OperatorSetIdProto* _internal_add_opset_import(); + public: + const ::onnx::OperatorSetIdProto& opset_import(int index) const; + ::onnx::OperatorSetIdProto* add_opset_import(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::OperatorSetIdProto >& + opset_import() const; + + // repeated .onnx.StringStringEntryProto metadata_props = 14; + int metadata_props_size() const; + private: + int _internal_metadata_props_size() const; + public: + void clear_metadata_props(); + ::onnx::StringStringEntryProto* mutable_metadata_props(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::StringStringEntryProto >* + mutable_metadata_props(); + private: + const ::onnx::StringStringEntryProto& _internal_metadata_props(int index) const; + ::onnx::StringStringEntryProto* _internal_add_metadata_props(); + public: + const ::onnx::StringStringEntryProto& metadata_props(int index) const; + ::onnx::StringStringEntryProto* add_metadata_props(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::StringStringEntryProto >& + metadata_props() const; + + // optional string producer_name = 2; + bool has_producer_name() const; + private: + bool _internal_has_producer_name() const; + public: + void clear_producer_name(); + const std::string& producer_name() const; + void set_producer_name(const std::string& value); + void set_producer_name(std::string&& value); + void set_producer_name(const char* value); + void set_producer_name(const char* value, size_t size); + std::string* mutable_producer_name(); + std::string* release_producer_name(); + void set_allocated_producer_name(std::string* producer_name); + private: + const std::string& _internal_producer_name() const; + void _internal_set_producer_name(const std::string& value); + std::string* _internal_mutable_producer_name(); + public: + + // optional string producer_version = 3; + bool has_producer_version() const; + private: + bool _internal_has_producer_version() const; + public: + void clear_producer_version(); + const std::string& producer_version() const; + void set_producer_version(const std::string& value); + void set_producer_version(std::string&& value); + void set_producer_version(const char* value); + void set_producer_version(const char* value, size_t size); + std::string* mutable_producer_version(); + std::string* release_producer_version(); + void set_allocated_producer_version(std::string* producer_version); + private: + const std::string& _internal_producer_version() const; + void _internal_set_producer_version(const std::string& value); + std::string* _internal_mutable_producer_version(); + public: + + // optional string domain = 4; + bool has_domain() const; + private: + bool _internal_has_domain() const; + public: + void clear_domain(); + const std::string& domain() const; + void set_domain(const std::string& value); + void set_domain(std::string&& value); + void set_domain(const char* value); + void set_domain(const char* value, size_t size); + std::string* mutable_domain(); + std::string* release_domain(); + void set_allocated_domain(std::string* domain); + private: + const std::string& _internal_domain() const; + void _internal_set_domain(const std::string& value); + std::string* _internal_mutable_domain(); + public: + + // optional string doc_string = 6; + bool has_doc_string() const; + private: + bool _internal_has_doc_string() const; + public: + void clear_doc_string(); + const std::string& doc_string() const; + void set_doc_string(const std::string& value); + void set_doc_string(std::string&& value); + void set_doc_string(const char* value); + void set_doc_string(const char* value, size_t size); + std::string* mutable_doc_string(); + std::string* release_doc_string(); + void set_allocated_doc_string(std::string* doc_string); + private: + const std::string& _internal_doc_string() const; + void _internal_set_doc_string(const std::string& value); + std::string* _internal_mutable_doc_string(); + public: + + // optional .onnx.GraphProto graph = 7; + bool has_graph() const; + private: + bool _internal_has_graph() const; + public: + void clear_graph(); + const ::onnx::GraphProto& graph() const; + ::onnx::GraphProto* release_graph(); + ::onnx::GraphProto* mutable_graph(); + void set_allocated_graph(::onnx::GraphProto* graph); + private: + const ::onnx::GraphProto& _internal_graph() const; + ::onnx::GraphProto* _internal_mutable_graph(); + public: + + // optional int64 ir_version = 1; + bool has_ir_version() const; + private: + bool _internal_has_ir_version() const; + public: + void clear_ir_version(); + ::PROTOBUF_NAMESPACE_ID::int64 ir_version() const; + void set_ir_version(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_ir_version() const; + void _internal_set_ir_version(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // optional int64 model_version = 5; + bool has_model_version() const; + private: + bool _internal_has_model_version() const; + public: + void clear_model_version(); + ::PROTOBUF_NAMESPACE_ID::int64 model_version() const; + void set_model_version(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_model_version() const; + void _internal_set_model_version(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // @@protoc_insertion_point(class_scope:onnx.ModelProto) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::OperatorSetIdProto > opset_import_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::StringStringEntryProto > metadata_props_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr producer_name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr producer_version_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr domain_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr doc_string_; + ::onnx::GraphProto* graph_; + ::PROTOBUF_NAMESPACE_ID::int64 ir_version_; + ::PROTOBUF_NAMESPACE_ID::int64 model_version_; + friend struct ::TableStruct_onnx_2eproto; +}; +// ------------------------------------------------------------------- + +class StringStringEntryProto : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.StringStringEntryProto) */ { + public: + StringStringEntryProto(); + virtual ~StringStringEntryProto(); + + StringStringEntryProto(const StringStringEntryProto& from); + StringStringEntryProto(StringStringEntryProto&& from) noexcept + : StringStringEntryProto() { + *this = ::std::move(from); + } + + inline StringStringEntryProto& operator=(const StringStringEntryProto& from) { + CopyFrom(from); + return *this; + } + inline StringStringEntryProto& operator=(StringStringEntryProto&& from) noexcept { + if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields(); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const StringStringEntryProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const StringStringEntryProto* internal_default_instance() { + return reinterpret_cast( + &_StringStringEntryProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 4; + + friend void swap(StringStringEntryProto& a, StringStringEntryProto& b) { + a.Swap(&b); + } + inline void Swap(StringStringEntryProto* other) { + if (other == this) return; + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline StringStringEntryProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + StringStringEntryProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const StringStringEntryProto& from); + void MergeFrom(const StringStringEntryProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(StringStringEntryProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.StringStringEntryProto"; + } + private: + inline ::PROTOBUF_NAMESPACE_ID::Arena* GetArenaNoVirtual() const { + return nullptr; + } + inline void* MaybeArenaPtr() const { + return nullptr; + } + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2eproto); + return ::descriptor_table_onnx_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kKeyFieldNumber = 1, + kValueFieldNumber = 2, + }; + // optional string key = 1; + bool has_key() const; + private: + bool _internal_has_key() const; + public: + void clear_key(); + const std::string& key() const; + void set_key(const std::string& value); + void set_key(std::string&& value); + void set_key(const char* value); + void set_key(const char* value, size_t size); + std::string* mutable_key(); + std::string* release_key(); + void set_allocated_key(std::string* key); + private: + const std::string& _internal_key() const; + void _internal_set_key(const std::string& value); + std::string* _internal_mutable_key(); + public: + + // optional string value = 2; + bool has_value() const; + private: + bool _internal_has_value() const; + public: + void clear_value(); + const std::string& value() const; + void set_value(const std::string& value); + void set_value(std::string&& value); + void set_value(const char* value); + void set_value(const char* value, size_t size); + std::string* mutable_value(); + std::string* release_value(); + void set_allocated_value(std::string* value); + private: + const std::string& _internal_value() const; + void _internal_set_value(const std::string& value); + std::string* _internal_mutable_value(); + public: + + // @@protoc_insertion_point(class_scope:onnx.StringStringEntryProto) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr key_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr value_; + friend struct ::TableStruct_onnx_2eproto; +}; +// ------------------------------------------------------------------- + +class TensorAnnotation : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.TensorAnnotation) */ { + public: + TensorAnnotation(); + virtual ~TensorAnnotation(); + + TensorAnnotation(const TensorAnnotation& from); + TensorAnnotation(TensorAnnotation&& from) noexcept + : TensorAnnotation() { + *this = ::std::move(from); + } + + inline TensorAnnotation& operator=(const TensorAnnotation& from) { + CopyFrom(from); + return *this; + } + inline TensorAnnotation& operator=(TensorAnnotation&& from) noexcept { + if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields(); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TensorAnnotation& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TensorAnnotation* internal_default_instance() { + return reinterpret_cast( + &_TensorAnnotation_default_instance_); + } + static constexpr int kIndexInFileMessages = + 5; + + friend void swap(TensorAnnotation& a, TensorAnnotation& b) { + a.Swap(&b); + } + inline void Swap(TensorAnnotation* other) { + if (other == this) return; + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TensorAnnotation* New() const final { + return CreateMaybeMessage(nullptr); + } + + TensorAnnotation* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TensorAnnotation& from); + void MergeFrom(const TensorAnnotation& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TensorAnnotation* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.TensorAnnotation"; + } + private: + inline ::PROTOBUF_NAMESPACE_ID::Arena* GetArenaNoVirtual() const { + return nullptr; + } + inline void* MaybeArenaPtr() const { + return nullptr; + } + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2eproto); + return ::descriptor_table_onnx_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kQuantParameterTensorNamesFieldNumber = 2, + kTensorNameFieldNumber = 1, + }; + // repeated .onnx.StringStringEntryProto quant_parameter_tensor_names = 2; + int quant_parameter_tensor_names_size() const; + private: + int _internal_quant_parameter_tensor_names_size() const; + public: + void clear_quant_parameter_tensor_names(); + ::onnx::StringStringEntryProto* mutable_quant_parameter_tensor_names(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::StringStringEntryProto >* + mutable_quant_parameter_tensor_names(); + private: + const ::onnx::StringStringEntryProto& _internal_quant_parameter_tensor_names(int index) const; + ::onnx::StringStringEntryProto* _internal_add_quant_parameter_tensor_names(); + public: + const ::onnx::StringStringEntryProto& quant_parameter_tensor_names(int index) const; + ::onnx::StringStringEntryProto* add_quant_parameter_tensor_names(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::StringStringEntryProto >& + quant_parameter_tensor_names() const; + + // optional string tensor_name = 1; + bool has_tensor_name() const; + private: + bool _internal_has_tensor_name() const; + public: + void clear_tensor_name(); + const std::string& tensor_name() const; + void set_tensor_name(const std::string& value); + void set_tensor_name(std::string&& value); + void set_tensor_name(const char* value); + void set_tensor_name(const char* value, size_t size); + std::string* mutable_tensor_name(); + std::string* release_tensor_name(); + void set_allocated_tensor_name(std::string* tensor_name); + private: + const std::string& _internal_tensor_name() const; + void _internal_set_tensor_name(const std::string& value); + std::string* _internal_mutable_tensor_name(); + public: + + // @@protoc_insertion_point(class_scope:onnx.TensorAnnotation) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::StringStringEntryProto > quant_parameter_tensor_names_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr tensor_name_; + friend struct ::TableStruct_onnx_2eproto; +}; +// ------------------------------------------------------------------- + +class GraphProto : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.GraphProto) */ { + public: + GraphProto(); + virtual ~GraphProto(); + + GraphProto(const GraphProto& from); + GraphProto(GraphProto&& from) noexcept + : GraphProto() { + *this = ::std::move(from); + } + + inline GraphProto& operator=(const GraphProto& from) { + CopyFrom(from); + return *this; + } + inline GraphProto& operator=(GraphProto&& from) noexcept { + if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields(); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const GraphProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const GraphProto* internal_default_instance() { + return reinterpret_cast( + &_GraphProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 6; + + friend void swap(GraphProto& a, GraphProto& b) { + a.Swap(&b); + } + inline void Swap(GraphProto* other) { + if (other == this) return; + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline GraphProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + GraphProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const GraphProto& from); + void MergeFrom(const GraphProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(GraphProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.GraphProto"; + } + private: + inline ::PROTOBUF_NAMESPACE_ID::Arena* GetArenaNoVirtual() const { + return nullptr; + } + inline void* MaybeArenaPtr() const { + return nullptr; + } + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2eproto); + return ::descriptor_table_onnx_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kNodeFieldNumber = 1, + kInitializerFieldNumber = 5, + kInputFieldNumber = 11, + kOutputFieldNumber = 12, + kValueInfoFieldNumber = 13, + kQuantizationAnnotationFieldNumber = 14, + kSparseInitializerFieldNumber = 15, + kNameFieldNumber = 2, + kDocStringFieldNumber = 10, + }; + // repeated .onnx.NodeProto node = 1; + int node_size() const; + private: + int _internal_node_size() const; + public: + void clear_node(); + ::onnx::NodeProto* mutable_node(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::NodeProto >* + mutable_node(); + private: + const ::onnx::NodeProto& _internal_node(int index) const; + ::onnx::NodeProto* _internal_add_node(); + public: + const ::onnx::NodeProto& node(int index) const; + ::onnx::NodeProto* add_node(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::NodeProto >& + node() const; + + // repeated .onnx.TensorProto initializer = 5; + int initializer_size() const; + private: + int _internal_initializer_size() const; + public: + void clear_initializer(); + ::onnx::TensorProto* mutable_initializer(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorProto >* + mutable_initializer(); + private: + const ::onnx::TensorProto& _internal_initializer(int index) const; + ::onnx::TensorProto* _internal_add_initializer(); + public: + const ::onnx::TensorProto& initializer(int index) const; + ::onnx::TensorProto* add_initializer(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorProto >& + initializer() const; + + // repeated .onnx.ValueInfoProto input = 11; + int input_size() const; + private: + int _internal_input_size() const; + public: + void clear_input(); + ::onnx::ValueInfoProto* mutable_input(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >* + mutable_input(); + private: + const ::onnx::ValueInfoProto& _internal_input(int index) const; + ::onnx::ValueInfoProto* _internal_add_input(); + public: + const ::onnx::ValueInfoProto& input(int index) const; + ::onnx::ValueInfoProto* add_input(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >& + input() const; + + // repeated .onnx.ValueInfoProto output = 12; + int output_size() const; + private: + int _internal_output_size() const; + public: + void clear_output(); + ::onnx::ValueInfoProto* mutable_output(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >* + mutable_output(); + private: + const ::onnx::ValueInfoProto& _internal_output(int index) const; + ::onnx::ValueInfoProto* _internal_add_output(); + public: + const ::onnx::ValueInfoProto& output(int index) const; + ::onnx::ValueInfoProto* add_output(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >& + output() const; + + // repeated .onnx.ValueInfoProto value_info = 13; + int value_info_size() const; + private: + int _internal_value_info_size() const; + public: + void clear_value_info(); + ::onnx::ValueInfoProto* mutable_value_info(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >* + mutable_value_info(); + private: + const ::onnx::ValueInfoProto& _internal_value_info(int index) const; + ::onnx::ValueInfoProto* _internal_add_value_info(); + public: + const ::onnx::ValueInfoProto& value_info(int index) const; + ::onnx::ValueInfoProto* add_value_info(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >& + value_info() const; + + // repeated .onnx.TensorAnnotation quantization_annotation = 14; + int quantization_annotation_size() const; + private: + int _internal_quantization_annotation_size() const; + public: + void clear_quantization_annotation(); + ::onnx::TensorAnnotation* mutable_quantization_annotation(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorAnnotation >* + mutable_quantization_annotation(); + private: + const ::onnx::TensorAnnotation& _internal_quantization_annotation(int index) const; + ::onnx::TensorAnnotation* _internal_add_quantization_annotation(); + public: + const ::onnx::TensorAnnotation& quantization_annotation(int index) const; + ::onnx::TensorAnnotation* add_quantization_annotation(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorAnnotation >& + quantization_annotation() const; + + // repeated .onnx.SparseTensorProto sparse_initializer = 15; + int sparse_initializer_size() const; + private: + int _internal_sparse_initializer_size() const; + public: + void clear_sparse_initializer(); + ::onnx::SparseTensorProto* mutable_sparse_initializer(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::SparseTensorProto >* + mutable_sparse_initializer(); + private: + const ::onnx::SparseTensorProto& _internal_sparse_initializer(int index) const; + ::onnx::SparseTensorProto* _internal_add_sparse_initializer(); + public: + const ::onnx::SparseTensorProto& sparse_initializer(int index) const; + ::onnx::SparseTensorProto* add_sparse_initializer(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::SparseTensorProto >& + sparse_initializer() const; + + // optional string name = 2; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional string doc_string = 10; + bool has_doc_string() const; + private: + bool _internal_has_doc_string() const; + public: + void clear_doc_string(); + const std::string& doc_string() const; + void set_doc_string(const std::string& value); + void set_doc_string(std::string&& value); + void set_doc_string(const char* value); + void set_doc_string(const char* value, size_t size); + std::string* mutable_doc_string(); + std::string* release_doc_string(); + void set_allocated_doc_string(std::string* doc_string); + private: + const std::string& _internal_doc_string() const; + void _internal_set_doc_string(const std::string& value); + std::string* _internal_mutable_doc_string(); + public: + + // @@protoc_insertion_point(class_scope:onnx.GraphProto) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::NodeProto > node_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorProto > initializer_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto > input_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto > output_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto > value_info_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorAnnotation > quantization_annotation_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::SparseTensorProto > sparse_initializer_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr doc_string_; + friend struct ::TableStruct_onnx_2eproto; +}; +// ------------------------------------------------------------------- + +class TensorProto_Segment : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.TensorProto.Segment) */ { + public: + TensorProto_Segment(); + virtual ~TensorProto_Segment(); + + TensorProto_Segment(const TensorProto_Segment& from); + TensorProto_Segment(TensorProto_Segment&& from) noexcept + : TensorProto_Segment() { + *this = ::std::move(from); + } + + inline TensorProto_Segment& operator=(const TensorProto_Segment& from) { + CopyFrom(from); + return *this; + } + inline TensorProto_Segment& operator=(TensorProto_Segment&& from) noexcept { + if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields(); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TensorProto_Segment& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TensorProto_Segment* internal_default_instance() { + return reinterpret_cast( + &_TensorProto_Segment_default_instance_); + } + static constexpr int kIndexInFileMessages = + 7; + + friend void swap(TensorProto_Segment& a, TensorProto_Segment& b) { + a.Swap(&b); + } + inline void Swap(TensorProto_Segment* other) { + if (other == this) return; + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TensorProto_Segment* New() const final { + return CreateMaybeMessage(nullptr); + } + + TensorProto_Segment* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TensorProto_Segment& from); + void MergeFrom(const TensorProto_Segment& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TensorProto_Segment* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.TensorProto.Segment"; + } + private: + inline ::PROTOBUF_NAMESPACE_ID::Arena* GetArenaNoVirtual() const { + return nullptr; + } + inline void* MaybeArenaPtr() const { + return nullptr; + } + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2eproto); + return ::descriptor_table_onnx_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kBeginFieldNumber = 1, + kEndFieldNumber = 2, + }; + // optional int64 begin = 1; + bool has_begin() const; + private: + bool _internal_has_begin() const; + public: + void clear_begin(); + ::PROTOBUF_NAMESPACE_ID::int64 begin() const; + void set_begin(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_begin() const; + void _internal_set_begin(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // optional int64 end = 2; + bool has_end() const; + private: + bool _internal_has_end() const; + public: + void clear_end(); + ::PROTOBUF_NAMESPACE_ID::int64 end() const; + void set_end(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_end() const; + void _internal_set_end(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // @@protoc_insertion_point(class_scope:onnx.TensorProto.Segment) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::int64 begin_; + ::PROTOBUF_NAMESPACE_ID::int64 end_; + friend struct ::TableStruct_onnx_2eproto; +}; +// ------------------------------------------------------------------- + +class TensorProto : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.TensorProto) */ { + public: + TensorProto(); + virtual ~TensorProto(); + + TensorProto(const TensorProto& from); + TensorProto(TensorProto&& from) noexcept + : TensorProto() { + *this = ::std::move(from); + } + + inline TensorProto& operator=(const TensorProto& from) { + CopyFrom(from); + return *this; + } + inline TensorProto& operator=(TensorProto&& from) noexcept { + if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields(); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TensorProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TensorProto* internal_default_instance() { + return reinterpret_cast( + &_TensorProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 8; + + friend void swap(TensorProto& a, TensorProto& b) { + a.Swap(&b); + } + inline void Swap(TensorProto* other) { + if (other == this) return; + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TensorProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + TensorProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TensorProto& from); + void MergeFrom(const TensorProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TensorProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.TensorProto"; + } + private: + inline ::PROTOBUF_NAMESPACE_ID::Arena* GetArenaNoVirtual() const { + return nullptr; + } + inline void* MaybeArenaPtr() const { + return nullptr; + } + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2eproto); + return ::descriptor_table_onnx_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef TensorProto_Segment Segment; + + typedef TensorProto_DataType DataType; + static constexpr DataType UNDEFINED = + TensorProto_DataType_UNDEFINED; + static constexpr DataType FLOAT = + TensorProto_DataType_FLOAT; + static constexpr DataType UINT8 = + TensorProto_DataType_UINT8; + static constexpr DataType INT8 = + TensorProto_DataType_INT8; + static constexpr DataType UINT16 = + TensorProto_DataType_UINT16; + static constexpr DataType INT16 = + TensorProto_DataType_INT16; + static constexpr DataType INT32 = + TensorProto_DataType_INT32; + static constexpr DataType INT64 = + TensorProto_DataType_INT64; + static constexpr DataType STRING = + TensorProto_DataType_STRING; + static constexpr DataType BOOL = + TensorProto_DataType_BOOL; + static constexpr DataType FLOAT16 = + TensorProto_DataType_FLOAT16; + static constexpr DataType DOUBLE = + TensorProto_DataType_DOUBLE; + static constexpr DataType UINT32 = + TensorProto_DataType_UINT32; + static constexpr DataType UINT64 = + TensorProto_DataType_UINT64; + static constexpr DataType COMPLEX64 = + TensorProto_DataType_COMPLEX64; + static constexpr DataType COMPLEX128 = + TensorProto_DataType_COMPLEX128; + static constexpr DataType BFLOAT16 = + TensorProto_DataType_BFLOAT16; + static inline bool DataType_IsValid(int value) { + return TensorProto_DataType_IsValid(value); + } + static constexpr DataType DataType_MIN = + TensorProto_DataType_DataType_MIN; + static constexpr DataType DataType_MAX = + TensorProto_DataType_DataType_MAX; + static constexpr int DataType_ARRAYSIZE = + TensorProto_DataType_DataType_ARRAYSIZE; + static inline const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* + DataType_descriptor() { + return TensorProto_DataType_descriptor(); + } + template + static inline const std::string& DataType_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function DataType_Name."); + return TensorProto_DataType_Name(enum_t_value); + } + static inline bool DataType_Parse(const std::string& name, + DataType* value) { + return TensorProto_DataType_Parse(name, value); + } + + typedef TensorProto_DataLocation DataLocation; + static constexpr DataLocation DEFAULT = + TensorProto_DataLocation_DEFAULT; + static constexpr DataLocation EXTERNAL = + TensorProto_DataLocation_EXTERNAL; + static inline bool DataLocation_IsValid(int value) { + return TensorProto_DataLocation_IsValid(value); + } + static constexpr DataLocation DataLocation_MIN = + TensorProto_DataLocation_DataLocation_MIN; + static constexpr DataLocation DataLocation_MAX = + TensorProto_DataLocation_DataLocation_MAX; + static constexpr int DataLocation_ARRAYSIZE = + TensorProto_DataLocation_DataLocation_ARRAYSIZE; + static inline const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* + DataLocation_descriptor() { + return TensorProto_DataLocation_descriptor(); + } + template + static inline const std::string& DataLocation_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function DataLocation_Name."); + return TensorProto_DataLocation_Name(enum_t_value); + } + static inline bool DataLocation_Parse(const std::string& name, + DataLocation* value) { + return TensorProto_DataLocation_Parse(name, value); + } + + // accessors ------------------------------------------------------- + + enum : int { + kDimsFieldNumber = 1, + kFloatDataFieldNumber = 4, + kInt32DataFieldNumber = 5, + kStringDataFieldNumber = 6, + kInt64DataFieldNumber = 7, + kDoubleDataFieldNumber = 10, + kUint64DataFieldNumber = 11, + kExternalDataFieldNumber = 13, + kNameFieldNumber = 8, + kRawDataFieldNumber = 9, + kDocStringFieldNumber = 12, + kSegmentFieldNumber = 3, + kDataTypeFieldNumber = 2, + kDataLocationFieldNumber = 14, + }; + // repeated int64 dims = 1; + int dims_size() const; + private: + int _internal_dims_size() const; + public: + void clear_dims(); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_dims(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + _internal_dims() const; + void _internal_add_dims(::PROTOBUF_NAMESPACE_ID::int64 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + _internal_mutable_dims(); + public: + ::PROTOBUF_NAMESPACE_ID::int64 dims(int index) const; + void set_dims(int index, ::PROTOBUF_NAMESPACE_ID::int64 value); + void add_dims(::PROTOBUF_NAMESPACE_ID::int64 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + dims() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + mutable_dims(); + + // repeated float float_data = 4 [packed = true]; + int float_data_size() const; + private: + int _internal_float_data_size() const; + public: + void clear_float_data(); + private: + float _internal_float_data(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + _internal_float_data() const; + void _internal_add_float_data(float value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + _internal_mutable_float_data(); + public: + float float_data(int index) const; + void set_float_data(int index, float value); + void add_float_data(float value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + float_data() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + mutable_float_data(); + + // repeated int32 int32_data = 5 [packed = true]; + int int32_data_size() const; + private: + int _internal_int32_data_size() const; + public: + void clear_int32_data(); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_int32_data(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + _internal_int32_data() const; + void _internal_add_int32_data(::PROTOBUF_NAMESPACE_ID::int32 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + _internal_mutable_int32_data(); + public: + ::PROTOBUF_NAMESPACE_ID::int32 int32_data(int index) const; + void set_int32_data(int index, ::PROTOBUF_NAMESPACE_ID::int32 value); + void add_int32_data(::PROTOBUF_NAMESPACE_ID::int32 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + int32_data() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + mutable_int32_data(); + + // repeated bytes string_data = 6; + int string_data_size() const; + private: + int _internal_string_data_size() const; + public: + void clear_string_data(); + const std::string& string_data(int index) const; + std::string* mutable_string_data(int index); + void set_string_data(int index, const std::string& value); + void set_string_data(int index, std::string&& value); + void set_string_data(int index, const char* value); + void set_string_data(int index, const void* value, size_t size); + std::string* add_string_data(); + void add_string_data(const std::string& value); + void add_string_data(std::string&& value); + void add_string_data(const char* value); + void add_string_data(const void* value, size_t size); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& string_data() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* mutable_string_data(); + private: + const std::string& _internal_string_data(int index) const; + std::string* _internal_add_string_data(); + public: + + // repeated int64 int64_data = 7 [packed = true]; + int int64_data_size() const; + private: + int _internal_int64_data_size() const; + public: + void clear_int64_data(); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_int64_data(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + _internal_int64_data() const; + void _internal_add_int64_data(::PROTOBUF_NAMESPACE_ID::int64 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + _internal_mutable_int64_data(); + public: + ::PROTOBUF_NAMESPACE_ID::int64 int64_data(int index) const; + void set_int64_data(int index, ::PROTOBUF_NAMESPACE_ID::int64 value); + void add_int64_data(::PROTOBUF_NAMESPACE_ID::int64 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + int64_data() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + mutable_int64_data(); + + // repeated double double_data = 10 [packed = true]; + int double_data_size() const; + private: + int _internal_double_data_size() const; + public: + void clear_double_data(); + private: + double _internal_double_data(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& + _internal_double_data() const; + void _internal_add_double_data(double value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* + _internal_mutable_double_data(); + public: + double double_data(int index) const; + void set_double_data(int index, double value); + void add_double_data(double value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& + double_data() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* + mutable_double_data(); + + // repeated uint64 uint64_data = 11 [packed = true]; + int uint64_data_size() const; + private: + int _internal_uint64_data_size() const; + public: + void clear_uint64_data(); + private: + ::PROTOBUF_NAMESPACE_ID::uint64 _internal_uint64_data(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >& + _internal_uint64_data() const; + void _internal_add_uint64_data(::PROTOBUF_NAMESPACE_ID::uint64 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >* + _internal_mutable_uint64_data(); + public: + ::PROTOBUF_NAMESPACE_ID::uint64 uint64_data(int index) const; + void set_uint64_data(int index, ::PROTOBUF_NAMESPACE_ID::uint64 value); + void add_uint64_data(::PROTOBUF_NAMESPACE_ID::uint64 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >& + uint64_data() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >* + mutable_uint64_data(); + + // repeated .onnx.StringStringEntryProto external_data = 13; + int external_data_size() const; + private: + int _internal_external_data_size() const; + public: + void clear_external_data(); + ::onnx::StringStringEntryProto* mutable_external_data(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::StringStringEntryProto >* + mutable_external_data(); + private: + const ::onnx::StringStringEntryProto& _internal_external_data(int index) const; + ::onnx::StringStringEntryProto* _internal_add_external_data(); + public: + const ::onnx::StringStringEntryProto& external_data(int index) const; + ::onnx::StringStringEntryProto* add_external_data(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::StringStringEntryProto >& + external_data() const; + + // optional string name = 8; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional bytes raw_data = 9; + bool has_raw_data() const; + private: + bool _internal_has_raw_data() const; + public: + void clear_raw_data(); + const std::string& raw_data() const; + void set_raw_data(const std::string& value); + void set_raw_data(std::string&& value); + void set_raw_data(const char* value); + void set_raw_data(const void* value, size_t size); + std::string* mutable_raw_data(); + std::string* release_raw_data(); + void set_allocated_raw_data(std::string* raw_data); + private: + const std::string& _internal_raw_data() const; + void _internal_set_raw_data(const std::string& value); + std::string* _internal_mutable_raw_data(); + public: + + // optional string doc_string = 12; + bool has_doc_string() const; + private: + bool _internal_has_doc_string() const; + public: + void clear_doc_string(); + const std::string& doc_string() const; + void set_doc_string(const std::string& value); + void set_doc_string(std::string&& value); + void set_doc_string(const char* value); + void set_doc_string(const char* value, size_t size); + std::string* mutable_doc_string(); + std::string* release_doc_string(); + void set_allocated_doc_string(std::string* doc_string); + private: + const std::string& _internal_doc_string() const; + void _internal_set_doc_string(const std::string& value); + std::string* _internal_mutable_doc_string(); + public: + + // optional .onnx.TensorProto.Segment segment = 3; + bool has_segment() const; + private: + bool _internal_has_segment() const; + public: + void clear_segment(); + const ::onnx::TensorProto_Segment& segment() const; + ::onnx::TensorProto_Segment* release_segment(); + ::onnx::TensorProto_Segment* mutable_segment(); + void set_allocated_segment(::onnx::TensorProto_Segment* segment); + private: + const ::onnx::TensorProto_Segment& _internal_segment() const; + ::onnx::TensorProto_Segment* _internal_mutable_segment(); + public: + + // optional int32 data_type = 2; + bool has_data_type() const; + private: + bool _internal_has_data_type() const; + public: + void clear_data_type(); + ::PROTOBUF_NAMESPACE_ID::int32 data_type() const; + void set_data_type(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_data_type() const; + void _internal_set_data_type(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // optional .onnx.TensorProto.DataLocation data_location = 14; + bool has_data_location() const; + private: + bool _internal_has_data_location() const; + public: + void clear_data_location(); + ::onnx::TensorProto_DataLocation data_location() const; + void set_data_location(::onnx::TensorProto_DataLocation value); + private: + ::onnx::TensorProto_DataLocation _internal_data_location() const; + void _internal_set_data_location(::onnx::TensorProto_DataLocation value); + public: + + // @@protoc_insertion_point(class_scope:onnx.TensorProto) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 > dims_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float > float_data_; + mutable std::atomic _float_data_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 > int32_data_; + mutable std::atomic _int32_data_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField string_data_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 > int64_data_; + mutable std::atomic _int64_data_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double > double_data_; + mutable std::atomic _double_data_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 > uint64_data_; + mutable std::atomic _uint64_data_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::StringStringEntryProto > external_data_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr raw_data_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr doc_string_; + ::onnx::TensorProto_Segment* segment_; + ::PROTOBUF_NAMESPACE_ID::int32 data_type_; + int data_location_; + friend struct ::TableStruct_onnx_2eproto; +}; +// ------------------------------------------------------------------- + +class SparseTensorProto : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.SparseTensorProto) */ { + public: + SparseTensorProto(); + virtual ~SparseTensorProto(); + + SparseTensorProto(const SparseTensorProto& from); + SparseTensorProto(SparseTensorProto&& from) noexcept + : SparseTensorProto() { + *this = ::std::move(from); + } + + inline SparseTensorProto& operator=(const SparseTensorProto& from) { + CopyFrom(from); + return *this; + } + inline SparseTensorProto& operator=(SparseTensorProto&& from) noexcept { + if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields(); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const SparseTensorProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const SparseTensorProto* internal_default_instance() { + return reinterpret_cast( + &_SparseTensorProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 9; + + friend void swap(SparseTensorProto& a, SparseTensorProto& b) { + a.Swap(&b); + } + inline void Swap(SparseTensorProto* other) { + if (other == this) return; + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline SparseTensorProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + SparseTensorProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const SparseTensorProto& from); + void MergeFrom(const SparseTensorProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(SparseTensorProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.SparseTensorProto"; + } + private: + inline ::PROTOBUF_NAMESPACE_ID::Arena* GetArenaNoVirtual() const { + return nullptr; + } + inline void* MaybeArenaPtr() const { + return nullptr; + } + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2eproto); + return ::descriptor_table_onnx_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kDimsFieldNumber = 3, + kValuesFieldNumber = 1, + kIndicesFieldNumber = 2, + }; + // repeated int64 dims = 3; + int dims_size() const; + private: + int _internal_dims_size() const; + public: + void clear_dims(); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_dims(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + _internal_dims() const; + void _internal_add_dims(::PROTOBUF_NAMESPACE_ID::int64 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + _internal_mutable_dims(); + public: + ::PROTOBUF_NAMESPACE_ID::int64 dims(int index) const; + void set_dims(int index, ::PROTOBUF_NAMESPACE_ID::int64 value); + void add_dims(::PROTOBUF_NAMESPACE_ID::int64 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + dims() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + mutable_dims(); + + // optional .onnx.TensorProto values = 1; + bool has_values() const; + private: + bool _internal_has_values() const; + public: + void clear_values(); + const ::onnx::TensorProto& values() const; + ::onnx::TensorProto* release_values(); + ::onnx::TensorProto* mutable_values(); + void set_allocated_values(::onnx::TensorProto* values); + private: + const ::onnx::TensorProto& _internal_values() const; + ::onnx::TensorProto* _internal_mutable_values(); + public: + + // optional .onnx.TensorProto indices = 2; + bool has_indices() const; + private: + bool _internal_has_indices() const; + public: + void clear_indices(); + const ::onnx::TensorProto& indices() const; + ::onnx::TensorProto* release_indices(); + ::onnx::TensorProto* mutable_indices(); + void set_allocated_indices(::onnx::TensorProto* indices); + private: + const ::onnx::TensorProto& _internal_indices() const; + ::onnx::TensorProto* _internal_mutable_indices(); + public: + + // @@protoc_insertion_point(class_scope:onnx.SparseTensorProto) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 > dims_; + ::onnx::TensorProto* values_; + ::onnx::TensorProto* indices_; + friend struct ::TableStruct_onnx_2eproto; +}; +// ------------------------------------------------------------------- + +class TensorShapeProto_Dimension : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.TensorShapeProto.Dimension) */ { + public: + TensorShapeProto_Dimension(); + virtual ~TensorShapeProto_Dimension(); + + TensorShapeProto_Dimension(const TensorShapeProto_Dimension& from); + TensorShapeProto_Dimension(TensorShapeProto_Dimension&& from) noexcept + : TensorShapeProto_Dimension() { + *this = ::std::move(from); + } + + inline TensorShapeProto_Dimension& operator=(const TensorShapeProto_Dimension& from) { + CopyFrom(from); + return *this; + } + inline TensorShapeProto_Dimension& operator=(TensorShapeProto_Dimension&& from) noexcept { + if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields(); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TensorShapeProto_Dimension& default_instance(); + + enum ValueCase { + kDimValue = 1, + kDimParam = 2, + VALUE_NOT_SET = 0, + }; + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TensorShapeProto_Dimension* internal_default_instance() { + return reinterpret_cast( + &_TensorShapeProto_Dimension_default_instance_); + } + static constexpr int kIndexInFileMessages = + 10; + + friend void swap(TensorShapeProto_Dimension& a, TensorShapeProto_Dimension& b) { + a.Swap(&b); + } + inline void Swap(TensorShapeProto_Dimension* other) { + if (other == this) return; + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TensorShapeProto_Dimension* New() const final { + return CreateMaybeMessage(nullptr); + } + + TensorShapeProto_Dimension* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TensorShapeProto_Dimension& from); + void MergeFrom(const TensorShapeProto_Dimension& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TensorShapeProto_Dimension* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.TensorShapeProto.Dimension"; + } + private: + inline ::PROTOBUF_NAMESPACE_ID::Arena* GetArenaNoVirtual() const { + return nullptr; + } + inline void* MaybeArenaPtr() const { + return nullptr; + } + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2eproto); + return ::descriptor_table_onnx_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kDenotationFieldNumber = 3, + kDimValueFieldNumber = 1, + kDimParamFieldNumber = 2, + }; + // optional string denotation = 3; + bool has_denotation() const; + private: + bool _internal_has_denotation() const; + public: + void clear_denotation(); + const std::string& denotation() const; + void set_denotation(const std::string& value); + void set_denotation(std::string&& value); + void set_denotation(const char* value); + void set_denotation(const char* value, size_t size); + std::string* mutable_denotation(); + std::string* release_denotation(); + void set_allocated_denotation(std::string* denotation); + private: + const std::string& _internal_denotation() const; + void _internal_set_denotation(const std::string& value); + std::string* _internal_mutable_denotation(); + public: + + // optional int64 dim_value = 1; + bool has_dim_value() const; + private: + bool _internal_has_dim_value() const; + public: + void clear_dim_value(); + ::PROTOBUF_NAMESPACE_ID::int64 dim_value() const; + void set_dim_value(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_dim_value() const; + void _internal_set_dim_value(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // optional string dim_param = 2; + bool has_dim_param() const; + private: + bool _internal_has_dim_param() const; + public: + void clear_dim_param(); + const std::string& dim_param() const; + void set_dim_param(const std::string& value); + void set_dim_param(std::string&& value); + void set_dim_param(const char* value); + void set_dim_param(const char* value, size_t size); + std::string* mutable_dim_param(); + std::string* release_dim_param(); + void set_allocated_dim_param(std::string* dim_param); + private: + const std::string& _internal_dim_param() const; + void _internal_set_dim_param(const std::string& value); + std::string* _internal_mutable_dim_param(); + public: + + void clear_value(); + ValueCase value_case() const; + // @@protoc_insertion_point(class_scope:onnx.TensorShapeProto.Dimension) + private: + class _Internal; + void set_has_dim_value(); + void set_has_dim_param(); + + inline bool has_value() const; + inline void clear_has_value(); + + ::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr denotation_; + union ValueUnion { + ValueUnion() {} + ::PROTOBUF_NAMESPACE_ID::int64 dim_value_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr dim_param_; + } value_; + ::PROTOBUF_NAMESPACE_ID::uint32 _oneof_case_[1]; + + friend struct ::TableStruct_onnx_2eproto; +}; +// ------------------------------------------------------------------- + +class TensorShapeProto : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.TensorShapeProto) */ { + public: + TensorShapeProto(); + virtual ~TensorShapeProto(); + + TensorShapeProto(const TensorShapeProto& from); + TensorShapeProto(TensorShapeProto&& from) noexcept + : TensorShapeProto() { + *this = ::std::move(from); + } + + inline TensorShapeProto& operator=(const TensorShapeProto& from) { + CopyFrom(from); + return *this; + } + inline TensorShapeProto& operator=(TensorShapeProto&& from) noexcept { + if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields(); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TensorShapeProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TensorShapeProto* internal_default_instance() { + return reinterpret_cast( + &_TensorShapeProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 11; + + friend void swap(TensorShapeProto& a, TensorShapeProto& b) { + a.Swap(&b); + } + inline void Swap(TensorShapeProto* other) { + if (other == this) return; + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TensorShapeProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + TensorShapeProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TensorShapeProto& from); + void MergeFrom(const TensorShapeProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TensorShapeProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.TensorShapeProto"; + } + private: + inline ::PROTOBUF_NAMESPACE_ID::Arena* GetArenaNoVirtual() const { + return nullptr; + } + inline void* MaybeArenaPtr() const { + return nullptr; + } + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2eproto); + return ::descriptor_table_onnx_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef TensorShapeProto_Dimension Dimension; + + // accessors ------------------------------------------------------- + + enum : int { + kDimFieldNumber = 1, + }; + // repeated .onnx.TensorShapeProto.Dimension dim = 1; + int dim_size() const; + private: + int _internal_dim_size() const; + public: + void clear_dim(); + ::onnx::TensorShapeProto_Dimension* mutable_dim(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorShapeProto_Dimension >* + mutable_dim(); + private: + const ::onnx::TensorShapeProto_Dimension& _internal_dim(int index) const; + ::onnx::TensorShapeProto_Dimension* _internal_add_dim(); + public: + const ::onnx::TensorShapeProto_Dimension& dim(int index) const; + ::onnx::TensorShapeProto_Dimension* add_dim(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorShapeProto_Dimension >& + dim() const; + + // @@protoc_insertion_point(class_scope:onnx.TensorShapeProto) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorShapeProto_Dimension > dim_; + friend struct ::TableStruct_onnx_2eproto; +}; +// ------------------------------------------------------------------- + +class TypeProto_Tensor : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.TypeProto.Tensor) */ { + public: + TypeProto_Tensor(); + virtual ~TypeProto_Tensor(); + + TypeProto_Tensor(const TypeProto_Tensor& from); + TypeProto_Tensor(TypeProto_Tensor&& from) noexcept + : TypeProto_Tensor() { + *this = ::std::move(from); + } + + inline TypeProto_Tensor& operator=(const TypeProto_Tensor& from) { + CopyFrom(from); + return *this; + } + inline TypeProto_Tensor& operator=(TypeProto_Tensor&& from) noexcept { + if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields(); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TypeProto_Tensor& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TypeProto_Tensor* internal_default_instance() { + return reinterpret_cast( + &_TypeProto_Tensor_default_instance_); + } + static constexpr int kIndexInFileMessages = + 12; + + friend void swap(TypeProto_Tensor& a, TypeProto_Tensor& b) { + a.Swap(&b); + } + inline void Swap(TypeProto_Tensor* other) { + if (other == this) return; + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TypeProto_Tensor* New() const final { + return CreateMaybeMessage(nullptr); + } + + TypeProto_Tensor* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TypeProto_Tensor& from); + void MergeFrom(const TypeProto_Tensor& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TypeProto_Tensor* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.TypeProto.Tensor"; + } + private: + inline ::PROTOBUF_NAMESPACE_ID::Arena* GetArenaNoVirtual() const { + return nullptr; + } + inline void* MaybeArenaPtr() const { + return nullptr; + } + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2eproto); + return ::descriptor_table_onnx_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kShapeFieldNumber = 2, + kElemTypeFieldNumber = 1, + }; + // optional .onnx.TensorShapeProto shape = 2; + bool has_shape() const; + private: + bool _internal_has_shape() const; + public: + void clear_shape(); + const ::onnx::TensorShapeProto& shape() const; + ::onnx::TensorShapeProto* release_shape(); + ::onnx::TensorShapeProto* mutable_shape(); + void set_allocated_shape(::onnx::TensorShapeProto* shape); + private: + const ::onnx::TensorShapeProto& _internal_shape() const; + ::onnx::TensorShapeProto* _internal_mutable_shape(); + public: + + // optional int32 elem_type = 1; + bool has_elem_type() const; + private: + bool _internal_has_elem_type() const; + public: + void clear_elem_type(); + ::PROTOBUF_NAMESPACE_ID::int32 elem_type() const; + void set_elem_type(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_elem_type() const; + void _internal_set_elem_type(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // @@protoc_insertion_point(class_scope:onnx.TypeProto.Tensor) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::onnx::TensorShapeProto* shape_; + ::PROTOBUF_NAMESPACE_ID::int32 elem_type_; + friend struct ::TableStruct_onnx_2eproto; +}; +// ------------------------------------------------------------------- + +class TypeProto_Sequence : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.TypeProto.Sequence) */ { + public: + TypeProto_Sequence(); + virtual ~TypeProto_Sequence(); + + TypeProto_Sequence(const TypeProto_Sequence& from); + TypeProto_Sequence(TypeProto_Sequence&& from) noexcept + : TypeProto_Sequence() { + *this = ::std::move(from); + } + + inline TypeProto_Sequence& operator=(const TypeProto_Sequence& from) { + CopyFrom(from); + return *this; + } + inline TypeProto_Sequence& operator=(TypeProto_Sequence&& from) noexcept { + if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields(); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TypeProto_Sequence& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TypeProto_Sequence* internal_default_instance() { + return reinterpret_cast( + &_TypeProto_Sequence_default_instance_); + } + static constexpr int kIndexInFileMessages = + 13; + + friend void swap(TypeProto_Sequence& a, TypeProto_Sequence& b) { + a.Swap(&b); + } + inline void Swap(TypeProto_Sequence* other) { + if (other == this) return; + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TypeProto_Sequence* New() const final { + return CreateMaybeMessage(nullptr); + } + + TypeProto_Sequence* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TypeProto_Sequence& from); + void MergeFrom(const TypeProto_Sequence& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TypeProto_Sequence* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.TypeProto.Sequence"; + } + private: + inline ::PROTOBUF_NAMESPACE_ID::Arena* GetArenaNoVirtual() const { + return nullptr; + } + inline void* MaybeArenaPtr() const { + return nullptr; + } + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2eproto); + return ::descriptor_table_onnx_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kElemTypeFieldNumber = 1, + }; + // optional .onnx.TypeProto elem_type = 1; + bool has_elem_type() const; + private: + bool _internal_has_elem_type() const; + public: + void clear_elem_type(); + const ::onnx::TypeProto& elem_type() const; + ::onnx::TypeProto* release_elem_type(); + ::onnx::TypeProto* mutable_elem_type(); + void set_allocated_elem_type(::onnx::TypeProto* elem_type); + private: + const ::onnx::TypeProto& _internal_elem_type() const; + ::onnx::TypeProto* _internal_mutable_elem_type(); + public: + + // @@protoc_insertion_point(class_scope:onnx.TypeProto.Sequence) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::onnx::TypeProto* elem_type_; + friend struct ::TableStruct_onnx_2eproto; +}; +// ------------------------------------------------------------------- + +class TypeProto_Map : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.TypeProto.Map) */ { + public: + TypeProto_Map(); + virtual ~TypeProto_Map(); + + TypeProto_Map(const TypeProto_Map& from); + TypeProto_Map(TypeProto_Map&& from) noexcept + : TypeProto_Map() { + *this = ::std::move(from); + } + + inline TypeProto_Map& operator=(const TypeProto_Map& from) { + CopyFrom(from); + return *this; + } + inline TypeProto_Map& operator=(TypeProto_Map&& from) noexcept { + if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields(); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TypeProto_Map& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TypeProto_Map* internal_default_instance() { + return reinterpret_cast( + &_TypeProto_Map_default_instance_); + } + static constexpr int kIndexInFileMessages = + 14; + + friend void swap(TypeProto_Map& a, TypeProto_Map& b) { + a.Swap(&b); + } + inline void Swap(TypeProto_Map* other) { + if (other == this) return; + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TypeProto_Map* New() const final { + return CreateMaybeMessage(nullptr); + } + + TypeProto_Map* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TypeProto_Map& from); + void MergeFrom(const TypeProto_Map& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TypeProto_Map* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.TypeProto.Map"; + } + private: + inline ::PROTOBUF_NAMESPACE_ID::Arena* GetArenaNoVirtual() const { + return nullptr; + } + inline void* MaybeArenaPtr() const { + return nullptr; + } + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2eproto); + return ::descriptor_table_onnx_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kValueTypeFieldNumber = 2, + kKeyTypeFieldNumber = 1, + }; + // optional .onnx.TypeProto value_type = 2; + bool has_value_type() const; + private: + bool _internal_has_value_type() const; + public: + void clear_value_type(); + const ::onnx::TypeProto& value_type() const; + ::onnx::TypeProto* release_value_type(); + ::onnx::TypeProto* mutable_value_type(); + void set_allocated_value_type(::onnx::TypeProto* value_type); + private: + const ::onnx::TypeProto& _internal_value_type() const; + ::onnx::TypeProto* _internal_mutable_value_type(); + public: + + // optional int32 key_type = 1; + bool has_key_type() const; + private: + bool _internal_has_key_type() const; + public: + void clear_key_type(); + ::PROTOBUF_NAMESPACE_ID::int32 key_type() const; + void set_key_type(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_key_type() const; + void _internal_set_key_type(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // @@protoc_insertion_point(class_scope:onnx.TypeProto.Map) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::onnx::TypeProto* value_type_; + ::PROTOBUF_NAMESPACE_ID::int32 key_type_; + friend struct ::TableStruct_onnx_2eproto; +}; +// ------------------------------------------------------------------- + +class TypeProto : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.TypeProto) */ { + public: + TypeProto(); + virtual ~TypeProto(); + + TypeProto(const TypeProto& from); + TypeProto(TypeProto&& from) noexcept + : TypeProto() { + *this = ::std::move(from); + } + + inline TypeProto& operator=(const TypeProto& from) { + CopyFrom(from); + return *this; + } + inline TypeProto& operator=(TypeProto&& from) noexcept { + if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields(); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TypeProto& default_instance(); + + enum ValueCase { + kTensorType = 1, + kSequenceType = 4, + kMapType = 5, + VALUE_NOT_SET = 0, + }; + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TypeProto* internal_default_instance() { + return reinterpret_cast( + &_TypeProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 15; + + friend void swap(TypeProto& a, TypeProto& b) { + a.Swap(&b); + } + inline void Swap(TypeProto* other) { + if (other == this) return; + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TypeProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + TypeProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TypeProto& from); + void MergeFrom(const TypeProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TypeProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.TypeProto"; + } + private: + inline ::PROTOBUF_NAMESPACE_ID::Arena* GetArenaNoVirtual() const { + return nullptr; + } + inline void* MaybeArenaPtr() const { + return nullptr; + } + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2eproto); + return ::descriptor_table_onnx_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef TypeProto_Tensor Tensor; + typedef TypeProto_Sequence Sequence; + typedef TypeProto_Map Map; + + // accessors ------------------------------------------------------- + + enum : int { + kDenotationFieldNumber = 6, + kTensorTypeFieldNumber = 1, + kSequenceTypeFieldNumber = 4, + kMapTypeFieldNumber = 5, + }; + // optional string denotation = 6; + bool has_denotation() const; + private: + bool _internal_has_denotation() const; + public: + void clear_denotation(); + const std::string& denotation() const; + void set_denotation(const std::string& value); + void set_denotation(std::string&& value); + void set_denotation(const char* value); + void set_denotation(const char* value, size_t size); + std::string* mutable_denotation(); + std::string* release_denotation(); + void set_allocated_denotation(std::string* denotation); + private: + const std::string& _internal_denotation() const; + void _internal_set_denotation(const std::string& value); + std::string* _internal_mutable_denotation(); + public: + + // optional .onnx.TypeProto.Tensor tensor_type = 1; + bool has_tensor_type() const; + private: + bool _internal_has_tensor_type() const; + public: + void clear_tensor_type(); + const ::onnx::TypeProto_Tensor& tensor_type() const; + ::onnx::TypeProto_Tensor* release_tensor_type(); + ::onnx::TypeProto_Tensor* mutable_tensor_type(); + void set_allocated_tensor_type(::onnx::TypeProto_Tensor* tensor_type); + private: + const ::onnx::TypeProto_Tensor& _internal_tensor_type() const; + ::onnx::TypeProto_Tensor* _internal_mutable_tensor_type(); + public: + + // optional .onnx.TypeProto.Sequence sequence_type = 4; + bool has_sequence_type() const; + private: + bool _internal_has_sequence_type() const; + public: + void clear_sequence_type(); + const ::onnx::TypeProto_Sequence& sequence_type() const; + ::onnx::TypeProto_Sequence* release_sequence_type(); + ::onnx::TypeProto_Sequence* mutable_sequence_type(); + void set_allocated_sequence_type(::onnx::TypeProto_Sequence* sequence_type); + private: + const ::onnx::TypeProto_Sequence& _internal_sequence_type() const; + ::onnx::TypeProto_Sequence* _internal_mutable_sequence_type(); + public: + + // optional .onnx.TypeProto.Map map_type = 5; + bool has_map_type() const; + private: + bool _internal_has_map_type() const; + public: + void clear_map_type(); + const ::onnx::TypeProto_Map& map_type() const; + ::onnx::TypeProto_Map* release_map_type(); + ::onnx::TypeProto_Map* mutable_map_type(); + void set_allocated_map_type(::onnx::TypeProto_Map* map_type); + private: + const ::onnx::TypeProto_Map& _internal_map_type() const; + ::onnx::TypeProto_Map* _internal_mutable_map_type(); + public: + + void clear_value(); + ValueCase value_case() const; + // @@protoc_insertion_point(class_scope:onnx.TypeProto) + private: + class _Internal; + void set_has_tensor_type(); + void set_has_sequence_type(); + void set_has_map_type(); + + inline bool has_value() const; + inline void clear_has_value(); + + ::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr denotation_; + union ValueUnion { + ValueUnion() {} + ::onnx::TypeProto_Tensor* tensor_type_; + ::onnx::TypeProto_Sequence* sequence_type_; + ::onnx::TypeProto_Map* map_type_; + } value_; + ::PROTOBUF_NAMESPACE_ID::uint32 _oneof_case_[1]; + + friend struct ::TableStruct_onnx_2eproto; +}; +// ------------------------------------------------------------------- + +class OperatorSetIdProto : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.OperatorSetIdProto) */ { + public: + OperatorSetIdProto(); + virtual ~OperatorSetIdProto(); + + OperatorSetIdProto(const OperatorSetIdProto& from); + OperatorSetIdProto(OperatorSetIdProto&& from) noexcept + : OperatorSetIdProto() { + *this = ::std::move(from); + } + + inline OperatorSetIdProto& operator=(const OperatorSetIdProto& from) { + CopyFrom(from); + return *this; + } + inline OperatorSetIdProto& operator=(OperatorSetIdProto&& from) noexcept { + if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields(); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const OperatorSetIdProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const OperatorSetIdProto* internal_default_instance() { + return reinterpret_cast( + &_OperatorSetIdProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 16; + + friend void swap(OperatorSetIdProto& a, OperatorSetIdProto& b) { + a.Swap(&b); + } + inline void Swap(OperatorSetIdProto* other) { + if (other == this) return; + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline OperatorSetIdProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + OperatorSetIdProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const OperatorSetIdProto& from); + void MergeFrom(const OperatorSetIdProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(OperatorSetIdProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.OperatorSetIdProto"; + } + private: + inline ::PROTOBUF_NAMESPACE_ID::Arena* GetArenaNoVirtual() const { + return nullptr; + } + inline void* MaybeArenaPtr() const { + return nullptr; + } + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2eproto); + return ::descriptor_table_onnx_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kDomainFieldNumber = 1, + kVersionFieldNumber = 2, + }; + // optional string domain = 1; + bool has_domain() const; + private: + bool _internal_has_domain() const; + public: + void clear_domain(); + const std::string& domain() const; + void set_domain(const std::string& value); + void set_domain(std::string&& value); + void set_domain(const char* value); + void set_domain(const char* value, size_t size); + std::string* mutable_domain(); + std::string* release_domain(); + void set_allocated_domain(std::string* domain); + private: + const std::string& _internal_domain() const; + void _internal_set_domain(const std::string& value); + std::string* _internal_mutable_domain(); + public: + + // optional int64 version = 2; + bool has_version() const; + private: + bool _internal_has_version() const; + public: + void clear_version(); + ::PROTOBUF_NAMESPACE_ID::int64 version() const; + void set_version(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_version() const; + void _internal_set_version(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // @@protoc_insertion_point(class_scope:onnx.OperatorSetIdProto) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr domain_; + ::PROTOBUF_NAMESPACE_ID::int64 version_; + friend struct ::TableStruct_onnx_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// AttributeProto + +// optional string name = 1; +inline bool AttributeProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool AttributeProto::has_name() const { + return _internal_has_name(); +} +inline void AttributeProto::clear_name() { + name_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& AttributeProto::name() const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.name) + return _internal_name(); +} +inline void AttributeProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:onnx.AttributeProto.name) +} +inline std::string* AttributeProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:onnx.AttributeProto.name) + return _internal_mutable_name(); +} +inline const std::string& AttributeProto::_internal_name() const { + return name_.GetNoArena(); +} +inline void AttributeProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void AttributeProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.AttributeProto.name) +} +inline void AttributeProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.AttributeProto.name) +} +inline void AttributeProto::set_name(const char* value, size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.AttributeProto.name) +} +inline std::string* AttributeProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* AttributeProto::release_name() { + // @@protoc_insertion_point(field_release:onnx.AttributeProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void AttributeProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name); + // @@protoc_insertion_point(field_set_allocated:onnx.AttributeProto.name) +} + +// optional string ref_attr_name = 21; +inline bool AttributeProto::_internal_has_ref_attr_name() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + return value; +} +inline bool AttributeProto::has_ref_attr_name() const { + return _internal_has_ref_attr_name(); +} +inline void AttributeProto::clear_ref_attr_name() { + ref_attr_name_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000008u; +} +inline const std::string& AttributeProto::ref_attr_name() const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.ref_attr_name) + return _internal_ref_attr_name(); +} +inline void AttributeProto::set_ref_attr_name(const std::string& value) { + _internal_set_ref_attr_name(value); + // @@protoc_insertion_point(field_set:onnx.AttributeProto.ref_attr_name) +} +inline std::string* AttributeProto::mutable_ref_attr_name() { + // @@protoc_insertion_point(field_mutable:onnx.AttributeProto.ref_attr_name) + return _internal_mutable_ref_attr_name(); +} +inline const std::string& AttributeProto::_internal_ref_attr_name() const { + return ref_attr_name_.GetNoArena(); +} +inline void AttributeProto::_internal_set_ref_attr_name(const std::string& value) { + _has_bits_[0] |= 0x00000008u; + ref_attr_name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void AttributeProto::set_ref_attr_name(std::string&& value) { + _has_bits_[0] |= 0x00000008u; + ref_attr_name_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.AttributeProto.ref_attr_name) +} +inline void AttributeProto::set_ref_attr_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000008u; + ref_attr_name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.AttributeProto.ref_attr_name) +} +inline void AttributeProto::set_ref_attr_name(const char* value, size_t size) { + _has_bits_[0] |= 0x00000008u; + ref_attr_name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.AttributeProto.ref_attr_name) +} +inline std::string* AttributeProto::_internal_mutable_ref_attr_name() { + _has_bits_[0] |= 0x00000008u; + return ref_attr_name_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* AttributeProto::release_ref_attr_name() { + // @@protoc_insertion_point(field_release:onnx.AttributeProto.ref_attr_name) + if (!_internal_has_ref_attr_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000008u; + return ref_attr_name_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void AttributeProto::set_allocated_ref_attr_name(std::string* ref_attr_name) { + if (ref_attr_name != nullptr) { + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + ref_attr_name_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ref_attr_name); + // @@protoc_insertion_point(field_set_allocated:onnx.AttributeProto.ref_attr_name) +} + +// optional string doc_string = 13; +inline bool AttributeProto::_internal_has_doc_string() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool AttributeProto::has_doc_string() const { + return _internal_has_doc_string(); +} +inline void AttributeProto::clear_doc_string() { + doc_string_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000004u; +} +inline const std::string& AttributeProto::doc_string() const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.doc_string) + return _internal_doc_string(); +} +inline void AttributeProto::set_doc_string(const std::string& value) { + _internal_set_doc_string(value); + // @@protoc_insertion_point(field_set:onnx.AttributeProto.doc_string) +} +inline std::string* AttributeProto::mutable_doc_string() { + // @@protoc_insertion_point(field_mutable:onnx.AttributeProto.doc_string) + return _internal_mutable_doc_string(); +} +inline const std::string& AttributeProto::_internal_doc_string() const { + return doc_string_.GetNoArena(); +} +inline void AttributeProto::_internal_set_doc_string(const std::string& value) { + _has_bits_[0] |= 0x00000004u; + doc_string_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void AttributeProto::set_doc_string(std::string&& value) { + _has_bits_[0] |= 0x00000004u; + doc_string_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.AttributeProto.doc_string) +} +inline void AttributeProto::set_doc_string(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000004u; + doc_string_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.AttributeProto.doc_string) +} +inline void AttributeProto::set_doc_string(const char* value, size_t size) { + _has_bits_[0] |= 0x00000004u; + doc_string_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.AttributeProto.doc_string) +} +inline std::string* AttributeProto::_internal_mutable_doc_string() { + _has_bits_[0] |= 0x00000004u; + return doc_string_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* AttributeProto::release_doc_string() { + // @@protoc_insertion_point(field_release:onnx.AttributeProto.doc_string) + if (!_internal_has_doc_string()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000004u; + return doc_string_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void AttributeProto::set_allocated_doc_string(std::string* doc_string) { + if (doc_string != nullptr) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + doc_string_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), doc_string); + // @@protoc_insertion_point(field_set_allocated:onnx.AttributeProto.doc_string) +} + +// optional .onnx.AttributeProto.AttributeType type = 20; +inline bool AttributeProto::_internal_has_type() const { + bool value = (_has_bits_[0] & 0x00000200u) != 0; + return value; +} +inline bool AttributeProto::has_type() const { + return _internal_has_type(); +} +inline void AttributeProto::clear_type() { + type_ = 0; + _has_bits_[0] &= ~0x00000200u; +} +inline ::onnx::AttributeProto_AttributeType AttributeProto::_internal_type() const { + return static_cast< ::onnx::AttributeProto_AttributeType >(type_); +} +inline ::onnx::AttributeProto_AttributeType AttributeProto::type() const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.type) + return _internal_type(); +} +inline void AttributeProto::_internal_set_type(::onnx::AttributeProto_AttributeType value) { + assert(::onnx::AttributeProto_AttributeType_IsValid(value)); + _has_bits_[0] |= 0x00000200u; + type_ = value; +} +inline void AttributeProto::set_type(::onnx::AttributeProto_AttributeType value) { + _internal_set_type(value); + // @@protoc_insertion_point(field_set:onnx.AttributeProto.type) +} + +// optional float f = 2; +inline bool AttributeProto::_internal_has_f() const { + bool value = (_has_bits_[0] & 0x00000100u) != 0; + return value; +} +inline bool AttributeProto::has_f() const { + return _internal_has_f(); +} +inline void AttributeProto::clear_f() { + f_ = 0; + _has_bits_[0] &= ~0x00000100u; +} +inline float AttributeProto::_internal_f() const { + return f_; +} +inline float AttributeProto::f() const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.f) + return _internal_f(); +} +inline void AttributeProto::_internal_set_f(float value) { + _has_bits_[0] |= 0x00000100u; + f_ = value; +} +inline void AttributeProto::set_f(float value) { + _internal_set_f(value); + // @@protoc_insertion_point(field_set:onnx.AttributeProto.f) +} + +// optional int64 i = 3; +inline bool AttributeProto::_internal_has_i() const { + bool value = (_has_bits_[0] & 0x00000080u) != 0; + return value; +} +inline bool AttributeProto::has_i() const { + return _internal_has_i(); +} +inline void AttributeProto::clear_i() { + i_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000080u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 AttributeProto::_internal_i() const { + return i_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 AttributeProto::i() const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.i) + return _internal_i(); +} +inline void AttributeProto::_internal_set_i(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000080u; + i_ = value; +} +inline void AttributeProto::set_i(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_i(value); + // @@protoc_insertion_point(field_set:onnx.AttributeProto.i) +} + +// optional bytes s = 4; +inline bool AttributeProto::_internal_has_s() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool AttributeProto::has_s() const { + return _internal_has_s(); +} +inline void AttributeProto::clear_s() { + s_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& AttributeProto::s() const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.s) + return _internal_s(); +} +inline void AttributeProto::set_s(const std::string& value) { + _internal_set_s(value); + // @@protoc_insertion_point(field_set:onnx.AttributeProto.s) +} +inline std::string* AttributeProto::mutable_s() { + // @@protoc_insertion_point(field_mutable:onnx.AttributeProto.s) + return _internal_mutable_s(); +} +inline const std::string& AttributeProto::_internal_s() const { + return s_.GetNoArena(); +} +inline void AttributeProto::_internal_set_s(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + s_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void AttributeProto::set_s(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + s_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.AttributeProto.s) +} +inline void AttributeProto::set_s(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + s_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.AttributeProto.s) +} +inline void AttributeProto::set_s(const void* value, size_t size) { + _has_bits_[0] |= 0x00000002u; + s_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.AttributeProto.s) +} +inline std::string* AttributeProto::_internal_mutable_s() { + _has_bits_[0] |= 0x00000002u; + return s_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* AttributeProto::release_s() { + // @@protoc_insertion_point(field_release:onnx.AttributeProto.s) + if (!_internal_has_s()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return s_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void AttributeProto::set_allocated_s(std::string* s) { + if (s != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + s_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), s); + // @@protoc_insertion_point(field_set_allocated:onnx.AttributeProto.s) +} + +// optional .onnx.TensorProto t = 5; +inline bool AttributeProto::_internal_has_t() const { + bool value = (_has_bits_[0] & 0x00000010u) != 0; + PROTOBUF_ASSUME(!value || t_ != nullptr); + return value; +} +inline bool AttributeProto::has_t() const { + return _internal_has_t(); +} +inline void AttributeProto::clear_t() { + if (t_ != nullptr) t_->Clear(); + _has_bits_[0] &= ~0x00000010u; +} +inline const ::onnx::TensorProto& AttributeProto::_internal_t() const { + const ::onnx::TensorProto* p = t_; + return p != nullptr ? *p : *reinterpret_cast( + &::onnx::_TensorProto_default_instance_); +} +inline const ::onnx::TensorProto& AttributeProto::t() const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.t) + return _internal_t(); +} +inline ::onnx::TensorProto* AttributeProto::release_t() { + // @@protoc_insertion_point(field_release:onnx.AttributeProto.t) + _has_bits_[0] &= ~0x00000010u; + ::onnx::TensorProto* temp = t_; + t_ = nullptr; + return temp; +} +inline ::onnx::TensorProto* AttributeProto::_internal_mutable_t() { + _has_bits_[0] |= 0x00000010u; + if (t_ == nullptr) { + auto* p = CreateMaybeMessage<::onnx::TensorProto>(GetArenaNoVirtual()); + t_ = p; + } + return t_; +} +inline ::onnx::TensorProto* AttributeProto::mutable_t() { + // @@protoc_insertion_point(field_mutable:onnx.AttributeProto.t) + return _internal_mutable_t(); +} +inline void AttributeProto::set_allocated_t(::onnx::TensorProto* t) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArenaNoVirtual(); + if (message_arena == nullptr) { + delete t_; + } + if (t) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = nullptr; + if (message_arena != submessage_arena) { + t = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, t, submessage_arena); + } + _has_bits_[0] |= 0x00000010u; + } else { + _has_bits_[0] &= ~0x00000010u; + } + t_ = t; + // @@protoc_insertion_point(field_set_allocated:onnx.AttributeProto.t) +} + +// optional .onnx.GraphProto g = 6; +inline bool AttributeProto::_internal_has_g() const { + bool value = (_has_bits_[0] & 0x00000020u) != 0; + PROTOBUF_ASSUME(!value || g_ != nullptr); + return value; +} +inline bool AttributeProto::has_g() const { + return _internal_has_g(); +} +inline void AttributeProto::clear_g() { + if (g_ != nullptr) g_->Clear(); + _has_bits_[0] &= ~0x00000020u; +} +inline const ::onnx::GraphProto& AttributeProto::_internal_g() const { + const ::onnx::GraphProto* p = g_; + return p != nullptr ? *p : *reinterpret_cast( + &::onnx::_GraphProto_default_instance_); +} +inline const ::onnx::GraphProto& AttributeProto::g() const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.g) + return _internal_g(); +} +inline ::onnx::GraphProto* AttributeProto::release_g() { + // @@protoc_insertion_point(field_release:onnx.AttributeProto.g) + _has_bits_[0] &= ~0x00000020u; + ::onnx::GraphProto* temp = g_; + g_ = nullptr; + return temp; +} +inline ::onnx::GraphProto* AttributeProto::_internal_mutable_g() { + _has_bits_[0] |= 0x00000020u; + if (g_ == nullptr) { + auto* p = CreateMaybeMessage<::onnx::GraphProto>(GetArenaNoVirtual()); + g_ = p; + } + return g_; +} +inline ::onnx::GraphProto* AttributeProto::mutable_g() { + // @@protoc_insertion_point(field_mutable:onnx.AttributeProto.g) + return _internal_mutable_g(); +} +inline void AttributeProto::set_allocated_g(::onnx::GraphProto* g) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArenaNoVirtual(); + if (message_arena == nullptr) { + delete g_; + } + if (g) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = nullptr; + if (message_arena != submessage_arena) { + g = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, g, submessage_arena); + } + _has_bits_[0] |= 0x00000020u; + } else { + _has_bits_[0] &= ~0x00000020u; + } + g_ = g; + // @@protoc_insertion_point(field_set_allocated:onnx.AttributeProto.g) +} + +// optional .onnx.SparseTensorProto sparse_tensor = 22; +inline bool AttributeProto::_internal_has_sparse_tensor() const { + bool value = (_has_bits_[0] & 0x00000040u) != 0; + PROTOBUF_ASSUME(!value || sparse_tensor_ != nullptr); + return value; +} +inline bool AttributeProto::has_sparse_tensor() const { + return _internal_has_sparse_tensor(); +} +inline void AttributeProto::clear_sparse_tensor() { + if (sparse_tensor_ != nullptr) sparse_tensor_->Clear(); + _has_bits_[0] &= ~0x00000040u; +} +inline const ::onnx::SparseTensorProto& AttributeProto::_internal_sparse_tensor() const { + const ::onnx::SparseTensorProto* p = sparse_tensor_; + return p != nullptr ? *p : *reinterpret_cast( + &::onnx::_SparseTensorProto_default_instance_); +} +inline const ::onnx::SparseTensorProto& AttributeProto::sparse_tensor() const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.sparse_tensor) + return _internal_sparse_tensor(); +} +inline ::onnx::SparseTensorProto* AttributeProto::release_sparse_tensor() { + // @@protoc_insertion_point(field_release:onnx.AttributeProto.sparse_tensor) + _has_bits_[0] &= ~0x00000040u; + ::onnx::SparseTensorProto* temp = sparse_tensor_; + sparse_tensor_ = nullptr; + return temp; +} +inline ::onnx::SparseTensorProto* AttributeProto::_internal_mutable_sparse_tensor() { + _has_bits_[0] |= 0x00000040u; + if (sparse_tensor_ == nullptr) { + auto* p = CreateMaybeMessage<::onnx::SparseTensorProto>(GetArenaNoVirtual()); + sparse_tensor_ = p; + } + return sparse_tensor_; +} +inline ::onnx::SparseTensorProto* AttributeProto::mutable_sparse_tensor() { + // @@protoc_insertion_point(field_mutable:onnx.AttributeProto.sparse_tensor) + return _internal_mutable_sparse_tensor(); +} +inline void AttributeProto::set_allocated_sparse_tensor(::onnx::SparseTensorProto* sparse_tensor) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArenaNoVirtual(); + if (message_arena == nullptr) { + delete sparse_tensor_; + } + if (sparse_tensor) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = nullptr; + if (message_arena != submessage_arena) { + sparse_tensor = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, sparse_tensor, submessage_arena); + } + _has_bits_[0] |= 0x00000040u; + } else { + _has_bits_[0] &= ~0x00000040u; + } + sparse_tensor_ = sparse_tensor; + // @@protoc_insertion_point(field_set_allocated:onnx.AttributeProto.sparse_tensor) +} + +// repeated float floats = 7; +inline int AttributeProto::_internal_floats_size() const { + return floats_.size(); +} +inline int AttributeProto::floats_size() const { + return _internal_floats_size(); +} +inline void AttributeProto::clear_floats() { + floats_.Clear(); +} +inline float AttributeProto::_internal_floats(int index) const { + return floats_.Get(index); +} +inline float AttributeProto::floats(int index) const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.floats) + return _internal_floats(index); +} +inline void AttributeProto::set_floats(int index, float value) { + floats_.Set(index, value); + // @@protoc_insertion_point(field_set:onnx.AttributeProto.floats) +} +inline void AttributeProto::_internal_add_floats(float value) { + floats_.Add(value); +} +inline void AttributeProto::add_floats(float value) { + _internal_add_floats(value); + // @@protoc_insertion_point(field_add:onnx.AttributeProto.floats) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +AttributeProto::_internal_floats() const { + return floats_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +AttributeProto::floats() const { + // @@protoc_insertion_point(field_list:onnx.AttributeProto.floats) + return _internal_floats(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +AttributeProto::_internal_mutable_floats() { + return &floats_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +AttributeProto::mutable_floats() { + // @@protoc_insertion_point(field_mutable_list:onnx.AttributeProto.floats) + return _internal_mutable_floats(); +} + +// repeated int64 ints = 8; +inline int AttributeProto::_internal_ints_size() const { + return ints_.size(); +} +inline int AttributeProto::ints_size() const { + return _internal_ints_size(); +} +inline void AttributeProto::clear_ints() { + ints_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 AttributeProto::_internal_ints(int index) const { + return ints_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 AttributeProto::ints(int index) const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.ints) + return _internal_ints(index); +} +inline void AttributeProto::set_ints(int index, ::PROTOBUF_NAMESPACE_ID::int64 value) { + ints_.Set(index, value); + // @@protoc_insertion_point(field_set:onnx.AttributeProto.ints) +} +inline void AttributeProto::_internal_add_ints(::PROTOBUF_NAMESPACE_ID::int64 value) { + ints_.Add(value); +} +inline void AttributeProto::add_ints(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_add_ints(value); + // @@protoc_insertion_point(field_add:onnx.AttributeProto.ints) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +AttributeProto::_internal_ints() const { + return ints_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +AttributeProto::ints() const { + // @@protoc_insertion_point(field_list:onnx.AttributeProto.ints) + return _internal_ints(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +AttributeProto::_internal_mutable_ints() { + return &ints_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +AttributeProto::mutable_ints() { + // @@protoc_insertion_point(field_mutable_list:onnx.AttributeProto.ints) + return _internal_mutable_ints(); +} + +// repeated bytes strings = 9; +inline int AttributeProto::_internal_strings_size() const { + return strings_.size(); +} +inline int AttributeProto::strings_size() const { + return _internal_strings_size(); +} +inline void AttributeProto::clear_strings() { + strings_.Clear(); +} +inline std::string* AttributeProto::add_strings() { + // @@protoc_insertion_point(field_add_mutable:onnx.AttributeProto.strings) + return _internal_add_strings(); +} +inline const std::string& AttributeProto::_internal_strings(int index) const { + return strings_.Get(index); +} +inline const std::string& AttributeProto::strings(int index) const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.strings) + return _internal_strings(index); +} +inline std::string* AttributeProto::mutable_strings(int index) { + // @@protoc_insertion_point(field_mutable:onnx.AttributeProto.strings) + return strings_.Mutable(index); +} +inline void AttributeProto::set_strings(int index, const std::string& value) { + // @@protoc_insertion_point(field_set:onnx.AttributeProto.strings) + strings_.Mutable(index)->assign(value); +} +inline void AttributeProto::set_strings(int index, std::string&& value) { + // @@protoc_insertion_point(field_set:onnx.AttributeProto.strings) + strings_.Mutable(index)->assign(std::move(value)); +} +inline void AttributeProto::set_strings(int index, const char* value) { + GOOGLE_DCHECK(value != nullptr); + strings_.Mutable(index)->assign(value); + // @@protoc_insertion_point(field_set_char:onnx.AttributeProto.strings) +} +inline void AttributeProto::set_strings(int index, const void* value, size_t size) { + strings_.Mutable(index)->assign( + reinterpret_cast(value), size); + // @@protoc_insertion_point(field_set_pointer:onnx.AttributeProto.strings) +} +inline std::string* AttributeProto::_internal_add_strings() { + return strings_.Add(); +} +inline void AttributeProto::add_strings(const std::string& value) { + strings_.Add()->assign(value); + // @@protoc_insertion_point(field_add:onnx.AttributeProto.strings) +} +inline void AttributeProto::add_strings(std::string&& value) { + strings_.Add(std::move(value)); + // @@protoc_insertion_point(field_add:onnx.AttributeProto.strings) +} +inline void AttributeProto::add_strings(const char* value) { + GOOGLE_DCHECK(value != nullptr); + strings_.Add()->assign(value); + // @@protoc_insertion_point(field_add_char:onnx.AttributeProto.strings) +} +inline void AttributeProto::add_strings(const void* value, size_t size) { + strings_.Add()->assign(reinterpret_cast(value), size); + // @@protoc_insertion_point(field_add_pointer:onnx.AttributeProto.strings) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& +AttributeProto::strings() const { + // @@protoc_insertion_point(field_list:onnx.AttributeProto.strings) + return strings_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* +AttributeProto::mutable_strings() { + // @@protoc_insertion_point(field_mutable_list:onnx.AttributeProto.strings) + return &strings_; +} + +// repeated .onnx.TensorProto tensors = 10; +inline int AttributeProto::_internal_tensors_size() const { + return tensors_.size(); +} +inline int AttributeProto::tensors_size() const { + return _internal_tensors_size(); +} +inline void AttributeProto::clear_tensors() { + tensors_.Clear(); +} +inline ::onnx::TensorProto* AttributeProto::mutable_tensors(int index) { + // @@protoc_insertion_point(field_mutable:onnx.AttributeProto.tensors) + return tensors_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorProto >* +AttributeProto::mutable_tensors() { + // @@protoc_insertion_point(field_mutable_list:onnx.AttributeProto.tensors) + return &tensors_; +} +inline const ::onnx::TensorProto& AttributeProto::_internal_tensors(int index) const { + return tensors_.Get(index); +} +inline const ::onnx::TensorProto& AttributeProto::tensors(int index) const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.tensors) + return _internal_tensors(index); +} +inline ::onnx::TensorProto* AttributeProto::_internal_add_tensors() { + return tensors_.Add(); +} +inline ::onnx::TensorProto* AttributeProto::add_tensors() { + // @@protoc_insertion_point(field_add:onnx.AttributeProto.tensors) + return _internal_add_tensors(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorProto >& +AttributeProto::tensors() const { + // @@protoc_insertion_point(field_list:onnx.AttributeProto.tensors) + return tensors_; +} + +// repeated .onnx.GraphProto graphs = 11; +inline int AttributeProto::_internal_graphs_size() const { + return graphs_.size(); +} +inline int AttributeProto::graphs_size() const { + return _internal_graphs_size(); +} +inline void AttributeProto::clear_graphs() { + graphs_.Clear(); +} +inline ::onnx::GraphProto* AttributeProto::mutable_graphs(int index) { + // @@protoc_insertion_point(field_mutable:onnx.AttributeProto.graphs) + return graphs_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::GraphProto >* +AttributeProto::mutable_graphs() { + // @@protoc_insertion_point(field_mutable_list:onnx.AttributeProto.graphs) + return &graphs_; +} +inline const ::onnx::GraphProto& AttributeProto::_internal_graphs(int index) const { + return graphs_.Get(index); +} +inline const ::onnx::GraphProto& AttributeProto::graphs(int index) const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.graphs) + return _internal_graphs(index); +} +inline ::onnx::GraphProto* AttributeProto::_internal_add_graphs() { + return graphs_.Add(); +} +inline ::onnx::GraphProto* AttributeProto::add_graphs() { + // @@protoc_insertion_point(field_add:onnx.AttributeProto.graphs) + return _internal_add_graphs(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::GraphProto >& +AttributeProto::graphs() const { + // @@protoc_insertion_point(field_list:onnx.AttributeProto.graphs) + return graphs_; +} + +// repeated .onnx.SparseTensorProto sparse_tensors = 23; +inline int AttributeProto::_internal_sparse_tensors_size() const { + return sparse_tensors_.size(); +} +inline int AttributeProto::sparse_tensors_size() const { + return _internal_sparse_tensors_size(); +} +inline void AttributeProto::clear_sparse_tensors() { + sparse_tensors_.Clear(); +} +inline ::onnx::SparseTensorProto* AttributeProto::mutable_sparse_tensors(int index) { + // @@protoc_insertion_point(field_mutable:onnx.AttributeProto.sparse_tensors) + return sparse_tensors_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::SparseTensorProto >* +AttributeProto::mutable_sparse_tensors() { + // @@protoc_insertion_point(field_mutable_list:onnx.AttributeProto.sparse_tensors) + return &sparse_tensors_; +} +inline const ::onnx::SparseTensorProto& AttributeProto::_internal_sparse_tensors(int index) const { + return sparse_tensors_.Get(index); +} +inline const ::onnx::SparseTensorProto& AttributeProto::sparse_tensors(int index) const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.sparse_tensors) + return _internal_sparse_tensors(index); +} +inline ::onnx::SparseTensorProto* AttributeProto::_internal_add_sparse_tensors() { + return sparse_tensors_.Add(); +} +inline ::onnx::SparseTensorProto* AttributeProto::add_sparse_tensors() { + // @@protoc_insertion_point(field_add:onnx.AttributeProto.sparse_tensors) + return _internal_add_sparse_tensors(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::SparseTensorProto >& +AttributeProto::sparse_tensors() const { + // @@protoc_insertion_point(field_list:onnx.AttributeProto.sparse_tensors) + return sparse_tensors_; +} + +// ------------------------------------------------------------------- + +// ValueInfoProto + +// optional string name = 1; +inline bool ValueInfoProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool ValueInfoProto::has_name() const { + return _internal_has_name(); +} +inline void ValueInfoProto::clear_name() { + name_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& ValueInfoProto::name() const { + // @@protoc_insertion_point(field_get:onnx.ValueInfoProto.name) + return _internal_name(); +} +inline void ValueInfoProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:onnx.ValueInfoProto.name) +} +inline std::string* ValueInfoProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:onnx.ValueInfoProto.name) + return _internal_mutable_name(); +} +inline const std::string& ValueInfoProto::_internal_name() const { + return name_.GetNoArena(); +} +inline void ValueInfoProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void ValueInfoProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.ValueInfoProto.name) +} +inline void ValueInfoProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.ValueInfoProto.name) +} +inline void ValueInfoProto::set_name(const char* value, size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.ValueInfoProto.name) +} +inline std::string* ValueInfoProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* ValueInfoProto::release_name() { + // @@protoc_insertion_point(field_release:onnx.ValueInfoProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void ValueInfoProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name); + // @@protoc_insertion_point(field_set_allocated:onnx.ValueInfoProto.name) +} + +// optional .onnx.TypeProto type = 2; +inline bool ValueInfoProto::_internal_has_type() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + PROTOBUF_ASSUME(!value || type_ != nullptr); + return value; +} +inline bool ValueInfoProto::has_type() const { + return _internal_has_type(); +} +inline void ValueInfoProto::clear_type() { + if (type_ != nullptr) type_->Clear(); + _has_bits_[0] &= ~0x00000004u; +} +inline const ::onnx::TypeProto& ValueInfoProto::_internal_type() const { + const ::onnx::TypeProto* p = type_; + return p != nullptr ? *p : *reinterpret_cast( + &::onnx::_TypeProto_default_instance_); +} +inline const ::onnx::TypeProto& ValueInfoProto::type() const { + // @@protoc_insertion_point(field_get:onnx.ValueInfoProto.type) + return _internal_type(); +} +inline ::onnx::TypeProto* ValueInfoProto::release_type() { + // @@protoc_insertion_point(field_release:onnx.ValueInfoProto.type) + _has_bits_[0] &= ~0x00000004u; + ::onnx::TypeProto* temp = type_; + type_ = nullptr; + return temp; +} +inline ::onnx::TypeProto* ValueInfoProto::_internal_mutable_type() { + _has_bits_[0] |= 0x00000004u; + if (type_ == nullptr) { + auto* p = CreateMaybeMessage<::onnx::TypeProto>(GetArenaNoVirtual()); + type_ = p; + } + return type_; +} +inline ::onnx::TypeProto* ValueInfoProto::mutable_type() { + // @@protoc_insertion_point(field_mutable:onnx.ValueInfoProto.type) + return _internal_mutable_type(); +} +inline void ValueInfoProto::set_allocated_type(::onnx::TypeProto* type) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArenaNoVirtual(); + if (message_arena == nullptr) { + delete type_; + } + if (type) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = nullptr; + if (message_arena != submessage_arena) { + type = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, type, submessage_arena); + } + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + type_ = type; + // @@protoc_insertion_point(field_set_allocated:onnx.ValueInfoProto.type) +} + +// optional string doc_string = 3; +inline bool ValueInfoProto::_internal_has_doc_string() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool ValueInfoProto::has_doc_string() const { + return _internal_has_doc_string(); +} +inline void ValueInfoProto::clear_doc_string() { + doc_string_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& ValueInfoProto::doc_string() const { + // @@protoc_insertion_point(field_get:onnx.ValueInfoProto.doc_string) + return _internal_doc_string(); +} +inline void ValueInfoProto::set_doc_string(const std::string& value) { + _internal_set_doc_string(value); + // @@protoc_insertion_point(field_set:onnx.ValueInfoProto.doc_string) +} +inline std::string* ValueInfoProto::mutable_doc_string() { + // @@protoc_insertion_point(field_mutable:onnx.ValueInfoProto.doc_string) + return _internal_mutable_doc_string(); +} +inline const std::string& ValueInfoProto::_internal_doc_string() const { + return doc_string_.GetNoArena(); +} +inline void ValueInfoProto::_internal_set_doc_string(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + doc_string_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void ValueInfoProto::set_doc_string(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + doc_string_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.ValueInfoProto.doc_string) +} +inline void ValueInfoProto::set_doc_string(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + doc_string_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.ValueInfoProto.doc_string) +} +inline void ValueInfoProto::set_doc_string(const char* value, size_t size) { + _has_bits_[0] |= 0x00000002u; + doc_string_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.ValueInfoProto.doc_string) +} +inline std::string* ValueInfoProto::_internal_mutable_doc_string() { + _has_bits_[0] |= 0x00000002u; + return doc_string_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* ValueInfoProto::release_doc_string() { + // @@protoc_insertion_point(field_release:onnx.ValueInfoProto.doc_string) + if (!_internal_has_doc_string()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return doc_string_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void ValueInfoProto::set_allocated_doc_string(std::string* doc_string) { + if (doc_string != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + doc_string_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), doc_string); + // @@protoc_insertion_point(field_set_allocated:onnx.ValueInfoProto.doc_string) +} + +// ------------------------------------------------------------------- + +// NodeProto + +// repeated string input = 1; +inline int NodeProto::_internal_input_size() const { + return input_.size(); +} +inline int NodeProto::input_size() const { + return _internal_input_size(); +} +inline void NodeProto::clear_input() { + input_.Clear(); +} +inline std::string* NodeProto::add_input() { + // @@protoc_insertion_point(field_add_mutable:onnx.NodeProto.input) + return _internal_add_input(); +} +inline const std::string& NodeProto::_internal_input(int index) const { + return input_.Get(index); +} +inline const std::string& NodeProto::input(int index) const { + // @@protoc_insertion_point(field_get:onnx.NodeProto.input) + return _internal_input(index); +} +inline std::string* NodeProto::mutable_input(int index) { + // @@protoc_insertion_point(field_mutable:onnx.NodeProto.input) + return input_.Mutable(index); +} +inline void NodeProto::set_input(int index, const std::string& value) { + // @@protoc_insertion_point(field_set:onnx.NodeProto.input) + input_.Mutable(index)->assign(value); +} +inline void NodeProto::set_input(int index, std::string&& value) { + // @@protoc_insertion_point(field_set:onnx.NodeProto.input) + input_.Mutable(index)->assign(std::move(value)); +} +inline void NodeProto::set_input(int index, const char* value) { + GOOGLE_DCHECK(value != nullptr); + input_.Mutable(index)->assign(value); + // @@protoc_insertion_point(field_set_char:onnx.NodeProto.input) +} +inline void NodeProto::set_input(int index, const char* value, size_t size) { + input_.Mutable(index)->assign( + reinterpret_cast(value), size); + // @@protoc_insertion_point(field_set_pointer:onnx.NodeProto.input) +} +inline std::string* NodeProto::_internal_add_input() { + return input_.Add(); +} +inline void NodeProto::add_input(const std::string& value) { + input_.Add()->assign(value); + // @@protoc_insertion_point(field_add:onnx.NodeProto.input) +} +inline void NodeProto::add_input(std::string&& value) { + input_.Add(std::move(value)); + // @@protoc_insertion_point(field_add:onnx.NodeProto.input) +} +inline void NodeProto::add_input(const char* value) { + GOOGLE_DCHECK(value != nullptr); + input_.Add()->assign(value); + // @@protoc_insertion_point(field_add_char:onnx.NodeProto.input) +} +inline void NodeProto::add_input(const char* value, size_t size) { + input_.Add()->assign(reinterpret_cast(value), size); + // @@protoc_insertion_point(field_add_pointer:onnx.NodeProto.input) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& +NodeProto::input() const { + // @@protoc_insertion_point(field_list:onnx.NodeProto.input) + return input_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* +NodeProto::mutable_input() { + // @@protoc_insertion_point(field_mutable_list:onnx.NodeProto.input) + return &input_; +} + +// repeated string output = 2; +inline int NodeProto::_internal_output_size() const { + return output_.size(); +} +inline int NodeProto::output_size() const { + return _internal_output_size(); +} +inline void NodeProto::clear_output() { + output_.Clear(); +} +inline std::string* NodeProto::add_output() { + // @@protoc_insertion_point(field_add_mutable:onnx.NodeProto.output) + return _internal_add_output(); +} +inline const std::string& NodeProto::_internal_output(int index) const { + return output_.Get(index); +} +inline const std::string& NodeProto::output(int index) const { + // @@protoc_insertion_point(field_get:onnx.NodeProto.output) + return _internal_output(index); +} +inline std::string* NodeProto::mutable_output(int index) { + // @@protoc_insertion_point(field_mutable:onnx.NodeProto.output) + return output_.Mutable(index); +} +inline void NodeProto::set_output(int index, const std::string& value) { + // @@protoc_insertion_point(field_set:onnx.NodeProto.output) + output_.Mutable(index)->assign(value); +} +inline void NodeProto::set_output(int index, std::string&& value) { + // @@protoc_insertion_point(field_set:onnx.NodeProto.output) + output_.Mutable(index)->assign(std::move(value)); +} +inline void NodeProto::set_output(int index, const char* value) { + GOOGLE_DCHECK(value != nullptr); + output_.Mutable(index)->assign(value); + // @@protoc_insertion_point(field_set_char:onnx.NodeProto.output) +} +inline void NodeProto::set_output(int index, const char* value, size_t size) { + output_.Mutable(index)->assign( + reinterpret_cast(value), size); + // @@protoc_insertion_point(field_set_pointer:onnx.NodeProto.output) +} +inline std::string* NodeProto::_internal_add_output() { + return output_.Add(); +} +inline void NodeProto::add_output(const std::string& value) { + output_.Add()->assign(value); + // @@protoc_insertion_point(field_add:onnx.NodeProto.output) +} +inline void NodeProto::add_output(std::string&& value) { + output_.Add(std::move(value)); + // @@protoc_insertion_point(field_add:onnx.NodeProto.output) +} +inline void NodeProto::add_output(const char* value) { + GOOGLE_DCHECK(value != nullptr); + output_.Add()->assign(value); + // @@protoc_insertion_point(field_add_char:onnx.NodeProto.output) +} +inline void NodeProto::add_output(const char* value, size_t size) { + output_.Add()->assign(reinterpret_cast(value), size); + // @@protoc_insertion_point(field_add_pointer:onnx.NodeProto.output) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& +NodeProto::output() const { + // @@protoc_insertion_point(field_list:onnx.NodeProto.output) + return output_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* +NodeProto::mutable_output() { + // @@protoc_insertion_point(field_mutable_list:onnx.NodeProto.output) + return &output_; +} + +// optional string name = 3; +inline bool NodeProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool NodeProto::has_name() const { + return _internal_has_name(); +} +inline void NodeProto::clear_name() { + name_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& NodeProto::name() const { + // @@protoc_insertion_point(field_get:onnx.NodeProto.name) + return _internal_name(); +} +inline void NodeProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:onnx.NodeProto.name) +} +inline std::string* NodeProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:onnx.NodeProto.name) + return _internal_mutable_name(); +} +inline const std::string& NodeProto::_internal_name() const { + return name_.GetNoArena(); +} +inline void NodeProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void NodeProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.NodeProto.name) +} +inline void NodeProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.NodeProto.name) +} +inline void NodeProto::set_name(const char* value, size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.NodeProto.name) +} +inline std::string* NodeProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* NodeProto::release_name() { + // @@protoc_insertion_point(field_release:onnx.NodeProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void NodeProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name); + // @@protoc_insertion_point(field_set_allocated:onnx.NodeProto.name) +} + +// optional string op_type = 4; +inline bool NodeProto::_internal_has_op_type() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool NodeProto::has_op_type() const { + return _internal_has_op_type(); +} +inline void NodeProto::clear_op_type() { + op_type_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& NodeProto::op_type() const { + // @@protoc_insertion_point(field_get:onnx.NodeProto.op_type) + return _internal_op_type(); +} +inline void NodeProto::set_op_type(const std::string& value) { + _internal_set_op_type(value); + // @@protoc_insertion_point(field_set:onnx.NodeProto.op_type) +} +inline std::string* NodeProto::mutable_op_type() { + // @@protoc_insertion_point(field_mutable:onnx.NodeProto.op_type) + return _internal_mutable_op_type(); +} +inline const std::string& NodeProto::_internal_op_type() const { + return op_type_.GetNoArena(); +} +inline void NodeProto::_internal_set_op_type(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + op_type_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void NodeProto::set_op_type(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + op_type_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.NodeProto.op_type) +} +inline void NodeProto::set_op_type(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + op_type_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.NodeProto.op_type) +} +inline void NodeProto::set_op_type(const char* value, size_t size) { + _has_bits_[0] |= 0x00000002u; + op_type_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.NodeProto.op_type) +} +inline std::string* NodeProto::_internal_mutable_op_type() { + _has_bits_[0] |= 0x00000002u; + return op_type_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* NodeProto::release_op_type() { + // @@protoc_insertion_point(field_release:onnx.NodeProto.op_type) + if (!_internal_has_op_type()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return op_type_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void NodeProto::set_allocated_op_type(std::string* op_type) { + if (op_type != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + op_type_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), op_type); + // @@protoc_insertion_point(field_set_allocated:onnx.NodeProto.op_type) +} + +// optional string domain = 7; +inline bool NodeProto::_internal_has_domain() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + return value; +} +inline bool NodeProto::has_domain() const { + return _internal_has_domain(); +} +inline void NodeProto::clear_domain() { + domain_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000008u; +} +inline const std::string& NodeProto::domain() const { + // @@protoc_insertion_point(field_get:onnx.NodeProto.domain) + return _internal_domain(); +} +inline void NodeProto::set_domain(const std::string& value) { + _internal_set_domain(value); + // @@protoc_insertion_point(field_set:onnx.NodeProto.domain) +} +inline std::string* NodeProto::mutable_domain() { + // @@protoc_insertion_point(field_mutable:onnx.NodeProto.domain) + return _internal_mutable_domain(); +} +inline const std::string& NodeProto::_internal_domain() const { + return domain_.GetNoArena(); +} +inline void NodeProto::_internal_set_domain(const std::string& value) { + _has_bits_[0] |= 0x00000008u; + domain_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void NodeProto::set_domain(std::string&& value) { + _has_bits_[0] |= 0x00000008u; + domain_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.NodeProto.domain) +} +inline void NodeProto::set_domain(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000008u; + domain_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.NodeProto.domain) +} +inline void NodeProto::set_domain(const char* value, size_t size) { + _has_bits_[0] |= 0x00000008u; + domain_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.NodeProto.domain) +} +inline std::string* NodeProto::_internal_mutable_domain() { + _has_bits_[0] |= 0x00000008u; + return domain_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* NodeProto::release_domain() { + // @@protoc_insertion_point(field_release:onnx.NodeProto.domain) + if (!_internal_has_domain()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000008u; + return domain_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void NodeProto::set_allocated_domain(std::string* domain) { + if (domain != nullptr) { + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + domain_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), domain); + // @@protoc_insertion_point(field_set_allocated:onnx.NodeProto.domain) +} + +// repeated .onnx.AttributeProto attribute = 5; +inline int NodeProto::_internal_attribute_size() const { + return attribute_.size(); +} +inline int NodeProto::attribute_size() const { + return _internal_attribute_size(); +} +inline void NodeProto::clear_attribute() { + attribute_.Clear(); +} +inline ::onnx::AttributeProto* NodeProto::mutable_attribute(int index) { + // @@protoc_insertion_point(field_mutable:onnx.NodeProto.attribute) + return attribute_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::AttributeProto >* +NodeProto::mutable_attribute() { + // @@protoc_insertion_point(field_mutable_list:onnx.NodeProto.attribute) + return &attribute_; +} +inline const ::onnx::AttributeProto& NodeProto::_internal_attribute(int index) const { + return attribute_.Get(index); +} +inline const ::onnx::AttributeProto& NodeProto::attribute(int index) const { + // @@protoc_insertion_point(field_get:onnx.NodeProto.attribute) + return _internal_attribute(index); +} +inline ::onnx::AttributeProto* NodeProto::_internal_add_attribute() { + return attribute_.Add(); +} +inline ::onnx::AttributeProto* NodeProto::add_attribute() { + // @@protoc_insertion_point(field_add:onnx.NodeProto.attribute) + return _internal_add_attribute(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::AttributeProto >& +NodeProto::attribute() const { + // @@protoc_insertion_point(field_list:onnx.NodeProto.attribute) + return attribute_; +} + +// optional string doc_string = 6; +inline bool NodeProto::_internal_has_doc_string() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool NodeProto::has_doc_string() const { + return _internal_has_doc_string(); +} +inline void NodeProto::clear_doc_string() { + doc_string_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000004u; +} +inline const std::string& NodeProto::doc_string() const { + // @@protoc_insertion_point(field_get:onnx.NodeProto.doc_string) + return _internal_doc_string(); +} +inline void NodeProto::set_doc_string(const std::string& value) { + _internal_set_doc_string(value); + // @@protoc_insertion_point(field_set:onnx.NodeProto.doc_string) +} +inline std::string* NodeProto::mutable_doc_string() { + // @@protoc_insertion_point(field_mutable:onnx.NodeProto.doc_string) + return _internal_mutable_doc_string(); +} +inline const std::string& NodeProto::_internal_doc_string() const { + return doc_string_.GetNoArena(); +} +inline void NodeProto::_internal_set_doc_string(const std::string& value) { + _has_bits_[0] |= 0x00000004u; + doc_string_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void NodeProto::set_doc_string(std::string&& value) { + _has_bits_[0] |= 0x00000004u; + doc_string_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.NodeProto.doc_string) +} +inline void NodeProto::set_doc_string(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000004u; + doc_string_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.NodeProto.doc_string) +} +inline void NodeProto::set_doc_string(const char* value, size_t size) { + _has_bits_[0] |= 0x00000004u; + doc_string_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.NodeProto.doc_string) +} +inline std::string* NodeProto::_internal_mutable_doc_string() { + _has_bits_[0] |= 0x00000004u; + return doc_string_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* NodeProto::release_doc_string() { + // @@protoc_insertion_point(field_release:onnx.NodeProto.doc_string) + if (!_internal_has_doc_string()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000004u; + return doc_string_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void NodeProto::set_allocated_doc_string(std::string* doc_string) { + if (doc_string != nullptr) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + doc_string_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), doc_string); + // @@protoc_insertion_point(field_set_allocated:onnx.NodeProto.doc_string) +} + +// ------------------------------------------------------------------- + +// ModelProto + +// optional int64 ir_version = 1; +inline bool ModelProto::_internal_has_ir_version() const { + bool value = (_has_bits_[0] & 0x00000020u) != 0; + return value; +} +inline bool ModelProto::has_ir_version() const { + return _internal_has_ir_version(); +} +inline void ModelProto::clear_ir_version() { + ir_version_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000020u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 ModelProto::_internal_ir_version() const { + return ir_version_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 ModelProto::ir_version() const { + // @@protoc_insertion_point(field_get:onnx.ModelProto.ir_version) + return _internal_ir_version(); +} +inline void ModelProto::_internal_set_ir_version(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000020u; + ir_version_ = value; +} +inline void ModelProto::set_ir_version(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_ir_version(value); + // @@protoc_insertion_point(field_set:onnx.ModelProto.ir_version) +} + +// repeated .onnx.OperatorSetIdProto opset_import = 8; +inline int ModelProto::_internal_opset_import_size() const { + return opset_import_.size(); +} +inline int ModelProto::opset_import_size() const { + return _internal_opset_import_size(); +} +inline void ModelProto::clear_opset_import() { + opset_import_.Clear(); +} +inline ::onnx::OperatorSetIdProto* ModelProto::mutable_opset_import(int index) { + // @@protoc_insertion_point(field_mutable:onnx.ModelProto.opset_import) + return opset_import_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::OperatorSetIdProto >* +ModelProto::mutable_opset_import() { + // @@protoc_insertion_point(field_mutable_list:onnx.ModelProto.opset_import) + return &opset_import_; +} +inline const ::onnx::OperatorSetIdProto& ModelProto::_internal_opset_import(int index) const { + return opset_import_.Get(index); +} +inline const ::onnx::OperatorSetIdProto& ModelProto::opset_import(int index) const { + // @@protoc_insertion_point(field_get:onnx.ModelProto.opset_import) + return _internal_opset_import(index); +} +inline ::onnx::OperatorSetIdProto* ModelProto::_internal_add_opset_import() { + return opset_import_.Add(); +} +inline ::onnx::OperatorSetIdProto* ModelProto::add_opset_import() { + // @@protoc_insertion_point(field_add:onnx.ModelProto.opset_import) + return _internal_add_opset_import(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::OperatorSetIdProto >& +ModelProto::opset_import() const { + // @@protoc_insertion_point(field_list:onnx.ModelProto.opset_import) + return opset_import_; +} + +// optional string producer_name = 2; +inline bool ModelProto::_internal_has_producer_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool ModelProto::has_producer_name() const { + return _internal_has_producer_name(); +} +inline void ModelProto::clear_producer_name() { + producer_name_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& ModelProto::producer_name() const { + // @@protoc_insertion_point(field_get:onnx.ModelProto.producer_name) + return _internal_producer_name(); +} +inline void ModelProto::set_producer_name(const std::string& value) { + _internal_set_producer_name(value); + // @@protoc_insertion_point(field_set:onnx.ModelProto.producer_name) +} +inline std::string* ModelProto::mutable_producer_name() { + // @@protoc_insertion_point(field_mutable:onnx.ModelProto.producer_name) + return _internal_mutable_producer_name(); +} +inline const std::string& ModelProto::_internal_producer_name() const { + return producer_name_.GetNoArena(); +} +inline void ModelProto::_internal_set_producer_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + producer_name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void ModelProto::set_producer_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + producer_name_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.ModelProto.producer_name) +} +inline void ModelProto::set_producer_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + producer_name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.ModelProto.producer_name) +} +inline void ModelProto::set_producer_name(const char* value, size_t size) { + _has_bits_[0] |= 0x00000001u; + producer_name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.ModelProto.producer_name) +} +inline std::string* ModelProto::_internal_mutable_producer_name() { + _has_bits_[0] |= 0x00000001u; + return producer_name_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* ModelProto::release_producer_name() { + // @@protoc_insertion_point(field_release:onnx.ModelProto.producer_name) + if (!_internal_has_producer_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return producer_name_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void ModelProto::set_allocated_producer_name(std::string* producer_name) { + if (producer_name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + producer_name_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), producer_name); + // @@protoc_insertion_point(field_set_allocated:onnx.ModelProto.producer_name) +} + +// optional string producer_version = 3; +inline bool ModelProto::_internal_has_producer_version() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool ModelProto::has_producer_version() const { + return _internal_has_producer_version(); +} +inline void ModelProto::clear_producer_version() { + producer_version_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& ModelProto::producer_version() const { + // @@protoc_insertion_point(field_get:onnx.ModelProto.producer_version) + return _internal_producer_version(); +} +inline void ModelProto::set_producer_version(const std::string& value) { + _internal_set_producer_version(value); + // @@protoc_insertion_point(field_set:onnx.ModelProto.producer_version) +} +inline std::string* ModelProto::mutable_producer_version() { + // @@protoc_insertion_point(field_mutable:onnx.ModelProto.producer_version) + return _internal_mutable_producer_version(); +} +inline const std::string& ModelProto::_internal_producer_version() const { + return producer_version_.GetNoArena(); +} +inline void ModelProto::_internal_set_producer_version(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + producer_version_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void ModelProto::set_producer_version(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + producer_version_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.ModelProto.producer_version) +} +inline void ModelProto::set_producer_version(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + producer_version_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.ModelProto.producer_version) +} +inline void ModelProto::set_producer_version(const char* value, size_t size) { + _has_bits_[0] |= 0x00000002u; + producer_version_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.ModelProto.producer_version) +} +inline std::string* ModelProto::_internal_mutable_producer_version() { + _has_bits_[0] |= 0x00000002u; + return producer_version_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* ModelProto::release_producer_version() { + // @@protoc_insertion_point(field_release:onnx.ModelProto.producer_version) + if (!_internal_has_producer_version()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return producer_version_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void ModelProto::set_allocated_producer_version(std::string* producer_version) { + if (producer_version != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + producer_version_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), producer_version); + // @@protoc_insertion_point(field_set_allocated:onnx.ModelProto.producer_version) +} + +// optional string domain = 4; +inline bool ModelProto::_internal_has_domain() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool ModelProto::has_domain() const { + return _internal_has_domain(); +} +inline void ModelProto::clear_domain() { + domain_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000004u; +} +inline const std::string& ModelProto::domain() const { + // @@protoc_insertion_point(field_get:onnx.ModelProto.domain) + return _internal_domain(); +} +inline void ModelProto::set_domain(const std::string& value) { + _internal_set_domain(value); + // @@protoc_insertion_point(field_set:onnx.ModelProto.domain) +} +inline std::string* ModelProto::mutable_domain() { + // @@protoc_insertion_point(field_mutable:onnx.ModelProto.domain) + return _internal_mutable_domain(); +} +inline const std::string& ModelProto::_internal_domain() const { + return domain_.GetNoArena(); +} +inline void ModelProto::_internal_set_domain(const std::string& value) { + _has_bits_[0] |= 0x00000004u; + domain_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void ModelProto::set_domain(std::string&& value) { + _has_bits_[0] |= 0x00000004u; + domain_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.ModelProto.domain) +} +inline void ModelProto::set_domain(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000004u; + domain_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.ModelProto.domain) +} +inline void ModelProto::set_domain(const char* value, size_t size) { + _has_bits_[0] |= 0x00000004u; + domain_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.ModelProto.domain) +} +inline std::string* ModelProto::_internal_mutable_domain() { + _has_bits_[0] |= 0x00000004u; + return domain_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* ModelProto::release_domain() { + // @@protoc_insertion_point(field_release:onnx.ModelProto.domain) + if (!_internal_has_domain()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000004u; + return domain_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void ModelProto::set_allocated_domain(std::string* domain) { + if (domain != nullptr) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + domain_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), domain); + // @@protoc_insertion_point(field_set_allocated:onnx.ModelProto.domain) +} + +// optional int64 model_version = 5; +inline bool ModelProto::_internal_has_model_version() const { + bool value = (_has_bits_[0] & 0x00000040u) != 0; + return value; +} +inline bool ModelProto::has_model_version() const { + return _internal_has_model_version(); +} +inline void ModelProto::clear_model_version() { + model_version_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000040u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 ModelProto::_internal_model_version() const { + return model_version_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 ModelProto::model_version() const { + // @@protoc_insertion_point(field_get:onnx.ModelProto.model_version) + return _internal_model_version(); +} +inline void ModelProto::_internal_set_model_version(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000040u; + model_version_ = value; +} +inline void ModelProto::set_model_version(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_model_version(value); + // @@protoc_insertion_point(field_set:onnx.ModelProto.model_version) +} + +// optional string doc_string = 6; +inline bool ModelProto::_internal_has_doc_string() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + return value; +} +inline bool ModelProto::has_doc_string() const { + return _internal_has_doc_string(); +} +inline void ModelProto::clear_doc_string() { + doc_string_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000008u; +} +inline const std::string& ModelProto::doc_string() const { + // @@protoc_insertion_point(field_get:onnx.ModelProto.doc_string) + return _internal_doc_string(); +} +inline void ModelProto::set_doc_string(const std::string& value) { + _internal_set_doc_string(value); + // @@protoc_insertion_point(field_set:onnx.ModelProto.doc_string) +} +inline std::string* ModelProto::mutable_doc_string() { + // @@protoc_insertion_point(field_mutable:onnx.ModelProto.doc_string) + return _internal_mutable_doc_string(); +} +inline const std::string& ModelProto::_internal_doc_string() const { + return doc_string_.GetNoArena(); +} +inline void ModelProto::_internal_set_doc_string(const std::string& value) { + _has_bits_[0] |= 0x00000008u; + doc_string_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void ModelProto::set_doc_string(std::string&& value) { + _has_bits_[0] |= 0x00000008u; + doc_string_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.ModelProto.doc_string) +} +inline void ModelProto::set_doc_string(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000008u; + doc_string_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.ModelProto.doc_string) +} +inline void ModelProto::set_doc_string(const char* value, size_t size) { + _has_bits_[0] |= 0x00000008u; + doc_string_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.ModelProto.doc_string) +} +inline std::string* ModelProto::_internal_mutable_doc_string() { + _has_bits_[0] |= 0x00000008u; + return doc_string_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* ModelProto::release_doc_string() { + // @@protoc_insertion_point(field_release:onnx.ModelProto.doc_string) + if (!_internal_has_doc_string()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000008u; + return doc_string_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void ModelProto::set_allocated_doc_string(std::string* doc_string) { + if (doc_string != nullptr) { + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + doc_string_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), doc_string); + // @@protoc_insertion_point(field_set_allocated:onnx.ModelProto.doc_string) +} + +// optional .onnx.GraphProto graph = 7; +inline bool ModelProto::_internal_has_graph() const { + bool value = (_has_bits_[0] & 0x00000010u) != 0; + PROTOBUF_ASSUME(!value || graph_ != nullptr); + return value; +} +inline bool ModelProto::has_graph() const { + return _internal_has_graph(); +} +inline void ModelProto::clear_graph() { + if (graph_ != nullptr) graph_->Clear(); + _has_bits_[0] &= ~0x00000010u; +} +inline const ::onnx::GraphProto& ModelProto::_internal_graph() const { + const ::onnx::GraphProto* p = graph_; + return p != nullptr ? *p : *reinterpret_cast( + &::onnx::_GraphProto_default_instance_); +} +inline const ::onnx::GraphProto& ModelProto::graph() const { + // @@protoc_insertion_point(field_get:onnx.ModelProto.graph) + return _internal_graph(); +} +inline ::onnx::GraphProto* ModelProto::release_graph() { + // @@protoc_insertion_point(field_release:onnx.ModelProto.graph) + _has_bits_[0] &= ~0x00000010u; + ::onnx::GraphProto* temp = graph_; + graph_ = nullptr; + return temp; +} +inline ::onnx::GraphProto* ModelProto::_internal_mutable_graph() { + _has_bits_[0] |= 0x00000010u; + if (graph_ == nullptr) { + auto* p = CreateMaybeMessage<::onnx::GraphProto>(GetArenaNoVirtual()); + graph_ = p; + } + return graph_; +} +inline ::onnx::GraphProto* ModelProto::mutable_graph() { + // @@protoc_insertion_point(field_mutable:onnx.ModelProto.graph) + return _internal_mutable_graph(); +} +inline void ModelProto::set_allocated_graph(::onnx::GraphProto* graph) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArenaNoVirtual(); + if (message_arena == nullptr) { + delete graph_; + } + if (graph) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = nullptr; + if (message_arena != submessage_arena) { + graph = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, graph, submessage_arena); + } + _has_bits_[0] |= 0x00000010u; + } else { + _has_bits_[0] &= ~0x00000010u; + } + graph_ = graph; + // @@protoc_insertion_point(field_set_allocated:onnx.ModelProto.graph) +} + +// repeated .onnx.StringStringEntryProto metadata_props = 14; +inline int ModelProto::_internal_metadata_props_size() const { + return metadata_props_.size(); +} +inline int ModelProto::metadata_props_size() const { + return _internal_metadata_props_size(); +} +inline void ModelProto::clear_metadata_props() { + metadata_props_.Clear(); +} +inline ::onnx::StringStringEntryProto* ModelProto::mutable_metadata_props(int index) { + // @@protoc_insertion_point(field_mutable:onnx.ModelProto.metadata_props) + return metadata_props_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::StringStringEntryProto >* +ModelProto::mutable_metadata_props() { + // @@protoc_insertion_point(field_mutable_list:onnx.ModelProto.metadata_props) + return &metadata_props_; +} +inline const ::onnx::StringStringEntryProto& ModelProto::_internal_metadata_props(int index) const { + return metadata_props_.Get(index); +} +inline const ::onnx::StringStringEntryProto& ModelProto::metadata_props(int index) const { + // @@protoc_insertion_point(field_get:onnx.ModelProto.metadata_props) + return _internal_metadata_props(index); +} +inline ::onnx::StringStringEntryProto* ModelProto::_internal_add_metadata_props() { + return metadata_props_.Add(); +} +inline ::onnx::StringStringEntryProto* ModelProto::add_metadata_props() { + // @@protoc_insertion_point(field_add:onnx.ModelProto.metadata_props) + return _internal_add_metadata_props(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::StringStringEntryProto >& +ModelProto::metadata_props() const { + // @@protoc_insertion_point(field_list:onnx.ModelProto.metadata_props) + return metadata_props_; +} + +// ------------------------------------------------------------------- + +// StringStringEntryProto + +// optional string key = 1; +inline bool StringStringEntryProto::_internal_has_key() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool StringStringEntryProto::has_key() const { + return _internal_has_key(); +} +inline void StringStringEntryProto::clear_key() { + key_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& StringStringEntryProto::key() const { + // @@protoc_insertion_point(field_get:onnx.StringStringEntryProto.key) + return _internal_key(); +} +inline void StringStringEntryProto::set_key(const std::string& value) { + _internal_set_key(value); + // @@protoc_insertion_point(field_set:onnx.StringStringEntryProto.key) +} +inline std::string* StringStringEntryProto::mutable_key() { + // @@protoc_insertion_point(field_mutable:onnx.StringStringEntryProto.key) + return _internal_mutable_key(); +} +inline const std::string& StringStringEntryProto::_internal_key() const { + return key_.GetNoArena(); +} +inline void StringStringEntryProto::_internal_set_key(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + key_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void StringStringEntryProto::set_key(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + key_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.StringStringEntryProto.key) +} +inline void StringStringEntryProto::set_key(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + key_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.StringStringEntryProto.key) +} +inline void StringStringEntryProto::set_key(const char* value, size_t size) { + _has_bits_[0] |= 0x00000001u; + key_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.StringStringEntryProto.key) +} +inline std::string* StringStringEntryProto::_internal_mutable_key() { + _has_bits_[0] |= 0x00000001u; + return key_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* StringStringEntryProto::release_key() { + // @@protoc_insertion_point(field_release:onnx.StringStringEntryProto.key) + if (!_internal_has_key()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return key_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void StringStringEntryProto::set_allocated_key(std::string* key) { + if (key != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + key_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), key); + // @@protoc_insertion_point(field_set_allocated:onnx.StringStringEntryProto.key) +} + +// optional string value = 2; +inline bool StringStringEntryProto::_internal_has_value() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool StringStringEntryProto::has_value() const { + return _internal_has_value(); +} +inline void StringStringEntryProto::clear_value() { + value_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& StringStringEntryProto::value() const { + // @@protoc_insertion_point(field_get:onnx.StringStringEntryProto.value) + return _internal_value(); +} +inline void StringStringEntryProto::set_value(const std::string& value) { + _internal_set_value(value); + // @@protoc_insertion_point(field_set:onnx.StringStringEntryProto.value) +} +inline std::string* StringStringEntryProto::mutable_value() { + // @@protoc_insertion_point(field_mutable:onnx.StringStringEntryProto.value) + return _internal_mutable_value(); +} +inline const std::string& StringStringEntryProto::_internal_value() const { + return value_.GetNoArena(); +} +inline void StringStringEntryProto::_internal_set_value(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + value_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void StringStringEntryProto::set_value(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + value_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.StringStringEntryProto.value) +} +inline void StringStringEntryProto::set_value(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + value_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.StringStringEntryProto.value) +} +inline void StringStringEntryProto::set_value(const char* value, size_t size) { + _has_bits_[0] |= 0x00000002u; + value_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.StringStringEntryProto.value) +} +inline std::string* StringStringEntryProto::_internal_mutable_value() { + _has_bits_[0] |= 0x00000002u; + return value_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* StringStringEntryProto::release_value() { + // @@protoc_insertion_point(field_release:onnx.StringStringEntryProto.value) + if (!_internal_has_value()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return value_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void StringStringEntryProto::set_allocated_value(std::string* value) { + if (value != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + value_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); + // @@protoc_insertion_point(field_set_allocated:onnx.StringStringEntryProto.value) +} + +// ------------------------------------------------------------------- + +// TensorAnnotation + +// optional string tensor_name = 1; +inline bool TensorAnnotation::_internal_has_tensor_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool TensorAnnotation::has_tensor_name() const { + return _internal_has_tensor_name(); +} +inline void TensorAnnotation::clear_tensor_name() { + tensor_name_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& TensorAnnotation::tensor_name() const { + // @@protoc_insertion_point(field_get:onnx.TensorAnnotation.tensor_name) + return _internal_tensor_name(); +} +inline void TensorAnnotation::set_tensor_name(const std::string& value) { + _internal_set_tensor_name(value); + // @@protoc_insertion_point(field_set:onnx.TensorAnnotation.tensor_name) +} +inline std::string* TensorAnnotation::mutable_tensor_name() { + // @@protoc_insertion_point(field_mutable:onnx.TensorAnnotation.tensor_name) + return _internal_mutable_tensor_name(); +} +inline const std::string& TensorAnnotation::_internal_tensor_name() const { + return tensor_name_.GetNoArena(); +} +inline void TensorAnnotation::_internal_set_tensor_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + tensor_name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void TensorAnnotation::set_tensor_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + tensor_name_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.TensorAnnotation.tensor_name) +} +inline void TensorAnnotation::set_tensor_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + tensor_name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.TensorAnnotation.tensor_name) +} +inline void TensorAnnotation::set_tensor_name(const char* value, size_t size) { + _has_bits_[0] |= 0x00000001u; + tensor_name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.TensorAnnotation.tensor_name) +} +inline std::string* TensorAnnotation::_internal_mutable_tensor_name() { + _has_bits_[0] |= 0x00000001u; + return tensor_name_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* TensorAnnotation::release_tensor_name() { + // @@protoc_insertion_point(field_release:onnx.TensorAnnotation.tensor_name) + if (!_internal_has_tensor_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return tensor_name_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void TensorAnnotation::set_allocated_tensor_name(std::string* tensor_name) { + if (tensor_name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + tensor_name_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), tensor_name); + // @@protoc_insertion_point(field_set_allocated:onnx.TensorAnnotation.tensor_name) +} + +// repeated .onnx.StringStringEntryProto quant_parameter_tensor_names = 2; +inline int TensorAnnotation::_internal_quant_parameter_tensor_names_size() const { + return quant_parameter_tensor_names_.size(); +} +inline int TensorAnnotation::quant_parameter_tensor_names_size() const { + return _internal_quant_parameter_tensor_names_size(); +} +inline void TensorAnnotation::clear_quant_parameter_tensor_names() { + quant_parameter_tensor_names_.Clear(); +} +inline ::onnx::StringStringEntryProto* TensorAnnotation::mutable_quant_parameter_tensor_names(int index) { + // @@protoc_insertion_point(field_mutable:onnx.TensorAnnotation.quant_parameter_tensor_names) + return quant_parameter_tensor_names_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::StringStringEntryProto >* +TensorAnnotation::mutable_quant_parameter_tensor_names() { + // @@protoc_insertion_point(field_mutable_list:onnx.TensorAnnotation.quant_parameter_tensor_names) + return &quant_parameter_tensor_names_; +} +inline const ::onnx::StringStringEntryProto& TensorAnnotation::_internal_quant_parameter_tensor_names(int index) const { + return quant_parameter_tensor_names_.Get(index); +} +inline const ::onnx::StringStringEntryProto& TensorAnnotation::quant_parameter_tensor_names(int index) const { + // @@protoc_insertion_point(field_get:onnx.TensorAnnotation.quant_parameter_tensor_names) + return _internal_quant_parameter_tensor_names(index); +} +inline ::onnx::StringStringEntryProto* TensorAnnotation::_internal_add_quant_parameter_tensor_names() { + return quant_parameter_tensor_names_.Add(); +} +inline ::onnx::StringStringEntryProto* TensorAnnotation::add_quant_parameter_tensor_names() { + // @@protoc_insertion_point(field_add:onnx.TensorAnnotation.quant_parameter_tensor_names) + return _internal_add_quant_parameter_tensor_names(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::StringStringEntryProto >& +TensorAnnotation::quant_parameter_tensor_names() const { + // @@protoc_insertion_point(field_list:onnx.TensorAnnotation.quant_parameter_tensor_names) + return quant_parameter_tensor_names_; +} + +// ------------------------------------------------------------------- + +// GraphProto + +// repeated .onnx.NodeProto node = 1; +inline int GraphProto::_internal_node_size() const { + return node_.size(); +} +inline int GraphProto::node_size() const { + return _internal_node_size(); +} +inline void GraphProto::clear_node() { + node_.Clear(); +} +inline ::onnx::NodeProto* GraphProto::mutable_node(int index) { + // @@protoc_insertion_point(field_mutable:onnx.GraphProto.node) + return node_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::NodeProto >* +GraphProto::mutable_node() { + // @@protoc_insertion_point(field_mutable_list:onnx.GraphProto.node) + return &node_; +} +inline const ::onnx::NodeProto& GraphProto::_internal_node(int index) const { + return node_.Get(index); +} +inline const ::onnx::NodeProto& GraphProto::node(int index) const { + // @@protoc_insertion_point(field_get:onnx.GraphProto.node) + return _internal_node(index); +} +inline ::onnx::NodeProto* GraphProto::_internal_add_node() { + return node_.Add(); +} +inline ::onnx::NodeProto* GraphProto::add_node() { + // @@protoc_insertion_point(field_add:onnx.GraphProto.node) + return _internal_add_node(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::NodeProto >& +GraphProto::node() const { + // @@protoc_insertion_point(field_list:onnx.GraphProto.node) + return node_; +} + +// optional string name = 2; +inline bool GraphProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool GraphProto::has_name() const { + return _internal_has_name(); +} +inline void GraphProto::clear_name() { + name_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& GraphProto::name() const { + // @@protoc_insertion_point(field_get:onnx.GraphProto.name) + return _internal_name(); +} +inline void GraphProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:onnx.GraphProto.name) +} +inline std::string* GraphProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:onnx.GraphProto.name) + return _internal_mutable_name(); +} +inline const std::string& GraphProto::_internal_name() const { + return name_.GetNoArena(); +} +inline void GraphProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void GraphProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.GraphProto.name) +} +inline void GraphProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.GraphProto.name) +} +inline void GraphProto::set_name(const char* value, size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.GraphProto.name) +} +inline std::string* GraphProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* GraphProto::release_name() { + // @@protoc_insertion_point(field_release:onnx.GraphProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void GraphProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name); + // @@protoc_insertion_point(field_set_allocated:onnx.GraphProto.name) +} + +// repeated .onnx.TensorProto initializer = 5; +inline int GraphProto::_internal_initializer_size() const { + return initializer_.size(); +} +inline int GraphProto::initializer_size() const { + return _internal_initializer_size(); +} +inline void GraphProto::clear_initializer() { + initializer_.Clear(); +} +inline ::onnx::TensorProto* GraphProto::mutable_initializer(int index) { + // @@protoc_insertion_point(field_mutable:onnx.GraphProto.initializer) + return initializer_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorProto >* +GraphProto::mutable_initializer() { + // @@protoc_insertion_point(field_mutable_list:onnx.GraphProto.initializer) + return &initializer_; +} +inline const ::onnx::TensorProto& GraphProto::_internal_initializer(int index) const { + return initializer_.Get(index); +} +inline const ::onnx::TensorProto& GraphProto::initializer(int index) const { + // @@protoc_insertion_point(field_get:onnx.GraphProto.initializer) + return _internal_initializer(index); +} +inline ::onnx::TensorProto* GraphProto::_internal_add_initializer() { + return initializer_.Add(); +} +inline ::onnx::TensorProto* GraphProto::add_initializer() { + // @@protoc_insertion_point(field_add:onnx.GraphProto.initializer) + return _internal_add_initializer(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorProto >& +GraphProto::initializer() const { + // @@protoc_insertion_point(field_list:onnx.GraphProto.initializer) + return initializer_; +} + +// repeated .onnx.SparseTensorProto sparse_initializer = 15; +inline int GraphProto::_internal_sparse_initializer_size() const { + return sparse_initializer_.size(); +} +inline int GraphProto::sparse_initializer_size() const { + return _internal_sparse_initializer_size(); +} +inline void GraphProto::clear_sparse_initializer() { + sparse_initializer_.Clear(); +} +inline ::onnx::SparseTensorProto* GraphProto::mutable_sparse_initializer(int index) { + // @@protoc_insertion_point(field_mutable:onnx.GraphProto.sparse_initializer) + return sparse_initializer_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::SparseTensorProto >* +GraphProto::mutable_sparse_initializer() { + // @@protoc_insertion_point(field_mutable_list:onnx.GraphProto.sparse_initializer) + return &sparse_initializer_; +} +inline const ::onnx::SparseTensorProto& GraphProto::_internal_sparse_initializer(int index) const { + return sparse_initializer_.Get(index); +} +inline const ::onnx::SparseTensorProto& GraphProto::sparse_initializer(int index) const { + // @@protoc_insertion_point(field_get:onnx.GraphProto.sparse_initializer) + return _internal_sparse_initializer(index); +} +inline ::onnx::SparseTensorProto* GraphProto::_internal_add_sparse_initializer() { + return sparse_initializer_.Add(); +} +inline ::onnx::SparseTensorProto* GraphProto::add_sparse_initializer() { + // @@protoc_insertion_point(field_add:onnx.GraphProto.sparse_initializer) + return _internal_add_sparse_initializer(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::SparseTensorProto >& +GraphProto::sparse_initializer() const { + // @@protoc_insertion_point(field_list:onnx.GraphProto.sparse_initializer) + return sparse_initializer_; +} + +// optional string doc_string = 10; +inline bool GraphProto::_internal_has_doc_string() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool GraphProto::has_doc_string() const { + return _internal_has_doc_string(); +} +inline void GraphProto::clear_doc_string() { + doc_string_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& GraphProto::doc_string() const { + // @@protoc_insertion_point(field_get:onnx.GraphProto.doc_string) + return _internal_doc_string(); +} +inline void GraphProto::set_doc_string(const std::string& value) { + _internal_set_doc_string(value); + // @@protoc_insertion_point(field_set:onnx.GraphProto.doc_string) +} +inline std::string* GraphProto::mutable_doc_string() { + // @@protoc_insertion_point(field_mutable:onnx.GraphProto.doc_string) + return _internal_mutable_doc_string(); +} +inline const std::string& GraphProto::_internal_doc_string() const { + return doc_string_.GetNoArena(); +} +inline void GraphProto::_internal_set_doc_string(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + doc_string_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void GraphProto::set_doc_string(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + doc_string_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.GraphProto.doc_string) +} +inline void GraphProto::set_doc_string(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + doc_string_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.GraphProto.doc_string) +} +inline void GraphProto::set_doc_string(const char* value, size_t size) { + _has_bits_[0] |= 0x00000002u; + doc_string_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.GraphProto.doc_string) +} +inline std::string* GraphProto::_internal_mutable_doc_string() { + _has_bits_[0] |= 0x00000002u; + return doc_string_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* GraphProto::release_doc_string() { + // @@protoc_insertion_point(field_release:onnx.GraphProto.doc_string) + if (!_internal_has_doc_string()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return doc_string_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void GraphProto::set_allocated_doc_string(std::string* doc_string) { + if (doc_string != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + doc_string_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), doc_string); + // @@protoc_insertion_point(field_set_allocated:onnx.GraphProto.doc_string) +} + +// repeated .onnx.ValueInfoProto input = 11; +inline int GraphProto::_internal_input_size() const { + return input_.size(); +} +inline int GraphProto::input_size() const { + return _internal_input_size(); +} +inline void GraphProto::clear_input() { + input_.Clear(); +} +inline ::onnx::ValueInfoProto* GraphProto::mutable_input(int index) { + // @@protoc_insertion_point(field_mutable:onnx.GraphProto.input) + return input_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >* +GraphProto::mutable_input() { + // @@protoc_insertion_point(field_mutable_list:onnx.GraphProto.input) + return &input_; +} +inline const ::onnx::ValueInfoProto& GraphProto::_internal_input(int index) const { + return input_.Get(index); +} +inline const ::onnx::ValueInfoProto& GraphProto::input(int index) const { + // @@protoc_insertion_point(field_get:onnx.GraphProto.input) + return _internal_input(index); +} +inline ::onnx::ValueInfoProto* GraphProto::_internal_add_input() { + return input_.Add(); +} +inline ::onnx::ValueInfoProto* GraphProto::add_input() { + // @@protoc_insertion_point(field_add:onnx.GraphProto.input) + return _internal_add_input(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >& +GraphProto::input() const { + // @@protoc_insertion_point(field_list:onnx.GraphProto.input) + return input_; +} + +// repeated .onnx.ValueInfoProto output = 12; +inline int GraphProto::_internal_output_size() const { + return output_.size(); +} +inline int GraphProto::output_size() const { + return _internal_output_size(); +} +inline void GraphProto::clear_output() { + output_.Clear(); +} +inline ::onnx::ValueInfoProto* GraphProto::mutable_output(int index) { + // @@protoc_insertion_point(field_mutable:onnx.GraphProto.output) + return output_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >* +GraphProto::mutable_output() { + // @@protoc_insertion_point(field_mutable_list:onnx.GraphProto.output) + return &output_; +} +inline const ::onnx::ValueInfoProto& GraphProto::_internal_output(int index) const { + return output_.Get(index); +} +inline const ::onnx::ValueInfoProto& GraphProto::output(int index) const { + // @@protoc_insertion_point(field_get:onnx.GraphProto.output) + return _internal_output(index); +} +inline ::onnx::ValueInfoProto* GraphProto::_internal_add_output() { + return output_.Add(); +} +inline ::onnx::ValueInfoProto* GraphProto::add_output() { + // @@protoc_insertion_point(field_add:onnx.GraphProto.output) + return _internal_add_output(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >& +GraphProto::output() const { + // @@protoc_insertion_point(field_list:onnx.GraphProto.output) + return output_; +} + +// repeated .onnx.ValueInfoProto value_info = 13; +inline int GraphProto::_internal_value_info_size() const { + return value_info_.size(); +} +inline int GraphProto::value_info_size() const { + return _internal_value_info_size(); +} +inline void GraphProto::clear_value_info() { + value_info_.Clear(); +} +inline ::onnx::ValueInfoProto* GraphProto::mutable_value_info(int index) { + // @@protoc_insertion_point(field_mutable:onnx.GraphProto.value_info) + return value_info_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >* +GraphProto::mutable_value_info() { + // @@protoc_insertion_point(field_mutable_list:onnx.GraphProto.value_info) + return &value_info_; +} +inline const ::onnx::ValueInfoProto& GraphProto::_internal_value_info(int index) const { + return value_info_.Get(index); +} +inline const ::onnx::ValueInfoProto& GraphProto::value_info(int index) const { + // @@protoc_insertion_point(field_get:onnx.GraphProto.value_info) + return _internal_value_info(index); +} +inline ::onnx::ValueInfoProto* GraphProto::_internal_add_value_info() { + return value_info_.Add(); +} +inline ::onnx::ValueInfoProto* GraphProto::add_value_info() { + // @@protoc_insertion_point(field_add:onnx.GraphProto.value_info) + return _internal_add_value_info(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >& +GraphProto::value_info() const { + // @@protoc_insertion_point(field_list:onnx.GraphProto.value_info) + return value_info_; +} + +// repeated .onnx.TensorAnnotation quantization_annotation = 14; +inline int GraphProto::_internal_quantization_annotation_size() const { + return quantization_annotation_.size(); +} +inline int GraphProto::quantization_annotation_size() const { + return _internal_quantization_annotation_size(); +} +inline void GraphProto::clear_quantization_annotation() { + quantization_annotation_.Clear(); +} +inline ::onnx::TensorAnnotation* GraphProto::mutable_quantization_annotation(int index) { + // @@protoc_insertion_point(field_mutable:onnx.GraphProto.quantization_annotation) + return quantization_annotation_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorAnnotation >* +GraphProto::mutable_quantization_annotation() { + // @@protoc_insertion_point(field_mutable_list:onnx.GraphProto.quantization_annotation) + return &quantization_annotation_; +} +inline const ::onnx::TensorAnnotation& GraphProto::_internal_quantization_annotation(int index) const { + return quantization_annotation_.Get(index); +} +inline const ::onnx::TensorAnnotation& GraphProto::quantization_annotation(int index) const { + // @@protoc_insertion_point(field_get:onnx.GraphProto.quantization_annotation) + return _internal_quantization_annotation(index); +} +inline ::onnx::TensorAnnotation* GraphProto::_internal_add_quantization_annotation() { + return quantization_annotation_.Add(); +} +inline ::onnx::TensorAnnotation* GraphProto::add_quantization_annotation() { + // @@protoc_insertion_point(field_add:onnx.GraphProto.quantization_annotation) + return _internal_add_quantization_annotation(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorAnnotation >& +GraphProto::quantization_annotation() const { + // @@protoc_insertion_point(field_list:onnx.GraphProto.quantization_annotation) + return quantization_annotation_; +} + +// ------------------------------------------------------------------- + +// TensorProto_Segment + +// optional int64 begin = 1; +inline bool TensorProto_Segment::_internal_has_begin() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool TensorProto_Segment::has_begin() const { + return _internal_has_begin(); +} +inline void TensorProto_Segment::clear_begin() { + begin_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000001u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto_Segment::_internal_begin() const { + return begin_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto_Segment::begin() const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.Segment.begin) + return _internal_begin(); +} +inline void TensorProto_Segment::_internal_set_begin(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000001u; + begin_ = value; +} +inline void TensorProto_Segment::set_begin(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_begin(value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.Segment.begin) +} + +// optional int64 end = 2; +inline bool TensorProto_Segment::_internal_has_end() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool TensorProto_Segment::has_end() const { + return _internal_has_end(); +} +inline void TensorProto_Segment::clear_end() { + end_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000002u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto_Segment::_internal_end() const { + return end_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto_Segment::end() const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.Segment.end) + return _internal_end(); +} +inline void TensorProto_Segment::_internal_set_end(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000002u; + end_ = value; +} +inline void TensorProto_Segment::set_end(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_end(value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.Segment.end) +} + +// ------------------------------------------------------------------- + +// TensorProto + +// repeated int64 dims = 1; +inline int TensorProto::_internal_dims_size() const { + return dims_.size(); +} +inline int TensorProto::dims_size() const { + return _internal_dims_size(); +} +inline void TensorProto::clear_dims() { + dims_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto::_internal_dims(int index) const { + return dims_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto::dims(int index) const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.dims) + return _internal_dims(index); +} +inline void TensorProto::set_dims(int index, ::PROTOBUF_NAMESPACE_ID::int64 value) { + dims_.Set(index, value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.dims) +} +inline void TensorProto::_internal_add_dims(::PROTOBUF_NAMESPACE_ID::int64 value) { + dims_.Add(value); +} +inline void TensorProto::add_dims(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_add_dims(value); + // @@protoc_insertion_point(field_add:onnx.TensorProto.dims) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +TensorProto::_internal_dims() const { + return dims_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +TensorProto::dims() const { + // @@protoc_insertion_point(field_list:onnx.TensorProto.dims) + return _internal_dims(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +TensorProto::_internal_mutable_dims() { + return &dims_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +TensorProto::mutable_dims() { + // @@protoc_insertion_point(field_mutable_list:onnx.TensorProto.dims) + return _internal_mutable_dims(); +} + +// optional int32 data_type = 2; +inline bool TensorProto::_internal_has_data_type() const { + bool value = (_has_bits_[0] & 0x00000010u) != 0; + return value; +} +inline bool TensorProto::has_data_type() const { + return _internal_has_data_type(); +} +inline void TensorProto::clear_data_type() { + data_type_ = 0; + _has_bits_[0] &= ~0x00000010u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 TensorProto::_internal_data_type() const { + return data_type_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 TensorProto::data_type() const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.data_type) + return _internal_data_type(); +} +inline void TensorProto::_internal_set_data_type(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000010u; + data_type_ = value; +} +inline void TensorProto::set_data_type(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_data_type(value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.data_type) +} + +// optional .onnx.TensorProto.Segment segment = 3; +inline bool TensorProto::_internal_has_segment() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + PROTOBUF_ASSUME(!value || segment_ != nullptr); + return value; +} +inline bool TensorProto::has_segment() const { + return _internal_has_segment(); +} +inline void TensorProto::clear_segment() { + if (segment_ != nullptr) segment_->Clear(); + _has_bits_[0] &= ~0x00000008u; +} +inline const ::onnx::TensorProto_Segment& TensorProto::_internal_segment() const { + const ::onnx::TensorProto_Segment* p = segment_; + return p != nullptr ? *p : *reinterpret_cast( + &::onnx::_TensorProto_Segment_default_instance_); +} +inline const ::onnx::TensorProto_Segment& TensorProto::segment() const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.segment) + return _internal_segment(); +} +inline ::onnx::TensorProto_Segment* TensorProto::release_segment() { + // @@protoc_insertion_point(field_release:onnx.TensorProto.segment) + _has_bits_[0] &= ~0x00000008u; + ::onnx::TensorProto_Segment* temp = segment_; + segment_ = nullptr; + return temp; +} +inline ::onnx::TensorProto_Segment* TensorProto::_internal_mutable_segment() { + _has_bits_[0] |= 0x00000008u; + if (segment_ == nullptr) { + auto* p = CreateMaybeMessage<::onnx::TensorProto_Segment>(GetArenaNoVirtual()); + segment_ = p; + } + return segment_; +} +inline ::onnx::TensorProto_Segment* TensorProto::mutable_segment() { + // @@protoc_insertion_point(field_mutable:onnx.TensorProto.segment) + return _internal_mutable_segment(); +} +inline void TensorProto::set_allocated_segment(::onnx::TensorProto_Segment* segment) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArenaNoVirtual(); + if (message_arena == nullptr) { + delete segment_; + } + if (segment) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = nullptr; + if (message_arena != submessage_arena) { + segment = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, segment, submessage_arena); + } + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + segment_ = segment; + // @@protoc_insertion_point(field_set_allocated:onnx.TensorProto.segment) +} + +// repeated float float_data = 4 [packed = true]; +inline int TensorProto::_internal_float_data_size() const { + return float_data_.size(); +} +inline int TensorProto::float_data_size() const { + return _internal_float_data_size(); +} +inline void TensorProto::clear_float_data() { + float_data_.Clear(); +} +inline float TensorProto::_internal_float_data(int index) const { + return float_data_.Get(index); +} +inline float TensorProto::float_data(int index) const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.float_data) + return _internal_float_data(index); +} +inline void TensorProto::set_float_data(int index, float value) { + float_data_.Set(index, value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.float_data) +} +inline void TensorProto::_internal_add_float_data(float value) { + float_data_.Add(value); +} +inline void TensorProto::add_float_data(float value) { + _internal_add_float_data(value); + // @@protoc_insertion_point(field_add:onnx.TensorProto.float_data) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +TensorProto::_internal_float_data() const { + return float_data_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +TensorProto::float_data() const { + // @@protoc_insertion_point(field_list:onnx.TensorProto.float_data) + return _internal_float_data(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +TensorProto::_internal_mutable_float_data() { + return &float_data_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +TensorProto::mutable_float_data() { + // @@protoc_insertion_point(field_mutable_list:onnx.TensorProto.float_data) + return _internal_mutable_float_data(); +} + +// repeated int32 int32_data = 5 [packed = true]; +inline int TensorProto::_internal_int32_data_size() const { + return int32_data_.size(); +} +inline int TensorProto::int32_data_size() const { + return _internal_int32_data_size(); +} +inline void TensorProto::clear_int32_data() { + int32_data_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 TensorProto::_internal_int32_data(int index) const { + return int32_data_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 TensorProto::int32_data(int index) const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.int32_data) + return _internal_int32_data(index); +} +inline void TensorProto::set_int32_data(int index, ::PROTOBUF_NAMESPACE_ID::int32 value) { + int32_data_.Set(index, value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.int32_data) +} +inline void TensorProto::_internal_add_int32_data(::PROTOBUF_NAMESPACE_ID::int32 value) { + int32_data_.Add(value); +} +inline void TensorProto::add_int32_data(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_add_int32_data(value); + // @@protoc_insertion_point(field_add:onnx.TensorProto.int32_data) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +TensorProto::_internal_int32_data() const { + return int32_data_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +TensorProto::int32_data() const { + // @@protoc_insertion_point(field_list:onnx.TensorProto.int32_data) + return _internal_int32_data(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +TensorProto::_internal_mutable_int32_data() { + return &int32_data_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +TensorProto::mutable_int32_data() { + // @@protoc_insertion_point(field_mutable_list:onnx.TensorProto.int32_data) + return _internal_mutable_int32_data(); +} + +// repeated bytes string_data = 6; +inline int TensorProto::_internal_string_data_size() const { + return string_data_.size(); +} +inline int TensorProto::string_data_size() const { + return _internal_string_data_size(); +} +inline void TensorProto::clear_string_data() { + string_data_.Clear(); +} +inline std::string* TensorProto::add_string_data() { + // @@protoc_insertion_point(field_add_mutable:onnx.TensorProto.string_data) + return _internal_add_string_data(); +} +inline const std::string& TensorProto::_internal_string_data(int index) const { + return string_data_.Get(index); +} +inline const std::string& TensorProto::string_data(int index) const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.string_data) + return _internal_string_data(index); +} +inline std::string* TensorProto::mutable_string_data(int index) { + // @@protoc_insertion_point(field_mutable:onnx.TensorProto.string_data) + return string_data_.Mutable(index); +} +inline void TensorProto::set_string_data(int index, const std::string& value) { + // @@protoc_insertion_point(field_set:onnx.TensorProto.string_data) + string_data_.Mutable(index)->assign(value); +} +inline void TensorProto::set_string_data(int index, std::string&& value) { + // @@protoc_insertion_point(field_set:onnx.TensorProto.string_data) + string_data_.Mutable(index)->assign(std::move(value)); +} +inline void TensorProto::set_string_data(int index, const char* value) { + GOOGLE_DCHECK(value != nullptr); + string_data_.Mutable(index)->assign(value); + // @@protoc_insertion_point(field_set_char:onnx.TensorProto.string_data) +} +inline void TensorProto::set_string_data(int index, const void* value, size_t size) { + string_data_.Mutable(index)->assign( + reinterpret_cast(value), size); + // @@protoc_insertion_point(field_set_pointer:onnx.TensorProto.string_data) +} +inline std::string* TensorProto::_internal_add_string_data() { + return string_data_.Add(); +} +inline void TensorProto::add_string_data(const std::string& value) { + string_data_.Add()->assign(value); + // @@protoc_insertion_point(field_add:onnx.TensorProto.string_data) +} +inline void TensorProto::add_string_data(std::string&& value) { + string_data_.Add(std::move(value)); + // @@protoc_insertion_point(field_add:onnx.TensorProto.string_data) +} +inline void TensorProto::add_string_data(const char* value) { + GOOGLE_DCHECK(value != nullptr); + string_data_.Add()->assign(value); + // @@protoc_insertion_point(field_add_char:onnx.TensorProto.string_data) +} +inline void TensorProto::add_string_data(const void* value, size_t size) { + string_data_.Add()->assign(reinterpret_cast(value), size); + // @@protoc_insertion_point(field_add_pointer:onnx.TensorProto.string_data) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& +TensorProto::string_data() const { + // @@protoc_insertion_point(field_list:onnx.TensorProto.string_data) + return string_data_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* +TensorProto::mutable_string_data() { + // @@protoc_insertion_point(field_mutable_list:onnx.TensorProto.string_data) + return &string_data_; +} + +// repeated int64 int64_data = 7 [packed = true]; +inline int TensorProto::_internal_int64_data_size() const { + return int64_data_.size(); +} +inline int TensorProto::int64_data_size() const { + return _internal_int64_data_size(); +} +inline void TensorProto::clear_int64_data() { + int64_data_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto::_internal_int64_data(int index) const { + return int64_data_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto::int64_data(int index) const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.int64_data) + return _internal_int64_data(index); +} +inline void TensorProto::set_int64_data(int index, ::PROTOBUF_NAMESPACE_ID::int64 value) { + int64_data_.Set(index, value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.int64_data) +} +inline void TensorProto::_internal_add_int64_data(::PROTOBUF_NAMESPACE_ID::int64 value) { + int64_data_.Add(value); +} +inline void TensorProto::add_int64_data(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_add_int64_data(value); + // @@protoc_insertion_point(field_add:onnx.TensorProto.int64_data) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +TensorProto::_internal_int64_data() const { + return int64_data_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +TensorProto::int64_data() const { + // @@protoc_insertion_point(field_list:onnx.TensorProto.int64_data) + return _internal_int64_data(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +TensorProto::_internal_mutable_int64_data() { + return &int64_data_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +TensorProto::mutable_int64_data() { + // @@protoc_insertion_point(field_mutable_list:onnx.TensorProto.int64_data) + return _internal_mutable_int64_data(); +} + +// optional string name = 8; +inline bool TensorProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool TensorProto::has_name() const { + return _internal_has_name(); +} +inline void TensorProto::clear_name() { + name_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& TensorProto::name() const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.name) + return _internal_name(); +} +inline void TensorProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.name) +} +inline std::string* TensorProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:onnx.TensorProto.name) + return _internal_mutable_name(); +} +inline const std::string& TensorProto::_internal_name() const { + return name_.GetNoArena(); +} +inline void TensorProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void TensorProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.TensorProto.name) +} +inline void TensorProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.TensorProto.name) +} +inline void TensorProto::set_name(const char* value, size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.TensorProto.name) +} +inline std::string* TensorProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* TensorProto::release_name() { + // @@protoc_insertion_point(field_release:onnx.TensorProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void TensorProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name); + // @@protoc_insertion_point(field_set_allocated:onnx.TensorProto.name) +} + +// optional string doc_string = 12; +inline bool TensorProto::_internal_has_doc_string() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool TensorProto::has_doc_string() const { + return _internal_has_doc_string(); +} +inline void TensorProto::clear_doc_string() { + doc_string_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000004u; +} +inline const std::string& TensorProto::doc_string() const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.doc_string) + return _internal_doc_string(); +} +inline void TensorProto::set_doc_string(const std::string& value) { + _internal_set_doc_string(value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.doc_string) +} +inline std::string* TensorProto::mutable_doc_string() { + // @@protoc_insertion_point(field_mutable:onnx.TensorProto.doc_string) + return _internal_mutable_doc_string(); +} +inline const std::string& TensorProto::_internal_doc_string() const { + return doc_string_.GetNoArena(); +} +inline void TensorProto::_internal_set_doc_string(const std::string& value) { + _has_bits_[0] |= 0x00000004u; + doc_string_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void TensorProto::set_doc_string(std::string&& value) { + _has_bits_[0] |= 0x00000004u; + doc_string_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.TensorProto.doc_string) +} +inline void TensorProto::set_doc_string(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000004u; + doc_string_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.TensorProto.doc_string) +} +inline void TensorProto::set_doc_string(const char* value, size_t size) { + _has_bits_[0] |= 0x00000004u; + doc_string_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.TensorProto.doc_string) +} +inline std::string* TensorProto::_internal_mutable_doc_string() { + _has_bits_[0] |= 0x00000004u; + return doc_string_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* TensorProto::release_doc_string() { + // @@protoc_insertion_point(field_release:onnx.TensorProto.doc_string) + if (!_internal_has_doc_string()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000004u; + return doc_string_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void TensorProto::set_allocated_doc_string(std::string* doc_string) { + if (doc_string != nullptr) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + doc_string_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), doc_string); + // @@protoc_insertion_point(field_set_allocated:onnx.TensorProto.doc_string) +} + +// optional bytes raw_data = 9; +inline bool TensorProto::_internal_has_raw_data() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool TensorProto::has_raw_data() const { + return _internal_has_raw_data(); +} +inline void TensorProto::clear_raw_data() { + raw_data_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& TensorProto::raw_data() const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.raw_data) + return _internal_raw_data(); +} +inline void TensorProto::set_raw_data(const std::string& value) { + _internal_set_raw_data(value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.raw_data) +} +inline std::string* TensorProto::mutable_raw_data() { + // @@protoc_insertion_point(field_mutable:onnx.TensorProto.raw_data) + return _internal_mutable_raw_data(); +} +inline const std::string& TensorProto::_internal_raw_data() const { + return raw_data_.GetNoArena(); +} +inline void TensorProto::_internal_set_raw_data(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + raw_data_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void TensorProto::set_raw_data(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + raw_data_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.TensorProto.raw_data) +} +inline void TensorProto::set_raw_data(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + raw_data_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.TensorProto.raw_data) +} +inline void TensorProto::set_raw_data(const void* value, size_t size) { + _has_bits_[0] |= 0x00000002u; + raw_data_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.TensorProto.raw_data) +} +inline std::string* TensorProto::_internal_mutable_raw_data() { + _has_bits_[0] |= 0x00000002u; + return raw_data_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* TensorProto::release_raw_data() { + // @@protoc_insertion_point(field_release:onnx.TensorProto.raw_data) + if (!_internal_has_raw_data()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return raw_data_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void TensorProto::set_allocated_raw_data(std::string* raw_data) { + if (raw_data != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + raw_data_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), raw_data); + // @@protoc_insertion_point(field_set_allocated:onnx.TensorProto.raw_data) +} + +// repeated .onnx.StringStringEntryProto external_data = 13; +inline int TensorProto::_internal_external_data_size() const { + return external_data_.size(); +} +inline int TensorProto::external_data_size() const { + return _internal_external_data_size(); +} +inline void TensorProto::clear_external_data() { + external_data_.Clear(); +} +inline ::onnx::StringStringEntryProto* TensorProto::mutable_external_data(int index) { + // @@protoc_insertion_point(field_mutable:onnx.TensorProto.external_data) + return external_data_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::StringStringEntryProto >* +TensorProto::mutable_external_data() { + // @@protoc_insertion_point(field_mutable_list:onnx.TensorProto.external_data) + return &external_data_; +} +inline const ::onnx::StringStringEntryProto& TensorProto::_internal_external_data(int index) const { + return external_data_.Get(index); +} +inline const ::onnx::StringStringEntryProto& TensorProto::external_data(int index) const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.external_data) + return _internal_external_data(index); +} +inline ::onnx::StringStringEntryProto* TensorProto::_internal_add_external_data() { + return external_data_.Add(); +} +inline ::onnx::StringStringEntryProto* TensorProto::add_external_data() { + // @@protoc_insertion_point(field_add:onnx.TensorProto.external_data) + return _internal_add_external_data(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::StringStringEntryProto >& +TensorProto::external_data() const { + // @@protoc_insertion_point(field_list:onnx.TensorProto.external_data) + return external_data_; +} + +// optional .onnx.TensorProto.DataLocation data_location = 14; +inline bool TensorProto::_internal_has_data_location() const { + bool value = (_has_bits_[0] & 0x00000020u) != 0; + return value; +} +inline bool TensorProto::has_data_location() const { + return _internal_has_data_location(); +} +inline void TensorProto::clear_data_location() { + data_location_ = 0; + _has_bits_[0] &= ~0x00000020u; +} +inline ::onnx::TensorProto_DataLocation TensorProto::_internal_data_location() const { + return static_cast< ::onnx::TensorProto_DataLocation >(data_location_); +} +inline ::onnx::TensorProto_DataLocation TensorProto::data_location() const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.data_location) + return _internal_data_location(); +} +inline void TensorProto::_internal_set_data_location(::onnx::TensorProto_DataLocation value) { + assert(::onnx::TensorProto_DataLocation_IsValid(value)); + _has_bits_[0] |= 0x00000020u; + data_location_ = value; +} +inline void TensorProto::set_data_location(::onnx::TensorProto_DataLocation value) { + _internal_set_data_location(value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.data_location) +} + +// repeated double double_data = 10 [packed = true]; +inline int TensorProto::_internal_double_data_size() const { + return double_data_.size(); +} +inline int TensorProto::double_data_size() const { + return _internal_double_data_size(); +} +inline void TensorProto::clear_double_data() { + double_data_.Clear(); +} +inline double TensorProto::_internal_double_data(int index) const { + return double_data_.Get(index); +} +inline double TensorProto::double_data(int index) const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.double_data) + return _internal_double_data(index); +} +inline void TensorProto::set_double_data(int index, double value) { + double_data_.Set(index, value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.double_data) +} +inline void TensorProto::_internal_add_double_data(double value) { + double_data_.Add(value); +} +inline void TensorProto::add_double_data(double value) { + _internal_add_double_data(value); + // @@protoc_insertion_point(field_add:onnx.TensorProto.double_data) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& +TensorProto::_internal_double_data() const { + return double_data_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& +TensorProto::double_data() const { + // @@protoc_insertion_point(field_list:onnx.TensorProto.double_data) + return _internal_double_data(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* +TensorProto::_internal_mutable_double_data() { + return &double_data_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* +TensorProto::mutable_double_data() { + // @@protoc_insertion_point(field_mutable_list:onnx.TensorProto.double_data) + return _internal_mutable_double_data(); +} + +// repeated uint64 uint64_data = 11 [packed = true]; +inline int TensorProto::_internal_uint64_data_size() const { + return uint64_data_.size(); +} +inline int TensorProto::uint64_data_size() const { + return _internal_uint64_data_size(); +} +inline void TensorProto::clear_uint64_data() { + uint64_data_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 TensorProto::_internal_uint64_data(int index) const { + return uint64_data_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 TensorProto::uint64_data(int index) const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.uint64_data) + return _internal_uint64_data(index); +} +inline void TensorProto::set_uint64_data(int index, ::PROTOBUF_NAMESPACE_ID::uint64 value) { + uint64_data_.Set(index, value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.uint64_data) +} +inline void TensorProto::_internal_add_uint64_data(::PROTOBUF_NAMESPACE_ID::uint64 value) { + uint64_data_.Add(value); +} +inline void TensorProto::add_uint64_data(::PROTOBUF_NAMESPACE_ID::uint64 value) { + _internal_add_uint64_data(value); + // @@protoc_insertion_point(field_add:onnx.TensorProto.uint64_data) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >& +TensorProto::_internal_uint64_data() const { + return uint64_data_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >& +TensorProto::uint64_data() const { + // @@protoc_insertion_point(field_list:onnx.TensorProto.uint64_data) + return _internal_uint64_data(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >* +TensorProto::_internal_mutable_uint64_data() { + return &uint64_data_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >* +TensorProto::mutable_uint64_data() { + // @@protoc_insertion_point(field_mutable_list:onnx.TensorProto.uint64_data) + return _internal_mutable_uint64_data(); +} + +// ------------------------------------------------------------------- + +// SparseTensorProto + +// optional .onnx.TensorProto values = 1; +inline bool SparseTensorProto::_internal_has_values() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + PROTOBUF_ASSUME(!value || values_ != nullptr); + return value; +} +inline bool SparseTensorProto::has_values() const { + return _internal_has_values(); +} +inline void SparseTensorProto::clear_values() { + if (values_ != nullptr) values_->Clear(); + _has_bits_[0] &= ~0x00000001u; +} +inline const ::onnx::TensorProto& SparseTensorProto::_internal_values() const { + const ::onnx::TensorProto* p = values_; + return p != nullptr ? *p : *reinterpret_cast( + &::onnx::_TensorProto_default_instance_); +} +inline const ::onnx::TensorProto& SparseTensorProto::values() const { + // @@protoc_insertion_point(field_get:onnx.SparseTensorProto.values) + return _internal_values(); +} +inline ::onnx::TensorProto* SparseTensorProto::release_values() { + // @@protoc_insertion_point(field_release:onnx.SparseTensorProto.values) + _has_bits_[0] &= ~0x00000001u; + ::onnx::TensorProto* temp = values_; + values_ = nullptr; + return temp; +} +inline ::onnx::TensorProto* SparseTensorProto::_internal_mutable_values() { + _has_bits_[0] |= 0x00000001u; + if (values_ == nullptr) { + auto* p = CreateMaybeMessage<::onnx::TensorProto>(GetArenaNoVirtual()); + values_ = p; + } + return values_; +} +inline ::onnx::TensorProto* SparseTensorProto::mutable_values() { + // @@protoc_insertion_point(field_mutable:onnx.SparseTensorProto.values) + return _internal_mutable_values(); +} +inline void SparseTensorProto::set_allocated_values(::onnx::TensorProto* values) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArenaNoVirtual(); + if (message_arena == nullptr) { + delete values_; + } + if (values) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = nullptr; + if (message_arena != submessage_arena) { + values = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, values, submessage_arena); + } + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + values_ = values; + // @@protoc_insertion_point(field_set_allocated:onnx.SparseTensorProto.values) +} + +// optional .onnx.TensorProto indices = 2; +inline bool SparseTensorProto::_internal_has_indices() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + PROTOBUF_ASSUME(!value || indices_ != nullptr); + return value; +} +inline bool SparseTensorProto::has_indices() const { + return _internal_has_indices(); +} +inline void SparseTensorProto::clear_indices() { + if (indices_ != nullptr) indices_->Clear(); + _has_bits_[0] &= ~0x00000002u; +} +inline const ::onnx::TensorProto& SparseTensorProto::_internal_indices() const { + const ::onnx::TensorProto* p = indices_; + return p != nullptr ? *p : *reinterpret_cast( + &::onnx::_TensorProto_default_instance_); +} +inline const ::onnx::TensorProto& SparseTensorProto::indices() const { + // @@protoc_insertion_point(field_get:onnx.SparseTensorProto.indices) + return _internal_indices(); +} +inline ::onnx::TensorProto* SparseTensorProto::release_indices() { + // @@protoc_insertion_point(field_release:onnx.SparseTensorProto.indices) + _has_bits_[0] &= ~0x00000002u; + ::onnx::TensorProto* temp = indices_; + indices_ = nullptr; + return temp; +} +inline ::onnx::TensorProto* SparseTensorProto::_internal_mutable_indices() { + _has_bits_[0] |= 0x00000002u; + if (indices_ == nullptr) { + auto* p = CreateMaybeMessage<::onnx::TensorProto>(GetArenaNoVirtual()); + indices_ = p; + } + return indices_; +} +inline ::onnx::TensorProto* SparseTensorProto::mutable_indices() { + // @@protoc_insertion_point(field_mutable:onnx.SparseTensorProto.indices) + return _internal_mutable_indices(); +} +inline void SparseTensorProto::set_allocated_indices(::onnx::TensorProto* indices) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArenaNoVirtual(); + if (message_arena == nullptr) { + delete indices_; + } + if (indices) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = nullptr; + if (message_arena != submessage_arena) { + indices = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, indices, submessage_arena); + } + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + indices_ = indices; + // @@protoc_insertion_point(field_set_allocated:onnx.SparseTensorProto.indices) +} + +// repeated int64 dims = 3; +inline int SparseTensorProto::_internal_dims_size() const { + return dims_.size(); +} +inline int SparseTensorProto::dims_size() const { + return _internal_dims_size(); +} +inline void SparseTensorProto::clear_dims() { + dims_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 SparseTensorProto::_internal_dims(int index) const { + return dims_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 SparseTensorProto::dims(int index) const { + // @@protoc_insertion_point(field_get:onnx.SparseTensorProto.dims) + return _internal_dims(index); +} +inline void SparseTensorProto::set_dims(int index, ::PROTOBUF_NAMESPACE_ID::int64 value) { + dims_.Set(index, value); + // @@protoc_insertion_point(field_set:onnx.SparseTensorProto.dims) +} +inline void SparseTensorProto::_internal_add_dims(::PROTOBUF_NAMESPACE_ID::int64 value) { + dims_.Add(value); +} +inline void SparseTensorProto::add_dims(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_add_dims(value); + // @@protoc_insertion_point(field_add:onnx.SparseTensorProto.dims) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +SparseTensorProto::_internal_dims() const { + return dims_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +SparseTensorProto::dims() const { + // @@protoc_insertion_point(field_list:onnx.SparseTensorProto.dims) + return _internal_dims(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +SparseTensorProto::_internal_mutable_dims() { + return &dims_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +SparseTensorProto::mutable_dims() { + // @@protoc_insertion_point(field_mutable_list:onnx.SparseTensorProto.dims) + return _internal_mutable_dims(); +} + +// ------------------------------------------------------------------- + +// TensorShapeProto_Dimension + +// optional int64 dim_value = 1; +inline bool TensorShapeProto_Dimension::_internal_has_dim_value() const { + return value_case() == kDimValue; +} +inline bool TensorShapeProto_Dimension::has_dim_value() const { + return _internal_has_dim_value(); +} +inline void TensorShapeProto_Dimension::set_has_dim_value() { + _oneof_case_[0] = kDimValue; +} +inline void TensorShapeProto_Dimension::clear_dim_value() { + if (_internal_has_dim_value()) { + value_.dim_value_ = PROTOBUF_LONGLONG(0); + clear_has_value(); + } +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorShapeProto_Dimension::_internal_dim_value() const { + if (_internal_has_dim_value()) { + return value_.dim_value_; + } + return PROTOBUF_LONGLONG(0); +} +inline void TensorShapeProto_Dimension::_internal_set_dim_value(::PROTOBUF_NAMESPACE_ID::int64 value) { + if (!_internal_has_dim_value()) { + clear_value(); + set_has_dim_value(); + } + value_.dim_value_ = value; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorShapeProto_Dimension::dim_value() const { + // @@protoc_insertion_point(field_get:onnx.TensorShapeProto.Dimension.dim_value) + return _internal_dim_value(); +} +inline void TensorShapeProto_Dimension::set_dim_value(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_dim_value(value); + // @@protoc_insertion_point(field_set:onnx.TensorShapeProto.Dimension.dim_value) +} + +// optional string dim_param = 2; +inline bool TensorShapeProto_Dimension::_internal_has_dim_param() const { + return value_case() == kDimParam; +} +inline bool TensorShapeProto_Dimension::has_dim_param() const { + return _internal_has_dim_param(); +} +inline void TensorShapeProto_Dimension::set_has_dim_param() { + _oneof_case_[0] = kDimParam; +} +inline void TensorShapeProto_Dimension::clear_dim_param() { + if (_internal_has_dim_param()) { + value_.dim_param_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + clear_has_value(); + } +} +inline const std::string& TensorShapeProto_Dimension::dim_param() const { + // @@protoc_insertion_point(field_get:onnx.TensorShapeProto.Dimension.dim_param) + return _internal_dim_param(); +} +inline void TensorShapeProto_Dimension::set_dim_param(const std::string& value) { + _internal_set_dim_param(value); + // @@protoc_insertion_point(field_set:onnx.TensorShapeProto.Dimension.dim_param) +} +inline std::string* TensorShapeProto_Dimension::mutable_dim_param() { + // @@protoc_insertion_point(field_mutable:onnx.TensorShapeProto.Dimension.dim_param) + return _internal_mutable_dim_param(); +} +inline const std::string& TensorShapeProto_Dimension::_internal_dim_param() const { + if (_internal_has_dim_param()) { + return value_.dim_param_.GetNoArena(); + } + return *&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(); +} +inline void TensorShapeProto_Dimension::_internal_set_dim_param(const std::string& value) { + if (!_internal_has_dim_param()) { + clear_value(); + set_has_dim_param(); + value_.dim_param_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + value_.dim_param_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void TensorShapeProto_Dimension::set_dim_param(std::string&& value) { + // @@protoc_insertion_point(field_set:onnx.TensorShapeProto.Dimension.dim_param) + if (!_internal_has_dim_param()) { + clear_value(); + set_has_dim_param(); + value_.dim_param_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + value_.dim_param_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.TensorShapeProto.Dimension.dim_param) +} +inline void TensorShapeProto_Dimension::set_dim_param(const char* value) { + GOOGLE_DCHECK(value != nullptr); + if (!_internal_has_dim_param()) { + clear_value(); + set_has_dim_param(); + value_.dim_param_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + value_.dim_param_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.TensorShapeProto.Dimension.dim_param) +} +inline void TensorShapeProto_Dimension::set_dim_param(const char* value, size_t size) { + if (!_internal_has_dim_param()) { + clear_value(); + set_has_dim_param(); + value_.dim_param_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + value_.dim_param_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.TensorShapeProto.Dimension.dim_param) +} +inline std::string* TensorShapeProto_Dimension::_internal_mutable_dim_param() { + if (!_internal_has_dim_param()) { + clear_value(); + set_has_dim_param(); + value_.dim_param_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + return value_.dim_param_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* TensorShapeProto_Dimension::release_dim_param() { + // @@protoc_insertion_point(field_release:onnx.TensorShapeProto.Dimension.dim_param) + if (_internal_has_dim_param()) { + clear_has_value(); + return value_.dim_param_.ReleaseNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } else { + return nullptr; + } +} +inline void TensorShapeProto_Dimension::set_allocated_dim_param(std::string* dim_param) { + if (has_value()) { + clear_value(); + } + if (dim_param != nullptr) { + set_has_dim_param(); + value_.dim_param_.UnsafeSetDefault(dim_param); + } + // @@protoc_insertion_point(field_set_allocated:onnx.TensorShapeProto.Dimension.dim_param) +} + +// optional string denotation = 3; +inline bool TensorShapeProto_Dimension::_internal_has_denotation() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool TensorShapeProto_Dimension::has_denotation() const { + return _internal_has_denotation(); +} +inline void TensorShapeProto_Dimension::clear_denotation() { + denotation_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& TensorShapeProto_Dimension::denotation() const { + // @@protoc_insertion_point(field_get:onnx.TensorShapeProto.Dimension.denotation) + return _internal_denotation(); +} +inline void TensorShapeProto_Dimension::set_denotation(const std::string& value) { + _internal_set_denotation(value); + // @@protoc_insertion_point(field_set:onnx.TensorShapeProto.Dimension.denotation) +} +inline std::string* TensorShapeProto_Dimension::mutable_denotation() { + // @@protoc_insertion_point(field_mutable:onnx.TensorShapeProto.Dimension.denotation) + return _internal_mutable_denotation(); +} +inline const std::string& TensorShapeProto_Dimension::_internal_denotation() const { + return denotation_.GetNoArena(); +} +inline void TensorShapeProto_Dimension::_internal_set_denotation(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + denotation_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void TensorShapeProto_Dimension::set_denotation(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + denotation_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.TensorShapeProto.Dimension.denotation) +} +inline void TensorShapeProto_Dimension::set_denotation(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + denotation_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.TensorShapeProto.Dimension.denotation) +} +inline void TensorShapeProto_Dimension::set_denotation(const char* value, size_t size) { + _has_bits_[0] |= 0x00000001u; + denotation_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.TensorShapeProto.Dimension.denotation) +} +inline std::string* TensorShapeProto_Dimension::_internal_mutable_denotation() { + _has_bits_[0] |= 0x00000001u; + return denotation_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* TensorShapeProto_Dimension::release_denotation() { + // @@protoc_insertion_point(field_release:onnx.TensorShapeProto.Dimension.denotation) + if (!_internal_has_denotation()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return denotation_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void TensorShapeProto_Dimension::set_allocated_denotation(std::string* denotation) { + if (denotation != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + denotation_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), denotation); + // @@protoc_insertion_point(field_set_allocated:onnx.TensorShapeProto.Dimension.denotation) +} + +inline bool TensorShapeProto_Dimension::has_value() const { + return value_case() != VALUE_NOT_SET; +} +inline void TensorShapeProto_Dimension::clear_has_value() { + _oneof_case_[0] = VALUE_NOT_SET; +} +inline TensorShapeProto_Dimension::ValueCase TensorShapeProto_Dimension::value_case() const { + return TensorShapeProto_Dimension::ValueCase(_oneof_case_[0]); +} +// ------------------------------------------------------------------- + +// TensorShapeProto + +// repeated .onnx.TensorShapeProto.Dimension dim = 1; +inline int TensorShapeProto::_internal_dim_size() const { + return dim_.size(); +} +inline int TensorShapeProto::dim_size() const { + return _internal_dim_size(); +} +inline void TensorShapeProto::clear_dim() { + dim_.Clear(); +} +inline ::onnx::TensorShapeProto_Dimension* TensorShapeProto::mutable_dim(int index) { + // @@protoc_insertion_point(field_mutable:onnx.TensorShapeProto.dim) + return dim_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorShapeProto_Dimension >* +TensorShapeProto::mutable_dim() { + // @@protoc_insertion_point(field_mutable_list:onnx.TensorShapeProto.dim) + return &dim_; +} +inline const ::onnx::TensorShapeProto_Dimension& TensorShapeProto::_internal_dim(int index) const { + return dim_.Get(index); +} +inline const ::onnx::TensorShapeProto_Dimension& TensorShapeProto::dim(int index) const { + // @@protoc_insertion_point(field_get:onnx.TensorShapeProto.dim) + return _internal_dim(index); +} +inline ::onnx::TensorShapeProto_Dimension* TensorShapeProto::_internal_add_dim() { + return dim_.Add(); +} +inline ::onnx::TensorShapeProto_Dimension* TensorShapeProto::add_dim() { + // @@protoc_insertion_point(field_add:onnx.TensorShapeProto.dim) + return _internal_add_dim(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorShapeProto_Dimension >& +TensorShapeProto::dim() const { + // @@protoc_insertion_point(field_list:onnx.TensorShapeProto.dim) + return dim_; +} + +// ------------------------------------------------------------------- + +// TypeProto_Tensor + +// optional int32 elem_type = 1; +inline bool TypeProto_Tensor::_internal_has_elem_type() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool TypeProto_Tensor::has_elem_type() const { + return _internal_has_elem_type(); +} +inline void TypeProto_Tensor::clear_elem_type() { + elem_type_ = 0; + _has_bits_[0] &= ~0x00000002u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 TypeProto_Tensor::_internal_elem_type() const { + return elem_type_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 TypeProto_Tensor::elem_type() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.Tensor.elem_type) + return _internal_elem_type(); +} +inline void TypeProto_Tensor::_internal_set_elem_type(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000002u; + elem_type_ = value; +} +inline void TypeProto_Tensor::set_elem_type(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_elem_type(value); + // @@protoc_insertion_point(field_set:onnx.TypeProto.Tensor.elem_type) +} + +// optional .onnx.TensorShapeProto shape = 2; +inline bool TypeProto_Tensor::_internal_has_shape() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + PROTOBUF_ASSUME(!value || shape_ != nullptr); + return value; +} +inline bool TypeProto_Tensor::has_shape() const { + return _internal_has_shape(); +} +inline void TypeProto_Tensor::clear_shape() { + if (shape_ != nullptr) shape_->Clear(); + _has_bits_[0] &= ~0x00000001u; +} +inline const ::onnx::TensorShapeProto& TypeProto_Tensor::_internal_shape() const { + const ::onnx::TensorShapeProto* p = shape_; + return p != nullptr ? *p : *reinterpret_cast( + &::onnx::_TensorShapeProto_default_instance_); +} +inline const ::onnx::TensorShapeProto& TypeProto_Tensor::shape() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.Tensor.shape) + return _internal_shape(); +} +inline ::onnx::TensorShapeProto* TypeProto_Tensor::release_shape() { + // @@protoc_insertion_point(field_release:onnx.TypeProto.Tensor.shape) + _has_bits_[0] &= ~0x00000001u; + ::onnx::TensorShapeProto* temp = shape_; + shape_ = nullptr; + return temp; +} +inline ::onnx::TensorShapeProto* TypeProto_Tensor::_internal_mutable_shape() { + _has_bits_[0] |= 0x00000001u; + if (shape_ == nullptr) { + auto* p = CreateMaybeMessage<::onnx::TensorShapeProto>(GetArenaNoVirtual()); + shape_ = p; + } + return shape_; +} +inline ::onnx::TensorShapeProto* TypeProto_Tensor::mutable_shape() { + // @@protoc_insertion_point(field_mutable:onnx.TypeProto.Tensor.shape) + return _internal_mutable_shape(); +} +inline void TypeProto_Tensor::set_allocated_shape(::onnx::TensorShapeProto* shape) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArenaNoVirtual(); + if (message_arena == nullptr) { + delete shape_; + } + if (shape) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = nullptr; + if (message_arena != submessage_arena) { + shape = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, shape, submessage_arena); + } + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + shape_ = shape; + // @@protoc_insertion_point(field_set_allocated:onnx.TypeProto.Tensor.shape) +} + +// ------------------------------------------------------------------- + +// TypeProto_Sequence + +// optional .onnx.TypeProto elem_type = 1; +inline bool TypeProto_Sequence::_internal_has_elem_type() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + PROTOBUF_ASSUME(!value || elem_type_ != nullptr); + return value; +} +inline bool TypeProto_Sequence::has_elem_type() const { + return _internal_has_elem_type(); +} +inline void TypeProto_Sequence::clear_elem_type() { + if (elem_type_ != nullptr) elem_type_->Clear(); + _has_bits_[0] &= ~0x00000001u; +} +inline const ::onnx::TypeProto& TypeProto_Sequence::_internal_elem_type() const { + const ::onnx::TypeProto* p = elem_type_; + return p != nullptr ? *p : *reinterpret_cast( + &::onnx::_TypeProto_default_instance_); +} +inline const ::onnx::TypeProto& TypeProto_Sequence::elem_type() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.Sequence.elem_type) + return _internal_elem_type(); +} +inline ::onnx::TypeProto* TypeProto_Sequence::release_elem_type() { + // @@protoc_insertion_point(field_release:onnx.TypeProto.Sequence.elem_type) + _has_bits_[0] &= ~0x00000001u; + ::onnx::TypeProto* temp = elem_type_; + elem_type_ = nullptr; + return temp; +} +inline ::onnx::TypeProto* TypeProto_Sequence::_internal_mutable_elem_type() { + _has_bits_[0] |= 0x00000001u; + if (elem_type_ == nullptr) { + auto* p = CreateMaybeMessage<::onnx::TypeProto>(GetArenaNoVirtual()); + elem_type_ = p; + } + return elem_type_; +} +inline ::onnx::TypeProto* TypeProto_Sequence::mutable_elem_type() { + // @@protoc_insertion_point(field_mutable:onnx.TypeProto.Sequence.elem_type) + return _internal_mutable_elem_type(); +} +inline void TypeProto_Sequence::set_allocated_elem_type(::onnx::TypeProto* elem_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArenaNoVirtual(); + if (message_arena == nullptr) { + delete elem_type_; + } + if (elem_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = nullptr; + if (message_arena != submessage_arena) { + elem_type = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, elem_type, submessage_arena); + } + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + elem_type_ = elem_type; + // @@protoc_insertion_point(field_set_allocated:onnx.TypeProto.Sequence.elem_type) +} + +// ------------------------------------------------------------------- + +// TypeProto_Map + +// optional int32 key_type = 1; +inline bool TypeProto_Map::_internal_has_key_type() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool TypeProto_Map::has_key_type() const { + return _internal_has_key_type(); +} +inline void TypeProto_Map::clear_key_type() { + key_type_ = 0; + _has_bits_[0] &= ~0x00000002u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 TypeProto_Map::_internal_key_type() const { + return key_type_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 TypeProto_Map::key_type() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.Map.key_type) + return _internal_key_type(); +} +inline void TypeProto_Map::_internal_set_key_type(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000002u; + key_type_ = value; +} +inline void TypeProto_Map::set_key_type(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_key_type(value); + // @@protoc_insertion_point(field_set:onnx.TypeProto.Map.key_type) +} + +// optional .onnx.TypeProto value_type = 2; +inline bool TypeProto_Map::_internal_has_value_type() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + PROTOBUF_ASSUME(!value || value_type_ != nullptr); + return value; +} +inline bool TypeProto_Map::has_value_type() const { + return _internal_has_value_type(); +} +inline void TypeProto_Map::clear_value_type() { + if (value_type_ != nullptr) value_type_->Clear(); + _has_bits_[0] &= ~0x00000001u; +} +inline const ::onnx::TypeProto& TypeProto_Map::_internal_value_type() const { + const ::onnx::TypeProto* p = value_type_; + return p != nullptr ? *p : *reinterpret_cast( + &::onnx::_TypeProto_default_instance_); +} +inline const ::onnx::TypeProto& TypeProto_Map::value_type() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.Map.value_type) + return _internal_value_type(); +} +inline ::onnx::TypeProto* TypeProto_Map::release_value_type() { + // @@protoc_insertion_point(field_release:onnx.TypeProto.Map.value_type) + _has_bits_[0] &= ~0x00000001u; + ::onnx::TypeProto* temp = value_type_; + value_type_ = nullptr; + return temp; +} +inline ::onnx::TypeProto* TypeProto_Map::_internal_mutable_value_type() { + _has_bits_[0] |= 0x00000001u; + if (value_type_ == nullptr) { + auto* p = CreateMaybeMessage<::onnx::TypeProto>(GetArenaNoVirtual()); + value_type_ = p; + } + return value_type_; +} +inline ::onnx::TypeProto* TypeProto_Map::mutable_value_type() { + // @@protoc_insertion_point(field_mutable:onnx.TypeProto.Map.value_type) + return _internal_mutable_value_type(); +} +inline void TypeProto_Map::set_allocated_value_type(::onnx::TypeProto* value_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArenaNoVirtual(); + if (message_arena == nullptr) { + delete value_type_; + } + if (value_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = nullptr; + if (message_arena != submessage_arena) { + value_type = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, value_type, submessage_arena); + } + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + value_type_ = value_type; + // @@protoc_insertion_point(field_set_allocated:onnx.TypeProto.Map.value_type) +} + +// ------------------------------------------------------------------- + +// TypeProto + +// optional .onnx.TypeProto.Tensor tensor_type = 1; +inline bool TypeProto::_internal_has_tensor_type() const { + return value_case() == kTensorType; +} +inline bool TypeProto::has_tensor_type() const { + return _internal_has_tensor_type(); +} +inline void TypeProto::set_has_tensor_type() { + _oneof_case_[0] = kTensorType; +} +inline void TypeProto::clear_tensor_type() { + if (_internal_has_tensor_type()) { + delete value_.tensor_type_; + clear_has_value(); + } +} +inline ::onnx::TypeProto_Tensor* TypeProto::release_tensor_type() { + // @@protoc_insertion_point(field_release:onnx.TypeProto.tensor_type) + if (_internal_has_tensor_type()) { + clear_has_value(); + ::onnx::TypeProto_Tensor* temp = value_.tensor_type_; + value_.tensor_type_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::onnx::TypeProto_Tensor& TypeProto::_internal_tensor_type() const { + return _internal_has_tensor_type() + ? *value_.tensor_type_ + : *reinterpret_cast< ::onnx::TypeProto_Tensor*>(&::onnx::_TypeProto_Tensor_default_instance_); +} +inline const ::onnx::TypeProto_Tensor& TypeProto::tensor_type() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.tensor_type) + return _internal_tensor_type(); +} +inline ::onnx::TypeProto_Tensor* TypeProto::_internal_mutable_tensor_type() { + if (!_internal_has_tensor_type()) { + clear_value(); + set_has_tensor_type(); + value_.tensor_type_ = CreateMaybeMessage< ::onnx::TypeProto_Tensor >( + GetArenaNoVirtual()); + } + return value_.tensor_type_; +} +inline ::onnx::TypeProto_Tensor* TypeProto::mutable_tensor_type() { + // @@protoc_insertion_point(field_mutable:onnx.TypeProto.tensor_type) + return _internal_mutable_tensor_type(); +} + +// optional .onnx.TypeProto.Sequence sequence_type = 4; +inline bool TypeProto::_internal_has_sequence_type() const { + return value_case() == kSequenceType; +} +inline bool TypeProto::has_sequence_type() const { + return _internal_has_sequence_type(); +} +inline void TypeProto::set_has_sequence_type() { + _oneof_case_[0] = kSequenceType; +} +inline void TypeProto::clear_sequence_type() { + if (_internal_has_sequence_type()) { + delete value_.sequence_type_; + clear_has_value(); + } +} +inline ::onnx::TypeProto_Sequence* TypeProto::release_sequence_type() { + // @@protoc_insertion_point(field_release:onnx.TypeProto.sequence_type) + if (_internal_has_sequence_type()) { + clear_has_value(); + ::onnx::TypeProto_Sequence* temp = value_.sequence_type_; + value_.sequence_type_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::onnx::TypeProto_Sequence& TypeProto::_internal_sequence_type() const { + return _internal_has_sequence_type() + ? *value_.sequence_type_ + : *reinterpret_cast< ::onnx::TypeProto_Sequence*>(&::onnx::_TypeProto_Sequence_default_instance_); +} +inline const ::onnx::TypeProto_Sequence& TypeProto::sequence_type() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.sequence_type) + return _internal_sequence_type(); +} +inline ::onnx::TypeProto_Sequence* TypeProto::_internal_mutable_sequence_type() { + if (!_internal_has_sequence_type()) { + clear_value(); + set_has_sequence_type(); + value_.sequence_type_ = CreateMaybeMessage< ::onnx::TypeProto_Sequence >( + GetArenaNoVirtual()); + } + return value_.sequence_type_; +} +inline ::onnx::TypeProto_Sequence* TypeProto::mutable_sequence_type() { + // @@protoc_insertion_point(field_mutable:onnx.TypeProto.sequence_type) + return _internal_mutable_sequence_type(); +} + +// optional .onnx.TypeProto.Map map_type = 5; +inline bool TypeProto::_internal_has_map_type() const { + return value_case() == kMapType; +} +inline bool TypeProto::has_map_type() const { + return _internal_has_map_type(); +} +inline void TypeProto::set_has_map_type() { + _oneof_case_[0] = kMapType; +} +inline void TypeProto::clear_map_type() { + if (_internal_has_map_type()) { + delete value_.map_type_; + clear_has_value(); + } +} +inline ::onnx::TypeProto_Map* TypeProto::release_map_type() { + // @@protoc_insertion_point(field_release:onnx.TypeProto.map_type) + if (_internal_has_map_type()) { + clear_has_value(); + ::onnx::TypeProto_Map* temp = value_.map_type_; + value_.map_type_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::onnx::TypeProto_Map& TypeProto::_internal_map_type() const { + return _internal_has_map_type() + ? *value_.map_type_ + : *reinterpret_cast< ::onnx::TypeProto_Map*>(&::onnx::_TypeProto_Map_default_instance_); +} +inline const ::onnx::TypeProto_Map& TypeProto::map_type() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.map_type) + return _internal_map_type(); +} +inline ::onnx::TypeProto_Map* TypeProto::_internal_mutable_map_type() { + if (!_internal_has_map_type()) { + clear_value(); + set_has_map_type(); + value_.map_type_ = CreateMaybeMessage< ::onnx::TypeProto_Map >( + GetArenaNoVirtual()); + } + return value_.map_type_; +} +inline ::onnx::TypeProto_Map* TypeProto::mutable_map_type() { + // @@protoc_insertion_point(field_mutable:onnx.TypeProto.map_type) + return _internal_mutable_map_type(); +} + +// optional string denotation = 6; +inline bool TypeProto::_internal_has_denotation() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool TypeProto::has_denotation() const { + return _internal_has_denotation(); +} +inline void TypeProto::clear_denotation() { + denotation_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& TypeProto::denotation() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.denotation) + return _internal_denotation(); +} +inline void TypeProto::set_denotation(const std::string& value) { + _internal_set_denotation(value); + // @@protoc_insertion_point(field_set:onnx.TypeProto.denotation) +} +inline std::string* TypeProto::mutable_denotation() { + // @@protoc_insertion_point(field_mutable:onnx.TypeProto.denotation) + return _internal_mutable_denotation(); +} +inline const std::string& TypeProto::_internal_denotation() const { + return denotation_.GetNoArena(); +} +inline void TypeProto::_internal_set_denotation(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + denotation_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void TypeProto::set_denotation(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + denotation_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.TypeProto.denotation) +} +inline void TypeProto::set_denotation(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + denotation_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.TypeProto.denotation) +} +inline void TypeProto::set_denotation(const char* value, size_t size) { + _has_bits_[0] |= 0x00000001u; + denotation_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.TypeProto.denotation) +} +inline std::string* TypeProto::_internal_mutable_denotation() { + _has_bits_[0] |= 0x00000001u; + return denotation_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* TypeProto::release_denotation() { + // @@protoc_insertion_point(field_release:onnx.TypeProto.denotation) + if (!_internal_has_denotation()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return denotation_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void TypeProto::set_allocated_denotation(std::string* denotation) { + if (denotation != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + denotation_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), denotation); + // @@protoc_insertion_point(field_set_allocated:onnx.TypeProto.denotation) +} + +inline bool TypeProto::has_value() const { + return value_case() != VALUE_NOT_SET; +} +inline void TypeProto::clear_has_value() { + _oneof_case_[0] = VALUE_NOT_SET; +} +inline TypeProto::ValueCase TypeProto::value_case() const { + return TypeProto::ValueCase(_oneof_case_[0]); +} +// ------------------------------------------------------------------- + +// OperatorSetIdProto + +// optional string domain = 1; +inline bool OperatorSetIdProto::_internal_has_domain() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool OperatorSetIdProto::has_domain() const { + return _internal_has_domain(); +} +inline void OperatorSetIdProto::clear_domain() { + domain_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& OperatorSetIdProto::domain() const { + // @@protoc_insertion_point(field_get:onnx.OperatorSetIdProto.domain) + return _internal_domain(); +} +inline void OperatorSetIdProto::set_domain(const std::string& value) { + _internal_set_domain(value); + // @@protoc_insertion_point(field_set:onnx.OperatorSetIdProto.domain) +} +inline std::string* OperatorSetIdProto::mutable_domain() { + // @@protoc_insertion_point(field_mutable:onnx.OperatorSetIdProto.domain) + return _internal_mutable_domain(); +} +inline const std::string& OperatorSetIdProto::_internal_domain() const { + return domain_.GetNoArena(); +} +inline void OperatorSetIdProto::_internal_set_domain(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + domain_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value); +} +inline void OperatorSetIdProto::set_domain(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + domain_.SetNoArena( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:onnx.OperatorSetIdProto.domain) +} +inline void OperatorSetIdProto::set_domain(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + domain_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:onnx.OperatorSetIdProto.domain) +} +inline void OperatorSetIdProto::set_domain(const char* value, size_t size) { + _has_bits_[0] |= 0x00000001u; + domain_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:onnx.OperatorSetIdProto.domain) +} +inline std::string* OperatorSetIdProto::_internal_mutable_domain() { + _has_bits_[0] |= 0x00000001u; + return domain_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline std::string* OperatorSetIdProto::release_domain() { + // @@protoc_insertion_point(field_release:onnx.OperatorSetIdProto.domain) + if (!_internal_has_domain()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return domain_.ReleaseNonDefaultNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +inline void OperatorSetIdProto::set_allocated_domain(std::string* domain) { + if (domain != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + domain_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), domain); + // @@protoc_insertion_point(field_set_allocated:onnx.OperatorSetIdProto.domain) +} + +// optional int64 version = 2; +inline bool OperatorSetIdProto::_internal_has_version() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool OperatorSetIdProto::has_version() const { + return _internal_has_version(); +} +inline void OperatorSetIdProto::clear_version() { + version_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000002u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 OperatorSetIdProto::_internal_version() const { + return version_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 OperatorSetIdProto::version() const { + // @@protoc_insertion_point(field_get:onnx.OperatorSetIdProto.version) + return _internal_version(); +} +inline void OperatorSetIdProto::_internal_set_version(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000002u; + version_ = value; +} +inline void OperatorSetIdProto::set_version(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_version(value); + // @@protoc_insertion_point(field_set:onnx.OperatorSetIdProto.version) +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + + +// @@protoc_insertion_point(namespace_scope) + +} // namespace onnx + +PROTOBUF_NAMESPACE_OPEN + +template <> struct is_proto_enum< ::onnx::AttributeProto_AttributeType> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< ::onnx::AttributeProto_AttributeType>() { + return ::onnx::AttributeProto_AttributeType_descriptor(); +} +template <> struct is_proto_enum< ::onnx::TensorProto_DataType> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< ::onnx::TensorProto_DataType>() { + return ::onnx::TensorProto_DataType_descriptor(); +} +template <> struct is_proto_enum< ::onnx::TensorProto_DataLocation> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< ::onnx::TensorProto_DataLocation>() { + return ::onnx::TensorProto_DataLocation_descriptor(); +} +template <> struct is_proto_enum< ::onnx::Version> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< ::onnx::Version>() { + return ::onnx::Version_descriptor(); +} + +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_onnx_2eproto diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 6136aa9c7..58292db1d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -6,10 +6,10 @@ project(eddl-tests) enable_testing() if(GTEST_ROOT) # Find libraries (need absolute paths) - find_library(GTEST_LIBRARY gtest HINTS ${GTEST_ROOT} PATHS ${GTEST_ROOT} PATH_SUFFIXES "lib" "lib64") - find_library(GTEST_MAIN_LIBRARY gtest_main HINTS ${GTEST_ROOT} PATHS ${GTEST_ROOT} PATH_SUFFIXES "lib" "lib64") -# find_library(GTESTD_LIBRARY gtestd HINTS ${GTEST_ROOT} PATHS ${GTEST_ROOT} PATH_SUFFIXES "lib" "lib64") -# find_library(GTESTD_MAIN_LIBRARY gtest_maind HINTS ${GTEST_ROOT} PATHS ${GTEST_ROOT} PATH_SUFFIXES "lib" "lib64") + find_library(GTEST_LIBRARY NAMES gtest HINTS ${GTEST_ROOT} PATHS ${GTEST_ROOT} PATH_SUFFIXES "lib" "lib64") + find_library(GTEST_MAIN_LIBRARY NAMES gtest_main HINTS ${GTEST_ROOT} PATHS ${GTEST_ROOT} PATH_SUFFIXES "lib" "lib64") + find_library(GTESTD_LIBRARY NAMES gtestd HINTS ${GTEST_ROOT} PATHS ${GTEST_ROOT} PATH_SUFFIXES "lib" "lib64") + find_library(GTEST_MAIND_LIBRARY NAMES gtest_maind HINTS ${GTEST_ROOT} PATHS ${GTEST_ROOT} PATH_SUFFIXES "lib" "lib64") else() find_package(GTest REQUIRED) endif() @@ -17,25 +17,28 @@ endif() # Find tests (recursively, from here) file(GLOB_RECURSE CPP_TESTS_FILES "${PROJECT_SOURCE_DIR}/*" *.{h, cpp}) +# Filter ONNX files if they are not needed if(NOT BUILD_PROTOBUF) list(FILTER CPP_TESTS_FILES EXCLUDE REGEX ".*/onnx/*") endif() # Build test and target libraries add_executable(unit_tests ${CPP_TESTS_FILES}) -target_include_directories(unit_tests PUBLIC $) # TODO: Why build interface? - +target_include_directories(unit_tests PUBLIC $) +# Add libraries if(MSVC) - target_link_libraries(unit_tests PUBLIC eddl ${GTEST_LIBRARY} ${GTEST_MAIN_LIBRARY}) + target_link_libraries(unit_tests PUBLIC eddl + optimized ${GTEST_LIBRARY} optimized ${GTEST_MAIN_LIBRARY} + debug ${GTESTD_LIBRARY} debug ${GTEST_MAIND_LIBRARY} + ) else() - find_package(Threads) + find_package(Threads REQUIRED) target_link_libraries(unit_tests PUBLIC eddl ${GTEST_LIBRARY} ${GTEST_MAIN_LIBRARY} ${CMAKE_THREAD_LIBS_INIT}) endif() -# CUDA (set in parent scope) -if(USE_CUDA) +if (USE_CUDA) add_definitions(-DcGPU) endif() if(USE_CUDNN) @@ -45,6 +48,77 @@ if (USE_FPGA) add_definitions(-DcFPGA) endif() + +# CUDA (TEMP! We shouldn't need this) +if(USE_CUDA) + # Check if cuda is available + include(CheckLanguage) + check_language(CUDA) + + if (CMAKE_CUDA_COMPILER) + enable_language(CUDA) + find_package(CUDAToolkit) + + # NVCC needs GCC versions to be less or equal than 8 (GCC < 9.0;) + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 9) + message(FATAL_ERROR "[WARNING] The nvcc compiler in CUDA 10 (or later) does not support gcc versions later than 8 (DETECTED: ${CMAKE_CXX_COMPILER_VERSION}). + Hint: Use g++-8 (or older); Set other compiler version by using '-D CMAKE_CXX_COMPILER='$(which g++-8)' or creating a symbolic link.") + endif() + endif() + + if(USE_CUDNN) + add_definitions(-DcCUDNN) + endif() + add_definitions(-DcGPU) + + # Set standard CUDA variables + if(NOT DEFINED CMAKE_CUDA_STANDARD) + set(CMAKE_CUDA_STANDARD 11) + set(CMAKE_CUDA_STANDARD_REQUIRED ON) + endif() + + # Target properties + set_target_properties(unit_tests PROPERTIES + CUDA_SEPARABLE_COMPILATION ON + CUDA_RESOLVE_DEVICE_SYMBOLS ON + ) + + # Add source files + target_sources(unit_tests PRIVATE ${CUDA_HEADERS} ${CUDA_SOURCES}) + + # Add includes + target_include_directories(unit_tests PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) + + # # Find libraries (need absolute paths) + # find_library(CUBLAS_LIBRARY cublas HINTS ${CUDA_TOOLKIT_ROOT_DIR}) + # find_library(CUDART_LIBRARY cudart HINTS ${CUDA_TOOLKIT_ROOT_DIR}) + # find_library(CURAND_LIBRARY curand HINTS ${CUDA_TOOLKIT_ROOT_DIR}) + # target_link_libraries(unit_tests PRIVATE ${CUBLAS_LIBRARY} ${CUDART_LIBRARY} ${CURAND_LIBRARY}) + target_link_libraries(unit_tests PRIVATE CUDA::cublas CUDA::cudart CUDA::curand) + if(USE_CUDNN) + find_library(CUDNN_LIBRARY cudnn HINTS ${CUDAToolkit_LIBRARY_DIR}) + target_link_libraries(unit_tests PRIVATE ${CUDNN_LIBRARY}) + endif() + + if(APPLE) + # We need to add the path to the driver (libcuda.dylib) as an rpath, + # so that the static cuda runtime can find it at runtime. + set_property(TARGET unit_tests PROPERTY BUILD_RPATH ${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES}) + endif() + + if(NOT DEFINED CMAKE_CUDA_STANDARD) + # Make EDDL works for cuda 7.5 + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr -D_MWAITXINTRIN_H_INCLUDED -D_FORCE_INLINES -D__STRICT_ANSI__") + endif() + else() + message(WARNING "[WARNING] CUDA compiler not found but requested during compilation. (Falling back to: '-D USE_CUDA=OFF') + Hint: Set the NVCC path using '-D CMAKE_CUDA_COMPILER=path' or creating a symbolic link.") + SET(BUILD_TARGET "CPU") # Local's scope + endif() +endif() + + # Add test add_test(NAME unit_tests COMMAND unit_tests) diff --git a/tests/layers/conv/test_conv2d.cpp b/tests/layers/conv/test_conv2d.cpp index 2af239c34..0a79ffdca 100644 --- a/tests/layers/conv/test_conv2d.cpp +++ b/tests/layers/conv/test_conv2d.cpp @@ -171,7 +171,6 @@ TEST(Conv2DTestSuite, conv2d_k2x2_s2x2_pad_valid) } - TEST(Conv2DTestSuite, conv2d_k2x2_s2x2_pad_same) { // Image diff --git a/tests/layers/normalization/test_batchnorm.cpp b/tests/layers/normalization/test_batchnorm.cpp index 7a0e7f993..e4ceda68f 100644 --- a/tests/layers/normalization/test_batchnorm.cpp +++ b/tests/layers/normalization/test_batchnorm.cpp @@ -1,94 +1,95 @@ -//#include -//#include -// -//#include "eddl/tensor/tensor.h" -//#include "eddl/tensor/nn/tensor_nn.h" -//#include "eddl/descriptors/descriptors.h" -//#include "eddl/layers/normalization/layer_normalization.h" -// -//using namespace std; -// -// -//TEST(NormalizationTestSuite, batchnorm){ -// -// // Image -// auto *ptr_img = new float[1*3*5*5]{-0.66, 1.88, -0.09, 2.00, -1.26, -// -0.96, 1.49, -0.34, -0.12, -0.09, -// -0.19, -0.60, -1.60, -0.84, -1.44, -// -0.83, -0.06, 0.01, -0.81, -0.90, -// 0.43, 0.82, -0.46, -0.10, -0.17, -// -// 0.32, -1.09, 0.52, 1.19, 0.76, -// 0.16, 1.07, -1.08, 0.14, -2.00, -// 0.94, 1.24, 0.23, -0.77, 0.23, -// 0.09, -1.64, 2.31, 0.09, 0.98, -// 0.23, -2.14, -1.47, 1.18, -0.02, -// -// 0.75, -0.97, 0.47, 0.67, -0.03, -// 0.77, 0.27, 1.16, 0.62, 1.39, -// -0.23, 0.51, 0.26, 0.75, -0.08, -// 0.14, 0.17, 0.89, 0.19, -0.44, -// 0.98, 0.66, 0.48, -0.20, 0.10,}; -// auto* t_image = new Tensor({1, 3, 5, 5}, ptr_img, DEV_CPU); -// -// // Forward -// auto *ptr_fwrd_ref = new float[1*3*5*5]{ -// -0.51, 2.27, 0.12, 2.41, -1.16, -// -0.84, 1.84, -0.15, 0.08, 0.12, -// 0.00, -0.45, -1.53, -0.70, -1.36, -// -0.70, 0.15, 0.23, -0.67, -0.77, -// 0.68, 1.11, -0.29, 0.10, 0.02, -// -// 0.24, -1.05, 0.42, 1.03, 0.64, -// 0.09, 0.92, -1.04, 0.07, -1.87, -// 0.81, 1.07, 0.15, -0.75, 0.15, -// 0.02, -1.55, 2.05, 0.03, 0.84, -// 0.15, -2.00, -1.40, 1.02, -0.07, -// -// 0.73, -2.58, 0.19, 0.58, -0.78, -// 0.77, -0.20, 1.52, 0.47, 1.96, -// -1.16, 0.26, -0.21, 0.73, -0.86, -// -0.44, -0.38, 1.00, -0.35, -1.56, -// 1.17, 0.56, 0.20, -1.09, -0.53, -// }; -// auto* t_fwrd_ref = new Tensor({1, 3, 5, 5}, ptr_fwrd_ref, DEV_CPU); -// -// // Mean -// auto *ptr_mean_ref = new float[3]{-0.0196, 0.0058, 0.0370}; -// auto* t_mean_ref = new Tensor({3}, ptr_mean_ref, DEV_CPU); -// -// // Var -// auto *ptr_var_ref = new float[3]{0.9870, 1.0254, 0.9280}; -// auto* t_var_ref = new Tensor({3}, ptr_var_ref, DEV_CPU); -// t_var_ref->add(10e-5); -// t_var_ref->sqrt(); -// t_var_ref->inv(); -// -// -// // Forward -// auto* t_output = Tensor::empty_like(t_image); -// auto* t_opa = Tensor::empty_like(t_image); -// -// auto* t_mean_acc = Tensor::zeros({3}); -// auto* t_var_acc = Tensor::ones({3}); -// -// auto* t_mean = Tensor::zeros({3}); -// auto* t_var = Tensor::ones({3}); -// -// auto* t_gamma = Tensor::ones({3}); -// auto* t_beta = Tensor::zeros({3}); -//// -//// -//// -//// cout << "Mean" << endl; -//// t_mean_ref->print(3); -//// t_mean->print(3); -//// -//// cout << "Var" << endl; -//// t_var_ref->print(3); -//// t_var->print(3); -//// ASSERT_TRUE((bool) Tensor::equivalent(t_fwrd_ref, t_output, 10e-2f)); -//// ASSERT_TRUE((bool) Tensor::equivalent(t_mean_ref, t_mean, 10e-2f)); -//// ASSERT_TRUE((bool) Tensor::equivalent(t_var_ref, t_var, 10e-2f)); -// int asd = 33; -//} +#include +#include + +#include "eddl/tensor/tensor.h" +#include "eddl/tensor/nn/tensor_nn.h" +#include "eddl/descriptors/descriptors.h" +#include "eddl/layers/normalization/layer_normalization.h" + +using namespace std; + + +TEST(NormalizationTestSuite, batchnorm){ + + // Image + auto *ptr_img = new float[1*3*5*5]{-0.66, 1.88, -0.09, 2.00, -1.26, + -0.96, 1.49, -0.34, -0.12, -0.09, + -0.19, -0.60, -1.60, -0.84, -1.44, + -0.83, -0.06, 0.01, -0.81, -0.90, + 0.43, 0.82, -0.46, -0.10, -0.17, + + 0.32, -1.09, 0.52, 1.19, 0.76, + 0.16, 1.07, -1.08, 0.14, -2.00, + 0.94, 1.24, 0.23, -0.77, 0.23, + 0.09, -1.64, 2.31, 0.09, 0.98, + 0.23, -2.14, -1.47, 1.18, -0.02, + + 0.75, -0.97, 0.47, 0.67, -0.03, + 0.77, 0.27, 1.16, 0.62, 1.39, + -0.23, 0.51, 0.26, 0.75, -0.08, + 0.14, 0.17, 0.89, 0.19, -0.44, + 0.98, 0.66, 0.48, -0.20, 0.10,}; + auto* t_image = new Tensor({1, 3, 5, 5}, ptr_img, DEV_CPU); + + // Forward + auto *ptr_fwrd_ref = new float[1*3*5*5]{ + -0.51, 2.27, 0.12, 2.41, -1.16, + -0.84, 1.84, -0.15, 0.08, 0.12, + 0.00, -0.45, -1.53, -0.70, -1.36, + -0.70, 0.15, 0.23, -0.67, -0.77, + 0.68, 1.11, -0.29, 0.10, 0.02, + + 0.24, -1.05, 0.42, 1.03, 0.64, + 0.09, 0.92, -1.04, 0.07, -1.87, + 0.81, 1.07, 0.15, -0.75, 0.15, + 0.02, -1.55, 2.05, 0.03, 0.84, + 0.15, -2.00, -1.40, 1.02, -0.07, + + 0.73, -2.58, 0.19, 0.58, -0.78, + 0.77, -0.20, 1.52, 0.47, 1.96, + -1.16, 0.26, -0.21, 0.73, -0.86, + -0.44, -0.38, 1.00, -0.35, -1.56, + 1.17, 0.56, 0.20, -1.09, -0.53, + }; + auto* t_fwrd_ref = new Tensor({1, 3, 5, 5}, ptr_fwrd_ref, DEV_CPU); + + // Mean + auto *ptr_mean_ref = new float[3]{-0.196, 0.058, 0.370}; + auto* t_mean_ref = new Tensor({3}, ptr_mean_ref, DEV_CPU); + + // Var + const float epsilon = 1e-5; + auto *ptr_var_ref = new float[3]{0.914, 1.098, 0.520}; + auto* t_var_ref = new Tensor({3}, ptr_var_ref, DEV_CPU); + t_var_ref->add(10e-5); + t_var_ref->sqrt(); + t_var_ref->inv(); + + + // Forward + auto* t_output = Tensor::empty_like(t_image); + auto* t_opa = Tensor::empty_like(t_image); + + auto* t_mean_acc = Tensor::zeros({3}); + auto* t_var_acc = Tensor::ones({3}); + + auto* t_mean = Tensor::zeros({3}); + auto* t_var = Tensor::ones({3}); + + auto* t_gamma = Tensor::ones({3}); + auto* t_beta = Tensor::zeros({3}); + + + + // BN_forward(t_image, t_gamma, t_beta, t_mean, t_var,0.1f,10e-5, TRMODE); + tensorNN::BatchNormForward(t_image, t_output, t_opa, + t_mean_acc, t_var_acc, + t_gamma, t_beta, + t_mean, t_var, TRMODE==TRMODE, 1e-5, 0.1f); + + + ASSERT_TRUE((bool) Tensor::equivalent(t_fwrd_ref, t_output, 1e-2f)); + ASSERT_TRUE((bool) Tensor::equivalent(t_mean_ref, t_mean, 1e-2f)); + ASSERT_TRUE((bool) Tensor::equivalent(t_var_ref, t_var, 1e-2f)); + int asd = 33; +}