diff --git a/INSTALL.rst b/INSTALL.rst index 4bd0faccdb87..65abe7aacf82 100644 --- a/INSTALL.rst +++ b/INSTALL.rst @@ -93,7 +93,7 @@ Intermediate MPI instructions (enables a few packages) make -j install -Intermediate MPI instructions (enables a few packages, explict compilers) +Intermediate MPI instructions (enables a few packages, explicit compilers) ------------------------------------------------------------------------- :: diff --git a/cmake/ctest/drivers/geminga/ctest_linux_nightly_serial_debug_muelu_epetra_geminga.cmake b/cmake/ctest/drivers/geminga/ctest_linux_nightly_serial_debug_muelu_epetra_geminga.cmake index 145721094165..c7d80e890e5d 100644 --- a/cmake/ctest/drivers/geminga/ctest_linux_nightly_serial_debug_muelu_epetra_geminga.cmake +++ b/cmake/ctest/drivers/geminga/ctest_linux_nightly_serial_debug_muelu_epetra_geminga.cmake @@ -76,7 +76,6 @@ SET(Trilinos_PACKAGES MueLu Xpetra) SET(EXTRA_CONFIGURE_OPTIONS "-DTrilinos_ENABLE_DEPENDENCY_UNIT_TESTS=OFF" - "-DTrilinos_ENABLE_Kokkos=OFF" "-DTrilinos_ENABLE_Tpetra=OFF" "-DTrilinos_ENABLE_ML=OFF" "-DTPL_ENABLE_SuperLU=ON" diff --git a/cmake/tribits/CHANGELOG.md b/cmake/tribits/CHANGELOG.md index de889ddac29c..558c787b124e 100644 --- a/cmake/tribits/CHANGELOG.md +++ b/cmake/tribits/CHANGELOG.md @@ -2,6 +2,20 @@ ChangeLog for TriBITS ---------------------------------------- +## 2022-08-22: + +* **Added:** Added support for exporting cache variables for packages in their + `Config.cmake` files using the new function + `tribits_pkg_export_cache_var()`. + +## 2022-08-18: + +* **Changed:** Made setting parent package tests/examples enable/disable + correctly propagate down to subpackages in a more intuitive way (see + [TriBITSPub/TriBITS#268](https://github.com/TriBITSPub/TriBITS/issues/268)). + This also results in not enabling tests for subpackages that are not + explicitly enabled or enabled as part of the forward sweep of packages + enables due to `_ENABLE_ALL_FORWARD_DEP_PACKAGES=ON`. ## 2022-08-11: @@ -11,13 +25,6 @@ ChangeLog for TriBITS and [TriBITSPub/TriBITS#510](https://github.com/TriBITSPub/TriBITS/issues/510)). -* **Changed:** Made setting parent package tests/examples enables correctly - propagate down to subpackages in a more intuitive way (see - [TriBITSPub/TriBITS#268](https://github.com/TriBITSPub/TriBITS/issues/268)). - This also results in not enabling tests for subpackages that are not - explicitly enabled or enabled as part of the forward sweep of packages - enables due to `_ENABLE_ALL_FORWARD_DEP_PACKAGES=ON`. - ## 2022-07-20: * **Changed:** Fixed TriBITS generated and installed `Config.cmake` diff --git a/cmake/tribits/common_tpls/FindTPLNetcdf.cmake b/cmake/tribits/common_tpls/FindTPLNetcdf.cmake index 9ffd66c06a63..13730f0457eb 100644 --- a/cmake/tribits/common_tpls/FindTPLNetcdf.cmake +++ b/cmake/tribits/common_tpls/FindTPLNetcdf.cmake @@ -130,7 +130,7 @@ if ("${TPL_Netcdf_PARALLEL}" STREQUAL "") string(REGEX MATCH "[01]" netcdf_par_val "${netcdf_par_string}") if (netcdf_par_val EQUAL 1) set(TPL_Netcdf_PARALLEL True CACHE INTERNAL - "True if netcdf compiled with parallel enabled") + "True if netcdf compiled with parallel enabled") endif() endif() if ("${TPL_Netcdf_PARALLEL}" STREQUAL "") diff --git a/cmake/tribits/common_tpls/find_modules/FindHDF5.cmake b/cmake/tribits/common_tpls/find_modules/FindHDF5.cmake index bed6cb04b1ca..1ec93541032f 100644 --- a/cmake/tribits/common_tpls/find_modules/FindHDF5.cmake +++ b/cmake/tribits/common_tpls/find_modules/FindHDF5.cmake @@ -389,10 +389,10 @@ else() foreach( _component ${HDF5_VALID_COMPONENTS} ) set(target ${HDF5_${_component}_TARGET}) - if ( TARGET ${target} ) - set(HDF5_${_component}_LIBRARY ${target}) - list(APPEND HDF5_LIBRARIES ${HDF5_${_component}_LIBRARY}) - endif() + if ( TARGET ${target} ) + set(HDF5_${_component}_LIBRARY ${target}) + list(APPEND HDF5_LIBRARIES ${HDF5_${_component}_LIBRARY}) + endif() endforeach() # Define HDF5_C_LIBRARIES to contain hdf5 and hdf5_hl C libraries @@ -470,7 +470,7 @@ else() LOCATION ${_HDF5_C_LIBRARY} LINK_LANGUAGES "C" LINK_INTERFACE_LIBRARIES "${HDF5_LINK_LIBRARIES}") - set(HDF5_C_LIBRARY ${HDF5_C_TARGET}) + set(HDF5_C_LIBRARY ${HDF5_C_TARGET}) # --- Search for the other possible component libraries @@ -495,7 +495,7 @@ else() # Define the HDF5__LIBRARY to point to the target foreach ( _component ${HDF5_VALID_COMPONENTS} ) if ( TARGET ${HDF5_${_component}_TARGET} ) - set(HDF5_${_component}_LIBRARY ${HDF5_${_component}_TARGET}) + set(HDF5_${_component}_LIBRARY ${HDF5_${_component}_TARGET}) endif() endforeach() @@ -513,7 +513,7 @@ else() set(HDF5_LIBRARIES) foreach (_component ${HDF5_VALID_COMPONENTS}) if ( TARGET ${HDF5_${_component}_TARGET} ) - list(APPEND HDF5_LIBRARIES ${_HDF5_${_component}_LIBRARY}) + list(APPEND HDF5_LIBRARIES ${_HDF5_${_component}_LIBRARY}) endif() endforeach() list(APPEND HDF5_LIBRARIES ${HDF5_LINK_LIBRARIES}) @@ -581,8 +581,8 @@ if ( NOT HDF5_FIND_QUIETLY ) set(HDF5_COMPONENTS_NOTFOUND) foreach (_component ${HDF5_VALID_COMPONENTS} ) if ( HDF5_${_component}_FOUND ) - #message(STATUS "\t HDF5_${_component}_LIBRARY\t\t=${HDF5_${_component}_LIBRARY}") - message(STATUS "\t${HDF5_${_component}_LIBRARY}") + #message(STATUS "\t HDF5_${_component}_LIBRARY\t\t=${HDF5_${_component}_LIBRARY}") + message(STATUS "\t${HDF5_${_component}_LIBRARY}") else() list(APPEND HDF5_COMPONENTS_NOTFOUND ${_component}) endif() diff --git a/cmake/tribits/common_tpls/find_modules/FindNetCDF.cmake b/cmake/tribits/common_tpls/find_modules/FindNetCDF.cmake index 960088f3e4cf..1e97c3e237bc 100644 --- a/cmake/tribits/common_tpls/find_modules/FindNetCDF.cmake +++ b/cmake/tribits/common_tpls/find_modules/FindNetCDF.cmake @@ -186,21 +186,21 @@ else(NetCDF_LIBRARIES AND NetCDF_INCLUDE_DIRS) set(NetCDF_LARGE_DIMS FALSE) endif() - set(NetCDF_PARALLEL False) + set(NetCDF_PARALLEL False) find_path(meta_path - NAMES "netcdf_meta.h" + NAMES "netcdf_meta.h" HINTS ${NetCDF_INCLUDE_DIR} NO_DEFAULT_PATH) if(meta_path) - # Search meta for NC_HAS_PARALLEL setting... - # Note that there is both NC_HAS_PARALLEL4 and NC_HAS_PARALLEL, only want NC_HAS_PARALLEL - # so add a space to end to avoid getting NC_HAS_PARALLEL4 - file(STRINGS "${meta_path}/netcdf_meta.h" netcdf_par_string REGEX "NC_HAS_PARALLEL ") - string(REGEX REPLACE "[^0-9]" "" netcdf_par_val "${netcdf_par_string}") - # NOTE: The line for NC_HAS_PARALLEL has an hdf5 string in it which results + # Search meta for NC_HAS_PARALLEL setting... + # Note that there is both NC_HAS_PARALLEL4 and NC_HAS_PARALLEL, only want NC_HAS_PARALLEL + # so add a space to end to avoid getting NC_HAS_PARALLEL4 + file(STRINGS "${meta_path}/netcdf_meta.h" netcdf_par_string REGEX "NC_HAS_PARALLEL ") + string(REGEX REPLACE "[^0-9]" "" netcdf_par_val "${netcdf_par_string}") + # NOTE: The line for NC_HAS_PARALLEL has an hdf5 string in it which results # netcdf_par_val being set to 05 or 15 above... - if (netcdf_par_val EQUAL 15) - set(NetCDF_PARALLEL True) + if (netcdf_par_val EQUAL 15) + set(NetCDF_PARALLEL True) endif() endif() @@ -291,8 +291,8 @@ else(NetCDF_LIBRARIES AND NetCDF_INCLUDE_DIRS) message(STATUS "\tNetCDF_ROOT is ${NetCDF_ROOT}") find_program(netcdf_config nc-config PATHS ${NetCDF_ROOT}/bin ${NetCDF_BIN_DIR} - NO_DEFAULT_PATH - NO_CMAKE_SYSTEM_PATH + NO_DEFAULT_PATH + NO_CMAKE_SYSTEM_PATH DOC "NetCDF configuration script") if (netcdf_config) diff --git a/cmake/tribits/core/installation/TribitsPackageConfigTemplate.cmake.in b/cmake/tribits/core/installation/TribitsPackageConfigTemplate.cmake.in index 15025561c294..465648d63ac3 100644 --- a/cmake/tribits/core/installation/TribitsPackageConfigTemplate.cmake.in +++ b/cmake/tribits/core/installation/TribitsPackageConfigTemplate.cmake.in @@ -45,6 +45,15 @@ # ############################################################################## +if(CMAKE_VERSION VERSION_LESS 3.3) + set(${PDOLLAR}{CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE + "${PACKAGE_NAME} requires CMake 3.3 or later for 'if (... IN_LIST ...)'" + ) + set(${PDOLLAR}{CMAKE_FIND_PACKAGE_NAME}_FOUND FALSE) + return() +endif() +cmake_minimum_required(VERSION 3.3...${TRIBITS_CMAKE_MINIMUM_REQUIRED}) + ## --------------------------------------------------------------------------- ## Compilers used by ${PROJECT_NAME}/${PACKAGE_NAME} build ## --------------------------------------------------------------------------- diff --git a/cmake/tribits/core/installation/TribitsProjectConfigTemplate.cmake.in b/cmake/tribits/core/installation/TribitsProjectConfigTemplate.cmake.in index 5872986017f6..67eb685ed483 100644 --- a/cmake/tribits/core/installation/TribitsProjectConfigTemplate.cmake.in +++ b/cmake/tribits/core/installation/TribitsProjectConfigTemplate.cmake.in @@ -46,6 +46,15 @@ # ############################################################################## +if(CMAKE_VERSION VERSION_LESS 3.3) + set(${PDOLLAR}{CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE + "${PROJECT_NAME} requires CMake 3.3 or later for 'if (... IN_LIST ...)'" + ) + set(${PDOLLAR}{CMAKE_FIND_PACKAGE_NAME}_FOUND FALSE) + return() +endif() +cmake_minimum_required(VERSION 3.3...${TRIBITS_CMAKE_MINIMUM_REQUIRED}) + ## --------------------------------------------------------------------------- ## Compilers used by ${PROJECT_NAME} build ## --------------------------------------------------------------------------- diff --git a/cmake/tribits/core/package_arch/TribitsAddAdvancedTest.cmake b/cmake/tribits/core/package_arch/TribitsAddAdvancedTest.cmake index 3f4d02a35ac4..2f4928c8f0b8 100644 --- a/cmake/tribits/core/package_arch/TribitsAddAdvancedTest.cmake +++ b/cmake/tribits/core/package_arch/TribitsAddAdvancedTest.cmake @@ -1097,7 +1097,7 @@ function(tribits_add_advanced_test TEST_NAME_IN) "" # multi_value_keywords "COPY_FILES_TO_TEST_DIR;SOURCE_DIR;DEST_DIR" - # Arguments to parse + # Arguments to parse ${PARSE_TEST_${TEST_CMND_IDX}} ) tribits_check_for_unparsed_arguments() @@ -1113,20 +1113,20 @@ function(tribits_add_advanced_test TEST_NAME_IN) # Parse TEST_ block args for types EXEC and CMND set(testBlockOptionsList NOEXEPREFIX NOEXESUFFIX NO_ECHO_OUTPUT PASS_ANY - STANDARD_PASS_OUTPUT ALWAYS_FAIL_ON_NONZERO_RETURN ALWAYS_FAIL_ON_ZERO_RETURN - WILL_FAIL ADD_DIR_TO_NAME SKIP_CLEAN_WORKING_DIRECTORY + STANDARD_PASS_OUTPUT ALWAYS_FAIL_ON_NONZERO_RETURN ALWAYS_FAIL_ON_ZERO_RETURN + WILL_FAIL ADD_DIR_TO_NAME SKIP_CLEAN_WORKING_DIRECTORY ) set(testBlockMultiValueKeywordsList EXEC CMND ARGS DIRECTORY MESSAGE - WORKING_DIRECTORY OUTPUT_FILE NUM_MPI_PROCS NUM_TOTAL_CORES_USED - PASS_REGULAR_EXPRESSION_ALL FAIL_REGULAR_EXPRESSION PASS_REGULAR_EXPRESSION - ) + WORKING_DIRECTORY OUTPUT_FILE NUM_MPI_PROCS NUM_TOTAL_CORES_USED + PASS_REGULAR_EXPRESSION_ALL FAIL_REGULAR_EXPRESSION PASS_REGULAR_EXPRESSION + ) cmake_parse_arguments( PARSE #prefix - "${testBlockOptionsList}" - "" # one_value_keywords - "${testBlockMultiValueKeywordsList}" + "${testBlockOptionsList}" + "" # one_value_keywords + "${testBlockMultiValueKeywordsList}" ${PARSE_TEST_${TEST_CMND_IDX}} ) diff --git a/cmake/tribits/core/package_arch/TribitsAddLibrary.cmake b/cmake/tribits/core/package_arch/TribitsAddLibrary.cmake index 26bddd91931e..bb32b1c3ae8e 100644 --- a/cmake/tribits/core/package_arch/TribitsAddLibrary.cmake +++ b/cmake/tribits/core/package_arch/TribitsAddLibrary.cmake @@ -701,7 +701,7 @@ function(tribits_add_library_assert_deplibs) else() message(WARNING "WARNING: The case PARSE_TESTONLY=${PARSE_TESTONLY}," " depLibAlreadyInPkgLibs=${depLibAlreadyInPkgLibs}," - " depLibIsTestOnlyLib=${depLibIsTestOnlyLib}, has" + " depLibIsTestOnlyLib=${depLibIsTestOnlyLib}, has" " not yet been handled!") endif() @@ -781,7 +781,7 @@ function(tribits_add_library_determine_install_lib_and_or_headers if (${PROJECT_NAME}_VERBOSE_CONFIGURE) message("-- " "Skipping installation of headers and libraries" " because ${PROJECT_NAME}_INSTALL_LIBRARIES_AND_HEADERS=FALSE and" - " BUILD_SHARED_LIBS=FALSE ...") + " BUILD_SHARED_LIBS=FALSE ...") endif() set(installLib OFF) set(installHeaders OFF) @@ -789,7 +789,7 @@ function(tribits_add_library_determine_install_lib_and_or_headers if (${PROJECT_NAME}_VERBOSE_CONFIGURE) message("-- " "Skipping installation of headers but installing libraries" " because ${PROJECT_NAME}_INSTALL_LIBRARIES_AND_HEADERS=FALSE and" - " BUILD_SHARED_LIBS=TRUE ...") + " BUILD_SHARED_LIBS=TRUE ...") endif() set(installHeaders OFF) endif() diff --git a/cmake/tribits/core/package_arch/TribitsAddOptionAndDefine.cmake b/cmake/tribits/core/package_arch/TribitsAddOptionAndDefine.cmake index 8c98b1717d74..ea812e51cded 100644 --- a/cmake/tribits/core/package_arch/TribitsAddOptionAndDefine.cmake +++ b/cmake/tribits/core/package_arch/TribitsAddOptionAndDefine.cmake @@ -37,12 +37,13 @@ # ************************************************************************ # @HEADER +include(TribitsPkgExportCacheVars) include(GlobalSet) # @MACRO: tribits_add_option_and_define() # -# Add an option and a define variable in one shot. +# Add an option and an optional macro define variable in one shot. # # Usage:: # @@ -58,6 +59,18 @@ include(GlobalSet) # # #cmakedefine # +# NOTE: This also calls `tribits_pkg_export_cache_var()`_ to export the +# variables ```` and ````. This also +# requires that local variables with the same names of these cache variables +# not be assigned with a different value from these cache variables. If they +# are, then an error will occur later when these variables are read. +# +# NOTE: The define var name ```` can be empty "" in which +# case all logic related to ```` is skipped. (But in this +# case, it would be better to just call:: +# +# set( CACHE BOOL "") +# macro(tribits_add_option_and_define USER_OPTION_NAME MACRO_DEFINE_NAME DOCSTRING DEFAULT_VALUE ) @@ -70,6 +83,10 @@ macro(tribits_add_option_and_define USER_OPTION_NAME MACRO_DEFINE_NAME global_set(${MACRO_DEFINE_NAME} OFF) endif() endif() + tribits_pkg_export_cache_var(${USER_OPTION_NAME}) + if(NOT ${MACRO_DEFINE_NAME} STREQUAL "") + tribits_pkg_export_cache_var(${MACRO_DEFINE_NAME}) + endif() endmacro() # 2008/10/05: rabartl: ToDo: Add an option to automatically add the macro diff --git a/cmake/tribits/core/package_arch/TribitsAddTest.cmake b/cmake/tribits/core/package_arch/TribitsAddTest.cmake index bdf500b5bf1b..5853d99a08f6 100644 --- a/cmake/tribits/core/package_arch/TribitsAddTest.cmake +++ b/cmake/tribits/core/package_arch/TribitsAddTest.cmake @@ -1009,7 +1009,7 @@ function(tribits_add_test EXE_NAME) "${EXECUTABLE_PATH}" "${PARSE_CATEGORIES}" "${NUM_PROCS_USED}" "${NUM_TOTAL_CORES_USED}" "${SET_RUN_SERIAL}" "${SET_DISABLED_AND_MSG}" ADDED_TEST_NAME ${INARGS} - "${${TEST_NAME_INSTANCE}_EXTRA_ARGS}" ) + "${${TEST_NAME_INSTANCE}_EXTRA_ARGS}" ) if(PARSE_ADDED_TESTS_NAMES_OUT AND ADDED_TEST_NAME) list(APPEND ADDED_TESTS_NAMES_OUT ${ADDED_TEST_NAME}) endif() @@ -1054,7 +1054,7 @@ function(tribits_add_test EXE_NAME) "${EXECUTABLE_PATH}" "${PARSE_CATEGORIES}" "${NUM_PROCS_USED}" "${NUM_TOTAL_CORES_USED}" "${SET_RUN_SERIAL}" "${SET_DISABLED_AND_MSG}" ADDED_TEST_NAME ${INARGS} - "${${TEST_NAME_INSTANCE}_EXTRA_ARGS}" + "${${TEST_NAME_INSTANCE}_EXTRA_ARGS}" ) if(PARSE_ADDED_TESTS_NAMES_OUT AND ADDED_TEST_NAME) list(APPEND ADDED_TESTS_NAMES_OUT ${ADDED_TEST_NAME}) diff --git a/cmake/tribits/core/package_arch/TribitsAddTestHelpers.cmake b/cmake/tribits/core/package_arch/TribitsAddTestHelpers.cmake index ca8d1b0bd9c8..453de5432672 100644 --- a/cmake/tribits/core/package_arch/TribitsAddTestHelpers.cmake +++ b/cmake/tribits/core/package_arch/TribitsAddTestHelpers.cmake @@ -658,10 +658,10 @@ function(tribits_add_test_process_skip_ctest_add_test ADD_THE_TEST_OUT) if(${PACKAGE_NAME}_SKIP_CTEST_ADD_TEST OR ${PARENT_PACKAGE_NAME}_SKIP_CTEST_ADD_TEST) if (PARENT_PACKAGE_NAME STREQUAL PACKAGE_NAME) set(DISABLE_VAR_MSG - "${PACKAGE_NAME}_SKIP_CTEST_ADD_TEST='${${PACKAGE_NAME}_SKIP_CTEST_ADD_TEST}'") + "${PACKAGE_NAME}_SKIP_CTEST_ADD_TEST='${${PACKAGE_NAME}_SKIP_CTEST_ADD_TEST}'") else() set(DISABLE_VAR_MSG - "${PARENT_PACKAGE_NAME}_SKIP_CTEST_ADD_TEST='${${PARENT_PACKAGE_NAME}_SKIP_CTEST_ADD_TEST}'") + "${PARENT_PACKAGE_NAME}_SKIP_CTEST_ADD_TEST='${${PARENT_PACKAGE_NAME}_SKIP_CTEST_ADD_TEST}'") endif() message_wrapper( "-- ${TEST_NAME}: NOT added test because ${DISABLE_VAR_MSG}!") diff --git a/cmake/tribits/core/package_arch/TribitsAdjustPackageEnables.cmake b/cmake/tribits/core/package_arch/TribitsAdjustPackageEnables.cmake index 84a5994320fa..5b26b5549781 100644 --- a/cmake/tribits/core/package_arch/TribitsAdjustPackageEnables.cmake +++ b/cmake/tribits/core/package_arch/TribitsAdjustPackageEnables.cmake @@ -999,10 +999,59 @@ macro(tribits_apply_test_example_enables PACKAGE_NAME) endmacro() -# Macro to set ${TRIBITS_SUBPACKAGE)_ENABLE_TESTS and +# Macro to disable ${PARENT_PACKAGE_NAME)_ENABLE_ENABLES by default if +# ${PARENT_PACKAGE_NAME)_ENABLE_TESTS is explicitly disabled. +# +macro(tribits_apply_package_examples_disable PARENT_PACKAGE_NAME) + if (NOT "${${PARENT_PACKAGE_NAME}_ENABLE_TESTS}" STREQUAL "" + AND NOT ${PARENT_PACKAGE_NAME}_ENABLE_TESTS + AND "${${PARENT_PACKAGE_NAME}_ENABLE_EXAMPLES}" STREQUAL "" + ) + message("-- " "Setting" + " ${PARENT_PACKAGE_NAME}_ENABLE_EXAMPLES" + "=${${PARENT_PACKAGE_NAME}_ENABLE_TESTS}" + " because" + " ${PARENT_PACKAGE_NAME}_ENABLE_TESTS" + "=${${PARENT_PACKAGE_NAME}_ENABLE_TESTS}" ) + set(${PARENT_PACKAGE_NAME}_ENABLE_EXAMPLES ${${PARENT_PACKAGE_NAME}_ENABLE_TESTS}) + endif() +endmacro() +# NOTE: Above, the top-level package ${PARENT_PACKAGE_NAME} may not even be +# enabled yet when this gets called but its subpackages might and we need to +# process this default disable in case their are any enabled subpackages. + + +# Macro to disable ${TRIBITS_SUBPACKAGE)_ENABLE_TESTS and # ${TRIBITS_SUBPACKAGE)_ENABLE_EXAMPLES based on # ${TRIBITS_PARENTPACKAGE)_ENABLE_TESTS or # ${TRIBITS_PARENTPACKAGE)_ENABLE_EXAMPLES +# +macro(tribits_apply_subpackage_tests_or_examples_disables PARENT_PACKAGE_NAME + TESTS_OR_EXAMPLES + ) + set(parentPkgEnableVar ${PARENT_PACKAGE_NAME}_ENABLE_${TESTS_OR_EXAMPLES}) + if (NOT "${${parentPkgEnableVar}}" STREQUAL "" AND NOT ${parentPkgEnableVar}) + foreach(spkg IN LISTS ${PARENT_PACKAGE_NAME}_SUBPACKAGES) + set(fullSpkgName ${PARENT_PACKAGE_NAME}${spkg}) + if (${PROJECT_NAME}_ENABLE_${fullSpkgName} AND NOT ${parentPkgEnableVar}) + if ("${${fullSpkgName}_ENABLE_${TESTS_OR_EXAMPLES}}" STREQUAL "") + message("-- " "Setting" + " ${fullSpkgName}_ENABLE_${TESTS_OR_EXAMPLES}=${${parentPkgEnableVar}}" + " because parent package" + " ${parentPkgEnableVar}=${${parentPkgEnableVar}}") + set(${fullSpkgName}_ENABLE_${TESTS_OR_EXAMPLES} ${${parentPkgEnableVar}}) + endif() + endif() + endforeach() + endif() +endmacro() + + +# Macro to enable ${TRIBITS_SUBPACKAGE)_ENABLE_TESTS and +# ${TRIBITS_SUBPACKAGE)_ENABLE_EXAMPLES based on +# ${TRIBITS_PARENTPACKAGE)_ENABLE_TESTS or +# ${TRIBITS_PARENTPACKAGE)_ENABLE_EXAMPLES +# macro(tribits_apply_subpackage_tests_examples_enables PARENT_PACKAGE_NAME) if ("${${PARENT_PACKAGE_NAME}_ENABLE_EXAMPLES}" STREQUAL "" AND ${PARENT_PACKAGE_NAME}_ENABLE_TESTS @@ -1019,20 +1068,20 @@ macro(tribits_apply_subpackage_tests_examples_enables PARENT_PACKAGE_NAME) if (${PARENT_PACKAGE_NAME}_ENABLE_TESTS) if ("${${fullSpkgName}_ENABLE_TESTS}" STREQUAL "") message("-- " "Setting" - " ${fullSpkgName}_ENABLE_TESTS=${${PARENT_PACKAGE_NAME}_ENABLE_TESTS}" - " because parent package" - " ${PARENT_PACKAGE_NAME}_ENABLE_TESTS" - "=${${PARENT_PACKAGE_NAME}_ENABLE_TESTS}") + " ${fullSpkgName}_ENABLE_TESTS=${${PARENT_PACKAGE_NAME}_ENABLE_TESTS}" + " because parent package" + " ${PARENT_PACKAGE_NAME}_ENABLE_TESTS" + "=${${PARENT_PACKAGE_NAME}_ENABLE_TESTS}") set(${fullSpkgName}_ENABLE_TESTS ${${PARENT_PACKAGE_NAME}_ENABLE_TESTS}) endif() endif() if (${PARENT_PACKAGE_NAME}_ENABLE_EXAMPLES) if ("${${fullSpkgName}_ENABLE_EXAMPLES}" STREQUAL "") message("-- " "Setting" - " ${fullSpkgName}_ENABLE_EXAMPLES=${${PARENT_PACKAGE_NAME}_ENABLE_EXAMPLES}" - " because parent package" - " ${PARENT_PACKAGE_NAME}_ENABLE_EXAMPLES" - "=${${PARENT_PACKAGE_NAME}_ENABLE_EXAMPLES}") + " ${fullSpkgName}_ENABLE_EXAMPLES=${${PARENT_PACKAGE_NAME}_ENABLE_EXAMPLES}" + " because parent package" + " ${PARENT_PACKAGE_NAME}_ENABLE_EXAMPLES" + "=${${PARENT_PACKAGE_NAME}_ENABLE_EXAMPLES}") set(${fullSpkgName}_ENABLE_EXAMPLES ${${PARENT_PACKAGE_NAME}_ENABLE_EXAMPLES}) endif() endif() @@ -1399,9 +1448,18 @@ macro(tribits_adjust_package_enables) ${PROJECT_NAME}_ENABLED_SE_PACKAGES "") # - # C) Enable tests for currently enabled SE packages + # C) Disable and enable tests for currently enabled SE packages # + message("") + message("Disabling subpackage tests/examples based on parent package tests/examples disables ...") + message("") + foreach(TRIBITS_PACKAGE ${${PROJECT_NAME}_PACKAGES}) + tribits_apply_package_examples_disable(${TRIBITS_PACKAGE} TESTS) + tribits_apply_subpackage_tests_or_examples_disables(${TRIBITS_PACKAGE} TESTS) + tribits_apply_subpackage_tests_or_examples_disables(${TRIBITS_PACKAGE} EXAMPLES) + endforeach() + if (${PROJECT_NAME}_ENABLE_TESTS OR ${PROJECT_NAME}_ENABLE_EXAMPLES) message("") message("Enabling all tests and/or examples that have not been" diff --git a/cmake/tribits/core/package_arch/TribitsExternalPackageWriteConfigFile.cmake b/cmake/tribits/core/package_arch/TribitsExternalPackageWriteConfigFile.cmake index 3c1c6763d4f5..2d041a2ca424 100644 --- a/cmake/tribits/core/package_arch/TribitsExternalPackageWriteConfigFile.cmake +++ b/cmake/tribits/core/package_arch/TribitsExternalPackageWriteConfigFile.cmake @@ -328,7 +328,7 @@ function(tribits_extpkg_add_find_upstream_dependencies_str ) foreach (upstreamTplDepEntry IN LISTS ${tplName}_LIB_ENABLED_DEPENDENCIES) tribits_extpkg_get_dep_name_and_vis( - "${upstreamTplDepEntry}" upstreamTplDepName upstreamTplDepVis) + "${upstreamTplDepEntry}" upstreamTplDepName upstreamTplDepVis) if ("${${upstreamTplDepName}_DIR}" STREQUAL "") message(FATAL_ERROR "ERROR: ${upstreamTplDepName}_DIR is empty!") endif() @@ -336,9 +336,9 @@ function(tribits_extpkg_add_find_upstream_dependencies_str "if (NOT TARGET ${upstreamTplDepName}::all_libs)\n" " set(${upstreamTplDepName}_DIR \"\${CMAKE_CURRENT_LIST_DIR}/../${upstreamTplDepName}\")\n" " find_dependency(${upstreamTplDepName} REQUIRED CONFIG \${${tplName}_SearchNoOtherPathsArgs})\n" - " unset(${upstreamTplDepName}_DIR)\n" + " unset(${upstreamTplDepName}_DIR)\n" "endif()\n" - "\n" + "\n" ) endforeach() string(APPEND configFileFragStr @@ -705,7 +705,7 @@ function(tribits_extpkg_append_upstream_target_link_libraries_str "target_link_libraries(${prefix_libname}\n") foreach (upstreamTplDepEntry IN LISTS ${tplName}_LIB_ENABLED_DEPENDENCIES) tribits_extpkg_get_dep_name_and_vis( - "${upstreamTplDepEntry}" upstreamTplDepName upstreamTplDepVis) + "${upstreamTplDepEntry}" upstreamTplDepName upstreamTplDepVis) if (upstreamTplDepVis STREQUAL "PUBLIC") string(APPEND configFileStr " INTERFACE ${upstreamTplDepName}::all_libs # i.e. PUBLIC\n") diff --git a/cmake/tribits/core/package_arch/TribitsGlobalMacros.cmake b/cmake/tribits/core/package_arch/TribitsGlobalMacros.cmake index ca763c91480b..c269cf63b9ae 100644 --- a/cmake/tribits/core/package_arch/TribitsGlobalMacros.cmake +++ b/cmake/tribits/core/package_arch/TribitsGlobalMacros.cmake @@ -204,12 +204,12 @@ function(assert_project_set_group_and_permissions_on_install_base_dir) "***\n" "*** ERROR in ${PROJECT_NAME}_SET_GROUP_AND_PERMISSIONS_ON_INSTALL_BASE_DIR!\n" "***\n" - "\n" - "${PROJECT_NAME}_SET_GROUP_AND_PERMISSIONS_ON_INSTALL_BASE_DIR=${${PROJECT_NAME}_SET_GROUP_AND_PERMISSIONS_ON_INSTALL_BASE_DIR}\n" + "\n" + "${PROJECT_NAME}_SET_GROUP_AND_PERMISSIONS_ON_INSTALL_BASE_DIR=${${PROJECT_NAME}_SET_GROUP_AND_PERMISSIONS_ON_INSTALL_BASE_DIR}\n" "\n" "is not a strict base dir of:\n" - "\n" - "CMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}\n" + "\n" + "CMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}\n" "\n" "Either remove ${PROJECT_NAME}_SET_GROUP_AND_PERMISSIONS_ON_INSTALL_BASE_DIR from the cache or set it to be a base dir of CMAKE_INSTALL_PREFIX!\n" "\n" @@ -2079,7 +2079,7 @@ macro(tribits_configure_enabled_packages) endif() else() set(${TRIBITS_PACKAGE}_BINARY_DIR - ${CMAKE_CURRENT_BINARY_DIR}/${${TRIBITS_PACKAGE}_REL_SOURCE_DIR}) + ${CMAKE_CURRENT_BINARY_DIR}/${${TRIBITS_PACKAGE}_REL_SOURCE_DIR}) endif() if (${PROJECT_NAME}_VERBOSE_CONFIGURE) print_var(${TRIBITS_PACKAGE}_BINARY_DIR) @@ -2139,7 +2139,7 @@ macro(tribits_configure_enabled_packages) "${TRIBITS_PACKAGE_CMAKELIST_FILE}") if (NOT ${TRIBITS_PACKAGE}_SOURCE_DIR STREQUAL ${PROJECT_NAME}_SOURCE_DIR) add_subdirectory(${${TRIBITS_PACKAGE}_SOURCE_DIR} ${${TRIBITS_PACKAGE}_BINARY_DIR}) - else() + else() include("${TRIBITS_PACKAGE_CMAKELIST_FILE}") endif() if (NOT ${PACKAGE_NAME}_TRIBITS_PACKAGE_POSTPROCESS) @@ -2307,15 +2307,15 @@ macro(tribits_setup_packaging_and_distribution) # package has to have this file to work correctly it should be # guaranteed to be there. string(REGEX MATCH "[.][.]/" RELATIVE_PATH_CHARS_MATCH - ${${TRIBITS_PACKAGE}_REL_SOURCE_DIR}) + ${${TRIBITS_PACKAGE}_REL_SOURCE_DIR}) if ("${RELATIVE_PATH_CHARS_MATCH}" STREQUAL "") set(CPACK_SOURCE_IGNORE_FILES - "${PROJECT_SOURCE_DIR}/${${TRIBITS_PACKAGE}_REL_SOURCE_DIR}/" + "${PROJECT_SOURCE_DIR}/${${TRIBITS_PACKAGE}_REL_SOURCE_DIR}/" ${CPACK_SOURCE_IGNORE_FILES}) else() find_path(ABSOLUTE_PATH CMakeLists.txt PATHS "${PROJECT_SOURCE_DIR}/${${TRIBITS_PACKAGE}_REL_SOURCE_DIR}" - NO_DEFAULT_PATH) + NO_DEFAULT_PATH) if ("${ABSOLUTE_PATH}" STREQUAL "ABSOLUTE_PATH-NOTFOUND") message(AUTHOR_WARNING "Relative path found for disabled package" " ${TRIBITS_PACKAGE} but package was missing a CMakeLists.txt file." diff --git a/cmake/tribits/core/package_arch/TribitsPackageMacros.cmake b/cmake/tribits/core/package_arch/TribitsPackageMacros.cmake index 99191e850611..62d3791e1b67 100644 --- a/cmake/tribits/core/package_arch/TribitsPackageMacros.cmake +++ b/cmake/tribits/core/package_arch/TribitsPackageMacros.cmake @@ -52,6 +52,7 @@ include(RemoveGlobalDuplicates) include(TribitsGatherBuildTargets) include(TribitsAddOptionAndDefine) +include(TribitsPkgExportCacheVars) include(TribitsLibraryMacros) include(TribitsAddExecutable) include(TribitsAddExecutableAndTest) @@ -180,25 +181,7 @@ macro(tribits_package_decl PACKAGE_NAME_IN) message("\nTRIBITS_PACKAGE_DECL: ${PACKAGE_NAME_IN}") endif() - if (CURRENTLY_PROCESSING_SUBPACKAGE) - tribits_report_invalid_tribits_usage( - "Cannot call tribits_package_decl() in a subpackage." - " Use tribits_subpackage() instead" - " error in ${CURRENT_SUBPACKAGE_CMAKELIST_FILE}") - endif() - - if(${PACKAGE_NAME}_TRIBITS_PACKAGE_DECL_CALLED) - tribits_report_invalid_tribits_usage( - "tribits_package_decl() called more than once in Package ${PACKAGE_NAME}" - " This may be because tribits_package_decl() was explicitly called more than once or" - " TRIBITS_PACKAGE_DECL was called after TRIBITS_PACKAGE. You do not need both." - " If your package has subpackages then do not call tribits_package() instead call:" - " tribits_pacakge_decl() then tribits_process_subpackages() then tribits package_def()" - ) - endif() - - # Set flag to check that macros are called in the correct order - set(${PACKAGE_NAME}_TRIBITS_PACKAGE_DECL_CALLED TRUE) + tribits_package_decl_assert_call_context() # # A) Parse the input arguments @@ -236,6 +219,7 @@ macro(tribits_package_decl PACKAGE_NAME_IN) # tribits_set_common_vars(${PACKAGE_NAME_IN}) + tribits_pkg_init_exported_vars(${PACKAGE_NAME_IN}) set(${PACKAGE_NAME_IN}_DISABLE_STRONG_WARNINGS OFF CACHE BOOL @@ -268,6 +252,31 @@ macro(tribits_package_decl PACKAGE_NAME_IN) endmacro() +macro(tribits_package_decl_assert_call_context) + + if (CURRENTLY_PROCESSING_SUBPACKAGE) + tribits_report_invalid_tribits_usage( + "Cannot call tribits_package_decl() in a subpackage." + " Use tribits_subpackage() instead" + " error in ${CURRENT_SUBPACKAGE_CMAKELIST_FILE}") + endif() + + if(${PACKAGE_NAME}_TRIBITS_PACKAGE_DECL_CALLED) + tribits_report_invalid_tribits_usage( + "tribits_package_decl() called more than once in Package ${PACKAGE_NAME}" + " This may be because tribits_package_decl() was explicitly called more than once or" + " TRIBITS_PACKAGE_DECL was called after TRIBITS_PACKAGE. You do not need both." + " If your package has subpackages then do not call tribits_package() instead call:" + " tribits_pacakge_decl() then tribits_process_subpackages() then tribits package_def()" + ) + endif() + + # Set flag to check that macros are called in the correct order + set(${PACKAGE_NAME}_TRIBITS_PACKAGE_DECL_CALLED TRUE) + +endmacro() + + # @MACRO: tribits_package_def() # # Macro called in `/CMakeLists.txt`_ after subpackages are @@ -291,6 +300,30 @@ endmacro() # macro(tribits_package_def) + if (${PROJECT_NAME}_VERBOSE_CONFIGURE) + message("\nTRIBITS_PACKAGE_DEF: ${PACKAGE_NAME}") + endif() + + tribits_package_def_assert_call_context() + + if (NOT ${PROJECT_NAME}_ENABLE_${PACKAGE_NAME}) + if (${PROJECT_NAME}_VERBOSE_CONFIGURE) + message("\n${PACKAGE_NAME} not enabled so exiting package processing") + endif() + return() + endif() + + # Reset in case were changed by subpackages + tribits_set_common_vars(${PACKAGE_NAME}) + + # Define package linkage variables + tribits_define_linkage_vars(${PACKAGE_NAME}) + +endmacro() + + +macro(tribits_package_def_assert_call_context) + # check that this is not being called from a subpackage if(NOT ${SUBPACKAGE_FULLNAME}_TRIBITS_SUBPACKAGE_POSTPROCESS_CALLED) if (CURRENTLY_PROCESSING_SUBPACKAGE) @@ -311,23 +344,6 @@ macro(tribits_package_def) "${CURRENT_SUBPACKAGE_CMAKELIST_FILE}") endif() - if (${PROJECT_NAME}_VERBOSE_CONFIGURE) - message("\nTRIBITS_PACKAGE_DEF: ${PACKAGE_NAME}") - endif() - - if (NOT ${PROJECT_NAME}_ENABLE_${PACKAGE_NAME}) - if (${PROJECT_NAME}_VERBOSE_CONFIGURE) - message("\n${PACKAGE_NAME} not enabled so exiting package processing") - endif() - return() - endif() - - # Reset in case were changed by subpackages - tribits_set_common_vars(${PACKAGE_NAME}) - - # Define package linkage variables - tribits_define_linkage_vars(${PACKAGE_NAME}) - set(${PACKAGE_NAME}_TRIBITS_PACKAGE_DEF_CALLED TRUE) endmacro() @@ -353,6 +369,13 @@ endmacro() # side-effects (and variables set) after calling this macro. # macro(tribits_package PACKAGE_NAME_IN) + tribits_package_assert_call_context() + tribits_package_decl(${PACKAGE_NAME_IN} ${ARGN}) + tribits_package_def() +endmacro() + + +macro(tribits_package_assert_call_context) if (CURRENTLY_PROCESSING_SUBPACKAGE) if (NOT ${SUBPACKAGE_FULLNAME}_TRIBITS_SUBPACKAGE_POSTPROCESS_CALLED) @@ -381,8 +404,6 @@ macro(tribits_package PACKAGE_NAME_IN) set(${PACKAGE_NAME}_TRIBITS_PACKAGE_CALLED TRUE) - tribits_package_decl(${PACKAGE_NAME_IN} ${ARGN}) - tribits_package_def() endmacro() @@ -440,6 +461,9 @@ endmacro() # typically called in the package's `/CMakeLists.txt`_ file (see # the example ``SimpleCxx/CMakeLists.txt``). # +# NOTE: This also calls `tribits_pkg_export_cache_var()`_ to export the +# variable ``${PACKAGE_NAME}_ENABLE_DEBUG``. +# macro(tribits_add_debug_option) tribits_add_option_and_define( ${PACKAGE_NAME}_ENABLE_DEBUG @@ -737,7 +761,7 @@ macro(tribits_package_postprocess) NOT ${PACKAGE_NAME}_TRIBITS_PROCESS_SUBPACKAGES_CALLED ) tribits_report_invalid_tribits_usage( - "Must call tribits_package_decl(), tribits_process_subpackages()" + "Must call tribits_package_decl(), tribits_process_subpackages()" " and tribits_package_def() before tribits_package_postprocess()." " Because this package has subpackages you cannot use tribits_package()" " you must call these in the following order:" @@ -755,16 +779,16 @@ macro(tribits_package_postprocess) # This is a package without subpackages if ( - (NOT ${PACKAGE_NAME}_TRIBITS_PACKAGE_CALLED) - AND - (NOT ${PACKAGE_NAME}_TRIBITS_PACKAGE_DEF_CALLED) + (NOT ${PACKAGE_NAME}_TRIBITS_PACKAGE_CALLED) + AND + (NOT ${PACKAGE_NAME}_TRIBITS_PACKAGE_DEF_CALLED) ) tribits_report_invalid_tribits_usage( "Must call tribits_package() or tribits_package_def() before" - " tribits_package_postprocess()" - " at the top of the file:\n" - " ${TRIBITS_PACKAGE_CMAKELIST_FILE}" - ) + " tribits_package_postprocess()" + " at the top of the file:\n" + " ${TRIBITS_PACKAGE_CMAKELIST_FILE}" + ) endif() endif() diff --git a/cmake/tribits/core/package_arch/TribitsPkgExportCacheVars.cmake b/cmake/tribits/core/package_arch/TribitsPkgExportCacheVars.cmake new file mode 100644 index 000000000000..6fb85e0ef8de --- /dev/null +++ b/cmake/tribits/core/package_arch/TribitsPkgExportCacheVars.cmake @@ -0,0 +1,130 @@ +# @HEADER +# ************************************************************************ +# +# TriBITS: Tribal Build, Integrate, and Test System +# Copyright 2013 Sandia Corporation +# +# Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation, +# the U.S. Government retains certain rights in this software. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# 3. Neither the name of the Corporation nor the names of the +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# ************************************************************************ +# @HEADER + + +# @MACRO: tribits_pkg_export_cache_var() +# +# Macro that registers a package-level cache var to be exported in the +# ``Config.cmake`` file +# +# Usage:: +# +# tribits_pkg_export_cache_var() +# +# where ```` must be the name of a cache variable (or an error +# will occur). +# +# NOTE: This will also export this variable to the +# ``Config.cmake`` file for every enabled subpackage (if this +# is called from a ``CMakeLists.txt`` file of a top-level package that has +# subpackages). That way, any top-level package cache vars are provided by +# any of the subpackages' ``Config.cmake`` files. +# +macro(tribits_pkg_export_cache_var cacheVarName) + if (DEFINED ${PACKAGE_NAME}_PKG_VARS_TO_EXPORT) + # Assert this is a cache var + get_property(cacheVarIsCacheVar CACHE ${cacheVarName} PROPERTY VALUE SET) + if (NOT cacheVarIsCacheVar) + message(SEND_ERROR + "ERROR: The variable ${cacheVarName} is NOT a cache var and cannot" + " be exported!") + endif() + # Add to the list of package cache vars to export + append_global_set(${PACKAGE_NAME}_PKG_VARS_TO_EXPORT + ${cacheVarName}) + endif() +endmacro() + + +# @MACRO: tribits_assert_cache_and_local_vars_same_value() +# +# Asset that a cache variable and a possible local variable (if it exists) +# have the same value. +# +# Usage:: +# +# tribits_assert_cache_and_local_vars_same_value() +# +# If the local var ```` and the cache var ```` +# both exist but have different values, then ``message(SEND_ERROR ...)`` is +# called with an informative error message. +# +macro(tribits_assert_cache_and_local_vars_same_value cacheVarName) + set(cacheVarValue "$CACHE{${cacheVarName}}") + set(localValue "${${cacheVarName}}") + if (NOT localValue STREQUAL cacheVarValue) + message_wrapper(SEND_ERROR "ERROR: The cache variable ${cacheVarName} with the" + " cache var value '${cacheVarValue}' is not the same value as the local" + " variable ${cacheVarName} with value '${localValue}'!") + endif() +endmacro() + + +# Function that sets up data-structures for package-level cache var to be +# exported +# +function(tribits_pkg_init_exported_vars PACKAGE_NAME_IN) + global_set(${PACKAGE_NAME_IN}_PKG_VARS_TO_EXPORT "") +endfunction() + + +# Function that injects set() statements for a package's exported cache vars into +# a string. +# +# This is used to create set() statements to be injected into a package's +# ``Config.cmake`` file. +# +function(tribits_pkg_append_set_commands_for_exported_vars packageName + configFileStrInOut + ) + set(configFileStr "${${configFileStrInOut}}") + if (NOT "${${packageName}_PARENT_PACKAGE}" STREQUAL "") + foreach(exportedCacheVar IN LISTS ${${packageName}_PARENT_PACKAGE}_PKG_VARS_TO_EXPORT) + tribits_assert_cache_and_local_vars_same_value(${exportedCacheVar}) + string(APPEND configFileStr + "set(${exportedCacheVar} \"${${exportedCacheVar}}\")\n") + endforeach() + endif() + foreach(exportedCacheVar IN LISTS ${packageName}_PKG_VARS_TO_EXPORT) + tribits_assert_cache_and_local_vars_same_value(${exportedCacheVar}) + string(APPEND configFileStr + "set(${exportedCacheVar} \"${${exportedCacheVar}}\")\n") + endforeach() + set(${configFileStrInOut} "${configFileStr}" PARENT_SCOPE) +endfunction() diff --git a/cmake/tribits/core/package_arch/TribitsSubPackageMacros.cmake b/cmake/tribits/core/package_arch/TribitsSubPackageMacros.cmake index 60a7c668e9a5..27383b1907b6 100644 --- a/cmake/tribits/core/package_arch/TribitsSubPackageMacros.cmake +++ b/cmake/tribits/core/package_arch/TribitsSubPackageMacros.cmake @@ -88,6 +88,7 @@ macro(tribits_subpackage SUBPACKAGE_NAME_IN) # Now override the package-like variables tribits_set_common_vars(${SUBPACKAGE_FULLNAME}) tribits_define_linkage_vars(${SUBPACKAGE_FULLNAME}) + tribits_pkg_init_exported_vars(${SUBPACKAGE_FULLNAME}) tribits_append_package_specific_compiler_flags() if(${PROJECT_NAME}_VERBOSE_CONFIGURE) @@ -123,23 +124,22 @@ function(tribits_subpackage_assert_call_context) if(${SUBPACKAGE_FULLNAME}_TRIBITS_SUBPACKAGE_CALLED) tribits_report_invalid_tribits_usage( "Already called tribits_subpackge() for the" - " ${PARENT_PACKAGE_NAME} subpackage ${TRIBITS_SUBPACKAGE}") + " ${PARENT_PACKAGE_NAME} subpackage ${TRIBITS_SUBPACKAGE}") endif() # make sure the name in the macro call matches the name in the packages cmake file if (NOT ${SUBPACKAGE_NAME_IN} STREQUAL ${SUBPACKAGE_NAME}) tribits_report_invalid_tribits_usage( "Error, the package-defined subpackage name" - " '${SUBPACKAGE_NAME_IN}' is not the same as the subpackage name" - " '${SUBPACKAGE_NAME}' defined in the parent packages's" - " Dependencies.cmake file") + " '${SUBPACKAGE_NAME_IN}' is not the same as the subpackage name" + " '${SUBPACKAGE_NAME}' defined in the parent packages's" + " Dependencies.cmake file") endif() endif() endfunction() - # @MACRO: tribits_subpackage_postprocess() # # Macro that performs standard post-processing after defining a `TriBITS @@ -158,20 +158,22 @@ endfunction() # this macro but limitations of the CMake language make it necessary to do so. # macro(tribits_subpackage_postprocess) + tribits_subpackage_postprocess_assert_call_context() + tribits_package_postprocess_common() +endmacro() - # check that this is not being called from a package - if (NOT CURRENTLY_PROCESSING_SUBPACKAGE) - # This is being called from a package +macro(tribits_subpackage_postprocess_assert_call_context) + # check that this is not being called from a package + if (NOT CURRENTLY_PROCESSING_SUBPACKAGE) + # This is being called from a package tribits_report_invalid_tribits_usage( "Cannot call tribits_subpackage_postprocess() from a package." " Use tribits_package_postprocess() instead" " ${CURRENT_PACKAGE_CMAKELIST_FILE}") - else() - # This is being caleld from a subpackage - + # This is being called from a subpackage # check to make sure this has not already been called if (${SUBPACKAGE_FULLNAME}_TRIBITS_SUBPACKAGE_POSTPROCESS_CALLED) tribits_report_invalid_tribits_usage( @@ -185,12 +187,9 @@ macro(tribits_subpackage_postprocess) "tribits_subpackage() must be called before tribits_subpackage_postprocess()" " for the ${PARENT_PACKAGE_NAME} subpackage ${TRIBITS_SUBPACKAGE}") endif() - endif() # Set flags that are used to check that macros are called in the correct order dual_scope_set(${SUBPACKAGE_FULLNAME}_TRIBITS_SUBPACKAGE_POSTPROCESS_CALLED TRUE) - tribits_package_postprocess_common() - endmacro() diff --git a/cmake/tribits/core/package_arch/TribitsTplFindIncludeDirsAndLibraries.cmake b/cmake/tribits/core/package_arch/TribitsTplFindIncludeDirsAndLibraries.cmake index 987aaf8dd3b9..8a5eeb35f686 100644 --- a/cmake/tribits/core/package_arch/TribitsTplFindIncludeDirsAndLibraries.cmake +++ b/cmake/tribits/core/package_arch/TribitsTplFindIncludeDirsAndLibraries.cmake @@ -505,7 +505,7 @@ function(tribits_tpl_find_include_dirs_and_libraries TPL_NAME) "-- ${LIB_NOT_FOUND_MSG_PREFIX} Did not find a lib in the lib set \"${LIBNAME_SET}\"" " for the TPL '${TPL_NAME}'!") if (MUST_FIND_ALL_LIBS) - set(_${TPL_NAME}_ENABLE_SUCCESS FALSE) + set(_${TPL_NAME}_ENABLE_SUCCESS FALSE) else() break() endif() diff --git a/cmake/tribits/core/package_arch/TribitsWriteClientExportFiles.cmake b/cmake/tribits/core/package_arch/TribitsWriteClientExportFiles.cmake index 9b1f967af96b..cf3e9ef2f37b 100644 --- a/cmake/tribits/core/package_arch/TribitsWriteClientExportFiles.cmake +++ b/cmake/tribits/core/package_arch/TribitsWriteClientExportFiles.cmake @@ -38,6 +38,7 @@ # @HEADER include(TribitsGeneralMacros) +include(TribitsPkgExportCacheVars) ### ### WARNING: See "NOTES TO DEVELOPERS" at the bottom of the file @@ -545,8 +546,7 @@ function(tribits_append_dependent_package_config_file_includes_and_enables packa # Parse input cmake_parse_arguments( - PARSE #prefix - "" #options + PARSE "" # prefix, options #one_value_keywords "EXPORT_FILE_VAR_PREFIX;EXT_PKG_CONFIG_FILE_BASE_DIR;PKG_CONFIG_FILE_BASE_DIR;CONFIG_FILE_STR_INOUT" "" #multi_value_keywords @@ -577,6 +577,11 @@ function(tribits_append_dependent_package_config_file_includes_and_enables packa "set(${EXPORT_FILE_VAR_PREFIX}_ENABLE_${depPkg} ${enableVal})\n") endforeach() + # Put in set() statements for exported cache vars + string(APPEND configFileStr + "\n# Exported cache variables\n") + tribits_pkg_append_set_commands_for_exported_vars(${packageName} configFileStr) + # Include configurations of dependent packages string(APPEND configFileStr "\n# Include configuration of dependent packages\n") @@ -876,7 +881,7 @@ include(\"${${TRIBITS_PACKAGE}_BINARY_DIR}/${TRIBITS_PACKAGE}Config.cmake\")") set(TRIBITS_PROJECT_INSTALL_INCLUDE_DIR "${${PROJECT_NAME}_INSTALL_INCLUDE_DIR}") else() set(TRIBITS_PROJECT_INSTALL_INCLUDE_DIR - "${CMAKE_INSTALL_PREFIX}/${${PROJECT_NAME}_INSTALL_INCLUDE_DIR}") + "${CMAKE_INSTALL_PREFIX}/${${PROJECT_NAME}_INSTALL_INCLUDE_DIR}") endif() configure_file( diff --git a/cmake/tribits/ctest_driver/TribitsAddDashboardTarget.cmake b/cmake/tribits/ctest_driver/TribitsAddDashboardTarget.cmake index a9a8b6992f21..723725deb3d7 100644 --- a/cmake/tribits/ctest_driver/TribitsAddDashboardTarget.cmake +++ b/cmake/tribits/ctest_driver/TribitsAddDashboardTarget.cmake @@ -196,7 +196,7 @@ macro(tribits_add_dashboard_target) # NOTE: Above, if ${PROJECT_NAME}_ENABLE_ALL_PACKAGES was set in CMakeCache.txt, then setting # -D${PROJECT_NAME}_ENABLE_ALL_PACKAGES:BOOL=OFF will turn it off in the cache. Note that it will # never be turned on again which means that the list of packages will be set explicitly below. - ) + ) set(DASHBOARD_TARGET_CTEST_DRIVER_CMND_NUM "B) ") @@ -226,7 +226,7 @@ macro(tribits_add_dashboard_target) COMMAND echo COMMAND echo "See the results at http://${CTEST_DROP_SITE}${CTEST_DROP_LOCATION}&display=project\#Experimental" COMMAND echo - ) + ) endif() diff --git a/cmake/tribits/ctest_driver/TribitsCTestDriverCore.cmake b/cmake/tribits/ctest_driver/TribitsCTestDriverCore.cmake index da12e79f96fe..33f2c886b88a 100644 --- a/cmake/tribits/ctest_driver/TribitsCTestDriverCore.cmake +++ b/cmake/tribits/ctest_driver/TribitsCTestDriverCore.cmake @@ -2128,18 +2128,18 @@ function(tribits_ctest_driver) if (EXISTS "${CTEST_TESTING_TAG_FILE}") file(READ "${CTEST_TESTING_TAG_FILE}" TAG_FILE_CONTENTS_STR) message( - "\nPrevious file:" - "\n" - "\n '${CTEST_TESTING_TAG_FILE}'" - "\n" - "\nexists with contents:\n" - "\n" - "${TAG_FILE_CONTENTS_STR}\n") + "\nPrevious file:" + "\n" + "\n '${CTEST_TESTING_TAG_FILE}'" + "\n" + "\nexists with contents:\n" + "\n" + "${TAG_FILE_CONTENTS_STR}\n") else() message(FATAL_ERROR - "ERROR: Previous file '${CTEST_TESTING_TAG_FILE}' does NOT exist!" - " A previous ctest_start() was not called. Please call again" - " this time setting CTEST_DO_NEW_START=TRUE") + "ERROR: Previous file '${CTEST_TESTING_TAG_FILE}' does NOT exist!" + " A previous ctest_start() was not called. Please call again" + " this time setting CTEST_DO_NEW_START=TRUE") endif() list(APPEND CTEST_START_ARGS APPEND) diff --git a/cmake/tribits/ctest_driver/TribitsCTestDriverCoreHelpers.cmake b/cmake/tribits/ctest_driver/TribitsCTestDriverCoreHelpers.cmake index f54285bb8c1f..1d62d292e8ea 100644 --- a/cmake/tribits/ctest_driver/TribitsCTestDriverCoreHelpers.cmake +++ b/cmake/tribits/ctest_driver/TribitsCTestDriverCoreHelpers.cmake @@ -50,7 +50,7 @@ macro(extrarepo_execute_process_wrapper) if (NOT EXTRAREPO_EXECUTE_PROCESS_WRAPPER_RTN_VAL STREQUAL "0") message(SEND_ERROR "Error: execute_process(${ARGN}) returned" - " '${EXTRAREPO_EXECUTE_PROCESS_WRAPPER_RTN_VAL}'") + " '${EXTRAREPO_EXECUTE_PROCESS_WRAPPER_RTN_VAL}'") endif() else() message("execute_process(${ARGN})") @@ -136,12 +136,12 @@ function(tribits_clone_or_update_extrarepo EXTRAREPO_NAME_IN EXTRAREPO_DIR_IN set(CLONE_CMND_ARGS COMMAND "${GIT_EXECUTABLE}" clone ${CHECKOUT_BRANCH_ARG} -o ${${PROJECT_NAME}_GIT_REPOSITORY_REMOTE} - "${EXTRAREPO_REPOURL}" ${EXTRAREPO_DIR_IN} + "${EXTRAREPO_REPOURL}" ${EXTRAREPO_DIR_IN} WORKING_DIRECTORY "${${PROJECT_NAME}_SOURCE_DIRECTORY}" OUTPUT_FILE "${EXTRAREPO_CLONE_OUT_FILE}" ) else() message(SEND_ERROR - "Error, Invalid EXTRAREPO_REPOTYPE_IN='${EXTRAREPO_REPOTYPE_IN}'!") + "Error, Invalid EXTRAREPO_REPOTYPE_IN='${EXTRAREPO_REPOTYPE_IN}'!") endif() # Do the clone @@ -508,7 +508,7 @@ macro(enable_only_modified_packages) if (${PROJECT_NAME}_ENABLE_ALL_PACKAGES) if (NOT ${PROJECT_NAME}_CTEST_DO_ALL_AT_ONCE) message(FATAL_ERROR - "Error, failing 'ALL_PACKAGES' only allowed with all-at-once mode!") + "Error, failing 'ALL_PACKAGES' only allowed with all-at-once mode!") endif() message("\nDirectly modified or failing non-disabled packages that need" " to be tested: ALL_PACKAGES") @@ -954,7 +954,7 @@ macro(tribits_ctest_package_by_package) if (CTEST_DEPENDENCY_HANDLING_UNIT_TESTING) message("${TRIBITS_PACKAGE}: Skipping configure due" - " to running in unit testing mode!") + " to running in unit testing mode!") else() @@ -1001,7 +1001,7 @@ macro(tribits_ctest_package_by_package) if (NOT CTEST_DO_CONFIGURE AND CTEST_DO_SUBMIT) message("${TRIBITS_PACKAGE}: Skipping submitting configure" - " and notes due to CTEST_DO_CONFIGURE='${CTEST_DO_CONFIGURE}'!") + " and notes due to CTEST_DO_CONFIGURE='${CTEST_DO_CONFIGURE}'!") elseif (CTEST_DO_SUBMIT) message("\nSubmitting configure and notes ...") tribits_ctest_submit( PARTS configure notes ) @@ -1025,7 +1025,7 @@ macro(tribits_ctest_package_by_package) if ( NOT PBP_CONFIGURE_PASSED AND CTEST_DO_BUILD ) message("\n${TRIBITS_PACKAGE}: Skipping build due" - " to configure failing!") + " to configure failing!") set(PBP_BUILD_PASSED FALSE) set(PBP_BUILD_LIBS_PASSED FALSE) @@ -1033,14 +1033,14 @@ macro(tribits_ctest_package_by_package) elseif (NOT CTEST_DO_BUILD) message("\n${TRIBITS_PACKAGE}: Skipping build due" - " to CTEST_DO_BUILD='${CTEST_DO_BUILD}'!") + " to CTEST_DO_BUILD='${CTEST_DO_BUILD}'!") elseif (CTEST_DEPENDENCY_HANDLING_UNIT_TESTING OR CTEST_CONFIGURATION_UNIT_TESTING ) message("\n${TRIBITS_PACKAGE}: Skipping build due" - " to running in unit testing mode!") + " to running in unit testing mode!") else() @@ -1152,7 +1152,7 @@ macro(tribits_ctest_package_by_package) " exists so there were failed tests!") else() message("\n${TRIBITS_PACKAGE}: File '${FAILED_TEST_LOG_FILE}'" - " does NOT exist so all tests passed!") + " does NOT exist so all tests passed!") set(PBP_TESTS_PASSED TRUE) endif() # 2009/12/05: ToDo: We need to add an argument to ctest_test(...) @@ -1190,7 +1190,7 @@ macro(tribits_ctest_package_by_package) if (NOT PBP_BUILD_LIBS_PASSED AND CTEST_DO_MEMORY_TESTING) message("\n${TRIBITS_PACKAGE}: Skipping running memory checking" - "tests since library build failed!\n") + "tests since library build failed!\n") elseif (NOT CTEST_DO_MEMORY_TESTING) @@ -1209,7 +1209,7 @@ macro(tribits_ctest_package_by_package) BUILD "${CTEST_BINARY_DIRECTORY}" PARALLEL_LEVEL "${CTEST_PARALLEL_LEVEL}" INCLUDE_LABEL "^${TRIBITS_PACKAGE}$" - ) + ) # ToDo: Determine if memory testing passed or not and affect overall # pass/fail! diff --git a/cmake/tribits/doc/guides/TribitsGuidesBody.rst b/cmake/tribits/doc/guides/TribitsGuidesBody.rst index 848f5aea1d9d..293fcca3ea1e 100644 --- a/cmake/tribits/doc/guides/TribitsGuidesBody.rst +++ b/cmake/tribits/doc/guides/TribitsGuidesBody.rst @@ -3378,7 +3378,7 @@ management system are: 9) `TPL disable triggers auto-disables of downstream dependencies`_ 10) `Disables trump enables where there is a conflict`_ 11) `Enable/disable of parent package is enable/disable for subpackages`_ -12) `Enable of parent package tests/examples is enable for subpackages tests/examples`_ +12) `Enable/disable of parent package tests/examples is enable/disable for subpackages tests/examples`_ 13) `Subpackage enable does not auto-enable the parent package`_ 14) `Support for optional SE package/TPL is enabled by default`_ 15) `Support for optional SE package/TPL can be explicitly disabled`_ @@ -3588,17 +3588,19 @@ In more detail, these rules/behaviors are: see `Explicit enable of a package, its tests, an optional TPL, with ST enabled`_. -.. _Enable of parent package tests/examples is enable for subpackages tests/examples: - -12) **Enable of parent package tests/examples is enable for subpackages - tests/examples**: Setting ``_ENABLE_TESTS=ON`` is - equivalent to setting the default for - ``_ENABLE_TESTS=ON`` for each subpackage ```` of - the parent package ```` (if ```` has - subpackages). Same is true for ``_ENABLE_EXAMPLES=ON`` - setting the default for ``_ENABLE_EXAMPLES=ON``. In - addition, setting ``_ENABLE_TESTS=ON`` will set - ``_ENABLE_EXAMPLES=ON`` by default as well. +.. _Enable/disable of parent package tests/examples is enable/disable for subpackages tests/examples: + +12) **Enable/disable of parent package tests/examples is enable/disable for + subpackages tests/examples**: Setting + ``_ENABLE_TESTS=[ON|OFF]`` is equivalent to setting the + default for ``_ENABLE_TESTS=[ON|OFF]`` for each + subpackage ```` of the parent package ```` (if + ```` has subpackages). Same is true for + ``_ENABLE_EXAMPLES=[ON|OFF]`` setting the default for + ``_ENABLE_EXAMPLES=[ON|OFF]``. In addition, setting + ``_ENABLE_TESTS=[ON|OFF]`` will set + ``_ENABLE_EXAMPLES=[ON|OFF]`` by default as well (but not + vice versa). .. _Subpackage enable does not auto-enable the parent package: diff --git a/cmake/tribits/doc/guides/TribitsMacroFunctionDocTemplate.rst b/cmake/tribits/doc/guides/TribitsMacroFunctionDocTemplate.rst index b87a485d85de..c045c1962e54 100644 --- a/cmake/tribits/doc/guides/TribitsMacroFunctionDocTemplate.rst +++ b/cmake/tribits/doc/guides/TribitsMacroFunctionDocTemplate.rst @@ -14,6 +14,7 @@ @FUNCTION: tribits_add_test() + @MACRO: tribits_add_test_directories() + @FUNCTION: tribits_allow_missing_external_packages() + +@MACRO: tribits_assert_cache_and_local_vars_same_value() + @FUNCTION: tribits_configure_file() + @FUNCTION: tribits_copy_files_to_binary_dir() + @FUNCTION: tribits_ctest_driver() + @@ -30,6 +31,7 @@ @FUNCTION: tribits_find_most_recent_source_file_timestamp() + @FUNCTION: tribits_install_headers() + @MACRO: tribits_include_directories() + +@MACRO: tribits_pkg_export_cache_var() + @MACRO: tribits_package() + @MACRO: tribits_package_decl() + @MACRO: tribits_package_def() + diff --git a/cmake/tribits/examples/TribitsExampleApp/CMakeLists.txt b/cmake/tribits/examples/TribitsExampleApp/CMakeLists.txt index 4de561d8acbb..562a38ffd4ab 100644 --- a/cmake/tribits/examples/TribitsExampleApp/CMakeLists.txt +++ b/cmake/tribits/examples/TribitsExampleApp/CMakeLists.txt @@ -17,6 +17,9 @@ include(AppHelperFuncs) getTribitsExProjStuffForApp() +# Show that we can see exported cache vars +message("-- WithSubpackagesA_SPECIAL_VALUE = '${WithSubpackagesA_SPECIAL_VALUE}'") + # Enable the compilers now that we have gotten them from the *Config.cmake file enable_language(C) enable_language(CXX) diff --git a/cmake/tribits/examples/TribitsExampleProject/packages/simple_cxx/CMakeLists.txt b/cmake/tribits/examples/TribitsExampleProject/packages/simple_cxx/CMakeLists.txt index d9e5c1c2bd1a..4eb715a09576 100644 --- a/cmake/tribits/examples/TribitsExampleProject/packages/simple_cxx/CMakeLists.txt +++ b/cmake/tribits/examples/TribitsExampleProject/packages/simple_cxx/CMakeLists.txt @@ -8,6 +8,7 @@ tribits_package( SimpleCxx ENABLE_SHADOWING_WARNINGS CLEANED ) # include(CheckFor__int64) check_for___int64(HAVE_SIMPLECXX___INT64) +tribits_pkg_export_cache_var(HAVE_SIMPLECXX___INT64) # # C) Set up package-specific options diff --git a/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/A.cpp b/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/A.cpp index f9db8c90a125..85286e850a7c 100644 --- a/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/A.cpp +++ b/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/A.cpp @@ -1,4 +1,5 @@ #include "A.hpp" +#include "WithSubpackagesA_config.h" #include "SimpleCxx_HelloWorld.hpp" @@ -9,3 +10,7 @@ std::string WithSubpackages::getA() { std::string WithSubpackages::depsA() { return "SimpleCxx "+SimpleCxx::deps(); } + +int WithSubpackages::specialValue() { + return WITHSUBPACKAGESA_SPECIAL_VALUE; +} diff --git a/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/A.hpp b/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/A.hpp index 3476ee5b17ff..36364a10af12 100644 --- a/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/A.hpp +++ b/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/A.hpp @@ -11,6 +11,9 @@ namespace WithSubpackages { // return a string describing the dependencies of "A", recursively std::string depsA(); + // return special value + int specialValue(); + } diff --git a/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/CMakeLists.txt b/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/CMakeLists.txt index f9ede481a956..e7cdf9cf6f63 100644 --- a/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/CMakeLists.txt +++ b/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/CMakeLists.txt @@ -6,18 +6,22 @@ tribits_subpackage(A) # # B) Set up subpackage-specific options # -# Typically there are none or are few as most options are picked up from the -# parent package's CMakeLists.txt file! + +set(${PACKAGE_NAME}_SPECIAL_VALUE 3 CACHE STRING "Integer special value") +tribits_pkg_export_cache_var(${PACKAGE_NAME}_SPECIAL_VALUE) # # C) Add the libraries, tests, and examples # +tribits_configure_file(${PACKAGE_NAME}_config.h) + +include_directories(${CMAKE_CURRENT_BINARY_DIR}) + include_directories(${CMAKE_CURRENT_SOURCE_DIR}) tribits_add_library(pws_a SOURCES A.cpp - HEADERS A.hpp - NOINSTALLHEADERS + HEADERS A.hpp ${CMAKE_CURRENT_BINARY_DIR}/${PACKAGE_NAME}_config.h ) tribits_add_test_directories(tests) diff --git a/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/cmake/WithSubpackagesA_config.h.in b/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/cmake/WithSubpackagesA_config.h.in new file mode 100644 index 000000000000..4536208a03c8 --- /dev/null +++ b/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/cmake/WithSubpackagesA_config.h.in @@ -0,0 +1,6 @@ +#ifndef WITHSUBPACKAGESA_CONFIG_H +#define WITHSUBPACKAGESA_CONFIG_H + +#define WITHSUBPACKAGESA_SPECIAL_VALUE ${WithSubpackagesA_SPECIAL_VALUE} + +#endif // WITHSUBPACKAGESA_CONFIG_H diff --git a/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/tests/CMakeLists.txt b/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/tests/CMakeLists.txt index 788b93eac65d..c891750a7e94 100644 --- a/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/tests/CMakeLists.txt +++ b/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/tests/CMakeLists.txt @@ -6,4 +6,5 @@ tribits_add_advanced_test( test_of_a PASS_REGULAR_EXPRESSION_ALL "A label is: A" "A deps are: ${EXPECTED_SIMPLECXX_AND_DEPS}" + "A special value: ${WithSubpackagesA_SPECIAL_VALUE}" ) diff --git a/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/tests/a_test.cpp b/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/tests/a_test.cpp index 4c48a7838eae..6497b4d8ebee 100644 --- a/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/tests/a_test.cpp +++ b/cmake/tribits/examples/TribitsExampleProject/packages/with_subpackages/a/tests/a_test.cpp @@ -3,13 +3,11 @@ #include "A.hpp" -using namespace WithSubpackages; - int main() { - std::string label_A = getA(); - std::string deps_A = depsA(); + std::string label_A = WithSubpackages::getA(); + std::string deps_A = WithSubpackages::depsA(); std::cout << "A label is: " << label_A << std::endl; std::cout << "A deps are: " << deps_A << std::endl; - + std::cout << "A special value: " << WithSubpackages::specialValue() << std::endl; return 0; } diff --git a/demos/simpleBuildAgainstTrilinos/CMakeLists.txt b/demos/simpleBuildAgainstTrilinos/CMakeLists.txt index 23c4f2e34ea9..990287789e66 100644 --- a/demos/simpleBuildAgainstTrilinos/CMakeLists.txt +++ b/demos/simpleBuildAgainstTrilinos/CMakeLists.txt @@ -1,6 +1,6 @@ # CMAKE File for "MyApp" application building against an installed Trilinos -cmake_minimum_required(VERSION 3.17.1) +cmake_minimum_required(VERSION 3.0) # Declare project but don't process compilers yet # @@ -22,11 +22,8 @@ MESSAGE(" Trilinos_VERSION = ${Trilinos_VERSION}") MESSAGE(" Trilinos_PACKAGE_LIST = ${Trilinos_PACKAGE_LIST}") MESSAGE(" Trilinos_LIBRARIES = ${Trilinos_LIBRARIES}") MESSAGE(" Trilinos_INCLUDE_DIRS = ${Trilinos_INCLUDE_DIRS}") -MESSAGE(" Trilinos_LIBRARY_DIRS = ${Trilinos_LIBRARY_DIRS}") MESSAGE(" Trilinos_TPL_LIST = ${Trilinos_TPL_LIST}") -MESSAGE(" Trilinos_TPL_INCLUDE_DIRS = ${Trilinos_TPL_INCLUDE_DIRS}") MESSAGE(" Trilinos_TPL_LIBRARIES = ${Trilinos_TPL_LIBRARIES}") -MESSAGE(" Trilinos_TPL_LIBRARY_DIRS = ${Trilinos_TPL_LIBRARY_DIRS}") MESSAGE(" Trilinos_BUILD_SHARED_LIBS = ${Trilinos_BUILD_SHARED_LIBS}") MESSAGE("End of Trilinos details\n") diff --git a/doc/DocumentingParameterLists/memo/SMemo_PListDoc.tex b/doc/DocumentingParameterLists/memo/SMemo_PListDoc.tex index 92a4a30c7e5c..f7214e45483d 100644 --- a/doc/DocumentingParameterLists/memo/SMemo_PListDoc.tex +++ b/doc/DocumentingParameterLists/memo/SMemo_PListDoc.tex @@ -30,7 +30,7 @@ \from{Bill Spotz, Org 1446; Dena Vigil, Org 1441} % Fill in a subject - \subject{How to Document \texttt{Teuchos::ParameterList}s with collapsable HTML} + \subject{How to Document \texttt{Teuchos::ParameterList}s with collapsible HTML} diff --git a/doc/DocumentingParameterLists/memo/smemo.cls b/doc/DocumentingParameterLists/memo/smemo.cls index 71c958daf00b..ed2f3ac9f6d6 100644 --- a/doc/DocumentingParameterLists/memo/smemo.cls +++ b/doc/DocumentingParameterLists/memo/smemo.cls @@ -918,7 +918,7 @@ % **************************************** % % The distribution environment switches to two column mode if necessary, -% shrinks up the \baselineskip and \parskip, allowing disribution list +% shrinks up the \baselineskip and \parskip, allowing distribution list % to be set in a minimal amount of space. \def\distribution#1{% @@ -1029,8 +1029,8 @@ % of counter CTR. It is defined in terms of the following macros: % % \arabic{COUNTER} : The value of COUNTER printed as an arabic numeral. -% \roman{COUNTER} : Its value printed as a lower-case roman numberal. -% \Roman{COUNTER} : Its value printed as an upper-case roman numberal. +% \roman{COUNTER} : Its value printed as a lower-case roman numeral. +% \Roman{COUNTER} : Its value printed as an upper-case roman numeral. % \alph{COUNTER} : Value of COUNTER printed as a lower-case letter: % 1 = a, 2 = b, etc. % \Alph{COUNTER} : Value of COUNTER printed as an upper-case letter: diff --git a/doc/build_docs.pl b/doc/build_docs.pl index 3efb64e550d7..8fc3769079ec 100755 --- a/doc/build_docs.pl +++ b/doc/build_docs.pl @@ -3,7 +3,7 @@ ############################################################################### # Trilinos/doc/build_docs.pl # -# - You must run this script from this directoy! +# - You must run this script from this directory! # - Run any build_docs in any doc directory # - Create html file with links to each set of documentation # diff --git a/doc/build_ref/TrilinosBuildReferenceTemplate.rst b/doc/build_ref/TrilinosBuildReferenceTemplate.rst index 488087a4ba95..87841989a7b1 100644 --- a/doc/build_ref/TrilinosBuildReferenceTemplate.rst +++ b/doc/build_ref/TrilinosBuildReferenceTemplate.rst @@ -48,17 +48,17 @@ various Trilinos packages can be enabled using the following options: ``-DTrilinos_ENABLE_FLOAT=ON`` - Enables suppport and explicit instantiations for the ``float`` scalar + Enables support and explicit instantiations for the ``float`` scalar data-type in all supported Trilinos packages. ``-DTrilinos_ENABLE_COMPLEX=ON`` - Enables suppport and explicit instantiations for the ``std::complex`` + Enables support and explicit instantiations for the ``std::complex`` scalar data-type in all supported Trilinos packages. ``-DTrilinos_ENABLE_COMPLEX_FLOAT=ON`` - Enables suppport and explicit instantiations for the + Enables support and explicit instantiations for the ``std::complex`` scalar data-type in all supported Trilinos packages. This is set to ``ON`` by default when ``-DTrilinos_ENABLE_FLOAT=ON`` and ``-DTrilinos_ENABLE_COMPLEX=ON`` are @@ -66,7 +66,7 @@ various Trilinos packages can be enabled using the following options: ``-DTrilinos_ENABLE_COMPLEX_DOUBLE=ON`` - Enables suppport and explicit instantiations for the + Enables support and explicit instantiations for the ``std::complex`` scalar data-type in all supported Trilinos packages. This is set to ``ON`` by default when ``-DTrilinos_ENABLE_COMPLEX=ON`` is set. @@ -125,7 +125,7 @@ target machine. These build-related flags are selected to create correct and perforamnt code and for C++ software that uses Kokkos. ============================ ====================================== -Functionality CMake Cache Varaible +Functionality CMake Cache Variable ============================ ====================================== Specify architecture ``KOKKOS_ARCH`` Debug builds ``KOKKOS_DEBUG`` @@ -218,7 +218,7 @@ Addressing problems with large builds of Trilinos ------------------------------------------------- Trilinos is a large collection of complex software. Depending on what gets -enbaled when building Trlinos, one can experience build and installation +enabled when building Trlinos, one can experience build and installation problems due to this large size. When running into problems like these, the first thing that should be tried is diff --git a/packages/TrilinosInstallTests/CMakeLists.txt b/packages/TrilinosInstallTests/CMakeLists.txt index 595a7866ecb5..41e32007d1dd 100644 --- a/packages/TrilinosInstallTests/CMakeLists.txt +++ b/packages/TrilinosInstallTests/CMakeLists.txt @@ -66,13 +66,18 @@ tribits_add_advanced_test(doInstall -P "${CMAKE_CURRENT_SOURCE_DIR}/remove_dir_if_exists.cmake" TEST_1 - MESSAGE "Install whatever Trilinos packages have been enabled" + MESSAGE "Install enabled and built Trilinos packages (NOTE: This test will fail if the project has **any** build errors!)" CMND "${CMAKE_COMMAND}" ARGS --install ${PROJECT_BINARY_DIR} --prefix ${PROJECT_BINARY_DIR}/install OUTPUT_FILE doInstall.out NO_ECHO_OUTPUT + TEST_2 + MESSAGE "Grep doInstall.out file produced above to see any errors" + CMND grep ARGS -A 50 "CMake Error" doInstall.out + PASS_ANY + ADDED_TEST_NAME_OUT doInstall_name ) # NOTE: Above works even if Trilinos was configured without setting @@ -81,6 +86,11 @@ tribits_add_advanced_test(doInstall # the source dir and the build dir will still be sticking around in the # below example build. +if (doInstall_name) + set_tests_properties(${doInstall_name} + PROPERTIES FIXTURES_SETUP doInstall_passed) +endif() + tribits_add_advanced_test(find_package_Trilinos OVERALL_NUM_MPI_PROCS 1 @@ -104,8 +114,10 @@ tribits_add_advanced_test(find_package_Trilinos if (find_package_Trilinos_name) set_tests_properties(${find_package_Trilinos_name} - PROPERTIES DEPENDS ${doInstall_name} ) + PROPERTIES FIXTURES_REQUIRED doInstall_passed) endif() +# NOTE: Above, only attempt to run the find_package() test if the install +# command passed or it is guaranteed to fail. tribits_add_advanced_test(simpleBuildAgainstTrilinos @@ -144,8 +156,17 @@ tribits_add_advanced_test(simpleBuildAgainstTrilinos if (simpleBuildAgainstTrilinos_name) set_tests_properties(${simpleBuildAgainstTrilinos_name} - PROPERTIES DEPENDS ${doInstall_name} ) + PROPERTIES FIXTURES_REQUIRED doInstall_passed) endif() +# NOTE: Above, only attempt to build and test the simpleBuildAgainstTrilinos +# project if the install command passed or it is guaranteed to fail. Also +# note that we could have blocked this based on the find_package() test but +# that runs find_package(Trilinos) for all of Trilinos while the +# simpleBuildAgainstTrilinos/CMakeLists.txt file only calls +# find_package(Trilinos COMPONENTS Tpetra) so it could pass when the full +# find_package(Trilinos) call fails. Therefore, it makes sense to run the +# this test for simpleBuildAgainstTrilinos even if the test for the full +# find_package(Trilinos) command fails. tribits_package_postprocess() diff --git a/packages/TrilinosInstallTests/find_package_Trilinos/CMakeLists.txt b/packages/TrilinosInstallTests/find_package_Trilinos/CMakeLists.txt index 65661f2b7d3d..4b4d8cb0f549 100644 --- a/packages/TrilinosInstallTests/find_package_Trilinos/CMakeLists.txt +++ b/packages/TrilinosInstallTests/find_package_Trilinos/CMakeLists.txt @@ -1,4 +1,7 @@ -cmake_minimum_required(VERSION 3.17) +cmake_minimum_required(VERSION 3.0) + +# Disable Kokkos warning about not supporting C++ extensions +set(CMAKE_CXX_EXTENSIONS OFF) project(find_package_Trilinos NONE) diff --git a/packages/adelus/README.md b/packages/adelus/README.md index 1e5b46b692e1..26127ed2b169 100755 --- a/packages/adelus/README.md +++ b/packages/adelus/README.md @@ -98,6 +98,8 @@ We organize the directories as follows: * ```Adelus::GetDistribution()```: gives the distribution information that is required by the dense solver to the user that defines the matrix block and right hand side information. +* ```Adelus::AdelusHandle<...>```: an application must create a handle to the Adelus communicator and necessary metadata (the handle is passed to every subsequent Adelus function call) + * ```Adelus::FactorSolve()```: factors and solves the dense matrix in which the matrix and rhs are packed in Kokkos View @@ -105,10 +107,14 @@ and rhs are packed in Kokkos View * ```Adelus::FactorSolve_hostPtr()```: matrix and rhs are packed and passed as host pointer +* ```Adelus::Factor()```: factors the dense matrix for later solve + +* ```Adelus::Solve()```: solves the previously factored dense matrix for provided RHS + 2. Implementations of the phases of the solver (i.e. factor, solve, permutation) and other utility functions also locate in the ```src/``` subdirectory. -3. A correctness test is in the ```test/``` subdirectory. +3. Correctness tests is in the ```test/``` subdirectory. 4. A simple example that generates a random matrix and a right-hand-side to exercise the solver is in the ```example/``` subdirectory. @@ -249,12 +255,14 @@ the solver can be called. In this example, the portion of matrix on each MPI process and the reference solution vector are randomly generated. Then, the assigned RHS vectors on MPI processes can be computed. -3. Launch Adelus using ```Adelus::FactorSolve```, or ```Adelus::FactorSolve_devPtr```, +3. Create a handle to the Adelus communicator and necessary metadata + +4. Launch Adelus using ```Adelus::FactorSolve```, or ```Adelus::FactorSolve_devPtr```, or ```Adelus::FactorSolve_hostPtr```. -4. Gather results. +5. Gather results. -5. Compare the returned solution vector with the reference vector. +6. Compare the returned solution vector with the reference vector. ### Compile with Makefile diff --git a/packages/adelus/example/CMakeLists.txt b/packages/adelus/example/CMakeLists.txt index 44e3bd839536..3fdcc9bfa670 100644 --- a/packages/adelus/example/CMakeLists.txt +++ b/packages/adelus/example/CMakeLists.txt @@ -1,5 +1,7 @@ # CMAKE File for "adelus_driver" application building against an installed Trilinos -cmake_minimum_required(VERSION 3.1) +cmake_minimum_required(VERSION 3.12) + +cmake_policy(SET CMP0057 NEW) # Use Trilinos_PREFIX, if the user set it, to help find Trilinos. # The final location will actually be held in Trilinos_DIR which must @@ -28,7 +30,7 @@ MESSAGE(" Trilinos_Fortran_COMPILER = ${Trilinos_Fortran_COMPILER}") MESSAGE(" Trilinos_CXX_COMPILER_FLAGS = ${Trilinos_CXX_COMPILER_FLAGS}") MESSAGE(" Trilinos_C_COMPILER_FLAGS = ${Trilinos_C_COMPILER_FLAGS}") MESSAGE(" Trilinos_Fortran_COMPILER_FLAGS = ${Trilinos_Fortran_COMPILER_FLAGS}") -MESSAGE(" Trilinos_EXTRA_LINK_FLAGS = ${Trilinos_EXTRA_LINK_FLAGS}") +MESSAGE(" Trilinos_EXTRA_LD_FLAGS = ${Trilinos_EXTRA_LD_FLAGS}") MESSAGE("End of Trilinos details\n") # Make sure to use same compilers and flags as Trilinos @@ -36,7 +38,10 @@ SET(CMAKE_CXX_COMPILER ${Trilinos_CXX_COMPILER} ) SET(CMAKE_C_COMPILER ${Trilinos_C_COMPILER} ) SET(CMAKE_Fortran_COMPILER ${Trilinos_Fortran_COMPILER} ) -SET(CMAKE_CXX_FLAGS "${Trilinos_CXX_COMPILER_FLAGS} ${CMAKE_CXX_FLAGS} -L$ENV{MPI_ROOT}/lib -lmpi_ibm -fopenmp") +#For older versions of Trilinos +#SET(CMAKE_CXX_FLAGS "${Trilinos_CXX_COMPILER_FLAGS} ${CMAKE_CXX_FLAGS} --remove-duplicate-link-files") +#For Trilinos versions after the merge of PR#10614 +SET(CMAKE_CXX_FLAGS "${Trilinos_CXX_COMPILER_FLAGS} ${CMAKE_CXX_FLAGS}") SET(CMAKE_C_FLAGS "${Trilinos_C_COMPILER_FLAGS} ${CMAKE_C_FLAGS}") SET(CMAKE_Fortran_FLAGS "${Trilinos_Fortran_COMPILER_FLAGS} ${CMAKE_Fortran_FLAGS}") @@ -56,5 +61,7 @@ ADD_EXECUTABLE(adelus_driver adelus_driver.cpp) set_property(TARGET adelus_driver PROPERTY CXX_STANDARD 14) -TARGET_LINK_LIBRARIES(adelus_driver ${Trilinos_LIBRARIES} ${Trilinos_TPL_LIBRARIES} ${Trilinos_EXTRA_LINK_FLAGS}) - +#For older versions of Trilinos +#TARGET_LINK_LIBRARIES(adelus_driver ${Trilinos_LIBRARIES} ${Trilinos_TPL_LIBRARIES} ${Trilinos_EXTRA_LD_FLAGS}) +#For Trilinos versions after the merge of PR#10614 +TARGET_LINK_LIBRARIES(adelus_driver Trilinos::all_selected_libs) diff --git a/packages/adelus/example/adelus_driver.cpp b/packages/adelus/example/adelus_driver.cpp index d5c8c0896bb5..aa4df48d3914 100644 --- a/packages/adelus/example/adelus_driver.cpp +++ b/packages/adelus/example/adelus_driver.cpp @@ -151,16 +151,10 @@ int main( int argc, char* argv[] ) int my_rows_max; int my_cols_max; - Adelus::GetDistribution( &nprocs_row, - &matrix_size, - &nrhs, - &my_rows, - &my_cols, - &my_first_row, - &my_first_col, - &my_rhs, - &my_row, - &my_col ); + Adelus::GetDistribution( MPI_COMM_WORLD, + nprocs_row, matrix_size, nrhs, + my_rows, my_cols, my_first_row, my_first_col, + my_rhs, my_row, my_col ); MPI_Allreduce( &my_rows, &my_rows_max, 1, MPI_INT, MPI_MAX, MPI_COMM_WORLD); MPI_Allreduce( &my_cols, &my_cols_max, 1, MPI_INT, MPI_MAX, MPI_COMM_WORLD); @@ -340,6 +334,10 @@ int main( int argc, char* argv[] ) #endif } + // Create handle + Adelus::AdelusHandle + ahandle(0, MPI_COMM_WORLD, matrix_size, nprocs_row, nrhs ); + double time = 0.0; MPI_Barrier (MPI_COMM_WORLD); @@ -358,16 +356,16 @@ int main( int argc, char* argv[] ) gettimeofday( &begin, NULL ); #ifdef KKVIEW_API - Adelus::FactorSolve (my_A, my_rows, my_cols, &matrix_size, &nprocs_row, &nrhs, &secs); + Adelus::FactorSolve (ahandle, my_A, &secs); #endif #if defined(DEVPTR_API) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) - Adelus::FactorSolve_devPtr (reinterpret_cast(my_A.data()),my_rows,my_cols,my_rhs,&matrix_size,&nprocs_row,&nrhs,&secs); + Adelus::FactorSolve_devPtr (ahandle, reinterpret_cast(my_A.data()), my_rows, my_cols, my_rhs, &matrix_size, &nprocs_row, &nrhs, &secs); #endif #if defined(HOSTPTR_API) && !(defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP))//KOKKOS_ENABLE_OPENMP - Adelus::FactorSolve_hostPtr (reinterpret_cast(my_A.data()),my_rows,my_cols,my_rhs,&matrix_size,&nprocs_row,&nrhs,&secs); + Adelus::FactorSolve_hostPtr (ahandle, reinterpret_cast(my_A.data()), my_rows, my_cols, my_rhs, &matrix_size, &nprocs_row, &nrhs, &secs); #endif #if defined(HOSTPTR_API) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) - Adelus::FactorSolve_hostPtr (reinterpret_cast(h_my_A_hptr.data()),my_rows,my_cols,my_rhs,&matrix_size,&nprocs_row,&nrhs,&secs); + Adelus::FactorSolve_hostPtr (ahandle, reinterpret_cast(h_my_A_hptr.data()), my_rows, my_cols, my_rhs, &matrix_size, &nprocs_row, &nrhs, &secs); #endif Kokkos::fence(); diff --git a/packages/adelus/src/Adelus.hpp b/packages/adelus/src/Adelus.hpp index e9bcb31c2039..f3c091935370 100644 --- a/packages/adelus/src/Adelus.hpp +++ b/packages/adelus/src/Adelus.hpp @@ -44,12 +44,18 @@ //@HEADER */ +#ifndef __ADELUS_HPP__ +#define __ADELUS_HPP__ + #pragma once -#include +#include #include -#include #include +#include +#include +#include + #include // Adelus: provides the functionality to interface to a dense LU solver @@ -59,38 +65,41 @@ namespace Adelus { /// Adelus GetDistirbution /// Gives the distribution information that is required by the dense solver - /// \param nprocs_row_ (In) - number of processors for a row + /// \param comm (In) - communicator that Adelus runs on + /// \param nprocs_row (In) - number of processors for a row /// \param number_of_unknowns (In) - order of the dense matrix - /// \param nrhs_ (In) - number of right hand sides - /// \param my_rows_ (Out) - number of rows of the matrix on this processor - /// \param my_cols_ (Out) - number of columns of the matrix on this processor - /// \param my_first_row_ (Out) - first (global) row number on this processor (array starts at index 1) - /// \param my_first_col_ (Out) - first (global) column number on this processor (array starts at index 1) - /// \param my_rhs_ (Out) - number of right hand sides on this processor + /// \param nrhs (In) - number of right hand sides + /// \param my_rows (Out) - number of rows of the matrix on this processor + /// \param my_cols (Out) - number of columns of the matrix on this processor + /// \param my_first_row (Out) - first (global) row number on this processor (array starts at index 1) + /// \param my_first_col (Out) - first (global) column number on this processor (array starts at index 1) + /// \param my_rhs (Out) - number of right hand sides on this processor /// \param my_row (Out) - row number in processor mesh, 0 to the number of processors for a column -1 /// \param my_col (Out) - column number in processor mesh, 0 to the number of processors for a row -1 inline - int GetDistribution( int* nprocs_row_, - int* number_of_unknowns, - int* nrhs_, - int* my_rows_, - int* my_cols_, - int* my_first_row_, - int* my_first_col_, - int* my_rhs_, - int* my_row, - int* my_col ) { + int GetDistribution( MPI_Comm comm, + const int nprocs_row, + const int number_of_unknowns, + const int nrhs, + int& my_rows, + int& my_cols, + int& my_first_row, + int& my_first_col, + int& my_rhs, + int& my_row, + int& my_col ) { // This function echoes the multiprocessor distribution of the matrix - distmat_(nprocs_row_, + distmat_(comm, + nprocs_row, number_of_unknowns, - nrhs_, - my_rows_, - my_cols_, - my_first_row_, - my_first_col_, - my_rhs_, + nrhs, + my_rows, + my_cols, + my_first_row, + my_first_col, + my_rhs, my_row, my_col); @@ -101,36 +110,71 @@ namespace Adelus { /// Adelus FactorSolve /// Factors and solves the dense matrix - /// \param AA (InOut) -- Kokkos View that has the matrix and rhs packed (Note: matrix and rhs are overwritten) - /// \param my_rows_ (In) -- number of rows of the matrix on this processor - /// \param my_cols_ (In) -- number of columns of the matrix on this processor - /// \param matrix_size (In) -- order of the dense matrix - /// \param num_procsr (In) -- number of processors for a row - /// \param num_rhs (In) -- number of right hand sides + /// \param ahandle (In) -- handle that contains metadata needed by the Adelus solver + /// \param AA (InOut) -- Kokkos View that has the matrix and rhs packed in this processor + /// (Note: matrix and rhs are overwritten) /// \param secs (Out) -- factor and solve time in seconds - template + template inline - void FactorSolve( ZDView AA, - int my_rows_, - int my_cols_, - int* matrix_size, - int* num_procsr, - int* num_rhs, + void FactorSolve( HandleType& ahandle, + ZRHSViewType& AA, double* secs ) { - int rank; - - MPI_Comm_rank(MPI_COMM_WORLD, &rank); #ifdef PRINT_STATUS - printf("FactorSolve (Kokkos View interface) in rank %d -- my_rows %u , my_cols %u , matrix_size %u, num_procs_per_row %u, num_rhs %u\n", rank, my_rows_, my_cols_, *matrix_size, *num_procsr, *num_rhs); + printf("FactorSolve (Kokkos View interface) in rank %d\n", ahandle.get_myrank()); +#endif + + lusolve_(ahandle, AA, secs); + + } + + /// Adelus Factor + /// Factors the dense matrix for later solve + + /// \param ahandle (In) -- handle that contains metadata needed by the Adelus solver + /// \param AA (InOut) -- Kokkos View that has the matrix in this processor (Note: matrix is overwritten) + /// \param permute (In) -- Kokkos View that has the global pivot vector + /// \param secs (Out) -- factor and solve time in seconds + + template + inline + void Factor( HandleType& ahandle, + ZViewType& AA, + PViewType& permute, + double* secs ) { + +#ifdef PRINT_STATUS + printf("Factor (Kokkos View interface) in rank %d\n", ahandle.get_myrank()); +#endif + + lu_(ahandle, AA, permute, secs); + + } + + /// Adelus Solve + /// Solves the previously factored dense matrix for provided RHS + + /// \param ahandle (In) -- handle that contains metadata needed by the Adelus solver + /// \param AA (In) -- Kokkos View that has the LU-factorized matrix + /// \param BB (InOut) -- Kokkos View that has the rhs and solution (Note: rhs are overwritten) + /// \param permute (In) -- Kokkos View that has the global pivot vector + /// \param secs (Out) -- factor and solve time in seconds + + template + inline + void Solve( HandleType& ahandle, + ZViewType& AA, + RHSViewType& BB, + PViewType& permute, + double* secs ) { + +#ifdef PRINT_STATUS + printf("Solve (Kokkos View interface) in rank %d\n", ahandle.get_myrank()); #endif - lusolve_(AA, - matrix_size, - num_procsr, - num_rhs, - secs); + solve_(ahandle, AA, BB, permute, secs); } @@ -138,8 +182,10 @@ namespace Adelus { /// Adelus FactorSolve_devPtr /// Matrix and rhs are packed and passed as device pointer + template inline - void FactorSolve_devPtr( ADELUS_DATA_TYPE* AA, + void FactorSolve_devPtr( HandleType& ahandle, + ADELUS_DATA_TYPE* AA, int my_rows_, int my_cols_, int my_rhs_, @@ -147,9 +193,6 @@ namespace Adelus { int* num_procsr, int* num_rhs, double* secs ) { - int rank; - - MPI_Comm_rank(MPI_COMM_WORLD, &rank) ; { // Note: To avoid segmentation fault when FactorSolve is called multiple times with the unmanaged View, it's safest to make sure unmanaged View falls out of scope before freeing its memory. #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) @@ -165,14 +208,10 @@ namespace Adelus { AA_Internal AA_i(reinterpret_cast *>(AA), my_rows_, my_cols_ + my_rhs_ + 6); #ifdef PRINT_STATUS - printf("FactorSolve_devPtr (double complex pointer interface) in rank %d -- my_rows %u , my_cols %u, my_rhs %u , matrix_size %u, num_procs_per_row %u, num_rhs %u\n", rank, my_rows_, my_cols_, my_rhs_, *matrix_size, *num_procsr, *num_rhs); + printf("FactorSolve_devPtr (double complex pointer interface) in rank %d -- my_rows %u , my_cols %u, my_rhs %u , matrix_size %u, num_procs_per_row %u, num_rhs %u\n", ahandle.get_myrank(), my_rows_, my_cols_, my_rhs_, *matrix_size, *num_procsr, *num_rhs); #endif - lusolve_(AA_i, - matrix_size, - num_procsr, - num_rhs, - secs); + lusolve_(ahandle, AA_i, secs); #endif } } @@ -180,8 +219,10 @@ namespace Adelus { /// Adelus FactorSolve_hostPtr /// Matrix and rhs are packed and passed as host pointer + template inline - void FactorSolve_hostPtr( ADELUS_DATA_TYPE* AA, + void FactorSolve_hostPtr( HandleType& ahandle, + ADELUS_DATA_TYPE* AA, int my_rows_, int my_cols_, int my_rhs_, @@ -189,9 +230,6 @@ namespace Adelus { int* num_procsr, int* num_rhs, double* secs ) { - int rank; - - MPI_Comm_rank(MPI_COMM_WORLD, &rank) ; { // Note: To avoid segmentation fault when FactorSolve is called multiple times with the unmanaged View, it's safest to make sure unmanaged View falls out of scope before freeing its memory. typedef Kokkos::View**, @@ -213,28 +251,20 @@ namespace Adelus { AA_Internal_dev AA_i_dev( "AA_i_dev", my_rows_, my_cols_ + my_rhs_ + 6 ); #ifdef PRINT_STATUS - printf("FactorSolve_hostPtr with CUDA solve (double complex pointer interface) in rank %d -- my_rows %u , my_cols %u, my_rhs %u , matrix_size %u, num_procs_per_row %u, num_rhs %u\n", rank, my_rows_, my_cols_, my_rhs_, *matrix_size, *num_procsr, *num_rhs); + printf("FactorSolve_hostPtr with CUDA solve (double complex pointer interface) in rank %d -- my_rows %u , my_cols %u, my_rhs %u , matrix_size %u, num_procs_per_row %u, num_rhs %u\n", ahandle.get_myrank(), my_rows_, my_cols_, my_rhs_, *matrix_size, *num_procsr, *num_rhs); #endif Kokkos::deep_copy( AA_i_dev, AA_i ); - lusolve_(AA_i_dev, - matrix_size, - num_procsr, - num_rhs, - secs); + lusolve_(ahandle, AA_i_dev, secs); Kokkos::deep_copy( AA_i, AA_i_dev ); #else//OpenMP #ifdef PRINT_STATUS - printf("FactorSolve_hostPtr with host solve (double complex pointer interface) in rank %d -- my_rows %u , my_cols %u, my_rhs %u , matrix_size %u, num_procs_per_row %u, num_rhs %u\n", rank, my_rows_, my_cols_, my_rhs_, *matrix_size, *num_procsr, *num_rhs); + printf("FactorSolve_hostPtr with host solve (double complex pointer interface) in rank %d -- my_rows %u , my_cols %u, my_rhs %u , matrix_size %u, num_procs_per_row %u, num_rhs %u\n", ahandle.get_myrank(), my_rows_, my_cols_, my_rhs_, *matrix_size, *num_procsr, *num_rhs); #endif - lusolve_(AA_i, - matrix_size, - num_procsr, - num_rhs, - secs); + lusolve_(ahandle, AA_i, secs); #endif } } @@ -244,8 +274,10 @@ namespace Adelus { /// Adelus FactorSolve_devPtr /// Matrix and rhs are packed and passed as device pointer + template inline - void FactorSolve_devPtr( ADELUS_DATA_TYPE* AA, + void FactorSolve_devPtr( HandleType& ahandle, + ADELUS_DATA_TYPE* AA, int my_rows_, int my_cols_, int my_rhs_, @@ -253,9 +285,6 @@ namespace Adelus { int* num_procsr, int* num_rhs, double* secs ) { - int rank; - - MPI_Comm_rank(MPI_COMM_WORLD, &rank) ; { // Note: To avoid segmentation fault when FactorSolve is called multiple times with the unmanaged View, it's safest to make sure unmanaged View falls out of scope before freeing its memory. #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) @@ -271,14 +300,10 @@ namespace Adelus { AA_Internal AA_i(reinterpret_cast(AA), my_rows_, my_cols_ + my_rhs_ + 6); #ifdef PRINT_STATUS - printf("FactorSolve_devPtr (double pointer interface) in rank %d -- my_rows %u , my_cols %u, my_rhs %u , matrix_size %u, num_procs_per_row %u, num_rhs %u\n", rank, my_rows_, my_cols_, my_rhs_, *matrix_size, *num_procsr, *num_rhs); + printf("FactorSolve_devPtr (double pointer interface) in rank %d -- my_rows %u , my_cols %u, my_rhs %u , matrix_size %u, num_procs_per_row %u, num_rhs %u\n", ahandle.get_myrank(), my_rows_, my_cols_, my_rhs_, *matrix_size, *num_procsr, *num_rhs); #endif - lusolve_(AA_i, - matrix_size, - num_procsr, - num_rhs, - secs); + lusolve_(ahandle, AA_i, secs); #endif } } @@ -286,8 +311,10 @@ namespace Adelus { /// Adelus FactorSolve_hostPtr /// Matrix and rhs are packed and passed as host pointer + template inline - void FactorSolve_hostPtr( ADELUS_DATA_TYPE* AA, + void FactorSolve_hostPtr( HandleType& ahandle, + ADELUS_DATA_TYPE* AA, int my_rows_, int my_cols_, int my_rhs_, @@ -295,9 +322,6 @@ namespace Adelus { int* num_procsr, int* num_rhs, double* secs ) { - int rank; - - MPI_Comm_rank(MPI_COMM_WORLD, &rank) ; { // Note: To avoid segmentation fault when FactorSolve is called multiple times with the unmanaged View, it's safest to make sure unmanaged View falls out of scope before freeing its memory. typedef Kokkos::View inline - void FactorSolve_devPtr( ADELUS_DATA_TYPE* AA, + void FactorSolve_devPtr( HandleType& ahandle, + ADELUS_DATA_TYPE* AA, int my_rows_, int my_cols_, int my_rhs_, @@ -359,9 +377,6 @@ namespace Adelus { int* num_procsr, int* num_rhs, double* secs ) { - int rank; - - MPI_Comm_rank(MPI_COMM_WORLD, &rank) ; { // Note: To avoid segmentation fault when FactorSolve is called multiple times with the unmanaged View, it's safest to make sure unmanaged View falls out of scope before freeing its memory. #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) @@ -377,14 +392,10 @@ namespace Adelus { AA_Internal AA_i(reinterpret_cast *>(AA), my_rows_, my_cols_ + my_rhs_ + 6); #ifdef PRINT_STATUS - printf("FactorSolve_devPtr (float complex pointer interface) in rank %d -- my_rows %u , my_cols %u, my_rhs %u , matrix_size %u, num_procs_per_row %u, num_rhs %u\n", rank, my_rows_, my_cols_, my_rhs_, *matrix_size, *num_procsr, *num_rhs); + printf("FactorSolve_devPtr (float complex pointer interface) in rank %d -- my_rows %u , my_cols %u, my_rhs %u , matrix_size %u, num_procs_per_row %u, num_rhs %u\n", ahandle.get_myrank(), my_rows_, my_cols_, my_rhs_, *matrix_size, *num_procsr, *num_rhs); #endif - lusolve_(AA_i, - matrix_size, - num_procsr, - num_rhs, - secs); + lusolve_(ahandle, AA_i, secs); #endif } } @@ -392,8 +403,10 @@ namespace Adelus { /// Adelus FactorSolve_hostPtr /// Matrix and rhs are packed and passed as host pointer + template inline - void FactorSolve_hostPtr( ADELUS_DATA_TYPE* AA, + void FactorSolve_hostPtr( HandleType& ahandle, + ADELUS_DATA_TYPE* AA, int my_rows_, int my_cols_, int my_rhs_, @@ -401,9 +414,6 @@ namespace Adelus { int* num_procsr, int* num_rhs, double* secs ) { - int rank; - - MPI_Comm_rank(MPI_COMM_WORLD, &rank) ; { // Note: To avoid segmentation fault when FactorSolve is called multiple times with the unmanaged View, it's safest to make sure unmanaged View falls out of scope before freeing its memory. typedef Kokkos::View**, @@ -425,28 +435,20 @@ namespace Adelus { AA_Internal_dev AA_i_dev( "AA_i_dev", my_rows_, my_cols_ + my_rhs_ + 6 ); #ifdef PRINT_STATUS - printf("FactorSolve_hostPtr with CUDA solve (float complex pointer interface) in rank %d -- my_rows %u , my_cols %u, my_rhs %u , matrix_size %u, num_procs_per_row %u, num_rhs %u\n", rank, my_rows_, my_cols_, my_rhs_, *matrix_size, *num_procsr, *num_rhs); + printf("FactorSolve_hostPtr with CUDA solve (float complex pointer interface) in rank %d -- my_rows %u , my_cols %u, my_rhs %u , matrix_size %u, num_procs_per_row %u, num_rhs %u\n", ahandle.get_myrank(), my_rows_, my_cols_, my_rhs_, *matrix_size, *num_procsr, *num_rhs); #endif Kokkos::deep_copy( AA_i_dev, AA_i ); - lusolve_(AA_i_dev, - matrix_size, - num_procsr, - num_rhs, - secs); + lusolve_(ahandle, AA_i_dev, secs); Kokkos::deep_copy( AA_i, AA_i_dev ); #else//OpenMP #ifdef PRINT_STATUS - printf("FactorSolve_hostPtr with host solve (float complex pointer interface) in rank %d -- my_rows %u , my_cols %u, my_rhs %u , matrix_size %u, num_procs_per_row %u, num_rhs %u\n", rank, my_rows_, my_cols_, my_rhs_, *matrix_size, *num_procsr, *num_rhs); + printf("FactorSolve_hostPtr with host solve (float complex pointer interface) in rank %d -- my_rows %u , my_cols %u, my_rhs %u , matrix_size %u, num_procs_per_row %u, num_rhs %u\n", ahandle.get_myrank(), my_rows_, my_cols_, my_rhs_, *matrix_size, *num_procsr, *num_rhs); #endif - lusolve_(AA_i, - matrix_size, - num_procsr, - num_rhs, - secs); + lusolve_(ahandle, AA_i, secs); #endif } } @@ -456,8 +458,10 @@ namespace Adelus { /// Adelus FactorSolve_devPtr /// Matrix and rhs are packed and passed as device pointer + template inline - void FactorSolve_devPtr( ADELUS_DATA_TYPE* AA, + void FactorSolve_devPtr( HandleType& ahandle, + ADELUS_DATA_TYPE* AA, int my_rows_, int my_cols_, int my_rhs_, @@ -465,9 +469,6 @@ namespace Adelus { int* num_procsr, int* num_rhs, double* secs ) { - int rank; - - MPI_Comm_rank(MPI_COMM_WORLD, &rank) ; { // Note: To avoid segmentation fault when FactorSolve is called multiple times with the unmanaged View, it's safest to make sure unmanaged View falls out of scope before freeing its memory. #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) @@ -483,14 +484,10 @@ namespace Adelus { AA_Internal AA_i(reinterpret_cast(AA), my_rows_, my_cols_ + my_rhs_ + 6); #ifdef PRINT_STATUS - printf("FactorSolve_devPtr (float pointer interface) in rank %d -- my_rows %u , my_cols %u, my_rhs %u , matrix_size %u, num_procs_per_row %u, num_rhs %u\n", rank, my_rows_, my_cols_, my_rhs_, *matrix_size, *num_procsr, *num_rhs); + printf("FactorSolve_devPtr (float pointer interface) in rank %d -- my_rows %u , my_cols %u, my_rhs %u , matrix_size %u, num_procs_per_row %u, num_rhs %u\n", ahandle.get_myrank(), my_rows_, my_cols_, my_rhs_, *matrix_size, *num_procsr, *num_rhs); #endif - lusolve_(AA_i, - matrix_size, - num_procsr, - num_rhs, - secs); + lusolve_(ahandle, AA_i, secs); #endif } } @@ -498,8 +495,10 @@ namespace Adelus { /// Adelus FactorSolve_hostPtr /// Matrix and rhs are packed and passed as host pointer + template inline - void FactorSolve_hostPtr( ADELUS_DATA_TYPE* AA, + void FactorSolve_hostPtr( HandleType& ahandle, + ADELUS_DATA_TYPE* AA, int my_rows_, int my_cols_, int my_rhs_, @@ -507,9 +506,6 @@ namespace Adelus { int* num_procsr, int* num_rhs, double* secs ) { - int rank; - - MPI_Comm_rank(MPI_COMM_WORLD, &rank) ; { // Note: To avoid segmentation fault when FactorSolve is called multiple times with the unmanaged View, it's safest to make sure unmanaged View falls out of scope before freeing its memory. typedef Kokkos::View (nrows%nprocs_col_)) ? *my_first_row_ + (nrows%nprocs_col_) : - *my_first_row_ + (*my_row); + my_first_row = (my_row > (nrows%nprocs_col)) ? my_first_row + (nrows%nprocs_col) : + my_first_row + my_row; - if (*my_row < nrows % nprocs_col_) - ++(*my_rows_); + if (my_row < (nrows%nprocs_col)) ++my_rows; - *my_cols_ = nrows / nprocs_row_; + // + my_cols = ncols / nprocs_row; - *my_first_col_ = (*my_col)*(*my_cols_) + 1; + my_first_col = my_col * my_cols + 1; - *my_first_col_ = ((*my_col) > (nrows%nprocs_row_)) ? *my_first_col_ + (nrows%nprocs_row_) : - *my_first_col_ + (*my_col); + my_first_col = (my_col > (ncols%nprocs_row)) ? my_first_col + (ncols%nprocs_row) : + my_first_col + my_col; - *my_cols_ = *ncols / *nprocsr; - if (*my_col < *ncols % (*nprocsr)) - ++(*my_cols_); + if (my_col < (ncols%nprocs_row)) ++my_cols; // Distribute the RHS per processor - *my_rhs_ = *nrhs_ / *nprocsr; - if (*my_col < *nrhs_ % (*nprocsr)) ++(*my_rhs_); + my_rhs = nrhs / nprocs_row; + if (my_col < (nrhs%nprocs_row)) ++my_rhs; } diff --git a/packages/adelus/src/Adelus_distribute.hpp b/packages/adelus/src/Adelus_distribute.hpp index 7405f8a67937..39bc1822395a 100644 --- a/packages/adelus/src/Adelus_distribute.hpp +++ b/packages/adelus/src/Adelus_distribute.hpp @@ -54,6 +54,7 @@ //jdkotul@sandia.gov // Variables INPUT +// comm --- communicator that Adelus is running on // nprocsr --- number of processors assigned to a row // ncols --- number of columns(=rows) for the matrix // nrhs --- number of right hand sides @@ -74,16 +75,17 @@ namespace Adelus { -void distmat_( int *nprocsr, - int *ncols, - int *nrhs_, - int *my_rows_, - int *my_cols_, - int *my_first_row_, - int *my_first_col_, - int *my_rhs_, - int *my_row, - int *my_col ); +void distmat_( MPI_Comm comm, + const int nprocsr, + const int ncols, + const int nrhs, + int& my_rows, + int& my_cols, + int& my_first_row, + int& my_first_col, + int& my_rhs, + int& my_row, + int& my_col ); }//namespace Adelus diff --git a/packages/adelus/src/Adelus_factor.hpp b/packages/adelus/src/Adelus_factor.hpp index 51907824510b..686e7e37f055 100644 --- a/packages/adelus/src/Adelus_factor.hpp +++ b/packages/adelus/src/Adelus_factor.hpp @@ -54,7 +54,6 @@ #include "Adelus_defines.h" #include "Adelus_macros.h" -#include "Adelus_pcomm.hpp" #include "Adelus_mytime.hpp" #include "Kokkos_Core.hpp" @@ -63,22 +62,8 @@ #include "KokkosBlas1_iamax.hpp" #include "KokkosBlas3_gemm.hpp" -extern int myrow; -extern int mycol; -extern int me; // processor id information -extern int nprocs_row; // num of procs to which a row is assigned -extern int nprocs_col; // num of procs to which a col is assigned -extern int nrows_matrix; // number of rows in the matrix -extern int ncols_matrix; // number of cols in the matrix -extern int my_rows; // num of rows I own -extern int my_cols; // num of cols I own -extern int my_rhs; // num of right hand side I own -extern int blksz; // block size for BLAS 3 operations - #define LUSTATUSINT 64 -extern MPI_Comm col_comm; - // Message tags #define LUPIVOTTYPE (1<<13) #define LUCOLTYPE (1<<14) @@ -92,14 +77,21 @@ extern MPI_Comm col_comm; namespace Adelus { -template +template inline -void factor(ZDView& ZV, // matrix and rhs +void factor(HandleType& ahandle, // handle containg metadata + ZDView& ZV, // matrix and rhs ViewType2D& col1_view, // col used for updating a col ViewType2D& row1_view, // diagonal row ViewType1D& row2_view, // pivot row ViewType1D& row3_view, // temporary vector for rows - ViewIntType1D& pivot_vec_view) // vector storing list of pivot rows + ViewIntType1D& pivot_vec_view, // vector storing list of pivot rows + int nrhs, // total num of RHS (note: set to 0 if factoring matrix only) + int my_rhs) // num of RHS I own (note: set to 0 if factoring matrix only) { typedef typename ZDView::value_type value_type; #ifdef PRINT_STATUS @@ -113,6 +105,16 @@ void factor(ZDView& ZV, // matrix and rhs typedef Kokkos::View View1DHostPinnType;//HIPHostPinnedSpace #endif #endif + + MPI_Comm comm = ahandle.get_comm(); + MPI_Comm col_comm = ahandle.get_col_comm(); + int me = ahandle.get_myrank(); + int nprocs_row = ahandle.get_nprocs_row(); + int nprocs_col = ahandle.get_nprocs_col(); + int ncols_matrix = ahandle.get_ncols_matrix(); + int my_rows = ahandle.get_my_rows(); + int my_cols = ahandle.get_my_cols(); + int blksz = ahandle.get_blksz(); int j,k; // loop counters @@ -140,6 +142,7 @@ void factor(ZDView& ZV, // matrix and rhs int cur_col_i, cur_col_j, cur_row_i, cur_row_j, act_col_i, act_row_j, update_i, update_j; int sav_col_i, sav_col_j, sav_piv_row_i, sav_piv_row_j, act_piv_row_i, piv_row_i; int cur_col1_row_i, piv_col1_row_i; + int sav_pivot_vec_i; int ringdist,rdist; long type,bytes; @@ -177,13 +180,13 @@ void factor(ZDView& ZV, // matrix and rhs // Distribution for the matrix on me - MPI_Comm_size(MPI_COMM_WORLD,&numprocs); + MPI_Comm_size(comm,&numprocs); if ( (numprocs/nprocs_row) * nprocs_row != numprocs ) { if (me == 0) { printf("nprocs_row must go into numprocs perfectly!\n"); printf("Try a different value of nprocs_row.\n"); } - MPI_Barrier(MPI_COMM_WORLD); + MPI_Barrier(comm); exit(0); } @@ -217,6 +220,8 @@ void factor(ZDView& ZV, // matrix and rhs sav_piv_row_i=0; sav_piv_row_j=0; // location for next row being saved for gemm update update_i=0; update_j=0; // location of remaining local matrix + sav_pivot_vec_i = 0; // location to store name of pivot row + #ifdef GET_TIMING xpivmsgtime=bcastpivstime=bcastpivrtime=bcastcolstime=bcastcolrtime=bcastrowtime=sendrowtime=recvrowtime=0.0; copycoltime=copyrowtime=copyrow1time=copypivrowtime=copypivrow1time=pivotswaptime=0.0; @@ -344,11 +349,15 @@ void factor(ZDView& ZV, // matrix and rhs xpivmsgtime += (MPI_Wtime()-t1); #endif + pivot_vec_view(sav_pivot_vec_i) = pivot.row; gpivot_row = pivot.row; pivot_mag = abs(pivot.entry); if (pivot_mag == 0.0) { - printf("Node %d error -- zero pivot found in column %d -- exiting\n",me,j); - return; + //printf("Node %d error -- zero pivot found in column %d -- exiting\n",me,j); + //return; + std::ostringstream os; + os << "Adelus::factor: rank " << me << " error -- zero pivot found in column "<< j; + Kokkos::Impl::throw_runtime_exception (os.str ()); } // divide everything including the diagonal by the pivot entry @@ -409,7 +418,7 @@ void factor(ZDView& ZV, // matrix and rhs for (rdist = 1;rdist <= MAXDIST;rdist++){ if (rowplus(rdist) == c_owner) break; bytes = sizeof(gpivot_row); - MPI_Send(&gpivot_row,bytes,MPI_BYTE,rowplus(rdist),LUPIVROWTYPE+j,MPI_COMM_WORLD); + MPI_Send(&gpivot_row,bytes,MPI_BYTE,rowplus(rdist),LUPIVROWTYPE+j,comm); } #ifdef GET_TIMING bcastpivstime += (MPI_Wtime()-t1); @@ -432,9 +441,9 @@ void factor(ZDView& ZV, // matrix and rhs if (rowplus(rdist) == c_owner) break; bytes=sizeof(ADELUS_DATA_TYPE)*col_len; #if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined (KOKKOS_ENABLE_HIP)) - MPI_Send(h_coltmp.data(),bytes,MPI_BYTE,rowplus(rdist),LUROWTYPE+j,MPI_COMM_WORLD); + MPI_Send(h_coltmp.data(),bytes,MPI_BYTE,rowplus(rdist),LUROWTYPE+j,comm); #else //GPU-aware MPI - MPI_Send(col1_view.data()+sav_col_j*col1_view.stride(1)+sav_col_i,bytes,MPI_BYTE,rowplus(rdist),LUROWTYPE+j,MPI_COMM_WORLD); + MPI_Send(col1_view.data()+sav_col_j*col1_view.stride(1)+sav_col_i,bytes,MPI_BYTE,rowplus(rdist),LUROWTYPE+j,comm); #endif } #ifdef GET_TIMING @@ -450,6 +459,7 @@ void factor(ZDView& ZV, // matrix and rhs act_row_j++; sav_piv_row_j++; cols_used++; + sav_pivot_vec_i++; } else { @@ -457,10 +467,10 @@ void factor(ZDView& ZV, // matrix and rhs bytes=col_len*sizeof(ADELUS_DATA_TYPE); #if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined (KOKKOS_ENABLE_HIP)) - MPI_Irecv(h_coltmp.data(),bytes,MPI_BYTE,MPI_ANY_SOURCE,LUROWTYPE+j,MPI_COMM_WORLD,&msgrequest); + MPI_Irecv(h_coltmp.data(),bytes,MPI_BYTE,MPI_ANY_SOURCE,LUROWTYPE+j,comm,&msgrequest); #else //GPU-aware MPI MPI_Irecv(col1_view.data()+sav_col_j*col1_view.stride(1)+sav_col_i,bytes,MPI_BYTE, - MPI_ANY_SOURCE,LUROWTYPE+j,MPI_COMM_WORLD,&msgrequest); + MPI_ANY_SOURCE,LUROWTYPE+j,comm,&msgrequest); #endif #ifdef GET_TIMING @@ -469,7 +479,7 @@ void factor(ZDView& ZV, // matrix and rhs bytes = 0; type = LUPIVROWTYPE+j; bytes=4; bytes = sizeof(gpivot_row); - MPI_Recv(&gpivot_row,bytes,MPI_BYTE,MPI_ANY_SOURCE,type,MPI_COMM_WORLD,&msgstatus); + MPI_Recv(&gpivot_row,bytes,MPI_BYTE,MPI_ANY_SOURCE,type,comm,&msgstatus); #ifdef GET_TIMING bcastpivrtime += (MPI_Wtime()-t1); #endif @@ -483,7 +493,7 @@ void factor(ZDView& ZV, // matrix and rhs for (rdist = 1;rdist <= MAXDIST;rdist++) { if (rowplus(rdist) == c_owner) break; bytes = sizeof(gpivot_row); - MPI_Send(&gpivot_row,bytes,MPI_BYTE,rowplus(rdist),LUPIVROWTYPE+j,MPI_COMM_WORLD); + MPI_Send(&gpivot_row,bytes,MPI_BYTE,rowplus(rdist),LUPIVROWTYPE+j,comm); } #ifdef GET_TIMING bcastpivstime += (MPI_Wtime()-t1); @@ -515,9 +525,9 @@ void factor(ZDView& ZV, // matrix and rhs if (rowplus(rdist) == c_owner) break; bytes=col_len*sizeof(ADELUS_DATA_TYPE); #if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined (KOKKOS_ENABLE_HIP)) - MPI_Send(h_coltmp.data(),bytes,MPI_BYTE,rowplus(rdist),LUROWTYPE+j,MPI_COMM_WORLD); + MPI_Send(h_coltmp.data(),bytes,MPI_BYTE,rowplus(rdist),LUROWTYPE+j,comm); #else //GPU-aware MPI - MPI_Send(col1_view.data()+sav_col_j*col1_view.stride(1)+sav_col_i,bytes,MPI_BYTE,rowplus(rdist),LUROWTYPE+j,MPI_COMM_WORLD); + MPI_Send(col1_view.data()+sav_col_j*col1_view.stride(1)+sav_col_i,bytes,MPI_BYTE,rowplus(rdist),LUROWTYPE+j,comm); #endif } #ifdef GET_TIMING @@ -723,9 +733,9 @@ void factor(ZDView& ZV, // matrix and rhs #endif bytes=(row_len+colcnt)*sizeof(ADELUS_DATA_TYPE); #if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined (KOKKOS_ENABLE_HIP)) - MPI_Send(h_row2.data(),bytes,MPI_BYTE,pivot_owner,LUSENDTYPE+j,MPI_COMM_WORLD); + MPI_Send(h_row2.data(),bytes,MPI_BYTE,pivot_owner,LUSENDTYPE+j,comm); #else //GPU-aware MPI - MPI_Send(row2_view.data(),bytes,MPI_BYTE,pivot_owner,LUSENDTYPE+j,MPI_COMM_WORLD); + MPI_Send(row2_view.data(),bytes,MPI_BYTE,pivot_owner,LUSENDTYPE+j,comm); #endif #ifdef GET_TIMING sendrowtime += (MPI_Wtime()-t1); @@ -740,9 +750,9 @@ void factor(ZDView& ZV, // matrix and rhs if (me != r_owner) { bytes=(row_len+colcnt)*sizeof(ADELUS_DATA_TYPE); #if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined (KOKKOS_ENABLE_HIP)) - MPI_Recv(h_row2.data(),bytes,MPI_BYTE,r_owner,LUSENDTYPE+j,MPI_COMM_WORLD,&msgstatus); + MPI_Recv(h_row2.data(),bytes,MPI_BYTE,r_owner,LUSENDTYPE+j,comm,&msgstatus); #else //GPU-aware MPI - MPI_Recv(row2_view.data(),bytes,MPI_BYTE,r_owner,LUSENDTYPE+j,MPI_COMM_WORLD,&msgstatus); + MPI_Recv(row2_view.data(),bytes,MPI_BYTE,r_owner,LUSENDTYPE+j,comm,&msgstatus); #endif } #ifdef GET_TIMING @@ -924,57 +934,57 @@ void factor(ZDView& ZV, // matrix and rhs copytime = pivotswaptime+copycoltime+copyrowtime+copyrow1time+copypivrowtime+copypivrow1time; dgemmtime = updatetime+colupdtime+rowupdtime+scaltime; #ifdef ADELUS_SHOW_TIMING_DETAILS - showtime("Time to do iamax",&iamaxtime); - showtime("Time to get local pivot",&getlocalpivtime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to do iamax",&iamaxtime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to get local pivot",&getlocalpivtime); #endif - showtime("Total finding local pivot time",&localpivtime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Total finding local pivot time",&localpivtime); double tmp = 100*localpivtime/totalfactortime; - showtime("Percent finding local pivot time",&tmp); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Percent finding local pivot time",&tmp); #ifdef ADELUS_SHOW_TIMING_DETAILS - showtime("Time to xchgpivot",&xpivmsgtime); - showtime("Time to do send in bcast pivot",&bcastpivstime); - showtime("Time to do recv in bcast pivot",&bcastpivrtime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to xchgpivot",&xpivmsgtime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to do send in bcast pivot",&bcastpivstime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to do recv in bcast pivot",&bcastpivrtime); tmp = bcastpivrtime+bcastpivstime; - showtime("Time to do bcast pivot",&tmp); - showtime("Time to do send in bcast cur col",&bcastcolstime); - showtime("Time to do recv bcast cur col",&bcastcolrtime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to do bcast pivot",&tmp); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to do send in bcast cur col",&bcastcolstime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to do recv bcast cur col",&bcastcolrtime); tmp = bcastcolrtime+bcastcolstime; - showtime("Time to do bcast cur col",&tmp); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to do bcast cur col",&tmp); tmp = bcastcolrtime+bcastcolstime+bcastpivrtime+bcastpivstime; - showtime("Time to do bcast cur col and pivot",&tmp); - showtime("Time to bcast piv row",&bcastrowtime); - showtime("Time to send cur row",&sendrowtime); - showtime("Time to recv cur row",&recvrowtime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to do bcast cur col and pivot",&tmp); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to bcast piv row",&bcastrowtime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to send cur row",&sendrowtime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to recv cur row",&recvrowtime); #endif - showtime("Total msg passing time",&msgtime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Total msg passing time",&msgtime); tmp = 100*msgtime/totalfactortime; - showtime("Percent msg passing time",&tmp); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Percent msg passing time",&tmp); #if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined (KOKKOS_ENABLE_HIP)) - showtime("Total copy between host pinned mem and dev mem time",©hostpinnedtime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Total copy between host pinned mem and dev mem time",©hostpinnedtime); tmp = 100*copyhostpinnedtime/totalfactortime; - showtime("Percent copy between host pinned mem and dev mem time",&tmp); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Percent copy between host pinned mem and dev mem time",&tmp); #endif #ifdef ADELUS_SHOW_TIMING_DETAILS - showtime("Time to swap pivot",&pivotswaptime); - showtime("Time to copy cur col",©coltime); - showtime("Time to copy cur row to sav row",©rowtime); - showtime("Time to copy piv row to sav piv",©pivrowtime); - showtime("Time to copy sav row to cur row",©row1time); - showtime("Time to copy sav piv to piv row",©pivrow1time); -#endif - showtime("Total copying time",©time); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to swap pivot",&pivotswaptime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to copy cur col",©coltime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to copy cur row to sav row",©rowtime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to copy piv row to sav piv",©pivrowtime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to copy sav row to cur row",©row1time); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to copy sav piv to piv row",©pivrow1time); +#endif + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Total copying time",©time); tmp = 100*copytime/totalfactortime; - showtime("Percent copying time",&tmp); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Percent copying time",&tmp); #ifdef ADELUS_SHOW_TIMING_DETAILS - showtime("Time to scale cur col",&scaltime); - showtime("Time to update cur col",&colupdtime); - showtime("Time to update piv row",&rowupdtime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to scale cur col",&scaltime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to update cur col",&colupdtime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to update piv row",&rowupdtime); #endif - showtime("Time to update matrix",&updatetime); - showtime("Total update time",&dgemmtime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Time to update matrix",&updatetime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Total update time",&dgemmtime); tmp = 100*dgemmtime/totalfactortime; - showtime("Percent update time",&tmp); - showtime("Total time in factor",&totalfactortime); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Percent update time",&tmp); + showtime(ahandle.get_comm_id(),comm,me,numprocs,"Total time in factor",&totalfactortime); #endif } diff --git a/packages/adelus/src/Adelus_forward.hpp b/packages/adelus/src/Adelus_forward.hpp new file mode 100644 index 000000000000..d46b959378a1 --- /dev/null +++ b/packages/adelus/src/Adelus_forward.hpp @@ -0,0 +1,174 @@ +/* +//@HEADER +// ************************************************************************ +// +// Adelus v. 1.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of NTESS nor the names of the contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY EXPRESS OR IMPLIED +// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +// IN NO EVENT SHALL NTESS OR THE CONTRIBUTORS BE LIABLE FOR ANY DIRECT, +// INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING +// IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Vinh Dang (vqdang@sandia.gov) +// Joseph Kotulski (jdkotul@sandia.gov) +// Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef __ADELUS_FORWARD_HPP__ +#define __ADELUS_FORWARD_HPP__ + +#include +#include +#include +#include +#include "Adelus_defines.h" +#include "Adelus_macros.h" +#include "Adelus_mytime.hpp" +#include "Kokkos_Core.hpp" +#include "KokkosBlas3_gemm.hpp" + +namespace Adelus { + +template +inline +void forward(HandleType& ahandle, ZViewType& Z, RHSViewType& RHS) +{ + using value_type = typename ZViewType::value_type ; + using execution_space = typename ZViewType::device_type::execution_space ; + using memory_space = typename ZViewType::device_type::memory_space ; + using ViewMatrixType = Kokkos::View; +#ifdef ADELUS_HOST_PINNED_MEM_MPI + #if defined(KOKKOS_ENABLE_CUDA) + using ViewMatrixHostPinnType = Kokkos::View;//CudaHostPinnedSpace + #elif defined(KOKKOS_ENABLE_HIP) + using ViewMatrixHostPinnType = Kokkos::View;//HIPHostPinnedSpace + #endif +#endif + + MPI_Comm row_comm = ahandle.get_row_comm(); + MPI_Comm col_comm = ahandle.get_col_comm(); + int myrow = ahandle.get_myrow(); + int nprocs_row = ahandle.get_nprocs_row(); + int nprocs_col = ahandle.get_nprocs_col(); + int nrows_matrix = ahandle.get_nrows_matrix(); + int my_rows = ahandle.get_my_rows(); + + int k_row; // torus-wrap row corresponding to kth global row + int k_col; // torus-wrap column corresponding to kth global col + int istart; // Starting row index for pivot column + int count_row; // dummy index + value_type d_one = static_cast( 1.0); + value_type d_min_one = static_cast(-1.0); + + ViewMatrixType piv_col( "piv_col", my_rows, 1 ); // portion of pivot column I am sending + ViewMatrixType ck( "ck", 1, RHS.extent(1) ); // rhs corresponding to current column of the backsubstitution +#if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) + ViewMatrixHostPinnType h_piv_col( "h_piv_col", my_rows, 1 ); + ViewMatrixHostPinnType h_ck( "h_ck", 1, RHS.extent(1) ); +#endif + +#ifdef PRINT_STATUS + printf("Rank %i -- forward() Begin forward solve with myrow %d, nprocs_row %d, nprocs_col %d, nrows_matrix %d, ncols_matrix %d, my_rows %d, my_cols %d, my_rhs %d, nrhs %d, value_type %s, execution_space %s, memory_space %s\n", ahandle.get_myrank(), myrow, nprocs_row, nprocs_col, nrows_matrix, ahandle.get_ncols_matrix(), my_rows, ahandle.get_my_cols(), ahandle.get_my_rhs(), ahandle.get_nrhs(), typeid(value_type).name(), typeid(execution_space).name(), typeid(memory_space).name()); +#endif + +#ifdef GET_TIMING + double t1, fwdsolvetime; + t1 = MPI_Wtime(); +#endif + + // Perform the Forward Substitution + for (int k=0; k<= nrows_matrix-2; k++) { + k_row=k%nprocs_col; + k_col=k%nprocs_row; + istart = (k+1-myrow)/nprocs_col; + if (istart * nprocs_col < k+1-myrow) istart++; + + if (istart < my_rows) { + Kokkos::deep_copy( subview(piv_col, Kokkos::make_pair(0, my_rows - istart), 0), + subview(Z, Kokkos::make_pair(istart, my_rows), k/nprocs_row) ); + } + count_row = my_rows - istart; + + //Note: replace MPI_Send/MPI_Irecv with MPI_Bcast + // Rank k_col broadcasts the pivot_col to all + // other ranks in the row_comm +#if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) + Kokkos::deep_copy(h_piv_col,piv_col); + MPI_Bcast(reinterpret_cast(h_piv_col.data()), count_row*sizeof(ADELUS_DATA_TYPE), MPI_CHAR, k_col, row_comm); + Kokkos::deep_copy(piv_col,h_piv_col); +#else //GPU-aware MPI + MPI_Bcast(reinterpret_cast(piv_col.data()), count_row*sizeof(ADELUS_DATA_TYPE), MPI_CHAR, k_col, row_comm); +#endif + + if (ahandle.get_my_rhs() > 0) { + //ck = RHS(k/nprocs_col,0); + //MPI_Bcast((char *)(&ck),sizeof(ADELUS_DATA_TYPE),MPI_CHAR,k_row,col_comm); + //count_row=0; + //printf("Point 2: k %d, istart %d, my_rows %d\n", k, istart, my_rows); + //for (int i=istart;i<=my_rows-1;i++) { + // RHS(i,0) = RHS(i,0) - piv_col(count_row) * ck; + // count_row++; + //} + int curr_lrid = k/nprocs_col;//note: nprocs_col (global var) cannot be read in a device function + Kokkos::parallel_for(Kokkos::RangePolicy(0,RHS.extent(1)), KOKKOS_LAMBDA (const int i) { + ck(0,i) = RHS(curr_lrid,i); + }); + +#if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) + Kokkos::deep_copy(h_ck,ck); + MPI_Bcast(reinterpret_cast(h_ck.data()), RHS.extent(1)*sizeof(ADELUS_DATA_TYPE), MPI_CHAR, k_row, col_comm); + Kokkos::deep_copy(ck,h_ck); +#else //GPU-aware MPI + Kokkos::fence(); + MPI_Bcast(reinterpret_cast(ck.data()), RHS.extent(1)*sizeof(ADELUS_DATA_TYPE), MPI_CHAR, k_row, col_comm); +#endif + + auto sub_pivot_col = subview(piv_col, Kokkos::make_pair(0, my_rows - istart), Kokkos::ALL()); + auto sub_rhs = subview(RHS, Kokkos::make_pair(istart, my_rows), Kokkos::ALL()); + if (istart < my_rows) { + KokkosBlas::gemm("N", "N", d_min_one, sub_pivot_col, ck, d_one, sub_rhs); + } + } + MPI_Barrier(ahandle.get_comm()); + }// end of for (k=0; k<= nrows_matrix-2; k++) + +#ifdef GET_TIMING + fwdsolvetime = MPI_Wtime() - t1; + showtime(ahandle.get_comm_id(), ahandle.get_comm(), ahandle.get_myrank(), ahandle.get_nprocs_cube(), + "Total time in forward solve", &fwdsolvetime); +#endif +} + +}//namespace Adelus + +#endif diff --git a/packages/adelus/src/Adelus_macros.h b/packages/adelus/src/Adelus_macros.h index 8a0b89e4c9c9..5acf4dddfb7c 100644 --- a/packages/adelus/src/Adelus_macros.h +++ b/packages/adelus/src/Adelus_macros.h @@ -44,19 +44,22 @@ //@HEADER */ -#define grey_c(P) ((P)^((P)>>1)) +#ifndef __ADELUS_MACROS_H__ +#define __ADELUS_MACROS_H__ + +//#define grey_c(P) ((P)^((P)>>1)) #define lrow_to_grow(R) ( (mesh_row(me) + nprocs_col*(R)) ) #define grow_to_lrow(R) ( (R/nprocs_col) ) -// #define col_owner(C) (((C)%nprocs_row) + (me - me%nprocs_row)) +//// #define col_owner(C) (((C)%nprocs_row) + (me - me%nprocs_row)) #define col_owner(C) ( proc_num(mesh_row(me) , (C)%nprocs_row) ) -// #define row_owner(R) ((((R)%nprocs_col)*nprocs_row) + (me%nprocs_row)) +//// #define row_owner(R) ((((R)%nprocs_col)*nprocs_row) + (me%nprocs_row)) #define row_owner(R) ( proc_num((R)%nprocs_col , mesh_col(me)) ) -#define owner(R, C) ((((R)%nprocs_col)*nprocs_row) + ((C)%nprocs_row)) +//#define owner(R, C) ((((R)%nprocs_col)*nprocs_row) + ((C)%nprocs_row)) #define mesh_col(P) ((P)%nprocs_row) @@ -64,4 +67,6 @@ #define proc_num(R,C) ((R)*nprocs_row + (C)) -#define mac_send_msg(D,B,S,T) MPI_Send(B,S,MPI_CHAR,D,T,MPI_COMM_WORLD) +//#define mac_send_msg(D,B,S,T) MPI_Send(B,S,MPI_CHAR,D,T,MPI_COMM_WORLD) + +#endif diff --git a/packages/adelus/src/Adelus_mytime.hpp b/packages/adelus/src/Adelus_mytime.hpp index eb397a8929a4..30dc4ee9bb37 100644 --- a/packages/adelus/src/Adelus_mytime.hpp +++ b/packages/adelus/src/Adelus_mytime.hpp @@ -64,11 +64,8 @@ double get_seconds(double start) // Exchange and calculate max, min, and average timing information -void showtime(const char *label, double *value) +void showtime(int comm_id, MPI_Comm comm, int me, int nprocs_cube, const char *label, double *value) { - extern int me; // current processor rank - extern int nprocs_cube; - double avgtime; struct { @@ -77,18 +74,18 @@ void showtime(const char *label, double *value) } max_in, max_out, min_in, min_out; max_in.val = *value; max_in.proc = me; - MPI_Allreduce(&max_in,&max_out,1,MPI_DOUBLE_INT,MPI_MAXLOC,MPI_COMM_WORLD); + MPI_Allreduce(&max_in,&max_out,1,MPI_DOUBLE_INT,MPI_MAXLOC,comm); min_in.val = *value; min_in.proc = me; - MPI_Allreduce(&min_in,&min_out,1,MPI_DOUBLE_INT,MPI_MINLOC,MPI_COMM_WORLD); + MPI_Allreduce(&min_in,&min_out,1,MPI_DOUBLE_INT,MPI_MINLOC,comm); - MPI_Allreduce(value,&avgtime,1,MPI_DOUBLE,MPI_SUM,MPI_COMM_WORLD); + MPI_Allreduce(value,&avgtime,1,MPI_DOUBLE,MPI_SUM,comm); avgtime /= nprocs_cube; if (me == 0) { - fprintf(stderr, "%s = %.4f (min, on proc %d), %.4f (avg), %.4f (max, on proc %d).\n", - label,min_out.val,min_out.proc,avgtime, max_out.val,max_out.proc); + fprintf(stderr, "Communicator %d -- %s = %.4f (min, on proc %d), %.4f (avg), %.4f (max, on proc %d).\n", + comm_id,label,min_out.val,min_out.proc,avgtime, max_out.val,max_out.proc); } } diff --git a/packages/adelus/src/Adelus_pcomm.cpp b/packages/adelus/src/Adelus_pcomm.cpp deleted file mode 100644 index 973bf96fe1ec..000000000000 --- a/packages/adelus/src/Adelus_pcomm.cpp +++ /dev/null @@ -1,127 +0,0 @@ -/* -//@HEADER -// ************************************************************************ -// -// Adelus v. 1.0 -// Copyright (2020) National Technology & Engineering -// Solutions of Sandia, LLC (NTESS). -// -// Under the terms of Contract DE-NA0003525 with NTESS, -// the U.S. Government retains certain rights in this software. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are -// met: -// -// 1. Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright -// notice, this list of conditions and the following disclaimer in the -// documentation and/or other materials provided with the distribution. -// -// 3. Neither the name of NTESS nor the names of the contributors may be -// used to endorse or promote products derived from this software without -// specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY EXPRESS OR IMPLIED -// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. -// IN NO EVENT SHALL NTESS OR THE CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -// INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING -// IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -// POSSIBILITY OF SUCH DAMAGE. -// -// Questions? Contact Vinh Dang (vqdang@sandia.gov) -// Joseph Kotulski (jdkotul@sandia.gov) -// Siva Rajamanickam (srajama@sandia.gov) -// -// ************************************************************************ -//@HEADER -*/ - -#include -#include -#include -#include "Adelus_defines.h" -#include "Adelus_macros.h" - -#define DEBUG1 0 - -namespace Adelus { - -// define variables to avoid compiler error - -int one = 1; -double d_one = 1.; - -int ringnext,ringprev,hbit,rmbit,my_col_id,my_row_id; -int ringnex2,ringpre2,ringnex3,ringpre3,ringnex4,ringpre4; -//typedef struct { -// DATA_TYPE entry; -// DATA_TYPE current; -// int row; -//} pivot_type; - -void initcomm(){ - extern int nprocs_col, nprocs_row, me, hbit, my_col_id, my_row_id, rmbit; - extern int ringnext,ringprev,ringnex2,ringpre2,ringnex3,ringpre3,ringnex4,ringpre4; - int col_id,bit; - - my_col_id = mesh_col(me); - my_row_id = mesh_row(me); - - - col_id = my_col_id + 1; - if (col_id >= nprocs_row) col_id = 0; - ringnext = proc_num(my_row_id,col_id); - - col_id = my_col_id + 2; - if (col_id >= nprocs_row) col_id -= nprocs_row; - ringnex2 = proc_num(my_row_id,col_id); - - col_id = my_col_id + 3; - if (col_id >= nprocs_row) col_id -= nprocs_row; - ringnex3 = proc_num(my_row_id,col_id); - - col_id = my_col_id + 4; - if (col_id >= nprocs_row) col_id -= nprocs_row; - ringnex4 = proc_num(my_row_id,col_id); - - col_id = my_col_id - 1; - if (col_id < 0) col_id = nprocs_row - 1; - ringprev = proc_num(my_row_id,col_id); - - col_id = my_col_id - 2; - if (col_id < 0) col_id += nprocs_row; - ringpre2 = proc_num(my_row_id,col_id); - - col_id = my_col_id - 3; - if (col_id < 0) col_id += nprocs_row; - ringpre3 = proc_num(my_row_id,col_id); - - col_id = my_col_id - 4; - if (col_id < 0) col_id += nprocs_row; - ringpre4 = proc_num(my_row_id,col_id); - - // calculate first power of two bigger or equal to the number of rows, - // and low order one bit in own name - - for (hbit = 1; nprocs_col > hbit ; hbit = hbit << 1); - - rmbit = 0; - for (bit = 1; bit < hbit; bit = bit << 1) { - if ((my_row_id & bit) == bit) { - rmbit = bit; break;} - } - -#if (DEBUG1 > 0) - printf("In initcomm, node %d: my_col_id = %d, my_row_id = %d, hbit = %d, rmbit = %d, ringnext = %d, ringprev = %d\n",me,my_col_id,my_row_id,hbit,rmbit,ringnext,ringprev); -#endif -} - -}//namespace Adelus diff --git a/packages/adelus/src/Adelus_pcomm.hpp b/packages/adelus/src/Adelus_pcomm.hpp deleted file mode 100644 index 402d8bbcdb81..000000000000 --- a/packages/adelus/src/Adelus_pcomm.hpp +++ /dev/null @@ -1,56 +0,0 @@ -/* -//@HEADER -// ************************************************************************ -// -// Adelus v. 1.0 -// Copyright (2020) National Technology & Engineering -// Solutions of Sandia, LLC (NTESS). -// -// Under the terms of Contract DE-NA0003525 with NTESS, -// the U.S. Government retains certain rights in this software. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are -// met: -// -// 1. Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright -// notice, this list of conditions and the following disclaimer in the -// documentation and/or other materials provided with the distribution. -// -// 3. Neither the name of NTESS nor the names of the contributors may be -// used to endorse or promote products derived from this software without -// specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY EXPRESS OR IMPLIED -// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. -// IN NO EVENT SHALL NTESS OR THE CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -// INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING -// IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -// POSSIBILITY OF SUCH DAMAGE. -// -// Questions? Contact Vinh Dang (vqdang@sandia.gov) -// Joseph Kotulski (jdkotul@sandia.gov) -// Siva Rajamanickam (srajama@sandia.gov) -// -// ************************************************************************ -//@HEADER -*/ - -#ifndef __ADELUS_PCOMM_HPP__ -#define __ADELUS_PCOMM_HPP__ - -namespace Adelus { - -void initcomm( ); - -}//namespace Adelus - -#endif diff --git a/packages/adelus/src/Adelus_perm1.hpp b/packages/adelus/src/Adelus_perm1.hpp index 99da2e6cf4dd..6a96aaa8aff3 100644 --- a/packages/adelus/src/Adelus_perm1.hpp +++ b/packages/adelus/src/Adelus_perm1.hpp @@ -60,17 +60,6 @@ #define IBM_MPI_WRKAROUND -extern int me; // processor id information -extern int nprocs_row; // num of procs to which a row is assigned -extern int nprocs_col; // num of procs to which a col is assigned -extern int nrows_matrix; // number of rows in the matrix -extern int ncols_matrix; // number of cols in the matrix -extern int my_rows; // num of rows I own -extern int my_cols; // num of cols I own -extern int myrow; -extern int mycol; -extern MPI_Comm col_comm; - namespace Adelus { #ifndef IBM_MPI_WRKAROUND @@ -101,14 +90,23 @@ namespace Adelus { // Permutes -- unwraps the torus-wrap for the solution // using the communication buffer - template + template inline - void perm1_(ZDView& ZV, int *num_my_rhs) { - + void perm1_(HandleType& ahandle, ZDView& ZV) { + + MPI_Comm comm = ahandle.get_comm(); + int me = ahandle.get_myrank(); + int my_rhs_ = ahandle.get_my_rhs(); + int my_rows = ahandle.get_my_rows(); + int nprocs_row = ahandle.get_nprocs_row(); + int nprocs_col = ahandle.get_nprocs_col(); + int nrows_matrix = ahandle.get_nrows_matrix(); + int ncols_matrix = ahandle.get_ncols_matrix(); + int my_first_row = ahandle.get_my_first_row(); + int my_first_col = ahandle.get_my_first_col(); + int i; - int my_rhs_; - - + int bytes; int dest; int type; @@ -139,8 +137,6 @@ namespace Adelus { #ifdef GET_TIMING t2 = MPI_Wtime(); #endif - - my_rhs_=*num_my_rhs; typedef typename ZDView::value_type value_type; typedef typename ZDView::device_type::execution_space execution_space; @@ -217,19 +213,19 @@ namespace Adelus { } - if( dest !=me ) { + if( dest != me ) { bytes = (my_rhs_ + 1)*sizeof(ADELUS_DATA_TYPE); MPI_Irecv( (char *)(reinterpret_cast(rhs_temp.data())+next_s),bytes,MPI_CHAR,MPI_ANY_SOURCE, - MPI_ANY_TAG,MPI_COMM_WORLD,&msgrequest); + MPI_ANY_TAG,comm,&msgrequest); auto sub_ZV = subview(ZV, ptr1_idx, Kokkos::ALL()); zcopy_wr_local_index(my_rhs_, sub_ZV, temp_s, local_index); type = PERMTYPE+change_send; MPI_Send((char *)(reinterpret_cast(temp_s.data())),bytes,MPI_CHAR,dest, - type,MPI_COMM_WORLD); + type,comm); change_send++; next_s = change_send * (my_rhs_+1); @@ -270,7 +266,7 @@ namespace Adelus { totalpermtime = MPI_Wtime() - t2; #endif #ifdef GET_TIMING - showtime("Total time in perm",&totalpermtime); + showtime(ahandle.get_comm_id(), comm, me, ahandle.get_nprocs_cube(), "Total time in perm", &totalpermtime); #endif } @@ -286,15 +282,24 @@ namespace Adelus { // Permutes -- unwraps the torus-wrap for the solution // using the communication buffer - template + template inline - void perm1_(ZDView& ZV, int *num_my_rhs) { - + void perm1_(HandleType& ahandle, ZDView& ZV) { + + MPI_Comm col_comm = ahandle.get_col_comm(); + int myrow = ahandle.get_myrow(); + int my_rhs_ = ahandle.get_my_rhs(); + int my_rows = ahandle.get_my_rows(); + int nprocs_row = ahandle.get_nprocs_row(); + int nprocs_col = ahandle.get_nprocs_col(); + int nrows_matrix = ahandle.get_nrows_matrix(); + int ncols_matrix = ahandle.get_ncols_matrix(); + int my_first_row = ahandle.get_my_first_row(); + int i; - int my_rhs_; int dest, global_index, local_index; - + int row_offset; int ncols_proc1, ncols_proc2, nprocs_row1; int ptr1_idx, myfirstrow; @@ -310,8 +315,6 @@ namespace Adelus { #ifdef GET_TIMING t2 = MPI_Wtime(); #endif - - my_rhs_=*num_my_rhs; typedef typename ZDView::value_type value_type; #ifdef PRINT_STATUS @@ -346,11 +349,11 @@ namespace Adelus { myfirstrow = myrow * (nrows_matrix / nprocs_col) + 1; myfirstrow = ( myrow > (nrows_matrix%nprocs_col) ) ? myfirstrow + (nrows_matrix%nprocs_col) : myfirstrow + myrow; - + ptr1_idx = 0; #ifdef PRINT_STATUS - printf("Rank %i -- perm1_() Begin permutation, execution_space %s, memory_space %s\n",me,typeid(execution_space).name(),typeid(memory_space).name()); + printf("Rank %i -- perm1_() Begin permutation, execution_space %s, memory_space %s\n",ahandle.get_myrank(),typeid(execution_space).name(),typeid(memory_space).name()); #endif for (i=0; i host pinned mem",©hostpinnedtime); + showtime(ahandle.get_comm_id(), ahandle.get_comm(), ahandle.get_myrank(), ahandle.get_nprocs_cube(), + "Time to copy dev mem --> host pinned mem", ©hostpinnedtime); #endif - showtime("Total time in perm",&totalpermtime); + showtime(ahandle.get_comm_id(), ahandle.get_comm(), ahandle.get_myrank(), ahandle.get_nprocs_cube(), + "Total time in perm", &totalpermtime); #endif } diff --git a/packages/adelus/src/Adelus_perm_mat.hpp b/packages/adelus/src/Adelus_perm_mat.hpp new file mode 100644 index 000000000000..1380f7bfff29 --- /dev/null +++ b/packages/adelus/src/Adelus_perm_mat.hpp @@ -0,0 +1,297 @@ +/* +//@HEADER +// ************************************************************************ +// +// Adelus v. 1.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of NTESS nor the names of the contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY EXPRESS OR IMPLIED +// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +// IN NO EVENT SHALL NTESS OR THE CONTRIBUTORS BE LIABLE FOR ANY DIRECT, +// INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING +// IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Vinh Dang (vqdang@sandia.gov) +// Joseph Kotulski (jdkotul@sandia.gov) +// Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef __ADELUS_PERMMAT_HPP__ +#define __ADELUS_PERMMAT_HPP__ + +#include +#include +#include +#include +#include "Adelus_defines.h" +#include "Adelus_macros.h" +#include "Adelus_mytime.hpp" +#include "Kokkos_Core.hpp" + +namespace Adelus { + + template + inline + void exchange_pivots(HandleType& ahandle, PViewType& lpiv_view, PViewType& permute) { + + MPI_Comm comm = ahandle.get_comm(); + MPI_Comm row_comm = ahandle.get_row_comm(); + int me = ahandle.get_myrank(); + int myrow = ahandle.get_myrow(); + int mycol = ahandle.get_mycol(); + int nprocs_row = ahandle.get_nprocs_row(); + int nprocs_col = ahandle.get_nprocs_col(); + int nrows_matrix = ahandle.get_nrows_matrix(); + int my_rows = ahandle.get_my_rows(); + + MPI_Status msgstatus; + int rank_row,k_row,pivot_col; + + // First gather the permutation vector to processor 0 in row_comm + if (myrow == 0 || mycol == 0) { + for (int k=0;k<=nrows_matrix-1;k++) { + pivot_col = k%nprocs_row; + k_row = k%nprocs_col; + rank_row = k_row*nprocs_row; + if (rank_row == pivot_col) {//on the same rank + if (me == rank_row) {//I am the right process to do + int j=k/nprocs_row; + int i=k/nprocs_col; + permute(i) = lpiv_view(j); + } + } + else {//on different ranks + if (me == pivot_col) { + int j=k/nprocs_row; + MPI_Send(lpiv_view.data()+j,1,MPI_INT,rank_row,0,comm); + } + if (me == rank_row) { + int i=k/nprocs_col; + MPI_Recv(permute.data()+i,1,MPI_INT,pivot_col,0,comm,&msgstatus); + } + } + } + } + MPI_Barrier(comm); + // Broadcast to the rest of the processors in row_comm + MPI_Bcast(permute.data(),my_rows,MPI_INT,0,row_comm); + + }// End of function exchange_pivots + + template + inline + void permute_mat(HandleType& ahandle, ZViewType& Z, PViewType& lpiv_view, PViewType& permute) { + using value_type = typename ZViewType::value_type; +#ifndef ADELUS_PERM_MAT_FORWARD_COPY_TO_HOST + using execution_space = typename ZViewType::device_type::execution_space ; + using memory_space = typename ZViewType::device_type::memory_space ; + using ViewVectorType = Kokkos::View; +#ifdef ADELUS_HOST_PINNED_MEM_MPI + #if defined(KOKKOS_ENABLE_CUDA) + using ViewVectorHostPinnType = Kokkos::View;//CudaHostPinnedSpace + #elif defined(KOKKOS_ENABLE_HIP) + using ViewVectorHostPinnType = Kokkos::View;//HIPHostPinnedSpace + #endif +#endif + MPI_Comm col_comm = ahandle.get_col_comm(); + int myrow = ahandle.get_myrow(); + int mycol = ahandle.get_mycol(); + int nprocs_row = ahandle.get_nprocs_row(); + int nprocs_col = ahandle.get_nprocs_col(); + int nrows_matrix = ahandle.get_nrows_matrix(); + +#ifdef PRINT_STATUS + printf("Rank %i -- permute_mat() Begin permute mat with myrow %d, mycol %d, nprocs_row %d, nprocs_col %d, nrows_matrix %d, ncols_matrix %d, my_rows %d, my_cols %d, my_rhs %d, nrhs %d, value_type %s, execution_space %s, memory_space %s\n", ahandle.get_myrank(), myrow, mycol, nprocs_row, nprocs_col, nrows_matrix, ahandle.get_ncols_matrix(), ahandle.get_my_rows(), ahandle.get_my_cols(), ahandle.get_my_rhs(), ahandle.get_nrhs(), typeid(value_type).name(), typeid(execution_space).name(), typeid(memory_space).name()); +#endif +#endif + + MPI_Status msgstatus; + + int pivot_row, k_row; +#ifdef ADELUS_PERM_MAT_FORWARD_COPY_TO_HOST + value_type tmpr, tmps; +#else + ViewVectorType tmpr( "tmpr", Z.extent(1) ); + ViewVectorType tmps( "tmps", Z.extent(1) ); +#if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) + ViewVectorHostPinnType h_tmpr( "h_tmpr", Z.extent(1) ); + ViewVectorHostPinnType h_tmps( "h_tmps", Z.extent(1) ); +#endif +#endif + +#ifdef GET_TIMING + double exchpivtime,permutemattime,t1; + + t1 = MPI_Wtime(); +#endif + + exchange_pivots(ahandle, lpiv_view, permute); + +#ifdef GET_TIMING + exchpivtime = MPI_Wtime()-t1; + + t1 = MPI_Wtime(); +#endif + +#ifdef ADELUS_PERM_MAT_FORWARD_COPY_TO_HOST + for (int j=0;j<=ahandle.get_my_cols()-1;j++) { + int J=j*nprocs_row+mycol; // global column index + for (int k=J+1;k<=nrows_matrix-1;k++) { + k_row=k%nprocs_col; + if (myrow==k_row) + pivot_row=permute(k/nprocs_col); + MPI_Bcast(&pivot_row,1,MPI_INT,k_row,col_comm); + if (k != pivot_row) { + if (myrow == k_row) { + tmps = Z(k/nprocs_col, J/nprocs_row); + MPI_Send((char *)(&tmps),sizeof(value_type),MPI_CHAR,pivot_row%nprocs_col,2,col_comm); + } + if (myrow == pivot_row%nprocs_col) { + tmps = Z(pivot_row/nprocs_col, J/nprocs_row); + MPI_Send((char *)(&tmps),sizeof(value_type),MPI_CHAR,k_row,3,col_comm); + } + if (myrow == k_row) { + MPI_Recv((char *)(&tmpr),sizeof(value_type),MPI_CHAR,pivot_row%nprocs_col,3,col_comm,&msgstatus); + Z(k/nprocs_col, J/nprocs_row) = tmpr; + } + if (myrow == pivot_row%nprocs_col) { + MPI_Recv((char *)(&tmpr),sizeof(value_type),MPI_CHAR,k_row,2,col_comm,&msgstatus); + Z(pivot_row/nprocs_col, J/nprocs_row) = tmpr; + } + }// End of if (k != pivot_row) + }// End of for (k=J+1;k<=nrows_matrix-1;k++) + }// End of for (j=0;j<=my_cols-1;j++) +#else + for (int k = 1 + mycol; k <= nrows_matrix - 1; k++) { + int max_gcol_k=k-1; // max. global column index in the k row + int max_lcol_k=0; // max. local column index in the k row + k_row=k%nprocs_col; // mesh row id (in the MPI process mesh) of the process that holds k + + if (myrow==k_row) pivot_row = permute(k/nprocs_col); + MPI_Bcast(&pivot_row,1,MPI_INT,k_row,col_comm); + + int max_gcol_pivot=pivot_row-1; // max. global column index in the pivot row + int max_lcol_pivot=0; // max. local column index in the pivot row + int pivot_row_pid = pivot_row%nprocs_col;// mesh row id (in the MPI process mesh) of the process that holds pivot_row + + //Find max. local column index in the k row that covers the lower triangular part + if ( mycol <= max_gcol_k%nprocs_row) + max_lcol_k = max_gcol_k/nprocs_row; + else + max_lcol_k = max_gcol_k/nprocs_row - 1;//one element less + + //Find max. local column index in the pivot row that covers the lower triangular part + if ( mycol <= max_gcol_pivot%nprocs_row) + max_lcol_pivot = max_gcol_pivot/nprocs_row; + else + max_lcol_pivot = max_gcol_pivot/nprocs_row - 1;//one element less + + //Find the number of columns needs to be exchanged + int min_len = std::min(max_lcol_k,max_lcol_pivot) + 1; + + if (k != pivot_row) {//k row is differrent from pivot_row, i.e. needs permutation + if (k_row == pivot_row_pid) {//pivot row is in the same rank + if (myrow == k_row) {//I am the right process to do permutation + int curr_lrid = k/nprocs_col; + int piv_lrid = pivot_row/nprocs_col; + Kokkos::parallel_for(Kokkos::RangePolicy(0,min_len), KOKKOS_LAMBDA (const int i) { + value_type tmp = Z(curr_lrid,i); + Z(curr_lrid,i) = Z(piv_lrid,i); + Z(piv_lrid,i) = tmp; + }); + Kokkos::fence(); + } + } + else {//k row and pivot row are in different processes (rank) + if (myrow == k_row) {//I am holding k row + int curr_lrid = k/nprocs_col; + Kokkos::parallel_for(Kokkos::RangePolicy(0,min_len), KOKKOS_LAMBDA (const int i) { + tmps(i) = Z(curr_lrid,i); + }); + +#if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) + Kokkos::deep_copy(h_tmps,tmps); + MPI_Send(reinterpret_cast(h_tmps.data()),min_len*sizeof(value_type),MPI_CHAR,pivot_row_pid,2,col_comm); + MPI_Recv(reinterpret_cast(h_tmpr.data()),min_len*sizeof(value_type),MPI_CHAR,pivot_row_pid,3,col_comm,&msgstatus); + Kokkos::deep_copy(tmpr,h_tmpr); +#else //GPU-aware MPI + Kokkos::fence(); + + MPI_Send(reinterpret_cast(tmps.data()),min_len*sizeof(value_type),MPI_CHAR,pivot_row_pid,2,col_comm); + MPI_Recv(reinterpret_cast(tmpr.data()),min_len*sizeof(value_type),MPI_CHAR,pivot_row_pid,3,col_comm,&msgstatus); +#endif + + Kokkos::parallel_for(Kokkos::RangePolicy(0,min_len), KOKKOS_LAMBDA (const int i) { + Z(curr_lrid,i) = tmpr(i); + }); + Kokkos::fence(); + } + if (myrow == pivot_row_pid) {//I am holding the pivot row + int piv_lrid = pivot_row/nprocs_col; + Kokkos::parallel_for(Kokkos::RangePolicy(0,min_len), KOKKOS_LAMBDA (const int i) { + tmps(i) = Z(piv_lrid,i); + }); + +#if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) + Kokkos::deep_copy(h_tmps,tmps); + MPI_Recv(reinterpret_cast(h_tmpr.data()),min_len*sizeof(value_type),MPI_CHAR,k_row,2,col_comm,&msgstatus); + MPI_Send(reinterpret_cast(h_tmps.data()),min_len*sizeof(value_type),MPI_CHAR,k_row,3,col_comm); + Kokkos::deep_copy(tmpr,h_tmpr); +#else // GPU-aware MPI + Kokkos::fence(); + + MPI_Recv(reinterpret_cast(tmpr.data()),min_len*sizeof(value_type),MPI_CHAR,k_row,2,col_comm,&msgstatus); + MPI_Send(reinterpret_cast(tmps.data()),min_len*sizeof(value_type),MPI_CHAR,k_row,3,col_comm); +#endif + + Kokkos::parallel_for(Kokkos::RangePolicy(0,min_len), KOKKOS_LAMBDA (const int i) { + Z(piv_lrid,i) = tmpr(i); + }); + Kokkos::fence(); + } + }//End of k row and pivot row are in different processes (rank) + }// End of if (k != pivot_row) + }// End of for (int k=1+mycol;k<=nrows_matrix-1;k++) { +#endif + +#ifdef GET_TIMING + permutemattime = MPI_Wtime()-t1; + + showtime(ahandle.get_comm_id(), ahandle.get_comm(), ahandle.get_myrank(), ahandle.get_nprocs_cube(), + "Time to exchange pivot information", &exchpivtime); + showtime(ahandle.get_comm_id(), ahandle.get_comm(), ahandle.get_myrank(), ahandle.get_nprocs_cube(), + "Time to permute matrix", &permutemattime); +#endif + }// End of function permute_mat + +}//namespace Adelus + +#endif diff --git a/packages/adelus/src/Adelus_perm_rhs.hpp b/packages/adelus/src/Adelus_perm_rhs.hpp new file mode 100644 index 000000000000..9d2000268df3 --- /dev/null +++ b/packages/adelus/src/Adelus_perm_rhs.hpp @@ -0,0 +1,191 @@ +/* +//@HEADER +// ************************************************************************ +// +// Adelus v. 1.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of NTESS nor the names of the contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY EXPRESS OR IMPLIED +// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +// IN NO EVENT SHALL NTESS OR THE CONTRIBUTORS BE LIABLE FOR ANY DIRECT, +// INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING +// IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Vinh Dang (vqdang@sandia.gov) +// Joseph Kotulski (jdkotul@sandia.gov) +// Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef __ADELUS_PERMRHS_HPP__ +#define __ADELUS_PERMRHS_HPP__ + +#include +#include +#include +#include +#include "Adelus_defines.h" +#include "Adelus_macros.h" +#include "Adelus_mytime.hpp" +#include "Kokkos_Core.hpp" + +namespace Adelus { + + template + inline + void permute_rhs(HandleType& ahandle, ZViewType& RHS, PViewType& permute) { + using value_type = typename ZViewType::value_type; + using execution_space = typename ZViewType::device_type::execution_space ; + using memory_space = typename ZViewType::device_type::memory_space ; + using ViewVectorType = Kokkos::View; +#ifdef ADELUS_HOST_PINNED_MEM_MPI + #if defined(KOKKOS_ENABLE_CUDA) + using ViewVectorHostPinnType = Kokkos::View;//CudaHostPinnedSpace + #elif defined(KOKKOS_ENABLE_HIP) + using ViewVectorHostPinnType = Kokkos::View;//HIPHostPinnedSpace + #endif +#endif + + MPI_Comm col_comm = ahandle.get_col_comm(); + int myrow = ahandle.get_myrow(); + int nprocs_col = ahandle.get_nprocs_col(); + int nrows_matrix = ahandle.get_nrows_matrix(); + + int pivot_row, k_row; + ViewVectorType tmpr( "tmpr", RHS.extent(1) ); + ViewVectorType tmps( "tmps", RHS.extent(1) ); +#if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) + ViewVectorHostPinnType h_tmpr( "h_tmpr", RHS.extent(1) ); + ViewVectorHostPinnType h_tmps( "h_tmps", RHS.extent(1) ); +#endif + + MPI_Status msgstatus; + + //TODO: try this later + //MPI_Datatype strided_vec_type; + //int strided_vec_nblocks = RHS.extent(1); + //int strided_vec_blocklen = 1; + //int strided_vec_stride = RHS.extent(0); + //MPI_Type_vector( strided_vec_nblocks, strided_vec_blocklen, strided_vec_stride, + // ADELUS_MPI_DATA_TYPE, &strided_vec_type); + //MPI_Type_commit(&strided_vec_type); + +#ifdef GET_TIMING + double permuterhstime,t1; + + t1 = MPI_Wtime(); +#endif + + for (int k=0;k<=nrows_matrix-2;k++) { + k_row=k%nprocs_col; + + if (ahandle.get_my_rhs() > 0) { + if (myrow==k_row) pivot_row = permute(k/nprocs_col); + MPI_Bcast(&pivot_row,1,MPI_INT,k_row,col_comm); + int pivot_row_pid = pivot_row%nprocs_col; + + if (k != pivot_row) { + if (k_row == pivot_row_pid) {//pivot row is in the same rank + if (myrow == k_row) { + int curr_lrid = k/nprocs_col; + int piv_lrid = pivot_row/nprocs_col; + Kokkos::parallel_for(Kokkos::RangePolicy(0,RHS.extent(1)), KOKKOS_LAMBDA (const int i) { + value_type tmp = RHS(curr_lrid,i); + RHS(curr_lrid,i) = RHS(piv_lrid,i); + RHS(piv_lrid,i) = tmp; + }); + Kokkos::fence(); + } + } + else {//pivot row is is a different rank + if (myrow == k_row) { + int curr_lrid = k/nprocs_col; + Kokkos::parallel_for(Kokkos::RangePolicy(0,RHS.extent(1)), KOKKOS_LAMBDA (const int i) { + tmps(i) = RHS(curr_lrid,i); + }); + +#if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) + Kokkos::deep_copy(h_tmps,tmps); + MPI_Send(reinterpret_cast(h_tmps.data()),RHS.extent(1)*sizeof(value_type),MPI_CHAR,pivot_row_pid,2,col_comm); + MPI_Recv(reinterpret_cast(h_tmpr.data()),RHS.extent(1)*sizeof(value_type),MPI_CHAR,pivot_row_pid,3,col_comm,&msgstatus); + Kokkos::deep_copy(tmpr,h_tmpr); +#else //GPU-aware MPI + Kokkos::fence(); + + MPI_Send(reinterpret_cast(tmps.data()),RHS.extent(1)*sizeof(value_type),MPI_CHAR,pivot_row_pid,2,col_comm); + MPI_Recv(reinterpret_cast(tmpr.data()),RHS.extent(1)*sizeof(value_type),MPI_CHAR,pivot_row_pid,3,col_comm,&msgstatus); +#endif + + Kokkos::parallel_for(Kokkos::RangePolicy(0,RHS.extent(1)), KOKKOS_LAMBDA (const int i) { + RHS(curr_lrid,i) = tmpr(i); + }); + Kokkos::fence(); + } + if (myrow == pivot_row_pid) { + int piv_lrid = pivot_row/nprocs_col; + Kokkos::parallel_for(Kokkos::RangePolicy(0,RHS.extent(1)), KOKKOS_LAMBDA (const int i) { + tmps(i) = RHS(piv_lrid,i); + }); + +#if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) + Kokkos::deep_copy(h_tmps,tmps); + MPI_Recv(reinterpret_cast(h_tmpr.data()),RHS.extent(1)*sizeof(value_type),MPI_CHAR,k_row,2,col_comm,&msgstatus); + MPI_Send(reinterpret_cast(h_tmps.data()),RHS.extent(1)*sizeof(value_type),MPI_CHAR,k_row,3,col_comm); + Kokkos::deep_copy(tmpr,h_tmpr); +#else // GPU-aware MPI + Kokkos::fence(); + + MPI_Recv(reinterpret_cast(tmpr.data()),RHS.extent(1)*sizeof(value_type),MPI_CHAR,k_row,2,col_comm,&msgstatus); + MPI_Send(reinterpret_cast(tmps.data()),RHS.extent(1)*sizeof(value_type),MPI_CHAR,k_row,3,col_comm); +#endif + + Kokkos::parallel_for(Kokkos::RangePolicy(0,RHS.extent(1)), KOKKOS_LAMBDA (const int i) { + RHS(piv_lrid,i) = tmpr(i); + }); + Kokkos::fence(); + } + }//End of pivot row is is a different rank + }// End of if (k != pivot_row) + + }// End of if (my_num_rhs > 0) + + }// End of for (k=0;k<=nrows_matrix-2;k++) + +#ifdef GET_TIMING + permuterhstime = MPI_Wtime()-t1; + + showtime(ahandle.get_comm_id(), ahandle.get_comm(), ahandle.get_myrank(), ahandle.get_nprocs_cube(), + "Time to permute rhs", &permuterhstime); +#endif + }// End of function permute_rhs + +}//namespace Adelus + +#endif diff --git a/packages/adelus/src/Adelus_solve.hpp b/packages/adelus/src/Adelus_solve.hpp index b23ccdcaec5d..fe20098ef03e 100644 --- a/packages/adelus/src/Adelus_solve.hpp +++ b/packages/adelus/src/Adelus_solve.hpp @@ -53,32 +53,12 @@ #include #include "Adelus_defines.h" #include "Adelus_macros.h" -#include "Adelus_pcomm.hpp" #include "Adelus_mytime.hpp" #include "Kokkos_Core.hpp" #include "KokkosBlas3_gemm.hpp" #define IBM_MPI_WRKAROUND2 -extern int me; - -extern int ncols_matrix; // number of cols in the matrix - -extern int nprocs_col; // num of procs to which a col is assigned -extern int nprocs_row; // num of procs to which a row is assigned - -extern int my_first_col; // proc position in a col -extern int my_first_row; // proc position in a row - -extern int my_rows; // num of rows I own -extern int my_cols; // num of cols I own - -extern int nrhs; // number of right hand sides -extern int my_rhs; // number of right hand sides that I own - -extern MPI_Comm col_comm; - - #define SOSTATUSINT 32768 // Message tags @@ -89,10 +69,10 @@ extern MPI_Comm col_comm; namespace Adelus { // Customized elimination on the rhs that I own -template -void elimination_rhs(int N, ZDView& ptr3, ZDView& ptr2, RView& ptr4, int act_col) { +template +void elimination_rhs(int N, ZView& ptr2, RHSView& ptr3, DView& ptr4, int act_col) { #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) - Kokkos::parallel_for(Kokkos::RangePolicy(0,N), KOKKOS_LAMBDA (const int i) { + Kokkos::parallel_for(Kokkos::RangePolicy(0,N), KOKKOS_LAMBDA (const int i) { ptr4(0,i) = ptr3(i)/ptr2(act_col); ptr3(i) = ptr4(0,i); }); @@ -104,25 +84,38 @@ void elimination_rhs(int N, ZDView& ptr3, ZDView& ptr2, RView& ptr4, int act_col #endif } -template +template inline -void back_solve6(ZDView& ZV) +void back_solve_rhs_pipelined_comm(HandleType& ahandle, ZViewType& Z, RHSViewType& RHS) { - typedef typename ZDView::value_type value_type; + using value_type = typename ZViewType::value_type; #ifdef PRINT_STATUS - typedef typename ZDView::device_type::execution_space execution_space; + using execution_space = typename ZViewType::device_type::execution_space; #endif - typedef typename ZDView::device_type::memory_space memory_space; - typedef Kokkos::View ViewMatrixType; + using memory_space = typename ZViewType::device_type::memory_space; + using View2DType = Kokkos::View; #if defined(ADELUS_HOST_PINNED_MEM_MPI) || defined(IBM_MPI_WRKAROUND2) #if defined(KOKKOS_ENABLE_CUDA) - typedef Kokkos::View View2DHostPinnType;//CudaHostPinnedSpace + using View2DHostPinnType = Kokkos::View;//CudaHostPinnedSpace #elif defined(KOKKOS_ENABLE_HIP) - typedef Kokkos::View View2DHostPinnType;//HIPHostPinnedSpace + using View2DHostPinnType = Kokkos::View;//HIPHostPinnedSpace #endif #endif + MPI_Comm comm = ahandle.get_comm(); + MPI_Comm col_comm = ahandle.get_col_comm(); + int me = ahandle.get_myrank(); + int nprocs_row = ahandle.get_nprocs_row(); + int nprocs_col = ahandle.get_nprocs_col(); + int ncols_matrix = ahandle.get_ncols_matrix(); + int my_rows = ahandle.get_my_rows(); + int my_cols = ahandle.get_my_cols(); + int my_first_row = ahandle.get_my_first_row(); + int my_first_col = ahandle.get_my_first_col(); + int nrhs = ahandle.get_nrhs(); + int my_rhs = ahandle.get_my_rhs(); + int j; // loop counters int end_row; // row num to end column operations int bytes[16]; // number of bytes in messages @@ -192,11 +185,11 @@ void back_solve6(ZDView& ZV) t1 = MPI_Wtime(); #endif - ViewMatrixType row1( "row1", one, nrhs ); // row1: diagonal row (temp variables) + View2DType row1( "row1", one, nrhs ); // row1: diagonal row (temp variables) #if (defined(ADELUS_HOST_PINNED_MEM_MPI) || defined(IBM_MPI_WRKAROUND2)) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) View2DHostPinnType h_row2( "h_row2", my_rows, max_bytes/sizeof(ADELUS_DATA_TYPE)/my_rows ); #else - ViewMatrixType row2( "row2", my_rows, max_bytes/sizeof(ADELUS_DATA_TYPE)/my_rows ); + View2DType row2( "row2", my_rows, max_bytes/sizeof(ADELUS_DATA_TYPE)/my_rows ); #endif #if (defined(ADELUS_HOST_PINNED_MEM_MPI) || defined(IBM_MPI_WRKAROUND2)) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) View2DHostPinnType h_row1( "h_row1", one, nrhs ); @@ -238,7 +231,7 @@ void back_solve6(ZDView& ZV) // do an elimination step on the rhs that I own - //auto ptr2_view = subview(ZV, end_row-1, Kokkos::ALL()); + //auto ptr2_view = subview(Z, end_row-1, Kokkos::ALL()); root = row_owner(global_col); @@ -246,9 +239,9 @@ void back_solve6(ZDView& ZV) #ifdef GET_TIMING t1 = MPI_Wtime(); #endif - auto ptr2_view = subview(ZV, end_row-1, Kokkos::ALL()); - auto ptr3_view = subview(ZV, end_row-1, Kokkos::make_pair(my_cols, my_cols+n_rhs_this)); - elimination_rhs(n_rhs_this, ptr3_view, ptr2_view, row1, act_col);//note: row1 = ptr4 + auto ptr2_view = subview(Z, end_row-1, Kokkos::ALL()); + auto ptr3_view = subview(RHS, end_row-1, Kokkos::make_pair(0, n_rhs_this)); + elimination_rhs(n_rhs_this, ptr2_view, ptr3_view, row1, act_col);//note: row1 = ptr4 end_row--; #ifdef GET_TIMING eliminaterhstime += (MPI_Wtime()-t1); @@ -298,8 +291,8 @@ void back_solve6(ZDView& ZV) t1 = MPI_Wtime(); #endif - auto A_view = subview(ZV, Kokkos::make_pair(0, end_row), Kokkos::make_pair(act_col, act_col+one)); - auto C_view = subview(ZV, Kokkos::make_pair(0, end_row), Kokkos::make_pair(my_cols, my_cols+n_rhs_this)); + auto A_view = subview(Z, Kokkos::make_pair(0, end_row), Kokkos::make_pair(act_col, act_col+one)); + auto C_view = subview(RHS, Kokkos::make_pair(0, end_row), Kokkos::make_pair(0, n_rhs_this)); auto B_view = subview(row1, Kokkos::ALL(), Kokkos::make_pair(0, n_rhs_this)); KokkosBlas::gemm("N","N",d_min_one, @@ -324,15 +317,16 @@ void back_solve6(ZDView& ZV) type[0] = SOROWTYPE+j; #if (defined(ADELUS_HOST_PINNED_MEM_MPI) || defined(IBM_MPI_WRKAROUND2)) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) - MPI_Irecv(reinterpret_cast(h_row2.data()), bytes[0], MPI_CHAR, MPI_ANY_SOURCE, type[0], MPI_COMM_WORLD, &msgrequest); + MPI_Irecv(reinterpret_cast(h_row2.data()), bytes[0], MPI_CHAR, MPI_ANY_SOURCE, type[0], comm, &msgrequest); #else - MPI_Irecv(reinterpret_cast( row2.data()), bytes[0], MPI_CHAR, MPI_ANY_SOURCE, type[0], MPI_COMM_WORLD, &msgrequest); + MPI_Irecv(reinterpret_cast( row2.data()), bytes[0], MPI_CHAR, MPI_ANY_SOURCE, type[0], comm, &msgrequest); #endif n_rhs_this = bytes[0]/sizeof(ADELUS_DATA_TYPE)/my_rows; #if (defined(ADELUS_HOST_PINNED_MEM_MPI) || defined(IBM_MPI_WRKAROUND2)) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) - Kokkos::deep_copy(subview(h_rhs, Kokkos::ALL(), Kokkos::make_pair(0, n_rhs_this)), subview(ZV, Kokkos::ALL(), Kokkos::make_pair(my_cols, my_cols+n_rhs_this))); + Kokkos::deep_copy(subview(h_rhs, Kokkos::ALL(), Kokkos::make_pair(0, n_rhs_this)), + subview(RHS, Kokkos::ALL(), Kokkos::make_pair(0, n_rhs_this))); #endif dest[1] = dest_left; @@ -340,9 +334,9 @@ void back_solve6(ZDView& ZV) type[1] = SOROWTYPE+j; #if (defined(ADELUS_HOST_PINNED_MEM_MPI) || defined(IBM_MPI_WRKAROUND2)) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) - MPI_Send(reinterpret_cast(h_rhs.data()), bytes[1], MPI_CHAR, dest[1], type[1], MPI_COMM_WORLD); + MPI_Send(reinterpret_cast(h_rhs.data()), bytes[1], MPI_CHAR, dest[1], type[1], comm); #else //GPU-aware MPI - MPI_Send(reinterpret_cast(ZV.data()+my_rows*my_cols), bytes[1], MPI_CHAR, dest[1], type[1], MPI_COMM_WORLD); + MPI_Send(reinterpret_cast(RHS.data()), bytes[1], MPI_CHAR, dest[1], type[1], comm); #endif MPI_Wait(&msgrequest,&msgstatus); @@ -351,18 +345,18 @@ void back_solve6(ZDView& ZV) int blas_length = n_rhs_this*my_rows; #if (defined(ADELUS_HOST_PINNED_MEM_MPI) || defined(IBM_MPI_WRKAROUND2)) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) //Use memcpy for now, can use deep_copy in the future //deep_copy is slower than BLAS XCOPY #if defined(KOKKOS_ENABLE_CUDA) - //Kokkos::deep_copy(subview(ZV, Kokkos::ALL(), Kokkos::make_pair(my_cols, my_cols+n_rhs_this)), subview(h_row2, Kokkos::ALL(), Kokkos::make_pair(0, n_rhs_this))); - cudaMemcpy(reinterpret_cast(ZV.data()+my_rows*my_cols), reinterpret_cast(h_row2.data()), blas_length*sizeof(ADELUS_DATA_TYPE), cudaMemcpyHostToDevice); + //Kokkos::deep_copy(subview(RHS, Kokkos::ALL(), Kokkos::make_pair(0, n_rhs_this)), subview(h_row2, Kokkos::ALL(), Kokkos::make_pair(0, n_rhs_this))); + cudaMemcpy(reinterpret_cast(RHS.data()), reinterpret_cast(h_row2.data()), blas_length*sizeof(ADELUS_DATA_TYPE), cudaMemcpyHostToDevice); #elif defined(KOKKOS_ENABLE_HIP) - hipMemcpy(reinterpret_cast(ZV.data()+my_rows*my_cols), reinterpret_cast(h_row2.data()), blas_length*sizeof(ADELUS_DATA_TYPE), hipMemcpyHostToDevice); + hipMemcpy(reinterpret_cast(RHS.data()), reinterpret_cast(h_row2.data()), blas_length*sizeof(ADELUS_DATA_TYPE), hipMemcpyHostToDevice); #endif #else #if defined(KOKKOS_ENABLE_CUDA) - cudaMemcpy(reinterpret_cast(ZV.data()+my_rows*my_cols), reinterpret_cast(row2.data()), blas_length*sizeof(ADELUS_DATA_TYPE), cudaMemcpyDeviceToDevice); + cudaMemcpy(reinterpret_cast(RHS.data()), reinterpret_cast(row2.data()), blas_length*sizeof(ADELUS_DATA_TYPE), cudaMemcpyDeviceToDevice); #elif defined(KOKKOS_ENABLE_HIP) - hipMemcpy(reinterpret_cast(ZV.data()+my_rows*my_cols), reinterpret_cast(row2.data()), blas_length*sizeof(ADELUS_DATA_TYPE), hipMemcpyDeviceToDevice); + hipMemcpy(reinterpret_cast(RHS.data()), reinterpret_cast(row2.data()), blas_length*sizeof(ADELUS_DATA_TYPE), hipMemcpyDeviceToDevice); #else - memcpy(reinterpret_cast(ZV.data()+my_rows*my_cols), reinterpret_cast(row2.data()), blas_length*sizeof(ADELUS_DATA_TYPE)); + memcpy(reinterpret_cast(RHS.data()), reinterpret_cast(row2.data()), blas_length*sizeof(ADELUS_DATA_TYPE)); #endif #endif } @@ -382,15 +376,250 @@ void back_solve6(ZDView& ZV) totalsolvetime = MPI_Wtime() - t2; #endif #ifdef GET_TIMING - showtime("Time to alloc view",&allocviewtime); - showtime("Time to eliminate rhs",&eliminaterhstime); - showtime("Time to bcast temp row",&bcastrowtime); - showtime("Time to update rhs",&updrhstime); + showtime(ahandle.get_comm_id(), comm, me, ahandle.get_nprocs_cube(), "Time to alloc view", &allocviewtime); + showtime(ahandle.get_comm_id(), comm, me, ahandle.get_nprocs_cube(), "Time to eliminate rhs",&eliminaterhstime); + showtime(ahandle.get_comm_id(), comm, me, ahandle.get_nprocs_cube(), "Time to bcast temp row",&bcastrowtime); + showtime(ahandle.get_comm_id(), comm, me, ahandle.get_nprocs_cube(), "Time to update rhs",&updrhstime); +#if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) + showtime(ahandle.get_comm_id(), comm, me, ahandle.get_nprocs_cube(), "Time to copy host pinned mem <--> dev mem",©hostpinnedtime); +#endif + showtime(ahandle.get_comm_id(), comm, me, ahandle.get_nprocs_cube(), "Time to xchg rhs",&xchgrhstime); + showtime(ahandle.get_comm_id(), comm, me, ahandle.get_nprocs_cube(), "Total time in solve",&totalsolvetime); +#endif +} + +template +inline +void back_solve_currcol_bcast(HandleType& ahandle, ZViewType& Z, RHSViewType& RHS) +{ + using value_type = typename ZViewType::value_type; +#ifdef PRINT_STATUS + using execution_space = typename ZViewType::device_type::execution_space; +#endif + using memory_space = typename ZViewType::device_type::memory_space; + using View2DType = Kokkos::View; + +#if defined(ADELUS_HOST_PINNED_MEM_MPI) +#if defined(KOKKOS_ENABLE_CUDA) + using View2DHostPinnType = Kokkos::View;//CudaHostPinnedSpace +#elif defined(KOKKOS_ENABLE_HIP) + using View2DHostPinnType = Kokkos::View;//HIPHostPinnedSpace +#endif +#endif + +#if defined(GET_TIMING) || defined(PRINT_STATUS) + int me = ahandle.get_myrank(); +#ifdef GET_TIMING + MPI_Comm comm = ahandle.get_comm(); +#endif +#endif + MPI_Comm col_comm = ahandle.get_col_comm(); + MPI_Comm row_comm = ahandle.get_row_comm(); + int myrow = ahandle.get_myrow(); + int mycol = ahandle.get_mycol(); + int nprocs_row = ahandle.get_nprocs_row(); + int nprocs_col = ahandle.get_nprocs_col(); + int ncols_matrix = ahandle.get_ncols_matrix(); + int my_rows = ahandle.get_my_rows(); + int my_rhs = ahandle.get_my_rhs(); + + value_type d_one = 1.0; + value_type d_min_one = -1.0; + +#ifdef GET_TIMING + double t1,t2; + double allocviewtime,eliminaterhstime,bcastrowtime,updrhstime,bcastcoltime,copycoltime; + double totalsolvetime; +#if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) + double copyhostpinnedtime; +#endif +#endif + +#ifdef GET_TIMING + t2 = MPI_Wtime(); +#endif + +#ifdef GET_TIMING + allocviewtime=eliminaterhstime=bcastrowtime=updrhstime=bcastcoltime=copycoltime=0.0; +#if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) + copyhostpinnedtime=0.0; +#endif + + t1 = MPI_Wtime(); +#endif + + View2DType curr_col( "curr_col", my_rows, 1 ); //current column + View2DType rhs_row ( "rhs_row", 1, my_rhs ); //current row of RHS to hold the elimination results (i.e row of solution) #if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) - showtime("Time to copy host pinned mem <--> dev mem",©hostpinnedtime); + View2DHostPinnType h_curr_col( "h_curr_col", my_rows, 1 ); + View2DHostPinnType h_rhs_row( "h_rhs_row", 1, my_rhs ); +#endif + + //Kokkos::fence();//NOTE: Should we need this? + +#ifdef GET_TIMING + allocviewtime += (MPI_Wtime()-t1); #endif - showtime("Time to xchg rhs",&xchgrhstime); - showtime("Total time in solve",&totalsolvetime); + +#ifdef PRINT_STATUS + printf("Rank %i -- back_solve6() Begin back solve, execution_space %s, memory_space %s\n",me, typeid(execution_space).name(), typeid(memory_space).name()); +#endif + + for (int k = ncols_matrix-1; k >= 0; k--) { + int k_row = k%nprocs_col;//proc. id (in the col_comm) having global k + int k_col = k%nprocs_row;//proc. id (in the row_comm) having global k + int end_row = k/nprocs_col; + if (myrow <= k_row) end_row++; + +#ifdef GET_TIMING + t1 = MPI_Wtime(); +#endif + //Step 1: copy the current column of Z to a temporary view + if (mycol == k_col) { //only deep_copy if holding the current column + Kokkos::deep_copy( Kokkos::subview(curr_col, Kokkos::make_pair(0, end_row), 0), + Kokkos::subview(Z, Kokkos::make_pair(0, end_row), k/nprocs_row) ); + } +#ifdef GET_TIMING + copycoltime += (MPI_Wtime()-t1); +#endif + +#if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) +#ifdef GET_TIMING + t1 = MPI_Wtime(); +#endif + Kokkos::deep_copy(h_curr_col,curr_col); +#ifdef GET_TIMING + copyhostpinnedtime += (MPI_Wtime()-t1); +#endif +#endif + +#ifdef GET_TIMING + t1 = MPI_Wtime(); +#endif + //Step 2: broadcast the current column to all ranks in the row_comm +#if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) + MPI_Bcast(reinterpret_cast(h_curr_col.data()), end_row*sizeof(ADELUS_DATA_TYPE), MPI_CHAR, k_col, row_comm); +#else //GPU-aware MPI + MPI_Bcast(reinterpret_cast(curr_col.data()), end_row*sizeof(ADELUS_DATA_TYPE), MPI_CHAR, k_col, row_comm); +#endif +#ifdef GET_TIMING + bcastcoltime += (MPI_Wtime()-t1); +#endif + +#if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) +#ifdef GET_TIMING + t1 = MPI_Wtime(); +#endif + Kokkos::deep_copy(curr_col,h_curr_col); +#ifdef GET_TIMING + copyhostpinnedtime += (MPI_Wtime()-t1); +#endif +#endif + +#ifdef GET_TIMING + t1 = MPI_Wtime(); +#endif + //Step 3: do rhs elimination to get solution x + if (myrow == k_row) {//only on ranks having row k + if (my_rhs > 0) { //only on ranks having some rhs + auto sub_curr_col = Kokkos::subview(curr_col, end_row-1, Kokkos::ALL()); + auto sub_rhs = Kokkos::subview(RHS, end_row-1, Kokkos::make_pair(0, my_rhs)); + int act_col = 0; + elimination_rhs(my_rhs, sub_curr_col, sub_rhs, rhs_row, act_col); Kokkos::fence(); + end_row--;//do not count the eliminated row in Step 5 + } + } +#ifdef GET_TIMING + eliminaterhstime += (MPI_Wtime()-t1); +#endif + + //MPI_Barrier(comm);//NOTE: Should we need this? + + if (my_rhs > 0) { //only on ranks having rhs + if (k >= 1) {//still have row(s) to do rhs updates with elimination results +#if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) +#ifdef GET_TIMING + t1 = MPI_Wtime(); +#endif + Kokkos::deep_copy(h_rhs_row,rhs_row); +#ifdef GET_TIMING + copyhostpinnedtime += (MPI_Wtime()-t1); +#endif +#endif + +#ifdef GET_TIMING + t1 = MPI_Wtime(); +#endif + //Step 4: broadcast elimination results to all ranks in col_comm +#if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) + MPI_Bcast(reinterpret_cast(h_rhs_row.data()), my_rhs*sizeof(ADELUS_DATA_TYPE), MPI_CHAR, k_row, col_comm); +#else //GPU-aware MPI + MPI_Bcast(reinterpret_cast(rhs_row.data()), my_rhs*sizeof(ADELUS_DATA_TYPE), MPI_CHAR, k_row, col_comm); +#endif + +#ifdef GET_TIMING + bcastrowtime += (MPI_Wtime()-t1); +#endif + +#if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) +#ifdef GET_TIMING + t1 = MPI_Wtime(); +#endif + Kokkos::deep_copy(rhs_row,h_rhs_row); +#ifdef GET_TIMING + copyhostpinnedtime += (MPI_Wtime()-t1); +#endif +#endif + +#ifdef GET_TIMING + t1 = MPI_Wtime(); +#endif + //Step 5: call gemm to update RHS with partial solution + auto A_view = Kokkos::subview(curr_col, Kokkos::make_pair(0, end_row), Kokkos::ALL()); + auto B_view = Kokkos::subview(rhs_row, Kokkos::ALL(), Kokkos::make_pair(0, my_rhs)); + auto C_view = Kokkos::subview(RHS, Kokkos::make_pair(0, end_row), Kokkos::make_pair(0, my_rhs)); + + KokkosBlas::gemm("N","N",d_min_one, A_view, B_view, d_one, C_view); Kokkos::fence(); +#ifdef GET_TIMING + updrhstime += (MPI_Wtime()-t1); +#endif + }//end of (k >= 1) + }//end of (my_rhs > 0) + + //MPI_Barrier(comm);//NOTE: Should we need this? + + }//end of for (int k = ncols_matrix-1; k >= 0; k--) + +#ifdef GET_TIMING + totalsolvetime = MPI_Wtime() - t2; +#endif +#ifdef GET_TIMING + showtime(ahandle.get_comm_id(), comm, me, ahandle.get_nprocs_cube(), "Time to alloc view", &allocviewtime); + showtime(ahandle.get_comm_id(), comm, me, ahandle.get_nprocs_cube(), "Time to copy matrix column",©coltime); + showtime(ahandle.get_comm_id(), comm, me, ahandle.get_nprocs_cube(), "Time to bcast matrix column",&bcastcoltime); + showtime(ahandle.get_comm_id(), comm, me, ahandle.get_nprocs_cube(), "Time to eliminate rhs",&eliminaterhstime); + showtime(ahandle.get_comm_id(), comm, me, ahandle.get_nprocs_cube(), "Time to bcast temp row",&bcastrowtime); + showtime(ahandle.get_comm_id(), comm, me, ahandle.get_nprocs_cube(), "Time to update rhs",&updrhstime); +#if defined(ADELUS_HOST_PINNED_MEM_MPI) && (defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)) + showtime(ahandle.get_comm_id(), comm, me, ahandle.get_nprocs_cube(), "Time to copy host pinned mem <--> dev mem",©hostpinnedtime); +#endif + showtime(ahandle.get_comm_id(), comm, me, ahandle.get_nprocs_cube(), "Total time in solve",&totalsolvetime); +#endif +} + +template +inline +void back_solve6(HandleType& ahandle, ZViewType& Z, RHSViewType& RHS) +{ +#if 0 + back_solve_rhs_pipelined_comm(ahandle, Z, RHS); +#else + if (ahandle.get_nrhs() <= ahandle.get_nprocs_row()) { + back_solve_rhs_pipelined_comm(ahandle, Z, RHS); + } + else { + back_solve_currcol_bcast(ahandle, Z, RHS); + } #endif } diff --git a/packages/adelus/src/Adelus_vars.hpp b/packages/adelus/src/Adelus_vars.hpp index e58f5aeb255b..47b35a41afda 100644 --- a/packages/adelus/src/Adelus_vars.hpp +++ b/packages/adelus/src/Adelus_vars.hpp @@ -49,31 +49,149 @@ namespace Adelus { - -int me; // processor id information -int nprocs_cube; // num of procs in the allocated cube -int nprocs_row; // num of procs to which a row is assigned -int nprocs_col; // num of procs to which a col is assigned -int max_procs; // max num of procs in any dimension +template +class AdelusHandle { + public: + using value_type = ScalarType; + using execution_space = ExecutionSpace; + using memory_space = MemorySpace; -int nrows_matrix; // number of rows in the matrix -int ncols_matrix; // number of cols in the matrix + private: -int my_first_row; // proc position in a row -int my_first_col; // proc position in a col + int comm_id; // communicator id + MPI_Comm comm; // communicator that I belong to + MPI_Comm row_comm; // row sub-communicator that I belong to + MPI_Comm col_comm; // column sub-communicator that I belong to -int my_rows; // num of rows I own -int my_cols; // num of cols I own + int myrank; // process id information -int nrhs; // number of right hand sides in the matrix -int my_rhs; // number of right hand sides that I own + int nrows_matrix; // number of rows in the matrix + int ncols_matrix; // number of cols in the matrix -int blksz; // block size for BLAS 3 operations + int nprocs_cube; // num of procs in the allocated cube + int nprocs_row; // num of procs to which a row is assigned + int nprocs_col; // num of procs to which a col is assigned -int myrow,mycol; + int my_first_row; // proc position in a row + int my_first_col; // proc position in a col + + int my_rows; // num of rows I own + int my_cols; // num of cols I own + + int nrhs; // number of right hand sides in the matrix + int my_rhs; // number of right hand sides that I own + + int blksz; // block size for matrix update (matrix-matrix multiply) + // (e.g. blksz = 128 for GPU, or blksz = 96 for CPU) -MPI_Comm row_comm,col_comm; + int myrow; // process id in the col_comm + int mycol; // process id in the row_comm + + + + public: + AdelusHandle( const int comm_id_, + MPI_Comm comm_, + const int matrix_size_, + const int num_procsr_, + const int num_rhs_, + const int blksz_ = 128 ) + : comm_id(comm_id_), + comm(comm_), + nrows_matrix(matrix_size_), + ncols_matrix(matrix_size_), + nprocs_row(num_procsr_), + nrhs(num_rhs_), + blksz(blksz_) { + // Determine who I am (myrank) and the total number of processes (nprocs_cube) + MPI_Comm_size(comm, &nprocs_cube); + MPI_Comm_rank(comm, &myrank); + nprocs_col = nprocs_cube/nprocs_row; + + // Set up communicators for rows and columns + mycol = myrank % nprocs_row; + myrow = myrank / nprocs_row; + + MPI_Comm_split(comm, myrow, mycol, &row_comm); + + MPI_Comm_split(comm, mycol, myrow, &col_comm); + + // Distribution for the matrix on myrank + my_first_col = myrank % nprocs_row; + my_first_row = myrank / nprocs_row; + + my_rows = nrows_matrix / nprocs_col; + if (my_first_row < nrows_matrix % nprocs_col) my_rows++; + my_cols = ncols_matrix / nprocs_row; + if (my_first_col < ncols_matrix % nprocs_row) my_cols++; + + // Distribution for the rhs on myrank + my_rhs = nrhs / nprocs_row; + if (my_first_col < nrhs % nprocs_row) my_rhs++; + } + + ~AdelusHandle(){} + + KOKKOS_INLINE_FUNCTION + int get_comm_id() const { return comm_id; } + + KOKKOS_INLINE_FUNCTION + MPI_Comm get_comm() const { return comm; } + + KOKKOS_INLINE_FUNCTION + MPI_Comm get_row_comm() const { return row_comm; } + + KOKKOS_INLINE_FUNCTION + MPI_Comm get_col_comm() const { return col_comm; } + + KOKKOS_INLINE_FUNCTION + int get_myrank() const { return myrank; } + + KOKKOS_INLINE_FUNCTION + int get_myrow() const { return myrow; } + + KOKKOS_INLINE_FUNCTION + int get_mycol() const { return mycol; } + + KOKKOS_INLINE_FUNCTION + int get_nprocs_cube() const { return nprocs_cube; } + + KOKKOS_INLINE_FUNCTION + int get_nprocs_row() const { return nprocs_row; } + + KOKKOS_INLINE_FUNCTION + int get_nprocs_col() const { return nprocs_col; } + + KOKKOS_INLINE_FUNCTION + int get_nrows_matrix() const { return nrows_matrix; } + + KOKKOS_INLINE_FUNCTION + int get_ncols_matrix() const { return ncols_matrix; } + + KOKKOS_INLINE_FUNCTION + int get_my_first_row() const { return my_first_row; } + + KOKKOS_INLINE_FUNCTION + int get_my_first_col() const { return my_first_col; } + + KOKKOS_INLINE_FUNCTION + int get_my_rows() const { return my_rows; } + + KOKKOS_INLINE_FUNCTION + int get_my_cols() const { return my_cols; } + + KOKKOS_INLINE_FUNCTION + int get_nrhs() const { return nrhs; } + + KOKKOS_INLINE_FUNCTION + int get_my_rhs() const { return my_rhs; } + + KOKKOS_INLINE_FUNCTION + int get_blksz() const { return blksz; } +}; }//namespace Adelus diff --git a/packages/adelus/src/Adelus_x_factor.hpp b/packages/adelus/src/Adelus_x_factor.hpp new file mode 100644 index 000000000000..57f42f725156 --- /dev/null +++ b/packages/adelus/src/Adelus_x_factor.hpp @@ -0,0 +1,166 @@ +/* +//@HEADER +// ************************************************************************ +// +// Adelus v. 1.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of NTESS nor the names of the contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY EXPRESS OR IMPLIED +// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +// IN NO EVENT SHALL NTESS OR THE CONTRIBUTORS BE LIABLE FOR ANY DIRECT, +// INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING +// IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Vinh Dang (vqdang@sandia.gov) +// Joseph Kotulski (jdkotul@sandia.gov) +// Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef __ADELUS_XLU_HPP__ +#define __ADELUS_XLU_HPP__ + +#include +#include +#include +#include "mpi.h" +#include "Kokkos_Core.hpp" +#include "Adelus_defines.h" +#include "Adelus_macros.h" +#include "Adelus_vars.hpp" +#include "Adelus_mytime.hpp" +#include "Adelus_factor.hpp" +#include "Adelus_perm_mat.hpp" + + +#ifdef ADELUS_HAVE_TIME_MONITOR +#include "Teuchos_TimeMonitor.hpp" +#endif + +namespace Adelus { + +template +inline +void lu_(HandleType& ahandle, ZViewType& Z, PViewType& permute, double *secs) +{ +#ifdef ADELUS_HAVE_TIME_MONITOR + using Teuchos::TimeMonitor; +#endif + + using value_type = typename ZViewType::value_type; +#ifdef PRINT_STATUS + using execution_space = typename ZViewType::device_type::execution_space; +#endif + using memory_space = typename ZViewType::device_type::memory_space; + + int blksz = ahandle.get_blksz(); + int my_rows = ahandle.get_my_rows(); + int my_cols = ahandle.get_my_cols(); + + double run_secs; // time (in secs) during which the prog ran + double tsecs; // intermediate storage of timing info + int totmem = 0; // Initialize the total memory used + +#ifdef PRINT_STATUS + printf("Rank %i -- factor_() Begin LU with blksz %d, myrow %d, mycol %d, nprocs_row %d, nprocs_col %d, nrows_matrix %d, ncols_matrix %d, my_rows %d, my_cols %d, my_rhs %d, nrhs %d, value_type %s, execution_space %s, memory_space %s\n", ahandle.get_myrank(), blksz, ahandle.get_myrow(), ahandle.get_mycol(), ahandle.get_nprocs_row(), ahandle.get_nprocs_col(), ahandle.get_nrows_matrix(), ahandle.get_ncols_matrix(), my_rows, my_cols, ahandle.get_my_rhs(), ahandle.get_nrhs(), typeid(value_type).name(), typeid(execution_space).name(), typeid(memory_space).name()); +#endif + + // Allocate arrays for factor + using ViewType1D = Kokkos::View; + using ViewType2D = Kokkos::View; + + totmem += (blksz) * (my_rows) * sizeof(ADELUS_DATA_TYPE); //col1_view + totmem += blksz * (my_cols + blksz + 0) * sizeof(ADELUS_DATA_TYPE);//row1_view + totmem += (my_cols + blksz + 0) * sizeof(ADELUS_DATA_TYPE); //row2_view + totmem += (my_cols + blksz + 0) * sizeof(ADELUS_DATA_TYPE); //row3_view + totmem += my_cols * sizeof(int); //lpiv_view + + ViewType2D col1_view ( "col1_view", my_rows, blksz ); + ViewType2D row1_view ( "row1_view", blksz, my_cols + blksz + 0 ); + ViewType1D row2_view ( "row2_view", my_cols + blksz + 0 ); + ViewType1D row3_view ( "row3_view", my_cols + blksz + 0 ); + PViewType lpiv_view ( "lpiv_view", my_cols ); + + { + // Factor the system + + tsecs = get_seconds(0.0); + +#ifdef PRINT_STATUS + printf("OpenMP or Cuda: Rank %i -- factor() starts ...\n", ahandle.get_myrank()); +#endif +#ifdef ADELUS_HAVE_TIME_MONITOR + { + TimeMonitor t(*TimeMonitor::getNewTimer("Adelus: factor")); +#endif + factor(ahandle, + Z, + col1_view, + row1_view, + row2_view, + row3_view, + lpiv_view, + 0, 0); +#ifdef ADELUS_HAVE_TIME_MONITOR + } +#endif + + // Permute the lower triangular matrix +#ifdef ADELUS_HAVE_TIME_MONITOR + { + TimeMonitor t(*TimeMonitor::getNewTimer("Adelus: matrix permutation")); +#endif +#ifdef ADELUS_PERM_MAT_FORWARD_COPY_TO_HOST + typename ZViewType::HostMirror h_Z = Kokkos::create_mirror_view( Z ); + Kokkos::deep_copy (h_Z, Z); + + permute_mat(ahandle, h_Z, lpiv_view, permute); + + Kokkos::deep_copy (Z, h_Z); +#else + permute_mat(ahandle, Z, lpiv_view, permute); +#endif +#ifdef ADELUS_HAVE_TIME_MONITOR + } +#endif + + tsecs = get_seconds(tsecs); + + run_secs = (double) tsecs; + + *secs = run_secs; + showtime(ahandle.get_comm_id(), ahandle.get_comm(), ahandle.get_myrank(), ahandle.get_nprocs_cube(), + "Total time in Factor (inl. matrix permutation)", &run_secs ); + } +} + +}//namespace Adelus + +#endif diff --git a/packages/adelus/src/Adelus_x_solve.hpp b/packages/adelus/src/Adelus_x_solve.hpp new file mode 100644 index 000000000000..2d996297742a --- /dev/null +++ b/packages/adelus/src/Adelus_x_solve.hpp @@ -0,0 +1,172 @@ +/* +//@HEADER +// ************************************************************************ +// +// Adelus v. 1.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of NTESS nor the names of the contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY EXPRESS OR IMPLIED +// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +// IN NO EVENT SHALL NTESS OR THE CONTRIBUTORS BE LIABLE FOR ANY DIRECT, +// INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING +// IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Vinh Dang (vqdang@sandia.gov) +// Joseph Kotulski (jdkotul@sandia.gov) +// Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef __ADELUS_XSOLVE_HPP__ +#define __ADELUS_XSOLVE_HPP__ + +#include +#include +#include +#include "mpi.h" +#include "Kokkos_Core.hpp" +#include "Adelus_defines.h" +#include "Adelus_macros.h" +#include "Adelus_vars.hpp" +#include "Adelus_mytime.hpp" +#include "Adelus_perm_rhs.hpp" +#include "Adelus_forward.hpp" +#include "Adelus_solve.hpp" +#include "Adelus_perm1.hpp" + +#ifdef ADELUS_HAVE_TIME_MONITOR +#include "Teuchos_TimeMonitor.hpp" +#endif + +namespace Adelus { + +template +inline +void solve_(HandleType& ahandle, ZViewType& Z, RHSViewType& RHS, PViewType& permute, double *secs) +{ +#ifdef ADELUS_HAVE_TIME_MONITOR + using Teuchos::TimeMonitor; +#endif + + using value_type = typename ZViewType::value_type; +#ifdef PRINT_STATUS + using execution_space = typename ZViewType::device_type::execution_space; + using memory_space = typename ZViewType::device_type::memory_space; +#endif + + double run_secs; // time (in secs) during which the prog ran + double tsecs; // intermediate storage of timing info + +#ifdef PRINT_STATUS + printf("Rank %i -- solve_() Begin FwdSolve+BwdSolve+Perm with blksz %d, myrow %d, mycol %d, nprocs_row %d, nprocs_col %d, nrows_matrix %d, ncols_matrix %d, my_rows %d, my_cols %d, my_rhs %d, nrhs %d, value_type %s, execution_space %s, memory_space %s\n", ahandle.get_myrank(), ahandle.get_blksz(), ahandle.get_myrow(), ahandle.get_mycol(), ahandle.get_nprocs_row(), ahandle.get_nprocs_col(), ahandle.get_nrows_matrix(), ahandle.get_ncols_matrix(), ahandle.get_my_rows(), ahandle.get_my_cols(), ahandle.get_my_rhs(), ahandle.get_nrhs(), typeid(value_type).name(), typeid(execution_space).name(), typeid(memory_space).name()); +#endif + + { + tsecs = get_seconds(0.0); + +#ifdef ADELUS_PERM_MAT_FORWARD_COPY_TO_HOST + typename ZViewType::HostMirror h_Z = Kokkos::create_mirror_view( Z ); + typename RHSViewType::HostMirror h_RHS = Kokkos::create_mirror_view( RHS ); + // Bring data to host memory + Kokkos::deep_copy (h_Z, Z); + Kokkos::deep_copy (h_RHS, RHS); +#endif + +#ifdef ADELUS_HAVE_TIME_MONITOR + { + TimeMonitor t(*TimeMonitor::getNewTimer("Adelus: rhs permutation")); +#endif + // Permute the RHS +#ifdef ADELUS_PERM_MAT_FORWARD_COPY_TO_HOST + permute_rhs(ahandle, h_RHS, permute); +#else + permute_rhs(ahandle, RHS, permute); +#endif +#ifdef ADELUS_HAVE_TIME_MONITOR + } +#endif + +#ifdef ADELUS_HAVE_TIME_MONITOR + { + TimeMonitor t(*TimeMonitor::getNewTimer("Adelus: forward solve")); +#endif + //Forward Solve +#ifdef ADELUS_PERM_MAT_FORWARD_COPY_TO_HOST + forward(ahandle, h_Z, h_RHS); +#else + forward(ahandle, Z, RHS); +#endif +#ifdef ADELUS_HAVE_TIME_MONITOR + } +#endif + +#ifdef ADELUS_PERM_MAT_FORWARD_COPY_TO_HOST + // Copy back to device memory + Kokkos::deep_copy (Z, h_Z); + Kokkos::deep_copy (RHS, h_RHS); +#endif + + MPI_Barrier(ahandle.get_comm()); + +#ifdef ADELUS_HAVE_TIME_MONITOR + { + TimeMonitor t(*TimeMonitor::getNewTimer("Adelus: backsolve")); +#endif + back_solve6(ahandle, Z, RHS); +#ifdef ADELUS_HAVE_TIME_MONITOR + } +#endif + + MPI_Barrier(ahandle.get_comm()); + +#ifdef ADELUS_HAVE_TIME_MONITOR + { + TimeMonitor t(*TimeMonitor::getNewTimer("Adelus: permutation")); +#endif + perm1_(ahandle, RHS); +#ifdef ADELUS_HAVE_TIME_MONITOR + } +#endif + + MPI_Barrier(ahandle.get_comm()); + + tsecs = get_seconds(tsecs); + + run_secs = (double) tsecs; + + *secs = run_secs; + showtime(ahandle.get_comm_id(), ahandle.get_comm(), ahandle.get_myrank(), ahandle.get_nprocs_cube(), + "Total time in Solve", &run_secs ); + } +} + +}//namespace Adelus + +#endif diff --git a/packages/adelus/src/Adelus_xlu_solve.hpp b/packages/adelus/src/Adelus_xlu_solve.hpp index d88ad034f91e..eb3474f857a9 100644 --- a/packages/adelus/src/Adelus_xlu_solve.hpp +++ b/packages/adelus/src/Adelus_xlu_solve.hpp @@ -50,17 +50,15 @@ #include #include #include -#include "Adelus_defines.h" #include "mpi.h" -#include "Adelus_vars.hpp" +#include "Kokkos_Core.hpp" +#include "Adelus_defines.h" #include "Adelus_macros.h" -#include "Adelus_block.h" +#include "Adelus_vars.hpp" +#include "Adelus_mytime.hpp" #include "Adelus_solve.hpp" #include "Adelus_factor.hpp" #include "Adelus_perm1.hpp" -#include "Adelus_pcomm.hpp" -#include "Adelus_mytime.hpp" -#include "Kokkos_Core.hpp" #ifdef ADELUS_HAVE_TIME_MONITOR #include "Teuchos_TimeMonitor.hpp" @@ -68,71 +66,38 @@ namespace Adelus { -template +template inline -void lusolve_(ZDView& ZV, int *matrix_size, int *num_procsr, int *num_rhs, double *secs) +void lusolve_(HandleType& ahandle, ZRHSViewType& ZRHS, double *secs) { #ifdef ADELUS_HAVE_TIME_MONITOR using Teuchos::TimeMonitor; #endif - using value_type = typename ZDView::value_type; + using value_type = typename ZRHSViewType::value_type; #ifdef PRINT_STATUS - using execution_space = typename ZDView::device_type::execution_space; + using execution_space = typename ZRHSViewType::device_type::execution_space; #endif - using memory_space = typename ZDView::device_type::memory_space; - - double run_secs; // time (in secs) during which the prog ran - double tsecs; // intermediate storage of timing info - int totmem; - - // Determine who I am (me ) and the total number of nodes (nprocs_cube) - MPI_Comm_size(MPI_COMM_WORLD,&nprocs_cube); - MPI_Comm_rank(MPI_COMM_WORLD, &me); + using memory_space = typename ZRHSViewType::device_type::memory_space; - nrows_matrix = *matrix_size; - ncols_matrix = *matrix_size; - nprocs_row = *num_procsr; + int blksz = ahandle.get_blksz(); + int my_rows = ahandle.get_my_rows(); + int my_cols = ahandle.get_my_cols(); + int nrhs = ahandle.get_nrhs(); + int my_rhs = ahandle.get_my_rhs(); - totmem=0; // Initialize the total memory used - nprocs_col = nprocs_cube/nprocs_row; - max_procs = (nprocs_row < nprocs_col) ? nprocs_col : nprocs_row; - - // Set up communicators for rows and columns - myrow = mesh_row(me); - mycol = mesh_col(me); - - MPI_Comm_split(MPI_COMM_WORLD,myrow,mycol,&row_comm); - - MPI_Comm_split(MPI_COMM_WORLD,mycol,myrow,&col_comm); - - // Distribution for the matrix on me - my_first_col = mesh_col(me); - my_first_row = mesh_row(me); - - my_rows = nrows_matrix / nprocs_col; - if (my_first_row < nrows_matrix % nprocs_col) - ++my_rows; - my_cols = ncols_matrix / nprocs_row; - if (my_first_col < ncols_matrix % nprocs_row) - ++my_cols; - - // blksz parameter must be set - blksz = DEFBLKSZ; - - // Distribution for the rhs on me - nrhs = *num_rhs; - my_rhs = nrhs / nprocs_row; - if (my_first_col < nrhs % nprocs_row) ++my_rhs; + double run_secs; // time (in secs) during which the prog ran + double tsecs; // intermediate storage of timing info + int totmem = 0; // Initialize the total memory used #ifdef PRINT_STATUS - printf("Rank %i -- lusolve_() Begin LU+Solve+Perm with blksz %d, value_type %s, execution_space %s, memory_space %s\n", me, blksz, typeid(value_type).name(), typeid(execution_space).name(), typeid(memory_space).name()); + printf("Rank %i -- lusolve_() Begin LU+Solve+Perm with blksz %d, value_type %s, execution_space %s, memory_space %s\n", ahandle.get_myrank(), blksz, typeid(value_type).name(), typeid(execution_space).name(), typeid(memory_space).name()); #endif // Allocate arrays for factor/solve typedef Kokkos::View ViewType1D; typedef Kokkos::View ViewType2D; - typedef Kokkos::View ViewIntType1D; + typedef Kokkos::View ViewIntType1D; totmem += (blksz) * (my_rows) * sizeof(ADELUS_DATA_TYPE); //col1_view totmem += blksz * (my_cols + blksz + nrhs) * sizeof(ADELUS_DATA_TYPE);//row1_view @@ -146,42 +111,44 @@ void lusolve_(ZDView& ZV, int *matrix_size, int *num_procsr, int *num_rhs, doubl ViewType1D row3_view ( "row3_view", my_cols + blksz + nrhs ); ViewIntType1D pivot_vec_view ( "pivot_vec_view", my_cols ); - { // Factor and Solve the system tsecs = get_seconds(0.0); - initcomm(); - #ifdef PRINT_STATUS - printf("OpenMP or Cuda: Rank %i -- factor() starts ...\n", me); + printf("OpenMP or Cuda: Rank %i -- factor() starts ...\n", ahandle.get_myrank()); #endif #ifdef ADELUS_HAVE_TIME_MONITOR { TimeMonitor t(*TimeMonitor::getNewTimer("Adelus: factor")); #endif - factor(ZV, + factor(ahandle, + ZRHS, col1_view, row1_view, row2_view, row3_view, - pivot_vec_view); + pivot_vec_view, + nrhs, my_rhs); #ifdef ADELUS_HAVE_TIME_MONITOR } #endif if (nrhs > 0) { + auto Z = subview(ZRHS, Kokkos::ALL(), Kokkos::make_pair(0, my_cols)); + auto RHS = subview(ZRHS, Kokkos::ALL(), Kokkos::make_pair(my_cols, my_cols + my_rhs + 6)); + // Perform the backsolve #ifdef PRINT_STATUS - printf("OpenMP or Cuda: Rank %i -- back_solve6() starts ...\n", me); + printf("OpenMP or Cuda: Rank %i -- back_solve6() starts ...\n", ahandle.get_myrank()); #endif #ifdef ADELUS_HAVE_TIME_MONITOR { TimeMonitor t(*TimeMonitor::getNewTimer("Adelus: backsolve")); #endif - back_solve6(ZV); + back_solve6(ahandle, Z, RHS); #ifdef ADELUS_HAVE_TIME_MONITOR } #endif @@ -189,14 +156,13 @@ void lusolve_(ZDView& ZV, int *matrix_size, int *num_procsr, int *num_rhs, doubl // Permute the results -- undo the torus map #ifdef PRINT_STATUS - printf("OpenMP or Cuda: Rank %i -- perm1_()(permute the results -- undo the torus map) starts ...\n", me); + printf("OpenMP or Cuda: Rank %i -- perm1_()(permute the results -- undo the torus map) starts ...\n", ahandle.get_myrank()); #endif #ifdef ADELUS_HAVE_TIME_MONITOR { TimeMonitor t(*TimeMonitor::getNewTimer("Adelus: permutation")); #endif - auto sub_ZV = subview(ZV, Kokkos::ALL(), Kokkos::make_pair(my_cols, my_cols + my_rhs + 6)); - perm1_(sub_ZV, &my_rhs); + perm1_(ahandle, RHS); #ifdef ADELUS_HAVE_TIME_MONITOR } #endif @@ -209,7 +175,8 @@ void lusolve_(ZDView& ZV, int *matrix_size, int *num_procsr, int *num_rhs, doubl // Solve time secs *secs = run_secs; - showtime("Total time in Factor and Solve",&run_secs); + showtime(ahandle.get_comm_id(), ahandle.get_comm(), ahandle.get_myrank(), ahandle.get_nprocs_cube(), + "Total time in Factor and Solve", &run_secs); } } diff --git a/packages/adelus/src/CMakeLists.txt b/packages/adelus/src/CMakeLists.txt index 9ce34c2b37da..81717dd70d3b 100644 --- a/packages/adelus/src/CMakeLists.txt +++ b/packages/adelus/src/CMakeLists.txt @@ -55,21 +55,23 @@ IF (TPL_ENABLE_MPI) # APPEND_SET(HEADERS - Adelus_block.h Adelus_distribute.hpp Adelus_factor.hpp - Adelus_pcomm.hpp + Adelus_forward.hpp Adelus_perm1.hpp + Adelus_perm_mat.hpp + Adelus_perm_rhs.hpp Adelus_solve.hpp Adelus_vars.hpp Adelus_xlu_solve.hpp + Adelus_x_factor.hpp + Adelus_x_solve.hpp Adelus.hpp BlasWrapper_copy_spec.hpp BlasWrapper_copy.hpp ) APPEND_SET(SOURCES - Adelus_pcomm.cpp Adelus_distribute.cpp ) diff --git a/packages/adelus/test/CMakeLists.txt b/packages/adelus/test/CMakeLists.txt index bcdf73fde286..a1f92123a73c 100644 --- a/packages/adelus/test/CMakeLists.txt +++ b/packages/adelus/test/CMakeLists.txt @@ -1,5 +1,7 @@ IF(Adelus_ENABLE_ZCPLX OR Adelus_ENABLE_DREAL) ADD_SUBDIRECTORY(vector_random) + ADD_SUBDIRECTORY(vector_random_fs) + ADD_SUBDIRECTORY(vector_random_mc) ENDIF() IF(Adelus_ENABLE_Teuchos AND (Adelus_ENABLE_ZCPLX OR Adelus_ENABLE_DREAL)) diff --git a/packages/adelus/test/definition b/packages/adelus/test/definition index 6e41b37ee2a7..acfeeeda2a27 100644 --- a/packages/adelus/test/definition +++ b/packages/adelus/test/definition @@ -6,13 +6,25 @@ PACKAGE_NAME=Adelus (FRAMEWORK, INSTALL) { DIRS = vector_random; - ARGS = 1000 1; + ARGS = 3500 1; COMM = MPI(1 2 3); } (FRAMEWORK, INSTALL) { DIRS = vector_random; - ARGS = 1000 2; + ARGS = 3500 2; + COMM = MPI(4); +} + +(FRAMEWORK, INSTALL) { + DIRS = vector_random_fs; + ARGS = 3500 1; + COMM = MPI(1 2 3); +} + +(FRAMEWORK, INSTALL) { + DIRS = vector_random_fs; + ARGS = 3500 2; COMM = MPI(4); } diff --git a/packages/adelus/test/perf_test/cxx_main.cpp b/packages/adelus/test/perf_test/cxx_main.cpp index 226f5fb4709d..b159c757a444 100644 --- a/packages/adelus/test/perf_test/cxx_main.cpp +++ b/packages/adelus/test/perf_test/cxx_main.cpp @@ -109,7 +109,7 @@ int main(int argc, char *argv[]) double rhs_nrm, m_nrm; - int result; + int result=1; // Enroll into MPI @@ -180,16 +180,10 @@ int main(int argc, char *argv[]) // Get Info to build the matrix on a processor - Adelus::GetDistribution( &nprocs_per_row, - &matrix_size, - &numrhs, - &myrows, - &mycols, - &myfirstrow, - &myfirstcol, - &myrhs, - &my_row, - &my_col ); + Adelus::GetDistribution( MPI_COMM_WORLD, + nprocs_per_row, matrix_size, numrhs, + myrows, mycols, myfirstrow, myfirstcol, + myrhs, my_row, my_col ); // Define a new communicator @@ -228,30 +222,30 @@ int main(int argc, char *argv[]) { // Local size -- myrows * (mycols + myrhs) - typedef Kokkos::LayoutLeft Layout; + using Layout = Kokkos::LayoutLeft; #if defined(KOKKOS_ENABLE_CUDA) - typedef Kokkos::CudaSpace TestSpace; + using TestSpace = Kokkos::CudaSpace; #elif defined(KOKKOS_ENABLE_HIP) - typedef Kokkos::Experimental::HIPSpace TestSpace; + using TestSpace = Kokkos::Experimental::HIPSpace; #else - typedef Kokkos::HostSpace TestSpace; + using TestSpace = Kokkos::HostSpace; #endif #ifdef DREAL - typedef Kokkos::View ViewMatrixType; - typedef Kokkos::View ViewVectorType_Host; + using ViewMatrixType = Kokkos::View; + using ViewVectorType_Host = Kokkos::View; #elif defined(SREAL) - typedef Kokkos::View ViewMatrixType; - typedef Kokkos::View ViewVectorType_Host; + using ViewMatrixType = Kokkos::View; + using ViewVectorType_Host = Kokkos::View; #elif defined(SCPLX) - typedef Kokkos::View**, Layout, TestSpace> ViewMatrixType; - typedef Kokkos::View*, Layout, Kokkos::HostSpace> ViewVectorType_Host; + using ViewMatrixType = Kokkos::View**, Layout, TestSpace>; + using ViewVectorType_Host = Kokkos::View*, Layout, Kokkos::HostSpace>; #else - typedef Kokkos::View**, Layout, TestSpace> ViewMatrixType; - typedef Kokkos::View*, Layout, Kokkos::HostSpace> ViewVectorType_Host; + using ViewMatrixType = Kokkos::View**, Layout, TestSpace>; + using ViewVectorType_Host = Kokkos::View*, Layout, Kokkos::HostSpace>; #endif - typedef typename ViewMatrixType::device_type::execution_space execution_space; - typedef typename ViewMatrixType::device_type::memory_space memory_space; - typedef typename ViewMatrixType::value_type ScalarA; + using execution_space = typename ViewMatrixType::device_type::execution_space; + using memory_space = typename ViewMatrixType::device_type::memory_space; + using ScalarA = typename ViewMatrixType::value_type; printf("Rank %d, ViewMatrixType execution_space %s, memory_space %s, value_type %s\n",rank, typeid(execution_space).name(), typeid(memory_space).name(), typeid(ScalarA).name()); @@ -325,6 +319,10 @@ int main(int argc, char *argv[]) Kokkos::deep_copy( subview(A,Kokkos::ALL(),mycols), subview(h_A,Kokkos::ALL(),mycols) ); + // Create handle + Adelus::AdelusHandle + ahandle(0, MPI_COMM_WORLD, matrix_size, nprocs_per_row, numrhs ); + // Now Solve the Problem RCP timer = rcp(new StackedTimer("Adelus: total")); TimeMonitor::setStackedTimer(timer); @@ -332,7 +330,7 @@ int main(int argc, char *argv[]) if( rank == 0 ) std::cout << " **** Beginning Matrix Solve ****" << std::endl; - Adelus::FactorSolve (A, myrows, mycols, &matrix_size, &nprocs_per_row, &numrhs, &secs); + Adelus::FactorSolve (ahandle, A, &secs); if( rank == 0) { std::cout << " ---- Solution time ---- " << secs << " in secs. " << std::endl; diff --git a/packages/adelus/test/vector_random/CMakeLists.txt b/packages/adelus/test/vector_random/CMakeLists.txt index 2d12bfcf5cea..c91f95131d41 100644 --- a/packages/adelus/test/vector_random/CMakeLists.txt +++ b/packages/adelus/test/vector_random/CMakeLists.txt @@ -1,29 +1,196 @@ - +#1 RANK TRIBITS_ADD_EXECUTABLE_AND_TEST( vector_random SOURCES cxx_main.cpp + NAME vector_random_npr1_rhs1 + NUM_MPI_PROCS 1 + ARGS "3501 1 1 1" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random + NAME vector_random_npr1_rhs4 + NUM_MPI_PROCS 1 + ARGS "3501 1 1 4" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random + NAME vector_random_npr1_rhs5 NUM_MPI_PROCS 1 - ARGS "1000 1 1" + ARGS "3501 1 1 5" + COMM mpi + ) + +#2 RANKS +TRIBITS_ADD_TEST( + vector_random + NAME vector_random_npr1_rhs1 + NUM_MPI_PROCS 2 + ARGS "3501 1 2 1" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random + NAME vector_random_npr1_rhs4 + NUM_MPI_PROCS 2 + ARGS "3501 1 2 4" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random + NAME vector_random_npr1_rhs5 + NUM_MPI_PROCS 2 + ARGS "3501 1 2 5" COMM mpi ) TRIBITS_ADD_TEST( vector_random + NAME vector_random_npr2_rhs1 NUM_MPI_PROCS 2 - ARGS "1000 1 2" + ARGS "3501 2 2 1" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random + NAME vector_random_npr2_rhs4 + NUM_MPI_PROCS 2 + ARGS "3501 2 2 4" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random + NAME vector_random_npr2_rhs5 + NUM_MPI_PROCS 2 + ARGS "3501 2 2 5" + COMM mpi + ) + +#3 RANKS +TRIBITS_ADD_TEST( + vector_random + NAME vector_random_npr1_rhs1 + NUM_MPI_PROCS 3 + ARGS "3501 1 3 1" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random + NAME vector_random_npr1_rhs4 + NUM_MPI_PROCS 3 + ARGS "3501 1 3 4" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random + NAME vector_random_npr1_rhs5 + NUM_MPI_PROCS 3 + ARGS "3501 1 3 5" COMM mpi ) TRIBITS_ADD_TEST( vector_random + NAME vector_random_npr3_rhs1 NUM_MPI_PROCS 3 - ARGS "1000 1 3" + ARGS "3501 3 3 1" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random + NAME vector_random_npr3_rhs4 + NUM_MPI_PROCS 3 + ARGS "3501 3 3 4" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random + NAME vector_random_npr3_rhs5 + NUM_MPI_PROCS 3 + ARGS "3501 3 3 5" + COMM mpi + ) + +#4 RANKS +TRIBITS_ADD_TEST( + vector_random + NAME vector_random_npr1_rhs1 + NUM_MPI_PROCS 4 + ARGS "3501 1 4 1" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random + NAME vector_random_npr1_rhs4 + NUM_MPI_PROCS 4 + ARGS "3501 1 4 4" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random + NAME vector_random_npr1_rhs5 + NUM_MPI_PROCS 4 + ARGS "3501 1 4 5" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random + NAME vector_random_npr2_rhs1 + NUM_MPI_PROCS 4 + ARGS "3501 2 4 1" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random + NAME vector_random_npr2_rhs4 + NUM_MPI_PROCS 4 + ARGS "3501 2 4 4" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random + NAME vector_random_npr2_rhs5 + NUM_MPI_PROCS 4 + ARGS "3501 2 4 5" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random + NAME vector_random_npr4_rhs1 + NUM_MPI_PROCS 4 + ARGS "3501 4 4 1" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random + NAME vector_random_npr4_rhs4 + NUM_MPI_PROCS 4 + ARGS "3501 4 4 4" COMM mpi ) TRIBITS_ADD_TEST( vector_random + NAME vector_random_npr4_rhs5 NUM_MPI_PROCS 4 - ARGS "1000 2 4" + ARGS "3501 4 4 5" COMM mpi ) diff --git a/packages/adelus/test/vector_random/cxx_main.cpp b/packages/adelus/test/vector_random/cxx_main.cpp index 2cbd9cbf4642..8e4f2723fa28 100644 --- a/packages/adelus/test/vector_random/cxx_main.cpp +++ b/packages/adelus/test/vector_random/cxx_main.cpp @@ -54,8 +54,9 @@ #include #include #include +#include #include -#include +#include #include int main(int argc, char *argv[]) @@ -74,13 +75,13 @@ int main(int argc, char *argv[]) int matrix_size; int nprocs_per_row; int nptile = 1; // number of processors per node + int numrhs = 1; double mflops; - MPI_Comm rowcomm; + MPI_Comm rowcomm, colcomm; static int buf[4]; - int numrhs; int i, m, k; @@ -98,9 +99,7 @@ int main(int argc, char *argv[]) double tempc; - double rhs_nrm, m_nrm; - - int result; + int result=1; // Enroll into MPI @@ -126,10 +125,12 @@ int main(int argc, char *argv[]) buf[1] = atoi(argv[2]); // argv[3] should be #procs per node buf[2] = atoi(argv[3]); + // argv[4] should be #rhs + buf[3] = atoi(argv[4]); } else { // default is 1, but sqrt(p) would be better - buf[1] = 1; buf[2] = 1; + buf[1] = 1; buf[2] = 1; buf[3] = 1; } } else { @@ -147,10 +148,14 @@ int main(int argc, char *argv[]) std::cout << "Enter number of processors per node " << std::endl; std::cin >> buf[2]; } + if (buf[3] < 0) { + std::cout << "Enter number of rhs vectors " << std::endl; + std::cin >> buf[3]; + } } } - /* Send the initilization data to each processor */ + // Send the initilization data to each processor mlen = 4*sizeof(int); MPI_Bcast(reinterpret_cast(buf), mlen, MPI_CHAR, 0, MPI_COMM_WORLD); @@ -163,36 +168,30 @@ int main(int argc, char *argv[]) nptile = buf[2]; + numrhs = buf[3]; + if( rank == 0 ) { std::cout << " Matrix Size " << matrix_size << std::endl; std::cout << " Processors in a row " << nprocs_per_row << std::endl; std::cout << " Processors in a node " << nptile << std::endl; + std::cout << " Number of RHS vectors " << numrhs << std::endl; } - // Example for 1 RHS - - numrhs = 1; - if( rank == 0) { std::cout << " ---- Building Adelus solver ----" << std::endl; } // Get Info to build the matrix on a processor - Adelus::GetDistribution( &nprocs_per_row, - &matrix_size, - &numrhs, - &myrows, - &mycols, - &myfirstrow, - &myfirstcol, - &myrhs, - &my_row, - &my_col ); + Adelus::GetDistribution( MPI_COMM_WORLD, + nprocs_per_row, matrix_size, numrhs, + myrows, mycols, myfirstrow, myfirstcol, + myrhs, my_row, my_col ); - // Define a new communicator + // Define new communicators: rowcomm and colcomm MPI_Comm_split(MPI_COMM_WORLD,my_row,my_col,&rowcomm); + MPI_Comm_split(MPI_COMM_WORLD,my_col,my_row,&colcomm); std::cout << " ------ PARALLEL Distribution Info for : ---------" < ViewMatrixType; - typedef Kokkos::View ViewVectorType_Host; + using ViewMatrixType = Kokkos::View; + using ViewVectorType_Host = Kokkos::View; + using ViewMatrixType_Host = Kokkos::View; + using ViewNrmVectorType_Host = Kokkos::View; #elif defined(SREAL) - typedef Kokkos::View ViewMatrixType; - typedef Kokkos::View ViewVectorType_Host; + using ViewMatrixType = Kokkos::View; + using ViewVectorType_Host = Kokkos::View; + using ViewMatrixType_Host = Kokkos::View; + using ViewNrmVectorType_Host = Kokkos::View; #elif defined(SCPLX) - typedef Kokkos::View**, Layout, TestSpace> ViewMatrixType; - typedef Kokkos::View*, Layout, Kokkos::HostSpace> ViewVectorType_Host; + using ViewMatrixType = Kokkos::View**, Layout, TestSpace>; + using ViewVectorType_Host = Kokkos::View*, Layout, Kokkos::HostSpace>; + using ViewMatrixType_Host = Kokkos::View**, Layout, Kokkos::HostSpace>; + using ViewNrmVectorType_Host = Kokkos::View; #else - typedef Kokkos::View**, Layout, TestSpace> ViewMatrixType; - typedef Kokkos::View*, Layout, Kokkos::HostSpace> ViewVectorType_Host; + using ViewMatrixType = Kokkos::View**, Layout, TestSpace>; + using ViewVectorType_Host = Kokkos::View*, Layout, Kokkos::HostSpace>; + using ViewMatrixType_Host = Kokkos::View**, Layout, Kokkos::HostSpace>; + using ViewNrmVectorType_Host = Kokkos::View; #endif - typedef typename ViewMatrixType::device_type::execution_space execution_space; - typedef typename ViewMatrixType::device_type::memory_space memory_space; - typedef typename ViewMatrixType::value_type ScalarA; + using execution_space = typename ViewMatrixType::device_type::execution_space; + using memory_space = typename ViewMatrixType::device_type::memory_space; + using ScalarA = typename ViewMatrixType::value_type; printf("Rank %d, ViewMatrixType execution_space %s, memory_space %s, value_type %s\n",rank, typeid(execution_space).name(), typeid(memory_space).name(), typeid(ScalarA).name()); @@ -275,15 +282,19 @@ int main(int argc, char *argv[]) ViewVectorType_Host temp2 ( "temp2", myrows ); - ViewVectorType_Host rhs ( "rhs", matrix_size ); + ViewMatrixType_Host rhs ( "rhs", matrix_size, numrhs ); - ViewVectorType_Host temp3 ( "temp3", matrix_size ); + ViewMatrixType_Host temp3 ( "temp3", matrix_size, numrhs ); - ViewVectorType_Host temp4 ( "temp4", matrix_size ); + ViewMatrixType_Host temp4 ( "temp4", matrix_size, numrhs ); - ViewVectorType_Host tempp ( "tempp", matrix_size ); + ViewMatrixType_Host tempp ( "tempp", matrix_size, numrhs ); + + ViewMatrixType_Host temp22( "temp22", matrix_size, numrhs ); - ViewVectorType_Host temp22( "temp22", matrix_size ); + ViewNrmVectorType_Host rhs_nrm( "rhs_nrm", numrhs ); + + ViewNrmVectorType_Host m_nrm ( "m_nrm", numrhs ); // Set Random values @@ -309,38 +320,73 @@ int main(int argc, char *argv[]) } } - // Sum to Processor 0 + // Sum from all processes and distribute the result back to all processes in rowcomm MPI_Allreduce(temp.data(), temp2.data(), myrows, ADELUS_MPI_DATA_TYPE, MPI_SUM, rowcomm); + // Find the location of my RHS in the global RHS + + int *nrhs_procs_rowcomm; + int my_rhs_offset = 0; + + nrhs_procs_rowcomm = (int*)malloc( nprocs_per_row * sizeof(int)); + MPI_Allgather(&myrhs, 1, MPI_INT, nrhs_procs_rowcomm, 1, MPI_INT, rowcomm);//gather numbers of rhs of other processes + + for (i=0; i 0 ) { - Kokkos::deep_copy( subview(h_A,Kokkos::ALL(),mycols), temp2 ); - Kokkos::deep_copy( subview(rhs,Kokkos::make_pair(myfirstrow - 1, myfirstrow - 1 + myrows)), temp2 ); + for (k = 0; k < myrhs; k++) { +#if defined(DREAL) || defined(ZCPLX) + ScalarA scal_factor = static_cast(my_rhs_offset+k+1); +#else + ScalarA scal_factor = static_cast(my_rhs_offset+k+1); +#endif + auto cur_rhs_vec_1d = subview(h_A,Kokkos::ALL(),mycols+k); + Kokkos::deep_copy( cur_rhs_vec_1d, temp2 ); + KokkosBlas::scal(cur_rhs_vec_1d,scal_factor,cur_rhs_vec_1d); + } + for (k = 0; k < numrhs; k++) { +#if defined(DREAL) || defined(ZCPLX) + ScalarA scal_factor = static_cast(k+1); +#else + ScalarA scal_factor = static_cast(k+1); +#endif + auto cur_rhs_vec_1d = subview(rhs,Kokkos::make_pair(myfirstrow - 1, myfirstrow - 1 + myrows),k); + Kokkos::deep_copy( cur_rhs_vec_1d, temp2 ); + KokkosBlas::scal(cur_rhs_vec_1d,scal_factor,cur_rhs_vec_1d); + } } // Globally Sum the RHS needed for testing later - MPI_Allreduce(rhs.data(), temp4.data(), matrix_size, ADELUS_MPI_DATA_TYPE, MPI_SUM, MPI_COMM_WORLD); + MPI_Allreduce(rhs.data(), temp4.data(), matrix_size*numrhs, ADELUS_MPI_DATA_TYPE, MPI_SUM, colcomm); // Pack back into RHS Kokkos::deep_copy( rhs, temp4 ); - rhs_nrm = KokkosBlas::nrm2(rhs); + KokkosBlas::nrm2(rhs_nrm, rhs); - Kokkos::deep_copy( subview(A,Kokkos::ALL(),mycols), subview(h_A,Kokkos::ALL(),mycols) ); + Kokkos::deep_copy( subview(A, Kokkos::ALL(),Kokkos::make_pair(mycols, mycols + myrhs)), + subview(h_A, Kokkos::ALL(),Kokkos::make_pair(mycols, mycols + myrhs)) ); + + // Create handle + Adelus::AdelusHandle + ahandle(0, MPI_COMM_WORLD, matrix_size, nprocs_per_row, numrhs ); // Now Solve the Problem if( rank == 0 ) std::cout << " **** Beginning Matrix Solve ****" << std::endl; - Adelus::FactorSolve (A, myrows, mycols, &matrix_size, &nprocs_per_row, &numrhs, &secs); + Adelus::FactorSolve (ahandle, A, &secs); if( rank == 0) { std::cout << " ---- Solution time ---- " << secs << " in secs. " << std::endl; @@ -352,28 +398,32 @@ int main(int argc, char *argv[]) // Now Check the Solution - Kokkos::deep_copy( subview(h_A,Kokkos::ALL(),mycols), subview(A,Kokkos::ALL(),mycols) ); + Kokkos::deep_copy( subview(h_A, Kokkos::ALL(),Kokkos::make_pair(mycols, mycols + myrhs)), + subview(A, Kokkos::ALL(),Kokkos::make_pair(mycols, mycols + myrhs)) ); // Pack the Answer into the apropriate position - if ( myrhs > 0) { - Kokkos::deep_copy( subview(tempp,Kokkos::make_pair(myfirstrow - 1, myfirstrow - 1 + myrows)), subview(h_A,Kokkos::ALL(),mycols) ); + if ( myrhs > 0 ) { + Kokkos::deep_copy( subview(tempp,Kokkos::make_pair(myfirstrow - 1, myfirstrow - 1 + myrows), + Kokkos::make_pair(my_rhs_offset, my_rhs_offset + myrhs)), + subview(h_A,Kokkos::ALL(),Kokkos::make_pair(mycols, mycols + myrhs)) ); } // All processors get the answer - MPI_Allreduce(tempp.data(), temp22.data(), matrix_size, ADELUS_MPI_DATA_TYPE, MPI_SUM, MPI_COMM_WORLD); - - // perform the Matrix vector product + MPI_Allreduce(tempp.data(), temp22.data(), matrix_size*numrhs, ADELUS_MPI_DATA_TYPE, MPI_SUM, MPI_COMM_WORLD); + // Perform the Matrix vector product + ScalarA alpha = 1.0; ScalarA beta = 0.0; - KokkosBlas::gemv("N", alpha, subview(h_A,Kokkos::ALL(),Kokkos::make_pair(0, mycols)), - subview(temp22,Kokkos::make_pair(myfirstcol - 1, myfirstcol - 1 + mycols)), - beta, subview(tempp,Kokkos::make_pair(myfirstrow - 1, myfirstrow - 1 + myrows))); + KokkosBlas::gemm("N", "N", alpha, + subview(h_A,Kokkos::ALL(),Kokkos::make_pair(0, mycols)), + subview(temp22,Kokkos::make_pair(myfirstcol - 1, myfirstcol - 1 + mycols),Kokkos::ALL()), + beta, subview(tempp,Kokkos::make_pair(myfirstrow - 1, myfirstrow - 1 + myrows),Kokkos::ALL())); - MPI_Allreduce(tempp.data(), temp3.data(), matrix_size, ADELUS_MPI_DATA_TYPE, MPI_SUM, MPI_COMM_WORLD); + MPI_Allreduce(tempp.data(), temp3.data(), matrix_size*numrhs, ADELUS_MPI_DATA_TYPE, MPI_SUM, MPI_COMM_WORLD); if( rank == 0) { std::cout << "======================================" << std::endl; @@ -381,9 +431,9 @@ int main(int argc, char *argv[]) ScalarA alpha_ = -1.0; - KokkosBlas::axpy(alpha_,rhs,temp3);//temp3=temp3-rhs + KokkosBlas::axpy(alpha_, rhs, temp3);//temp3=temp3-rhs - m_nrm = KokkosBlas::nrm2(temp3); + KokkosBlas::nrm2(m_nrm, temp3); } // Machine epsilon Calculation @@ -395,32 +445,34 @@ int main(int argc, char *argv[]) eps = fabs(tempc-1.0); if ( rank == 0 ) { - std::cout << " Machine eps " << eps << std::endl; - } - - if ( rank == 0 ) { + std::cout << " Machine eps " << eps << std::endl; - std::cout << " ||Ax - b||_2 = " << m_nrm << std::endl; + std::cout << " Threshold = " << eps*1e4 << std::endl; - std::cout << " ||b||_2 = " << rhs_nrm << std::endl; + for (k = 0; k < numrhs; k++) { + std::cout << " Solution " << k << ": ||Ax - b||_2 = " << m_nrm(k) << std::endl; - std::cout << " ||Ax - b||_2 / ||b||_2 = " << m_nrm/rhs_nrm << std::endl; + std::cout << " Solution " << k << ": ||b||_2 = " << rhs_nrm(k) << std::endl; - std::cout << " Threshold = " << eps*1e4 << std::endl; + std::cout << " Solution " << k << ": ||Ax - b||_2 / ||b||_2 = " << m_nrm(k)/rhs_nrm(k) << std::endl; - if ( m_nrm/rhs_nrm > (eps*1e4)) { - std::cout << " **** Solution Fails ****" << std::endl; - result = 1; - } - else { - std::cout << " **** Solution Passes ****" << std::endl; - result = 0; + if ( m_nrm(k)/rhs_nrm(k) > (eps*1e4)) { + std::cout << " **** Solution " << k << " Fails ****" << std::endl; + result = 1; + break; + } + else { + std::cout << " **** Solution " << k << " Passes ****" << std::endl; + result = 0; + } } std::cout << "======================================" << std::endl; } MPI_Bcast(&result, 1, MPI_INT, 0, MPI_COMM_WORLD); + free(nrhs_procs_rowcomm); + } Kokkos::finalize(); diff --git a/packages/adelus/test/vector_random_fs/CMakeLists.txt b/packages/adelus/test/vector_random_fs/CMakeLists.txt new file mode 100644 index 000000000000..b529a64dac62 --- /dev/null +++ b/packages/adelus/test/vector_random_fs/CMakeLists.txt @@ -0,0 +1,196 @@ +#1 RANK +TRIBITS_ADD_EXECUTABLE_AND_TEST( + vector_random_fs + SOURCES cxx_main.cpp + NAME vector_random_fs_npr1_rhs1 + NUM_MPI_PROCS 1 + ARGS "3501 1 1 1" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr1_rhs4 + NUM_MPI_PROCS 1 + ARGS "3501 1 1 4" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr1_rhs5 + NUM_MPI_PROCS 1 + ARGS "3501 1 1 5" + COMM mpi + ) + +#2 RANKS +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr1_rhs1 + NUM_MPI_PROCS 2 + ARGS "3501 1 2 1" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr1_rhs4 + NUM_MPI_PROCS 2 + ARGS "3501 1 2 4" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr1_rhs5 + NUM_MPI_PROCS 2 + ARGS "3501 1 2 5" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr2_rhs1 + NUM_MPI_PROCS 2 + ARGS "3501 2 2 1" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr2_rhs4 + NUM_MPI_PROCS 2 + ARGS "3501 2 2 4" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr2_rhs5 + NUM_MPI_PROCS 2 + ARGS "3501 2 2 5" + COMM mpi + ) + +#3 RANKS +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr1_rhs1 + NUM_MPI_PROCS 3 + ARGS "3501 1 3 1" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr1_rhs4 + NUM_MPI_PROCS 3 + ARGS "3501 1 3 4" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr1_rhs5 + NUM_MPI_PROCS 3 + ARGS "3501 1 3 5" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr3_rhs1 + NUM_MPI_PROCS 3 + ARGS "3501 3 3 1" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr3_rhs4 + NUM_MPI_PROCS 3 + ARGS "3501 3 3 4" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr3_rhs5 + NUM_MPI_PROCS 3 + ARGS "3501 3 3 5" + COMM mpi + ) + +#4 RANKS +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr1_rhs1 + NUM_MPI_PROCS 4 + ARGS "3501 1 4 1" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr1_rhs4 + NUM_MPI_PROCS 4 + ARGS "3501 1 4 4" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr1_rhs5 + NUM_MPI_PROCS 4 + ARGS "3501 1 4 5" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr2_rhs1 + NUM_MPI_PROCS 4 + ARGS "3501 2 4 1" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr2_rhs4 + NUM_MPI_PROCS 4 + ARGS "3501 2 4 4" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr2_rhs5 + NUM_MPI_PROCS 4 + ARGS "3501 2 4 5" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr4_rhs1 + NUM_MPI_PROCS 4 + ARGS "3501 4 4 1" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr4_rhs4 + NUM_MPI_PROCS 4 + ARGS "3501 4 4 4" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_fs + NAME vector_random_fs_npr4_rhs5 + NUM_MPI_PROCS 4 + ARGS "3501 4 4 5" + COMM mpi + ) diff --git a/packages/adelus/test/vector_random_fs/cxx_main.cpp b/packages/adelus/test/vector_random_fs/cxx_main.cpp new file mode 100644 index 000000000000..e0d49021d905 --- /dev/null +++ b/packages/adelus/test/vector_random_fs/cxx_main.cpp @@ -0,0 +1,511 @@ +/* +//@HEADER +// ************************************************************************ +// +// Adelus v. 1.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of NTESS nor the names of the contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY EXPRESS OR IMPLIED +// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +// IN NO EVENT SHALL NTESS OR THE CONTRIBUTORS BE LIABLE FOR ANY DIRECT, +// INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING +// IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Vinh Dang (vqdang@sandia.gov) +// Joseph Kotulski (jdkotul@sandia.gov) +// Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +int main(int argc, char *argv[]) +{ + char processor_name[MPI_MAX_PROCESSOR_NAME]; + int name_len; + int rank, size; + + int myrows; + int mycols; + int myfirstrow; + int myfirstcol; + int myrhs; + int my_row; + int my_col; + int matrix_size; + int nprocs_per_row; + int nptile = 1; // number of processors per node + int numrhs = 1; + + double mflops; + + MPI_Comm rowcomm, colcomm; + + static int buf[4]; + + int i, m, k; + + int mlen; // Message length for input data + + unsigned int seed= 10; + + double secs; + + double eps; + + double othird; + + double four_thirds = 4./3.; + + double tempc; + + int result=1; + + // Enroll into MPI + + MPI_Init(&argc,&argv); /* starts MPI */ + MPI_Comm_rank (MPI_COMM_WORLD, &rank); /* get current process id */ + MPI_Comm_size (MPI_COMM_WORLD, &size); /* get number of processes */ + MPI_Get_processor_name(processor_name, &name_len); /* get name of the processor */ + + // Initialize Input buffer + + for(i=0;i<4;i++) buf[i]=-1; + + std::cout << "proc " << rank << " (" << processor_name << ") is alive of " << size << " Processors" << std::endl; + + if( rank == 0 ) { + // Check for commandline input + + if (argc > 1) { + // argv[1] should be size of matrix + buf[0] = atoi(argv[1]); + if (argc > 2) { + // argv[2] should be #procs per row + buf[1] = atoi(argv[2]); + // argv[3] should be #procs per node + buf[2] = atoi(argv[3]); + // argv[4] should be #rhs + buf[3] = atoi(argv[4]); + } + else { + // default is 1, but sqrt(p) would be better + buf[1] = 1; buf[2] = 1; buf[3] = 1; + } + } + else { + // Input Data about matrix and distribution + + if (buf[0] < 0) { + std::cout << "Enter size of matrix " << std::endl; + std::cin >> buf[0]; + } + if (buf[1] < 0) { + std::cout << "Enter number of processors to which each row is assigned " << std::endl; + std::cin >> buf[1]; + } + if (buf[2] < 0) { + std::cout << "Enter number of processors per node " << std::endl; + std::cin >> buf[2]; + } + if (buf[3] < 0) { + std::cout << "Enter number of rhs vectors " << std::endl; + std::cin >> buf[3]; + } + } + } + + // Send the initilization data to each processor + mlen = 4*sizeof(int); + + MPI_Bcast(reinterpret_cast(buf), mlen, MPI_CHAR, 0, MPI_COMM_WORLD); + + // Set the values where needed + + matrix_size = buf[0]; + + nprocs_per_row = buf[1]; + + nptile = buf[2]; + + numrhs = buf[3]; + + if( rank == 0 ) { + std::cout << " Matrix Size " << matrix_size << std::endl; + std::cout << " Processors in a row " << nprocs_per_row << std::endl; + std::cout << " Processors in a node " << nptile << std::endl; + std::cout << " Number of RHS vectors " << numrhs << std::endl; + } + + if( rank == 0) { + std::cout << " ---- Building Adelus solver ----" << std::endl; + } + + // Get Info to build the matrix on a processor + + Adelus::GetDistribution( MPI_COMM_WORLD, + nprocs_per_row, matrix_size, numrhs, + myrows, mycols, myfirstrow, myfirstcol, + myrhs, my_row, my_col ); + + // Define new communicators: rowcomm and colcomm + + MPI_Comm_split(MPI_COMM_WORLD,my_row,my_col,&rowcomm); + MPI_Comm_split(MPI_COMM_WORLD,my_col,my_row,&colcomm); + + std::cout << " ------ PARALLEL Distribution Info for : ---------" < gpu_count) { + if( rank == 0 ) { + std::cout << "Request more GPUs than the number of GPUs available " + << "to MPI processes (requested: " << nptile + << " vs. available: " << gpu_count + << "). Exit without test." << std::endl; + } + MPI_Finalize() ; + return 0; + } + + Kokkos::InitArguments args; + args.num_threads = 0; + args.num_numa = 0; + args.device_id = rank%nptile; + std::cout << " Processor " << rank << " (" << processor_name << "), GPU: " + << args.device_id << "/" << gpu_count << std::endl; + Kokkos::initialize( args ); +#else + Kokkos::initialize( argc, argv ); +#endif + { + // Local size -- myrows * (mycols + myrhs) + + using Layout = Kokkos::LayoutLeft; +#if defined(KOKKOS_ENABLE_CUDA) + using TestSpace = Kokkos::CudaSpace; +#elif defined(KOKKOS_ENABLE_HIP) + using TestSpace = Kokkos::Experimental::HIPSpace; +#else + using TestSpace = Kokkos::HostSpace; +#endif +#ifdef DREAL + using ViewMatrixType = Kokkos::View; + using ViewVectorType_Host = Kokkos::View; + using ViewMatrixType_Host = Kokkos::View; + using ViewNrmVectorType_Host = Kokkos::View; +#elif defined(SREAL) + using ViewMatrixType = Kokkos::View; + using ViewVectorType_Host = Kokkos::View; + using ViewMatrixType_Host = Kokkos::View; + using ViewNrmVectorType_Host = Kokkos::View; +#elif defined(SCPLX) + using ViewMatrixType = Kokkos::View**, Layout, TestSpace>; + using ViewVectorType_Host = Kokkos::View*, Layout, Kokkos::HostSpace>; + using ViewMatrixType_Host = Kokkos::View**, Layout, Kokkos::HostSpace>; + using ViewNrmVectorType_Host = Kokkos::View; +#else + using ViewMatrixType = Kokkos::View**, Layout, TestSpace>; + using ViewVectorType_Host = Kokkos::View*, Layout, Kokkos::HostSpace>; + using ViewMatrixType_Host = Kokkos::View**, Layout, Kokkos::HostSpace>; + using ViewNrmVectorType_Host = Kokkos::View; +#endif + + using ViewIntType_Host= Kokkos::View; + + using execution_space = typename ViewMatrixType::device_type::execution_space; + using memory_space = typename ViewMatrixType::device_type::memory_space; + using ScalarA = typename ViewMatrixType::value_type; + + printf("Rank %d, ViewMatrixType execution_space %s, memory_space %s, value_type %s\n",rank, typeid(execution_space).name(), typeid(memory_space).name(), typeid(ScalarA).name()); + + ViewMatrixType A( "A", myrows, mycols ); + ViewMatrixType B( "B", myrows, myrhs + 6 ); + + ViewMatrixType::HostMirror h_A = Kokkos::create_mirror( A ); + ViewMatrixType::HostMirror h_B = Kokkos::create_mirror( B ); + + // Some temp arrays + + ViewVectorType_Host temp ( "temp", myrows ); + + ViewVectorType_Host temp2 ( "temp2", myrows ); + + ViewMatrixType_Host rhs ( "rhs", matrix_size, numrhs ); + + ViewMatrixType_Host temp3 ( "temp3", matrix_size, numrhs ); + + ViewMatrixType_Host temp4 ( "temp4", matrix_size, numrhs ); + + ViewMatrixType_Host tempp ( "tempp", matrix_size, numrhs ); + + ViewMatrixType_Host temp22( "temp22", matrix_size, numrhs ); + + ViewNrmVectorType_Host rhs_nrm( "rhs_nrm", numrhs ); + + ViewNrmVectorType_Host m_nrm ( "m_nrm", numrhs ); + + ViewIntType_Host h_permute( "h_permute", matrix_size);// Permutation array for factor and solve done independently + + // Set Random values + + if( rank == 0 ) + std::cout << " **** Setting Random Matrix ****" << std::endl; + + Kokkos::Random_XorShift64_Pool rand_pool(seed+rank); + Kokkos::fill_random(A, rand_pool,Kokkos::rand,ScalarA >::max()); + + Kokkos::deep_copy( h_A, A ); + + // Now Create the RHS + + if( rank == 0 ) + std::cout << " **** Creating RHS ****" << std::endl; + + // Sum the portion of the row that I have + + for (k= 0; k < myrows; k++) { + temp(k) = 0; + for (m=0; m < mycols; m++) { + temp(k) = temp(k) + h_A(k,m); + } + } + + // Sum from all processes and distribute the result back to all processes in rowcomm + + MPI_Allreduce(temp.data(), temp2.data(), myrows, ADELUS_MPI_DATA_TYPE, MPI_SUM, rowcomm); + + // Find the location of my RHS in the global RHS + + int *nrhs_procs_rowcomm; + int my_rhs_offset = 0; + + nrhs_procs_rowcomm = (int*)malloc( nprocs_per_row * sizeof(int)); + MPI_Allgather(&myrhs, 1, MPI_INT, nrhs_procs_rowcomm, 1, MPI_INT, rowcomm);//gather numbers of rhs of other processes + + for (i=0; i 0 ) { + for (k = 0; k < myrhs; k++) { +#if defined(DREAL) || defined(ZCPLX) + ScalarA scal_factor = static_cast(my_rhs_offset+k+1); +#else + ScalarA scal_factor = static_cast(my_rhs_offset+k+1); +#endif + auto cur_rhs_vec_1d = subview(h_B,Kokkos::ALL(),k); + Kokkos::deep_copy( cur_rhs_vec_1d, temp2 ); + KokkosBlas::scal(cur_rhs_vec_1d,scal_factor,cur_rhs_vec_1d); + } + for (k = 0; k < numrhs; k++) { +#if defined(DREAL) || defined(ZCPLX) + ScalarA scal_factor = static_cast(k+1); +#else + ScalarA scal_factor = static_cast(k+1); +#endif + auto cur_rhs_vec_1d = subview(rhs,Kokkos::make_pair(myfirstrow - 1, myfirstrow - 1 + myrows),k); + Kokkos::deep_copy( cur_rhs_vec_1d, temp2 ); + KokkosBlas::scal(cur_rhs_vec_1d,scal_factor,cur_rhs_vec_1d); + } + } + + // Globally Sum the RHS needed for testing later + + MPI_Allreduce(rhs.data(), temp4.data(), matrix_size*numrhs, ADELUS_MPI_DATA_TYPE, MPI_SUM, colcomm); + + // Pack back into RHS + + Kokkos::deep_copy( rhs, temp4 ); + + KokkosBlas::nrm2(rhs_nrm, rhs); + + Kokkos::deep_copy( B, h_B ); + + + // Create handle + Adelus::AdelusHandle + ahandle(0, MPI_COMM_WORLD, matrix_size, nprocs_per_row, numrhs ); + + // Now Factor the matrix + + if( rank == 0 ) + std::cout << " **** Beginning Matrix Factor ****" << std::endl; + + Adelus::Factor (ahandle, A, h_permute, &secs); + + if( rank == 0) { + std::cout << " ---- Factor time ---- " << secs << " in secs. " << std::endl; + + mflops = 2./3.*pow(matrix_size,3.)/secs/1000000.; + + std::cout << " ***** MFLOPS ***** " << mflops << std::endl; + } + + // Call Solve (1st time) + + if( rank == 0 ) + std::cout << " **** Beginning Matrix Solve (1st) ****" << std::endl; + + Adelus::Solve (ahandle, A, B, h_permute, &secs); + + if( rank == 0) + std::cout << " ---- Solution time (1st) ---- " << secs << " in secs. " << std::endl; + + // Restore the orig. RHS for testing Adelus::Solve() on a pre-computed LU factorization + Kokkos::deep_copy( B, h_B ); + + // Call Solve (2nd time) + if( rank == 0 ) + std::cout << " **** Beginning Matrix Solve (2nd) ****" << std::endl; + + Adelus::Solve (ahandle, A, B, h_permute, &secs); + + if( rank == 0) + std::cout << " ---- Solution time (2nd) ---- " << secs << " in secs. " << std::endl; + + // Now Check the Solution + + Kokkos::deep_copy( h_B, B ); + + + // Pack the Answer into the apropriate position + + if ( myrhs > 0 ) { + Kokkos::deep_copy( subview(tempp,Kokkos::make_pair(myfirstrow - 1, myfirstrow - 1 + myrows), + Kokkos::make_pair(my_rhs_offset, my_rhs_offset + myrhs)), + subview(h_B,Kokkos::ALL(),Kokkos::make_pair(0, myrhs)) ); + } + + // All processors get the answer + + MPI_Allreduce(tempp.data(), temp22.data(), matrix_size*numrhs, ADELUS_MPI_DATA_TYPE, MPI_SUM, MPI_COMM_WORLD); + + // Perform the Matrix vector product + + ScalarA alpha = 1.0; + ScalarA beta = 0.0; + + KokkosBlas::gemm("N", "N", alpha, + subview(h_A,Kokkos::ALL(),Kokkos::make_pair(0, mycols)), + subview(temp22,Kokkos::make_pair(myfirstcol - 1, myfirstcol - 1 + mycols),Kokkos::ALL()), + beta, subview(tempp,Kokkos::make_pair(myfirstrow - 1, myfirstrow - 1 + myrows),Kokkos::ALL())); + + MPI_Allreduce(tempp.data(), temp3.data(), matrix_size*numrhs, ADELUS_MPI_DATA_TYPE, MPI_SUM, MPI_COMM_WORLD); + + if( rank == 0) { + std::cout << "======================================" << std::endl; + std::cout << " ---- Error Calculation ----" << std::endl; + + ScalarA alpha_ = -1.0; + + KokkosBlas::axpy(alpha_, rhs, temp3);//temp3=temp3-rhs + + KokkosBlas::nrm2(m_nrm, temp3); + } + + // Machine epsilon Calculation + + othird = four_thirds - 1.; + + tempc = othird + othird + othird; + + eps = fabs(tempc-1.0); + + if ( rank == 0 ) { + std::cout << " Machine eps " << eps << std::endl; + + std::cout << " Threshold = " << eps*1e4 << std::endl; + + for (k = 0; k < numrhs; k++) { + std::cout << " Solution " << k << ": ||Ax - b||_2 = " << m_nrm(k) << std::endl; + + std::cout << " Solution " << k << ": ||b||_2 = " << rhs_nrm(k) << std::endl; + + std::cout << " Solution " << k << ": ||Ax - b||_2 / ||b||_2 = " << m_nrm(k)/rhs_nrm(k) << std::endl; + + if ( m_nrm(k)/rhs_nrm(k) > (eps*1e4)) { + std::cout << " **** Solution " << k << " Fails ****" << std::endl; + result = 1; + break; + } + else { + std::cout << " **** Solution " << k << " Passes ****" << std::endl; + result = 0; + } + } + std::cout << "======================================" << std::endl; + } + + MPI_Bcast(&result, 1, MPI_INT, 0, MPI_COMM_WORLD); + + free(nrhs_procs_rowcomm); + + } + Kokkos::finalize(); + + MPI_Finalize() ; + + return (result); +} diff --git a/packages/adelus/test/vector_random_mc/CMakeLists.txt b/packages/adelus/test/vector_random_mc/CMakeLists.txt new file mode 100644 index 000000000000..40758e01a44b --- /dev/null +++ b/packages/adelus/test/vector_random_mc/CMakeLists.txt @@ -0,0 +1,74 @@ +#2 RANKS -- 2 COMMS, EACH COMM has 1 RANK +TRIBITS_ADD_EXECUTABLE_AND_TEST( + vector_random_mc + SOURCES cxx_main.cpp + NAME vector_random_mc_npr1_rhs1 + NUM_MPI_PROCS 2 + ARGS "3501 1 2 1" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_mc + NAME vector_random_mc_npr1_rhs4 + NUM_MPI_PROCS 2 + ARGS "3501 1 2 4" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_mc + NAME vector_random_mc_npr1_rhs5 + NUM_MPI_PROCS 2 + ARGS "3501 1 2 5" + COMM mpi + ) + +#4 RANKS -- 2 COMMS, EACH COMM has 2 RANKS +TRIBITS_ADD_TEST( + vector_random_mc + NAME vector_random_mc_npr1_rhs1 + NUM_MPI_PROCS 4 + ARGS "3501 1 4 1" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_mc + NAME vector_random_mc_npr1_rhs4 + NUM_MPI_PROCS 4 + ARGS "3501 1 4 4" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_mc + NAME vector_random_mc_npr1_rhs5 + NUM_MPI_PROCS 4 + ARGS "3501 1 4 5" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_mc + NAME vector_random_mc_npr2_rhs1 + NUM_MPI_PROCS 4 + ARGS "3501 2 4 1" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_mc + NAME vector_random_mc_npr2_rhs4 + NUM_MPI_PROCS 4 + ARGS "3501 2 4 4" + COMM mpi + ) + +TRIBITS_ADD_TEST( + vector_random_mc + NAME vector_random_mc_npr2_rhs5 + NUM_MPI_PROCS 4 + ARGS "3501 2 4 5" + COMM mpi + ) diff --git a/packages/adelus/test/vector_random_mc/cxx_main.cpp b/packages/adelus/test/vector_random_mc/cxx_main.cpp new file mode 100644 index 000000000000..c7904f361b29 --- /dev/null +++ b/packages/adelus/test/vector_random_mc/cxx_main.cpp @@ -0,0 +1,496 @@ +/* +//@HEADER +// ************************************************************************ +// +// Adelus v. 1.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of NTESS nor the names of the contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY EXPRESS OR IMPLIED +// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +// IN NO EVENT SHALL NTESS OR THE CONTRIBUTORS BE LIABLE FOR ANY DIRECT, +// INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING +// IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Vinh Dang (vqdang@sandia.gov) +// Joseph Kotulski (jdkotul@sandia.gov) +// Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +int main(int argc, char *argv[]) +{ + char processor_name[MPI_MAX_PROCESSOR_NAME]; + int name_len; + int rank, size; + int sub_rank/*, sub_size*/; + + int myrows; + int mycols; + int myfirstrow; + int myfirstcol; + int myrhs; + int my_row; + int my_col; + int matrix_size; + int nprocs_per_row; + int nptile = 1; // number of processors per node + int numrhs = 1; + + double mflops; + + MPI_Comm sub_comm, rowcomm, colcomm; + + static int buf[4]; + + int i, m, k; + + int mlen; // Message length for input data + + unsigned int seed= 10; + + double secs; + + double eps; + + double othird; + + double four_thirds = 4./3.; + + double tempc; + + int result=1; + + // Enroll into MPI + + MPI_Init(&argc,&argv); /* starts MPI */ + MPI_Comm_rank (MPI_COMM_WORLD, &rank); /* get current process id */ + MPI_Comm_size (MPI_COMM_WORLD, &size); /* get number of processes */ + MPI_Get_processor_name(processor_name, &name_len); /* get name of the processor */ + + // Divide the global comm into 2 halves communicators + int my_color = rank/(size/2);//NOTE: colors for first and second communicators + int my_key = rank%(size/2);//NOTE: rank in each new communicator + MPI_Comm_split (MPI_COMM_WORLD, my_color, my_key, &sub_comm); + MPI_Comm_rank (sub_comm, &sub_rank); + //MPI_Comm_size (sub_comm, &sub_size); + + // Initialize Input buffer + + for(i=0;i<4;i++) buf[i]=-1; + + std::cout << "proc " << rank << " (sub rank " << sub_rank << ") (" << processor_name << ") is alive of " << size << " Processors" << std::endl; + + if( rank == 0 ) { + // Check for commandline input + + if (argc > 1) { + // argv[1] should be size of matrix + buf[0] = atoi(argv[1]); + if (argc > 2) { + // argv[2] should be #procs per row + buf[1] = atoi(argv[2]); + // argv[3] should be #procs per node + buf[2] = atoi(argv[3]); + // argv[4] should be #rhs + buf[3] = atoi(argv[4]); + } + else { + // default is 1, but sqrt(p) would be better + buf[1] = 1; buf[2] = 1; buf[3] = 1; + } + } + else { + // Input Data about matrix and distribution + + if (buf[0] < 0) { + std::cout << "Enter size of matrix " << std::endl; + std::cin >> buf[0]; + } + if (buf[1] < 0) { + std::cout << "Enter number of processors to which each row is assigned " << std::endl; + std::cin >> buf[1]; + } + if (buf[2] < 0) { + std::cout << "Enter number of processors per node " << std::endl; + std::cin >> buf[2]; + } + if (buf[3] < 0) { + std::cout << "Enter number of rhs vectors " << std::endl; + std::cin >> buf[3]; + } + } + } + + // Send the initilization data to each processor + mlen = 4*sizeof(int); + + MPI_Bcast(reinterpret_cast(buf), mlen, MPI_CHAR, 0, MPI_COMM_WORLD); + + // Set the values where needed + + matrix_size = buf[0]; + + nprocs_per_row = buf[1]; + + nptile = buf[2]; + + numrhs = buf[3]; + + if( rank == 0 ) { + std::cout << " Matrix Size " << matrix_size << std::endl; + std::cout << " Processors in a row " << nprocs_per_row << std::endl; + std::cout << " Processors in a node " << nptile << std::endl; + std::cout << " Number of RHS vectors " << numrhs << std::endl; + } + + if( rank == 0) { + std::cout << " ---- Building Adelus solver ----" << std::endl; + } + + // Get Info to build the matrix on a processor + + Adelus::GetDistribution( sub_comm, + nprocs_per_row, matrix_size, numrhs, + myrows, mycols, myfirstrow, myfirstcol, + myrhs, my_row, my_col ); + + // Define new communicators: rowcomm and colcomm + + MPI_Comm_split(sub_comm,my_row,my_col,&rowcomm); + MPI_Comm_split(sub_comm,my_col,my_row,&colcomm); + + std::cout << " ------ PARALLEL Distribution Info for : ---------" < gpu_count) { + if( rank == 0 ) { + std::cout << "Request more GPUs than the number of GPUs available " + << "to MPI processes (requested: " << nptile + << " vs. available: " << gpu_count + << "). Exit without test." << std::endl; + } + MPI_Finalize() ; + return 0; + } + + Kokkos::InitArguments args; + args.num_threads = 0; + args.num_numa = 0; + args.device_id = rank%nptile; + std::cout << " Processor " << rank << " (" << processor_name << "), GPU: " + << args.device_id << "/" << gpu_count << std::endl; + Kokkos::initialize( args ); +#else + Kokkos::initialize( argc, argv ); +#endif + { + // Local size -- myrows * (mycols + myrhs) + + using Layout = Kokkos::LayoutLeft; +#if defined(KOKKOS_ENABLE_CUDA) + using TestSpace = Kokkos::CudaSpace; +#elif defined(KOKKOS_ENABLE_HIP) + using TestSpace = Kokkos::Experimental::HIPSpace; +#else + using TestSpace = Kokkos::HostSpace; +#endif +#ifdef DREAL + using ViewMatrixType = Kokkos::View; + using ViewVectorType_Host = Kokkos::View; + using ViewMatrixType_Host = Kokkos::View; + using ViewNrmVectorType_Host = Kokkos::View; +#elif defined(SREAL) + using ViewMatrixType = Kokkos::View; + using ViewVectorType_Host = Kokkos::View; + using ViewMatrixType_Host = Kokkos::View; + using ViewNrmVectorType_Host = Kokkos::View; +#elif defined(SCPLX) + using ViewMatrixType = Kokkos::View**, Layout, TestSpace>; + using ViewVectorType_Host = Kokkos::View*, Layout, Kokkos::HostSpace>; + using ViewMatrixType_Host = Kokkos::View**, Layout, Kokkos::HostSpace>; + using ViewNrmVectorType_Host = Kokkos::View; +#else + using ViewMatrixType = Kokkos::View**, Layout, TestSpace>; + using ViewVectorType_Host = Kokkos::View*, Layout, Kokkos::HostSpace>; + using ViewMatrixType_Host = Kokkos::View**, Layout, Kokkos::HostSpace>; + using ViewNrmVectorType_Host = Kokkos::View; +#endif + using execution_space = typename ViewMatrixType::device_type::execution_space; + using memory_space = typename ViewMatrixType::device_type::memory_space; + using ScalarA = typename ViewMatrixType::value_type; + + printf("Rank %d, ViewMatrixType execution_space %s, memory_space %s, value_type %s\n",rank, typeid(execution_space).name(), typeid(memory_space).name(), typeid(ScalarA).name()); + + ViewMatrixType A( "A", myrows, mycols + myrhs + 6 ); + + ViewMatrixType::HostMirror h_A = Kokkos::create_mirror( A ); + + // Some temp arrays + + ViewVectorType_Host temp ( "temp", myrows ); + + ViewVectorType_Host temp2 ( "temp2", myrows ); + + ViewMatrixType_Host rhs ( "rhs", matrix_size, numrhs ); + + ViewMatrixType_Host temp3 ( "temp3", matrix_size, numrhs ); + + ViewMatrixType_Host temp4 ( "temp4", matrix_size, numrhs ); + + ViewMatrixType_Host tempp ( "tempp", matrix_size, numrhs ); + + ViewMatrixType_Host temp22( "temp22", matrix_size, numrhs ); + + ViewNrmVectorType_Host rhs_nrm( "rhs_nrm", numrhs ); + + ViewNrmVectorType_Host m_nrm ( "m_nrm", numrhs ); + + // Set Random values + + if( rank == 0 ) + std::cout << " **** Setting Random Matrix ****" << std::endl; + + Kokkos::Random_XorShift64_Pool rand_pool(seed+rank); + Kokkos::fill_random(A, rand_pool,Kokkos::rand,ScalarA >::max()); + + Kokkos::deep_copy( h_A, A ); + + // Now Create the RHS + + if( rank == 0 ) + std::cout << " **** Creating RHS ****" << std::endl; + + // Sum the portion of the row that I have + + for (k= 0; k < myrows; k++) { + temp(k) = 0; + for (m=0; m < mycols; m++) { + temp(k) = temp(k) + h_A(k,m); + } + } + + // Sum from all processes and distribute the result back to all processes in rowcomm + + MPI_Allreduce(temp.data(), temp2.data(), myrows, ADELUS_MPI_DATA_TYPE, MPI_SUM, rowcomm); + + // Find the location of my RHS in the global RHS + + int *nrhs_procs_rowcomm; + int my_rhs_offset = 0; + + nrhs_procs_rowcomm = (int*)malloc( nprocs_per_row * sizeof(int)); + MPI_Allgather(&myrhs, 1, MPI_INT, nrhs_procs_rowcomm, 1, MPI_INT, rowcomm);//gather numbers of rhs of other processes + + for (i=0; i 0 ) { + for (k = 0; k < myrhs; k++) { +#if defined(DREAL) || defined(ZCPLX) + ScalarA scal_factor = static_cast(my_rhs_offset+k+1); +#else + ScalarA scal_factor = static_cast(my_rhs_offset+k+1); +#endif + auto cur_rhs_vec_1d = subview(h_A,Kokkos::ALL(),mycols+k); + Kokkos::deep_copy( cur_rhs_vec_1d, temp2 ); + KokkosBlas::scal(cur_rhs_vec_1d,scal_factor,cur_rhs_vec_1d); + } + for (k = 0; k < numrhs; k++) { +#if defined(DREAL) || defined(ZCPLX) + ScalarA scal_factor = static_cast(k+1); +#else + ScalarA scal_factor = static_cast(k+1); +#endif + auto cur_rhs_vec_1d = subview(rhs,Kokkos::make_pair(myfirstrow - 1, myfirstrow - 1 + myrows),k); + Kokkos::deep_copy( cur_rhs_vec_1d, temp2 ); + KokkosBlas::scal(cur_rhs_vec_1d,scal_factor,cur_rhs_vec_1d); + } + } + + // Globally Sum the RHS needed for testing later + + MPI_Allreduce(rhs.data(), temp4.data(), matrix_size*numrhs, ADELUS_MPI_DATA_TYPE, MPI_SUM, colcomm); + + // Pack back into RHS + + Kokkos::deep_copy( rhs, temp4 ); + + KokkosBlas::nrm2(rhs_nrm, rhs); + + Kokkos::deep_copy( subview(A, Kokkos::ALL(),Kokkos::make_pair(mycols, mycols + myrhs)), + subview(h_A, Kokkos::ALL(),Kokkos::make_pair(mycols, mycols + myrhs)) ); + + // Create handle + Adelus::AdelusHandle + ahandle(my_color, sub_comm, matrix_size, nprocs_per_row, numrhs ); + + // Now Solve the Problem + + if( rank == 0 ) + std::cout << " **** Beginning Matrix Solve ****" << std::endl; + + Adelus::FactorSolve (ahandle, A, &secs); + + if( rank == 0) { + std::cout << " ---- Solution time ---- " << secs << " in secs. " << std::endl; + + mflops = 2./3.*pow(matrix_size,3.)/secs/1000000.; + + std::cout << " ***** MFLOPS ***** " << mflops << std::endl; + } + + // Now Check the Solution + + Kokkos::deep_copy( subview(h_A, Kokkos::ALL(),Kokkos::make_pair(mycols, mycols + myrhs)), + subview(A, Kokkos::ALL(),Kokkos::make_pair(mycols, mycols + myrhs)) ); + + // Pack the Answer into the apropriate position + + if ( myrhs > 0 ) { + Kokkos::deep_copy( subview(tempp,Kokkos::make_pair(myfirstrow - 1, myfirstrow - 1 + myrows), + Kokkos::make_pair(my_rhs_offset, my_rhs_offset + myrhs)), + subview(h_A,Kokkos::ALL(),Kokkos::make_pair(mycols, mycols + myrhs)) ); + } + + // All processors get the answer + + MPI_Allreduce(tempp.data(), temp22.data(), matrix_size*numrhs, ADELUS_MPI_DATA_TYPE, MPI_SUM, sub_comm); + + // Perform the Matrix vector product + + ScalarA alpha = 1.0; + ScalarA beta = 0.0; + + KokkosBlas::gemm("N", "N", alpha, + subview(h_A,Kokkos::ALL(),Kokkos::make_pair(0, mycols)), + subview(temp22,Kokkos::make_pair(myfirstcol - 1, myfirstcol - 1 + mycols),Kokkos::ALL()), + beta, subview(tempp,Kokkos::make_pair(myfirstrow - 1, myfirstrow - 1 + myrows),Kokkos::ALL())); + + MPI_Allreduce(tempp.data(), temp3.data(), matrix_size*numrhs, ADELUS_MPI_DATA_TYPE, MPI_SUM, sub_comm); + + if( rank == 0) { + std::cout << "======================================" << std::endl; + std::cout << " ---- Error Calculation ----" << std::endl; + } + if( sub_rank == 0) { + ScalarA alpha_ = -1.0; + + KokkosBlas::axpy(alpha_, rhs, temp3);//temp3=temp3-rhs + + KokkosBlas::nrm2(m_nrm, temp3); + } + + // Machine epsilon Calculation + + othird = four_thirds - 1.; + + tempc = othird + othird + othird; + + eps = fabs(tempc-1.0); + + if ( rank == 0 ) { + std::cout << " Machine eps " << eps << std::endl; + + std::cout << " Threshold = " << eps*1e4 << std::endl; + } + if ( sub_rank == 0 ) { + for (k = 0; k < numrhs; k++) { + std::cout << " Solution " << k << ": ||Ax - b||_2 = " << m_nrm(k) << " on comm " << my_color << std::endl; + + std::cout << " Solution " << k << ": ||b||_2 = " << rhs_nrm(k) << " on comm " << my_color << std::endl; + + std::cout << " Solution " << k << ": ||Ax - b||_2 / ||b||_2 = " << m_nrm(k)/rhs_nrm(k) << " on comm " << my_color << std::endl; + + if ( m_nrm(k)/rhs_nrm(k) > (eps*1e4)) { + std::cout << " **** Solution " << k << " Fails ****" << " on comm " << my_color << std::endl; + result = 1; + break; + } + else { + std::cout << " **** Solution " << k << " Passes ****" << " on comm " << my_color << std::endl; + result = 0; + } + } + } + if ( rank == 0 ) { + std::cout << "======================================" << std::endl; + } + + MPI_Bcast(&result, 1, MPI_INT, 0, sub_comm); + MPI_Allreduce(MPI_IN_PLACE, &result, 1, MPI_INT, MPI_SUM, MPI_COMM_WORLD); + + free(nrhs_procs_rowcomm); + + } + Kokkos::finalize(); + + MPI_Finalize() ; + + return (result); +} diff --git a/packages/amesos2/src/Amesos2_Control.cpp b/packages/amesos2/src/Amesos2_Control.cpp index b164de92f010..c3e8931fe4ff 100644 --- a/packages/amesos2/src/Amesos2_Control.cpp +++ b/packages/amesos2/src/Amesos2_Control.cpp @@ -97,7 +97,7 @@ void Control::setControlParameters( if( parameterList->isType("Iterative refinement") ){ useIterRefine_ = parameterList->get("Iterative refinement"); } - if( parameterList->isType("Number of iterative refinements") ){ + if( parameterList->isType("Number of iterative refinements") ){ maxNumIterRefines_ = parameterList->get("Number of iterative refinements"); } if( parameterList->isType("Verboes for iterative refinement") ){ diff --git a/packages/amesos2/src/Amesos2_PardisoMKL_TypeMap.hpp b/packages/amesos2/src/Amesos2_PardisoMKL_TypeMap.hpp index bb023bfb996a..7ea99656f9dc 100644 --- a/packages/amesos2/src/Amesos2_PardisoMKL_TypeMap.hpp +++ b/packages/amesos2/src/Amesos2_PardisoMKL_TypeMap.hpp @@ -60,27 +60,27 @@ #include #endif +#include +#include + #include #ifdef HAVE_TEUCHOS_COMPLEX #include #endif #include "Amesos2_TypeMap.hpp" -#ifdef _MKL_TYPES_H_ - #undef _MKL_TYPES_H_ - #define PARDISOMKL_PREVIOUS_MKL_TYPES_H -#endif namespace Amesos2{ namespace PMKL { - //Update JDB 6.25.15 - //MKL has changed _INTEGER_t to deprecated - //MKL has changed _INTEGER_t to define from typedef + #undef _MKL_TYPES_H_ #include - #ifdef __MKL_DSS_H + #undef __MKL_DSS_H - #endif #include + + //Update JDB 6.25.15 + //MKL has changed _INTEGER_t to deprecated + //MKL has changed _INTEGER_t to define from typedef #undef _INTEGER_t typedef MKL_INT _INTEGER_t; } // end namespace PMKL @@ -286,8 +286,4 @@ namespace Amesos2 { } // end namespace Amesos -#ifndef PARDISOMKL_PREVIOUS_MKL_TYPES_H - // first time including mkl_types.h - #undef _MKL_TYPES_H_ -#endif #endif // AMESOS2_PARDISOMKL_TYPEMAP_HPP diff --git a/packages/amesos2/src/Amesos2_SolverCore_def.hpp b/packages/amesos2/src/Amesos2_SolverCore_def.hpp index ff8b195ffb33..bb31a4d57605 100644 --- a/packages/amesos2/src/Amesos2_SolverCore_def.hpp +++ b/packages/amesos2/src/Amesos2_SolverCore_def.hpp @@ -309,6 +309,11 @@ SolverCore::solve_ir(const Teuchos::Ptr< Vect crsmat = host_crsmat_t("CrsMatrix", nrows, values_view, static_graph); } + // + // ** First Solve ** + static_cast(this)->solve_impl(Teuchos::outArg(*X), Teuchos::ptrInArg(*B)); + + // auxiliary scalar Kokkos views const int ldx = (this->root_ ? X->getGlobalLength() : 0); const int ldb = (this->root_ ? B->getGlobalLength() : 0); @@ -336,10 +341,6 @@ SolverCore::solve_ir(const Teuchos::Ptr< Vect do_get(not_initialize_data, Eptr, E_view, lde, CONTIGUOUS_AND_ROOTED, rowIndexBase); - // - // first solve - static_cast(this)->solve_impl(Teuchos::outArg(*X), Teuchos::ptrInArg(*B)); - host_magni_view x0norms("x0norms", nrhs); host_magni_view bnorms("bnorms", nrhs); host_magni_view enorms("enorms", nrhs); @@ -372,7 +373,7 @@ SolverCore::solve_ir(const Teuchos::Ptr< Vect // - // iterative refinement + // ** Iterative Refinement ** int numIters = 0; int converged = 0; // 0 = has not converged, 1 = converged for (numIters = 0; numIters < maxNumIters && converged == 0; ++numIters) { @@ -409,14 +410,19 @@ SolverCore::solve_ir(const Teuchos::Ptr< Vect if (this->root_) { KokkosBlas::axpy(one, E_view, X_view); - // compute norm of corrections for "convergence" check - converged = 1; - for (size_t j = 0; j < nrhs; j++) { - auto e_subview = Kokkos::subview(E_view, Kokkos::ALL(), j); - host_vector_t e_1d (const_cast(e_subview.data()), e_subview.extent(0)); - enorms(j) = KokkosBlas::nrm2(e_1d); - if (enorms(j) > eps * x0norms(j)) { - converged = 0; + if (numIters < maxNumIters-1) { + // compute norm of corrections for "convergence" check + converged = 1; + for (size_t j = 0; j < nrhs; j++) { + auto e_subview = Kokkos::subview(E_view, Kokkos::ALL(), j); + host_vector_t e_1d (const_cast(e_subview.data()), e_subview.extent(0)); + enorms(j) = KokkosBlas::nrm2(e_1d); + if (enorms(j) > eps * x0norms(j)) { + converged = 0; + } + } + if (verbose && converged) { + std::cout << " converged " << std::endl; } } } diff --git a/packages/framework/ini-files/config-specs.ini b/packages/framework/ini-files/config-specs.ini index f74a1dcaa657..0fd0f532f060 100644 --- a/packages/framework/ini-files/config-specs.ini +++ b/packages/framework/ini-files/config-specs.ini @@ -887,6 +887,12 @@ opt-set-cmake-var Trilinos_CUDA_NUM_GPUS STRING : ${KOKKOS_NUM_DEVICES|ENV} opt-set-cmake-var Trilinos_CUDA_SLOTS_PER_GPU STRING : 2 # See https://tribits.org/doc/TribitsBuildReference.html#spreading-out-and-limiting-tests-running-on-gpus opt-set-cmake-var Trilinos_AUTOGENERATE_TEST_RESOURCE_FILE BOOL FORCE : ON +opt-set-cmake-var Trilinos_ENABLE_Pliris BOOL FORCE : OFF +# NOTE: Above, FORCE is needed for Trilinos_ENABLE_Pliris=OFF in case somemoe +# changes a file under packages/pliris/ and results in adding +# set(Trilinos_ENABLE_Pliris ON CACHE BOOL "") in the packageEnables.cmake +# file. For more details on the complexity of this, see Trilinos GitHub Issue +# #10931. [ATS2-COMMON-OVERRIDES] # Override TPL enables from [COMMON] @@ -977,6 +983,8 @@ opt-set-cmake-var PanzerAdaptersSTK_main_driver_energy-ss-blocked-tp_DISABLE BOO opt-set-cmake-var PanzerDiscFE_integration_values2_MPI_1_DISABLE BOOL : ON opt-set-cmake-var PanzerMiniEM_MiniEM-BlockPrec_Augmentation_MPI_4_DISABLE BOOL : ON opt-set-cmake-var PanzerMiniEM_MiniEM-BlockPrec_RefMaxwell_MPI_4_DISABLE BOOL : ON +opt-set-cmake-var Pliris_vector_random_MPI_3_DISABLE BOOL : ON +opt-set-cmake-var Pliris_vector_random_MPI_4_DISABLE BOOL : ON opt-set-cmake-var ROL_NonlinearProblemTest_MPI_4_DISABLE BOOL : ON opt-set-cmake-var ROL_adapters_minitensor_test_function_test_01_MPI_4_DISABLE BOOL : ON opt-set-cmake-var ROL_adapters_minitensor_test_function_test_02_MPI_4_DISABLE BOOL : ON @@ -2351,6 +2359,67 @@ use USE-DEPRECATED|YES use PACKAGE-ENABLES|NO-PACKAGE-ENABLES use COMMON_SPACK_TPLS use SEMS_COMMON_CUDA_11 + +# TPL ENABLE/DISABLE settings +opt-set-cmake-var TPL_ENABLE_BLAS BOOL FORCE : ON +opt-set-cmake-var TPL_ENABLE_BinUtils BOOL FORCE : OFF +opt-set-cmake-var TPL_ENABLE_Boost BOOL FORCE : ON +opt-set-cmake-var TPL_ENABLE_CGNS BOOL FORCE : ON +opt-set-cmake-var TPL_ENABLE_CUDA BOOL FORCE : ON +opt-set-cmake-var TPL_ENABLE_CUSPARSE BOOL FORCE : ON +opt-set-cmake-var TPL_ENABLE_DLlib BOOL FORCE : OFF +opt-set-cmake-var TPL_ENABLE_HDF5 BOOL FORCE : ON +opt-set-cmake-var TPL_ENABLE_HWLOC BOOL FORCE : OFF +opt-set-cmake-var TPL_ENABLE_LAPACK BOOL FORCE : ON +opt-set-cmake-var TPL_ENABLE_METIS BOOL FORCE : ON +opt-set-cmake-var TPL_ENABLE_Matio BOOL FORCE : OFF +opt-set-cmake-var TPL_ENABLE_MPI BOOL FORCE : ON +opt-set-cmake-var TPL_ENABLE_Netcdf BOOL FORCE : ON +opt-set-cmake-var TPL_ENABLE_ParMETIS BOOL FORCE : OFF +opt-set-cmake-var TPL_ENABLE_Pthread BOOL FORCE : ON +opt-set-cmake-var TPL_ENABLE_Scotch BOOL FORCE : OFF +opt-set-cmake-var TPL_ENABLE_SuperLU BOOL FORCE : OFF +opt-set-cmake-var TPL_ENABLE_SuperLUDist BOOL FORCE : OFF +opt-set-cmake-var TPL_ENABLE_Zlib BOOL FORCE : ON + +#TPL_*_LIBRARIES +opt-set-cmake-var TPL_BLAS_LIBRARIES STRING FORCE : -L${BLAS_ROOT|ENV}/lib;-lopenblas;-lgfortran;-lgomp +opt-set-cmake-var TPL_BLAS_LIBRARY_DIRS STRING FORCE : ${BLAS_ROOT|ENV}/lib +opt-set-cmake-var TPL_BoostLib_LIBRARIES STRING FORCE : ${BOOST_LIB|ENV}/libboost_program_options.a;${BOOST_LIB|ENV}/libboost_system.a +opt-set-cmake-var TPL_Boost_LIBRARIES STRING FORCE : ${BOOST_LIB|ENV}/libboost_program_options.a;${BOOST_LIB|ENV}/libboost_system.a +opt-set-cmake-var TPL_DLlib_LIBRARIES FILEPATH FORCE : "-ldl" +opt-set-cmake-var TPL_HDF5_LIBRARIES STRING FORCE : ${HDF5_LIB|ENV}/libhdf5_hl.so;${HDF5_LIB|ENV}/libhdf5.a;${ZLIB_LIB|ENV}/libz.a;-ldl +opt-set-cmake-var TPL_LAPACK_LIBRARIES STRING FORCE : -L${BLAS_ROOT|ENV}/lib;-lopenblas;-lgfortran;-lgomp +opt-set-cmake-var TPL_LAPACK_LIBRARY_DIRS STRING FORCE : ${BLAS_ROOT|ENV}/lib +opt-set-cmake-var TPL_METIS_LIBRARIES STRING FORCE : ${METIS_LIB|ENV}/libmetis.so +opt-set-cmake-var TPL_Netcdf_LIBRARIES STRING FORCE : -L${NETCDF_C_ROOT|ENV}/lib64;${NETCDF_C_ROOT|ENV}/lib/libnetcdf.a;${PARALLEL_NETCDF_ROOT|ENV}/lib/libpnetcdf.a;${TPL_HDF5_LIBRARIES|CMAKE} + +#TPL_[INCLUDE|LIBRARY]_DIRS +opt-set-cmake-var Netcdf_INCLUDE_DIRS STRING FORCE : ${NETCDF_C_INC|ENV} +opt-set-cmake-var ParMETIS_INCLUDE_DIRS STRING FORCE : ${PARMETIS_INC|ENV} +opt-set-cmake-var ParMETIS_LIBRARY_DIRS STRING FORCE : ${PARMETIS_LIB|ENV} +opt-set-cmake-var Scotch_INCLUDE_DIRS STRING FORCE : ${SCOTCH_INC|ENV} +opt-set-cmake-var Scotch_LIBRARY_DIRS STRING FORCE : ${SCOTCH_LIB|ENV} +opt-set-cmake-var SuperLU_INCLUDE_DIRS STRING FORCE : ${SUPERLU_INC|ENV} +opt-set-cmake-var SuperLU_LIBRARY_DIRS STRING FORCE : ${SUPERLU_LIB|ENV} + +#CXX Settings +opt-set-cmake-var CMAKE_CXX_STANDARD STRING FORCE : 17 +opt-set-cmake-var CMAKE_CXX_FLAGS STRING : -fPIC -Wall -Warray-bounds -Wchar-subscripts -Wcomment -Wenum-compare -Wformat -Wuninitialized -Wmaybe-uninitialized -Wmain -Wnarrowing -Wnonnull -Wparentheses -Wreorder -Wreturn-type -Wsign-compare -Wsequence-point -Wtrigraphs -Wunused-function -Wunused-but-set-variable -Wunused-variable -Wwrite-strings + +#Package Options +opt-set-cmake-var EpetraExt_ENABLE_HDF5 BOOL FORCE : OFF +opt-set-cmake-var Kokkos_ENABLE_CUDA BOOL FORCE : ON +opt-set-cmake-var Kokkos_ENABLE_CUDA_LAMBDA BOOL FORCE : ON +opt-set-cmake-var Kokkos_ENABLE_CXX11_DISPATCH_LAMBDA BOOL FORCE : ON +#opt-set-cmake-var Kokkos_ENABLE_Debug_Bounds_Check BOOL FORCE : ON +opt-set-cmake-var MPI_EXEC_POST_NUMPROCS_FLAGS STRING FORCE : "-map-by;socket:PE=4" +opt-set-cmake-var Panzer_FADTYPE STRING FORCE : "Sacado::Fad::DFad" +opt-set-cmake-var Phalanx_KOKKOS_DEVICE_TYPE STRING FORCE : CUDA +opt-set-cmake-var Sacado_ENABLE_HIERARCHICAL_DFAD BOOL FORCE : ON +opt-set-cmake-var Tpetra_INST_SERIAL BOOL FORCE : ON +opt-set-cmake-var Zoltan_ENABLE_Scotch BOOL FORCE : OFF + use RHEL7_SEMS_CUDA_UVM_OFF_DISABLES use RHEL7_POST diff --git a/packages/ifpack2/src/Ifpack2_AdditiveSchwarz_decl.hpp b/packages/ifpack2/src/Ifpack2_AdditiveSchwarz_decl.hpp index 8c0a135e5451..310469f44008 100644 --- a/packages/ifpack2/src/Ifpack2_AdditiveSchwarz_decl.hpp +++ b/packages/ifpack2/src/Ifpack2_AdditiveSchwarz_decl.hpp @@ -809,7 +809,12 @@ class AdditiveSchwarz : mutable Teuchos::RCP validParams_; //! Combine mode for off-process elements (only if overlap is used) + //! To average values in overlap region, set CombineMode_ + //! to ADD and AvgOverlap_ to true (can be done via + //! param list by setting "schwarz: combine mode" to "AVG") + //! Don't average with CG as preconditioner is nonsymmetric. Tpetra::CombineMode CombineMode_ = Tpetra::ZERO; + bool AvgOverlap_ = false; //! If \c true, reorder the local matrix. bool UseReordering_ = false; //! Record reordering for output purposes. @@ -851,6 +856,8 @@ class AdditiveSchwarz : mutable std::unique_ptr overlapping_B_; //! Cached local (possibly) overlapping output (multi)vector. mutable std::unique_ptr overlapping_Y_; + //! Cached local (possibly) vector that indicates how many copies of a dof exist due to overlap + mutable std::unique_ptr num_overlap_copies_; //! Cached residual (multi)vector. mutable std::unique_ptr R_; //! Cached intermediate result (multi)vector. diff --git a/packages/ifpack2/src/Ifpack2_AdditiveSchwarz_def.hpp b/packages/ifpack2/src/Ifpack2_AdditiveSchwarz_def.hpp index c7d00563a275..4c1987d2ba3d 100644 --- a/packages/ifpack2/src/Ifpack2_AdditiveSchwarz_def.hpp +++ b/packages/ifpack2/src/Ifpack2_AdditiveSchwarz_def.hpp @@ -425,6 +425,19 @@ apply (const Tpetra::MultiVectorgetNumVectors () != numVectors) { C_.reset (new MV (Y.getMap (), numVectors, false)); } + // If taking averages in overlap region, we need to compute + // the number of procs who have a copy of each overlap dof + Teuchos::ArrayRCP dataNumOverlapCopies; + if (IsOverlapping_ && AvgOverlap_) { + if (num_overlap_copies_.get() == nullptr) { + num_overlap_copies_.reset (new MV (Y.getMap (), 1, false)); + RCP onesVec( new MV(OverlappingMatrix_->getRowMap(), 1, false) ); + onesVec->putScalar(Teuchos::ScalarTraits::one()); + rcp_dynamic_cast> (OverlappingMatrix_)->exportMultiVector (*onesVec, *(num_overlap_copies_.get ()), CombineMode_); + } + dataNumOverlapCopies = num_overlap_copies_.get ()->getDataNonConst(0); + } + MV* R = R_.get (); MV* C = C_.get (); @@ -550,6 +563,14 @@ apply (const Tpetra::MultiVectorexportMultiVector (*OverlappingY, *C, CombineMode_); + + // average solution in overlap regions if requested via "schwarz: combine mode" "AVG" + if (AvgOverlap_) { + Teuchos::ArrayRCP dataC = C->getDataNonConst(0); + for (int i = 0; i < (int) C->getMap()->getLocalNumElements(); i++) { + dataC[i] = dataC[i]/dataNumOverlapCopies[i]; + } + } } else { // mfh 16 Apr 2014: Make a view of Y with the same Map as @@ -785,7 +806,15 @@ setParameterList (const Teuchos::RCP& plist) using vs2e_type = StringToIntegralParameterEntryValidator; RCP vs2e = rcp_dynamic_cast (v, true); - const ParameterEntry& inputEntry = plist->getEntry (cmParamName); + ParameterEntry& inputEntry = plist->getEntry (cmParamName); + // As AVG is only a Schwarz option and does not exist in Tpetra's + // version of CombineMode, we use a separate boolean local to + // Schwarz in conjunction with CombineMode_ == ADD to handle + // averaging. Here, we change input entry to ADD and set the boolean. + if (strncmp(Teuchos::getValue(inputEntry).c_str(),"AVG",3) == 0) { + inputEntry.template setValue("ADD"); + AvgOverlap_ = true; + } CombineMode_ = vs2e->getIntegralValue (inputEntry, cmParamName); } } @@ -802,6 +831,7 @@ setParameterList (const Teuchos::RCP& plist) if (plist->sublist("subdomain solver parameters").get("partitioner: type") == "user") { if (CombineMode_ == Tpetra::ADD) plist->sublist("subdomain solver parameters").set("partitioner: combine mode","ADD"); if (CombineMode_ == Tpetra::ZERO) plist->sublist("subdomain solver parameters").set("partitioner: combine mode","ZERO"); + AvgOverlap_ = false; // averaging already taken care of by the partitioner: nonsymmetric overlap combine option } } } diff --git a/packages/ifpack2/src/Ifpack2_BlockRelaxation_def.hpp b/packages/ifpack2/src/Ifpack2_BlockRelaxation_def.hpp index 53bd87de9a5d..78f5085479d0 100644 --- a/packages/ifpack2/src/Ifpack2_BlockRelaxation_def.hpp +++ b/packages/ifpack2/src/Ifpack2_BlockRelaxation_def.hpp @@ -1053,6 +1053,11 @@ description () const } else { out << "INVALID"; } + + // BlockCrs if we have that + if(hasBlockCrsMatrix_) + out<<", BlockCrs"; + // Print the approximate # rows per part int approx_rows_per_part = A_->getLocalNumRows()/Partitioner_->numLocalParts(); out <<", blocksize: "< #include #include +#include #ifndef _SPILUKHANDLE_HPP #define _SPILUKHANDLE_HPP @@ -87,6 +88,12 @@ class SPILUKHandle { typedef typename Kokkos::View nnz_lno_view_t; + typedef typename Kokkos::View + nnz_row_view_host_t; + + typedef typename Kokkos::View + nnz_lno_view_host_t; + typedef typename std::make_signed< typename nnz_row_view_t::non_const_value_type>::type signed_integral_t; typedef Kokkos::View signed_nnz_lno_view_t; + typedef Kokkos::View + work_view_t; + private: nnz_row_view_t level_list; // level IDs which the rows belong to nnz_lno_view_t level_idx; // the list of rows in each level nnz_lno_view_t level_ptr; // the starting index (into the view level_idx) of each level - nnz_lno_view_t level_nchunks; // number of chunks of rows at each level - nnz_lno_view_t + nnz_lno_view_host_t level_nchunks; // number of chunks of rows at each level + nnz_lno_view_host_t level_nrowsperchunk; // maximum number of rows among chunks at each level + work_view_t iw; // working view for mapping dense indices to sparse indices size_type nrows; size_type nlevels; @@ -128,6 +140,7 @@ class SPILUKHandle { level_ptr(), level_nchunks(), level_nrowsperchunk(), + iw(), nrows(nrows_), nlevels(0), nnzL(nnzL_), @@ -147,11 +160,12 @@ class SPILUKHandle { set_nnzU(nnzU_); set_level_maxrows(0); set_level_maxrowsperchunk(0); - level_list = nnz_row_view_t("level_list", nrows_), - level_idx = nnz_lno_view_t("level_idx", nrows_), - level_ptr = nnz_lno_view_t("level_ptr", nrows_ + 1), - level_nchunks = nnz_lno_view_t(), level_nrowsperchunk = nnz_lno_view_t(), - reset_symbolic_complete(); + level_list = nnz_row_view_t("level_list", nrows_), + level_idx = nnz_lno_view_t("level_idx", nrows_), + level_ptr = nnz_lno_view_t("level_ptr", nrows_ + 1), + level_nchunks = nnz_lno_view_host_t(), + level_nrowsperchunk = nnz_lno_view_host_t(), reset_symbolic_complete(), + iw = work_view_t(); } virtual ~SPILUKHandle(){}; @@ -170,17 +184,28 @@ class SPILUKHandle { nnz_lno_view_t get_level_ptr() const { return level_ptr; } KOKKOS_INLINE_FUNCTION - nnz_lno_view_t get_level_nchunks() const { return level_nchunks; } + nnz_lno_view_host_t get_level_nchunks() const { return level_nchunks; } void alloc_level_nchunks(const size_type nlevels_) { - level_nchunks = nnz_lno_view_t("level_nchunks", nlevels_); + level_nchunks = nnz_lno_view_host_t("level_nchunks", nlevels_); } KOKKOS_INLINE_FUNCTION - nnz_lno_view_t get_level_nrowsperchunk() const { return level_nrowsperchunk; } + nnz_lno_view_host_t get_level_nrowsperchunk() const { + return level_nrowsperchunk; + } void alloc_level_nrowsperchunk(const size_type nlevels_) { - level_nrowsperchunk = nnz_lno_view_t("level_nrowsperchunk", nlevels_); + level_nrowsperchunk = nnz_lno_view_host_t("level_nrowsperchunk", nlevels_); + } + + KOKKOS_INLINE_FUNCTION + work_view_t get_iw() const { return iw; } + + void alloc_iw(const size_type nrows_, const size_type ncols_) { + iw = work_view_t(Kokkos::view_alloc(Kokkos::WithoutInitializing, "iw"), + nrows_, ncols_); + Kokkos::deep_copy(iw, nnz_lno_t(-1)); } KOKKOS_INLINE_FUNCTION @@ -238,8 +263,7 @@ class SPILUKHandle { if (algm == SPILUKAlgorithm::SEQLVLSCHD_TP1) std::cout << "SEQLVLSCHD_TP1" << std::endl; - /* - if ( algm == SPILUKAlgorithm::SEQLVLSCHED_TP2 ) { + /*if ( algm == SPILUKAlgorithm::SEQLVLSCHED_TP2 ) { std::cout << "SEQLVLSCHED_TP2" << std::endl;; std::cout << "WARNING: With CUDA this is currently only reliable with int-int ordinal-offset pair" << std::endl; diff --git a/packages/kokkos-kernels/src/sparse/impl/KokkosSparse_spiluk_numeric_impl.hpp b/packages/kokkos-kernels/src/sparse/impl/KokkosSparse_spiluk_numeric_impl.hpp index d0b80ace6928..4af8606dfbde 100644 --- a/packages/kokkos-kernels/src/sparse/impl/KokkosSparse_spiluk_numeric_impl.hpp +++ b/packages/kokkos-kernels/src/sparse/impl/KokkosSparse_spiluk_numeric_impl.hpp @@ -242,52 +242,54 @@ struct ILUKLvlSchedTP1NumericFunctor { KOKKOS_INLINE_FUNCTION void operator()(const member_type &team) const { - auto my_league = team.league_rank(); // map to rowid - auto rowid = level_idx(my_league + lev_start); - auto my_team = team.team_rank(); + nnz_lno_t my_team = static_cast(team.league_rank()); + nnz_lno_t rowid = + static_cast(level_idx(my_team + lev_start)); // map to rowid - auto k1 = L_row_map(rowid); - auto k2 = L_row_map(rowid + 1); + size_type k1 = static_cast(L_row_map(rowid)); + size_type k2 = static_cast(L_row_map(rowid + 1)); #ifdef KEEP_DIAG Kokkos::parallel_for(Kokkos::TeamThreadRange(team, k1, k2 - 1), [&](const size_type k) { - auto col = L_entries(k); - L_values(k) = 0.0; - iw(my_league, col) = k; + nnz_lno_t col = static_cast(L_entries(k)); + L_values(k) = 0.0; + iw(my_team, col) = static_cast(k); }); #else Kokkos::parallel_for(Kokkos::TeamThreadRange(team, k1, k2), [&](const size_type k) { - auto col = L_entries(k); - L_values(k) = 0.0; - iw(my_league, col) = k; + nnz_lno_t col = static_cast(L_entries(k)); + L_values(k) = 0.0; + iw(my_team, col) = static_cast(k); }); #endif #ifdef KEEP_DIAG - if (my_team == 0) L_values(k2 - 1) = scalar_t(1.0); + // if (my_thread == 0) L_values(k2 - 1) = scalar_t(1.0); + Kokkos::single(Kokkos::PerTeam(team), + [&]() { L_values(k2 - 1) = scalar_t(1.0); }); #endif team.team_barrier(); - k1 = U_row_map(rowid); - k2 = U_row_map(rowid + 1); + k1 = static_cast(U_row_map(rowid)); + k2 = static_cast(U_row_map(rowid + 1)); Kokkos::parallel_for(Kokkos::TeamThreadRange(team, k1, k2), [&](const size_type k) { - auto col = U_entries(k); - U_values(k) = 0.0; - iw(my_league, col) = k; + nnz_lno_t col = static_cast(U_entries(k)); + U_values(k) = 0.0; + iw(my_team, col) = static_cast(k); }); team.team_barrier(); // Unpack the ith row of A - k1 = A_row_map(rowid); - k2 = A_row_map(rowid + 1); + k1 = static_cast(A_row_map(rowid)); + k2 = static_cast(A_row_map(rowid + 1)); Kokkos::parallel_for(Kokkos::TeamThreadRange(team, k1, k2), [&](const size_type k) { - auto col = A_entries(k); - auto ipos = iw(my_league, col); + nnz_lno_t col = static_cast(A_entries(k)); + nnz_lno_t ipos = iw(my_team, col); if (col < rowid) L_values(ipos) = A_values(k); else @@ -297,20 +299,22 @@ struct ILUKLvlSchedTP1NumericFunctor { team.team_barrier(); // Eliminate prev rows - k1 = L_row_map(rowid); - k2 = L_row_map(rowid + 1); + k1 = static_cast(L_row_map(rowid)); + k2 = static_cast(L_row_map(rowid + 1)); #ifdef KEEP_DIAG - for (auto k = k1; k < k2 - 1; ++k) { + for (size_type k = k1; k < k2 - 1; k++) #else - for (auto k = k1; k < k2; ++k) { + for (size_type k = k1; k < k2; k++) #endif - auto prev_row = L_entries(k); + { + nnz_lno_t prev_row = L_entries(k); #ifdef KEEP_DIAG - auto fact = L_values(k) / U_values(U_row_map(prev_row)); + scalar_t fact = L_values(k) / U_values(U_row_map(prev_row)); #else - auto fact = L_values(k) * U_values(U_row_map(prev_row)); + scalar_t fact = L_values(k) * U_values(U_row_map(prev_row)); #endif - if (my_team == 0) L_values(k) = fact; + // if (my_thread == 0) L_values(k) = fact; + Kokkos::single(Kokkos::PerTeam(team), [&]() { L_values(k) = fact; }); team.team_barrier(); @@ -318,10 +322,10 @@ struct ILUKLvlSchedTP1NumericFunctor { Kokkos::TeamThreadRange(team, U_row_map(prev_row) + 1, U_row_map(prev_row + 1)), [&](const size_type kk) { - auto col = U_entries(kk); - auto ipos = iw(my_league, col); + nnz_lno_t col = static_cast(U_entries(kk)); + nnz_lno_t ipos = iw(my_team, col); + auto lxu = -U_values(kk) * fact; if (ipos != -1) { - auto lxu = -U_values(kk) * fact; if (col < rowid) Kokkos::atomic_add(&L_values(ipos), lxu); else @@ -332,40 +336,49 @@ struct ILUKLvlSchedTP1NumericFunctor { team.team_barrier(); } // end for k - if (my_team == 0) { + // if (my_thread == 0) { + Kokkos::single(Kokkos::PerTeam(team), [&]() { + nnz_lno_t ipos = iw(my_team, rowid); #ifdef KEEP_DIAG - if (U_values(iw(my_league, rowid)) == 0.0) { - U_values(iw(my_league, rowid)) = 1e6; + if (U_values(ipos) == 0.0) { + U_values(ipos) = 1e6; } #else - if (U_values(iw(my_league, rowid)) == 0.0) { - U_values(iw(my_league, rowid)) = 1e6; + if (U_values(ipos) == 0.0) { + U_values(ipos) = 1e6; } else { - U_values(iw(my_league, rowid)) = 1.0 / U_values(iw(my_league, rowid)); + U_values(ipos) = 1.0 / U_values(ipos); } #endif - } + }); + //} team.team_barrier(); // Reset - k1 = L_row_map(rowid); - k2 = L_row_map(rowid + 1); + k1 = static_cast(L_row_map(rowid)); + k2 = static_cast(L_row_map(rowid + 1)); #ifdef KEEP_DIAG - Kokkos::parallel_for( - Kokkos::TeamThreadRange(team, k1, k2 - 1), - [&](const size_type k) { iw(my_league, L_entries(k)) = -1; }); + Kokkos::parallel_for(Kokkos::TeamThreadRange(team, k1, k2 - 1), + [&](const size_type k) { + nnz_lno_t col = static_cast(L_entries(k)); + iw(my_team, col) = -1; + }); #else - Kokkos::parallel_for( - Kokkos::TeamThreadRange(team, k1, k2), - [&](const size_type k) { iw(my_league, L_entries(k)) = -1; }); + Kokkos::parallel_for(Kokkos::TeamThreadRange(team, k1, k2), + [&](const size_type k) { + nnz_lno_t col = static_cast(L_entries(k)); + iw(my_team, col) = -1; + }); #endif - k1 = U_row_map(rowid); - k2 = U_row_map(rowid + 1); - Kokkos::parallel_for( - Kokkos::TeamThreadRange(team, k1, k2), - [&](const size_type k) { iw(my_league, U_entries(k)) = -1; }); + k1 = static_cast(U_row_map(rowid)); + k2 = static_cast(U_row_map(rowid + 1)); + Kokkos::parallel_for(Kokkos::TeamThreadRange(team, k1, k2), + [&](const size_type k) { + nnz_lno_t col = static_cast(U_entries(k)); + iw(my_team, col) = -1; + }); } }; @@ -379,23 +392,17 @@ void iluk_numeric(IlukHandle &thandle, const ARowMapType &A_row_map, LValuesType &L_values, const URowMapType &U_row_map, const UEntriesType &U_entries, UValuesType &U_values) { using execution_space = typename IlukHandle::execution_space; - using memory_space = typename IlukHandle::memory_space; using size_type = typename IlukHandle::size_type; using nnz_lno_t = typename IlukHandle::nnz_lno_t; using HandleDeviceEntriesType = typename IlukHandle::nnz_lno_view_t; - using WorkViewType = - Kokkos::View>; - using LevelHostViewType = Kokkos::View; + using WorkViewType = typename IlukHandle::work_view_t; + using LevelHostViewType = typename IlukHandle::nnz_lno_view_host_t; size_type nlevels = thandle.get_num_levels(); - size_type nrows = thandle.get_nrows(); // Keep these as host View, create device version and copy back to host - HandleDeviceEntriesType level_ptr = thandle.get_level_ptr(); - HandleDeviceEntriesType level_idx = thandle.get_level_idx(); - HandleDeviceEntriesType level_nchunks = thandle.get_level_nchunks(); - HandleDeviceEntriesType level_nrowsperchunk = - thandle.get_level_nrowsperchunk(); + HandleDeviceEntriesType level_ptr = thandle.get_level_ptr(); + HandleDeviceEntriesType level_idx = thandle.get_level_idx(); // Make level_ptr_h a separate allocation, since it will be accessed on host // between kernel launches. If a mirror were used and level_ptr is in UVM @@ -409,25 +416,13 @@ void iluk_numeric(IlukHandle &thandle, const ARowMapType &A_row_map, level_ptr.extent(0)); Kokkos::deep_copy(level_ptr_h, level_ptr); + //{ if (thandle.get_algorithm() == KokkosSparse::Experimental::SPILUKAlgorithm::SEQLVLSCHD_TP1) { - level_nchunks_h = LevelHostViewType( - Kokkos::view_alloc(Kokkos::WithoutInitializing, "Host level nchunks"), - level_nchunks.extent(0)); - level_nrowsperchunk_h = - LevelHostViewType(Kokkos::view_alloc(Kokkos::WithoutInitializing, - "Host level nrowsperchunk"), - level_nrowsperchunk.extent(0)); - Kokkos::deep_copy(level_nchunks_h, level_nchunks); - Kokkos::deep_copy(level_nrowsperchunk_h, level_nrowsperchunk); - iw = WorkViewType(Kokkos::view_alloc(Kokkos::WithoutInitializing, "iw"), - thandle.get_level_maxrowsperchunk(), nrows); - Kokkos::deep_copy(iw, nnz_lno_t(-1)); - } else { - iw = WorkViewType(Kokkos::view_alloc(Kokkos::WithoutInitializing, "iw"), - thandle.get_level_maxrows(), nrows); - Kokkos::deep_copy(iw, nnz_lno_t(-1)); + level_nchunks_h = thandle.get_level_nchunks(); + level_nrowsperchunk_h = thandle.get_level_nrowsperchunk(); } + iw = thandle.get_iw(); // Main loop must be performed sequential. Question: Try out Cuda's graph // stuff to reduce kernel launch overhead @@ -476,49 +471,13 @@ void iluk_numeric(IlukHandle &thandle, const ARowMapType &A_row_map, else Kokkos::parallel_for("parfor_l_team", policy_type(lvl_nrows_chunk, team_size), tstf); - + Kokkos::fence(); lvl_rowid_start += lvl_nrows_chunk; } } - // /* - // // TP2 algorithm has issues with some offset-ordinal combo to be - // addressed else if ( thandle.get_algorithm() == - // KokkosSparse::Experimental::SPTRSVAlgorithm::SEQLVLSCHED_TP2 ) { - // typedef Kokkos::TeamPolicy tvt_policy_type; - // - // int team_size = thandle.get_team_size(); - // if ( team_size == -1 ) { - // team_size = std::is_same< typename - // Kokkos::DefaultExecutionSpace::memory_space, Kokkos::HostSpace - // >::value ? 1 : 128; - // } - // int vector_size = thandle.get_team_size(); - // if ( vector_size == -1 ) { - // vector_size = std::is_same< typename - // Kokkos::DefaultExecutionSpace::memory_space, Kokkos::HostSpace - // >::value ? 1 : 4; - // } - // - // // This impl: "chunk" lvl_nodes into node_groups; a league_rank - // is responsible for processing that many nodes - // // TeamThreadRange over number of node_groups - // // To avoid masking threads, 1 thread (team) per node in - // node_group - // // ThreadVectorRange responsible for the actual solve - // computation const int node_groups = team_size; - // - // LowerTriLvlSchedTP2SolverFunctor - // tstf(row_map, entries, values, lhs, rhs, nodes_grouped_by_level, - // row_count, node_groups); - // Kokkos::parallel_for("parfor_u_team_vector", tvt_policy_type( - // (int)std::ceil((float)lvl_nodes/(float)node_groups) , team_size, - // vector_size ), tstf); - // } // end elseif - // */ - } // end if } // end for lvl + //} // Output check #ifdef NUMERIC_OUTPUT_INFO @@ -526,7 +485,7 @@ void iluk_numeric(IlukHandle &thandle, const ARowMapType &A_row_map, std::cout << " nnzL: " << thandle.get_nnzL() << std::endl; std::cout << " L_row_map = "; - for (size_type i = 0; i < nrows + 1; ++i) { + for (size_type i = 0; i < thandle.get_nrows() + 1; ++i) { std::cout << L_row_map(i) << " "; } std::cout << std::endl; @@ -545,7 +504,7 @@ void iluk_numeric(IlukHandle &thandle, const ARowMapType &A_row_map, std::cout << " nnzU: " << thandle.get_nnzU() << std::endl; std::cout << " U_row_map = "; - for (size_type i = 0; i < nrows + 1; ++i) { + for (size_type i = 0; i < thandle.get_nrows() + 1; ++i) { std::cout << U_row_map(i) << " "; } std::cout << std::endl; diff --git a/packages/kokkos-kernels/src/sparse/impl/KokkosSparse_spiluk_symbolic_impl.hpp b/packages/kokkos-kernels/src/sparse/impl/KokkosSparse_spiluk_symbolic_impl.hpp index 90bb88e05709..691d6249639f 100644 --- a/packages/kokkos-kernels/src/sparse/impl/KokkosSparse_spiluk_symbolic_impl.hpp +++ b/packages/kokkos-kernels/src/sparse/impl/KokkosSparse_spiluk_symbolic_impl.hpp @@ -121,15 +121,15 @@ void level_sched(IlukHandle& thandle, const RowMapType row_map, // SEQLVLSCHD_TP1 algorithm (chunks) template -void level_sched(IlukHandle& thandle, const RowMapType row_map, - const EntriesType entries, LevelType1& level_list, - LevelType2& level_ptr, LevelType2& level_idx, - LevelType3& level_nchunks, LevelType3& level_nrowsperchunk, - size_type& nlevels) { + class LevelType1, class LevelType2, class size_type> +void level_sched_tp(IlukHandle& thandle, const RowMapType row_map, + const EntriesType entries, LevelType1& level_list, + LevelType2& level_ptr, LevelType2& level_idx, + size_type& nlevels) { // Scheduling currently compute on host - using nnz_lno_t = typename IlukHandle::nnz_lno_t; + using nnz_lno_t = typename IlukHandle::nnz_lno_t; + using nnz_lno_view_host_t = typename IlukHandle::nnz_lno_view_host_t; size_type nrows = thandle.get_nrows(); @@ -168,11 +168,10 @@ void level_sched(IlukHandle& thandle, const RowMapType row_map, level_ptr(0) = 0; // Find max rows, number of chunks, max rows of chunks across levels - using HostViewType = - Kokkos::View; - - HostViewType lnchunks("lnchunks", nlevels); - HostViewType lnrowsperchunk("lnrowsperchunk", nlevels); + thandle.alloc_level_nchunks(nlevels); + thandle.alloc_level_nrowsperchunk(nlevels); + nnz_lno_view_host_t lnchunks = thandle.get_level_nchunks(); + nnz_lno_view_host_t lnrowsperchunk = thandle.get_level_nrowsperchunk(); #ifdef KOKKOS_ENABLE_CUDA using memory_space = typename IlukHandle::memory_space; @@ -214,9 +213,6 @@ void level_sched(IlukHandle& thandle, const RowMapType row_map, thandle.set_num_levels(nlevels); thandle.set_level_maxrows(maxrows); thandle.set_level_maxrowsperchunk(maxrowsperchunk); - - level_nchunks = lnchunks; - level_nrowsperchunk = lnrowsperchunk; } // Linear Search for the smallest row index @@ -326,7 +322,6 @@ void iluk_symbolic(IlukHandle& thandle, HostTmpViewType h_iw("h_iw", nrows); HostTmpViewType h_iL("h_iL", nrows); HostTmpViewType h_llev("h_llev", nrows); - HostTmpViewType level_nchunks, level_nrowsperchunk; size_type cntL = 0; size_type cntU = 0; @@ -472,19 +467,13 @@ void iluk_symbolic(IlukHandle& thandle, // Level scheduling on L if (thandle.get_algorithm() == KokkosSparse::Experimental::SPILUKAlgorithm::SEQLVLSCHD_TP1) { - level_sched(thandle, L_row_map, L_entries, level_list, level_ptr, - level_idx, level_nchunks, level_nrowsperchunk, nlev); - - thandle.alloc_level_nchunks(nlev); - thandle.alloc_level_nrowsperchunk(nlev); - HandleDeviceEntriesType dlevel_nchunks = thandle.get_level_nchunks(); - HandleDeviceEntriesType dlevel_nrowsperchunk = - thandle.get_level_nrowsperchunk(); - Kokkos::deep_copy(dlevel_nchunks, level_nchunks); - Kokkos::deep_copy(dlevel_nrowsperchunk, level_nrowsperchunk); + level_sched_tp(thandle, L_row_map, L_entries, level_list, level_ptr, + level_idx, nlev); + thandle.alloc_iw(thandle.get_level_maxrowsperchunk(), nrows); } else { level_sched(thandle, L_row_map, L_entries, level_list, level_ptr, level_idx, nlev); + thandle.alloc_iw(thandle.get_level_maxrows(), nrows); } Kokkos::deep_copy(dlevel_ptr, level_ptr); diff --git a/packages/krino/delete_small_elements/Akri_DeleteSmallElementsMain.cpp b/packages/krino/delete_small_elements/Akri_DeleteSmallElementsMain.cpp index 34aec55776d8..1b07ed4ee21e 100644 --- a/packages/krino/delete_small_elements/Akri_DeleteSmallElementsMain.cpp +++ b/packages/krino/delete_small_elements/Akri_DeleteSmallElementsMain.cpp @@ -108,6 +108,7 @@ static bool delete_small_elements(const DeleteSmallElementsInputData& inputData, { std::shared_ptr bulk = stk::mesh::MeshBuilder(comm).create(); stk::mesh::MetaData& meta = bulk->mesh_meta_data(); + meta.use_simple_fields(); stk::io::fill_mesh_with_auto_decomp(inputData.meshIn, *bulk); diff --git a/packages/krino/krino/CMakeLists.txt b/packages/krino/krino/CMakeLists.txt index 075eaee8f8f8..ae0b8379575d 100644 --- a/packages/krino/krino/CMakeLists.txt +++ b/packages/krino/krino/CMakeLists.txt @@ -4,7 +4,7 @@ add_subdirectory(adaptivity_interface) add_subdirectory(region) add_subdirectory(rebalance_utils) add_subdirectory(parser) -tribits_add_test_directories(unit_tests) +add_subdirectory(unit_tests) SET(SOURCES_MAIN Apps_krino.cpp) diff --git a/packages/krino/krino/krino_lib/Akri_AnalyticSurf.cpp b/packages/krino/krino/krino_lib/Akri_AnalyticSurf.cpp index 374396fb7bbf..c5523279a44b 100644 --- a/packages/krino/krino/krino_lib/Akri_AnalyticSurf.cpp +++ b/packages/krino/krino/krino_lib/Akri_AnalyticSurf.cpp @@ -270,9 +270,7 @@ Plane::point_signed_distance(const Vector3d &x) const BoundingBox Plane::get_bounding_box() { - //bounding box is entire domain - return BoundingBox(Vector3d(-std::numeric_limits::max(), -std::numeric_limits::max(), -std::numeric_limits::max()), - Vector3d(std::numeric_limits::max(), std::numeric_limits::max(), std::numeric_limits::max())); + return BoundingBox::ENTIRE_DOMAIN; } Random::Random(const unsigned long seed) @@ -292,32 +290,20 @@ Random::point_signed_distance(const Vector3d &x) const BoundingBox Random::get_bounding_box() { - //bounding box is entire domain - return BoundingBox(Vector3d(-std::numeric_limits::max(), -std::numeric_limits::max(), -std::numeric_limits::max()), - Vector3d(std::numeric_limits::max(), std::numeric_limits::max(), std::numeric_limits::max())); + return BoundingBox::ENTIRE_DOMAIN; } -Analytic_Isosurface::Analytic_Isosurface() - : SurfaceThatDoesntTakeAdvantageOfNarrowBandAndThereforeHasCorrectSign() -{ -} - -BoundingBox -Analytic_Isosurface::get_bounding_box() +LevelSet_String_Function::LevelSet_String_Function(const std::string & expression) + : SurfaceThatDoesntTakeAdvantageOfNarrowBandAndThereforeHasCorrectSign(), + myExpression(expression), + myBoundingBox(BoundingBox::ENTIRE_DOMAIN) { - return BoundingBox( - Vector3d(-1.,-1.,-1.), - Vector3d(1.,1.,1.) - ); } double -Analytic_Isosurface::point_signed_distance(const Vector3d &coord) const +LevelSet_String_Function::point_signed_distance(const Vector3d &coord) const { - const double x = coord[0]; - const double y = coord[1]; - const double z = coord[2]; - return 2.*y*(y*y-3.*x*x)*(1.-z*z) + std::pow(x*x+y*y,2) - (9.*z*z-1.)*(1.-z*z); + return myExpression.evaluate(coord); } } // namespace krino diff --git a/packages/krino/krino/krino_lib/Akri_AnalyticSurf.hpp b/packages/krino/krino/krino_lib/Akri_AnalyticSurf.hpp index 18f70a5bb0e1..efb7753341b0 100644 --- a/packages/krino/krino/krino_lib/Akri_AnalyticSurf.hpp +++ b/packages/krino/krino/krino_lib/Akri_AnalyticSurf.hpp @@ -17,6 +17,7 @@ #include #include +#include namespace stk { namespace mesh { class BulkData; } } namespace stk { namespace mesh { class Entity; } } @@ -168,17 +169,23 @@ class Random : public SurfaceThatDoesntTakeAdvantageOfNarrowBandAndThereforeHasC void my_srand(unsigned int seed) const {iseed = seed;} }; -class Analytic_Isosurface: public SurfaceThatDoesntTakeAdvantageOfNarrowBandAndThereforeHasCorrectSign { +class LevelSet_String_Function: public SurfaceThatDoesntTakeAdvantageOfNarrowBandAndThereforeHasCorrectSign { public: - Analytic_Isosurface(); + LevelSet_String_Function(const std::string & expression); - virtual ~Analytic_Isosurface() {} + virtual ~LevelSet_String_Function() {} - virtual Surface_Type type() const override { return SPHERE; } - virtual size_t storage_size() const override { return sizeof(Analytic_Isosurface); } + virtual Surface_Type type() const override { return STRING_FUNCTION; } + virtual size_t storage_size() const override { return sizeof(LevelSet_String_Function); } virtual double point_signed_distance(const Vector3d &x) const override; - virtual BoundingBox get_bounding_box() override; + virtual BoundingBox get_bounding_box() override { return myBoundingBox; } + + void set_bounding_box(const BoundingBox & bbox) { myBoundingBox = bbox; } + +private: + String_Function_Expression myExpression; + BoundingBox myBoundingBox; }; } // namespace krino diff --git a/packages/krino/krino/krino_lib/Akri_AnalyticSurfaceInterfaceGeometry.cpp b/packages/krino/krino/krino_lib/Akri_AnalyticSurfaceInterfaceGeometry.cpp index cd5019569038..5958df5bfb71 100644 --- a/packages/krino/krino/krino_lib/Akri_AnalyticSurfaceInterfaceGeometry.cpp +++ b/packages/krino/krino/krino_lib/Akri_AnalyticSurfaceInterfaceGeometry.cpp @@ -35,11 +35,12 @@ static std::function build_edge_distance_function(const Su return distanceFunction; } -static double find_crossing_position(const Surface & surface, const Segment3d & edge) +static double find_crossing_position(const Surface & surface, const Segment3d & edge, const double edgeTol) { const double phi0 = surface.point_signed_distance(edge.GetNode(0)); const double phi1 = surface.point_signed_distance(edge.GetNode(1)); - const auto result = find_root(build_edge_distance_function(surface, edge), 0., 1., phi0, phi1); + const int maxIters = 100; + const auto result = find_root(build_edge_distance_function(surface, edge), 0., 1., phi0, phi1, maxIters, edgeTol); ThrowRequire(result.first); return result.second; } @@ -69,9 +70,11 @@ static Vector3d get_centroid(const std::vector & elemNodesCoords) SurfaceElementCutter::SurfaceElementCutter(const stk::mesh::BulkData & mesh, stk::mesh::Entity element, - const Surface & surface) + const Surface & surface, + const double edgeTol) : myMasterElem(MasterElementDeterminer::getMasterElement(mesh.bucket(element).topology())), - mySurface(surface) + mySurface(surface), + myEdgeCrossingTol(edgeTol) { const FieldRef coordsField(mesh.mesh_meta_data().coordinate_field()); fill_element_node_coordinates(mesh, element, coordsField, myElementNodeCoords); @@ -95,7 +98,7 @@ bool SurfaceElementCutter::have_crossing(const InterfaceID interface, const Segm double SurfaceElementCutter::interface_crossing_position(const InterfaceID interface, const Segment3d & edge) const { const Segment3d globalEdge(parametric_to_global_coordinates(edge.GetNode(0)), parametric_to_global_coordinates(edge.GetNode(1))); - return find_crossing_position(mySurface, globalEdge); + return find_crossing_position(mySurface, globalEdge, myEdgeCrossingTol); } int SurfaceElementCutter::sign_at_position(const InterfaceID interface, const Vector3d & paramCoords) const @@ -116,6 +119,7 @@ Vector3d SurfaceElementCutter::parametric_to_global_coordinates(const Vector3d & static void append_surface_edge_intersection_points(const stk::mesh::BulkData & mesh, const std::vector & elementsToIntersect, const Surface & surface, + const double edgeCrossingTol, const IntersectionPointFilter & intersectionPointFilter, std::vector & intersectionPoints) { @@ -150,7 +154,7 @@ static void append_surface_edge_intersection_points(const stk::mesh::BulkData & if (haveCrossing) { const InterfaceID interface(0,0); - const double location = find_crossing_position(surface, Segment3d(node0Coords, node1Coords)); + const double location = find_crossing_position(surface, Segment3d(node0Coords, node1Coords), edgeCrossingTol); interface.fill_sorted_domains(intersectionPointSortedDomains); const std::vector intersectionPointNodes{node0,node1}; if (intersectionPointFilter(intersectionPointNodes, intersectionPointSortedDomains)) @@ -284,6 +288,21 @@ static void set_domains_for_element_if_it_will_be_uncut_after_snapping(const stk } } +AnalyticSurfaceInterfaceGeometry::AnalyticSurfaceInterfaceGeometry(const Surface_Identifier surfaceIdentifier, + const Surface & surface, + const stk::mesh::Part & activePart, + const CDFEM_Support & cdfemSupport, + const Phase_Support & phaseSupport) + : mySurface(surface), + myActivePart(activePart), + myCdfemSupport(cdfemSupport), + myPhaseSupport(phaseSupport), + mySurfaceIdentifiers({surfaceIdentifier}), + myEdgeCrossingTol(0.1*cdfemSupport.get_snapper().get_edge_tolerance()) +{ + ThrowRequireMsg(myEdgeCrossingTol > 0., "Invalid minimum edge crossing tolerance " << myEdgeCrossingTol); +} + void AnalyticSurfaceInterfaceGeometry::store_phase_for_elements_that_will_be_uncut_after_snapping(const stk::mesh::BulkData & mesh, const std::vector & intersectionPoints, const std::vector & snapInfos, @@ -306,7 +325,7 @@ std::vector AnalyticSurfaceInterfaceGeometry::get_edge_inters const IntersectionPointFilter intersectionPointFilter = keep_all_intersection_points_filter(); std::vector intersectionPoints; - append_surface_edge_intersection_points(mesh, myElementsToIntersect, mySurface, intersectionPointFilter, intersectionPoints); + append_surface_edge_intersection_points(mesh, myElementsToIntersect, mySurface, myEdgeCrossingTol, intersectionPointFilter, intersectionPoints); return intersectionPoints; } @@ -317,7 +336,7 @@ void AnalyticSurfaceInterfaceGeometry::append_element_intersection_points(const std::vector & intersectionPoints) const { prepare_to_process_elements(mesh, elementsToIntersect, nodesToCapturedDomains); - append_surface_edge_intersection_points(mesh, myElementsToIntersect, mySurface, intersectionPointFilter, intersectionPoints); + append_surface_edge_intersection_points(mesh, myElementsToIntersect, mySurface, myEdgeCrossingTol, intersectionPointFilter, intersectionPoints); } std::unique_ptr AnalyticSurfaceInterfaceGeometry::build_element_cutter(const stk::mesh::BulkData & mesh, @@ -325,7 +344,7 @@ std::unique_ptr AnalyticSurfaceInterfaceGeometry::build_element_c const std::function &)> & intersectingPlanesDiagonalPicker) const { std::unique_ptr cutter; - cutter.reset( new SurfaceElementCutter(mesh, element, mySurface) ); + cutter.reset( new SurfaceElementCutter(mesh, element, mySurface, myEdgeCrossingTol) ); return cutter; } diff --git a/packages/krino/krino/krino_lib/Akri_AnalyticSurfaceInterfaceGeometry.hpp b/packages/krino/krino/krino_lib/Akri_AnalyticSurfaceInterfaceGeometry.hpp index 37ee0aeee07a..e07d512e8bc5 100644 --- a/packages/krino/krino/krino_lib/Akri_AnalyticSurfaceInterfaceGeometry.hpp +++ b/packages/krino/krino/krino_lib/Akri_AnalyticSurfaceInterfaceGeometry.hpp @@ -27,7 +27,8 @@ class SurfaceElementCutter : public ElementCutter public: SurfaceElementCutter(const stk::mesh::BulkData & mesh, stk::mesh::Entity element, - const Surface & surface); + const Surface & surface, + const double edgeTol); virtual ~SurfaceElementCutter() {} virtual bool might_have_interior_or_face_intersections() const override { return false; } @@ -55,6 +56,7 @@ class SurfaceElementCutter : public ElementCutter const MasterElement & myMasterElem; std::vector myElementNodeCoords; const Surface & mySurface; + double myEdgeCrossingTol; int myElementSign{0}; }; @@ -65,12 +67,7 @@ class AnalyticSurfaceInterfaceGeometry : public InterfaceGeometry { const Surface & surface, const stk::mesh::Part & activePart, const CDFEM_Support & cdfemSupport, - const Phase_Support & phaseSupport) - : mySurface(surface), - myActivePart(activePart), - myCdfemSupport(cdfemSupport), - myPhaseSupport(phaseSupport), - mySurfaceIdentifiers({surfaceIdentifier}) {} + const Phase_Support & phaseSupport); virtual ~AnalyticSurfaceInterfaceGeometry() {} @@ -113,6 +110,7 @@ class AnalyticSurfaceInterfaceGeometry : public InterfaceGeometry { const CDFEM_Support & myCdfemSupport; const Phase_Support & myPhaseSupport; std::vector mySurfaceIdentifiers; + double myEdgeCrossingTol; mutable ElementToDomainMap myUncutElementPhases; mutable std::vector myElementsToIntersect; }; diff --git a/packages/krino/krino/krino_lib/Akri_AuxMetaData.cpp b/packages/krino/krino/krino_lib/Akri_AuxMetaData.cpp index 80ee7b1c14d7..1c2430c3beab 100644 --- a/packages/krino/krino/krino_lib/Akri_AuxMetaData.cpp +++ b/packages/krino/krino/krino_lib/Akri_AuxMetaData.cpp @@ -297,15 +297,15 @@ AuxMetaData::declare_field( stk::mesh::FieldBase * field = NULL; const std::type_info & value_type = field_type.type_info(); if (value_type == typeid(int)) - field = &my_meta.declare_field< stk::mesh::Field >(entity_rank, fld_name, num_states); + field = &my_meta.declare_field(entity_rank, fld_name, num_states); else if (value_type == typeid(double)) - field = &my_meta.declare_field< stk::mesh::Field >(entity_rank, fld_name, num_states); + field = &my_meta.declare_field(entity_rank, fld_name, num_states); else if (value_type == typeid(unsigned)) - field = &my_meta.declare_field< stk::mesh::Field >(entity_rank, fld_name, num_states); + field = &my_meta.declare_field(entity_rank, fld_name, num_states); else if (value_type == typeid(int64_t)) - field = &my_meta.declare_field< stk::mesh::Field >(entity_rank, fld_name, num_states); + field = &my_meta.declare_field(entity_rank, fld_name, num_states); else if (value_type == typeid(uint64_t)) - field = &my_meta.declare_field< stk::mesh::Field >(entity_rank, fld_name, num_states); + field = &my_meta.declare_field(entity_rank, fld_name, num_states); else { ThrowRequireMsg(false, "Unhandled primitive type " << value_type.name()); } @@ -332,22 +332,23 @@ AuxMetaData::register_field( return FieldRef(fmwk_register_field(fld_name, field_type.name(), field_type.type_info(), field_type.dimension(), entity_rank, num_states, dimension, part, value_type_init)); } - const unsigned field_length = field_type.dimension()*dimension; if (field_type.name() == FieldType::VECTOR_2D.name()) { - auto & field = my_meta.declare_field< stk::mesh::Field >(entity_rank, fld_name, num_states); - stk::mesh::put_field_on_mesh(field, part, field_length, nullptr); + auto & field = my_meta.declare_field(entity_rank, fld_name, num_states); + stk::mesh::put_field_on_mesh(field, part, field_type.dimension(), dimension, nullptr); + stk::io::set_field_output_type(field, "Vector_2D"); return FieldRef(field); } else if (field_type.name() == FieldType::VECTOR_3D.name()) { - auto & field = my_meta.declare_field< stk::mesh::Field >(entity_rank, fld_name, num_states); - stk::mesh::put_field_on_mesh(field, part, field_length, nullptr); + auto & field = my_meta.declare_field(entity_rank, fld_name, num_states); + stk::mesh::put_field_on_mesh(field, part, field_type.dimension(), dimension, nullptr); + stk::io::set_field_output_type(field, "Vector_3D"); return FieldRef(field); } FieldRef field = declare_field(fld_name, field_type, entity_rank, num_states); - stk::mesh::put_field_on_mesh(field.field(), part, field_length, value_type_init); + stk::mesh::put_field_on_mesh(field.field(), part, field_type.dimension(), dimension, value_type_init); return field; } diff --git a/packages/krino/krino/krino_lib/Akri_BoundingBox.hpp b/packages/krino/krino/krino_lib/Akri_BoundingBox.hpp index e64ce337df88..016fbdad8576 100644 --- a/packages/krino/krino/krino_lib/Akri_BoundingBox.hpp +++ b/packages/krino/krino/krino_lib/Akri_BoundingBox.hpp @@ -30,6 +30,8 @@ class BoundingBox_T { VecType max; public: + static const BoundingBox_T ENTIRE_DOMAIN; + static void gather_bboxes( const BoundingBox_T & local_bbox, std::vector< BoundingBox_T > & all_bboxes ); @@ -205,6 +207,10 @@ BoundingBox_T::scale( const Real & scale_factor ) max += extension; } +template +const BoundingBox_T BoundingBox_T::ENTIRE_DOMAIN(VecType(-std::numeric_limits::max(), -std::numeric_limits::max(), -std::numeric_limits::max()), + VecType(std::numeric_limits::max(), std::numeric_limits::max(), std::numeric_limits::max())); + typedef BoundingBox_T BoundingBox; } // namespace krino diff --git a/packages/krino/krino/krino_lib/Akri_BoundingBoxMesh.cpp b/packages/krino/krino/krino_lib/Akri_BoundingBoxMesh.cpp index 32bd6e2b02b6..f968754d331c 100644 --- a/packages/krino/krino/krino_lib/Akri_BoundingBoxMesh.cpp +++ b/packages/krino/krino/krino_lib/Akri_BoundingBoxMesh.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -32,7 +33,10 @@ BoundingBoxMesh::BoundingBoxMesh(stk::topology element_topology, const std::vect element_topology == stk::topology::TETRAHEDRON_4 || element_topology == stk::topology::HEXAHEDRON_8); - m_meta = std::make_unique(element_topology.dimension(), entity_rank_names); + m_meta = stk::mesh::MeshBuilder().set_spatial_dimension(element_topology.dimension()) + .set_entity_rank_names(entity_rank_names) + .create_meta_data(); + AuxMetaData & aux_meta = AuxMetaData::create(*m_meta); stk::mesh::Part & block_part = m_meta->declare_part_with_topology( "block_1", element_topology ); stk::io::put_io_part_attribute(block_part); @@ -114,7 +118,7 @@ void BoundingBoxMesh::populate_mesh(stk::ParallelMachine pm, const stk::mesh::BulkData::AutomaticAuraOption auto_aura_option) { /* %TRACE[ON]% */ Trace trace__("krino::BoundingBoxMesh::populate_mesh()"); /* %TRACE% */ ThrowRequireMsg(m_mesh_bbox.valid(), "Must call set_domain() before populate_mesh()"); - m_mesh = std::make_unique(*m_meta, pm, auto_aura_option); + m_mesh = stk::mesh::MeshBuilder(pm).set_aura_option(auto_aura_option).create(m_meta); if (CUBIC_BOUNDING_BOX_MESH == myMeshStructureType) populate_cell_based_mesh(); else if (TRIANGULAR_LATTICE_BOUNDING_BOX_MESH == myMeshStructureType || FLAT_WALLED_TRIANGULAR_LATTICE_BOUNDING_BOX_MESH == myMeshStructureType) diff --git a/packages/krino/krino/krino_lib/Akri_BoundingBoxMesh.hpp b/packages/krino/krino/krino_lib/Akri_BoundingBoxMesh.hpp index 45572ed5ea79..2a1521dda3ef 100644 --- a/packages/krino/krino/krino_lib/Akri_BoundingBoxMesh.hpp +++ b/packages/krino/krino/krino_lib/Akri_BoundingBoxMesh.hpp @@ -116,7 +116,7 @@ class BoundingBoxMesh { void set_is_cell_edge_function_for_BCC_mesh() const; void set_is_cell_edge_function_for_cell_based_mesh() const; private: - std::unique_ptr m_meta; + std::shared_ptr m_meta; std::unique_ptr m_mesh; std::unique_ptr my_coord_mapping; stk::mesh::PartVector m_elem_parts; diff --git a/packages/krino/krino/krino_lib/Akri_CDFEM_Support.cpp b/packages/krino/krino/krino_lib/Akri_CDFEM_Support.cpp index b8c6201fc0b8..4ddd8e91dd41 100644 --- a/packages/krino/krino/krino_lib/Akri_CDFEM_Support.cpp +++ b/packages/krino/krino/krino_lib/Akri_CDFEM_Support.cpp @@ -60,6 +60,9 @@ CDFEM_Support::CDFEM_Support(stk::mesh::MetaData & meta) my_cdfem_snapper(), my_cdfem_dof_edge_tol(0.0), my_internal_face_stabilization_multiplier(0.0), + mySnappingSharpFeatureAngleInDegrees(0.), + myLengthScaleTypeForInterfaceCFL(LOCAL_LENGTH_SCALE), + myConstantLengthScaleForInterfaceCFL(0.), my_flag_use_hierarchical_dofs(false), my_flag_constrain_CDFEM_to_XFEM_space(false), my_flag_use_nonconformal_element_size(true), diff --git a/packages/krino/krino/krino_lib/Akri_CDFEM_Support.hpp b/packages/krino/krino/krino_lib/Akri_CDFEM_Support.hpp index 9a357f6d3d83..0892174ad98d 100644 --- a/packages/krino/krino/krino_lib/Akri_CDFEM_Support.hpp +++ b/packages/krino/krino/krino_lib/Akri_CDFEM_Support.hpp @@ -55,6 +55,14 @@ enum Simplex_Generation_Method MAX_SIMPLEX_GENERATION_METHOD }; +enum Interface_CFL_Length_Scale +{ + CONSTANT_LENGTH_SCALE=0, + LOCAL_LENGTH_SCALE, + L1_NORM_LENGTH_SCALE, + MAX_LENGTH_SCALE_TYPE +}; + class CDFEM_Support { public: @@ -82,6 +90,8 @@ class CDFEM_Support { void set_global_ids_are_NOT_parallel_consistent() { myGlobalIDsAreParallelConsistent = false; } void activate_interface_refinement(int minimum_level, int maximum_level); void activate_nonconformal_adaptivity(const int num_levels); + void set_snapping_sharp_feature_angle_in_degrees(const double snappingSharpFeatureAngleInDegrees) { mySnappingSharpFeatureAngleInDegrees = snappingSharpFeatureAngleInDegrees; } + double get_snapping_sharp_feature_angle_in_degrees() const { return mySnappingSharpFeatureAngleInDegrees; } void create_parts(); @@ -175,10 +185,14 @@ class CDFEM_Support { void set_use_hierarchical_dofs(bool flag) { my_flag_use_hierarchical_dofs = flag; } bool get_constrain_CDFEM_to_XFEM_space() const { return my_flag_constrain_CDFEM_to_XFEM_space; } void set_constrain_CDFEM_to_XFEM_space(bool flag) { my_flag_constrain_CDFEM_to_XFEM_space = flag; } + + void set_constant_length_scale_for_interface_CFL(double lengthScale) { myConstantLengthScaleForInterfaceCFL = lengthScale; } + double get_constant_length_scale_for_interface_CFL() const { return myConstantLengthScaleForInterfaceCFL; } + void set_length_scale_type_for_interface_CFL(Interface_CFL_Length_Scale lengthScaleType) { myLengthScaleTypeForInterfaceCFL = lengthScaleType; } + Interface_CFL_Length_Scale get_length_scale_type_for_interface_CFL() const { return myLengthScaleTypeForInterfaceCFL; } bool get_use_velocity_to_evaluate_interface_CFL() const { return myFlagUseVelocityToEvaluateInterfaceCFL; } void set_use_velocity_to_evaluate_interface_CFL(bool flag) { myFlagUseVelocityToEvaluateInterfaceCFL = flag; } - void force_ale_prolongation_for_field(const std::string & field_name); private: @@ -225,6 +239,9 @@ class CDFEM_Support { CDFEM_Snapper my_cdfem_snapper; double my_cdfem_dof_edge_tol; double my_internal_face_stabilization_multiplier; + double mySnappingSharpFeatureAngleInDegrees; + Interface_CFL_Length_Scale myLengthScaleTypeForInterfaceCFL; + double myConstantLengthScaleForInterfaceCFL; bool my_flag_use_hierarchical_dofs; bool my_flag_constrain_CDFEM_to_XFEM_space; bool my_flag_use_nonconformal_element_size; diff --git a/packages/krino/krino/krino_lib/Akri_CDMesh.cpp b/packages/krino/krino/krino_lib/Akri_CDMesh.cpp index 9d1be73a6c4c..af2b5193d980 100644 --- a/packages/krino/krino/krino_lib/Akri_CDMesh.cpp +++ b/packages/krino/krino/krino_lib/Akri_CDMesh.cpp @@ -291,7 +291,12 @@ void CDMesh::snap_and_update_fields_and_captured_domains(const InterfaceGeometry stk::mesh::field_copy(my_cdfem_support.get_coords_field(), cdfemSnapField); const stk::mesh::Selector parentElementSelector = get_parent_element_selector(get_active_part(), my_cdfem_support, my_phase_support); - nodesToCapturedDomains = snap_as_much_as_possible_while_maintaining_quality(stk_bulk(), parentElementSelector, snapFields, interfaceGeometry, my_cdfem_support.get_global_ids_are_parallel_consistent()); + nodesToCapturedDomains = snap_as_much_as_possible_while_maintaining_quality(stk_bulk(), + parentElementSelector, + snapFields, + interfaceGeometry, + my_cdfem_support.get_global_ids_are_parallel_consistent(), + my_cdfem_support.get_snapping_sharp_feature_angle_in_degrees()); if (cdfemSnapField.valid()) stk::mesh::field_axpby(+1.0, my_cdfem_support.get_coords_field(), -1.0, cdfemSnapField); @@ -386,6 +391,13 @@ CDMesh::decompose_mesh(stk::mesh::BulkData & mesh, return status; } +static void rebuild_mesh_sidesets(stk::mesh::BulkData & mesh) +{ + for (auto && part : mesh.mesh_meta_data().get_parts()) + if (part->primary_entity_rank() == mesh.mesh_meta_data().side_rank()) + stk::mesh::reconstruct_sideset(mesh, *part); +} + bool CDMesh::modify_mesh() {/* %TRACE[ON]% */ Trace trace__("krino::Mesh::modify_mesh()"); /* %TRACE% */ @@ -631,10 +643,16 @@ CDMesh::rebuild_from_restart_mesh(stk::mesh::BulkData & mesh) // rebuild conformal side parts the_new_mesh->stk_bulk().modification_begin(); + update_node_activation(the_new_mesh->stk_bulk(), the_new_mesh->aux_meta().active_part()); // we should be able to skip this step if there are no higher order elements the_new_mesh->update_element_side_parts(); the_new_mesh->stk_bulk().modification_end(); delete_extraneous_inactive_sides(mesh, the_new_mesh->get_parent_part(), the_new_mesh->get_active_part()); + + rebuild_mesh_sidesets(mesh); + + ParallelThrowAssert(mesh.parallel(), check_face_and_edge_ownership(mesh)); + ParallelThrowAssert(mesh.parallel(), check_face_and_edge_relations(mesh)); } static bool is_child_elem(const stk::mesh::BulkData & mesh, const stk::mesh::Part & childEdgeNodePart, stk::mesh::Entity elem) @@ -735,6 +753,19 @@ CDMesh::find_or_build_subelement_edge_node(const stk::mesh::Entity node, const M return build_subelement_edge_node(node, ownerMeshElem, idToSubElementNode); } +void +CDMesh::find_or_build_midside_nodes(const stk::topology & elemTopo, const Mesh_Element & ownerMeshElem, const stk::mesh::Entity * elemNodes, const NodeVec & subelemNodes ) +{ + if (elemTopo.num_nodes() > elemTopo.base().num_nodes()) + { + for (unsigned iEdge=0; iEdge & idToSubElementNode ) { @@ -747,8 +778,7 @@ CDMesh::build_subelement_edge_node(const stk::mesh::Entity node, const Mesh_Elem const double position = compute_child_position(mesh, node, immediateParent0->entity(), immediateParent1->entity()); - std::unique_ptr newNode = std::make_unique(&ownerMeshElem, position, immediateParent0, immediateParent1); - SubElementNode * edgeNode = add_managed_node(std::move(newNode)); + const SubElementNode * edgeNode = create_edge_node(&ownerMeshElem, immediateParent0, immediateParent1, position); edgeNode->set_entity(stk_bulk(), node); idToSubElementNode[mesh.identifier(node)] = edgeNode; @@ -771,8 +801,8 @@ CDMesh::restore_subelements() for(const auto & b_ptr : buckets) { const stk::topology & topo = b_ptr->topology(); - const unsigned num_nodes = topo.num_nodes(); - subelem_nodes.reserve(num_nodes); + const unsigned num_base_nodes = topo.base().num_nodes(); + subelem_nodes.reserve(num_base_nodes); for(const auto & elem : *b_ptr) { const stk::mesh::Entity parent = get_parent_element(elem); @@ -784,30 +814,54 @@ CDMesh::restore_subelements() subelem_nodes.clear(); // TODO: May need to create subelement edge nodes somehow const auto * elem_nodes = mesh.begin_nodes(elem); - ThrowAssert(mesh.num_nodes(elem) == num_nodes); - for(unsigned i=0; i < num_nodes; ++i) + for(unsigned i=0; i < num_base_nodes; ++i) { const SubElementNode * node = find_or_build_subelement_edge_node(elem_nodes[i], *parentMeshElem, idToSubElementNode); subelem_nodes.push_back(node); } + find_or_build_midside_nodes(topo, *parentMeshElem, elem_nodes, subelem_nodes); std::unique_ptr subelem; switch(topo) { case stk::topology::TRIANGLE_3_2D: + case stk::topology::TRIANGLE_6_2D: subelem = std::make_unique(subelem_nodes, std::vector{-1, -1, -1}, parentMeshElem); break; case stk::topology::TETRAHEDRON_4: + case stk::topology::TETRAHEDRON_10: subelem = std::make_unique(subelem_nodes, std::vector{-1, -1, -1, -1}, parentMeshElem); break; default: - ThrowRuntimeError("At present only Tri3 and Tet4 topologies are supported for restart of CDFEM problems."); + ThrowRuntimeError("At present only Tri3, Tri6, Tet4 and Tet10 topologies are supported for restart of CDFEM problems."); + } + + if (topo == stk::topology::TRIANGLE_6_2D || topo == stk::topology::TETRAHEDRON_10) + { + subelem->build_quadratic_subelements(*this); + std::vector highOrderSubElems; + subelem->get_subelements( highOrderSubElems ); + ThrowRequire(highOrderSubElems.size() == 1); + highOrderSubElems[0]->set_entity(stk_bulk(), elem); } - ThrowAssert(subelem); - subelem->set_entity(stk_bulk(), elem); + else + { + subelem->set_entity(stk_bulk(), elem); + } + const_cast(parentMeshElem)->add_subelement(std::move(subelem)); } } + + std::vector subelems; + for (auto && element : elements) + { + element->get_subelements( subelems ); + if (subelems.size() > 1) + { + element->set_have_interface(); + } + } } void @@ -1049,6 +1103,8 @@ CDMesh::stash_nodal_field_data(const CDMesh & new_mesh) const for(auto&& node_part_ptr : node_parts) { // This is designed to catch side with block_2 + block_1_air, block_1_air + block_1_solid, etc. + // These are included so that we can prolongate a node on the block_1_air + block_1_solid + block_2 + // from a node on that same part ownership. (This is needed in cases where block_2 has other vars). if (node_part_ptr->primary_entity_rank() == stk::topology::ELEMENT_RANK && !my_phase_support.is_nonconformal(node_part_ptr) && stk::io::is_part_io_part(*node_part_ptr)) @@ -1083,9 +1139,12 @@ CDMesh::stash_nodal_field_data(const CDMesh & new_mesh) const const stk::mesh::PartVector & side_parts = bucket_ptr->supersets(); for(auto&& side_part_ptr : side_parts) { - // This is designed to catch side with block_2 + block_1_air, block_1_air + block_1_solid, etc. + // This is designed to catch sides like block_1_air + block_1_solid, etc and not block_2 + block_1_air. + // If we include the nondecomposed blocks like block_2, this could result in prolongation of a node + // on the interface (block_1_air + block_1_solid) from a node on the boundary of the undecomposed block + // (block_1_air + block_2). if (side_part_ptr->primary_entity_rank() == stk::topology::ELEMENT_RANK && - !my_phase_support.is_nonconformal(side_part_ptr) && + my_phase_support.is_conformal(side_part_ptr) && stk::io::is_part_io_part(*side_part_ptr)) { ++num_conformal_parts; @@ -1344,6 +1403,7 @@ CDMesh::find_prolongation_node(const SubElementNode & dst_node) const krinolog << node->entityId() << " "; } krinolog << stk::diag::dendl; + krinolog << " with required fields " << print_fields(stk_meta(), required_fields) << stk::diag::dendl; } } @@ -1356,6 +1416,7 @@ CDMesh::find_prolongation_node(const SubElementNode & dst_node) const return nullptr; } // Search for facet failed. Now try nodes. This will handle triple points. Something better that handles an actual edge search might be better in 3d. + if (krinolog.shouldPrint(LOG_DEBUG)) krinolog << "Prolongation facet search failed for " << dst_node.entityId() << " with required fields " << print_fields(stk_meta(), required_fields) << stk::diag::dendl; const ProlongationNodeData * closest_node = nullptr; double closest_dist2 = std::numeric_limits::max(); for (auto && entry : my_prolong_node_map) @@ -2941,17 +3002,92 @@ CDMesh::get_parent_nodes_and_weights(stk::mesh::Entity child, stk::mesh::Entity //-------------------------------------------------------------------------------- -std::function build_get_element_volume_function(const CDMesh & cdmesh) +std::function build_get_local_length_scale_for_side_function(const CDMesh & cdmesh) { - auto get_element_size = - [&cdmesh](stk::mesh::Entity elem) + const stk::mesh::Selector elementSelector = selectUnion(cdmesh.get_phase_support().get_conformal_parts()) & cdmesh.get_active_part() & cdmesh.get_locally_owned_part(); + + auto get_length_scale_for_side = + [&cdmesh,elementSelector](stk::mesh::Entity side) + { + const stk::mesh::BulkData & mesh = cdmesh.stk_bulk(); + double minElemVolume = 0.; + for (auto elem : StkMeshEntities{mesh.begin_elements(side), mesh.end_elements(side)}) + { + if (elementSelector(mesh.bucket(elem))) + { + stk::mesh::Entity volumeElement = cdmesh.get_cdfem_support().use_nonconformal_element_size() ? cdmesh.get_parent_element(elem) : elem; + ThrowRequire(cdmesh.stk_bulk().is_valid(volumeElement)); + const double elemVol = ElementObj::volume( mesh, volumeElement, cdmesh.get_coords_field() ); + if (minElemVolume == 0. || elemVol < minElemVolume) + minElemVolume = elemVol; + } + } + double lengthScale = 0.; + if (minElemVolume > 0.) + { + const double invDim = 1.0 / mesh.mesh_meta_data().spatial_dimension(); + lengthScale = std::pow(minElemVolume, invDim); + } + + return lengthScale; + }; + return get_length_scale_for_side; +} + +std::function build_get_constant_length_scale_for_side_function(const double lengthScale) +{ + auto get_length_scale_for_side = + [lengthScale](stk::mesh::Entity side) { - stk::mesh::Entity volumeElement = cdmesh.get_cdfem_support().use_nonconformal_element_size() ? cdmesh.get_parent_element(elem) : elem; - ThrowRequire(cdmesh.stk_bulk().is_valid(volumeElement)); - const double elemVolume = ElementObj::volume( cdmesh.stk_bulk(), volumeElement, cdmesh.get_coords_field() ); - return elemVolume; + return lengthScale; }; - return get_element_size; + return get_length_scale_for_side; +} + +std::vector get_unique_owned_volume_elements_using_sides(const CDMesh & cdmesh, const stk::mesh::Selector & interfaceSideSelector) +{ + // Not exactly cheap + const stk::mesh::BulkData & mesh = cdmesh.stk_bulk(); + const stk::mesh::Selector elementSelector = selectUnion(cdmesh.get_phase_support().get_conformal_parts()) & cdmesh.get_active_part() & cdmesh.get_locally_owned_part(); + + std::vector volumeElements; + for( auto&& bucket : mesh.get_buckets(mesh.mesh_meta_data().side_rank(), interfaceSideSelector) ) + { + for (auto && side : *bucket) + { + for (auto elem : StkMeshEntities{mesh.begin_elements(side), mesh.end_elements(side)}) + { + if (elementSelector(mesh.bucket(elem))) + { + stk::mesh::Entity volumeElement = cdmesh.get_cdfem_support().use_nonconformal_element_size() ? cdmesh.get_parent_element(elem) : elem; + volumeElements.push_back(volumeElement); + } + } + } + } + stk::util::sort_and_unique(volumeElements); + return volumeElements; +} + +double compute_L1_norm_of_side_length_scales(const CDMesh & cdmesh, const stk::mesh::Selector & interfaceSideSelector) +{ + const std::vector elementsInNorm = get_unique_owned_volume_elements_using_sides(cdmesh, interfaceSideSelector); + + const double invDim = 1.0 / cdmesh.spatial_dim(); + + double sumLengths = 0.; + for (auto elem : elementsInNorm) + { + const double elemVolume = ElementObj::volume( cdmesh.stk_bulk(), elem, cdmesh.get_coords_field() ); + sumLengths += std::pow(elemVolume, invDim); + } + + const double sumCount = elementsInNorm.size(); + + const std::array localSum{sumLengths, sumCount}; + std::array globalSum; + stk::all_reduce_sum(cdmesh.stk_bulk().parallel(), localSum.data(), globalSum.data(), localSum.size()); + return globalSum[0]/globalSum[1]; } Vector3d get_side_average_of_vector(const stk::mesh::BulkData& mesh, @@ -2998,59 +3134,19 @@ std::function build_get_side_displacement_from_velo return get_element_size; } -Vector3d get_side_normal(const stk::mesh::BulkData& mesh, - const FieldRef coordsField, - stk::mesh::Entity side) -{ - const auto * sideNodes = mesh.begin_nodes(side); - const stk::topology sideTopology = mesh.bucket(side).topology(); - if (sideTopology == stk::topology::TRIANGLE_3 || sideTopology == stk::topology::TRIANGLE_6) - { - const Vector3d v0(field_data(coordsField, sideNodes[0])); - const Vector3d v1(field_data(coordsField, sideNodes[1])); - const Vector3d v2(field_data(coordsField, sideNodes[2])); - return Cross(v1-v0,v2-v0).unit_vector(); - } - else if (sideTopology == stk::topology::LINE_2 || sideTopology == stk::topology::LINE_3) - { - const Vector3d v0(field_data(coordsField, sideNodes[0]), 2); - const Vector3d v1(field_data(coordsField, sideNodes[1]), 2); - return crossZ(v1-v0).unit_vector(); - } - ThrowRequireMsg(false, "Unsupported topology " << sideTopology); - - return Vector3d::ZERO; -} - double get_side_cdfem_cfl(const stk::mesh::BulkData& mesh, const FieldRef coordsField, - const stk::mesh::Selector & elementSelector, const std::function & get_side_displacement, - const std::function & get_element_volume, + const std::function & get_length_scale_for_side, stk::mesh::Entity side) { const Vector3d sideCDFEMDisplacement = get_side_displacement(side); const Vector3d sideNormal = get_side_normal(mesh, coordsField, side); const double sideNormalDisplacement = Dot(sideCDFEMDisplacement, sideNormal); - double minElemVolume = 0.; - for (auto elem : StkMeshEntities{mesh.begin_elements(side), mesh.end_elements(side)}) - { - if (elementSelector(mesh.bucket(elem))) - { - const double elemVol = get_element_volume(elem); - if (minElemVolume == 0. || elemVol < minElemVolume) - minElemVolume = elemVol; - } - } - double sideCFL = 0.; - if (minElemVolume > 0.) - { - const double invDim = 1.0 / mesh.mesh_meta_data().spatial_dimension(); - const double lengthScale = std::pow(minElemVolume, invDim); - sideCFL = sideNormalDisplacement / lengthScale; - } - return sideCFL; + const double sideLengthScale = get_length_scale_for_side(side); + + return (sideLengthScale == 0.) ? 0. : sideNormalDisplacement/sideLengthScale; } double CDMesh::compute_cdfem_cfl(const std::function & get_side_displacement) const @@ -3058,16 +3154,30 @@ double CDMesh::compute_cdfem_cfl(const std::function get_length_scale_for_side; + if (my_cdfem_support.get_length_scale_type_for_interface_CFL() == CONSTANT_LENGTH_SCALE) + { + get_length_scale_for_side = build_get_constant_length_scale_for_side_function(my_cdfem_support.get_constant_length_scale_for_interface_CFL()); + } + else if (my_cdfem_support.get_length_scale_type_for_interface_CFL() == LOCAL_LENGTH_SCALE) + { + get_length_scale_for_side = build_get_local_length_scale_for_side_function(*this); + } + else + { + ThrowRequire(my_cdfem_support.get_length_scale_type_for_interface_CFL() == L1_NORM_LENGTH_SCALE); + const double lengthScaleNorm = compute_L1_norm_of_side_length_scales(*this, interfaceSideSelector); + krinolog << "Using L1 Norm length scale " << lengthScaleNorm << " to compute Interface CFL." << stk::diag::dendl; + get_length_scale_for_side = build_get_constant_length_scale_for_side_function(lengthScaleNorm); + } double cfl = 0.; for( auto&& bucket : stk_bulk().get_buckets(stk_bulk().mesh_meta_data().side_rank(), interfaceSideSelector) ) { for (auto && side : *bucket) { - const double sideCFL = get_side_cdfem_cfl(stk_bulk(), get_coords_field(), elementSelector, get_side_displacement, get_element_volume, side); + const double sideCFL = get_side_cdfem_cfl(stk_bulk(), get_coords_field(), get_side_displacement, get_length_scale_for_side, side); if (sideCFL > 0.) cfl = std::max(cfl, sideCFL); } @@ -3242,7 +3352,9 @@ CDMesh::create_element_and_side_entities(std::vector & side_request { std::vector conformal_subelems; elem->get_subelements(conformal_subelems); - num_local_subelems += conformal_subelems.size(); + for (auto && subelem : conformal_subelems) + if (0 == subelem->entityId()) + ++num_local_subelems; } } diff --git a/packages/krino/krino/krino_lib/Akri_CDMesh.hpp b/packages/krino/krino/krino_lib/Akri_CDMesh.hpp index 7a702be5cd7f..bb285b19dc75 100644 --- a/packages/krino/krino/krino_lib/Akri_CDMesh.hpp +++ b/packages/krino/krino/krino_lib/Akri_CDMesh.hpp @@ -281,6 +281,7 @@ class CDMesh { const SubElementNode * build_subelement_edge_node(const stk::mesh::Entity node, const Mesh_Element & ownerMeshElem, std::map & idToSubElementNode); const SubElementNode * find_or_build_subelement_edge_node_with_id(const stk::mesh::EntityId nodeId, const Mesh_Element & ownerMeshElem, std::map & idToSubElementNode); const SubElementNode * find_or_build_subelement_edge_node(const stk::mesh::Entity node, const Mesh_Element & ownerMeshElem, std::map & idToSubElementNode); + void find_or_build_midside_nodes(const stk::topology & elemTopo, const Mesh_Element & ownerMeshElem, const stk::mesh::Entity * elemNodes, const NodeVec & subelemNodes); stk::mesh::MetaData& my_meta; AuxMetaData& my_aux_meta; diff --git a/packages/krino/krino/krino_lib/Akri_CDMesh_Debug.cpp b/packages/krino/krino/krino_lib/Akri_CDMesh_Debug.cpp index eccd3249c341..46368b0288a2 100644 --- a/packages/krino/krino/krino_lib/Akri_CDMesh_Debug.cpp +++ b/packages/krino/krino/krino_lib/Akri_CDMesh_Debug.cpp @@ -128,6 +128,12 @@ debug_elem_parts_and_relations(const stk::mesh::BulkData & mesh, const Mesh_Elem } } +static double filter_negative_zero(const double val) +{ + if (val == 0.) return 0.; + return val; +} + void debug_nodal_parts_and_fields(const stk::mesh::BulkData & mesh, const SubElementNode * node) { @@ -176,13 +182,13 @@ debug_nodal_parts_and_fields(const stk::mesh::BulkData & mesh, const SubElementN { if (1 == field_length) { - krinolog << " Field: field_name=" << field.name() << ", field_state=" << field.state() << ", value=" << *data << "\n"; + krinolog << " Field: field_name=" << field.name() << ", field_state=" << field.state() << ", value=" << filter_negative_zero(*data) << "\n"; } else { for (unsigned i=0; i(phaseSupport.find_nonconformal_part(*part)); } -static bool is_part_to_check(const Phase_Support & phaseSupport, const AuxMetaData & auxMeta, const stk::mesh::Part & part) +bool is_part_to_check_for_snapping_compatibility(const Phase_Support & phaseSupport, const AuxMetaData & auxMeta, const stk::mesh::EntityRank targetRank, const stk::mesh::Part & part) { const stk::mesh::Part & exposedBoundaryPart = auxMeta.exposed_boundary_part(); - return part.primary_entity_rank() != stk::topology::INVALID_RANK && - (&part == &exposedBoundaryPart || stk::io::is_part_io_part(part)) && + return part.primary_entity_rank() == targetRank && + (&part == &exposedBoundaryPart || stk::io::is_part_io_part(part) || phaseSupport.is_nonconformal(&part)) && part.name().compare(0,7,"refine_") != 0 && !phaseSupport.is_interface(&part); } -static stk::mesh::PartVector get_nonconformal_parts_to_check(const AuxMetaData & auxMeta, const Phase_Support & phaseSupport, const stk::mesh::PartVector & inputParts) +static stk::mesh::PartVector get_nonconformal_parts_to_check(const stk::mesh::BulkData & mesh, const AuxMetaData & auxMeta, const Phase_Support & phaseSupport, const stk::mesh::EntityRank targetRank, const std::vector & targetEntities) { stk::mesh::PartVector partsToCheck; - partsToCheck.reserve(inputParts.size()); - for (auto && part : inputParts) - if (is_part_to_check(phaseSupport, auxMeta, *part)) - partsToCheck.push_back(get_nonconformal_part(phaseSupport, part)); + for (auto && targetEntity : targetEntities) + for (auto && part : mesh.bucket(targetEntity).supersets()) + if (is_part_to_check_for_snapping_compatibility(phaseSupport, auxMeta, targetRank, *part)) + partsToCheck.push_back(get_nonconformal_part(phaseSupport, part)); stk::util::sort_and_unique(partsToCheck, stk::mesh::PartLess()); return partsToCheck; } bool -parts_are_compatible_for_snapping_when_ignoring_phase(const stk::mesh::BulkData & mesh, const AuxMetaData & auxMeta, const Phase_Support & phaseSupport, stk::mesh::Entity possibleSnapNode, stk::mesh::Entity fixedNode) +parts_are_compatible_for_snapping_when_ignoring_phase(const stk::mesh::BulkData & mesh, + const AuxMetaData & auxMeta, + const Phase_Support & phaseSupport, + const stk::mesh::Entity possibleSnapNode, + const stk::mesh::EntityRank targetRank, + const stk::mesh::PartVector & nonconformalPartsToCheck) { - const stk::mesh::PartVector & possibleSnapNodeParts = mesh.bucket(possibleSnapNode).supersets(); - const stk::mesh::PartVector nonconformalPartsToCheck = get_nonconformal_parts_to_check(auxMeta, phaseSupport, mesh.bucket(fixedNode).supersets()); - for (auto && possibleSnapNodePart : possibleSnapNodeParts) + for (auto && possibleSnapNodePart : mesh.bucket(possibleSnapNode).supersets()) { - if (is_part_to_check(phaseSupport, auxMeta, *possibleSnapNodePart)) + if (is_part_to_check_for_snapping_compatibility(phaseSupport, auxMeta, targetRank, *possibleSnapNodePart)) { stk::mesh::Part * nonconformalPart = get_nonconformal_part(phaseSupport, possibleSnapNodePart); if (!stk::mesh::contain(nonconformalPartsToCheck, *nonconformalPart)) @@ -77,6 +80,40 @@ parts_are_compatible_for_snapping_when_ignoring_phase(const stk::mesh::BulkData return true; } +static stk::topology get_simplex_element_topology(const stk::mesh::BulkData & mesh) +{ + return ((mesh.mesh_meta_data().spatial_dimension() == 2) ? stk::topology::TRIANGLE_3_2D : stk::topology::TETRAHEDRON_4); +} + +static void fill_topology_entities(const stk::mesh::BulkData & mesh, const stk::topology & topology, const std::vector & nodes, std::vector & topologyEntities) +{ + topologyEntities.clear(); + if (nodes.size() <= topology.num_nodes()) + { + stk::mesh::get_entities_through_relations(mesh, nodes, topology.rank(), topologyEntities); + } +} + +std::vector which_intersection_point_nodes_are_compatible_for_snapping(const stk::mesh::BulkData & mesh, const AuxMetaData & auxMeta, const Phase_Support & phaseSupport, const std::vector & intersectionPointNodes) +{ + std::vector areIntersectionPointsCompatibleForSnapping(intersectionPointNodes.size(),true); + std::vector topologyEntities; + stk::topology elemTopology = get_simplex_element_topology(mesh); + std::array sideAndElementTopology{{elemTopology.side_topology(), elemTopology}}; + for (stk::topology topo : sideAndElementTopology) + { + fill_topology_entities(mesh, topo, intersectionPointNodes, topologyEntities); + const stk::mesh::PartVector nonconformalPartsToCheck = get_nonconformal_parts_to_check(mesh, auxMeta, phaseSupport, topo.rank(), topologyEntities); + for(size_t iNode=0; iNode & surfaceIDs, const PhaseTag & phase, const InterfaceID interface) { if(surfaceIDs.size() > 1 && oneLSPerPhase) diff --git a/packages/krino/krino/krino_lib/Akri_CDMesh_Utils.hpp b/packages/krino/krino/krino_lib/Akri_CDMesh_Utils.hpp index d04d40168f2f..2dfb18d8462f 100644 --- a/packages/krino/krino/krino_lib/Akri_CDMesh_Utils.hpp +++ b/packages/krino/krino/krino_lib/Akri_CDMesh_Utils.hpp @@ -19,8 +19,10 @@ class AuxMetaData; class Phase_Support; class Surface_Identifier; +bool is_part_to_check_for_snapping_compatibility(const Phase_Support & phaseSupport, const AuxMetaData & auxMeta, const stk::mesh::EntityRank targetRank, const stk::mesh::Part & part); bool parts_are_compatible_for_snapping(const stk::mesh::BulkData & mesh, stk::mesh::Entity possible_snap_node, stk::mesh::Entity fixed_node); -bool parts_are_compatible_for_snapping_when_ignoring_phase(const stk::mesh::BulkData & mesh, const AuxMetaData & auxMeta, const Phase_Support & phaseSupport, stk::mesh::Entity possible_snap_node, stk::mesh::Entity fixed_node); +std::vector which_intersection_point_nodes_are_compatible_for_snapping(const stk::mesh::BulkData & mesh, const AuxMetaData & auxMeta, const Phase_Support & phaseSupport, const std::vector & intersectionPointNodes); +bool parts_are_compatible_for_snapping_when_ignoring_phase(const stk::mesh::BulkData & mesh, const AuxMetaData & auxMeta, const Phase_Support & phaseSupport, const stk::mesh::Entity possibleSnapNode, const stk::mesh::EntityRank targetRank, const std::vector & targetEntities); bool phase_matches_interface(const bool oneLSPerPhase, const std::vector & surfaceIDs, const PhaseTag & phase, const InterfaceID interface); bool determine_phase_from_parts(PhaseTag & phase, const stk::mesh::PartVector & parts, const Phase_Support & phaseSupport); PhaseTag determine_phase_for_entity(const stk::mesh::BulkData & mesh, stk::mesh::Entity entity, const Phase_Support & phaseSupport); diff --git a/packages/krino/krino/krino_lib/Akri_CramersRuleSolver.cpp b/packages/krino/krino/krino_lib/Akri_CramersRuleSolver.cpp new file mode 100644 index 000000000000..b4d9651048c1 --- /dev/null +++ b/packages/krino/krino/krino_lib/Akri_CramersRuleSolver.cpp @@ -0,0 +1,139 @@ +// Copyright 2002 - 2008, 2010, 2011 National Technology Engineering +// Solutions of Sandia, LLC (NTESS). Under the terms of Contract +// DE-NA0003525 with NTESS, the U.S. Government retains certain rights +// in this software. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "Akri_CramersRuleSolver.hpp" +#include +#include + + +namespace krino { +namespace CramersRuleSolver { + +std::array solve3x3( + double a11, double a12, double a13, + double a21, double a22, double a23, + double a31, double a32, double a33, + double k1, double k2, double k3 ) +{ + double det = compute_determinant3x3( a11, a12, a13, a21, a22, a23, a31, a32, a33 ); + + ThrowRequireMsg(det != 0.0, "Bad determinant. Are the points really unique?"); + + std::array answer; + + answer[0] = compute_determinant3x3( k1, a12, a13, k2, a22, a23, k3, a32, a33 )/det; + answer[1] = compute_determinant3x3( a11, k1, a13, a21, k2, a23, a31, k3, a33 )/det; + answer[2] = compute_determinant3x3( a11, a12, k1, a21, a22, k2, a31, a32, k3 )/det; + return( answer ); +} + +std::array solve3x3(const std::array,3> &A, std::array & b) +{ + return solve3x3(A[0][0], A[0][1], A[0][2], A[1][0], A[1][1], A[1][2], A[2][0], A[2][1], A[2][2], b[0], b[1], b[2]); +} + +std::array solve5x5( + double a11, double a12, double a13, double a14, double a15, + double a21, double a22, double a23, double a24, double a25, + double a31, double a32, double a33, double a34, double a35, + double a41, double a42, double a43, double a44, double a45, + double a51, double a52, double a53, double a54, double a55, + double k1, double k2, double k3, double k4, double k5 ) +{ + const double det = compute_determinant5x5( + a11, a12, a13, a14, a15, + a21, a22, a23, a24, a25, + a31, a32, a33, a34, a35, + a41, a42, a43, a44, a45, + a51, a52, a53, a54, a55 ); + + ThrowRequireMsg(det != 0.0, "Bad determinant. Are the points really unique?"); + + std::array answer; + answer[0] = compute_determinant5x5( + k1, a12, a13, a14, a15, + k2, a22, a23, a24, a25, + k3, a32, a33, a34, a35, + k4, a42, a43, a44, a45, + k5, a52, a53, a54, a55)/det; + answer[1] = compute_determinant5x5( + a11, k1, a13, a14, a15, + a21, k2, a23, a24, a25, + a31, k3, a33, a34, a35, + a41, k4, a43, a44, a45, + a51, k5, a53, a54, a55)/det; + answer[2] = compute_determinant5x5( + a11, a12, k1, a14, a15, + a21, a22, k2, a24, a25, + a31, a32, k3, a34, a35, + a41, a42, k4, a44, a45, + a51, a52, k5, a54, a55)/det; + answer[3] = compute_determinant5x5( + a11, a12, a13, k1, a15, + a21, a22, a23, k2, a25, + a31, a32, a33, k3, a35, + a41, a42, a43, k4, a45, + a51, a52, a53, k5, a55)/det; + answer[4] = compute_determinant5x5( + a11, a12, a13, a14, k1, + a21, a22, a23, a24, k2, + a31, a32, a33, a34, k3, + a41, a42, a43, a44, k4, + a51, a52, a53, a54, k5)/det; + return( answer ); +} + +std::array solve5x5(const std::array,5> &A, std::array & b) +{ + return solve5x5( + A[0][0], A[0][1], A[0][2], A[0][3], A[0][4], + A[1][0], A[1][1], A[1][2], A[1][3], A[1][4], + A[2][0], A[2][1], A[2][2], A[2][3], A[2][4], + A[3][0], A[3][1], A[3][2], A[3][3], A[3][4], + A[4][0], A[4][1], A[4][2], A[4][3], A[4][4], + b[0], b[1], b[2], b[3], b[4]); +} + + +double compute_determinant3x3( + double a11, double a12, double a13, + double a21, double a22, double a23, + double a31, double a32, double a33 ) +{ + return( a11*a22*a33 + a12*a23*a31 + a13*a21*a32 - + a13*a22*a31 - a12*a21*a33 - a11*a23*a32 ); +} + +double compute_determinant4x4( + double a11, double a12, double a13, double a14, + double a21, double a22, double a23, double a24, + double a31, double a32, double a33, double a34, + double a41, double a42, double a43, double a44 ) +{ + return( a11*compute_determinant3x3(a22, a23, a24, a32, a33, a34, a42, a43, a44) - + a12*compute_determinant3x3(a21, a23, a24, a31, a33, a34, a41, a43, a44) + + a13*compute_determinant3x3(a21, a22, a24, a31, a32, a34, a41, a42, a44) - + a14*compute_determinant3x3(a21, a22, a23, a31, a32, a33, a41, a42, a43)); +} + +double compute_determinant5x5( + double a11, double a12, double a13, double a14, double a15, + double a21, double a22, double a23, double a24, double a25, + double a31, double a32, double a33, double a34, double a35, + double a41, double a42, double a43, double a44, double a45, + double a51, double a52, double a53, double a54, double a55 ) +{ + return( a11*compute_determinant4x4(a22, a23, a24, a25, a32, a33, a34, a35, a42, a43, a44, a45, a52, a53, a54, a55 ) - + a12*compute_determinant4x4(a21, a23, a24, a25, a31, a33, a34, a35, a41, a43, a44, a45, a51, a53, a54, a55 ) + + a13*compute_determinant4x4(a21, a22, a24, a25, a31, a32, a34, a35, a41, a42, a44, a45, a51, a52, a54, a55 ) - + a14*compute_determinant4x4(a21, a22, a23, a25, a31, a32, a33, a35, a41, a42, a43, a45, a51, a52, a53, a55 ) + + a15*compute_determinant4x4(a21, a22, a23, a24, a31, a32, a33, a34, a41, a42, a43, a44, a51, a52, a53, a54 )); +} + +} // namespace CramersRuleSolver +} // namespace krino diff --git a/packages/krino/krino/krino_lib/Akri_CramersRuleSolver.hpp b/packages/krino/krino/krino_lib/Akri_CramersRuleSolver.hpp new file mode 100644 index 000000000000..8d8634566f7b --- /dev/null +++ b/packages/krino/krino/krino_lib/Akri_CramersRuleSolver.hpp @@ -0,0 +1,56 @@ +// Copyright 2002 - 2008, 2010, 2011 National Technology Engineering +// Solutions of Sandia, LLC (NTESS). Under the terms of Contract +// DE-NA0003525 with NTESS, the U.S. Government retains certain rights +// in this software. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef KRINO_KRINO_KRINO_LIB_AKRI_CRAMERSRULESOLVER_HPP_ +#define KRINO_KRINO_KRINO_LIB_AKRI_CRAMERSRULESOLVER_HPP_ +#include + + +namespace krino { +namespace CramersRuleSolver { + +std::array solve3x3( + double a11, double a12, double a13, + double a21, double a22, double a23, + double a31, double a32, double a33, + double k1, double k2, double k3 ); + +std::array solve3x3(const std::array,3> &A, std::array & b); + +std::array solve5x5( + double a11, double a12, double a13, double a14, double a15, + double a21, double a22, double a23, double a24, double a25, + double a31, double a32, double a33, double a34, double a35, + double a41, double a42, double a43, double a44, double a45, + double a51, double a52, double a53, double a54, double a55, + double k1, double k2, double k3, double k4, double k5 ); + +std::array solve5x5(const std::array,5> &A, std::array & b); + +double compute_determinant3x3( + double a11, double a12, double a13, + double a21, double a22, double a23, + double a31, double a32, double a33 ); + +double compute_determinant4x4( + double a11, double a12, double a13, double a14, + double a21, double a22, double a23, double a24, + double a31, double a32, double a33, double a34, + double a41, double a42, double a43, double a44 ); + +double compute_determinant5x5( + double a11, double a12, double a13, double a14, double a15, + double a21, double a22, double a23, double a24, double a25, + double a31, double a32, double a33, double a34, double a35, + double a41, double a42, double a43, double a44, double a45, + double a51, double a52, double a53, double a54, double a55 ); + +} // namespace CramersRuleSolver +} // namespace krino + +#endif /* KRINO_KRINO_KRINO_LIB_AKRI_CRAMERSRULESOLVER_HPP_ */ diff --git a/packages/krino/krino/krino_lib/Akri_CurvatureLeastSquares.cpp b/packages/krino/krino/krino_lib/Akri_CurvatureLeastSquares.cpp new file mode 100644 index 000000000000..21046c6cc4bf --- /dev/null +++ b/packages/krino/krino/krino_lib/Akri_CurvatureLeastSquares.cpp @@ -0,0 +1,284 @@ +// Copyright 2002 - 2008, 2010, 2011 National Technology Engineering +// Solutions of Sandia, LLC (NTESS). Under the terms of Contract +// DE-NA0003525 with NTESS, the U.S. Government retains certain rights +// in this software. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#include +#include "Akri_CramersRuleSolver.hpp" +#include "Akri_Vec.hpp" + +namespace krino { + +static std::vector get_unique_halo_nodes(const std::vector> & haloSegments) +{ + std::vector uniqueHaloNodes; + for (auto && haloSegment : haloSegments) + { + uniqueHaloNodes.push_back(haloSegment[0]); + uniqueHaloNodes.push_back(haloSegment[1]); + } + + stk::util::sort_and_unique(uniqueHaloNodes); + + return uniqueHaloNodes; +} + +void set_rotation_matrix_for_rotating_normal_to_zDir(std::array,3> & m, const Vector3d & normalDir) +{ + const Vector3d normal = normalDir.unit_vector(); + static const Vector3d zDir(0.,0.,1.); + const double c = Dot(zDir, normal); + Vector3d v = Cross(normal, zDir); + const double s = v.length(); + if (s > 0.) v *= (1./s); + + const double c1 = 1.-c; + + m[0][0] = c + v[0]*v[0]*c1; + m[0][1] = v[0]*v[1]*c1 - v[2]*s; + m[0][2] = v[0]*v[2]*c1 + v[1]*s; + m[1][0] = v[1]*v[0]*c1 + v[2]*s; + m[1][1] = c + v[1]*v[1]*(1.-c); + m[1][2] = v[1]*v[2]*c1 - v[0]*s; + m[2][0] = v[2]*v[0]*c1 - v[1]*s; + m[2][1] = v[2]*v[1]*c1 + v[0]*s; + m[2][2] = c + v[2]*v[2]*c1; +} + +Vector3d compute_patch_normal(const std::vector & haloNodeLocs, const std::vector> & haloSegments) +{ + Vector3d patchNormal = Vector3d::ZERO; + for (auto && haloSegment : haloSegments) + { + const stk::math::Vector3d & xc0 = haloNodeLocs[haloSegment[0]]; + const stk::math::Vector3d & xc1 = haloNodeLocs[haloSegment[1]]; + const stk::math::Vector3d wtNormal = Cross(xc0, xc1) / (xc0.length_squared()*xc1.length_squared()); + patchNormal += wtNormal; + } + + return patchNormal.unit_vector(); +} + +static void fill_matrix_and_rhs_for_curvature_least_squares(const std::vector & rotatedUniqueHaloNodeLocs, std::array,3> & A, std::array & b) +{ + if (rotatedUniqueHaloNodeLocs.size() == 3) + { + for (int i=0; i<3; ++i) + { + A[i][0] = rotatedUniqueHaloNodeLocs[i][0]*rotatedUniqueHaloNodeLocs[i][0]; + A[i][1] = rotatedUniqueHaloNodeLocs[i][0]*rotatedUniqueHaloNodeLocs[i][1]; + A[i][2] = rotatedUniqueHaloNodeLocs[i][1]*rotatedUniqueHaloNodeLocs[i][1]; + b[i] = rotatedUniqueHaloNodeLocs[i][2]; + } + } + else + { + ThrowRequireMsg(rotatedUniqueHaloNodeLocs.size() == 4, "Unexpected vector size in fill_matrix_and_rhs_for_curvature_least_squares."); + std::array,4> Apts; + for (int i=0; i<4; ++i) + { + Apts[i][0] = rotatedUniqueHaloNodeLocs[i][0]*rotatedUniqueHaloNodeLocs[i][0]; + Apts[i][1] = rotatedUniqueHaloNodeLocs[i][0]*rotatedUniqueHaloNodeLocs[i][1]; + Apts[i][2] = rotatedUniqueHaloNodeLocs[i][1]*rotatedUniqueHaloNodeLocs[i][1]; + } + + for (int i=0; i<3; ++i) + { + b[i] = 0.; + for (int k=0; k<4; ++k) + b[i] += Apts[k][i] * rotatedUniqueHaloNodeLocs[k][2]; + + for (int j=0; j<3; ++j) + { + A[i][j] = 0.; + for (int k=0; k<4; ++k) + A[i][j] += Apts[k][j]*Apts[k][i]; + } + } + } +} + +static void fill_matrix_and_rhs_for_curvature_normal_least_squares(const std::vector & rotatedUniqueHaloNodeLocs, std::array,5> & A, std::array & b) +{ + if (rotatedUniqueHaloNodeLocs.size() == 5) + { + for (int i=0; i<5; ++i) + { + A[i][0] = rotatedUniqueHaloNodeLocs[i][0]*rotatedUniqueHaloNodeLocs[i][0]; + A[i][1] = rotatedUniqueHaloNodeLocs[i][0]*rotatedUniqueHaloNodeLocs[i][1]; + A[i][2] = rotatedUniqueHaloNodeLocs[i][1]*rotatedUniqueHaloNodeLocs[i][1]; + A[i][3] = rotatedUniqueHaloNodeLocs[i][0]; + A[i][4] = rotatedUniqueHaloNodeLocs[i][1]; + b[i] = rotatedUniqueHaloNodeLocs[i][2]; + } + } + else + { + std::vector> Apts; + Apts.resize(rotatedUniqueHaloNodeLocs.size()); + for (unsigned i=0; i & rotatedUniqueHaloNodeLocs) +{ + if (rotatedUniqueHaloNodeLocs.size() == 3 || rotatedUniqueHaloNodeLocs.size() == 4) + { + std::array,3> A; + std::array b; + fill_matrix_and_rhs_for_curvature_least_squares(rotatedUniqueHaloNodeLocs, A, b); + + const std::array soln = CramersRuleSolver::solve3x3(A,b); + + return Vector3d(0., 0., -2.*soln[0]-2.*soln[2]); + } + else if (rotatedUniqueHaloNodeLocs.size() >= 5) + { + std::array,5> A; + std::array b; + fill_matrix_and_rhs_for_curvature_normal_least_squares(rotatedUniqueHaloNodeLocs, A, b); + + const std::array soln = CramersRuleSolver::solve5x5(A,b); + + Vector3d normal(-soln[3],-soln[4],1.); + const double mag = normal.unitize(); + + const double curvature = + ((normal[0]*normal[0] - 1.) * 2.*soln[0] + + normal[0]*normal[1] * 2.*soln[1] + + (normal[1]*normal[1] - 1.) * 2.*soln[2]) / mag; + + return curvature*normal; + } + + return Vector3d::ZERO; +} + +static Vector3d compute_least_squares_normal(const std::vector & rotatedUniqueHaloNodeLocs) +{ + ThrowRequire(rotatedUniqueHaloNodeLocs.size() >= 5); + + std::array,5> A; + std::array b; + fill_matrix_and_rhs_for_curvature_normal_least_squares(rotatedUniqueHaloNodeLocs, A, b); + + const std::array soln = CramersRuleSolver::solve5x5(A,b); + + Vector3d normal(-soln[3],-soln[4],1.); + normal.unitize(); + + return normal; +} + +Vector3d rotate_3d_vector(const std::array,3> & m, const Vector3d & v) +{ + return Vector3d( + (m[0][0] * v[0] + m[0][1] * v[1] + m[0][2] * v[2]), + (m[1][0] * v[0] + m[1][1] * v[1] + m[1][2] * v[2]), + (m[2][0] * v[0] + m[2][1] * v[1] + m[2][2] * v[2])); +} + +Vector3d reverse_rotate_3d_vector(const std::array,3> & m, const Vector3d & v) +{ + return Vector3d( + (m[0][0] * v[0] + m[1][0] * v[1] + m[2][0] * v[2]), + (m[0][1] * v[0] + m[1][1] * v[1] + m[2][1] * v[2]), + (m[0][2] * v[0] + m[1][2] * v[1] + m[2][2] * v[2])); +} + +static std::vector get_rotated_neighbor_node_locations(const std::vector & neighborNodeLocs, const std::array,3> & m) +{ + std::vector rotatedUniqueHaloNodeLocs; + rotatedUniqueHaloNodeLocs.reserve(neighborNodeLocs.size()); + for (auto && loc : neighborNodeLocs) + rotatedUniqueHaloNodeLocs.push_back(rotate_3d_vector(m, loc)); + return rotatedUniqueHaloNodeLocs; +} + +static std::vector get_rotated_unique_halo_node_locations(const std::vector & haloNodeLocs, std::vector uniqueHaloNodes, const std::array,3> & m) +{ + std::vector rotatedUniqueHaloNodeLocs; + rotatedUniqueHaloNodeLocs.reserve(uniqueHaloNodes.size()); + for (int haloNode : uniqueHaloNodes) + rotatedUniqueHaloNodeLocs.push_back(rotate_3d_vector(m, haloNodeLocs[haloNode])); + return rotatedUniqueHaloNodeLocs; +} + +Vector3d compute_least_squares_curvature_times_normal(const std::vector & haloNodeLocs, const std::vector> & haloSegments) +{ + if (haloSegments.size() < 3) + return Vector3d::ZERO; + + const Vector3d patchNormal = compute_patch_normal(haloNodeLocs, haloSegments); + + std::vector uniqueHaloNodes = get_unique_halo_nodes(haloSegments); + + std::array,3> m; + set_rotation_matrix_for_rotating_normal_to_zDir(m, patchNormal); + + const std::vector rotatedUniqueHaloNodeLocs = get_rotated_unique_halo_node_locations(haloNodeLocs, uniqueHaloNodes, m); + + const Vector3d rotatedCurvatureNormal = compute_least_squares_curvature_times_normal(rotatedUniqueHaloNodeLocs); + + return reverse_rotate_3d_vector(m, rotatedCurvatureNormal); +} + +Vector3d compute_least_squares_curvature_times_normal(const Vector3d & approximateNormal, const std::vector & neighborNodeLocs) +{ + if (neighborNodeLocs.size() < 3) + return Vector3d::ZERO; + + std::array,3> m; + set_rotation_matrix_for_rotating_normal_to_zDir(m, approximateNormal); + + const std::vector rotatedNbrNodeLocs = get_rotated_neighbor_node_locations(neighborNodeLocs, m); + + const Vector3d rotatedCurvatureNormal = compute_least_squares_curvature_times_normal(rotatedNbrNodeLocs); + + return reverse_rotate_3d_vector(m, rotatedCurvatureNormal); +} + +Vector3d compute_least_squares_normal(const Vector3d & approximateNormal, const std::vector & neighborNodeLocs) +{ + if (neighborNodeLocs.size() < 5) + return approximateNormal; + + std::array,3> m; + set_rotation_matrix_for_rotating_normal_to_zDir(m, approximateNormal); + + const std::vector rotatedNbrNodeLocs = get_rotated_neighbor_node_locations(neighborNodeLocs, m); + + const Vector3d rotatedCurvatureNormal = compute_least_squares_normal(rotatedNbrNodeLocs); + + return reverse_rotate_3d_vector(m, rotatedCurvatureNormal).unit_vector(); +} + +} + + diff --git a/packages/krino/krino/krino_lib/Akri_CurvatureLeastSquares.hpp b/packages/krino/krino/krino_lib/Akri_CurvatureLeastSquares.hpp new file mode 100644 index 000000000000..932542f593e5 --- /dev/null +++ b/packages/krino/krino/krino_lib/Akri_CurvatureLeastSquares.hpp @@ -0,0 +1,30 @@ +// Copyright 2002 - 2008, 2010, 2011 National Technology Engineering +// Solutions of Sandia, LLC (NTESS). Under the terms of Contract +// DE-NA0003525 with NTESS, the U.S. Government retains certain rights +// in this software. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef KRINO_KRINO_KRINO_LIB_AKRI_CURVATURELEASTSQUARES_HPP_ +#define KRINO_KRINO_KRINO_LIB_AKRI_CURVATURELEASTSQUARES_HPP_ +#include +#include +#include "Akri_Vec.hpp" + +namespace krino { + +class Quaternion; + +void set_rotation_matrix_for_rotating_normal_to_zDir(std::array,3> & m, const Vector3d & normalDir); +Vector3d rotate_3d_vector(const std::array,3> & m, const Vector3d & v); +Vector3d reverse_rotate_3d_vector(const std::array,3> & m, const Vector3d & v); + +Vector3d compute_patch_normal(const std::vector & haloNodeLocs, const std::vector> & haloSegments); +Vector3d compute_least_squares_curvature_times_normal(const std::vector & haloNodeLocs, const std::vector> & haloSegments); +Vector3d compute_least_squares_curvature_times_normal(const Vector3d & approximateNormal, const std::vector & neighborNodeLocs); +Vector3d compute_least_squares_normal(const Vector3d & approximateNormal, const std::vector & neighborNodeLocs); + +} + +#endif /* KRINO_KRINO_KRINO_LIB_AKRI_CURVATURELEASTSQUARES_HPP_ */ diff --git a/packages/krino/krino/krino_lib/Akri_Element.hpp b/packages/krino/krino/krino_lib/Akri_Element.hpp index 55afc1bb7dad..1f564a9118b2 100644 --- a/packages/krino/krino/krino_lib/Akri_Element.hpp +++ b/packages/krino/krino/krino_lib/Akri_Element.hpp @@ -174,6 +174,7 @@ class Mesh_Element : public ElementObj { int get_interface_index(const InterfaceID interface) const; const std::vector & get_sorted_cutting_interfaces() const { return myCuttingInterfaces; } virtual void determine_decomposed_elem_phase(const std::vector & surfaceIDs) override; + void set_have_interface() { my_have_interface = true; } bool triangulate(const CDMesh & mesh, const InterfaceGeometry & interfaceGeometry); //return value indicates if any changes were made void create_cutter(const CDMesh & mesh, const InterfaceGeometry & interfaceGeometry); diff --git a/packages/krino/krino/krino_lib/Akri_IO_Helpers.cpp b/packages/krino/krino/krino_lib/Akri_IO_Helpers.cpp index 4b51faad2f62..a5c54a658ab7 100644 --- a/packages/krino/krino/krino_lib/Akri_IO_Helpers.cpp +++ b/packages/krino/krino/krino_lib/Akri_IO_Helpers.cpp @@ -8,11 +8,8 @@ #include -#include #include #include -#include -#include #include #include #include @@ -39,48 +36,44 @@ Block_Surface_Connectivity::Block_Surface_Connectivity(const stk::mesh::MetaData } } -Block_Surface_Connectivity::Block_Surface_Connectivity(const stk::mesh::MetaData & meta, const Ioss::Region & io_region) +void Block_Surface_Connectivity::dump_surface_connectivity(const stk::mesh::MetaData & meta) { - /* %TRACE[ON]% */ - Trace trace__("Block_Surface_Connectivity::Block_Surface_Connectivity(const Ioss::Region & reg)"); /* %TRACE% */ + const std::vector surfacesInMap = meta.get_surfaces_in_surface_to_block_map(); + for(auto && surface : surfacesInMap) + { + krinolog << "Surface " << surface->name() << " touches blocks "; + for (auto && touchingBlock : meta.get_blocks_touching_surface(surface)) + krinolog << touchingBlock->name() << " "; + krinolog << stk::diag::dendl; + } +} - std::vector side_block_names; - std::vector side_block_ordinals; +std::set Block_Surface_Connectivity::get_surfaces_touching_block(const stk::mesh::PartOrdinal & blockOrdinal) const +{ + auto it = block_to_surface_map.find(blockOrdinal); + if(it != block_to_surface_map.end()) + return it->second; - for(auto sideset : io_region.get_sidesets()) - { - side_block_names.clear(); - sideset->block_membership(side_block_names); - side_block_ordinals.clear(); - for (auto && block_name : side_block_names) - { - const stk::mesh::Part * side_block_part = meta.get_part(block_name); - ThrowRequire(nullptr != side_block_part); - side_block_ordinals.push_back(side_block_part->mesh_meta_data_ordinal()); - } - const stk::mesh::Part * side_part = meta.get_part(sideset->name()); - ThrowRequire(nullptr != side_part); - add_surface(side_part->mesh_meta_data_ordinal(), std::set(side_block_ordinals.begin(), side_block_ordinals.end())); + std::set emptySurfaces; + return emptySurfaces; +} - if (!sideset->get_side_blocks().empty()) - { - for (auto&& side_subset : sideset->get_side_blocks()) - { - // Fmwk only creates subset if more than 1 sideblock, but stk always creates them, so just check. - const stk::mesh::Part * side_subset_part = meta.get_part(side_subset->name()); - if (nullptr == side_subset_part) continue; - side_block_names.clear(); - side_subset->block_membership(side_block_names); - side_block_ordinals.clear(); - for (auto && block_name : side_block_names) - { - const stk::mesh::Part * side_block_part = meta.get_part(block_name); - ThrowRequire(nullptr != side_block_part); - side_block_ordinals.push_back(side_block_part->mesh_meta_data_ordinal()); - } - add_surface(side_subset_part->mesh_meta_data_ordinal(), std::set(side_block_ordinals.begin(), side_block_ordinals.end())); - } - } +std::set Block_Surface_Connectivity::get_blocks_touching_surface(const stk::mesh::PartOrdinal & surfaceOrdinal) const +{ + auto it = surface_to_block_map.find(surfaceOrdinal); + if(it != surface_to_block_map.end()) + return it->second; + + std::set emptyBlocks; + return emptyBlocks; +} + +void Block_Surface_Connectivity::add_surface(const stk::mesh::PartOrdinal & surf_ordinal, const std::set touching_blocks) +{ + surface_to_block_map[surf_ordinal].insert(touching_blocks.begin(), touching_blocks.end()); + for(auto && block : touching_blocks) + { + block_to_surface_map[block].insert(surf_ordinal); } } diff --git a/packages/krino/krino/krino_lib/Akri_IO_Helpers.hpp b/packages/krino/krino/krino_lib/Akri_IO_Helpers.hpp index 2ec5005ffd08..0af0bdb1d083 100644 --- a/packages/krino/krino/krino_lib/Akri_IO_Helpers.hpp +++ b/packages/krino/krino/krino_lib/Akri_IO_Helpers.hpp @@ -8,53 +8,27 @@ #ifndef Akri_IO_Helpers_h #define Akri_IO_Helpers_h -// -#include -#include - -#include +#include #include -namespace stk { namespace mesh { class BulkData; } } namespace stk { namespace mesh { class MetaData; } } -namespace stk { namespace diag { class Timer; } } namespace Ioss { class Region; } namespace krino { -class AuxMetaData; - class Block_Surface_Connectivity { public: Block_Surface_Connectivity() {} Block_Surface_Connectivity(const stk::mesh::MetaData & meta); - Block_Surface_Connectivity(const stk::mesh::MetaData & meta, const Ioss::Region & io_region); - void get_surfaces_touching_block(const stk::mesh::PartOrdinal & block_ordinal, - std::set & surface_ordinals) const - { - auto it = block_to_surface_map.find(block_ordinal); - if(it == block_to_surface_map.end()) return; - surface_ordinals.insert(it->second.begin(), it->second.end()); - } - void get_blocks_touching_surface(const stk::mesh::PartOrdinal & surface_ordinal, - std::set & block_ordinals) const - { - block_ordinals.clear(); - auto it = surface_to_block_map.find(surface_ordinal); - if(it == surface_to_block_map.end()) return; - block_ordinals.insert(it->second.begin(), it->second.end()); - } - void add_surface(const stk::mesh::PartOrdinal & surf_ordinal, const std::set touching_blocks) - { - surface_to_block_map[surf_ordinal].insert(touching_blocks.begin(), touching_blocks.end()); - for(auto && block : touching_blocks) - { - block_to_surface_map[block].insert(surf_ordinal); - } - } + std::set get_surfaces_touching_block(const stk::mesh::PartOrdinal & blockOrdinal) const; + std::set get_blocks_touching_surface(const stk::mesh::PartOrdinal & surfaceOrdinal) const; + void add_surface(const stk::mesh::PartOrdinal & surf_ordinal, const std::set touching_blocks); + + static void dump_surface_connectivity(const stk::mesh::MetaData & meta); + private: std::map< stk::mesh::PartOrdinal, std::set > block_to_surface_map; std::map< stk::mesh::PartOrdinal, std::set > surface_to_block_map; diff --git a/packages/krino/krino/krino_lib/Akri_LevelSet.cpp b/packages/krino/krino/krino_lib/Akri_LevelSet.cpp index 93bfea805dc5..97e3860cdb78 100644 --- a/packages/krino/krino/krino_lib/Akri_LevelSet.cpp +++ b/packages/krino/krino/krino_lib/Akri_LevelSet.cpp @@ -204,7 +204,7 @@ void LevelSet::register_fields(void) const bool cdfem_is_active = krino::CDFEM_Support::is_active(meta()); if (cdfem_is_active) { - Phase_Support phase_support = Phase_Support::get(meta()); + Phase_Support & phase_support = Phase_Support::get(meta()); for (auto partPtr : meta().get_mesh_parts()) { if (partPtr->primary_entity_rank() == stk::topology::ELEMENT_RANK && @@ -1372,7 +1372,7 @@ LevelSet::simple_remove_wall_features() const void LevelSet::set_surface_parts_vector() { - Phase_Support my_phase_support = Phase_Support::get(meta()); + Phase_Support & my_phase_support = Phase_Support::get(meta()); std::vector conformal_parts = my_phase_support.get_conformal_parts(); diff --git a/packages/krino/krino/krino_lib/Akri_MeshClone.cpp b/packages/krino/krino/krino_lib/Akri_MeshClone.cpp index b1cfdc570951..833aeb72b12b 100644 --- a/packages/krino/krino/krino_lib/Akri_MeshClone.cpp +++ b/packages/krino/krino/krino_lib/Akri_MeshClone.cpp @@ -10,6 +10,8 @@ #include #include #include +#include +#include #include #include @@ -83,16 +85,13 @@ MeshClone::MeshClone( stk::mesh::BulkData & orig_mesh, stk::diag::Timer parent_t { stk::diag::TimeBlock timer__(my_timer); const stk::mesh::MetaData & in_meta = my_orig_mesh->mesh_meta_data(); - my_meta = std::make_unique(); + my_meta = stk::mesh::MeshBuilder().create_meta_data(); clone_meta_data_parts_and_fields(in_meta, *my_meta); - my_mesh = std::make_unique(*my_meta, - my_orig_mesh->parallel(), - stk::mesh::BulkData::NO_AUTO_AURA -#ifdef SIERRA_MIGRATION - ,my_orig_mesh->add_fmwk_data() -#endif - ); + my_mesh = stk::mesh::MeshBuilder(my_orig_mesh->parallel()) + .set_aura_option(stk::mesh::BulkData::NO_AUTO_AURA) + .set_add_fmwk_data(my_orig_mesh->add_fmwk_data()) + .create(my_meta); my_meta->commit(); @@ -134,26 +133,8 @@ void MeshClone::clone_mesh(const stk::mesh::BulkData & in_mesh, stk::mesh::BulkD { /* %TRACE[ON]% */ Trace trace__("krino::MeshClone::clone_mesh(const stk::mesh::BulkData & in_mesh, stk::mesh::BulkData & out_mesh, const bool full_overwrite)"); /* %TRACE% */ if (full_overwrite) { - // Ugly, but legal and effective. - stk::mesh::MetaData & out_meta = out_mesh.mesh_meta_data(); - out_mesh.~BulkData(); - - const stk::mesh::BulkData::AutomaticAuraOption aura_option = - in_mesh.is_automatic_aura_on() ? - stk::mesh::BulkData::AUTO_AURA : - stk::mesh::BulkData::NO_AUTO_AURA; - - new (&out_mesh) stk::mesh::BulkData(out_meta, - in_mesh.parallel(), - aura_option -#ifdef SIERRA_MIGRATION - ,in_mesh.add_fmwk_data() -#endif - ); - - out_mesh.modification_begin(); - clone_bulk_data_entities(in_mesh, out_mesh, false); - out_mesh.modification_end(); +// std::function op = [](stk::mesh::BulkData& outMesh_) {}; + stk::tools::replace_bulk_data(in_mesh, out_mesh/*, op*/); } else { @@ -172,9 +153,9 @@ void MeshClone::clone_mesh(const stk::mesh::BulkData & in_mesh, stk::mesh::BulkD out_mesh.modification_begin(); clone_bulk_data_entities(in_mesh, out_mesh, true); out_mesh.modification_end(); + copy_field_data(in_mesh, out_mesh); } - copy_field_data(in_mesh, out_mesh); } diff --git a/packages/krino/krino/krino_lib/Akri_MeshClone.hpp b/packages/krino/krino/krino_lib/Akri_MeshClone.hpp index d73789a38c97..ac7f4bc7cd37 100644 --- a/packages/krino/krino/krino_lib/Akri_MeshClone.hpp +++ b/packages/krino/krino/krino_lib/Akri_MeshClone.hpp @@ -54,7 +54,7 @@ class MeshClone { static stk::mesh::Entity get_entity_on_other_mesh(const stk::mesh::BulkData & mesh, stk::mesh::Entity entity, const stk::mesh::BulkData & other_mesh); stk::mesh::BulkData* my_orig_mesh; - std::unique_ptr my_meta; + std::shared_ptr my_meta; std::unique_ptr my_mesh; mutable stk::diag::Timer my_timer; diff --git a/packages/krino/krino/krino_lib/Akri_MeshHelpers.cpp b/packages/krino/krino/krino_lib/Akri_MeshHelpers.cpp index 30fe6f7a21ff..1758a920449d 100644 --- a/packages/krino/krino/krino_lib/Akri_MeshHelpers.cpp +++ b/packages/krino/krino/krino_lib/Akri_MeshHelpers.cpp @@ -56,6 +56,30 @@ void resize_container(CONTAINER & container, size_t size) resizer.resize(size); } +Vector3d get_side_normal(const stk::mesh::BulkData& mesh, + const FieldRef coordsField, + stk::mesh::Entity side) +{ + const auto * sideNodes = mesh.begin_nodes(side); + const stk::topology sideTopology = mesh.bucket(side).topology(); + if (sideTopology == stk::topology::TRIANGLE_3 || sideTopology == stk::topology::TRIANGLE_6) + { + const Vector3d v0(field_data(coordsField, sideNodes[0])); + const Vector3d v1(field_data(coordsField, sideNodes[1])); + const Vector3d v2(field_data(coordsField, sideNodes[2])); + return Cross(v1-v0,v2-v0).unit_vector(); + } + else if (sideTopology == stk::topology::LINE_2 || sideTopology == stk::topology::LINE_3) + { + const Vector3d v0(field_data(coordsField, sideNodes[0]), 2); + const Vector3d v1(field_data(coordsField, sideNodes[1]), 2); + return crossZ(v1-v0).unit_vector(); + } + ThrowRequireMsg(false, "Unsupported topology " << sideTopology); + + return Vector3d::ZERO; +} + void fill_procs_owning_or_sharing_or_ghosting_node(const stk::mesh::BulkData& bulkData, stk::mesh::Entity node, std::vector & procsOwningSharingOrGhostingNode) { ThrowAssert(bulkData.parallel_owner_rank(node)==bulkData.parallel_rank()); @@ -111,16 +135,24 @@ static std::array gather_tri_coordinates(const stk::mesh::Bul return elementNodeCoords; } -static double compute_tri_volume(const std::array & elementNodeCoords) +double compute_tri_volume(const std::array & elementNodeCoords) { return 0.5*(Cross(elementNodeCoords[1]-elementNodeCoords[0], elementNodeCoords[2]-elementNodeCoords[0]).length()); } -static double compute_tet_volume(const std::array & elementNodeCoords) +double compute_tet_volume(const std::array & elementNodeCoords) { return Dot(elementNodeCoords[3]-elementNodeCoords[0],Cross(elementNodeCoords[1]-elementNodeCoords[0], elementNodeCoords[2]-elementNodeCoords[0]))/6.0; } +double compute_tri_or_tet_volume(const std::vector & elementNodeCoords) +{ + ThrowAssert(elementNodeCoords.size() == 4 || elementNodeCoords.size() == 3); + if (elementNodeCoords.size() == 4) + return compute_tet_volume({{elementNodeCoords[0],elementNodeCoords[1],elementNodeCoords[2],elementNodeCoords[3]}}); + return compute_tri_volume({{elementNodeCoords[0],elementNodeCoords[1],elementNodeCoords[2]}}); +} + static double compute_tri_or_tet_volume(const stk::mesh::BulkData & mesh, stk::mesh::Entity element, const FieldRef coordsField) { stk::topology elemTopology = mesh.bucket(element).topology(); @@ -609,6 +641,41 @@ debug_entity(const stk::mesh::BulkData & mesh, stk::mesh::Entity entity) return debug_entity(mesh, entity, false); } +static void +debug_entity_1line(std::ostream & output, const stk::mesh::BulkData & mesh, stk::mesh::Entity entity) +{ + if (!mesh.is_valid(entity)) + { + output << "Invalid entity: " << mesh.entity_key(entity) << std::endl; + return; + } + output << mesh.entity_key(entity); + output << ", Connectivity: "; + const stk::mesh::EntityRank end_rank = static_cast(mesh.mesh_meta_data().entity_rank_count()); + for (stk::mesh::EntityRank r = stk::topology::BEGIN_RANK; r < end_rank; ++r) { + unsigned num_rels = mesh.num_connectivity(entity, r); + stk::mesh::Entity const *rel_entities = mesh.begin(entity, r); + stk::mesh::ConnectivityOrdinal const *rel_ordinals = mesh.begin_ordinals(entity, r); + for (unsigned i = 0; i < num_rels; ++i) { + output << " " << mesh.entity_key(rel_entities[i]) + << " @" << rel_ordinals[i] << " "; + } + } + output << ", Parts: "; + for(auto&& part : mesh.bucket(entity).supersets()) + { + output << part->name() << " "; + } +} + +std::string +debug_entity_1line(const stk::mesh::BulkData & mesh, stk::mesh::Entity entity) +{ + std::ostringstream out; + debug_entity_1line(out, mesh, entity); + return out.str(); +} + //-------------------------------------------------------------------------------- std::vector @@ -1100,6 +1167,14 @@ attach_sides_to_elements(stk::mesh::BulkData & mesh) mesh.modification_end(); } +static bool is_entity_attached_to_element(const stk::mesh::BulkData & mesh, const stk::mesh::EntityRank entityRank, const stk::mesh::Entity entity, const stk::mesh::Entity element) +{ + for (auto && elemEntity : StkMeshEntities{mesh.begin(element, entityRank), mesh.end(element, entityRank)}) + if (elemEntity == entity) + return true; + return false; +} + void attach_entity_to_elements(stk::mesh::BulkData & mesh, stk::mesh::Entity entity) { @@ -1134,18 +1209,8 @@ attach_entity_to_elements(stk::mesh::BulkData & mesh, stk::mesh::Entity entity) { continue; } - bool already_attached = false; - const unsigned num_elem_entities = mesh.num_connectivity(elem, entity_rank); - const stk::mesh::Entity* elem_entities = mesh.begin(elem, entity_rank); - for (unsigned it_s=0; it_s relationship(stk::mesh::INVALID_CONNECTIVITY_ORDINAL, stk::mesh::INVALID_PERMUTATION); if (!have_coincident_shell) @@ -1174,17 +1239,18 @@ attach_entity_to_elements(stk::mesh::BulkData & mesh, stk::mesh::Entity entity) } mesh.declare_relation( elem, entity, relationship.first, relationship.second, scratch1, scratch2, scratch3 ); - const bool successfully_attached = (find_entity_by_ordinal(mesh, elem, entity_rank, relationship.first) == entity); - if (!successfully_attached) + const bool successfullyAttached = (find_entity_by_ordinal(mesh, elem, entity_rank, relationship.first) == entity); + if (!successfullyAttached) { - krinolog << "Could not attach " << debug_entity(mesh,entity) << " to element " << debug_entity(mesh,elem) << stk::diag::dendl; - krinolog << "Existing attached entities:" << stk::diag::dendl; - for (unsigned it_s=0; it_s()) { const auto * stored_parent_ids = field_data(parent_id_field, edge_node_entity); - ThrowAssertMsg(stored_parent_ids, "No SubElementNode found for node " << mesh.identifier(edge_node_entity) + ThrowRequireMsg(stored_parent_ids, "No SubElementNode found for node " << mesh.identifier(edge_node_entity) << ", but it does not have the parent_ids field suggesting it is a mesh node."); parent_ids[0] = stored_parent_ids[0]; parent_ids[1] = stored_parent_ids[1]; @@ -1469,7 +1535,7 @@ get_edge_node_parent_ids(const stk::mesh::BulkData & mesh, else if (parent_id_field.type_is()) { const auto * stored_parent_ids = field_data(parent_id_field, edge_node_entity); - ThrowAssertMsg(stored_parent_ids, "No SubElementNode found for node " << mesh.identifier(edge_node_entity) + ThrowRequireMsg(stored_parent_ids, "No SubElementNode found for node " << mesh.identifier(edge_node_entity) << ", but it does not have the parent_ids field suggesting it is a mesh node."); parent_ids[0] = stored_parent_ids[0]; parent_ids[1] = stored_parent_ids[1]; @@ -1492,7 +1558,7 @@ void get_parent_nodes_from_child(const stk::mesh::BulkData & mesh, auto parent_ids = get_edge_node_parent_ids(mesh, parent_id_field, child); const stk::mesh::Entity parent0 = mesh.get_entity(stk::topology::NODE_RANK, parent_ids[0]); const stk::mesh::Entity parent1 = mesh.get_entity(stk::topology::NODE_RANK, parent_ids[1]); - ThrowAssert(mesh.is_valid(parent0) && mesh.is_valid(parent1)); + ThrowRequire(mesh.is_valid(parent0) && mesh.is_valid(parent1)); get_parent_nodes_from_child(mesh, parent0, parent_id_field, parent_nodes); get_parent_nodes_from_child(mesh, parent1, parent_id_field, parent_nodes); } diff --git a/packages/krino/krino/krino_lib/Akri_MeshHelpers.hpp b/packages/krino/krino/krino_lib/Akri_MeshHelpers.hpp index ccbb8a339e68..8f14d1210195 100644 --- a/packages/krino/krino/krino_lib/Akri_MeshHelpers.hpp +++ b/packages/krino/krino/krino_lib/Akri_MeshHelpers.hpp @@ -44,6 +44,10 @@ struct StkMeshEntities value_type operator[](int i) const { return *(mBegin + i); } }; +double compute_tri_volume(const std::array & elementNodeCoords); +double compute_tet_volume(const std::array & elementNodeCoords); +double compute_tri_or_tet_volume(const std::vector & elementNodeCoords); +Vector3d get_side_normal(const stk::mesh::BulkData& mesh, const FieldRef coordsField, stk::mesh::Entity side); void fill_element_node_coordinates(const stk::mesh::BulkData & mesh, stk::mesh::Entity element, const FieldRef coordsField, std::vector & elementNodeCoords); void fill_procs_owning_or_sharing_or_ghosting_node(const stk::mesh::BulkData& bulkData, stk::mesh::Entity node, std::vector & procsOwningSharingOrGhostingNode); double compute_maximum_element_size(stk::mesh::BulkData& mesh); @@ -79,6 +83,14 @@ stk::mesh::PartVector get_common_io_parts(const stk::mesh::BulkData & mesh, cons stk::mesh::PartVector get_removable_parts(const stk::mesh::BulkData & mesh, const stk::mesh::Bucket & bucket); stk::mesh::PartVector get_removable_parts(const stk::mesh::BulkData & mesh, const stk::mesh::Entity entity); +template +void fill_node_locations(const int dim, const FieldRef coordsField, const NODECONTAINER & nodes, std::vector & nodeLocations) +{ + nodeLocations.clear(); + for (auto node : nodes) + nodeLocations.emplace_back(field_data(coordsField, node), dim); +} + void store_edge_node_parent_ids(const stk::mesh::BulkData & mesh, const FieldRef & parent_id_field, @@ -104,6 +116,7 @@ const unsigned * get_edge_node_ordinals(stk::topology topology, unsigned edge_or std::string debug_entity(const stk::mesh::BulkData & mesh, stk::mesh::Entity entity); std::string debug_entity(const stk::mesh::BulkData & mesh, stk::mesh::Entity entity, const bool includeFields); +std::string debug_entity_1line(const stk::mesh::BulkData & mesh, stk::mesh::Entity entity); struct SideRequest { diff --git a/packages/krino/krino/krino_lib/Akri_Phase_Support.cpp b/packages/krino/krino/krino_lib/Akri_Phase_Support.cpp index 8b46db9b7303..1c6224baefc4 100644 --- a/packages/krino/krino/krino_lib/Akri_Phase_Support.cpp +++ b/packages/krino/krino/krino_lib/Akri_Phase_Support.cpp @@ -73,9 +73,7 @@ std::string build_part_name(const krino::Phase_Support & ps, std::string parent_part_name = parent_part->name(); std::string io_part_name = part.name(); - std::set touching_block_ordinals; - ps.get_input_blocks_touching_surface(ps.get_input_block_surface_connectivity(), part.mesh_meta_data_ordinal(), touching_block_ordinals); - + const std::set touching_block_ordinals = ps.get_input_block_surface_connectivity().get_blocks_touching_surface(part.mesh_meta_data_ordinal()); ThrowRequireMsg(touching_block_ordinals.size() > 0, "krino::Akri_Phase_Support: Side block must be touching at least 1 block"); @@ -262,12 +260,12 @@ Phase_Support::addPhasePart(stk::mesh::Part & io_part, PhasePartSet & phase_part } void -Phase_Support::create_nonconformal_parts(const PartSet & decomposed_ioparts) +Phase_Support::create_nonconformal_parts(const PartSet & decomposedIoParts) { const std::string nonconformal_part_suffix = "_nonconformal"; - for(PartSet::const_iterator it = decomposed_ioparts.begin(); it != decomposed_ioparts.end(); ++it) + for(auto && decomposedIoPart : decomposedIoParts) { - const stk::mesh::Part & iopart = aux_meta().get_part((*it)->name()); + const stk::mesh::Part & iopart = aux_meta().get_part(decomposedIoPart->name()); std::string nonconformal_part_name = build_part_name(*this, iopart, PhaseTag(), nonconformal_part_suffix); @@ -350,8 +348,7 @@ Phase_Support::get_blocks_and_touching_surfaces(const stk::mesh::MetaData & mesh for (auto && block_ptr : input_blocks) { blocks_and_touching_sides.insert(block_ptr); - std::set touching_surface_ordinals; - get_input_surfaces_touching_block(input_block_surface_info, block_ptr->mesh_meta_data_ordinal(), touching_surface_ordinals); + const std::set touching_surface_ordinals = input_block_surface_info.get_surfaces_touching_block(block_ptr->mesh_meta_data_ordinal()); for (auto && surf_ordinal : touching_surface_ordinals) { stk::mesh::Part & surf_part = mesh_meta.get_part(surf_ordinal); @@ -387,8 +384,6 @@ void Phase_Support::subset_and_alias_surface_phase_parts(const PhaseVec& ls_phases, const PartSet& decomposed_ioparts) { - std::set touching_block_ordinals; - for (auto && io_part : decomposed_ioparts) { if (!(io_part->subsets().empty())) @@ -403,10 +398,8 @@ Phase_Support::subset_and_alias_surface_phase_parts(const PhaseVec& ls_phases, stk::mesh::Part * nonconformal_iopart = const_cast(find_nonconformal_part(*io_part)); ThrowRequire(NULL != nonconformal_iopart); - for (stk::mesh::PartVector::const_iterator subset = io_part->subsets().begin(); - subset != io_part->subsets().end(); ++subset) + for (auto && io_part_subset : io_part->subsets()) { - stk::mesh::Part * io_part_subset = *subset; ThrowRequire(NULL != io_part_subset); addPhasePart(*io_part_subset, my_phase_parts, ls_phase_entry); @@ -422,7 +415,7 @@ Phase_Support::subset_and_alias_surface_phase_parts(const PhaseVec& ls_phases, if(krinolog.shouldPrint(LOG_PARTS)) krinolog << "Adding " << nonconformal_iopart_subset->name() << " as subset of " << nonconformal_iopart->name() << stk::diag::dendl; meta().declare_part_subset(*nonconformal_iopart, *nonconformal_iopart_subset); - get_input_blocks_touching_surface(my_input_block_surface_connectivity, io_part_subset->mesh_meta_data_ordinal(), touching_block_ordinals); + const std::set touching_block_ordinals = my_input_block_surface_connectivity.get_blocks_touching_surface(io_part_subset->mesh_meta_data_ordinal()); for (auto && touching_block_ordinal : touching_block_ordinals) { const std::string conformal_part_alias = conformal_iopart->name() + "_" + meta().get_part(touching_block_ordinal).name(); @@ -435,49 +428,80 @@ Phase_Support::subset_and_alias_surface_phase_parts(const PhaseVec& ls_phases, } } +void +Phase_Support::update_touching_parts_for_phase_part(const stk::mesh::Part & origPart, const stk::mesh::Part & phasePart, const PhaseTag & phase) +{ + const std::set & origTouchingBlockOrdinals = my_input_block_surface_connectivity.get_blocks_touching_surface(origPart.mesh_meta_data_ordinal()); + + std::vector phaseTouchingBlocks = meta().get_blocks_touching_surface(&phasePart); + + for (auto && origTouchingBlockOrdinal : origTouchingBlockOrdinals) + { + stk::mesh::Part & origTouchingBlock = meta().get_part(origTouchingBlockOrdinal); + const stk::mesh::Part * phaseTouchingBlock = (phase.empty()) ? find_nonconformal_part(origTouchingBlock) : find_conformal_io_part(origTouchingBlock, phase); + ThrowRequire(phaseTouchingBlock); + + if (std::find(phaseTouchingBlocks.begin(), phaseTouchingBlocks.end(), phaseTouchingBlock) == phaseTouchingBlocks.end()) + phaseTouchingBlocks.push_back(phaseTouchingBlock); + } + + if(krinolog.shouldPrint(LOG_PARTS)) + { + const std::string conformingType = (phase.empty()) ? "Nonconforming" : "Conforming"; + krinolog << conformingType << " surface " << phasePart.name() << " touches blocks "; + for (auto && phaseTouchingBlock : phaseTouchingBlocks) + krinolog << phaseTouchingBlock->name() << " "; + krinolog << "\n"; + } + + meta().set_surface_to_block_mapping(&phasePart, phaseTouchingBlocks); +} + void Phase_Support::build_decomposed_block_surface_connectivity() { + std::set> nonconformingAndOriginalPartOrdinalPairs; + for (auto && part : meta().get_mesh_parts()) { if (part->primary_entity_rank() != meta().side_rank()) continue; const PhasePartTag * phase_part = find_conformal_phase_part(*part); if (nullptr == phase_part) continue; - stk::mesh::Part & orig_part = meta().get_part(phase_part->get_original_part_ordinal()); - if (orig_part == meta().universal_part()) continue; + stk::mesh::Part & origPart = meta().get_part(phase_part->get_original_part_ordinal()); + if (origPart == meta().universal_part()) continue; if (phase_part->is_interface()) { - const stk::mesh::Part * conformal_touching_block = find_conformal_io_part(orig_part, phase_part->get_touching_phase()); + const stk::mesh::Part * conformal_touching_block = find_conformal_io_part(origPart, phase_part->get_touching_phase()); ThrowRequire(conformal_touching_block); - if(krinolog.shouldPrint(LOG_PARTS)) krinolog << "Surface " << part->name() << " touches block " << conformal_touching_block->name() << "\n"; + if(krinolog.shouldPrint(LOG_PARTS)) krinolog << "Interface surface " << part->name() << " touches block " << conformal_touching_block->name() << "\n"; std::vector touching_blocks = meta().get_blocks_touching_surface(part); if (std::find(touching_blocks.begin(), touching_blocks.end(), conformal_touching_block) == touching_blocks.end()) - { touching_blocks.push_back(conformal_touching_block); - } + meta().set_surface_to_block_mapping(part, touching_blocks); } else { - std::set touching_block_ordinals; - get_input_blocks_touching_surface(my_input_block_surface_connectivity, orig_part.mesh_meta_data_ordinal(), touching_block_ordinals); + update_touching_parts_for_phase_part(origPart, *part, phase_part->get_phase()); - for (auto && touching_block_ordinal : touching_block_ordinals) - { - stk::mesh::Part & touching_block = meta().get_part(touching_block_ordinal); - const stk::mesh::Part * conformal_touching_block = find_conformal_io_part(touching_block, phase_part->get_phase()); - ThrowRequire(conformal_touching_block); - if(krinolog.shouldPrint(LOG_PARTS)) krinolog << "Surface " << part->name() << " touches block " << conformal_touching_block->name() << "\n"; - std::vector touching_blocks = meta().get_blocks_touching_surface(part); - if (std::find(touching_blocks.begin(), touching_blocks.end(), conformal_touching_block) == touching_blocks.end()) - { - touching_blocks.push_back(conformal_touching_block); - } - meta().set_surface_to_block_mapping(part, touching_blocks); - } + // store off nonconforming and original parts for second pass below + nonconformingAndOriginalPartOrdinalPairs.emplace(phase_part->get_nonconformal_part_ordinal(), phase_part->get_original_part_ordinal()); } } + + const PhaseTag emptyPhaseToIndicateNoncoformingPart; + for (auto && nonconformingAndOriginalPartOrdinalPair : nonconformingAndOriginalPartOrdinalPairs) + { + const stk::mesh::Part & nonconformingPart = meta().get_part(nonconformingAndOriginalPartOrdinalPair.first); + const stk::mesh::Part & origPart = meta().get_part(nonconformingAndOriginalPartOrdinalPair.second); + update_touching_parts_for_phase_part(origPart, nonconformingPart, emptyPhaseToIndicateNoncoformingPart); + } + + if(krinolog.shouldPrint(LOG_PARTS)) + { + Block_Surface_Connectivity::dump_surface_connectivity(meta()); + } } void @@ -599,7 +623,7 @@ Phase_Support::decompose_blocks(std::vector(ls_set); PhaseVec & ls_phases = std::get<2>(ls_set); - if(std::get<2>(ls_set).empty()) continue; + if(ls_phases.empty()) continue; part_set_vec[i] = get_blocks_and_touching_surfaces(meta(), blocks_to_decompose, my_input_block_surface_connectivity); create_nonconformal_parts(part_set_vec[i]); @@ -881,21 +905,6 @@ Phase_Support::get_blocks_touching_surface(const std::string & surface_name, std } } //-------------------------------------------------------------------------------- -void -Phase_Support::get_input_surfaces_touching_block(const Block_Surface_Connectivity & input_block_surface_connectivity, - const stk::mesh::PartOrdinal block_ordinal, std::set & surface_ordinals) -{ - input_block_surface_connectivity.get_surfaces_touching_block(block_ordinal, surface_ordinals); -} - -//-------------------------------------------------------------------------------- -void -Phase_Support::get_input_blocks_touching_surface(const Block_Surface_Connectivity & input_block_surface_connectivity, - const stk::mesh::PartOrdinal surfaceOrdinal, std::set & blockOrdinals) const -{ - input_block_surface_connectivity.get_blocks_touching_surface(surfaceOrdinal, blockOrdinals); -} -//-------------------------------------------------------------------------------- const stk::mesh::Part * Phase_Support::find_conformal_io_part(const stk::mesh::Part & io_part, const PhaseTag & phase) const { @@ -1006,8 +1015,7 @@ void Phase_Support::register_blocks_for_level_set(const Surface_Identifier level lsUsedByParts_[levelSetIdentifier].insert(block_ptr); // Now get surfaces touching this block - std::set surfaceOrdinals; - get_input_surfaces_touching_block(my_input_block_surface_connectivity, block_ptr->mesh_meta_data_ordinal(), surfaceOrdinals); + const std::set surfaceOrdinals = my_input_block_surface_connectivity.get_surfaces_touching_block(block_ptr->mesh_meta_data_ordinal()); for (auto && surfaceOrdinal : surfaceOrdinals) { // For each surface, add IO Part/Level Set pairing to maps diff --git a/packages/krino/krino/krino_lib/Akri_Phase_Support.hpp b/packages/krino/krino/krino_lib/Akri_Phase_Support.hpp index a0f55b452646..a28b0aead806 100644 --- a/packages/krino/krino/krino_lib/Akri_Phase_Support.hpp +++ b/packages/krino/krino/krino_lib/Akri_Phase_Support.hpp @@ -56,6 +56,9 @@ struct LS_Field class Phase_Support { public: + Phase_Support (const Phase_Support&) = delete; + Phase_Support& operator= (const Phase_Support&) = delete; + typedef std::set PartSet; static bool exists_and_has_phases_defined(const stk::mesh::MetaData & meta); @@ -69,10 +72,6 @@ class Phase_Support { static void check_isovariable_field_existence_on_decomposed_blocks(const stk::mesh::MetaData & meta, const std::vector & lsFields, const bool conformal_parts_require_field); void get_blocks_touching_surface(const std::string & surface_name, std::vector & block_names); - static void get_input_surfaces_touching_block(const Block_Surface_Connectivity & input_block_surface_info, - const stk::mesh::PartOrdinal block_ordinal, std::set & surface_ordinals); - void get_input_blocks_touching_surface(const Block_Surface_Connectivity & input_block_surface_info, - const stk::mesh::PartOrdinal surfaceOrdinal, std::set & blockOrdinals) const; void check_phase_parts() const; @@ -131,6 +130,7 @@ class Phase_Support { const AuxMetaData & aux_meta() const { ThrowAssertMsg(myAuxMeta, "AuxMetaData not yet set on Phase_Support"); return *myAuxMeta; } AuxMetaData & aux_meta() { ThrowAssertMsg(myAuxMeta, "AuxMetaData not yet set on Phase_Support"); return *myAuxMeta; } + void update_touching_parts_for_phase_part(const stk::mesh::Part & origPart, const stk::mesh::Part & phasePart, const PhaseTag & phase); const PhasePartTag * find_conformal_phase_part(const stk::mesh::Part & conformal_part) const; void create_nonconformal_parts(const PartSet & decomposed_ioparts); void addPhasePart(stk::mesh::Part & io_part, PhasePartSet & phase_parts, const NamedPhase & ls_phase); diff --git a/packages/krino/krino/krino_lib/Akri_SharpFeature.cpp b/packages/krino/krino/krino_lib/Akri_SharpFeature.cpp new file mode 100644 index 000000000000..fd41503ed1ae --- /dev/null +++ b/packages/krino/krino/krino_lib/Akri_SharpFeature.cpp @@ -0,0 +1,352 @@ +#include "Akri_SharpFeature.hpp" + +#include +#include + +#include +#include +#include +#include +#include +#include "Akri_AuxMetaData.hpp" +#include "Akri_CDMesh_Utils.hpp" +#include "Akri_Phase_Support.hpp" + +namespace krino { + +uint64_t edge_from_edge_node_offsets(stk::mesh::Entity::entity_value_type edgeNodeOffset0, stk::mesh::Entity::entity_value_type edgeNodeOffset1) +{ + static_assert(std::is_same::value, "stk::mesh::Entity must be 32 bit."); + return (static_cast(edgeNodeOffset1) << 32) + edgeNodeOffset0; +} + +uint64_t edge_from_edge_nodes(const stk::mesh::BulkData & mesh, stk::mesh::Entity edgeNode0, stk::mesh::Entity edgeNode1) +{ + return (mesh.identifier(edgeNode0) < mesh.identifier(edgeNode1)) ? + edge_from_edge_node_offsets(edgeNode0.local_offset(), edgeNode1.local_offset()) : + edge_from_edge_node_offsets(edgeNode1.local_offset(), edgeNode0.local_offset()); +} + +std::array get_edge_nodes(uint64_t edge) +{ + static_assert(std::is_same::value, "stk::mesh::Entity must be 32 bit."); + return std::array{stk::mesh::Entity(edge & 0xFFFFFFFF), stk::mesh::Entity(edge >> 32)}; +} + +std::array get_tet_edges(const stk::mesh::BulkData & mesh, const stk::mesh::Entity element) +{ + StkMeshEntities elementNodes{mesh.begin_nodes(element), mesh.end_nodes(element)}; + return { edge_from_edge_nodes(mesh, elementNodes[0],elementNodes[1]), + edge_from_edge_nodes(mesh, elementNodes[1],elementNodes[2]), + edge_from_edge_nodes(mesh, elementNodes[2],elementNodes[0]), + edge_from_edge_nodes(mesh, elementNodes[3],elementNodes[0]), + edge_from_edge_nodes(mesh, elementNodes[3],elementNodes[1]), + edge_from_edge_nodes(mesh, elementNodes[3],elementNodes[2]) }; +} + +std::array get_tri_edges(const stk::mesh::BulkData & mesh, const stk::mesh::Entity element) +{ + StkMeshEntities elementNodes{mesh.begin_nodes(element), mesh.end_nodes(element)}; + return { edge_from_edge_nodes(mesh, elementNodes[0],elementNodes[1]), + edge_from_edge_nodes(mesh, elementNodes[1],elementNodes[2]), + edge_from_edge_nodes(mesh, elementNodes[2],elementNodes[0]) }; +} + +uint64_t get_segment_edge(const stk::mesh::BulkData & mesh, const stk::mesh::Entity element) +{ + StkMeshEntities elementNodes{mesh.begin_nodes(element), mesh.end_nodes(element)}; + return edge_from_edge_nodes(mesh, elementNodes[0],elementNodes[1]); +} + +void fill_element_edges(const stk::mesh::BulkData & mesh, const unsigned dim, const stk::mesh::Entity element, std::vector & elementEdges) +{ + if (dim == 2) + { + const std::array triEdges = get_tri_edges(mesh, element); + elementEdges.assign(triEdges.begin(), triEdges.end()); + return; + } + + const std::array tetEdges = get_tet_edges(mesh, element); + elementEdges.assign(tetEdges.begin(), tetEdges.end()); +} + +void fill_face_edges(const stk::mesh::BulkData & mesh, const stk::mesh::Entity face, std::vector & sideEdges) +{ + const std::array triEdges = get_tri_edges(mesh, face); + sideEdges.assign(triEdges.begin(), triEdges.end()); +} + +int get_edge_owner(const stk::mesh::BulkData & mesh, const uint64_t edge) +{ + const std::array & edgeNodes = get_edge_nodes(edge); + return std::min(mesh.parallel_owner_rank(edgeNodes[0]), mesh.parallel_owner_rank(edgeNodes[1])); +} + +std::vector get_owned_edges(const stk::mesh::BulkData & mesh, const stk::mesh::Selector & elementSelector) +{ + std::vector edges; + std::vector elementEdges; + const unsigned dim = mesh.mesh_meta_data().spatial_dimension(); + for(const auto & bucketPtr : mesh.get_buckets(stk::topology::ELEMENT_RANK, elementSelector)) + { + for(const auto & elem : *bucketPtr) + { + fill_element_edges(mesh, dim, elem, elementEdges); + for (auto edge : elementEdges) + if (get_edge_owner(mesh, edge) == mesh.parallel_rank()) + edges.push_back(edge); + } + } + + stk::util::sort_and_unique(edges); + + return edges; +} + +static bool does_entity_have_selected_element(const stk::mesh::BulkData & mesh, const stk::mesh::Entity entity, const stk::mesh::Selector & elementSelector) +{ + for (auto && elem : StkMeshEntities{mesh.begin_elements(entity), mesh.end_elements(entity)}) + if (elementSelector(mesh.bucket(elem))) + return true; + return false; +} + +static stk::mesh::Selector +build_side_selector(const stk::mesh::BulkData & mesh) +{ + const AuxMetaData & auxMeta = AuxMetaData::get(mesh.mesh_meta_data()); + const Phase_Support & phaseSupport = Phase_Support::get(mesh.mesh_meta_data()); + const stk::mesh::EntityRank sideRank = mesh.mesh_meta_data().side_rank(); + + stk::mesh::PartVector sideParts; + for (auto && part : mesh.mesh_meta_data().get_parts()) + if (is_part_to_check_for_snapping_compatibility(phaseSupport, auxMeta, sideRank, *part)) + sideParts.push_back(part); + + return stk::mesh::selectUnion(sideParts); +} + +bool edge_has_owned_node(const stk::mesh::BulkData & mesh, const uint64_t edge) +{ + const std::array & edgeNodes = get_edge_nodes(edge); + return mesh.parallel_rank() == mesh.parallel_owner_rank(edgeNodes[0]) || mesh.parallel_rank() == mesh.parallel_owner_rank(edgeNodes[1]); +} + +std::vector get_edges_with_owned_nodes_of_selected_faces(const stk::mesh::BulkData & mesh, const stk::mesh::Selector & elementSelector, const stk::mesh::Selector & sideSelector) +{ + std::vector edges; + std::vector sideEdges; + for(const auto & bucketPtr : mesh.buckets(stk::topology::FACE_RANK)) + { + if (sideSelector(*bucketPtr)) + { + for(const auto & side : *bucketPtr) + { + if (does_entity_have_selected_element(mesh, side, elementSelector)) + { + fill_face_edges(mesh, side, sideEdges); + for (auto edge : sideEdges) + if (edge_has_owned_node(mesh, edge)) + edges.push_back(edge); + } + } + } + } + + stk::util::sort_and_unique(edges); + + return edges; +} + +std::vector get_owned_nodes_of_edges_with_selected_sides(const stk::mesh::BulkData & mesh, const stk::mesh::Selector & elementSelector, const stk::mesh::Selector & sideSelector) +{ + std::vector edgeNodes; + for(const auto & bucketPtr : mesh.buckets(stk::topology::NODE_RANK)) + if (bucketPtr->owned() && sideSelector(*bucketPtr)) + for(const auto & node : *bucketPtr) + if (does_entity_have_selected_element(mesh, node, elementSelector)) + edgeNodes.push_back(node); + return edgeNodes; +} + +bool is_intersection_point_node_compatible_for_snapping_based_on_sharp_features(const SharpFeatureInfo & sharpFeatureInfo, const stk::mesh::Entity intPtNode, const std::vector & intPtNodes) +{ + const SharpFeatureConstraint * constraint = sharpFeatureInfo.get_constraint(intPtNode); + + if (constraint == nullptr) return true; + if (intPtNodes.size() != 2 || constraint->is_pinned()) return false; + + const std::array sharpEdgeNodes = constraint->get_sharp_edge_nodes(); + for (auto && sharpEdgeNode : sharpEdgeNodes) + if (intPtNodes[0] == sharpEdgeNode || intPtNodes[1] == sharpEdgeNode) + return true; + return false; +} + +void SharpFeatureInfo::find_sharp_features(const stk::mesh::BulkData & mesh, const FieldRef coordsField, const stk::mesh::Selector & elementSelector, const double cosFeatureAngle) +{ + const stk::mesh::Selector sideSelector = build_side_selector(mesh); + + if (mesh.mesh_meta_data().spatial_dimension() == 2) + find_sharp_features_2D(mesh, coordsField, elementSelector, sideSelector, cosFeatureAngle); + find_sharp_features_3D(mesh, coordsField, elementSelector, sideSelector, cosFeatureAngle); + + if (krinolog.shouldPrint(LOG_DEBUG)) + { + for (auto && entry : myNodeToConstrainedNeighbors) + { + stk::mesh::Entity node = entry.first; + const SharpFeatureConstraint & constraint = entry.second; + krinolog << "Node " << mesh.identifier(node) << " is "; + if (constraint.is_pinned()) + { + krinolog << "pinned." << stk::diag::dendl; + } + else + { + const std::array sharpEdgeNbrs = constraint.get_sharp_edge_nodes(); + krinolog << "constrained to move along edge between nodes " << mesh.identifier(sharpEdgeNbrs[0]) << " and " << mesh.identifier(sharpEdgeNbrs[1]) << "." << std::endl; + } + } + } +} + +void SharpFeatureInfo::find_sharp_features_3D(const stk::mesh::BulkData & mesh, const FieldRef coordsField, const stk::mesh::Selector & elementSelector, const stk::mesh::Selector & sideSelector, const double cosFeatureAngle) +{ + std::map> nodeToSharpEdgeNeighbors; + const int parallelRank = mesh.parallel_rank(); + + const std::vector edgesWithOwnedNodes = get_edges_with_owned_nodes_of_selected_faces(mesh, elementSelector, sideSelector); + for (auto edge : edgesWithOwnedNodes) + { + if (edge_has_sharp_feature_3D(mesh, coordsField, elementSelector, sideSelector, cosFeatureAngle, edge)) + { + const std::array & edgeNodes = get_edge_nodes(edge); + if (parallelRank == mesh.parallel_owner_rank(edgeNodes[0])) + nodeToSharpEdgeNeighbors[edgeNodes[0]].push_back(edgeNodes[1]); + if (parallelRank == mesh.parallel_owner_rank(edgeNodes[1])) + nodeToSharpEdgeNeighbors[edgeNodes[1]].push_back(edgeNodes[0]); + } + } + + for (auto && entry : nodeToSharpEdgeNeighbors) + if (entry.second.size() == 2) + myNodeToConstrainedNeighbors.insert({entry.first, SharpFeatureConstraint::edge_constraint(entry.second[0], entry.second[1])}); + else if (entry.second.size() > 2) + myNodeToConstrainedNeighbors.insert({entry.first, SharpFeatureConstraint::pinned_constraint()}); +} + +void SharpFeatureInfo::find_sharp_features_2D(const stk::mesh::BulkData & mesh, const FieldRef coordsField, const stk::mesh::Selector & elementSelector, const stk::mesh::Selector & sideSelector, const double cosFeatureAngle) +{ + std::map> nodeToSharpEdgeNeighbors; + + const std::vector ownedSideNodes = get_owned_nodes_of_edges_with_selected_sides(mesh, elementSelector, sideSelector); + for (auto node : ownedSideNodes) + if (node_has_sharp_feature_2D(mesh, coordsField, elementSelector, sideSelector, cosFeatureAngle, node)) + myNodeToConstrainedNeighbors.insert({node, SharpFeatureConstraint::pinned_constraint()}); +} + +void filter_sides_based_on_attached_element_and_side_parts(const stk::mesh::BulkData & mesh, const stk::mesh::Selector & elementSelector, const stk::mesh::Selector & sideSelector, std::vector & sides) +{ + size_t numRetainedSides = 0; + for (auto && side : sides) + if (sideSelector(mesh.bucket(side)) && does_entity_have_selected_element(mesh, side, elementSelector)) + sides[numRetainedSides++] = side; + sides.resize(numRetainedSides); +} + +bool SharpFeatureInfo::angle_is_sharp_between_any_two_sides_2D(const stk::mesh::BulkData & mesh, const FieldRef coordsField, const double cosFeatureAngle, const stk::mesh::Entity node, const std::vector & sidesOfEdge) +{ + if (sidesOfEdge.size() > 1) + { + const Vector3d nodeCoords(field_data(coordsField, node),2); + + std::vector sideVec; + sideVec.reserve(sidesOfEdge.size()); + for (auto && side : sidesOfEdge) + { + StkMeshEntities sideNodes{mesh.begin_nodes(side), mesh.end_nodes(side)}; + ThrowAssertMsg(sideNodes[0] == node || sideNodes[1] == node, "Did not find side node for segment."); + const stk::mesh::Entity sideNode = (sideNodes[1] == node) ? sideNodes[0] : sideNodes[1]; + const Vector3d coordsOfSideNode(field_data(coordsField, sideNode),2); + sideVec.push_back((coordsOfSideNode - nodeCoords).unit_vector()); + } + + for (size_t i=0; i cosFeatureAngle) + return true; + } + return false; +} + +double cosine_of_dihedral_angle_3D(const Vector3d & edgeVec, const Vector3d & faceTangent0, const Vector3d & faceTangent1) +{ + // https://en.wikipedia.org/wiki/Dihedral_angle + const Vector3d crossEdgeFace0 = Cross(edgeVec, faceTangent0); + const Vector3d crossEdgeFace1 = Cross(edgeVec, faceTangent1); + + return Dot(crossEdgeFace0,crossEdgeFace1) / (crossEdgeFace0.length()*crossEdgeFace1.length()); +} + +stk::mesh::Entity get_face_node_not_on_edge(const stk::mesh::BulkData & mesh, const std::array & edgeNodes, const stk::mesh::Entity sideOfEdge) +{ + StkMeshEntities faceNodes{mesh.begin_nodes(sideOfEdge), mesh.end_nodes(sideOfEdge)}; + for (auto && faceNode : faceNodes) + if (faceNode != edgeNodes[0] && faceNode != edgeNodes[1]) + return faceNode; + ThrowRuntimeError("Did not find face node not on edge."); +} + +bool SharpFeatureInfo::angle_is_sharp_between_any_two_sides_3D(const stk::mesh::BulkData & mesh, const FieldRef coordsField, const double cosFeatureAngle, const std::array & edgeNodes, const std::vector & sidesOfEdge) +{ + if (sidesOfEdge.size() > 1) + { + const Vector3d edgeNodeCoords0(field_data(coordsField, edgeNodes[0])); + const Vector3d edgeNodeCoords1(field_data(coordsField, edgeNodes[1])); + const Vector3d edgeVec = edgeNodeCoords1 - edgeNodeCoords0; + + std::vector faceTangent; + faceTangent.reserve(sidesOfEdge.size()); + for (auto && side : sidesOfEdge) + { + const Vector3d coordsOfNonEdgeNodeOfSide(field_data(coordsField, get_face_node_not_on_edge(mesh, edgeNodes, side))); + faceTangent.push_back(coordsOfNonEdgeNodeOfSide - edgeNodeCoords0); + } + + for (size_t i=0; i cosFeatureAngle) + return true; + } + return false; +} + +bool SharpFeatureInfo::edge_has_sharp_feature_3D(const stk::mesh::BulkData & mesh, const FieldRef coordsField, const stk::mesh::Selector & elementSelector, const stk::mesh::Selector & sideSelector, const double cosFeatureAngle, const uint64_t edge) +{ + const std::array & edgeNodes = get_edge_nodes(edge); + std::vector sidesOfEdge; + stk::mesh::get_entities_through_relations(mesh, {edgeNodes[0], edgeNodes[1]}, stk::topology::FACE_RANK, sidesOfEdge); + if (sidesOfEdge.size() > 1) + filter_sides_based_on_attached_element_and_side_parts(mesh, elementSelector, sideSelector, sidesOfEdge); + return angle_is_sharp_between_any_two_sides_3D(mesh, coordsField, cosFeatureAngle, edgeNodes, sidesOfEdge); +} + +bool SharpFeatureInfo::node_has_sharp_feature_2D(const stk::mesh::BulkData & mesh, const FieldRef coordsField, const stk::mesh::Selector & elementSelector, const stk::mesh::Selector & sideSelector, const double cosFeatureAngle, const stk::mesh::Entity node) +{ + std::vector sidesOfEdge(mesh.begin_edges(node), mesh.end_edges(node)); + if (sidesOfEdge.size() > 1) + filter_sides_based_on_attached_element_and_side_parts(mesh, elementSelector, sideSelector, sidesOfEdge); + return angle_is_sharp_between_any_two_sides_2D(mesh, coordsField, cosFeatureAngle, node, sidesOfEdge); +} + +const SharpFeatureConstraint * SharpFeatureInfo::get_constraint(const stk::mesh::Entity node) const +{ + const auto iter = myNodeToConstrainedNeighbors.find(node); + if (iter != myNodeToConstrainedNeighbors.end()) + return &(iter->second); + return nullptr; +} + +} diff --git a/packages/krino/krino/krino_lib/Akri_SharpFeature.hpp b/packages/krino/krino/krino_lib/Akri_SharpFeature.hpp new file mode 100644 index 000000000000..948c44ccb7ef --- /dev/null +++ b/packages/krino/krino/krino_lib/Akri_SharpFeature.hpp @@ -0,0 +1,44 @@ +#ifndef KRINO_KRINO_KRINO_LIB_AKRI_SHARPFEATURE_HPP_ +#define KRINO_KRINO_KRINO_LIB_AKRI_SHARPFEATURE_HPP_ +#include +#include +#include +#include +#include "Akri_FieldRef.hpp" + +namespace krino { + +class SharpFeatureConstraint +{ +public: + bool is_pinned() const { return myConstrainedEdgeNeighbors[0] == invalid_entity() && myConstrainedEdgeNeighbors[1] == invalid_entity(); } + bool is_constrained_on_edge() const { return myConstrainedEdgeNeighbors[0] != invalid_entity() && myConstrainedEdgeNeighbors[1] != invalid_entity(); } + const std::array & get_sharp_edge_nodes() const { ThrowAssert(is_constrained_on_edge()); return myConstrainedEdgeNeighbors; } + static SharpFeatureConstraint edge_constraint(const stk::mesh::Entity entity0, const stk::mesh::Entity entity1) { return SharpFeatureConstraint{entity0, entity1}; } + static SharpFeatureConstraint pinned_constraint() { return SharpFeatureConstraint(invalid_entity(),invalid_entity()); } +private: + static stk::mesh::Entity invalid_entity() { static const stk::mesh::Entity invalidEntity; return invalidEntity; } + SharpFeatureConstraint(const stk::mesh::Entity entity0, const stk::mesh::Entity entity1) : myConstrainedEdgeNeighbors{entity0, entity1} {} + std::array myConstrainedEdgeNeighbors; +}; + +class SharpFeatureInfo +{ +public: + void find_sharp_features(const stk::mesh::BulkData & mesh, const FieldRef coordsField, const stk::mesh::Selector & elementSelector, const double cosFeatureAngle); + const SharpFeatureConstraint * get_constraint(const stk::mesh::Entity node) const; +private: + void find_sharp_features_2D(const stk::mesh::BulkData & mesh, const FieldRef coordsField, const stk::mesh::Selector & elementSelector, const stk::mesh::Selector & sideSelector, const double cosFeatureAngle); + void find_sharp_features_3D(const stk::mesh::BulkData & mesh, const FieldRef coordsField, const stk::mesh::Selector & elementSelector, const stk::mesh::Selector & sideSelector, const double cosFeatureAngle); + static bool edge_has_sharp_feature_3D(const stk::mesh::BulkData & mesh, const FieldRef coordsField, const stk::mesh::Selector & elementSelector, const stk::mesh::Selector & sideSelector, const double cosFeatureAngle, const uint64_t edge); + static bool node_has_sharp_feature_2D(const stk::mesh::BulkData & mesh, const FieldRef coordsField, const stk::mesh::Selector & elementSelector, const stk::mesh::Selector & sideSelector, const double cosFeatureAngle, const stk::mesh::Entity node ); + static bool angle_is_sharp_between_any_two_sides_3D(const stk::mesh::BulkData & mesh, const FieldRef coordsField, const double cosFeatureAngle, const std::array & edgeNodes, const std::vector & sidesOfEdge); + static bool angle_is_sharp_between_any_two_sides_2D(const stk::mesh::BulkData & mesh, const FieldRef coordsField, const double cosFeatureAngle, const stk::mesh::Entity node, const std::vector & sidesOfEdge); + std::map myNodeToConstrainedNeighbors; +}; + +bool is_intersection_point_node_compatible_for_snapping_based_on_sharp_features(const SharpFeatureInfo & sharpFeatureInfo, const stk::mesh::Entity intPtNode, const std::vector & intPtNodes); + +} + +#endif /* KRINO_KRINO_KRINO_LIB_AKRI_SHARPFEATURE_HPP_ */ diff --git a/packages/krino/krino/krino_lib/Akri_Snap.cpp b/packages/krino/krino/krino_lib/Akri_Snap.cpp index cfe55cba430f..0a5f402ef34b 100644 --- a/packages/krino/krino/krino_lib/Akri_Snap.cpp +++ b/packages/krino/krino/krino_lib/Akri_Snap.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -26,18 +27,13 @@ namespace krino { -static void fill_node_locations(const int dim, const FieldRef coordsField, const std::vector & nodes, std::vector & nodeLocations) -{ - nodeLocations.clear(); - for (auto node : nodes) - nodeLocations.emplace_back(field_data(coordsField, node), dim); -} - -static stk::math::Vector3d compute_intersection_point_location(const int dim, const FieldRef coordsField, const IntersectionPoint & intersectionPoint) +static stk::math::Vector3d compute_intersection_point_location( + const int dim, + const FieldRef coordsField, + const std::vector & intPtNodes, + const std::vector & intPtWeights) { - const auto & intPtNodes = intersectionPoint.get_nodes(); - const auto & intPtWeights = intersectionPoint.get_weights(); - stk::math::Vector3d snapLocation{stk::math::Vector3d::ZERO}; + stk::math::Vector3d snapLocation = stk::math::Vector3d::ZERO; for (size_t i=0; i(coordsField, intPtNodes[i]), dim); @@ -46,6 +42,11 @@ static stk::math::Vector3d compute_intersection_point_location(const int dim, co return snapLocation; } +static stk::math::Vector3d compute_intersection_point_location(const int dim, const FieldRef coordsField, const IntersectionPoint & intersectionPoint) +{ + return compute_intersection_point_location(dim, coordsField, intersectionPoint.get_nodes(), intersectionPoint.get_weights()); +} + static void fill_global_ids_of_elements_using_node(const stk::mesh::BulkData & mesh, const stk::mesh::Selector & elementSelector, stk::mesh::Entity node, @@ -159,14 +160,6 @@ static double estimate_quality_of_cutting_intersection_points(const stk::mesh::B return qualityAfterCut; } -static bool parts_are_compatible_for_snapping(const stk::mesh::BulkData & mesh, const AuxMetaData & auxMeta, const Phase_Support & phaseSupport, stk::mesh::Entity node, const std::vector & interpNodes) -{ - for (auto && interpNode : interpNodes) - if (interpNode != node && !parts_are_compatible_for_snapping_when_ignoring_phase(mesh, auxMeta, phaseSupport, node, interpNode)) - return false; - return true; -} - static double get_node_intersection_point_weight(const IntersectionPoint & intersectionPoint, stk::mesh::Entity node) { const std::vector & nodes = intersectionPoint.get_nodes(); @@ -224,15 +217,16 @@ static void sort_intersection_points_for_cutting(const stk::mesh::BulkData & mes static void fill_sorted_intersection_point_indices_for_node_for_domains(const stk::mesh::BulkData & mesh, const FieldRef coordsField, const std::vector & intersectionPoints, - const std::vector & candidatesIntersectionPointIndices, + const std::vector> & nodeIntersectionPointIndicesAndWhichSnapsAllowed, const stk::mesh::Entity node, const std::vector & domains, const bool globalIDsAreParallelConsistent, std::vector & sortedIntersectionPointIndices) { sortedIntersectionPointIndices.clear(); - for (auto && intPtIndex : candidatesIntersectionPointIndices) + for (auto && intPtIndexAndIsSnapAllowed : nodeIntersectionPointIndicesAndWhichSnapsAllowed) { + const size_t intPtIndex = intPtIndexAndIsSnapAllowed.first; if (first_sorted_vector_of_domains_contains_all_domains_in_second_vector(domains, intersectionPoints[intPtIndex].get_sorted_domains())) sortedIntersectionPointIndices.push_back(intPtIndex); } @@ -257,22 +251,63 @@ static std::set get_intersected_elements(const stk::mesh::Bul return intersectedElements; } -static std::map> get_node_to_intersection_point_indices(const stk::mesh::BulkData & mesh, +static std::vector which_intersection_point_nodes_are_compatible_for_snapping_based_on_parts_and_sharp_features(const stk::mesh::BulkData & mesh, + const AuxMetaData & auxMeta, + const Phase_Support & phaseSupport, + const SharpFeatureInfo * sharpFeatureInfo, + const std::vector & intPtNodes) +{ + std::vector whichSnapsAreAllowed = which_intersection_point_nodes_are_compatible_for_snapping(mesh, auxMeta, phaseSupport, intPtNodes); + if (nullptr != sharpFeatureInfo) + { + for (size_t iNode=0; iNode>> mapFromEntityToIntPtIndexAndSnapAllowed; + +static mapFromEntityToIntPtIndexAndSnapAllowed get_node_to_intersection_point_indices_and_which_snaps_allowed(const stk::mesh::BulkData & mesh, + const SharpFeatureInfo * sharpFeatureInfo, const std::vector & intersectionPoints) { - std::map> nodeToInsersectionPointIndices; + const AuxMetaData & auxMeta = AuxMetaData::get(mesh.mesh_meta_data()); + const Phase_Support & phaseSupport = Phase_Support::get(mesh.mesh_meta_data()); + + mapFromEntityToIntPtIndexAndSnapAllowed nodeToIntPtIndicesAndWhichSnapsAllowed; for (size_t intersectionPointIndex=0; intersectionPointIndex whichSnapsAreAllowed = which_intersection_point_nodes_are_compatible_for_snapping_based_on_parts_and_sharp_features(mesh, auxMeta, phaseSupport, sharpFeatureInfo, intPtNodes); + for (size_t iNode=0; iNode, std::map> determine_quality_per_node_per_domain(const stk::mesh::BulkData & mesh, const stk::mesh::Selector & elementSelector, const FieldRef coordsField, const std::vector & intersectionPoints, - const std::map> & nodeToInsersectionPointIndices, + const mapFromEntityToIntPtIndexAndSnapAllowed & nodeToIntPtIndicesAndWhichSnapsAllowed, const QualityMetric &qualityMetric, const bool globalIDsAreParallelConsistent) { @@ -283,18 +318,18 @@ std::map, std::map> determine_quali std::vector elemNodeCoords; std::map, std::map> domainsToNodesToQuality; - for (auto entry : nodeToInsersectionPointIndices) + for (auto entry : nodeToIntPtIndicesAndWhichSnapsAllowed) { stk::mesh::Entity node = entry.first; - const auto nodeIntersectionPointIndices = entry.second; + const auto nodeIntersectionPointIndicesAndWhichSnapsAllowed = entry.second; std::set> nodeIntPtDomains; - for (auto && intPtIndex : nodeIntersectionPointIndices) - nodeIntPtDomains.insert(intersectionPoints[intPtIndex].get_sorted_domains()); + for (auto && intPtIndexAndIsSnapAllowed : nodeIntersectionPointIndicesAndWhichSnapsAllowed) + nodeIntPtDomains.insert(intersectionPoints[intPtIndexAndIsSnapAllowed.first].get_sorted_domains()); for (auto && intPtDomains : nodeIntPtDomains) { - fill_sorted_intersection_point_indices_for_node_for_domains(mesh, coordsField, intersectionPoints, nodeIntersectionPointIndices, node, intPtDomains, globalIDsAreParallelConsistent, sortedIntersectionPointIndices); + fill_sorted_intersection_point_indices_for_node_for_domains(mesh, coordsField, intersectionPoints, nodeIntersectionPointIndicesAndWhichSnapsAllowed, node, intPtDomains, globalIDsAreParallelConsistent, sortedIntersectionPointIndices); const std::set intersectedElements = get_intersected_elements(mesh, elementSelector, intersectionPoints, sortedIntersectionPointIndices); double qualityAfterCut = qualityMetric.get_best_value_for_metric(); @@ -320,13 +355,11 @@ append_snap_infos_from_intersection_points(const stk::mesh::BulkData & mesh, const stk::mesh::Selector & elementSelector, const NodeToCapturedDomainsMap & nodesToCapturedDomains, const std::vector & intersectionPoints, - const std::map> & nodeToInsersectionPointIndices, + const mapFromEntityToIntPtIndexAndSnapAllowed & nodeToIntPtIndicesAndWhichSnapsAllowed, const QualityMetric &qualityMetric, const bool globalIDsAreParallelConsistent, std::vector & snapInfos) { - const AuxMetaData & auxMeta = AuxMetaData::get(mesh.mesh_meta_data()); - const Phase_Support phaseSupport = Phase_Support::get(mesh.mesh_meta_data()); const FieldRef coordsField(mesh.mesh_meta_data().coordinate_field()); const int dim = mesh.mesh_meta_data().spatial_dimension(); std::vector procsThatNeedToKnowAboutThisInfo; @@ -334,24 +367,26 @@ append_snap_infos_from_intersection_points(const stk::mesh::BulkData & mesh, int owner = mesh.parallel_rank(); - const auto domainsToNodesToQuality = determine_quality_per_node_per_domain(mesh, elementSelector, coordsField, intersectionPoints, nodeToInsersectionPointIndices, qualityMetric, globalIDsAreParallelConsistent); + const auto domainsToNodesToQuality = determine_quality_per_node_per_domain(mesh, elementSelector, coordsField, intersectionPoints, nodeToIntPtIndicesAndWhichSnapsAllowed, qualityMetric, globalIDsAreParallelConsistent); - for (auto entry : nodeToInsersectionPointIndices) + for (auto entry : nodeToIntPtIndicesAndWhichSnapsAllowed) { stk::mesh::Entity node = entry.first; - const auto nodeIntersectionPointIndices = entry.second; + const auto nodeIntersectionPointIndicesAndWhichSnapsAllowed = entry.second; if (mesh.bucket(node).owned()) { const stk::math::Vector3d nodeLocation(field_data(coordsField, node), dim); - for (auto && intPtIndex : nodeIntersectionPointIndices) + for (auto && intPtIndexAndIsSnapAllowed : nodeIntersectionPointIndicesAndWhichSnapsAllowed) { + const size_t intPtIndex = intPtIndexAndIsSnapAllowed.first; + const bool isSnapAllowed = intPtIndexAndIsSnapAllowed.second; const IntersectionPoint & intersectionPoint = intersectionPoints[intPtIndex]; + const auto & intPtNodes = intersectionPoint.get_nodes(); - if (domains_already_snapped_to_node_are_also_at_intersection_point(nodesToCapturedDomains, node, intersectionPoint.get_sorted_domains()) && - parts_are_compatible_for_snapping(mesh, auxMeta, phaseSupport, node, intPtNodes)) + if (isSnapAllowed && domains_already_snapped_to_node_are_also_at_intersection_point(nodesToCapturedDomains, node, intersectionPoint.get_sorted_domains())) { const stk::math::Vector3d snapLocation = compute_intersection_point_location(dim, coordsField, intersectionPoint); const double cutQualityEstimate = domainsToNodesToQuality.at(intersectionPoint.get_sorted_domains()).at(mesh.identifier(node)); @@ -382,6 +417,7 @@ append_snap_infos_from_intersection_points(const stk::mesh::BulkData & mesh, std::vector build_snap_infos_from_intersection_points(const stk::mesh::BulkData & mesh, + const SharpFeatureInfo * sharpFeatureInfo, const stk::mesh::Selector & elementSelector, const NodeToCapturedDomainsMap & nodesToCapturedDomains, const std::vector & intersectionPoints, @@ -390,8 +426,8 @@ build_snap_infos_from_intersection_points(const stk::mesh::BulkData & mesh, { std::vector snapInfos; - const auto nodeToInsersectionPointIndices = get_node_to_intersection_point_indices(mesh, intersectionPoints); - append_snap_infos_from_intersection_points(mesh, elementSelector, nodesToCapturedDomains, intersectionPoints, nodeToInsersectionPointIndices, qualityMetric, globalIDsAreParallelConsistent, snapInfos); + const auto nodeToIntPtIndicesAndWhichSnapsAllowed = get_node_to_intersection_point_indices_and_which_snaps_allowed(mesh, sharpFeatureInfo, intersectionPoints); + append_snap_infos_from_intersection_points(mesh, elementSelector, nodesToCapturedDomains, intersectionPoints, nodeToIntPtIndicesAndWhichSnapsAllowed, qualityMetric, globalIDsAreParallelConsistent, snapInfos); return snapInfos; } @@ -649,23 +685,33 @@ static void prune_snap_infos_modified_by_snap_iteration(const stk::mesh::BulkDat snapInfos.erase(snapInfos.begin()+newNumSnapInfos, snapInfos.end()); } -static std::map> get_node_to_intersection_point_indices_for_nodes_that_need_new_snap_infos(const stk::mesh::BulkData & mesh, +static mapFromEntityToIntPtIndexAndSnapAllowed get_node_to_intersection_point_indices_and_which_snaps_allowed_for_nodes_that_need_new_snap_infos(const stk::mesh::BulkData & mesh, + const SharpFeatureInfo * sharpFeatureInfo, const std::vector & intersectionPoints, const std::vector & sortedIdsOfNodesThatNeedNewSnapInfos) { - std::map> nodeToInsersectionPointIndices; + const AuxMetaData & auxMeta = AuxMetaData::get(mesh.mesh_meta_data()); + const Phase_Support & phaseSupport = Phase_Support::get(mesh.mesh_meta_data()); + + mapFromEntityToIntPtIndexAndSnapAllowed nodeToIntPtIndicesAndWhichSnapsAllowed; for (size_t intPtIndex=0; intPtIndex whichSnapsAreAllowed = which_intersection_point_nodes_are_compatible_for_snapping_based_on_parts_and_sharp_features(mesh, auxMeta, phaseSupport, sharpFeatureInfo, intPtNodes); + for (size_t iNode=0; iNode & iterationSortedSnapNodes, const NodeToCapturedDomainsMap & nodesToCapturedDomains, const stk::mesh::Selector & elementSelector, @@ -680,27 +726,35 @@ void update_intersection_points_and_snap_infos_after_snap_iteration(const stk::m prune_snap_infos_modified_by_snap_iteration(mesh, oldToNewIntPts, sortedIdsOfNodesThatNeedNewSnapInfos, snapInfos); - const auto nodeToInsersectionPointIndices = get_node_to_intersection_point_indices_for_nodes_that_need_new_snap_infos(mesh, intersectionPoints, sortedIdsOfNodesThatNeedNewSnapInfos); + const auto nodeToIntPtIndicesAndWhichSnapsAllowed = get_node_to_intersection_point_indices_and_which_snaps_allowed_for_nodes_that_need_new_snap_infos(mesh, sharpFeatureInfo, intersectionPoints, sortedIdsOfNodesThatNeedNewSnapInfos); - append_snap_infos_from_intersection_points(mesh, elementSelector, nodesToCapturedDomains, intersectionPoints, nodeToInsersectionPointIndices, qualityMetric, globalIDsAreParallelConsistent, snapInfos); + append_snap_infos_from_intersection_points(mesh, elementSelector, nodesToCapturedDomains, intersectionPoints, nodeToIntPtIndicesAndWhichSnapsAllowed, qualityMetric, globalIDsAreParallelConsistent, snapInfos); } NodeToCapturedDomainsMap snap_as_much_as_possible_while_maintaining_quality(const stk::mesh::BulkData & mesh, const stk::mesh::Selector & elementSelector, const FieldSet & interpolationFields, const InterfaceGeometry & geometry, - const bool globalIDsAreParallelConsistent) + const bool globalIDsAreParallelConsistent, + const double snappingSharpFeatureAngleInDegrees) {/* %TRACE[ON]% */ Trace trace__("krino::snap_as_much_as_possible_while_maintaining_quality()"); /* %TRACE% */ const ScaledJacobianQualityMetric qualityMetric; size_t iteration{0}; NodeToCapturedDomainsMap nodesToCapturedDomains; stk::ParallelMachine comm = mesh.parallel(); + std::unique_ptr sharpFeatureInfo; + if (snappingSharpFeatureAngleInDegrees > 0.) + { + sharpFeatureInfo = std::make_unique(); + const FieldRef coordsField(mesh.mesh_meta_data().coordinate_field()); + sharpFeatureInfo->find_sharp_features(mesh, coordsField, elementSelector, std::cos(snappingSharpFeatureAngleInDegrees*M_PI/180.)); + } std::vector intersectionPoints; geometry.store_phase_for_uncut_elements(mesh); intersectionPoints = build_all_intersection_points(mesh, geometry, nodesToCapturedDomains); - std::vector snapInfos = build_snap_infos_from_intersection_points(mesh, elementSelector, nodesToCapturedDomains, intersectionPoints, qualityMetric, globalIDsAreParallelConsistent); + std::vector snapInfos = build_snap_infos_from_intersection_points(mesh, sharpFeatureInfo.get(), elementSelector, nodesToCapturedDomains, intersectionPoints, qualityMetric, globalIDsAreParallelConsistent); while (true) { @@ -721,7 +775,7 @@ NodeToCapturedDomainsMap snap_as_much_as_possible_while_maintaining_quality(cons const std::vector iterationSortedSnapNodes = get_sorted_nodes_modified_in_current_snapping_iteration(mesh, independentSnapInfos); - update_intersection_points_and_snap_infos_after_snap_iteration(mesh, geometry, iterationSortedSnapNodes, nodesToCapturedDomains, elementSelector, qualityMetric, globalIDsAreParallelConsistent, intersectionPoints, snapInfos); + update_intersection_points_and_snap_infos_after_snap_iteration(mesh, geometry, sharpFeatureInfo.get(), iterationSortedSnapNodes, nodesToCapturedDomains, elementSelector, qualityMetric, globalIDsAreParallelConsistent, intersectionPoints, snapInfos); } krinolog << "After snapping quality is " << determine_quality(mesh, elementSelector, qualityMetric) << stk::diag::dendl; diff --git a/packages/krino/krino/krino_lib/Akri_Snap.hpp b/packages/krino/krino/krino_lib/Akri_Snap.hpp index f56bc31ecd38..9273b4d6e475 100644 --- a/packages/krino/krino/krino_lib/Akri_Snap.hpp +++ b/packages/krino/krino/krino_lib/Akri_Snap.hpp @@ -22,7 +22,8 @@ NodeToCapturedDomainsMap snap_as_much_as_possible_while_maintaining_quality(cons const stk::mesh::Selector & elementSelector, const FieldSet & interpolationFields, const InterfaceGeometry & geometry, - const bool globalIDsAreParallelConsistent); + const bool globalIDsAreParallelConsistent, + const double snappingSharpFeatureAngleInDegrees); double determine_quality(const stk::mesh::BulkData & mesh, const stk::mesh::Selector & elementSelector, diff --git a/packages/krino/krino/krino_lib/Akri_String_Function_Expression.cpp b/packages/krino/krino/krino_lib/Akri_String_Function_Expression.cpp new file mode 100644 index 000000000000..a843075bacb4 --- /dev/null +++ b/packages/krino/krino/krino_lib/Akri_String_Function_Expression.cpp @@ -0,0 +1,56 @@ +// Copyright 2002 - 2008, 2010, 2011 National Technology Engineering +// Solutions of Sandia, LLC (NTESS). Under the terms of Contract +// DE-NA0003525 with NTESS, the U.S. Government retains certain rights +// in this software. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include +#include +#include + +namespace krino { + +String_Function_Expression::String_Function_Expression(const std::string & expression) +: myEvaluator(*this) +{ + parse(expression); +} + +void String_Function_Expression::parse(const std::string & expression) +{ + try { + myEvaluator.parse(expression); + } + catch (std::runtime_error &x) { + stk::RuntimeDoomedSymmetric() << "In expression '" << expression << "':" << std::endl << x.what() << std::endl; + } +} + +void String_Function_Expression::resolve(stk::expreval::VariableMap::iterator & varIt) +{ + std::string name = (*varIt).first; + + if (!(name).compare("x")) + (*varIt).second->bind(myQueryCoords[0]); + else if (!(name).compare("y")) + (*varIt).second->bind(myQueryCoords[1]); + else if (!(name).compare("z")) + (*varIt).second->bind(myQueryCoords[2]); + else + { + std::ostringstream msg; + msg << " Unable to resolve symbol: " << name; + throw std::runtime_error(msg.str()); + } +} + +double +String_Function_Expression::evaluate(const Vector3d &coord) const +{ + myQueryCoords = coord; + return myEvaluator.evaluate(); +} + +} diff --git a/packages/krino/krino/krino_lib/Akri_String_Function_Expression.hpp b/packages/krino/krino/krino_lib/Akri_String_Function_Expression.hpp new file mode 100644 index 000000000000..0fb916b8842d --- /dev/null +++ b/packages/krino/krino/krino_lib/Akri_String_Function_Expression.hpp @@ -0,0 +1,31 @@ +// Copyright 2002 - 2008, 2010, 2011 National Technology Engineering +// Solutions of Sandia, LLC (NTESS). Under the terms of Contract +// DE-NA0003525 with NTESS, the U.S. Government retains certain rights +// in this software. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef KRINO_KRINO_KRINO_LIB_AKRI_STRING_FUNCTION_EXPRESSION_HPP_ +#define KRINO_KRINO_KRINO_LIB_AKRI_STRING_FUNCTION_EXPRESSION_HPP_ + +#include +#include "Akri_Vec.hpp" + +namespace krino { + +class String_Function_Expression : public stk::expreval::VariableMap::Resolver +{ +public: + String_Function_Expression(const std::string & expression); + void resolve(stk::expreval::VariableMap::iterator & varIt) override; + double evaluate(const Vector3d &coords) const; +private: + void parse(const std::string & expression); + stk::expreval::Eval myEvaluator; + mutable Vector3d myQueryCoords; +}; + +} + +#endif /* KRINO_KRINO_KRINO_LIB_AKRI_STRING_FUNCTION_EXPRESSION_HPP_ */ diff --git a/packages/krino/krino/krino_lib/Akri_Surface.hpp b/packages/krino/krino/krino_lib/Akri_Surface.hpp index 78019bafe845..99b0f5a28ed4 100644 --- a/packages/krino/krino/krino_lib/Akri_Surface.hpp +++ b/packages/krino/krino/krino_lib/Akri_Surface.hpp @@ -27,6 +27,7 @@ enum Surface_Type COMPOSITE_SURFACE, PLANE, RANDOM, + STRING_FUNCTION, FACETED_SURFACE, // Never, ever, ever add an entry after MAX_SURFACE_TYPE. Never. MAX_SURFACE_TYPE diff --git a/packages/krino/krino/krino_lib/Akri_Transformation.cpp b/packages/krino/krino/krino_lib/Akri_Transformation.cpp index d7c563f2a588..a6792173e442 100644 --- a/packages/krino/krino/krino_lib/Akri_Transformation.cpp +++ b/packages/krino/krino/krino_lib/Akri_Transformation.cpp @@ -84,15 +84,10 @@ Transformation::initialize() void Transformation::update( const double time ) const { - if (time == my_last_update) + if (my_last_update > 0. && time == my_last_update) { return; } - if (my_last_update < 0.0) - { - my_last_update = time; - return; - } const double dt = time - my_last_update; const Vector3d update_rotation_angle = dt*my_rotational_velocity; diff --git a/packages/krino/krino/krino_lib/Akri_Transformation.hpp b/packages/krino/krino/krino_lib/Akri_Transformation.hpp index 602d3510af9a..a0337f4a2fe4 100644 --- a/packages/krino/krino/krino_lib/Akri_Transformation.hpp +++ b/packages/krino/krino/krino_lib/Akri_Transformation.hpp @@ -30,7 +30,7 @@ class Transformation { public: Transformation() : my_translational_velocity(Vector3d::ZERO), my_rotational_velocity(Vector3d::ZERO), - my_reference_point(Vector3d::ZERO), my_last_update(-1.0), my_update_orientation(), my_update_offset(Vector3d::ZERO) {} + my_reference_point(Vector3d::ZERO), my_last_update(0.0), my_update_orientation(), my_update_offset(Vector3d::ZERO) {} virtual ~Transformation() {} void set_translational_velocity(const Vector3d & v) { my_translational_velocity = v; } diff --git a/packages/krino/krino/krino_lib/Akri_VolumePreservingSnappingLimiter.cpp b/packages/krino/krino/krino_lib/Akri_VolumePreservingSnappingLimiter.cpp new file mode 100644 index 000000000000..7158d09d99a8 --- /dev/null +++ b/packages/krino/krino/krino_lib/Akri_VolumePreservingSnappingLimiter.cpp @@ -0,0 +1,102 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace krino { + +static void replace_coordinates_of_node_with_new_location(const StkMeshEntities & elemNodes, const stk::mesh::Entity node, const Vector3d & newNodeLoc, std::vector & elemNodeCoords) +{ + for (size_t n=0; n & elemNodeCoordsWorkspace) +{ + const int dim = mesh.mesh_meta_data().spatial_dimension(); + + StkMeshEntities nodeElements{mesh.begin_elements(node), mesh.end_elements(node)}; + + double volumeBefore = 0.; + double volumeAfter = 0.; + unsigned numElements = 0.; + for (auto elem : nodeElements) + { + if (elementToBlockConverter(mesh, elem) == &blockPart) + { + ++numElements; + StkMeshEntities elemNodes{mesh.begin_nodes(elem), mesh.end_nodes(elem)}; + fill_node_locations(dim, coordsField, elemNodes, elemNodeCoordsWorkspace); + + volumeBefore += compute_tri_or_tet_volume(elemNodeCoordsWorkspace); + replace_coordinates_of_node_with_new_location(elemNodes, node, newNodeLoc, elemNodeCoordsWorkspace); + volumeAfter += compute_tri_or_tet_volume(elemNodeCoordsWorkspace); + } + } + + if (0 == numElements) + return 0.; + + const double elemAverageVol = std::max(volumeBefore,volumeAfter)/numElements; + return std::abs(volumeAfter-volumeBefore)/elemAverageVol; +} + +VolumePreservingSnappingLimiter::VolumePreservingSnappingLimiter( + const stk::mesh::BulkData & mesh, + const FieldRef coordsField, + const ElementToBlockConverter & elementToBlockConverter, + const double volumeConservationTol) + : myMesh(mesh), + myAuxMeta(AuxMetaData::get(mesh.mesh_meta_data())), + myElementToBlockConverter(elementToBlockConverter), + myCoordsField(coordsField), + myVolumeConservationTol(volumeConservationTol) +{ +} + +std::set VolumePreservingSnappingLimiter::get_blocks_to_consider(const stk::mesh::Entity node) const +{ + std::set blocksToConsider; + for (auto && elem : StkMeshEntities{myMesh.begin_elements(node), myMesh.end_elements(node)}) + { + stk::mesh::Part * blockPart = myElementToBlockConverter(myMesh, elem); + if (nullptr != blockPart) + blocksToConsider.insert(blockPart); + } + return blocksToConsider; +} + +bool VolumePreservingSnappingLimiter::is_snap_allowed(const stk::mesh::Entity node, const Vector3d & snapLocation) const +{ + const std::set blocksToConsider = get_blocks_to_consider(node); + if (blocksToConsider.size() == 1 && !myMesh.bucket(node).member(myAuxMeta.exposed_boundary_part())) + return true; + + std::vector elemNodeCoords; + for (auto && blockPart : blocksToConsider) + { + const double volChange = compute_relative_volume_change(myMesh, myCoordsField, myElementToBlockConverter, *blockPart, node, snapLocation, elemNodeCoords); + if (volChange > myVolumeConservationTol) + return false; + } + return true; +} + +} diff --git a/packages/krino/krino/krino_lib/Akri_VolumePreservingSnappingLimiter.hpp b/packages/krino/krino/krino_lib/Akri_VolumePreservingSnappingLimiter.hpp new file mode 100644 index 000000000000..8b1d373e66ed --- /dev/null +++ b/packages/krino/krino/krino_lib/Akri_VolumePreservingSnappingLimiter.hpp @@ -0,0 +1,36 @@ +#ifndef KRINO_KRINO_KRINO_LIB_AKRI_VOLUMEPRESERVINGSNAPPINGLIMITER_HPP_ +#define KRINO_KRINO_KRINO_LIB_AKRI_VOLUMEPRESERVINGSNAPPINGLIMITER_HPP_ +#include +#include +#include +#include + +namespace krino { + +class AuxMetaData; + +class VolumePreservingSnappingLimiter +{ +public: + typedef std::function ElementToBlockConverter; + + VolumePreservingSnappingLimiter( + const stk::mesh::BulkData & mesh, + const FieldRef coordsField, + const ElementToBlockConverter & elementToBlockConverter, + const double volumeConservationTol); + bool is_snap_allowed(const stk::mesh::Entity node, const Vector3d & snapLocation) const; +private: + std::set get_blocks_to_consider(const stk::mesh::Entity node) const; + const stk::mesh::BulkData & myMesh; + const AuxMetaData & myAuxMeta; + ElementToBlockConverter myElementToBlockConverter; + FieldRef myCoordsField; + double myVolumeConservationTol; +}; + +} + + + +#endif /* KRINO_KRINO_KRINO_LIB_AKRI_VOLUMEPRESERVINGSNAPPINGLIMITER_HPP_ */ diff --git a/packages/krino/krino/parser/Akri_Surface_Parser.cpp b/packages/krino/krino/parser/Akri_Surface_Parser.cpp index 2d4e3b288f60..b48be642ede5 100644 --- a/packages/krino/krino/parser/Akri_Surface_Parser.cpp +++ b/packages/krino/krino/parser/Akri_Surface_Parser.cpp @@ -295,6 +295,34 @@ parse_mesh_surface(const Parser::Node & ic_node, const stk::mesh::MetaData & met return new MeshSurface(meta, *coords, surface_selector, sign); } +LevelSet_String_Function * +parse_string_function(const Parser::Node & ic_node) +{ + std::string expression; + if (!ic_node.get_if_present("expression", expression)) + { + stk::RuntimeDoomedAdHoc() << "Missing expression for string_function.\n"; + } + + LevelSet_String_Function * surf = new LevelSet_String_Function(expression); + + std::vector bounds; + if (ic_node.get_if_present("bounding_box", bounds)) + { + if (bounds.size() == 6) + { + const BoundingBox surfBbox( Vector3d(bounds[0],bounds[1],bounds[2]), Vector3d(bounds[3],bounds[4],bounds[5]) ); + surf->set_bounding_box(surfBbox); + } + else + { + stk::RuntimeDoomedAdHoc() << "bounding_box for string_function must be a vector of length 6 (for both 2D or 3D) (xmin,ymin,zmin, xmax,ymax,zmax).\n"; + } + } + + return surf; +} + } Surface * @@ -308,10 +336,6 @@ Surface_Parser::parse(const Parser::Node & parserNode, const stk::mesh::MetaData { return new Random(0); } - else if (ic_type == "analytic_isosurface") - { - return new Analytic_Isosurface(); - } return nullptr; } @@ -331,6 +355,10 @@ Surface_Parser::parse(const Parser::Node & parserNode, const stk::mesh::MetaData { return parse_cylinder(parserNode); } + else if ( parserNode.get_null_if_present("string_function") ) + { + return parse_string_function(parserNode); + } else if ( parserNode.get_null_if_present("facets") ) { return parse_facets(parserNode, parentTimer); diff --git a/packages/krino/krino/rebalance_utils/Akri_RebalanceUtils.cpp b/packages/krino/krino/rebalance_utils/Akri_RebalanceUtils.cpp index 54d9b87504f3..c2a869e8e29f 100644 --- a/packages/krino/krino/rebalance_utils/Akri_RebalanceUtils.cpp +++ b/packages/krino/krino/rebalance_utils/Akri_RebalanceUtils.cpp @@ -30,7 +30,7 @@ class MultipleCriteriaSettings : public stk::balance::GraphCreationSettings m_critFields(critFields), m_defaultWeight(default_weight) { - method = "rcb"; + m_method = "rcb"; setUseNodeBalancer(true); setNodeBalancerTargetLoadBalance(getImbalanceTolerance()); setNodeBalancerMaxIterations(max_num_nodal_rebal_iters); @@ -46,8 +46,8 @@ class MultipleCriteriaSettings : public stk::balance::GraphCreationSettings virtual bool includeSearchResultsInGraph() const override { return false; } virtual int getGraphVertexWeight(stk::topology type) const override { return 1; } virtual double getImbalanceTolerance() const override { return 1.05; } - virtual void setDecompMethod(const std::string & input_method) override { method = input_method; } - virtual std::string getDecompMethod() const override { return method; } + virtual void setDecompMethod(const std::string & input_method) override { m_method = input_method; } + virtual std::string getDecompMethod() const override { return m_method; } virtual int getNumCriteria() const override { return m_critFields.size(); } virtual bool isMultiCriteriaRebalance() const override { return true; } virtual bool shouldFixMechanisms() const override { return false; } diff --git a/packages/krino/krino/region/Akri_Region.cpp b/packages/krino/krino/region/Akri_Region.cpp index ae4b9bf3eb3f..99f5a7883752 100644 --- a/packages/krino/krino/region/Akri_Region.cpp +++ b/packages/krino/krino/region/Akri_Region.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -58,6 +59,7 @@ Region::Region(Simulation & owning_simulation, const std::string & regionName) { /* %TRACE[ON]% */ Trace trace__("krino::Region::Region()"); /* %TRACE% */ my_simulation.add_region(this); myIOBroker = std::make_unique(stk::EnvData::parallel_comm()); + myIOBroker->use_simple_fields(); std::vector entity_rank_names = stk::mesh::entity_rank_names(); entity_rank_names.push_back("FAMILY_TREE"); @@ -156,8 +158,9 @@ void Region::commit() } else { - auto shared_bulk = std::make_shared(meta,stk::EnvData::parallel_comm(),auto_aura_option); + std::shared_ptr shared_bulk = stk::mesh::MeshBuilder(stk::EnvData::parallel_comm()).set_aura_option(auto_aura_option).create(std::shared_ptr(&meta,[](auto ptrWeWontDelete){})); my_bulk = shared_bulk.get(); + my_bulk->mesh_meta_data().use_simple_fields(); stk_IO().set_bulk_data( shared_bulk ); stk_IO().populate_bulk_data(); } @@ -630,7 +633,7 @@ Region::associate_input_mesh(const std::string & model_name, bool assert_32bit_i entity_rank_names.push_back("FAMILY_TREE"); my_generated_mesh = std::make_unique(generated_mesh_element_type,entity_rank_names); my_meta = &my_generated_mesh->meta_data(); - stk::mesh::Field & coords_field = my_meta->declare_field>(stk::topology::NODE_RANK, "coordinates", 1); + stk::mesh::Field & coords_field = my_meta->declare_field(stk::topology::NODE_RANK, "coordinates", 1); stk::mesh::put_field_on_mesh(coords_field, my_meta->universal_part(), generated_mesh_element_type.dimension(), nullptr); } else diff --git a/packages/krino/krino/unit_tests/Akri_StkMeshBuilder.cpp b/packages/krino/krino/unit_tests/Akri_StkMeshBuilder.cpp new file mode 100644 index 000000000000..48df04185151 --- /dev/null +++ b/packages/krino/krino/unit_tests/Akri_StkMeshBuilder.cpp @@ -0,0 +1,353 @@ +#include "Akri_StkMeshBuilder.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include "../../../stk/stk_mesh/stk_mesh/base/SkinBoundary.hpp" + +namespace krino +{ + +template +StkMeshBuilder::StkMeshBuilder(stk::mesh::BulkData & mesh, const stk::ParallelMachine comm) +: mMesh(mesh), mAuxMeta(AuxMetaData::create(mesh.mesh_meta_data())), mPhaseSupport(Phase_Support::get(mesh.mesh_meta_data())), mComm(comm) +{ + declare_coordinates(); + mMesh.mesh_meta_data().use_simple_fields(); +} + +template +void StkMeshBuilder::declare_coordinates() +{ + stk::mesh::Field & coordsField = mMesh.mesh_meta_data().template declare_field( + stk::topology::NODE_RANK, "coordinates", 1u); + stk::mesh::put_field_on_entire_mesh(coordsField, DIM); + stk::io::set_field_role(coordsField, Ioss::Field::MESH); +} + +template +void StkMeshBuilder::create_block_parts(const std::vector &elementBlockIDs) +{ + stk::topology simplexTopology = ((DIM == 2) ? stk::topology::TRIANGLE_3_2D : stk::topology::TETRAHEDRON_4); + + for (unsigned blockId : elementBlockIDs) + { + const std::string blockName = "block_"+std::to_string(blockId); + stk::mesh::Part &part = mMesh.mesh_meta_data().declare_part_with_topology(blockName, simplexTopology); + mMesh.mesh_meta_data().set_part_id(part, blockId); + stk::io::put_io_part_attribute(part); + } +} + +std::string get_surface_name(const unsigned sidesetId) +{ + const std::string surfaceName = "surface_"+std::to_string(sidesetId); + return surfaceName; +} + +template +void StkMeshBuilder::create_sideset_parts(const std::vector &sidesetIds) +{ + for (unsigned sidesetId : sidesetIds) + { + stk::mesh::Part &sidesetPart = mMesh.mesh_meta_data().declare_part(get_surface_name(sidesetId), mMesh.mesh_meta_data().side_rank()); + mMesh.mesh_meta_data().set_part_id(sidesetPart, sidesetId); + stk::io::put_io_part_attribute(sidesetPart); + } +} + +std::vector convert_vector_of_vector_of_sideset_ids_to_parts(const stk::mesh::MetaData & meta, const std::vector>& vectorOfVectorsOfSidesetIds) +{ + std::vector addParts {vectorOfVectorsOfSidesetIds.size()}; + for(size_t i {0}; i < vectorOfVectorsOfSidesetIds.size(); ++i) + { + addParts[i].reserve(vectorOfVectorsOfSidesetIds[i].size()); + for(const size_t sidesetId : vectorOfVectorsOfSidesetIds[i]) + addParts[i].push_back(meta.get_part(get_surface_name(sidesetId))); + } + return addParts; +} + +template +void StkMeshBuilder::add_sides_to_sidesets(const std::vector &sides, const std::vector> &sidesetIdsPerSide) +{ + ThrowRequireWithSierraHelpMsg(sides.size() == sidesetIdsPerSide.size()); + const std::vector addParts = convert_vector_of_vector_of_sideset_ids_to_parts(mMesh.mesh_meta_data(), sidesetIdsPerSide); + const std::vector remParts(sidesetIdsPerSide.size(), stk::mesh::PartVector{}); + mMesh.batch_change_entity_parts(sides, addParts, remParts); +} + +template +stk::mesh::Entity StkMeshBuilder::get_side_with_nodes(const std::vector &nodesOfSide) const +{ + std::vector sidesWithNodes; + + stk::mesh::get_entities_through_relations(mMesh, nodesOfSide, mMesh.mesh_meta_data().side_rank(), sidesWithNodes); + ThrowRequireMsg(sidesWithNodes.size() == 1, "Expected to find one side with nodes, but found " << sidesWithNodes.size()); + return sidesWithNodes[0]; +} + +template +void StkMeshBuilder::set_node_coordinates(const stk::mesh::Entity node, const stk::math::Vector3d &newLoc) +{ + double* node_coords = (double*)stk::mesh::field_data(*mMesh.mesh_meta_data().coordinate_field(), node); + node_coords[0] = newLoc[0]; + node_coords[1] = newLoc[1]; + if (mMesh.mesh_meta_data().spatial_dimension() == 3) node_coords[2] = newLoc[2]; +} + +template +stk::mesh::Entity StkMeshBuilder::create_node(const stk::math::Vector3d &loc, const std::vector &sharingProcs, stk::mesh::EntityId nodeId) +{ + stk::mesh::Entity node = mMesh.declare_node(nodeId); + + int proc = mMesh.parallel_rank(); + for(int sharingProc : sharingProcs) + { + if ( sharingProc != proc) + mMesh.add_node_sharing(node, sharingProc); + } + + set_node_coordinates(node, loc); + return node; +} + +stk::mesh::Part * get_block_part(const stk::mesh::MetaData &meta, const unsigned blockId) +{ + stk::mesh::Part *blockPart{nullptr}; + for (stk::mesh::Part * part : meta.get_parts()) + { + if (part->primary_entity_rank() == stk::topology::ELEM_RANK && (unsigned)part->id() == blockId) + { + blockPart = part; + break; + } + } + ThrowRequireMsg(blockPart!=nullptr, "Can't find a block with id " << blockId); + return blockPart; +} + +template +void StkMeshBuilder::create_boundary_sides() +{ + stk::mesh::create_exposed_block_boundary_sides(mMesh, mMesh.mesh_meta_data().universal_part(), {&mAuxMeta.exposed_boundary_part()}); +} + +template +bool StkMeshBuilder::check_boundary_sides() const +{ + return stk::mesh::check_exposed_block_boundary_sides(mMesh, mMesh.mesh_meta_data().universal_part(), mAuxMeta.exposed_boundary_part()); +} + +template +void StkMeshBuilder::create_block_boundary_sides() +{ + stk::mesh::create_exposed_block_boundary_sides(mMesh, mMesh.mesh_meta_data().universal_part(), {&mAuxMeta.block_boundary_part()}); +} + +template +bool StkMeshBuilder::check_block_boundary_sides() const +{ + return stk::mesh::check_interior_block_boundary_sides(mMesh, mMesh.mesh_meta_data().universal_part(), mAuxMeta.block_boundary_part()); +} + +template +stk::mesh::Entity StkMeshBuilder::create_element(const std::vector &nodes, stk::mesh::EntityId elementId, unsigned blockId) +{ + const stk::mesh::Part *blockPart = get_block_part(mMesh.mesh_meta_data(), blockId); + stk::mesh::Entity element = mMesh.declare_element(elementId, stk::mesh::ConstPartVector{blockPart}); + unsigned idx = 0; + for (auto nd : nodes) + mMesh.declare_relation(element, stk::mesh::Entity(nd), idx++); + return element; +} + +template +std::vector +StkMeshBuilder::create_parallel_nodes(const std::vector>& nodeLocs, + const std::map> &nodeIndicesWithSharingProcs, + const std::vector & assignedGlobalNodeIdsforAllNodes) +{ + std::vector nodesWhichAreValidIfTheyExistOnProc(nodeLocs.size(), stk::mesh::Entity()); + int curProc = stk::parallel_machine_rank(mComm); + for(auto &nodeIndexWithSharingProcs : nodeIndicesWithSharingProcs) + { + unsigned nodeIndex = nodeIndexWithSharingProcs.first; + stk::mesh::EntityId nodeGlobalId = assignedGlobalNodeIdsforAllNodes[nodeIndex]; + const std::vector &sharingProcs = nodeIndexWithSharingProcs.second; + + if (std::find(sharingProcs.begin(), sharingProcs.end(), curProc) != sharingProcs.end() ) + nodesWhichAreValidIfTheyExistOnProc[nodeIndex] = create_node(stk::math::Vector3d{nodeLocs[nodeIndex].data(),DIM}, + sharingProcs, + nodeGlobalId); + } + return nodesWhichAreValidIfTheyExistOnProc; +} + +std::vector get_ids_available_for_rank(stk::mesh::BulkData & mesh, stk::mesh::EntityRank rank, size_t numRequested) +{ + stk::mesh::EntityIdVector requestedIds; + mesh.generate_new_ids(rank, numRequested, requestedIds); + std::vector idsToReturn(requestedIds.begin(), requestedIds.end()); + std::reverse(idsToReturn.begin(), idsToReturn.end()); + return idsToReturn; +} + +template +std::vector +StkMeshBuilder::create_parallel_elements(const std::vector> &elementConn, + const std::vector &elementBlockIDs, + const std::vector &elementProcOwners, + const std::vector& nodesWhichAreValidIfTheyExistOnProc) +{ + const int proc = stk::parallel_machine_rank(mComm); + + size_t numOwnedElements = 0; + for (int elemProc : elementProcOwners) + if (elemProc == proc) ++numOwnedElements; + + std::vector elementIds = get_ids_available_for_rank(mMesh, stk::topology::ELEM_RANK, numOwnedElements); + + std::vector ownedElems; + for (size_t iElem=0; iElem oneElementConnWithLocalIds(NPE); + for(unsigned i = 0; i < NPE; i++) + oneElementConnWithLocalIds[i] = nodesWhichAreValidIfTheyExistOnProc[elementConn[iElem][i]]; + + stk::mesh::Entity elem = create_element(oneElementConnWithLocalIds, elementId, elementBlockIDs[iElem]); + ownedElems.push_back(elem); + } + } + + return ownedElems; +} + +template +std::map> +StkMeshBuilder::build_node_sharing_procs(const std::vector> &elementConn, + const std::vector &elementProcOwners) const +{ + std::map> nodeIndicesWithSharingProcs; + for (size_t iElem=0; iElem +std::map> +StkMeshBuilder::build_node_sharing_procs_for_all_nodes_on_all_procs(const unsigned numNodes, const unsigned numProcs) const +{ + std::map> nodeIndicesWithSharingProcs; + for (unsigned iNode{0}; iNode +void StkMeshBuilder::build_mesh(const std::vector> &nodeLocs, + const std::vector>> &elementConnPerProc, + const unsigned blockId) +{ + ThrowRequireWithSierraHelpMsg(elementConnPerProc.size() == (size_t)stk::parallel_machine_size(mComm)); + std::vector> elementConn; + std::vector elementBlockIDs; + std::vector elementProcOwners; + for (unsigned proc=0; proc +void StkMeshBuilder::build_mesh(const std::vector> &nodeLocs, + const std::vector> &elementConn, + const std::vector &elementBlockIDs, + const std::vector &specifiedElementProcOwners) +{ + std::vector allBlockIDs = elementBlockIDs; + stk::util::sort_and_unique(allBlockIDs); + + build_mesh_with_all_needed_block_ids(nodeLocs, elementConn, elementBlockIDs, allBlockIDs, specifiedElementProcOwners); +} + +template +void StkMeshBuilder::build_mesh_nodes_and_elements( + const std::vector> &nodeLocs, + const std::vector> &elementConn, + const std::vector &elementBlockIDs, + const std::vector &specifiedElementProcOwners +) +{ + create_block_parts(elementBlockIDs); + + const size_t numGlobalElems = elementConn.size(); + std::vector elementProcOwners = specifiedElementProcOwners; + if (elementProcOwners.empty()) // Put all elements on proc 0 if called with empty specifiedElementProcOwners + elementProcOwners.assign(numGlobalElems, 0); + + ThrowRequireWithSierraHelpMsg(elementBlockIDs.size() == numGlobalElems); + ThrowRequireWithSierraHelpMsg(elementProcOwners.size() == numGlobalElems); + + mAssignedGlobalNodeIdsforAllNodes.resize(nodeLocs.size()); + for (unsigned iNode=0; iNode> nodeIndicesWithSharingProcs = + (0 == numGlobalElems) ? + build_node_sharing_procs_for_all_nodes_on_all_procs(nodeLocs.size(), stk::parallel_machine_size(mComm)) : + build_node_sharing_procs(elementConn, elementProcOwners); + + mMesh.modification_begin(); + const auto nodeHandlesWhichAreValidForNodesThatExistOnProc = create_parallel_nodes(nodeLocs, nodeIndicesWithSharingProcs, mAssignedGlobalNodeIdsforAllNodes); + mOwnedElems = create_parallel_elements(elementConn, elementBlockIDs, elementProcOwners, nodeHandlesWhichAreValidForNodesThatExistOnProc); + mMesh.modification_end(); + + create_boundary_sides(); + create_block_boundary_sides(); +} + +template +void StkMeshBuilder::build_mesh_with_all_needed_block_ids +( + const std::vector> &nodeLocs, + const std::vector> &elementConn, + const std::vector &elementBlockIDs, + const std::vector &allBlocksIncludingThoseThatDontHaveElements, + const std::vector &specifiedElementProcOwners +) +{ + build_mesh_nodes_and_elements(nodeLocs, elementConn, elementBlockIDs, specifiedElementProcOwners); +} + +// Explicit template instantiation +template class StkMeshBuilder<2>; +template class StkMeshBuilder<3>; + +} diff --git a/packages/krino/krino/unit_tests/Akri_StkMeshBuilder.hpp b/packages/krino/krino/unit_tests/Akri_StkMeshBuilder.hpp new file mode 100644 index 000000000000..530ef55b30f8 --- /dev/null +++ b/packages/krino/krino/unit_tests/Akri_StkMeshBuilder.hpp @@ -0,0 +1,91 @@ +#ifndef KRINO_KRINO_UNIT_TESTS_AKRI_STKMESHBUILDER_HPP_ +#define KRINO_KRINO_UNIT_TESTS_AKRI_STKMESHBUILDER_HPP_ +#include +#include +#include + +namespace krino { + +class AuxMetaData; +class Phase_Support; + +template +class StkMeshBuilder +{ +public: + static constexpr int NPE = DIM+1; + + StkMeshBuilder(stk::mesh::BulkData & mesh, const stk::ParallelMachine comm); + + void build_mesh(const std::vector> &nodeLocs, + const std::vector>> &elemConnPerProc, + const unsigned blockId=1u); + + void build_mesh(const std::vector> &nodeLocs, + const std::vector> &elementConn, + const std::vector &elementBlockIDs, + const std::vector &specifiedElementProcOwners = {}); + + void build_mesh_with_all_needed_block_ids(const std::vector> &nodeLocs, + const std::vector> &elementConn, + const std::vector &elementBlockIDs, + const std::vector &allBlocksIncludingThoseThatDontHaveElements, + const std::vector &specifiedElementProcOwners); + + void build_mesh_nodes_and_elements( + const std::vector> &nodeLocs, + const std::vector> &elementConn, + const std::vector &elementBlockIDs, + const std::vector &specifiedElementProcOwners); + + const std::vector & get_owned_elements() const { return mOwnedElems; } + const std::vector & get_assigned_node_global_ids() const { return mAssignedGlobalNodeIdsforAllNodes; } + + bool check_boundary_sides() const; + bool check_block_boundary_sides() const; + + AuxMetaData & get_aux_meta() { return mAuxMeta; } + const AuxMetaData & get_aux_meta() const { return mAuxMeta; } + + Phase_Support & get_phase_support() { return mPhaseSupport; } + const Phase_Support & get_phase_support() const { return mPhaseSupport; } + + void create_sideset_parts(const std::vector &sidesetIds); + void add_sides_to_sidesets(const std::vector &sides, const std::vector> &sidesetIdsPerSide); + stk::mesh::Entity get_side_with_nodes(const std::vector &nodesOfSide) const; + void create_block_parts(const std::vector &elementBlockIDs); + +private: + stk::mesh::BulkData & mMesh; + AuxMetaData & mAuxMeta; + Phase_Support & mPhaseSupport; + const stk::ParallelMachine mComm; + std::vector mAssignedGlobalNodeIdsforAllNodes; + std::vector mOwnedElems; + + void declare_coordinates(); + + void create_boundary_sides(); + void create_block_boundary_sides(); + + void set_node_coordinates(const stk::mesh::Entity node, const stk::math::Vector3d &newLoc); + stk::mesh::Entity create_node(const stk::math::Vector3d &loc, const std::vector &sharingProcs, stk::mesh::EntityId nodeId); + stk::mesh::Entity create_element(const std::vector &nodes, stk::mesh::EntityId elementId, unsigned blockId); + + std::vector create_parallel_nodes(const std::vector>& nodeLocs, + const std::map> &nodeIndicesWithSharingProcs, + const std::vector & assignedGlobalNodeIdsforAllNodes); + + std::vector create_parallel_elements(const std::vector> &elementConn, + const std::vector &elementBlockIDs, + const std::vector &elementProcOwners, + const std::vector& nodesWhichAreValidIfTheyExistOnProc); + + std::map> build_node_sharing_procs(const std::vector> &elementConn, + const std::vector &elementProcOwners) const; + std::map> build_node_sharing_procs_for_all_nodes_on_all_procs(const unsigned numNodes, const unsigned numProcs) const; +}; + +} + +#endif /* KRINO_KRINO_UNIT_TESTS_AKRI_STKMESHBUILDER_HPP_ */ diff --git a/packages/krino/krino/unit_tests/Akri_StkMeshFixture.hpp b/packages/krino/krino/unit_tests/Akri_StkMeshFixture.hpp new file mode 100644 index 000000000000..3403920ce406 --- /dev/null +++ b/packages/krino/krino/unit_tests/Akri_StkMeshFixture.hpp @@ -0,0 +1,62 @@ +#ifndef KRINO_KRINO_UNIT_TESTS_AKRI_STKMESHFIXTURE_HPP_ +#define KRINO_KRINO_UNIT_TESTS_AKRI_STKMESHFIXTURE_HPP_ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace krino +{ + +template +class StkMeshFixture : public ::testing::Test +{ +protected: + static constexpr int NPE = DIM+1; + static constexpr unsigned theBlockId = 1; + const stk::ParallelMachine mComm = MPI_COMM_WORLD; + const int mProc{stk::parallel_machine_rank(mComm)}; + std::unique_ptr mMeshPtr{stk::mesh::MeshBuilder(mComm).set_spatial_dimension(DIM).create()}; + stk::mesh::BulkData & mMesh{*mMeshPtr}; + StkMeshBuilder mBuilder{mMesh, mComm}; + + const std::vector & get_assigned_node_global_ids() const { return mBuilder.get_assigned_node_global_ids(); } + stk::mesh::Entity get_assigned_node_for_index(const size_t nodeIndex) const { return mMesh.get_entity(stk::topology::NODE_RANK, get_assigned_node_global_ids()[nodeIndex]); } + const std::vector & get_owned_elements() const { return mBuilder.get_owned_elements(); } + + template + void build_full_np1_mesh(const MeshSpecType &meshSpec) + { + build_mesh(meshSpec.mNodeLocs, {meshSpec.mAllTetConn}); + } + + void build_mesh(const std::vector> &nodeLocs, + const std::vector>> &elemConnPerProc) + { + mMesh.mesh_meta_data().use_simple_fields(); + mBuilder.build_mesh(nodeLocs, elemConnPerProc, theBlockId); + } + + void build_mesh(const std::vector> &nodeLocs, + const std::vector> &elementConn, + const std::vector &elementBlockIDs, + const std::vector &specifiedElementProcOwners = {}) + { + mMesh.mesh_meta_data().use_simple_fields(); + mBuilder.create_block_parts(elementBlockIDs); + mBuilder.build_mesh(nodeLocs, elementConn, elementBlockIDs, specifiedElementProcOwners); + } +}; + +typedef StkMeshFixture<3> StkMeshTetFixture; +typedef StkMeshFixture<2> StkMeshTriFixture; + +} + + +#endif /* KRINO_KRINO_UNIT_TESTS_AKRI_STKMESHFIXTURE_HPP_ */ diff --git a/packages/krino/krino/unit_tests/Akri_UnitTestUtils.cpp b/packages/krino/krino/unit_tests/Akri_UnitTestUtils.cpp index fbb85740d585..ab4dd88ee740 100644 --- a/packages/krino/krino/unit_tests/Akri_UnitTestUtils.cpp +++ b/packages/krino/krino/unit_tests/Akri_UnitTestUtils.cpp @@ -6,14 +6,20 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +#include #include #include namespace krino { -void expect_eq(const Vector3d & gold, const Vector3d & result, const double relativeTol=1.e-6) +void expect_eq(const Vector3d & gold, const Vector3d & result, const double relativeTol) { const double absoluteTol = relativeTol * (gold.length() + result.length()); + expect_eq_absolute(gold, result, absoluteTol); +} + +void expect_eq_absolute(const Vector3d & gold, const Vector3d & result, const double absoluteTol) +{ for (int i=0; i<3; ++i) EXPECT_NEAR(gold[i], result[i], absoluteTol) <<"gold: " << gold << " actual:" << result; } diff --git a/packages/krino/krino/unit_tests/Akri_UnitTestUtils.hpp b/packages/krino/krino/unit_tests/Akri_UnitTestUtils.hpp index ebb0e3b7dbe0..429b57ba7d27 100644 --- a/packages/krino/krino/unit_tests/Akri_UnitTestUtils.hpp +++ b/packages/krino/krino/unit_tests/Akri_UnitTestUtils.hpp @@ -13,6 +13,7 @@ namespace krino { void expect_eq(const Vector3d & gold, const Vector3d & result, const double relativeTol=1.e-6); +void expect_eq_absolute(const Vector3d & gold, const Vector3d & result, const double absoluteTol=1.e-6); } diff --git a/packages/krino/krino/unit_tests/Akri_Unit_Analytic_CDMesh.cpp b/packages/krino/krino/unit_tests/Akri_Unit_Analytic_CDMesh.cpp index eded15aa0c89..34543003a2a0 100644 --- a/packages/krino/krino/unit_tests/Akri_Unit_Analytic_CDMesh.cpp +++ b/packages/krino/krino/unit_tests/Akri_Unit_Analytic_CDMesh.cpp @@ -82,7 +82,12 @@ class AnalyticDecompositionFixture : public ::testing::Test { NodeToCapturedDomainsMap nodesToCapturedDomains; if (cdfemSupport.get_cdfem_edge_degeneracy_handling() == SNAP_TO_INTERFACE_WHEN_QUALITY_ALLOWS_THEN_SNAP_TO_NODE) - nodesToCapturedDomains = snap_as_much_as_possible_while_maintaining_quality(krino_mesh->stk_bulk(), krino_mesh->get_active_part(), cdfemSupport.get_interpolation_fields(), interfaceGeometry, cdfemSupport.get_global_ids_are_parallel_consistent()); + nodesToCapturedDomains = snap_as_much_as_possible_while_maintaining_quality(krino_mesh->stk_bulk(), + krino_mesh->get_active_part(), + cdfemSupport.get_interpolation_fields(), + interfaceGeometry, + cdfemSupport.get_global_ids_are_parallel_consistent(), + cdfemSupport.get_snapping_sharp_feature_angle_in_degrees()); interfaceGeometry.prepare_to_process_elements(krino_mesh->stk_bulk(), nodesToCapturedDomains); if(!krino_mesh->my_old_mesh) diff --git a/packages/krino/krino/unit_tests/Akri_Unit_CDMesh.cpp b/packages/krino/krino/unit_tests/Akri_Unit_CDMesh.cpp index 184e973dfbb9..664754051685 100644 --- a/packages/krino/krino/unit_tests/Akri_Unit_CDMesh.cpp +++ b/packages/krino/krino/unit_tests/Akri_Unit_CDMesh.cpp @@ -587,7 +587,14 @@ class CompleteDecompositionFixture : public ::testing::Test NodeToCapturedDomainsMap nodesToSnappedDomains; std::unique_ptr interfaceGeometry = create_levelset_geometry(krino_mesh->get_active_part(), cdfemSupport, Phase_Support::get(fixture.meta_data()), ls_policy.ls_fields()); if (cdfemSupport.get_cdfem_edge_degeneracy_handling() == SNAP_TO_INTERFACE_WHEN_QUALITY_ALLOWS_THEN_SNAP_TO_NODE) - nodesToSnappedDomains = snap_as_much_as_possible_while_maintaining_quality(krino_mesh->stk_bulk(), krino_mesh->get_active_part(), cdfemSupport.get_interpolation_fields(), *interfaceGeometry, cdfemSupport.get_global_ids_are_parallel_consistent()); + { + nodesToSnappedDomains = snap_as_much_as_possible_while_maintaining_quality(krino_mesh->stk_bulk(), + krino_mesh->get_active_part(), + cdfemSupport.get_interpolation_fields(), + *interfaceGeometry, + cdfemSupport.get_global_ids_are_parallel_consistent(), + cdfemSupport.get_snapping_sharp_feature_angle_in_degrees()); + } interfaceGeometry->prepare_to_process_elements(krino_mesh->stk_bulk(), nodesToSnappedDomains); if(!krino_mesh->my_old_mesh) diff --git a/packages/krino/krino/unit_tests/Akri_Unit_CurvatureLeastSquares.cpp b/packages/krino/krino/unit_tests/Akri_Unit_CurvatureLeastSquares.cpp new file mode 100644 index 000000000000..594a523442fe --- /dev/null +++ b/packages/krino/krino/unit_tests/Akri_Unit_CurvatureLeastSquares.cpp @@ -0,0 +1,164 @@ +// Copyright 2002 - 2008, 2010, 2011 National Technology Engineering +// Solutions of Sandia, LLC (NTESS). Under the terms of Contract +// DE-NA0003525 with NTESS, the U.S. Government retains certain rights +// in this software. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include + +#include +#include +#include "Akri_UnitTestUtils.hpp" + +namespace krino +{ + +class NodePatchInterface +{ +public: + virtual const std::vector & get_halo_node_locations() const = 0; + virtual const std::vector>& get_halo_segments() const = 0; +}; + +void test_rotation(const Vector3d & normal) +{ + static const Vector3d zDir(0.,0.,1.); + + std::array,3> rotationMatrix; + set_rotation_matrix_for_rotating_normal_to_zDir(rotationMatrix, normal); + + const Vector3d rotatedNormal = rotate_3d_vector(rotationMatrix, normal); + + expect_eq(zDir, rotatedNormal); + + const Vector3d reverseRotatedZdir = reverse_rotate_3d_vector(rotationMatrix, zDir); + + expect_eq(normal, reverseRotatedZdir); +} + +TEST(CurvatureLeastSquares, rotationTest) +{ + test_rotation(Vector3d(0.,0.,1.)); + test_rotation(Vector3d(1.,0.,0.)); + test_rotation(Vector3d(0.,1.,0.)); + + const double cos45 = std::sqrt(2.)/2.; + test_rotation(Vector3d(cos45, cos45, 0.)); + test_rotation(Vector3d(cos45, 0., cos45)); + test_rotation(Vector3d(0., cos45, cos45)); +} + +class PolygonalPatchOnSphere : public NodePatchInterface +{ +public: + PolygonalPatchOnSphere(const Vector3d & normalDir, const double halfCurvature, const int numHaloPts) + { + const double radius = 2.0/halfCurvature; + + std::array,3> rotationMatrix; + set_rotation_matrix_for_rotating_normal_to_zDir(rotationMatrix, normalDir); + + const Vector3d unrotatedNodeLoc(0.,0.,radius); + myNodeLoc = reverse_rotate_3d_vector(rotationMatrix, unrotatedNodeLoc); + + const double phi = 15.*M_PI/180.; + myHaloNodeLocs.reserve(numHaloPts); + myHaloSegments.reserve(numHaloPts); + const double dTheta = 2.*M_PI/numHaloPts; + for (int i=0; i & get_halo_node_locations() const override { return myHaloNodeLocs; } + virtual const std::vector> & get_halo_segments() const override { return myHaloSegments; } + +private: + Vector3d myNodeLoc; + std::vector myHaloNodeLocs; + std::vector> myHaloSegments; +}; + +class PolygonalPatchOnPlane : public NodePatchInterface +{ +public: + PolygonalPatchOnPlane(const Vector3d & normalDir, const int numHaloPts) + { + std::array,3> rotationMatrix; + set_rotation_matrix_for_rotating_normal_to_zDir(rotationMatrix, normalDir); + + myHaloNodeLocs.reserve(numHaloPts); + myHaloSegments.reserve(numHaloPts); + const double dTheta = 2.*M_PI/numHaloPts; + for (int i=0; i & get_halo_node_locations() const override { return myHaloNodeLocs; } + virtual const std::vector> & get_halo_segments() const override { return myHaloSegments; } + +private: + std::vector myHaloNodeLocs; + std::vector> myHaloSegments; +}; + +void test_flat_triangle_patch_with_normal_gives_zero_normalCurvature(const Vector3d & normal) +{ + PolygonalPatchOnPlane patch(normal, 3); + + const Vector3d normalCurvature = compute_least_squares_curvature_times_normal(patch.get_halo_node_locations(), patch.get_halo_segments()); + expect_eq_absolute(Vector3d::ZERO, normalCurvature, 1.e-6); +} + +TEST(CurvatureLeastSquares, Flat3TrianglePatches_zeroNormalCurvature) +{ + test_flat_triangle_patch_with_normal_gives_zero_normalCurvature(Vector3d(0.,0.,1.)); + test_flat_triangle_patch_with_normal_gives_zero_normalCurvature(Vector3d(1.,0.,0.)); + test_flat_triangle_patch_with_normal_gives_zero_normalCurvature(Vector3d(0.,1.,0.)); + + const double cos45 = std::sqrt(2.)/2.; + test_flat_triangle_patch_with_normal_gives_zero_normalCurvature(Vector3d(cos45, cos45, 0.)); + test_flat_triangle_patch_with_normal_gives_zero_normalCurvature(Vector3d(cos45, 0., cos45)); +} + +void test_normalCurvature_for_curved_patch(const Vector3d & normalDir, const double curvature, const int numHaloNodes) +{ + const Vector3d normal = normalDir.unit_vector(); + PolygonalPatchOnSphere patch(normal, curvature, numHaloNodes); + const Vector3d goldNormalCurvature = curvature*normal; + + const Vector3d normalCurvature = compute_least_squares_curvature_times_normal(patch.get_halo_node_locations(), patch.get_halo_segments()); + expect_eq(goldNormalCurvature, normalCurvature, 1.e-2); +} + +void test_normalCurvature_for_curved_patches(const int numHaloNodes) +{ + const double curvature = 0.1; + test_normalCurvature_for_curved_patch(Vector3d(0.,0.,1.), curvature, numHaloNodes); + test_normalCurvature_for_curved_patch(Vector3d(1.,0.,0.), curvature, numHaloNodes); + test_normalCurvature_for_curved_patch(Vector3d(0.,1.,0.), curvature, numHaloNodes); + + const double cos45 = std::sqrt(2.)/2.; + test_normalCurvature_for_curved_patch(Vector3d(cos45, cos45, 0.), curvature, numHaloNodes); + test_normalCurvature_for_curved_patch(Vector3d(cos45, 0., cos45), curvature, numHaloNodes); +} + +TEST(CurvatureLeastSquares, CurvedPatchesOfVariousSizes_correctNormalCurvature) +{ + // curvature only fit + test_normalCurvature_for_curved_patches(3); + test_normalCurvature_for_curved_patches(4); + + // curvature and normal fit + test_normalCurvature_for_curved_patches(5); + test_normalCurvature_for_curved_patches(7); +} + +} // namespace krino diff --git a/packages/krino/krino/unit_tests/Akri_Unit_Explicit_Hamilton_Jacobi.cpp b/packages/krino/krino/unit_tests/Akri_Unit_Explicit_Hamilton_Jacobi.cpp index 33e73c71a2e4..02e0c482180c 100644 --- a/packages/krino/krino/unit_tests/Akri_Unit_Explicit_Hamilton_Jacobi.cpp +++ b/packages/krino/krino/unit_tests/Akri_Unit_Explicit_Hamilton_Jacobi.cpp @@ -46,18 +46,18 @@ stk::mesh::BulkData & read_mesh(stk::io::StkMeshIoBroker & stkIo) void declare_fields(stk::mesh::MetaData & meta, ProblemFields & fields) { - fields.levelSetField = &meta.declare_field>(stk::topology::NODE_RANK, "LevelSet", 2); + fields.levelSetField = &meta.declare_field(stk::topology::NODE_RANK, "LevelSet", 2); stk::mesh::put_field_on_mesh(*fields.levelSetField, meta.universal_part(), nullptr); - fields.RHS = &meta.declare_field>(stk::topology::NODE_RANK, "RHS", 1); + fields.RHS = &meta.declare_field(stk::topology::NODE_RANK, "RHS", 1); stk::mesh::put_field_on_mesh(*fields.RHS, meta.universal_part(), nullptr); - fields.RHSNorm = &meta.declare_field>(stk::topology::NODE_RANK, "RHSNorm", 1); + fields.RHSNorm = &meta.declare_field(stk::topology::NODE_RANK, "RHSNorm", 1); stk::mesh::put_field_on_mesh(*fields.RHSNorm, meta.universal_part(), nullptr); auto constCoordsField = static_cast*>(meta.coordinate_field()); fields.coordsField = const_cast*>(constCoordsField); if (true) { - fields.speedField = &meta.declare_field>(stk::topology::ELEMENT_RANK, "Speed", 1); + fields.speedField = &meta.declare_field(stk::topology::ELEMENT_RANK, "Speed", 1); stk::mesh::put_field_on_mesh(*fields.speedField, meta.universal_part(), nullptr); } } diff --git a/packages/krino/krino/unit_tests/Akri_Unit_MeshHelpers.cpp b/packages/krino/krino/unit_tests/Akri_Unit_MeshHelpers.cpp index a7a04758b117..2b5de30e7b6c 100644 --- a/packages/krino/krino/unit_tests/Akri_Unit_MeshHelpers.cpp +++ b/packages/krino/krino/unit_tests/Akri_Unit_MeshHelpers.cpp @@ -8,6 +8,8 @@ #include +#include +#include #include #include #include @@ -104,6 +106,13 @@ void test_and_cleanup_internal_side(stk::mesh::BulkData & mesh, const stk::mesh: } } +auto create_2D_mesh(const stk::ParallelMachine & pm) +{ + std::unique_ptr bulk = stk::mesh::MeshBuilder(pm).set_spatial_dimension(2).create(); + bulk->mesh_meta_data().use_simple_fields(); + return bulk; +} + TEST(MeshHelpers, DeclareElementSide) { stk::ParallelMachine pm = MPI_COMM_WORLD; @@ -118,9 +127,10 @@ TEST(MeshHelpers, DeclareElementSide) * 1---2---3 */ - unsigned spatialDim = 2; - stk::mesh::MetaData meta(spatialDim); - stk::mesh::BulkData mesh(meta, pm); + + auto meshPtr = create_2D_mesh(pm); + stk::mesh::BulkData& mesh = *meshPtr; + stk::mesh::MetaData& meta = mesh.mesh_meta_data(); stk::mesh::Part& block_1 = meta.declare_part_with_topology("block_1", stk::topology::QUAD_4_2D); stk::mesh::Part& block_2 = meta.declare_part_with_topology("block_2", stk::topology::QUAD_4_2D); @@ -205,9 +215,9 @@ TEST(MeshHelpers, FullyCoincidentVolumeElements) // ranks 0 and 1 will have any elements. We test larger number of processors to ensure that // we get a parallel-consistent result to avoid potential parallel hangs in the full app. - unsigned spatialDim = 2; - stk::mesh::MetaData meta(spatialDim); - stk::mesh::BulkData mesh(meta, pm); + auto meshPtr = create_2D_mesh(pm); + stk::mesh::BulkData& mesh = *meshPtr; + stk::mesh::MetaData& meta = mesh.mesh_meta_data(); stk::mesh::Part& block_1 = meta.declare_part_with_topology("block_1", stk::topology::QUAD_4_2D); stk::mesh::Part& active_part = meta.declare_part("active"); @@ -231,9 +241,9 @@ TEST(MeshHelpers, PartiallyCoincidentActiveVolumeElements) // This test will create a two element mesh (quad4 elements) on 1 or 2 processors. - unsigned spatialDim = 2; - stk::mesh::MetaData meta(spatialDim); - stk::mesh::BulkData mesh(meta, pm); + auto meshPtr = create_2D_mesh(pm); + stk::mesh::BulkData& mesh = *meshPtr; + stk::mesh::MetaData& meta = mesh.mesh_meta_data(); stk::mesh::Part& block_1 = meta.declare_part_with_topology("block_1", stk::topology::QUAD_4_2D); stk::mesh::Part& active_part = meta.declare_part("active"); @@ -257,9 +267,9 @@ TEST(MeshHelpers, NotCoincidentActiveDegenerateVolumeElements) // This test will create a two element mesh (quad4 elements) on 1 or 2 processors. - unsigned spatialDim = 2; - stk::mesh::MetaData meta(spatialDim); - stk::mesh::BulkData mesh(meta, pm); + auto meshPtr = create_2D_mesh(pm); + stk::mesh::BulkData& mesh = *meshPtr; + stk::mesh::MetaData& meta = mesh.mesh_meta_data(); stk::mesh::Part& block_1 = meta.declare_part_with_topology("block_1", stk::topology::QUAD_4_2D); stk::mesh::Part& active_part = meta.declare_part("active"); diff --git a/packages/krino/krino/unit_tests/Akri_Unit_RebalanceUtils.cpp b/packages/krino/krino/unit_tests/Akri_Unit_RebalanceUtils.cpp index 76e36e1822de..c489a261b2c9 100644 --- a/packages/krino/krino/unit_tests/Akri_Unit_RebalanceUtils.cpp +++ b/packages/krino/krino/unit_tests/Akri_Unit_RebalanceUtils.cpp @@ -50,10 +50,8 @@ void create_block_and_register_fields(SimpleStkFixture & fixture) auto & meta = fixture.meta_data(); meta.declare_part_with_topology("block_1", stk::topology::TRIANGLE_3_2D); - meta.declare_field > - (stk::topology::NODE_RANK, "coordinates"); - auto & load_field = - meta.declare_field>(stk::topology::ELEMENT_RANK, "element_weights"); + meta.declare_field(stk::topology::NODE_RANK, "coordinates"); + auto & load_field = meta.declare_field(stk::topology::ELEMENT_RANK, "element_weights"); stk::mesh::put_field_on_mesh(load_field, meta.universal_part(), nullptr); fixture.commit(); @@ -359,16 +357,13 @@ TEST(Rebalance, MultipleWeightFields) stk::mesh::Part & block_1 = meta.declare_part_with_topology("block_1", stk::topology::QUAD_4_2D); stk::mesh::Part & block_2 = meta.declare_part_with_topology("block_2", stk::topology::QUAD_4_2D); - auto & coords_field = meta.declare_field>( - stk::topology::NODE_RANK, "coordinates"); - stk::mesh::put_field_on_mesh(coords_field, meta.universal_part(), nullptr); + auto & coords_field = meta.declare_field(stk::topology::NODE_RANK, "coordinates"); + stk::mesh::put_field_on_mesh(coords_field, meta.universal_part(), 2, nullptr); - auto & weights_field_1 = meta.declare_field>( - stk::topology::ELEMENT_RANK, "element_weights_1"); + auto & weights_field_1 = meta.declare_field(stk::topology::ELEMENT_RANK, "element_weights_1"); stk::mesh::put_field_on_mesh(weights_field_1, block_1, nullptr); - auto & weights_field_2 = meta.declare_field>( - stk::topology::ELEMENT_RANK, "element_weights_2"); + auto & weights_field_2 = meta.declare_field(stk::topology::ELEMENT_RANK, "element_weights_2"); stk::mesh::put_field_on_mesh(weights_field_2, block_2, nullptr); meta.commit(); diff --git a/packages/krino/krino/unit_tests/Akri_Unit_Single_Element_Fixtures.hpp b/packages/krino/krino/unit_tests/Akri_Unit_Single_Element_Fixtures.hpp index f3d969b484ba..fb0e9fa223af 100644 --- a/packages/krino/krino/unit_tests/Akri_Unit_Single_Element_Fixtures.hpp +++ b/packages/krino/krino/unit_tests/Akri_Unit_Single_Element_Fixtures.hpp @@ -9,6 +9,7 @@ #ifndef AKRI_UNIT_SINGLE_ELEMENT_FIXTURES_H_ #define AKRI_UNIT_SINGLE_ELEMENT_FIXTURES_H_ +#include #include // for BulkData #include // for MetaData @@ -27,22 +28,25 @@ inline std::vector entity_rank_names_with_ft() class SimpleStkFixture { public: - SimpleStkFixture(unsigned dimension, MPI_Comm comm = MPI_COMM_WORLD) - : meta(dimension, entity_rank_names_with_ft()), - bulk(meta, comm) - { - meta.set_mesh_bulk_data(&bulk); - AuxMetaData::create(meta); + SimpleStkFixture(unsigned dimension, MPI_Comm comm = MPI_COMM_WORLD) { + bulk = stk::mesh::MeshBuilder(comm) + .set_spatial_dimension(dimension) + .set_entity_rank_names(entity_rank_names_with_ft()) + .create(); + + meta = bulk->mesh_meta_data_ptr(); + meta->use_simple_fields(); + AuxMetaData::create(*meta); } - void commit() { meta.commit(); } - void write_results(const std::string & filename) { write_results(filename, bulk); } + void commit() { meta->commit(); } + void write_results(const std::string & filename) { write_results(filename, *bulk); } static void write_results(const std::string & filename, stk::mesh::BulkData & mesh, const bool use64bitIds = true); - stk::mesh::MetaData & meta_data() { return meta; } - stk::mesh::BulkData & bulk_data() { return bulk; } + stk::mesh::MetaData & meta_data() { return *meta; } + stk::mesh::BulkData & bulk_data() { return *bulk; } private: - stk::mesh::MetaData meta; - stk::mesh::BulkData bulk; + std::shared_ptr meta; + std::unique_ptr bulk; }; class SimpleStkFixture2d : public SimpleStkFixture diff --git a/packages/krino/krino/unit_tests/Akri_Unit_Snap.cpp b/packages/krino/krino/unit_tests/Akri_Unit_Snap.cpp new file mode 100644 index 000000000000..402653829453 --- /dev/null +++ b/packages/krino/krino/unit_tests/Akri_Unit_Snap.cpp @@ -0,0 +1,555 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace krino { + +struct RegularTri +{ + RegularTri() = default; + static constexpr int DIM = 2; + std::vector nodeLocs + {{ + { 0.000, 0.000 }, + { 1.000, 0.000 }, + { 0.500, std::sqrt(3.)/2. }, + }}; + + std::array TriConn{{0, 1, 2}}; + std::vector> allElementConn{TriConn}; +}; + +struct Tri306090 +{ + Tri306090() = default; + static constexpr int DIM = 2; + std::vector nodeLocs + {{ + { 0.000, 0.000 }, + { 0.500, 0.000 }, + { 0.000, std::sqrt(3.)/2. }, + }}; + + std::array TriConn{{0, 1, 2}}; + std::vector> allElementConn{TriConn}; +}; + +struct TwoTri306090 +{ + TwoTri306090() = default; + static constexpr int DIM = 2; + std::vector nodeLocs + {{ + { 0.000, 0.000 }, + { 0.500, 0.000 }, + { 0.000, std::sqrt(3.)/2. }, + {-0.500, 0.000 } + }}; + + std::array Tri1Conn{{0, 1, 2}}; + std::array Tri2Conn{{0, 2, 3}}; + std::vector> allElementConn{Tri1Conn, Tri2Conn}; +}; + +struct RegularTet +{ + RegularTet() = default; + static constexpr int DIM = 3; + std::vector nodeLocs + {{ + { 0.5, 0.0, -0.5/std::sqrt(2.) }, + {-0.5, 0.0, -0.5/std::sqrt(2.) }, + { 0.0, -0.5, 0.5/std::sqrt(2.) }, + { 0.0, 0.5, 0.5/std::sqrt(2.) }, + }}; + + std::array TetConn{{0, 1, 2, 3}}; + std::vector> allElementConn{TetConn}; +}; + +struct RightTet +{ + RightTet() = default; + static constexpr int DIM = 3; + std::vector nodeLocs + {{ + { 0.0, 0.0, 0.0 }, + { 1.0, 0.0, 0.0 }, + { 0.0, 1.0, 0.0 }, + { 0.0, 0.0, 1.0 }, + }}; + + std::array TetConn{{0, 1, 2, 3}}; + std::vector> allElementConn{TetConn}; +}; + +struct FourRightTets +{ + FourRightTets() = default; + static constexpr int DIM = 3; + std::vector nodeLocs + {{ + { 0.0, 0.0, 0.0 }, + { 1.0, 0.0, 0.0 }, + { 0.0, 1.0, 0.0 }, + {-1.0, 0.0, 0.0 }, + { 0.0,-1.0, 0.0 }, + { 0.0, 0.0, 1.0 }, + }}; + + std::array Tet1Conn{{0, 1, 2, 5}}; + std::array Tet2Conn{{0, 2, 3, 5}}; + std::array Tet3Conn{{0, 3, 4, 5}}; + std::array Tet4Conn{{0, 4, 1, 5}}; + std::vector> allElementConn{Tet1Conn, Tet2Conn, Tet3Conn, Tet4Conn}; +}; + +struct TwoRightTets +{ + TwoRightTets() = default; + static constexpr int DIM = 3; + std::vector nodeLocs + {{ + { 0.0, 0.0, 0.0 }, + { 1.0, 0.0, 0.0 }, + { 0.0, 1.0, 0.0 }, + {-1.0, 0.0, 0.0 }, + { 0.0, 0.0, 1.0 }, + }}; + + std::array Tet1Conn{{0, 1, 2, 4}}; + std::array Tet2Conn{{0, 2, 3, 4}}; + std::vector> allElementConn{Tet1Conn, Tet2Conn}; +}; + +struct TwoRightTris +{ + TwoRightTris() = default; + static constexpr int DIM = 2; + std::vector nodeLocs + {{ + { 0.0, 0.0 }, + { 1.0, 0.0 }, + { 0.0, 1.0 }, + {-1.0, 0.0 }, + }}; + + std::array Tri1Conn{{0, 1, 2}}; + std::array Tri2Conn{{0, 2, 3}}; + std::vector> allElementConn{Tri1Conn, Tri2Conn}; +}; + +class RegularTriWithSides : public StkMeshTriFixture +{ +protected: + void create_sides_and_build_mesh(const std::vector &sidesetIds) + { + mBuilder.create_sideset_parts(sidesetIds); + + RegularTri meshSpec; + build_mesh(meshSpec.nodeLocs, {meshSpec.allElementConn}); + } + + stk::mesh::Entity get_side_1() const { return mBuilder.get_side_with_nodes({get_assigned_node_for_index(0), get_assigned_node_for_index(1)}); } + stk::mesh::Entity get_side_2() const { return mBuilder.get_side_with_nodes({get_assigned_node_for_index(0), get_assigned_node_for_index(2)}); } + + void expect_which_snaps_are_allowed(const std::vector & goldWhichSnapsAreAllowed, const std::vector & intersectionNodeIndices) + { + std::vector intersectionNodes; + for (auto intersectionNodeIndex : intersectionNodeIndices) + intersectionNodes.push_back(get_assigned_node_for_index(intersectionNodeIndex)); + + const std::vector whichSnapsAreAllowed = which_intersection_point_nodes_are_compatible_for_snapping(mMesh, mBuilder.get_aux_meta(), mBuilder.get_phase_support(), intersectionNodes); + EXPECT_EQ(goldWhichSnapsAreAllowed, whichSnapsAreAllowed); + } +}; + +TEST_F(RegularTriWithSides, triMeshWithNoSidesets_attemptSnapToIntPointOnSide_snapsAllowed) +{ + if(stk::parallel_machine_size(mComm) == 1) + { + create_sides_and_build_mesh({}); + + expect_which_snaps_are_allowed({true,true}, {1,2}); + } +} + +TEST_F(RegularTriWithSides, triMeshWithOneSidesetOnOneSide_attemptSnapToIntPointOnThirdSide_oneSnapAllowed) +{ + if(stk::parallel_machine_size(mComm) == 1) + { + const unsigned sideset1Id = 1; + create_sides_and_build_mesh({sideset1Id}); + + mBuilder.add_sides_to_sidesets({get_side_1()}, {{sideset1Id}}); + + expect_which_snaps_are_allowed({false,true}, {1,2}); + } +} + +TEST_F(RegularTriWithSides, triMeshWithTwoSidesetOnTwoSides_attemptSnapToIntPointOnThirdSide_noSnapAllowed) +{ + if(stk::parallel_machine_size(mComm) == 1) + { + const unsigned sideset1Id = 1; + const unsigned sideset2Id = 2; + create_sides_and_build_mesh({sideset1Id, sideset2Id}); + + mBuilder.add_sides_to_sidesets({get_side_1(), get_side_2()}, {{sideset1Id},{sideset2Id}}); + + expect_which_snaps_are_allowed({false,false}, {1,2}); + } +} + +TEST_F(RegularTriWithSides, triMeshWithOneSidesetOnTwoSides_attemptSnapToIntPointOnThirdSide_noSnapAllowed) +{ + // This is the sideset keyhole problem. + if(stk::parallel_machine_size(mComm) == 1) + { + const unsigned sideset1Id = 1; + create_sides_and_build_mesh({sideset1Id}); + + mBuilder.add_sides_to_sidesets({get_side_1(), get_side_2()}, {{sideset1Id},{sideset1Id}}); + + expect_which_snaps_are_allowed({false,false}, {1,2}); + } +} + +TEST_F(RegularTriWithSides, triMeshWithOneSidesetOnTwoSides_attemptSnapToIntPointOnVolume_noSnapAllowed) +{ + // This is a volume intersection point version of the keyhole problem. + if(stk::parallel_machine_size(mComm) == 1) + { + const unsigned sideset1Id = 1; + create_sides_and_build_mesh({sideset1Id}); + + mBuilder.add_sides_to_sidesets({get_side_1(), get_side_2()}, {{sideset1Id},{sideset1Id}}); + + expect_which_snaps_are_allowed({false,false,false}, {0,1,2}); + } +} + +template +class SharpFeatureFixture : public StkMeshFixture +{ +protected: + using StkMeshFixture::mMesh; + + void find_sharp_features() + { + mySharpFeatureInfo.find_sharp_features(mMesh, mMesh.mesh_meta_data().coordinate_field(), mMesh.mesh_meta_data().universal_part(), myCosFeatureAngle); + } + + void build_mesh_and_find_sharp_features() + { + this->build_mesh(meshSpec.nodeLocs, {meshSpec.allElementConn}); + find_sharp_features(); + } + + void test_is_node_pinned(const stk::mesh::Entity node, const bool goldIsNodePinned) + { + const SharpFeatureConstraint * constraint = mySharpFeatureInfo.get_constraint(node); + if (goldIsNodePinned) + { + EXPECT_TRUE(constraint != nullptr && constraint->is_pinned()); + } + else + { + EXPECT_TRUE(constraint == nullptr || !constraint->is_pinned()); + } + } + + bool is_node_in_assigned_nodes_for_indices(const stk::mesh::BulkData & mesh, const stk::mesh::Entity node, const std::vector & nodeIndices) + { + for (auto nodeIndex : nodeIndices) + if (this->get_assigned_node_for_index(nodeIndex) == node) + return true; + return false; + } + + void test_are_nodes_pinned(const std::vector & goldPinnedNodeIndices) + { + std::vector ownedNodes; + stk::mesh::get_selected_entities( mMesh.mesh_meta_data().locally_owned_part(), mMesh.buckets( stk::topology::NODE_RANK ), ownedNodes ); + + for (auto && node : ownedNodes) + { + const bool goldIsNodePinned = is_node_in_assigned_nodes_for_indices(mMesh, node, goldPinnedNodeIndices); + test_is_node_pinned(node, goldIsNodePinned); + } + } + + void test_is_node_constrained_on_edge(const unsigned sharpEdgeNodeIndex, const std::array & goldSharpEdgeNodeNbrIndices) + { + stk::mesh::Entity sharpEdgeNode = this->get_assigned_node_for_index(sharpEdgeNodeIndex); + if (mMesh.is_valid(sharpEdgeNode) && mMesh.parallel_owner_rank(sharpEdgeNode) == mMesh.parallel_rank()) + { + const SharpFeatureConstraint * constraint = mySharpFeatureInfo.get_constraint(sharpEdgeNode); + ASSERT_TRUE(constraint != nullptr && constraint->is_constrained_on_edge()); + const std::array sharpEdgeNodes = constraint->get_sharp_edge_nodes(); + for (auto goldSharpEdgeNodeNbrIndex : goldSharpEdgeNodeNbrIndices) + { + stk::mesh::Entity goldSharpEdgeNodeNbr = this->get_assigned_node_for_index(goldSharpEdgeNodeNbrIndex); + EXPECT_TRUE(sharpEdgeNodes[0] == goldSharpEdgeNodeNbr || sharpEdgeNodes[1] == goldSharpEdgeNodeNbr); + } + } + } + + void test_are_all_nodes_pinned() + { + std::vector ownedNodes; + stk::mesh::get_selected_entities( mMesh.mesh_meta_data().locally_owned_part(), mMesh.buckets( stk::topology::NODE_RANK ), ownedNodes ); + + for (auto && node : ownedNodes) + test_is_node_pinned(node, true); + } + + MESHSPEC meshSpec; + double myCosFeatureAngle{std::cos(M_PI/180.*135.0)}; + SharpFeatureInfo mySharpFeatureInfo; +}; + +typedef SharpFeatureFixture SharpFeatureRegularTetFixture; + +TEST_F(SharpFeatureRegularTetFixture, meshWithAllNodesOnCorners_allNodesArePinned) +{ + if(stk::parallel_machine_size(mComm) == 1) + { + build_mesh_and_find_sharp_features(); + + test_are_all_nodes_pinned(); + } +} + +typedef SharpFeatureFixture SharpFeatureRightTetFixture; + +TEST_F(SharpFeatureRightTetFixture, meshWithAllNodesOnCorners_allNodesArePinned) +{ + if(stk::parallel_machine_size(mComm) == 1) + { + build_mesh_and_find_sharp_features(); + + test_are_all_nodes_pinned(); + } +} + +typedef SharpFeatureFixture SharpFeatureRegularTriFixture; + +TEST_F(SharpFeatureRegularTriFixture, meshWithAllNodesOnCorners_allNodesArePinned) +{ + if(stk::parallel_machine_size(mComm) == 1) + { + build_mesh_and_find_sharp_features(); + + test_are_all_nodes_pinned(); + } +} + +typedef SharpFeatureFixture SharpFeatureTri306090Fixture; + +TEST_F(SharpFeatureTri306090Fixture, meshWithAllNodesOnCorners_allNodesArePinned) +{ + if(stk::parallel_machine_size(mComm) == 1) + { + build_mesh_and_find_sharp_features(); + + test_are_all_nodes_pinned(); + } +} + +typedef SharpFeatureFixture SharpFeatureTwoTri306090Fixture; + +TEST_F(SharpFeatureTwoTri306090Fixture, meshWithCornerNodesAndUnconstrainedNode_constraintsAreCorrect) +{ + if(stk::parallel_machine_size(mComm) <= 2) + { + if(stk::parallel_machine_size(mComm) == 1) + this->build_mesh(meshSpec.nodeLocs, {meshSpec.allElementConn}); + else if(stk::parallel_machine_size(mComm) == 2) + this->build_mesh(meshSpec.nodeLocs, meshSpec.allElementConn, {1,1}, {0,1}); + + find_sharp_features(); + + const std::vector goldPinnedNodeIndices{1,2,3}; + test_are_nodes_pinned(goldPinnedNodeIndices); + } +} + +typedef SharpFeatureFixture SharpFeatureFourRightTetsFixture; + +TEST_F(SharpFeatureFourRightTetsFixture, meshWithCornerNodesAndUnconstrainedNode_constraintsAreCorrect) +{ + if(stk::parallel_machine_size(mComm) <= 4) + { + if(stk::parallel_machine_size(mComm) == 1) + this->build_mesh(meshSpec.nodeLocs, {meshSpec.allElementConn}); + else if(stk::parallel_machine_size(mComm) == 2) + this->build_mesh(meshSpec.nodeLocs, meshSpec.allElementConn, {1,1,1,1}, {0,0,1,1}); + else if(stk::parallel_machine_size(mComm) == 3) + this->build_mesh(meshSpec.nodeLocs, meshSpec.allElementConn, {1,1,1,1}, {0,1,2,2}); + else if(stk::parallel_machine_size(mComm) == 4) + this->build_mesh(meshSpec.nodeLocs, meshSpec.allElementConn, {1,1,1,1}, {0,1,2,3}); + + find_sharp_features(); + + const std::vector goldPinnedNodeIndices{1,2,3,4,5}; + test_are_nodes_pinned(goldPinnedNodeIndices); + } +} + +typedef SharpFeatureFixture SharpFeatureTwoRightTetsFixture; + +TEST_F(SharpFeatureTwoRightTetsFixture, meshWithCornerNodesAndEdgeNode_constraintsAreCorrect) +{ + if(stk::parallel_machine_size(mComm) <= 2) + { + if(stk::parallel_machine_size(mComm) == 1) + this->build_mesh(meshSpec.nodeLocs, {meshSpec.allElementConn}); + else if(stk::parallel_machine_size(mComm) == 2) + this->build_mesh(meshSpec.nodeLocs, meshSpec.allElementConn, {1,1}, {0,1}); + + find_sharp_features(); + + const std::vector goldPinnedNodeIndices{1,2,3,4}; + test_are_nodes_pinned(goldPinnedNodeIndices); + + test_is_node_constrained_on_edge(0, {{1,3}}); + } +} + +template +class VolumePreservingSnappingLimiterFixture : public StkMeshFixture +{ +protected: + using StkMeshFixture::mMesh; + + Vector3d compute_snap_location(const std::vector & snapNodeIndices, const std::vector & snapNodeWeights) + { + stk::math::Vector3d snapLocation = stk::math::Vector3d::ZERO; + for (size_t i=0; iget_assigned_node_for_index(snapNodeIndices[i]); + const stk::math::Vector3d nodeLocation(field_data(*mMesh.mesh_meta_data().coordinate_field(), snapNode), mMesh.mesh_meta_data().spatial_dimension()); + snapLocation += snapNodeWeights[i] * nodeLocation; + } + return snapLocation; + } + + void test_is_snap_allowed_based_on_volume_change(const bool goldIsSnapAllowed, const unsigned nodeIndex, const std::vector & snapNodeIndices, const std::vector & snapNodeWeights) + { + stk::mesh::Entity node = this->get_assigned_node_for_index(nodeIndex); + if (mMesh.is_valid(node) && mMesh.parallel_owner_rank(node) == mMesh.parallel_rank()) + { + EXPECT_EQ(goldIsSnapAllowed, myVolumePreservingSnappingLimiter->is_snap_allowed(node, compute_snap_location(snapNodeIndices, snapNodeWeights))); + } + } + + VolumePreservingSnappingLimiter::ElementToBlockConverter build_element_to_block_converter() + { + auto converter = [](const stk::mesh::BulkData & mesh, const stk::mesh::Entity elem) + { + for (auto && part : mesh.bucket(elem).supersets()) + if (part->primary_entity_rank() == stk::topology::ELEMENT_RANK && !stk::mesh::is_auto_declared_part(*part)) + return part; + stk::mesh::Part * blockPart = nullptr; + return blockPart; + }; + return converter; + } + + void setup_volume_preserving_snapping_limiter() + { + myVolumePreservingSnappingLimiter = std::make_unique(mMesh, *mMesh.mesh_meta_data().coordinate_field(), build_element_to_block_converter(), myVolumeConservationTol); + } + + MESHSPEC meshSpec; + double myVolumeConservationTol{0.05}; + std::unique_ptr myVolumePreservingSnappingLimiter; +}; + +typedef VolumePreservingSnappingLimiterFixture VolumePreservingSnappingLimiterTwoRightTetsFixture; + +TEST_F(VolumePreservingSnappingLimiterTwoRightTetsFixture, meshWithOneBlockWithCornerNodesAndEdgeNode_constraintsAreCorrect) +{ + if(stk::parallel_machine_size(mComm) <= 2) + { + if(stk::parallel_machine_size(mComm) == 1) + this->build_mesh(meshSpec.nodeLocs, {meshSpec.allElementConn}); + else if(stk::parallel_machine_size(mComm) == 2) + this->build_mesh(meshSpec.nodeLocs, meshSpec.allElementConn, {1,1}, {0,1}); + + setup_volume_preserving_snapping_limiter(); + + test_is_snap_allowed_based_on_volume_change(true, 1, {1,2}, {0.99,0.01}); + test_is_snap_allowed_based_on_volume_change(false, 1, {1,2}, {0.5,0.5}); + + test_is_snap_allowed_based_on_volume_change(true, 0, {0,1}, {0.5,0.5}); + test_is_snap_allowed_based_on_volume_change(true, 0, {0,3}, {0.5,0.5}); + } +} + +TEST_F(VolumePreservingSnappingLimiterTwoRightTetsFixture, meshWithTwoBlockWithCornerNodesAndEdgeNode_constraintsAreCorrect) +{ + if(stk::parallel_machine_size(mComm) <= 2) + { + if(stk::parallel_machine_size(mComm) == 1) + this->build_mesh(meshSpec.nodeLocs, meshSpec.allElementConn, {1,2}, {0,0}); + else if(stk::parallel_machine_size(mComm) == 2) + this->build_mesh(meshSpec.nodeLocs, meshSpec.allElementConn, {1,2}, {0,1}); + + setup_volume_preserving_snapping_limiter(); + + test_is_snap_allowed_based_on_volume_change(false, 0, {0,1}, {0.5,0.5}); + test_is_snap_allowed_based_on_volume_change(false, 0, {0,3}, {0.5,0.5}); + } +} + +typedef VolumePreservingSnappingLimiterFixture VolumePreservingSnappingLimiterTwoRightTrisFixture; + +TEST_F(VolumePreservingSnappingLimiterTwoRightTrisFixture, meshWithOneBlockWithCornerNodesAndEdgeNode_constraintsAreCorrect) +{ + if(stk::parallel_machine_size(mComm) <= 2) + { + if(stk::parallel_machine_size(mComm) == 1) + this->build_mesh(meshSpec.nodeLocs, {meshSpec.allElementConn}); + else if(stk::parallel_machine_size(mComm) == 2) + this->build_mesh(meshSpec.nodeLocs, meshSpec.allElementConn, {1,1}, {0,1}); + + setup_volume_preserving_snapping_limiter(); + + test_is_snap_allowed_based_on_volume_change(true, 1, {1,2}, {0.99,0.01}); + test_is_snap_allowed_based_on_volume_change(false, 1, {1,2}, {0.5,0.5}); + + test_is_snap_allowed_based_on_volume_change(true, 0, {0,1}, {0.5,0.5}); + test_is_snap_allowed_based_on_volume_change(true, 0, {0,3}, {0.5,0.5}); + } +} + +TEST_F(VolumePreservingSnappingLimiterTwoRightTrisFixture, meshWithTwoBlockWithCornerNodesAndEdgeNode_constraintsAreCorrect) +{ + if(stk::parallel_machine_size(mComm) <= 2) + { + if(stk::parallel_machine_size(mComm) == 1) + this->build_mesh(meshSpec.nodeLocs, meshSpec.allElementConn, {1,2}, {0,0}); + else if(stk::parallel_machine_size(mComm) == 2) + this->build_mesh(meshSpec.nodeLocs, meshSpec.allElementConn, {1,2}, {0,1}); + + setup_volume_preserving_snapping_limiter(); + + test_is_snap_allowed_based_on_volume_change(false, 0, {0,1}, {0.5,0.5}); + test_is_snap_allowed_based_on_volume_change(false, 0, {0,3}, {0.5,0.5}); + } +} + +} diff --git a/packages/krino/krino/unit_tests/Akri_Unit_main.cpp b/packages/krino/krino/unit_tests/Akri_Unit_main.cpp index fef305ede6b5..75f831ae3fde 100644 --- a/packages/krino/krino/unit_tests/Akri_Unit_main.cpp +++ b/packages/krino/krino/unit_tests/Akri_Unit_main.cpp @@ -26,7 +26,7 @@ int main(int argc, char **argv) { Kokkos::ScopeGuard guard(argc, argv); - stk::unit_test_util::create_parallel_output(sierra::Env::parallel_rank()); + stk::unit_test_util::simple_fields::create_parallel_output(sierra::Env::parallel_rank()); return RUN_ALL_TESTS(); } diff --git a/packages/muelu/adapters/tpetra/MueLu_CreateTpetraPreconditioner.hpp b/packages/muelu/adapters/tpetra/MueLu_CreateTpetraPreconditioner.hpp index db1c7777acb2..e41be8434d6a 100644 --- a/packages/muelu/adapters/tpetra/MueLu_CreateTpetraPreconditioner.hpp +++ b/packages/muelu/adapters/tpetra/MueLu_CreateTpetraPreconditioner.hpp @@ -171,6 +171,25 @@ namespace MueLu { MueLu::ReuseXpetraPreconditioner(A, H); } + template + void ReuseTpetraPreconditioner(const Teuchos::RCP >& inA, + MueLu::TpetraOperator& Op) { + typedef Scalar SC; + typedef LocalOrdinal LO; + typedef GlobalOrdinal GO; + typedef Node NO; + + typedef Xpetra::Matrix Matrix; + typedef MueLu ::Hierarchy Hierarchy; + + RCP H = Op.GetHierarchy(); + RCP > temp = rcp(new Xpetra::TpetraBlockCrsMatrix(inA)); + TEUCHOS_TEST_FOR_EXCEPTION(temp==Teuchos::null, Exceptions::RuntimeError, "ReuseTpetraPreconditioner: cast from Tpetra::BlockCrsMatrix to Xpetra::TpetraBlockCrsMatrix failed."); + RCP A = rcp(new Xpetra::CrsMatrixWrap(temp)); + + MueLu::ReuseXpetraPreconditioner(A, H); + } + } //namespace diff --git a/packages/muelu/src/Graph/Containers/MueLu_Aggregates_def.hpp b/packages/muelu/src/Graph/Containers/MueLu_Aggregates_def.hpp index 4ad46301f9df..7e5daf357256 100644 --- a/packages/muelu/src/Graph/Containers/MueLu_Aggregates_def.hpp +++ b/packages/muelu/src/Graph/Containers/MueLu_Aggregates_def.hpp @@ -191,12 +191,12 @@ namespace MueLu { for(LO i=0; i; using device_type = DeviceType; using range_type = Kokkos::RangePolicy; + using LO_view = Kokkos::View; using aggregates_sizes_type = Kokkos::View; @@ -259,6 +260,12 @@ namespace MueLu { local_graph_type GetGraph() const; + /*! @brief Generates a compressed list of nodes in each aggregate, where + the entries in aggNodes[aggPtr[i]] up to aggNodes[aggPtr[i+1]-1] contain the nodes in aggregate i. + unaggregated contains the list of nodes which are, for whatever reason, not aggregated (e.g. Dirichlet) + */ + void ComputeNodesInAggregate(LO_view & aggPtr, LO_view & aggNodes, LO_view & unaggregated) const; + //! @name Overridden from Teuchos::Describable //@{ diff --git a/packages/muelu/src/Graph/Containers/MueLu_Aggregates_kokkos_def.hpp b/packages/muelu/src/Graph/Containers/MueLu_Aggregates_kokkos_def.hpp index 80827cb480e8..f8cba41cfaee 100644 --- a/packages/muelu/src/Graph/Containers/MueLu_Aggregates_kokkos_def.hpp +++ b/packages/muelu/src/Graph/Containers/MueLu_Aggregates_kokkos_def.hpp @@ -180,6 +180,58 @@ namespace MueLu { return graph_; } + + template + void + Aggregates_kokkos >::ComputeNodesInAggregate(LO_view & aggPtr, LO_view & aggNodes, LO_view & unaggregated) const { + LO numAggs = GetNumAggregates(); + LO numNodes = vertex2AggId_->getLocalLength(); + auto vertex2AggId = vertex2AggId_->getDeviceLocalView(Xpetra::Access::ReadOnly); + typename aggregates_sizes_type::const_type aggSizes = ComputeAggregateSizes(true); + LO INVALID = Teuchos::OrdinalTraits::invalid(); + + aggPtr = LO_view("aggPtr",numAggs+1); + aggNodes = LO_view("aggNodes",numNodes); + LO_view aggCurr("agg curr",numAggs+1); + + // Construct the "rowptr" and the counter + Kokkos::parallel_scan("MueLu:Aggregates:ComputeNodesInAggregate:scan", range_type(0,numAggs+1), + KOKKOS_LAMBDA(const LO aggIdx, LO& aggOffset, bool final_pass) { + LO count = 0; + if(aggIdx < numAggs) + count = aggSizes(aggIdx); + if(final_pass) { + aggPtr(aggIdx) = aggOffset; + aggCurr(aggIdx) = aggOffset; + if(aggIdx==numAggs) + aggCurr(numAggs) = 0; // use this for counting unaggregated nodes + } + aggOffset += count; + }); + + // Preallocate unaggregated to the correct size + LO numUnaggregated = 0; + Kokkos::parallel_reduce("MueLu:Aggregates:ComputeNodesInAggregate:unaggregatedSize", range_type(0,numNodes), + KOKKOS_LAMBDA(const LO nodeIdx, LO & count) { + if(vertex2AggId(nodeIdx,0)==INVALID) + count++; + }, numUnaggregated); + unaggregated = LO_view("unaggregated",numUnaggregated); + + // Stick the nodes in each aggregate's spot + Kokkos::parallel_for("MueLu:Aggregates:ComputeNodesInAggregate:for", range_type(0,numNodes), + KOKKOS_LAMBDA(const LO nodeIdx) { + LO aggIdx = vertex2AggId(nodeIdx,0); + if(aggIdx != INVALID) { + // atomic postincrement aggCurr(aggIdx) each time + aggNodes(Kokkos::atomic_fetch_add(&aggCurr(aggIdx),1)) = nodeIdx; + } else { + // same, but using last entry of aggCurr for unaggregated nodes + unaggregated(Kokkos::atomic_fetch_add(&aggCurr(numAggs),1)) = nodeIdx; + } + }); + + } template std::string Aggregates_kokkos >::description() const { diff --git a/packages/muelu/src/Graph/MatrixTransformation/MueLu_AmalgamationFactory_def.hpp b/packages/muelu/src/Graph/MatrixTransformation/MueLu_AmalgamationFactory_def.hpp index 549b00075d81..9b4f08644349 100644 --- a/packages/muelu/src/Graph/MatrixTransformation/MueLu_AmalgamationFactory_def.hpp +++ b/packages/muelu/src/Graph/MatrixTransformation/MueLu_AmalgamationFactory_def.hpp @@ -75,11 +75,30 @@ namespace MueLu { RCP A = Get< RCP >(currentLevel, "A"); + /* NOTE: storageblocksize (from GetStorageBlockSize()) is the size of a block in the chosen storage scheme. + fullblocksize is the number of storage blocks that must kept together during the amalgamation process. + + Both of these quantities may be different than numPDEs (from GetFixedBlockSize()), but the following must always hold: + + numPDEs = fullblocksize * storageblocksize. + + If numPDEs==1 + Matrix is point storage (classical CRS storage). storageblocksize=1 and fullblocksize=1 + No other values makes sense. + + If numPDEs>1 + If matrix uses point storage, then storageblocksize=1 and fullblockssize=numPDEs. + If matrix uses block storage, with block size of n, then storageblocksize=n, and fullblocksize=numPDEs/n. + Thus far, only storageblocksize=numPDEs and fullblocksize=1 has been tested. + */ + + LO fullblocksize = 1; // block dim for fixed size blocks GO offset = 0; // global offset of dof gids LO blockid = -1; // block id in strided map LO nStridedOffset = 0; // DOF offset for strided block id "blockid" (default = 0) LO stridedblocksize = fullblocksize; // size of strided block id "blockid" (default = fullblocksize, only if blockid!=-1 stridedblocksize <= fullblocksize) + LO storageblocksize = A->GetStorageBlockSize(); // GO indexBase = A->getRowMap()->getIndexBase(); // index base for maps (unused) // 1) check for blocking/striding information @@ -101,6 +120,12 @@ namespace MueLu { } else { stridedblocksize = fullblocksize; } + // Correct for the storageblocksize + // NOTE: Before this point fullblocksize is actually numPDEs + TEUCHOS_TEST_FOR_EXCEPTION(fullblocksize % storageblocksize != 0,Exceptions::RuntimeError,"AmalgamationFactory: fullblocksize needs to be a multiple of A->GetStorageBlockSize()"); + fullblocksize /= storageblocksize; + stridedblocksize /= storageblocksize; + oldView = A->SwitchToView(oldView); GetOStream(Runtime1) << "AmalagamationFactory::Build():" << " found fullblocksize=" << fullblocksize << " and stridedblocksize=" << stridedblocksize << " from strided maps. offset=" << offset << std::endl; @@ -108,6 +133,7 @@ namespace MueLu { GetOStream(Warnings0) << "AmalagamationFactory::Build(): no striding information available. Use blockdim=1 with offset=0" << std::endl; } + // build node row map (uniqueMap) and node column map (nonUniqueMap) // the arrays rowTranslation and colTranslation contain the local node id // given a local dof id. They are only necessary for the CoalesceDropFactory if @@ -166,7 +192,7 @@ namespace MueLu { container filter; GO offset = 0; - LO blkSize = A.GetFixedBlockSize(); + LO blkSize = A.GetFixedBlockSize() / A.GetStorageBlockSize(); if (A.IsView("stridedMaps") == true) { Teuchos::RCP myMap = A.getRowMap("stridedMaps"); Teuchos::RCP strMap = Teuchos::rcp_dynamic_cast(myMap); diff --git a/packages/muelu/src/Graph/MatrixTransformation/MueLu_AmalgamationFactory_kokkos_def.hpp b/packages/muelu/src/Graph/MatrixTransformation/MueLu_AmalgamationFactory_kokkos_def.hpp index e945469b5ba1..cc284df2e27b 100644 --- a/packages/muelu/src/Graph/MatrixTransformation/MueLu_AmalgamationFactory_kokkos_def.hpp +++ b/packages/muelu/src/Graph/MatrixTransformation/MueLu_AmalgamationFactory_kokkos_def.hpp @@ -80,11 +80,29 @@ namespace MueLu { RCP A = Get< RCP >(currentLevel, "A"); + /* NOTE: storageblocksize (from GetStorageBlockSize()) is the size of a block in the chosen storage scheme. + fullblocksize is the number of storage blocks that must kept together during the amalgamation process. + + Both of these quantities may be different than numPDEs (from GetFixedBlockSize()), but the following must always hold: + + numPDEs = fullblocksize * storageblocksize. + + If numPDEs==1 + Matrix is point storage (classical CRS storage). storageblocksize=1 and fullblocksize=1 + No other values makes sense. + + If numPDEs>1 + If matrix uses point storage, then storageblocksize=1 and fullblockssize=numPDEs. + If matrix uses block storage, with block size of n, then storageblocksize=n, and fullblocksize=numPDEs/n. + Thus far, only storageblocksize=numPDEs and fullblocksize=1 has been tested. + */ + LO fullblocksize = 1; // block dim for fixed size blocks GO offset = 0; // global offset of dof gids LO blockid = -1; // block id in strided map LO nStridedOffset = 0; // DOF offset for strided block id "blockid" (default = 0) LO stridedblocksize = fullblocksize; // size of strided block id "blockid" (default = fullblocksize, only if blockid!=-1 stridedblocksize <= fullblocksize) + LO storageblocksize = A->GetStorageBlockSize(); // GO indexBase = A->getRowMap()->getIndexBase(); // index base for maps (unused) // 1) check for blocking/striding information @@ -106,6 +124,12 @@ namespace MueLu { } else { stridedblocksize = fullblocksize; } + // Correct for the storageblocksize + // NOTE: Before this point fullblocksize is actually numPDEs + TEUCHOS_TEST_FOR_EXCEPTION(fullblocksize % storageblocksize != 0,Exceptions::RuntimeError,"AmalgamationFactory::Build(): fullblocksize needs to be a multiple of A->GetStorageBlockSize()"); + fullblocksize /= storageblocksize; + stridedblocksize /= storageblocksize; + oldView = A->SwitchToView(oldView); GetOStream(Runtime1) << "AmalagamationFactory::Build():" << " found fullblocksize=" << fullblocksize << " and stridedblocksize=" << stridedblocksize << " from strided maps. offset=" << offset << std::endl; @@ -172,7 +196,7 @@ namespace MueLu { container filter; // TODO: replace std::set with an object having faster lookup/insert, hashtable for instance GO offset = 0; - LO blkSize = A.GetFixedBlockSize(); + LO blkSize = A.GetFixedBlockSize() / A.GetStorageBlockSize(); if (A.IsView("stridedMaps") == true) { Teuchos::RCP myMap = A.getRowMap("stridedMaps"); Teuchos::RCP strMap = Teuchos::rcp_dynamic_cast(myMap); diff --git a/packages/muelu/src/Graph/MatrixTransformation/MueLu_AmalgamationInfo_def.hpp b/packages/muelu/src/Graph/MatrixTransformation/MueLu_AmalgamationInfo_def.hpp index 9bd4b73d1169..edfad670c279 100644 --- a/packages/muelu/src/Graph/MatrixTransformation/MueLu_AmalgamationInfo_def.hpp +++ b/packages/muelu/src/Graph/MatrixTransformation/MueLu_AmalgamationInfo_def.hpp @@ -132,7 +132,7 @@ namespace MueLu { template void AmalgamationInfo::UnamalgamateAggregatesLO(const Aggregates& aggregates, - Teuchos::ArrayRCP& aggStart, Teuchos::ArrayRCP& aggToRowMap) const { + Teuchos::ArrayRCP& aggStart, Teuchos::ArrayRCP& aggToRowMap) const { int myPid = aggregates.GetMap()->getComm()->getRank(); Teuchos::ArrayView nodeGlobalElts = aggregates.GetMap()->getLocalElementList(); diff --git a/packages/muelu/src/Graph/MatrixTransformation/MueLu_CoalesceDropFactory_def.hpp b/packages/muelu/src/Graph/MatrixTransformation/MueLu_CoalesceDropFactory_def.hpp index 6c2241c12b8d..4d5f035069e1 100644 --- a/packages/muelu/src/Graph/MatrixTransformation/MueLu_CoalesceDropFactory_def.hpp +++ b/packages/muelu/src/Graph/MatrixTransformation/MueLu_CoalesceDropFactory_def.hpp @@ -370,7 +370,7 @@ namespace MueLu { distanceLaplacianAlgo = scaled_cut_symmetric; else TEUCHOS_TEST_FOR_EXCEPTION(true, Exceptions::RuntimeError, "\"aggregation: distance laplacian algo\" must be one of (default|unscaled cut|scaled cut), not \"" << distanceLaplacianAlgoStr << "\""); - GetOStream(Runtime0) << "algorithm = \"" << algo << "\" distance laplacian algorithm = \"" << distanceLaplacianAlgoStr << "\": threshold = " << threshold << ", blocksize = " << A->GetFixedBlockSize() << std::endl; + GetOStream(Runtime0) << "algorithm = \"" << algo << "\" distance laplacian algorithm = \"" << distanceLaplacianAlgoStr << "\": threshold = " << threshold << ", blocksize = " << A->GetFixedBlockSize()<< std::endl; } else if (algo == "classical") { if (classicalAlgoStr == "default") classicalAlgo = defaultAlgo; @@ -396,6 +396,27 @@ namespace MueLu { GO numDropped = 0, numTotal = 0; std::string graphType = "unamalgamated"; //for description purposes only + + /* NOTE: storageblocksize (from GetStorageBlockSize()) is the size of a block in the chosen storage scheme. + BlockSize is the number of storage blocks that must kept together during the amalgamation process. + + Both of these quantities may be different than numPDEs (from GetFixedBlockSize()), but the following must always hold: + + numPDEs = BlockSize * storageblocksize. + + If numPDEs==1 + Matrix is point storage (classical CRS storage). storageblocksize=1 and BlockSize=1 + No other values makes sense. + + If numPDEs>1 + If matrix uses point storage, then storageblocksize=1 and BlockSize=numPDEs. + If matrix uses block storage, with block size of n, then storageblocksize=n, and BlockSize=numPDEs/n. + Thus far, only storageblocksize=numPDEs and BlockSize=1 has been tested. + */ + TEUCHOS_TEST_FOR_EXCEPTION(A->GetFixedBlockSize() % A->GetStorageBlockSize() != 0,Exceptions::RuntimeError,"A->GetFixedBlockSize() needs to be a multiple of A->GetStorageBlockSize()"); + const LO BlockSize = A->GetFixedBlockSize() / A->GetStorageBlockSize(); + + /************************** RS or SA-style Classical Dropping (and variants) **************************/ if (algo == "classical") { if (predrop_ == null) { @@ -417,7 +438,7 @@ namespace MueLu { // At this points we either have // (predrop_ != null) // Therefore, it is sufficient to check only threshold - if (A->GetFixedBlockSize() == 1 && threshold == STS::zero() && !useSignedClassicalRS && !useSignedClassicalSA && A->hasCrsGraph()) { + if ( BlockSize==1 && threshold == STS::zero() && !useSignedClassicalRS && !useSignedClassicalSA && A->hasCrsGraph()) { // Case 1: scalar problem, no dropping => just use matrix graph RCP graph = rcp(new Graph(A->getCrsGraph(), "graph of A")); // Detect and record rows that correspond to Dirichlet boundary conditions @@ -442,10 +463,10 @@ namespace MueLu { Set(currentLevel, "DofsPerNode", 1); Set(currentLevel, "Graph", graph); - } else if ( (A->GetFixedBlockSize() == 1 && threshold != STS::zero()) || - (A->GetFixedBlockSize() == 1 && threshold == STS::zero() && !A->hasCrsGraph()) || - (A->GetFixedBlockSize() == 1 && useSignedClassicalRS) || - (A->GetFixedBlockSize() == 1 && useSignedClassicalSA) ) { + } else if ( (BlockSize == 1 && threshold != STS::zero()) || + (BlockSize == 1 && threshold == STS::zero() && !A->hasCrsGraph()) || + (BlockSize == 1 && useSignedClassicalRS) || + (BlockSize == 1 && useSignedClassicalSA) ) { // Case 2: scalar problem with dropping => record the column indices of undropped entries, but still use original // graph's map information, e.g., whether index is local // OR a matrix without a CrsGraph @@ -721,7 +742,7 @@ namespace MueLu { } #endif }//end generateColoringGraph - } else if (A->GetFixedBlockSize() > 1 && threshold == STS::zero()) { + } else if (BlockSize > 1 && threshold == STS::zero()) { // Case 3: Multiple DOF/node problem without dropping const RCP rowMap = A->getRowMap(); const RCP colMap = A->getColMap(); @@ -853,7 +874,7 @@ namespace MueLu { Set(currentLevel, "Graph", graph); Set(currentLevel, "DofsPerNode", blkSize); // full block size - } else if (A->GetFixedBlockSize() > 1 && threshold != STS::zero()) { + } else if (BlockSize > 1 && threshold != STS::zero()) { // Case 4: Multiple DOF/node problem with dropping const RCP rowMap = A->getRowMap(); const RCP colMap = A->getColMap(); diff --git a/packages/muelu/src/Graph/MatrixTransformation/MueLu_CoalesceDropFactory_kokkos_def.hpp b/packages/muelu/src/Graph/MatrixTransformation/MueLu_CoalesceDropFactory_kokkos_def.hpp index a3f394fe36b8..8f5a42e6d653 100644 --- a/packages/muelu/src/Graph/MatrixTransformation/MueLu_CoalesceDropFactory_kokkos_def.hpp +++ b/packages/muelu/src/Graph/MatrixTransformation/MueLu_CoalesceDropFactory_kokkos_def.hpp @@ -506,7 +506,27 @@ namespace MueLu { const MT zero = Teuchos::ScalarTraits::zero(); auto A = Get< RCP >(currentLevel, "A"); - LO blkSize = A->GetFixedBlockSize(); + + + /* NOTE: storageblocksize (from GetStorageBlockSize()) is the size of a block in the chosen storage scheme. + blkSize is the number of storage blocks that must kept together during the amalgamation process. + + Both of these quantities may be different than numPDEs (from GetFixedBlockSize()), but the following must always hold: + + numPDEs = blkSize * storageblocksize. + + If numPDEs==1 + Matrix is point storage (classical CRS storage). storageblocksize=1 and blkSize=1 + No other values makes sense. + + If numPDEs>1 + If matrix uses point storage, then storageblocksize=1 and blkSize=numPDEs. + If matrix uses block storage, with block size of n, then storageblocksize=n, and blkSize=numPDEs/n. + Thus far, only storageblocksize=numPDEs and blkSize=1 has been tested. + */ + + TEUCHOS_TEST_FOR_EXCEPTION(A->GetFixedBlockSize() % A->GetStorageBlockSize() != 0,Exceptions::RuntimeError,"A->GetFixedBlockSize() needs to be a multiple of A->GetStorageBlockSize()"); + LO blkSize = A->GetFixedBlockSize() / A->GetStorageBlockSize(); auto amalInfo = Get< RCP >(currentLevel, "UnAmalgamationInfo"); @@ -542,7 +562,7 @@ namespace MueLu { boundaryNodes = Utilities_kokkos::DetectDirichletRows(*A, dirichletThreshold); // Trivial LWGraph construction - graph = rcp(new LWGraph_kokkos(A->getLocalMatrixDevice().graph, A->getRowMap(), A->getColMap(), "graph of A")); + graph = rcp(new LWGraph_kokkos(A->getCrsGraph()->getLocalGraphDevice(), A->getRowMap(), A->getColMap(), "graph of A")); graph->getLocalLWGraph().SetBoundaryNodeMap(boundaryNodes); numTotal = A->getLocalNumEntries(); diff --git a/packages/muelu/src/Misc/MueLu_RAPFactory_def.hpp b/packages/muelu/src/Misc/MueLu_RAPFactory_def.hpp index 318da1790a21..1a95e3cd1c7f 100644 --- a/packages/muelu/src/Misc/MueLu_RAPFactory_def.hpp +++ b/packages/muelu/src/Misc/MueLu_RAPFactory_def.hpp @@ -56,6 +56,7 @@ #include #include #include +#include #include "MueLu_RAPFactory_decl.hpp" @@ -281,8 +282,7 @@ namespace MueLu { Xpetra::TripleMatrixMultiply:: MultiplyRAP(*P, doTranspose, *A, !doTranspose, *P, !doTranspose, *Ac, doFillComplete, doOptimizeStorage, labelstr+std::string("MueLu::R*A*P-implicit-")+levelstr.str(), - RAPparams); - + RAPparams); } else { RCP R = Get< RCP >(coarseLevel, "R"); Ac = MatrixFactory::Build(R->getRowMap(), Teuchos::as(0)); diff --git a/packages/muelu/src/MueCentral/MueLu_Hierarchy_def.hpp b/packages/muelu/src/MueCentral/MueLu_Hierarchy_def.hpp index 81751dab90ec..eda2f914e8a5 100644 --- a/packages/muelu/src/MueCentral/MueLu_Hierarchy_def.hpp +++ b/packages/muelu/src/MueCentral/MueLu_Hierarchy_def.hpp @@ -1245,9 +1245,10 @@ namespace MueLu { break; } - Xpetra::global_size_t nnz = Am->getGlobalNumEntries(); + LO storageblocksize=Am->GetStorageBlockSize(); + Xpetra::global_size_t nnz = Am->getGlobalNumEntries()*storageblocksize*storageblocksize; nnzPerLevel .push_back(nnz); - rowsPerLevel .push_back(Am->getGlobalNumRows()); + rowsPerLevel .push_back(Am->getGlobalNumRows()*storageblocksize); numProcsPerLevel.push_back(Am->getRowMap()->getComm()->getSize()); } @@ -1434,8 +1435,9 @@ namespace MueLu { } GetOStream(Runtime1) << "Replacing coordinate map" << std::endl; + TEUCHOS_TEST_FOR_EXCEPTION(A->GetFixedBlockSize() % A->GetStorageBlockSize() != 0, Exceptions::RuntimeError, "Hierarchy::ReplaceCoordinateMap: Storage block size does not evenly divide fixed block size"); - size_t blkSize = A->GetFixedBlockSize(); + size_t blkSize = A->GetFixedBlockSize() / A->GetStorageBlockSize(); RCP nodeMap = A->getRowMap(); if (blkSize > 1) { diff --git a/packages/muelu/src/Smoothers/MueLu_Ifpack2Smoother_def.hpp b/packages/muelu/src/Smoothers/MueLu_Ifpack2Smoother_def.hpp index 798209774e57..14357b2feba9 100644 --- a/packages/muelu/src/Smoothers/MueLu_Ifpack2Smoother_def.hpp +++ b/packages/muelu/src/Smoothers/MueLu_Ifpack2Smoother_def.hpp @@ -67,6 +67,7 @@ #include #include #include +#include #include #include @@ -234,16 +235,21 @@ namespace MueLu { if(Acrs.is_null()) throw std::runtime_error("Ifpack2Smoother: Cannot extract CrsMatrix from matrix A."); RCP At = rcp_dynamic_cast(Acrs); - if(At.is_null()) - throw std::runtime_error("Ifpack2Smoother: Cannot extract TpetraCrsMatrix from matrix A."); - - RCP > blockCrs = Tpetra::convertToBlockCrsMatrix(*At->getTpetra_CrsMatrix(),blocksize); - RCP blockCrs_as_crs = rcp(new TpetraBlockCrsMatrix(blockCrs)); - RCP blockWrap = rcp(new CrsMatrixWrap(blockCrs_as_crs)); - A_ = blockWrap; - this->GetOStream(Statistics0) << "Ifpack2Smoother: Using BlockCrsMatrix storage with blocksize "<::isTpetraBlockCrs(matA)) + throw std::runtime_error("Ifpack2Smoother: Cannot extract CrsMatrix or BlockCrsMatrix from matrix A."); + this->GetOStream(Statistics0) << "Ifpack2Smoother: Using (native) BlockCrsMatrix storage with blocksize "< > blockCrs = Tpetra::convertToBlockCrsMatrix(*At->getTpetra_CrsMatrix(),blocksize); + RCP blockCrs_as_crs = rcp(new TpetraBlockCrsMatrix(blockCrs)); + RCP blockWrap = rcp(new CrsMatrixWrap(blockCrs_as_crs)); + A_ = blockWrap; + this->GetOStream(Statistics0) << "Ifpack2Smoother: Using BlockCrsMatrix storage with blocksize "< @@ -155,6 +157,7 @@ namespace MueLu { SubFactoryMonitor sfm(*this, "BuildCoordinates", coarseLevel); RCP coarseCoordsFineMap = Get< RCP >(fineLevel, "coarseCoordinatesFineMap"); RCP coarseCoordsMap = Get< RCP >(fineLevel, "coarseCoordinatesMap"); + fineCoordinates = Get< RCP >(fineLevel, "Coordinates"); coarseCoordinates = Xpetra::MultiVectorFactory::Build(coarseCoordsFineMap, fineCoordinates->getNumVectors()); @@ -172,6 +175,7 @@ namespace MueLu { *out << "Fine and coarse coordinates have been loaded from the fine level and set on the coarse level." << std::endl; + if(interpolationOrder == 0) { SubFactoryMonitor sfm(*this, "BuildConstantP", coarseLevel); // Compute the prolongator using piece-wise constant interpolation @@ -222,8 +226,19 @@ namespace MueLu { RCP fineNullspace = Get< RCP > (fineLevel, "Nullspace"); RCP coarseNullspace = MultiVectorFactory::Build(P->getDomainMap(), fineNullspace->getNumVectors()); - P->apply(*fineNullspace, *coarseNullspace, Teuchos::TRANS, Teuchos::ScalarTraits::one(), - Teuchos::ScalarTraits::zero()); + + using helpers=Xpetra::Helpers; + if(helpers::isTpetraBlockCrs(A)) { + // FIXME: BlockCrs doesn't currently support transpose apply, so we have to do this the hard way + RCP Ptrans = Utilities::Transpose(*P); + Ptrans->apply(*fineNullspace, *coarseNullspace, Teuchos::NO_TRANS, Teuchos::ScalarTraits::one(), + Teuchos::ScalarTraits::zero()); + } + else { + P->apply(*fineNullspace, *coarseNullspace, Teuchos::TRANS, Teuchos::ScalarTraits::one(), + Teuchos::ScalarTraits::zero()); + } + Set(coarseLevel, "Nullspace", coarseNullspace); } @@ -257,19 +272,78 @@ namespace MueLu { *out << "Call prolongator constructor" << std::endl; - // Create the prolongator matrix and its associated objects - RCP dummyList = rcp(new ParameterList()); - P = rcp(new CrsMatrixWrap(prolongatorGraph, dummyList)); - RCP PCrs = rcp_dynamic_cast(P)->getCrsMatrix(); - PCrs->setAllToScalar(1.0); - PCrs->fillComplete(); + using helpers=Xpetra::Helpers; + if(helpers::isTpetraBlockCrs(A)) { +#ifdef HAVE_MUELU_TPETRA + SC one = Teuchos::ScalarTraits::one(); + SC zero = Teuchos::ScalarTraits::zero(); + LO NSDim = A->GetStorageBlockSize(); + + // Build the exploded Map + RCP BlockMap = prolongatorGraph->getDomainMap(); + Teuchos::ArrayView block_dofs = BlockMap->getLocalElementList(); + Teuchos::Array point_dofs(block_dofs.size()*NSDim); + for(LO i=0, ct=0; i PointMap = MapFactory::Build(BlockMap->lib(), + BlockMap->getGlobalNumElements() *NSDim, + point_dofs(), + BlockMap->getIndexBase(), + BlockMap->getComm()); + strideInfo[0] = A->GetFixedBlockSize(); + RCP stridedPointMap = StridedMapFactory::Build(PointMap, strideInfo); + + RCP > P_xpetra = Xpetra::CrsMatrixFactory::BuildBlock(prolongatorGraph, PointMap, A->getRangeMap(),NSDim); + RCP > P_tpetra = rcp_dynamic_cast >(P_xpetra); + if(P_tpetra.is_null()) throw std::runtime_error("BuildConstantP Matrix factory did not return a Tpetra::BlockCrsMatrix"); + RCP P_wrap = rcp(new CrsMatrixWrap(P_xpetra)); + + // NOTE: Assumes block-diagonal prolongation + Teuchos::Array temp(1); + Teuchos::ArrayView indices; + Teuchos::Array block(NSDim*NSDim, zero); + for(LO i=0; igetLocalNumRows(); i++) { + prolongatorGraph->getLocalRowView(i,indices); + for(LO j=0; j<(LO)indices.size();j++) { + temp[0] = indices[j]; + P_tpetra->replaceLocalValues(i,temp(),block()); + } + } + + P = P_wrap; + if (A->IsView("stridedMaps") == true) { + P->CreateView("stridedMaps", A->getRowMap("stridedMaps"), stridedPointMap); + } + else { + P->CreateView("stridedMaps", P->getRangeMap(), PointMap); + } + +#else + throw std::runtime_error("GeometricInteroplationFactory::Build(): BlockCrs requires Tpetra"); +#endif - // set StridingInformation of P - if (A->IsView("stridedMaps") == true) { - P->CreateView("stridedMaps", A->getRowMap("stridedMaps"), stridedDomainMap); - } else { - P->CreateView("stridedMaps", P->getRangeMap(), stridedDomainMap); } + else { + // Create the prolongator matrix and its associated objects + RCP dummyList = rcp(new ParameterList()); + P = rcp(new CrsMatrixWrap(prolongatorGraph, dummyList)); + RCP PCrs = rcp_dynamic_cast(P)->getCrsMatrix(); + PCrs->setAllToScalar(1.0); + PCrs->fillComplete(); + + // set StridingInformation of P + if (A->IsView("stridedMaps") == true) + P->CreateView("stridedMaps", A->getRowMap("stridedMaps"), stridedDomainMap); + else + P->CreateView("stridedMaps", P->getRangeMap(), stridedDomainMap); + } } // BuildConstantP @@ -293,7 +367,7 @@ namespace MueLu { // Compute 2^numDimensions using bit logic to avoid round-off errors const int numInterpolationPoints = 1 << numDimensions; - const int dofsPerNode = A->GetFixedBlockSize(); + const int dofsPerNode = A->GetFixedBlockSize()/ A->GetStorageBlockSize();; RCP dummyList = rcp(new ParameterList()); P = rcp(new CrsMatrixWrap(prolongatorGraph, dummyList)); diff --git a/packages/muelu/src/Transfers/GeneralGeometric/MueLu_GeometricInterpolationPFactory_kokkos_decl.hpp b/packages/muelu/src/Transfers/GeneralGeometric/MueLu_GeometricInterpolationPFactory_kokkos_decl.hpp index 91024e11abe2..b0e5f27a90e2 100644 --- a/packages/muelu/src/Transfers/GeneralGeometric/MueLu_GeometricInterpolationPFactory_kokkos_decl.hpp +++ b/packages/muelu/src/Transfers/GeneralGeometric/MueLu_GeometricInterpolationPFactory_kokkos_decl.hpp @@ -123,8 +123,8 @@ namespace MueLu{ //@} - private: void BuildConstantP(RCP& P, RCP& prolongatorGraph, RCP& A) const; + private: void BuildLinearP(RCP& A, RCP& prolongatorGraph, RCP& fineCoordinates, RCP& ghostCoordinates, diff --git a/packages/muelu/src/Transfers/GeneralGeometric/MueLu_GeometricInterpolationPFactory_kokkos_def.hpp b/packages/muelu/src/Transfers/GeneralGeometric/MueLu_GeometricInterpolationPFactory_kokkos_def.hpp index 5cbcdd71108f..e7c94590f77d 100644 --- a/packages/muelu/src/Transfers/GeneralGeometric/MueLu_GeometricInterpolationPFactory_kokkos_def.hpp +++ b/packages/muelu/src/Transfers/GeneralGeometric/MueLu_GeometricInterpolationPFactory_kokkos_def.hpp @@ -53,6 +53,11 @@ #include "MueLu_Monitor.hpp" #include "MueLu_IndexManager_kokkos.hpp" +#ifdef HAVE_MUELU_TPETRA +#include "Xpetra_TpetraCrsMatrix.hpp" +#endif + + // Including this one last ensure that the short names of the above headers are defined properly #include "MueLu_GeometricInterpolationPFactory_kokkos_decl.hpp" @@ -236,21 +241,87 @@ namespace MueLu { StridedMapFactory::Build(prolongatorGraph->getDomainMap(), strideInfo); *out << "Call prolongator constructor" << std::endl; + using helpers=Xpetra::Helpers; + if(helpers::isTpetraBlockCrs(A)) { +#ifdef HAVE_MUELU_TPETRA + LO NSDim = A->GetStorageBlockSize(); + + // Build the exploded Map + // FIXME: Should look at doing this on device + RCP BlockMap = prolongatorGraph->getDomainMap(); + Teuchos::ArrayView block_dofs = BlockMap->getLocalElementList(); + Teuchos::Array point_dofs(block_dofs.size()*NSDim); + for(LO i=0, ct=0; i PointMap = MapFactory::Build(BlockMap->lib(), + BlockMap->getGlobalNumElements() *NSDim, + point_dofs(), + BlockMap->getIndexBase(), + BlockMap->getComm()); + strideInfo[0] = A->GetFixedBlockSize(); + RCP stridedPointMap = StridedMapFactory::Build(PointMap, strideInfo); + + RCP > P_xpetra = Xpetra::CrsMatrixFactory::BuildBlock(prolongatorGraph, PointMap, A->getRangeMap(),NSDim); + RCP > P_tpetra = rcp_dynamic_cast >(P_xpetra); + if(P_tpetra.is_null()) throw std::runtime_error("BuildConstantP: Matrix factory did not return a Tpetra::BlockCrsMatrix"); + RCP P_wrap = rcp(new CrsMatrixWrap(P_xpetra)); + + const LO stride = strideInfo[0]*strideInfo[0]; + const LO in_stride = strideInfo[0]; + typename CrsMatrix::local_graph_type localGraph = prolongatorGraph->getLocalGraphDevice(); + auto rowptr = localGraph.row_map; + auto indices = localGraph.entries; + auto values = P_tpetra->getTpetra_BlockCrsMatrix()->getValuesDeviceNonConst(); + + using ISC = typename Tpetra::BlockCrsMatrix::impl_scalar_type; + ISC one = Teuchos::ScalarTraits::one(); + + const Kokkos::TeamPolicy policy(prolongatorGraph->getLocalNumRows(), 1); + + Kokkos::parallel_for("MueLu:GeoInterpFact::BuildConstantP::fill", policy, + KOKKOS_LAMBDA(const typename Kokkos::TeamPolicy::member_type &thread) { + auto row = thread.league_rank(); + for(LO j = (LO)rowptr[row]; j<(LO) rowptr[row+1]; j++) { + LO block_offset = j*stride; + for(LO k=0; kIsView("stridedMaps") == true) { + P->CreateView("stridedMaps", A->getRowMap("stridedMaps"), stridedPointMap); + } + else { + P->CreateView("stridedMaps", P->getRangeMap(), PointMap); + } - // Create the prolongator matrix and its associated objects - RCP dummyList = rcp(new ParameterList()); - P = rcp(new CrsMatrixWrap(prolongatorGraph, dummyList)); - RCP PCrs = rcp_dynamic_cast(P)->getCrsMatrix(); - PCrs->setAllToScalar(1.0); - PCrs->fillComplete(); +#else + throw std::runtime_error("GeometricInteroplationFactory::BuildConstantP(): BlockCrs requires Tpetra"); +#endif - // set StridingInformation of P - if (A->IsView("stridedMaps") == true) { - P->CreateView("stridedMaps", A->getRowMap("stridedMaps"), stridedDomainMap); - } else { - P->CreateView("stridedMaps", P->getRangeMap(), stridedDomainMap); } - + else { + // Create the prolongator matrix and its associated objects + RCP dummyList = rcp(new ParameterList()); + P = rcp(new CrsMatrixWrap(prolongatorGraph, dummyList)); + RCP PCrs = rcp_dynamic_cast(P)->getCrsMatrix(); + PCrs->setAllToScalar(1.0); + PCrs->fillComplete(); + + // set StridingInformation of P + if (A->IsView("stridedMaps") == true) { + P->CreateView("stridedMaps", A->getRowMap("stridedMaps"), stridedDomainMap); + } else { + P->CreateView("stridedMaps", P->getRangeMap(), stridedDomainMap); + } + } + } // BuildConstantP template diff --git a/packages/muelu/src/Transfers/Smoothed-Aggregation/MueLu_TentativePFactory_decl.hpp b/packages/muelu/src/Transfers/Smoothed-Aggregation/MueLu_TentativePFactory_decl.hpp index 24673877c7e6..2cce10e7d3e3 100644 --- a/packages/muelu/src/Transfers/Smoothed-Aggregation/MueLu_TentativePFactory_decl.hpp +++ b/packages/muelu/src/Transfers/Smoothed-Aggregation/MueLu_TentativePFactory_decl.hpp @@ -50,6 +50,7 @@ #include #include +#include #include #include #include @@ -157,6 +158,8 @@ template coarseMap, RCP& Ptentative, RCP& coarseNullspace, const int levelID) const; void BuildPcoupled (RCP A, RCP aggregates, RCP amalgInfo, RCP fineNullspace, RCP coarseMap, RCP& Ptentative, RCP& coarseNullspace) const; + void BuildPuncoupledBlockCrs(RCP A, RCP aggregates, RCP amalgInfo, RCP fineNullspace, + RCP coarseMap, RCP& Ptentative, RCP& coarseNullspace, const int levelID) const; mutable bool bTransferCoordinates_ = false; diff --git a/packages/muelu/src/Transfers/Smoothed-Aggregation/MueLu_TentativePFactory_def.hpp b/packages/muelu/src/Transfers/Smoothed-Aggregation/MueLu_TentativePFactory_def.hpp index c5d76496c439..eb9d0a49f021 100644 --- a/packages/muelu/src/Transfers/Smoothed-Aggregation/MueLu_TentativePFactory_def.hpp +++ b/packages/muelu/src/Transfers/Smoothed-Aggregation/MueLu_TentativePFactory_def.hpp @@ -49,7 +49,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -58,6 +60,12 @@ #include #include #include +#include + +#ifdef HAVE_MUELU_TPETRA +#include "Xpetra_TpetraBlockCrsMatrix.hpp" +//#include "Tpetra_BlockCrsMatrix.hpp" +#endif #include "MueLu_TentativePFactory_decl.hpp" @@ -71,6 +79,9 @@ #include "MueLu_PerfUtils.hpp" #include "MueLu_Utilities.hpp" + + + namespace MueLu { template @@ -157,7 +168,8 @@ namespace MueLu { Set > >(coarseLevel, "Node Comm", nodeComm); } - TEUCHOS_TEST_FOR_EXCEPTION(A->getRowMap()->getLocalNumElements() != fineNullspace->getMap()->getLocalNumElements(), + // NOTE: We check DomainMap here rather than RowMap because those are different for BlockCrs matrices + TEUCHOS_TEST_FOR_EXCEPTION( A->getDomainMap()->getLocalNumElements() != fineNullspace->getMap()->getLocalNumElements(), Exceptions::RuntimeError,"MueLu::TentativePFactory::MakeTentative: Size mismatch between A and Nullspace"); RCP Ptentative; @@ -225,10 +237,16 @@ namespace MueLu { } } - if (!aggregates->AggregatesCrossProcessors()) - BuildPuncoupled(A, aggregates, amalgInfo, fineNullspace, coarseMap, Ptentative, coarseNullspace,coarseLevel.GetLevelID()); + if (!aggregates->AggregatesCrossProcessors()) { + if(Xpetra::Helpers::isTpetraBlockCrs(A)) { + BuildPuncoupledBlockCrs(A, aggregates, amalgInfo, fineNullspace, coarseMap, Ptentative, coarseNullspace,coarseLevel.GetLevelID()); + } + else { + BuildPuncoupled(A, aggregates, amalgInfo, fineNullspace, coarseMap, Ptentative, coarseNullspace,coarseLevel.GetLevelID()); + } + } else - BuildPcoupled (A, aggregates, amalgInfo, fineNullspace, coarseMap, Ptentative, coarseNullspace); + BuildPcoupled(A, aggregates, amalgInfo, fineNullspace, coarseMap, Ptentative, coarseNullspace); // If available, use striding information of fine level matrix A for range // map and coarseMap as domain map; otherwise use plain range map of @@ -258,12 +276,24 @@ namespace MueLu { template void TentativePFactory:: - BuildPuncoupled(RCP A, RCP aggregates, RCP amalgInfo, RCP fineNullspace, - RCP coarseMap, RCP& Ptentative, RCP& coarseNullspace, const int levelID) const { - RCP rowMap = A->getRowMap(); - RCP colMap = A->getColMap(); + BuildPuncoupledBlockCrs(RCP A, RCP aggregates, RCP amalgInfo, RCP fineNullspace, + RCP coarsePointMap, RCP& Ptentative, RCP& coarseNullspace, const int levelID) const { +#ifdef HAVE_MUELU_TPETRA - const size_t numRows = rowMap->getLocalNumElements(); + /* This routine generates a BlockCrs P for a BlockCrs A. There are a few assumptions here, which meet the use cases we care about, but could + be generalized later, if we ever need to do so: + 1) Null space dimension === block size of matrix: So no elasticity right now + 2) QR is not supported: Under assumption #1, this shouldn't cause problems. + 3) Maps are "good": Aka the first chunk of the ColMap is the RowMap. + + These assumptions keep our code way simpler and still support the use cases we actually care about. + */ + + RCP rowMap = A->getRowMap(); + RCP rangeMap = A->getRangeMap(); + RCP colMap = A->getColMap(); + // const size_t numFinePointRows = rangeMap->getLocalNumElements(); + const size_t numFineBlockRows = rowMap->getLocalNumElements(); typedef Teuchos::ScalarTraits STS; typedef typename STS::magnitudeType Magnitude; @@ -275,7 +305,16 @@ namespace MueLu { const size_t NSDim = fineNullspace->getNumVectors(); ArrayRCP aggSizes = aggregates->ComputeAggregateSizes(); - + // Need to generate the coarse block map + // NOTE: We assume NSDim == block size here + // NOTE: We also assume that coarseMap has contiguous GIDs + //const size_t numCoarsePointRows = coarsePointMap->getLocalNumElements(); + const size_t numCoarseBlockRows = coarsePointMap->getLocalNumElements() / NSDim; + RCP coarseBlockMap = MapFactory::Build(coarsePointMap->lib(), + Teuchos::OrdinalTraits::invalid(), + numCoarseBlockRows, + coarsePointMap->getIndexBase(), + coarsePointMap->getComm()); // Sanity checking const ParameterList& pL = GetParameterList(); const bool &doQRStep = pL.get("tentative: calculate qr"); @@ -284,6 +323,8 @@ namespace MueLu { TEUCHOS_TEST_FOR_EXCEPTION(doQRStep && constantColSums,Exceptions::RuntimeError, "MueLu::TentativePFactory::MakeTentative: cannot use 'constant column sums' and 'calculate qr' at the same time"); + // The aggregates use the amalgamated column map, which in this case is what we want + // Aggregates map is based on the amalgamated column map // We can skip global-to-local conversion if LIDs in row map are // same as LIDs in column map @@ -299,394 +340,238 @@ namespace MueLu { if (goodMap) { amalgInfo->UnamalgamateAggregatesLO(*aggregates, aggStart, aggToRowMapLO); GetOStream(Runtime1) << "Column map is consistent with the row map, good." << std::endl; - } else { - amalgInfo->UnamalgamateAggregates(*aggregates, aggStart, aggToRowMapGO); - GetOStream(Warnings0) << "Column map is not consistent with the row map\n" - << "using GO->LO conversion with performance penalty" << std::endl; + throw std::runtime_error("TentativePFactory::PuncoupledBlockCrs: Inconsistent maps not currently supported"); } - - coarseNullspace = MultiVectorFactory::Build(coarseMap, NSDim); + + coarseNullspace = MultiVectorFactory::Build(coarsePointMap, NSDim); // Pull out the nullspace vectors so that we can have random access. ArrayRCP > fineNS (NSDim); ArrayRCP > coarseNS(NSDim); for (size_t i = 0; i < NSDim; i++) { fineNS[i] = fineNullspace->getData(i); - if (coarseMap->getLocalNumElements() > 0) + if (coarsePointMap->getLocalNumElements() > 0) coarseNS[i] = coarseNullspace->getDataNonConst(i); } - size_t nnzEstimate = numRows * NSDim; - - // Time to construct the matrix and fill in the values - Ptentative = rcp(new CrsMatrixWrap(rowMap, coarseMap, 0)); - RCP PtentCrs = rcp_dynamic_cast(Ptentative)->getCrsMatrix(); - + // BlockCrs requires that we build the (block) graph first, so let's do that... + // NOTE: Because we're assuming that the NSDim == BlockSize, we only have one + // block non-zero per row in the matrix; + RCP BlockGraph = CrsGraphFactory::Build(rowMap,coarseBlockMap,0); ArrayRCP iaPtent; ArrayRCP jaPtent; - ArrayRCP valPtent; - - PtentCrs->allocateAllValues(nnzEstimate, iaPtent, jaPtent, valPtent); - + BlockGraph->allocateAllIndices(numFineBlockRows, iaPtent, jaPtent); ArrayView ia = iaPtent(); ArrayView ja = jaPtent(); - ArrayView val = valPtent(); - - ia[0] = 0; - for (size_t i = 1; i <= numRows; i++) - ia[i] = ia[i-1] + NSDim; - for (size_t j = 0; j < nnzEstimate; j++) { - ja [j] = INVALID; - val[j] = zero; + for (size_t i = 0; i < numFineBlockRows; i++) { + ia[i] = i; + ja[i] = INVALID; } + ia[numCoarseBlockRows] = numCoarseBlockRows; - if (doQRStep) { - //////////////////////////////// - // Standard aggregate-wise QR // - //////////////////////////////// - for (GO agg = 0; agg < numAggs; agg++) { - LO aggSize = aggStart[agg+1] - aggStart[agg]; + for (GO agg = 0; agg < numAggs; agg++) { + LO aggSize = aggStart[agg+1] - aggStart[agg]; + Xpetra::global_size_t offset = agg; - Xpetra::global_size_t offset = agg*NSDim; + for (LO j = 0; j < aggSize; j++) { + // FIXME: Allow for bad maps + const LO localRow = aggToRowMapLO[aggStart[agg]+j]; + const size_t rowStart = ia[localRow]; + ja[rowStart] = offset; + } + } - // Extract the piece of the nullspace corresponding to the aggregate, and - // put it in the flat array, "localQR" (in column major format) for the - // QR routine. - Teuchos::SerialDenseMatrix localQR(aggSize, NSDim); - if (goodMap) { - for (size_t j = 0; j < NSDim; j++) - for (LO k = 0; k < aggSize; k++) - localQR(k,j) = fineNS[j][aggToRowMapLO[aggStart[agg]+k]]; - } else { - for (size_t j = 0; j < NSDim; j++) - for (LO k = 0; k < aggSize; k++) - localQR(k,j) = fineNS[j][rowMap->getLocalElement(aggToRowMapGO[aggStart[agg]+k])]; + // Compress storage (remove all INVALID, which happen when we skip zeros) + // We do that in-place + size_t ia_tmp = 0, nnz = 0; + for (size_t i = 0; i < numFineBlockRows; i++) { + for (size_t j = ia_tmp; j < ia[i+1]; j++) + if (ja[j] != INVALID) { + ja [nnz] = ja [j]; + nnz++; } + ia_tmp = ia[i+1]; + ia[i+1] = nnz; + } - // Test for zero columns - for (size_t j = 0; j < NSDim; j++) { - bool bIsZeroNSColumn = true; + if (rowMap->lib() == Xpetra::UseTpetra) { + // - Cannot resize for Epetra, as it checks for same pointers + // - Need to resize for Tpetra, as it check ().size() == ia[numRows] + // NOTE: these invalidate ja and val views + jaPtent .resize(nnz); + } - for (LO k = 0; k < aggSize; k++) - if (localQR(k,j) != zero) - bIsZeroNSColumn = false; + GetOStream(Runtime1) << "TentativePFactory : generating block graph" << std::endl; + BlockGraph->setAllIndices(iaPtent, jaPtent); - TEUCHOS_TEST_FOR_EXCEPTION(bIsZeroNSColumn == true, Exceptions::RuntimeError, - "MueLu::TentativePFactory::MakeTentative: fine level NS part has a zero column in NS column " << j); - } + // Managing labels & constants for ESFC + { + RCP FCparams; + if(pL.isSublist("matrixmatrix: kernel params")) + FCparams=rcp(new ParameterList(pL.sublist("matrixmatrix: kernel params"))); + else + FCparams= rcp(new ParameterList); + // By default, we don't need global constants for TentativeP, but we do want it for the graph + // if we're printing statistics, so let's leave it on for now. + FCparams->set("compute global constants",FCparams->get("compute global constants",true)); + std::string levelIDs = toString(levelID); + FCparams->set("Timer Label",std::string("MueLu::TentativeP-")+levelIDs); + RCP dummy_e; + RCP dummy_i; + BlockGraph->expertStaticFillComplete(coarseBlockMap,rowMap,dummy_i,dummy_e,FCparams); + } - // Calculate QR decomposition (standard) - // NOTE: Q is stored in localQR and R is stored in coarseNS - if (aggSize >= Teuchos::as(NSDim)) { + // Now let's make a BlockCrs Matrix + // NOTE: Assumes block size== NSDim + RCP > P_xpetra = Xpetra::CrsMatrixFactory::BuildBlock(BlockGraph, coarsePointMap, rangeMap,NSDim); + RCP > P_tpetra = rcp_dynamic_cast >(P_xpetra); + if(P_tpetra.is_null()) throw std::runtime_error("BuildPUncoupled: Matrix factory did not return a Tpetra::BlockCrsMatrix"); + RCP P_wrap = rcp(new CrsMatrixWrap(P_xpetra)); + + ///////////////////////////// + // "no-QR" option // + ///////////////////////////// + // Local Q factor is just the fine nullspace support over the current aggregate. + // Local R factor is the identity. + // NOTE: We're not going to do a QR here as we're assuming that blocksize == NSDim + // NOTE: "goodMap" case only + Teuchos::Array block(NSDim*NSDim, zero); + Teuchos::Array bcol(1); + + GetOStream(Runtime1) << "TentativePFactory : bypassing local QR phase" << std::endl; + for (LO agg = 0; agg < numAggs; agg++) { + bcol[0] = agg; + const LO aggSize = aggStart[agg+1] - aggStart[agg]; + Xpetra::global_size_t offset = agg*NSDim; + + // Process each row in the local Q factor + // NOTE: Blocks are in row-major order + for (LO j = 0; j < aggSize; j++) { + const LO localBlockRow = aggToRowMapLO[aggStart[agg]+j]; + + for (size_t r = 0; r < NSDim; r++) { + LO localPointRow = localBlockRow*NSDim + r; + for (size_t c = 0; c < NSDim; c++) + block[r*NSDim+c] = fineNS[c][localPointRow]; + } + // NOTE: Assumes columns==aggs and are ordered sequentially + P_tpetra->replaceLocalValues(localBlockRow,bcol(),block()); + + }//end aggSize + + for (size_t j = 0; j < NSDim; j++) + coarseNS[j][offset+j] = one; + + } //for (GO agg = 0; agg < numAggs; agg++) + + Ptentative = P_wrap; +#else + throw std::runtime_error("TentativePFactory::BuildPuncoupledBlockCrs: Requires Tpetra"); +#endif + } - if (NSDim == 1) { - // Only one nullspace vector, calculate Q and R by hand - Magnitude norm = STS::magnitude(zero); - for (size_t k = 0; k < Teuchos::as(aggSize); k++) - norm += STS::magnitude(localQR(k,0)*localQR(k,0)); - norm = Teuchos::ScalarTraits::squareroot(norm); + template + void TentativePFactory:: + BuildPcoupled(RCP A, RCP aggregates, RCP amalgInfo, RCP fineNullspace, + RCP coarseMap, RCP& Ptentative, RCP& coarseNullspace) const { + typedef Teuchos::ScalarTraits STS; + typedef typename STS::magnitudeType Magnitude; + const SC zero = STS::zero(); + const SC one = STS::one(); - // R = norm - coarseNS[0][offset] = norm; + // number of aggregates + GO numAggs = aggregates->GetNumAggregates(); - // Q = localQR(:,0)/norm - for (LO i = 0; i < aggSize; i++) - localQR(i,0) /= norm; + // Create a lookup table to determine the rows (fine DOFs) that belong to a given aggregate. + // aggStart is a pointer into aggToRowMap + // aggStart[i]..aggStart[i+1] are indices into aggToRowMap + // aggToRowMap[aggStart[i]]..aggToRowMap[aggStart[i+1]-1] are the DOFs in aggregate i + ArrayRCP aggStart; + ArrayRCP< GO > aggToRowMap; + amalgInfo->UnamalgamateAggregates(*aggregates, aggStart, aggToRowMap); - } else { - Teuchos::SerialQRDenseSolver qrSolver; - qrSolver.setMatrix(Teuchos::rcp(&localQR, false)); - qrSolver.factor(); + // find size of the largest aggregate + LO maxAggSize=0; + for (GO i=0; i maxAggSize) maxAggSize = sizeOfThisAgg; + } - // R = upper triangular part of localQR - for (size_t j = 0; j < NSDim; j++) - for (size_t k = 0; k <= j; k++) - coarseNS[j][offset+k] = localQR(k,j); //TODO is offset+k the correct local ID?! + // dimension of fine level nullspace + const size_t NSDim = fineNullspace->getNumVectors(); - // Calculate Q, the tentative prolongator. - // The Lapack GEQRF call only works for myAggsize >= NSDim - qrSolver.formQ(); - Teuchos::RCP > qFactor = qrSolver.getQ(); - for (size_t j = 0; j < NSDim; j++) - for (size_t i = 0; i < Teuchos::as(aggSize); i++) - localQR(i,j) = (*qFactor)(i,j); - } + // index base for coarse Dof map (usually 0) + GO indexBase=A->getRowMap()->getIndexBase(); - } else { - // Special handling for aggSize < NSDim (i.e. single node aggregates in structural mechanics) + const RCP nonUniqueMap = amalgInfo->ComputeUnamalgamatedImportDofMap(*aggregates); + const RCP uniqueMap = A->getDomainMap(); + RCP importer = ImportFactory::Build(uniqueMap, nonUniqueMap); + RCP fineNullspaceWithOverlap = MultiVectorFactory::Build(nonUniqueMap,NSDim); + fineNullspaceWithOverlap->doImport(*fineNullspace,*importer,Xpetra::INSERT); - // The local QR decomposition is not possible in the "overconstrained" - // case (i.e. number of columns in localQR > number of rows), which - // corresponds to #DOFs in Aggregate < NSDim. For usual problems this - // is only possible for single node aggregates in structural mechanics. - // (Similar problems may arise in discontinuous Galerkin problems...) - // We bypass the QR decomposition and use an identity block in the - // tentative prolongator for the single node aggregate and transfer the - // corresponding fine level null space information 1-to-1 to the coarse - // level null space part. + // Pull out the nullspace vectors so that we can have random access. + ArrayRCP< ArrayRCP > fineNS(NSDim); + for (size_t i=0; igetData(i); - // NOTE: The resulting tentative prolongation operator has - // (aggSize*DofsPerNode-NSDim) zero columns leading to a singular - // coarse level operator A. To deal with that one has the following - // options: - // - Use the "RepairMainDiagonal" flag in the RAPFactory (default: - // false) to add some identity block to the diagonal of the zero rows - // in the coarse level operator A, such that standard level smoothers - // can be used again. - // - Use special (projection-based) level smoothers, which can deal - // with singular matrices (very application specific) - // - Adapt the code below to avoid zero columns. However, we do not - // support a variable number of DOFs per node in MueLu/Xpetra which - // makes the implementation really hard. + //Allocate storage for the coarse nullspace. + coarseNullspace = MultiVectorFactory::Build(coarseMap, NSDim); - // R = extended (by adding identity rows) localQR - for (size_t j = 0; j < NSDim; j++) - for (size_t k = 0; k < NSDim; k++) - if (k < as(aggSize)) - coarseNS[j][offset+k] = localQR(k,j); - else - coarseNS[j][offset+k] = (k == j ? one : zero); + ArrayRCP< ArrayRCP > coarseNS(NSDim); + for (size_t i=0; igetLocalNumElements() > 0) coarseNS[i] = coarseNullspace->getDataNonConst(i); - // Q = I (rectangular) - for (size_t i = 0; i < as(aggSize); i++) - for (size_t j = 0; j < NSDim; j++) - localQR(i,j) = (j == i ? one : zero); - } + //This makes the rowmap of Ptent the same as that of A-> + //This requires moving some parts of some local Q's to other processors + //because aggregates can span processors. + RCP rowMapForPtent = A->getRowMap(); + const Map& rowMapForPtentRef = *rowMapForPtent; + // Set up storage for the rows of the local Qs that belong to other processors. + // FIXME This is inefficient and could be done within the main loop below with std::vector's. + RCP colMap = A->getColMap(); - // Process each row in the local Q factor - // FIXME: What happens if maps are blocked? - for (LO j = 0; j < aggSize; j++) { - LO localRow = (goodMap ? aggToRowMapLO[aggStart[agg]+j] : rowMap->getLocalElement(aggToRowMapGO[aggStart[agg]+j])); + RCP ghostQMap; + RCP ghostQvalues; + Array > > ghostQcolumns; + RCP > ghostQrowNums; + ArrayRCP< ArrayRCP > ghostQvals; + ArrayRCP< ArrayRCP > ghostQcols; + ArrayRCP< GO > ghostQrows; - size_t rowStart = ia[localRow]; - for (size_t k = 0, lnnz = 0; k < NSDim; k++) { - // Skip zeros (there may be plenty of them, i.e., NSDim > 1 or boundary conditions) - if (localQR(j,k) != zero) { - ja [rowStart+lnnz] = offset + k; - val[rowStart+lnnz] = localQR(j,k); - lnnz++; - } - } + Array ghostGIDs; + for (LO j=0; j1) - GetOStream(Warnings0) << "TentativePFactory : for nontrivial nullspace, this may degrade performance" << std::endl; - ///////////////////////////// - // "no-QR" option // - ///////////////////////////// - // Local Q factor is just the fine nullspace support over the current aggregate. - // Local R factor is the identity. - // TODO I have not implemented any special handling for aggregates that are too - // TODO small to locally support the nullspace, as is done in the standard QR - // TODO case above. - if (goodMap) { - for (GO agg = 0; agg < numAggs; agg++) { - const LO aggSize = aggStart[agg+1] - aggStart[agg]; - Xpetra::global_size_t offset = agg*NSDim; - - // Process each row in the local Q factor - // FIXME: What happens if maps are blocked? - for (LO j = 0; j < aggSize; j++) { - - //TODO Here I do not check for a zero nullspace column on the aggregate. - // as is done in the standard QR case. - - const LO localRow = aggToRowMapLO[aggStart[agg]+j]; - - const size_t rowStart = ia[localRow]; - - for (size_t k = 0, lnnz = 0; k < NSDim; k++) { - // Skip zeros (there may be plenty of them, i.e., NSDim > 1 or boundary conditions) - SC qr_jk = fineNS[k][aggToRowMapLO[aggStart[agg]+j]]; - if(constantColSums) qr_jk = qr_jk / (Magnitude)aggSizes[agg]; - if (qr_jk != zero) { - ja [rowStart+lnnz] = offset + k; - val[rowStart+lnnz] = qr_jk; - lnnz++; - } - } - } - for (size_t j = 0; j < NSDim; j++) - coarseNS[j][offset+j] = one; - } //for (GO agg = 0; agg < numAggs; agg++) - - } else { - for (GO agg = 0; agg < numAggs; agg++) { - const LO aggSize = aggStart[agg+1] - aggStart[agg]; - Xpetra::global_size_t offset = agg*NSDim; - for (LO j = 0; j < aggSize; j++) { - - const LO localRow = rowMap->getLocalElement(aggToRowMapGO[aggStart[agg]+j]); - - const size_t rowStart = ia[localRow]; - - for (size_t k = 0, lnnz = 0; k < NSDim; ++k) { - // Skip zeros (there may be plenty of them, i.e., NSDim > 1 or boundary conditions) - SC qr_jk = fineNS[k][rowMap->getLocalElement(aggToRowMapGO[aggStart[agg]+j])]; - if(constantColSums) qr_jk = qr_jk / (Magnitude)aggSizes[agg]; - if (qr_jk != zero) { - ja [rowStart+lnnz] = offset + k; - val[rowStart+lnnz] = qr_jk; - lnnz++; - } - } - } - for (size_t j = 0; j < NSDim; j++) - coarseNS[j][offset+j] = one; - } //for (GO agg = 0; agg < numAggs; agg++) - - } //if (goodmap) else ... - - } //if doQRStep ... else - - // Compress storage (remove all INVALID, which happen when we skip zeros) - // We do that in-place - size_t ia_tmp = 0, nnz = 0; - for (size_t i = 0; i < numRows; i++) { - for (size_t j = ia_tmp; j < ia[i+1]; j++) - if (ja[j] != INVALID) { - ja [nnz] = ja [j]; - val[nnz] = val[j]; - nnz++; - } - ia_tmp = ia[i+1]; - ia[i+1] = nnz; - } - if (rowMap->lib() == Xpetra::UseTpetra) { - // - Cannot resize for Epetra, as it checks for same pointers - // - Need to resize for Tpetra, as it check ().size() == ia[numRows] - // NOTE: these invalidate ja and val views - jaPtent .resize(nnz); - valPtent.resize(nnz); - } - - GetOStream(Runtime1) << "TentativePFactory : aggregates do not cross process boundaries" << std::endl; - - PtentCrs->setAllValues(iaPtent, jaPtent, valPtent); - - - // Managing labels & constants for ESFC - RCP FCparams; - if(pL.isSublist("matrixmatrix: kernel params")) - FCparams=rcp(new ParameterList(pL.sublist("matrixmatrix: kernel params"))); - else - FCparams= rcp(new ParameterList); - // By default, we don't need global constants for TentativeP - FCparams->set("compute global constants",FCparams->get("compute global constants",false)); - std::string levelIDs = toString(levelID); - FCparams->set("Timer Label",std::string("MueLu::TentativeP-")+levelIDs); - RCP dummy_e; - RCP dummy_i; - - PtentCrs->expertStaticFillComplete(coarseMap, A->getDomainMap(),dummy_i,dummy_e,FCparams); - } - - template - void TentativePFactory:: - BuildPcoupled(RCP A, RCP aggregates, RCP amalgInfo, RCP fineNullspace, - RCP coarseMap, RCP& Ptentative, RCP& coarseNullspace) const { - typedef Teuchos::ScalarTraits STS; - typedef typename STS::magnitudeType Magnitude; - const SC zero = STS::zero(); - const SC one = STS::one(); - - // number of aggregates - GO numAggs = aggregates->GetNumAggregates(); - - // Create a lookup table to determine the rows (fine DOFs) that belong to a given aggregate. - // aggStart is a pointer into aggToRowMap - // aggStart[i]..aggStart[i+1] are indices into aggToRowMap - // aggToRowMap[aggStart[i]]..aggToRowMap[aggStart[i+1]-1] are the DOFs in aggregate i - ArrayRCP aggStart; - ArrayRCP< GO > aggToRowMap; - amalgInfo->UnamalgamateAggregates(*aggregates, aggStart, aggToRowMap); - - // find size of the largest aggregate - LO maxAggSize=0; - for (GO i=0; i maxAggSize) maxAggSize = sizeOfThisAgg; - } - - // dimension of fine level nullspace - const size_t NSDim = fineNullspace->getNumVectors(); - - // index base for coarse Dof map (usually 0) - GO indexBase=A->getRowMap()->getIndexBase(); - - const RCP nonUniqueMap = amalgInfo->ComputeUnamalgamatedImportDofMap(*aggregates); - const RCP uniqueMap = A->getDomainMap(); - RCP importer = ImportFactory::Build(uniqueMap, nonUniqueMap); - RCP fineNullspaceWithOverlap = MultiVectorFactory::Build(nonUniqueMap,NSDim); - fineNullspaceWithOverlap->doImport(*fineNullspace,*importer,Xpetra::INSERT); - - // Pull out the nullspace vectors so that we can have random access. - ArrayRCP< ArrayRCP > fineNS(NSDim); - for (size_t i=0; igetData(i); - - //Allocate storage for the coarse nullspace. - coarseNullspace = MultiVectorFactory::Build(coarseMap, NSDim); - - ArrayRCP< ArrayRCP > coarseNS(NSDim); - for (size_t i=0; igetLocalNumElements() > 0) coarseNS[i] = coarseNullspace->getDataNonConst(i); - - //This makes the rowmap of Ptent the same as that of A-> - //This requires moving some parts of some local Q's to other processors - //because aggregates can span processors. - RCP rowMapForPtent = A->getRowMap(); - const Map& rowMapForPtentRef = *rowMapForPtent; - - // Set up storage for the rows of the local Qs that belong to other processors. - // FIXME This is inefficient and could be done within the main loop below with std::vector's. - RCP colMap = A->getColMap(); - - RCP ghostQMap; - RCP ghostQvalues; - Array > > ghostQcolumns; - RCP > ghostQrowNums; - ArrayRCP< ArrayRCP > ghostQvals; - ArrayRCP< ArrayRCP > ghostQcols; - ArrayRCP< GO > ghostQrows; - - Array ghostGIDs; - for (LO j=0; jgetRowMap()->lib(), - Teuchos::OrdinalTraits::invalid(), - ghostGIDs, - indexBase, A->getRowMap()->getComm()); //JG:Xpetra::global_size_t>? - //Vector to hold bits of Q that go to other processors. - ghostQvalues = MultiVectorFactory::Build(ghostQMap,NSDim); - //Note that Epetra does not support MultiVectors templated on Scalar != double. - //So to work around this, we allocate an array of Vectors. This shouldn't be too - //expensive, as the number of Vectors is NSDim. - ghostQcolumns.resize(NSDim); - for (size_t i=0; i::Build(ghostQMap); - ghostQrowNums = Xpetra::VectorFactory::Build(ghostQMap); - if (ghostQvalues->getLocalLength() > 0) { - ghostQvals.resize(NSDim); - ghostQcols.resize(NSDim); - for (size_t i=0; igetDataNonConst(i); - ghostQcols[i] = ghostQcolumns[i]->getDataNonConst(0); - } - ghostQrows = ghostQrowNums->getDataNonConst(0); - } + } + ghostQMap = MapFactory::Build(A->getRowMap()->lib(), + Teuchos::OrdinalTraits::invalid(), + ghostGIDs, + indexBase, A->getRowMap()->getComm()); //JG:Xpetra::global_size_t>? + //Vector to hold bits of Q that go to other processors. + ghostQvalues = MultiVectorFactory::Build(ghostQMap,NSDim); + //Note that Epetra does not support MultiVectors templated on Scalar != double. + //So to work around this, we allocate an array of Vectors. This shouldn't be too + //expensive, as the number of Vectors is NSDim. + ghostQcolumns.resize(NSDim); + for (size_t i=0; i::Build(ghostQMap); + ghostQrowNums = Xpetra::VectorFactory::Build(ghostQMap); + if (ghostQvalues->getLocalLength() > 0) { + ghostQvals.resize(NSDim); + ghostQcols.resize(NSDim); + for (size_t i=0; igetDataNonConst(i); + ghostQcols[i] = ghostQcolumns[i]->getDataNonConst(0); + } + ghostQrows = ghostQrowNums->getDataNonConst(0); + } //importer to handle moving Q importer = ImportFactory::Build(ghostQMap, A->getRowMap()); @@ -961,6 +846,338 @@ namespace MueLu { + template + void TentativePFactory:: + BuildPuncoupled(RCP A, RCP aggregates, RCP amalgInfo, RCP fineNullspace, + RCP coarseMap, RCP& Ptentative, RCP& coarseNullspace, const int levelID) const { + RCP rowMap = A->getRowMap(); + RCP colMap = A->getColMap(); + const size_t numRows = rowMap->getLocalNumElements(); + + typedef Teuchos::ScalarTraits STS; + typedef typename STS::magnitudeType Magnitude; + const SC zero = STS::zero(); + const SC one = STS::one(); + const LO INVALID = Teuchos::OrdinalTraits::invalid(); + + const GO numAggs = aggregates->GetNumAggregates(); + const size_t NSDim = fineNullspace->getNumVectors(); + ArrayRCP aggSizes = aggregates->ComputeAggregateSizes(); + + + // Sanity checking + const ParameterList& pL = GetParameterList(); + const bool &doQRStep = pL.get("tentative: calculate qr"); + const bool &constantColSums = pL.get("tentative: constant column sums"); + + TEUCHOS_TEST_FOR_EXCEPTION(doQRStep && constantColSums,Exceptions::RuntimeError, + "MueLu::TentativePFactory::MakeTentative: cannot use 'constant column sums' and 'calculate qr' at the same time"); + + // Aggregates map is based on the amalgamated column map + // We can skip global-to-local conversion if LIDs in row map are + // same as LIDs in column map + bool goodMap = MueLu::Utilities::MapsAreNested(*rowMap, *colMap); + + // Create a lookup table to determine the rows (fine DOFs) that belong to a given aggregate. + // aggStart is a pointer into aggToRowMapLO + // aggStart[i]..aggStart[i+1] are indices into aggToRowMapLO + // aggToRowMapLO[aggStart[i]]..aggToRowMapLO[aggStart[i+1]-1] are the DOFs in aggregate i + ArrayRCP aggStart; + ArrayRCP aggToRowMapLO; + ArrayRCP aggToRowMapGO; + if (goodMap) { + amalgInfo->UnamalgamateAggregatesLO(*aggregates, aggStart, aggToRowMapLO); + GetOStream(Runtime1) << "Column map is consistent with the row map, good." << std::endl; + + } else { + amalgInfo->UnamalgamateAggregates(*aggregates, aggStart, aggToRowMapGO); + GetOStream(Warnings0) << "Column map is not consistent with the row map\n" + << "using GO->LO conversion with performance penalty" << std::endl; + } + coarseNullspace = MultiVectorFactory::Build(coarseMap, NSDim); + + // Pull out the nullspace vectors so that we can have random access. + ArrayRCP > fineNS (NSDim); + ArrayRCP > coarseNS(NSDim); + for (size_t i = 0; i < NSDim; i++) { + fineNS[i] = fineNullspace->getData(i); + if (coarseMap->getLocalNumElements() > 0) + coarseNS[i] = coarseNullspace->getDataNonConst(i); + } + + size_t nnzEstimate = numRows * NSDim; + + // Time to construct the matrix and fill in the values + Ptentative = rcp(new CrsMatrixWrap(rowMap, coarseMap, 0)); + RCP PtentCrs = rcp_dynamic_cast(Ptentative)->getCrsMatrix(); + + ArrayRCP iaPtent; + ArrayRCP jaPtent; + ArrayRCP valPtent; + + PtentCrs->allocateAllValues(nnzEstimate, iaPtent, jaPtent, valPtent); + + ArrayView ia = iaPtent(); + ArrayView ja = jaPtent(); + ArrayView val = valPtent(); + + ia[0] = 0; + for (size_t i = 1; i <= numRows; i++) + ia[i] = ia[i-1] + NSDim; + + for (size_t j = 0; j < nnzEstimate; j++) { + ja [j] = INVALID; + val[j] = zero; + } + + + if (doQRStep) { + //////////////////////////////// + // Standard aggregate-wise QR // + //////////////////////////////// + for (GO agg = 0; agg < numAggs; agg++) { + LO aggSize = aggStart[agg+1] - aggStart[agg]; + + Xpetra::global_size_t offset = agg*NSDim; + + // Extract the piece of the nullspace corresponding to the aggregate, and + // put it in the flat array, "localQR" (in column major format) for the + // QR routine. + Teuchos::SerialDenseMatrix localQR(aggSize, NSDim); + if (goodMap) { + for (size_t j = 0; j < NSDim; j++) + for (LO k = 0; k < aggSize; k++) + localQR(k,j) = fineNS[j][aggToRowMapLO[aggStart[agg]+k]]; + } else { + for (size_t j = 0; j < NSDim; j++) + for (LO k = 0; k < aggSize; k++) + localQR(k,j) = fineNS[j][rowMap->getLocalElement(aggToRowMapGO[aggStart[agg]+k])]; + } + + // Test for zero columns + for (size_t j = 0; j < NSDim; j++) { + bool bIsZeroNSColumn = true; + + for (LO k = 0; k < aggSize; k++) + if (localQR(k,j) != zero) + bIsZeroNSColumn = false; + + TEUCHOS_TEST_FOR_EXCEPTION(bIsZeroNSColumn == true, Exceptions::RuntimeError, + "MueLu::TentativePFactory::MakeTentative: fine level NS part has a zero column in NS column " << j); + } + + // Calculate QR decomposition (standard) + // NOTE: Q is stored in localQR and R is stored in coarseNS + if (aggSize >= Teuchos::as(NSDim)) { + + if (NSDim == 1) { + // Only one nullspace vector, calculate Q and R by hand + Magnitude norm = STS::magnitude(zero); + for (size_t k = 0; k < Teuchos::as(aggSize); k++) + norm += STS::magnitude(localQR(k,0)*localQR(k,0)); + norm = Teuchos::ScalarTraits::squareroot(norm); + + // R = norm + coarseNS[0][offset] = norm; + + // Q = localQR(:,0)/norm + for (LO i = 0; i < aggSize; i++) + localQR(i,0) /= norm; + + } else { + Teuchos::SerialQRDenseSolver qrSolver; + qrSolver.setMatrix(Teuchos::rcp(&localQR, false)); + qrSolver.factor(); + + // R = upper triangular part of localQR + for (size_t j = 0; j < NSDim; j++) + for (size_t k = 0; k <= j; k++) + coarseNS[j][offset+k] = localQR(k,j); //TODO is offset+k the correct local ID?! + + // Calculate Q, the tentative prolongator. + // The Lapack GEQRF call only works for myAggsize >= NSDim + qrSolver.formQ(); + Teuchos::RCP > qFactor = qrSolver.getQ(); + for (size_t j = 0; j < NSDim; j++) + for (size_t i = 0; i < Teuchos::as(aggSize); i++) + localQR(i,j) = (*qFactor)(i,j); + } + + } else { + // Special handling for aggSize < NSDim (i.e. single node aggregates in structural mechanics) + + // The local QR decomposition is not possible in the "overconstrained" + // case (i.e. number of columns in localQR > number of rows), which + // corresponds to #DOFs in Aggregate < NSDim. For usual problems this + // is only possible for single node aggregates in structural mechanics. + // (Similar problems may arise in discontinuous Galerkin problems...) + // We bypass the QR decomposition and use an identity block in the + // tentative prolongator for the single node aggregate and transfer the + // corresponding fine level null space information 1-to-1 to the coarse + // level null space part. + + // NOTE: The resulting tentative prolongation operator has + // (aggSize*DofsPerNode-NSDim) zero columns leading to a singular + // coarse level operator A. To deal with that one has the following + // options: + // - Use the "RepairMainDiagonal" flag in the RAPFactory (default: + // false) to add some identity block to the diagonal of the zero rows + // in the coarse level operator A, such that standard level smoothers + // can be used again. + // - Use special (projection-based) level smoothers, which can deal + // with singular matrices (very application specific) + // - Adapt the code below to avoid zero columns. However, we do not + // support a variable number of DOFs per node in MueLu/Xpetra which + // makes the implementation really hard. + + // R = extended (by adding identity rows) localQR + for (size_t j = 0; j < NSDim; j++) + for (size_t k = 0; k < NSDim; k++) + if (k < as(aggSize)) + coarseNS[j][offset+k] = localQR(k,j); + else + coarseNS[j][offset+k] = (k == j ? one : zero); + + // Q = I (rectangular) + for (size_t i = 0; i < as(aggSize); i++) + for (size_t j = 0; j < NSDim; j++) + localQR(i,j) = (j == i ? one : zero); + } + + + // Process each row in the local Q factor + // FIXME: What happens if maps are blocked? + for (LO j = 0; j < aggSize; j++) { + LO localRow = (goodMap ? aggToRowMapLO[aggStart[agg]+j] : rowMap->getLocalElement(aggToRowMapGO[aggStart[agg]+j])); + + size_t rowStart = ia[localRow]; + for (size_t k = 0, lnnz = 0; k < NSDim; k++) { + // Skip zeros (there may be plenty of them, i.e., NSDim > 1 or boundary conditions) + if (localQR(j,k) != zero) { + ja [rowStart+lnnz] = offset + k; + val[rowStart+lnnz] = localQR(j,k); + lnnz++; + } + } + } + } + + } else { + GetOStream(Runtime1) << "TentativePFactory : bypassing local QR phase" << std::endl; + if (NSDim>1) + GetOStream(Warnings0) << "TentativePFactory : for nontrivial nullspace, this may degrade performance" << std::endl; + ///////////////////////////// + // "no-QR" option // + ///////////////////////////// + // Local Q factor is just the fine nullspace support over the current aggregate. + // Local R factor is the identity. + // TODO I have not implemented any special handling for aggregates that are too + // TODO small to locally support the nullspace, as is done in the standard QR + // TODO case above. + if (goodMap) { + for (GO agg = 0; agg < numAggs; agg++) { + const LO aggSize = aggStart[agg+1] - aggStart[agg]; + Xpetra::global_size_t offset = agg*NSDim; + + // Process each row in the local Q factor + // FIXME: What happens if maps are blocked? + for (LO j = 0; j < aggSize; j++) { + + //TODO Here I do not check for a zero nullspace column on the aggregate. + // as is done in the standard QR case. + + const LO localRow = aggToRowMapLO[aggStart[agg]+j]; + + const size_t rowStart = ia[localRow]; + + for (size_t k = 0, lnnz = 0; k < NSDim; k++) { + // Skip zeros (there may be plenty of them, i.e., NSDim > 1 or boundary conditions) + SC qr_jk = fineNS[k][aggToRowMapLO[aggStart[agg]+j]]; + if(constantColSums) qr_jk = qr_jk / (Magnitude)aggSizes[agg]; + if (qr_jk != zero) { + ja [rowStart+lnnz] = offset + k; + val[rowStart+lnnz] = qr_jk; + lnnz++; + } + } + } + for (size_t j = 0; j < NSDim; j++) + coarseNS[j][offset+j] = one; + } //for (GO agg = 0; agg < numAggs; agg++) + + } else { + for (GO agg = 0; agg < numAggs; agg++) { + const LO aggSize = aggStart[agg+1] - aggStart[agg]; + Xpetra::global_size_t offset = agg*NSDim; + for (LO j = 0; j < aggSize; j++) { + + const LO localRow = rowMap->getLocalElement(aggToRowMapGO[aggStart[agg]+j]); + + const size_t rowStart = ia[localRow]; + + for (size_t k = 0, lnnz = 0; k < NSDim; ++k) { + // Skip zeros (there may be plenty of them, i.e., NSDim > 1 or boundary conditions) + SC qr_jk = fineNS[k][rowMap->getLocalElement(aggToRowMapGO[aggStart[agg]+j])]; + if(constantColSums) qr_jk = qr_jk / (Magnitude)aggSizes[agg]; + if (qr_jk != zero) { + ja [rowStart+lnnz] = offset + k; + val[rowStart+lnnz] = qr_jk; + lnnz++; + } + } + } + for (size_t j = 0; j < NSDim; j++) + coarseNS[j][offset+j] = one; + } //for (GO agg = 0; agg < numAggs; agg++) + + } //if (goodmap) else ... + + } //if doQRStep ... else + + // Compress storage (remove all INVALID, which happen when we skip zeros) + // We do that in-place + size_t ia_tmp = 0, nnz = 0; + for (size_t i = 0; i < numRows; i++) { + for (size_t j = ia_tmp; j < ia[i+1]; j++) + if (ja[j] != INVALID) { + ja [nnz] = ja [j]; + val[nnz] = val[j]; + nnz++; + } + ia_tmp = ia[i+1]; + ia[i+1] = nnz; + } + if (rowMap->lib() == Xpetra::UseTpetra) { + // - Cannot resize for Epetra, as it checks for same pointers + // - Need to resize for Tpetra, as it check ().size() == ia[numRows] + // NOTE: these invalidate ja and val views + jaPtent .resize(nnz); + valPtent.resize(nnz); + } + + GetOStream(Runtime1) << "TentativePFactory : aggregates do not cross process boundaries" << std::endl; + + PtentCrs->setAllValues(iaPtent, jaPtent, valPtent); + + + // Managing labels & constants for ESFC + RCP FCparams; + if(pL.isSublist("matrixmatrix: kernel params")) + FCparams=rcp(new ParameterList(pL.sublist("matrixmatrix: kernel params"))); + else + FCparams= rcp(new ParameterList); + // By default, we don't need global constants for TentativeP + FCparams->set("compute global constants",FCparams->get("compute global constants",false)); + std::string levelIDs = toString(levelID); + FCparams->set("Timer Label",std::string("MueLu::TentativeP-")+levelIDs); + RCP dummy_e; + RCP dummy_i; + + PtentCrs->expertStaticFillComplete(coarseMap, A->getDomainMap(),dummy_i,dummy_e,FCparams); + } + + + } //namespace MueLu // TODO ReUse: If only P or Nullspace is missing, TentativePFactory can be smart and skip part of the computation. diff --git a/packages/muelu/src/Transfers/Smoothed-Aggregation/MueLu_TentativePFactory_kokkos_decl.hpp b/packages/muelu/src/Transfers/Smoothed-Aggregation/MueLu_TentativePFactory_kokkos_decl.hpp index 35b3151fc4b2..e387eb1a9677 100644 --- a/packages/muelu/src/Transfers/Smoothed-Aggregation/MueLu_TentativePFactory_kokkos_decl.hpp +++ b/packages/muelu/src/Transfers/Smoothed-Aggregation/MueLu_TentativePFactory_kokkos_decl.hpp @@ -55,6 +55,8 @@ #include "Teuchos_ScalarTraits.hpp" +#include "Xpetra_CrsGraphFactory_fwd.hpp" + #include "MueLu_Aggregates_kokkos_fwd.hpp" #include "MueLu_AmalgamationFactory_kokkos_fwd.hpp" #include "MueLu_AmalgamationInfo_kokkos_fwd.hpp" @@ -151,27 +153,27 @@ namespace MueLu { //@} - // CUDA 7.5 and 8.0 place a restriction on the placement of __device__ lambdas: - // - // An explicit __device__ lambda cannot be defined in a member function - // that has private or protected access within its class. - // - // Therefore, we expose BuildPuncoupled and isGoodMap for now. An alternative solution - // could be writing an out of class implementation, and then calling it in - // a member function. - void BuildPuncoupled(Level& coarseLevel, RCP A, RCP aggregates, - RCP amalgInfo, RCP fineNullspace, - RCP coarseMap, RCP& Ptentative, - RCP& coarseNullspace, const int levelID) const; + + // NOTE: All of thess should really be private, but CUDA doesn't like that + + void BuildPuncoupledBlockCrs(Level& coarseLevel, RCP A, RCP aggregates, RCP amalgInfo, + RCP fineNullspace, RCP coarseMap, RCP& Ptentative, RCP& coarseNullspace, const int levelID) const; + + bool isGoodMap(const Map& rowMap, const Map& colMap) const; - private: + void BuildPcoupled (RCP A, RCP aggregates, RCP amalgInfo, RCP fineNullspace, RCP coarseMap, RCP& Ptentative, RCP& coarseNullspace) const; + void BuildPuncoupled(Level& coarseLevel, RCP A, RCP aggregates, + RCP amalgInfo, RCP fineNullspace, + RCP coarseMap, RCP& Ptentative, + RCP& coarseNullspace, const int levelID) const; + mutable bool bTransferCoordinates_ = false; }; diff --git a/packages/muelu/src/Transfers/Smoothed-Aggregation/MueLu_TentativePFactory_kokkos_def.hpp b/packages/muelu/src/Transfers/Smoothed-Aggregation/MueLu_TentativePFactory_kokkos_def.hpp index 08f647cbcce9..e12983cc10d1 100644 --- a/packages/muelu/src/Transfers/Smoothed-Aggregation/MueLu_TentativePFactory_kokkos_def.hpp +++ b/packages/muelu/src/Transfers/Smoothed-Aggregation/MueLu_TentativePFactory_kokkos_def.hpp @@ -49,6 +49,7 @@ #ifdef HAVE_MUELU_KOKKOS_REFACTOR #include "Kokkos_UnorderedMap.hpp" +#include "Xpetra_CrsGraphFactory.hpp" #include "MueLu_TentativePFactory_kokkos_decl.hpp" @@ -56,12 +57,15 @@ #include "MueLu_AmalgamationFactory_kokkos.hpp" #include "MueLu_AmalgamationInfo_kokkos.hpp" #include "MueLu_CoarseMapFactory_kokkos.hpp" + #include "MueLu_MasterList.hpp" #include "MueLu_NullspaceFactory_kokkos.hpp" #include "MueLu_PerfUtils.hpp" #include "MueLu_Monitor.hpp" #include "MueLu_Utilities_kokkos.hpp" +#include "Xpetra_IO.hpp" + namespace MueLu { namespace { // anonymous @@ -531,8 +535,15 @@ namespace MueLu { } } - if (!aggregates->AggregatesCrossProcessors()) - BuildPuncoupled(coarseLevel, A, aggregates, amalgInfo, fineNullspace, coarseMap, Ptentative, coarseNullspace, coarseLevel.GetLevelID()); + if (!aggregates->AggregatesCrossProcessors()) { + if(Xpetra::Helpers::isTpetraBlockCrs(A)) { + BuildPuncoupledBlockCrs(coarseLevel,A, aggregates, amalgInfo, fineNullspace, coarseMap, Ptentative, coarseNullspace, + coarseLevel.GetLevelID()); + } + else { + BuildPuncoupled(coarseLevel, A, aggregates, amalgInfo, fineNullspace, coarseMap, Ptentative, coarseNullspace, coarseLevel.GetLevelID()); + } + } else BuildPcoupled (A, aggregates, amalgInfo, fineNullspace, coarseMap, Ptentative, coarseNullspace); @@ -966,6 +977,310 @@ namespace MueLu { } } + + template + void TentativePFactory_kokkos>:: + BuildPuncoupledBlockCrs(Level& coarseLevel, RCP A, RCP aggregates, + RCP amalgInfo, RCP fineNullspace, + RCP coarsePointMap, RCP& Ptentative, + RCP& coarseNullspace, const int levelID) const { +#ifdef HAVE_MUELU_TPETRA + /* This routine generates a BlockCrs P for a BlockCrs A. There are a few assumptions here, which meet the use cases we care about, but could + be generalized later, if we ever need to do so: + 1) Null space dimension === block size of matrix: So no elasticity right now + 2) QR is not supported: Under assumption #1, this shouldn't cause problems. + 3) Maps are "good": Aka the first chunk of the ColMap is the RowMap. + + These assumptions keep our code way simpler and still support the use cases we actually care about. + */ + + RCP rowMap = A->getRowMap(); + RCP rangeMap = A->getRangeMap(); + RCP colMap = A->getColMap(); + // const size_t numFinePointRows = rangeMap->getLocalNumElements(); + const size_t numFineBlockRows = rowMap->getLocalNumElements(); + + typedef Teuchos::ScalarTraits STS; + typedef typename STS::magnitudeType Magnitude; + // const SC zero = STS::zero(); + const SC one = STS::one(); + const LO INVALID = Teuchos::OrdinalTraits::invalid(); + + // const GO numAggs = aggregates->GetNumAggregates(); + const size_t NSDim = fineNullspace->getNumVectors(); + auto aggSizes = aggregates->ComputeAggregateSizes(); + + + typename Aggregates_kokkos::local_graph_type aggGraph; + { + SubFactoryMonitor m2(*this, "Get Aggregates graph", coarseLevel); + aggGraph = aggregates->GetGraph(); + } + auto aggRows = aggGraph.row_map; + auto aggCols = aggGraph.entries; + + + // Need to generate the coarse block map + // NOTE: We assume NSDim == block size here + // NOTE: We also assume that coarseMap has contiguous GIDs + //const size_t numCoarsePointRows = coarsePointMap->getLocalNumElements(); + const size_t numCoarseBlockRows = coarsePointMap->getLocalNumElements() / NSDim; + RCP coarseBlockMap = MapFactory::Build(coarsePointMap->lib(), + Teuchos::OrdinalTraits::invalid(), + numCoarseBlockRows, + coarsePointMap->getIndexBase(), + coarsePointMap->getComm()); + // Sanity checking + const ParameterList& pL = GetParameterList(); + // const bool &doQRStep = pL.get("tentative: calculate qr"); + + + // The aggregates use the amalgamated column map, which in this case is what we want + + // Aggregates map is based on the amalgamated column map + // We can skip global-to-local conversion if LIDs in row map are + // same as LIDs in column map + bool goodMap = MueLu::Utilities::MapsAreNested(*rowMap, *colMap); + TEUCHOS_TEST_FOR_EXCEPTION(!goodMap, Exceptions::RuntimeError, + "MueLu: TentativePFactory_kokkos: for now works only with good maps " + "(i.e. \"matching\" row and column maps)"); + + // STEP 1: do unamalgamation + // The non-kokkos version uses member functions from the AmalgamationInfo + // container class to unamalgamate the data. In contrast, the kokkos + // version of TentativePFactory does the unamalgamation here and only uses + // the data of the AmalgamationInfo container class + + // Extract information for unamalgamation + LO fullBlockSize, blockID, stridingOffset, stridedBlockSize; + GO indexBase; + amalgInfo->GetStridingInformation(fullBlockSize, blockID, stridingOffset, stridedBlockSize, indexBase); + //GO globalOffset = amalgInfo->GlobalOffset(); + + // Extract aggregation info (already in Kokkos host views) + auto procWinner = aggregates->GetProcWinner() ->getDeviceLocalView(Xpetra::Access::ReadOnly); + auto vertex2AggId = aggregates->GetVertex2AggId()->getDeviceLocalView(Xpetra::Access::ReadOnly); + const size_t numAggregates = aggregates->GetNumAggregates(); + + int myPID = aggregates->GetMap()->getComm()->getRank(); + + // Create Kokkos::View (on the device) to store the aggreate dof sizes + // Later used to get aggregate dof offsets + // NOTE: This zeros itself on construction + typedef typename Aggregates_kokkos::aggregates_sizes_type::non_const_type AggSizeType; + AggSizeType aggDofSizes; // This turns into "starts" after the parallel_scan + + { + SubFactoryMonitor m2(*this, "Calc AggSizes", coarseLevel); + + // FIXME_KOKKOS: use ViewAllocateWithoutInitializing + set a single value + aggDofSizes = AggSizeType("agg_dof_sizes", numAggregates+1); + + Kokkos::deep_copy(Kokkos::subview(aggDofSizes, Kokkos::make_pair(static_cast(1), numAggregates+1)), aggSizes); + } + + // Find maximum dof size for aggregates + // Later used to reserve enough scratch space for local QR decompositions + LO maxAggSize = 0; + ReduceMaxFunctor reduceMax(aggDofSizes); + Kokkos::parallel_reduce("MueLu:TentativePF:Build:max_agg_size", range_type(0, aggDofSizes.extent(0)), reduceMax, maxAggSize); + + // parallel_scan (exclusive) + // The aggDofSizes View then contains the aggregate dof offsets + Kokkos::parallel_scan("MueLu:TentativePF:Build:aggregate_sizes:stage1_scan", range_type(0,numAggregates+1), + KOKKOS_LAMBDA(const LO i, LO& update, const bool& final_pass) { + update += aggDofSizes(i); + if (final_pass) + aggDofSizes(i) = update; + }); + + // Create Kokkos::View on the device to store mapping + // between (local) aggregate id and row map ids (LIDs) + Kokkos::View aggToRowMapLO(Kokkos::ViewAllocateWithoutInitializing("aggtorow_map_LO"), numFineBlockRows); + { + SubFactoryMonitor m2(*this, "Create AggToRowMap", coarseLevel); + + AggSizeType aggOffsets(Kokkos::ViewAllocateWithoutInitializing("aggOffsets"), numAggregates); + Kokkos::deep_copy(aggOffsets, Kokkos::subview(aggDofSizes, Kokkos::make_pair(static_cast(0), numAggregates))); + + Kokkos::parallel_for("MueLu:TentativePF:Build:createAgg2RowMap", range_type(0, vertex2AggId.extent(0)), + KOKKOS_LAMBDA(const LO lnode) { + if (procWinner(lnode, 0) == myPID) { + // No need for atomics, it's one-to-one + auto aggID = vertex2AggId(lnode,0); + + auto offset = Kokkos::atomic_fetch_add( &aggOffsets(aggID), stridedBlockSize ); + // FIXME: I think this may be wrong + // We unconditionally add the whole block here. When we calculated + // aggDofSizes, we did the isLocalElement check. Something's fishy. + for (LO k = 0; k < stridedBlockSize; k++) + aggToRowMapLO(offset + k) = lnode*stridedBlockSize + k; + } + }); + } + + // STEP 2: prepare local QR decomposition + // Reserve memory for tentative prolongation operator + coarseNullspace = MultiVectorFactory::Build(coarsePointMap, NSDim); + + // Pull out the nullspace vectors so that we can have random access (on the device) + auto fineNS = fineNullspace ->getDeviceLocalView(Xpetra::Access::ReadWrite); + auto coarseNS = coarseNullspace->getDeviceLocalView(Xpetra::Access::OverwriteAll); + + typedef typename Xpetra::Matrix::local_matrix_type local_matrix_type; + typedef typename local_matrix_type::row_map_type::non_const_type rows_type; + typedef typename local_matrix_type::index_type::non_const_type cols_type; + typedef typename local_matrix_type::values_type::non_const_type vals_type; + + + // Device View for status (error messages...) + typedef Kokkos::View status_type; + status_type status("status"); + + typename AppendTrait::type fineNSRandom = fineNS; + typename AppendTrait ::type statusAtomic = status; + + // We're going to bypass QR in the BlockCrs version of the code regardless of what the user asks for + GetOStream(Runtime1) << "TentativePFactory : bypassing local QR phase" << std::endl; + + // BlockCrs requires that we build the (block) graph first, so let's do that... + + // NOTE: Because we're assuming that the NSDim == BlockSize, we only have one + // block non-zero per row in the matrix; + rows_type ia(Kokkos::ViewAllocateWithoutInitializing("BlockGraph_rowptr"), numFineBlockRows+1); + cols_type ja(Kokkos::ViewAllocateWithoutInitializing("BlockGraph_colind"), numFineBlockRows); + + Kokkos::parallel_for("MueLu:TentativePF:BlockCrs:graph_init", range_type(0, numFineBlockRows), + KOKKOS_LAMBDA(const LO j) { + ia[j] = j; + ja[j] = INVALID; + + if(j==(LO)numFineBlockRows-1) + ia[numFineBlockRows] = numFineBlockRows; + }); + + // Fill Graph + const Kokkos::TeamPolicy policy(numAggregates, 1); + Kokkos::parallel_for("MueLu:TentativePF:BlockCrs:fillGraph", policy, + KOKKOS_LAMBDA(const typename Kokkos::TeamPolicy::member_type &thread) { + auto agg = thread.league_rank(); + Xpetra::global_size_t offset = agg; + + // size of the aggregate (number of DOFs in aggregate) + LO aggSize = aggRows(agg+1) - aggRows(agg); + + for (LO j = 0; j < aggSize; j++) { + // FIXME: Allow for bad maps + const LO localRow = aggToRowMapLO[aggDofSizes[agg]+j]; + const size_t rowStart = ia[localRow]; + ja[rowStart] = offset; + } + }); + + // Compress storage (remove all INVALID, which happen when we skip zeros) + // We do that in-place + { + // Stage 2: compress the arrays + SubFactoryMonitor m2(*this, "Stage 2 (CompressData)", coarseLevel); + // Fill i_temp with the correct row starts + rows_type i_temp(Kokkos::ViewAllocateWithoutInitializing("BlockGraph_rowptr"), numFineBlockRows+1); + size_t nnz=0; + Kokkos::parallel_scan("MueLu:TentativePF:BlockCrs:compress_rows", range_type(0,numFineBlockRows), + KOKKOS_LAMBDA(const LO i, LO& upd, const bool& final) { + if(final) + i_temp[i] = upd; + for (auto j = ia[i]; j < ia[i+1]; j++) + if (ja[j] != INVALID) + upd++; + if(final && i == (LO) numFineBlockRows-1) + i_temp[numFineBlockRows] = upd; + },nnz); + + cols_type j_temp(Kokkos::ViewAllocateWithoutInitializing("BlockGraph_colind"), nnz); + + + Kokkos::parallel_for("MueLu:TentativePF:BlockCrs:compress_cols", range_type(0,numFineBlockRows), + KOKKOS_LAMBDA(const LO i) { + size_t rowStart = i_temp[i]; + size_t lnnz = 0; + for (auto j = ia[i]; j < ia[i+1]; j++) + if (ja[j] != INVALID) { + j_temp[rowStart+lnnz] = ja[j]; + lnnz++; + } + }); + + ia = i_temp; + ja = j_temp; + } + + RCP BlockGraph = CrsGraphFactory::Build(rowMap,coarseBlockMap,ia,ja); + + + // Managing labels & constants for ESFC + { + RCP FCparams; + if(pL.isSublist("matrixmatrix: kernel params")) + FCparams=rcp(new ParameterList(pL.sublist("matrixmatrix: kernel params"))); + else + FCparams= rcp(new ParameterList); + // By default, we don't need global constants for TentativeP + FCparams->set("compute global constants",FCparams->get("compute global constants",false)); + std::string levelIDs = toString(levelID); + FCparams->set("Timer Label",std::string("MueLu::TentativeP-")+levelIDs); + RCP dummy_e; + RCP dummy_i; + BlockGraph->expertStaticFillComplete(coarseBlockMap,rowMap,dummy_i,dummy_e,FCparams); + } + + // We can't leave the ia/ja pointers floating around, because of host/device view counting, so + // we clear them here + ia = rows_type(); + ja = cols_type(); + + + // Now let's make a BlockCrs Matrix + // NOTE: Assumes block size== NSDim + RCP > P_xpetra = Xpetra::CrsMatrixFactory::BuildBlock(BlockGraph, coarsePointMap, rangeMap,NSDim); + RCP > P_tpetra = rcp_dynamic_cast >(P_xpetra); + if(P_tpetra.is_null()) throw std::runtime_error("BuildPUncoupled: Matrix factory did not return a Tpetra::BlockCrsMatrix"); + RCP P_wrap = rcp(new CrsMatrixWrap(P_xpetra)); + + auto values = P_tpetra->getTpetra_BlockCrsMatrix()->getValuesDeviceNonConst(); + const LO stride = NSDim*NSDim; + + Kokkos::parallel_for("MueLu:TentativePF:BlockCrs:main_loop_noqr", policy, + KOKKOS_LAMBDA(const typename Kokkos::TeamPolicy::member_type &thread) { + auto agg = thread.league_rank(); + + // size of the aggregate (number of DOFs in aggregate) + LO aggSize = aggRows(agg+1) - aggRows(agg); + Xpetra::global_size_t offset = agg*NSDim; + + // Q = localQR(:,0)/norm + for (LO j = 0; j < aggSize; j++) { + LO localBlockRow = aggToRowMapLO(aggRows(agg)+j); + LO rowStart = localBlockRow * stride; + for (LO r = 0; r < (LO)NSDim; r++) { + LO localPointRow = localBlockRow*NSDim + r; + for (LO c = 0; c < (LO)NSDim; c++) { + values[rowStart + r*NSDim + c] = fineNSRandom(localPointRow,c); + } + } + } + + // R = norm + for(LO j=0; j<(LO)NSDim; j++) + coarseNS(offset+j,j) = one; + }); + + Ptentative = P_wrap; + +#else + throw std::runtime_error("TentativePFactory::BuildPuncoupledBlockCrs: Requires Tpetra"); +#endif + } + template void TentativePFactory_kokkos>:: BuildPcoupled(RCP /* A */, RCP /* aggregates */, diff --git a/packages/muelu/src/Utils/MueLu_Utilities_decl.hpp b/packages/muelu/src/Utils/MueLu_Utilities_decl.hpp index b3f8ab3fc887..e9f08ae11e6a 100644 --- a/packages/muelu/src/Utils/MueLu_Utilities_decl.hpp +++ b/packages/muelu/src/Utils/MueLu_Utilities_decl.hpp @@ -97,6 +97,8 @@ class Epetra_Vector; #ifdef HAVE_MUELU_TPETRA #include +#include +#include #include #include #include @@ -199,6 +201,14 @@ namespace MueLu { static const Tpetra::CrsMatrix& Op2TpetraCrs(const Xpetra::Matrix& Op); static Tpetra::CrsMatrix& Op2NonConstTpetraCrs(Xpetra::Matrix& Op); + static RCP > Op2TpetraBlockCrs(RCP > Op); + static RCP< Tpetra::BlockCrsMatrix > Op2NonConstTpetraBlockCrs(RCP > Op); + + static const Tpetra::BlockCrsMatrix& Op2TpetraBlockCrs(const Xpetra::Matrix& Op); + static Tpetra::BlockCrsMatrix& Op2NonConstTpetraBlockCrs(Xpetra::Matrix& Op); + + + static RCP > Op2TpetraRow(RCP > Op); static RCP< Tpetra::RowMatrix > Op2NonConstTpetraRow(RCP > Op); @@ -532,6 +542,76 @@ namespace MueLu { #endif } + + static RCP > Op2TpetraBlockCrs(RCP Op) { +#if ((defined(EPETRA_HAVE_OMP) && (!defined(HAVE_TPETRA_INST_OPENMP) || !defined(HAVE_TPETRA_INST_INT_INT))) || \ + (!defined(EPETRA_HAVE_OMP) && (!defined(HAVE_TPETRA_INST_SERIAL) || !defined(HAVE_TPETRA_INST_INT_INT)))) + throw Exceptions::RuntimeError("Op2TpetraBlockCrs: Tpetra has not been compiled with support for LO=GO=int."); +#else + // Get the underlying Tpetra Mtx + RCP crsOp = rcp_dynamic_cast(Op); + if (crsOp == Teuchos::null) + throw Exceptions::BadCast("Cast from Xpetra::Matrix to Xpetra::CrsMatrixWrap failed"); + const RCP > &tmp_ECrsMtx = rcp_dynamic_cast >(crsOp->getCrsMatrix()); + if (tmp_ECrsMtx == Teuchos::null) + throw Exceptions::BadCast("Cast from Xpetra::CrsMatrix to Xpetra::TpetraBlockCrsMatrix failed"); + return tmp_ECrsMtx->getTpetra_BlockCrsMatrix(); +#endif + } + + static RCP< Tpetra::BlockCrsMatrix > Op2NonConstTpetraBlockCrs(RCP Op){ +#if ((defined(EPETRA_HAVE_OMP) && (!defined(HAVE_TPETRA_INST_OPENMP) || !defined(HAVE_TPETRA_INST_INT_INT))) || \ + (!defined(EPETRA_HAVE_OMP) && (!defined(HAVE_TPETRA_INST_SERIAL) || !defined(HAVE_TPETRA_INST_INT_INT)))) + throw Exceptions::RuntimeError("Op2NonConstTpetraBlockCrs: Tpetra has not been compiled with support for LO=GO=int."); +#else + RCP crsOp = rcp_dynamic_cast(Op); + if (crsOp == Teuchos::null) + throw Exceptions::BadCast("Cast from Xpetra::Matrix to Xpetra::CrsMatrixWrap failed"); + const RCP > &tmp_ECrsMtx = rcp_dynamic_cast >(crsOp->getCrsMatrix()); + if (tmp_ECrsMtx == Teuchos::null) + throw Exceptions::BadCast("Cast from Xpetra::CrsMatrix to Xpetra::TpetraBlockCrsMatrix failed"); + return tmp_ECrsMtx->getTpetra_BlockCrsMatrixNonConst(); +#endif + }; + + static const Tpetra::BlockCrsMatrix& Op2TpetraBlockCrs(const Matrix& Op) { +#if ((defined(EPETRA_HAVE_OMP) && (!defined(HAVE_TPETRA_INST_OPENMP) || !defined(HAVE_TPETRA_INST_INT_INT))) || \ + (!defined(EPETRA_HAVE_OMP) && (!defined(HAVE_TPETRA_INST_SERIAL) || !defined(HAVE_TPETRA_INST_INT_INT)))) + throw Exceptions::RuntimeError("Op2TpetraBlockCrs: Tpetra has not been compiled with support for LO=GO=int."); +#else + try { + const CrsMatrixWrap& crsOp = dynamic_cast(Op); + try { + const Xpetra::TpetraBlockCrsMatrix& tmp_ECrsMtx = dynamic_cast&>(*crsOp.getCrsMatrix()); + return *tmp_ECrsMtx.getTpetra_BlockCrsMatrix(); + } catch (std::bad_cast&) { + throw Exceptions::BadCast("Cast from Xpetra::CrsMatrix to Xpetra::TpetraBlockCrsMatrix failed"); + } + } catch (std::bad_cast&) { + throw Exceptions::BadCast("Cast from Xpetra::Matrix to Xpetra::CrsMatrixWrap failed"); + } +#endif + } + static Tpetra::BlockCrsMatrix& Op2NonConstTpetraBlockCrs(Matrix& Op) { +#if ((defined(EPETRA_HAVE_OMP) && (!defined(HAVE_TPETRA_INST_OPENMP) || !defined(HAVE_TPETRA_INST_INT_INT))) || \ + (!defined(EPETRA_HAVE_OMP) && (!defined(HAVE_TPETRA_INST_SERIAL) || !defined(HAVE_TPETRA_INST_INT_INT)))) + throw Exceptions::RuntimeError("Op2NonConstTpetraCrs: Tpetra has not been compiled with support for LO=GO=int."); +#else + try { + CrsMatrixWrap& crsOp = dynamic_cast(Op); + try { + Xpetra::TpetraBlockCrsMatrix& tmp_ECrsMtx = dynamic_cast&>(*crsOp.getCrsMatrix()); + return *tmp_ECrsMtx.getTpetra_BlockCrsMatrixNonConst(); + } catch (std::bad_cast&) { + throw Exceptions::BadCast("Cast from Xpetra::CrsMatrix to Xpetra::TpetraBlockCrsMatrix failed"); + } + } catch (std::bad_cast&) { + throw Exceptions::BadCast("Cast from Xpetra::Matrix to Xpetra::CrsMatrixWrap failed"); + } +#endif + } + + static RCP > Op2TpetraRow(RCP Op) { #if ((defined(EPETRA_HAVE_OMP) && (!defined(HAVE_TPETRA_INST_OPENMP) || !defined(HAVE_TPETRA_INST_INT_INT))) || \ (!defined(EPETRA_HAVE_OMP) && (!defined(HAVE_TPETRA_INST_SERIAL) || !defined(HAVE_TPETRA_INST_INT_INT)))) @@ -799,9 +879,11 @@ namespace MueLu { (!defined(EPETRA_HAVE_OMP) && (!defined(HAVE_TPETRA_INST_SERIAL) || !defined(HAVE_TPETRA_INST_INT_INT)))) throw Exceptions::RuntimeError("Utilities::Transpose: Tpetra is not compiled with LO=GO=int. Add TPETRA_INST_INT_INT:BOOL=ON to your configuration!"); #else - try { + using Helpers = Xpetra::Helpers; + /***************************************************************/ + if(Helpers::isTpetraCrs(Op)) { const Tpetra::CrsMatrix& tpetraOp = Utilities::Op2TpetraCrs(Op); - + // Compute the transpose A of the Tpetra matrix tpetraOp. RCP > A; Tpetra::RowMatrixTransposer transposer(rcpFromRef(tpetraOp),label); @@ -825,9 +907,43 @@ namespace MueLu { return AAAA; } - catch (std::exception& e) { - std::cout << "threw exception '" << e.what() << "'" << std::endl; - throw Exceptions::RuntimeError("Utilities::Transpose failed, perhaps because matrix is not a Crs matrix"); + /***************************************************************/ + else if(Helpers::isTpetraBlockCrs(Op)) { + using BCRS = Tpetra::BlockCrsMatrix; + using CRS = Tpetra::CrsMatrix; + const BCRS & tpetraOp = Utilities::Op2TpetraBlockCrs(Op); + + if(!Op.getRowMap()->getComm()->getRank()) + std::cout<<"WARNING: Utilities::Transpose(): Using inefficient placeholder algorithm for Transpose"< At; + RCP Acrs = Tpetra::convertToCrsMatrix(tpetraOp); + { + Tpetra::RowMatrixTransposer transposer(Acrs,label); + + using Teuchos::ParameterList; + using Teuchos::rcp; + RCP transposeParams = params.is_null () ? + rcp (new ParameterList) : + rcp (new ParameterList (*params)); + transposeParams->set ("sort", false); + RCP Atcrs = transposer.createTranspose(transposeParams); + + At = Tpetra::convertToBlockCrsMatrix(*Atcrs,Op.GetStorageBlockSize()); + } + RCP > AA = rcp(new Xpetra::TpetraBlockCrsMatrix(At)); + RCP AAA = rcp_implicit_cast(AA); + RCP AAAA = rcp( new CrsMatrixWrap(AAA)); + + if (Op.IsView("stridedMaps")) + AAAA->CreateView("stridedMaps", Teuchos::rcpFromRef(Op), true/*doTranspose*/); + + return AAAA; + + } + /***************************************************************/ + else { + throw Exceptions::RuntimeError("Utilities::Transpose failed, perhaps because matrix is not a Crs or BlockCrs matrix"); } #endif #else diff --git a/packages/muelu/src/Utils/MueLu_Utilities_def.hpp b/packages/muelu/src/Utils/MueLu_Utilities_def.hpp index 469276531197..49a674aaa91e 100644 --- a/packages/muelu/src/Utils/MueLu_Utilities_def.hpp +++ b/packages/muelu/src/Utils/MueLu_Utilities_def.hpp @@ -300,6 +300,65 @@ namespace MueLu { } } + + template + RCP > Utilities::Op2TpetraBlockCrs(RCP > Op) { + using XCrsMatrixWrap = Xpetra::CrsMatrixWrap; + // Get the underlying Tpetra Mtx + RCP crsOp = rcp_dynamic_cast(Op); + if (crsOp == Teuchos::null) + throw Exceptions::BadCast("Cast from Xpetra::Matrix to Xpetra::CrsMatrixWrap failed"); + const RCP > &tmp_ECrsMtx = rcp_dynamic_cast >(crsOp->getCrsMatrix()); + if (tmp_ECrsMtx == Teuchos::null) + throw Exceptions::BadCast("Cast from Xpetra::CrsMatrix to Xpetra::TpetraBlockCrsMatrix failed"); + return tmp_ECrsMtx->getTpetra_BlockCrsMatrix(); + } + + template + RCP< Tpetra::BlockCrsMatrix > Utilities::Op2NonConstTpetraBlockCrs(RCP > Op){ + using XCrsMatrixWrap = Xpetra::CrsMatrixWrap; + RCP crsOp = rcp_dynamic_cast(Op); + if (crsOp == Teuchos::null) + throw Exceptions::BadCast("Cast from Xpetra::Matrix to Xpetra::CrsMatrixWrap failed"); + const RCP > &tmp_ECrsMtx = rcp_dynamic_cast >(crsOp->getCrsMatrix()); + if (tmp_ECrsMtx == Teuchos::null) + throw Exceptions::BadCast("Cast from Xpetra::CrsMatrix to Xpetra::TpetraBlockCrsMatrix failed"); + return tmp_ECrsMtx->getTpetra_BlockCrsMatrixNonConst(); + } + + template + const Tpetra::BlockCrsMatrix& Utilities::Op2TpetraBlockCrs(const Xpetra::Matrix& Op) { + try { + using XCrsMatrixWrap = Xpetra::CrsMatrixWrap; + const XCrsMatrixWrap& crsOp = dynamic_cast(Op); + try { + const Xpetra::TpetraBlockCrsMatrix& tmp_ECrsMtx = dynamic_cast&>(*crsOp.getCrsMatrix()); + return *tmp_ECrsMtx.getTpetra_BlockCrsMatrix(); + } catch (std::bad_cast&) { + throw Exceptions::BadCast("Cast from Xpetra::CrsMatrix to Xpetra::TpetraBlockCrsMatrix failed"); + } + } catch (std::bad_cast&) { + throw Exceptions::BadCast("Cast from Xpetra::Matrix to Xpetra::CrsMatrixWrap failed"); + } + } + + template + Tpetra::BlockCrsMatrix& Utilities::Op2NonConstTpetraBlockCrs(Xpetra::Matrix& Op) { + try { + using XCrsMatrixWrap = Xpetra::CrsMatrixWrap; + XCrsMatrixWrap& crsOp = dynamic_cast(Op); + try { + Xpetra::TpetraBlockCrsMatrix& tmp_ECrsMtx = dynamic_cast&>(*crsOp.getCrsMatrix()); + return *tmp_ECrsMtx.getTpetra_BlockCrsMatrixNonConst(); + } catch (std::bad_cast&) { + throw Exceptions::BadCast("Cast from Xpetra::CrsMatrix to Xpetra::TpetraBlockCrsMatrix failed"); + } + } catch (std::bad_cast&) { + throw Exceptions::BadCast("Cast from Xpetra::Matrix to Xpetra::CrsMatrixWrap failed"); + } + } + + template RCP > Utilities::Op2TpetraRow(RCP > Op) { RCP > mat = rcp_dynamic_cast >(Op); @@ -498,12 +557,14 @@ namespace MueLu { #ifdef HAVE_MUELU_TPETRA if (TorE == "tpetra") { - try { + using Helpers = Xpetra::Helpers; + /***************************************************************/ + if(Helpers::isTpetraCrs(Op)) { const Tpetra::CrsMatrix& tpetraOp = Utilities::Op2TpetraCrs(Op); - + RCP > A; Tpetra::RowMatrixTransposer transposer(rcpFromRef(tpetraOp),label); //more than meets the eye - + { using Teuchos::ParameterList; using Teuchos::rcp; @@ -513,20 +574,53 @@ namespace MueLu { transposeParams->set ("sort", false); A = transposer.createTranspose (transposeParams); } - + RCP > AA = rcp(new Xpetra::TpetraCrsMatrix(A) ); RCP > AAA = rcp_implicit_cast >(AA); RCP > AAAA = rcp( new Xpetra::CrsMatrixWrap(AAA) ); if (!AAAA->isFillComplete()) AAAA->fillComplete(Op.getRangeMap(), Op.getDomainMap()); - + if (Op.IsView("stridedMaps")) AAAA->CreateView("stridedMaps", Teuchos::rcpFromRef(Op), true/*doTranspose*/); - + return AAAA; - - } catch (std::exception& e) { - std::cout << "threw exception '" << e.what() << "'" << std::endl; + } + else if(Helpers::isTpetraBlockCrs(Op)) { + using XMatrix = Xpetra::Matrix; + using XCrsMatrix = Xpetra::CrsMatrix; + using XCrsMatrixWrap = Xpetra::CrsMatrixWrap; + using BCRS = Tpetra::BlockCrsMatrix; + using CRS = Tpetra::CrsMatrix; + const BCRS & tpetraOp = Utilities::Op2TpetraBlockCrs(Op); + + if(!Op.getRowMap()->getComm()->getRank()) + std::cout<<"WARNING: Utilities::Transpose(): Using inefficient placeholder algorithm for Transpose"< At; + RCP Acrs = Tpetra::convertToCrsMatrix(tpetraOp); + { + Tpetra::RowMatrixTransposer transposer(Acrs,label); + + using Teuchos::ParameterList; + using Teuchos::rcp; + RCP transposeParams = params.is_null () ? + rcp (new ParameterList) : + rcp (new ParameterList (*params)); + transposeParams->set ("sort", false); + RCP Atcrs = transposer.createTranspose(transposeParams); + + At = Tpetra::convertToBlockCrsMatrix(*Atcrs,Op.GetStorageBlockSize()); + } + RCP > AA = rcp(new Xpetra::TpetraBlockCrsMatrix(At)); + RCP AAA = rcp_implicit_cast(AA); + RCP AAAA = rcp( new XCrsMatrixWrap(AAA)); + + if (Op.IsView("stridedMaps")) + AAAA->CreateView("stridedMaps", Teuchos::rcpFromRef(Op), true/*doTranspose*/); + + return AAAA; + } else { throw Exceptions::RuntimeError("Utilities::Transpose failed, perhaps because matrix is not a Crs matrix"); } } //if diff --git a/packages/muelu/src/Utils/MueLu_Utilities_kokkos_def.hpp b/packages/muelu/src/Utils/MueLu_Utilities_kokkos_def.hpp index 90799c2cd562..7d379bc68259 100644 --- a/packages/muelu/src/Utils/MueLu_Utilities_kokkos_def.hpp +++ b/packages/muelu/src/Utils/MueLu_Utilities_kokkos_def.hpp @@ -368,47 +368,93 @@ namespace MueLu { using impl_scalar_type = typename Kokkos::ArithTraits::val_type; using ATS = Kokkos::ArithTraits; using range_type = Kokkos::RangePolicy; + using helpers = Xpetra::Helpers; - auto localMatrix = A.getLocalMatrixDevice(); - LO numRows = A.getLocalNumRows(); - Kokkos::View boundaryNodes(Kokkos::ViewAllocateWithoutInitializing("boundaryNodes"), numRows); - if (count_twos_as_dirichlet) - Kokkos::parallel_for("MueLu:Utils::DetectDirichletRows_Twos_As_Dirichlet", range_type(0,numRows), + if(helpers::isTpetraBlockCrs(A)) { +#ifdef HAVE_MUELU_TPETRA + const Tpetra::BlockCrsMatrix & Am = helpers::Op2TpetraBlockCrs(A); + auto b_graph = Am.getCrsGraph().getLocalGraphDevice(); + auto b_rowptr = Am.getCrsGraph().getLocalRowPtrsDevice(); + auto values = Am.getValuesDevice(); + LO numBlockRows = Am.getLocalNumRows(); + const LO stride = Am.getBlockSize() * Am.getBlockSize(); + + Kokkos::View boundaryNodes(Kokkos::ViewAllocateWithoutInitializing("boundaryNodes"), numBlockRows); + + if (count_twos_as_dirichlet) + throw Exceptions::RuntimeError("BlockCrs does not support counting twos as Dirichlet"); + + Kokkos::parallel_for("MueLu:Utils::DetectDirichletRowsBlockCrs", range_type(0,numBlockRows), KOKKOS_LAMBDA(const LO row) { - auto rowView = localMatrix.row(row); + auto rowView = b_graph.rowConst(row); auto length = rowView.length; + LO valstart = b_rowptr[row] * stride; boundaryNodes(row) = true; - if (length > 2) { - decltype(length) colID = 0; - for (; colID < length; colID++) - if ((rowView.colidx(colID) != row) && - (ATS::magnitude(rowView.value(colID)) > tol)) { - if (!boundaryNodes(row)) + decltype(length) colID =0; + for (; colID < length; colID++) { + if (rowView.colidx(colID) != row) { + LO current = valstart + colID*stride; + for(LO k=0; k tol) { + boundaryNodes(row) = false; break; - boundaryNodes(row) = false; + } } - if (colID == length) - boundaryNodes(row) = true; + } + if(boundaryNodes(row) == false) + break; } }); - else - Kokkos::parallel_for("MueLu:Utils::DetectDirichletRows", range_type(0,numRows), - KOKKOS_LAMBDA(const LO row) { - auto rowView = localMatrix.row(row); - auto length = rowView.length; - boundaryNodes(row) = true; - for (decltype(length) colID = 0; colID < length; colID++) - if ((rowView.colidx(colID) != row) && - (ATS::magnitude(rowView.value(colID)) > tol)) { - boundaryNodes(row) = false; - break; + return boundaryNodes; +#else + throw Exceptions::RuntimeError("BlockCrs requires Tpetra"); +#endif + } + else { + auto localMatrix = A.getLocalMatrixDevice(); + LO numRows = A.getLocalNumRows(); + Kokkos::View boundaryNodes(Kokkos::ViewAllocateWithoutInitializing("boundaryNodes"), numRows); + + if (count_twos_as_dirichlet) + Kokkos::parallel_for("MueLu:Utils::DetectDirichletRows_Twos_As_Dirichlet", range_type(0,numRows), + KOKKOS_LAMBDA(const LO row) { + auto rowView = localMatrix.row(row); + auto length = rowView.length; + + boundaryNodes(row) = true; + if (length > 2) { + decltype(length) colID =0; + for ( ; colID < length; colID++) + if ((rowView.colidx(colID) != row) && + (ATS::magnitude(rowView.value(colID)) > tol)) { + if (!boundaryNodes(row)) + break; + boundaryNodes(row) = false; + } + if (colID == length) + boundaryNodes(row) = true; } - }); - + }); + else + Kokkos::parallel_for("MueLu:Utils::DetectDirichletRows", range_type(0,numRows), + KOKKOS_LAMBDA(const LO row) { + auto rowView = localMatrix.row(row); + auto length = rowView.length; + + boundaryNodes(row) = true; + for (decltype(length) colID = 0; colID < length; colID++) + if ((rowView.colidx(colID) != row) && + (ATS::magnitude(rowView.value(colID)) > tol)) { + boundaryNodes(row) = false; + break; + } + }); return boundaryNodes; + } + } template diff --git a/packages/muelu/test/scaling/CMakeLists.txt b/packages/muelu/test/scaling/CMakeLists.txt index 67b36efd4e5f..e6f49b53bfab 100644 --- a/packages/muelu/test/scaling/CMakeLists.txt +++ b/packages/muelu/test/scaling/CMakeLists.txt @@ -63,7 +63,7 @@ IF (${PACKAGE_NAME}_HAVE_TPETRA_SOLVER_STACK OR ${PACKAGE_NAME}_HAVE_EPETRA_SOLV ) TRIBITS_COPY_FILES_TO_BINARY_DIR(Driver_cp - SOURCE_FILES scaling.xml scaling.yaml scaling-complex.xml scaling-withglobalconstants.xml scaling-complex-withglobalconstants.xml circ_nsp_dependency.xml isorropia.xml iso_poisson.xml conchas_milestone_zoltan.xml conchas_milestone_zoltan2.xml conchas_milestone_zoltan2_complex.xml sa_with_ilu.xml sa_with_Ifpack2_line_detection.xml rap.xml smoother.xml smoother_complex.xml tripleMatrixProduct.xml scaling-ml.xml elasticity3D.xml amgx.json amgx.xml scaling-with-rerun.xml scaling_distance2_agg.xml smooVec.mm smooVecCoalesce.xml pairwise.xml sa_enforce_constraints.xml recurMG.xml anisotropic.xml comp_rotations.xml generalBlkSmoothing.xml GblkMap.dat GblkAmat.dat GblkRhs.dat Gblks.dat + SOURCE_FILES scaling.xml scaling.yaml scaling-complex.xml scaling-withglobalconstants.xml scaling-complex-withglobalconstants.xml circ_nsp_dependency.xml isorropia.xml iso_poisson.xml conchas_milestone_zoltan.xml conchas_milestone_zoltan2.xml conchas_milestone_zoltan2_complex.xml sa_with_ilu.xml sa_with_Ifpack2_line_detection.xml rap.xml smoother.xml smoother_complex.xml tripleMatrixProduct.xml scaling-ml.xml elasticity3D.xml amgx.json amgx.xml scaling-with-rerun.xml scaling_distance2_agg.xml smooVec.mm smooVecCoalesce.xml pairwise.xml sa_enforce_constraints.xml recurMG.xml anisotropic.xml comp_rotations.xml generalBlkSmoothing.xml GblkMap.dat GblkAmat.dat GblkRhs.dat Gblks.dat blkSmooEquivOlapSchwarz.xml oLapSchwarzEquivBlkSmoo.xml regularOverLap.dat CATEGORIES BASIC PERFORMANCE ) @@ -289,6 +289,22 @@ IF (${PACKAGE_NAME}_HAVE_TPETRA_SOLVER_STACK) COMM mpi # HAVE_MPI required ) + TRIBITS_ADD_TEST( + Driver + NAME "BlockSmoothingWithAverages" + ARGS "--linAlgebra=Tpetra --xml=blkSmooEquivOlapSchwarz.xml --belosType=\"Fixed\ Point\" --rowmap=GblkMap.dat --matrix=GblkAmat.dat --rhs=GblkRhs.dat --tol=.1 --userBlks=regularOverLap.dat" + NUM_MPI_PROCS 4 + COMM mpi # HAVE_MPI required + ) + + TRIBITS_ADD_TEST( + Driver + NAME "SchwarzSmoothingWithAverages" + ARGS "--linAlgebra=Tpetra --xml=oLapSchwarzEquivBlkSmoo.xml --belosType=\"Fixed\ Point\" --rowmap=GblkMap.dat --matrix=GblkAmat.dat --rhs=GblkRhs.dat --tol=.1" + NUM_MPI_PROCS 4 + COMM mpi # HAVE_MPI required + ) + ENDIF() IF (${PACKAGE_NAME}_HAVE_TPETRA_SOLVER_STACK) diff --git a/packages/muelu/test/scaling/blkSmooEquivOlapSchwarz.xml b/packages/muelu/test/scaling/blkSmooEquivOlapSchwarz.xml new file mode 100644 index 000000000000..c9a3911fbc9d --- /dev/null +++ b/packages/muelu/test/scaling/blkSmooEquivOlapSchwarz.xml @@ -0,0 +1,50 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/packages/muelu/test/scaling/oLapSchwarzEquivBlkSmoo.xml b/packages/muelu/test/scaling/oLapSchwarzEquivBlkSmoo.xml new file mode 100644 index 000000000000..caa33197936f --- /dev/null +++ b/packages/muelu/test/scaling/oLapSchwarzEquivBlkSmoo.xml @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/packages/muelu/test/scaling/regularOverLap.dat b/packages/muelu/test/scaling/regularOverLap.dat new file mode 100644 index 000000000000..20768d39b94f --- /dev/null +++ b/packages/muelu/test/scaling/regularOverLap.dat @@ -0,0 +1,188 @@ +%%MatrixMarket matrix coordinate real general +4 137 186 +1 1 1.0 +1 5 1.0 +1 6 1.0 +1 7 1.0 +1 8 1.0 +1 9 1.0 +1 10 1.0 +1 25 1.0 +1 26 1.0 +1 27 1.0 +1 28 1.0 +1 29 1.0 +1 30 1.0 +1 46 1.0 +1 47 1.0 +1 51 1.0 +1 52 1.0 +1 56 1.0 +1 58 1.0 +1 59 1.0 +1 63 1.0 +1 64 1.0 +1 69 1.0 +1 71 1.0 +1 72 1.0 +1 81 1.0 +1 88 1.0 +1 89 1.0 +1 92 1.0 +1 98 1.0 +1 99 1.0 +1 103 1.0 +1 104 1.0 +1 105 1.0 +1 107 1.0 +1 109 1.0 +1 113 1.0 +1 114 1.0 +1 115 1.0 +1 118 1.0 +1 119 1.0 +1 123 1.0 +1 125 1.0 +1 126 1.0 +1 130 1.0 +1 132 1.0 +2 2 1.0 +2 9 1.0 +2 10 1.0 +2 11 1.0 +2 12 1.0 +2 13 1.0 +2 14 1.0 +2 35 1.0 +2 36 1.0 +2 37 1.0 +2 38 1.0 +2 39 1.0 +2 40 1.0 +2 41 1.0 +2 48 1.0 +2 49 1.0 +2 52 1.0 +2 54 1.0 +2 55 1.0 +2 56 1.0 +2 57 1.0 +2 62 1.0 +2 63 1.0 +2 65 1.0 +2 68 1.0 +2 76 1.0 +2 77 1.0 +2 79 1.0 +2 80 1.0 +2 81 1.0 +2 86 1.0 +2 93 1.0 +2 96 1.0 +2 99 1.0 +2 102 1.0 +2 104 1.0 +2 105 1.0 +2 106 1.0 +2 108 1.0 +2 118 1.0 +2 127 1.0 +2 128 1.0 +2 131 1.0 +2 132 1.0 +2 134 1.0 +2 137 1.0 +3 4 1.0 +3 15 1.0 +3 16 1.0 +3 17 1.0 +3 18 1.0 +3 19 1.0 +3 20 1.0 +3 29 1.0 +3 30 1.0 +3 31 1.0 +3 32 1.0 +3 33 1.0 +3 34 1.0 +3 47 1.0 +3 49 1.0 +3 50 1.0 +3 51 1.0 +3 53 1.0 +3 58 1.0 +3 63 1.0 +3 64 1.0 +3 66 1.0 +3 70 1.0 +3 74 1.0 +3 77 1.0 +3 79 1.0 +3 82 1.0 +3 83 1.0 +3 87 1.0 +3 88 1.0 +3 90 1.0 +3 91 1.0 +3 94 1.0 +3 95 1.0 +3 97 1.0 +3 104 1.0 +3 107 1.0 +3 110 1.0 +3 111 1.0 +3 112 1.0 +3 114 1.0 +3 117 1.0 +3 121 1.0 +3 122 1.0 +3 125 1.0 +3 126 1.0 +3 130 1.0 +3 132 1.0 +3 135 1.0 +4 3 1.0 +4 19 1.0 +4 20 1.0 +4 21 1.0 +4 22 1.0 +4 23 1.0 +4 24 1.0 +4 39 1.0 +4 40 1.0 +4 41 1.0 +4 42 1.0 +4 43 1.0 +4 44 1.0 +4 45 1.0 +4 49 1.0 +4 53 1.0 +4 60 1.0 +4 61 1.0 +4 62 1.0 +4 67 1.0 +4 73 1.0 +4 75 1.0 +4 76 1.0 +4 77 1.0 +4 78 1.0 +4 79 1.0 +4 84 1.0 +4 85 1.0 +4 97 1.0 +4 100 1.0 +4 101 1.0 +4 102 1.0 +4 108 1.0 +4 110 1.0 +4 111 1.0 +4 116 1.0 +4 117 1.0 +4 120 1.0 +4 124 1.0 +4 127 1.0 +4 129 1.0 +4 131 1.0 +4 133 1.0 +4 136 1.0 +4 137 1.0 diff --git a/packages/muelu/test/unit_tests/CMakeLists.txt b/packages/muelu/test/unit_tests/CMakeLists.txt index 122d6bd99c6c..e1d22c25e242 100644 --- a/packages/muelu/test/unit_tests/CMakeLists.txt +++ b/packages/muelu/test/unit_tests/CMakeLists.txt @@ -445,8 +445,9 @@ ENDIF() ADD_SUBDIRECTORY(ParameterList/FactoryFactory/) +ADD_SUBDIRECTORY(ParameterList/ParameterListInterpreter/) + IF (${PACKAGE_NAME}_ENABLE_Epetra) ADD_SUBDIRECTORY(ParameterList/MLParameterListInterpreter/) - ADD_SUBDIRECTORY(ParameterList/ParameterListInterpreter/) ADD_SUBDIRECTORY(ParameterList/CreateSublists/) ENDIF() diff --git a/packages/muelu/test/unit_tests/Hierarchy.cpp b/packages/muelu/test/unit_tests/Hierarchy.cpp index 1a8c92841ed7..6132086accfb 100644 --- a/packages/muelu/test/unit_tests/Hierarchy.cpp +++ b/packages/muelu/test/unit_tests/Hierarchy.cpp @@ -1672,7 +1672,7 @@ namespace MueLuTests { } - TEUCHOS_UNIT_TEST_TEMPLATE_4_DECL(Hierarchy, BlockCrs, Scalar, LocalOrdinal, GlobalOrdinal, Node) + TEUCHOS_UNIT_TEST_TEMPLATE_4_DECL(Hierarchy, BlockCrs_Mixed, Scalar, LocalOrdinal, GlobalOrdinal, Node) { # include MUELU_TESTING_SET_OSTREAM; @@ -1789,6 +1789,7 @@ namespace MueLuTests { TEST_EQUALITY(0,0); } + TEUCHOS_UNIT_TEST_TEMPLATE_4_DECL(Hierarchy, CheckNullspaceDimension, Scalar, LocalOrdinal, GlobalOrdinal, Node) { // Test that HierarchyManager throws if user-supplied nullspace has dimension smaller than numPDEs @@ -1835,7 +1836,7 @@ namespace MueLuTests { TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT(Hierarchy, SetupHierarchy3levelFacManagers, Scalar, LO, GO, Node) \ TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT(Hierarchy, SetupHierarchyTestBreakCondition, Scalar, LO, GO, Node) \ TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT(Hierarchy, Write, Scalar, LO, GO, Node) \ - TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT(Hierarchy, BlockCrs, Scalar, LO, GO, Node) \ + TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT(Hierarchy, BlockCrs_Mixed, Scalar, LO, GO, Node) \ TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT(Hierarchy, CheckNullspaceDimension, Scalar, LO, GO, Node) \ diff --git a/packages/muelu/test/unit_tests/MueLu_TestHelpers.hpp b/packages/muelu/test/unit_tests/MueLu_TestHelpers.hpp index 4bd0c652af14..a6f3e9af16c9 100644 --- a/packages/muelu/test/unit_tests/MueLu_TestHelpers.hpp +++ b/packages/muelu/test/unit_tests/MueLu_TestHelpers.hpp @@ -881,10 +881,24 @@ namespace MueLuTests { basematrix[4] = two; basematrix[7] = three; basematrix[8] = two; + Teuchos::Array offmatrix(blocksize*blocksize, zero); + offmatrix[0]=offmatrix[4]=offmatrix[8]=-1; + Teuchos::Array lclColInds(1); for (LocalOrdinal lclRowInd = meshRowMap.getMinLocalIndex (); lclRowInd <= meshRowMap.getMaxLocalIndex(); ++lclRowInd) { lclColInds[0] = lclRowInd; bcrsmatrix->replaceLocalValues(lclRowInd, lclColInds.getRawPtr(), &basematrix[0], 1); + + // Off diagonals + if(lclRowInd > meshRowMap.getMinLocalIndex ()) { + lclColInds[0] = lclRowInd - 1; + bcrsmatrix->replaceLocalValues(lclRowInd, lclColInds.getRawPtr(), &offmatrix[0], 1); + } + if(lclRowInd < meshRowMap.getMaxLocalIndex ()) { + lclColInds[0] = lclRowInd + 1; + bcrsmatrix->replaceLocalValues(lclRowInd, lclColInds.getRawPtr(), &offmatrix[0], 1); + } + } RCP > temp = rcp(new Xpetra::TpetraBlockCrsMatrix(bcrsmatrix)); diff --git a/packages/muelu/test/unit_tests/MueLu_Test_ETI.hpp b/packages/muelu/test/unit_tests/MueLu_Test_ETI.hpp index c904379b227b..1ee25d148301 100644 --- a/packages/muelu/test/unit_tests/MueLu_Test_ETI.hpp +++ b/packages/muelu/test/unit_tests/MueLu_Test_ETI.hpp @@ -57,7 +57,7 @@ // need this to have the ETI defined macros #if defined(HAVE_MUELU_EXPLICIT_INSTANTIATION) #include -#endif +#endif #if defined(HAVE_MUELU_TPETRA) #include @@ -87,6 +87,14 @@ bool Automatic_Test_ETI(int argc, char *argv[]) { // MPI initialization using Teuchos Teuchos::GlobalMPISession mpiSession(&argc, &argv, NULL); + Teuchos::RCP out = Teuchos::rcp(new Teuchos::FancyOStream(Teuchos::rcpFromRef(std::cout))); +#ifdef HAVE_MPI + Teuchos::RCP > comm = Teuchos::rcp_dynamic_cast >(Teuchos::DefaultComm::getComm()); + if (comm->getSize() > 1) { + out->setOutputToRootOnly(0); + } +#endif + // Tpetra nodes call Kokkos::execution_space::initialize if the execution // space is not initialized, but they don't call Kokkos::initialize. // Teuchos::GlobalMPISession captures its command-line arguments for later @@ -145,7 +153,7 @@ bool Automatic_Test_ETI(int argc, char *argv[]) { // Both Epetra and Tpetra (with double, int, int) enabled return MUELU_AUTOMATIC_TEST_ETI_NAME(clp, lib, argc, argv); # else - std::cout << "Skip running with Epetra since both Epetra and Tpetra are enabled but Tpetra is not instantiated on double, int, int." << std::endl; + *out << "Skip running with Epetra since both Epetra and Tpetra are enabled but Tpetra is not instantiated on double, int, int." << std::endl; # endif // end Tpetra instantiated on double, int, int # else // only Epetra enabled. No Tpetra instantiation possible @@ -162,6 +170,11 @@ bool Automatic_Test_ETI(int argc, char *argv[]) { if (node == "") { typedef KokkosClassic::DefaultNode::DefaultNodeType Node; + if (config) { + *out << "Node type: " << Node::execution_space::name() << std::endl; + Node::execution_space::print_configuration(*out, true/*details*/); + } + #ifndef HAVE_MUELU_EXPLICIT_INSTANTIATION return MUELU_AUTOMATIC_TEST_ETI_NAME(clp, lib, argc, argv); #else @@ -191,8 +204,10 @@ bool Automatic_Test_ETI(int argc, char *argv[]) { #ifdef KOKKOS_ENABLE_SERIAL typedef Kokkos::Compat::KokkosSerialWrapperNode Node; - if (config) - Kokkos::Serial().print_configuration(std::cout, true/*details*/); + if (config) { + *out << "Node type: " << Node::execution_space::name() << std::endl; + Kokkos::Serial().print_configuration(*out, true/*details*/); + } # ifndef HAVE_MUELU_EXPLICIT_INSTANTIATION return MUELU_AUTOMATIC_TEST_ETI_NAME(clp, lib, argc, argv); @@ -227,8 +242,9 @@ bool Automatic_Test_ETI(int argc, char *argv[]) { typedef Kokkos::Compat::KokkosOpenMPWrapperNode Node; if (config) { - Kokkos::OpenMP().print_configuration(std::cout, true/*details*/); - std::cout << "OpenMP Max Threads = " << omp_get_max_threads() << std::endl; + *out << "Node type: " << Node::execution_space::name() << std::endl; + Kokkos::OpenMP().print_configuration(*out, true/*details*/); + *out << "OpenMP Max Threads = " << omp_get_max_threads() << std::endl; } # ifndef HAVE_MUELU_EXPLICIT_INSTANTIATION @@ -263,8 +279,10 @@ bool Automatic_Test_ETI(int argc, char *argv[]) { #ifdef KOKKOS_ENABLE_CUDA typedef Kokkos::Compat::KokkosCudaWrapperNode Node; - if (config) - Kokkos::Cuda().print_configuration(std::cout, true/*details*/); + if (config) { + *out << "Node type: " << Node::execution_space::name() << std::endl; + Kokkos::Cuda().print_configuration(*out, true/*details*/); + } # ifndef HAVE_MUELU_EXPLICIT_INSTANTIATION return MUELU_AUTOMATIC_TEST_ETI_NAME(clp, lib, argc, argv); @@ -298,8 +316,10 @@ bool Automatic_Test_ETI(int argc, char *argv[]) { #ifdef KOKKOS_ENABLE_HIP typedef Kokkos::Compat::KokkosHIPWrapperNode Node; - if (config) - Kokkos::Experimental::HIP().print_configuration(std::cout, true/*details*/); + if (config) { + *out << "Node type: " << Node::execution_space::name() << std::endl; + Kokkos::Experimental::HIP().print_configuration(*out, true/*details*/); + } # ifndef HAVE_MUELU_EXPLICIT_INSTANTIATION return MUELU_AUTOMATIC_TEST_ETI_NAME(clp, lib, argc, argv); diff --git a/packages/muelu/test/unit_tests/ParameterList/ParameterListInterpreter.cpp b/packages/muelu/test/unit_tests/ParameterList/ParameterListInterpreter.cpp index a0f0ada4b9b7..90639e98e0d9 100644 --- a/packages/muelu/test/unit_tests/ParameterList/ParameterListInterpreter.cpp +++ b/packages/muelu/test/unit_tests/ParameterList/ParameterListInterpreter.cpp @@ -43,14 +43,23 @@ // *********************************************************************** // // @HEADER - #include +#include #include #include #include #include +#include + +#include + +#ifdef HAVE_MUELU_TPETRA +#include "Tpetra_BlockCrsMatrix_Helpers.hpp" +#include "TpetraExt_MatrixMatrix.hpp" +#endif + namespace MueLuTests { @@ -67,6 +76,14 @@ namespace MueLuTests { ArrayRCP fileList = TestHelpers::GetFileList(std::string("ParameterList/ParameterListInterpreter/"), std::string(".xml")); for(int i=0; i< fileList.size(); i++) { + // Ignore files with "BlockCrs" in their name + auto found = fileList[i].find("BlockCrs"); + if(found != std::string::npos) continue; + + // Ignore files with "Comparison" in their name + found = fileList[i].find("Comparison"); + if(found != std::string::npos) continue; + out << "Processing file: " << fileList[i] << std::endl; ParameterListInterpreter mueluFactory("ParameterList/ParameterListInterpreter/" + fileList[i],*comm); @@ -83,8 +100,178 @@ namespace MueLuTests { out << "Skipping test because some required packages are not enabled (Tpetra, Epetra, EpetraExt, Ifpack, Ifpack2, Amesos, Amesos2)." << std::endl; # endif } + + +TEUCHOS_UNIT_TEST_TEMPLATE_4_DECL(ParameterListInterpreter, BlockCrs, Scalar, LocalOrdinal, GlobalOrdinal, Node) + { +# include + MUELU_TESTING_SET_OSTREAM; + MUELU_TESTING_LIMIT_SCOPE(Scalar,GlobalOrdinal,Node); +#if defined(HAVE_MUELU_TPETRA) + MUELU_TEST_ONLY_FOR(Xpetra::UseTpetra) { + Teuchos::ParameterList matrixParams; + matrixParams.set("matrixType","Laplace1D"); + matrixParams.set("nx",(GlobalOrdinal)300);// needs to be even + + RCP A = TestHelpers::TpetraTestFactory::BuildBlockMatrix(matrixParams,Xpetra::UseTpetra); + out<<"Matrix Size (block) = "<getGlobalNumRows()<<" (point) "<getRangeMap()->getGlobalNumElements()< > comm = TestHelpers::Parameters::getDefaultComm(); + + ArrayRCP fileList = TestHelpers::GetFileList(std::string("ParameterList/ParameterListInterpreter/"), std::string(".xml")); + + for(int i=0; i< fileList.size(); i++) { + // Only run files with "BlockCrs" in their name + auto found = fileList[i].find("BlockCrs"); + if(found == std::string::npos) continue; + + out << "Processing file: " << fileList[i] << std::endl; + ParameterListInterpreter mueluFactory("ParameterList/ParameterListInterpreter/" + fileList[i],*comm); + + RCP H = mueluFactory.CreateHierarchy(); + H->GetLevel(0)->Set("A", A); + + mueluFactory.SetupHierarchy(*H); + + // Test to make sure all of the matrices in the Hierarchy are actually Block Matrices + using helpers = Xpetra::Helpers; + for(int j=0; jGetNumLevels(); j++) { + RCP level = H->GetLevel(j); + + RCP Am = level->Get >("A"); + TEST_EQUALITY(helpers::isTpetraBlockCrs(Am),true); + if(j>0) { + RCP P = level->Get >("P"); + TEST_EQUALITY(helpers::isTpetraBlockCrs(P),true); + RCP R = level->Get >("R"); + TEST_EQUALITY(helpers::isTpetraBlockCrs(R),true); + } + } + + //TODO: check no unused parameters + //TODO: check results of Iterate() + } + } +# endif + TEST_EQUALITY(1,1); + } + + + +#if defined(HAVE_MUELU_TPETRA) +template +MT compare_matrices(RCP & Ap, RCP &Ab) { + using SC = typename Matrix::scalar_type; + using LO = typename Matrix::local_ordinal_type; + using GO = typename Matrix::global_ordinal_type; + using NO = typename Matrix::node_type; + using CRS=Tpetra::CrsMatrix; + SC one = Teuchos::ScalarTraits::one(); + SC zero = Teuchos::ScalarTraits::zero(); + + RCP Ap_t = MueLu::Utilities::Op2TpetraCrs(Ap); + auto Ab_t = MueLu::Utilities::Op2TpetraBlockCrs(Ab); + RCP Ab_as_point = Tpetra::convertToCrsMatrix(*Ab_t); + + RCP diff = rcp(new CRS(Ap_t->getCrsGraph())); + diff->setAllToScalar(zero); + diff->fillComplete(); + Tpetra::MatrixMatrix::Add(*Ap_t,false,one,*Ab_as_point,false,-one,diff); + return diff->getFrobeniusNorm(); +} +#endif + + +TEUCHOS_UNIT_TEST_TEMPLATE_4_DECL(ParameterListInterpreter, PointCrs_vs_BlockCrs, Scalar, LocalOrdinal, GlobalOrdinal, Node) + { +# include + MUELU_TESTING_SET_OSTREAM; + MUELU_TESTING_LIMIT_SCOPE(Scalar,GlobalOrdinal,Node); +#if defined(HAVE_MUELU_TPETRA) + MUELU_TEST_ONLY_FOR(Xpetra::UseTpetra) { + Teuchos::ParameterList matrixParams; + matrixParams.set("matrixType","Laplace1D"); + matrixParams.set("nx",(GlobalOrdinal)300);// needs to be even + + RCP PointA = TestHelpers::TestFactory::BuildMatrix(matrixParams,Xpetra::UseTpetra); + RCP BlockA; + { + using XCRS = Xpetra::TpetraBlockCrsMatrix; + + auto tA = MueLu::Utilities::Op2TpetraCrs(PointA); + auto bA = Tpetra::convertToBlockCrsMatrix(*tA,1); + RCP AA = rcp(new XCRS(bA)); + BlockA = rcp(new CrsMatrixWrap(rcp_implicit_cast(AA))); + } + + out<<"Point: Matrix Size (block) = "<getGlobalNumRows()<<" (point) "<getRangeMap()->getGlobalNumElements()<getRangeMap()->getGlobalNumElements()< > comm = TestHelpers::Parameters::getDefaultComm(); + + ArrayRCP fileList = TestHelpers::GetFileList(std::string("ParameterList/ParameterListInterpreter/"), std::string(".xml")); + + for(int i=0; i< fileList.size(); i++) { + // Only run files with "Comparison" in their name + auto found = fileList[i].find("Comparison"); + if(found == std::string::npos) continue; + + out << "Processing file: " << fileList[i] << std::endl; + + // Point Hierarchy + ParameterListInterpreter mueluFactory1("ParameterList/ParameterListInterpreter/" + fileList[i],*comm); + RCP PointH = mueluFactory1.CreateHierarchy(); + PointH->GetLevel(0)->Set("A", PointA); + mueluFactory1.SetupHierarchy(*PointH); + + // Block Hierachy + ParameterListInterpreter mueluFactory2("ParameterList/ParameterListInterpreter/" + fileList[i],*comm); + RCP BlockH = mueluFactory2.CreateHierarchy(); + BlockH->GetLevel(0)->Set("A", BlockA); + mueluFactory2.SetupHierarchy(*BlockH); + + // Check to see that we get the same matrices in both hierarchies + TEST_EQUALITY(PointH->GetNumLevels(),BlockH->GetNumLevels()); + + for(int j=0; jGetNumLevels(); j++) { + using CRS=Tpetra::CrsMatrix; + using MT = typename Teuchos::ScalarTraits::magnitudeType; + MT tol = Teuchos::ScalarTraits::squareroot(Teuchos::ScalarTraits::eps()); + + RCP Plevel = PointH->GetLevel(j); + RCP Blevel = BlockH->GetLevel(j); + + // Compare A + RCP Ap = Plevel->Get >("A"); + RCP Ab = Blevel->Get >("A"); + MT norm = compare_matrices(Ap,Ab); + TEUCHOS_TEST_COMPARE(norm,<,tol,out,success); + + // Compare P, R + if(j>0) { + RCP Pp = Plevel->Get >("P"); + RCP Pb = Blevel->Get >("P"); + norm = compare_matrices(Pp,Pb); + TEUCHOS_TEST_COMPARE(norm,<,tol,out,success); + + RCP Rp = Plevel->Get >("R"); + RCP Rb = Blevel->Get >("R"); + norm = compare_matrices(Rp,Rb); + TEUCHOS_TEST_COMPARE(norm,<,tol,out,success); + } + } + + //TODO: check no unused parameters + //TODO: check results of Iterate() + } + } +# endif + TEST_EQUALITY(1,1); + } + + #define MUELU_ETI_GROUP(Scalar, LocalOrdinal, GlobalOrdinal, Node) \ - TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT(ParameterListInterpreter, SetParameterList, Scalar, LocalOrdinal, GlobalOrdinal, Node) + TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT(ParameterListInterpreter, SetParameterList, Scalar, LocalOrdinal, GlobalOrdinal, Node) \ + TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT(ParameterListInterpreter, BlockCrs, Scalar, LocalOrdinal, GlobalOrdinal, Node) \ + TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT(ParameterListInterpreter, PointCrs_vs_BlockCrs, Scalar, LocalOrdinal, GlobalOrdinal, Node) #include diff --git a/packages/muelu/test/unit_tests/ParameterList/ParameterListInterpreter/BlockCrs1.xml b/packages/muelu/test/unit_tests/ParameterList/ParameterListInterpreter/BlockCrs1.xml new file mode 100644 index 000000000000..d048ab0d38d8 --- /dev/null +++ b/packages/muelu/test/unit_tests/ParameterList/ParameterListInterpreter/BlockCrs1.xml @@ -0,0 +1,28 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/packages/muelu/test/unit_tests/ParameterList/ParameterListInterpreter/BlockCrs2.xml b/packages/muelu/test/unit_tests/ParameterList/ParameterListInterpreter/BlockCrs2.xml new file mode 100644 index 000000000000..5954af0d9af4 --- /dev/null +++ b/packages/muelu/test/unit_tests/ParameterList/ParameterListInterpreter/BlockCrs2.xml @@ -0,0 +1,32 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/packages/muelu/test/unit_tests/ParameterList/ParameterListInterpreter/CMakeLists.txt b/packages/muelu/test/unit_tests/ParameterList/ParameterListInterpreter/CMakeLists.txt index 8e56cda05351..7023acbef07c 100644 --- a/packages/muelu/test/unit_tests/ParameterList/ParameterListInterpreter/CMakeLists.txt +++ b/packages/muelu/test/unit_tests/ParameterList/ParameterListInterpreter/CMakeLists.txt @@ -3,6 +3,8 @@ # regenerate a build system incorporating the new file. # YOU MUST ALSO TOUCH A CMAKE CONFIGURATION FILE WHEN YOU PUSH THE NEW # FILE TO FORCE THE RECONFIGURE ON OTHER PEOPLE'S BUILDS. + + FILE(GLOB xmlFiles RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} *.xml) TRIBITS_COPY_FILES_TO_BINARY_DIR(ParameterList_ParameterListInterpreter_cp diff --git a/packages/muelu/test/unit_tests/ParameterList/ParameterListInterpreter/Comparison1.xml b/packages/muelu/test/unit_tests/ParameterList/ParameterListInterpreter/Comparison1.xml new file mode 100644 index 000000000000..5432a0dbdc37 --- /dev/null +++ b/packages/muelu/test/unit_tests/ParameterList/ParameterListInterpreter/Comparison1.xml @@ -0,0 +1,27 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/packages/muelu/test/unit_tests/Smoothers/Ifpack2Smoother.cpp b/packages/muelu/test/unit_tests/Smoothers/Ifpack2Smoother.cpp index eb8d11f54ba0..81bcf614654e 100644 --- a/packages/muelu/test/unit_tests/Smoothers/Ifpack2Smoother.cpp +++ b/packages/muelu/test/unit_tests/Smoothers/Ifpack2Smoother.cpp @@ -567,7 +567,7 @@ namespace MueLuTests { } // banded - TEUCHOS_UNIT_TEST_TEMPLATE_4_DECL(Ifpack2Smoother, BlockCrsMatrix_Relaxation, Scalar, LocalOrdinal, GlobalOrdinal, Node) + TEUCHOS_UNIT_TEST_TEMPLATE_4_DECL(Ifpack2Smoother, BlockCrsMatrix_Relaxation_ViaPoint, Scalar, LocalOrdinal, GlobalOrdinal, Node) { # include MUELU_TESTING_SET_OSTREAM; @@ -592,6 +592,32 @@ namespace MueLuTests { } } + TEUCHOS_UNIT_TEST_TEMPLATE_4_DECL(Ifpack2Smoother, BlockCrsMatrix_Relaxation_AsBlock, Scalar, LocalOrdinal, GlobalOrdinal, Node) + { +# include + MUELU_TESTING_SET_OSTREAM; + MUELU_TESTING_LIMIT_SCOPE(Scalar,GlobalOrdinal,Node); + + MUELU_TEST_ONLY_FOR(Xpetra::UseTpetra) { + Teuchos::ParameterList matrixParams, ifpack2Params; + + matrixParams.set("matrixType","Laplace1D"); + matrixParams.set("nx",(GlobalOrdinal)20);// needs to be even + + RCP A = TestHelpers::TpetraTestFactory::BuildBlockMatrix(matrixParams,Xpetra::UseTpetra); + ifpack2Params.set("smoother: use blockcrsmatrix storage",true); + + Ifpack2Smoother smoother("RELAXATION",ifpack2Params); + + Level level; TestHelpers::TestFactory::createSingleLevelHierarchy(level); + level.Set("A", A); + smoother.Setup(level); + + TEST_EQUALITY(1,1); + } + } + + #define MUELU_ETI_GROUP(SC,LO,GO,NO) \ @@ -605,7 +631,8 @@ namespace MueLuTests { TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT(Ifpack2Smoother,BandedRelaxation,SC,LO,GO,NO) \ TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT(Ifpack2Smoother,TriDiRelaxation,SC,LO,GO,NO) \ TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT(Ifpack2Smoother,BlockRelaxation_Autosize,SC,LO,GO,NO) \ - TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT(Ifpack2Smoother,BlockCrsMatrix_Relaxation,SC,LO,GO,NO) + TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT(Ifpack2Smoother,BlockCrsMatrix_Relaxation_ViaPoint,SC,LO,GO,NO) \ + TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT(Ifpack2Smoother,BlockCrsMatrix_Relaxation_AsBlock,SC,LO,GO,NO) #include diff --git a/packages/muelu/test/unit_tests_kokkos/Aggregates_kokkos.cpp b/packages/muelu/test/unit_tests_kokkos/Aggregates_kokkos.cpp index 2d8deec90d7a..a9bb67efc595 100644 --- a/packages/muelu/test/unit_tests_kokkos/Aggregates_kokkos.cpp +++ b/packages/muelu/test/unit_tests_kokkos/Aggregates_kokkos.cpp @@ -611,6 +611,20 @@ namespace MueLuTests { }, numBadAggregates); TEST_EQUALITY(numBadAggregates, 0); + // Check ComputeNodesInAggregate + typename Aggregates_kokkos::LO_view aggPtr, aggNodes, unaggregated; + aggregates->ComputeNodesInAggregate(aggPtr, aggNodes, unaggregated); + TEST_EQUALITY(aggPtr.extent_int(0), numAggs+1); + // TEST_EQUALITY(unaggregated.extent_int(0), 0); // 1 unaggregated node in the MPI_4 case + + // test aggPtr(i)+aggSizes(i)=aggPtr(i+1) + typename Aggregates_kokkos::LO_view::HostMirror aggPtr_h = Kokkos::create_mirror_view(aggPtr); + typename Aggregates_kokkos::aggregates_sizes_type::HostMirror aggSizes_h = Kokkos::create_mirror_view(aggSizes); + Kokkos::deep_copy(aggPtr_h, aggPtr); + Kokkos::deep_copy(aggSizes_h, aggSizes); + for(LO i=0; i SearchId; -typedef stk::search::Sphere Sphere; -typedef std::vector< std::pair > SphereIdVector; -typedef std::vector > SearchPairVector; -typedef std::vector > SearchPairSet; #endif namespace panzer_stk { @@ -69,6 +62,15 @@ namespace panzer_stk { */ namespace periodic_helpers { +#ifdef PANZER_HAVE_STKSEARCH + // Copied from PeriodicBoundarySearch + typedef stk::search::IdentProc SearchId; + typedef stk::search::Sphere Sphere; + typedef std::vector< std::pair > SphereIdVector; + typedef std::vector > SearchPairVector; + typedef std::vector > SearchPairSet; +#endif + /** Construct the vector pair (similar to getLocallyMatchedPair) * usign specified side sets, mesh object, and matcher object. This * is primarily a utility function. diff --git a/packages/panzer/adapters-stk/test/periodic_bcs/periodic_search.cpp b/packages/panzer/adapters-stk/test/periodic_bcs/periodic_search.cpp index 64002f04756b..8010a1d8ef70 100644 --- a/packages/panzer/adapters-stk/test/periodic_bcs/periodic_search.cpp +++ b/packages/panzer/adapters-stk/test/periodic_bcs/periodic_search.cpp @@ -117,7 +117,7 @@ namespace panzer { auto myrank = mesh->getBulkData()->parallel_rank(); panzer_stk::CoordMatcher x_matcher(0); - SphereIdVector coordsIds; + panzer_stk::periodic_helpers::SphereIdVector coordsIds; auto error = x_matcher.getAbsoluteTolerance(); panzer_stk::periodic_helpers::fillLocalSearchVector(*mesh,coordsIds,error,"top","coord"); @@ -150,7 +150,7 @@ namespace panzer { auto myrank = mesh->getBulkData()->parallel_rank(); panzer_stk::CoordMatcher x_matcher(0); - SphereIdVector coordsIds; + panzer_stk::periodic_helpers::SphereIdVector coordsIds; auto error = x_matcher.getAbsoluteTolerance(); panzer_stk::periodic_helpers::fillLocalSearchVector(*mesh,coordsIds,error,"top","edge"); @@ -185,7 +185,7 @@ namespace panzer { auto myrank = mesh->getBulkData()->parallel_rank(); panzer_stk::CoordMatcher x_matcher(0); - SphereIdVector coordsIds; + panzer_stk::periodic_helpers::SphereIdVector coordsIds; auto error = x_matcher.getAbsoluteTolerance(); panzer_stk::periodic_helpers::fillLocalSearchVector(*mesh,coordsIds,error,"top","face"); @@ -250,8 +250,8 @@ namespace panzer { TEST_EQUALITY(mesh->getBulkData()->parallel_size(),2); panzer_stk::CoordMatcher x_matcher(0),y_matcher(1),z_matcher(2); - SphereIdVector topCoordsIds,leftCoordsIds,frontCoordsIds; - SphereIdVector uniqueLeftCoordsIds,uniqueFrontCoordsIds; + panzer_stk::periodic_helpers::SphereIdVector topCoordsIds,leftCoordsIds,frontCoordsIds; + panzer_stk::periodic_helpers::SphereIdVector uniqueLeftCoordsIds,uniqueFrontCoordsIds; auto error = x_matcher.getAbsoluteTolerance(); // first get all the ids on each face @@ -261,7 +261,7 @@ namespace panzer { // now only get ids if they have not already been found std::vector > matchedSides(3); - std::vector doubleRequestsL, doubleRequestsF; + std::vector doubleRequestsL, doubleRequestsF; matchedSides[0].push_back("top"); panzer_stk::periodic_helpers::fillLocalSearchVector(*mesh,uniqueLeftCoordsIds,error,"left","coord",false,matchedSides[0],doubleRequestsL); matchedSides[0].push_back("left"); @@ -430,7 +430,7 @@ namespace panzer { panzer_stk::CoordMatcher x_matcher(0); panzer_stk::CoordMatcher y_matcher(1); - SphereIdVector bottom, left; + panzer_stk::periodic_helpers::SphereIdVector bottom, left; // create lines of points to be shifted @@ -438,7 +438,7 @@ namespace panzer { stk::mesh::EntityId id(0); // doesnt matter stk::mesh::EntityKey key(stk::topology::NODE_RANK,id); // doesnt matter - SearchId search_id(key,0); // doesnt matter + panzer_stk::periodic_helpers::SearchId search_id(key,0); // doesnt matter for (size_t n=0; n yCenter(0,n,0); stk::search::Point xCenter(n,0,0); @@ -474,7 +474,7 @@ namespace panzer { panzer_stk::PlaneMatcher yz_matcher(1,2); panzer_stk::PlaneMatcher xz_matcher(0,2); - SphereIdVector xy, yz, xz; + panzer_stk::periodic_helpers::SphereIdVector xy, yz, xz; // create planes of points to be shifted @@ -482,7 +482,7 @@ namespace panzer { stk::mesh::EntityId id(0); // doesnt matter stk::mesh::EntityKey key(stk::topology::NODE_RANK,id); // doesnt matter - SearchId search_id(key,0); // doesnt matter + panzer_stk::periodic_helpers::SearchId search_id(key,0); // doesnt matter for (size_t i=0; i xyCenter(i,j,0); @@ -546,7 +546,7 @@ namespace panzer { panzer_stk::QuarterPlaneMatcher xzY_matcher(0,2,1); panzer_stk::QuarterPlaneMatcher yxZ_matcher(1,0,2); - SphereIdVector yz,zx,zy,xz; + panzer_stk::periodic_helpers::SphereIdVector yz,zx,zy,xz; // create planes of points to be shifted (these are side B's) @@ -554,7 +554,7 @@ namespace panzer { stk::mesh::EntityId id(0); // doesnt matter stk::mesh::EntityKey key(stk::topology::NODE_RANK,id); // doesnt matter - SearchId search_id(key,0); // doesnt matter + panzer_stk::periodic_helpers::SearchId search_id(key,0); // doesnt matter for (size_t i=0; i yzCenter(0,i,j); @@ -633,7 +633,7 @@ namespace panzer { panzer_stk::WedgeMatcher YZ_matcher(panzer_stk::WedgeMatcher::MirrorPlane::YZ_PLANE,params); panzer_stk::WedgeMatcher XZ_matcher(panzer_stk::WedgeMatcher::MirrorPlane::XZ_PLANE,params); - SphereIdVector YZ_sideB, XZ_sideB; + panzer_stk::periodic_helpers::SphereIdVector YZ_sideB, XZ_sideB; // create planes of points to be shifted (these are side B's) @@ -641,7 +641,7 @@ namespace panzer { stk::mesh::EntityId id(0); // doesnt matter stk::mesh::EntityKey key(stk::topology::NODE_RANK,id); // doesnt matter - SearchId search_id(key,0); // doesnt matter + panzer_stk::periodic_helpers::SearchId search_id(key,0); // doesnt matter // we will create planes with corners (0,0,0) (1,1,0) (1,1,1) (0,0,1) for (size_t i=0; i > > globallyMatchedIds_edge = panzer_stk::periodic_helpers::matchPeriodicSidesSearch("left","right",*mesh,matcher,"edge"); - SphereIdVector leftCoordsIds_edge,rightCoordsIds_edge; + panzer_stk::periodic_helpers::SphereIdVector leftCoordsIds_edge,rightCoordsIds_edge; auto error = matcher.getAbsoluteTolerance(); panzer_stk::periodic_helpers::fillLocalSearchVector(*mesh,leftCoordsIds_edge,error,"left","edge"); panzer_stk::periodic_helpers::fillLocalSearchVector(*mesh,rightCoordsIds_edge,error,"right","edge"); diff --git a/packages/phalanx/src/Phalanx_KokkosViewOfViews.hpp b/packages/phalanx/src/Phalanx_KokkosViewOfViews.hpp index c4f2c52d27bd..c8d9206599ed 100644 --- a/packages/phalanx/src/Phalanx_KokkosViewOfViews.hpp +++ b/packages/phalanx/src/Phalanx_KokkosViewOfViews.hpp @@ -323,12 +323,17 @@ namespace PHX { check_use_count_(true) {} + // Making this a kokkos function eliminates cuda compiler warnings + // in objects that contain ViewOfViews3 that are copied to device. + KOKKOS_INLINE_FUNCTION ~ViewOfViews3() { // Make sure there is not another object pointing to device view // since the host view will delete the inner views on exit. - if ( check_use_count_ && (view_device_.impl_track().use_count() != use_count_) ) - Kokkos::abort("\n ERROR - PHX::ViewOfViews - please free all instances of device ViewOfView \n before deleting the host ViewOfView!\n\n"); + KOKKOS_IF_ON_HOST(( + if ( check_use_count_ && (view_device_.impl_track().use_count() != use_count_) ) + Kokkos::abort("\n ERROR - PHX::ViewOfViews - please free all instances of device ViewOfView \n before deleting the host ViewOfView!\n\n"); + )) } /// Enable safety check in dtor for external references. diff --git a/packages/phalanx/test/Kokkos/CMakeLists.txt b/packages/phalanx/test/Kokkos/CMakeLists.txt index 998ca7181de5..ddf102ec3d20 100644 --- a/packages/phalanx/test/Kokkos/CMakeLists.txt +++ b/packages/phalanx/test/Kokkos/CMakeLists.txt @@ -41,3 +41,10 @@ TRIBITS_ADD_EXECUTABLE_AND_TEST( TESTONLYLIBS phalanx_unit_test_main phalanx_test_utilities NUM_MPI_PROCS 1 ) + +TRIBITS_ADD_EXECUTABLE_AND_TEST( + tKokkosClassOnDevice + SOURCES tKokkosClassOnDevice.cpp + TESTONLYLIBS phalanx_unit_test_main phalanx_test_utilities + NUM_MPI_PROCS 1 + ) diff --git a/packages/phalanx/test/Kokkos/tKokkosClassOnDevice.cpp b/packages/phalanx/test/Kokkos/tKokkosClassOnDevice.cpp new file mode 100644 index 000000000000..80cd182e2aa3 --- /dev/null +++ b/packages/phalanx/test/Kokkos/tKokkosClassOnDevice.cpp @@ -0,0 +1,108 @@ +#include "Kokkos_Core.hpp" +#include "Kokkos_View.hpp" + +#include "Teuchos_Assert.hpp" +#include "Teuchos_UnitTestHarness.hpp" +#include "Phalanx_KokkosViewOfViews.hpp" + +template +class Vector { + T values_[dim]; +public: + + template + KOKKOS_INLINE_FUNCTION + T& operator[](const INDEX_I& index){return values_[index];} + + template + KOKKOS_INLINE_FUNCTION + const T& operator[](const INDEX_I& index)const{return values_[index];} + + template + KOKKOS_INLINE_FUNCTION + volatile T& operator[](const INDEX_I& index)volatile{return values_[index];} + + template + KOKKOS_INLINE_FUNCTION + const volatile T& operator[](const INDEX_I& index)const volatile{return values_[index];} +}; + +class MyClass { + Kokkos::View a_; + double b_[3]; + Kokkos::View c_; + Vector d_; + // To test for cuda warnings when MyClass is lambda captured to + // device + PHX::ViewOfViews3<1,Kokkos::View> e_; + +public: + MyClass() : + a_("a",3), + c_("c",3) + { + Kokkos::deep_copy(a_,1.0); + b_[0] = 1.0; + b_[1] = 2.0; + b_[2] = 3.0; + Kokkos::deep_copy(c_,2.0); + d_[0] = 1.0; + d_[1] = 2.0; + d_[2] = 3.0; + } + + void KOKKOS_FUNCTION checkInternalMethod1() const + { this->callInternalMethod1(); } + + void KOKKOS_FUNCTION + callInternalMethod1() const + { + printf("b_[0]=%f\n",b_[0]); + printf("b_[1]=%f\n",b_[1]); + printf("b_[2]=%f\n",b_[2]); + a_(0)=b_[0]; + a_(1)=b_[1]; + a_(2)=b_[2]; + } + + void KOKKOS_FUNCTION checkInternalMethod2() const + { this->callInternalMethod2(); } + + void KOKKOS_FUNCTION + callInternalMethod2() const + { + a_(0)=c_(0); + a_(1)=c_(1); + a_(2)=c_(2); + } + + void KOKKOS_FUNCTION checkInternalMethod3() const + { this->callInternalMethod3(); } + + void KOKKOS_FUNCTION + callInternalMethod3() const + { + a_(0)=d_[0]; + a_(1)=d_[1]; + a_(2)=d_[2]; + } +}; + +TEUCHOS_UNIT_TEST(KokkosClassOnDevice, One) +{ + MyClass my_class; + + Kokkos::parallel_for("test 1",1,KOKKOS_LAMBDA (const int ) { + my_class.checkInternalMethod1(); + }); + + Kokkos::parallel_for("test 2",1,KOKKOS_LAMBDA (const int ) { + my_class.checkInternalMethod2(); + }); + + Kokkos::parallel_for("test 3",1,KOKKOS_LAMBDA (const int ) { + my_class.checkInternalMethod3(); + }); + + Kokkos::fence(); +} diff --git a/packages/pliris/CMakeLists.txt b/packages/pliris/CMakeLists.txt index 738a4416edd3..c03d1d7390fe 100644 --- a/packages/pliris/CMakeLists.txt +++ b/packages/pliris/CMakeLists.txt @@ -20,11 +20,6 @@ TRIBITS_ADD_OPTION_AND_DEFINE(${PACKAGE_NAME}_ENABLE_SCPLX "Enable single precision complex functionality." OFF ) -TRIBITS_ADD_OPTION_AND_DEFINE(${PACKAGE_NAME}_ENABLE_DREAL - DREAL - "Set reals to double precision." - OFF ) - TRIBITS_ADD_OPTION_AND_DEFINE(${PACKAGE_NAME}_ENABLE_SREAL SREAL "Set reals to single precision." @@ -37,11 +32,16 @@ TRIBITS_ADD_OPTION_AND_DEFINE(${PACKAGE_NAME}_ENABLE_SREAL #If no options are specified set DREAL as the default IF(NOT ${PACKAGE_NAME}_ENABLE_ZCPLX AND NOT ${PACKAGE_NAME}_ENABLE_SCPLX AND - NOT ${PACKAGE_NAME}_ENABLE_SREAL AND - NOT ${PACKAGE_NAME}_ENABLE_DREAL) - - SET(${PACKAGE_NAME}_ENABLE_DREAL ON) + NOT ${PACKAGE_NAME}_ENABLE_SREAL + ) + SET(${PACKAGE_NAME}_ENABLE_DREAL_DEFAULT ON) +ELSE() + SET(${PACKAGE_NAME}_ENABLE_DREAL_DEFAULT OFF) ENDIF() +TRIBITS_ADD_OPTION_AND_DEFINE(${PACKAGE_NAME}_ENABLE_DREAL + DREAL + "Set reals to double precision." + ${${PACKAGE_NAME}_ENABLE_DREAL_DEFAULT} ) ADD_SUBDIRECTORY(src) diff --git a/packages/shylu/shylu_node/basker/src/shylubasker_def.hpp b/packages/shylu/shylu_node/basker/src/shylubasker_def.hpp index 10cd77c64f45..402a722c38fd 100644 --- a/packages/shylu/shylu_node/basker/src/shylubasker_def.hpp +++ b/packages/shylu/shylu_node/basker/src/shylubasker_def.hpp @@ -925,10 +925,11 @@ namespace BaskerNS } } - #ifdef AMD_ON_D + #define SHYLU_BASKER_AMD_ON_D + #ifdef SHYLU_BASKER_AMD_ON_D // -------------------------------------------- // reset the small D blocks - if (btf_top_tabs_offset > 0) { + if (btf_top_tabs_offset > 0) { Int d_last = btf_top_tabs_offset; Int ncol = btf_tabs(d_last); @@ -942,13 +943,15 @@ namespace BaskerNS permute_row(BTF_E, order_blk_amd_d); } - // revert BLK_MWM ordering - auto order_blk_mwm_d = Kokkos::subview(order_blk_mwm_inv, - range_type (0, ncol)); - permute_row(BTF_D, order_blk_mwm_d); - if (BTF_E.ncol > 0) { - // Apply MWM perm to cols - permute_row(BTF_E, order_blk_mwm_d); + if (Options.blk_matching != 0) { + // revert BLK_MWM ordering + auto order_blk_mwm_d = Kokkos::subview(order_blk_mwm_inv, + range_type (0, ncol)); + permute_row(BTF_D, order_blk_mwm_d); + if (BTF_E.ncol > 0) { + // Apply MWM perm to cols + permute_row(BTF_E, order_blk_mwm_d); + } } } #endif @@ -1100,7 +1103,7 @@ namespace BaskerNS numeric_col_iperm_array(i) = i; } - #ifdef AMD_ON_D + #ifdef SHYLU_BASKER_AMD_ON_D if (btf_top_tabs_offset > 0) { Kokkos::Timer mwm_amd_perm_timer; Int d_last = btf_top_tabs_offset; @@ -1116,7 +1119,7 @@ namespace BaskerNS } // ---------------------------------------------------------------------------------------------- - // recompute MWM and AMD on each block of C + // recompute MWM and AMD on each block of D INT_1DARRAY blk_nnz; INT_1DARRAY blk_work; btf_blk_mwm_amd(0, d_last, BTF_D, diff --git a/packages/stk/stk_balance/stk_balance/balanceUtils.cpp b/packages/stk/stk_balance/stk_balance/balanceUtils.cpp index abb3f316e0ba..60a3ab5918c8 100644 --- a/packages/stk/stk_balance/stk_balance/balanceUtils.cpp +++ b/packages/stk/stk_balance/stk_balance/balanceUtils.cpp @@ -1,6 +1,7 @@ #include "balanceUtils.hpp" #include "mpi.h" #include "search_tolerance/FaceSearchTolerance.hpp" +#include "stk_balance/search_tolerance_algs/SecondShortestEdgeFaceSearchTolerance.hpp" #include "stk_mesh/base/Field.hpp" // for field_data #include "stk_mesh/base/FieldBase.hpp" // for field_data #include "stk_util/diag/StringUtil.hpp" @@ -20,6 +21,7 @@ BalanceSettings::BalanceSettings() : m_numInputProcessors(0), m_numOutputProcessors(0), m_isRebalancing(false), + m_shouldFixCoincidentElements(true), m_initialDecompMethod("RIB"), m_useNestedDecomp(false), m_shouldPrintDiagnostics(false), @@ -74,6 +76,18 @@ VertexWeightMethod BalanceSettings::getVertexWeightMethod() const return m_vertexWeightMethod; } +bool +BalanceSettings::shouldFixCoincidentElements() const +{ + return m_shouldFixCoincidentElements; +} + +void +BalanceSettings::setShouldFixCoincidentElements(bool fixCoincidentElements) +{ + m_shouldFixCoincidentElements = fixCoincidentElements; +} + bool BalanceSettings::includeSearchResultsInGraph() const { return false; @@ -345,6 +359,39 @@ std::string BalanceSettings::get_log_filename() const ////////////////////////////////////// +GraphCreationSettings::GraphCreationSettings() + : m_method(DefaultSettings::decompMethod), + m_ToleranceForFaceSearch(DefaultSettings::faceSearchAbsTol), + m_ToleranceForParticleSearch(DefaultSettings::particleSearchTol), + m_vertexWeightMultiplierForVertexInSearch(DefaultSettings::faceSearchVertexMultiplier), + m_edgeWeightForSearch(DefaultSettings::faceSearchEdgeWeight), + m_UseConstantToleranceForFaceSearch(false), + m_shouldFixSpiders(DefaultSettings::fixSpiders), + m_shouldFixMechanisms(DefaultSettings::fixMechanisms), + m_spiderBeamConnectivityCountField(nullptr), + m_spiderVolumeConnectivityCountField(nullptr), + m_outputSubdomainField(nullptr), + m_includeSearchResultInGraph(DefaultSettings::useContactSearch), + m_useNodeBalancer(false), + m_nodeBalancerTargetLoadBalance(1.0), + m_nodeBalancerMaxIterations(5) +{ + setToleranceFunctionForFaceSearch( + std::make_shared(DefaultSettings::faceSearchRelTol) + ); +} + +GraphCreationSettings::GraphCreationSettings(double faceSearchTol, double particleSearchTol, double edgeWeightSearch, + const std::string& decompMethod, double multiplierVWSearch) + : GraphCreationSettings() +{ + m_method = decompMethod; + m_ToleranceForFaceSearch = faceSearchTol; + m_ToleranceForParticleSearch = particleSearchTol; + m_vertexWeightMultiplierForVertexInSearch = multiplierVWSearch; + m_edgeWeightForSearch = edgeWeightSearch; +} + size_t GraphCreationSettings::getNumNodesRequiredForConnection(stk::topology element1Topology, stk::topology element2Topology) const { const int noConnection = 1000; @@ -367,7 +414,7 @@ size_t GraphCreationSettings::getNumNodesRequiredForConnection(stk::topology ele double GraphCreationSettings::getGraphEdgeWeightForSearch() const { - return edgeWeightForSearch; + return m_edgeWeightForSearch; } double GraphCreationSettings::getGraphEdgeWeight(stk::topology element1Topology, stk::topology element2Topology) const @@ -476,7 +523,7 @@ void GraphCreationSettings::setIncludeSearchResultsInGraph(bool doContactSearch) double GraphCreationSettings::getToleranceForParticleSearch() const { - return mToleranceForParticleSearch; + return m_ToleranceForParticleSearch; } void GraphCreationSettings::setToleranceFunctionForFaceSearch(std::shared_ptr faceSearchTolerance) @@ -496,7 +543,7 @@ double GraphCreationSettings::getToleranceForFaceSearch(const stk::mesh::BulkDat const unsigned numFaceNodes) const { if (m_UseConstantToleranceForFaceSearch) { - return mToleranceForFaceSearch; + return m_ToleranceForFaceSearch; } else { return m_faceSearchToleranceFunction->compute(mesh, coordField, faceNodes, numFaceNodes); @@ -510,35 +557,35 @@ bool GraphCreationSettings::getEdgesForParticlesUsingSearch() const double GraphCreationSettings::getVertexWeightMultiplierForVertexInSearch() const { - return vertexWeightMultiplierForVertexInSearch; + return m_vertexWeightMultiplierForVertexInSearch; } std::string GraphCreationSettings::getDecompMethod() const { - return method; + return m_method; } void GraphCreationSettings::setDecompMethod(const std::string& input_method) { - method = input_method; + m_method = input_method; } void GraphCreationSettings::setToleranceForFaceSearch(double tol) { m_UseConstantToleranceForFaceSearch = true; - mToleranceForFaceSearch = tol; + m_ToleranceForFaceSearch = tol; } void GraphCreationSettings::setToleranceForParticleSearch(double tol) { - mToleranceForParticleSearch = tol; + m_ToleranceForParticleSearch = tol; } void GraphCreationSettings::setEdgeWeightForSearch(double w) { - edgeWeightForSearch = w; + m_edgeWeightForSearch = w; } void GraphCreationSettings::setVertexWeightMultiplierForVertexInSearch(double w) { - vertexWeightMultiplierForVertexInSearch = w; + m_vertexWeightMultiplierForVertexInSearch = w; } int GraphCreationSettings::getConnectionTableIndex(stk::topology elementTopology) const { diff --git a/packages/stk/stk_balance/stk_balance/balanceUtils.hpp b/packages/stk/stk_balance/stk_balance/balanceUtils.hpp index 0ce8ff6f87d4..e67359b5e6f0 100644 --- a/packages/stk/stk_balance/stk_balance/balanceUtils.hpp +++ b/packages/stk/stk_balance/stk_balance/balanceUtils.hpp @@ -115,6 +115,9 @@ class BalanceSettings virtual void setVertexWeightMethod(VertexWeightMethod method); virtual VertexWeightMethod getVertexWeightMethod() const; + virtual bool shouldFixCoincidentElements() const; + virtual void setShouldFixCoincidentElements(bool fixCoincidentElements); + // Graph based options only virtual bool includeSearchResultsInGraph() const; virtual void setIncludeSearchResultsInGraph(bool doContactSearch); @@ -208,6 +211,7 @@ class BalanceSettings unsigned m_numInputProcessors; unsigned m_numOutputProcessors; bool m_isRebalancing; + bool m_shouldFixCoincidentElements; std::string m_initialDecompMethod; std::string m_inputFilename; std::string m_outputFilename; @@ -230,33 +234,9 @@ class BasicGeometricSettings : public BalanceSettings class GraphCreationSettings : public BalanceSettings { public: - GraphCreationSettings() - : mToleranceForFaceSearch(DefaultSettings::faceSearchAbsTol), - mToleranceForParticleSearch(DefaultSettings::particleSearchTol), - edgeWeightForSearch(DefaultSettings::faceSearchEdgeWeight), - method(DefaultSettings::decompMethod), - vertexWeightMultiplierForVertexInSearch(DefaultSettings::faceSearchVertexMultiplier), - m_UseConstantToleranceForFaceSearch(true), - m_shouldFixSpiders(DefaultSettings::fixSpiders), - m_shouldFixMechanisms(DefaultSettings::fixMechanisms), - m_spiderBeamConnectivityCountField(nullptr), - m_spiderVolumeConnectivityCountField(nullptr), - m_outputSubdomainField(nullptr), - m_includeSearchResultInGraph(DefaultSettings::useContactSearch), - m_useNodeBalancer(false), - m_nodeBalancerTargetLoadBalance(1.0), - m_nodeBalancerMaxIterations(5) - {} - - GraphCreationSettings(double faceSearchTol, double particleSearchTol, double edgeWeightSearch, const std::string& decompMethod, double multiplierVWSearch) - : GraphCreationSettings() - { - mToleranceForFaceSearch = faceSearchTol; - mToleranceForParticleSearch = particleSearchTol; - edgeWeightForSearch = edgeWeightSearch; - method = decompMethod; - vertexWeightMultiplierForVertexInSearch = multiplierVWSearch; - } + GraphCreationSettings(); + GraphCreationSettings(double faceSearchTol, double particleSearchTol, double edgeWeightSearch, + const std::string& decompMethod, double multiplierVWSearch); virtual ~GraphCreationSettings() = default; @@ -311,11 +291,12 @@ class GraphCreationSettings : public BalanceSettings protected: int getConnectionTableIndex(stk::topology elementTopology) const; int getEdgeWeightTableIndex(stk::topology elementTopology) const; - double mToleranceForFaceSearch; - double mToleranceForParticleSearch; - double edgeWeightForSearch; - std::string method; - double vertexWeightMultiplierForVertexInSearch; + + std::string m_method; + double m_ToleranceForFaceSearch; + double m_ToleranceForParticleSearch; + double m_vertexWeightMultiplierForVertexInSearch; + double m_edgeWeightForSearch; bool m_UseConstantToleranceForFaceSearch; bool m_shouldFixSpiders; bool m_shouldFixMechanisms; @@ -335,8 +316,8 @@ class GraphCreationSettingsWithCustomTolerances : public GraphCreationSettings GraphCreationSettingsWithCustomTolerances() : GraphCreationSettings() { - mToleranceForFaceSearch = 0.1; - mToleranceForParticleSearch = 1.0; + m_ToleranceForFaceSearch = 0.1; + m_ToleranceForParticleSearch = 1.0; } virtual bool getEdgesForParticlesUsingSearch() const { return true; } @@ -384,7 +365,7 @@ class FieldVertexWeightSettings : public GraphCreationSettings m_weightField(weightField), m_defaultWeight(defaultWeight) { - method = "parmetis"; + m_method = "parmetis"; m_includeSearchResultInGraph = false; } virtual ~FieldVertexWeightSettings() = default; @@ -393,8 +374,8 @@ class FieldVertexWeightSettings : public GraphCreationSettings virtual bool areVertexWeightsProvidedViaFields() const { return true; } virtual int getGraphVertexWeight(stk::topology type) const { return 1; } virtual double getImbalanceTolerance() const { return 1.05; } - virtual void setDecompMethod(const std::string& input_method) { method = input_method;} - virtual std::string getDecompMethod() const { return method; } + virtual void setDecompMethod(const std::string& input_method) { m_method = input_method;} + virtual std::string getDecompMethod() const { return m_method; } virtual double getGraphVertexWeight(stk::mesh::Entity entity, int criteria_index = 0) const { diff --git a/packages/stk/stk_balance/stk_balance/internal/Balancer.cpp b/packages/stk/stk_balance/stk_balance/internal/Balancer.cpp index ad756b03d2ff..9614c78bbd74 100644 --- a/packages/stk/stk_balance/stk_balance/internal/Balancer.cpp +++ b/packages/stk/stk_balance/stk_balance/internal/Balancer.cpp @@ -77,8 +77,10 @@ bool loadBalance(const BalanceSettings& balanceSettings, stk::mesh::BulkData& st DecompositionChangeList changeList(stkMeshBulkData, decomp); balanceSettings.modifyDecomposition(changeList); - internal::logMessage(stkMeshBulkData.parallel(), "Moving coincident elements to the same processor"); - keep_coincident_elements_together(stkMeshBulkData, changeList); + if (balanceSettings.shouldFixCoincidentElements()) { + internal::logMessage(stkMeshBulkData.parallel(), "Moving coincident elements to the same processor"); + keep_coincident_elements_together(stkMeshBulkData, changeList); + } if (balanceSettings.shouldFixSpiders()) { internal::logMessage(stkMeshBulkData.parallel(), "Fixing spider elements"); @@ -88,19 +90,18 @@ bool loadBalance(const BalanceSettings& balanceSettings, stk::mesh::BulkData& st const size_t num_global_entity_migrations = changeList.get_num_global_entity_migrations(); const size_t max_global_entity_migrations = changeList.get_max_global_entity_migrations(); - if (num_global_entity_migrations > 0) - { + if (num_global_entity_migrations > 0) { internal::logMessage(stkMeshBulkData.parallel(), "Moving elements to new processors"); internal::rebalance(changeList); - if (balanceSettings.shouldFixMechanisms()) - { + if (balanceSettings.shouldFixMechanisms()) { internal::logMessage(stkMeshBulkData.parallel(), "Fixing mechanisms found during decomposition"); stk::balance::internal::detectAndFixMechanisms(balanceSettings, stkMeshBulkData); } - if (balanceSettings.shouldPrintMetrics()) + if (balanceSettings.shouldPrintMetrics()) { internal::print_rebalance_metrics(num_global_entity_migrations, max_global_entity_migrations, stkMeshBulkData); + } } internal::compute_balance_diagnostics(stkMeshBulkData, balanceSettings); diff --git a/packages/stk/stk_balance/stk_balance/internal/SubdomainWriter.cpp b/packages/stk/stk_balance/stk_balance/internal/SubdomainWriter.cpp index 858d3b604ce1..ff50d05dcf4e 100644 --- a/packages/stk/stk_balance/stk_balance/internal/SubdomainWriter.cpp +++ b/packages/stk/stk_balance/stk_balance/internal/SubdomainWriter.cpp @@ -62,8 +62,8 @@ SubdomainWriter::setup_output_file(const std::string& fileName, unsigned subdoma { Ioss::DatabaseIO *dbo = stk::io::create_database_for_subdomain(fileName, subdomain, numSubdomains); m_outRegion = new Ioss::Region(dbo, fileName); - - stk::io::add_properties_for_subdomain(*m_bulk, *m_outRegion, subdomain, numSubdomains, globalNumNodes, globalNumElems); + stk::io::OutputParams params(*m_outRegion, *m_bulk); + stk::io::add_properties_for_subdomain(params, subdomain, numSubdomains, globalNumNodes, globalNumElems); int dbIntSize = m_inputBroker.check_integer_size_requirements_serial(); if (dbIntSize > 4) { @@ -114,7 +114,8 @@ SubdomainWriter::write_mesh() { add_qa_records(); add_info_records(); - stk::io::write_file_for_subdomain(*m_outRegion, *m_bulk, m_nodeSharingInfo); + stk::io::OutputParams params(*m_outRegion, *m_bulk); + stk::io::write_file_for_subdomain(params, m_nodeSharingInfo); add_global_variables(); } @@ -143,7 +144,8 @@ SubdomainWriter::write_global_variables(int step) void SubdomainWriter::write_transient_data(double time) { - const int step = stk::io::write_transient_data_for_subdomain(*m_outRegion, *m_bulk, time); + stk::io::OutputParams params(*m_outRegion, *m_bulk); + const int step = stk::io::write_transient_data_for_subdomain(params, time); write_global_variables(step); } diff --git a/packages/stk/stk_balance/stk_balance/internal/privateDeclarations.cpp b/packages/stk/stk_balance/stk_balance/internal/privateDeclarations.cpp index ce37c90f902c..13ffaca54e70 100644 --- a/packages/stk/stk_balance/stk_balance/internal/privateDeclarations.cpp +++ b/packages/stk/stk_balance/stk_balance/internal/privateDeclarations.cpp @@ -1511,6 +1511,89 @@ void compute_relative_node_interface_size_diagnostic(RelativeNodeInterfaceSizeDi : 0.0); } +double getTypicalElemsPerNode(stk::topology type) +{ + switch(type) + { + case stk::topology::PARTICLE: + return 1; + case stk::topology::LINE_2_1D: + return 1; + case stk::topology::LINE_3_1D: + return 1.0/2.0; + case stk::topology::BEAM_2: + return 1; + case stk::topology::BEAM_3: + return 1.0/2.0; + case stk::topology::SHELL_LINE_2: + return 1; + case stk::topology::SHELL_LINE_3: + return 1.0/2.0; + case stk::topology::SPRING_2: + return 1; + case stk::topology::SPRING_3: + return 1.0/2.0; + case stk::topology::TRI_3_2D: + return 2; + case stk::topology::TRI_4_2D: + return 2.0/3.0; + case stk::topology::TRI_6_2D: + return 2.0/4.0; + case stk::topology::SHELL_TRI_3: + return 2; + case stk::topology::SHELL_TRI_4: + return 2.0/3.0; + case stk::topology::SHELL_TRI_6: + return 2.0/4.0; + case stk::topology::QUAD_4_2D: + return 1; + case stk::topology::QUAD_8_2D: + return 1.0/3.0; + case stk::topology::QUAD_9_2D: + return 1.0/4.0; + case stk::topology::SHELL_QUAD_4: + return 1; + case stk::topology::SHELL_QUAD_8: + return 1.0/3.0; + case stk::topology::SHELL_QUAD_9: + return 1.0/4.0; + case stk::topology::TET_4: + return 6; + case stk::topology::TET_8: + return 6.0/13.0; + case stk::topology::TET_10: + return 6.0/8.0; + case stk::topology::TET_11: + return 6.0/14.0; + case stk::topology::PYRAMID_5: + return 6.0/2.0; + case stk::topology::PYRAMID_13: + return 6.0/13.0; + case stk::topology::PYRAMID_14: + return 6.0/19.0; + case stk::topology::WEDGE_6: + return 2; + case stk::topology::WEDGE_12: + return 2.0/4.0; + case stk::topology::WEDGE_15: + return 2.0/5.0; + case stk::topology::WEDGE_18: + return 2.0/8.0; + case stk::topology::HEXAHEDRON_8: + return 1; + case stk::topology::HEXAHEDRON_20: + return 1.0/4.0; + case stk::topology::HEXAHEDRON_27: + return 1.0/8.0; + default: + if ( type.is_superelement( )) + { + return 1.0/100.0; + } + throw("Invalid Element Type In WeightsOfElement"); + } +} + double get_connected_node_weight(const stk::mesh::BulkData & bulk, std::vector & connectedNodesBuffer, const stk::mesh::Entity node) { @@ -1544,7 +1627,9 @@ void spread_weight_across_connected_elements(const stk::mesh::BulkData & bulk, c const stk::mesh::Entity element = elements[elemIndex]; if (bulk.bucket(element).owned()) { double * elemWeight = stk::mesh::field_data(elementWeights, element); - *elemWeight += nodeWeight / numElements; + const unsigned numNodes = bulk.num_nodes(element); + const double typicalElemsPerNode = getTypicalElemsPerNode(bulk.bucket(element).topology()); + *elemWeight += nodeWeight / (numNodes * typicalElemsPerNode); } } } diff --git a/packages/stk/stk_balance/stk_balance/m2n/M2NSubdomainWriter.cpp b/packages/stk/stk_balance/stk_balance/m2n/M2NSubdomainWriter.cpp index 0524e908e674..9985dcd901fc 100644 --- a/packages/stk/stk_balance/stk_balance/m2n/M2NSubdomainWriter.cpp +++ b/packages/stk/stk_balance/stk_balance/m2n/M2NSubdomainWriter.cpp @@ -63,8 +63,8 @@ SubdomainWriter::setup_output_file(const std::string& fileName, unsigned subdoma { Ioss::DatabaseIO *dbo = stk::io::create_database_for_subdomain(fileName, subdomain, numSubdomains); m_outRegion = new Ioss::Region(dbo, fileName); - - stk::io::add_properties_for_subdomain(*m_bulk, *m_outRegion, subdomain, numSubdomains, globalNumNodes, globalNumElems); + stk::io::OutputParams params(*m_outRegion, *m_bulk); + stk::io::add_properties_for_subdomain(params, subdomain, numSubdomains, globalNumNodes, globalNumElems); int dbIntSize = m_inputBroker.check_integer_size_requirements_serial(); if (dbIntSize > 4) { @@ -115,7 +115,8 @@ SubdomainWriter::write_mesh() { add_qa_records(); add_info_records(); - stk::io::write_file_for_subdomain(*m_outRegion, *m_bulk, m_nodeSharingInfo); + stk::io::OutputParams params(*m_outRegion, *m_bulk); + stk::io::write_file_for_subdomain(params, m_nodeSharingInfo); add_global_variables(); } @@ -144,7 +145,8 @@ SubdomainWriter::write_global_variables(int step) void SubdomainWriter::write_transient_data(double time) { - const int step = stk::io::write_transient_data_for_subdomain(*m_outRegion, *m_bulk, time); + stk::io::OutputParams params(*m_outRegion, *m_bulk); + const int step = stk::io::write_transient_data_for_subdomain(params, time); write_global_variables(step); } diff --git a/packages/stk/stk_balance/stk_balance/setup/DefaultSettings.cpp b/packages/stk/stk_balance/stk_balance/setup/DefaultSettings.cpp index 32138ef01a65..1df78175be38 100644 --- a/packages/stk/stk_balance/stk_balance/setup/DefaultSettings.cpp +++ b/packages/stk/stk_balance/stk_balance/setup/DefaultSettings.cpp @@ -63,10 +63,10 @@ std::string vertex_weight_method_name(VertexWeightMethod method) { // this entire file can be deleted. constexpr const char * DefaultSettings::logFile; +constexpr const char * DefaultSettings::outputDirectory; constexpr const char * DefaultSettings::decompMethod; constexpr bool DefaultSettings::useContactSearch; -constexpr bool DefaultSettings::fixSpiders; constexpr bool DefaultSettings::fixMechanisms; constexpr double DefaultSettings::faceSearchRelTol; @@ -74,13 +74,24 @@ constexpr double DefaultSettings::faceSearchAbsTol; constexpr double DefaultSettings::particleSearchTol; +constexpr VertexWeightMethod DefaultSettings::vertexWeightMethod; +constexpr double DefaultSettings::graphEdgeWeightMultiplier; constexpr double DefaultSettings::faceSearchVertexMultiplier; constexpr double DefaultSettings::faceSearchEdgeWeight; +constexpr bool DefaultSettings::fixSpiders; +constexpr VertexWeightMethod DefaultSettings::sdVertexWeightMethod; +constexpr double DefaultSettings::sdGraphEdgeWeightMultiplier; +constexpr double DefaultSettings::sdFaceSearchVertexMultiplier; +constexpr double DefaultSettings::sdFaceSearchEdgeWeight; +constexpr bool DefaultSettings::sdFixSpiders; + +constexpr VertexWeightMethod DefaultSettings::smVertexWeightMethod; +constexpr double DefaultSettings::smGraphEdgeWeightMultiplier; constexpr double DefaultSettings::smFaceSearchVertexMultiplier; constexpr double DefaultSettings::smFaceSearchEdgeWeight; +constexpr bool DefaultSettings::smFixSpiders; constexpr const char * DefaultSettings::vertexWeightBlockMultiplier; - } } diff --git a/packages/stk/stk_balance/stk_balance/setup/DefaultSettings.hpp b/packages/stk/stk_balance/stk_balance/setup/DefaultSettings.hpp index 60cade9bb1c3..166aff0399ae 100644 --- a/packages/stk/stk_balance/stk_balance/setup/DefaultSettings.hpp +++ b/packages/stk/stk_balance/stk_balance/setup/DefaultSettings.hpp @@ -54,25 +54,33 @@ struct DefaultSettings { static constexpr const char * decompMethod {"parmetis"}; static constexpr bool useContactSearch {true}; - static constexpr bool fixSpiders {false}; - static constexpr bool fixMechanisms {true}; + static constexpr bool fixMechanisms {false}; static constexpr double faceSearchRelTol {0.15}; static constexpr double faceSearchAbsTol {0.0001}; static constexpr double particleSearchTol {3.0}; - static constexpr double faceSearchVertexMultiplier {5.0}; - static constexpr double faceSearchEdgeWeight {15.0}; + static constexpr VertexWeightMethod vertexWeightMethod {VertexWeightMethod::CONSTANT}; + static constexpr double graphEdgeWeightMultiplier {1.0}; + static constexpr double faceSearchVertexMultiplier {1.0}; + static constexpr double faceSearchEdgeWeight {1.0}; + static constexpr bool fixSpiders {false}; - static constexpr double smFaceSearchVertexMultiplier {10.0}; - static constexpr double smFaceSearchEdgeWeight {3.0}; + static constexpr VertexWeightMethod sdVertexWeightMethod {VertexWeightMethod::CONNECTIVITY}; + static constexpr double sdGraphEdgeWeightMultiplier {10.0}; + static constexpr double sdFaceSearchVertexMultiplier {2.0}; + static constexpr double sdFaceSearchEdgeWeight {1.0}; + static constexpr bool sdFixSpiders {true}; - static constexpr double graphEdgeWeightMultiplier {1.0}; + static constexpr VertexWeightMethod smVertexWeightMethod {VertexWeightMethod::CONSTANT}; + static constexpr double smGraphEdgeWeightMultiplier {1.0}; + static constexpr double smFaceSearchVertexMultiplier {3.0}; + static constexpr double smFaceSearchEdgeWeight {1.0}; + static constexpr bool smFixSpiders {false}; static constexpr const char * vertexWeightBlockMultiplier {""}; - static constexpr VertexWeightMethod vertexWeightMethod {VertexWeightMethod::TOPOLOGY}; }; } } diff --git a/packages/stk/stk_balance/stk_balance/setup/Parser.cpp b/packages/stk/stk_balance/stk_balance/setup/Parser.cpp index c83e5cbc8feb..5b33cb57496a 100644 --- a/packages/stk/stk_balance/stk_balance/setup/Parser.cpp +++ b/packages/stk/stk_balance/stk_balance/setup/Parser.cpp @@ -84,14 +84,11 @@ std::string Examples::get_long_examples() examples += tab + "To decompose for 512 processors and put the decomposition into a directory named 'temp1':\n"; examples += tab + tab + "> mpirun -n 512 " + m_execName + " mesh.g temp1\n"; examples += "\n"; - examples += tab + "To decompose for 16 processors and use the default relative contact search tolerance:\n"; - examples += tab + tab + "> mpirun -n 16 " + m_execName + " mesh.g " + stk::dash_it(m_optionNames.faceSearchRelTol) + "\n"; - examples += "\n"; examples += tab + "To decompose for 16 processors and use a relative contact search tolerance of 0.05:\n"; examples += tab + tab + "> mpirun -n 16 " + m_execName + " mesh.g " + stk::dash_it(m_optionNames.faceSearchRelTol) + "=0.05\n"; examples += "\n"; - examples += tab + "To decompose for 16 processors with the RCB decomposition method:\n"; - examples += tab + tab + "> mpirun -n 16 " + m_execName + " mesh.g " + stk::dash_it(m_optionNames.decompMethod) + "=rcb\n"; + examples += tab + "To decompose for 16 processors with the RIB decomposition method:\n"; + examples += tab + tab + "> mpirun -n 16 " + m_execName + " mesh.g " + stk::dash_it(m_optionNames.decompMethod) + "=rib\n"; examples += "\n"; examples += tab + "To rebalance a 16 processor mesh into 64 processors:\n"; examples += tab + tab + "> mpirun -n 16 " + m_execName + " mesh.g " + stk::dash_it(m_optionNames.rebalanceTo) + "=64\n"; @@ -156,20 +153,24 @@ void Parser::add_options_to_parser() smStream << "Use settings suitable for solving Solid Mechanics problems. " << "This flag implies:" << std::endl << " " << stk::dash_it(m_optionNames.faceSearchRelTol) << "=" << DefaultSettings::faceSearchRelTol << std::endl - << " " << stk::dash_it(m_optionNames.fixSpiders) << "=" << ((DefaultSettings::fixSpiders) ? "on" : "off") << std::endl + << " " << stk::dash_it(m_optionNames.fixSpiders) << "=" << ((DefaultSettings::smFixSpiders) ? "on" : "off") << std::endl << " " << stk::dash_it(m_optionNames.fixMechanisms) << "=" << ((DefaultSettings::fixMechanisms) ? "on" : "off") << std::endl - << " Face search graph vertex weight multiplier = " << DefaultSettings::smFaceSearchVertexMultiplier << std::endl - << " Face search graph edge weight = " << DefaultSettings::smFaceSearchEdgeWeight; + << " " << stk::dash_it(m_optionNames.vertexWeightMethod) << "=" << vertex_weight_method_name(DefaultSettings::smVertexWeightMethod) << std::endl + << " " << stk::dash_it(m_optionNames.edgeWeightMultiplier) << "=" << DefaultSettings::smGraphEdgeWeightMultiplier << std::endl + << " " << stk::dash_it(m_optionNames.contactSearchVertexWeightMultiplier) << "=" << DefaultSettings::smFaceSearchVertexMultiplier << std::endl + << " " << stk::dash_it(m_optionNames.contactSearchEdgeWeight) << "=" << DefaultSettings::smFaceSearchEdgeWeight << std::endl; stk::CommandLineOption smDefaults{m_optionNames.smDefaults, "", smStream.str()}; std::ostringstream sdStream; sdStream << "Use settings suitable for solving Structural Dynamics problems. " << "This flag implies:" << std::endl - << " " << stk::dash_it(m_optionNames.faceSearchAbsTol) << "=" << DefaultSettings::faceSearchAbsTol << std::endl - << " " << stk::dash_it(m_optionNames.fixSpiders) << "=on" << std::endl + << " " << stk::dash_it(m_optionNames.faceSearchRelTol) << "=" << DefaultSettings::faceSearchRelTol << std::endl + << " " << stk::dash_it(m_optionNames.fixSpiders) << "=" << ((DefaultSettings::sdFixSpiders) ? "on" : "off") << std::endl << " " << stk::dash_it(m_optionNames.fixMechanisms) << "=" << ((DefaultSettings::fixMechanisms) ? "on" : "off") << std::endl - << " Face search graph vertex weight multiplier = " << DefaultSettings::faceSearchVertexMultiplier << std::endl - << " Face search graph edge weight = " << DefaultSettings::faceSearchEdgeWeight; + << " " << stk::dash_it(m_optionNames.vertexWeightMethod) << "=" << vertex_weight_method_name(DefaultSettings::sdVertexWeightMethod) << std::endl + << " " << stk::dash_it(m_optionNames.edgeWeightMultiplier) << "=" << DefaultSettings::sdGraphEdgeWeightMultiplier << std::endl + << " " << stk::dash_it(m_optionNames.contactSearchVertexWeightMultiplier) << "=" << DefaultSettings::sdFaceSearchVertexMultiplier << std::endl + << " " << stk::dash_it(m_optionNames.contactSearchEdgeWeight) << "=" << DefaultSettings::sdFaceSearchEdgeWeight << std::endl; stk::CommandLineOption sdDefaults{m_optionNames.sdDefaults, "", sdStream.str()}; stk::CommandLineOption faceSearchAbsTol{m_optionNames.faceSearchAbsTol, "", @@ -177,7 +178,8 @@ void Parser::add_options_to_parser() "Optionally provide a numeric tolerance value."}; stk::CommandLineOption faceSearchRelTol{m_optionNames.faceSearchRelTol, "", "Use a tolerance relative to the face size for face contact search. " - "Optionally provide a numeric tolerance value."}; + "Optionally provide a numeric tolerance value. This is the global " + "default. Values less than 0.5 are recommended."}; stk::CommandLineOption contactSearch{m_optionNames.contactSearch, "", "Use proximity search for contact [on|off]"}; stk::CommandLineOption fixSpiders{m_optionNames.fixSpiders, "", @@ -202,16 +204,18 @@ void Parser::add_options_to_parser() "of processors must be an integer multiple of the input processors."}; stk::CommandLineOption vertexWeightMethod{m_optionNames.vertexWeightMethod, "", - "(Experimental) Method used to calculate vertex weights given to the partitioner. " + "Method used to calculate vertex weights given to the partitioner. " "[constant|topology|connectivity]"}; stk::CommandLineOption contactSearchEdgeWeight{m_optionNames.contactSearchEdgeWeight, "", - "(Experimental) Graph edge weight to use between elements that are determined to be " + "Graph edge weight to use between elements that are determined to be " "in contact."}; stk::CommandLineOption contactSearchVertexWeightMultiplier{m_optionNames.contactSearchVertexWeightMultiplier, "", - "(Experimental) Scale factor to be applied to graph vertex weights for elements that " + "Scale factor to be applied to graph vertex weights for elements that " "are determined to be in contact."}; stk::CommandLineOption edgeWeightMultiplier{m_optionNames.edgeWeightMultiplier, "", - "(Experimental) Scale factor to be applied to all graph edge weights."}; + "Scale factor to be applied to all graph edge weights. This will be " + "automatically set to 1.0 for constant vertex weights, 1.0 for topology " + "vertex weights, and 10.0 for connectivity vertex weights."}; m_commandLineParser.add_required_positional(infile); @@ -231,9 +235,9 @@ void Parser::add_options_to_parser() m_commandLineParser.add_flag(useNested); m_commandLineParser.add_optional(vertexWeightMethod, vertex_weight_method_name(DefaultSettings::vertexWeightMethod)); - m_commandLineParser.add_optional(contactSearchEdgeWeight); - m_commandLineParser.add_optional(contactSearchVertexWeightMultiplier); - m_commandLineParser.add_optional(edgeWeightMultiplier); + m_commandLineParser.add_optional(contactSearchEdgeWeight, DefaultSettings::faceSearchEdgeWeight); + m_commandLineParser.add_optional(contactSearchVertexWeightMultiplier, DefaultSettings::faceSearchVertexMultiplier); + m_commandLineParser.add_optional(edgeWeightMultiplier, DefaultSettings::graphEdgeWeightMultiplier); m_commandLineParser.disallow_unrecognized(); } @@ -298,15 +302,19 @@ void Parser::set_app_type_defaults(BalanceSettings& settings) const ThrowRequireMsg( !(useSM && useSD), "Can't set default settings for multiple apps at the same time"); if (useSM) { - settings.setEdgeWeightForSearch(DefaultSettings::smFaceSearchEdgeWeight); + settings.setVertexWeightMethod(DefaultSettings::smVertexWeightMethod); + settings.setGraphEdgeWeightMultiplier(DefaultSettings::smGraphEdgeWeightMultiplier); settings.setVertexWeightMultiplierForVertexInSearch(DefaultSettings::smFaceSearchVertexMultiplier); - settings.setToleranceFunctionForFaceSearch( - std::make_shared(DefaultSettings::faceSearchRelTol) - ); + settings.setEdgeWeightForSearch(DefaultSettings::smFaceSearchEdgeWeight); + settings.setShouldFixSpiders(DefaultSettings::smFixSpiders); } if (useSD) { - settings.setShouldFixSpiders(true); + settings.setVertexWeightMethod(DefaultSettings::sdVertexWeightMethod); + settings.setGraphEdgeWeightMultiplier(DefaultSettings::sdGraphEdgeWeightMultiplier); + settings.setVertexWeightMultiplierForVertexInSearch(DefaultSettings::sdFaceSearchVertexMultiplier); + settings.setEdgeWeightForSearch(DefaultSettings::sdFaceSearchEdgeWeight); + settings.setShouldFixSpiders(DefaultSettings::sdFixSpiders); } } @@ -424,12 +432,15 @@ void Parser::set_vertex_weight_method(BalanceSettings &settings) const // FIXME: case-insensitive comparison? Need this for decomp method too? if (vertexWeightMethodName == vertex_weight_method_name(VertexWeightMethod::CONSTANT)) { settings.setVertexWeightMethod(VertexWeightMethod::CONSTANT); + settings.setGraphEdgeWeightMultiplier(1.0); } else if (vertexWeightMethodName == vertex_weight_method_name(VertexWeightMethod::TOPOLOGY)) { settings.setVertexWeightMethod(VertexWeightMethod::TOPOLOGY); + settings.setGraphEdgeWeightMultiplier(1.0); } else if (vertexWeightMethodName == vertex_weight_method_name(VertexWeightMethod::CONNECTIVITY)) { settings.setVertexWeightMethod(VertexWeightMethod::CONNECTIVITY); + settings.setGraphEdgeWeightMultiplier(10.0); } else { ThrowErrorMsg("Unrecognized vertex weight method: " << vertexWeightMethodName); diff --git a/packages/stk/stk_balance/stk_balance/setup/Parser.hpp b/packages/stk/stk_balance/stk_balance/setup/Parser.hpp index 1b816b36f1f5..799125753df5 100644 --- a/packages/stk/stk_balance/stk_balance/setup/Parser.hpp +++ b/packages/stk/stk_balance/stk_balance/setup/Parser.hpp @@ -64,10 +64,10 @@ struct OptionNames const std::string vertexWeightBlockMultiplier = "block-weights"; const std::string useNestedDecomp = "use-nested-decomp"; - const std::string vertexWeightMethod = "EXP-vertex-weight-method"; - const std::string contactSearchEdgeWeight = "EXP-contact-search-edge-weight"; - const std::string contactSearchVertexWeightMultiplier = "EXP-contact-search-vertex-weight-mult"; - const std::string edgeWeightMultiplier = "EXP-edge-weight-mult"; + const std::string vertexWeightMethod = "vertex-weight-method"; + const std::string contactSearchEdgeWeight = "contact-search-edge-weight"; + const std::string contactSearchVertexWeightMultiplier = "contact-search-vertex-weight-mult"; + const std::string edgeWeightMultiplier = "edge-weight-mult"; }; class Examples diff --git a/packages/stk/stk_coupling/stk_coupling/OldCommSplitting.cpp b/packages/stk/stk_coupling/stk_coupling/OldCommSplitting.cpp index 96741d47b90f..5a62835f4ad6 100644 --- a/packages/stk/stk_coupling/stk_coupling/OldCommSplitting.cpp +++ b/packages/stk/stk_coupling/stk_coupling/OldCommSplitting.cpp @@ -22,6 +22,7 @@ #include #include +#ifndef STK_HIDE_DEPRECATED_CODE // delete October 2022 namespace stk { namespace coupling @@ -111,3 +112,5 @@ std::pair calc_my_root_and_other_root_ranks(MPI_Comm global, MPI_Comm } } + +#endif \ No newline at end of file diff --git a/packages/stk/stk_coupling/stk_coupling/OldCommSplitting.hpp b/packages/stk/stk_coupling/stk_coupling/OldCommSplitting.hpp index 2c3dda186d6c..25a408983a8f 100644 --- a/packages/stk/stk_coupling/stk_coupling/OldCommSplitting.hpp +++ b/packages/stk/stk_coupling/stk_coupling/OldCommSplitting.hpp @@ -15,19 +15,25 @@ #include #include #include +#include "stk_util/stk_config.h" +#ifndef STK_HIDE_DEPRECATED_CODE // delete October 2022 namespace stk { namespace coupling { +STK_DEPRECATED std::pair calc_my_root_and_other_root_ranks(MPI_Comm global, MPI_Comm local); +STK_DEPRECATED_MSG("prefer stk::couping::are_comms_unequal") bool has_split_comm(MPI_Comm global, MPI_Comm local); +STK_DEPRECATED MPI_Comm split_comm(MPI_Comm parentCommunicator, int color); } } +#endif #endif /* STK_COUPLING_OLD_COMM_SPLITTING_HPP */ diff --git a/packages/stk/stk_coupling/stk_coupling/OldSyncInfo.cpp b/packages/stk/stk_coupling/stk_coupling/OldSyncInfo.cpp index 81e1a1c31352..0a110e61ab04 100644 --- a/packages/stk/stk_coupling/stk_coupling/OldSyncInfo.cpp +++ b/packages/stk/stk_coupling/stk_coupling/OldSyncInfo.cpp @@ -11,6 +11,8 @@ #include #include +#ifndef STK_HIDE_DEPRECATED_CODE // remove October 2022 + namespace stk { namespace coupling @@ -97,3 +99,4 @@ OldSyncInfo::exchange(stk::ParallelMachine global, stk::ParallelMachine local) } } +#endif \ No newline at end of file diff --git a/packages/stk/stk_coupling/stk_coupling/OldSyncInfo.hpp b/packages/stk/stk_coupling/stk_coupling/OldSyncInfo.hpp index 207dc9f528c0..cc361c16b84b 100644 --- a/packages/stk/stk_coupling/stk_coupling/OldSyncInfo.hpp +++ b/packages/stk/stk_coupling/stk_coupling/OldSyncInfo.hpp @@ -14,13 +14,16 @@ #include #include // for CommBuffer #include +#include +#ifndef STK_HIDE_DEPRECATED_CODE namespace stk { namespace coupling { -class OldSyncInfo + +class STK_DEPRECATED OldSyncInfo { public: OldSyncInfo() @@ -137,4 +140,6 @@ inline bool OldSyncInfo::has_value(const std::string & parameterNam } // namespace coupling } // namespace stk +#endif /* STK_HIDE_DEPRECATED_CODE */ + #endif /* STK_COUPLING_OLDSYNCINFO_HPP */ diff --git a/packages/stk/stk_coupling/stk_coupling/SplitComms.cpp b/packages/stk/stk_coupling/stk_coupling/SplitComms.cpp index b7ee2301dde3..0fd53de528b5 100644 --- a/packages/stk/stk_coupling/stk_coupling/SplitComms.cpp +++ b/packages/stk/stk_coupling/stk_coupling/SplitComms.cpp @@ -239,6 +239,15 @@ void SplitCommsImpl::free_comms_impl() m_haveFreedComms = true; } +} + +bool are_comms_unequal(MPI_Comm global, MPI_Comm local) +{ + int result = 0; + MPI_Comm_compare(global, local, &result); + return result == MPI_UNEQUAL; +} + + } } -} \ No newline at end of file diff --git a/packages/stk/stk_coupling/stk_coupling/SplitComms.hpp b/packages/stk/stk_coupling/stk_coupling/SplitComms.hpp index 24fb42cc2965..a01f3c337c02 100644 --- a/packages/stk/stk_coupling/stk_coupling/SplitComms.hpp +++ b/packages/stk/stk_coupling/stk_coupling/SplitComms.hpp @@ -129,6 +129,8 @@ class SplitComms std::shared_ptr m_impl; }; +bool are_comms_unequal(MPI_Comm global, MPI_Comm local); + } } diff --git a/packages/stk/stk_expreval/stk_expreval/Parser.cpp b/packages/stk/stk_expreval/stk_expreval/Parser.cpp index 2ec9843e4372..21c181c6455c 100644 --- a/packages/stk/stk_expreval/stk_expreval/Parser.cpp +++ b/packages/stk/stk_expreval/stk_expreval/Parser.cpp @@ -404,6 +404,21 @@ parseFactor(Eval & eval, return factor; } +bool isRelation(Node* node) +{ + switch (node->m_opcode) { + case OPCODE_EQUAL: + case OPCODE_NOT_EQUAL: + case OPCODE_LESS: + case OPCODE_GREATER: + case OPCODE_LESS_EQUAL: + case OPCODE_GREATER_EQUAL: + return true; + default: + return false; + } +} + Node * parseRelation(Eval & eval, LexemVector::const_iterator from, @@ -445,6 +460,10 @@ parseRelation(Eval & eval, relation->m_left = parseExpression(eval, from, relation_it); relation->m_right = parseExpression(eval, relation_it + 1, to); + if (isRelation(relation->m_left) || isRelation(relation->m_right)) { + throw std::runtime_error("stk::expreval::parseRelation: stk_expreval does not support chained comparisons"); + } + return relation; } diff --git a/packages/stk/stk_expreval/unit_tests/UnitTestEvaluator.cpp b/packages/stk/stk_expreval/unit_tests/UnitTestEvaluator.cpp index e390d73589b7..ce3186277d64 100644 --- a/packages/stk/stk_expreval/unit_tests/UnitTestEvaluator.cpp +++ b/packages/stk/stk_expreval/unit_tests/UnitTestEvaluator.cpp @@ -947,6 +947,26 @@ TEST(UnitTestEvaluator, Ngp_testOpcode_GREATER_EQUAL) EXPECT_DOUBLE_EQ(device_evaluate("2>=(1+2)"), 0); } +TEST(UnitTestEvaluator, noChainedComparisons) +{ + EXPECT_ANY_THROW(evaluate("1 < 2 < 3")); + EXPECT_ANY_THROW(evaluate("3 > 4 > 5")); + EXPECT_ANY_THROW(evaluate("0 < 4 <= 2")); + EXPECT_ANY_THROW(evaluate("6 > 3 >= 1")); + EXPECT_ANY_THROW(evaluate("1 <= 2 < 3")); + EXPECT_ANY_THROW(evaluate("3 >= 4 > 5")); + EXPECT_ANY_THROW(evaluate("1 < x < 3", {{"x", 2}})); + EXPECT_ANY_THROW(evaluate("1 < (2 < 3)")); + EXPECT_ANY_THROW(evaluate("(1 < 2) < 3")); + EXPECT_ANY_THROW(evaluate("(3 > 1) > 0")); + EXPECT_ANY_THROW(evaluate("(2 <= 5) < 0")); + EXPECT_ANY_THROW(evaluate("(7 >= 3) > 1")); + EXPECT_ANY_THROW(evaluate("1 == 1 == 1")); + EXPECT_ANY_THROW(evaluate("(2 == 2) == 2")); + EXPECT_ANY_THROW(evaluate("2 != 1 != 6")); + EXPECT_ANY_THROW(evaluate("(3 != 4) != 8")); +} + TEST(UnitTestEvaluator, testOpcode_UNARY_NOT) { EXPECT_DOUBLE_EQ(evaluate("!0"), 1); diff --git a/packages/stk/stk_integration_tests/cmake_install_test/run_cmake_stk b/packages/stk/stk_integration_tests/cmake_install_test/run_cmake_stk index a1c03e945cec..80c2c1e26ff9 100755 --- a/packages/stk/stk_integration_tests/cmake_install_test/run_cmake_stk +++ b/packages/stk/stk_integration_tests/cmake_install_test/run_cmake_stk @@ -56,7 +56,7 @@ cmake \ -DTrilinos_ENABLE_CXX11=ON \ -DCMAKE_BUILD_TYPE=${build_type^^} \ -DTrilinos_ENABLE_EXPLICIT_INSTANTIATION:BOOL=ON \ --DTrilinos_ENABLE_TESTS:BOOL=OFF \ +-DTrilinos_ENABLE_TESTS:BOOL=ON \ -DTrilinos_ENABLE_ALL_OPTIONAL_PACKAGES=OFF \ -DTrilinos_ALLOW_NO_PACKAGES:BOOL=OFF \ -DTrilinos_ASSERT_MISSING_PACKAGES=OFF \ diff --git a/packages/stk/stk_integration_tests/cmake_install_test/run_cmake_stk_no_stk_mesh b/packages/stk/stk_integration_tests/cmake_install_test/run_cmake_stk_no_stk_mesh index 654835782c45..816469e7c410 100755 --- a/packages/stk/stk_integration_tests/cmake_install_test/run_cmake_stk_no_stk_mesh +++ b/packages/stk/stk_integration_tests/cmake_install_test/run_cmake_stk_no_stk_mesh @@ -31,10 +31,12 @@ cmake \ -DSTK_ENABLE_TESTS:BOOL=ON \ -DTrilinos_ENABLE_STK:BOOL=ON \ -DTrilinos_ENABLE_STKMesh:BOOL=OFF \ +-DTrilinos_ENABLE_STKUtil:BOOL=ON \ +-DTrilinos_ENABLE_STKMath:BOOL=ON \ +-DTrilinos_ENABLE_STKSimd:BOOL=ON \ -DTrilinos_ENABLE_STKCoupling:BOOL=ON \ -DTrilinos_ENABLE_STKTransfer:BOOL=ON \ -DTrilinos_ENABLE_STKSearch:BOOL=ON \ --DTrilinos_ENABLE_STKUtil:BOOL=ON \ -DTrilinos_ENABLE_STKUnit_tests:BOOL=ON \ -DTrilinos_ENABLE_STKDoc_tests:BOOL=ON \ -DTrilinos_ENABLE_Gtest:BOOL=ON \ diff --git a/packages/stk/stk_integration_tests/mock_apps/mock_aria.cpp b/packages/stk/stk_integration_tests/mock_apps/mock_aria.cpp index 04694138e10b..33e6ddf354d5 100644 --- a/packages/stk/stk_integration_tests/mock_apps/mock_aria.cpp +++ b/packages/stk/stk_integration_tests/mock_apps/mock_aria.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include "MockUtils.hpp" #include "StkMesh.hpp" @@ -67,6 +68,9 @@ class MockAria int defaultColor = stk::coupling::string_to_color(m_appName); int color = stk::get_command_line_option(argc, argv, "app-color", defaultColor); + int coupling_version_override = stk::get_command_line_option(argc, argv, "stk_coupling_version", STK_MAX_COUPLING_VERSION); + stk::util::impl::set_coupling_version(coupling_version_override); + stk::util::impl::set_error_on_reset(false); std::string defaultSyncMode = "Send"; std::string syncModeString = stk::get_command_line_option(argc, argv, "sync-mode", defaultSyncMode); m_syncMode = stk::coupling::string_to_sync_mode(syncModeString); diff --git a/packages/stk/stk_integration_tests/mock_apps/mock_fuego.cpp b/packages/stk/stk_integration_tests/mock_apps/mock_fuego.cpp index 56e3d6d84f32..7882ee661d3a 100644 --- a/packages/stk/stk_integration_tests/mock_apps/mock_fuego.cpp +++ b/packages/stk/stk_integration_tests/mock_apps/mock_fuego.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include "MockUtils.hpp" #include "StkMesh.hpp" #include @@ -53,6 +54,9 @@ class MockFuego } int defaultColor = stk::coupling::string_to_color(m_appName); int color = stk::get_command_line_option(argc, argv, "app-color", defaultColor); + int coupling_version_override = stk::get_command_line_option(argc, argv, "stk_coupling_version", STK_MAX_COUPLING_VERSION); + stk::util::impl::set_coupling_version(coupling_version_override); + stk::util::impl::set_error_on_reset(false); m_splitComms = stk::coupling::SplitComms(commWorld, color); MPI_Comm splitComm = m_splitComms.get_split_comm(); diff --git a/packages/stk/stk_integration_tests/mock_apps/mock_salinas.cpp b/packages/stk/stk_integration_tests/mock_apps/mock_salinas.cpp index 82547f10d744..9af0a8fc2e1f 100644 --- a/packages/stk/stk_integration_tests/mock_apps/mock_salinas.cpp +++ b/packages/stk/stk_integration_tests/mock_apps/mock_salinas.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include "MockUtils.hpp" #include "StkMesh.hpp" #include "StkRecvAdapter.hpp" @@ -59,6 +60,10 @@ class MockSalinas int defaultColor = stk::coupling::string_to_color(m_appName); int color = stk::get_command_line_option(argc, argv, "app-color", defaultColor); + int coupling_version_override = stk::get_command_line_option(argc, argv, "stk_coupling_version", STK_MAX_COUPLING_VERSION); + stk::util::impl::set_coupling_version(coupling_version_override); + stk::util::impl::set_error_on_reset(false); + m_splitComms = stk::coupling::SplitComms(commWorld, color); const std::vector& otherColors = m_splitComms.get_other_colors(); if (otherColors.size() != 1) { diff --git a/packages/stk/stk_integration_tests/mock_apps/mock_sparc.cpp b/packages/stk/stk_integration_tests/mock_apps/mock_sparc.cpp index b30738e9f2ab..31dd1e851be3 100644 --- a/packages/stk/stk_integration_tests/mock_apps/mock_sparc.cpp +++ b/packages/stk/stk_integration_tests/mock_apps/mock_sparc.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include "MockUtils.hpp" #include "SparcMesh.hpp" @@ -54,6 +55,9 @@ class MockSparc int defaultColor = stk::coupling::string_to_color(m_appName); int color = stk::get_command_line_option(argc, argv, "app-color", defaultColor); + int coupling_version_override = stk::get_command_line_option(argc, argv, "stk_coupling_version", STK_MAX_COUPLING_VERSION); + stk::util::impl::set_coupling_version(coupling_version_override); + stk::util::impl::set_error_on_reset(false); m_splitComms = stk::coupling::SplitComms(commWorld, color); MPI_Comm splitComm = m_splitComms.get_split_comm(); int myAppRank = stk::parallel_machine_rank(splitComm); diff --git a/packages/stk/stk_integration_tests/stk_balance/IntegrationTestIncrementalRebalance.cpp b/packages/stk/stk_integration_tests/stk_balance/IntegrationTestIncrementalRebalance.cpp index effdd96d18fe..3bd670ff81bc 100644 --- a/packages/stk/stk_integration_tests/stk_balance/IntegrationTestIncrementalRebalance.cpp +++ b/packages/stk/stk_integration_tests/stk_balance/IntegrationTestIncrementalRebalance.cpp @@ -41,7 +41,7 @@ class FieldVertexWeightSettingsWithSearchForParticles : public stk::balance::Gra m_defaultWeight(defaultWeight), m_incrementalRebalance(incrementalRebalance) { - method = "parmetis"; + m_method = "parmetis"; } virtual ~FieldVertexWeightSettingsWithSearchForParticles() = default; @@ -56,8 +56,8 @@ class FieldVertexWeightSettingsWithSearchForParticles : public stk::balance::Gra virtual double getToleranceForFaceSearch() const { return 0.005; } virtual int getGraphVertexWeight(stk::topology type) const { return 1; } virtual double getImbalanceTolerance() const { return 1.05; } - virtual void setDecompMethod(const std::string& input_method) { method = input_method;} - virtual std::string getDecompMethod() const { return method; } + virtual void setDecompMethod(const std::string& input_method) { m_method = input_method;} + virtual std::string getDecompMethod() const { return m_method; } virtual bool incrementalRebalance() const { return m_incrementalRebalance; } virtual double getGraphVertexWeight(stk::mesh::Entity entity, int criteria_index = 0) const diff --git a/packages/stk/stk_integration_tests/stk_balance/IntegrationTestUserSupport.cpp b/packages/stk/stk_integration_tests/stk_balance/IntegrationTestUserSupport.cpp index 077edcacd97c..1112843d8090 100644 --- a/packages/stk/stk_integration_tests/stk_balance/IntegrationTestUserSupport.cpp +++ b/packages/stk/stk_integration_tests/stk_balance/IntegrationTestUserSupport.cpp @@ -508,45 +508,4 @@ TEST(Stkbalance, changeOptions) delete balanceOptions; } -class ToleranceTester : public stk::unit_test_util::simple_fields::MeshFixture -{ -public: - ToleranceTester() - : balanceRunner(get_comm()), - meshFile("gapped_plates.g") - { - balanceRunner.set_filename(meshFile); - balanceRunner.set_output_dir("."); - balanceRunner.set_app_type_defaults("sm"); - } - -protected: - stk::integration_test_utils::StkBalanceRunner balanceRunner; - const std::string meshFile; -}; - -TEST_F(ToleranceTester, smDefaults) -{ - if (get_parallel_size() > 4) return; - - if (get_parallel_size() > 1) - { - balanceRunner.run_end_to_end(); - } - - setup_mesh(meshFile, stk::mesh::BulkData::NO_AUTO_AURA); - for(unsigned i=1; i<101; i++) - { - stk::mesh::EntityId lowerId = i; - stk::mesh::EntityId upperId = i+700; - stk::mesh::Entity lower = get_bulk().get_entity(stk::topology::ELEM_RANK, lowerId); - stk::mesh::Entity upper = get_bulk().get_entity(stk::topology::ELEM_RANK, upperId); - if(get_bulk().is_valid(lower)) - { - EXPECT_TRUE(get_bulk().is_valid(upper)) << "Elements not on same proc: " << lowerId << ", " << upperId; - } - } -} - - } diff --git a/packages/stk/stk_io/stk_io/IossBridge.cpp b/packages/stk/stk_io/stk_io/IossBridge.cpp index eda99307a09a..5552e41b1aaa 100644 --- a/packages/stk/stk_io/stk_io/IossBridge.cpp +++ b/packages/stk/stk_io/stk_io/IossBridge.cpp @@ -43,6 +43,7 @@ #include // for operator<<, basic... #include // for allocator_traits<... #include // for runtime_error +#include #include // for BulkData #include // for comm_mesh_counts #include // for Cartesian, FullTe... @@ -54,6 +55,7 @@ #include // for PartVector, Entit... #include // for make_lower, to_st... #include // for RuntimeWarning +#include // for all_reduce_sum #include // for sort_and_unique #include // for tokenize #include // for type_info @@ -107,7 +109,7 @@ namespace stk { namespace mesh { class Bucket; } } namespace stk { namespace io { bool is_field_on_part(const stk::mesh::FieldBase *field, - const stk::mesh::EntityRank part_type, + const stk::mesh::EntityRank partType, const stk::mesh::Part &part); stk::mesh::EntityRank get_entity_rank(const Ioss::GroupingEntity *entity, @@ -200,53 +202,53 @@ namespace { const stk::mesh::FieldBase *declare_stk_field(stk::mesh::MetaData &meta, stk::mesh::EntityRank type, stk::mesh::Part &part, - const Ioss::Field &io_field, - bool use_cartesian_for_scalar) + const Ioss::Field &ioField, + bool useCartesianForScalar) { - Ioss::Field::BasicType ioFieldType = io_field.get_type(); + Ioss::Field::BasicType ioFieldType = ioField.get_type(); const bool ioFieldTypeIsRecognized = (ioFieldType == Ioss::Field::INTEGER) || (ioFieldType == Ioss::Field::INT64) || (ioFieldType == Ioss::Field::REAL) || (ioFieldType == Ioss::Field::COMPLEX); - ThrowRequireMsg(ioFieldTypeIsRecognized, "Unrecognized field type for IO field '"< void internal_field_data_from_ioss(const stk::mesh::BulkData& mesh, - const Ioss::Field &io_field, + const Ioss::Field &ioField, const stk::mesh::FieldBase *field, std::vector &entities, - Ioss::GroupingEntity *io_entity) + Ioss::GroupingEntity *ioEntity) { - size_t iossNumFieldComponents = io_field.transformed_storage()->component_count(); + size_t iossNumFieldComponents = ioField.transformed_storage()->component_count(); - std::vector io_field_data; - size_t io_entity_count = io_entity->get_field_data(io_field.get_name(), io_field_data); - assert(io_field_data.size() == entities.size() * iossNumFieldComponents); + std::vector ioFieldData; + size_t ioEntityCount = ioEntity->get_field_data(ioField.get_name(), ioFieldData); + assert(ioFieldData.size() == entities.size() * iossNumFieldComponents); - size_t entity_count = entities.size(); + size_t entityCount = entities.size(); - if (io_entity_count != entity_count) { + if (ioEntityCount != entityCount) { std::ostringstream errmsg; errmsg << "ERROR: Field count mismatch for IO field '" - << io_field.get_name() - << "' on " << io_entity->type_string() << " " << io_entity->name() - << ". The IO system has " << io_entity_count - << " entries, but the stk:mesh system has " << entity_count + << ioField.get_name() + << "' on " << ioEntity->type_string() << " " << ioEntity->name() + << ". The IO system has " << ioEntityCount + << " entries, but the stk:mesh system has " << entityCount << " entries. The two counts must match."; throw std::runtime_error(errmsg.str()); } field->sync_to_host(); field->modify_on_host(); - for (size_t i=0; i < entity_count; ++i) { + for (size_t i=0; i < entityCount; ++i) { if (mesh.is_valid(entities[i])) { - T *fld_data = static_cast(stk::mesh::field_data(*field, entities[i])); - if (fld_data !=nullptr) { + T *fldData = static_cast(stk::mesh::field_data(*field, entities[i])); + if (fldData !=nullptr) { const size_t stkNumFieldComponents = stk::mesh::field_scalars_per_entity(*field, entities[i]); const size_t len = std::min(stkNumFieldComponents, iossNumFieldComponents); for(size_t j=0; j void internal_subsetted_field_data_from_ioss(const stk::mesh::BulkData& mesh, - const Ioss::Field &io_field, + const Ioss::Field &ioField, const stk::mesh::FieldBase *field, std::vector &entities, - Ioss::GroupingEntity *io_entity, - const stk::mesh::Part *stk_part) + Ioss::GroupingEntity *ioEntity, + const stk::mesh::Part *stkPart) { - size_t field_component_count = io_field.transformed_storage()->component_count(); - std::vector io_field_data; - size_t io_entity_count = io_entity->get_field_data(io_field.get_name(), io_field_data); - assert(io_field_data.size() == entities.size() * field_component_count); - size_t entity_count = entities.size(); - if (io_entity_count != entity_count) { + size_t field_componentCount = ioField.transformed_storage()->component_count(); + std::vector ioFieldData; + size_t ioEntityCount = ioEntity->get_field_data(ioField.get_name(), ioFieldData); + assert(ioFieldData.size() == entities.size() * field_componentCount); + size_t entityCount = entities.size(); + if (ioEntityCount != entityCount) { std::ostringstream errmsg; errmsg << "ERROR: Field count mismatch for IO field '" - << io_field.get_name() - << "' on " << io_entity->type_string() << " " << io_entity->name() - << ". The IO system has " << io_entity_count - << " entries, but the stk:mesh system has " << entity_count + << ioField.get_name() + << "' on " << ioEntity->type_string() << " " << ioEntity->name() + << ". The IO system has " << ioEntityCount + << " entries, but the stk:mesh system has " << entityCount << " entries. The two counts must match."; throw std::runtime_error(errmsg.str()); } - stk::mesh::MetaData &meta = stk::mesh::MetaData::get(*stk_part); - stk::mesh::Selector selector = (meta.globally_shared_part() | meta.locally_owned_part()) & *stk_part; + stk::mesh::MetaData &meta = stk::mesh::MetaData::get(*stkPart); + stk::mesh::Selector selector = (meta.globally_shared_part() | meta.locally_owned_part()) & *stkPart; field->sync_to_host(); field->modify_on_host(); - for (size_t i=0; i < entity_count; ++i) { + for (size_t i=0; i < entityCount; ++i) { if (mesh.is_valid(entities[i])) { const stk::mesh::Bucket &bucket = mesh.bucket(entities[i]); if (selector(bucket)) { - T *fld_data = static_cast(stk::mesh::field_data(*field, entities[i])); - if (fld_data !=nullptr) { - for(size_t j=0; j(stk::mesh::field_data(*field, entities[i])); + if (fldData !=nullptr) { + for(size_t j=0; j void internal_field_data_to_ioss(const stk::mesh::BulkData& mesh, - const Ioss::Field &io_field, + const Ioss::Field &ioField, const stk::mesh::FieldBase *field, std::vector &entities, - Ioss::GroupingEntity *io_entity) + Ioss::GroupingEntity *ioEntity) { - size_t iossFieldLength = io_field.transformed_storage()->component_count(); - size_t entity_count = entities.size(); + size_t iossFieldLength = ioField.transformed_storage()->component_count(); + size_t entityCount = entities.size(); - std::vector io_field_data(entity_count*iossFieldLength); + std::vector ioFieldData(entityCount*iossFieldLength); field->sync_to_host(); - for (size_t i=0; i < entity_count; ++i) { + for (size_t i=0; i < entityCount; ++i) { if (mesh.is_valid(entities[i]) && mesh.entity_rank(entities[i]) == field->entity_rank()) { - const T *fld_data = static_cast(stk::mesh::field_data(*field, entities[i])); - if (fld_data != nullptr) { + const T *fldData = static_cast(stk::mesh::field_data(*field, entities[i])); + if (fldData != nullptr) { size_t stkFieldLength = stk::mesh::field_scalars_per_entity(*field, entities[i]); - ThrowRequireMsg((iossFieldLength >= stkFieldLength), "Field "<name()<<" scalars-per-entity="<name()); + ThrowRequireMsg((iossFieldLength >= stkFieldLength), "Field "<name()<<" scalars-per-entity="<name()); size_t length = std::min(iossFieldLength, stkFieldLength); for(size_t j=0; jput_field_data(io_field.get_name(), io_field_data); - assert(io_field_data.size() == entities.size() * iossFieldLength); + size_t ioEntityCount = ioEntity->put_field_data(ioField.get_name(), ioFieldData); + assert(ioFieldData.size() == entities.size() * iossFieldLength); - if (io_entity_count != entity_count) { + if (ioEntityCount != entityCount) { std::ostringstream errmsg; errmsg << "ERROR: Field count mismatch for IO field '" - << io_field.get_name() - << "' on " << io_entity->type_string() << " " << io_entity->name() - << ". The IO system has " << io_entity_count - << " entries, but the stk:mesh system has " << entity_count + << ioField.get_name() + << "' on " << ioEntity->type_string() << " " << ioEntity->name() + << ". The IO system has " << ioEntityCount + << " entries, but the stk:mesh system has " << entityCount << " entries. The two counts must match."; throw std::runtime_error(errmsg.str()); } @@ -349,10 +351,9 @@ namespace { while (I != fields.end()) { const stk::mesh::FieldBase *f = *I ; ++I ; - bool valid_part_field = stk::io::is_valid_part_field(f, rank, part, Ioss::Field::TRANSIENT); - bool valid_part_field_by_bucket = false; //stk::io::is_valid_part_field_by_bucket(f, rank, part, Ioss::Field::TRANSIENT); + bool validPartField = stk::io::is_valid_part_field(f, rank, part, Ioss::Field::TRANSIENT); - if (valid_part_field || valid_part_field_by_bucket) { + if (validPartField) { return true; } } @@ -362,9 +363,9 @@ namespace { void add_canonical_name_property(Ioss::GroupingEntity* ge, stk::mesh::Part& part) { if(stk::io::has_alternate_part_name(part)) { - std::string canon_name = stk::io::get_alternate_part_name(part); - if(canon_name != ge->name()) { - ge->property_add(Ioss::Property("db_name", canon_name)); + std::string canonName = stk::io::get_alternate_part_name(part); + if(canonName != ge->name()) { + ge->property_add(Ioss::Property("db_name", canonName)); } } } @@ -374,9 +375,9 @@ namespace { std::string topoString("original_topology_type"); if(stk::io::has_original_topology_type(part)) { - std::string orig_topo = stk::io::get_original_topology_type(part); - if(!ge->property_exists(topoString) || (orig_topo != ge->get_property(topoString).get_string())) { - ge->property_add(Ioss::Property(topoString, orig_topo)); + std::string origTopology = stk::io::get_original_topology_type(part); + if(!ge->property_exists(topoString) || (origTopology != ge->get_property(topoString).get_string())) { + ge->property_add(Ioss::Property(topoString, origTopology)); } } } @@ -389,12 +390,12 @@ namespace { const std::vector& additionalFields = params.get_additional_attribute_fields(); - Ioss::Region & io_region = params.io_region(); + Ioss::Region & ioRegion = params.io_region(); stk::mesh::MetaData & meta = stk::mesh::MetaData::get(part); - Ioss::ElementBlock* io_block = io_region.get_element_block(stk::io::getPartName(part)); + Ioss::ElementBlock* ioBlock = ioRegion.get_element_block(stk::io::getPartName(part)); for(const stk::io::FieldAndName& attribute : additionalFields) { - if(attribute.apply_to_entity(io_block)) { + if(attribute.apply_to_entity(ioBlock)) { const stk::mesh::FieldBase *stkField = attribute.field(); ThrowRequireMsg(stkField->entity_rank() == rank, "Input attribute field: " + stkField->name() + " is not ELEM_RANK"); @@ -402,19 +403,18 @@ namespace { relevantParts.push_back(&part); stk::io::superset_mesh_parts(part, relevantParts); relevantParts.push_back(&meta.universal_part()); -// relevantParts.push_back(®ion->mesh_meta_data().active_part()); if(stkField->defined_on_any(relevantParts)) { const std::string dbName = attribute.db_name(); - if(!io_block->field_exists(dbName)) { - int eb_size = io_block->get_property("entity_count").get_int(); + if(!ioBlock->field_exists(dbName)) { + int ebSize = ioBlock->get_property("entity_count").get_int(); const stk::mesh::FieldBase::Restriction &res = stk::mesh::find_restriction(*stkField, rank, relevantParts); ThrowRequireMsg(res.num_scalars_per_entity() != 0, "Could not find a restriction for field: " + stkField->name() + " on part: " + part.name()); stk::io::FieldType fieldType; stk::io::get_io_field_type(stkField, res, &fieldType); - io_block->field_add(Ioss::Field(dbName, fieldType.type, fieldType.name, Ioss::Field::ATTRIBUTE, eb_size)); + ioBlock->field_add(Ioss::Field(dbName, fieldType.type, fieldType.name, Ioss::Field::ATTRIBUTE, ebSize)); } } } @@ -429,8 +429,8 @@ namespace { const std::vector& additionalFields = params.get_additional_attribute_fields(); - Ioss::Region & io_region = params.io_region(); - Ioss::ElementBlock* ioBlock = io_region.get_element_block(stk::io::getPartName(part)); + Ioss::Region & ioRegion = params.io_region(); + Ioss::ElementBlock* ioBlock = ioRegion.get_element_block(stk::io::getPartName(part)); for(const stk::io::FieldAndName& attribute : additionalFields) { if(attribute.apply_to_entity(ioBlock)) { @@ -447,13 +447,13 @@ namespace { } } - bool contain(const stk::mesh::BulkData& stkmesh, stk::mesh::Entity elem, const stk::mesh::Part* parent_block) + bool contain(const stk::mesh::BulkData& stkmesh, stk::mesh::Entity elem, const stk::mesh::Part* parentBlock) { const stk::mesh::PartVector& parts = stkmesh.bucket(elem).supersets(); - unsigned int part_id = parent_block->mesh_meta_data_ordinal(); + unsigned int partId = parentBlock->mesh_meta_data_ordinal(); auto i = parts.begin(); - for(; i != parts.end() && (*i)->mesh_meta_data_ordinal() != part_id; ++i) + for(; i != parts.end() && (*i)->mesh_meta_data_ordinal() != partId; ++i) ; return (i != parts.end()); @@ -504,13 +504,13 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta bool use_cartesian_for_scalar) { std::string name = io_field.get_name(); - stk::mesh::FieldBase *field_ptr = meta.get_field(type, name); + stk::mesh::FieldBase *fieldPtr = meta.get_field(type, name); // If the field has already been declared, don't redeclare it. - if (field_ptr != nullptr && stk::io::is_field_on_part(field_ptr, type, part)) { - return field_ptr; + if (fieldPtr != nullptr && stk::io::is_field_on_part(fieldPtr, type, part)) { + return fieldPtr; } - stk::topology::rank_t entity_rank = static_cast(type); + stk::topology::rank_t entityRank = static_cast(type); if (meta.is_using_simple_fields()) { const Ioss::VariableType* varType = io_field.transformed_storage(); @@ -526,7 +526,7 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta } std::string field_type = varType->name(); - stk::mesh::Field & field = meta.declare_field(entity_rank, name); + stk::mesh::Field & field = meta.declare_field(entityRank, name); stk::mesh::put_field_on_mesh(field, part, numComponents, numCopies, nullptr); const int oldVarTypeSize = has_field_output_type(field) ? get_field_output_type(field)->component_count() : 0; @@ -536,57 +536,57 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta set_field_output_type(field, varType); } - field_ptr = &field; + fieldPtr = &field; } else { const Ioss::VariableType* varType = io_field.transformed_storage(); - size_t num_components = varType->component_count(); + size_t numComponents = varType->component_count(); const Ioss::CompositeVariableType* compVarType = dynamic_cast(varType); if (compVarType != nullptr) { varType = compVarType->GetBaseType(); } - std::string field_type = varType->name(); + std::string fieldType = varType->name(); - if (field_type == "scalar" || num_components == 1) { + if (fieldType == "scalar" || numComponents == 1) { if (!use_cartesian_for_scalar) { - stk::mesh::Field & field = meta.declare_field>(entity_rank, name); + stk::mesh::Field & field = meta.declare_field>(entityRank, name); stk::mesh::put_field_on_mesh(field, part, nullptr); - field_ptr = &field; + fieldPtr = &field; } else { stk::mesh::Field & field = - meta.declare_field>(entity_rank, name); + meta.declare_field>(entityRank, name); stk::mesh::put_field_on_mesh(field, part, 1, nullptr); - field_ptr = &field; + fieldPtr = &field; } } - else if (stk::string_starts_with(sierra::make_lower(field_type), "real[")) { - stk::mesh::Field & field = meta.declare_field>(entity_rank, name); - stk::mesh::put_field_on_mesh(field, part, num_components, nullptr); - field_ptr = &field; + else if (stk::string_starts_with(sierra::make_lower(fieldType), "real[")) { + stk::mesh::Field & field = meta.declare_field>(entityRank, name); + stk::mesh::put_field_on_mesh(field, part, numComponents, nullptr); + fieldPtr = &field; } - else if ((field_type == "vector_2d") || (field_type == "vector_3d")) { - field_ptr = add_stk_field(meta, name, entity_rank, part, num_components); + else if ((fieldType == "vector_2d") || (fieldType == "vector_3d")) { + fieldPtr = add_stk_field(meta, name, entityRank, part, numComponents); } - else if (field_type == "sym_tensor_33") { - field_ptr = add_stk_field(meta, name, entity_rank, part, num_components); + else if (fieldType == "sym_tensor_33") { + fieldPtr = add_stk_field(meta, name, entityRank, part, numComponents); } - else if (field_type == "full_tensor_36") { - field_ptr = add_stk_field(meta, name, entity_rank, part, num_components); + else if (fieldType == "full_tensor_36") { + fieldPtr = add_stk_field(meta, name, entityRank, part, numComponents); } - else if ((field_type == "matrix_22") || (field_type == "matrix_33")) { - field_ptr = add_stk_field(meta, name, entity_rank, part, num_components); + else if ((fieldType == "matrix_22") || (fieldType == "matrix_33")) { + fieldPtr = add_stk_field(meta, name, entityRank, part, numComponents); } else { - field_ptr = add_stk_field(meta, name, entity_rank, part, num_components); + fieldPtr = add_stk_field(meta, name, entityRank, part, numComponents); } } - if (field_ptr != nullptr) { - stk::io::set_field_role(*field_ptr, io_field.get_role()); + if (fieldPtr != nullptr) { + stk::io::set_field_role(*fieldPtr, io_field.get_role()); } - return field_ptr; + return fieldPtr; } } //namespace impl @@ -781,9 +781,9 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta return part.name(); } - stk::mesh::Part *getPart(const stk::mesh::MetaData& meta_data, const std::string& name) + stk::mesh::Part *getPart(const stk::mesh::MetaData& metaData, const std::string& name) { - const mesh::PartVector & parts = meta_data.get_parts(); + const mesh::PartVector & parts = metaData.get_parts(); for (unsigned ii=0; ii < parts.size(); ++ii) { stk::mesh::Part *pp = parts[ii]; @@ -794,7 +794,7 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta return 0; } - Ioss::GroupingEntity* get_grouping_entity(const Ioss::Region& region, stk::mesh::Part& part) + Ioss::GroupingEntity* get_grouping_entity(const Ioss::Region& region, const stk::mesh::Part& part) { if(!stk::io::is_part_io_part(part)) { return nullptr; } @@ -841,7 +841,7 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta return types; } - std::vector get_ioss_entity_types(stk::mesh::Part& part) + std::vector get_ioss_entity_types(const stk::mesh::Part& part) { return get_ioss_entity_types(part.mesh_meta_data(), part.primary_entity_rank()); } @@ -851,20 +851,20 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta return entity->get_database()->int_byte_size_api(); } - void initialize_spatial_dimension(stk::mesh::MetaData & meta, size_t spatial_dimension, - const std::vector &entity_rank_names) + void initialize_spatial_dimension(stk::mesh::MetaData & meta, size_t spatialDimension, + const std::vector &entityRankNames) { if (!meta.is_initialized() ) { - meta.initialize(spatial_dimension, entity_rank_names); + meta.initialize(spatialDimension, entityRankNames); } } bool is_field_on_part(const stk::mesh::FieldBase *field, - const stk::mesh::EntityRank part_type, + const stk::mesh::EntityRank partType, const stk::mesh::Part &part) { const stk::mesh::MetaData &meta = stk::mesh::MetaData::get(part); - const stk::mesh::FieldBase::Restriction &res = stk::mesh::find_restriction(*field, part_type, part); + const stk::mesh::FieldBase::Restriction &res = stk::mesh::find_restriction(*field, partType, part); if (res.num_scalars_per_entity() > 0) { // The field exists on the current 'part'. Now check (for // node types only) whether the 'part' is *either* the @@ -880,12 +880,12 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta // exist on. There may be a problem if we start using element // sets ..., but wait until we get to that point; current code // works with current entity set. - if (part_type != stk::topology::NODE_RANK || part == meta.universal_part()) { + if (partType != stk::topology::NODE_RANK || part == meta.universal_part()) { return true; } - const stk::mesh::FieldBase::Restriction &res_universe = stk::mesh::find_restriction(*field, part_type, meta.universal_part()); - if (res_universe.num_scalars_per_entity() <= 0) { + const stk::mesh::FieldBase::Restriction &universalRes = stk::mesh::find_restriction(*field, partType, meta.universal_part()); + if (universalRes.num_scalars_per_entity() <= 0) { // Field exists on current part, but not on the universal // set (and this part is not the universal part) return true; @@ -895,9 +895,9 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta } bool is_valid_part_field(const stk::mesh::FieldBase *field, - const stk::mesh::EntityRank part_type, + const stk::mesh::EntityRank partType, const stk::mesh::Part &part, - const Ioss::Field::RoleType filter_role) + const Ioss::Field::RoleType filterRole) { const Ioss::Field::RoleType *role = stk::io::get_field_role(*field); @@ -905,10 +905,10 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta return false; } - if (role != nullptr && *role != filter_role) + if (role != nullptr && *role != filterRole) return false; - return is_field_on_part(field, part_type, part); + return is_field_on_part(field, partType, part); } void assign_generic_field_type(const stk::mesh::FieldRestriction &res, FieldType *result) @@ -1266,9 +1266,9 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta !is_part_assembly_io_part(part); } - stk::topology get_start_topology(const Ioss::ElementTopology* topology, unsigned mesh_spatial_dimension) + stk::topology get_start_topology(const Ioss::ElementTopology* topology, unsigned meshSpatialDimension) { - if (topology->is_element() && topology->spatial_dimension() == (int)mesh_spatial_dimension) + if (topology->is_element() && topology->spatial_dimension() == (int)meshSpatialDimension) { return stk::topology::BEGIN_ELEMENT_RANK; } @@ -1276,10 +1276,10 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta } stk::topology map_ioss_topology_to_stk(const Ioss::ElementTopology *topology, - unsigned mesh_spatial_dimension) + unsigned meshSpatialDimension) { - stk::topology begin_topo = get_start_topology(topology, mesh_spatial_dimension); - for (stk::topology topo=begin_topo; topo < stk::topology::END_TOPOLOGY; ++topo) { + stk::topology beginTopo = get_start_topology(topology, meshSpatialDimension); + for (stk::topology topo=beginTopo; topo < stk::topology::END_TOPOLOGY; ++topo) { if (topology->is_alias(topo.name())) { bool bothAreElements = topology->is_element() && topo.rank()==stk::topology::ELEM_RANK; @@ -1301,8 +1301,8 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta std::string map_stk_topology_to_ioss(stk::topology topo) { - Ioss::ElementTopology *ioss_topo = Ioss::ElementTopology::factory(topo.name(), true); - return ioss_topo != nullptr ? ioss_topo->name() : "invalid"; + Ioss::ElementTopology *iossTopo = Ioss::ElementTopology::factory(topo.name(), true); + return iossTopo != nullptr ? iossTopo->name() : "invalid"; } template @@ -1341,7 +1341,7 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta } } - void internal_part_processing(Ioss::GroupingEntity *entity, stk::mesh::MetaData &meta) + void internal_part_processing(Ioss::GroupingEntity *entity, stk::mesh::MetaData &meta, TopologyErrorHandler handler) { if (include_entity(entity)) { stk::mesh::Part & part = declare_stk_part(entity, meta); @@ -1352,7 +1352,7 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta } } - void internal_part_processing(Ioss::EntityBlock *entity, stk::mesh::MetaData &meta) + void internal_part_processing(Ioss::EntityBlock *entity, stk::mesh::MetaData &meta, TopologyErrorHandler handler) { if (include_entity(entity)) { mesh::EntityRank type = get_entity_rank(entity, meta); @@ -1389,9 +1389,11 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta set_original_topology_type_from_ioss(entity, *part); } - stk::topology stk_topology = map_ioss_topology_to_stk(topology, meta.spatial_dimension()); - if (stk_topology != stk::topology::INVALID_TOPOLOGY) { - stk::mesh::set_topology(*part, stk_topology); + stk::topology stkTopology = map_ioss_topology_to_stk(topology, meta.spatial_dimension()); + if (stkTopology != stk::topology::INVALID_TOPOLOGY) { + stk::mesh::set_topology(*part, stkTopology); + } else { + handler(*part); } stk::io::define_io_fields(entity, Ioss::Field::ATTRIBUTE, *part, type); } @@ -1407,21 +1409,21 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta const stk::mesh::FieldBase::Restriction &res, Ioss::GroupingEntity *entity, FieldAndName &namedField, - const Ioss::Field::RoleType filter_role) + const Ioss::Field::RoleType filterRole) { - FieldType field_type; - get_io_field_type(f, res, &field_type); - if ((field_type.type != Ioss::Field::INVALID) && namedField.apply_to_entity(entity)) { - size_t entity_size = entity->get_property("entity_count").get_int(); + FieldType fieldType; + get_io_field_type(f, res, &fieldType); + if ((fieldType.type != Ioss::Field::INVALID) && namedField.apply_to_entity(entity)) { + size_t entitySize = entity->get_property("entity_count").get_int(); std::string name = namedField.db_name(); - std::string storage = field_type.name; + std::string storage = fieldType.name; if (namedField.get_use_alias()) { Ioss::VariableType::get_field_type_mapping(f->name(), &storage); } - entity->field_add(Ioss::Field(name, field_type.type, storage, - field_type.copies, filter_role, entity_size)); + entity->field_add(Ioss::Field(name, fieldType.type, storage, + fieldType.copies, filterRole, entitySize)); if (entity->type() == Ioss::NODEBLOCK) { namedField.m_forceNodeblockOutput = true; } @@ -1429,10 +1431,10 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta } bool is_valid_nodeset_field(const stk::mesh::Part &part, - const stk::mesh::EntityRank part_type, + const stk::mesh::EntityRank partType, Ioss::GroupingEntity *entity, FieldAndName &namedField, - const Ioss::Field::RoleType filter_role) + const Ioss::Field::RoleType filterRole) { bool isValid = false; @@ -1440,13 +1442,13 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta const Ioss::Field::RoleType *role = stk::io::get_field_role(*f); bool isNodeset = (entity != nullptr) && (entity->type() == Ioss::NODESET); - bool hasMatchingFieldRole = (role != nullptr) ? (*role == filter_role) : false; - bool hasMatchingEntityRank = f->entity_rank() == part_type; + bool hasMatchingFieldRole = (role != nullptr) ? (*role == filterRole) : false; + bool hasMatchingEntityRank = f->entity_rank() == partType; bool isNodesetField = namedField.is_nodeset_variable(); if(isNodeset && hasMatchingFieldRole && hasMatchingEntityRank && isNodesetField) { - if(namedField.apply_to_entity(entity) /*sideblockPart->primary_entity_rank() == meta.side_rank()*/) { + if(namedField.apply_to_entity(entity)) { const stk::mesh::EntityRank nodeRank = stk::topology::NODE_RANK; const std::vector & restrictions = f->restrictions(); @@ -1460,25 +1462,24 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta } void ioss_add_field_to_derived_nodeset(const stk::mesh::Part &part, - const stk::mesh::EntityRank part_type, + const stk::mesh::EntityRank partType, Ioss::GroupingEntity *entity, FieldAndName &namedField, - const Ioss::Field::RoleType filter_role) + const Ioss::Field::RoleType filterRole) { - const bool isValid = is_valid_nodeset_field(part, part_type, entity, namedField, filter_role); + const bool isValid = is_valid_nodeset_field(part, partType, entity, namedField, filterRole); if(isValid) { const stk::mesh::FieldBase *f = namedField.field(); - const stk::mesh::FieldBase::Restriction *res = nullptr; //find_restriction_by_bucket(meta, *f, part, nodeRank); + const stk::mesh::FieldBase::Restriction *res = nullptr; const std::vector & restrictions = f->restrictions(); - if (restrictions.size() > 0 && f->entity_rank() == stk::topology::NODE_RANK) - { + if (restrictions.size() > 0 && f->entity_rank() == stk::topology::NODE_RANK) { res = &restrictions[0]; } if(res != nullptr) { - ioss_add_field_to_entity(f, *res, entity, namedField, filter_role); + ioss_add_field_to_entity(f, *res, entity, namedField, filterRole); } } } @@ -1512,27 +1513,27 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta } void ioss_add_fields_for_subpart(const stk::mesh::Part &part, - const stk::mesh::EntityRank part_type, + const stk::mesh::EntityRank partType, Ioss::GroupingEntity *entity, FieldAndName &namedField, - const Ioss::Field::RoleType filter_role) + const Ioss::Field::RoleType filterRole) { - stk::mesh::EntityRank part_rank = part_primary_entity_rank(part); + stk::mesh::EntityRank partRank = part_primary_entity_rank(part); stk::mesh::PartVector blocks = part.subsets(); const stk::mesh::FieldBase *f = namedField.field(); sort_by_descending_field_size(blocks, *f); for (size_t j = 0; j < blocks.size(); j++) { - mesh::Part & side_block_part = *blocks[j]; - bool validSubsetPartField = stk::io::is_valid_part_field(f, part_type, side_block_part, filter_role); + mesh::Part & sideBlockPart = *blocks[j]; + bool validSubsetPartField = stk::io::is_valid_part_field(f, partType, sideBlockPart, filterRole); Ioss::GroupingEntity* subEntity = entity; if (validSubsetPartField) { - const stk::mesh::FieldBase::Restriction &res = stk::mesh::find_restriction(*f, part_type, side_block_part); - if (part_rank < stk::topology::ELEM_RANK) { + const stk::mesh::FieldBase::Restriction &res = stk::mesh::find_restriction(*f, partType, sideBlockPart); + if (partRank < stk::topology::ELEM_RANK) { Ioss::Region* region = entity->get_database()->get_region(); if (nullptr != region) { - Ioss::GroupingEntity* tempEntity = region->get_entity(side_block_part.name()); + Ioss::GroupingEntity* tempEntity = region->get_entity(sideBlockPart.name()); if (nullptr != tempEntity) { const bool isEntityNodeRankOrSideSetBlock = (tempEntity->type() == Ioss::NODESET || @@ -1541,58 +1542,64 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta if (isEntityNodeRankOrSideSetBlock) { subEntity = tempEntity; } + + if(tempEntity->type() == Ioss::SIDEBLOCK && entity->type() == Ioss::SIDEBLOCK && tempEntity != entity) { + subEntity = nullptr; + } } } } - bool validIossField = namedField.is_nodeset_variable() ? (subEntity->type() == Ioss::NODESET) : true; - if((subEntity != nullptr) && validIossField) { - if (subEntity != entity && subEntity->type() == Ioss::SIDEBLOCK) { + bool validIossField = (subEntity == nullptr) ? false : + (namedField.is_nodeset_variable() ? (subEntity->type() == Ioss::NODESET) : true); + + if(validIossField) { + if(subEntity->type() == Ioss::SIDEBLOCK && subEntity != entity) { const bool shouldAddFieldToParent = - field_should_be_added(namedField.db_name(), - res.num_scalars_per_entity(), entity); + field_should_be_added(namedField.db_name(), res.num_scalars_per_entity(), entity); if (shouldAddFieldToParent) { - ioss_add_field_to_entity(f, res, entity, namedField, filter_role); + ioss_add_field_to_entity(f, res, entity, namedField, filterRole); } } - ioss_add_field_to_entity(f, res, subEntity, namedField, filter_role); + + ioss_add_field_to_entity(f, res, subEntity, namedField, filterRole); } } } } void ioss_add_fields(const stk::mesh::Part &part, - const stk::mesh::EntityRank part_type, + const stk::mesh::EntityRank partType, Ioss::GroupingEntity *entity, std::vector &namedFields, - const Ioss::Field::RoleType filter_role) + const Ioss::Field::RoleType filterRole) { - stk::mesh::EntityRank part_rank = part_primary_entity_rank(part); + stk::mesh::EntityRank partRank = part_primary_entity_rank(part); const stk::mesh::PartVector &blocks = part.subsets(); - bool check_subparts = (part_rank == stk::topology::NODE_RANK || - part_rank == stk::topology::EDGE_RANK || - part_rank == stk::topology::FACE_RANK) && + bool checkSubparts = (partRank == stk::topology::NODE_RANK || + partRank == stk::topology::EDGE_RANK || + partRank == stk::topology::FACE_RANK) && (blocks.size() > 0); for (size_t i=0; i &namedFields) { const std::vector &fields = meta.get_fields(); @@ -1600,29 +1607,29 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta std::vector::const_iterator fieldIterator = fields.begin(); for(;fieldIterator != fields.end();++fieldIterator) { const Ioss::Field::RoleType *role = stk::io::get_field_role(**fieldIterator); - if (role && *role == filter_role) { + if (role != nullptr && *role == filterRole) { namedFields.emplace_back(*fieldIterator, (*fieldIterator)->name()); } } } void ioss_add_fields(const stk::mesh::Part &part, - const stk::mesh::EntityRank part_type, + const stk::mesh::EntityRank partType, Ioss::GroupingEntity *entity, - const Ioss::Field::RoleType filter_role) + const Ioss::Field::RoleType filterRole) { std::vector namedFields; - stk::io::getNamedFields(mesh::MetaData::get(part), entity, filter_role, namedFields); + stk::io::getNamedFields(mesh::MetaData::get(part), entity, filterRole, namedFields); - ioss_add_fields(part, part_type, entity, namedFields, filter_role); + ioss_add_fields(part, partType, entity, namedFields, filterRole); } void ioss_add_fields(const stk::mesh::Part &part, - const stk::mesh::EntityRank part_type, + const stk::mesh::EntityRank partType, Ioss::GroupingEntity *entity, std::vector &namedFields) { - ioss_add_fields(part, part_type, entity, namedFields, Ioss::Field::Field::TRANSIENT); + ioss_add_fields(part, partType, entity, namedFields, Ioss::Field::Field::TRANSIENT); } @@ -1638,13 +1645,13 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta void define_io_fields(Ioss::GroupingEntity *entity, Ioss::Field::RoleType role, stk::mesh::Part &part, - stk::mesh::EntityRank part_type) + stk::mesh::EntityRank partType) { stk::mesh::MetaData &meta = mesh::MetaData::get(part); - bool use_cartesian_for_scalar = false; + bool useCartesianForScalar = false; if (role == Ioss::Field::ATTRIBUTE) - use_cartesian_for_scalar = true; + useCartesianForScalar = true; Ioss::NameList names; entity->field_describe(role, &names); @@ -1660,8 +1667,8 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta // \todo IMPLEMENT Need to determine whether these are // multi-state fields or constant, or interpolated, or ... - Ioss::Field io_field = entity->get_field(*I); - declare_stk_field(meta, part_type, part, io_field, use_cartesian_for_scalar); + Ioss::Field ioField = entity->get_field(*I); + declare_stk_field(meta, partType, part, ioField, useCartesianForScalar); } } @@ -1685,94 +1692,94 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta delete_selector_property(region.get_nodesets()); delete_selector_property(region.get_commsets()); - const Ioss::SideSetContainer& side_sets = region.get_sidesets(); - for(Ioss::SideSetContainer::const_iterator it = side_sets.begin(); - it != side_sets.end(); ++it) { + const Ioss::SideSetContainer& sideSets = region.get_sidesets(); + for(Ioss::SideSetContainer::const_iterator it = sideSets.begin(); + it != sideSets.end(); ++it) { Ioss::SideSet *sset = *it; delete_selector_property(*it); delete_selector_property(sset->get_side_blocks()); } } - void delete_selector_property(Ioss::GroupingEntity *io_entity) + void delete_selector_property(Ioss::GroupingEntity *ioEntity) { // If the Ioss::GroupingEntity has a property named 'selector' of // type 'pointer', delete the pointer and remove the property. - if (io_entity->property_exists(s_internal_selector_name)) { - mesh::Selector *select = reinterpret_cast(io_entity->get_property(s_internal_selector_name).get_pointer()); + if (ioEntity->property_exists(s_internalSelectorName)) { + mesh::Selector *select = reinterpret_cast(ioEntity->get_property(s_internalSelectorName).get_pointer()); delete select; - io_entity->property_erase(s_internal_selector_name); + ioEntity->property_erase(s_internalSelectorName); } } template - void get_entity_list(Ioss::GroupingEntity *io_entity, - stk::mesh::EntityRank part_type, + void get_entity_list(Ioss::GroupingEntity *ioEntity, + stk::mesh::EntityRank partType, const stk::mesh::BulkData &bulk, std::vector &entities) { - if (io_entity->type() == Ioss::SIDEBLOCK) { - std::vector elem_side ; - io_entity->get_field_data("element_side", elem_side); - size_t side_count = elem_side.size() / 2; - for(size_t is=0; istype() == Ioss::SIDEBLOCK) { + std::vector elemSide ; + ioEntity->get_field_data("element_side", elemSide); + size_t sideCount = elemSide.size() / 2; + for(size_t is=0; is ids ; - io_entity->get_field_data("ids", ids); + std::vector ids ; + ioEntity->get_field_data("ids", ids); - size_t count = ids.size(); - entities.reserve(count); + size_t count = ids.size(); + entities.reserve(count); - for(size_t i=0; i &entities) { - ThrowRequireMsg(io_entity->get_database()->is_input(), "Database is output type"); - if (db_api_int_size(io_entity) == 4) { - get_entity_list(io_entity, part_type, bulk, entities); + ThrowRequireMsg(ioEntity->get_database()->is_input(), "Database is output type"); + if (db_api_int_size(ioEntity) == 4) { + get_entity_list(ioEntity, partType, bulk, entities); } else { - get_entity_list(io_entity, part_type, bulk, entities); + get_entity_list(ioEntity, partType, bulk, entities); } } - void get_output_entity_list(Ioss::GroupingEntity *io_entity, - stk::mesh::EntityRank part_type, + void get_output_entity_list(Ioss::GroupingEntity *ioEntity, + stk::mesh::EntityRank partType, OutputParams ¶ms, std::vector &entities) { const stk::mesh::BulkData &bulk = params.bulk_data(); - ThrowRequireMsg(!io_entity->get_database()->is_input(), "Database is input type"); - assert(io_entity->property_exists(s_internal_selector_name)); + ThrowRequireMsg(!ioEntity->get_database()->is_input(), "Database is input type"); + assert(ioEntity->property_exists(s_internalSelectorName)); - mesh::Selector *select = reinterpret_cast(io_entity->get_property(s_internal_selector_name).get_pointer()); + mesh::Selector *select = reinterpret_cast(ioEntity->get_property(s_internalSelectorName).get_pointer()); - if(io_entity->type() == Ioss::NODEBLOCK) { + if(ioEntity->type() == Ioss::NODEBLOCK) { get_selected_nodes(params, *select, entities); } else { const bool sortById = true; - stk::mesh::get_entities(bulk, part_type, *select, entities, sortById); + stk::mesh::get_entities(bulk, partType, *select, entities, sortById); } } - const std::string get_suffix_for_field_at_state(enum stk::mesh::FieldState field_state, std::vector* multiStateSuffixes) + const std::string get_suffix_for_field_at_state(enum stk::mesh::FieldState fieldState, std::vector* multiStateSuffixes) { if(nullptr != multiStateSuffixes) { - ThrowRequireMsg((multiStateSuffixes->size() >= field_state), - "Invalid field state index '" << field_state << "'"); - return (*multiStateSuffixes)[field_state]; + ThrowRequireMsg((multiStateSuffixes->size() >= fieldState), + "Invalid field state index '" << fieldState << "'"); + return (*multiStateSuffixes)[fieldState]; } std::string suffix = ""; - switch(field_state) + switch(fieldState) { case stk::mesh::StateN: suffix = ".N"; @@ -1792,106 +1799,106 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta case stk::mesh::StateNP1: break; default: - ThrowRequireMsg(false, "Internal Error: Unsupported stk::mesh::FieldState: " << field_state << ".\n"); + ThrowRequireMsg(false, "Internal Error: Unsupported stk::mesh::FieldState: " << fieldState << ".\n"); } return suffix; } - std::string get_stated_field_name(const std::string &field_base_name, stk::mesh::FieldState state_identifier, + std::string get_stated_field_name(const std::string &fieldBaseName, stk::mesh::FieldState stateIdentifier, std::vector* multiStateSuffixes) { - std::string field_name_with_suffix = field_base_name + get_suffix_for_field_at_state(state_identifier, multiStateSuffixes); + std::string field_name_with_suffix = fieldBaseName + get_suffix_for_field_at_state(stateIdentifier, multiStateSuffixes); return field_name_with_suffix; } - bool field_state_exists_on_io_entity(const std::string& db_name, const stk::mesh::FieldBase* field, stk::mesh::FieldState state_identifier, - Ioss::GroupingEntity *io_entity, std::vector* multiStateSuffixes) + bool field_state_exists_on_io_entity(const std::string& dbName, const stk::mesh::FieldBase* field, stk::mesh::FieldState stateIdentifier, + Ioss::GroupingEntity *ioEntity, std::vector* multiStateSuffixes) { - std::string field_name_with_suffix = get_stated_field_name(db_name, state_identifier, multiStateSuffixes); - return io_entity->field_exists(field_name_with_suffix); + std::string fieldNameWithSuffix = get_stated_field_name(dbName, stateIdentifier, multiStateSuffixes); + return ioEntity->field_exists(fieldNameWithSuffix); } - bool all_field_states_exist_on_io_entity(const std::string& db_name, const stk::mesh::FieldBase* field, Ioss::GroupingEntity *io_entity, - std::vector &missing_states, std::vector* inputMultiStateSuffixes) + bool all_field_states_exist_on_io_entity(const std::string& dbName, const stk::mesh::FieldBase* field, Ioss::GroupingEntity *ioEntity, + std::vector &missingStates, std::vector* inputMultiStateSuffixes) { - bool all_states_exist = true; - size_t state_count = field->number_of_states(); + bool allStatesExist = true; + size_t stateCount = field->number_of_states(); - std::vector* multiStateSuffixes = state_count > 2 ? inputMultiStateSuffixes : nullptr; + std::vector* multiStateSuffixes = stateCount > 2 ? inputMultiStateSuffixes : nullptr; if(nullptr != multiStateSuffixes) { - ThrowRequire(multiStateSuffixes->size() >= state_count); + ThrowRequire(multiStateSuffixes->size() >= stateCount); } - for(size_t state = 0; state < state_count - 1; state++) { - stk::mesh::FieldState state_identifier = static_cast(state); - if (!field_state_exists_on_io_entity(db_name, field, state_identifier, io_entity, multiStateSuffixes)) { - all_states_exist = false; - missing_states.push_back(state_identifier); + for(size_t state = 0; state < stateCount - 1; state++) { + stk::mesh::FieldState stateIdentifier = static_cast(state); + if (!field_state_exists_on_io_entity(dbName, field, stateIdentifier, ioEntity, multiStateSuffixes)) { + allStatesExist = false; + missingStates.push_back(stateIdentifier); } } - return all_states_exist; + return allStatesExist; } void multistate_field_data_from_ioss(const stk::mesh::BulkData& mesh, const stk::mesh::FieldBase *field, - std::vector &entity_list, - Ioss::GroupingEntity *io_entity, + std::vector &entityList, + Ioss::GroupingEntity *ioEntity, const std::string &name, - const size_t state_count, - bool ignore_missing_fields, + const size_t stateCount, + bool ignoreMissingFields, std::vector* inputMultiStateSuffixes) { - std::vector* multiStateSuffixes = state_count > 2 ? inputMultiStateSuffixes : nullptr; + std::vector* multiStateSuffixes = stateCount > 2 ? inputMultiStateSuffixes : nullptr; if(nullptr != multiStateSuffixes) { - ThrowRequire(multiStateSuffixes->size() >= state_count); + ThrowRequire(multiStateSuffixes->size() >= stateCount); } - for(size_t state = 0; state < state_count - 1; state++) + for(size_t state = 0; state < stateCount - 1; state++) { - stk::mesh::FieldState state_identifier = static_cast(state); - bool field_exists = field_state_exists_on_io_entity(name, field, state_identifier, io_entity, multiStateSuffixes); - if (!field_exists && !ignore_missing_fields) { - STKIORequire(field_exists); + stk::mesh::FieldState stateIdentifier = static_cast(state); + bool fieldExists = field_state_exists_on_io_entity(name, field, stateIdentifier, ioEntity, multiStateSuffixes); + if (!fieldExists && !ignoreMissingFields) { + STKIORequire(fieldExists); } - if (field_exists) { - stk::mesh::FieldBase *stated_field = field->field_state(state_identifier); - std::string field_name_with_suffix = get_stated_field_name(name, state_identifier, multiStateSuffixes); - stk::io::field_data_from_ioss(mesh, stated_field, entity_list, io_entity, field_name_with_suffix); + if (fieldExists) { + stk::mesh::FieldBase *statedField = field->field_state(stateIdentifier); + std::string fieldNameWithSuffix = get_stated_field_name(name, stateIdentifier, multiStateSuffixes); + stk::io::field_data_from_ioss(mesh, statedField, entityList, ioEntity, fieldNameWithSuffix); } } } void subsetted_multistate_field_data_from_ioss(const stk::mesh::BulkData& mesh, const stk::mesh::FieldBase *field, - std::vector &entity_list, - Ioss::GroupingEntity *io_entity, - const stk::mesh::Part *stk_part, + std::vector &entityList, + Ioss::GroupingEntity *ioEntity, + const stk::mesh::Part *stkPart, const std::string &name, - const size_t state_count, - bool ignore_missing_fields, + const size_t stateCount, + bool ignoreMissingFields, std::vector* inputMultiStateSuffixes) { - std::vector* multiStateSuffixes = state_count > 2 ? inputMultiStateSuffixes : nullptr; + std::vector* multiStateSuffixes = stateCount > 2 ? inputMultiStateSuffixes : nullptr; if(nullptr != multiStateSuffixes) { - ThrowRequire(multiStateSuffixes->size() >= state_count); + ThrowRequire(multiStateSuffixes->size() >= stateCount); } - for(size_t state = 0; state < state_count - 1; state++) + for(size_t state = 0; state < stateCount - 1; state++) { - stk::mesh::FieldState state_identifier = static_cast(state); - bool field_exists = field_state_exists_on_io_entity(name, field, state_identifier, io_entity, multiStateSuffixes); - if (!field_exists && !ignore_missing_fields) { - STKIORequire(field_exists); + stk::mesh::FieldState stateIdentifier = static_cast(state); + bool fieldExists = field_state_exists_on_io_entity(name, field, stateIdentifier, ioEntity, multiStateSuffixes); + if (!fieldExists && !ignoreMissingFields) { + STKIORequire(fieldExists); } - if (field_exists) { - stk::mesh::FieldBase *stated_field = field->field_state(state_identifier); - std::string field_name_with_suffix = get_stated_field_name(name, state_identifier, multiStateSuffixes); - stk::io::subsetted_field_data_from_ioss(mesh, stated_field, entity_list, - io_entity, stk_part, field_name_with_suffix); + if (fieldExists) { + stk::mesh::FieldBase *statedField = field->field_state(stateIdentifier); + std::string fieldNameWithSuffix = get_stated_field_name(name, stateIdentifier, multiStateSuffixes); + stk::io::subsetted_field_data_from_ioss(mesh, statedField, entityList, + ioEntity, stkPart, fieldNameWithSuffix); } } } @@ -1899,41 +1906,41 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta void field_data_from_ioss(const stk::mesh::BulkData& mesh, const stk::mesh::FieldBase *field, std::vector &entities, - Ioss::GroupingEntity *io_entity, - const std::string &io_fld_name) + Ioss::GroupingEntity *ioEntity, + const std::string &ioFieldName) { /// \todo REFACTOR Need some additional compatibility checks between /// Ioss field and stk::mesh::Field; better error messages... - if (field != nullptr && io_entity->field_exists(io_fld_name)) { - const Ioss::Field &io_field = io_entity->get_fieldref(io_fld_name); + if (field != nullptr && ioEntity->field_exists(ioFieldName)) { + const Ioss::Field &ioField = ioEntity->get_fieldref(ioFieldName); if (field->type_is()) { - internal_field_data_from_ioss(mesh, io_field, field, entities, io_entity); + internal_field_data_from_ioss(mesh, ioField, field, entities, ioEntity); } else if (field->type_is()) { // Make sure the IO field type matches the STK field type. // By default, all IO fields are created of type 'double' - if (db_api_int_size(io_entity) == 4) { - io_field.check_type(Ioss::Field::INTEGER); - internal_field_data_from_ioss(mesh, io_field, field, entities, io_entity); + if (db_api_int_size(ioEntity) == 4) { + ioField.check_type(Ioss::Field::INTEGER); + internal_field_data_from_ioss(mesh, ioField, field, entities, ioEntity); } else { - io_field.check_type(Ioss::Field::INT64); - internal_field_data_from_ioss(mesh, io_field, field, entities, io_entity); + ioField.check_type(Ioss::Field::INT64); + internal_field_data_from_ioss(mesh, ioField, field, entities, ioEntity); } } else if (field->type_is()) { // Make sure the IO field type matches the STK field type. // By default, all IO fields are created of type 'double' - io_field.check_type(Ioss::Field::INT64); - internal_field_data_from_ioss(mesh, io_field, field, entities, io_entity); + ioField.check_type(Ioss::Field::INT64); + internal_field_data_from_ioss(mesh, ioField, field, entities, ioEntity); } else if (field->type_is()) { // Make sure the IO field type matches the STK field type. // By default, all IO fields are created of type 'double' - io_field.check_type(Ioss::Field::INTEGER); - internal_field_data_from_ioss(mesh, io_field, field, entities, io_entity); + ioField.check_type(Ioss::Field::INTEGER); + internal_field_data_from_ioss(mesh, ioField, field, entities, ioEntity); } else if (field->type_is()) { // Make sure the IO field type matches the STK field type. // By default, all IO fields are created of type 'double' - io_field.check_type(Ioss::Field::INT64); - internal_field_data_from_ioss(mesh, io_field, field, entities, io_entity); + ioField.check_type(Ioss::Field::INT64); + internal_field_data_from_ioss(mesh, ioField, field, entities, ioEntity); } } @@ -1942,37 +1949,37 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta void subsetted_field_data_from_ioss(const stk::mesh::BulkData& mesh, const stk::mesh::FieldBase *field, std::vector &entities, - Ioss::GroupingEntity *io_entity, - const stk::mesh::Part *stk_part, - const std::string &io_fld_name) + Ioss::GroupingEntity *ioEntity, + const stk::mesh::Part *stkPart, + const std::string &ioFieldName) { /// \todo REFACTOR Need some additional compatibility checks between /// Ioss field and stk::mesh::Field; better error messages... - if (field != nullptr && io_entity->field_exists(io_fld_name)) { - const Ioss::Field &io_field = io_entity->get_fieldref(io_fld_name); + if (field != nullptr && ioEntity->field_exists(ioFieldName)) { + const Ioss::Field &ioField = ioEntity->get_fieldref(ioFieldName); if (field->type_is()) { - internal_subsetted_field_data_from_ioss(mesh, io_field, field, entities, io_entity, stk_part); + internal_subsetted_field_data_from_ioss(mesh, ioField, field, entities, ioEntity, stkPart); } else if (field->type_is()) { // Make sure the IO field type matches the STK field type. // By default, all IO fields are created of type 'double' - if (db_api_int_size(io_entity) == 4) { - io_field.check_type(Ioss::Field::INTEGER); - internal_subsetted_field_data_from_ioss(mesh, io_field, field, entities, io_entity, stk_part); + if (db_api_int_size(ioEntity) == 4) { + ioField.check_type(Ioss::Field::INTEGER); + internal_subsetted_field_data_from_ioss(mesh, ioField, field, entities, ioEntity, stkPart); } else { - io_field.check_type(Ioss::Field::INT64); - internal_subsetted_field_data_from_ioss(mesh, io_field, field, entities, io_entity, - stk_part); + ioField.check_type(Ioss::Field::INT64); + internal_subsetted_field_data_from_ioss(mesh, ioField, field, entities, ioEntity, + stkPart); } } else if (field->type_is()) { // Make sure the IO field type matches the STK field type. // By default, all IO fields are created of type 'double' - if (db_api_int_size(io_entity) == 4) { - io_field.check_type(Ioss::Field::INTEGER); - internal_subsetted_field_data_from_ioss(mesh, io_field, field, entities, io_entity, stk_part); + if (db_api_int_size(ioEntity) == 4) { + ioField.check_type(Ioss::Field::INTEGER); + internal_subsetted_field_data_from_ioss(mesh, ioField, field, entities, ioEntity, stkPart); } else { - io_field.check_type(Ioss::Field::INT64); - internal_subsetted_field_data_from_ioss(mesh, io_field, field, entities, io_entity, - stk_part); + ioField.check_type(Ioss::Field::INT64); + internal_subsetted_field_data_from_ioss(mesh, ioField, field, entities, ioEntity, + stkPart); } } } @@ -1981,48 +1988,47 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta void multistate_field_data_to_ioss(const stk::mesh::BulkData& mesh, const stk::mesh::FieldBase *field, std::vector &entities, - Ioss::GroupingEntity *io_entity, - const std::string &io_fld_name, - Ioss::Field::RoleType filter_role, - const size_t state_count) + Ioss::GroupingEntity *ioEntity, + const std::string &ioFieldName, + Ioss::Field::RoleType filterRole, + const size_t stateCount) { - for(size_t state = 0; state < state_count - 1; state++) - { - stk::mesh::FieldState state_identifier = static_cast(state); - std::string field_name_with_suffix = get_stated_field_name(io_fld_name, state_identifier); - stk::mesh::FieldBase *stated_field = field->field_state(state_identifier); - //STKIORequire(io_entity->field_exists(field_name_with_suffix)); - stk::io::field_data_to_ioss(mesh, stated_field, entities, io_entity, field_name_with_suffix, filter_role); - } + for(size_t state = 0; state < stateCount - 1; state++) + { + stk::mesh::FieldState stateIdentifier = static_cast(state); + std::string fieldNameWithSuffix = get_stated_field_name(ioFieldName, stateIdentifier); + stk::mesh::FieldBase *statedField = field->field_state(stateIdentifier); + stk::io::field_data_to_ioss(mesh, statedField, entities, ioEntity, fieldNameWithSuffix, filterRole); + } } void field_data_to_ioss(const stk::mesh::BulkData& mesh, const stk::mesh::FieldBase *field, std::vector &entities, - Ioss::GroupingEntity *io_entity, - const std::string &io_fld_name, - Ioss::Field::RoleType filter_role) + Ioss::GroupingEntity *ioEntity, + const std::string &ioFieldName, + Ioss::Field::RoleType filterRole) { /// \todo REFACTOR Need some additional compatibility checks between /// Ioss field and stk::mesh::Field; better error messages... - if (field != nullptr && io_entity->field_exists(io_fld_name)) { - const Ioss::Field &io_field = io_entity->get_fieldref(io_fld_name); - if (io_field.get_role() == filter_role) { + if (field != nullptr && ioEntity->field_exists(ioFieldName)) { + const Ioss::Field &ioField = ioEntity->get_fieldref(ioFieldName); + if (ioField.get_role() == filterRole) { if (field->type_is()) { - internal_field_data_to_ioss(mesh, io_field, field, entities, io_entity); + internal_field_data_to_ioss(mesh, ioField, field, entities, ioEntity); } else if (field->type_is()) { - io_field.check_type(Ioss::Field::INTEGER); - internal_field_data_to_ioss(mesh, io_field, field, entities, io_entity); + ioField.check_type(Ioss::Field::INTEGER); + internal_field_data_to_ioss(mesh, ioField, field, entities, ioEntity); } else if (field->type_is()) { - io_field.check_type(Ioss::Field::INT64); - internal_field_data_to_ioss(mesh, io_field, field, entities, io_entity); + ioField.check_type(Ioss::Field::INT64); + internal_field_data_to_ioss(mesh, ioField, field, entities, ioEntity); } else if (field->type_is()) { - io_field.check_type(Ioss::Field::INT32); - internal_field_data_to_ioss(mesh, io_field, field, entities, io_entity); + ioField.check_type(Ioss::Field::INT32); + internal_field_data_to_ioss(mesh, ioField, field, entities, ioEntity); } else if (field->type_is()) { - io_field.check_type(Ioss::Field::INT64); - internal_field_data_to_ioss(mesh, io_field, field, entities, io_entity); + ioField.check_type(Ioss::Field::INT64); + internal_field_data_to_ioss(mesh, ioField, field, entities, ioEntity); } } } @@ -2048,14 +2054,13 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta namespace { - stk::mesh::EntityRank get_output_rank(const stk::io::OutputParams& params) + stk::mesh::EntityRank get_output_rank(stk::io::OutputParams& params) { return params.has_skin_mesh_selector() ? params.bulk_data().mesh_meta_data().side_rank() : stk::topology::ELEMENT_RANK; } //---------------------------------------------------------------------- - void define_node_block(stk::io::OutputParams ¶ms, - stk::mesh::Part &part) + void define_node_block(stk::io::OutputParams ¶ms, stk::mesh::Part &part) { //-------------------------------- // Set the spatial dimension: @@ -2065,23 +2070,23 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta //it from the coordinate-field's restriction onto the universal part. //This is because some codes (sierra framework) don't put the coordinate //field on the universal part. (framework puts it on active and inactive parts) - const int spatial_dim = meta.spatial_dimension(); + const int spatialDim = meta.spatial_dimension(); stk::mesh::EntityRank rank = get_output_rank(params); //-------------------------------- // Create the special universal node block: - mesh::Selector shared_selector = params.has_shared_selector() ? *(params.get_shared_selector()) - : meta.globally_shared_part(); + mesh::Selector sharedSelector = params.has_shared_selector() ? *(params.get_shared_selector()) + : meta.globally_shared_part(); - mesh::Selector all_selector = meta.globally_shared_part() | meta.locally_owned_part(); - if (params.get_subset_selector( )) all_selector &= *params.get_subset_selector(); - if (params.get_output_selector(rank)) all_selector &= *params.get_output_selector(rank); + mesh::Selector allSelector = meta.globally_shared_part() | meta.locally_owned_part(); + if (params.get_subset_selector( )) allSelector &= *params.get_subset_selector(); + if (params.get_output_selector(rank)) allSelector &= *params.get_output_selector(rank); - mesh::Selector own_selector = meta.locally_owned_part(); - if (params.get_subset_selector( )) own_selector &= *params.get_subset_selector(); - if (params.get_output_selector(rank)) own_selector &= *params.get_output_selector(rank); + mesh::Selector ownSelector = meta.locally_owned_part(); + if (params.get_subset_selector( )) ownSelector &= *params.get_subset_selector(); + if (params.get_output_selector(rank)) ownSelector &= *params.get_output_selector(rank); - int64_t all_nodes = count_selected_nodes(params, all_selector); - int64_t own_nodes = count_selected_nodes(params, own_selector); + int64_t allNodes = count_selected_nodes(params, allSelector); + int64_t ownNodes = count_selected_nodes(params, ownSelector); const std::string name("nodeblock_1"); @@ -2089,17 +2094,17 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta if(nb == nullptr) { nb = new Ioss::NodeBlock(params.io_region().get_database(), - name, all_nodes, spatial_dim); + name, allNodes, spatialDim); params.io_region().add( nb ); } delete_selector_property(nb); - mesh::Selector *node_select = new mesh::Selector(all_selector); - nb->property_add(Ioss::Property(s_internal_selector_name, node_select)); + mesh::Selector *nodeSelect = new mesh::Selector(allSelector); + nb->property_add(Ioss::Property(s_internalSelectorName, nodeSelect)); nb->property_add(Ioss::Property(base_stk_part_name, getPartName(part))); // Add locally-owned property... - nb->property_add(Ioss::Property("locally_owned_count", own_nodes)); + nb->property_add(Ioss::Property("locally_owned_count", ownNodes)); // Add the attribute fields. ioss_add_fields(part, part_primary_entity_rank(part), nb, Ioss::Field::ATTRIBUTE); } @@ -2111,39 +2116,39 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta { mesh::EntityRank rank = get_output_rank(params); mesh::MetaData & meta = mesh::MetaData::get(part); - Ioss::Region & io_region = params.io_region(); + Ioss::Region & ioRegion = params.io_region(); - mesh::Selector shared_selector = params.has_shared_selector() ? *(params.get_shared_selector()) - : meta.globally_shared_part(); + mesh::Selector sharedSelector = params.has_shared_selector() ? *(params.get_shared_selector()) + : meta.globally_shared_part(); - mesh::Selector all_selector = (meta.globally_shared_part() | meta.locally_owned_part()) & part; - if (params.get_subset_selector( )) all_selector &= *params.get_subset_selector(); - if (params.get_output_selector(rank)) all_selector &= *params.get_output_selector(rank); + mesh::Selector allSelector = (meta.globally_shared_part() | meta.locally_owned_part()) & part; + if (params.get_subset_selector( )) allSelector &= *params.get_subset_selector(); + if (params.get_output_selector(rank)) allSelector &= *params.get_output_selector(rank); - mesh::Selector own_selector = meta.locally_owned_part() & part; - if (params.get_subset_selector( )) own_selector &= *params.get_subset_selector(); - if (params.get_output_selector(rank)) own_selector &= *params.get_output_selector(rank); + mesh::Selector ownSelector = meta.locally_owned_part() & part; + if (params.get_subset_selector( )) ownSelector &= *params.get_subset_selector(); + if (params.get_output_selector(rank)) ownSelector &= *params.get_output_selector(rank); - int64_t all_nodes = count_selected_nodes(params, all_selector); - int64_t own_nodes = count_selected_nodes(params, own_selector); + int64_t allNodes = count_selected_nodes(params, allSelector); + int64_t ownNodes = count_selected_nodes(params, ownSelector); - Ioss::NodeSet *ns = io_region.get_nodeset(name); + Ioss::NodeSet *ns = ioRegion.get_nodeset(name); if(ns == nullptr) { - ns = new Ioss::NodeSet( io_region.get_database(), name, all_nodes); - io_region.add(ns); + ns = new Ioss::NodeSet( ioRegion.get_database(), name, allNodes); + ioRegion.add(ns); - bool use_generic_canonical_name = io_region.get_database()->get_use_generic_canonical_name(); + bool use_generic_canonical_name = ioRegion.get_database()->get_use_generic_canonical_name(); if(use_generic_canonical_name) { add_canonical_name_property(ns, part); } } - ns->property_add(Ioss::Property("locally_owned_count", own_nodes)); + ns->property_add(Ioss::Property("locally_owned_count", ownNodes)); delete_selector_property(ns); - mesh::Selector *select = new mesh::Selector(all_selector); - ns->property_add(Ioss::Property(s_internal_selector_name, select)); + mesh::Selector *select = new mesh::Selector(allSelector); + ns->property_add(Ioss::Property(s_internalSelectorName, select)); ns->property_add(Ioss::Property(base_stk_part_name, getPartName(part))); if(!isDerivedNodeset) { @@ -2163,17 +2168,17 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta stk::mesh::Selector selector, stk::mesh::Part &part, Ioss::SideSet *sset, - int spatial_dimension, - bool create_nodeset) + int spatialDimension, + bool createNodeset) { stk::mesh::EntityRank type = part.primary_entity_rank(); const stk::mesh::EntityRank siderank = stk::mesh::MetaData::get(part).side_rank(); const stk::mesh::EntityRank edgerank = stk::topology::EDGE_RANK; STKIORequire(type == siderank || type == edgerank); - stk::topology side_topology = part.topology(); - std::string io_topo = map_stk_topology_to_ioss(side_topology); - std::string element_topo_name = "unknown"; + stk::topology sideTopology = part.topology(); + std::string ioTopo = map_stk_topology_to_ioss(sideTopology); + std::string elementTopoName = "unknown"; const stk::mesh::BulkData &bulk = params.bulk_data(); @@ -2181,8 +2186,8 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta // Try to decode from part name... std::vector tokens; stk::util::tokenize(getPartName(part), "_", tokens); - const Ioss::ElementTopology *element_topo = nullptr; - stk::topology stk_element_topology = stk::topology::INVALID_TOPOLOGY; + const Ioss::ElementTopology *elementTopo = nullptr; + stk::topology stkElementTopology = stk::topology::INVALID_TOPOLOGY; if (tokens.size() >= 4) { // If the sideset has a "canonical" name as in "surface_{id}", // Then the sideblock name will be of the form: @@ -2194,123 +2199,123 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta // * "{sideset_name}_block_id_sidetopo" // Check the last token and see if it is an integer... - bool all_dig = tokens.back().find_first_not_of("0123456789") == std::string::npos; - if (all_dig) { - element_topo = Ioss::ElementTopology::factory(tokens[1], true); + bool allDigits = tokens.back().find_first_not_of("0123456789") == std::string::npos; + if (allDigits) { + elementTopo = Ioss::ElementTopology::factory(tokens[1], true); } else { - element_topo = Ioss::ElementTopology::factory(tokens[tokens.size() - 2], true); + elementTopo = Ioss::ElementTopology::factory(tokens[tokens.size() - 2], true); } - if (element_topo != nullptr) { - element_topo_name = element_topo->name(); - stk_element_topology = map_ioss_topology_to_stk(element_topo, bulk.mesh_meta_data().spatial_dimension()); + if (elementTopo != nullptr) { + elementTopoName = elementTopo->name(); + stkElementTopology = map_ioss_topology_to_stk(elementTopo, bulk.mesh_meta_data().spatial_dimension()); } } const stk::mesh::Part *parentElementBlock = get_parent_element_block(bulk, params.io_region(), part.name()); - size_t side_count = get_number_sides_in_sideset(params, part, stk_element_topology, parentElementBlock); + size_t sideCount = get_number_sides_in_sideset(params, part, stkElementTopology, parentElementBlock); std::string name = getPartName(part); - Ioss::SideBlock *side_block = sset->get_side_block(name); - if(side_block == nullptr) + Ioss::SideBlock *sideBlock = sset->get_side_block(name); + if(sideBlock == nullptr) { - side_block = new Ioss::SideBlock(sset->get_database(), name, io_topo, element_topo_name, side_count); - sset->add(side_block); + sideBlock = new Ioss::SideBlock(sset->get_database(), name, ioTopo, elementTopoName, sideCount); + sset->add(sideBlock); } const mesh::FieldBase *df = get_distribution_factor_field(part); if (df != nullptr) { - int nodes_per_side = side_topology.num_nodes(); - std::string storage_type = "Real["; - storage_type += sierra::to_string(nodes_per_side); - storage_type += "]"; - side_block->field_add(Ioss::Field(s_distribution_factors, Ioss::Field::REAL, storage_type, - Ioss::Field::MESH, side_count)); + int nodesPerSide = sideTopology.num_nodes(); + std::string storageType = "Real["; + storageType += sierra::to_string(nodesPerSide); + storageType += "]"; + sideBlock->field_add(Ioss::Field(s_distributionFactors, Ioss::Field::REAL, storageType, + Ioss::Field::MESH, sideCount)); } selector &= bulk.mesh_meta_data().locally_owned_part(); - delete_selector_property(side_block); + delete_selector_property(sideBlock); mesh::Selector *select = new mesh::Selector(selector); - side_block->property_add(Ioss::Property(s_internal_selector_name, select)); - side_block->property_add(Ioss::Property(base_stk_part_name, getPartName(part))); + sideBlock->property_add(Ioss::Property(s_internalSelectorName, select)); + sideBlock->property_add(Ioss::Property(base_stk_part_name, getPartName(part))); // Add the attribute fields. - ioss_add_fields(part, part_primary_entity_rank(part), side_block, Ioss::Field::ATTRIBUTE); + ioss_add_fields(part, part_primary_entity_rank(part), sideBlock, Ioss::Field::ATTRIBUTE); - if(create_nodeset) { - std::string nodes_name = getPartName(part) + s_entity_nodes_suffix; + if(createNodeset) { + std::string nodes_name = getPartName(part) + s_entityNodesSuffix; bool isDerivedNodeset = true; define_node_set(params, part, nodes_name, isDerivedNodeset); } } bool should_create_nodeset_from_sideset(stk::mesh::Part &part, - bool use_nodeset_for_nodal_fields, - bool check_field_existence) + bool useNodesetForNodalFields, + bool checkFieldExistence) { STKIORequire(part.primary_entity_rank() == stk::topology::FACE_RANK || stk::topology::EDGE_RANK); - bool create_nodesets = false; + bool createNodesets = false; - if (use_nodeset_for_nodal_fields) { - if(check_field_existence) { - bool lower_rank_fields = will_output_lower_rank_fields(part, stk::topology::NODE_RANK); + if (useNodesetForNodalFields) { + if(checkFieldExistence) { + bool lowerRankFields = will_output_lower_rank_fields(part, stk::topology::NODE_RANK); - if (!lower_rank_fields) { + if (!lowerRankFields) { // See if lower rank fields are defined on sideblock parts of this sideset... const stk::mesh::PartVector &blocks = part.subsets(); - for (size_t j = 0; j < blocks.size() && !lower_rank_fields; j++) { + for (size_t j = 0; j < blocks.size() && !lowerRankFields; j++) { mesh::Part & side_block_part = *blocks[j]; - lower_rank_fields |= will_output_lower_rank_fields(side_block_part, stk::topology::NODE_RANK); + lowerRankFields |= will_output_lower_rank_fields(side_block_part, stk::topology::NODE_RANK); } } - if (lower_rank_fields) { - create_nodesets = true; + if (lowerRankFields) { + createNodesets = true; } } else { - create_nodesets = true; + createNodesets = true; } } if(has_derived_nodeset_attribute(part)) { - create_nodesets = get_derived_nodeset_attribute(part); + createNodesets = get_derived_nodeset_attribute(part); } - return create_nodesets; + return createNodesets; } void define_side_blocks(stk::io::OutputParams ¶ms, stk::mesh::Part &part, Ioss::SideSet *sset, stk::mesh::EntityRank type, - int spatial_dimension) + int spatialDimension) { STKIORequire(type == stk::topology::FACE_RANK || stk::topology::EDGE_RANK); - bool create_nodesets = should_create_nodeset_from_sideset(part, - params.get_use_nodeset_for_sideset_node_fields(), - params.check_field_existence_when_creating_nodesets()); + bool createNodesets = should_create_nodeset_from_sideset(part, + params.get_use_nodeset_for_sideset_node_fields(), + params.check_field_existence_when_creating_nodesets()); stk::mesh::EntityRank rank = stk::topology::ELEM_RANK; const stk::mesh::PartVector &blocks = part.subsets(); if (blocks.size() > 0) { for (size_t j = 0; j < blocks.size(); j++) { - mesh::Part & side_block_part = *blocks[j]; - mesh::Selector selector = side_block_part; + mesh::Part & sideBlockPart = *blocks[j]; + mesh::Selector selector = sideBlockPart; if (params.get_subset_selector( )) selector &= *params.get_subset_selector(); if (params.get_output_selector(rank)) selector &= *params.get_output_selector(rank); define_side_block(params, selector, - side_block_part, sset, spatial_dimension, - create_nodesets); + sideBlockPart, sset, spatialDimension, + createNodesets); } } else { mesh::Selector selector = part; if (params.get_subset_selector( )) selector &= *params.get_subset_selector(); if (params.get_output_selector(rank)) selector &= *params.get_output_selector(rank); define_side_block(params, selector, - part, sset, spatial_dimension, - create_nodesets); + part, sset, spatialDimension, + createNodesets); } } @@ -2353,7 +2358,7 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta stk::mesh::PartVector leafParts = get_unique_leaf_parts(meta, assemblyPart.name()); for (stk::mesh::Part* leafPart : leafParts) { if (is_in_subsets_of_parts(*leafPart, leafParts)) {continue;} - if (is_valid_for_output(*leafPart, params.get_output_selector(leafPart->primary_entity_rank()))) { + if (is_valid_for_output(params, *leafPart)) { return true; } } @@ -2366,15 +2371,65 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta if (!assembly_has_valid_io_leaf_part(params, part)) { return; } - Ioss::Region &io_region = params.io_region(); + Ioss::Region &ioRegion = params.io_region(); std::string name = getPartName(part); - Ioss::Assembly *assembly = io_region.get_assembly(name); + Ioss::Assembly *assembly = ioRegion.get_assembly(name); if (assembly == nullptr) { - assembly = new Ioss::Assembly(io_region.get_database(), name); + assembly = new Ioss::Assembly(ioRegion.get_database(), name); set_id_property(params, part, assembly); - io_region.add(assembly); + ioRegion.add(assembly); + } + } + + bool is_valid_assembly_member_type(const Ioss::Assembly *assem, const Ioss::GroupingEntity* member) + { + if(nullptr == member) return false; + + if((member->type() != Ioss::ELEMENTBLOCK) && (member->type() != Ioss::SIDESET) && + (member->type() != Ioss::NODESET) && (member->type() != Ioss::ASSEMBLY)) { + std::string filename = assem->get_database()->get_filename(); + stk::RuntimeWarning() << "The entity type of '" << member->name() << "' (" << member->type_string() << + ") is not a valid assembly member type for " + "assembly '" << assem->name() << "' (" << assem->contains_string() << + ").\n\t In the database file '" << filename << "'.\n"; + return false; + } + + return true; + } + + bool is_empty_element_block(stk::io::OutputParams ¶ms, const stk::mesh::Part* leafPart) + { + bool isEmptyElementBlock = false; + const std::unordered_map& blockSizes = params.get_block_sizes(); + + if(leafPart != nullptr && is_part_element_block_io_part(*leafPart)) { + auto iter = blockSizes.find(leafPart->mesh_meta_data_ordinal()); + ThrowRequireMsg(iter != blockSizes.end(), "Could not find element block in block size list: " << leafPart->name()); + isEmptyElementBlock = (iter->second == 0); + } + + return isEmptyElementBlock; + } + + bool can_add_to_assembly(stk::io::OutputParams ¶ms, const Ioss::Assembly *assembly, + const Ioss::GroupingEntity* leafEntity, const stk::mesh::Part* leafPart) + { + bool isNotCurrentMember = (leafEntity != nullptr) && (assembly->get_member(leafEntity->name()) == nullptr); + bool isValidMemberType = is_valid_assembly_member_type(assembly, leafEntity); + bool isEmptyElementBlock = false; + + bool filterEmptyBlocks = params.get_filter_empty_entity_blocks() || + params.get_filter_empty_assembly_entity_blocks(); + + if(filterEmptyBlocks) { + isEmptyElementBlock = is_empty_element_block(params, leafPart); } + + bool isValid = isNotCurrentMember && isValidMemberType && !isEmptyElementBlock; + + return isValid; } void define_assembly_hierarchy(stk::io::OutputParams ¶ms, @@ -2384,11 +2439,11 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta return; } const stk::mesh::MetaData & meta = mesh::MetaData::get(part); - Ioss::Region &io_region = params.io_region(); + Ioss::Region &ioRegion = params.io_region(); std::string name = getPartName(part); - Ioss::Assembly *assembly = io_region.get_assembly(name); + Ioss::Assembly *assembly = ioRegion.get_assembly(name); ThrowRequireMsg(assembly != nullptr, "Failed to find assembly "<get_member(subAssembly->name())==nullptr) { assembly->add(subAssembly); } @@ -2409,19 +2467,20 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta for(stk::mesh::Part* leafPart : leafParts) { if(is_in_subsets_of_parts(*leafPart, leafParts)) {continue;} std::string iossLeafPartName = getPartName(*leafPart); - const Ioss::GroupingEntity* leafEntity = io_region.get_entity(iossLeafPartName); + const Ioss::GroupingEntity* leafEntity = ioRegion.get_entity(iossLeafPartName); if (leafEntity == nullptr) { stk::RuntimeWarning() << "Failed to find ioss entity: '" << iossLeafPartName << "' in assembly: '" << name << "'"; } - if ((leafEntity != nullptr) && assembly->get_member(leafEntity->name()) == nullptr) { + if (can_add_to_assembly(params, assembly, leafEntity, leafPart)) { assembly->add(leafEntity); } } } if (assembly->member_count() == 0) { - io_region.remove(assembly); + ioRegion.remove(assembly); + delete assembly; } } @@ -2430,12 +2489,12 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta { mesh::MetaData & meta = mesh::MetaData::get(part); const stk::mesh::BulkData &bulk = params.bulk_data(); - Ioss::Region &io_region = params.io_region(); + Ioss::Region &ioRegion = params.io_region(); stk::topology topo = part.topology(); if (topo == stk::topology::INVALID_TOPOLOGY) { std::ostringstream msg ; - msg << " INTERNAL_ERROR when defining output for region '"<get_use_generic_canonical_name(); - if(use_generic_canonical_name) { + bool useGenericCanonicalName = ioRegion.get_database()->get_use_generic_canonical_name(); + if(useGenericCanonicalName) { add_canonical_name_property(fb, part); } - bool use_original_topology = has_original_topology_type(part); - if(use_original_topology) { + bool useOriginalTopology = has_original_topology_type(part); + if(useOriginalTopology) { add_original_topology_property(fb, part); } } @@ -2477,7 +2536,7 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta delete_selector_property(fb); mesh::Selector *select = new mesh::Selector(selector); - fb->property_add(Ioss::Property(s_internal_selector_name, select)); + fb->property_add(Ioss::Property(s_internalSelectorName, select)); fb->property_add(Ioss::Property(base_stk_part_name, getPartName(part))); // Add the attribute fields. @@ -2489,12 +2548,12 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta { mesh::MetaData & meta = mesh::MetaData::get(part); const stk::mesh::BulkData &bulk = params.bulk_data(); - Ioss::Region &io_region = params.io_region(); + Ioss::Region &ioRegion = params.io_region(); stk::topology topo = part.topology(); if (topo == stk::topology::INVALID_TOPOLOGY) { std::ostringstream msg ; - msg << " INTERNAL_ERROR when defining output for region '"<get_use_generic_canonical_name(); + bool use_generic_canonical_name = ioRegion.get_database()->get_use_generic_canonical_name(); if(use_generic_canonical_name) { add_canonical_name_property(eb, part); } @@ -2536,7 +2595,7 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta delete_selector_property(eb); mesh::Selector *select = new mesh::Selector(selector); - eb->property_add(Ioss::Property(s_internal_selector_name, select)); + eb->property_add(Ioss::Property(s_internalSelectorName, select)); eb->property_add(Ioss::Property(base_stk_part_name, getPartName(part))); // Add the attribute fields. @@ -2546,24 +2605,24 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta void define_element_block(stk::io::OutputParams ¶ms, stk::mesh::Part &part, const std::vector> &attributeOrdering, - bool order_blocks_by_creation_order) + bool orderBlocksByCreationOrder) { mesh::MetaData & meta = mesh::MetaData::get(part); const stk::mesh::BulkData &bulk = params.bulk_data(); - Ioss::Region &io_region = params.io_region(); + Ioss::Region &ioRegion = params.io_region(); stk::mesh::EntityRank rank = get_output_rank(params); mesh::Selector selector = impl::internal_build_selector(params.get_subset_selector(), - params.get_output_selector(rank), - nullptr, part, false); + params.get_output_selector(rank), + nullptr, part, false); - const size_t num_elems = stk::mesh::count_entities(bulk, rank, selector); + const size_t numElems = stk::mesh::count_entities(bulk, rank, selector); stk::topology topo = part.topology(); if (topo == stk::topology::INVALID_TOPOLOGY) { std::ostringstream msg ; - msg << " INTERNAL_ERROR when defining output for region '"<get_use_generic_canonical_name(); - if(use_generic_canonical_name) { + bool useGenericCanonicalName = ioRegion.get_database()->get_use_generic_canonical_name(); + if(useGenericCanonicalName) { add_canonical_name_property(eb, part); } - bool use_original_topology = has_original_topology_type(part); - if(use_original_topology && !params.has_skin_mesh_selector()) { + bool useOriginalTopology = has_original_topology_type(part); + if(useOriginalTopology && !params.has_skin_mesh_selector()) { add_original_topology_property(eb, part); } } - if (order_blocks_by_creation_order) + if (orderBlocksByCreationOrder) { int ordinal = part.mesh_meta_data_ordinal(); eb->property_update("original_block_order", ordinal); @@ -2615,7 +2674,7 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta delete_selector_property(eb); mesh::Selector *select = new mesh::Selector(selector); - eb->property_add(Ioss::Property(s_internal_selector_name, select)); + eb->property_add(Ioss::Property(s_internalSelectorName, select)); eb->property_add(Ioss::Property(base_stk_part_name, getPartName(part))); // Add the attribute fields. @@ -2629,27 +2688,27 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta // and output the fields on that nodeset... if (params.get_use_nodeset_for_block_node_fields() && will_output_lower_rank_fields(part, stk::topology::NODE_RANK)) { - std::string nodes_name = getPartName(part) + s_entity_nodes_suffix; + std::string nodesName = getPartName(part) + s_entityNodesSuffix; bool isDerivedNodeset = true; - define_node_set(params, part, nodes_name, isDerivedNodeset); + define_node_set(params, part, nodesName, isDerivedNodeset); } } void define_communication_maps(stk::io::OutputParams ¶ms) { const mesh::BulkData & bulk = params.bulk_data(); - Ioss::Region & io_region = params.io_region(); + Ioss::Region & ioRegion = params.io_region(); mesh::EntityRank rank = get_output_rank(params); - const stk::mesh::Selector *subset_selector = params.get_subset_selector(); - const stk::mesh::Selector *output_selector = params.get_output_selector(rank); + const stk::mesh::Selector *subsetSelector = params.get_subset_selector(); + const stk::mesh::Selector *outputSelector = params.get_output_selector(rank); if (bulk.parallel_size() > 1) { const stk::mesh::MetaData & meta = bulk.mesh_meta_data(); - const std::string cs_name("node_symm_comm_spec"); + const std::string csName("node_symm_comm_spec"); mesh::Selector selector = meta.globally_shared_part(); - if (subset_selector) selector &= *subset_selector; - if (output_selector) selector &= *output_selector; + if (subsetSelector) selector &= *subsetSelector; + if (outputSelector) selector &= *outputSelector; std::vector entities; get_selected_nodes(params, selector, entities); @@ -2661,31 +2720,30 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta size+=sharingProcs.size(); } - Ioss::DatabaseIO *dbo = io_region.get_database(); - Ioss::CommSet *io_cs = new Ioss::CommSet(dbo, cs_name, "node", size); - io_region.add(io_cs); + Ioss::DatabaseIO *dbo = ioRegion.get_database(); + Ioss::CommSet *ioCs = new Ioss::CommSet(dbo, csName, "node", size); + ioRegion.add(ioCs); - delete_selector_property(io_cs); + delete_selector_property(ioCs); mesh::Selector *select = new mesh::Selector(selector); - io_cs->property_add(Ioss::Property(s_internal_selector_name, select)); + ioCs->property_add(Ioss::Property(s_internalSelectorName, select)); // Update global node and element count... - if (!io_region.property_exists("global_node_count") || !io_region.property_exists("global_element_count")) { + if (!ioRegion.property_exists("global_node_count") || !ioRegion.property_exists("global_element_count")) { std::vector entityCounts; stk::mesh::comm_mesh_counts(bulk, entityCounts); - io_region.property_add(Ioss::Property("global_node_count", static_cast(entityCounts[stk::topology::NODE_RANK]))); - io_region.property_add(Ioss::Property("global_element_count", static_cast(entityCounts[stk::topology::ELEMENT_RANK]))); + ioRegion.property_add(Ioss::Property("global_node_count", static_cast(entityCounts[stk::topology::NODE_RANK]))); + ioRegion.property_add(Ioss::Property("global_element_count", static_cast(entityCounts[stk::topology::ELEMENT_RANK]))); } } } - void define_side_set(stk::io::OutputParams ¶ms, - stk::mesh::Part &part) + void define_side_set(stk::io::OutputParams ¶ms, stk::mesh::Part &part) { - const stk::mesh::EntityRank si_rank = mesh::MetaData::get(part).side_rank(); + const stk::mesh::EntityRank sideRank = mesh::MetaData::get(part).side_rank(); - bool create_sideset = ! params.has_skin_mesh_selector(); + bool createSideset = ! params.has_skin_mesh_selector(); if (part.subsets().empty()) { // Only define a sideset for this part if its superset part is // not a side-containing part.. (i.e., this part is not a subset part @@ -2693,23 +2751,23 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta const stk::mesh::PartVector &supersets = part.supersets(); for (size_t i=0; i < supersets.size(); i++) { if (is_part_surface_io_part(*supersets[i])) { - create_sideset = false; + createSideset = false; break; } } } - if (create_sideset) { + if (createSideset) { std::string name = getPartName(part); - Ioss::Region & io_region = params.io_region(); - Ioss::SideSet *ss = io_region.get_sideset(name); + Ioss::Region & ioRegion = params.io_region(); + Ioss::SideSet *ss = ioRegion.get_sideset(name); if(ss == nullptr) { - ss = new Ioss::SideSet(io_region.get_database(), name); - io_region.add(ss); + ss = new Ioss::SideSet(ioRegion.get_database(), name); + ioRegion.add(ss); - bool use_generic_canonical_name = io_region.get_database()->get_use_generic_canonical_name(); - if(use_generic_canonical_name) { + bool useGenericCanonicalName = ioRegion.get_database()->get_use_generic_canonical_name(); + if(useGenericCanonicalName) { add_canonical_name_property(ss, part); } } @@ -2721,14 +2779,14 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta ss->property_add(Ioss::Property("id", part.id())); } - int spatial_dim = io_region.get_property("spatial_dimension").get_int(); - define_side_blocks(params, part, ss, si_rank, spatial_dim); + int spatialDim = ioRegion.get_property("spatial_dimension").get_int(); + define_side_blocks(params, part, ss, sideRank, spatialDim); } } } // namespace - void set_element_block_order(const mesh::PartVector *parts, Ioss::Region & io_region) + void set_element_block_order(const mesh::PartVector *parts, Ioss::Region & ioRegion) { int64_t offset=0; @@ -2738,9 +2796,9 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta if (is_part_io_part(*part) && (part->primary_entity_rank() == stk::topology::ELEMENT_RANK)) { if(has_original_block_order(*part)) { int64_t order = get_original_block_order(*part); - Ioss::GroupingEntity *element_block = io_region.get_entity(getPartName(*part)); - if (element_block) { - element_block->property_update("original_block_order", order); + Ioss::GroupingEntity *elementBlock = ioRegion.get_entity(getPartName(*part)); + if (elementBlock) { + elementBlock->property_update("original_block_order", order); offset = std::max(offset, order); } } @@ -2753,10 +2811,10 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta mesh::Part * const part = *i ; if (is_part_io_part(*part) && (part->primary_entity_rank() == stk::topology::ELEMENT_RANK)) { - Ioss::GroupingEntity *element_block = io_region.get_entity(getPartName(*part)); - if (element_block) { - if (!element_block->property_exists("original_block_order")) { - element_block->property_add(Ioss::Property("original_block_order", offset)); + Ioss::GroupingEntity *elementBlock = ioRegion.get_entity(getPartName(*part)); + if (elementBlock) { + if (!elementBlock->property_exists("original_block_order")) { + elementBlock->property_add(Ioss::Property("original_block_order", offset)); ++offset; } } @@ -2819,37 +2877,37 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta void define_output_db_within_state_define(stk::io::OutputParams ¶ms, const std::vector> &attributeOrdering, - const Ioss::Region *input_region = nullptr) + const Ioss::Region *inputRegion = nullptr) { - Ioss::Region & io_region = params.io_region(); - const mesh::BulkData &bulk_data = params.bulk_data(); - const bool sort_stk_parts_by_name = params.get_sort_stk_parts_by_name(); + Ioss::Region & ioRegion = params.io_region(); + const mesh::BulkData &bulkData = params.bulk_data(); + const bool sortStkPartsByName = params.get_sort_stk_parts_by_name(); - const mesh::MetaData & meta_data = bulk_data.mesh_meta_data(); - define_node_block(params, meta_data.universal_part()); + const mesh::MetaData & metaData = bulkData.mesh_meta_data(); + define_node_block(params, metaData.universal_part()); // All parts of the meta data: const mesh::PartVector *parts = nullptr; - mesh::PartVector all_parts_sorted; + mesh::PartVector allPartsSorted; - const mesh::PartVector & all_parts = meta_data.get_parts(); + const mesh::PartVector & allParts = metaData.get_parts(); // sort parts so they go out the same on all processors (srk: this was induced by streaming refine) - if (sort_stk_parts_by_name) { - all_parts_sorted = all_parts; - std::sort(all_parts_sorted.begin(), all_parts_sorted.end(), part_compare_by_name()); - parts = &all_parts_sorted; + if (sortStkPartsByName) { + allPartsSorted = allParts; + std::sort(allPartsSorted.begin(), allPartsSorted.end(), part_compare_by_name()); + parts = &allPartsSorted; } else { - parts = &all_parts; + parts = &allParts; } - const bool order_blocks_by_creation_order = (input_region == nullptr) && !sort_stk_parts_by_name; - const int spatialDim = meta_data.spatial_dimension(); + const bool orderBlocksByCreationOrder = (inputRegion == nullptr) && !sortStkPartsByName; + const int spatialDim = metaData.spatial_dimension(); for (stk::mesh::Part* const part : *parts) { const stk::mesh::EntityRank rank = part->primary_entity_rank(); if (is_part_io_part(*part)) { - bool isValidForOutput = is_valid_for_output(*part, params.get_output_selector(rank)); + bool isValidForOutput = is_valid_for_output(params, *part); if (is_part_assembly_io_part(*part)) { define_assembly(params, *part); @@ -2861,7 +2919,7 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta define_node_set(params, *part, getPartName(*part)); } else if ((rank == stk::topology::ELEMENT_RANK) && isValidForOutput) { - define_element_block(params, *part, attributeOrdering, order_blocks_by_creation_order); + define_element_block(params, *part, attributeOrdering, orderBlocksByCreationOrder); } else if (is_part_face_block_io_part(*part)) { define_face_block(params, *part); @@ -2884,18 +2942,18 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta define_communication_maps(params); - if (input_region != nullptr) - io_region.synchronize_id_and_name(input_region, true); + if (inputRegion != nullptr) + ioRegion.synchronize_id_and_name(inputRegion, true); - set_element_block_order(parts, io_region); + set_element_block_order(parts, ioRegion); } void define_output_db(stk::io::OutputParams ¶ms, const std::vector> &attributeOrdering, - const Ioss::Region *input_region) + const Ioss::Region *inputRegion) { params.io_region().begin_mode( Ioss::STATE_DEFINE_MODEL ); - define_output_db_within_state_define(params, attributeOrdering, input_region); + define_output_db_within_state_define(params, attributeOrdering, inputRegion); params.io_region().end_mode( Ioss::STATE_DEFINE_MODEL ); } @@ -2906,35 +2964,35 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta void write_side_data_to_ioss( stk::io::OutputParams ¶ms, Ioss::GroupingEntity & io , mesh::Part * const part , - const Ioss::ElementTopology *element_topology) + const Ioss::ElementTopology *elementTopology) { - std::vector elem_side_ids; + std::vector elemSideIds; stk::mesh::EntityVector sides; - fill_data_for_side_block(params, io, part, element_topology, elem_side_ids, sides); - size_t num_sides = sides.size(); + fill_data_for_side_block(params, io, part, elementTopology, elemSideIds, sides); + size_t numSides = sides.size(); - const size_t num_side_written = io.put_field_data("element_side",elem_side_ids); + const size_t numSideWritten = io.put_field_data("element_side",elemSideIds); - if ( num_sides != num_side_written ) { + if ( numSides != numSideWritten ) { std::ostringstream msg ; msg << "stk::io::write_side_data_to_ioss FAILED for " ; msg << io.name(); msg << " in Ioss::GroupingEntity::put_field_data:" ; - msg << " num_sides = " << num_sides ; - msg << " , num_side_written = " << num_side_written ; + msg << " numSides = " << numSides ; + msg << " , num_side_written = " << numSideWritten ; throw std::runtime_error( msg.str() ); } const mesh::FieldBase *df = get_distribution_factor_field(*part); if (df != nullptr) { - field_data_to_ioss(params.bulk_data(), df, sides, &io, s_distribution_factors, Ioss::Field::MESH); + field_data_to_ioss(params.bulk_data(), df, sides, &io, s_distributionFactors, Ioss::Field::MESH); } - const mesh::MetaData & meta_data = mesh::MetaData::get(*part); + const mesh::MetaData & metaData = mesh::MetaData::get(*part); - const std::vector &fields = meta_data.get_fields(); + const std::vector &fields = metaData.get_fields(); std::vector::const_iterator I = fields.begin(); while (I != fields.end()) { const mesh::FieldBase *f = *I ; ++I ; @@ -2962,38 +3020,40 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta stk::mesh::EntityRank rank = get_output_rank(params); std::vector nodes; - size_t num_nodes = get_entities_for_nodeblock(params, part, rank, + size_t numNodes = get_entities_for_nodeblock(params, part, rank, nodes, true); - std::vector node_ids; node_ids.reserve(num_nodes); - for(size_t i=0; i nodeIds; + nodeIds.reserve(numNodes); + for(size_t i=0; i &fields = meta_data.get_fields(); + const std::vector &fields = metaData.get_fields(); std::vector::const_iterator I = fields.begin(); while (I != fields.end()) { const mesh::FieldBase *f = *I ; ++I ; @@ -3005,41 +3065,41 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta } std::pair - get_parent_element(stk::io::OutputParams ¶ms, stk::mesh::Entity obj, const stk::mesh::Part* parent_block = nullptr) + get_parent_element(stk::io::OutputParams ¶ms, stk::mesh::Entity obj, const stk::mesh::Part* parentBlock = nullptr) { std::pair parent(stk::mesh::Entity(), 0U); const stk::mesh::BulkData& stkmesh = params.bulk_data(); - const stk::topology obj_topology = stkmesh.bucket(obj).topology(); + const stk::topology objTopology = stkmesh.bucket(obj).topology(); const stk::mesh::Entity* elems = stkmesh.begin_elements(obj); - const stk::mesh::ConnectivityOrdinal* elem_ordinals = stkmesh.begin_element_ordinals(obj); - const stk::mesh::Permutation* elem_permutations = stkmesh.begin_element_permutations(obj); + const stk::mesh::ConnectivityOrdinal* elemOrdinals = stkmesh.begin_element_ordinals(obj); + const stk::mesh::Permutation* elemPermutations = stkmesh.begin_element_permutations(obj); const stk::mesh::Selector* subsetSelector = params.get_subset_selector(); bool activeOnly = subsetSelector != nullptr; for(unsigned ielem = 0, e = stkmesh.num_elements(obj); ielem < e; ++ielem) { stk::mesh::Entity elem = elems[ielem]; - unsigned elem_side_ordinal = elem_ordinals[ielem]; + unsigned elemSideOrdinal = elemOrdinals[ielem]; stk::mesh::Bucket &elemBucket = stkmesh.bucket(elem); if(stkmesh.bucket(elem).owned() && (!activeOnly || (activeOnly && (*subsetSelector)(elemBucket)))) { - if((parent_block == nullptr && obj_topology.is_positive_polarity(elem_permutations[ielem])) || - (parent_block != nullptr && contain(stkmesh, elem, parent_block))) { + if((parentBlock == nullptr && objTopology.is_positive_polarity(elemPermutations[ielem])) || + (parentBlock != nullptr && contain(stkmesh, elem, parentBlock))) { if(params.has_output_selector(stk::topology::ELEMENT_RANK) && !params.get_is_restart()) { // See if elem is a member of any of the includedMeshBlocks. const stk::mesh::Selector* outputSelector = params.get_output_selector(stk::topology::ELEMENT_RANK); if((*outputSelector)(elemBucket)) { parent.first = elem; - parent.second = elem_side_ordinal; + parent.second = elemSideOrdinal; return parent; } return parent; } else { parent.first = elem; - parent.second = elem_side_ordinal; + parent.second = elemSideOrdinal; } return parent; } @@ -3053,8 +3113,8 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta const std::vector& meshObjects) { const stk::mesh::BulkData& stkmesh = params.bulk_data(); - bool skin_mesh = params.has_skin_mesh_selector(); - if(!skin_mesh) return; // This map only supported for skinning the mesh. + bool skinMesh = params.has_skin_mesh_selector(); + if(!skinMesh) return; // This map only supported for skinning the mesh. size_t entitySize = block->get_property("entity_count").get_int(); if(!block->field_exists("skin")) { @@ -3062,36 +3122,36 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta } size_t count = block->get_field("skin").raw_count(); - int map_size = block->get_field("skin").get_size(); - std::vector elem_face(map_size); + int mapSize = block->get_field("skin").get_size(); + std::vector elemFace(mapSize); if(count > 0) { // global element id + local face of that element. size_t i = 0; - size_t face_count = meshObjects.size(); - assert(face_count == count); - for(size_t j = 0; j < face_count; j++) { + size_t faceCount = meshObjects.size(); + assert(faceCount == count); + for(size_t j = 0; j < faceCount; j++) { stk::mesh::Entity face = meshObjects[j]; - std::pair elem_face_pair = get_parent_element(params, face); - if(stkmesh.is_valid(elem_face_pair.first)) { - elem_face[i++] = stkmesh.identifier(elem_face_pair.first); - elem_face[i++] = elem_face_pair.second + 1; + std::pair elemFacePair = get_parent_element(params, face); + if(stkmesh.is_valid(elemFacePair.first)) { + elemFace[i++] = stkmesh.identifier(elemFacePair.first); + elemFace[i++] = elemFacePair.second + 1; } } assert(i == 2 * count); } - block->put_field_data("skin", elem_face.data(), map_size); + block->put_field_data("skin", elemFace.data(), mapSize); } template void output_element_block(stk::io::OutputParams ¶ms, Ioss::ElementBlock *block) { const stk::mesh::BulkData &bulk = params.bulk_data(); - const stk::mesh::MetaData & meta_data = bulk.mesh_meta_data(); + const stk::mesh::MetaData & metaData = bulk.mesh_meta_data(); const std::string& name = block->name(); - mesh::Part* part = getPart( meta_data, name); + mesh::Part* part = getPart( metaData, name); assert(part != nullptr); stk::topology topo = part->topology(); @@ -3102,52 +3162,52 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta std::vector elements; stk::mesh::EntityRank type = part_primary_entity_rank(*part); if (params.has_skin_mesh_selector()) { - type = meta_data.side_rank(); + type = metaData.side_rank(); } - size_t num_elems = get_entities(params, *part, type, elements, false); + size_t numElems = get_entities(params, *part, type, elements, false); - if (num_elems > 0 && topo == stk::topology::INVALID_TOPOLOGY) { + if (numElems > 0 && topo == stk::topology::INVALID_TOPOLOGY) { std::ostringstream msg ; msg << " INTERNAL_ERROR: Part " << part->name() << " returned INVALID from get_topology()"; throw std::runtime_error( msg.str() ); } - size_t nodes_per_elem = block->get_property("topology_node_count").get_int(); + size_t nodesPerElem = block->get_property("topology_node_count").get_int(); - std::vector elem_ids; elem_ids.reserve(num_elems == 0 ? 1 : num_elems); - std::vector connectivity; connectivity.reserve( (num_elems*nodes_per_elem) == 0 ? 1 : (num_elems*nodes_per_elem)); + std::vector elemIds; + elemIds.reserve(numElems == 0 ? 1 : numElems); + std::vector connectivity; + connectivity.reserve( (numElems*nodesPerElem) == 0 ? 1 : (numElems*nodesPerElem)); - for (size_t i = 0; i < num_elems; ++i) { + for (size_t i = 0; i < numElems; ++i) { + elemIds.push_back(bulk.identifier(elements[i])); + stk::mesh::Entity const * elemNodes = bulk.begin_nodes(elements[i]); - elem_ids.push_back(bulk.identifier(elements[i])); - - stk::mesh::Entity const * elem_nodes = bulk.begin_nodes(elements[i]); - - for (size_t j = 0; j < nodes_per_elem; ++j) { - connectivity.push_back(bulk.identifier(elem_nodes[j])); + for (size_t j = 0; j < nodesPerElem; ++j) { + connectivity.push_back(bulk.identifier(elemNodes[j])); } } - const size_t num_ids_written = block->put_field_data("ids", elem_ids); - const size_t num_con_written = block->put_field_data("connectivity", connectivity); + const size_t numIdsWritten = block->put_field_data("ids", elemIds); + const size_t numConWritten = block->put_field_data("connectivity", connectivity); - if ( num_elems != num_ids_written || num_elems != num_con_written ) { + if ( numElems != numIdsWritten || numElems != numConWritten ) { std::ostringstream msg ; msg << " FAILED in Ioss::ElementBlock::put_field_data:" << std::endl ; - msg << " num_elems = " << num_elems << std::endl ; - msg << " num_ids_written = " << num_ids_written << std::endl ; - msg << " num_connectivity_written = " << num_con_written << std::endl ; + msg << " numElems = " << numElems << std::endl ; + msg << " numIdsWritten = " << numIdsWritten << std::endl ; + msg << " num_connectivity_written = " << numConWritten << std::endl ; throw std::runtime_error( msg.str() ); } - stk::mesh::EntityRank elem_rank = stk::topology::ELEMENT_RANK; - const std::vector &fields = meta_data.get_fields(); + stk::mesh::EntityRank elemRank = stk::topology::ELEMENT_RANK; + const std::vector &fields = metaData.get_fields(); std::vector::const_iterator I = fields.begin(); while (I != fields.end()) { const mesh::FieldBase *f = *I ; ++I ; const Ioss::Field::RoleType *role = stk::io::get_field_role(*f); if (role != nullptr && *role == Ioss::Field::ATTRIBUTE) { - const mesh::FieldBase::Restriction &res = stk::mesh::find_restriction(*f, elem_rank, *part); + const mesh::FieldBase::Restriction &res = stk::mesh::find_restriction(*f, elemRank, *part); if (res.num_scalars_per_entity() > 0) { stk::io::field_data_to_ioss(bulk, f, elements, block, f->name(), Ioss::Field::ATTRIBUTE); } @@ -3167,19 +3227,19 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta { const stk::mesh::MetaData & metaData = bulk.mesh_meta_data(); const std::string& name = ns->name(); - const std::string dfName = s_distribution_factors + "_" + name; - stk::mesh::Field* df_field = metaData.get_field(stk::topology::NODE_RANK, dfName); + const std::string dfName = s_distributionFactors + "_" + name; + stk::mesh::Field* dfField = metaData.get_field(stk::topology::NODE_RANK, dfName); - if(df_field != nullptr) { - const stk::mesh::FieldBase::Restriction& res = stk::mesh::find_restriction(*df_field, stk::topology::NODE_RANK, *part); + if(dfField != nullptr) { + const stk::mesh::FieldBase::Restriction& res = stk::mesh::find_restriction(*dfField, stk::topology::NODE_RANK, *part); if(res.num_scalars_per_entity() > 0) { - stk::io::field_data_to_ioss(bulk, df_field, nodes, ns, s_distribution_factors, Ioss::Field::MESH); + stk::io::field_data_to_ioss(bulk, dfField, nodes, ns, s_distributionFactors, Ioss::Field::MESH); } } else { - assert(ns->field_exists(s_distribution_factors)); - size_t df_size = ns->get_field(s_distribution_factors).raw_count(); + assert(ns->field_exists(s_distributionFactors)); + size_t dfSize = ns->get_field(s_distributionFactors).raw_count(); std::vector df; - df.reserve(df_size); + df.reserve(dfSize); const auto* const nodeFactorVar = get_distribution_factor_field(*part); if((nodeFactorVar != nullptr) && (nodeFactorVar->entity_rank() == stk::topology::NODE_RANK)) { nodeFactorVar->sync_to_host(); @@ -3192,7 +3252,7 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta df.push_back(1.0); } } - ns->put_field_data(s_distribution_factors, df); + ns->put_field_data(s_distributionFactors, df); } } @@ -3200,9 +3260,9 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta void output_node_set(stk::io::OutputParams ¶ms, Ioss::NodeSet *ns) { const stk::mesh::BulkData &bulk = params.bulk_data(); - const stk::mesh::MetaData & meta_data = bulk.mesh_meta_data(); + const stk::mesh::MetaData & metaData = bulk.mesh_meta_data(); const std::string& name = ns->name(); - mesh::Part* part = getPart( meta_data, name); + mesh::Part* part = getPart( metaData, name); // If part is null, then it is possible that this nodeset is a "viz nodeset" which // means that it is a nodeset containing the nodes of an element block. @@ -3210,8 +3270,8 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta // that name. if (part == nullptr) { if (ns->property_exists(base_stk_part_name)) { - std::string base_name = ns->get_property(base_stk_part_name).get_string(); - part = getPart( meta_data, base_name); + std::string baseName = ns->get_property(base_stk_part_name).get_string(); + part = getPart( metaData, baseName); } if (part == nullptr) { std::ostringstream msg ; @@ -3224,26 +3284,27 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta std::vector nodes; mesh::EntityRank rank = get_output_rank(params); - size_t num_nodes = get_entities_for_nodeblock(params, *part, rank, nodes, true); + size_t numNodes = get_entities_for_nodeblock(params, *part, rank, nodes, true); - std::vector node_ids; node_ids.reserve(num_nodes); - for(size_t i=0; i node_ids; + node_ids.reserve(numNodes); + for(size_t i=0; iput_field_data("ids", node_ids); - if ( num_nodes != num_ids_written ) { + size_t numIdsWritten = ns->put_field_data("ids", node_ids); + if ( numNodes != numIdsWritten ) { std::ostringstream msg ; msg << " FAILED in Ioss::NodeSet::output_node_set:" - << " num_nodes = " << num_nodes - << ", num_ids_written = " << num_ids_written; + << " numNodes = " << numNodes + << ", numIdsWritten = " << numIdsWritten; throw std::runtime_error( msg.str() ); } output_nodeset_distribution_factor(bulk, ns, part, nodes); - const std::vector &fields = meta_data.get_fields(); + const std::vector &fields = metaData.get_fields(); std::vector::const_iterator I = fields.begin(); while (I != fields.end()) { const mesh::FieldBase *f = *I ; ++I ; @@ -3260,28 +3321,28 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta template void output_communication_maps(stk::io::OutputParams ¶ms) { - Ioss::Region &io_region = params.io_region(); + Ioss::Region &ioRegion = params.io_region(); const stk::mesh::BulkData &bulk = params.bulk_data(); mesh::EntityRank rank = get_output_rank(params); - const stk::mesh::Selector *subset_selector = params.get_subset_selector(); - const stk::mesh::Selector *output_selector = params.get_output_selector(rank); + const stk::mesh::Selector *subsetSelector = params.get_subset_selector(); + const stk::mesh::Selector *outputSelector = params.get_output_selector(rank); if (bulk.parallel_size() > 1) { const stk::mesh::MetaData & meta = bulk.mesh_meta_data(); mesh::Selector selector = meta.globally_shared_part(); - if (subset_selector) selector &= *subset_selector; - if (output_selector) selector &= *output_selector; + if (subsetSelector) selector &= *subsetSelector; + if (outputSelector) selector &= *outputSelector; std::vector entities; get_selected_nodes(params, selector, entities); - const std::string cs_name("node_symm_comm_spec"); - Ioss::CommSet * io_cs = io_region.get_commset(cs_name); - STKIORequire(io_cs != nullptr); + const std::string csName("node_symm_comm_spec"); + Ioss::CommSet * ioCs = ioRegion.get_commset(csName); + STKIORequire(ioCs != nullptr); // Allocate data space to store pair - assert(io_cs->field_exists("entity_processor")); - size_t size = io_cs->get_field("entity_processor").raw_count(); + assert(ioCs->field_exists("entity_processor")); + size_t size = ioCs->get_field("entity_processor").raw_count(); std::vector ep; ep.reserve(size*2); @@ -3295,7 +3356,7 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta } } assert(size*2 == ep.size()); - io_cs->put_field_data("entity_processor", ep); + ioCs->put_field_data("entity_processor", ep); } } @@ -3304,14 +3365,13 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta { const stk::mesh::MetaData & meta = params.bulk_data().mesh_meta_data(); - size_t block_count = ss->block_count(); - for (size_t i=0; i < block_count; i++) { + size_t blockCount = ss->block_count(); + for (size_t i=0; i < blockCount; i++) { Ioss::SideBlock *block = ss->get_block(i); if (stk::io::include_entity(block)) { stk::mesh::Part * part = getPart(meta, block->name()); const Ioss::ElementTopology *parent_topology = block->parent_element_topology(); - stk::io::write_side_data_to_ioss(params, *block, part, - parent_topology); + stk::io::write_side_data_to_ioss(params, *block, part, parent_topology); } } } @@ -3320,9 +3380,9 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta void output_face_block(stk::io::OutputParams ¶ms, Ioss::FaceBlock *fb) { const stk::mesh::BulkData &bulk = params.bulk_data(); - const stk::mesh::MetaData & meta_data = bulk.mesh_meta_data(); + const stk::mesh::MetaData & metaData = bulk.mesh_meta_data(); const std::string& name = fb->name(); - mesh::Part* part = getPart( meta_data, name); + mesh::Part* part = getPart( metaData, name); assert(part != nullptr); stk::topology topo = part->topology(); @@ -3334,42 +3394,41 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta std::vector faces; stk::mesh::EntityRank type = part_primary_entity_rank(*part); - size_t num_faces = get_entities(params, *part, type, faces, false); - - size_t nodes_per_face = fb->get_property("topology_node_count").get_int(); - - std::vector face_ids; face_ids.reserve(num_faces == 0 ? 1 : num_faces); - std::vector connectivity; connectivity.reserve( (num_faces*nodes_per_face) == 0 ? 1 : (num_faces*nodes_per_face)); + size_t numFaces = get_entities(params, *part, type, faces, false); - for (size_t i = 0; i < num_faces; ++i) { + size_t nodesPerFace = fb->get_property("topology_node_count").get_int(); - face_ids.push_back(bulk.identifier(faces[i])); + std::vector faceIds; + faceIds.reserve(numFaces == 0 ? 1 : numFaces); + std::vector connectivity; connectivity.reserve( (numFaces*nodesPerFace) == 0 ? 1 : (numFaces*nodesPerFace)); - stk::mesh::Entity const * face_nodes = bulk.begin_nodes(faces[i]); + for (size_t i = 0; i < numFaces; ++i) { + faceIds.push_back(bulk.identifier(faces[i])); + stk::mesh::Entity const * faceNodes = bulk.begin_nodes(faces[i]); - for (size_t j = 0; j < nodes_per_face; ++j) { - connectivity.push_back(bulk.identifier(face_nodes[j])); + for (size_t j = 0; j < nodesPerFace; ++j) { + connectivity.push_back(bulk.identifier(faceNodes[j])); } } - const size_t num_ids_written = fb->put_field_data("ids", face_ids); - const size_t num_con_written = fb->put_field_data("connectivity", connectivity); + const size_t numIdsWritten = fb->put_field_data("ids", faceIds); + const size_t numConWritten = fb->put_field_data("connectivity", connectivity); - if ( num_faces != num_ids_written || num_faces != num_con_written ) { + if ( numFaces != numIdsWritten || numFaces != numConWritten ) { std::ostringstream msg ; msg << " FAILED in Ioss::FaceBlock::put_field_data:" << std::endl ; - msg << " num_faces = " << num_faces << std::endl ; - msg << " num_ids_written = " << num_ids_written << std::endl ; - msg << " num_connectivity_written = " << num_con_written << std::endl ; + msg << " numFaces = " << numFaces << std::endl ; + msg << " numIdsWritten = " << numIdsWritten << std::endl ; + msg << " num_connectivity_written = " << numConWritten << std::endl ; throw std::runtime_error( msg.str() ); } - stk::mesh::EntityRank face_rank = stk::topology::FACE_RANK; - const std::vector &fields = meta_data.get_fields(); + stk::mesh::EntityRank faceRank = stk::topology::FACE_RANK; + const std::vector &fields = metaData.get_fields(); for(const mesh::FieldBase* f : fields) { const Ioss::Field::RoleType *role = stk::io::get_field_role(*f); if (role != nullptr && *role == Ioss::Field::ATTRIBUTE) { - const mesh::FieldBase::Restriction &res = stk::mesh::find_restriction(*f, face_rank, *part); + const mesh::FieldBase::Restriction &res = stk::mesh::find_restriction(*f, faceRank, *part); if (res.num_scalars_per_entity() > 0) { stk::io::field_data_to_ioss(bulk, f, faces, fb, f->name(), Ioss::Field::ATTRIBUTE); } @@ -3381,9 +3440,9 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta void output_edge_block(stk::io::OutputParams ¶ms, Ioss::EdgeBlock *eb) { const stk::mesh::BulkData &bulk = params.bulk_data(); - const stk::mesh::MetaData & meta_data = bulk.mesh_meta_data(); + const stk::mesh::MetaData & metaData = bulk.mesh_meta_data(); const std::string& name = eb->name(); - mesh::Part* part = getPart( meta_data, name); + mesh::Part* part = getPart( metaData, name); assert(part != nullptr); stk::topology topo = part->topology(); @@ -3395,38 +3454,38 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta std::vector edges; stk::mesh::EntityRank type = part_primary_entity_rank(*part); - size_t num_edges = get_entities(params, *part, type, edges, false); + size_t numEdges = get_entities(params, *part, type, edges, false); - size_t nodes_per_edge = eb->get_property("topology_node_count").get_int(); + size_t nodesPerEdge = eb->get_property("topology_node_count").get_int(); - std::vector edge_ids; edge_ids.reserve(num_edges == 0 ? 1 : num_edges); - std::vector connectivity; connectivity.reserve( (num_edges*nodes_per_edge) == 0 ? 1 : (num_edges*nodes_per_edge)); + std::vector edgeIds; + edgeIds.reserve(numEdges == 0 ? 1 : numEdges); + std::vector connectivity; + connectivity.reserve( (numEdges*nodesPerEdge) == 0 ? 1 : (numEdges*nodesPerEdge)); - for (size_t i = 0; i < num_edges; ++i) { + for (size_t i = 0; i < numEdges; ++i) { + edgeIds.push_back(bulk.identifier(edges[i])); + stk::mesh::Entity const * edgeNodes = bulk.begin_nodes(edges[i]); - edge_ids.push_back(bulk.identifier(edges[i])); - - stk::mesh::Entity const * edge_nodes = bulk.begin_nodes(edges[i]); - - for (size_t j = 0; j < nodes_per_edge; ++j) { - connectivity.push_back(bulk.identifier(edge_nodes[j])); + for (size_t j = 0; j < nodesPerEdge; ++j) { + connectivity.push_back(bulk.identifier(edgeNodes[j])); } } - const size_t num_ids_written = eb->put_field_data("ids", edge_ids); - const size_t num_con_written = eb->put_field_data("connectivity", connectivity); + const size_t numIdsWritten = eb->put_field_data("ids", edgeIds); + const size_t numConWritten = eb->put_field_data("connectivity", connectivity); - if ( num_edges != num_ids_written || num_edges != num_con_written ) { + if ( numEdges != numIdsWritten || numEdges != numConWritten ) { std::ostringstream msg ; msg << " FAILED in Ioss::EdgeBlock::put_field_data:" << std::endl ; - msg << " num_edges = " << num_edges << std::endl ; - msg << " num_ids_written = " << num_ids_written << std::endl ; - msg << " num_connectivity_written = " << num_con_written << std::endl ; + msg << " numEdges = " << numEdges << std::endl ; + msg << " numIdsWritten = " << numIdsWritten << std::endl ; + msg << " num_connectivity_written = " << numConWritten << std::endl ; throw std::runtime_error( msg.str() ); } stk::mesh::EntityRank edge_rank = stk::topology::EDGE_RANK; - const std::vector &fields = meta_data.get_fields(); + const std::vector &fields = metaData.get_fields(); for(const mesh::FieldBase* f : fields) { const Ioss::Field::RoleType *role = stk::io::get_field_role(*f); if (role != nullptr && *role == Ioss::Field::ATTRIBUTE) { @@ -3442,11 +3501,11 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta void write_output_db_node_block(stk::io::OutputParams ¶ms) { const stk::mesh::MetaData & meta = params.bulk_data().mesh_meta_data(); - Ioss::Region &io_region = params.io_region(); + Ioss::Region &ioRegion = params.io_region(); - bool ints64bit = db_api_int_size(&io_region) == 8; + bool ints64bit = db_api_int_size(&ioRegion) == 8; - Ioss::NodeBlock & nb = *io_region.get_node_blocks()[0]; + Ioss::NodeBlock & nb = *ioRegion.get_node_blocks()[0]; if (ints64bit) output_node_block(params, nb, meta.universal_part()); @@ -3456,13 +3515,13 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta void write_output_db_element_blocks(stk::io::OutputParams ¶ms) { - Ioss::Region &io_region = params.io_region(); - bool ints64bit = db_api_int_size(&io_region) == 8; + Ioss::Region &ioRegion = params.io_region(); + bool ints64bit = db_api_int_size(&ioRegion) == 8; //---------------------------------- - const Ioss::ElementBlockContainer& elem_blocks = io_region.get_element_blocks(); - for(Ioss::ElementBlockContainer::const_iterator it = elem_blocks.begin(); - it != elem_blocks.end(); ++it) { + const Ioss::ElementBlockContainer& elemBlocks = ioRegion.get_element_blocks(); + for(Ioss::ElementBlockContainer::const_iterator it = elemBlocks.begin(); + it != elemBlocks.end(); ++it) { if (ints64bit) output_element_block(params, *it); else @@ -3473,21 +3532,21 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta template void write_output_db_for_entitysets_and_comm_map(stk::io::OutputParams ¶ms) { - Ioss::Region &io_region = params.io_region(); + Ioss::Region &ioRegion = params.io_region(); - for(Ioss::NodeSet *ns : io_region.get_nodesets()) { + for(Ioss::NodeSet *ns : ioRegion.get_nodesets()) { output_node_set(params, ns); } - for(Ioss::SideSet *ss : io_region.get_sidesets()) { + for(Ioss::SideSet *ss : ioRegion.get_sidesets()) { output_side_set(params, ss); } - for(Ioss::EdgeBlock *eb: io_region.get_edge_blocks()) { + for(Ioss::EdgeBlock *eb: ioRegion.get_edge_blocks()) { output_edge_block(params, eb); } - for(Ioss::FaceBlock *fb: io_region.get_face_blocks()) { + for(Ioss::FaceBlock *fb: ioRegion.get_face_blocks()) { output_face_block(params, fb); } @@ -3496,11 +3555,11 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta void write_output_db_rest_of_mesh(stk::io::OutputParams ¶ms) { - Ioss::Region &io_region = params.io_region(); + Ioss::Region &ioRegion = params.io_region(); write_output_db_element_blocks(params); - bool ints64bit = db_api_int_size(&io_region) == 8; + bool ints64bit = db_api_int_size(&ioRegion) == 8; if (ints64bit) { write_output_db_for_entitysets_and_comm_map(params); @@ -3511,12 +3570,12 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta void write_output_db(stk::io::OutputParams ¶ms) { - Ioss::Region &io_region = params.io_region(); + Ioss::Region &ioRegion = params.io_region(); - io_region.begin_mode( Ioss::STATE_MODEL ); + ioRegion.begin_mode( Ioss::STATE_MODEL ); write_output_db_node_block(params); write_output_db_rest_of_mesh(params); - io_region.end_mode( Ioss::STATE_MODEL ); + ioRegion.end_mode( Ioss::STATE_MODEL ); } //---------------------------------------------------------------------- @@ -3541,14 +3600,14 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta } void set_distribution_factor_field(stk::mesh::Part &p, - const stk::mesh::FieldBase &df_field) + const stk::mesh::FieldBase &dfField) { stk::mesh::MetaData &m = mesh::MetaData::get(p); if (const stk::mesh::FieldBase * existingDistFactField = p.attribute()) { m.remove_attribute(p, existingDistFactField); } - m.declare_attribute_no_delete(p, &df_field); + m.declare_attribute_no_delete(p, &dfField); } const Ioss::Field::RoleType* get_field_role(const stk::mesh::FieldBase &f) @@ -3558,42 +3617,42 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta void set_field_role(stk::mesh::FieldBase &f, const Ioss::Field::RoleType &role) { - Ioss::Field::RoleType *my_role = new Ioss::Field::RoleType(role); + Ioss::Field::RoleType *myRole = new Ioss::Field::RoleType(role); stk::mesh::MetaData &m = mesh::MetaData::get(f); - const Ioss::Field::RoleType *check = m.declare_attribute_with_delete(f, my_role); - if ( check != my_role ) { - if (*check != *my_role) { + const Ioss::Field::RoleType *check = m.declare_attribute_with_delete(f, myRole); + if ( check != myRole ) { + if (*check != *myRole) { std::ostringstream msg ; msg << " FAILED in IossBridge -- set_field_role:" << " The role type for field name= " << f.name() << " was already set to " << *check - << ", so it is not possible to change it to " << *my_role; - delete my_role; + << ", so it is not possible to change it to " << *myRole; + delete myRole; throw std::runtime_error( msg.str() ); } - delete my_role; + delete myRole; } } namespace { void define_input_nodeblock_fields(Ioss::Region ®ion, stk::mesh::MetaData &meta) { - const Ioss::NodeBlockContainer& node_blocks = region.get_node_blocks(); - assert(node_blocks.size() == 1); + const Ioss::NodeBlockContainer& nodeBlocks = region.get_node_blocks(); + assert(nodeBlocks.size() == 1); - Ioss::NodeBlock *nb = node_blocks[0]; + Ioss::NodeBlock *nb = nodeBlocks[0]; stk::io::define_io_fields(nb, Ioss::Field::TRANSIENT, meta.universal_part(), stk::topology::NODE_RANK); } void define_input_elementblock_fields(Ioss::Region ®ion, stk::mesh::MetaData &meta) { - const Ioss::ElementBlockContainer& elem_blocks = region.get_element_blocks(); - for(size_t i=0; i < elem_blocks.size(); i++) { - if (stk::io::include_entity(elem_blocks[i])) { - stk::mesh::Part* const part = meta.get_part(elem_blocks[i]->name()); + const Ioss::ElementBlockContainer& elemBlocks = region.get_element_blocks(); + for(size_t i=0; i < elemBlocks.size(); i++) { + if (stk::io::include_entity(elemBlocks[i])) { + stk::mesh::Part* const part = meta.get_part(elemBlocks[i]->name()); assert(part != nullptr); - stk::io::define_io_fields(elem_blocks[i], Ioss::Field::TRANSIENT, + stk::io::define_io_fields(elemBlocks[i], Ioss::Field::TRANSIENT, *part, part_primary_entity_rank(*part)); } } @@ -3617,9 +3676,9 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta unsigned sideRank = meta.side_rank(); if (meta.spatial_dimension() <= sideRank) return; - const Ioss::SideSetContainer& side_sets = region.get_sidesets(); - for(Ioss::SideSetContainer::const_iterator it = side_sets.begin(); - it != side_sets.end(); ++it) { + const Ioss::SideSetContainer& sideSets = region.get_sidesets(); + for(Ioss::SideSetContainer::const_iterator it = sideSets.begin(); + it != sideSets.end(); ++it) { Ioss::SideSet *entity = *it; if (stk::io::include_entity(entity)) { const Ioss::SideBlockContainer& blocks = entity->get_side_blocks(); @@ -3637,12 +3696,12 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta void define_input_face_block_fields(Ioss::Region ®ion, stk::mesh::MetaData &meta) { - const Ioss::FaceBlockContainer& face_blocks = region.get_face_blocks(); - for(size_t i=0; i < face_blocks.size(); i++) { - if (stk::io::include_entity(face_blocks[i])) { - stk::mesh::Part* const part = meta.get_part(face_blocks[i]->name()); + const Ioss::FaceBlockContainer& faceBlocks = region.get_face_blocks(); + for(size_t i=0; i < faceBlocks.size(); i++) { + if (stk::io::include_entity(faceBlocks[i])) { + stk::mesh::Part* const part = meta.get_part(faceBlocks[i]->name()); assert(part != nullptr); - stk::io::define_io_fields(face_blocks[i], Ioss::Field::TRANSIENT, + stk::io::define_io_fields(faceBlocks[i], Ioss::Field::TRANSIENT, *part, part_primary_entity_rank(*part)); } } @@ -3650,12 +3709,12 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta void define_input_edge_block_fields(Ioss::Region ®ion, stk::mesh::MetaData &meta) { - const Ioss::EdgeBlockContainer& edge_blocks = region.get_edge_blocks(); - for(size_t i=0; i < edge_blocks.size(); i++) { - if (stk::io::include_entity(edge_blocks[i])) { - stk::mesh::Part* const part = meta.get_part(edge_blocks[i]->name()); + const Ioss::EdgeBlockContainer& edgeBlocks = region.get_edge_blocks(); + for(size_t i=0; i < edgeBlocks.size(); i++) { + if (stk::io::include_entity(edgeBlocks[i])) { + stk::mesh::Part* const part = meta.get_part(edgeBlocks[i]->name()); assert(part != nullptr); - stk::io::define_io_fields(edge_blocks[i], Ioss::Field::TRANSIENT, + stk::io::define_io_fields(edgeBlocks[i], Ioss::Field::TRANSIENT, *part, part_primary_entity_rank(*part)); } } @@ -3741,64 +3800,65 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta } } - void put_field_data(stk::mesh::BulkData &bulk, - stk::io::OutputParams& params, + void put_field_data(stk::io::OutputParams& params, stk::mesh::Part &part, - stk::mesh::EntityRank part_type, - Ioss::GroupingEntity *io_entity, - Ioss::Field::RoleType filter_role) + stk::mesh::EntityRank partType, + Ioss::GroupingEntity *ioEntity, + Ioss::Field::RoleType filterRole) { std::vector entities; - stk::io::get_output_entity_list(io_entity, part_type, params, entities); + stk::io::get_output_entity_list(ioEntity, partType, params, entities); + const stk::mesh::BulkData &bulk = params.bulk_data(); stk::mesh::MetaData & meta = stk::mesh::MetaData::get(part); const std::vector &fields = meta.get_fields(); std::vector::const_iterator I = fields.begin(); while (I != fields.end()) { const stk::mesh::FieldBase *f = *I; ++I; - if (stk::io::is_valid_part_field(f, part_type, part, filter_role)) { - stk::io::field_data_to_ioss(bulk, f, entities, io_entity, f->name(), filter_role); + if (stk::io::is_valid_part_field(f, partType, part, filterRole)) { + stk::io::field_data_to_ioss(bulk, f, entities, ioEntity, f->name(), filterRole); } } } void put_field_data(stk::mesh::BulkData &bulk, stk::mesh::Part &part, - stk::mesh::EntityRank part_type, - Ioss::GroupingEntity *io_entity, - Ioss::Field::RoleType filter_role) + stk::mesh::EntityRank partType, + Ioss::GroupingEntity *ioEntity, + Ioss::Field::RoleType filterRole) { stk::io::OutputParams params(bulk); - put_field_data(bulk, params, part, part_type, io_entity, filter_role); + put_field_data(params, part, partType, ioEntity, filterRole); } struct DefineOutputFunctor { - void operator()(stk::mesh::BulkData &bulk, stk::io::OutputParams& params, stk::mesh::Part &part, stk::mesh::EntityRank rank, Ioss::GroupingEntity *ge, Ioss::Field::RoleType role) + void operator()(stk::io::OutputParams& params, stk::mesh::Part &part, stk::mesh::EntityRank rank, Ioss::GroupingEntity *ge, Ioss::Field::RoleType role) { stk::io::ioss_add_fields(part, rank, ge, role); } }; struct ProcessOutputFunctor { - void operator()(stk::mesh::BulkData &bulk, stk::io::OutputParams& params, stk::mesh::Part &part, stk::mesh::EntityRank rank, Ioss::GroupingEntity *ge, Ioss::Field::RoleType role) - { put_field_data(bulk, params, part, rank, ge, role); } + void operator()(stk::io::OutputParams& params, stk::mesh::Part &part, stk::mesh::EntityRank rank, Ioss::GroupingEntity *ge, Ioss::Field::RoleType role) + { put_field_data(params, part, rank, ge, role); } }; template - void process_field_loop(Ioss::Region ®ion, - stk::mesh::BulkData &bulk, T& callable) + void process_field_loop(stk::io::OutputParams& params, T& callable) { - stk::mesh::MetaData & meta = bulk.mesh_meta_data(); - stk::io::OutputParams params(region, bulk); + Ioss::Region ®ion = params.io_region(); + const stk::mesh::BulkData &bulk = params.bulk_data(); + + const stk::mesh::MetaData & meta = bulk.mesh_meta_data(); Ioss::NodeBlock *nb = region.get_node_blocks()[0]; - callable(bulk, params, meta.universal_part(), stk::topology::NODE_RANK, + callable(params, meta.universal_part(), stk::topology::NODE_RANK, dynamic_cast(nb), Ioss::Field::TRANSIENT); - const stk::mesh::PartVector & all_parts = meta.get_parts(); + const stk::mesh::PartVector & allParts = meta.get_parts(); for ( stk::mesh::PartVector::const_iterator - ip = all_parts.begin(); ip != all_parts.end(); ++ip ) { + ip = allParts.begin(); ip != allParts.end(); ++ip ) { stk::mesh::Part * const part = *ip; @@ -3812,12 +3872,12 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta for (int i=0; i < block_count; i++) { Ioss::SideBlock *fb = sset->get_block(i); - callable(bulk, params, *part, + callable(params, *part, stk::mesh::EntityRank( part->primary_entity_rank() ), dynamic_cast(fb), Ioss::Field::TRANSIENT); } } else { - callable(bulk, params, *part, + callable(params, *part, stk::mesh::EntityRank( part->primary_entity_rank() ), entity, Ioss::Field::TRANSIENT); } @@ -3825,18 +3885,16 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta } } - void process_output_request(Ioss::Region ®ion, - stk::mesh::BulkData &bulk, - int step) + void process_output_request(stk::io::OutputParams& params, int step) { - region.begin_state(step); + params.io_region().begin_state(step); ProcessOutputFunctor functor; - process_field_loop(region, bulk, functor); - region.end_state(step); + process_field_loop(params, functor); + params.io_region().end_state(step); } template - void output_node_sharing_info( Ioss::CommSet* io_cs, const EntitySharingInfo &nodeSharingInfo) + void output_node_sharing_info( Ioss::CommSet* ioCs, const EntitySharingInfo &nodeSharingInfo) { std::vector entity_proc(2*nodeSharingInfo.size()); int counter = 0; @@ -3847,28 +3905,28 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta counter += 2; } size_t size_field = entity_proc.size()*(sizeof(INT)); - io_cs->put_field_data("entity_processor", entity_proc.data(), size_field); + ioCs->put_field_data("entity_processor", entity_proc.data(), size_field); } void write_node_sharing_info(Ioss::DatabaseIO *dbo, const EntitySharingInfo &nodeSharingInfo) { bool ints64bit = db_api_int_size(dbo->get_region()) == 8; - Ioss::CommSet* io_cs = dbo->get_region()->get_commset("commset_node"); - if(io_cs) + Ioss::CommSet* ioCs = dbo->get_region()->get_commset("commset_node"); + if(ioCs) { if (ints64bit) - output_node_sharing_info(io_cs, nodeSharingInfo); + output_node_sharing_info(ioCs, nodeSharingInfo); else - output_node_sharing_info(io_cs, nodeSharingInfo); + output_node_sharing_info(ioCs, nodeSharingInfo); } } Ioss::DatabaseIO *create_database_for_subdomain(const std::string &baseFilename, - int index_subdomain, - int num_subdomains) + int indexSubdomain, + int numSubdomains) { - std::string parallelFilename{construct_filename_for_serial_or_parallel(baseFilename, num_subdomains, index_subdomain)}; + std::string parallelFilename{construct_filename_for_serial_or_parallel(baseFilename, numSubdomains, indexSubdomain)}; std::string dbtype("exodusII"); Ioss::DatabaseIO *dbo = Ioss::IOFactory::create(dbtype, parallelFilename, Ioss::WRITE_RESULTS, MPI_COMM_SELF); @@ -3876,111 +3934,117 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta return dbo; } - void write_mesh_data_for_subdomain(Ioss::Region& out_region, stk::mesh::BulkData& bulkData, const EntitySharingInfo& nodeSharingInfo) + void write_mesh_data_for_subdomain(stk::io::OutputParams& params, const EntitySharingInfo& nodeSharingInfo) { - stk::io::OutputParams params(out_region, bulkData); - out_region.begin_mode(Ioss::STATE_DEFINE_MODEL); + Ioss::Region ®ion = params.io_region(); + region.begin_mode(Ioss::STATE_DEFINE_MODEL); stk::io::define_output_db_within_state_define(params, {}); - Ioss::CommSet *commset = new Ioss::CommSet(out_region.get_database(), "commset_node", "node", nodeSharingInfo.size()); + Ioss::CommSet *commset = new Ioss::CommSet(region.get_database(), "commset_node", "node", nodeSharingInfo.size()); commset->property_add(Ioss::Property("id", 1)); - out_region.add(commset); - out_region.end_mode(Ioss::STATE_DEFINE_MODEL); + region.add(commset); + region.end_mode(Ioss::STATE_DEFINE_MODEL); - out_region.begin_mode(Ioss::STATE_MODEL); + region.begin_mode(Ioss::STATE_MODEL); stk::io::write_output_db_node_block(params); - write_node_sharing_info(out_region.get_database(), nodeSharingInfo); + write_node_sharing_info(region.get_database(), nodeSharingInfo); stk::io::write_output_db_rest_of_mesh(params); - out_region.end_mode(Ioss::STATE_MODEL); + region.end_mode(Ioss::STATE_MODEL); } - int write_transient_data_for_subdomain(Ioss::Region &out_region, stk::mesh::BulkData& bulkData, double timeStep) + int write_transient_data_for_subdomain(stk::io::OutputParams& params, double timeStep) { - if(!out_region.transient_defined()) { - out_region.begin_mode(Ioss::STATE_DEFINE_TRANSIENT); + Ioss::Region &outRegion = params.io_region(); + + if(!outRegion.transient_defined()) { + outRegion.begin_mode(Ioss::STATE_DEFINE_TRANSIENT); DefineOutputFunctor functor; - process_field_loop(out_region, bulkData, functor); - out_region.end_mode(Ioss::STATE_DEFINE_TRANSIENT); + process_field_loop(params, functor); + outRegion.end_mode(Ioss::STATE_DEFINE_TRANSIENT); } - out_region.begin_mode(Ioss::STATE_TRANSIENT); - int out_step = out_region.add_state(timeStep); - process_output_request(out_region, bulkData, out_step); - out_region.end_mode(Ioss::STATE_TRANSIENT); + outRegion.begin_mode(Ioss::STATE_TRANSIENT); + int out_step = outRegion.add_state(timeStep); + process_output_request(params, out_step); + outRegion.end_mode(Ioss::STATE_TRANSIENT); return out_step; } - void write_file_for_subdomain(Ioss::Region &out_region, - stk::mesh::BulkData& bulkData, + void write_file_for_subdomain(stk::io::OutputParams& params, const EntitySharingInfo &nodeSharingInfo, int numSteps, double timeStep) { - Ioss::DatabaseIO *dbo = out_region.get_database(); + Ioss::Region &outRegion = params.io_region(); + + Ioss::DatabaseIO *dbo = outRegion.get_database(); ThrowRequire(nullptr != dbo); - write_mesh_data_for_subdomain(out_region, bulkData, nodeSharingInfo); + write_mesh_data_for_subdomain(params, nodeSharingInfo); if(numSteps > 0) { - write_transient_data_for_subdomain(out_region, bulkData, timeStep); + write_transient_data_for_subdomain(params, timeStep); } } - void add_properties_for_subdomain(stk::mesh::BulkData& bulkData, - Ioss::Region &out_region, - int index_subdomain, - int num_subdomains, - int global_num_nodes, - int global_num_elems) + void add_properties_for_subdomain(stk::io::OutputParams& params, + int indexSubdomain, + int numSubdomains, + int globalNumNodes, + int globalNumElems) { - out_region.property_add(Ioss::Property("processor_count", num_subdomains)); - out_region.property_add(Ioss::Property("my_processor", index_subdomain)); - out_region.property_add(Ioss::Property("global_node_count", global_num_nodes)); - out_region.property_add(Ioss::Property("global_element_count", global_num_elems)); + Ioss::Region &outRegion = params.io_region(); - if(bulkData.supports_large_ids()) { - out_region.property_add(Ioss::Property("INTEGER_SIZE_API" , 8)); - out_region.property_add(Ioss::Property("INTEGER_SIZE_DB" , 8)); + outRegion.property_add(Ioss::Property("processor_count", numSubdomains)); + outRegion.property_add(Ioss::Property("my_processor", indexSubdomain)); + outRegion.property_add(Ioss::Property("global_node_count", globalNumNodes)); + outRegion.property_add(Ioss::Property("global_element_count", globalNumElems)); - Ioss::DatabaseIO *dbo = out_region.get_database(); + if(params.bulk_data().supports_large_ids()) { + outRegion.property_add(Ioss::Property("INTEGER_SIZE_API" , 8)); + outRegion.property_add(Ioss::Property("INTEGER_SIZE_DB" , 8)); + + Ioss::DatabaseIO *dbo = outRegion.get_database(); dbo->set_int_byte_size_api(Ioss::USE_INT64_API); } } void write_file_for_subdomain(const std::string &baseFilename, - int index_subdomain, - int num_subdomains, - int global_num_nodes, - int global_num_elems, - stk::mesh::BulkData& bulkData, + int indexSubdomain, + int numSubdomains, + int globalNumNodes, + int globalNumElems, + stk::io::OutputParams& params, const EntitySharingInfo &nodeSharingInfo, int numSteps, double timeStep) { - Ioss::DatabaseIO *dbo = create_database_for_subdomain(baseFilename, index_subdomain, num_subdomains); - - Ioss::Region out_region(dbo, "name"); + Ioss::DatabaseIO *dbo = create_database_for_subdomain(baseFilename, indexSubdomain, numSubdomains); + Ioss::Region outRegion(dbo, "name"); - add_properties_for_subdomain(bulkData, out_region, index_subdomain, num_subdomains, global_num_nodes, global_num_elems); + ThrowRequireMsg(params.io_region_ptr() == nullptr, "OutputParams argument must have a NULL IORegion"); + params.set_io_region(&outRegion); + add_properties_for_subdomain(params, indexSubdomain, numSubdomains, globalNumNodes, globalNumElems); - write_file_for_subdomain(out_region, bulkData, nodeSharingInfo, numSteps, timeStep); + write_file_for_subdomain(params, nodeSharingInfo, numSteps, timeStep); - stk::io::delete_selector_property(out_region); + stk::io::delete_selector_property(outRegion); + params.set_io_region(nullptr); } const stk::mesh::Part* get_parent_element_block_by_adjacency(const stk::mesh::BulkData& bulk, const std::string& name, - const stk::mesh::Part* parent_element_block) + const stk::mesh::Part* parentElementBlock) { const stk::mesh::Part* part = bulk.mesh_meta_data().get_part(name); if (part != nullptr) { - std::vector touching_parts = bulk.mesh_meta_data().get_blocks_touching_surface(part); - if (touching_parts.size() == 1) { - parent_element_block = touching_parts[0]; + std::vector touchingParts = bulk.mesh_meta_data().get_blocks_touching_surface(part); + if (touchingParts.size() == 1) { + parentElementBlock = touchingParts[0]; } } - return parent_element_block; + return parentElementBlock; } @@ -3995,9 +4059,9 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta // correct faceblock... std::string part_name = name + "_context"; - const stk::mesh::Part* parent_element_block = bulk.mesh_meta_data().get_part(part_name); + const stk::mesh::Part* parentElementBlock = bulk.mesh_meta_data().get_part(part_name); - if(parent_element_block == nullptr) { + if(parentElementBlock == nullptr) { if(ioRegion.get_database()->get_surface_split_type() == Ioss::SPLIT_BY_ELEMENT_BLOCK) { // If the surfaces were split by element block, then the surface // name will be of the form: "name_block_id_facetopo_id" "name" is typically "surface". @@ -4007,37 +4071,46 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta stk::util::tokenize(name, "_", tokens); if(tokens.size() >= 4) { // Check whether the second-last token is a face topology - const Ioss::ElementTopology* face_topo = Ioss::ElementTopology::factory(tokens[tokens.size() - 2], true); - if(face_topo != nullptr) { + const Ioss::ElementTopology* faceTopo = Ioss::ElementTopology::factory(tokens[tokens.size() - 2], true); + if(faceTopo != nullptr) { // Extract the blockname or "block"_id... - std::string eb_name; - size_t last_token = tokens.size() - 2; - for(size_t tok = 1; tok < last_token; tok++) { - eb_name += tokens[tok]; - if(tok < last_token - 1) eb_name += "_"; + std::string ebName; + size_t lastToken = tokens.size() - 2; + for(size_t tok = 1; tok < lastToken; tok++) { + ebName += tokens[tok]; + if(tok < lastToken - 1) ebName += "_"; } - stk::mesh::Part* elementBlock = bulk.mesh_meta_data().get_part(eb_name); + stk::mesh::Part* elementBlock = bulk.mesh_meta_data().get_part(ebName); if(elementBlock != nullptr && is_part_io_part(*elementBlock)) - parent_element_block = elementBlock; + parentElementBlock = elementBlock; } } else { - parent_element_block = get_parent_element_block_by_adjacency(bulk, name, parent_element_block); + parentElementBlock = get_parent_element_block_by_adjacency(bulk, name, parentElementBlock); } } else { - parent_element_block = get_parent_element_block_by_adjacency(bulk, name, parent_element_block); + parentElementBlock = get_parent_element_block_by_adjacency(bulk, name, parentElementBlock); } } - return parent_element_block; + return parentElementBlock; } - bool is_valid_for_output(const stk::mesh::Part &part, const stk::mesh::Selector *output_selector) + bool is_valid_for_output(stk::io::OutputParams ¶ms, const stk::mesh::Part &part) { + const stk::mesh::EntityRank rank = part.primary_entity_rank(); + const stk::mesh::Selector *outputSelector = params.get_output_selector(rank); + bool isIoPart = stk::io::is_part_io_part(part); - bool isSelected = (output_selector == nullptr) || (*output_selector)(part); + bool isSelected = (outputSelector == nullptr) || (*outputSelector)(part); + + bool isEmptyElementBlock = false; - return (isIoPart && isSelected); + if(rank == stk::topology::ELEM_RANK && params.get_filter_empty_entity_blocks()) { + isEmptyElementBlock = is_empty_element_block(params, &part); + } + + return (isIoPart && isSelected && !isEmptyElementBlock); } bool node_is_connected_to_local_element(const stk::mesh::BulkData &bulk, stk::mesh::Entity node, const stk::mesh::Selector *subsetSelector) @@ -4056,8 +4129,7 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta return isLocalElement; } - size_t count_selected_nodes(OutputParams ¶ms, - const stk::mesh::Selector &selector) + size_t count_selected_nodes(OutputParams ¶ms, const stk::mesh::Selector &selector) { stk::mesh::EntityVector nodes; get_selected_nodes(params, selector, nodes); @@ -4088,10 +4160,10 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta bool hasAdaptivity = params.get_has_adaptivity(); bool result = true; - bool active_only = subsetSelector != nullptr; + bool activeOnly = subsetSelector != nullptr; stk::mesh::Bucket &nodeBucket = bulk.bucket(node); if (hasAdaptivity) { - result = active_only ? (*subsetSelector)(nodeBucket) : true; + result = activeOnly ? (*subsetSelector)(nodeBucket) : true; } if (hasGhosting && result) { // Now need to check whether this node is locally owned or is used by @@ -4106,7 +4178,7 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta for (unsigned i = 0, e = bulk.num_elements(node); i < e; ++i) { stk::mesh::Entity elem = elements[i]; stk::mesh::Bucket &elemBucket = bulk.bucket(elem); - if (elemBucket.owned() && (!active_only || (active_only && (*subsetSelector)(elemBucket)))) { + if (elemBucket.owned() && (!activeOnly || (activeOnly && (*subsetSelector)(elemBucket)))) { result = true; break; } @@ -4134,18 +4206,18 @@ const stk::mesh::FieldBase *declare_stk_field_internal(stk::mesh::MetaData &meta stk::mesh::EntityVector &nodes) { const stk::mesh::BulkData &bulk = params.bulk_data(); - Ioss::Region &io_region = params.io_region(); + Ioss::Region &ioRegion = params.io_region(); nodes.clear(); - bool ignore_disconnected_nodes = false; - if(io_region.property_exists(stk::io::s_ignore_disconnected_nodes)) { - ignore_disconnected_nodes = io_region.get_property(stk::io::s_ignore_disconnected_nodes).get_int(); + bool ignoreDisconnectedNodes = false; + if(ioRegion.property_exists(stk::io::s_ignoreDisconnectedNodes)) { + ignoreDisconnectedNodes = ioRegion.get_property(stk::io::s_ignoreDisconnectedNodes).get_int(); } const bool sortById = true; stk::mesh::get_entities(bulk, stk::topology::NODE_RANK, selector, nodes, sortById); filter_nodes_by_ghosting(params, nodes); - if(!ignore_disconnected_nodes) { + if(!ignoreDisconnectedNodes) { return; } diff --git a/packages/stk/stk_io/stk_io/IossBridge.hpp b/packages/stk/stk_io/stk_io/IossBridge.hpp index 853894f4b38e..67d287429874 100644 --- a/packages/stk/stk_io/stk_io/IossBridge.hpp +++ b/packages/stk/stk_io/stk_io/IossBridge.hpp @@ -46,6 +46,7 @@ #include // for string, operator<, basic_... #include // for pair #include // for vector +#include #include "Ioss_EntityType.h" // for EntityType, SIDEBLOCK #include "Ioss_GroupingEntity.h" // for GroupingEntity #include "SidesetTranslator.hpp" // for fill_element_and_side_ids @@ -95,6 +96,8 @@ namespace stk { */ namespace io { +using TopologyErrorHandler = std::function; + stk::mesh::EntityRank get_entity_rank(const Ioss::GroupingEntity *entity, const stk::mesh::MetaData &meta); @@ -108,12 +111,12 @@ struct GlobalAnyVariable { stk::util::ParameterType::Type m_type; }; -static const std::string s_internal_selector_name("_stk_io_internal_selector"); -static const std::string s_ignore_disconnected_nodes("ignore_disconnected_nodes"); -static const std::string s_process_all_input_nodes("process_all_input_nodes"); -static const std::string s_sort_stk_parts("sort_stk_parts"); -static const std::string s_entity_nodes_suffix("_n"); -static const std::string s_distribution_factors("distribution_factors"); +static const std::string s_internalSelectorName("_stk_io_internal_selector"); +static const std::string s_ignoreDisconnectedNodes("ignore_disconnected_nodes"); +static const std::string s_processAllInputNodes("process_all_input_nodes"); +static const std::string s_sortStkParts("sort_stk_parts"); +static const std::string s_entityNodesSuffix("_n"); +static const std::string s_distributionFactors("distribution_factors"); typedef std::pair EntityIdToProcPair; typedef std::vector EntitySharingInfo; @@ -136,7 +139,7 @@ typedef std::vector FieldNameToPartVector; stk::mesh::Part *getPart(const stk::mesh::MetaData& meta_data, const std::string& name); -bool is_valid_for_output(const stk::mesh::Part &part, const stk::mesh::Selector *output_selector = nullptr); +bool is_valid_for_output(stk::io::OutputParams ¶ms, const stk::mesh::Part &part); void get_selected_nodes(OutputParams ¶ms, const stk::mesh::Selector &selector, stk::mesh::EntityVector &nodes); @@ -157,9 +160,9 @@ bool node_is_connected_to_local_element(const stk::mesh::BulkData &bulk, stk::me */ bool include_entity(const Ioss::GroupingEntity *entity); -void internal_part_processing(Ioss::GroupingEntity *entity, stk::mesh::MetaData &meta); +void internal_part_processing(Ioss::GroupingEntity *entity, stk::mesh::MetaData &meta, TopologyErrorHandler handler); -void internal_part_processing(Ioss::EntityBlock *entity, stk::mesh::MetaData &meta); +void internal_part_processing(Ioss::EntityBlock *entity, stk::mesh::MetaData &meta, TopologyErrorHandler handler); /** This is the primary function used by an application to define * the stk::mesh which corresponds to the Ioss mesh read from the @@ -171,12 +174,23 @@ void internal_part_processing(Ioss::EntityBlock *entity, stk::mesh::MetaData &me * stk::io::define_output_db()) which will cause the part to be output to a * results or restart file. */ +template +void default_part_processing(const std::vector &entities, stk::mesh::MetaData &meta, TopologyErrorHandler handler) +{ + for(size_t i=0; i < entities.size(); i++) { + T* entity = entities[i]; + internal_part_processing(entity, meta, handler); + } +} + template void default_part_processing(const std::vector &entities, stk::mesh::MetaData &meta) { + TopologyErrorHandler handler = [](stk::mesh::Part &part) { }; + for(size_t i=0; i < entities.size(); i++) { T* entity = entities[i]; - internal_part_processing(entity, meta); + internal_part_processing(entity, meta, handler); } } @@ -463,11 +477,11 @@ bool is_part_element_block_io_part(const stk::mesh::Part &part); bool is_part_surface_io_part(const stk::mesh::Part &part); -Ioss::GroupingEntity* get_grouping_entity(const Ioss::Region& region, stk::mesh::Part& part); +Ioss::GroupingEntity* get_grouping_entity(const Ioss::Region& region, const stk::mesh::Part& part); std::vector get_ioss_entity_types(const stk::mesh::MetaData& meta, stk::mesh::EntityRank rank); -std::vector get_ioss_entity_types(stk::mesh::Part& part); +std::vector get_ioss_entity_types(const stk::mesh::Part& part); std::string getPartName(const stk::mesh::Part& part); @@ -546,15 +560,14 @@ void initialize_spatial_dimension(mesh::MetaData &meta, size_t spatial_dimension Ioss::DatabaseIO *create_database_for_subdomain(const std::string &baseFilename, int index_subdomain, int num_subdomains); -void add_properties_for_subdomain(stk::mesh::BulkData& bulkData, Ioss::Region &out_region, int index_subdomain, +void add_properties_for_subdomain(stk::io::OutputParams& params, int index_subdomain, int num_subdomains, int global_num_nodes, int global_num_elems); -void write_mesh_data_for_subdomain(Ioss::Region& out_region, stk::mesh::BulkData& bulkData, const EntitySharingInfo& nodeSharingInfo); +void write_mesh_data_for_subdomain(stk::io::OutputParams& params, const EntitySharingInfo& nodeSharingInfo); -int write_transient_data_for_subdomain(Ioss::Region &out_region, stk::mesh::BulkData& bulkData, double timeStep); +int write_transient_data_for_subdomain(stk::io::OutputParams& params, double timeStep); -void write_file_for_subdomain(Ioss::Region &out_region, - stk::mesh::BulkData& bulkData, +void write_file_for_subdomain(stk::io::OutputParams& params, const EntitySharingInfo &nodeSharingInfo, int numSteps = -1, double timeStep = 0.0); @@ -564,7 +577,7 @@ void write_file_for_subdomain(const std::string &baseFilename, int num_subdomains, int global_num_nodes, int global_num_elems, - stk::mesh::BulkData& bulkData, + stk::io::OutputParams& params, const EntitySharingInfo &nodeSharingInfo, int numSteps = -1, double timeStep = 0.0); diff --git a/packages/stk/stk_io/stk_io/OutputFile.cpp b/packages/stk/stk_io/stk_io/OutputFile.cpp index 33cf1a340f8e..fcfee6f454fb 100644 --- a/packages/stk/stk_io/stk_io/OutputFile.cpp +++ b/packages/stk/stk_io/stk_io/OutputFile.cpp @@ -135,6 +135,9 @@ void OutputFile::setup_output_params(OutputParams ¶ms) const params.set_additional_attribute_fields(m_additionalAttributeFields); params.set_is_restart(m_dbPurpose == stk::io::WRITE_RESTART); params.set_enable_edge_io(m_enableEdgeIO); + + params.set_filter_empty_entity_blocks(m_filterEmptyEntityBlocks); + params.set_filter_empty_assembly_entity_blocks(m_filterEmptyAssemblyEntityBlocks); } void OutputFile::set_input_region(const Ioss::Region *input_region) @@ -146,7 +149,7 @@ void OutputFile::set_input_region(const Ioss::Region *input_region) } void OutputFile::write_output_mesh(const stk::mesh::BulkData& bulk_data, - const std::vector> &attributeOrdering) + const std::vector> &attributeOrdering) { if ( m_meshDefined == false ) { @@ -556,7 +559,7 @@ void OutputFile::define_output_fields(const stk::mesh::BulkData& bulk_data, : m_useNodesetForBlockNodesFields; if (use_nodeset) { - std::string nodes_name = partName + s_entity_nodes_suffix; + std::string nodes_name = partName + s_entityNodesSuffix; node_entity = region->get_entity(nodes_name); } } @@ -645,7 +648,7 @@ int OutputFile::write_defined_output_fields(const stk::mesh::BulkData& bulk_data m_useNodesetForBlockNodesFields; if (use_nodeset) { - std::string nodes_name = partName + s_entity_nodes_suffix; + std::string nodes_name = partName + s_entityNodesSuffix; node_entity = region->get_entity(nodes_name); } } @@ -810,6 +813,16 @@ void OutputFile::set_enable_edge_io(bool enableEdgeIO) m_enableEdgeIO = enableEdgeIO; } +void OutputFile::set_filter_empty_entity_blocks(const bool filterEmptyEntityBlocks) +{ + m_filterEmptyEntityBlocks = filterEmptyEntityBlocks; +} + +void OutputFile::set_filter_empty_assembly_entity_blocks(const bool filterEmptyAssemblyEntityBlocks) +{ + m_filterEmptyAssemblyEntityBlocks = filterEmptyAssemblyEntityBlocks; +} + } // namespace impl } // namespace io } // namespace stk diff --git a/packages/stk/stk_io/stk_io/OutputFile.hpp b/packages/stk/stk_io/stk_io/OutputFile.hpp index 5241ba1d9965..b46d87507c52 100644 --- a/packages/stk/stk_io/stk_io/OutputFile.hpp +++ b/packages/stk/stk_io/stk_io/OutputFile.hpp @@ -100,7 +100,9 @@ class OutputFile m_subsetSelector(nullptr), m_sharedSelector(nullptr), m_skinMeshSelector(nullptr), - m_multiStateSuffixes(nullptr) + m_multiStateSuffixes(nullptr), + m_filterEmptyEntityBlocks(false), + m_filterEmptyAssemblyEntityBlocks(false) { initialize_output_selectors(); setup_output_file(filename, communicator, property_manager, type, openFileImmediately); @@ -128,7 +130,9 @@ class OutputFile m_subsetSelector(nullptr), m_sharedSelector(nullptr), m_skinMeshSelector(nullptr), - m_multiStateSuffixes(nullptr) + m_multiStateSuffixes(nullptr), + m_filterEmptyEntityBlocks(false), + m_filterEmptyAssemblyEntityBlocks(false) { m_region = ioss_output_region; m_meshDefined = true; @@ -216,6 +220,9 @@ class OutputFile void set_enable_edge_io(bool enableEdgeIO); + void set_filter_empty_entity_blocks(const bool filterEmptyEntityBlocks); + void set_filter_empty_assembly_entity_blocks(const bool filterEmptyAssemblyEntityBlocks); + Ioss::DatabaseIO *get_output_database(); std::vector get_output_entities(const stk::mesh::BulkData& bulk_data, const std::string &name); @@ -262,6 +269,9 @@ class OutputFile std::vector* m_multiStateSuffixes = nullptr; + bool m_filterEmptyEntityBlocks; + bool m_filterEmptyAssemblyEntityBlocks; + OutputFile(const OutputFile &); const OutputFile & operator=(const OutputFile &); }; diff --git a/packages/stk/stk_io/stk_io/OutputParams.hpp b/packages/stk/stk_io/stk_io/OutputParams.hpp index 6a0d158c8595..fc318ebe5ef5 100644 --- a/packages/stk/stk_io/stk_io/OutputParams.hpp +++ b/packages/stk/stk_io/stk_io/OutputParams.hpp @@ -45,9 +45,12 @@ #include // for string, operator<, etc #include // for pair #include // for vector +#include #include "Ioss_EntityType.h" // for EntityType #include "Ioss_GroupingEntity.h" #include "Ioss_Region.h" +#include "Ioss_Utils.h" +#include "Ioss_DatabaseIO.h" #include "stk_mesh/base/FieldState.hpp" // for FieldState #include "stk_mesh/base/FieldBase.hpp" // for FieldState #include "stk_mesh/base/Part.hpp" // for Part @@ -55,6 +58,9 @@ #include "stk_io/MeshField.hpp" #include "stk_io/FieldAndName.hpp" +#include // for count_selected_en... +#include // for all_reduce_sum + namespace Ioss { class ElementTopology; } namespace Ioss { class EntityBlock; } namespace Ioss { class GroupingEntity; } @@ -70,6 +76,8 @@ namespace stk { namespace mesh { class Selector; } } namespace stk { namespace io { +bool is_part_element_block_io_part(const stk::mesh::Part &part); + struct OutputParams { public: @@ -78,6 +86,7 @@ struct OutputParams m_bulkData(bulk) { initialize_output_selectors(); + initialize_block_sizes(); } OutputParams(const mesh::BulkData &bulk) : @@ -85,8 +94,15 @@ struct OutputParams m_bulkData(bulk) { initialize_output_selectors(); + initialize_block_sizes(); } + void set_io_region(Ioss::Region* region) { + m_ioRegion = region; + } + Ioss::Region *io_region_ptr() const { + return m_ioRegion; + } Ioss::Region &io_region() const { ThrowRequireMsg(m_ioRegion != nullptr, "Region is null"); return *m_ioRegion; } @@ -206,6 +222,25 @@ struct OutputParams m_additionalAttributeFields = additionalAttributeFields; } + void set_filter_empty_entity_blocks(const bool filterEmptyEntityBlocks) { + m_filterEmptyEntityBlocks = filterEmptyEntityBlocks; + } + + bool get_filter_empty_entity_blocks() const { + return m_filterEmptyEntityBlocks; + } + + void set_filter_empty_assembly_entity_blocks(const bool filterEmptyAssemblyEntityBlocks) { + m_filterEmptyAssemblyEntityBlocks = filterEmptyAssemblyEntityBlocks; + } + + bool get_filter_empty_assembly_entity_blocks() const { + return m_filterEmptyAssemblyEntityBlocks; + } + + const std::unordered_map& get_block_sizes() const { + return m_blockSizes; + } private: OutputParams(); OutputParams(const OutputParams &); @@ -217,6 +252,36 @@ struct OutputParams } } + void initialize_block_sizes() + { + const stk::mesh::MetaData & meta = m_bulkData.mesh_meta_data(); + const mesh::PartVector & parts = meta.get_parts(); + stk::mesh::ConstPartVector elementParts; + elementParts.reserve(parts.size()); + + for (const stk::mesh::Part * part : parts) { + if (is_part_element_block_io_part(*part)) { + elementParts.push_back(part); + } + } + + size_t length = elementParts.size(); + std::vector localBlockSizes(length, 0); + std::vector globalBlockSizes(length, 0); + + for(size_t i=0; imesh_meta_data_ordinal()] = globalBlockSizes[i]; + } + } + bool is_valid_rank(stk::topology::rank_t rank) const {return ((rank >= stk::topology::NODE_RANK) && (rank <= stk::topology::ELEM_RANK)); } Ioss::Region * m_ioRegion = nullptr; @@ -235,6 +300,11 @@ struct OutputParams bool m_isRestart = false; bool m_enableEdgeIO = false; std::vector m_additionalAttributeFields; + std::unordered_map m_blockSizes; + + bool m_filterEmptyEntityBlocks = false; + bool m_filterEmptyAssemblyEntityBlocks = false; + }; }//namespace io diff --git a/packages/stk/stk_io/stk_io/ProcessSetsOrBlocks.cpp b/packages/stk/stk_io/stk_io/ProcessSetsOrBlocks.cpp index de3146def953..3a2a47cf8d63 100644 --- a/packages/stk/stk_io/stk_io/ProcessSetsOrBlocks.cpp +++ b/packages/stk/stk_io/stk_io/ProcessSetsOrBlocks.cpp @@ -77,10 +77,10 @@ void process_nodeblocks(Ioss::Region ®ion, stk::mesh::MetaData &meta) -void process_elementblocks(Ioss::Region ®ion, stk::mesh::MetaData &meta) +void process_elementblocks(Ioss::Region ®ion, stk::mesh::MetaData &meta, TopologyErrorHandler handler) { const Ioss::ElementBlockContainer& elem_blocks = region.get_element_blocks(); - stk::io::default_part_processing(elem_blocks, meta); + stk::io::default_part_processing(elem_blocks, meta, handler); } void process_nodesets_without_distribution_factors(Ioss::Region ®ion, stk::mesh::MetaData &meta) @@ -639,16 +639,16 @@ void process_edge_blocks(Ioss::Region ®ion, stk::mesh::BulkData &bulk) } } -void process_face_blocks(Ioss::Region ®ion, stk::mesh::MetaData &meta) +void process_face_blocks(Ioss::Region ®ion, stk::mesh::MetaData &meta, TopologyErrorHandler handler) { const Ioss::FaceBlockContainer& face_blocks = region.get_face_blocks(); - stk::io::default_part_processing(face_blocks, meta); + stk::io::default_part_processing(face_blocks, meta, handler); } -void process_edge_blocks(Ioss::Region ®ion, stk::mesh::MetaData &meta) +void process_edge_blocks(Ioss::Region ®ion, stk::mesh::MetaData &meta, TopologyErrorHandler handler) { const Ioss::EdgeBlockContainer& edge_blocks = region.get_edge_blocks(); - stk::io::default_part_processing(edge_blocks, meta); + stk::io::default_part_processing(edge_blocks, meta, handler); } void process_assemblies(Ioss::Region ®ion, stk::mesh::MetaData &meta) diff --git a/packages/stk/stk_io/stk_io/ProcessSetsOrBlocks.hpp b/packages/stk/stk_io/stk_io/ProcessSetsOrBlocks.hpp index 5e209ace1d3c..7438af163ad7 100644 --- a/packages/stk/stk_io/stk_io/ProcessSetsOrBlocks.hpp +++ b/packages/stk/stk_io/stk_io/ProcessSetsOrBlocks.hpp @@ -189,7 +189,7 @@ void process_nodeblocks(Ioss::Region ®ion, stk::mesh::BulkData &bulk) stk::mesh::Part* get_part_from_alias(const Ioss::Region ®ion, const stk::mesh::MetaData &meta, const std::string &name); stk::mesh::Part* get_part_for_grouping_entity(const Ioss::Region ®ion, const stk::mesh::MetaData &meta, const Ioss::GroupingEntity *entity); -void process_elementblocks(Ioss::Region ®ion, stk::mesh::MetaData &meta); +void process_elementblocks(Ioss::Region ®ion, stk::mesh::MetaData &meta, TopologyErrorHandler handler); template void process_elementblocks(Ioss::Region ®ion, stk::mesh::BulkData &bulk) { @@ -324,9 +324,9 @@ void process_hidden_nodesets(Ioss::Region &io, stk::mesh::BulkData & bulk) void process_sidesets(Ioss::Region ®ion, stk::mesh::BulkData &bulk, const stk::mesh::EntityIdProcMap &elemIdMovedToProc, stk::io::StkMeshIoBroker::SideSetFaceCreationBehavior behavior); void process_sidesets(Ioss::Region ®ion, stk::mesh::MetaData &meta); void process_face_blocks(Ioss::Region ®ion, stk::mesh::BulkData &bulk); -void process_face_blocks(Ioss::Region ®ion, stk::mesh::MetaData &meta); +void process_face_blocks(Ioss::Region ®ion, stk::mesh::MetaData &meta, TopologyErrorHandler handler); void process_edge_blocks(Ioss::Region ®ion, stk::mesh::BulkData &bulk); -void process_edge_blocks(Ioss::Region ®ion, stk::mesh::MetaData &meta); +void process_edge_blocks(Ioss::Region ®ion, stk::mesh::MetaData &meta, TopologyErrorHandler handler); void process_assemblies(Ioss::Region ®ion, stk::mesh::MetaData &meta); void build_assembly_hierarchies(Ioss::Region ®ion, stk::mesh::MetaData &meta); diff --git a/packages/stk/stk_io/stk_io/StkMeshIoBroker.cpp b/packages/stk/stk_io/stk_io/StkMeshIoBroker.cpp index db25b379c93e..62630049852b 100644 --- a/packages/stk/stk_io/stk_io/StkMeshIoBroker.cpp +++ b/packages/stk/stk_io/stk_io/StkMeshIoBroker.cpp @@ -212,6 +212,22 @@ stk::mesh::FieldBase const& StkMeshIoBroker::get_coordinate_field() const return * coord_field; } +bool StkMeshIoBroker::get_filter_empty_input_entity_blocks() const +{ + return get_filter_empty_input_entity_blocks(m_activeMeshIndex); +} + +bool StkMeshIoBroker::get_filter_empty_input_entity_blocks(size_t input_file_index) const +{ + validate_input_file_index(input_file_index); + auto ioss_input_region = m_inputFiles[input_file_index]->get_input_io_region(); + + bool retainEmptyBlocks = (ioss_input_region->get_assemblies().size() > 0); + const Ioss::PropertyManager &properties = ioss_input_region->get_database()->get_property_manager(); + Ioss::Utils::check_set_bool_property(properties, "RETAIN_EMPTY_BLOCKS", retainEmptyBlocks); + return !retainEmptyBlocks; +} + size_t StkMeshIoBroker::add_mesh_database(Teuchos::RCP ioss_input_region) { auto input_file = Teuchos::rcp(new InputFile(ioss_input_region)); @@ -477,11 +493,22 @@ void StkMeshIoBroker::create_input_mesh() initialize_spatial_dimension(meta_data(), spatial_dimension, m_rankNames); } + TopologyErrorHandler handler; + if(get_filter_empty_input_entity_blocks()) { + handler = [](stk::mesh::Part &part) { + std::ostringstream msg ; + msg << "\n\nERROR: Entity Block " << part.name() << " has invalid topology\n\n"; + throw std::runtime_error( msg.str() ); + }; + } else { + handler = [](stk::mesh::Part &part) { }; + } + process_nodeblocks(*region, meta_data()); - process_elementblocks(*region, meta_data()); + process_elementblocks(*region, meta_data(), handler); process_sidesets(*region, meta_data()); - process_face_blocks(*region, meta_data()); - process_edge_blocks(*region, meta_data()); + process_face_blocks(*region, meta_data(), handler); + process_edge_blocks(*region, meta_data(), handler); if(m_autoLoadDistributionFactorPerNodeSet) { process_nodesets(*region, meta_data()); @@ -666,8 +693,8 @@ bool StkMeshIoBroker::populate_mesh_elements_and_nodes(bool delay_field_data_all Ioss::Region *region = m_inputFiles[m_activeMeshIndex]->get_input_io_region().get(); bool ints64bit = db_api_int_size(region) == 8; bool processAllInputNodes = true; - if(region->property_exists(stk::io::s_process_all_input_nodes)) { - processAllInputNodes = region->get_property(stk::io::s_process_all_input_nodes).get_int(); + if(region->property_exists(stk::io::s_processAllInputNodes)) { + processAllInputNodes = region->get_property(stk::io::s_processAllInputNodes).get_int(); } if (ints64bit) { diff --git a/packages/stk/stk_io/stk_io/StkMeshIoBroker.hpp b/packages/stk/stk_io/stk_io/StkMeshIoBroker.hpp index 9c5803b9d6c6..9fef003c599c 100644 --- a/packages/stk/stk_io/stk_io/StkMeshIoBroker.hpp +++ b/packages/stk/stk_io/stk_io/StkMeshIoBroker.hpp @@ -165,6 +165,9 @@ namespace stk { void set_adaptivity_filter(size_t output_file_index, bool hasAdaptivity); void set_skin_mesh_flag(size_t output_file_index, bool skinMesh); + void set_filter_empty_output_entity_blocks(size_t output_file_index, const bool filterEmptyEntityBlocks); + void set_filter_empty_output_assembly_entity_blocks(size_t output_file_index, const bool filterEmptyAssemblyEntityBlocks); + stk::mesh::Selector get_active_selector() const; void set_active_selector(stk::mesh::Selector my_selector); @@ -212,6 +215,9 @@ namespace stk { m_autoLoadDistributionFactorPerNodeSet = shouldAutoLoad; } + bool get_filter_empty_input_entity_blocks() const; + bool get_filter_empty_input_entity_blocks(size_t input_file_index) const; + // Create the Ioss::DatabaseIO associated with the specified filename // and type (exodus by default). The routine checks that the // file exists and is readable and will throw an exception if not. @@ -850,6 +856,16 @@ namespace stk { m_outputFiles[output_file_index]->is_skin_mesh(skinMesh); } + inline void StkMeshIoBroker::set_filter_empty_output_entity_blocks(size_t output_file_index, const bool filterEmptyEntityBlocks) { + validate_output_file_index(output_file_index); + m_outputFiles[output_file_index]->set_filter_empty_entity_blocks(filterEmptyEntityBlocks); + } + + inline void StkMeshIoBroker::set_filter_empty_output_assembly_entity_blocks(size_t output_file_index, const bool filterEmptyAssemblyEntityBlocks) { + validate_output_file_index(output_file_index); + m_outputFiles[output_file_index]->set_filter_empty_assembly_entity_blocks(filterEmptyAssemblyEntityBlocks); + } + inline stk::mesh::Selector StkMeshIoBroker::get_active_selector() const { return m_activeSelector; } diff --git a/packages/stk/stk_mesh/stk_mesh/base/BulkData.cpp b/packages/stk/stk_mesh/stk_mesh/base/BulkData.cpp index d0f8c1964389..e93a52194082 100644 --- a/packages/stk/stk_mesh/stk_mesh/base/BulkData.cpp +++ b/packages/stk/stk_mesh/stk_mesh/base/BulkData.cpp @@ -80,12 +80,13 @@ #include #include "stk_mesh/base/GetNgpMesh.hpp" #include -#include // for EntityRepository, etc +#include #include #include #include #include // for SideConnector #include +#include #include // for CommSparse #include #include // for Reduce, all_reduce, etc @@ -325,6 +326,7 @@ void BulkData::find_and_delete_internal_faces(stk::mesh::EntityRank entityRank, } +#ifndef STK_HIDE_DEPRECATED_CODE // Delete after August 2022 //---------------------------------------------------------------------- BulkData::BulkData(MetaData & mesh_meta_data, ParallelMachine parallel, @@ -342,19 +344,16 @@ BulkData::BulkData(MetaData & mesh_meta_data, m_auraGhosting(std::make_shared()), m_entity_comm_map(), m_ghosting(), - m_meta_raw_ptr_to_be_deprecated( &mesh_meta_data ), m_meta_data(&mesh_meta_data, [](MetaData* metadata) {}), m_mark_entity(), m_add_node_sharing_called(false), m_closure_count(), m_mesh_indexes(), - m_entity_repo(new impl::EntityRepository()), + m_entityKeyMapping(new impl::EntityKeyMapping()), m_entity_comm_list(), m_entitycomm(), m_owner(), - m_comm_list_updater(m_entity_comm_list, m_entitycomm), - m_deleted_entities_current_modification_cycle(), - m_ghost_reuse_map(), + m_comm_list_updater(m_entity_comm_list, m_entitycomm, m_removedGhosts), m_entity_keys(), #ifdef SIERRA_MIGRATION m_add_fmwk_data(add_fmwk_data), @@ -370,7 +369,6 @@ BulkData::BulkData(MetaData & mesh_meta_data, m_volatile_fast_shared_comm_map_sync_count(0), m_all_sharing_procs(mesh_meta_data.entity_rank_count()), m_ghost_parts(), - m_deleted_entities(), m_num_fields(-1), // meta data not necessarily committed yet m_keep_fields_updated(true), m_local_ids(), @@ -396,7 +394,7 @@ BulkData::BulkData(MetaData & mesh_meta_data, mesh_meta_data.set_mesh_bulk_data(this); } catch(...) { - delete m_entity_repo; + delete m_entityKeyMapping; throw; } @@ -422,6 +420,7 @@ BulkData::BulkData(MetaData & mesh_meta_data, m_meshModification.set_sync_state_synchronized(); } +#endif //---------------------------------------------------------------------- BulkData::BulkData(std::shared_ptr mesh_meta_data, @@ -442,19 +441,16 @@ BulkData::BulkData(std::shared_ptr mesh_meta_data, m_auraGhosting((auraGhosting!=nullptr ? auraGhosting : std::make_shared())), m_entity_comm_map(), m_ghosting(), - m_meta_raw_ptr_to_be_deprecated( mesh_meta_data.get() ), m_meta_data(mesh_meta_data), m_mark_entity(), m_add_node_sharing_called(false), m_closure_count(), m_mesh_indexes(), - m_entity_repo(new impl::EntityRepository()), + m_entityKeyMapping(new impl::EntityKeyMapping()), m_entity_comm_list(), m_entitycomm(), m_owner(), - m_comm_list_updater(m_entity_comm_list, m_entitycomm), - m_deleted_entities_current_modification_cycle(), - m_ghost_reuse_map(), + m_comm_list_updater(m_entity_comm_list, m_entitycomm, m_removedGhosts), m_entity_keys(), #ifdef SIERRA_MIGRATION m_add_fmwk_data(add_fmwk_data), @@ -470,7 +466,6 @@ BulkData::BulkData(std::shared_ptr mesh_meta_data, m_volatile_fast_shared_comm_map_sync_count(0), m_all_sharing_procs(mesh_meta_data->entity_rank_count()), m_ghost_parts(), - m_deleted_entities(), m_num_fields(-1), // meta data not necessarily committed yet m_keep_fields_updated(true), m_local_ids(), @@ -496,7 +491,7 @@ BulkData::BulkData(std::shared_ptr mesh_meta_data, mesh_meta_data->set_mesh_bulk_data(this); } catch(...) { - delete m_entity_repo; + delete m_entityKeyMapping; throw; } @@ -546,7 +541,7 @@ BulkData::~BulkData() mesh_meta_data().set_mesh_bulk_data(nullptr); delete m_elemElemGraph; - delete m_entity_repo; + delete m_entityKeyMapping; delete m_ngpMeshBase; } @@ -594,18 +589,7 @@ void BulkData::set_automatic_aura_option(AutomaticAuraOption auraOption, bool ap void BulkData::update_deleted_entities_container() { - while(!m_deleted_entities_current_modification_cycle.empty()) { - Entity::entity_value_type entity_offset = m_deleted_entities_current_modification_cycle.front(); - m_deleted_entities_current_modification_cycle.pop_front(); - m_deleted_entities.push_front(entity_offset); - } - - // Reclaim offsets for deleted ghosts that were not regenerated - for (auto keyAndOffset : m_ghost_reuse_map) { - m_deleted_entities.push_front(keyAndOffset.second); - } - - m_ghost_reuse_map.clear(); + m_meshModification.get_deleted_entity_cache().update_deleted_entities_container(); } //---------------------------------------------------------------------- @@ -712,9 +696,12 @@ Entity BulkData::generate_new_entity(unsigned preferred_offset) if (preferred_offset != 0) { new_local_offset = preferred_offset; } - else if (!m_deleted_entities.empty()) { - new_local_offset = m_deleted_entities.front(); - m_deleted_entities.pop_front(); + else { + Entity::entity_value_type local_offset = m_meshModification.get_deleted_entity_cache().get_entity_for_reuse(); + if (local_offset != Entity::InvalidEntity) + { + new_local_offset = local_offset; + } } MeshIndex mesh_index = {nullptr, 0}; @@ -985,23 +972,6 @@ template Entity BulkData::declare_element_side_with_id(co template Entity BulkData::declare_element_side_with_id(const stk::mesh::EntityId, Entity, const unsigned, const stk::mesh::ConstPartVector&); -//---------------------------------------------------------------------- - -namespace { - -// A method for quickly finding an entity within a comm list -const EntityCommListInfo& find_entity(const BulkData& mesh, - const EntityCommListInfoVector& entities, - const EntityKey& key) -{ - EntityCommListInfoVector::const_iterator lb_itr = std::lower_bound(entities.begin(), entities.end(), key); - ThrowAssertMsg(lb_itr != entities.end() && lb_itr->key == key, - "proc " << mesh.parallel_rank() << " Cannot find entity-key " << key << " in comm-list" ); - return *lb_itr; -} - -} - void BulkData::entity_comm_list_insert(Entity node) { stk::mesh::EntityKey key = entity_key(node); @@ -1225,7 +1195,7 @@ void BulkData::change_entity_id( EntityId id, Entity entity) void BulkData::internal_change_entity_key( EntityKey old_key, EntityKey new_key, Entity entity) { - m_entity_repo->update_entity_key(new_key, old_key, entity); + m_entityKeyMapping->update_entity_key(new_key, old_key, entity); set_entity_key(entity, new_key); m_bucket_repository.set_needs_to_be_sorted(this->bucket(entity), true); } @@ -1298,15 +1268,8 @@ bool BulkData::internal_destroy_entity(Entity entity, bool wasGhost) } } - // If this is a ghosted entity, store key->local_offset so that local_offset can be - // reused if the entity is recreated in the next aura-regen. This will prevent clients - // from having their handles to ghosted entities go invalid when the ghost is refreshed. - const stk::mesh::EntityKey key = entity_key(entity); - if ( ghost ) { - m_ghost_reuse_map[key] = entity.local_offset(); - } - // Need to invalidate Entity handles in comm-list + const stk::mesh::EntityKey key = entity_key(entity); stk::mesh::EntityCommListInfoVector::iterator lb_itr = std::lower_bound(m_entity_comm_list.begin(), m_entity_comm_list.end(), key); if (lb_itr != m_entity_comm_list.end() && lb_itr->key == key) { @@ -1317,24 +1280,21 @@ bool BulkData::internal_destroy_entity(Entity entity, bool wasGhost) m_bucket_repository.remove_entity(mesh_index(entity)); - record_entity_deletion(entity); + record_entity_deletion(entity, ghost); - if ( !ghost ) { - m_deleted_entities_current_modification_cycle.push_front(entity.local_offset()); - } m_check_invalid_rels = true; return true ; } -void BulkData::record_entity_deletion(Entity entity) +void BulkData::record_entity_deletion(Entity entity, bool isGhost) { const EntityKey key = entity_key(entity); set_mesh_index(entity, 0, 0); - m_entity_repo->destroy_entity(key, entity); + m_entityKeyMapping->destroy_entity(key, entity); notifier.notify_local_entities_created_or_deleted(key.rank()); notifier.notify_local_buckets_changed(key.rank()); - m_meshModification.mark_entity_as_deleted(entity.local_offset()); + m_meshModification.mark_entity_as_deleted(entity, isGhost); m_mark_entity[entity.local_offset()] = NOT_MARKED; m_closure_count[entity.local_offset()] = static_cast(0u); } @@ -1352,7 +1312,7 @@ size_t get_max_num_ids_needed_across_all_procs(const stk::mesh::BulkData& bulkDa std::vector BulkData::internal_get_ids_in_use(stk::topology::rank_t rank, const std::vector& reserved_ids) const { std::vector ids_in_use; - ids_in_use.reserve(m_entity_keys.size() + m_deleted_entities_current_modification_cycle.size()); + ids_in_use.reserve(m_entity_keys.size() + m_meshModification.get_deleted_entity_cache().get_deleted_entities_current_mod_cycle().size()); const BucketVector& bkts = this->buckets(rank); for (const Bucket* bptr : bkts) { @@ -1361,7 +1321,7 @@ std::vector BulkData::internal_get_ids_in_use(stk::topology::rank_t ra } } - for (Entity::entity_value_type local_offset : m_deleted_entities_current_modification_cycle) { + for (Entity::entity_value_type local_offset : m_meshModification.get_deleted_entity_cache().get_deleted_entities_current_mod_cycle()) { stk::mesh::Entity entity; entity.set_local_offset(local_offset); if ((entity_rank(entity) == rank) && (is_valid(entity) || state(entity)==Deleted)) { @@ -1415,7 +1375,7 @@ void BulkData::generate_new_ids(stk::topology::rank_t rank, size_t numIdsNeeded, if ( globalNumIdsRequested == 0 ) return; EntityId globalMaxId = impl::get_global_max_id_in_use(*this, rank, - m_deleted_entities_current_modification_cycle); + m_meshModification.get_deleted_entity_cache().get_deleted_entities_current_mod_cycle()); uint64_t maxAllowedId = get_max_allowed_id(); uint64_t availableIds = maxAllowedId - globalMaxId; @@ -1465,7 +1425,7 @@ std::pair BulkData::internal_create_entity(EntityKey key, size_t p { m_modSummary.track_declare_entity(key.rank(), key.id(), stk::mesh::PartVector()); - std::pair entityBoolPair = m_entity_repo->internal_create_entity(key); + std::pair entityBoolPair = m_entityKeyMapping->internal_create_entity(key); if(entityBoolPair.second) { @@ -1529,7 +1489,7 @@ void BulkData::declare_entities(stk::topology::rank_t rank, const IDVECTOR& newI m_modSummary.track_declare_entity(key.rank(), key.id(), stk::mesh::PartVector()); - std::pair entityBoolPair = m_entity_repo->internal_create_entity(key); + std::pair entityBoolPair = m_entityKeyMapping->internal_create_entity(key); ThrowErrorMsgIf( ! entityBoolPair.second, "Generated id " << key.id() << " of rank " << key.rank() << " which was already used."); @@ -1719,10 +1679,11 @@ bool BulkData::is_communicated_with_proc(Entity entity, int proc) const void BulkData::comm_procs(Entity entity, std::vector & procs ) const { - ThrowAssertMsg(is_valid(entity), - "BulkData::comm_procs ERROR, input entity "<comm_map.begin(),entityComm->comm_map.end()), procs); + } } void BulkData::comm_shared_procs(EntityKey key, std::vector & procs ) const @@ -2072,12 +2033,12 @@ void BulkData::update_field_data_states() const_entity_iterator BulkData::begin_entities(EntityRank ent_rank) const { - return m_entity_repo->begin_rank(ent_rank); + return m_entityKeyMapping->begin_rank(ent_rank); } const_entity_iterator BulkData::end_entities(EntityRank ent_rank) const { - return m_entity_repo->end_rank(ent_rank); + return m_entityKeyMapping->end_rank(ent_rank); } Entity BulkData::get_entity( EntityRank ent_rank , EntityId entity_id ) const @@ -2085,12 +2046,12 @@ Entity BulkData::get_entity( EntityRank ent_rank , EntityId entity_id ) const if (!impl::is_good_rank_and_id(mesh_meta_data(), ent_rank, entity_id)) { return Entity(); } - return m_entity_repo->get_entity( EntityKey(ent_rank, entity_id)); + return m_entityKeyMapping->get_entity( EntityKey(ent_rank, entity_id)); } Entity BulkData::get_entity( const EntityKey key ) const { - return m_entity_repo->get_entity(key); + return m_entityKeyMapping->get_entity(key); } void BulkData::reorder_buckets_callback(EntityRank rank, const std::vector& reorderedBucketIds) @@ -2760,7 +2721,7 @@ void BulkData::internal_change_entity_owner( const std::vector & arg // Compute the closure of all the locally changing entities for (const EntityProc& entityProc : local_change) { store_entity_proc_in_set.proc = entityProc.second; - impl::VisitClosureGeneral(*this,entityProc.first,store_entity_proc_in_set,store_entity_proc_in_set); + impl::VisitClosureGeneral(*this,entityProc.first,entity_rank(entityProc.first),store_entity_proc_in_set,store_entity_proc_in_set); } // Calculate all the ghosts that are impacted by the set of ownership @@ -3241,7 +3202,10 @@ void BulkData::internal_verify_inputs_and_change_ghosting( //---------------------------------------------------------------------- -void BulkData::ghost_entities_and_fields(Ghosting & ghosting, const std::set& sendGhosts, bool isFullRegen) +void BulkData::ghost_entities_and_fields(Ghosting & ghosting, + const std::set& sendGhosts, + bool isFullRegen, + const std::vector& removedSendGhosts) { //------------------------------------ // Push newly ghosted entities to the receivers and update the comm list. @@ -3261,6 +3225,8 @@ void BulkData::ghost_entities_and_fields(Ghosting & ghosting, const std::set( entity_rank(entity) ); + unsigned flag = 1; + buf.pack(flag); pack_entity_info(*this, buf , entity ); pack_field_values(*this, buf , entity ); @@ -3274,6 +3240,14 @@ void BulkData::ghost_entities_and_fields(Ghosting & ghosting, const std::set(entity_rank(ep.first)); + unsigned flag = 0; + buf.pack(flag); + buf.pack(entity_key(ep.first)); + } + if (phase == 0) { commSparse.allocate_buffers(); } @@ -3287,6 +3261,7 @@ void BulkData::ghost_entities_and_fields(Ghosting & ghosting, const std::set relations ; + std::vector removedRecvGhosts; const MetaData & meta = mesh_meta_data() ; const unsigned rank_count = meta.entity_rank_count(); @@ -3305,12 +3280,30 @@ void BulkData::ghost_entities_and_fields(Ghosting & ghosting, const std::set( this_rank ); + unsigned rankAndFlag[2] = {~0u,~0u}; + buf.peek( rankAndFlag, 2 ); - if ( this_rank != rank ) break ; + if ( rankAndFlag[1] == 1 && rankAndFlag[0] != rank ) break ; - buf.unpack( this_rank ); + if (rankAndFlag[1] == 0) { + while(buf.remaining()) { + buf.unpack( rankAndFlag[0] ); + buf.unpack( rankAndFlag[1] ); + + ThrowAssert(rankAndFlag[1] == 0); + EntityKey key; + buf.unpack(key); + Entity rmEnt = get_entity(key); + if (!is_valid(rmEnt)) { + continue; + } + removedRecvGhosts.push_back(EntityProc(rmEnt,p)); + } + break; + } + + buf.unpack( rankAndFlag[0] ); + buf.unpack( rankAndFlag[1] ); } parts.clear(); @@ -3326,10 +3319,11 @@ void BulkData::ghost_entities_and_fields(Ghosting & ghosting, const std::setsecond; + auto& ghost_reuse_map = m_meshModification.get_deleted_entity_cache().get_ghost_reuse_map(); + GhostReuseMap::iterator f_itr = ghost_reuse_map.find(key); + const size_t use_this_offset = f_itr == ghost_reuse_map.end() ? 0 : f_itr->second; if (use_this_offset != 0) { - m_ghost_reuse_map.erase(f_itr); + ghost_reuse_map.erase(f_itr); } std::pair result = internal_get_or_create_entity_with_notification( key, use_this_offset ); @@ -3391,6 +3385,35 @@ void BulkData::ghost_entities_and_fields(Ghosting & ghosting, const std::setkey == key) { + const int owner = parallel_owner_rank(itr->entity); + if (owner != parallel_rank()) { + if ( internal_entity_comm_map(itr->entity).empty() ) { + if (is_valid(itr->entity)) { + internal_destroy_entity_with_notification(itr->entity, true); + } + } + else { + internal_change_entity_parts(itr->entity, addParts, removeParts, scratchOrdinalVec, scratchSpace); + } + } + } + } + } + delete_unneeded_entries_from_the_comm_list(); } void BulkData::conditionally_add_entity_to_ghosting_set(const stk::mesh::Ghosting &ghosting, stk::mesh::Entity entity, int toProc, std::set &entitiesWithClosure) @@ -3498,8 +3521,8 @@ void BulkData::filter_ghosting_remove_receives(const stk::mesh::Ghosting &ghosti Entity const *rels_e = end(e, irank); for (; rels_i != rels_e; ++rels_i) { - if (erank > stk::topology::ELEM_RANK) { - impl::VisitClosureGeneral(*this, *rels_i, vpb, org); + if (irank > stk::topology::ELEM_RANK) { + impl::VisitClosureGeneral(*this, *rels_i, irank, vpb, org); } else { if ( is_valid(*rels_i) && @@ -3545,7 +3568,7 @@ void BulkData::internal_change_ghosting( for ( const EntityProc& entityProc : add_send ) { og.proc = entityProc.second; sieps.proc = entityProc.second; - impl::VisitClosureGeneral(*this,entityProc.first,sieps,og); + impl::VisitClosureGeneral(*this,entityProc.first,entity_rank(entityProc.first),sieps,og); } //remove newSendGhosts that are already in comm-list: @@ -3666,98 +3689,6 @@ int BulkData::determine_new_owner( Entity entity ) const return new_owner ; } -bool BulkData::pack_entity_modification( const bool packShared , stk::CommSparse & comm ) -{ - bool flag = false; - bool packGhosted = packShared == false; - - const EntityCommListInfoVector & entityCommList = this->internal_comm_list(); - - for ( EntityCommListInfoVector::const_iterator - i = entityCommList.begin() ; i != entityCommList.end() ; ++i ) { - if (i->entity_comm != nullptr) { - Entity entity = i->entity; - EntityState status = this->is_valid(entity) ? this->state(entity) : Deleted; - - if ( status == Modified || status == Deleted ) { - int owned_closure_int = owned_closure(entity) ? 1 : 0; - - for ( PairIterEntityComm ec(i->entity_comm->comm_map); ! ec.empty() ; ++ec ) - { - if ( ( packGhosted && ec->ghost_id > BulkData::SHARED ) || ( packShared && ec->ghost_id == BulkData::SHARED ) ) - { - comm.send_buffer( ec->proc ) - .pack( i->key ) - .pack( status ) - .pack(owned_closure_int); - - const bool promotingGhostToShared = - packGhosted && owned_closure_int==1 && !bucket(entity).owned(); - if (promotingGhostToShared) { - comm.send_buffer(parallel_rank()) - .pack( i->key ) - .pack( status ) - .pack(owned_closure_int); - } - - flag = true ; - } - } - } - } - } - - return flag ; -} - -void BulkData::communicate_entity_modification( const bool shared , std::vector & data ) -{ - stk::CommSparse comm( this->parallel() ); - const int p_size = comm.parallel_size(); - - // Sizing send buffers: - pack_entity_modification(shared , comm); - - const bool needToSendOrRecv = comm.allocate_buffers(); - if ( needToSendOrRecv ) - { - - // Packing send buffers: - pack_entity_modification(shared , comm); - - comm.communicate(); - - const EntityCommListInfoVector & entityCommList = this->internal_comm_list(); - for ( int procNumber = 0 ; procNumber < p_size ; ++procNumber ) { - CommBuffer & buf = comm.recv_buffer( procNumber ); - EntityKey key; - EntityState state; - int remote_owned_closure_int; - bool remote_owned_closure; - - while ( buf.remaining() ) { - - buf.unpack( key ) - .unpack( state ) - .unpack( remote_owned_closure_int); - remote_owned_closure = ((remote_owned_closure_int==1)?true:false); - - // search through entity_comm, should only receive info on entities - // that are communicated. - EntityCommListInfo info = find_entity(*this, entityCommList, key); - int remoteProc = procNumber; - if (!shared && remoteProc == parallel_rank()) { - remoteProc = parallel_owner_rank(info.entity); - } - EntityParallelState parallel_state = {remoteProc, state, info, remote_owned_closure, this}; - data.push_back( parallel_state ); - } - } - } - - std::sort( data.begin() , data.end() ); -} - //---------------------------------------------------------------------- //---------------------------------------------------------------------- @@ -3932,34 +3863,6 @@ void BulkData::internal_update_sharing_comm_map_and_fill_list_modified_shared_en std::fill(m_mark_entity.begin(), m_mark_entity.end(), BulkData::NOT_MARKED); } - - -//---------------------------------------------------------------------- - -void BulkData::internal_establish_new_owner(stk::mesh::Entity entity) -{ - const int new_owner = determine_new_owner(entity); - - internal_set_owner(entity, new_owner); -} - -void BulkData::internal_update_parts_for_shared_entity(stk::mesh::Entity entity, const bool is_entity_shared, const bool did_i_just_become_owner) -{ - OrdinalVector parts_to_add_entity_to , parts_to_remove_entity_from, scratchOrdinalVec, scratchSpace; - - if ( !is_entity_shared ) { - parts_to_remove_entity_from.push_back(mesh_meta_data().globally_shared_part().mesh_meta_data_ordinal()); - } - - if ( did_i_just_become_owner ) { - parts_to_add_entity_to.push_back(mesh_meta_data().locally_owned_part().mesh_meta_data_ordinal()); - } - - if ( ! parts_to_add_entity_to.empty() || ! parts_to_remove_entity_from.empty() ) { - internal_change_entity_parts( entity , parts_to_add_entity_to , parts_to_remove_entity_from, scratchOrdinalVec, scratchSpace ); - } -} - void BulkData::filter_upward_ghost_relations(const Entity entity, std::function filter) { EntityRank rank = entity_rank(entity); @@ -3999,182 +3902,6 @@ EntityVector BulkData::get_upward_send_ghost_relations(const Entity entity) return ghosts; } -void BulkData::add_entity_to_same_ghosting(Entity entity, Entity connectedGhost) { - for(PairIterEntityComm ec(internal_entity_comm_map(connectedGhost)); ! ec.empty(); ++ec) { - if (ec->ghost_id > BulkData::AURA) { - entity_comm_map_insert(entity, EntityCommInfo(ec->ghost_id, ec->proc)); - entity_comm_list_insert(entity); - } - } -} - -void BulkData::internal_resolve_formerly_shared_entities(const EntityVector& entitiesNoLongerShared) -{ - for(Entity entity : entitiesNoLongerShared) { - EntityVector ghostRelations = get_upward_send_ghost_relations(entity); - - for(Entity ghost : ghostRelations) { - add_entity_to_same_ghosting(entity, ghost); - } - } -} - -//---------------------------------------------------------------------- -// Resolve modifications for ghosted entities: -// If a ghosted entity is modified or destroyed on the owning -// process then the ghosted entity must be destroyed. -// -// Post condition: -// Ghosted entities of modified or deleted entities are destroyed. -// Ghosted communication lists are cleared to reflect all deletions. - -void BulkData::internal_resolve_ghosted_modify_delete(const stk::mesh::EntityVector& entitiesNoLongerShared) -{ - ThrowRequireMsg(parallel_size() > 1, "Do not call this in serial"); - // Resolve modifications for ghosted entities: - - std::vector remotely_modified_ghosted_entities ; - internal_resolve_formerly_shared_entities(entitiesNoLongerShared); - - // Communicate entity modification state for ghost entities - const bool communicate_shared = false ; - communicate_entity_modification( communicate_shared , remotely_modified_ghosted_entities ); - - const size_t ghosting_count = m_ghosting.size(); - const size_t ghosting_count_minus_shared = ghosting_count - 1; - - std::vector promotingToShared; - - // We iterate backwards over remote_mod to ensure that we hit the - // higher-ranking entities first. This is important because higher-ranking - // entities like element must be deleted before the nodes they have are - // deleted. - for ( std::vector::reverse_iterator - i = remotely_modified_ghosted_entities.rbegin(); i != remotely_modified_ghosted_entities.rend() ; ++i ) - { - Entity entity = i->comm_info.entity; - const EntityKey key = i->comm_info.key; - const int remote_proc = i->from_proc; - const bool local_owner = parallel_owner_rank(entity) == parallel_rank() ; - const bool remotely_destroyed = Deleted == i->state ; - const bool remote_proc_is_owner = remote_proc == parallel_owner_rank(entity); - const bool isAlreadyDestroyed = !is_valid(entity); - - if ( local_owner ) { // Sending to 'remote_proc' for ghosting - - if ( remotely_destroyed ) { - - // remove from ghost-send list - - for ( size_t j = ghosting_count_minus_shared ; j>=1 ; --j) { - entity_comm_map_erase( key, EntityCommInfo( j , remote_proc ) ); - } - } - else { - if (!in_ghost(aura_ghosting(), entity) && state(entity)==Unchanged) { - set_state(entity, Modified); - } - - const bool shouldPromoteToShared = !isAlreadyDestroyed && i->remote_owned_closure==1 && key.rank() < stk::topology::ELEM_RANK; - if (shouldPromoteToShared) { - entity_comm_map_insert(entity, EntityCommInfo(SHARED, remote_proc)); - promotingToShared.push_back(entity); - } - } - } - else if (remote_proc_is_owner) { // Receiving from 'remote_proc' for ghosting - - const bool hasBeenPromotedToSharedOrOwned = this->owned_closure(entity); - bool isAuraGhost = false; - bool isCustomGhost = false; - PairIterEntityComm pairIterEntityComm = internal_entity_comm_map(entity); - if(pairIterEntityComm.empty()) { - if(std::binary_search(entitiesNoLongerShared.begin(), entitiesNoLongerShared.end(), entity)) { - EntityVector ghosts = get_upward_recv_ghost_relations(entity); - - for(Entity ghost : ghosts) { - add_entity_to_same_ghosting(entity, ghost); - } - } - } else { - for(unsigned j=0; j AURA) - { - isCustomGhost = true; - } - } - } - - if ( isAuraGhost ) { - if (!isAlreadyDestroyed && hasBeenPromotedToSharedOrOwned) { - entity_comm_map_insert(entity, EntityCommInfo(SHARED, remote_proc)); - promotingToShared.push_back(entity); - } - entity_comm_map_erase(key, aura_ghosting()); - } - - if(!isAlreadyDestroyed) - { - const bool wasDestroyedByOwner = remotely_destroyed; - const bool shouldDestroyGhost = wasDestroyedByOwner || (isAuraGhost && !isCustomGhost && !hasBeenPromotedToSharedOrOwned); - const bool shouldRemoveFromGhosting = remotely_destroyed && !isAuraGhost && hasBeenPromotedToSharedOrOwned; - - if (shouldRemoveFromGhosting) { - for ( size_t j = ghosting_count_minus_shared ; j >=1 ; --j ) { - entity_comm_map_erase( key, *m_ghosting[j] ); - } - } - - if ( shouldDestroyGhost ) - { - const bool was_ghost = true; - internal_destroy_entity_with_notification(entity, was_ghost); - } - - entity_comm_list_insert(entity); - } - } - } // end loop on remote mod - - // Erase all ghosting communication lists for: - // 1) Destroyed entities. - // 2) Owned and modified entities. - - for ( EntityCommListInfoVector::const_reverse_iterator - i = internal_comm_list().rbegin() ; i != internal_comm_list().rend() ; ++i) { - - Entity entity = i->entity; - - const bool locally_destroyed = !is_valid(entity); - const bool locally_owned_and_modified = locally_destroyed ? false : - (Modified == state(entity) && (parallel_rank() == parallel_owner_rank(entity))); - - if ( locally_destroyed ) { - for ( size_t j = ghosting_count_minus_shared ; j >=1 ; --j ) { - entity_comm_map_erase( i->key, *m_ghosting[j] ); - } - } - else if ( locally_owned_and_modified ) { - entity_comm_map_erase( i->key, aura_ghosting() ); - } - } - - if (!promotingToShared.empty()) { - OrdinalVector sharedPart, auraPart, scratchOrdinalVec, scratchSpace; - sharedPart.push_back(mesh_meta_data().globally_shared_part().mesh_meta_data_ordinal()); - auraPart.push_back(mesh_meta_data().aura_part().mesh_meta_data_ordinal()); - for(Entity entity : promotingToShared) { - internal_change_entity_parts(entity, sharedPart /*add*/, auraPart /*remove*/, scratchOrdinalVec, scratchSpace); - } - add_comm_list_entries_for_entities(promotingToShared); - } -} - void BulkData::resolve_ownership_of_modified_entities( const std::vector &shared_modified ) { const BulkData& bulk = *this; @@ -4985,7 +4712,7 @@ void BulkData::internal_send_part_memberships_from_owner(const std::vectorbucket(entity).supersets(); @@ -5811,71 +5538,7 @@ bool BulkData::comm_mesh_verify_parallel_consistency(std::ostream & error_log ) return verified_ok == 1 ; } -// Enforce that shared entities must be in the owned closure: - -void BulkData::destroy_dependent_ghosts( Entity entity, EntityProcVec& entitiesToRemoveFromSharing ) -{ - EntityRank entity_rank = this->entity_rank(entity); - const EntityRank end_rank = static_cast(this->mesh_meta_data().entity_rank_count()); - - for (EntityRank irank = static_cast(end_rank - 1); irank > entity_rank; --irank) - { - int num_rels = this->num_connectivity(entity, irank); - const Entity* rels = this->begin(entity, irank); - - for (int r = num_rels - 1; r >= 0; --r) - { - Entity e = rels[r]; - - bool upwardRelationOfEntityIsInClosure = this->owned_closure(e); - ThrowRequireMsg( !upwardRelationOfEntityIsInClosure, this->entity_rank(e) << " with id " << this->identifier(e) << " should not be in closure." ); - - // Recursion - if (this->is_valid(e) && this->bucket(e).in_aura()) - { - this->destroy_dependent_ghosts( e, entitiesToRemoveFromSharing ); - } - } - } - - const bool successfully_destroyed_entity = this->destroy_entity(entity); - if (!successfully_destroyed_entity) - { - std::vector sharing_procs; - comm_shared_procs(entity_key(entity), sharing_procs); - for(int p : sharing_procs) { - entitiesToRemoveFromSharing.emplace_back(entity, p); - } - } -} - -// Entities with sharing information that are not in the owned closure -// have been modified such that they are no longer shared. -// These may no longer be needed or may become ghost entities. -// There is not enough information so assume they are to be deleted -// and let these entities be re-ghosted if they are needed. - -// Open question: Should an owned and shared entity that does not -// have an upward relation to an owned entity be destroyed so that -// ownership transfers to another process? - -void BulkData::delete_shared_entities_which_are_no_longer_in_owned_closure(EntityProcVec& entitiesToRemoveFromSharing) -{ - for ( EntityCommListInfoVector::const_reverse_iterator - i = internal_comm_list().rbegin() ; - i != internal_comm_list().rend() ; ++i) - { - Entity entity = i->entity; - if (is_valid(entity) && !owned_closure(entity)) { - if ( in_shared(entity) ) - { - destroy_dependent_ghosts( entity, entitiesToRemoveFromSharing ); - } - } - } -} - -void BulkData::remove_entities_from_sharing(const EntityProcVec& entitiesToRemoveFromSharing, stk::mesh::EntityVector & entitiesNoLongerShared) +void BulkData::remove_entities_from_sharing(const EntityProcVec& entitiesToRemoveFromSharing, EntityVector & entitiesNoLongerShared) { entitiesNoLongerShared.clear(); OrdinalVector scratchOrdinalVec, scratchSpace; @@ -5886,7 +5549,7 @@ void BulkData::remove_entities_from_sharing(const EntityProcVec& entitiesToRemov entitiesNoLongerShared.push_back(entityAndProc.first); this->internal_change_entity_parts(entityAndProc.first,{},{this->mesh_meta_data().globally_shared_part().mesh_meta_data_ordinal()}, scratchOrdinalVec, scratchSpace); this->internal_mark_entity(entityAndProc.first, NOT_SHARED); - } + } } stk::util::sort_and_unique(entitiesNoLongerShared); } @@ -6407,7 +6070,9 @@ void BulkData::mark_entities_as_deleted(stk::mesh::Bucket * bucket) for(Entity e : *bucket) { notifier.notify_entity_deleted(e); - record_entity_deletion(e); + record_entity_deletion(e, false); // the only other user of record_entity_deletion adds the + // entity to the m_deleted_entities_current_modification_cycle if + // it is not a ghost. Not sure why this doesn't. } } diff --git a/packages/stk/stk_mesh/stk_mesh/base/BulkData.hpp b/packages/stk/stk_mesh/stk_mesh/base/BulkData.hpp index e03ffb431f12..7f648452c677 100644 --- a/packages/stk/stk_mesh/stk_mesh/base/BulkData.hpp +++ b/packages/stk/stk_mesh/stk_mesh/base/BulkData.hpp @@ -80,7 +80,7 @@ namespace stk { namespace mesh { class MetaData; } } namespace stk { namespace mesh { class Part; } } namespace stk { namespace mesh { class BulkData; } } namespace stk { namespace mesh { namespace impl { class AuraGhosting; } } } -namespace stk { namespace mesh { namespace impl { class EntityRepository; } } } +namespace stk { namespace mesh { namespace impl { class EntityKeyMapping; } } } namespace stk { namespace mesh { class FaceCreator; } } namespace stk { namespace mesh { class ElemElemGraph; } } namespace stk { namespace mesh { class ElemElemGraphUpdater; } } @@ -145,7 +145,6 @@ stk::mesh::Entity connect_side_to_element(stk::mesh::BulkData& bulkData, stk::me stk::mesh::Permutation side_permutation, const stk::mesh::PartVector& parts); } -typedef std::unordered_map GhostReuseMap; struct sharing_info { @@ -175,6 +174,7 @@ class BulkData { enum EntitySharing : char { NOT_MARKED=0, POSSIBLY_SHARED=1, IS_SHARED=2, NOT_SHARED }; enum AutomaticAuraOption { NO_AUTO_AURA, AUTO_AURA }; +#ifndef STK_HIDE_DEPRECATED_CODE // Delete after August 2022 /** \brief Construct mesh bulk data manager conformal to the given * \ref stk::mesh::MetaData "meta data manager" and will * distribute bulk data over the given parallel machine. @@ -182,7 +182,7 @@ class BulkData { * - The maximum number of entities per bucket may be supplied. * - The bulk data is in the synchronized or "locked" state. */ - BulkData( MetaData & mesh_meta_data + STK_DEPRECATED BulkData( MetaData & mesh_meta_data , ParallelMachine parallel , enum AutomaticAuraOption auto_aura_option = AUTO_AURA #ifdef SIERRA_MIGRATION @@ -191,13 +191,14 @@ class BulkData { , FieldDataManager *field_dataManager = nullptr , unsigned bucket_capacity = impl::BucketRepository::default_bucket_capacity ); +#endif virtual ~BulkData(); //------------------------------------ /** \brief The meta data manager for this bulk data manager. */ - const MetaData & mesh_meta_data() const { return *m_meta_raw_ptr_to_be_deprecated ; } - MetaData & mesh_meta_data() { return *m_meta_raw_ptr_to_be_deprecated ; } + const MetaData & mesh_meta_data() const { return *m_meta_data ; } + MetaData & mesh_meta_data() { return *m_meta_data ; } std::shared_ptr mesh_meta_data_ptr() {return m_meta_data; } const std::shared_ptr mesh_meta_data_ptr() const { return m_meta_data; } @@ -945,7 +946,8 @@ void get_entities(EntityRank rank, Selector const& selector, EntityVector& outpu void ghost_entities_and_fields(Ghosting & ghosting, const std::set& new_send, - bool isFullRegen = false); + bool isFullRegen = false, + const std::vector& removedSendGhosts = std::vector()); void conditionally_add_entity_to_ghosting_set(const stk::mesh::Ghosting &ghosting, stk::mesh::Entity entity, @@ -975,7 +977,7 @@ void get_entities(EntityRank rank, Selector const& selector, EntityVector& outpu PairIterEntityComm internal_entity_comm_map(Entity entity, const Ghosting & sub ) const { if (m_entitycomm[entity.local_offset()] != nullptr) { - return ghost_info_range(m_entitycomm[entity.local_offset()]->comm_map, sub); + return ghost_info_range(m_entitycomm[entity.local_offset()]->comm_map, sub.ordinal()); } return PairIterEntityComm(); } @@ -1103,11 +1105,8 @@ void get_entities(EntityRank rank, Selector const& selector, EntityVector& outpu void filter_upward_ghost_relations(const Entity entity, std::function filter); EntityVector get_upward_send_ghost_relations(const Entity entity); EntityVector get_upward_recv_ghost_relations(const Entity entity); - void add_entity_to_same_ghosting(Entity entity, Entity connectedGhost); void update_comm_list_based_on_changes_in_comm_map(); - void internal_resolve_formerly_shared_entities(const stk::mesh::EntityVector& entitiesNoLongerShared); - void internal_resolve_ghosted_modify_delete(const stk::mesh::EntityVector& entitiesNoLongerShared); void internal_resolve_shared_part_membership_for_element_death(); // Mod Mark void remove_unneeded_induced_parts(stk::mesh::Entity entity, const EntityCommInfoVector& entity_comm_info, @@ -1230,8 +1229,7 @@ void get_entities(EntityRank rank, Selector const& selector, EntityVector& outpu void check_mesh_consistency(); bool comm_mesh_verify_parallel_consistency(std::ostream & error_log); - void delete_shared_entities_which_are_no_longer_in_owned_closure(EntityProcVec& entitiesToRemoveFromSharing); // Mod Mark - virtual void remove_entities_from_sharing(const EntityProcVec& entitiesToRemoveFromSharing, stk::mesh::EntityVector & entitiesNoLongerShared); + virtual void remove_entities_from_sharing(const EntityProcVec& entitiesToRemoveFromSharing, EntityVector & entitiesNoLongerShared); virtual void check_if_entity_from_other_proc_exists_on_this_proc_and_update_info_if_shared(std::vector& shared_entity_map, int proc_id, const shared_entity_type &sentity); void update_owner_global_key_and_sharing_proc(stk::mesh::EntityKey global_key_other_proc, shared_entity_type& shared_entity_this_proc, int proc_id) const; void update_shared_entity_this_proc(EntityKey global_key_other_proc, shared_entity_type& shared_entity_this_proc, int proc_id); @@ -1289,7 +1287,7 @@ void get_entities(EntityRank rank, Selector const& selector, EntityVector& outpu void set_ngp_mesh(NgpMeshBase * ngpMesh) const { m_ngpMeshBase = ngpMesh; } NgpMeshBase * get_ngp_mesh() const { return m_ngpMeshBase; } - void record_entity_deletion(Entity entity); + void record_entity_deletion(Entity entity, bool isGhost); void break_boundary_relations_and_delete_buckets(const std::vector & relationsToDestroy, const stk::mesh::BucketVector & bucketsToDelete); void delete_buckets(const stk::mesh::BucketVector & buckets); void mark_entities_as_deleted(stk::mesh::Bucket * bucket); @@ -1311,26 +1309,6 @@ void get_entities(EntityRank rank, Selector const& selector, EntityVector& outpu void internal_resolve_sharing_and_ghosting_for_sides(bool connectFacesToPreexistingGhosts); -#ifdef __CUDACC__ -public: -#endif - struct EntityParallelState { - int from_proc; - EntityState state; - EntityCommListInfo comm_info; - bool remote_owned_closure; - const BulkData* mesh; - - bool operator<(const EntityParallelState& rhs) const - { return EntityLess(*mesh)(comm_info.entity, rhs.comm_info.entity); } - }; -#ifdef __CUDACC__ -private: -#endif - - void communicate_entity_modification( const bool shared , std::vector & data ); // Mod Mark - bool pack_entity_modification( const bool packShared , stk::CommSparse & comm ); - virtual bool does_entity_need_orphan_protection(stk::mesh::Entity entity) const { const bool isNode = (stk::topology::NODE_RANK == entity_rank(entity)); @@ -1450,9 +1428,6 @@ void get_entities(EntityRank rank, Selector const& selector, EntityVector& outpu ModEndOptimizationFlag opt ); // Mod Mark - void internal_establish_new_owner(stk::mesh::Entity entity); - void internal_update_parts_for_shared_entity(stk::mesh::Entity entity, const bool is_entity_shared, const bool did_i_just_become_owner); - inline void internal_check_unpopulated_relations(Entity entity, EntityRank rank) const; void internal_adjust_closure_count(Entity entity, @@ -1536,8 +1511,6 @@ void get_entities(EntityRank rank, Selector const& selector, EntityVector& outpu void reset_add_node_sharing() { m_add_node_sharing_called = false; } - void destroy_dependent_ghosts( Entity entity, EntityProcVec& entitiesToRemoveFromSharing ); - template Entity create_and_connect_side(const stk::mesh::EntityId globalSideId, Entity elem, @@ -1558,19 +1531,17 @@ void get_entities(EntityRank rank, Selector const& selector, EntityVector& outpu static const uint16_t orphaned_node_marking; EntityCommDatabase m_entity_comm_map; std::vector m_ghosting; - MetaData *m_meta_raw_ptr_to_be_deprecated; std::shared_ptr m_meta_data; std::vector m_mark_entity; //indexed by Entity bool m_add_node_sharing_called; std::vector m_closure_count; //indexed by Entity std::vector m_mesh_indexes; //indexed by Entity - impl::EntityRepository* m_entity_repo; + impl::EntityKeyMapping* m_entityKeyMapping; EntityCommListInfoVector m_entity_comm_list; std::vector m_entitycomm; std::vector m_owner; + std::vector> m_removedGhosts; CommListUpdater m_comm_list_updater; - std::list m_deleted_entities_current_modification_cycle; - GhostReuseMap m_ghost_reuse_map; std::vector m_entity_keys; //indexed by Entity #ifdef SIERRA_MIGRATION @@ -1603,7 +1574,6 @@ void get_entities(EntityRank rank, Selector const& selector, EntityVector& outpu mutable unsigned m_volatile_fast_shared_comm_map_sync_count; std::vector > m_all_sharing_procs; PartVector m_ghost_parts; - std::list m_deleted_entities; int m_num_fields; bool m_keep_fields_updated; std::vector m_local_ids; //indexed by Entity diff --git a/packages/stk/stk_mesh/stk_mesh/base/CommListUpdater.hpp b/packages/stk/stk_mesh/stk_mesh/base/CommListUpdater.hpp index 08e334714926..05dfe6ce99f5 100644 --- a/packages/stk/stk_mesh/stk_mesh/base/CommListUpdater.hpp +++ b/packages/stk/stk_mesh/stk_mesh/base/CommListUpdater.hpp @@ -44,12 +44,19 @@ namespace mesh { class CommListUpdater : public CommMapChangeListener { public: CommListUpdater(EntityCommListInfoVector& comm_list, - std::vector& entity_comms) - : m_comm_list(comm_list), m_entity_comms(entity_comms) + std::vector& entity_comms, + std::vector>& removedGhosts) + : m_comm_list(comm_list), + m_entity_comms(entity_comms), + m_removedGhosts(removedGhosts) {} virtual ~CommListUpdater(){} - virtual void removedKey(const EntityKey& key) { + void removedGhost(const EntityKey& key, unsigned ghostId, int proc) override { + m_removedGhosts.emplace_back(key, EntityCommInfo(ghostId, proc)); + } + + void removedKey(const EntityKey& key) override { EntityCommListInfoVector::iterator iter = std::lower_bound(m_comm_list.begin(), m_comm_list.end(), key); if (iter != m_comm_list.end() && iter->key == key) { @@ -61,6 +68,7 @@ class CommListUpdater : public CommMapChangeListener { private: EntityCommListInfoVector& m_comm_list; std::vector& m_entity_comms; + std::vector>& m_removedGhosts; }; } diff --git a/packages/stk/stk_mesh/stk_mesh/base/DeviceField.hpp b/packages/stk/stk_mesh/stk_mesh/base/DeviceField.hpp index 07368d21f3ba..0d077012ad37 100644 --- a/packages/stk/stk_mesh/stk_mesh/base/DeviceField.hpp +++ b/packages/stk/stk_mesh/stk_mesh/base/DeviceField.hpp @@ -474,12 +474,8 @@ class DeviceField : public NgpFieldBase newDeviceSelectedBucketOffset = UnsignedViewType(Kokkos::view_alloc(Kokkos::WithoutInitializing, hostField->name() + "_bucket_offset"), allBuckets.size()); -#ifndef NEW_TRILINOS_INTEGRATION - newHostSelectedBucketOffset = Kokkos::create_mirror_view(Kokkos::HostSpace(), newDeviceSelectedBucketOffset, Kokkos::WithoutInitializing); -#else newHostSelectedBucketOffset = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, Kokkos::HostSpace(), newDeviceSelectedBucketOffset); -#endif for(unsigned i = 0; i < allBuckets.size(); i++) { if(selector(*allBuckets[i])) { diff --git a/packages/stk/stk_mesh/stk_mesh/base/DeviceMesh.cpp b/packages/stk/stk_mesh/stk_mesh/base/DeviceMesh.cpp index 30bd8e5479ce..328749115bf2 100644 --- a/packages/stk/stk_mesh/stk_mesh/base/DeviceMesh.cpp +++ b/packages/stk/stk_mesh/stk_mesh/base/DeviceMesh.cpp @@ -48,11 +48,7 @@ void DeviceBucket::initialize_bucket_attributes(const stk::mesh::Bucket &bucket) void DeviceBucket::allocate(const stk::mesh::Bucket &bucket) { nodeOffsets = OrdinalViewType(Kokkos::view_alloc(Kokkos::WithoutInitializing, "NodeOffsets"), bucket.size()+1); -#ifndef NEW_TRILINOS_INTEGRATION - hostNodeOffsets = Kokkos::create_mirror_view(Kokkos::HostSpace(), nodeOffsets, Kokkos::WithoutInitializing); -#else hostNodeOffsets = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, Kokkos::HostSpace(), nodeOffsets); -#endif unsigned maxNodesPerEntity = bucketTopology.num_nodes(); unsigned totalNumNodes = bucketTopology.num_nodes()*bucketCapacity; @@ -68,33 +64,17 @@ void DeviceBucket::allocate(const stk::mesh::Bucket &bucket) const stk::mesh::PartVector& parts = bucket.supersets(); entities = EntityViewType(Kokkos::view_alloc(Kokkos::WithoutInitializing, "BucketEntities"), bucketCapacity); -#ifndef NEW_TRILINOS_INTEGRATION - hostEntities = Kokkos::create_mirror_view(Kokkos::HostSpace(), entities, Kokkos::WithoutInitializing); -#else hostEntities = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, Kokkos::HostSpace(), entities); -#endif nodeConnectivity = BucketConnectivityType(Kokkos::view_alloc(Kokkos::WithoutInitializing, "BucketConnectivity"), totalNumNodes); -#ifndef NEW_TRILINOS_INTEGRATION - hostNodeConnectivity = Kokkos::create_mirror_view(Kokkos::HostSpace(), nodeConnectivity, Kokkos::WithoutInitializing); -#else hostNodeConnectivity = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, Kokkos::HostSpace(), nodeConnectivity); -#endif nodeOrdinals = OrdinalViewType(Kokkos::view_alloc(Kokkos::WithoutInitializing, "NodeOrdinals"), static_cast(maxNodesPerEntity)); -#ifndef NEW_TRILINOS_INTEGRATION - hostNodeOrdinals = Kokkos::create_mirror_view(Kokkos::HostSpace(), nodeOrdinals, Kokkos::WithoutInitializing); -#else hostNodeOrdinals = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, Kokkos::HostSpace(), nodeOrdinals); -#endif partOrdinals = PartOrdinalViewType(Kokkos::view_alloc(Kokkos::WithoutInitializing, "PartOrdinals"), parts.size()); -#ifndef NEW_TRILINOS_INTEGRATION - hostPartOrdinals = Kokkos::create_mirror_view(Kokkos::HostSpace(), partOrdinals, Kokkos::WithoutInitializing); -#else hostPartOrdinals = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, Kokkos::HostSpace(), partOrdinals); -#endif } void DeviceBucket::initialize_from_host(const stk::mesh::Bucket &bucket) @@ -127,11 +107,7 @@ void DeviceBucket::update_from_host(const stk::mesh::Bucket &bucket) if (bucketSize+1 != hostNodeOffsets.size()) { nodeOffsets = OrdinalViewType(Kokkos::view_alloc(Kokkos::WithoutInitializing, "NodeOffsets"), bucketSize+1); -#ifndef NEW_TRILINOS_INTEGRATION - hostNodeOffsets = Kokkos::create_mirror_view(Kokkos::HostSpace(), nodeOffsets, Kokkos::WithoutInitializing); -#else hostNodeOffsets = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, Kokkos::HostSpace(), nodeOffsets); -#endif } unsigned totalNumNodes = bucket.topology().num_nodes()*bucketCapacity; @@ -147,23 +123,14 @@ void DeviceBucket::update_from_host(const stk::mesh::Bucket &bucket) if (totalNumNodes != hostNodeConnectivity.size()) { nodeConnectivity = BucketConnectivityType(Kokkos::view_alloc(Kokkos::WithoutInitializing, "BucketConnectivity"), totalNumNodes); -#ifndef NEW_TRILINOS_INTEGRATION - hostNodeConnectivity = - Kokkos::create_mirror_view(Kokkos::HostSpace(), nodeConnectivity, Kokkos::WithoutInitializing); -#else hostNodeConnectivity = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, Kokkos::HostSpace(), nodeConnectivity); -#endif } if (maxNodesPerEntity != hostNodeOrdinals.size()) { nodeOrdinals = OrdinalViewType(Kokkos::view_alloc(Kokkos::WithoutInitializing, "NodeOrdinals"), static_cast(maxNodesPerEntity)); -#ifndef NEW_TRILINOS_INTEGRATION - hostNodeOrdinals = Kokkos::create_mirror_view(Kokkos::HostSpace(), nodeOrdinals, Kokkos::WithoutInitializing); -#else hostNodeOrdinals = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, Kokkos::HostSpace(), nodeOrdinals); -#endif for (unsigned i = 0; i < maxNodesPerEntity; ++i) { hostNodeOrdinals(i) = static_cast(i); } @@ -277,11 +244,7 @@ inline void reallocate_views(DEVICE_VIEW & deviceView, HOST_VIEW & hostView, siz if (needGrowth || needShrink) { const size_t newSize = requiredSize + static_cast(resizeFactor*requiredSize); deviceView = DEVICE_VIEW(Kokkos::view_alloc(Kokkos::WithoutInitializing, deviceView.label()), newSize); -#ifndef NEW_TRILINOS_INTEGRATION - hostView = Kokkos::create_mirror_view(Kokkos::HostSpace(), deviceView, Kokkos::WithoutInitializing); -#else hostView = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, Kokkos::HostSpace(), deviceView); -#endif } } diff --git a/packages/stk/stk_mesh/stk_mesh/base/EntityCommDatabase.cpp b/packages/stk/stk_mesh/stk_mesh/base/EntityCommDatabase.cpp index 418540387ef3..deee10010760 100644 --- a/packages/stk/stk_mesh/stk_mesh/base/EntityCommDatabase.cpp +++ b/packages/stk/stk_mesh/stk_mesh/base/EntityCommDatabase.cpp @@ -472,6 +472,9 @@ bool EntityCommDatabase::erase( const EntityKey & key, const EntityCommInfo & va const bool result = ( (i != comm_map.end()) && (val == *i) ) ; if ( result ) { + if (m_comm_map_change_listener != nullptr) { + m_comm_map_change_listener->removedGhost(key, i->ghost_id, i->proc); + } comm_map.erase( i ); bool deleted = false; if (comm_map.empty()) { @@ -515,6 +518,12 @@ bool EntityCommDatabase::erase( const EntityKey & key, const Ghosting & ghost ) const bool result = i != e ; if ( result ) { + if (m_comm_map_change_listener != nullptr) { + for(EntityCommInfoVector::iterator it = i; it != e; ++it) { + m_comm_map_change_listener->removedGhost(key, it->ghost_id, it->proc); + } + } + comm_map.erase( i , e ); bool deleted = false; if (comm_map.empty()) { diff --git a/packages/stk/stk_mesh/stk_mesh/base/EntityCommDatabase.hpp b/packages/stk/stk_mesh/stk_mesh/base/EntityCommDatabase.hpp index 0c7befed3007..25ca1d63fa84 100644 --- a/packages/stk/stk_mesh/stk_mesh/base/EntityCommDatabase.hpp +++ b/packages/stk/stk_mesh/stk_mesh/base/EntityCommDatabase.hpp @@ -62,6 +62,7 @@ namespace mesh { class CommMapChangeListener { public: virtual ~CommMapChangeListener(){} + virtual void removedGhost(const EntityKey& key, unsigned ghostId, int proc) = 0; virtual void removedKey(const EntityKey& key) = 0; }; @@ -128,17 +129,17 @@ PairIterEntityComm shared_comm_info_range(const EntityCommInfoVector& comm_info_ } inline -PairIterEntityComm ghost_info_range(const EntityCommInfoVector& commInfo, const Ghosting & ghosting) +PairIterEntityComm ghost_info_range(const EntityCommInfoVector& commInfo, unsigned ghostingOrdinal) { EntityCommInfoVector::const_iterator ghostBegin = commInfo.begin(); EntityCommInfoVector::const_iterator ghostEnd, end = commInfo.end(); - while(ghostBegin != end && ghostBegin->ghost_id != ghosting.ordinal()) { + while(ghostBegin != end && ghostBegin->ghost_id != ghostingOrdinal) { ++ghostBegin; } if (ghostBegin != end) { ghostEnd = ghostBegin+1; - while(ghostEnd != end && ghostEnd->ghost_id == ghosting.ordinal()) { + while(ghostEnd != end && ghostEnd->ghost_id == ghostingOrdinal) { ++ghostEnd; } return PairIterEntityComm( ghostBegin , ghostEnd ); diff --git a/packages/stk/stk_mesh/stk_mesh/base/EntityCommListInfo.hpp b/packages/stk/stk_mesh/stk_mesh/base/EntityCommListInfo.hpp index 6eda60b5cb9e..81eb78ac083c 100644 --- a/packages/stk/stk_mesh/stk_mesh/base/EntityCommListInfo.hpp +++ b/packages/stk/stk_mesh/stk_mesh/base/EntityCommListInfo.hpp @@ -73,7 +73,7 @@ struct IsInvalid { bool operator()(const EntityCommListInfo& comm) const { - return comm.key == EntityKey(); + return comm.key == EntityKey() || comm.entity_comm == nullptr; } }; diff --git a/packages/stk/stk_mesh/stk_mesh/base/EntityProcMapping.hpp b/packages/stk/stk_mesh/stk_mesh/base/EntityProcMapping.hpp index ad5ffc58a71f..cd2bdde6405a 100644 --- a/packages/stk/stk_mesh/stk_mesh/base/EntityProcMapping.hpp +++ b/packages/stk/stk_mesh/stk_mesh/base/EntityProcMapping.hpp @@ -79,11 +79,23 @@ bool is_valid(Entity entity) class EntityProcMapping { public: - EntityProcMapping(unsigned sizeOfEntityIndexSpace) + EntityProcMapping(unsigned sizeOfEntityIndexSpace = 1024) : entityOffsets(sizeOfEntityIndexSpace, -1), entitiesAndProcs() {} + void reset(unsigned sizeOfEntityIndexSpace) + { + for(int& n : entityOffsets) { + if (n != -1) { + n = -1; + } + } +// std::fill(entityOffsets.begin(), entityOffsets.end(), -1); + entityOffsets.resize(sizeOfEntityIndexSpace, -1); + entitiesAndProcs.clear(); + } + void addEntityProc(Entity entity, int proc) { int offset = entityOffsets[entity.local_offset()]; @@ -178,38 +190,35 @@ class EntityProcMapping { return 0; } - template - void fill_set(SetType& entityProcSet) + template + void visit_entity_procs(const Alg& alg) { - entityProcSet.clear(); for(const EntityAndProcs& entProcs : entitiesAndProcs) { if (is_valid(entProcs.entity) && entProcs.proc >= 0) { - entityProcSet.insert(EntityProc(entProcs.entity, entProcs.proc)); + alg(entProcs.entity, entProcs.proc); } else if (is_valid(entProcs.entity)) { for(int p : entProcs.procs) { - entityProcSet.insert(EntityProc(entProcs.entity, p)); + alg(entProcs.entity, p); } } } } + template + void fill_set(SetType& entityProcSet) + { + entityProcSet.clear(); + visit_entity_procs([&entityProcSet](Entity ent, int proc){entityProcSet.insert(EntityProc(ent,proc));}); + } + template void fill_vec(VecType& entityProcVec) { - entityProcVec.clear(); size_t lengthEstimate = static_cast(std::floor(1.2*entitiesAndProcs.size())); entityProcVec.reserve(lengthEstimate); - for(const EntityAndProcs& entProcs : entitiesAndProcs) { - if (is_valid(entProcs.entity) && entProcs.proc >= 0) { - entityProcVec.emplace_back(EntityProc(entProcs.entity, entProcs.proc)); - } - else if (is_valid(entProcs.entity)) { - for(int p : entProcs.procs) { - entityProcVec.emplace_back(EntityProc(entProcs.entity, p)); - } - } - } + entityProcVec.clear(); + visit_entity_procs([&entityProcVec](Entity ent, int proc){entityProcVec.push_back(EntityProc(ent,proc));}); } private: diff --git a/packages/stk/stk_mesh/stk_mesh/base/MeshBuilder.cpp b/packages/stk/stk_mesh/stk_mesh/base/MeshBuilder.cpp index c89e6489af1c..9fd5d03965b6 100644 --- a/packages/stk/stk_mesh/stk_mesh/base/MeshBuilder.cpp +++ b/packages/stk/stk_mesh/stk_mesh/base/MeshBuilder.cpp @@ -47,7 +47,8 @@ MeshBuilder::MeshBuilder() m_fieldDataManager(nullptr), m_bucketCapacity(impl::BucketRepository::default_bucket_capacity), m_spatialDimension(0), - m_entityRankNames() + m_entityRankNames(), + m_upwardConnectivity(true) { } diff --git a/packages/stk/stk_mesh/stk_mesh/base/Types.hpp b/packages/stk/stk_mesh/stk_mesh/base/Types.hpp index e4f744e25ff9..9c00a95a1fd0 100644 --- a/packages/stk/stk_mesh/stk_mesh/base/Types.hpp +++ b/packages/stk/stk_mesh/stk_mesh/base/Types.hpp @@ -56,7 +56,6 @@ namespace stk { namespace mesh { class Part; } } namespace stk { namespace mesh { class Selector; } } namespace stk { namespace mesh { class Relation; } } namespace stk { namespace mesh { struct Entity; } } -namespace stk { namespace mesh { namespace impl { class EntityRepository; } } } namespace stk { namespace mesh { struct EntityKey; } } diff --git a/packages/stk/stk_mesh/stk_mesh/baseImpl/AuraGhosting.cpp b/packages/stk/stk_mesh/stk_mesh/baseImpl/AuraGhosting.cpp index bb5018e7cb90..aa638c8d99e0 100644 --- a/packages/stk/stk_mesh/stk_mesh/baseImpl/AuraGhosting.cpp +++ b/packages/stk/stk_mesh/stk_mesh/baseImpl/AuraGhosting.cpp @@ -46,6 +46,9 @@ namespace mesh { namespace impl { AuraGhosting::AuraGhosting() +: m_entitySharing(), + m_sendAura(), + m_scratchSpace() { } @@ -55,12 +58,13 @@ AuraGhosting::~AuraGhosting() void AuraGhosting::generate_aura(BulkData& bulkData) { - EntityProcMapping entitySharing(bulkData.get_size_of_entity_index_space()); + m_entitySharing.reset(bulkData.get_size_of_entity_index_space()); std::vector ranks = {stk::topology::NODE_RANK, stk::topology::EDGE_RANK}; const MetaData& meta = bulkData.mesh_meta_data(); if (meta.side_rank() > stk::topology::EDGE_RANK) { ranks.push_back(meta.side_rank()); } + EntityProcMapping& entitySharing = m_entitySharing; std::vector sharingProcs; for(EntityRank rank : ranks) { impl::for_each_selected_entity_run_no_threads(bulkData, rank, meta.globally_shared_part(), @@ -73,10 +77,10 @@ void AuraGhosting::generate_aura(BulkData& bulkData) }); } - EntityProcMapping sendAuraEntityProcs(bulkData.get_size_of_entity_index_space()); - fill_send_aura_entities(bulkData, sendAuraEntityProcs, entitySharing); + m_sendAura.reset(bulkData.get_size_of_entity_index_space()); + fill_send_aura_entities(bulkData, m_sendAura, m_entitySharing); - change_ghosting(bulkData, sendAuraEntityProcs, entitySharing); + change_ghosting(bulkData, m_sendAura, m_entitySharing); } void AuraGhosting::remove_aura(BulkData& bulkData) @@ -90,7 +94,8 @@ void AuraGhosting::fill_send_aura_entities(BulkData& bulkData, EntityProcMapping& sendAuraEntityProcs, const EntityProcMapping& entitySharing) { - const EntityRank end_rank = static_cast(bulkData.mesh_meta_data().entity_rank_count()); + const EntityRank endRank = static_cast(bulkData.mesh_meta_data().entity_rank_count()); + const EntityRank maxRank = static_cast(endRank-1); // Iterate over all shared entities, ensure that upwardly related // entities to each shared entity will be ghosted to the sharing proc. @@ -98,24 +103,22 @@ void AuraGhosting::fill_send_aura_entities(BulkData& bulkData, std::vector sharingProcs; impl::for_each_selected_entity_run_no_threads(bulkData, stk::topology::NODE_RANK, shared, - [&sendAuraEntityProcs, &entitySharing, &sharingProcs, &end_rank] + [&sendAuraEntityProcs, &entitySharing, &sharingProcs, &endRank, &maxRank] (const BulkData& bulk, const MeshIndex& meshIndex) { const Bucket& bucket = *meshIndex.bucket; const unsigned bucketOrd = meshIndex.bucket_ordinal; - const EntityRank nextHigherRank = stk::topology::EDGE_RANK; bulk.comm_shared_procs(bucket[bucketOrd], sharingProcs); - for (const int sharingProc : sharingProcs) { - for (EntityRank higherRank = nextHigherRank; higherRank < end_rank; ++higherRank) { - const unsigned num_rels = bucket.num_connectivity(bucketOrd, higherRank); - const Entity* rels = bucket.begin(bucketOrd, higherRank); + static constexpr EntityRank nextHigherRank = stk::topology::EDGE_RANK; + for (EntityRank higherRank = nextHigherRank; higherRank < endRank; ++higherRank) { + const unsigned num_rels = bucket.num_connectivity(bucketOrd, higherRank); + const Entity* rels = bucket.begin(bucketOrd, higherRank); - for (unsigned r = 0; r < num_rels; ++r) { - stk::mesh::impl::insert_upward_relations(bulk, entitySharing, rels[r], stk::topology::NODE_RANK, sharingProc, sendAuraEntityProcs); - } + for (unsigned r = 0; r < num_rels; ++r) { + stk::mesh::impl::insert_upward_relations(bulk, entitySharing, rels[r], higherRank, maxRank, sharingProcs, sendAuraEntityProcs); } - } + } } ); // for_each_entity_run } @@ -124,42 +127,49 @@ void AuraGhosting::change_ghosting(BulkData& bulkData, EntityProcMapping& sendAuraEntityProcs, const EntityProcMapping& entitySharing) { - std::vector add_send; - sendAuraEntityProcs.fill_vec(add_send); + std::vector& sendAuraGhosts = m_scratchSpace; + sendAuraEntityProcs.fill_vec(sendAuraGhosts); //------------------------------------ // Add the specified entities and their closure to sendAuraEntityProcs - impl::StoreInEntityProcMapping siepm(bulkData, sendAuraEntityProcs); - EntityProcMapping epm(bulkData.get_size_of_entity_index_space()); - impl::OnlyGhostsEPM og(bulkData, epm, entitySharing); - for ( const EntityProc& entityProc : add_send ) { - og.proc = entityProc.second; - siepm.proc = entityProc.second; - impl::VisitClosureGeneral(bulkData,entityProc.first,siepm,og); + impl::StoreInEntityProcMapping storeEntity(bulkData, sendAuraEntityProcs); + impl::NotAlreadyShared entityBelongsInAura(bulkData, entitySharing); + for ( const EntityProc& entityProc : sendAuraGhosts ) { + entityBelongsInAura.proc = entityProc.second; + storeEntity.proc = entityProc.second; + const EntityRank entityRank = bulkData.entity_rank(entityProc.first); + if (entityRank > stk::topology::ELEM_RANK) { + VisitClosureGeneral(bulkData, entityProc.first, entityRank, storeEntity, entityBelongsInAura); + } + else { + VisitClosureBelowEntityNoRecurse(bulkData, entityProc.first, entityRank, storeEntity, entityBelongsInAura); + } } - sendAuraEntityProcs.fill_vec(add_send); - - // Synchronize the send and receive list. - // If the send list contains a not-owned entity - // inform the owner and receiver to add that entity - // to their ghost send and receive lists. - - std::vector ghostStatus(bulkData.get_size_of_entity_index_space(), false); + std::vector& nonOwnedSendAuraGhosts = m_scratchSpace; + nonOwnedSendAuraGhosts.clear(); + sendAuraEntityProcs.visit_entity_procs( + [&bulkData,&nonOwnedSendAuraGhosts](Entity ent, int p) + { + if (!bulkData.bucket(ent).owned()) { + nonOwnedSendAuraGhosts.emplace_back(ent,p); + } + }); - stk::mesh::impl::comm_sync_aura_send_recv(bulkData, add_send, - sendAuraEntityProcs, ghostStatus ); + impl::comm_sync_nonowned_sends(bulkData, nonOwnedSendAuraGhosts, sendAuraEntityProcs); //------------------------------------ - // Remove the ghost entities that will not remain. - // If the last reference to the receive ghost entity then delete it. + // Remove send-ghost entities from the comm-list that no longer need to be sent. OrdinalVector addParts; OrdinalVector removeParts(1, bulkData.m_ghost_parts[BulkData::AURA]->mesh_meta_data_ordinal()); OrdinalVector scratchOrdinalVec, scratchSpace; bool removed = false ; + std::vector removedSendGhosts; + const unsigned auraGhostingOrdinal = bulkData.aura_ghosting().ordinal(); + std::vector comm_ghost ; for ( EntityCommListInfoVector::reverse_iterator i = bulkData.m_entity_comm_list.rbegin() ; i != bulkData.m_entity_comm_list.rend() ; ++i) { @@ -174,44 +184,30 @@ void AuraGhosting::change_ghosting(BulkData& bulkData, } const bool is_owner = bulkData.parallel_owner_rank(entityComm.entity) == bulkData.parallel_rank() ; - const bool remove_recv = ( ! is_owner ) && - !ghostStatus[entityComm.entity.local_offset()] && bulkData.in_receive_ghost(bulkData.aura_ghosting(), entityComm.entity); + if ( is_owner ) { + // Is owner, potentially removing ghost-sends + // Have to make a copy - if(bulkData.is_valid(entityComm.entity)) - { - if ( is_owner ) { - // Is owner, potentially removing ghost-sends - // Have to make a copy - - const PairIterEntityComm ec = ghost_info_range(entityComm.entity_comm->comm_map, bulkData.aura_ghosting()); - comm_ghost.assign( ec.first , ec.second ); - - for ( ; ! comm_ghost.empty() ; comm_ghost.pop_back() ) { - const EntityCommInfo tmp = comm_ghost.back(); - - if (!sendAuraEntityProcs.find(entityComm.entity, tmp.proc) ) { - bulkData.entity_comm_map_erase(entityComm.key, tmp); - } - else { - sendAuraEntityProcs.eraseEntityProc(entityComm.entity, tmp.proc); - } - } - } - else if ( remove_recv ) { - bulkData.entity_comm_map_erase(entityComm.key, bulkData.aura_ghosting()); - bulkData.internal_change_entity_parts(entityComm.entity, addParts, removeParts, scratchOrdinalVec, scratchSpace); - } + const PairIterEntityComm ec = ghost_info_range(entityComm.entity_comm->comm_map, auraGhostingOrdinal); + comm_ghost.assign( ec.first , ec.second ); - if ( bulkData.internal_entity_comm_map(entityComm.entity).empty() ) { - removed = true ; - entityComm.key = EntityKey(); // No longer communicated - if ( remove_recv ) { - ThrowRequireMsg( bulkData.internal_destroy_entity_with_notification( entityComm.entity, remove_recv ), - "P[" << bulkData.parallel_rank() << "]: FAILED attempt to destroy entity: " - << bulkData.entity_key(entityComm.entity) ); + for ( ; ! comm_ghost.empty() ; comm_ghost.pop_back() ) { + const EntityCommInfo tmp = comm_ghost.back(); + + if (!sendAuraEntityProcs.find(entityComm.entity, tmp.proc) ) { + bulkData.entity_comm_map_erase(entityComm.key, tmp); + removedSendGhosts.push_back(EntityProc(entityComm.entity, tmp.proc)); + } + else { + sendAuraEntityProcs.eraseEntityProc(entityComm.entity, tmp.proc); } } } + + if ( bulkData.internal_entity_comm_map(entityComm.entity).empty() ) { + removed = true ; + entityComm.key = EntityKey(); // No longer communicated + } } // if an entry in the comm_list has the EntityKey() value, it is invalid, @@ -221,12 +217,22 @@ void AuraGhosting::change_ghosting(BulkData& bulkData, bulkData.delete_unneeded_entries_from_the_comm_list(); } + const std::vector>& allRemovedGhosts = bulkData.m_removedGhosts; + for(const std::pair& rmGhost : allRemovedGhosts) { + Entity rmEnt = bulkData.get_entity(rmGhost.first); + if (bulkData.is_valid(rmEnt) && + rmGhost.second.ghost_id == auraGhostingOrdinal && + bulkData.parallel_owner_rank(rmEnt) == bulkData.parallel_rank() && + !sendAuraEntityProcs.find(rmEnt, rmGhost.second.proc)) { + removedSendGhosts.push_back(EntityProc(rmEnt,rmGhost.second.proc)); + } + } EntityLess entityLess(bulkData); std::set finalSendGhosts(entityLess); sendAuraEntityProcs.fill_set(finalSendGhosts); const bool isFullRegen = true; - bulkData.ghost_entities_and_fields(bulkData.aura_ghosting(), finalSendGhosts, isFullRegen); + bulkData.ghost_entities_and_fields(bulkData.aura_ghosting(), finalSendGhosts, isFullRegen, removedSendGhosts); } }}} // end namepsace stk mesh impl diff --git a/packages/stk/stk_mesh/stk_mesh/baseImpl/AuraGhosting.hpp b/packages/stk/stk_mesh/stk_mesh/baseImpl/AuraGhosting.hpp index fe0a80368dce..532c03185cf5 100644 --- a/packages/stk/stk_mesh/stk_mesh/baseImpl/AuraGhosting.hpp +++ b/packages/stk/stk_mesh/stk_mesh/baseImpl/AuraGhosting.hpp @@ -35,6 +35,10 @@ #ifndef stk_mesh_impl_AuraGhosting_hpp #define stk_mesh_impl_AuraGhosting_hpp +#include +#include +#include + namespace stk { namespace mesh { @@ -60,6 +64,10 @@ class AuraGhosting virtual void change_ghosting(BulkData& bulkData, EntityProcMapping& entityProcMapping, const EntityProcMapping& entitySharing); +private: + EntityProcMapping m_entitySharing; + EntityProcMapping m_sendAura; + std::vector m_scratchSpace; }; }}} // end namepsace stk mesh impl diff --git a/packages/stk/stk_mesh/stk_mesh/baseImpl/DeletedEntityCache.cpp b/packages/stk/stk_mesh/stk_mesh/baseImpl/DeletedEntityCache.cpp new file mode 100644 index 000000000000..1ea43bb90a61 --- /dev/null +++ b/packages/stk/stk_mesh/stk_mesh/baseImpl/DeletedEntityCache.cpp @@ -0,0 +1,49 @@ + +#include "MeshModification.hpp" +#include + + +namespace stk { +namespace mesh { +namespace impl { + +void DeletedEntityCache::mark_entity_as_deleted(Entity entity, bool is_ghost) +{ + if (is_ghost) + { + m_ghost_reuse_map[m_bulkData.entity_key(entity)] = entity.local_offset(); + } else + { + m_deleted_entities_current_modification_cycle.push_back(entity.local_offset()); + } +} + +Entity::entity_value_type DeletedEntityCache::get_entity_for_reuse() +{ + if (!m_deleted_entities.empty()) + { + size_t new_local_offset = m_deleted_entities.back(); + m_deleted_entities.pop_back(); + return new_local_offset; + } else + { + return Entity::InvalidEntity; + } +} + +void DeletedEntityCache::update_deleted_entities_container() +{ + m_deleted_entities.insert(m_deleted_entities.end(), m_deleted_entities_current_modification_cycle.begin(), + m_deleted_entities_current_modification_cycle.end()); + m_deleted_entities_current_modification_cycle.clear(); + + for (auto keyAndOffset : m_ghost_reuse_map) { + m_deleted_entities.push_back(keyAndOffset.second); + } + m_ghost_reuse_map.clear(); +} + + +} +} +} \ No newline at end of file diff --git a/packages/stk/stk_mesh/stk_mesh/baseImpl/DeletedEntityCache.hpp b/packages/stk/stk_mesh/stk_mesh/baseImpl/DeletedEntityCache.hpp new file mode 100644 index 000000000000..f516d7fe42f0 --- /dev/null +++ b/packages/stk/stk_mesh/stk_mesh/baseImpl/DeletedEntityCache.hpp @@ -0,0 +1,81 @@ +// Copyright 2002 - 2008, 2010, 2011 National Technology Engineering +// Solutions of Sandia, LLC (NTESS). Under the terms of Contract +// DE-NA0003525 with NTESS, the U.S. Government retains certain rights +// in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// +// * Neither the name of NTESS nor the names of its contributors +// may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// + +#ifndef stk_mesh_impl_DeletedEntityCache_hpp +#define stk_mesh_impl_DeletedEntityCache_hpp + +#include // for MeshIndex, EntityRank, etc +#include +#include "stk_mesh/base/EntityKey.hpp" + +namespace stk { +namespace mesh { + +class BulkData; +typedef std::unordered_map GhostReuseMap; + + +namespace impl { + +class DeletedEntityCache +{ + public: + explicit DeletedEntityCache(BulkData& bulkData) : + m_bulkData(bulkData) + {} + + void mark_entity_as_deleted(Entity entity, bool is_ghost); + + const std::vector& get_deleted_entities_current_mod_cycle() const { return m_deleted_entities_current_modification_cycle; } + + GhostReuseMap& get_ghost_reuse_map() { return m_ghost_reuse_map; } + + const GhostReuseMap& get_ghost_reuse_map() const { return m_ghost_reuse_map; } + + Entity::entity_value_type get_entity_for_reuse(); + + void update_deleted_entities_container(); + + private: + BulkData& m_bulkData; + std::vector m_deleted_entities_current_modification_cycle; + std::vector m_deleted_entities; + GhostReuseMap m_ghost_reuse_map; +}; + +} +} +} + +#endif \ No newline at end of file diff --git a/packages/stk/stk_mesh/stk_mesh/baseImpl/EntityRepository.cpp b/packages/stk/stk_mesh/stk_mesh/baseImpl/EntityKeyMapping.cpp similarity index 92% rename from packages/stk/stk_mesh/stk_mesh/baseImpl/EntityRepository.cpp rename to packages/stk/stk_mesh/stk_mesh/baseImpl/EntityKeyMapping.cpp index 55ee17693921..067192738055 100644 --- a/packages/stk/stk_mesh/stk_mesh/baseImpl/EntityRepository.cpp +++ b/packages/stk/stk_mesh/stk_mesh/baseImpl/EntityKeyMapping.cpp @@ -32,7 +32,7 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // -#include +#include #include // for NULL #include // for operator<<, basic_ostream, etc #include @@ -74,7 +74,7 @@ struct match_EntityKey { const EntityKey& m_key; }; -EntityRepository::EntityRepository() +EntityKeyMapping::EntityKeyMapping() : m_entities(stk::topology::NUM_RANKS), m_create_cache(stk::topology::NUM_RANKS), m_update_cache(stk::topology::NUM_RANKS), @@ -84,11 +84,11 @@ EntityRepository::EntityRepository() { } -EntityRepository::~EntityRepository() +EntityKeyMapping::~EntityKeyMapping() { } -void EntityRepository::clear_all_cache() +void EntityKeyMapping::clear_all_cache() { EntityRank nRanks = static_cast(m_create_cache.size()); for(EntityRank rank=stk::topology::BEGIN_RANK; rank& destroy = m_destroy_cache[rank]; @@ -135,7 +135,7 @@ void EntityRepository::clear_destroyed_entity_cache(EntityRank rank) const } } -void EntityRepository::clear_updated_entity_cache(EntityRank rank) const +void EntityKeyMapping::clear_updated_entity_cache(EntityRank rank) const { if (!m_update_cache[rank].empty()) { std::vector >& update = m_update_cache[rank]; @@ -154,7 +154,7 @@ void EntityRepository::clear_updated_entity_cache(EntityRank rank) const } } -void EntityRepository::clear_created_entity_cache(EntityRank rank) const +void EntityKeyMapping::clear_created_entity_cache(EntityRank rank) const { if (!m_create_cache[rank].empty()) { std::sort(m_create_cache[rank].begin(), m_create_cache[rank].end()); @@ -167,7 +167,7 @@ void EntityRepository::clear_created_entity_cache(EntityRank rank) const } } -void EntityRepository::clear_cache(EntityRank rank) const +void EntityKeyMapping::clear_cache(EntityRank rank) const { clear_created_entity_cache(rank); @@ -177,7 +177,7 @@ void EntityRepository::clear_cache(EntityRank rank) const } std::pair -EntityRepository::add_to_cache(const EntityKey& key) +EntityKeyMapping::add_to_cache(const EntityKey& key) { bool inserted_new_entity = false; EntityRank rank = key.rank(); @@ -208,7 +208,7 @@ EntityRepository::add_to_cache(const EntityKey& key) return std::make_pair(iter, inserted_new_entity); } -stk::mesh::entity_iterator EntityRepository::get_from_cache(const EntityKey& key) const +stk::mesh::entity_iterator EntityKeyMapping::get_from_cache(const EntityKey& key) const { if (!m_create_cache[key.rank()].empty()) { EntityKeyEntityVector& cache = m_create_cache[key.rank()]; @@ -222,7 +222,7 @@ stk::mesh::entity_iterator EntityRepository::get_from_cache(const EntityKey& key } std::pair -EntityRepository::internal_create_entity( const EntityKey & key) +EntityKeyMapping::internal_create_entity( const EntityKey & key) { if (key.rank() > entity_rank_count()) { m_entities.resize(key.rank()); @@ -242,7 +242,7 @@ EntityRepository::internal_create_entity( const EntityKey & key) return add_to_cache(key); } -Entity EntityRepository::get_entity(const EntityKey &key) const +Entity EntityKeyMapping::get_entity(const EntityKey &key) const { EntityRank rank = key.rank(); if (!m_destroy_cache[rank].empty()) { @@ -279,7 +279,7 @@ Entity EntityRepository::get_entity(const EntityKey &key) const return (iter != entities.end() && (iter->first==key)) ? iter->second : Entity() ; } -void EntityRepository::update_entity_key(EntityKey new_key, EntityKey old_key, Entity entity) +void EntityKeyMapping::update_entity_key(EntityKey new_key, EntityKey old_key, Entity entity) { EntityRank rank = new_key.rank(); clear_created_entity_cache(rank); @@ -292,7 +292,7 @@ void EntityRepository::update_entity_key(EntityKey new_key, EntityKey old_key, E m_update_cache[rank].emplace_back(old_key, new_key); } -void EntityRepository::destroy_entity(EntityKey key, Entity entity) +void EntityKeyMapping::destroy_entity(EntityKey key, Entity entity) { EntityRank rank = key.rank(); clear_created_entity_cache(rank); diff --git a/packages/stk/stk_mesh/stk_mesh/baseImpl/EntityRepository.hpp b/packages/stk/stk_mesh/stk_mesh/baseImpl/EntityKeyMapping.hpp similarity index 93% rename from packages/stk/stk_mesh/stk_mesh/baseImpl/EntityRepository.hpp rename to packages/stk/stk_mesh/stk_mesh/baseImpl/EntityKeyMapping.hpp index 6c840513e010..791f377f8b18 100644 --- a/packages/stk/stk_mesh/stk_mesh/baseImpl/EntityRepository.hpp +++ b/packages/stk/stk_mesh/stk_mesh/baseImpl/EntityKeyMapping.hpp @@ -32,8 +32,8 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // -#ifndef stk_mesh_baseImpl_EntityRepository_hpp -#define stk_mesh_baseImpl_EntityRepository_hpp +#ifndef stk_mesh_baseImpl_EntityKeyMapping_hpp +#define stk_mesh_baseImpl_EntityKeyMapping_hpp #include // for size_t #include // for map, map<>::value_compare @@ -47,7 +47,7 @@ namespace stk { namespace mesh { namespace impl { -class EntityRepository { +class EntityKeyMapping { public: @@ -56,9 +56,9 @@ class EntityRepository { typedef EntityKeyEntityVector::const_iterator const_iterator; typedef EntityKeyEntityVector::iterator iterator; - EntityRepository(); + EntityKeyMapping(); - ~EntityRepository(); + ~EntityKeyMapping(); Entity get_entity( const EntityKey &key ) const; @@ -112,8 +112,8 @@ class EntityRepository { mutable unsigned m_maxUpdateCacheSize; //disable copy constructor and assignment operator - EntityRepository(const EntityRepository &); - EntityRepository & operator =(const EntityRepository &); + EntityKeyMapping(const EntityKeyMapping &); + EntityKeyMapping & operator =(const EntityKeyMapping &); }; } // namespace impl @@ -121,5 +121,5 @@ class EntityRepository { } // namespace mesh } // namespace stk -#endif // stk_mesh_baseImpl_EntityRepository_hpp +#endif // stk_mesh_baseImpl_EntityKeyMapping_hpp diff --git a/packages/stk/stk_mesh/stk_mesh/baseImpl/MeshImplUtils.cpp b/packages/stk/stk_mesh/stk_mesh/baseImpl/MeshImplUtils.cpp index 3b5dc54a2c99..d67294fc3e5c 100644 --- a/packages/stk/stk_mesh/stk_mesh/baseImpl/MeshImplUtils.cpp +++ b/packages/stk/stk_mesh/stk_mesh/baseImpl/MeshImplUtils.cpp @@ -117,6 +117,15 @@ void find_entities_with_larger_ids_these_nodes_have_in_common_and_locally_owned( } } +const EntityCommListInfo& find_entity(const BulkData& mesh, + const EntityCommListInfoVector& entities, + const EntityKey& key) +{ + EntityCommListInfoVector::const_iterator lb_itr = std::lower_bound(entities.begin(), entities.end(), key); + ThrowAssertMsg(lb_itr != entities.end() && lb_itr->key == key, + "proc " << mesh.parallel_rank() << " Cannot find entity-key " << key << " in comm-list" ); + return *lb_itr; +} bool do_these_nodes_have_any_shell_elements_in_common(BulkData& mesh, unsigned numNodes, const Entity* nodes) { @@ -1264,43 +1273,61 @@ void comm_sync_aura_send_recv( } } -void insert_upward_relations(const BulkData& bulk_data, Entity rel_entity, - const EntityRank rank_of_orig_entity, - const int share_proc, - std::vector& send) +void comm_sync_nonowned_sends( + const BulkData & mesh , + std::vector & nonOwnedSendGhosts, + EntityProcMapping& entityProcMapping) { - EntityRank rel_entity_rank = bulk_data.entity_rank(rel_entity); - ThrowAssert(rel_entity_rank > rank_of_orig_entity); + const int parallel_size = mesh.parallel_size(); + const int parallel_rank = mesh.parallel_rank(); + stk::CommSparse commSparse( mesh.parallel() ); - // If related entity is higher rank, I own it, and it is not - // already shared by proc, ghost it to the sharing processor. - if ( bulk_data.bucket(rel_entity).owned() && ! bulk_data.in_shared(rel_entity, share_proc) ) { + for (const EntityProc& ep : nonOwnedSendGhosts) { + const int owner = mesh.parallel_owner_rank(ep.first); + if ( owner != mesh.parallel_rank() ) { + commSparse.send_buffer( owner ).skip(1).skip(1); + } + } - send.emplace_back(rel_entity,share_proc); + commSparse.allocate_buffers(); - // There may be even higher-ranking entities that need to be ghosted, so we must recurse - const EntityRank end_rank = static_cast(bulk_data.mesh_meta_data().entity_rank_count()); - for (EntityRank irank = static_cast(rel_entity_rank + 1); irank < end_rank; ++irank) - { - const int num_rels = bulk_data.num_connectivity(rel_entity, irank); - Entity const* rels = bulk_data.begin(rel_entity, irank); + for (const EntityProc& ep : nonOwnedSendGhosts) { + const int owner = mesh.parallel_owner_rank(ep.first); + if ( owner != parallel_rank ) { + commSparse.send_buffer( owner ).pack(mesh.entity_key(ep.first)).pack(ep.second); + entityProcMapping.eraseEntityProc(ep.first, ep.second); + } + } - for (int r = 0; r < num_rels; ++r) - { - Entity const rel_of_rel_entity = rels[r]; - if (bulk_data.is_valid(rel_of_rel_entity)) { - insert_upward_relations(bulk_data, rel_of_rel_entity, rel_entity_rank, share_proc, send); - } - } + commSparse.communicate(); + + for ( int p = 0 ; p < parallel_size ; ++p ) { + CommBuffer & buf = commSparse.recv_buffer(p); + while ( buf.remaining() ) { + + EntityKey entity_key; + int proc = 0; + + buf.unpack(entity_key).unpack(proc); + + Entity const e = mesh.get_entity( entity_key ); + + ThrowAssert(parallel_rank != proc); + ThrowAssert(mesh.is_valid(e)); + + //Receiving a ghosting need for an entity I own, add it. + entityProcMapping.addEntityProc(e, proc); } } } -EntityRank get_highest_upward_connected_rank(const BulkData& mesh, Entity entity) +EntityRank get_highest_upward_connected_rank(const Bucket& bucket, + unsigned bucketOrdinal, + EntityRank entityRank, + EntityRank maxRank) { - const EntityRank entityRank = mesh.entity_rank(entity); - EntityRank highestRank = static_cast(mesh.mesh_meta_data().entity_rank_count()-1); - while(highestRank > entityRank && mesh.num_connectivity(entity, highestRank) == 0) { + EntityRank highestRank = maxRank; + while(highestRank > entityRank && bucket.num_connectivity(bucketOrdinal, highestRank) == 0) { highestRank = static_cast(highestRank-1); } return highestRank; @@ -1308,29 +1335,39 @@ EntityRank get_highest_upward_connected_rank(const BulkData& mesh, Entity entity void insert_upward_relations(const BulkData& bulk_data, const EntityProcMapping& entitySharing, - Entity rel_entity, - const EntityRank rank_of_orig_entity, - const int share_proc, + const Entity entity, + const EntityRank entityRank, + const EntityRank maxRank, + const std::vector& sharingProcs, EntityProcMapping& send) { // If related entity is higher rank, I own it, and it is not // already shared by proc, ghost it to the sharing processor. - const MeshIndex& idx = bulk_data.mesh_index(rel_entity); + const MeshIndex& idx = bulk_data.mesh_index(entity); const Bucket& bucket = *idx.bucket; - if ( bucket.owned() && !entitySharing.find(rel_entity, share_proc) ) { - - send.addEntityProc(rel_entity,share_proc); - + if (bucket.owned()) { const unsigned bucketOrd = idx.bucket_ordinal; - const EntityRank upwardRank = get_highest_upward_connected_rank(bulk_data, rel_entity); - const int numRels = bucket.num_connectivity(bucketOrd, upwardRank); - Entity const* rels = bucket.begin(bucketOrd, upwardRank); - - for (int r = 0; r < numRels; ++r) { - Entity const upwardEntity = rels[r]; - if (bulk_data.is_valid(upwardEntity) && bulk_data.bucket(upwardEntity).owned()) { - if (!entitySharing.find(upwardEntity, share_proc)) { - send.addEntityProc(upwardEntity, share_proc); + const EntityRank upwardRank = get_highest_upward_connected_rank(bucket, bucketOrd, entityRank, maxRank); + + if (upwardRank > entityRank) { + const int numRels = bucket.num_connectivity(bucketOrd, upwardRank); + const Entity* rels = bucket.begin(bucketOrd, upwardRank); + + for (int r = 0; r < numRels; ++r) { + Entity const upwardEntity = rels[r]; + if (bulk_data.is_valid(upwardEntity) && bulk_data.bucket(upwardEntity).owned()) { + for(int sharingProc : sharingProcs) { + if (upwardRank >= stk::topology::ELEM_RANK || !entitySharing.find(upwardEntity, sharingProc)) { + send.addEntityProc(upwardEntity, sharingProc); + } + } + } + } + } + else { + for(int sharingProc : sharingProcs) { + if (entityRank >= stk::topology::ELEM_RANK || !entitySharing.find(entity, sharingProc)) { + send.addEntityProc(entity,sharingProc); } } } @@ -1604,7 +1641,7 @@ bool is_good_rank_and_id(const MetaData& meta, EntityId get_global_max_id_in_use(const BulkData& mesh, EntityRank rank, - const std::list& deletedEntitiesCurModCycle) + const std::vector& deletedEntitiesCurModCycle) { EntityId localMax = stk::mesh::get_max_id_on_local_proc(mesh, rank); diff --git a/packages/stk/stk_mesh/stk_mesh/baseImpl/MeshImplUtils.hpp b/packages/stk/stk_mesh/stk_mesh/baseImpl/MeshImplUtils.hpp index e493ac58311a..b1a709fd430f 100644 --- a/packages/stk/stk_mesh/stk_mesh/baseImpl/MeshImplUtils.hpp +++ b/packages/stk/stk_mesh/stk_mesh/baseImpl/MeshImplUtils.hpp @@ -89,6 +89,10 @@ void find_entities_these_nodes_have_in_common_and(const BulkData& mesh, EntityRa } } +const EntityCommListInfo& find_entity(const BulkData& mesh, + const EntityCommListInfoVector& entities, + const EntityKey& key); + bool do_these_nodes_have_any_shell_elements_in_common(BulkData& mesh, unsigned numNodes, const Entity* nodes); void find_locally_owned_elements_these_nodes_have_in_common(const BulkData& mesh, unsigned numNodes, const Entity* nodes, std::vector& elems); @@ -244,16 +248,17 @@ void comm_sync_aura_send_recv( EntityProcMapping& entityProcMapping, std::vector& ghostStatus ); -void insert_upward_relations(const BulkData& bulk_data, Entity rel_entity, - const EntityRank rank_of_orig_entity, - const int share_proc, - std::vector& send); +void comm_sync_nonowned_sends( + const BulkData & mesh , + std::vector & nonOwnedSendGhosts, + EntityProcMapping& entityProcMapping); void insert_upward_relations(const BulkData& bulk_data, const EntityProcMapping& entitySharing, - Entity rel_entity, - const EntityRank rank_of_orig_entity, - const int share_proc, + const Entity entity, + const EntityRank entityRank, + const EntityRank maxRank, + const std::vector& share_proc, EntityProcMapping& send); void move_unowned_entities_for_owner_to_ghost( @@ -318,7 +323,7 @@ bool is_good_rank_and_id(const MetaData& meta, EntityId get_global_max_id_in_use(const BulkData& mesh, EntityRank rank, - const std::list& deletedEntitiesCurModCycle); + const std::vector& deletedEntitiesCurModCycle); void check_declare_element_side_inputs(const BulkData & mesh, const Entity elem, diff --git a/packages/stk/stk_mesh/stk_mesh/baseImpl/MeshModification.cpp b/packages/stk/stk_mesh/stk_mesh/baseImpl/MeshModification.cpp index b7bb82c7bb7a..1938862ea4b9 100644 --- a/packages/stk/stk_mesh/stk_mesh/baseImpl/MeshModification.cpp +++ b/packages/stk/stk_mesh/stk_mesh/baseImpl/MeshModification.cpp @@ -1,9 +1,11 @@ #include "MeshModification.hpp" +#include #include #include #include +#include #include -#include +#include namespace stk { namespace mesh { @@ -16,7 +18,7 @@ bool MeshModification::modification_begin(const std::string description) if (this->synchronized_count() == 0) { m_bulkData.mesh_meta_data().set_mesh_on_fields(&m_bulkData); - m_bulkData.m_entity_repo->update_num_ranks(m_bulkData.mesh_meta_data().entity_rank_count()); + m_bulkData.m_entityKeyMapping->update_num_ranks(m_bulkData.mesh_meta_data().entity_rank_count()); const unsigned numRanks = m_bulkData.mesh_meta_data().entity_rank_count(); if (numRanks > m_bulkData.m_selector_to_buckets_maps.size()) { m_bulkData.m_selector_to_buckets_maps.resize(numRanks); @@ -36,6 +38,7 @@ bool MeshModification::modification_begin(const std::string description) else { this->reset_undeleted_entity_states_to_unchanged(); + m_bulkData.m_removedGhosts.clear(); } this->set_sync_state_modifiable(); @@ -55,82 +58,77 @@ bool MeshModification::modification_begin(const std::string description) bool MeshModification::modification_end(modification_optimization opt) { - return this->internal_modification_end( opt ); -} - -bool MeshModification::resolve_node_sharing() -{ - return this->internal_resolve_node_sharing( MOD_END_SORT ); -} - -bool MeshModification::modification_end_after_node_sharing_resolution() -{ - return this->internal_modification_end_after_node_sharing_resolution( MOD_END_SORT ); -} - -bool MeshModification::internal_modification_end(modification_optimization opt) -{ - if(this->in_synchronized_state()) - { - return false; - } - - ThrowAssertMsg(impl::check_for_connected_nodes(m_bulkData)==0, "BulkData::modification_end ERROR, all entities with rank higher than node are required to have connected nodes."); - - ThrowAssertMsg(m_bulkData.add_fmwk_data() || impl::check_no_shared_elements_or_higher(m_bulkData)==0, "BulkData::modification_end ERROR, Sharing of entities with rank ELEMENT_RANK or higher is not allowed."); - - m_bulkData.m_entity_repo->clear_all_cache(); + if(this->in_synchronized_state()) + { + return false; + } - if(m_bulkData.parallel_size() > 1) - { - // Resolve modification or deletion of shared entities - // which can cause deletion of ghost entities. - stk::mesh::EntityVector entitiesNoLongerShared; - internal_resolve_shared_modify_delete(entitiesNoLongerShared); + ThrowAssertMsg(impl::check_for_connected_nodes(m_bulkData)==0, "BulkData::modification_end ERROR, all entities with rank higher than node are required to have connected nodes."); - // Resolve modification or deletion of ghost entities - // by destroying ghost entities that have been touched. - m_bulkData.internal_resolve_ghosted_modify_delete(entitiesNoLongerShared); - m_bulkData.update_comm_list_based_on_changes_in_comm_map(); + ThrowAssertMsg(m_bulkData.add_fmwk_data() || impl::check_no_shared_elements_or_higher(m_bulkData)==0, "BulkData::modification_end ERROR, Sharing of entities with rank ELEMENT_RANK or higher is not allowed."); - // Resolve creation of entities: discover sharing and set unique ownership. - m_bulkData.internal_resolve_parallel_create(); + m_bulkData.m_entityKeyMapping->clear_all_cache(); - // Manoj: consider adding check_sharing_comm_maps here which is currently - // in BulkDataTester in UnitTestModificationEnd.cpp + if(m_bulkData.parallel_size() > 1) + { + // Resolve modification or deletion of shared entities + // which can cause deletion of ghost entities. + stk::mesh::EntityVector entitiesNoLongerShared; + internal_resolve_shared_modify_delete(entitiesNoLongerShared); + + // Resolve modification or deletion of ghost entities + // by destroying ghost entities that have been touched. + internal_resolve_ghosted_modify_delete(entitiesNoLongerShared); + m_bulkData.update_comm_list_based_on_changes_in_comm_map(); + + // Resolve creation of entities: discover sharing and set unique ownership. + m_bulkData.internal_resolve_parallel_create(); + + // Manoj: consider adding check_sharing_comm_maps here which is currently + // in BulkDataTester in UnitTestModificationEnd.cpp + + // Resolve part membership for shared entities. + // This occurs after resolving creation so created and shared + // entities are resolved along with previously existing shared entities. + m_bulkData.internal_resolve_shared_membership(entitiesNoLongerShared); + + // Regenerate the ghosting aura around all shared mesh entities. + if(m_bulkData.is_automatic_aura_on()) + { + m_bulkData.internal_regenerate_aura(); + } + else if (m_bulkData.m_turningOffAutoAura) { + m_bulkData.internal_remove_aura(); + } - // Resolve part membership for shared entities. - // This occurs after resolving creation so created and shared - // entities are resolved along with previously existing shared entities. - m_bulkData.internal_resolve_shared_membership(entitiesNoLongerShared); + m_bulkData.internal_resolve_send_ghost_membership(); - // Regenerate the ghosting aura around all shared mesh entities. - if(m_bulkData.is_automatic_aura_on()) - { - m_bulkData.internal_regenerate_aura(); - } - else if (m_bulkData.m_turningOffAutoAura) { - m_bulkData.internal_remove_aura(); - } + m_bulkData.m_modSummary.write_summary(synchronized_count()); + m_bulkData.check_mesh_consistency(); + } + else + { + m_bulkData.m_modSummary.write_summary(synchronized_count()); + if(!m_bulkData.add_fmwk_data()) + { + std::vector shared_modified; + m_bulkData.internal_update_sharing_comm_map_and_fill_list_modified_shared_entities(shared_modified); + } + } - m_bulkData.internal_resolve_send_ghost_membership(); + m_bulkData.internal_finish_modification_end(opt); - m_bulkData.m_modSummary.write_summary(synchronized_count()); - m_bulkData.check_mesh_consistency(); - } - else - { - m_bulkData.m_modSummary.write_summary(synchronized_count()); - if(!m_bulkData.add_fmwk_data()) - { - std::vector shared_modified; - m_bulkData.internal_update_sharing_comm_map_and_fill_list_modified_shared_entities(shared_modified); - } - } + return true; +} - m_bulkData.internal_finish_modification_end(opt); +bool MeshModification::resolve_node_sharing() +{ + return this->internal_resolve_node_sharing( MOD_END_SORT ); +} - return true; +bool MeshModification::modification_end_after_node_sharing_resolution() +{ + return this->internal_modification_end_after_node_sharing_resolution( MOD_END_SORT ); } bool MeshModification::internal_resolve_node_sharing(modification_optimization opt) @@ -217,6 +215,107 @@ void MeshModification::change_entity_owner( const EntityProcVec & arg_change) m_bulkData.internal_modification_end_for_change_entity_owner(mod_optimization); } +bool MeshModification::pack_entity_modification( const bool packShared , stk::CommSparse & comm ) +{ + bool flag = false; + bool packGhosted = packShared == false; + + const EntityCommListInfoVector & entityCommList = m_bulkData.internal_comm_list(); + + for ( EntityCommListInfoVector::const_iterator + i = entityCommList.begin() ; i != entityCommList.end() ; ++i ) { + if (i->entity_comm != nullptr) { + Entity entity = i->entity; + EntityState status = m_bulkData.is_valid(entity) ? m_bulkData.state(entity) : Deleted; + + if ( status == Modified || status == Deleted ) { + int owned_closure_int = m_bulkData.owned_closure(entity) ? 1 : 0; + + for ( PairIterEntityComm ec(i->entity_comm->comm_map); ! ec.empty() ; ++ec ) + { + if ( ( packGhosted && ec->ghost_id > BulkData::SHARED ) || ( packShared && ec->ghost_id == BulkData::SHARED ) ) + { + comm.send_buffer( ec->proc ) + .pack( i->key ) + .pack( status ) + .pack(owned_closure_int); + + const bool promotingGhostToShared = + packGhosted && owned_closure_int==1 && !m_bulkData.bucket(entity).owned(); + if (promotingGhostToShared) { + comm.send_buffer(comm.parallel_rank()) + .pack( i->key ) + .pack( status ) + .pack(owned_closure_int); + } + + flag = true ; + } + } + } + } + } + + return flag ; +} + +void MeshModification::communicate_entity_modification( const bool shared , std::vector & data ) +{ + stk::CommSparse comm( m_bulkData.parallel() ); + const int p_size = comm.parallel_size(); + + // Sizing send buffers: + pack_entity_modification(shared , comm); + + comm.allocate_buffers(); + + bool needToSend = false; + for (int procNumber=0; procNumber < p_size; ++procNumber) + { + if (comm.send_buffer(procNumber).capacity() > 0) + { + needToSend = true; + break; + } + } + + // Packing send buffers: + if (needToSend) { + pack_entity_modification(shared , comm); + } + + comm.communicate(); + + const EntityCommListInfoVector & entityCommList = m_bulkData.internal_comm_list(); + for ( int procNumber = 0 ; procNumber < p_size ; ++procNumber ) { + CommBuffer & buf = comm.recv_buffer( procNumber ); + EntityKey key; + EntityState state; + int remote_owned_closure_int; + bool remote_owned_closure; + + while ( buf.remaining() ) { + + buf.unpack( key ) + .unpack( state ) + .unpack( remote_owned_closure_int); + remote_owned_closure = ((remote_owned_closure_int==1)?true:false); + + // search through entity_comm, should only receive info on entities + // that are communicated. + EntityCommListInfo info = find_entity(m_bulkData, entityCommList, key); + int remoteProc = procNumber; + if (!shared && remoteProc == m_bulkData.parallel_rank()) { + remoteProc = m_bulkData.parallel_owner_rank(info.entity); + } + EntityParallelState parallel_state = {remoteProc, state, info, remote_owned_closure}; + data.push_back( parallel_state ); + } + } + + std::sort( data.begin() , data.end() ); +} + // Resolve modifications for shared entities: // If not locally destroyed and remotely modified // then set to locally modified. @@ -233,17 +332,17 @@ void MeshModification::internal_resolve_shared_modify_delete(stk::mesh::EntityVe ThrowRequireMsg(m_bulkData.parallel_size() > 1, "Do not call this in serial"); stk::mesh::EntityProcVec entitiesToRemoveFromSharing; - m_bulkData.delete_shared_entities_which_are_no_longer_in_owned_closure(entitiesToRemoveFromSharing); + delete_shared_entities_which_are_no_longer_in_owned_closure(entitiesToRemoveFromSharing); - std::vector remotely_modified_shared_entities; + std::vector remotely_modified_shared_entities; // Communicate entity modification state for shared entities // the resulting vector is sorted by entity and process. const bool communicate_shared = true; - m_bulkData.communicate_entity_modification(communicate_shared, remotely_modified_shared_entities); + communicate_entity_modification(communicate_shared, remotely_modified_shared_entities); // We iterate backwards to ensure that we hit the higher-ranking entities first. - for(std::vector::reverse_iterator + for(std::vector::reverse_iterator i = remotely_modified_shared_entities.rbegin(); i != remotely_modified_shared_entities.rend();) { @@ -298,16 +397,15 @@ void MeshModification::internal_resolve_shared_modify_delete(stk::mesh::EntityVe { const bool am_i_old_local_owner = m_bulkData.parallel_rank() == owner; - if(remote_owner_destroyed) - { - m_bulkData.internal_establish_new_owner(entity); + if(remote_owner_destroyed) { + internal_establish_new_owner(entity); } const bool am_i_new_local_owner = m_bulkData.parallel_rank() == m_bulkData.parallel_owner_rank(entity); const bool did_i_just_become_owner = (!am_i_old_local_owner && am_i_new_local_owner ); const bool is_entity_shared = !m_bulkData.internal_entity_comm_map_shared(key).empty(); - m_bulkData.internal_update_parts_for_shared_entity(entity, is_entity_shared, did_i_just_become_owner); + internal_update_parts_for_shared_entity(entity, is_entity_shared, did_i_just_become_owner); } } // remote mod loop @@ -324,6 +422,253 @@ void MeshModification::internal_resolve_shared_modify_delete(stk::mesh::EntityVe m_bulkData.remove_entities_from_sharing(entitiesToRemoveFromSharing, entitiesNoLongerShared); } +void MeshModification::internal_establish_new_owner(stk::mesh::Entity entity) +{ + const int new_owner = m_bulkData.determine_new_owner(entity); + + m_bulkData.internal_set_owner(entity, new_owner); +} + +void MeshModification::internal_update_parts_for_shared_entity(stk::mesh::Entity entity, const bool is_entity_shared, const bool did_i_just_become_owner) +{ + OrdinalVector parts_to_add_entity_to , parts_to_remove_entity_from, scratchOrdinalVec, scratchSpace; + + if ( !is_entity_shared ) { + parts_to_remove_entity_from.push_back(m_bulkData.mesh_meta_data().globally_shared_part().mesh_meta_data_ordinal()); + } + + if ( did_i_just_become_owner ) { + parts_to_add_entity_to.push_back(m_bulkData.mesh_meta_data().locally_owned_part().mesh_meta_data_ordinal()); + } + + if ( ! parts_to_add_entity_to.empty() || ! parts_to_remove_entity_from.empty() ) { + m_bulkData.internal_change_entity_parts( entity , parts_to_add_entity_to , parts_to_remove_entity_from, scratchOrdinalVec, scratchSpace ); + } +} + +void MeshModification::destroy_dependent_ghosts( Entity entity, EntityProcVec& entitiesToRemoveFromSharing ) +{ + EntityRank entity_rank = m_bulkData.entity_rank(entity); + const EntityRank end_rank = static_cast(m_bulkData.mesh_meta_data().entity_rank_count()); + for (EntityRank irank = static_cast(end_rank - 1); irank > entity_rank; --irank) + { + int num_rels = m_bulkData.num_connectivity(entity, irank); + const Entity* rels = m_bulkData.begin(entity, irank); + + for (int r = num_rels - 1; r >= 0; --r) + { + Entity e = rels[r]; + + bool upwardRelationOfEntityIsInClosure = m_bulkData.owned_closure(e); + ThrowRequireMsg( !upwardRelationOfEntityIsInClosure, m_bulkData.entity_rank(e) << " with id " << m_bulkData.identifier(e) << " should not be in closure." ); + + // Recursion + if (m_bulkData.is_valid(e) && m_bulkData.bucket(e).in_aura()) + { + destroy_dependent_ghosts( e, entitiesToRemoveFromSharing ); + } + } + } + + const bool successfully_destroyed_entity = m_bulkData.destroy_entity(entity); + if (!successfully_destroyed_entity) + { + std::vector sharing_procs; + m_bulkData.comm_shared_procs(m_bulkData.entity_key(entity), sharing_procs); + for(int p : sharing_procs) { + entitiesToRemoveFromSharing.emplace_back(entity, p); + } + } +} + +// Entities with sharing information that are not in the owned closure +// have been modified such that they are no longer shared. +// These may no longer be needed or may become ghost entities. +// There is not enough information so assume they are to be deleted +// and let these entities be re-ghosted if they are needed. + +// Open question: Should an owned and shared entity that does not +// have an upward relation to an owned entity be destroyed so that +// ownership transfers to another process? + +void MeshModification::delete_shared_entities_which_are_no_longer_in_owned_closure(EntityProcVec& entitiesToRemoveFromSharing) +{ + for ( EntityCommListInfoVector::const_reverse_iterator + i = m_bulkData.internal_comm_list().rbegin() ; + i != m_bulkData.internal_comm_list().rend() ; ++i) + { + Entity entity = i->entity; + if (m_bulkData.is_valid(entity) && !m_bulkData.owned_closure(entity)) { + if ( m_bulkData.in_shared(entity) ) + { + destroy_dependent_ghosts( entity, entitiesToRemoveFromSharing ); + } + } + } +} + +//---------------------------------------------------------------------- +// Resolve modifications for ghosted entities: +// If a ghosted entity is modified or destroyed on the owning +// process then the ghosted entity must be destroyed. +// +// Post condition: +// Ghosted entities of modified or deleted entities are destroyed. +// Ghosted communication lists are cleared to reflect all deletions. + +void MeshModification::internal_resolve_ghosted_modify_delete(const stk::mesh::EntityVector& entitiesNoLongerShared) +{ + ThrowRequireMsg(m_bulkData.parallel_size() > 1, "Do not call this in serial"); + // Resolve modifications for ghosted entities: + + std::vector remotely_modified_ghosted_entities ; + internal_resolve_formerly_shared_entities(entitiesNoLongerShared); + + // Communicate entity modification state for ghost entities + const bool communicate_shared = false ; + communicate_entity_modification( communicate_shared , remotely_modified_ghosted_entities ); + + const size_t ghosting_count = m_bulkData.m_ghosting.size(); + const size_t ghosting_count_minus_shared = ghosting_count - 1; + + std::vector promotingToShared; + + // We iterate backwards over remote_mod to ensure that we hit the + // higher-ranking entities first. This is important because higher-ranking + // entities like element must be deleted before the nodes they have are + // deleted. + for ( std::vector::reverse_iterator + i = remotely_modified_ghosted_entities.rbegin(); i != remotely_modified_ghosted_entities.rend() ; ++i ) + { + Entity entity = i->comm_info.entity; + const EntityKey key = i->comm_info.key; + const int remote_proc = i->from_proc; + const bool local_owner = m_bulkData.parallel_owner_rank(entity) == m_bulkData.parallel_rank() ; + const bool remotely_destroyed = Deleted == i->state ; + const bool remote_proc_is_owner = remote_proc == m_bulkData.parallel_owner_rank(entity); + const bool isAlreadyDestroyed = !m_bulkData.is_valid(entity); + + if ( local_owner ) { // Sending to 'remote_proc' for ghosting + + if ( remotely_destroyed ) { + + // remove from ghost-send list + + for ( size_t j = ghosting_count_minus_shared ; j>=1 ; --j) { + m_bulkData.entity_comm_map_erase( key, EntityCommInfo( j , remote_proc ) ); + } + } + else { + if (!m_bulkData.in_ghost(m_bulkData.aura_ghosting(), entity) && m_bulkData.state(entity)==Unchanged) { + m_bulkData.set_state(entity, Modified); + } + + const bool shouldPromoteToShared = !isAlreadyDestroyed && i->remote_owned_closure==1 && key.rank() < stk::topology::ELEM_RANK; + if (shouldPromoteToShared) { + m_bulkData.entity_comm_map_insert(entity, EntityCommInfo(BulkData::SHARED, remote_proc)); + promotingToShared.push_back(entity); + } + } + } + else if (remote_proc_is_owner) { // Receiving from 'remote_proc' for ghosting + + const bool hasBeenPromotedToSharedOrOwned = m_bulkData.owned_closure(entity); + bool isAuraGhost = false; + bool isCustomGhost = false; + PairIterEntityComm pairIterEntityComm = m_bulkData.internal_entity_comm_map(entity); + for(unsigned j=0; j BulkData::AURA) { + isCustomGhost = true; + } + } + + if ( isAuraGhost ) { + if (!isAlreadyDestroyed && hasBeenPromotedToSharedOrOwned) { + m_bulkData.entity_comm_map_insert(entity, EntityCommInfo(BulkData::SHARED, remote_proc)); + promotingToShared.push_back(entity); + } + m_bulkData.entity_comm_map_erase(key, m_bulkData.aura_ghosting()); + } + + if(!isAlreadyDestroyed) { + const bool wasDestroyedByOwner = remotely_destroyed; + const bool shouldDestroyGhost = wasDestroyedByOwner || (isAuraGhost && !isCustomGhost && !hasBeenPromotedToSharedOrOwned); + const bool shouldRemoveFromGhosting = remotely_destroyed && !isAuraGhost && hasBeenPromotedToSharedOrOwned; + + if (shouldRemoveFromGhosting) { + for ( size_t j = ghosting_count_minus_shared ; j >=1 ; --j ) { + m_bulkData.entity_comm_map_erase( key, *m_bulkData.m_ghosting[j] ); + } + } + + if ( shouldDestroyGhost ) { + const bool was_ghost = true; + m_bulkData.internal_destroy_entity_with_notification(entity, was_ghost); + } + + m_bulkData.entity_comm_list_insert(entity); + } + } + } // end loop on remote mod + + // Erase all ghosting communication lists for: + // 1) Destroyed entities. + // 2) Owned and modified entities. + + for ( EntityCommListInfoVector::const_reverse_iterator + i = m_bulkData.internal_comm_list().rbegin() ; i != m_bulkData.internal_comm_list().rend() ; ++i) { + + Entity entity = i->entity; + + const bool locally_destroyed = !is_valid(entity); + const bool locally_owned_and_modified = locally_destroyed ? false : + (Modified == m_bulkData.state(entity) && (m_bulkData.parallel_rank() == m_bulkData.parallel_owner_rank(entity))); + + if ( locally_destroyed ) { + for ( size_t j = ghosting_count_minus_shared ; j >=1 ; --j ) { + m_bulkData.entity_comm_map_erase( i->key, *m_bulkData.m_ghosting[j] ); + } + } + else if ( locally_owned_and_modified ) { + m_bulkData.entity_comm_map_erase( i->key, m_bulkData.aura_ghosting() ); + } + } + + if (!promotingToShared.empty()) { + OrdinalVector sharedPart, auraPart, scratchOrdinalVec, scratchSpace; + sharedPart.push_back(m_bulkData.mesh_meta_data().globally_shared_part().mesh_meta_data_ordinal()); + auraPart.push_back(m_bulkData.mesh_meta_data().aura_part().mesh_meta_data_ordinal()); + for(Entity entity : promotingToShared) { + m_bulkData.internal_change_entity_parts(entity, sharedPart /*add*/, auraPart /*remove*/, scratchOrdinalVec, scratchSpace); + } + m_bulkData.add_comm_list_entries_for_entities(promotingToShared); + } +} + +void MeshModification::add_entity_to_same_ghosting(Entity entity, Entity connectedGhost) +{ + for(PairIterEntityComm ec(m_bulkData.internal_entity_comm_map(connectedGhost)); ! ec.empty(); ++ec) { + if (ec->ghost_id > BulkData::AURA) { + m_bulkData.entity_comm_map_insert(entity, EntityCommInfo(ec->ghost_id, ec->proc)); + m_bulkData.entity_comm_list_insert(entity); + } + } +} + +void MeshModification::internal_resolve_formerly_shared_entities(const EntityVector& entitiesNoLongerShared) +{ + for(Entity entity : entitiesNoLongerShared) { + EntityVector ghostRelations = m_bulkData.get_upward_send_ghost_relations(entity); + + for(Entity ghost : ghostRelations) { + add_entity_to_same_ghosting(entity, ghost); + } + } +} + void MeshModification::ensure_meta_data_is_committed() { if (!m_bulkData.mesh_meta_data().is_commit()) diff --git a/packages/stk/stk_mesh/stk_mesh/baseImpl/MeshModification.hpp b/packages/stk/stk_mesh/stk_mesh/baseImpl/MeshModification.hpp index 2d120dd8d4fb..f677db89e866 100644 --- a/packages/stk/stk_mesh/stk_mesh/baseImpl/MeshModification.hpp +++ b/packages/stk/stk_mesh/stk_mesh/baseImpl/MeshModification.hpp @@ -37,14 +37,29 @@ #include // for MeshIndex, EntityRank, etc #include +#include +#include +#include "stk_mesh/base/EntityKey.hpp" +#include "stk_mesh/baseImpl/DeletedEntityCache.hpp" namespace stk { +class CommSparse; namespace mesh { class BulkData; namespace impl { +struct EntityParallelState { + int from_proc; + EntityState state; + EntityCommListInfo comm_info; + bool remote_owned_closure; + + bool operator<(const EntityParallelState& rhs) const + { return comm_info.key < rhs.comm_info.key; } +}; + class MeshModification { public: @@ -52,7 +67,7 @@ class MeshModification enum modification_optimization {MOD_END_SORT, MOD_END_NO_SORT }; MeshModification(stk::mesh::BulkData& bulkData) : m_bulkData(bulkData), m_entity_states(), - m_sync_state(MODIFIABLE), m_sync_count(0), m_did_any_shared_entity_change_parts(false) + m_deleted_entity_cache(bulkData), m_sync_state(MODIFIABLE), m_sync_count(0), m_did_any_shared_entity_change_parts(false) { m_entity_states.push_back(Deleted); } @@ -78,10 +93,12 @@ class MeshModification void change_entity_owner( const EntityProcVec & arg_change); void internal_resolve_shared_modify_delete(stk::mesh::EntityVector & entitiesNoLongerShared); + void internal_resolve_ghosted_modify_delete(const stk::mesh::EntityVector& entitiesNoLongerShared); bool did_any_shared_entity_change_parts () const { return m_did_any_shared_entity_change_parts; } void set_shared_entity_changed_parts() { m_did_any_shared_entity_change_parts = true; } + //TODO: these should be Entity::entity_value_type bool is_entity_deleted(size_t entity_index) const { return m_entity_states[entity_index] == Deleted; } bool is_entity_modified(size_t entity_index) const { return m_entity_states[entity_index] == Modified; } bool is_entity_created(size_t entity_index) const { return m_entity_states[entity_index] == Created; } @@ -90,23 +107,41 @@ class MeshModification stk::mesh::EntityState get_entity_state(size_t entity_index) const { return static_cast(m_entity_states[entity_index]); } void set_entity_state(size_t entity_index, stk::mesh::EntityState state) { m_entity_states[entity_index] = state; } - void mark_entity_as_deleted(size_t entity_index) { m_entity_states[entity_index] = Deleted; } + void mark_entity_as_deleted(Entity entity, bool is_ghost) + { + m_entity_states[entity.local_offset()] = Deleted; + m_deleted_entity_cache.mark_entity_as_deleted(entity, is_ghost); + } + void mark_entity_as_created(size_t entity_index) { m_entity_states[entity_index] = Created; } void add_created_entity_state() { m_entity_states.push_back(Created); } + DeletedEntityCache& get_deleted_entity_cache() { return m_deleted_entity_cache; } + + const DeletedEntityCache& get_deleted_entity_cache() const { return m_deleted_entity_cache; } private: + bool pack_entity_modification( const bool packShared , stk::CommSparse & comm ); + void communicate_entity_modification( const bool shared , std::vector & data ); void reset_shared_entity_changed_parts() { m_did_any_shared_entity_change_parts = false; } + void internal_establish_new_owner(stk::mesh::Entity entity); + void internal_update_parts_for_shared_entity(stk::mesh::Entity entity, const bool is_entity_shared, const bool did_i_just_become_owner); + void destroy_dependent_ghosts( Entity entity, EntityProcVec& entitiesToRemoveFromSharing ); + void delete_shared_entities_which_are_no_longer_in_owned_closure(EntityProcVec& entitiesToRemoveFromSharing); + void remove_entities_from_sharing(const EntityProcVec& entitiesToRemoveFromSharing, stk::mesh::EntityVector & entitiesNoLongerShared); + void add_entity_to_same_ghosting(Entity entity, Entity connectedGhost); + void internal_resolve_formerly_shared_entities(const EntityVector& entitiesNoLongerShared); void reset_undeleted_entity_states_to_unchanged(); void ensure_meta_data_is_committed(); - bool internal_modification_end(modification_optimization opt); bool internal_resolve_node_sharing(modification_optimization opt); bool internal_modification_end_after_node_sharing_resolution(modification_optimization opt); stk::mesh::BulkData &m_bulkData; std::vector m_entity_states; + DeletedEntityCache m_deleted_entity_cache; + BulkDataSyncState m_sync_state; size_t m_sync_count; bool m_did_any_shared_entity_change_parts; diff --git a/packages/stk/stk_mesh/stk_mesh/baseImpl/Visitors.hpp b/packages/stk/stk_mesh/stk_mesh/baseImpl/Visitors.hpp index f2e0127ed63d..7f4b92c89ab3 100644 --- a/packages/stk/stk_mesh/stk_mesh/baseImpl/Visitors.hpp +++ b/packages/stk/stk_mesh/stk_mesh/baseImpl/Visitors.hpp @@ -50,36 +50,49 @@ namespace stk { namespace mesh { namespace impl { +template +void VisitClosureBelowEntityNoRecurse( + const BulkData & mesh, + Entity inputEntity, + EntityRank inputEntityRank, + DO_THIS_FOR_ENTITY_IN_CLOSURE & do_this, + DESIRED_ENTITY & desired_entity) +{ + for (EntityRank rank = stk::topology::NODE_RANK ; rank < inputEntityRank ; ++rank) { + const unsigned num_entities_of_rank = mesh.num_connectivity(inputEntity,rank); + if (num_entities_of_rank > 0) { + const Entity * entities = mesh.begin(inputEntity,rank); + + for (unsigned i=0 ; i void VisitClosureNoRecurse( const BulkData & mesh, Entity inputEntity, + EntityRank inputEntityRank, DO_THIS_FOR_ENTITY_IN_CLOSURE & do_this, DESIRED_ENTITY & desired_entity) { if (desired_entity(inputEntity)) { do_this(inputEntity); - const EntityRank inputEntityRank = mesh.entity_rank(inputEntity); - for (EntityRank rank = stk::topology::NODE_RANK ; rank < inputEntityRank ; ++rank) { - const unsigned num_entities_of_rank = mesh.num_connectivity(inputEntity,rank); - if (num_entities_of_rank > 0) { - const Entity * entities = mesh.begin(inputEntity,rank); - - for (unsigned i=0 ; i(mesh.entity_rank(entity) - 1); - while (mesh.num_connectivity(entity, nextLowerRank) == 0 && nextLowerRank > stk::topology::NODE_RANK) { + EntityRank nextLowerRank = static_cast(entityRank - 1); + while (nextLowerRank > stk::topology::NODE_RANK && mesh.num_connectivity(entity, nextLowerRank) == 0) { nextLowerRank = static_cast(nextLowerRank-1); } return nextLowerRank; @@ -89,21 +102,21 @@ template void VisitClosureGeneral( const BulkData & mesh, Entity inputEntity, + EntityRank inputEntityRank, DO_THIS_FOR_ENTITY_IN_CLOSURE & do_this, DESIRED_ENTITY & desired_entity) { - const EntityRank inputEntityRank = mesh.entity_rank(inputEntity); if (inputEntityRank <= stk::topology::ELEM_RANK) { - VisitClosureNoRecurse(mesh, inputEntity, do_this, desired_entity); + VisitClosureNoRecurse(mesh, inputEntity, inputEntityRank, do_this, desired_entity); } else if (desired_entity(inputEntity)) { do_this(inputEntity); - const EntityRank nextLowerRank = get_highest_downward_connected_rank(mesh, inputEntity); + const EntityRank nextLowerRank = get_highest_downward_connected_rank(mesh, inputEntity, inputEntityRank); const unsigned num_entities_of_rank = mesh.num_connectivity(inputEntity,nextLowerRank); if (num_entities_of_rank > 0) { const Entity * entities = mesh.begin(inputEntity,nextLowerRank); for (unsigned i=0 ; i(mesh,get_entity(entity_iterator),do_this,desired_entity); + Entity entity = get_entity(entity_iterator); + VisitClosureGeneral(mesh,entity,mesh.entity_rank(entity),do_this,desired_entity); } } @@ -234,7 +248,7 @@ void VisitClosure( DO_THIS_FOR_ENTITY_IN_CLOSURE & do_this) { OnlyVisitOnce ovo(mesh); - VisitClosureGeneral(mesh,entity_of_interest,do_this,ovo); + VisitClosureGeneral(mesh,entity_of_interest,mesh.entity_rank(entity_of_interest),do_this,ovo); } @@ -457,19 +471,30 @@ struct OnlyGhosts { }; struct OnlyGhostsEPM { - OnlyGhostsEPM(BulkData & mesh_in, const EntityProcMapping& epm_in, const EntityProcMapping& entityShr) - : mesh(mesh_in), myMapping(epm_in), entitySharing(entityShr) {} + OnlyGhostsEPM(BulkData & mesh_in, const EntityProcMapping& entityShr) + : mesh(mesh_in), entitySharing(entityShr) {} + bool operator()(Entity entity) { + if (proc != mesh.parallel_owner_rank(entity)) { + const bool isSharedWithProc = entitySharing.find(entity, proc); + return !isSharedWithProc; + } + return false; + } + BulkData & mesh; + const EntityProcMapping& entitySharing; + int proc; +}; + +struct NotAlreadyShared { + NotAlreadyShared(BulkData & mesh_in, const EntityProcMapping& entityShr) + : mesh(mesh_in), entitySharing(entityShr) {} bool operator()(Entity entity) { - if (!myMapping.find(entity, proc)) { - if (proc != mesh.parallel_owner_rank(entity)) { - const bool isSharedWithProc = entitySharing.find(entity, proc); - return !isSharedWithProc; - } + if (proc != mesh.parallel_owner_rank(entity)) { + return !entitySharing.find(entity,proc); } return false; } BulkData & mesh; - const EntityProcMapping& myMapping; const EntityProcMapping& entitySharing; int proc; }; diff --git a/packages/stk/stk_mesh/stk_mesh/baseImpl/elementGraph/ElemElemGraphImpl.hpp b/packages/stk/stk_mesh/stk_mesh/baseImpl/elementGraph/ElemElemGraphImpl.hpp index a30762821b93..eb8e3a388d05 100644 --- a/packages/stk/stk_mesh/stk_mesh/baseImpl/elementGraph/ElemElemGraphImpl.hpp +++ b/packages/stk/stk_mesh/stk_mesh/baseImpl/elementGraph/ElemElemGraphImpl.hpp @@ -277,8 +277,7 @@ struct IdViaSidePair }//namespace impl -const int max_num_sides_per_elem = 10; -const double inverse_of_max_num_sides_per_elem = 0.1; +constexpr int max_num_sides_per_elem = 8; struct GraphEdge { @@ -289,7 +288,7 @@ struct GraphEdge } GraphEdge() : - vertex1(std::numeric_limits::max()), vertex2(std::numeric_limits::max()) + vertex1(impl::INVALID_LOCAL_ID), vertex2(impl::INVALID_LOCAL_ID) {} GraphEdge(const GraphEdge& rhs) @@ -326,12 +325,12 @@ struct GraphEdge impl::LocalId elem1() const { - return vertex1*inverse_of_max_num_sides_per_elem; + return vertex1/max_num_sides_per_elem; } impl::LocalId elem2() const { - return vertex2*inverse_of_max_num_sides_per_elem; + return vertex2/max_num_sides_per_elem; } int get_side(const impl::LocalId& vertex) const @@ -350,6 +349,11 @@ struct GraphEdge impl::LocalId vertex2; }; +constexpr bool is_valid(const GraphEdge& lhs) +{ + return lhs.vertex1 != impl::INVALID_LOCAL_ID; +} + using CoincidentElementConnection = GraphEdge; struct GraphEdgeLessByElem1 { @@ -383,6 +387,17 @@ struct GraphEdgeLessByElem1 { } }; +struct GraphEdgeLessByElem2Only +{ + bool operator()(const GraphEdge& a, const GraphEdge& b) const + { + impl::LocalId a_elem2 = std::abs(a.elem2()); + impl::LocalId b_elem2 = std::abs(b.elem2()); + + return a_elem2 < b_elem2 || (a_elem2 == b_elem2 && a.side2() < b.side2()); + } +}; + inline bool operator<(const GraphEdge& a, const GraphEdge& b) { @@ -421,7 +436,9 @@ bool operator==(const GraphEdge& a, const GraphEdge& b) inline std::ostream& operator<<(std::ostream& out, const GraphEdge& graphEdge) { - out << "(" << graphEdge.vertex1 << " -> " << graphEdge.vertex2 << ")"; + out << "GraphEdge vertices: (" << graphEdge.vertex1 << " -> " << graphEdge.vertex2 + << "), element-side pairs: (" << graphEdge.elem1() << ", " << graphEdge.side1() + << ") -> (" << graphEdge.elem2() << ", " << graphEdge.side2() << ")"; return out; } diff --git a/packages/stk/stk_mesh/stk_mesh/baseImpl/elementGraph/ElemGraphShellConnections.cpp b/packages/stk/stk_mesh/stk_mesh/baseImpl/elementGraph/ElemGraphShellConnections.cpp index e8578433b975..2f14c6d323e4 100644 --- a/packages/stk/stk_mesh/stk_mesh/baseImpl/elementGraph/ElemGraphShellConnections.cpp +++ b/packages/stk/stk_mesh/stk_mesh/baseImpl/elementGraph/ElemGraphShellConnections.cpp @@ -120,8 +120,14 @@ void remove_graph_edges_blocked_by_shell(GraphInfo &graphInfo) SideConnections sideConnectionsForElement(graphInfo.elementTopologies[localId].num_sides()); for(int side : sideConnectionsForElement.get_sides_connected_to_shell_and_nonshell(graphInfo, localId)) fill_non_shell_graph_edges_to_delete(graphInfo, stk::mesh::impl::ElementSidePair(localId, side), edgesToDelete); + + if (edgesToDelete.size() > 0) + { + std::sort(edgesToDelete.begin(), edgesToDelete.end(), GraphEdgeLessByElem1()); + graphInfo.graph.delete_sorted_edges(edgesToDelete); + edgesToDelete.clear(); + } } - graphInfo.graph.delete_sorted_edges(edgesToDelete); } } diff --git a/packages/stk/stk_mesh/stk_mesh/baseImpl/elementGraph/GraphEdgeData.cpp b/packages/stk/stk_mesh/stk_mesh/baseImpl/elementGraph/GraphEdgeData.cpp index aa3ea37716bb..aeb9629f53af 100644 --- a/packages/stk/stk_mesh/stk_mesh/baseImpl/elementGraph/GraphEdgeData.cpp +++ b/packages/stk/stk_mesh/stk_mesh/baseImpl/elementGraph/GraphEdgeData.cpp @@ -1,5 +1,6 @@ #include "GraphEdgeData.hpp" #include "ElemElemGraphImpl.hpp" +#include "stk_mesh/baseImpl/elementGraph/GraphTypes.hpp" #include #include @@ -10,35 +11,37 @@ namespace mesh void Graph::set_num_local_elements(size_t n) { - m_elemOffsets.resize(n+1); + m_elemOffsets.resize(n, IndexRange(m_graphEdges.size()+1, m_graphEdges.size()+1)); } void Graph::add_new_element() { - if (m_elemOffsets.empty()) { - m_elemOffsets.assign(1, 0); - } - m_elemOffsets.push_back(m_graphEdges.size()); + m_elemOffsets.push_back({m_graphEdges.size()+1, m_graphEdges.size()+1}); } + size_t Graph::get_num_elements_in_graph() const { - return m_elemOffsets.size() - 1; + return m_elemOffsets.size(); } size_t Graph::get_num_edges() const { - return m_graphEdges.size(); + return m_graphEdges.size() - m_numUnusedEntries; } size_t Graph::get_num_edges_for_element(impl::LocalId elem) const { - return m_elemOffsets[elem+1] - m_elemOffsets[elem]; + auto& indices = m_elemOffsets[elem]; + return indices.second - indices.first; } const GraphEdge & Graph::get_edge_for_element(impl::LocalId elem1, size_t index) const { - return m_graphEdges[m_elemOffsets[elem1]+index]; + ThrowAssertMsg(get_num_edges_for_element(elem1) != 0, "Cannot retrieve graph edge for element that has no faces"); + ThrowAssertMsg(get_num_edges_for_element(elem1) > index, "index out of range"); + + return m_graphEdges[m_elemOffsets[elem1].first+index]; } void fill_graph_edges_for_elem_side(const GraphEdgesForElement &graphEdgesForElement, int side, std::vector& edges) @@ -61,80 +64,183 @@ std::vector Graph::get_edges_for_element_side(impl::LocalId elem, int GraphEdgesForElement Graph::get_edges_for_element(impl::LocalId elem) const { - const unsigned begin = m_elemOffsets[elem]; - const unsigned end = m_elemOffsets[elem+1]; - return GraphEdgesForElement(&m_graphEdges[begin], &m_graphEdges[end]); + const unsigned beginOffset = m_elemOffsets[elem].first; + const unsigned endOffset = m_elemOffsets[elem].second; + + const GraphEdge* beginEdge = m_graphEdges.data() + beginOffset; + const GraphEdge* endEdge = m_graphEdges.data() + endOffset; + return GraphEdgesForElement(beginEdge, endEdge); } + void Graph::set_offsets() { - const unsigned numOffsets = m_elemOffsets.size(); - m_elemOffsets.assign(std::max(1u, numOffsets), 0); + if (m_graphEdges.size() == 0) + { + return; + } - impl::LocalId prevElem = impl::INVALID_LOCAL_ID; - unsigned edgeCounter = 0; - for(const GraphEdge& edge : m_graphEdges) { - impl::LocalId elem1 = edge.elem1(); - if (elem1 != prevElem) { - if (prevElem != impl::INVALID_LOCAL_ID) { - m_elemOffsets[prevElem] = edgeCounter; + impl::LocalId currElem = m_graphEdges[0].elem1(); + unsigned startIdx = 0; + for (unsigned i=0; i < m_graphEdges.size(); ++i) + { + impl::LocalId nextElem = m_graphEdges[i].elem1(); + if (nextElem != currElem) + { + ThrowAssertMsg(currElem >= 0 && size_t(currElem) <= m_elemOffsets.size(), "element out of range"); + m_elemOffsets[currElem] = IndexRange(startIdx, i); + for (impl::LocalId elem=currElem+1; elem < nextElem; elem++) + { + m_elemOffsets[elem] = IndexRange(0, 0); } - edgeCounter = 0; - prevElem = elem1; - } - ++edgeCounter; - } - if (prevElem != impl::INVALID_LOCAL_ID) { - m_elemOffsets[prevElem] = edgeCounter; + currElem = nextElem; + startIdx = i; + } } - unsigned edgeOffset = 0; - size_t numElems = m_elemOffsets.size()-1; - for(size_t i=0; i::iterator; void Graph::add_sorted_edges(const std::vector& graphEdges) { ThrowAssertMsg(stk::util::is_sorted_and_unique(graphEdges, GraphEdgeLessByElem1()),"Input vector 'graphEdges' is expected to be sorted-and-unique"); - if (!graphEdges.empty()) { - stk::util::insert_keep_sorted(graphEdges, m_graphEdges, GraphEdgeLessByElem1()); - set_offsets(); + + for (auto& edge : graphEdges) + { + insert_edge(edge); + } +} + + +void Graph::insert_edge(const GraphEdge& graphEdge) +{ + auto elem1 = graphEdge.elem1(); + auto& indices = m_elemOffsets[elem1]; + + if (check_for_edge(graphEdge)) + { + return; + } + + if (m_graphEdges.size() > 0 && double(m_numUnusedEntries) / m_graphEdges.size() > m_compressionThreshold) + { + compress_graph(); + } + + if (get_num_edges_for_element(elem1) == 0) + { + m_graphEdges.push_back(graphEdge); + indices.first = m_graphEdges.size()-1; + indices.second = m_graphEdges.size(); + } else if (indices.second >= m_graphEdges.size()) + { + m_graphEdges.emplace_back(); + insert_edge_into_sorted_range_or_next_entry(indices, graphEdge); + } else if (is_valid(m_graphEdges[indices.second])) + { + move_edges_to_end(elem1); + + m_graphEdges.emplace_back(); + insert_edge_into_sorted_range_or_next_entry(indices, graphEdge); + } else if (!is_valid(m_graphEdges[indices.second])) + { + insert_edge_into_sorted_range_or_next_entry(indices, graphEdge); + m_numUnusedEntries--; + } else + { + throw std::runtime_error("unreachable case"); } } +void Graph::insert_edge_into_sorted_range_or_next_entry(IndexRange& indices, const GraphEdge& graphEdge) +{ + unsigned idxToInsert = find_sorted_insertion_index(indices, graphEdge); + + for (unsigned i=indices.second; i > idxToInsert; i--) + { + m_graphEdges[i] = m_graphEdges[i-1]; + } + + m_graphEdges[idxToInsert] = graphEdge; + indices.second++; +} + + +unsigned Graph::find_sorted_insertion_index(IndexRange indices, const GraphEdge& graphEdge) +{ + GraphEdgeLessByElem2Only isLess; + for (unsigned i=indices.first; i < indices.second; ++i) + { + if (isLess(graphEdge, m_graphEdges[i])) + { + return i; + } + } + + return indices.second; +} + void Graph::replace_sorted_edges(std::vector& graphEdges) { + ThrowAssertMsg(stk::util::is_sorted_and_unique(graphEdges, GraphEdgeLessByElem1()),"Input vector 'graphEdges' is expected to be sorted-and-unique"); + m_graphEdges.swap(graphEdges); set_offsets(); + m_numUnusedEntries = 0; } + void Graph::delete_sorted_edges(const std::vector& edgesToDelete) { - for(const GraphEdge& edgeToDelete : edgesToDelete) { - impl::LocalId elem1 = edgeToDelete.elem1(); - for(unsigned offset = m_elemOffsets[elem1]; offset < m_elemOffsets[elem1+1]; ++offset) { - GraphEdge& thisEdge = m_graphEdges[offset]; - if (thisEdge == edgeToDelete) { - thisEdge.vertex1 = impl::INVALID_LOCAL_ID; - } + ThrowAssertMsg(std::is_sorted(edgesToDelete.begin(), edgesToDelete.end(), GraphEdgeLessByElem1()), + "Input vector is expected to be sorted"); + + int startIdx = 0; + while (size_t(startIdx) != edgesToDelete.size()) + { + int endIdx = get_end_of_element_range_for_sorted_edges(edgesToDelete, startIdx); + for (int idx=endIdx; idx >= startIdx; idx--) + { + delete_edge(edgesToDelete[idx]); } + + startIdx = endIdx + 1; } +} + +unsigned Graph::get_end_of_element_range_for_sorted_edges(const std::vector& edges, unsigned startIdx) +{ + unsigned currElement = edges[startIdx].elem1(); + unsigned endIdx = startIdx; + while (endIdx < edges.size() && edges[endIdx].elem1() == currElement) + { + endIdx++; + } + endIdx--; - if (!edgesToDelete.empty()) { - const unsigned offset = m_elemOffsets[edgesToDelete[0].elem1()]; - m_graphEdges.erase(std::remove_if(m_graphEdges.begin()+offset, m_graphEdges.end(), - [](const GraphEdge& edge) - { return edge.vertex1 == impl::INVALID_LOCAL_ID; }), - m_graphEdges.end()); - set_offsets(); + return endIdx; +} + +void Graph::delete_edge(const GraphEdge& edgeToDelete) +{ + impl::LocalId elem1 = edgeToDelete.elem1(); + auto& indices = m_elemOffsets[elem1]; + for(unsigned offset = indices.first; offset < indices.second; ++offset) { + if (m_graphEdges[offset] == edgeToDelete) + { + for (unsigned i=offset; i < indices.second-1; ++i) + { + m_graphEdges[i] = m_graphEdges[i+1]; + } + indices.second--; + m_graphEdges[indices.second] = GraphEdge(); + m_numUnusedEntries++; + break; + } } } @@ -142,8 +248,93 @@ void Graph::clear() { m_graphEdges.clear(); m_elemOffsets.clear(); + m_numUnusedEntries = 0; +} + + +void Graph::move_edges_to_end(impl::LocalId elem) +{ + auto& indices = m_elemOffsets[elem]; + size_t newStartIdx = m_graphEdges.size(); + for (unsigned i=indices.first; i < indices.second; ++i) + { + m_graphEdges.push_back(m_graphEdges[i]); + m_graphEdges[i] = GraphEdge(); + m_numUnusedEntries++; + } + + m_elemOffsets[elem] = IndexRange(newStartIdx, m_graphEdges.size()); +} + +void Graph::compress_graph() +{ + if (m_graphEdges.size() == 0 || m_graphEdges.size() == m_numUnusedEntries) + return; + + impl::LocalId prevElement = 0; + unsigned offset = 0; + for (unsigned i=0; i < m_graphEdges.size(); ++i) + { + if (is_valid(m_graphEdges[i])) + { + prevElement = m_graphEdges[i].elem1(); + break; + } else + { + offset++; + } + } + + { + auto& indices = m_elemOffsets[prevElement]; + indices.first -= offset; + indices.second -= offset; + } + + for (unsigned idx=offset; idx < m_graphEdges.size(); ++idx) + { + if (is_valid(m_graphEdges[idx])) + { + m_graphEdges[idx - offset] = m_graphEdges[idx]; + + impl::LocalId currElement = m_graphEdges[idx].elem1(); + if (currElement != prevElement) + { + auto& indices = m_elemOffsets[currElement]; + if (indices.first != indices.second) + { + indices.first -= offset; + indices.second -= offset; + } + prevElement = currElement; + } + + } else + { + offset++; + } + } + + ThrowRequireMsg(is_valid(m_graphEdges[m_graphEdges.size() - offset - 1]), "The count of unused edges is incorrect"); + m_graphEdges.resize(m_graphEdges.size() - offset); + m_numUnusedEntries = 0; } + +bool Graph::check_for_edge(const GraphEdge& edge) +{ + auto& indices = m_elemOffsets[edge.elem1()]; + for (unsigned i=indices.first; i < indices.second; ++i) + if (m_graphEdges[i] == edge) + { + return true; + } + + return false; +} + + + impl::ParallelInfo& ParallelInfoForGraphEdges::get_parallel_info_for_graph_edge(const GraphEdge& graphEdge) { return const_cast(get_parallel_info_iterator_for_graph_edge(graphEdge)->second); diff --git a/packages/stk/stk_mesh/stk_mesh/baseImpl/elementGraph/GraphEdgeData.hpp b/packages/stk/stk_mesh/stk_mesh/baseImpl/elementGraph/GraphEdgeData.hpp index c6640523b300..352dcab6e426 100644 --- a/packages/stk/stk_mesh/stk_mesh/baseImpl/elementGraph/GraphEdgeData.hpp +++ b/packages/stk/stk_mesh/stk_mesh/baseImpl/elementGraph/GraphEdgeData.hpp @@ -65,9 +65,31 @@ class Graph void clear(); private: + using IndexRange = std::pair; + void set_offsets(); + + void insert_edge(const GraphEdge& graphEdge); + + void insert_edge_into_sorted_range_or_next_entry(IndexRange& indices, const GraphEdge& graphEdge); + + unsigned find_sorted_insertion_index(IndexRange indices, const GraphEdge& graphEdge); + + void move_edges_to_end(impl::LocalId elem); + + void compress_graph(); + + + unsigned get_end_of_element_range_for_sorted_edges(const std::vector& edges, unsigned startIdx); + + void delete_edge(const GraphEdge& edgeToDelete); + + bool check_for_edge(const GraphEdge& edge); + std::vector m_graphEdges; - std::vector m_elemOffsets; + std::vector m_elemOffsets; + unsigned m_numUnusedEntries = 0; + const double m_compressionThreshold = 0.2; }; class ParallelInfoForGraphEdges diff --git a/packages/stk/stk_performance_tests/stk_mesh/NgpMeshUpdate.cpp b/packages/stk/stk_performance_tests/stk_mesh/NgpMeshUpdate.cpp index 4597903835f0..716b0d2c5337 100644 --- a/packages/stk/stk_performance_tests/stk_mesh/NgpMeshUpdate.cpp +++ b/packages/stk/stk_performance_tests/stk_mesh/NgpMeshUpdate.cpp @@ -41,6 +41,7 @@ #include #include #include +#include #include class NgpMeshChangeElementPartMembership : public stk::unit_test_util::simple_fields::MeshFixture @@ -48,12 +49,17 @@ class NgpMeshChangeElementPartMembership : public stk::unit_test_util::simple_fi public: NgpMeshChangeElementPartMembership() : stk::unit_test_util::simple_fields::MeshFixture(), - newPartName("block2") + newPartName("block2"), + numElements(1000000) { } - void setup_host_mesh() + void setup_host_mesh(stk::mesh::BulkData::AutomaticAuraOption auraOption) { - setup_mesh("generated:100x100x100", stk::mesh::BulkData::NO_AUTO_AURA); +#ifdef NDEBUG + setup_mesh("generated:400x250x10", auraOption); +#else + setup_mesh("generated:10x10x100", auraOption); +#endif get_meta().declare_part(newPartName); } @@ -75,8 +81,10 @@ class NgpMeshChangeElementPartMembership : public stk::unit_test_util::simple_fi private: stk::mesh::Entity get_element(int cycle) { - stk::mesh::EntityId elemId = cycle+1; - return get_bulk().get_entity(stk::topology::ELEM_RANK, elemId); + stk::mesh::EntityId firstLocalElemId = get_parallel_rank()*numElements/2 + 1; + stk::mesh::EntityId elemId = firstLocalElemId + cycle; + stk::mesh::Entity elem = get_bulk().get_entity(stk::topology::ELEM_RANK, elemId); + return elem; } stk::mesh::Part* get_part() @@ -85,6 +93,7 @@ class NgpMeshChangeElementPartMembership : public stk::unit_test_util::simple_fi } std::string newPartName; + unsigned numElements; }; class NgpMeshCreateEntity : public stk::unit_test_util::simple_fields::MeshFixture @@ -131,12 +140,12 @@ class NgpMeshGhosting : public stk::unit_test_util::simple_fields::MeshFixture { } protected: - void setup_host_mesh() + void setup_host_mesh(stk::mesh::BulkData::AutomaticAuraOption auraOption) { #ifdef NDEBUG - setup_mesh("generated:100x100x100", stk::mesh::BulkData::NO_AUTO_AURA); + setup_mesh("generated:400x250x10", auraOption); #else - setup_mesh("generated:10x10x100", stk::mesh::BulkData::NO_AUTO_AURA); + setup_mesh("generated:10x10x100", auraOption); #endif get_bulk().modification_begin(); ghosting = &get_bulk().create_ghosting(ghostingName); @@ -174,7 +183,26 @@ TEST_F( NgpMeshChangeElementPartMembership, Timing ) stk::performance_tests::Timer timer(get_comm()); timer.start_timing(); - setup_host_mesh(); + setup_host_mesh(stk::mesh::BulkData::NO_AUTO_AURA); + + for (int i=0; iinternal_resolve_ghosted_modify_delete(entitiesNoLongerShared); + this->m_meshModification.internal_resolve_ghosted_modify_delete(entitiesNoLongerShared); } void my_internal_resolve_parallel_create() @@ -280,12 +280,6 @@ class BulkDataTester : public stk::mesh::BulkData set_state(entity,entity_state); } - void my_delete_shared_entities_which_are_no_longer_in_owned_closure() - { - stk::mesh::EntityProcVec entitiesToRemoveFromSharing; - delete_shared_entities_which_are_no_longer_in_owned_closure(entitiesToRemoveFromSharing); - } - void my_ghost_entities_and_fields(stk::mesh::Ghosting & ghosting, const std::set& new_send) { ghost_entities_and_fields(ghosting, new_send); diff --git a/packages/stk/stk_unit_tests/stk_balance/UnitTestCommandLineParsing.cpp b/packages/stk/stk_unit_tests/stk_balance/UnitTestCommandLineParsing.cpp index f732dfa5f5b0..e50b486f650d 100644 --- a/packages/stk/stk_unit_tests/stk_balance/UnitTestCommandLineParsing.cpp +++ b/packages/stk/stk_unit_tests/stk_balance/UnitTestCommandLineParsing.cpp @@ -138,14 +138,14 @@ TEST_F(BalanceCommandLine, createBalanceSettings_default) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::vertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::graphEdgeWeightMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -164,12 +164,12 @@ TEST_F(BalanceCommandLine, createBalanceSettings_outputDirectory) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -190,12 +190,12 @@ TEST_F(BalanceCommandLine, createBalanceSettings_outputDirectory_fullOptions) EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_TRUE(balanceSettings.includeSearchResultsInGraph()); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -210,12 +210,12 @@ TEST_F(BalanceCommandLine, createBalanceSettings_customLogfile) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -230,12 +230,12 @@ TEST_F(BalanceCommandLine, createBalanceSettings_shortCustomLogfile) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -250,12 +250,12 @@ TEST_F(BalanceCommandLine, createBalanceSettings_coutLogfile) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -266,12 +266,12 @@ TEST_F(BalanceCommandLine, createBalanceSettings_printDiagnostics) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), true); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -282,12 +282,12 @@ TEST_F(BalanceCommandLine, createBalanceSettings_shortPrintDiagnostics) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), true); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -304,12 +304,12 @@ TEST_F(BalanceCommandLine, createBalanceSettings_rebalanceTo) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -326,12 +326,12 @@ TEST_F(BalanceCommandLine, createBalanceSettings_useNestedDecomp) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), true); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -342,13 +342,13 @@ TEST_F(BalanceCommandLine, createBalanceSettings_smDefaults) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::smFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); - EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::vertexWeightMethod); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::smVertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::smFaceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::smFaceSearchVertexMultiplier); - EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::graphEdgeWeightMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::smGraphEdgeWeightMultiplier); check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -360,14 +360,14 @@ TEST_F(BalanceCommandLine, createBalanceSettings_sdDefaults) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), true); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::sdFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); - EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::vertexWeightMethod); - EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); - EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::graphEdgeWeightMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::sdVertexWeightMethod); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::sdFaceSearchEdgeWeight); + EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::sdFaceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::sdGraphEdgeWeightMultiplier); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -379,10 +379,12 @@ TEST_F(BalanceCommandLine, createBalanceSettings_smDefaultsOverrideSpider) EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); EXPECT_EQ(balanceSettings.shouldFixSpiders(), true); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::smVertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::smFaceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::smFaceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::smGraphEdgeWeightMultiplier); check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -394,11 +396,13 @@ TEST_F(BalanceCommandLine, createBalanceSettings_smDefaultsOverrideMechanism) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::smFixSpiders); EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::smVertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::smFaceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::smFaceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::smGraphEdgeWeightMultiplier); check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -411,11 +415,13 @@ TEST_F(BalanceCommandLine, createBalanceSettings_sdDefaultsOverrideSpiders) EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); - EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); - EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::sdVertexWeightMethod); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::sdFaceSearchEdgeWeight); + EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::sdFaceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::sdGraphEdgeWeightMultiplier); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -426,12 +432,14 @@ TEST_F(BalanceCommandLine, createBalanceSettings_sdDefaultsOverrideMechanisms) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::sdFixSpiders); EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); - EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); - EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::sdVertexWeightMethod); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::sdFaceSearchEdgeWeight); + EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::sdFaceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::sdGraphEdgeWeightMultiplier); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -448,11 +456,13 @@ TEST_F(BalanceCommandLine, createBalanceSettings_defaultAbsoluteTolerance) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::vertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::graphEdgeWeightMultiplier); check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -464,11 +474,13 @@ TEST_F(BalanceCommandLine, createBalanceSettings_defaultRelativeTolerance) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::vertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::graphEdgeWeightMultiplier); check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -486,11 +498,13 @@ TEST_F(BalanceCommandLine, createBalanceSettings_faceSearchAbsoluteTolerance) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::vertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::graphEdgeWeightMultiplier); check_absolute_tolerance_for_face_search(balanceSettings, 0.001); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -502,11 +516,13 @@ TEST_F(BalanceCommandLine, createBalanceSettings_faceSearchRelativeTolerance) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::vertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::graphEdgeWeightMultiplier); check_relative_tolerance_for_face_search(balanceSettings, 0.123); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -519,11 +535,13 @@ TEST_F(BalanceCommandLine, createBalanceSettings_smDefaults_defaultAbsoluteToler EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::smFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::smVertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::smFaceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::smFaceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::smGraphEdgeWeightMultiplier); check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -535,11 +553,13 @@ TEST_F(BalanceCommandLine, createBalanceSettings_smDefaults_defaultRelativeToler EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::smFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::smVertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::smFaceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::smFaceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::smGraphEdgeWeightMultiplier); check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -551,11 +571,13 @@ TEST_F(BalanceCommandLine, createBalanceSettings_smDefaults_faceSearchAbsoluteTo EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::smFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::smVertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::smFaceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::smFaceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::smGraphEdgeWeightMultiplier); check_absolute_tolerance_for_face_search(balanceSettings, 0.005); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -567,11 +589,13 @@ TEST_F(BalanceCommandLine, createBalanceSettings_smDefaults_faceSearchRelativeTo EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::smFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::smVertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::smFaceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::smFaceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::smGraphEdgeWeightMultiplier); check_relative_tolerance_for_face_search(balanceSettings, 0.123); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -584,11 +608,13 @@ TEST_F(BalanceCommandLine, createBalanceSettings_sdDefaults_defaultAbsoluteToler EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), true); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::sdFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); - EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); - EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::sdVertexWeightMethod); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::sdFaceSearchEdgeWeight); + EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::sdFaceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::sdGraphEdgeWeightMultiplier); check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -600,11 +626,13 @@ TEST_F(BalanceCommandLine, createBalanceSettings_sdDefaults_defaultRelativeToler EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), true); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::sdFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); - EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); - EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::sdVertexWeightMethod); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::sdFaceSearchEdgeWeight); + EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::sdFaceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::sdGraphEdgeWeightMultiplier); check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -616,11 +644,13 @@ TEST_F(BalanceCommandLine, createBalanceSettings_sdDefaults_faceSearchAbsoluteTo EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), true); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::sdFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); - EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); - EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::sdVertexWeightMethod); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::sdFaceSearchEdgeWeight); + EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::sdFaceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::sdGraphEdgeWeightMultiplier); check_absolute_tolerance_for_face_search(balanceSettings, 0.0005); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -632,11 +662,13 @@ TEST_F(BalanceCommandLine, createBalanceSettings_sdDefaults_faceSearchRelativeTo EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), true); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::sdFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); - EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); - EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::sdVertexWeightMethod); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::sdFaceSearchEdgeWeight); + EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::sdFaceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::sdGraphEdgeWeightMultiplier); check_relative_tolerance_for_face_search(balanceSettings, 0.123); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -644,7 +676,7 @@ TEST_F(BalanceCommandLine, createBalanceSettings_sdDefaults_faceSearchRelativeTo TEST_F(BalanceCommandLine, createBalanceSettings_contactSearchEdgeWeight) { - const stk::balance::BalanceSettings& balanceSettings = get_stk_balance_settings({"--EXP-contact-search-edge-weight=20"}); + const stk::balance::BalanceSettings& balanceSettings = get_stk_balance_settings({"--contact-search-edge-weight=20"}); const int finalNumProcs = stk::parallel_machine_size(MPI_COMM_WORLD); @@ -655,30 +687,32 @@ TEST_F(BalanceCommandLine, createBalanceSettings_contactSearchEdgeWeight) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::vertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), 20); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::graphEdgeWeightMultiplier); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } TEST_F(BalanceCommandLine, createBalanceSettings_smDefaults_contactSearchEdgeWeight) { const stk::balance::BalanceSettings& balanceSettings = get_stk_balance_settings({"--sm", - "--EXP-contact-search-edge-weight=20"}); + "--contact-search-edge-weight=20"}); EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::smFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); - EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::vertexWeightMethod); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::smVertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), 20); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::smFaceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::smGraphEdgeWeightMultiplier); check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -686,25 +720,26 @@ TEST_F(BalanceCommandLine, createBalanceSettings_smDefaults_contactSearchEdgeWei TEST_F(BalanceCommandLine, createBalanceSettings_sdDefaults_contactSearchEdgeWeight) { const stk::balance::BalanceSettings& balanceSettings = get_stk_balance_settings({"--sd", - "--EXP-contact-search-edge-weight=20"}); + "--contact-search-edge-weight=20"}); EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), true); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::sdFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); - EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::vertexWeightMethod); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::sdVertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), 20); - EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::sdFaceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::sdGraphEdgeWeightMultiplier); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } TEST_F(BalanceCommandLine, createBalanceSettings_contactSearchVertexWeightMultiplier) { - const stk::balance::BalanceSettings& balanceSettings = get_stk_balance_settings({"--EXP-contact-search-vertex-weight-mult=9"}); + const stk::balance::BalanceSettings& balanceSettings = get_stk_balance_settings({"--contact-search-vertex-weight-mult=9"}); const int finalNumProcs = stk::parallel_machine_size(MPI_COMM_WORLD); @@ -715,30 +750,32 @@ TEST_F(BalanceCommandLine, createBalanceSettings_contactSearchVertexWeightMultip EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::vertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), 9); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::graphEdgeWeightMultiplier); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } TEST_F(BalanceCommandLine, createBalanceSettings_smDefaults_contactSearchVertexWeightMultiplier) { const stk::balance::BalanceSettings& balanceSettings = get_stk_balance_settings({"--sm", - "--EXP-contact-search-vertex-weight-mult=9"}); + "--contact-search-vertex-weight-mult=9"}); EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::smFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); - EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::vertexWeightMethod); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::smVertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::smFaceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), 9); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::smGraphEdgeWeightMultiplier); check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -746,25 +783,26 @@ TEST_F(BalanceCommandLine, createBalanceSettings_smDefaults_contactSearchVertexW TEST_F(BalanceCommandLine, createBalanceSettings_sdDefaults_contactSearchVertexWeightMultiplier) { const stk::balance::BalanceSettings& balanceSettings = get_stk_balance_settings({"--sd", - "--EXP-contact-search-vertex-weight-mult=9"}); + "--contact-search-vertex-weight-mult=9"}); EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), true); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::sdFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); - EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::vertexWeightMethod); - EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::sdVertexWeightMethod); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::sdFaceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), 9); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::sdGraphEdgeWeightMultiplier); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } TEST_F(BalanceCommandLine, createBalanceSettings_edgeWeightMultiplier) { - const stk::balance::BalanceSettings& balanceSettings = get_stk_balance_settings({"--EXP-edge-weight-mult=3"}); + const stk::balance::BalanceSettings& balanceSettings = get_stk_balance_settings({"--edge-weight-mult=3"}); const int finalNumProcs = stk::parallel_machine_size(MPI_COMM_WORLD); @@ -775,29 +813,29 @@ TEST_F(BalanceCommandLine, createBalanceSettings_edgeWeightMultiplier) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::vertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), 3); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } TEST_F(BalanceCommandLine, createBalanceSettings_smDefaults_edgeWeightMultiplier) { const stk::balance::BalanceSettings& balanceSettings = get_stk_balance_settings({"--sm", - "--EXP-edge-weight-mult=3"}); + "--edge-weight-mult=3"}); EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::smFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); - EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::vertexWeightMethod); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::smVertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::smFaceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::smFaceSearchVertexMultiplier); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), 3); @@ -808,19 +846,19 @@ TEST_F(BalanceCommandLine, createBalanceSettings_smDefaults_edgeWeightMultiplier TEST_F(BalanceCommandLine, createBalanceSettings_sdDefaults_edgeWeightMultiplier) { const stk::balance::BalanceSettings& balanceSettings = get_stk_balance_settings({"--sd", - "--EXP-edge-weight-mult=3"}); + "--edge-weight-mult=3"}); EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), true); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::sdFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); - EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::vertexWeightMethod); - EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); - EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::sdVertexWeightMethod); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::sdFaceSearchEdgeWeight); + EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::sdFaceSearchVertexMultiplier); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), 3); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -853,12 +891,14 @@ TEST_F(BalanceCommandLine, disableSearch_default_caseInsensitive) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), false); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::vertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::graphEdgeWeightMultiplier); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -869,11 +909,13 @@ TEST_F(BalanceCommandLine, disableSearch_smDefaults) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), false); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::smFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::smVertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::smFaceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::smFaceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::smGraphEdgeWeightMultiplier); check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -885,12 +927,14 @@ TEST_F(BalanceCommandLine, disableSearch_sdDefaults) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), false); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), true); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::sdFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); - EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); - EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::sdVertexWeightMethod); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::sdFaceSearchEdgeWeight); + EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::sdFaceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::sdGraphEdgeWeightMultiplier); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -902,12 +946,14 @@ TEST_F(BalanceCommandLine, enableSearch_default_caseInsensitive) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::vertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::graphEdgeWeightMultiplier); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -926,11 +972,13 @@ TEST_F(BalanceCommandLine, enableSearch_smDefaults) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::smFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::smVertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::smFaceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::smFaceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::smGraphEdgeWeightMultiplier); check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -942,12 +990,14 @@ TEST_F(BalanceCommandLine, enableSearch_sdDefaults) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), true); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::sdFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); - EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); - EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::sdVertexWeightMethod); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::sdFaceSearchEdgeWeight); + EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::sdFaceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::sdGraphEdgeWeightMultiplier); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -994,29 +1044,32 @@ TEST_F(BalanceCommandLine, decompMethodParmetis) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::vertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::graphEdgeWeightMultiplier); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } TEST_F(BalanceCommandLine, vertexWeightMethodConnectivity) { - const stk::balance::BalanceSettings& balanceSettings = get_stk_balance_settings({"--EXP-vertex-weight-method=connectivity"}); + const stk::balance::BalanceSettings& balanceSettings = get_stk_balance_settings({"--vertex-weight-method=connectivity"}); EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); EXPECT_EQ(balanceSettings.getVertexWeightMethod(), VertexWeightMethod::CONNECTIVITY); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), 10.); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {}); } @@ -1028,12 +1081,14 @@ TEST_F(BalanceCommandLine, userSpecifiedBlockMultiplier_default) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::vertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::graphEdgeWeightMultiplier); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {{"block_1", 2.0}}); } @@ -1045,11 +1100,13 @@ TEST_F(BalanceCommandLine, userSpecifiedBlockMultiplier_smDefaults) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::smFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::smVertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::smFaceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::smFaceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::smGraphEdgeWeightMultiplier); check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {{"block_2", 10.0}, {"block_5", 2.5}}); } @@ -1062,12 +1119,14 @@ TEST_F(BalanceCommandLine, userSpecifiedWeights_sdDefaults) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), true); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::sdFixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); - EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); - EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::sdVertexWeightMethod); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::sdFaceSearchEdgeWeight); + EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::sdFaceSearchVertexMultiplier); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::sdGraphEdgeWeightMultiplier); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {{"block_1", 1.5}, {"block_2", 5.0}}); } @@ -1079,12 +1138,14 @@ TEST_F(BalanceCommandLine, userSpecifiedBlockMultiplier_badFormatting) EXPECT_EQ(balanceSettings.getDecompMethod(), DefaultSettings::decompMethod); EXPECT_EQ(balanceSettings.get_use_nested_decomp(), false); EXPECT_EQ(balanceSettings.includeSearchResultsInGraph(), DefaultSettings::useContactSearch); - EXPECT_EQ(balanceSettings.shouldFixSpiders(), false); - EXPECT_EQ(balanceSettings.shouldFixMechanisms(), true); + EXPECT_EQ(balanceSettings.shouldFixSpiders(), DefaultSettings::fixSpiders); + EXPECT_EQ(balanceSettings.shouldFixMechanisms(), false); EXPECT_EQ(balanceSettings.shouldPrintDiagnostics(), false); + EXPECT_EQ(balanceSettings.getVertexWeightMethod(), (VertexWeightMethod)DefaultSettings::vertexWeightMethod); EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightForSearch(), DefaultSettings::faceSearchEdgeWeight); EXPECT_DOUBLE_EQ(balanceSettings.getVertexWeightMultiplierForVertexInSearch(), DefaultSettings::faceSearchVertexMultiplier); - check_absolute_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchAbsTol); + EXPECT_DOUBLE_EQ(balanceSettings.getGraphEdgeWeightMultiplier(), DefaultSettings::graphEdgeWeightMultiplier); + check_relative_tolerance_for_face_search(balanceSettings, DefaultSettings::faceSearchRelTol); check_vertex_weight_block_multiplier(balanceSettings, {{"block_1", 1.5}, {"block_2", 3.0}, {"block_3", 1.1}}); } diff --git a/packages/stk/stk_unit_tests/stk_balance/UnitTestDiagnosticsComputation.cpp b/packages/stk/stk_unit_tests/stk_balance/UnitTestDiagnosticsComputation.cpp index fb7714ceba5c..c3a0d1e3dc75 100644 --- a/packages/stk/stk_unit_tests/stk_balance/UnitTestDiagnosticsComputation.cpp +++ b/packages/stk/stk_unit_tests/stk_balance/UnitTestDiagnosticsComputation.cpp @@ -276,9 +276,9 @@ TEST_F(TestDiagnosticsComputation, ElementCount_Balance_HexMesh_GraphPartitioner std::vector expectedValues; if (get_parallel_size() == 1) { expectedValues = {3}; } - else if (get_parallel_size() == 2) { expectedValues = {1, 2}; } + else if (get_parallel_size() == 2) { expectedValues = {2, 1}; } else if (get_parallel_size() == 3) { expectedValues = {1, 1, 1}; } - else if (get_parallel_size() == 4) { expectedValues = {0, 1, 1, 1}; } + else if (get_parallel_size() == 4) { expectedValues = {1, 1, 0, 1}; } test_diag_values(expectedValues); } @@ -297,9 +297,9 @@ TEST_F(TestDiagnosticsComputation, ElementCount_Balance_HexPyramidTetMesh_Geomet std::vector expectedValues; if (get_parallel_size() == 1) { expectedValues = {6}; } - else if (get_parallel_size() == 2) { expectedValues = {1, 5}; } - else if (get_parallel_size() == 3) { expectedValues = {1, 2, 3}; } - else if (get_parallel_size() == 4) { expectedValues = {1, 0, 4, 1}; } + else if (get_parallel_size() == 2) { expectedValues = {3, 3}; } + else if (get_parallel_size() == 3) { expectedValues = {1, 3, 2}; } + else if (get_parallel_size() == 4) { expectedValues = {1, 2, 2, 1}; } test_diag_values(expectedValues); } @@ -317,9 +317,9 @@ TEST_F(TestDiagnosticsComputation, ElementCount_Balance_HexPyramidTetMesh_GraphP std::vector expectedValues; if (get_parallel_size() == 1) { expectedValues = {6}; } - else if (get_parallel_size() == 2) { expectedValues = {1, 5}; } + else if (get_parallel_size() == 2) { expectedValues = {3, 3}; } else if (get_parallel_size() == 3) { expectedValues = {1, 1, 4}; } - else if (get_parallel_size() == 4) { expectedValues = {0, 1, 1, 4}; } + else if (get_parallel_size() == 4) { expectedValues = {1, 2, 1, 2}; } test_diag_values(expectedValues); } @@ -364,9 +364,9 @@ TEST_P(RebalanceNumOutputProcs, ElementCount_Rebalance_HexMesh_GraphPartitioner) std::vector expectedValues; if (GetParam() == 1) { expectedValues = {3}; } - else if (GetParam() == 2) { expectedValues = {1, 2}; } + else if (GetParam() == 2) { expectedValues = {2, 1}; } else if (GetParam() == 3) { expectedValues = {1, 1, 1}; } - else if (GetParam() == 4) { expectedValues = {0, 1, 1, 1}; } + else if (GetParam() == 4) { expectedValues = {1, 1, 0, 1}; } test_diag_values(expectedValues); } @@ -386,9 +386,9 @@ TEST_P(RebalanceNumOutputProcs, ElementCount_Rebalance_HexPyramidTetMesh_Geometr std::vector expectedValues; if (GetParam() == 1) { expectedValues = {6}; } - else if (GetParam() == 2) { expectedValues = {1, 5}; } - else if (GetParam() == 3) { expectedValues = {1, 2, 3}; } - else if (GetParam() == 4) { expectedValues = {1, 0, 4, 1}; } + else if (GetParam() == 2) { expectedValues = {3, 3}; } + else if (GetParam() == 3) { expectedValues = {1, 4, 1}; } + else if (GetParam() == 4) { expectedValues = {1, 2, 2, 1}; } test_diag_values(expectedValues); } @@ -408,9 +408,9 @@ TEST_P(RebalanceNumOutputProcs, ElementCount_Rebalance_HexPyramidTetMesh_GraphPa std::vector expectedValues; if (GetParam() == 1) { expectedValues = {6}; } - else if (GetParam() == 2) { expectedValues = {1, 5}; } + else if (GetParam() == 2) { expectedValues = {3, 3}; } else if (GetParam() == 3) { expectedValues = {1, 1, 4}; } - else if (GetParam() == 4) { expectedValues = {0, 1, 1, 4}; } + else if (GetParam() == 4) { expectedValues = {1, 2, 1, 2}; } test_diag_values(expectedValues); } @@ -428,12 +428,11 @@ TEST_F(TestDiagnosticsComputation, TotalElementWeight_Balance_HexMesh_GeometricP stk::balance::balanceStkMesh(balanceSettings, get_bulk()); - const unsigned hexWeight = balanceSettings.getGraphVertexWeight(stk::topology::HEX_8); std::vector expectedValues; if (get_parallel_size() == 1) { expectedValues = {0}; } - else if (get_parallel_size() == 2) { expectedValues = {2*hexWeight, 1*hexWeight}; } - else if (get_parallel_size() == 3) { expectedValues = {1*hexWeight, 1*hexWeight, 1*hexWeight}; } - else if (get_parallel_size() == 4) { expectedValues = {1*hexWeight, 1*hexWeight, 1*hexWeight, 0}; } + else if (get_parallel_size() == 2) { expectedValues = {2, 1}; } + else if (get_parallel_size() == 3) { expectedValues = {1, 1, 1}; } + else if (get_parallel_size() == 4) { expectedValues = {1, 1, 1, 0}; } test_diag_multi_values(0, expectedValues); } @@ -449,12 +448,11 @@ TEST_F(TestDiagnosticsComputation, TotalElementWeight_Balance_HexMesh_GraphParti stk::balance::balanceStkMesh(balanceSettings, get_bulk()); - const unsigned hexWeight = balanceSettings.getGraphVertexWeight(stk::topology::HEX_8); std::vector expectedValues; if (get_parallel_size() == 1) { expectedValues = {0}; } - else if (get_parallel_size() == 2) { expectedValues = {1*hexWeight, 2*hexWeight}; } - else if (get_parallel_size() == 3) { expectedValues = {1*hexWeight, 1*hexWeight, 1*hexWeight}; } - else if (get_parallel_size() == 4) { expectedValues = {0*hexWeight, 1*hexWeight, 1*hexWeight, 1*hexWeight}; } + else if (get_parallel_size() == 2) { expectedValues = {2, 1}; } + else if (get_parallel_size() == 3) { expectedValues = {1, 1, 1}; } + else if (get_parallel_size() == 4) { expectedValues = {1, 1, 0, 1}; } test_diag_multi_values(0, expectedValues); } @@ -471,14 +469,11 @@ TEST_F(TestDiagnosticsComputation, TotalElementWeight_Balance_HexPyramidTetMesh_ stk::balance::balanceStkMesh(balanceSettings, get_bulk()); - const unsigned hexWeight = balanceSettings.getGraphVertexWeight(stk::topology::HEX_8); - const unsigned pyrWeight = balanceSettings.getGraphVertexWeight(stk::topology::PYRAMID_5); - const unsigned tetWeight = balanceSettings.getGraphVertexWeight(stk::topology::TET_4); std::vector expectedValues; if (get_parallel_size() == 1) { expectedValues = {0}; } - else if (get_parallel_size() == 2) { expectedValues = {1*hexWeight, 1*pyrWeight+4*tetWeight}; } - else if (get_parallel_size() == 3) { expectedValues = {1*hexWeight, 1*pyrWeight+1*tetWeight, 3*tetWeight}; } - else if (get_parallel_size() == 4) { expectedValues = {1*hexWeight, 0, 1*pyrWeight+3*tetWeight, 1*tetWeight}; } + else if (get_parallel_size() == 2) { expectedValues = {3, 3}; } + else if (get_parallel_size() == 3) { expectedValues = {1, 3, 2}; } + else if (get_parallel_size() == 4) { expectedValues = {1, 2, 2, 1}; } test_diag_multi_values(0, expectedValues); } @@ -494,14 +489,11 @@ TEST_F(TestDiagnosticsComputation, TotalElementWeight_Balance_HexPyramidTetMesh_ stk::balance::balanceStkMesh(balanceSettings, get_bulk()); - const unsigned hexWeight = balanceSettings.getGraphVertexWeight(stk::topology::HEX_8); - const unsigned pyrWeight = balanceSettings.getGraphVertexWeight(stk::topology::PYRAMID_5); - const unsigned tetWeight = balanceSettings.getGraphVertexWeight(stk::topology::TET_4); std::vector expectedValues; if (get_parallel_size() == 1) { expectedValues = {0}; } - else if (get_parallel_size() == 2) { expectedValues = {1*hexWeight, 1*pyrWeight+4*tetWeight}; } - else if (get_parallel_size() == 3) { expectedValues = {1*hexWeight, 1*tetWeight, 1*pyrWeight+3*tetWeight}; } - else if (get_parallel_size() == 4) { expectedValues = {0, 1*hexWeight, 1*tetWeight, 1*pyrWeight+3*tetWeight}; } + else if (get_parallel_size() == 2) { expectedValues = {3, 3}; } + else if (get_parallel_size() == 3) { expectedValues = {1, 1, 4}; } + else if (get_parallel_size() == 4) { expectedValues = {1, 2, 1, 2}; } test_diag_multi_values(0, expectedValues); } @@ -520,12 +512,11 @@ TEST_P(RebalanceNumOutputProcs, TotalElementWeight_Rebalance_HexMesh_GeometricPa rebalanceMesh(ioBroker, balanceSettings); - const unsigned hexWeight = balanceSettings.getGraphVertexWeight(stk::topology::HEX_8); std::vector expectedValues; if (GetParam() == 1) { expectedValues = {0}; } - else if (GetParam() == 2) { expectedValues = {2*hexWeight, 1*hexWeight}; } - else if (GetParam() == 3) { expectedValues = {1*hexWeight, 1*hexWeight, 1*hexWeight}; } - else if (GetParam() == 4) { expectedValues = {1*hexWeight, 1*hexWeight, 1*hexWeight, 0}; } + else if (GetParam() == 2) { expectedValues = {2, 1}; } + else if (GetParam() == 3) { expectedValues = {1, 1, 1}; } + else if (GetParam() == 4) { expectedValues = {1, 1, 1, 0}; } test_diag_multi_values(0, expectedValues); } @@ -543,12 +534,11 @@ TEST_P(RebalanceNumOutputProcs, TotalElementWeight_Rebalance_HexMesh_GraphPartit rebalanceMesh(ioBroker, balanceSettings); - const unsigned hexWeight = balanceSettings.getGraphVertexWeight(stk::topology::HEX_8); std::vector expectedValues; if (GetParam() == 1) { expectedValues = {0}; } - else if (GetParam() == 2) { expectedValues = {1*hexWeight, 2*hexWeight}; } - else if (GetParam() == 3) { expectedValues = {1*hexWeight, 1*hexWeight, 1*hexWeight}; } - else if (GetParam() == 4) { expectedValues = {0*hexWeight, 1*hexWeight, 1*hexWeight, 1*hexWeight}; } + else if (GetParam() == 2) { expectedValues = {2, 1}; } + else if (GetParam() == 3) { expectedValues = {1, 1, 1}; } + else if (GetParam() == 4) { expectedValues = {1, 1, 0, 1}; } test_diag_multi_values(0, expectedValues); } @@ -566,14 +556,11 @@ TEST_P(RebalanceNumOutputProcs, TotalElementWeight_Rebalance_HexPyramidTetMesh_G rebalanceMesh(ioBroker, balanceSettings); - const unsigned hexWeight = balanceSettings.getGraphVertexWeight(stk::topology::HEX_8); - const unsigned pyrWeight = balanceSettings.getGraphVertexWeight(stk::topology::PYRAMID_5); - const unsigned tetWeight = balanceSettings.getGraphVertexWeight(stk::topology::TET_4); std::vector expectedValues; if (GetParam() == 1) { expectedValues = {0}; } - else if (GetParam() == 2) { expectedValues = {1*hexWeight, 1*pyrWeight+4*tetWeight}; } - else if (GetParam() == 3) { expectedValues = {1*hexWeight, 1*pyrWeight+1*tetWeight, 3*tetWeight}; } - else if (GetParam() == 4) { expectedValues = {1*hexWeight, 0, 1*pyrWeight+3*tetWeight, 1*tetWeight}; } + else if (GetParam() == 2) { expectedValues = {3, 3}; } + else if (GetParam() == 3) { expectedValues = {1, 4, 1}; } + else if (GetParam() == 4) { expectedValues = {1, 2, 2, 1}; } test_diag_multi_values(0, expectedValues); } @@ -591,14 +578,11 @@ TEST_P(RebalanceNumOutputProcs, TotalElementWeight_Rebalance_HexPyramidTetMesh_G rebalanceMesh(ioBroker, balanceSettings); - const unsigned hexWeight = balanceSettings.getGraphVertexWeight(stk::topology::HEX_8); - const unsigned pyrWeight = balanceSettings.getGraphVertexWeight(stk::topology::PYRAMID_5); - const unsigned tetWeight = balanceSettings.getGraphVertexWeight(stk::topology::TET_4); std::vector expectedValues; if (GetParam() == 1) { expectedValues = {0}; } - else if (GetParam() == 2) { expectedValues = {1*hexWeight, 1*pyrWeight+4*tetWeight}; } - else if (GetParam() == 3) { expectedValues = {1*hexWeight, 1*tetWeight, 1*pyrWeight+3*tetWeight}; } - else if (GetParam() == 4) { expectedValues = {0, 1*hexWeight, 1*tetWeight, 1*pyrWeight+3*tetWeight}; } + else if (GetParam() == 2) { expectedValues = {3, 3}; } + else if (GetParam() == 3) { expectedValues = {1, 1, 4}; } + else if (GetParam() == 4) { expectedValues = {1, 2, 1, 2}; } test_diag_multi_values(0, expectedValues); } @@ -790,11 +774,13 @@ TEST_F(TestDiagnosticsComputation, NodeInterfaceSize_Balance_HexPyramidTetMesh_G stk::balance::balanceStkMesh(balanceSettings, get_bulk()); + stk::io::write_mesh("nodeInterfaceSize_balance_hexPyramidTet_geometric.g", get_bulk()); + std::vector expectedValues; if (get_parallel_size() == 1) { expectedValues = {0.0/12.0}; } - else if (get_parallel_size() == 2) { expectedValues = { 4.0/8.0, 4.0/8.0}; } - else if (get_parallel_size() == 3) { expectedValues = { 4.0/8.0, 6.0/6.0, 4.0/6.0}; } - else if (get_parallel_size() == 4) { expectedValues = { 4.0/8.0, 0.0, 6.0/8.0, 6.0/6.0}; } + else if (get_parallel_size() == 2) { expectedValues = { 4.0/10.0, 4.0/6.0}; } + else if (get_parallel_size() == 3) { expectedValues = { 4.0/8.0, 6.0/6.0, 4.0/5.0}; } + else if (get_parallel_size() == 4) { expectedValues = { 4.0/8.0, 6.0/6.0, 4.0/5.0, 4.0/4.0}; } test_diag_values(expectedValues); } @@ -810,11 +796,13 @@ TEST_F(TestDiagnosticsComputation, NodeInterfaceSize_Balance_HexPyramidTetMesh_G stk::balance::balanceStkMesh(balanceSettings, get_bulk()); + stk::io::write_mesh("nodeInterfaceSize_balance_hexPyramidTet_graph.g", get_bulk()); + std::vector expectedValues; if (get_parallel_size() == 1) { expectedValues = {0.0/12.0}; } - else if (get_parallel_size() == 2) { expectedValues = {4.0/8.0, 4.0/8.0}; } - else if (get_parallel_size() == 3) { expectedValues = {4.0/8.0, 4.0/4.0, 6.0/8.0}; } - else if (get_parallel_size() == 4) { expectedValues = { 0.0, 4.0/8.0, 4.0/4.0, 6.0/8.0}; } + else if (get_parallel_size() == 2) { expectedValues = {4.0/10.0, 4.0/6.0}; } + else if (get_parallel_size() == 3) { expectedValues = {4.0/8.0, 4.0/4.0, 6.0/8.0}; } + else if (get_parallel_size() == 4) { expectedValues = {4.0/8.0, 6.0/6.0, 4.0/4.0, 4.0/5.0}; } test_diag_values(expectedValues); } @@ -915,9 +903,9 @@ TEST_F(TestDiagnosticsComputation, ConnectivityWeight_Balance_HexMesh_GeometricP stk::balance::balanceStkMesh(balanceSettings, get_bulk()); const double cornerNode = 8.0; - const double edgeNode = 12.0/2.0; - const double centerNode = 18.0/4.0; - const double elemWeight = 2*cornerNode + 4*edgeNode + 2*centerNode; + const double edgeNode = 12.0; + const double centerNode = 18.0; + const double elemWeight = (2*cornerNode + 4*edgeNode + 2*centerNode)/8; std::vector expectedValues; if (get_parallel_size() == 1) { expectedValues = {4*elemWeight}; } else if (get_parallel_size() == 2) { expectedValues = {2*elemWeight, 2*elemWeight}; } @@ -939,9 +927,9 @@ TEST_F(TestDiagnosticsComputation, ConnectivityWeight_Balance_HexMesh_GraphParti stk::balance::balanceStkMesh(balanceSettings, get_bulk()); const double cornerNode = 8.0; - const double edgeNode = 12.0/2.0; - const double centerNode = 18.0/4.0; - const double elemWeight = 2*cornerNode + 4*edgeNode + 2*centerNode; + const double edgeNode = 12.0; + const double centerNode = 18.0; + const double elemWeight = (2*cornerNode + 4*edgeNode + 2*centerNode)/8; std::vector expectedValues; if (get_parallel_size() == 1) { expectedValues = {4*elemWeight}; } else if (get_parallel_size() == 2) { expectedValues = {2*elemWeight, 2*elemWeight}; } @@ -954,25 +942,28 @@ TEST_F(TestDiagnosticsComputation, ConnectivityWeight_Balance_HexMesh_GraphParti std::tuple get_hex_pyramid_tet_element_connectivity_weights() { - const double node1Weight = 8.0/1.0; - const double node2Weight = 8.0/1.0; - const double node3Weight = 8.0/1.0; - const double node4Weight = 8.0/1.0; - const double node5Weight = 10.0/3.0; - const double node6Weight = 9.0/2.0; - const double node7Weight = 10.0/3.0; - const double node8Weight = 12.0/6.0; - const double node9Weight = 5.0/2.0; - const double node10Weight = 8.0/5.0; - const double node11Weight = 5.0/2.0; - const double node12Weight = 5.0/2.0; - const double elem1Weight = node1Weight + node2Weight + node3Weight + node4Weight + - node5Weight + node6Weight + node7Weight + node8Weight; - const double elem2Weight = node5Weight + node6Weight + node7Weight + node8Weight + node10Weight; - const double elem3Weight = node5Weight + node9Weight + node8Weight + node10Weight; - const double elem4Weight = node8Weight + node9Weight + node12Weight + node10Weight; - const double elem5Weight = node8Weight + node12Weight + node10Weight + node11Weight; - const double elem6Weight = node7Weight + node8Weight + node10Weight + node11Weight; + const double node1Weight = 8.0; + const double node2Weight = 8.0; + const double node3Weight = 8.0; + const double node4Weight = 8.0; + const double node5Weight = 10.0; + const double node6Weight = 9.0; + const double node7Weight = 10.0; + const double node8Weight = 12.0; + const double node9Weight = 5.0; + const double node10Weight = 8.0; + const double node11Weight = 5.0; + const double node12Weight = 5.0; + const double hexElemsPerNode = 1; + const double pyrElemsPerNode = 6.0/2.0; + const double tetElemsPerNode = 6; + const double elem1Weight = (node1Weight + node2Weight + node3Weight + node4Weight + + node5Weight + node6Weight + node7Weight + node8Weight)/8/hexElemsPerNode; + const double elem2Weight = (node5Weight + node6Weight + node7Weight + node8Weight + node10Weight)/5/pyrElemsPerNode; + const double elem3Weight = (node5Weight + node9Weight + node8Weight + node10Weight)/4/tetElemsPerNode; + const double elem4Weight = (node8Weight + node9Weight + node12Weight + node10Weight)/4/tetElemsPerNode; + const double elem5Weight = (node8Weight + node12Weight + node10Weight + node11Weight)/4/tetElemsPerNode; + const double elem6Weight = (node7Weight + node8Weight + node10Weight + node11Weight)/4/tetElemsPerNode; return std::make_tuple(elem1Weight, elem2Weight, elem3Weight, elem4Weight, elem5Weight, elem6Weight); } @@ -994,9 +985,9 @@ TEST_F(TestDiagnosticsComputation, ConnectivityWeight_Balance_HexPyramidTetMesh_ std::vector expectedValues; if (get_parallel_size() == 1) { expectedValues = {e1wt+e2wt+e3wt+e4wt+e5wt+e6wt}; } - else if (get_parallel_size() == 2) { expectedValues = {e1wt, e2wt+e3wt+e4wt+e5wt+e6wt}; } - else if (get_parallel_size() == 3) { expectedValues = {e1wt, e2wt+e3wt, e4wt+e5wt+e6wt}; } - else if (get_parallel_size() == 4) { expectedValues = {e1wt, 0, e2wt+e3wt+e4wt+e5wt, e6wt}; } + else if (get_parallel_size() == 2) { expectedValues = {e1wt+e2wt+e3wt, e4wt+e5wt+e6wt}; } + else if (get_parallel_size() == 3) { expectedValues = {e1wt, e2wt+e3wt+e6wt, e4wt+e5wt}; } + else if (get_parallel_size() == 4) { expectedValues = {e1wt, e2wt+e3wt, e4wt+e5wt, e6wt}; } test_diag_values(expectedValues); } @@ -1017,9 +1008,9 @@ TEST_F(TestDiagnosticsComputation, ConnectivityWeight_Balance_HexPyramidTetMesh_ std::vector expectedValues; if (get_parallel_size() == 1) { expectedValues = {e1wt+e2wt+e3wt+e4wt+e5wt+e6wt}; } - else if (get_parallel_size() == 2) { expectedValues = {e1wt, e2wt+e3wt+e4wt+e5wt+e6wt}; } + else if (get_parallel_size() == 2) { expectedValues = {e1wt+e2wt+e3wt, e4wt+e5wt+e6wt}; } else if (get_parallel_size() == 3) { expectedValues = {e1wt, e6wt, e2wt+e3wt+e4wt+e5wt}; } - else if (get_parallel_size() == 4) { expectedValues = {0, e1wt, e6wt, e2wt+e3wt+e4wt+e5wt}; } + else if (get_parallel_size() == 4) { expectedValues = {e1wt, e2wt+e3wt, e6wt, e4wt+e5wt}; } test_diag_values(expectedValues); } @@ -1037,9 +1028,10 @@ TEST_F(TestDiagnosticsComputation, ConnectivityWeight_Balance_ShellMesh_Geometri stk::balance::balanceStkMesh(balanceSettings, get_bulk()); const double cornerNode = 4.0; - const double edgeNode = 6.0/2.0; - const double centerNode = 9.0/4.0; - const double elemWeight = cornerNode + 2*edgeNode + centerNode; + const double edgeNode = 6.0; + const double centerNode = 9.0; + const double quadShellElemsPerNode = 1.0; + const double elemWeight = (cornerNode + 2*edgeNode + centerNode)/4/quadShellElemsPerNode; std::vector expectedValues; if (get_parallel_size() == 1) { expectedValues = {4*elemWeight}; } else if (get_parallel_size() == 2) { expectedValues = {2*elemWeight, 2*elemWeight}; } @@ -1061,9 +1053,10 @@ TEST_F(TestDiagnosticsComputation, ConnectivityWeight_Balance_ShellMesh_GraphPar stk::balance::balanceStkMesh(balanceSettings, get_bulk()); const double cornerNode = 4.0; - const double edgeNode = 6.0/2.0; - const double centerNode = 9.0/4.0; - const double elemWeight = cornerNode + 2*edgeNode + centerNode; + const double edgeNode = 6.0; + const double centerNode = 9.0; + const double quadShellElemsPerNode = 1.0; + const double elemWeight = (cornerNode + 2*edgeNode + centerNode)/4/quadShellElemsPerNode; std::vector expectedValues; if (get_parallel_size() == 1) { expectedValues = {4*elemWeight}; } else if (get_parallel_size() == 2) { expectedValues = {2*elemWeight, 2*elemWeight}; } @@ -1076,15 +1069,16 @@ TEST_F(TestDiagnosticsComputation, ConnectivityWeight_Balance_ShellMesh_GraphPar std::tuple get_beam_element_connectivity_weights() { - const double node1Weight = 2.0/1.0; - const double node2Weight = 3.0/2.0; - const double node3Weight = 4.0/3.0; - const double node4Weight = 2.0/1.0; - const double node5Weight = 2.0/1.0; - const double elem1Weight = node1Weight + node2Weight; - const double elem2Weight = node2Weight + node3Weight; - const double elem3Weight = node3Weight + node4Weight; - const double elem4Weight = node3Weight + node5Weight; + const double node1Weight = 2.0; + const double node2Weight = 3.0; + const double node3Weight = 4.0; + const double node4Weight = 2.0; + const double node5Weight = 2.0; + const double beamElemsPerNode = 1.0; + const double elem1Weight = (node1Weight + node2Weight)/2/beamElemsPerNode; + const double elem2Weight = (node2Weight + node3Weight)/2/beamElemsPerNode; + const double elem3Weight = (node3Weight + node4Weight)/2/beamElemsPerNode; + const double elem4Weight = (node3Weight + node5Weight)/2/beamElemsPerNode; return std::make_tuple(elem1Weight, elem2Weight, elem3Weight, elem4Weight); } diff --git a/packages/stk/stk_unit_tests/stk_balance/UnitTestGeometricMethodsWithSelector.cpp b/packages/stk/stk_unit_tests/stk_balance/UnitTestGeometricMethodsWithSelector.cpp index 0d9a620c1d0b..0d73eecfa441 100644 --- a/packages/stk/stk_unit_tests/stk_balance/UnitTestGeometricMethodsWithSelector.cpp +++ b/packages/stk/stk_unit_tests/stk_balance/UnitTestGeometricMethodsWithSelector.cpp @@ -19,13 +19,13 @@ class GeometricBalanceSettingsTester : public stk::balance::GraphCreationSetting { public: GeometricBalanceSettingsTester(const std::string& decompMethod) - : method(decompMethod) { } + : m_method(decompMethod) { } virtual ~GeometricBalanceSettingsTester() = default; - virtual std::string getDecompMethod() const { return method; } + virtual std::string getDecompMethod() const { return m_method; } private: - const std::string& method; + const std::string& m_method; }; class ZoltanGeometricMethods : public stk::unit_test_util::simple_fields::MeshFixture diff --git a/packages/stk/stk_unit_tests/stk_balance/UnitTestM2NFileOutput.cpp b/packages/stk/stk_unit_tests/stk_balance/UnitTestM2NFileOutput.cpp index 5b2c71b0b5c8..22853f16cf43 100644 --- a/packages/stk/stk_unit_tests/stk_balance/UnitTestM2NFileOutput.cpp +++ b/packages/stk/stk_unit_tests/stk_balance/UnitTestM2NFileOutput.cpp @@ -98,12 +98,13 @@ TEST_F(M2NFileOutput, CheckSharingInformation) int global_num_elems = counts[stk::topology::ELEM_RANK]; const std::string outputFilename = "TemporaryOutputFile.g"; + stk::io::OutputParams params(get_bulk()); stk::io::write_file_for_subdomain(outputFilename, get_bulk().parallel_rank(), get_bulk().parallel_size(), global_num_nodes, global_num_elems, - get_bulk(), + params, nodeSharingInfo); verify_node_sharing_info(nodeSharingInfo, outputFilename); diff --git a/packages/stk/stk_unit_tests/stk_balance/UnitTestRebalanceFileOutput.cpp b/packages/stk/stk_unit_tests/stk_balance/UnitTestRebalanceFileOutput.cpp index 0ba8c811ca0b..765eeb41b022 100644 --- a/packages/stk/stk_unit_tests/stk_balance/UnitTestRebalanceFileOutput.cpp +++ b/packages/stk/stk_unit_tests/stk_balance/UnitTestRebalanceFileOutput.cpp @@ -101,12 +101,13 @@ TEST_F(RebalanceFileOutput, CheckSharingInformation) int global_num_elems = counts[stk::topology::ELEM_RANK]; const std::string outputFilename = "TemporaryOutputFile.g"; + stk::io::OutputParams params(get_bulk()); stk::io::write_file_for_subdomain(outputFilename, get_bulk().parallel_rank(), get_bulk().parallel_size(), global_num_nodes, global_num_elems, - get_bulk(), + params, nodeSharingInfo); verify_node_sharing_info(nodeSharingInfo, outputFilename); diff --git a/packages/stk/stk_unit_tests/stk_balance/UnitTestSearchTolerance.cpp b/packages/stk/stk_unit_tests/stk_balance/UnitTestSearchTolerance.cpp index 5324580037e5..0f1a82662803 100644 --- a/packages/stk/stk_unit_tests/stk_balance/UnitTestSearchTolerance.cpp +++ b/packages/stk/stk_unit_tests/stk_balance/UnitTestSearchTolerance.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include @@ -116,6 +117,7 @@ TEST_F(SearchToleranceTester, constantTolerance) if(stk::parallel_machine_size(get_comm()) == 1) { stk::balance::GraphCreationSettings balanceSettings; + balanceSettings.setToleranceForFaceSearch(stk::balance::DefaultSettings::faceSearchAbsTol); const unsigned numSelfInteractions = 2; EXPECT_EQ(numSelfInteractions, get_num_search_results_with_app_settings(balanceSettings)); } @@ -126,7 +128,6 @@ TEST_F(SearchToleranceTester, secondShortestEdgeFaceSearchTolerance) if(stk::parallel_machine_size(get_comm()) == 1) { stk::balance::GraphCreationSettings balanceSettings; - balanceSettings.setToleranceFunctionForFaceSearch(std::make_shared()); const unsigned numSelfPlusSymmetricInteractions = 4; EXPECT_EQ(numSelfPlusSymmetricInteractions, get_num_search_results_with_app_settings(balanceSettings)); } diff --git a/packages/stk/stk_unit_tests/stk_coupling/UnitTestOldCommSplitting.cpp b/packages/stk/stk_unit_tests/stk_coupling/UnitTestOldCommSplitting.cpp index 9aa2299bde61..5937a3e161c6 100644 --- a/packages/stk/stk_unit_tests/stk_coupling/UnitTestOldCommSplitting.cpp +++ b/packages/stk/stk_unit_tests/stk_coupling/UnitTestOldCommSplitting.cpp @@ -41,6 +41,7 @@ #include #include +#ifndef STK_HIDE_DEPRECATED_CODE // delete October 2022 namespace { TEST(UnitTestSplitComm, has_split_comm_false_when_same) @@ -141,3 +142,5 @@ TEST(UnitTestSplitComm, calc_my_root_and_other_root_ranks_non_contig_comm) } } + +#endif \ No newline at end of file diff --git a/packages/stk/stk_unit_tests/stk_coupling/UnitTestOldSyncInfo.cpp b/packages/stk/stk_unit_tests/stk_coupling/UnitTestOldSyncInfo.cpp index 0c42ff759053..61d6b55e8d45 100644 --- a/packages/stk/stk_unit_tests/stk_coupling/UnitTestOldSyncInfo.cpp +++ b/packages/stk/stk_unit_tests/stk_coupling/UnitTestOldSyncInfo.cpp @@ -39,6 +39,8 @@ #include #include +#ifndef STK_HIDE_DEPRECATED_CODE // delete October 2022 + namespace { TEST(UnitTestOldSyncInfo, get_and_set) @@ -180,3 +182,4 @@ TEST(UnitTestOldSyncInfo, exchangeAsymmetric) } } +#endif \ No newline at end of file diff --git a/packages/stk/stk_unit_tests/stk_io/UnitTestReadWriteAssemblies.cpp b/packages/stk/stk_unit_tests/stk_io/UnitTestReadWriteAssemblies.cpp index 9a381067fb74..b7fd9ef7fa3b 100644 --- a/packages/stk/stk_unit_tests/stk_io/UnitTestReadWriteAssemblies.cpp +++ b/packages/stk/stk_unit_tests/stk_io/UnitTestReadWriteAssemblies.cpp @@ -127,12 +127,12 @@ TEST_F(Assembly_legacy, readWriteAssembly_simple_emptyblock) const std::vector partNames {"block_1", "block_2"}; stk::mesh::Part& assemblyPart = create_assembly(assemblyName, 10); - stk::mesh::Part& block1Part = create_io_part(partNames[0]); - stk::mesh::Part& block2Part = create_io_part(partNames[1]); + stk::mesh::Part& block1Part = create_io_part(partNames[0], 1); + stk::mesh::Part& block2Part = create_io_part(partNames[1], 2); declare_subsets(assemblyPart, {&block1Part, &block2Part}); stk::io::fill_mesh("generated:2x2x2", get_bulk()); - test_write_then_read_block_assemblies(1); + test_write_then_read_block_assemblies(1, stk::mesh::PartVector{&block2Part}); } TEST_F(Assembly_legacy, readWriteAssembly_simple_emptysurface) @@ -638,12 +638,12 @@ TEST_F(Assembly, readWriteAssembly_simple_emptyblock) const std::vector partNames {"block_1", "block_2"}; stk::mesh::Part& assemblyPart = create_assembly(assemblyName, 10); - stk::mesh::Part& block1Part = create_io_part(partNames[0]); - stk::mesh::Part& block2Part = create_io_part(partNames[1]); + stk::mesh::Part& block1Part = create_io_part(partNames[0], 1); + stk::mesh::Part& block2Part = create_io_part(partNames[1], 2); declare_subsets(assemblyPart, {&block1Part, &block2Part}); stk::io::fill_mesh("generated:2x2x2", get_bulk()); - test_write_then_read_block_assemblies(1); + test_write_then_read_block_assemblies(1, stk::mesh::PartVector{&block2Part}); } TEST_F(Assembly, readWriteAssembly_simple_emptysurface) diff --git a/packages/stk/stk_unit_tests/stk_mesh/CMakeLists.txt b/packages/stk/stk_unit_tests/stk_mesh/CMakeLists.txt index 0f54b1fb7b62..68a60f91d76e 100644 --- a/packages/stk/stk_unit_tests/stk_mesh/CMakeLists.txt +++ b/packages/stk/stk_unit_tests/stk_mesh/CMakeLists.txt @@ -68,6 +68,7 @@ LIST(REMOVE_ITEM SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/UnitTestCommInfoObserver.c LIST(REMOVE_ITEM SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/UnitTestCommunicateFieldData.cpp") LIST(REMOVE_ITEM SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/UnitTestCreateEdges.cpp") LIST(REMOVE_ITEM SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/UnitTestCreateFaces.cpp") +LIST(REMOVE_ITEM SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/UnitTestDeletedEntityCache.cpp") LIST(REMOVE_ITEM SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/UnitTestDeleteEntities.cpp") LIST(REMOVE_ITEM SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/UnitTestDestroyElements.cpp") LIST(REMOVE_ITEM SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/UnitTestDistributedIndexWithBulkData.cpp") diff --git a/packages/stk/stk_unit_tests/stk_mesh/UnitTestBucketRepository.cpp b/packages/stk/stk_unit_tests/stk_mesh/UnitTestBucketRepository.cpp index 09448493d3f0..27369071e656 100644 --- a/packages/stk/stk_unit_tests/stk_mesh/UnitTestBucketRepository.cpp +++ b/packages/stk/stk_unit_tests/stk_mesh/UnitTestBucketRepository.cpp @@ -35,7 +35,7 @@ #include // for AssertHelper, EXPECT_EQ, etc #include // for size_t #include // for BucketRepository -#include // for EntityRepository +#include #include // for Partition #include // for ParallelMachine #include @@ -64,7 +64,7 @@ TEST(BucketRepositoryTest, createBuckets) stkMeshMetaData.commit(); stk::unit_test_util::BulkDataTester stkMeshBulkData(stkMeshMetaData, comm); - stk::mesh::impl::EntityRepository entityRepository; + stk::mesh::impl::EntityKeyMapping entityKeyMapping; stk::mesh::impl::BucketRepository &bucketRepository = stkMeshBulkData.my_get_bucket_repository(); stk::mesh::impl::Partition* partition = bucketRepository.get_or_create_partition(stk::topology::NODE_RANK, parts); @@ -74,7 +74,7 @@ TEST(BucketRepositoryTest, createBuckets) { stk::mesh::EntityId nodeID = i+1; stk::mesh::EntityKey nodeKey(stk::topology::NODE_RANK, nodeID); - std::pair createResult = entityRepository.internal_create_entity(nodeKey); + std::pair createResult = entityKeyMapping.internal_create_entity(nodeKey); bool aNewEntityWasCreated = createResult.second; EXPECT_TRUE(aNewEntityWasCreated); stk::mesh::Entity node = stkMeshBulkData.my_generate_new_entity(); @@ -94,7 +94,7 @@ TEST(BucketRepositoryTest, createBuckets) stk::mesh::EntityId nodeID = numNodes+1; stk::mesh::EntityKey nodeKey(stk::topology::NODE_RANK, nodeID); - std::pair createResult = entityRepository.internal_create_entity(nodeKey); + std::pair createResult = entityKeyMapping.internal_create_entity(nodeKey); bool aNewEntityWasCreated = createResult.second; EXPECT_TRUE(aNewEntityWasCreated); stk::mesh::Entity node = stkMeshBulkData.my_generate_new_entity(); diff --git a/packages/stk/stk_unit_tests/stk_mesh/UnitTestBulkData.cpp b/packages/stk/stk_unit_tests/stk_mesh/UnitTestBulkData.cpp index ca388144ff60..66dac7553940 100644 --- a/packages/stk/stk_unit_tests/stk_mesh/UnitTestBulkData.cpp +++ b/packages/stk/stk_unit_tests/stk_mesh/UnitTestBulkData.cpp @@ -3119,11 +3119,11 @@ TEST(BulkData, ModificationEnd) exodusFileReader.populate_bulk_data(); } - int elementToMove = 3; - int nodeToCheck = 9; + stk::mesh::EntityId elementToDestroy = 3; + stk::mesh::EntityId nodeToCheck = 9; stk::mesh::EntityKey nodeEntityKey(stk::topology::NODE_RANK, nodeToCheck); - stk::mesh::EntityKey entityToMoveKey(stk::topology::ELEMENT_RANK, elementToMove); + stk::mesh::EntityKey entityToDestroyKey(stk::topology::ELEMENT_RANK, elementToDestroy); stk::mesh::EntityCommListInfoVector::const_iterator iter = std::lower_bound(stkMeshBulkData->my_internal_comm_list().begin(), stkMeshBulkData->my_internal_comm_list().end(), @@ -3135,28 +3135,22 @@ TEST(BulkData, ModificationEnd) stkMeshBulkData->modification_begin(); - ASSERT_TRUE( stkMeshBulkData->is_valid(stkMeshBulkData->get_entity(entityToMoveKey))); + ASSERT_TRUE( stkMeshBulkData->is_valid(stkMeshBulkData->get_entity(entityToDestroyKey))); if(stkMeshBulkData->parallel_rank() == 1) { - stkMeshBulkData->destroy_entity(stkMeshBulkData->get_entity(entityToMoveKey)); + stkMeshBulkData->destroy_entity(stkMeshBulkData->get_entity(entityToDestroyKey)); } - // Really testing destroy_entity - stkMeshBulkData->my_delete_shared_entities_which_are_no_longer_in_owned_closure(); + stkMeshBulkData->modification_end(); - iter = std::lower_bound(stkMeshBulkData->my_internal_comm_list().begin(), stkMeshBulkData->my_internal_comm_list().end(), nodeEntityKey); - - ASSERT_TRUE(iter != stkMeshBulkData->my_internal_comm_list().end()); - EXPECT_EQ(nodeEntityKey, iter->key); - - if(stkMeshBulkData->parallel_rank() == 0) - { - EXPECT_TRUE(stkMeshBulkData->is_valid(iter->entity)); + stk::mesh::Entity nodeEntity = stkMeshBulkData->get_entity(nodeEntityKey); + if (stkMeshBulkData->parallel_rank() == 0) { + EXPECT_TRUE(stkMeshBulkData->is_valid(nodeEntity)); + EXPECT_FALSE(stkMeshBulkData->in_shared(nodeEntity)); } - else - { - EXPECT_FALSE(stkMeshBulkData->is_valid(iter->entity)); + else { + EXPECT_FALSE(stkMeshBulkData->is_valid(nodeEntity)); } std::vector globalCounts; diff --git a/packages/stk/stk_unit_tests/stk_mesh/UnitTestDeleteEntities.cpp b/packages/stk/stk_unit_tests/stk_mesh/UnitTestDeleteEntities.cpp index b8ffda2b67c1..4bc7b59a5d99 100644 --- a/packages/stk/stk_unit_tests/stk_mesh/UnitTestDeleteEntities.cpp +++ b/packages/stk/stk_unit_tests/stk_mesh/UnitTestDeleteEntities.cpp @@ -5,6 +5,8 @@ #include #include "stk_mesh/base/FEMHelpers.hpp" #include "stk_mesh/base/GetEntities.hpp" +#include "stk_mesh/base/Types.hpp" +#include "stk_unit_test_utils/BulkDataTester.hpp" namespace { @@ -240,5 +242,4 @@ TEST_F(SingleHexMesh, DISABLED_CreateFacesThenCreateAnotherElement_ConnectivityI expect_one_face_connected_to_two_elements(); } } - } diff --git a/packages/stk/stk_unit_tests/stk_mesh/UnitTestDeletedEntityCache.cpp b/packages/stk/stk_unit_tests/stk_mesh/UnitTestDeletedEntityCache.cpp new file mode 100644 index 000000000000..959275ccb0fc --- /dev/null +++ b/packages/stk/stk_unit_tests/stk_mesh/UnitTestDeletedEntityCache.cpp @@ -0,0 +1,150 @@ +#include "gtest/gtest.h" +#include "stk_mesh/baseImpl/DeletedEntityCache.hpp" +#include "stk_unit_test_utils/MeshFixture.hpp" +#include "stk_unit_test_utils/TextMesh.hpp" + +namespace { + +class DeletedEntityCacheTester : public stk::unit_test_util::MeshFixture +{ + protected: + DeletedEntityCacheTester() : + MeshFixture(3), + cache() + { + std::string meshDesc = + "0,1,HEX_8,1,2,3,4,5,6,7,8\n\ + 0,2,HEX_8,2,9,10,3,6,11,12,7"; + setup_empty_mesh(stk::mesh::BulkData::NO_AUTO_AURA); + stk::unit_test_util::simple_fields::setup_text_mesh(get_bulk(), meshDesc); + + stk::mesh::BucketVector const& buckets = bulkData->get_buckets(stk::topology::NODE_RANK, metaData->universal_part()); + for (auto& bucket : buckets) + { + for (const stk::mesh::Entity& entity : *bucket) + { + nodes.push_back(entity); + max_local_offset = std::max(entity.local_offset(), max_local_offset); + } + } + + cache = std::make_shared(*bulkData); + } + + std::vector get_ghost_entity_counts() + { + auto& ghostReuseMap = cache->get_ghost_reuse_map(); + std::vector usedValuesCount(max_local_offset+1, 0); + for (auto& keyOffsetPair : ghostReuseMap) + { + usedValuesCount[keyOffsetPair.second] += 1; + } + + return usedValuesCount; + } + + std::shared_ptr cache; + std::vector nodes; + stk::mesh::Entity::entity_value_type max_local_offset = 0; + +}; + +} + +TEST_F(DeletedEntityCacheTester, mark_entity_as_deleted_nonghost) +{ + if (get_parallel_size() != 1) + { + GTEST_SKIP(); + } + + for (int i=0; i < 3; ++i) + { + cache->mark_entity_as_deleted(nodes[i], false); + } + + auto& ghost_reuse_map = cache->get_ghost_reuse_map(); + EXPECT_EQ(ghost_reuse_map.size(), 0u); + + auto& deleted_entities = cache->get_deleted_entities_current_mod_cycle(); + EXPECT_EQ(deleted_entities.size(), 3u); + for (int i=0; i < 3; ++i) + { + EXPECT_EQ(deleted_entities[i], nodes[i].local_offset()); + } +} + +TEST_F(DeletedEntityCacheTester, mark_entity_as_deleted_ghost) +{ + if (get_parallel_size() != 1) + { + GTEST_SKIP(); + } + + for (int i=0; i < 3; ++i) + { + cache->mark_entity_as_deleted(nodes[i], true); + } + + EXPECT_EQ(cache->get_deleted_entities_current_mod_cycle().size(), 0u); + EXPECT_EQ(cache->get_ghost_reuse_map().size(), 3u); + + auto usedGhosts = get_ghost_entity_counts(); + for (size_t i=0; i < nodes.size(); ++i) + { + size_t expected_val = i < 3 ? 1 : 0; + EXPECT_EQ(usedGhosts[nodes[i].local_offset()], expected_val); + } +} + +TEST_F(DeletedEntityCacheTester, get_entity_for_reuse_initial) +{ + if (get_parallel_size() != 1) + { + GTEST_SKIP(); + } + + EXPECT_EQ(cache->get_entity_for_reuse(), stk::mesh::Entity::InvalidEntity); +} + +TEST_F(DeletedEntityCacheTester, update_deleted_entities_container) +{ + if (get_parallel_size() != 1) + { + GTEST_SKIP(); + } + + std::vector destroyedEntities; + for (int i=0; i < 5; ++i) + { + bool isGhost = i < 3; + cache->mark_entity_as_deleted(nodes[i], isGhost); + destroyedEntities.push_back(nodes[i].local_offset()); + } + std::sort(destroyedEntities.begin(), destroyedEntities.end()); + + EXPECT_EQ(cache->get_entity_for_reuse(), stk::mesh::Entity::InvalidEntity); + cache->update_deleted_entities_container(); + + std::vector reusedEntities; + for (int i=0; i < 5; ++i) + { + reusedEntities.push_back(cache->get_entity_for_reuse()); + } + std::sort(reusedEntities.begin(), reusedEntities.end()); + + for (int i=0; i < 5; ++i) + { + EXPECT_EQ(destroyedEntities[i], reusedEntities[i]); + } + + for (int i=0; i < 10; ++i) + { + EXPECT_EQ(cache->get_entity_for_reuse(), stk::mesh::Entity::InvalidEntity); + } +} + + + + + diff --git a/packages/stk/stk_unit_tests/stk_mesh/UnitTestEntity.cpp b/packages/stk/stk_unit_tests/stk_mesh/UnitTestEntity.cpp index 76b678f2083e..9981194fe82d 100644 --- a/packages/stk/stk_unit_tests/stk_mesh/UnitTestEntity.cpp +++ b/packages/stk/stk_unit_tests/stk_mesh/UnitTestEntity.cpp @@ -45,7 +45,6 @@ namespace stk { namespace mesh { class Bucket; } } namespace stk { namespace mesh { class BulkData; } } namespace stk { namespace mesh { class MetaData; } } namespace stk { namespace mesh { class Part; } } -namespace stk { namespace mesh { namespace impl { class EntityRepository; } } } namespace stk { namespace mesh { namespace impl { class PartRepository; } } } namespace stk { namespace mesh { struct Entity; } } @@ -58,7 +57,6 @@ using stk::mesh::EntityKey; using stk::mesh::Entity; using stk::mesh::Bucket; using stk::mesh::impl::PartRepository; -using stk::mesh::impl::EntityRepository; namespace { diff --git a/packages/stk/stk_unit_tests/stk_mesh/UnitTestEntityCommDatabase.cpp b/packages/stk/stk_unit_tests/stk_mesh/UnitTestEntityCommDatabase.cpp index 171def564192..3fded4a5c568 100644 --- a/packages/stk/stk_unit_tests/stk_mesh/UnitTestEntityCommDatabase.cpp +++ b/packages/stk/stk_unit_tests/stk_mesh/UnitTestEntityCommDatabase.cpp @@ -46,15 +46,22 @@ TEST(EntityCommDatabase, testCommMapChangeListener) stk::mesh::EntityCommDatabase commDB; stk::mesh::EntityCommListInfoVector comm_list; std::vector entityComms(200); - stk::mesh::CommListUpdater comm_list_updater(comm_list, entityComms); + std::vector> removedGhosts; + stk::mesh::CommListUpdater comm_list_updater(comm_list, entityComms, removedGhosts); commDB.setCommMapChangeListener(&comm_list_updater); int owner = 0; stk::mesh::EntityKey key(stk::topology::NODE_RANK, 99); unsigned ghost_id = 3; - int proc = 4; - stk::mesh::EntityCommInfo value(ghost_id, proc); - commDB.insert(key, value, owner); + commDB.insert(key, stk::mesh::EntityCommInfo(ghost_id, 2), owner); + commDB.insert(key, stk::mesh::EntityCommInfo(ghost_id, 3), owner); + commDB.insert(key, stk::mesh::EntityCommInfo(ghost_id, 4), owner); + + EXPECT_FALSE(commDB.erase(key, stk::mesh::EntityCommInfo(ghost_id, 1))); + EXPECT_TRUE(commDB.erase(key, stk::mesh::EntityCommInfo(ghost_id, 3))); + EXPECT_EQ(1u, removedGhosts.size()); + EXPECT_TRUE(commDB.erase(key, stk::mesh::EntityCommInfo(ghost_id, 2))); + EXPECT_EQ(2u, removedGhosts.size()); //CommListUpdater only manages removing entries from comm-list, //so we must add an entry manually to set up the test. diff --git a/packages/stk/stk_unit_tests/stk_mesh/UnitTestEntityProcMapping.cpp b/packages/stk/stk_unit_tests/stk_mesh/UnitTestEntityProcMapping.cpp index a9c0f1c718d2..f3945d51e131 100644 --- a/packages/stk/stk_unit_tests/stk_mesh/UnitTestEntityProcMapping.cpp +++ b/packages/stk/stk_unit_tests/stk_mesh/UnitTestEntityProcMapping.cpp @@ -70,14 +70,14 @@ TEST(EntityAndProcs, find_proc_multiple) TEST(EntityProcMapping, basic) { - if (stk::parallel_machine_size(MPI_COMM_WORLD) > 4) { + if (stk::parallel_machine_size(MPI_COMM_WORLD) > 3) { return; } const unsigned spatialDim = 3; std::shared_ptr bulkPtr = build_mesh(spatialDim, MPI_COMM_WORLD); stk::mesh::BulkData& bulk = *bulkPtr; - stk::io::fill_mesh("generated:4x4x4",bulk); + stk::io::fill_mesh("generated:1x2x3",bulk); if (stk::parallel_machine_rank(MPI_COMM_WORLD) != 0) { return; @@ -127,11 +127,8 @@ TEST(EntityProcMapping, basic) EXPECT_EQ(3u, entityProcVec.size()); } -TEST(EntityProcMapping, add_two_remove_one_then_other_still_found) +void test_add_two_remove_one_then_other_still_found(stk::mesh::EntityProcMapping& mapping, stk::mesh::Entity entity) { - stk::mesh::Entity entity(1); - const unsigned arbitraryMaxNumEntities = 10; - stk::mesh::EntityProcMapping mapping(arbitraryMaxNumEntities); mapping.addEntityProc(entity, 0); mapping.addEntityProc(entity, 2); EXPECT_TRUE(mapping.find(entity,0)); @@ -139,6 +136,31 @@ TEST(EntityProcMapping, add_two_remove_one_then_other_still_found) mapping.eraseEntityProc(entity,2); EXPECT_TRUE(mapping.find(entity,0)); + EXPECT_FALSE(mapping.find(entity,2)); +} + +TEST(EntityProcMapping, add_two_remove_one_then_other_still_found) +{ + stk::mesh::Entity entity(1); + const unsigned arbitraryMaxNumEntities = 10; + stk::mesh::EntityProcMapping mapping(arbitraryMaxNumEntities); + test_add_two_remove_one_then_other_still_found(mapping, entity); +} + +TEST(EntityProcMapping, add_two_remove_one_then_other_still_found_with_reset) +{ + stk::mesh::Entity entity(1); + const unsigned arbitraryMaxNumEntities = 10; + stk::mesh::EntityProcMapping mapping(arbitraryMaxNumEntities); + test_add_two_remove_one_then_other_still_found(mapping, entity); + + const unsigned largerMaxNumEntities = 128; + mapping.reset(largerMaxNumEntities); + EXPECT_FALSE(mapping.find(entity,0)); + EXPECT_FALSE(mapping.find(entity,2)); + EXPECT_FALSE(mapping.find(entity)); + + test_add_two_remove_one_then_other_still_found(mapping, entity); } TEST(EntityProcMapping, erase_nonexisting_then_previous_proc_still_found) @@ -153,3 +175,18 @@ TEST(EntityProcMapping, erase_nonexisting_then_previous_proc_still_found) EXPECT_TRUE(mapping.find(entity,0)); } +TEST(EntityProcMapping, visitEntityProcs) +{ + stk::mesh::Entity entity1(1), entity2(2); + const unsigned arbitraryMaxNumEntities = 10; + stk::mesh::EntityProcMapping mapping(arbitraryMaxNumEntities); + mapping.addEntityProc(entity1, 2); + mapping.addEntityProc(entity2, 1); + mapping.addEntityProc(entity2, 3); + + std::vector gold = {stk::mesh::EntityProc(entity1,2),stk::mesh::EntityProc(entity2,1),stk::mesh::EntityProc(entity2,3)}; + std::vector entityProcs; + mapping.visit_entity_procs([&](stk::mesh::Entity entity, int proc){entityProcs.push_back(stk::mesh::EntityProc(entity,proc));}); + EXPECT_EQ(gold, entityProcs); +} + diff --git a/packages/stk/stk_unit_tests/stk_mesh/UnitTestMetaData.cpp b/packages/stk/stk_unit_tests/stk_mesh/UnitTestMetaData.cpp index 08e5604aed9d..1bce4f059ad1 100644 --- a/packages/stk/stk_unit_tests/stk_mesh/UnitTestMetaData.cpp +++ b/packages/stk/stk_unit_tests/stk_mesh/UnitTestMetaData.cpp @@ -63,8 +63,6 @@ using stk::mesh::Part; using stk::mesh::PartVector; using stk::mesh::EntityRank; using stk::mesh::MeshBuilder; -using std::cout; -using std::endl; namespace { @@ -164,11 +162,10 @@ TEST( UnitTestMetaData, rankHigherThanDefined ) ); } -TEST( UnitTestMetaData, testEntityRepository ) +TEST( UnitTestMetaData, testEntityKeyMapping ) { static const size_t spatial_dimension = 3; - //Test Entity repository - covering EntityRepository.cpp/hpp stk::mesh::MetaData meta ( spatial_dimension ); meta.use_simple_fields(); stk::mesh::Part & part = meta.declare_part("another part"); diff --git a/packages/stk/stk_unit_tests/stk_util/parallel/UnitTestCommSparse.cpp b/packages/stk/stk_unit_tests/stk_util/parallel/UnitTestCommSparse.cpp index 99bb6ad5c834..2a6d49a87d8f 100644 --- a/packages/stk/stk_unit_tests/stk_util/parallel/UnitTestCommSparse.cpp +++ b/packages/stk/stk_unit_tests/stk_util/parallel/UnitTestCommSparse.cpp @@ -33,17 +33,18 @@ // #include "gtest/gtest.h" +#include "stk_util/stk_config.h" // for STK_HAS_MPI +#if defined ( STK_HAS_MPI ) + #include "stk_util/parallel/CommSparse.hpp" // for CommSparse, comm_recv_msg_sizes, comm_recv... #include "stk_util/parallel/Parallel.hpp" // for parallel_machine_rank, parallel_machine_size #include "stk_util/parallel/ParallelComm.hpp" // for CommBuffer -#include "stk_util/stk_config.h" // for STK_HAS_MPI #include "stk_util/util/ReportHandler.hpp" // for ThrowRequireMsg #include // for allocator_traits<>::value_type #include // for basic_ostream::operator<<, operator<<, bas... #include // for vector -#if defined ( STK_HAS_MPI ) - +#ifndef STK_HIDE_DEPRECATED_CODE // delete after August 2022 TEST(ParallelComm, comm_recv_msg_sizes) { MPI_Comm comm = MPI_COMM_WORLD; @@ -107,6 +108,7 @@ TEST(ParallelComm, comm_recv_procs_and_msg_sizes) } } } +#endif TEST(ParallelComm, CommSparse_pair_with_string) { @@ -216,21 +218,6 @@ TEST(ParallelComm, CommSparse_all_including_self) srcBuf.unpack(msg); EXPECT_EQ(msg, srcProc); } - - commSparse.swap_send_recv(); - commSparse.reset_buffers(); - - for(int destProc=0; destProc(srcBuf.remaining())); - int msg; - srcBuf.unpack(msg); - EXPECT_EQ(msg, myProc); - } } TEST(ParallelComm, CommSparse_set_procs) diff --git a/packages/stk/stk_unit_tests/stk_util/parallel/UnitTestCouplingVersions.cpp b/packages/stk/stk_unit_tests/stk_util/parallel/UnitTestCouplingVersions.cpp index f6d66953a16d..26b6c60ae6b1 100644 --- a/packages/stk/stk_unit_tests/stk_util/parallel/UnitTestCouplingVersions.cpp +++ b/packages/stk/stk_unit_tests/stk_util/parallel/UnitTestCouplingVersions.cpp @@ -25,14 +25,14 @@ class CouplingVersionsTester : public ::testing::Test TEST_F(CouplingVersionsTester, CompatibileRangeGetter) { - EXPECT_EQ(stk::util::get_local_max_coupling_version(), STK_MAX_COUPLING_VERSION); + EXPECT_EQ(stk::util::get_local_max_coupling_version(), stk::util::impl::SHORT_TERM_STK_MAX_COUPLING_VERSION /*STK_MAX_COUPLING_VERSION*/); EXPECT_EQ(stk::util::get_local_min_coupling_version(), STK_MIN_COUPLING_VERSION); } TEST_F(CouplingVersionsTester, DefaultVersion) { EXPECT_EQ(stk::util::get_common_coupling_version(), STK_MAX_COUPLING_VERSION); - EXPECT_EQ(stk::util::get_global_max_coupling_version(), STK_MAX_COUPLING_VERSION); + EXPECT_EQ(stk::util::get_global_max_coupling_version(), stk::util::impl::SHORT_TERM_STK_MAX_COUPLING_VERSION /*STK_MAX_COUPLING_VERSION*/); } TEST_F(CouplingVersionsTester, NewVersion) @@ -48,7 +48,7 @@ TEST_F(CouplingVersionsTester, OldVersion) { stk::util::impl::set_coupling_version(STK_MIN_COUPLING_VERSION); EXPECT_EQ(stk::util::get_common_coupling_version(), STK_MIN_COUPLING_VERSION); - EXPECT_EQ(stk::util::get_global_max_coupling_version(), STK_MAX_COUPLING_VERSION); + EXPECT_EQ(stk::util::get_global_max_coupling_version(), stk::util::impl::SHORT_TERM_STK_MAX_COUPLING_VERSION /*STK_MAX_COUPLING_VERSION*/); } TEST_F(CouplingVersionsTester, MixedVersion) @@ -81,8 +81,8 @@ TEST_F(CouplingVersionsTester, DeprecatedVersionCheck) TEST_F(CouplingVersionsTester, NewVersionComm) { stk::util::set_coupling_version(MPI_COMM_WORLD); - EXPECT_EQ(stk::util::get_common_coupling_version(), STK_MAX_COUPLING_VERSION); - EXPECT_EQ(stk::util::get_global_max_coupling_version(), STK_MAX_COUPLING_VERSION); + EXPECT_EQ(stk::util::get_common_coupling_version(), stk::util::impl::SHORT_TERM_STK_MAX_COUPLING_VERSION /*STK_MAX_COUPLING_VERSION*/); + EXPECT_EQ(stk::util::get_global_max_coupling_version(), stk::util::impl::SHORT_TERM_STK_MAX_COUPLING_VERSION /*STK_MAX_COUPLING_VERSION*/); } @@ -90,12 +90,12 @@ TEST_F(CouplingVersionsTester, NonincreasingVersion) { stk::util::impl::set_coupling_version(STK_MIN_COUPLING_VERSION); EXPECT_EQ(stk::util::get_common_coupling_version(), STK_MIN_COUPLING_VERSION); - EXPECT_EQ(stk::util::get_global_max_coupling_version(), STK_MAX_COUPLING_VERSION); + EXPECT_EQ(stk::util::get_global_max_coupling_version(), stk::util::impl::SHORT_TERM_STK_MAX_COUPLING_VERSION /*STK_MAX_COUPLING_VERSION*/); stk::util::set_coupling_version(MPI_COMM_WORLD); EXPECT_EQ(stk::util::get_common_coupling_version(), STK_MIN_COUPLING_VERSION); - EXPECT_EQ(stk::util::get_global_max_coupling_version(), STK_MAX_COUPLING_VERSION); + EXPECT_EQ(stk::util::get_global_max_coupling_version(), stk::util::impl::SHORT_TERM_STK_MAX_COUPLING_VERSION /*STK_MAX_COUPLING_VERSION*/); } diff --git a/packages/stk/stk_unit_tests/stk_util/parallel/UnitTestParallelReduceBool.cpp b/packages/stk/stk_unit_tests/stk_util/parallel/UnitTestParallelReduceBool.cpp new file mode 100644 index 000000000000..f194f4d6f9e9 --- /dev/null +++ b/packages/stk/stk_unit_tests/stk_util/parallel/UnitTestParallelReduceBool.cpp @@ -0,0 +1,66 @@ +#include "gtest/gtest.h" +#include "stk_util/parallel/Parallel.hpp" +#include "stk_util/parallel/ParallelReduceBool.hpp" + +#ifdef STK_HAS_MPI + +TEST(ParallelReduceBool, MPI_Cxx_Bool) +{ + EXPECT_NE(MPI_CXX_BOOL, MPI_DATATYPE_NULL); +} + +//----------------------------------------------------------------------------- +// is_true_on_any_proc test + +TEST(ParallelReduceBool, is_true_on_any_proc_all_true) +{ + EXPECT_TRUE(stk::is_true_on_any_proc(MPI_COMM_WORLD, true)); +} + +TEST(ParallelReduceBool, is_true_on_any_proc_all_false) +{ + EXPECT_FALSE(stk::is_true_on_any_proc(MPI_COMM_WORLD, false)); +} + +TEST(ParallelReduceBool, is_true_on_any_proc_one_true) +{ + int comm_rank = stk::parallel_machine_rank(MPI_COMM_WORLD); + int comm_size = stk::parallel_machine_size(MPI_COMM_WORLD); + + for (int i=0; i < comm_size; ++i) + { + bool val = comm_rank == i ? true : false; + EXPECT_TRUE(stk::is_true_on_any_proc(MPI_COMM_WORLD, val)); + } +} + +//----------------------------------------------------------------------------- +// is_true_on_all_procs tests + +TEST(ParallelReduceBool, is_true_on_all_procs_all_true) +{ + EXPECT_TRUE(stk::is_true_on_all_procs(MPI_COMM_WORLD, true)); +} + +TEST(ParallelReduceBool, is_true_on_all_procs_all_false) +{ + EXPECT_FALSE(stk::is_true_on_all_procs(MPI_COMM_WORLD, false)); +} + +TEST(ParallelReduceBool, is_true_on_all_procs_one_true) +{ + int comm_rank = stk::parallel_machine_rank(MPI_COMM_WORLD); + int comm_size = stk::parallel_machine_size(MPI_COMM_WORLD); + + for (int i=0; i < comm_size; ++i) + { + bool val = comm_rank == i ? true : false; + if (comm_size > 1) { + EXPECT_FALSE(stk::is_true_on_all_procs(MPI_COMM_WORLD, val)); + } else { + EXPECT_TRUE(stk::is_true_on_all_procs(MPI_COMM_WORLD, val)); + } + } +} + +#endif \ No newline at end of file diff --git a/packages/stk/stk_unit_tests/stk_util/schedulerTest.cpp b/packages/stk/stk_unit_tests/stk_util/schedulerTest.cpp index f38a3121f358..93992a20f6ea 100644 --- a/packages/stk/stk_unit_tests/stk_util/schedulerTest.cpp +++ b/packages/stk/stk_unit_tests/stk_util/schedulerTest.cpp @@ -164,6 +164,23 @@ TEST(SchedulerTest, emptyScheduler) EXPECT_FALSE(scheduler.is_it_time(terminationTime+0.5, 2)); } +TEST(SchedulerTest, stepIntervalWithTerminationTime) +{ + using stk::util::Step; + + stk::util::Scheduler scheduler; + const stk::util::Time terminationTime = 4.5; + scheduler.set_termination_time(terminationTime); + scheduler.add_interval(Step(0), Step(2)); + EXPECT_TRUE(scheduler.is_it_time(0.0, 0)); + EXPECT_FALSE(scheduler.is_it_time(0.5, 1)); + EXPECT_TRUE(scheduler.is_it_time(2.0, 2)); + EXPECT_FALSE(scheduler.is_it_time(3.5, 3)); + EXPECT_TRUE(scheduler.is_it_time(4.0, 4)); + EXPECT_TRUE(scheduler.is_it_time(terminationTime, 5)); + EXPECT_FALSE(scheduler.is_it_time(terminationTime+0.5, 6)); +} + TEST(SchedulerTest, largeStartingTimeFollowedBySmallStep) { stk::util::Scheduler scheduler; diff --git a/packages/stk/stk_unit_tests/stk_util/util/UnitTestScheduler.cpp b/packages/stk/stk_unit_tests/stk_util/util/UnitTestScheduler.cpp index 277f99f7181f..d57466edaaae 100644 --- a/packages/stk/stk_unit_tests/stk_util/util/UnitTestScheduler.cpp +++ b/packages/stk/stk_unit_tests/stk_util/util/UnitTestScheduler.cpp @@ -93,9 +93,7 @@ TEST(Scheduler, LogarithmicOutput) const double dt_max = 100.0; std::mt19937 rng; - auto time = static_cast(stk::wall_time()); - rng.seed(time); - std::cout << "Running with seed = " << time << std::endl; + rng.seed(666); std::uniform_real_distribution noise(-1.0, 1.0); diff --git a/packages/stk/stk_util/stk_util/environment/CPUTime.cpp b/packages/stk/stk_util/stk_util/environment/CPUTime.cpp index d7d10f4a9c3f..67e7ba09aee1 100644 --- a/packages/stk/stk_util/stk_util/environment/CPUTime.cpp +++ b/packages/stk/stk_util/stk_util/environment/CPUTime.cpp @@ -33,23 +33,15 @@ // #include "stk_util/environment/CPUTime.hpp" -#include // for rusage, getrusage, RUSAGE_SELF -#include // for timeval - +#include namespace stk { double cpu_time() { - struct rusage my_rusage; - - ::getrusage(RUSAGE_SELF, &my_rusage); - - double seconds = my_rusage.ru_utime.tv_sec + my_rusage.ru_stime.tv_sec; - double micro_seconds = my_rusage.ru_utime.tv_usec + my_rusage.ru_stime.tv_usec; - - return seconds + micro_seconds*1.0e-6; + auto time = std::chrono::high_resolution_clock::now(); + return std::chrono::duration(time.time_since_epoch()).count(); } } // namespace stk diff --git a/packages/stk/stk_util/stk_util/environment/Env.cpp b/packages/stk/stk_util/stk_util/environment/Env.cpp index d66b99f8a969..4b3298d9dbe7 100644 --- a/packages/stk/stk_util/stk_util/environment/Env.cpp +++ b/packages/stk/stk_util/stk_util/environment/Env.cpp @@ -363,7 +363,6 @@ set_param( const char *s = std::strcpy(new char[std::strlen(option) + 1], option); stk::parse_command_line_args(argc, &s, stk::get_options_specification(), stk::get_parsed_options()); - delete [] s; } diff --git a/packages/stk/stk_util/stk_util/parallel/CommSparse.cpp b/packages/stk/stk_util/stk_util/parallel/CommSparse.cpp index bff636ae975d..471cee624eee 100644 --- a/packages/stk/stk_util/stk_util/parallel/CommSparse.cpp +++ b/packages/stk/stk_util/stk_util/parallel/CommSparse.cpp @@ -52,10 +52,15 @@ namespace stk { static const int STK_COMMSPARSE_MPI_TAG_MSG_SIZING = 10101; static const int STK_COMMSPARSE_MPI_TAG_PROC_SIZING = 10111; + +#if STK_MIN_COUPLING_VERSION < 6 static const int STK_COMMSPARSE_MPI_TAG_DATA = 11011; +#endif + namespace { +#if STK_MIN_COPULING_VERSION < 6 void launch_ireceives(ParallelMachine p_comm, const std::vector& recv_procs, std::vector& recv, @@ -175,6 +180,8 @@ void communicate_unpack(ParallelMachine p_comm , } +#endif // STK_MIN_COUPLING_VERSION + #else // Not parallel @@ -183,6 +190,7 @@ void communicate_unpack(ParallelMachine p_comm , //---------------------------------------------------------------------- +#if STK_MIN_COUPLING_VERSION < 6 namespace { inline @@ -194,28 +202,40 @@ size_t align_quad( size_t n ) } -//---------------------------------------------------------------------- +#endif -void CommSparse::reset_buffers() -{ - for (size_t i=0 ; i= 6) { + if (m_exchanger) + { + for (int p=0; p < m_size; ++p) + { + m_exchanger->get_send_buf(p).reset(); + m_exchanger->get_recv_buf(p).reset(); + } + } else + { + m_null_comm_send_buffer.reset(); + m_null_comm_recv_buffer.reset(); + } -//---------------------------------------------------------------------- + m_num_recvs = DataExchangeUnknownPatternNonBlocking::Unknown; + } else { + for (size_t i=0 ; i& bufs, std::vector& data) { size_t n_size = 0; @@ -235,46 +255,71 @@ void CommSparse::allocate_data(std::vector& bufs, std::vector= 6) { + if (m_exchanger) { + m_exchanger->allocate_send_buffers(); + } else { + size_t size = m_null_comm_send_buffer.size(); + m_null_comm_storage.resize(size); + auto* ptr = m_null_comm_storage.data(); + m_null_comm_send_buffer.set_buffer_ptrs(ptr, ptr, ptr + size); + m_null_comm_recv_buffer.set_buffer_ptrs(ptr, ptr, ptr + size); + } - if (m_size > 1) { - comm_recv_procs_and_msg_sizes(m_comm, m_send, m_recv, m_send_procs, m_recv_procs); - allocate_data(m_send, m_send_data); - allocate_data(m_recv, m_recv_data); - } - else { - allocate_data(m_send, m_send_data); - m_recv = m_send; - m_recv_data = m_send_data; - if (m_send[0].capacity() > 0) { - m_send_procs.resize(1); - m_send_procs[0] = 0; - m_recv_procs = m_send_procs; + return false; + } else { + m_send.resize(m_size); + m_recv.resize(m_size); + + if (m_size > 1) { + comm_recv_procs_and_msg_sizes(m_comm, m_send, m_recv, m_send_procs, m_recv_procs); + allocate_data(m_send, m_send_data); + allocate_data(m_recv, m_recv_data); } + else { + allocate_data(m_send, m_send_data); + m_recv = m_send; + m_recv_data = m_send_data; + if (m_send[0].capacity() > 0) { + m_send_procs.resize(1); + m_send_procs[0] = 0; + m_recv_procs = m_send_procs; + } + } + return ((m_send_procs.size() > 0) || (m_recv_procs.size() > 0)); } - return ((m_send_procs.size() > 0) || (m_recv_procs.size() > 0)); } void CommSparse::allocate_buffers(const std::vector& send_procs, const std::vector& recv_procs) { - m_send.resize(m_size); - m_recv.resize(m_size); - - m_send_procs = send_procs; - m_recv_procs = recv_procs; - - if (m_size > 1) { - comm_recv_msg_sizes(m_comm , send_procs, recv_procs, m_send, m_recv); - allocate_data(m_send, m_send_data); - allocate_data(m_recv, m_recv_data); - } - else { - m_recv = m_send; - m_recv_data = m_send_data; + + stk::util::print_unsupported_version_warning(5, __LINE__, __FILE__); + + if (stk::util::get_common_coupling_version() >= 6) { + allocate_buffers(); + m_num_recvs = recv_procs.size(); + } else { + m_send.resize(m_size); + m_recv.resize(m_size); + + m_send_procs = send_procs; + m_recv_procs = recv_procs; + + if (m_size > 1) { + comm_recv_msg_sizes(m_comm , send_procs, recv_procs, m_send, m_recv); + allocate_data(m_send, m_send_data); + allocate_data(m_recv, m_recv_data); + } + else { + m_recv = m_send; + m_recv_data = m_send_data; + } } } @@ -283,7 +328,7 @@ void CommSparse::verify_send_buffers_filled() #ifndef NDEBUG for ( int i = 0 ; i < m_size ; ++i ) { // Verify the send buffers have been filled - if ( m_send[i].remaining() ) { + if ( send_buffer(i).remaining() ) { std::ostringstream msg ; msg << "stk::CommSparse::communicate LOCAL[" << m_rank << "] ERROR: Send[" << i << "] Buffer not filled." ; @@ -295,23 +340,50 @@ void CommSparse::verify_send_buffers_filled() void CommSparse::communicate() { - verify_send_buffers_filled(); +#ifdef STK_HAS_MPI + stk::util::print_unsupported_version_warning(5, __LINE__, __FILE__); + + if (stk::util::get_common_coupling_version() >= 6) { + if (m_exchanger) + { + auto f = [](int rank, stk::CommBuffer& buf) {}; + communicate_with_unpacker(f); + } + } else { + verify_send_buffers_filled(); - if ( 1 < m_size ) { - communicate_any( m_comm , m_send , m_recv, m_send_procs, m_recv_procs ); + if ( 1 < m_size ) { + communicate_any( m_comm , m_send , m_recv, m_send_procs, m_recv_procs ); + } } +#endif } void CommSparse::communicate_with_unpacker(const std::function& functor) { - verify_send_buffers_filled(); +#ifdef STK_HAS_MPI + stk::util::print_unsupported_version_warning(5, __LINE__, __FILE__); + + if (stk::util::get_common_coupling_version() >= 6) { + if (m_exchanger) + { + verify_send_buffers_filled(); + + m_exchanger->start_nonblocking(m_num_recvs); + m_exchanger->post_nonblocking_receives(); + m_exchanger->complete_receives(functor); + m_exchanger->complete_sends(); + } + } else { + verify_send_buffers_filled(); - if (1 < m_size) { - communicate_unpack(m_comm , m_send , m_recv, m_send_procs, m_recv_procs, functor); + if (1 < m_size) { + communicate_unpack(m_comm , m_send , m_recv, m_send_procs, m_recv_procs, functor); + } } +#endif } -//---------------------------------------------------------------------- //---------------------------------------------------------------------- #if defined(STK_HAS_MPI) diff --git a/packages/stk/stk_util/stk_util/parallel/CommSparse.hpp b/packages/stk/stk_util/stk_util/parallel/CommSparse.hpp index 99251cbc297f..09fa19e5ee8a 100644 --- a/packages/stk/stk_util/stk_util/parallel/CommSparse.hpp +++ b/packages/stk/stk_util/stk_util/parallel/CommSparse.hpp @@ -35,9 +35,10 @@ #ifndef stk_util_parallel_CommSparse_hpp #define stk_util_parallel_CommSparse_hpp +#include "stk_util/parallel/CouplingVersions.hpp" #include "stk_util/util/ReportHandler.hpp" #include "stk_util/parallel/Parallel.hpp" // for ParallelMachine, parallel_machine_null -#include "stk_util/parallel/ParallelComm.hpp" // for CommBuffer +#include "stk_util/parallel/DataExchangeUnknownPatternNonBlockingBuffer.hpp" #include // for size_t #include // for vector @@ -53,13 +54,14 @@ namespace stk { * Output vectors for send-procs and recv-procs will have * length num-send-procs and num-recv-procs respectively. */ -void comm_recv_procs_and_msg_sizes(ParallelMachine comm, +#ifndef STK_HIDE_DEPRECATED_CODE // delete coupling version 5 is deprecated +STK_DEPRECATED void comm_recv_procs_and_msg_sizes(ParallelMachine comm, const unsigned * const send_size, unsigned * const recv_size, std::vector& output_send_procs, std::vector& output_recv_procs); -void comm_recv_procs_and_msg_sizes(ParallelMachine comm , +STK_DEPRECATED void comm_recv_procs_and_msg_sizes(ParallelMachine comm , const std::vector& send_bufs , std::vector& recv_bufs, std::vector& send_procs, @@ -69,17 +71,19 @@ void comm_recv_procs_and_msg_sizes(ParallelMachine comm , * send-procs and recv-procs (of length number-of-procs-to-send/recv-with), * set recv sizes (recv_size array has length number-of-MPI-processor-ranks). */ -void comm_recv_msg_sizes(ParallelMachine comm , + +STK_DEPRECATED void comm_recv_msg_sizes(ParallelMachine comm , const unsigned * const send_size , const std::vector& send_procs, const std::vector& recv_procs, unsigned * const recv_size); -void comm_recv_msg_sizes(ParallelMachine comm , +STK_DEPRECATED void comm_recv_msg_sizes(ParallelMachine comm , const std::vector& send_procs, const std::vector& recv_procs, const std::vector& send_bufs, std::vector& recv_bufs); +#endif class CommSparse { public: @@ -92,21 +96,67 @@ class CommSparse { CommBuffer & send_buffer( int p ) { ThrowAssertMsg(p < m_size,"CommSparse::send_buffer: "<= 6) { + if (m_exchanger) { + return m_exchanger->get_send_buf(p); + } else { + return m_null_comm_send_buffer; + } + } else { + return m_send[p] ; + } + } + + const CommBuffer & send_buffer( int p ) const + { + ThrowAssertMsg(p < m_size,"CommSparse::send_buffer: "<= 6) { + if (m_exchanger) { + return m_exchanger->get_send_buf(p); + } else { + return m_null_comm_send_buffer; + } + } else { + return m_send[p] ; + } } /** Obtain the message buffer for a given processor */ CommBuffer & recv_buffer( int p ) { ThrowAssertMsg(p < m_size,"CommSparse::recv_buffer: "<= 6) { + if (m_exchanger) { + return m_exchanger->get_recv_buf(p); + } else { + return m_null_comm_recv_buffer; + } + } else { + return m_recv[p] ; + } } /** Obtain the message buffer for a given processor */ const CommBuffer & recv_buffer( int p ) const { ThrowAssertMsg(p < m_size,"CommSparse::recv_buffer: "<= 6) { + if (m_exchanger) { + return m_exchanger->get_recv_buf(p); + } else { + return m_null_comm_recv_buffer; + } + } else { + return m_recv[p] ; + } } //---------------------------------------- @@ -125,21 +175,25 @@ class CommSparse { : m_comm( comm ), m_size( parallel_machine_size( comm ) ), m_rank( parallel_machine_rank( comm ) ), +#if STK_MIN_COUPLING_VERSION < 6 m_send(m_size), m_recv(m_size), m_send_data(), m_recv_data(), m_send_procs(), - m_recv_procs() + m_recv_procs(), +#endif + m_exchanger(nullptr) { + if (comm != MPI_COMM_NULL && stk::util::get_common_coupling_version() >= 6) { + m_exchanger = std::make_shared(comm); + } } CommSparse(const CommSparse&) = delete; /** Allocate communication buffers based upon * sizing from the surrogate send buffer packing. - * Returns true if the local processor is actually - * sending or receiving. */ bool allocate_buffers(); @@ -162,64 +216,65 @@ class CommSparse { communicate_with_unpacker(alg); } - /** Swap send and receive buffers leading to reversed communication. */ - void swap_send_recv(); - /** Reset, but do not reallocate, message buffers for reprocessing. * Sets 'size() == 0' and 'remaining() == capacity()'. */ void reset_buffers(); - ~CommSparse() - { - m_comm = parallel_machine_null(); - m_size = 0 ; - m_rank = 0 ; - m_send.clear(); - m_recv.clear(); - } private: - /** Construct for undefined communication. - * No buffers are allocated. - */ - CommSparse() - : m_comm( parallel_machine_null() ), - m_size( 0 ), - m_rank( 0 ), - m_send(), - m_recv(), - m_send_data(), - m_recv_data(), - m_send_procs(), - m_recv_procs() - {} - +#if STK_MIN_COUPLING_VERSION < 6 void allocate_data(std::vector& bufs, std::vector& data); +#endif void verify_send_buffers_filled(); void communicate_with_unpacker(const std::function& functor); ParallelMachine m_comm ; int m_size ; int m_rank ; +#if STK_MIN_COUPLING_VERSION < 6 std::vector m_send; std::vector m_recv; std::vector m_send_data; std::vector m_recv_data; std::vector m_send_procs; std::vector m_recv_procs; +#endif + + int m_num_recvs = DataExchangeUnknownPatternNonBlocking::Unknown; + std::shared_ptr m_exchanger; + + stk::CommBuffer m_null_comm_send_buffer; + stk::CommBuffer m_null_comm_recv_buffer; + std::vector m_null_comm_storage; }; template bool pack_and_communicate(COMM & comm, const PACK_ALGORITHM & algorithm) { + stk::util::print_unsupported_version_warning(5, __LINE__, __FILE__); + + if (stk::util::get_common_coupling_version() >= 6) { + algorithm(); + comm.allocate_buffers(); + algorithm(); + comm.communicate(); + + for (int i=0; i < comm.parallel_size(); ++i) { + if (comm.send_buffer(i).capacity() > 0 || comm.recv_buffer(i).capacity() > 0) { + return true; + } + } + return false; + } else { algorithm(); const bool actuallySendingOrReceiving = comm.allocate_buffers(); if (actuallySendingOrReceiving) { algorithm(); comm.communicate(); } - return actuallySendingOrReceiving; + return actuallySendingOrReceiving; + } } template diff --git a/packages/stk/stk_util/stk_util/parallel/CouplingVersions.cpp b/packages/stk/stk_util/stk_util/parallel/CouplingVersions.cpp index db132ddc15a2..54bdb0ca8301 100644 --- a/packages/stk/stk_util/stk_util/parallel/CouplingVersions.cpp +++ b/packages/stk/stk_util/stk_util/parallel/CouplingVersions.cpp @@ -12,6 +12,7 @@ namespace stk { namespace util { + #ifdef STK_HAS_MPI void MPI_Op_MaxMinReduction(void* invec, void* inoutvec, int* len, MPI_Datatype* datatype) @@ -19,17 +20,17 @@ void MPI_Op_MaxMinReduction(void* invec, void* inoutvec, int* len, MPI_Datatype* int* invec_int = reinterpret_cast(invec); int* inoutvec_int = reinterpret_cast(inoutvec); - inoutvec_int[0] = std::max(invec_int[0], inoutvec_int[0]); inoutvec_int[1] = std::min(invec_int[1], inoutvec_int[1]); + inoutvec_int[2] = std::max(invec_int[2], inoutvec_int[2]); } std::pair allreduce_minmax(MPI_Comm comm, int localVersion) { // for compatibility with the ParallelReduce code, the buffer has to be - // large enough for 3 ints (2 ints + empty struct + padding), even though - // only the first 2 are used + // large enough for 3 ints (empty struct + padding + 2 ints), even though + // only the ints are used constexpr int bufSize = 3; - std::array inbuf{localVersion, localVersion, -1}, outbuf; + std::array inbuf{-1, localVersion, localVersion}, outbuf; MPI_Op mpiOp = MPI_OP_NULL ; MPI_Op_create( MPI_Op_MaxMinReduction , false , &mpiOp ); @@ -49,7 +50,7 @@ std::pair allreduce_minmax(MPI_Comm comm, int localVersion) MPI_Op_free(&mpiOp); - return {outbuf[1], outbuf[0]}; + return {outbuf[1], outbuf[2]}; } @@ -68,7 +69,7 @@ class StkCompatibleVersion void set_version(MPI_Comm comm) { - set_version_impl(comm, m_version); + set_version_impl(comm, std::min(impl::SHORT_TERM_STK_MAX_COUPLING_VERSION, m_version) /*m_version*/); } void set_error_on_reset(bool val) @@ -78,7 +79,7 @@ class StkCompatibleVersion void reset_global_max_coupling_version() { - m_globalMaxVersion = STK_MAX_COUPLING_VERSION; + m_globalMaxVersion = impl::SHORT_TERM_STK_MAX_COUPLING_VERSION; // STK_MAX_COUPLING_VERSION; } private: @@ -150,7 +151,7 @@ class StkCompatibleVersion } int m_version = STK_MAX_COUPLING_VERSION; - int m_globalMaxVersion = STK_MAX_COUPLING_VERSION; + int m_globalMaxVersion = impl::SHORT_TERM_STK_MAX_COUPLING_VERSION; // STK_MAX_COUPLING_VERSION; bool m_isVersionSet = false; bool m_errorOnResetVersion = true; }; @@ -170,13 +171,13 @@ int get_common_coupling_version() #ifdef STK_HAS_MPI return get_stk_coupling_version().get_version(); #else - return STK_MAX_COUPLING_VERSION; + return STK_impl::SHORT_TERM_MAX_COUPLING_VERSION; //STK_MAX_COUPLING_VERSION; #endif } int get_local_max_coupling_version() { - return STK_MAX_COUPLING_VERSION; + return impl::SHORT_TERM_STK_MAX_COUPLING_VERSION; //STK_MAX_COUPLING_VERSION; } int get_local_min_coupling_version() @@ -190,7 +191,7 @@ int get_global_max_coupling_version() #ifdef STK_HAS_MPI return get_stk_coupling_version().get_global_max_version(); #else - return STK_MAX_COUPLING_VERSION + return impl::SHORT_TERM_STK_MAX_COUPLING_VERSION; // STK_MAX_COUPLING_VERSION #endif } @@ -202,7 +203,10 @@ std::string get_deprecation_date(int version) std::make_pair(2, "7/26/2022"), std::make_pair(3, "7/26/2022"), std::make_pair(4, "7/27/2022"), - std::make_pair(5, "") + std::make_pair(5, "9/13/2022"), + std::make_pair(6, "9/18/2022"), + std::make_pair(7, "10/16/2022"), + std::make_pair(8, "") }; return deprecationDates.at(version); @@ -222,10 +226,10 @@ bool is_local_stk_coupling_deprecated() } -void print_unsupported_version_warning(int version, int line, const std::string& file) +void print_unsupported_version_warning(int version, int line, const char* file) { if ( STK_MIN_COUPLING_VERSION > version ) { - std::cerr << "The function at line " << __LINE__ << " of file " << __FILE__ + std::cerr << "The function at line " << line << " of file " << file << " can be simplified now that STK_MIN_COUPLING_VERSION is greater than " << (version) << std::endl; } diff --git a/packages/stk/stk_util/stk_util/parallel/CouplingVersions.hpp b/packages/stk/stk_util/stk_util/parallel/CouplingVersions.hpp index 99725675ca08..da2fbaee8caa 100644 --- a/packages/stk/stk_util/stk_util/parallel/CouplingVersions.hpp +++ b/packages/stk/stk_util/stk_util/parallel/CouplingVersions.hpp @@ -8,13 +8,18 @@ #include -#define STK_MAX_COUPLING_VERSION 5 +#define STK_MAX_COUPLING_VERSION 8 #define STK_MIN_COUPLING_VERSION 0 - + namespace stk { namespace util { +namespace impl { +constexpr int SHORT_TERM_STK_MAX_COUPLING_VERSION=1; +} + + int get_common_coupling_version(); int get_local_max_coupling_version(); @@ -31,7 +36,7 @@ void set_coupling_version(MPI_Comm comm); bool is_local_stk_coupling_deprecated(); -void print_unsupported_version_warning(int version, int line, const std::string& file); +void print_unsupported_version_warning(int version, int line, const char* file); } diff --git a/packages/stk/stk_util/stk_util/parallel/DataExchangeUnknownPatternNonBlocking.cpp b/packages/stk/stk_util/stk_util/parallel/DataExchangeUnknownPatternNonBlocking.cpp index a7674d509142..b3a8ef258325 100644 --- a/packages/stk/stk_util/stk_util/parallel/DataExchangeUnknownPatternNonBlocking.cpp +++ b/packages/stk/stk_util/stk_util/parallel/DataExchangeUnknownPatternNonBlocking.cpp @@ -25,7 +25,7 @@ void DataExchangeUnknownPatternNonBlocking::yield() { // Note: sleep_for would be better for this, but its minimum sleep time is // too long - std::this_thread::yield(); + //std::this_thread::yield(); } } // namespace diff --git a/packages/stk/stk_util/stk_util/parallel/MPITagManager.cpp b/packages/stk/stk_util/stk_util/parallel/MPITagManager.cpp index 205e1d9e235a..69e7ce3f4e76 100644 --- a/packages/stk/stk_util/stk_util/parallel/MPITagManager.cpp +++ b/packages/stk/stk_util/stk_util/parallel/MPITagManager.cpp @@ -1,4 +1,5 @@ #include "stk_util/parallel/MPITagManager.hpp" +#include "stk_util/parallel/CouplingVersions.hpp" #include namespace stk { @@ -176,7 +177,16 @@ void MPITagManager::check_same_value_on_all_procs_debug_only(MPI_Comm comm, int MPITagManager& get_mpi_tag_manager() { - int deletionGroupSize = 32; + stk::util::print_unsupported_version_warning(7, __LINE__, __FILE__); + int deletionGroupSize; + if (stk::util::get_common_coupling_version() >= 8) + { + deletionGroupSize = 33; + } else + { + deletionGroupSize = 32; + } + static int delayCount = -1; if (delayCount < 0) { diff --git a/packages/stk/stk_util/stk_util/parallel/ManagedBufferBase.hpp b/packages/stk/stk_util/stk_util/parallel/ManagedBufferBase.hpp index 82dab06afd67..65f9c91835b8 100644 --- a/packages/stk/stk_util/stk_util/parallel/ManagedBufferBase.hpp +++ b/packages/stk/stk_util/stk_util/parallel/ManagedBufferBase.hpp @@ -121,8 +121,7 @@ class ManagedCommBufferBase explicit ManagedCommBufferBase(MPI_Comm comm) : m_comm(comm) { - int commSize; - MPI_Comm_size(comm, &commSize); + int commSize = parallel_machine_size(comm); m_sendBufs.resize(commSize); m_sendBufStorage.resize(commSize); m_recvBufs.resize(commSize); @@ -141,6 +140,11 @@ class ManagedCommBufferBase return m_sendBufs[rank]; } + const stk::CommBuffer& get_send_buf(int rank) const + { + return m_sendBufs[rank]; + } + stk::CommBuffer& get_recv_buf(int rank) { if (m_recvsInProgress) diff --git a/packages/stk/stk_util/stk_util/parallel/ParallelReduceBool.hpp b/packages/stk/stk_util/stk_util/parallel/ParallelReduceBool.hpp index 4b1f6c692240..b66ce7486b82 100644 --- a/packages/stk/stk_util/stk_util/parallel/ParallelReduceBool.hpp +++ b/packages/stk/stk_util/stk_util/parallel/ParallelReduceBool.hpp @@ -46,8 +46,13 @@ inline bool is_true_on_all_procs(ParallelMachine comm , const bool truthValue) { #ifdef STK_HAS_MPI stk::util::print_unsupported_version_warning(2, __LINE__, __FILE__); - - if (stk::util::get_common_coupling_version() >= 3) { + stk::util::print_unsupported_version_warning(6, __LINE__, __FILE__); + + if (stk::util::get_common_coupling_version() >= 7) { + int truthValueInt = truthValue, globalResult; + MPI_Allreduce(&truthValueInt, &globalResult, 1, MPI_INT, MPI_LAND, comm); + return globalResult; + } else if (stk::util::get_common_coupling_version() >= 3) { bool globalResult; MPI_Allreduce(&truthValue, &globalResult, 1, MPI_CXX_BOOL, MPI_LAND, comm); return globalResult; @@ -67,8 +72,13 @@ inline bool is_true_on_any_proc(ParallelMachine comm , const bool truthValue) { #ifdef STK_HAS_MPI stk::util::print_unsupported_version_warning(2, __LINE__, __FILE__); + stk::util::print_unsupported_version_warning(6, __LINE__, __FILE__); - if (stk::util::get_common_coupling_version() >= 3) { + if (stk::util::get_common_coupling_version() >= 7) { + int truthValueInt = truthValue, globalResult; + MPI_Allreduce(&truthValueInt, &globalResult, 1, MPI_INT, MPI_LOR, comm); + return globalResult; + } else if (stk::util::get_common_coupling_version() >= 3) { bool globalResult; MPI_Allreduce(&truthValue, &globalResult, 1, MPI_CXX_BOOL, MPI_LOR, comm); return globalResult; diff --git a/packages/stk/stk_util/stk_util/registry/ProductRegistry.cpp b/packages/stk/stk_util/stk_util/registry/ProductRegistry.cpp index 83dfc511008d..2f9a625d632c 100644 --- a/packages/stk/stk_util/stk_util/registry/ProductRegistry.cpp +++ b/packages/stk/stk_util/stk_util/registry/ProductRegistry.cpp @@ -42,7 +42,7 @@ //In Sierra, STK_VERSION_STRING is provided on the compile line by bake. //For Trilinos stk snapshots, the following macro definition gets populated with //the real version string by the trilinos_snapshot.sh script. -#define STK_VERSION_STRING "5.7.4-14-gb9702494" +#define STK_VERSION_STRING "5.9.2-596-g5255aa34" #endif namespace stk { diff --git a/packages/stk/stk_util/stk_util/util/StkNgpVector.hpp b/packages/stk/stk_util/stk_util/util/StkNgpVector.hpp index 529eeec21516..2f386caef4d5 100644 --- a/packages/stk/stk_util/stk_util/util/StkNgpVector.hpp +++ b/packages/stk/stk_util/stk_util/util/StkNgpVector.hpp @@ -53,11 +53,7 @@ class NgpVector NgpVector(const std::string &n, size_t s) : mSize(s), deviceVals(Kokkos::view_alloc(Kokkos::WithoutInitializing, n), mSize), -#ifndef NEW_TRILINOS_INTEGRATION - hostVals(Kokkos::create_mirror_view(HostSpace(), deviceVals, Kokkos::WithoutInitializing)) -#else hostVals(Kokkos::create_mirror_view(Kokkos::WithoutInitializing, HostSpace(), deviceVals)) -#endif { } NgpVector(size_t s) : NgpVector(get_default_name(), s) diff --git a/packages/stk/stk_util/stk_util/util/VecSet.hpp b/packages/stk/stk_util/stk_util/util/VecSet.hpp index bdce1e339730..0f7b79b4de56 100644 --- a/packages/stk/stk_util/stk_util/util/VecSet.hpp +++ b/packages/stk/stk_util/stk_util/util/VecSet.hpp @@ -104,11 +104,8 @@ class vecset { public: + typedef typename storage::allocator_type allocator_type ; - typedef typename allocator_type::reference reference ; - typedef typename allocator_type::const_reference const_reference ; - typedef typename allocator_type::pointer pointer ; - typedef typename allocator_type::const_pointer const_pointer ; typedef typename storage::size_type size_type ; typedef typename storage::difference_type difference_type ; typedef typename storage::iterator iterator ; diff --git a/packages/tpetra/CMakeLists.txt b/packages/tpetra/CMakeLists.txt index a3c855107418..919ce4ab8120 100644 --- a/packages/tpetra/CMakeLists.txt +++ b/packages/tpetra/CMakeLists.txt @@ -684,7 +684,7 @@ MESSAGE(STATUS "Tpetra: Tpetra_INST_INT_LONG is disabled by default.") # IF anything other than long long is defined and enabled, and long long isn't defined then we disable long long IF( ( ( DEFINED Tpetra_INST_INT_INT AND Tpetra_INST_INT_INT) - OR ( DEFINED Tpetra_INST_INT_LONG AND Tpetra_INST_INT_UNSIGNED_LONG) + OR ( DEFINED Tpetra_INST_INT_LONG AND Tpetra_INST_INT_LONG) OR ( DEFINED Tpetra_INST_INT_UNSIGNED AND Tpetra_INST_INT_UNSIGNED) OR ( DEFINED Tpetra_INST_INT_UNSIGNED_LONG AND Tpetra_INST_INT_UNSIGNED_LONG)) AND (NOT DEFINED Tpetra_INST_INT_LONG_LONG) ) diff --git a/packages/tpetra/core/CMakeLists.txt b/packages/tpetra/core/CMakeLists.txt index 3bfbbce2590d..56d2a94e72f8 100644 --- a/packages/tpetra/core/CMakeLists.txt +++ b/packages/tpetra/core/CMakeLists.txt @@ -198,6 +198,16 @@ TRIBITS_ADD_OPTION_AND_DEFINE( OFF ) +TRIBITS_ADD_OPTION_AND_DEFINE( + Tpetra_ENABLE_KokkosIntegrationTest + HAVE_TPETRA_KOKKOSINTEGRATION_TEST + "Enable the KokkosIntegrationTest" + ON + ) + + + + # # Add libraries, tests, and examples # diff --git a/packages/tpetra/core/src/Tpetra_BlockCrsMatrix_decl.hpp b/packages/tpetra/core/src/Tpetra_BlockCrsMatrix_decl.hpp index 5217d8e6fb32..a6f6ea61c732 100644 --- a/packages/tpetra/core/src/Tpetra_BlockCrsMatrix_decl.hpp +++ b/packages/tpetra/core/src/Tpetra_BlockCrsMatrix_decl.hpp @@ -50,6 +50,20 @@ namespace Tpetra { +template +Teuchos::RCP +importAndFillCompleteBlockCrsMatrix (const Teuchos::RCP& sourceMatrix, + const Import& importer, + const Teuchos::RCP >& domainMap = Teuchos::null, + const Teuchos::RCP >& rangeMap = Teuchos::null, + const Teuchos::RCP& params = Teuchos::null); + /// \class BlockCrsMatrix /// \brief Sparse matrix whose entries are small dense square blocks, /// all of the same dimensions. @@ -378,6 +392,13 @@ class BlockCrsMatrix : const Scalar alpha = Teuchos::ScalarTraits::one (), const Scalar beta = Teuchos::ScalarTraits::zero ()); + void + importAndFillComplete (Teuchos::RCP >& destMatrix, + const Import& importer, + const Teuchos::RCP& domainMap, + const Teuchos::RCP& rangeMap, + const Teuchos::RCP& params = Teuchos::null) const; + /// \brief Replace values at the given (mesh, i.e., block) column /// indices, in the given (mesh, i.e., block) row. /// @@ -1195,8 +1216,42 @@ class BlockCrsMatrix : virtual typename ::Tpetra::RowMatrix::mag_type getFrobeniusNorm () const override; //@} + + // Friend declaration for nonmember function. + template + friend Teuchos::RCP + Tpetra::importAndFillCompleteBlockCrsMatrix (const Teuchos::RCP& sourceMatrix, + const Import& importer, + const Teuchos::RCP >& domainMap, + const Teuchos::RCP >& rangeMap, + const Teuchos::RCP& params); }; +template +Teuchos::RCP +importAndFillCompleteBlockCrsMatrix (const Teuchos::RCP& sourceMatrix, + const Import& importer, + const Teuchos::RCP >& domainMap, + const Teuchos::RCP >& rangeMap, + const Teuchos::RCP& params) +{ + Teuchos::RCP destMatrix; + sourceMatrix->importAndFillComplete (destMatrix, importer, domainMap, rangeMap, params); + return destMatrix; +} + } // namespace Tpetra #endif // TPETRA_BLOCKCRSMATRIX_DECL_HPP diff --git a/packages/tpetra/core/src/Tpetra_BlockCrsMatrix_def.hpp b/packages/tpetra/core/src/Tpetra_BlockCrsMatrix_def.hpp index 19ca96919c17..701256680171 100644 --- a/packages/tpetra/core/src/Tpetra_BlockCrsMatrix_def.hpp +++ b/packages/tpetra/core/src/Tpetra_BlockCrsMatrix_def.hpp @@ -943,6 +943,40 @@ class GetLocalDiagCopy { } } + template + void + BlockCrsMatrix:: + importAndFillComplete (Teuchos::RCP >& destMatrix, + const Import& importer, + const Teuchos::RCP& domainMap, + const Teuchos::RCP& rangeMap, + const Teuchos::RCP& params) const + { + using Teuchos::RCP; + using Teuchos::rcp; + using this_type = BlockCrsMatrix; + + // Right now, we make many assumptions... + TEUCHOS_TEST_FOR_EXCEPTION(!destMatrix.is_null(), std::invalid_argument, + "Right now, assuming destMatrix is null."); + TEUCHOS_TEST_FOR_EXCEPTION(!domainMap.is_null(), std::invalid_argument, + "Right now, assuming domainMap is null."); + TEUCHOS_TEST_FOR_EXCEPTION(!rangeMap.is_null(), std::invalid_argument, + "Right now, assuming rangeMap is null."); + TEUCHOS_TEST_FOR_EXCEPTION(!params.is_null(), std::invalid_argument, + "Right now, assuming params is null."); + + // BlockCrsMatrix requires a complete graph at construction. + // So first step is to import and fill complete the destGraph. + RCP destGraph = rcp (new crs_graph_type (importer.getTargetMap(), 0)); + destGraph->doImport(this->getCrsGraph(), importer, Tpetra::INSERT); + destGraph->fillComplete(); + + // Final step, create and import the destMatrix. + destMatrix = rcp (new this_type (*destGraph, getBlockSize())); + destMatrix->doImport(*this, importer, Tpetra::INSERT); + } + template void BlockCrsMatrix:: @@ -2738,23 +2772,25 @@ class GetLocalDiagCopy { errorDuringUnpack () = 0; { using policy_type = Kokkos::TeamPolicy; - const auto policy = policy_type (numImportLIDs, 1, 1) - .set_scratch_size (0, Kokkos::PerTeam (sizeof (GO) * maxRowNumEnt + - sizeof (LO) * maxRowNumEnt + - numBytesPerValue * maxRowNumScalarEnt)); + size_t scratch_per_row = sizeof(GO) * maxRowNumEnt + sizeof (LO) * maxRowNumEnt + numBytesPerValue * maxRowNumScalarEnt + + 2 * sizeof(GO); // Yeah, this is a fudge factor + + const auto policy = policy_type (numImportLIDs, 1, 1) + .set_scratch_size (0, Kokkos::PerTeam (scratch_per_row)); using host_scratch_space = typename host_exec::scratch_memory_space; + using pair_type = Kokkos::pair; Kokkos::parallel_for ("Tpetra::BlockCrsMatrix::unpackAndCombine: unpack", policy, [=] (const typename policy_type::member_type& member) { const size_t i = member.league_rank(); - Kokkos::View gblColInds (member.team_scratch (0), maxRowNumEnt); Kokkos::View lclColInds (member.team_scratch (0), maxRowNumEnt); Kokkos::View vals (member.team_scratch (0), maxRowNumScalarEnt); + const size_t offval = offset(i); const LO lclRow = importLIDsHost(i); diff --git a/packages/tpetra/core/src/Tpetra_Details_computeOffsets.hpp b/packages/tpetra/core/src/Tpetra_Details_computeOffsets.hpp index 193633f265dd..e0ccadd155fc 100644 --- a/packages/tpetra/core/src/Tpetra_Details_computeOffsets.hpp +++ b/packages/tpetra/core/src/Tpetra_Details_computeOffsets.hpp @@ -127,7 +127,7 @@ class ComputeOffsetsFromCounts { functor_type functor (offsets, counts); OffsetType total (0); const char funcName[] = "Tpetra::Details::computeOffsetsFromCounts"; - Kokkos::parallel_scan (range, functor, total, funcName); + Kokkos::parallel_scan (funcName, range, functor, total); return total; } diff --git a/packages/tpetra/core/test/Block/BlockCrsMatrix.cpp b/packages/tpetra/core/test/Block/BlockCrsMatrix.cpp index 42c99ff85ff4..1b92db5a415f 100644 --- a/packages/tpetra/core/test/Block/BlockCrsMatrix.cpp +++ b/packages/tpetra/core/test/Block/BlockCrsMatrix.cpp @@ -61,6 +61,7 @@ namespace { using Teuchos::reduceAll; using Teuchos::RCP; using Teuchos::rcp; + using Teuchos::ScalarTraits; using std::endl; typedef Tpetra::global_size_t GST; @@ -1524,6 +1525,503 @@ namespace { TEST_EQUALITY_CONST( gblSuccess, 1 ); } + // Test that two graphs are same. + template + bool graphs_are_same(const Graph& G1, const Graph& G2) + { + typedef typename Graph::local_ordinal_type LO; + + int my_rank = G1.getRowMap()->getComm()->getRank(); + + // Make sure each graph is fill complete before checking other properties + if (! G1.isFillComplete()) { + if (my_rank == 0) + std::cerr << "Error: Graph 1 is not fill complete!" << std::endl; + return false; + } + if (! G2.isFillComplete()) { + if (my_rank == 0) + std::cerr << "Error: Graph 2 is not fill complete!" << std::endl; + return false; + } + + int errors = 0; + + if (! G1.getRowMap()->isSameAs(*G2.getRowMap())) { + if (my_rank == 0) + std::cerr << "Error: Graph 1's row map is different than Graph 2's" << std::endl; + errors++; + } + if (! G1.getDomainMap()->isSameAs(*G2.getDomainMap())) { + if (my_rank == 0) + std::cerr << "Error: Graph 1's domain map is different than Graph 2's" << std::endl; + errors++; + } + if (! G1.getRangeMap()->isSameAs(*G2.getRangeMap())) { + if (my_rank == 0) + std::cerr << "Error: Graph 1's range map is different than Graph 2's" << std::endl; + errors++; + } + if (G1.getLocalNumEntries() != G2.getLocalNumEntries()) { + std::cerr << "Error: Graph 1 does not have the same number of entries as Graph 2 on Process " + << my_rank << std::endl; + errors++; + } + + if (errors != 0) return false; + + for (LO i=0; i(G1.getLocalNumRows()); i++) { + typename Graph::local_inds_host_view_type V1, V2; + G1.getLocalRowView(i, V1); + G2.getLocalRowView(i, V2); + if (V1.size() != V2.size()) { + std::cerr << "Error: Graph 1 and Graph 2 have different number of entries in local row " + << i << " on Process " << my_rank << std::endl; + errors++; + continue; + } + int jerr = 0; + for (LO j=0; j(V1.size()); j++) { + if (V1[j] != V2[j]) + jerr++; + } + if (jerr != 0) { + std::cerr << "Error: One or more entries in row " << i << " on Process " << my_rank + << " Graphs 1 and 2 are not the same" << std::endl; + errors++; + continue; + } + } + + return (errors == 0); + + } + + // Test that two matrices' rows have the same entries. + template + bool matrices_are_same(const RCP& A1, + const RCP& A2) + { + // Loop through A1 and make sure each row has the same + // entries as A2. In the fully general case, the + // redistribution may have added together values, resulting in + // small rounding errors. This is why we use an error tolerance + // (with a little bit of wiggle room). + + int my_rank = A1->getRowMap()->getComm()->getRank(); + + using LO = typename BlockCrsMatrixType::local_ordinal_type; + using Scalar = typename BlockCrsMatrixType::scalar_type; + using lids_type = typename BlockCrsMatrixType::local_inds_host_view_type; + using vals_type = typename BlockCrsMatrixType::values_host_view_type; + + using ST = ScalarTraits; + using magnitude_type = typename ST::magnitudeType; + const magnitude_type tol = + Teuchos::as (10) * ScalarTraits::eps (); + + const LO blocksize = A1->getBlockSize(); + // Verify the blocksizes are identical + if (blocksize != A2->getBlockSize()) { + if (my_rank==0) std::cerr << "Error: Blocksizes are not the same!" << std::endl; + return false; + } + + // Verify the maps are identical + bool maps_same = A1->getRowMap()->isSameAs(*(A2->getRowMap())); + if (!maps_same) { + if (my_rank==0) std::cerr << "Error: RowMaps are not the same!" << std::endl; + return false; + } + + // Verify the graphs are identical + bool graphs_same = graphs_are_same(A1->getCrsGraph(), A2->getCrsGraph()); + if (!graphs_same) { + if (my_rank==0) std::cerr << "Error: Graphs are not the same!" << std::endl; + return false; + } + + lids_type A1RowInds; + vals_type A1RowVals; + lids_type A2RowInds; + vals_type A2RowVals; + for (LO localrow = A1->getRowMap()->getMinLocalIndex(); + localrow <= A1->getRowMap()->getMaxLocalIndex(); + ++localrow) + { + size_t A1NumEntries = A1->getNumEntriesInLocalRow (localrow); + size_t A2NumEntries = A1->getNumEntriesInLocalRow (localrow); + + // Verify the same number of entries in each row + if (A1NumEntries != A2NumEntries) { + if (my_rank==0) std::cerr << "Error: Matrices have different number of entries in at least one row!" << std::endl; + return false; + } + + A1->getLocalRowView (localrow, A1RowInds, A1RowVals); + A2->getLocalRowView (localrow, A2RowInds, A2RowVals); + + // Verify the same number of values in each row + if (A1RowVals.extent(0) != A2RowVals.extent(0)) { + if (my_rank==0) std::cerr << "Error: Matrices have different number of entries in at least one row!" << std::endl; + return false; + } + + typedef typename Array::size_type size_type; + for (size_type k = 0; k < static_cast (A1NumEntries); ++k) { + // Verify the same column indices + if(A1RowInds[k]!=A2RowInds[k]) { + if (my_rank==0) std::cerr << "Error: Matrices have different column indices!" << std::endl; + return false; + } + } + + for (size_t val=0; val tol) { + if (my_rank==0) std::cerr << "Error: Matrices have different values!" << std::endl; + return false; + } + } + } + + return true; + } + + // Build lower diag matrix for test + template + void build_lower_diag_matrix (const RCP& A) { + + using LO = typename BlockCrsMatrixType::local_ordinal_type; + using GO = typename BlockCrsMatrixType::global_ordinal_type; + using Scalar = typename BlockCrsMatrixType::scalar_type; + + const typename BlockCrsMatrixType::map_type row_map = *(A->getRowMap()); + const typename BlockCrsMatrixType::map_type col_map = *(A->getColMap()); + + int my_rank = row_map.getComm()->getRank(); + + if(A->getBlockSize() != 3) { + if (my_rank==0) std::cerr << "Error: A->getBlockSize != 3!" << std::endl; + return; + } + const int blocksize = 3; + + for (LO localrow = row_map.getMinLocalIndex(); + localrow <= row_map.getMaxLocalIndex(); + ++localrow) { + + const GO globalrow = row_map.getGlobalElement(localrow); + + if (globalrow == 0) { + + LO local_col_indices[1]; + local_col_indices[0] = col_map.getLocalElement(0); + + Scalar values[blocksize*blocksize]; + for (size_t b=0; breplaceLocalValues(localrow, + local_col_indices, + values, + 1); + } + else if (globalrow == 1) { + + LO local_col_indices[2]; + local_col_indices[0] = col_map.getLocalElement(0); + local_col_indices[1] = col_map.getLocalElement(1); + + Scalar values[2*blocksize*blocksize]; + for (GO globalcol=0; globalcol<2; ++globalcol) { + int start = globalcol*blocksize*blocksize; + for (size_t b=0; breplaceLocalValues(localrow, + local_col_indices, + values, + 2); + } else { + + LO local_col_indices[3]; + local_col_indices[0] = col_map.getLocalElement(globalrow-2); + local_col_indices[1] = col_map.getLocalElement(globalrow-1); + local_col_indices[2] = col_map.getLocalElement(globalrow); + + Scalar values[3*blocksize*blocksize]; + int local_indx = 0; + for (GO globalcol=globalrow-2; globalcol<=globalrow; ++globalcol) { + int start = local_indx*blocksize*blocksize; + for (size_t b=0; breplaceLocalValues(localrow, + local_col_indices, + values, + 3); + } + } + + return; + } + + // Test BlockCrsMatrix importAndFillComplete + TEUCHOS_UNIT_TEST_TEMPLATE_4_DECL( BlockCrsMatrix, importAndFillComplete, Scalar, LO, GO, Node ) + { + using Tpetra::Details::gathervPrint; + typedef Tpetra::BlockCrsMatrix block_crs_type; + typedef Tpetra::CrsGraph crs_graph_type; + typedef Tpetra::Map map_type; + typedef Tpetra::Import import_type; + using Teuchos::REDUCE_MAX; + + std::ostringstream err; + int lclErr = 0; + int gblErr = 0; + + out << "Testing Tpetra::BlockCrsMatrix importAndFillComplete" << endl; + Teuchos::OSTab tab0 (out); + + RCP > comm = getDefaultComm (); + const int myRank = comm->getRank (); + const int numRanks = comm->getSize(); + const GST INVALID = Teuchos::OrdinalTraits::invalid (); + + out << "1st test: Import a diagonal BlockCrsMatrix from a source row Map " + "that has all indices on Process 0, to a target row Map that is " + "uniformly distributed over processes. Blocksize=3." << endl; + try { + Teuchos::OSTab tab1 (out); + + const GO indexBase = 0; + const LO tgt_num_local_elements = 2; + const LO src_num_local_elements = (myRank == 0) ? + static_cast (numRanks*tgt_num_local_elements) : + static_cast (0); + + const int blocksize = 3; + + // Create row Maps for the source and target + RCP src_map = + rcp (new map_type (INVALID, + src_num_local_elements, + indexBase, comm)); + RCP tgt_map = + rcp (new map_type (INVALID, + tgt_num_local_elements, + indexBase, comm)); + + // Build src graph. + Teuchos::RCP src_graph = + Teuchos::rcp (new crs_graph_type (src_map, 1)); + for (LO localrow = src_map->getMinLocalIndex(); + localrow<=src_map->getMaxLocalIndex(); + ++localrow) { + + const GO globalrow = src_map->getGlobalElement(localrow); + GO globalcol[1]; + globalcol[0] = globalrow; + + src_graph->insertGlobalIndices(globalrow, 1, globalcol); + } + src_graph->fillComplete(); + + // Build src matrix. Simple block diagonal matrix with A(b,b) = [b*b*row,...,+b*b]. + RCP src_mat = + rcp (new block_crs_type (*src_graph, blocksize)); + if (src_num_local_elements != 0) { + for (LO localrow = src_map->getMinLocalIndex(); + localrow <= src_map->getMaxLocalIndex(); + ++localrow) { + const GO globalrow = src_map->getGlobalElement(localrow); + LO col_indices[1]; Scalar values[blocksize*blocksize]; + col_indices[0] = localrow; + for (size_t b=0; breplaceLocalValues(localrow, + col_indices, + values, + 1); + TEST_EQUALITY_CONST(actual_num_replaces, 1); + } + } + + // Create the importer + import_type importer (src_map, tgt_map); + + // Call importAndFillComplete to get the tgt matrix + RCP tgt_mat = + Tpetra::importAndFillCompleteBlockCrsMatrix (src_mat, importer); + + // Manually build the tgt matrix and test that it matches the returned matrix + + // Build tgt graph. + Teuchos::RCP tgt_graph_for_testing = + Teuchos::rcp (new crs_graph_type (tgt_map, 1)); + for (LO localrow = tgt_map->getMinLocalIndex(); + localrow<=tgt_map->getMaxLocalIndex(); + ++localrow) { + + const GO globalrow = tgt_map->getGlobalElement(localrow); + GO globalcol[1]; + globalcol[0] = globalrow; + + tgt_graph_for_testing->insertGlobalIndices(globalrow, 1, globalcol); + } + tgt_graph_for_testing->fillComplete(); + + // Build tgt matrix + RCP tgt_mat_for_testing = + rcp (new block_crs_type (*tgt_graph_for_testing, blocksize)); + for (LO localrow = tgt_map->getMinLocalIndex(); + localrow <= tgt_map->getMaxLocalIndex(); + ++localrow) { + const GO globalrow = tgt_map->getGlobalElement(localrow); + LO col_indices[1]; Scalar values[blocksize*blocksize]; + col_indices[0] = localrow; + for (size_t b=0; breplaceLocalValues(localrow, + col_indices, + values, + 1); + TEST_EQUALITY_CONST(actual_num_replaces, 1); + } + + // Test that matrices are identical + bool matrices_match = matrices_are_same(tgt_mat, tgt_mat_for_testing); + TEST_ASSERT(matrices_match); + } + catch (std::exception& e) { // end of the first test + err << "Proc " << myRank << ": " << e.what () << endl; + lclErr = 1; + } + + reduceAll (*comm, REDUCE_MAX, lclErr, outArg (gblErr)); + TEST_EQUALITY_CONST( gblErr, 0 ); + if (gblErr != 0) { + Tpetra::Details::gathervPrint (out, err.str (), *comm); + out << "Above test failed; aborting further tests" << endl; + return; + } + + out << "2nd test: Import a lower triangular BlockCrsMatrix from a source row Map " + "where even processors have 1 element and odd processors have 3 elements, " + "to a target row Map where each processor have 2 elements. Blocksize=3." << endl; + try { + Teuchos::OSTab tab1 (out); + + // This test only makes sense for even number of ranks + if (numRanks % 2 != 0) { + return; + } + + const GO indexBase = 0; + LO src_num_local_elements; + if (myRank % 2 == 0) src_num_local_elements = 1; + else src_num_local_elements = 3; + LO tgt_num_local_elements = 2; + const int blocksize = 3; + + // Create row Maps for the source and target + RCP src_map = + rcp (new map_type (INVALID, + src_num_local_elements, + indexBase, comm)); + RCP tgt_map = + rcp (new map_type (INVALID, + tgt_num_local_elements, + indexBase, comm)); + //src_map->describe(out, Teuchos::VERB_EXTREME); + //tgt_map->describe(out, Teuchos::VERB_EXTREME); + + // Build src graph. Allow for up to 2 off-diagonal entries. + Teuchos::RCP src_graph = + Teuchos::rcp (new crs_graph_type (src_map, 3)); + { + Array cols(3); + for (GO globalrow = src_map->getMinGlobalIndex (); + globalrow <= src_map->getMaxGlobalIndex (); ++globalrow) { + if (globalrow==0) cols.resize(1); + else if (globalrow==1) cols.resize(2); + else cols.resize(3); + for (GO col = 0; col < cols.size(); ++col) { + cols[col] = globalrow - col; + } + src_graph->insertGlobalIndices (globalrow, cols()); + } + src_graph->fillComplete(); + //src_graph->describe(out, Teuchos::VERB_EXTREME); + } + + // Build src matrix. Simple block lower-diagonal matrix with + // A(b1,b2) = [(b1)+10*(b2+1)]. + RCP src_mat = + rcp (new block_crs_type (*src_graph, blocksize)); + build_lower_diag_matrix(src_mat); + //src_mat->describe(out, Teuchos::VERB_EXTREME); + + // Create the importer + import_type importer (src_map, tgt_map); + + // Call importAndFillComplete to get the tgt matrix + RCP tgt_mat = + Tpetra::importAndFillCompleteBlockCrsMatrix (src_mat, importer); + //tgt_mat->describe(out, Teuchos::VERB_EXTREME); + + // Manually build the tgt matrix and test that it matches the returned matrix + + // Build tgt graph. + Teuchos::RCP tgt_graph_for_testing = + Teuchos::rcp (new crs_graph_type (tgt_map, 3)); + { + Array cols(3); + for (GO globalrow = tgt_map->getMinGlobalIndex (); + globalrow <= tgt_map->getMaxGlobalIndex (); ++globalrow) { + if (globalrow==0) cols.resize(1); + else if (globalrow==1) cols.resize(2); + else cols.resize(3); + for (GO col = 0; col < cols.size(); ++col) { + cols[col] = globalrow - col; + } + tgt_graph_for_testing->insertGlobalIndices (globalrow, cols()); + } + tgt_graph_for_testing->fillComplete(); + //tgt_graph_for_testing->describe(out, Teuchos::VERB_EXTREME); + } + + // Build tgt matrix + RCP tgt_mat_for_testing = + rcp (new block_crs_type (*tgt_graph_for_testing, blocksize)); + build_lower_diag_matrix(tgt_mat_for_testing); + //tgt_mat_for_testing->describe(out, Teuchos::VERB_EXTREME); + + // Test that matrices are identical + bool matrices_match = matrices_are_same(tgt_mat, tgt_mat_for_testing); + TEST_ASSERT(matrices_match); + } + catch (std::exception& e) { // end of the first test + err << "Proc " << myRank << ": " << e.what () << endl; + lclErr = 1; + } + + reduceAll (*comm, REDUCE_MAX, lclErr, outArg (gblErr)); + TEST_EQUALITY_CONST( gblErr, 0 ); + if (gblErr != 0) { + Tpetra::Details::gathervPrint (out, err.str (), *comm); + out << "Above test failed; aborting further tests" << endl; + return; + } + } + // Test BlockCrsMatrix Export for different graphs with different // row Maps. This tests packAndPrepare and unpackAndCombine. TEUCHOS_UNIT_TEST_TEMPLATE_4_DECL( BlockCrsMatrix, ExportDiffRowMaps, Scalar, LO, GO, Node ) @@ -2307,6 +2805,7 @@ namespace { TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT( BlockCrsMatrix, getLocalDiagCopy, SCALAR, LO, GO, NODE ) \ TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT( BlockCrsMatrix, SetAllToScalar, SCALAR, LO, GO, NODE ) \ TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT( BlockCrsMatrix, ImportCopy, SCALAR, LO, GO, NODE ) \ + TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT( BlockCrsMatrix, importAndFillComplete, SCALAR, LO, GO, NODE ) \ TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT( BlockCrsMatrix, ExportDiffRowMaps, SCALAR, LO, GO, NODE ) \ TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT( BlockCrsMatrix, point2block, SCALAR, LO, GO, NODE ) \ TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT( BlockCrsMatrix, block2point, SCALAR, LO, GO, NODE ) diff --git a/packages/tpetra/core/test/ImportExport2/AsyncTransfer_UnitTests.cpp b/packages/tpetra/core/test/ImportExport2/AsyncTransfer_UnitTests.cpp index 6780d2459ee1..1dfad1f7ecd4 100644 --- a/packages/tpetra/core/test/ImportExport2/AsyncTransfer_UnitTests.cpp +++ b/packages/tpetra/core/test/ImportExport2/AsyncTransfer_UnitTests.cpp @@ -289,99 +289,6 @@ namespace { std::vector> targetMVs; }; - template - class MultiVectorCyclicGroupTransferFixture { - private: - using map_type = Map; - using mv_type = MultiVector; - - public: - MultiVectorCyclicGroupTransferFixture(FancyOStream& o, bool& s) - : out(o), - success(s), - comm(getDefaultComm()), - numProcs(comm->getSize()), - myRank(comm->getRank()), - numMVs(4) - { } - - ~MultiVectorCyclicGroupTransferFixture() { } - - void setup(int collectRank) { - setupMaps(collectRank); - setupMultiVectors(); - } - - template - void performTransfer(const TransferMethod& transfer) { - transfer(sourceMVs, targetMVs); - for (int i=0; iimportsAreAliased(), false); - } - } - - template - void checkResults(const ReferenceSolution& referenceSolution) { - for (int i=0; i referenceMV = referenceSolution.generateWithClassicalCodePath(sourceMVs[i], targetMap); - compareMultiVectors(targetMVs[i], referenceMV); - } - } - - private: - void setupMaps(int collectRank) { - const GO indexBase = 0; - const global_size_t INVALID = OrdinalTraits::invalid(); - - const size_t sourceNumLocalElements = 3; - const size_t totalElements = numProcs*sourceNumLocalElements; - const size_t targetNumLocalElements = (myRank == collectRank) ? totalElements : 0; - - Teuchos::Array sourceEntries(sourceNumLocalElements); - for (size_t i=0; irandomize(); - - targetMVs.push_back(rcp(new mv_type(targetMap, 1))); - targetMVs[i]->putScalar(ScalarTraits::zero()); - } - } - - void compareMultiVectors(RCP resultMV, RCP referenceMV) { - auto data = resultMV->getLocalViewHost(Tpetra::Access::ReadOnly); - auto referenceData = referenceMV->getLocalViewHost(Tpetra::Access::ReadOnly); - - TEST_EQUALITY(data.size(), referenceData.size()); - for (LO localRow = 0; localRow < as(data.size()); localRow++) { - TEST_EQUALITY(data(localRow, 0), referenceData(localRow, 0)); - } - } - - FancyOStream& out; - bool& success; - - RCP> comm; - const int numProcs; - const int myRank; - - const int numMVs; - - RCP sourceMap; - RCP targetMap; - - std::vector> sourceMVs; - std::vector> targetMVs; - }; - template class DiagonalCrsMatrixTransferFixture { private: @@ -1086,6 +993,34 @@ namespace { } + template + class ForwardImportGroup { + private: + using DistObjectRCP = RCP>; + + public: + void operator()(std::vector& sources, std::vector& targets) const { + Import importer(sources[0]->getMap(), targets[0]->getMap()); + + for (unsigned i=0; ibeginImport(*sources[i], importer, INSERT); + } + + unsigned completedImports = 0; + std::vector completedImport(sources.size(), false); + while (completedImports < completedImport.size()) { + for (unsigned i=0; itransferArrived()) { + targets[i]->endImport(*sources[i], importer, INSERT); + completedImport[i] = true; + completedImports++; + } + } + } + } + }; + template class ContiguousMaps { private: @@ -1122,48 +1057,61 @@ namespace { RCP targetMap; }; - template - class ForwardImportGroup { + TEUCHOS_UNIT_TEST_TEMPLATE_3_DECL( AsyncForwardImport, MultiVectorGroup_ContiguousMaps_rank0, Scalar, LO, GO ) + { + MultiVectorGroupTransferFixture fixture(out, success); + + fixture.template setup>(0); + fixture.performTransfer(ForwardImportGroup()); + fixture.checkResults(ReferenceImportMultiVector()); + } + + template + class CyclicMaps { private: - using DistObjectRCP = RCP>; + using map_type = Map; public: - void operator()(std::vector& sources, std::vector& targets) const { - Import importer(sources[0]->getMap(), targets[0]->getMap()); + CyclicMaps(RCP> c) + : comm(c), + numProcs(comm->getSize()), + myRank(comm->getRank()) + { } - for (unsigned i=0; ibeginImport(*sources[i], importer, INSERT); - } + void setup(int collectRank) { + const GO indexBase = 0; + const global_size_t INVALID = OrdinalTraits::invalid(); - unsigned completedImports = 0; - std::vector completedImport(sources.size(), false); - while (completedImports < completedImport.size()) { - for (unsigned i=0; itransferArrived()) { - targets[i]->endImport(*sources[i], importer, INSERT); - completedImport[i] = true; - completedImports++; - } - } + const size_t sourceNumLocalElements = 3; + const size_t totalElements = numProcs*sourceNumLocalElements; + const size_t targetNumLocalElements = (myRank == collectRank) ? totalElements : 0; + + Teuchos::Array sourceEntries(sourceNumLocalElements); + for (size_t i=0; i fixture(out, success); + RCP getSourceMap() { return sourceMap; } + RCP getTargetMap() { return targetMap; } - fixture.template setup>(0); - fixture.performTransfer(ForwardImportGroup()); - fixture.checkResults(ReferenceImportMultiVector()); - } + private: + RCP> comm; + const int numProcs; + const int myRank; - TEUCHOS_UNIT_TEST_TEMPLATE_3_DECL( AsyncForwardImport, MultiVectorCyclicGroup_rank0, Scalar, LO, GO ) + RCP sourceMap; + RCP targetMap; + }; + + TEUCHOS_UNIT_TEST_TEMPLATE_3_DECL( AsyncForwardImport, MultiVectorGroup_CyclicMaps_rank0, Scalar, LO, GO ) { - MultiVectorCyclicGroupTransferFixture fixture(out, success); + MultiVectorGroupTransferFixture fixture(out, success); - fixture.setup(0); + fixture.template setup>(0); fixture.performTransfer(ForwardImportGroup()); fixture.checkResults(ReferenceImportMultiVector()); } @@ -1198,7 +1146,7 @@ namespace { TEUCHOS_UNIT_TEST_TEMPLATE_3_INSTANT( TransferArrived, MultiVector_forwardImportFalse, SC, LO, GO ) \ TEUCHOS_UNIT_TEST_TEMPLATE_3_INSTANT( TransferArrived, MultiVector_forwardExportFalse, SC, LO, GO ) \ TEUCHOS_UNIT_TEST_TEMPLATE_3_INSTANT( AsyncForwardImport, MultiVectorGroup_ContiguousMaps_rank0, SC, LO, GO ) \ - TEUCHOS_UNIT_TEST_TEMPLATE_3_INSTANT( AsyncForwardImport, MultiVectorCyclicGroup_rank0, SC, LO, GO ) \ + TEUCHOS_UNIT_TEST_TEMPLATE_3_INSTANT( AsyncForwardImport, MultiVectorGroup_CyclicMaps_rank0, SC, LO, GO ) \ TPETRA_ETI_MANGLING_TYPEDEFS() diff --git a/packages/tpetra/core/test/KokkosIntegration/CMakeLists.txt b/packages/tpetra/core/test/KokkosIntegration/CMakeLists.txt index ea0b7d2fb933..e97ec317b559 100644 --- a/packages/tpetra/core/test/KokkosIntegration/CMakeLists.txt +++ b/packages/tpetra/core/test/KokkosIntegration/CMakeLists.txt @@ -1,7 +1,9 @@ ASSERT_DEFINED (Tpetra_ENABLE_CUDA) -IF (Tpetra_ENABLE_CUDA AND BUILD_SHARED_LIBS) +ASSERT_DEFINED (Tpetra_ENABLE_KokkosIntegrationTest) + +IF (Tpetra_ENABLE_CUDA AND BUILD_SHARED_LIBS AND Tpetra_ENABLE_KokkosIntegrationTest) MESSAGE(STATUS "Tpetra: Enabling KokkosIntegration Tests") TRIBITS_ADD_LIBRARY( diff --git a/packages/xpetra/src/BlockedCrsMatrix/Xpetra_BlockedCrsMatrix.hpp b/packages/xpetra/src/BlockedCrsMatrix/Xpetra_BlockedCrsMatrix.hpp index bb9d38caf18b..4c7893e0734b 100644 --- a/packages/xpetra/src/BlockedCrsMatrix/Xpetra_BlockedCrsMatrix.hpp +++ b/packages/xpetra/src/BlockedCrsMatrix/Xpetra_BlockedCrsMatrix.hpp @@ -1520,6 +1520,9 @@ namespace Xpetra { return thbOp; } #endif + //! Returns the block size of the storage mechanism + LocalOrdinal GetStorageBlockSize() const {return 1;} + //! Compute a residual R = B - (*this) * X void residual(const MultiVector & X, diff --git a/packages/xpetra/src/CrsMatrix/Xpetra_CrsMatrix.hpp b/packages/xpetra/src/CrsMatrix/Xpetra_CrsMatrix.hpp index 407f1d45f31a..c0806e943ba8 100644 --- a/packages/xpetra/src/CrsMatrix/Xpetra_CrsMatrix.hpp +++ b/packages/xpetra/src/CrsMatrix/Xpetra_CrsMatrix.hpp @@ -352,6 +352,9 @@ namespace Xpetra { //! Does this have an underlying matrix virtual bool hasMatrix() const = 0; + //! Returns the block size of the storage mechanism, which is usually 1, except for Tpetra::BlockCrsMatrix + virtual LocalOrdinal GetStorageBlockSize() const = 0; + //! Compute a residual R = B - (*this) * X virtual void residual(const MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > & X, const MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > & B, diff --git a/packages/xpetra/src/CrsMatrix/Xpetra_CrsMatrixFactory.hpp b/packages/xpetra/src/CrsMatrix/Xpetra_CrsMatrixFactory.hpp index d7ad7d410d61..9a81d816f02b 100644 --- a/packages/xpetra/src/CrsMatrix/Xpetra_CrsMatrixFactory.hpp +++ b/packages/xpetra/src/CrsMatrix/Xpetra_CrsMatrixFactory.hpp @@ -52,6 +52,7 @@ #ifdef HAVE_XPETRA_TPETRA #include "Xpetra_TpetraCrsMatrix.hpp" +#include "Xpetra_TpetraBlockCrsMatrix.hpp" #endif #ifdef HAVE_XPETRA_EPETRA @@ -289,6 +290,25 @@ namespace Xpetra { XPETRA_FACTORY_END; } #endif + + // Builds a BlockCrsMatrix + static RCP > BuildBlock ( + const Teuchos::RCP >& blockGraph, + const Teuchos::RCP >& domainMap, + const Teuchos::RCP >& rangeMap, + LocalOrdinal blockSize) { + + XPETRA_MONITOR("CrsMatrixFactory::BuildBlock"); + +#ifdef HAVE_XPETRA_TPETRA + if (domainMap->lib() == UseTpetra) { + return rcp(new Xpetra::TpetraBlockCrsMatrix(blockGraph,domainMap,rangeMap,blockSize) ); + } +#endif + TEUCHOS_TEST_FOR_EXCEPTION(domainMap->lib() == UseEpetra, std::logic_error, "Epetra doesn't support this matrix constructor"); + + XPETRA_FACTORY_END; + } }; @@ -318,10 +338,9 @@ namespace Xpetra { if (rowMap->lib() == UseTpetra) return rcp( new TpetraCrsMatrix(rowMap, 0) ); #endif -#ifdef HAVE_XPETRA_EPETRA if(rowMap->lib() == UseEpetra) return rcp( new EpetraCrsMatrixT(rowMap)); -#endif + XPETRA_FACTORY_END; } @@ -543,6 +562,23 @@ namespace Xpetra { } #endif + //! Build a BlockCrsMatrix + static RCP > BuildBlock ( + const Teuchos::RCP >& blockGraph, + const Teuchos::RCP >& domainMap, + const Teuchos::RCP >& rangeMap, + LocalOrdinal blockSize) { + + XPETRA_MONITOR("CrsMatrixFactory::BuildBlock"); +#ifdef HAVE_XPETRA_TPETRA + if (domainMap->lib() == UseTpetra) + return rcp(new Xpetra::TpetraBlockCrsMatrix(blockGraph,domainMap,rangeMap,blockSize) ); +#endif + TEUCHOS_TEST_FOR_EXCEPTION(domainMap->lib() == UseEpetra, std::logic_error, "Epetra doesn't support this matrix constructor"); + + XPETRA_FACTORY_END; + } + }; #endif @@ -772,6 +808,28 @@ namespace Xpetra { } #endif + + //! Build a BlockCrsMatrix + static RCP > BuildBlock ( + const Teuchos::RCP >& blockGraph, + const Teuchos::RCP >& domainMap, + const Teuchos::RCP >& rangeMap, + LocalOrdinal blockSize) { + + XPETRA_MONITOR("CrsMatrixFactory::BuildBlock"); + +#ifdef HAVE_XPETRA_TPETRA + if (domainMap->lib() == UseTpetra) { + return rcp(new Xpetra::TpetraBlockCrsMatrix(blockGraph,domainMap,rangemap,blockSize) ); + } +#endif + TEUCHOS_TEST_FOR_EXCEPTION(domainMap->lib() == UseEpetra, std::logic_error, "Epetra doesn't support this matrix constructor"); + + XPETRA_FACTORY_END; + } + + + }; #endif diff --git a/packages/xpetra/src/CrsMatrix/Xpetra_EpetraCrsMatrix.hpp b/packages/xpetra/src/CrsMatrix/Xpetra_EpetraCrsMatrix.hpp index 9b9321049333..42dc9b4d390a 100644 --- a/packages/xpetra/src/CrsMatrix/Xpetra_EpetraCrsMatrix.hpp +++ b/packages/xpetra/src/CrsMatrix/Xpetra_EpetraCrsMatrix.hpp @@ -258,6 +258,9 @@ local_matrix_type getLocalMatrixDevice () const { TEUCHOS_TEST_FOR_EXCEPTION(true, Xpetra::Exceptions::RuntimeError, "Xpetra::EpetraCrsMatrix only available for GO=int or GO=long long with EpetraNode (Serial or OpenMP depending on configuration)"); } + + LocalOrdinal GetStorageBlockSize() const {return 1;} + #else #ifdef __GNUC__ #warning "Xpetra Kokkos interface for CrsMatrix is enabled (HAVE_XPETRA_KOKKOS_REFACTOR) but Tpetra is disabled. The Kokkos interface needs Tpetra to be enabled, too." @@ -265,6 +268,7 @@ local_matrix_type getLocalMatrixDevice () const { #endif #endif + void residual(const MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > & X, const MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > & B, MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > & R) const{ @@ -1305,6 +1309,7 @@ typename local_matrix_type::HostMirror getLocalMatrixHost () const { } + LocalOrdinal GetStorageBlockSize() const {return 1;} private: #else @@ -1315,6 +1320,8 @@ typename local_matrix_type::HostMirror getLocalMatrixHost () const { #endif //@} + + void residual(const MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > & X, const MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > & B, MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > & R) const { @@ -2348,6 +2355,8 @@ class EpetraCrsMatrixT } + + LocalOrdinal GetStorageBlockSize() const {return 1;} private: #else @@ -2356,6 +2365,8 @@ class EpetraCrsMatrixT #endif #endif #endif + + //@} void residual(const MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > & X, diff --git a/packages/xpetra/src/CrsMatrix/Xpetra_TpetraBlockCrsMatrix_decl.hpp b/packages/xpetra/src/CrsMatrix/Xpetra_TpetraBlockCrsMatrix_decl.hpp index 1db4e9d5bd3f..2dc8bccf794b 100644 --- a/packages/xpetra/src/CrsMatrix/Xpetra_TpetraBlockCrsMatrix_decl.hpp +++ b/packages/xpetra/src/CrsMatrix/Xpetra_TpetraBlockCrsMatrix_decl.hpp @@ -113,6 +113,12 @@ namespace Xpetra { TpetraBlockCrsMatrix(const Teuchos::RCP< const CrsGraph< LocalOrdinal, GlobalOrdinal, Node> > &graph, const LocalOrdinal blockSize); + //! Constructor specifying a previously constructed graph, point maps & blocksize + TpetraBlockCrsMatrix(const Teuchos::RCP< const CrsGraph< LocalOrdinal, GlobalOrdinal, Node> > &graph, + const Teuchos::RCP >& pointDomainMap, + const Teuchos::RCP >& pointRangeMap, + const LocalOrdinal blockSize); + //! Constructor for a fused import ( not implemented ) TpetraBlockCrsMatrix(const Teuchos::RCP >& sourceMatrix, @@ -410,6 +416,9 @@ namespace Xpetra { #endif // HAVE_XPETRA_TPETRA #endif // HAVE_XPETRA_KOKKOS_REFACTOR + //! Returns the block size of the storage mechanism + LocalOrdinal GetStorageBlockSize() const {return mtx_->getBlockSize();} + //! Compute a residual R = B - (*this) * X void residual(const MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > & X, const MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > & B, diff --git a/packages/xpetra/src/CrsMatrix/Xpetra_TpetraBlockCrsMatrix_def.hpp b/packages/xpetra/src/CrsMatrix/Xpetra_TpetraBlockCrsMatrix_def.hpp index 410dbddfab76..fd069dd54356 100644 --- a/packages/xpetra/src/CrsMatrix/Xpetra_TpetraBlockCrsMatrix_def.hpp +++ b/packages/xpetra/src/CrsMatrix/Xpetra_TpetraBlockCrsMatrix_def.hpp @@ -46,7 +46,9 @@ #ifndef XPETRA_TPETRABLOCKCRSMATRIX_DEF_HPP #define XPETRA_TPETRABLOCKCRSMATRIX_DEF_HPP + #include "Xpetra_TpetraBlockCrsMatrix_decl.hpp" +#include "Xpetra_TpetraCrsGraph.hpp" namespace Xpetra { @@ -118,6 +120,17 @@ namespace Xpetra { { } + //! Constructor specifying a previously constructed graph, point maps & blocksize + template + TpetraBlockCrsMatrix:: + TpetraBlockCrsMatrix(const Teuchos::RCP< const CrsGraph< LocalOrdinal, GlobalOrdinal, Node> > &graph, + const Teuchos::RCP >& pointDomainMap, + const Teuchos::RCP >& pointRangeMap, + const LocalOrdinal blockSize) + : mtx_(Teuchos::rcp(new Tpetra::BlockCrsMatrix(*toTpetra(graph), *toTpetra(pointDomainMap), *toTpetra(pointRangeMap),blockSize))) + { } + + //! Constructor for a fused import ( not implemented ) template TpetraBlockCrsMatrix:: @@ -377,7 +390,12 @@ namespace Xpetra { TpetraBlockCrsMatrix:: getCrsGraph() const { - throw std::runtime_error("Xpetra::TpetraBlockCrsMatrix function not implemented in "+std::string(__FILE__)+":"+std::to_string(__LINE__)); + XPETRA_MONITOR("TpetraBlockCrsMatrix::getCrsGraph"); + using G_t = Tpetra::CrsGraph; + using G_x = TpetraCrsGraph; + RCP t_graph = Teuchos::rcp_const_cast(Teuchos::rcpFromRef(mtx_->getCrsGraph())); + RCP x_graph = rcp(new G_x(t_graph)); + return x_graph; } diff --git a/packages/xpetra/src/CrsMatrix/Xpetra_TpetraCrsMatrix_decl.hpp b/packages/xpetra/src/CrsMatrix/Xpetra_TpetraCrsMatrix_decl.hpp index d5e94fee5ade..b6fa11b68eae 100644 --- a/packages/xpetra/src/CrsMatrix/Xpetra_TpetraCrsMatrix_decl.hpp +++ b/packages/xpetra/src/CrsMatrix/Xpetra_TpetraCrsMatrix_decl.hpp @@ -453,6 +453,9 @@ namespace Xpetra { #endif #endif + //! Returns the block size of the storage mechanism, which is usually 1, except for Tpetra::BlockCrsMatrix + LocalOrdinal GetStorageBlockSize() const {return 1;} + //! Compute a residual R = B - (*this) * X void residual(const MultiVector & X, const MultiVector & B, @@ -860,6 +863,9 @@ namespace Xpetra { #endif #endif + //! Returns the block size of the storage mechanism, which is usually 1, except for Tpetra::BlockCrsMatrix + LocalOrdinal GetStorageBlockSize() const {return 1;} + //! Compute a residual R = B - (*this) * X void residual(const MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > & X, const MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > & B, @@ -1263,6 +1269,9 @@ namespace Xpetra { #endif #endif + //! Returns the block size of the storage mechanism, which is usually 1, except for Tpetra::BlockCrsMatrix + LocalOrdinal GetStorageBlockSize() const {return 1;} + //! Compute a residual R = B - (*this) * X void residual(const MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > & X, const MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > & B, diff --git a/packages/xpetra/sup/Matrix/Xpetra_CrsMatrixWrap_decl.hpp b/packages/xpetra/sup/Matrix/Xpetra_CrsMatrixWrap_decl.hpp index c30d1a00991f..415acfa6c940 100644 --- a/packages/xpetra/sup/Matrix/Xpetra_CrsMatrixWrap_decl.hpp +++ b/packages/xpetra/sup/Matrix/Xpetra_CrsMatrixWrap_decl.hpp @@ -480,12 +480,18 @@ class CrsMatrixWrap : RCP getCrsMatrix() const; + //! Returns the block size of the storage mechanism, which is usually 1, except for Tpetra::BlockCrsMatrix + LocalOrdinal GetStorageBlockSize() const; + //! Compute a residual R = B - (*this) * X void residual(const MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > & X, const MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > & B, MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > & R) const; + //! Expert only + void replaceCrsMatrix(RCP & M); + //@} private: diff --git a/packages/xpetra/sup/Matrix/Xpetra_CrsMatrixWrap_def.hpp b/packages/xpetra/sup/Matrix/Xpetra_CrsMatrixWrap_def.hpp index c3914921aa83..8a1d48d3b181 100644 --- a/packages/xpetra/sup/Matrix/Xpetra_CrsMatrixWrap_def.hpp +++ b/packages/xpetra/sup/Matrix/Xpetra_CrsMatrixWrap_def.hpp @@ -497,6 +497,28 @@ namespace Xpetra { } + // Expert only + template + void CrsMatrixWrap::replaceCrsMatrix(RCP & M) { + // Clear the old view table + Teuchos::Hashtable > dummy_table; + Matrix::operatorViewTable_ = dummy_table; + + finalDefaultView_ = M->isFillComplete(); + // Set matrix data + matrixData_ = M; + + + // Default view + CreateDefaultView(); + } + + + template + LocalOrdinal CrsMatrixWrap::GetStorageBlockSize() const { + return matrixData_->GetStorageBlockSize(); + } + template void CrsMatrixWrap::residual( const MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > & X, diff --git a/packages/xpetra/sup/Matrix/Xpetra_Matrix.hpp b/packages/xpetra/sup/Matrix/Xpetra_Matrix.hpp index b7f9432f054f..69c11f74ccec 100644 --- a/packages/xpetra/sup/Matrix/Xpetra_Matrix.hpp +++ b/packages/xpetra/sup/Matrix/Xpetra_Matrix.hpp @@ -559,11 +559,17 @@ namespace Xpetra { return 1; }; //TODO: why LocalOrdinal? + //! Returns true, if `SetFixedBlockSize` has been called before. bool IsFixedBlockSizeSet() const { return IsView("stridedMaps"); }; + + //! Returns the block size of the storage mechanism, which is usually 1, except for Tpetra::BlockCrsMatrix + virtual LocalOrdinal GetStorageBlockSize() const = 0; + + // ---------------------------------------------------------------------------------- virtual void SetMaxEigenvalueEstimate(Scalar const &sigma) { diff --git a/packages/xpetra/sup/Utils/Xpetra_IO.hpp b/packages/xpetra/sup/Utils/Xpetra_IO.hpp index 364b6743fdb4..da863a390f23 100644 --- a/packages/xpetra/sup/Utils/Xpetra_IO.hpp +++ b/packages/xpetra/sup/Utils/Xpetra_IO.hpp @@ -313,6 +313,15 @@ namespace Xpetra { Tpetra::MatrixMarket::Writer >::writeSparseFile(fileName, A); return; } + const RCP >& tmp_BlockCrs = + Teuchos::rcp_dynamic_cast >(tmp_CrsMtx); + if(tmp_BlockCrs != Teuchos::null) { + std::ofstream outstream (fileName,std::ofstream::out); + Teuchos::FancyOStream ofs(Teuchos::rcpFromRef(outstream)); + tmp_BlockCrs->getTpetra_BlockCrsMatrix()->describe(ofs,Teuchos::VERB_EXTREME); + return; + } + #endif // HAVE_XPETRA_TPETRA throw Exceptions::BadCast("Could not cast to EpetraCrsMatrix or TpetraCrsMatrix in matrix writing"); @@ -1037,6 +1046,15 @@ namespace Xpetra { Tpetra::MatrixMarket::Writer >::writeSparseFile(fileName, A); return; } + const RCP >& tmp_BlockCrs = + Teuchos::rcp_dynamic_cast >(tmp_CrsMtx); + if(tmp_BlockCrs != Teuchos::null) { + std::ofstream outstream (fileName,std::ofstream::out); + Teuchos::FancyOStream ofs(Teuchos::rcpFromRef(outstream)); + tmp_BlockCrs->getTpetra_BlockCrsMatrix()->describe(ofs,Teuchos::VERB_EXTREME); + return; + } + # endif #endif // HAVE_XPETRA_TPETRA diff --git a/packages/xpetra/sup/Utils/Xpetra_MatrixMatrix.hpp b/packages/xpetra/sup/Utils/Xpetra_MatrixMatrix.hpp index 5fad0246aafe..217c590bf912 100644 --- a/packages/xpetra/sup/Utils/Xpetra_MatrixMatrix.hpp +++ b/packages/xpetra/sup/Utils/Xpetra_MatrixMatrix.hpp @@ -75,6 +75,7 @@ #include #include #include +#include #include #include #endif // HAVE_XPETRA_TPETRA @@ -254,25 +255,25 @@ Note: this class is not in the Xpetra_UseShortNames.hpp return tmp_BlockCrs->getTpetra_BlockCrsMatrixNonConst(); } - static RCP > Op2TpetraBlockCrs(const Matrix& Op) { + static const Tpetra::BlockCrsMatrix & Op2TpetraBlockCrs(const Matrix& Op) { try { const CrsMatrixWrap& crsOp = dynamic_cast(Op); RCP tmp_CrsMtx = crsOp.getCrsMatrix(); RCP tmp_BlockCrs= Teuchos::rcp_dynamic_cast(tmp_CrsMtx); TEUCHOS_TEST_FOR_EXCEPTION(tmp_BlockCrs == Teuchos::null, Xpetra::Exceptions::BadCast, "Cast from Xpetra::CrsMatrix to Xpetra::TpetraBlockCrsMatrix failed"); - return tmp_BlockCrs->getTpetra_BlockCrsMatrix(); + return *tmp_BlockCrs->getTpetra_BlockCrsMatrix(); } catch(...) { throw(Xpetra::Exceptions::BadCast("Cast from Xpetra::Matrix to Xpetra::CrsMatrixWrap failed")); } } - static RCP > Op2NonTpetraBlockCrs(const Matrix& Op) { + static Tpetra::BlockCrsMatrix & Op2NonConstTpetraBlockCrs(const Matrix& Op) { try { const CrsMatrixWrap& crsOp = dynamic_cast(Op); RCP tmp_CrsMtx = crsOp.getCrsMatrix(); RCP tmp_BlockCrs= Teuchos::rcp_dynamic_cast(tmp_CrsMtx); TEUCHOS_TEST_FOR_EXCEPTION(tmp_BlockCrs == Teuchos::null, Xpetra::Exceptions::BadCast, "Cast from Xpetra::CrsMatrix to Xpetra::TpetraBlockCrsMatrix failed"); - return tmp_BlockCrs->getTpetra_BlockCrsMatrixNonConst(); + return *tmp_BlockCrs->getTpetra_BlockCrsMatrixNonConst(); } catch(...) { throw(Xpetra::Exceptions::BadCast("Cast from Xpetra::Matrix to Xpetra::CrsMatrixWrap failed")); } @@ -298,7 +299,14 @@ Note: this class is not in the Xpetra_UseShortNames.hpp return false; } } +#else // HAVE_XPETRA_TPETRA + static bool isTpetraCrs(const Matrix& Op) { + return false; + } + static bool isTpetraBlockCrs(const Matrix& Op) { + return false; + } #endif // HAVE_XPETRA_TPETRA @@ -491,9 +499,38 @@ Note: this class is not in the Xpetra_UseShortNames.hpp // Previously, Tpetra's matrix matrix multiply did not support fillComplete. Tpetra::MatrixMatrix::Multiply(tpA, transposeA, tpB, transposeB, tpC, haveMultiplyDoFillComplete, label, params); } - else if (helpers::isTpetraBlockCrs(A) && helpers::isTpetraBlockCrs(B) && helpers::isTpetraBlockCrs(C)) { - // All matrices are BlockCrs - TEUCHOS_TEST_FOR_EXCEPTION(1, Exceptions::RuntimeError, "BlockCrs Multiply not currently supported"); + else if (helpers::isTpetraBlockCrs(A) && helpers::isTpetraBlockCrs(B)) { + // All matrices are BlockCrs (except maybe Ac) + // FIXME: For the moment we're just going to clobber the innards of Ac, so no reuse. Once we have a reuse kernel, + // we'll need to think about refactoring BlockCrs so we can do something smarter here. + if(!A.getRowMap()->getComm()->getRank()) + std::cout<<"WARNING: Using inefficient BlockCrs Multiply Placeholder"< & tpA = Xpetra::Helpers::Op2TpetraBlockCrs(A); + const Tpetra::BlockCrsMatrix & tpB = Xpetra::Helpers::Op2TpetraBlockCrs(B); + using CRS=Tpetra::CrsMatrix; + RCP Acrs = Tpetra::convertToCrsMatrix(tpA); + RCP Bcrs = Tpetra::convertToCrsMatrix(tpB); + + // We need the global constants to do the copy back to BlockCrs + RCP new_params; + if(!params.is_null()) { + new_params = rcp(new Teuchos::ParameterList(*params)); + new_params->set("compute global constants",true); + } + + // FIXME: The lines below only works because we're assuming Ac is Point + RCP tempAc = Teuchos::rcp(new CRS(Acrs->getRowMap(),0)); + Tpetra::MatrixMatrix::Multiply(*Acrs, transposeA, *Bcrs, transposeB, *tempAc, haveMultiplyDoFillComplete, label, new_params); + + // Temporary output matrix + RCP > Ac_t = Tpetra::convertToBlockCrsMatrix(*tempAc,A.GetStorageBlockSize()); + RCP > Ac_x = Teuchos::rcp(new Xpetra::TpetraBlockCrsMatrix(Ac_t)); + RCP > Ac_p = Ac_x; + + // We can now cheat and replace the innards of Ac + RCP > Ac_w = Teuchos::rcp_dynamic_cast >(Teuchos::rcpFromRef(C)); + Ac_w->replaceCrsMatrix(Ac_p); } else { // Mix and match @@ -1022,9 +1059,39 @@ Note: this class is not in the Xpetra_UseShortNames.hpp // Previously, Tpetra's matrix matrix multiply did not support fillComplete. Tpetra::MatrixMatrix::Multiply(tpA, transposeA, tpB, transposeB, tpC, haveMultiplyDoFillComplete, label, params); } - else if (helpers::isTpetraBlockCrs(A) && helpers::isTpetraBlockCrs(B) && helpers::isTpetraBlockCrs(C)) { - // All matrices are BlockCrs - TEUCHOS_TEST_FOR_EXCEPTION(1, Exceptions::RuntimeError, "BlockCrs Multiply not currently supported"); + else if (helpers::isTpetraBlockCrs(A) && helpers::isTpetraBlockCrs(B)) { + // All matrices are BlockCrs (except maybe Ac) + // FIXME: For the moment we're just going to clobber the innards of Ac, so no reuse. Once we have a reuse kernel, + // we'll need to think about refactoring BlockCrs so we can do something smarter here. + + if(!A.getRowMap()->getComm()->getRank()) + std::cout<<"WARNING: Using inefficient BlockCrs Multiply Placeholder"< & tpA = Xpetra::Helpers::Op2TpetraBlockCrs(A); + const Tpetra::BlockCrsMatrix & tpB = Xpetra::Helpers::Op2TpetraBlockCrs(B); + using CRS=Tpetra::CrsMatrix; + RCP Acrs = Tpetra::convertToCrsMatrix(tpA); + RCP Bcrs = Tpetra::convertToCrsMatrix(tpB); + + // We need the global constants to do the copy back to BlockCrs + RCP new_params; + if(!params.is_null()) { + new_params = rcp(new Teuchos::ParameterList(*params)); + new_params->set("compute global constants",true); + } + + // FIXME: The lines below only works because we're assuming Ac is Point + RCP tempAc = Teuchos::rcp(new CRS(Acrs->getRowMap(),0)); + Tpetra::MatrixMatrix::Multiply(*Acrs, transposeA, *Bcrs, transposeB, *tempAc, haveMultiplyDoFillComplete, label, new_params); + + // Temporary output matrix + RCP > Ac_t = Tpetra::convertToBlockCrsMatrix(*tempAc,A.GetStorageBlockSize()); + RCP > Ac_x = Teuchos::rcp(new Xpetra::TpetraBlockCrsMatrix(Ac_t)); + RCP > Ac_p = Ac_x; + + // We can now cheat and replace the innards of Ac + RCP > Ac_w = Teuchos::rcp_dynamic_cast >(Teuchos::rcpFromRef(C)); + Ac_w->replaceCrsMatrix(Ac_p); } else { // Mix and match @@ -1787,9 +1854,39 @@ Note: this class is not in the Xpetra_UseShortNames.hpp // Previously, Tpetra's matrix matrix multiply did not support fillComplete. Tpetra::MatrixMatrix::Multiply(tpA, transposeA, tpB, transposeB, tpC, haveMultiplyDoFillComplete, label, params); } - else if (helpers::isTpetraBlockCrs(A) && helpers::isTpetraBlockCrs(B) && helpers::isTpetraBlockCrs(C)) { - // All matrices are BlockCrs - TEUCHOS_TEST_FOR_EXCEPTION(1, Exceptions::RuntimeError, "BlockCrs Multiply not currently supported"); + else if (helpers::isTpetraBlockCrs(A) && helpers::isTpetraBlockCrs(B)) { + // All matrices are BlockCrs (except maybe Ac) + // FIXME: For the moment we're just going to clobber the innards of Ac, so no reuse. Once we have a reuse kernel, + // we'll need to think about refactoring BlockCrs so we can do something smarter here. + + if(!A.getRowMap()->getComm()->getRank()) + std::cout<<"WARNING: Using inefficient BlockCrs Multiply Placeholder"< & tpA = Xpetra::Helpers::Op2TpetraBlockCrs(A); + const Tpetra::BlockCrsMatrix & tpB = Xpetra::Helpers::Op2TpetraBlockCrs(B); + using CRS=Tpetra::CrsMatrix; + RCP Acrs = Tpetra::convertToCrsMatrix(tpA); + RCP Bcrs = Tpetra::convertToCrsMatrix(tpB); + + // We need the global constants to do the copy back to BlockCrs + RCP new_params; + if(!params.is_null()) { + new_params = rcp(new Teuchos::ParameterList(*params)); + new_params->set("compute global constants",true); + } + + // FIXME: The lines below only works because we're assuming Ac is Point + RCP tempAc = Teuchos::rcp(new CRS(Acrs->getRowMap(),0)); + Tpetra::MatrixMatrix::Multiply(*Acrs, transposeA, *Bcrs, transposeB, *tempAc, haveMultiplyDoFillComplete, label, new_params); + + // Temporary output matrix + RCP > Ac_t = Tpetra::convertToBlockCrsMatrix(*tempAc,A.GetStorageBlockSize()); + RCP > Ac_x = Teuchos::rcp(new Xpetra::TpetraBlockCrsMatrix(Ac_t)); + RCP > Ac_p = Ac_x; + + // We can now cheat and replace the innards of Ac + RCP > Ac_w = Teuchos::rcp_dynamic_cast >(Teuchos::rcpFromRef(C)); + Ac_w->replaceCrsMatrix(Ac_p); } else { // Mix and match diff --git a/packages/xpetra/sup/Utils/Xpetra_TripleMatrixMultiply.hpp b/packages/xpetra/sup/Utils/Xpetra_TripleMatrixMultiply.hpp index 3b5798fc62ce..d930088f3908 100644 --- a/packages/xpetra/sup/Utils/Xpetra_TripleMatrixMultiply.hpp +++ b/packages/xpetra/sup/Utils/Xpetra_TripleMatrixMultiply.hpp @@ -58,10 +58,13 @@ #include "Xpetra_Matrix.hpp" #include "Xpetra_StridedMapFactory.hpp" #include "Xpetra_StridedMap.hpp" +#include "Xpetra_IO.hpp" #ifdef HAVE_XPETRA_TPETRA #include #include +#include +#include // #include // #include #endif // HAVE_XPETRA_TPETRA @@ -126,14 +129,55 @@ namespace Xpetra { throw(Xpetra::Exceptions::RuntimeError("Xpetra::TripleMatrixMultiply::MultiplyRAP is only implemented for Tpetra")); } else if (Ac.getRowMap()->lib() == Xpetra::UseTpetra) { #ifdef HAVE_XPETRA_TPETRA - const Tpetra::CrsMatrix & tpR = Xpetra::Helpers::Op2TpetraCrs(R); - const Tpetra::CrsMatrix & tpA = Xpetra::Helpers::Op2TpetraCrs(A); - const Tpetra::CrsMatrix & tpP = Xpetra::Helpers::Op2TpetraCrs(P); - Tpetra::CrsMatrix & tpAc = Xpetra::Helpers::Op2NonConstTpetraCrs(Ac); - - // 18Feb2013 JJH I'm reenabling the code that allows the matrix matrix multiply to do the fillComplete. - // Previously, Tpetra's matrix matrix multiply did not support fillComplete. - Tpetra::TripleMatrixMultiply::MultiplyRAP(tpR, transposeR, tpA, transposeA, tpP, transposeP, tpAc, haveMultiplyDoFillComplete, label, params); + using helpers = Xpetra::Helpers; + if(helpers::isTpetraCrs(R) && helpers::isTpetraCrs(A) && helpers::isTpetraCrs(P)) { + // All matrices are Crs + const Tpetra::CrsMatrix & tpR = Xpetra::Helpers::Op2TpetraCrs(R); + const Tpetra::CrsMatrix & tpA = Xpetra::Helpers::Op2TpetraCrs(A); + const Tpetra::CrsMatrix & tpP = Xpetra::Helpers::Op2TpetraCrs(P); + Tpetra::CrsMatrix & tpAc = Xpetra::Helpers::Op2NonConstTpetraCrs(Ac); + + // 18Feb2013 JJH I'm reenabling the code that allows the matrix matrix multiply to do the fillComplete. + // Previously, Tpetra's matrix matrix multiply did not support fillComplete. + Tpetra::TripleMatrixMultiply::MultiplyRAP(tpR, transposeR, tpA, transposeA, tpP, transposeP, tpAc, haveMultiplyDoFillComplete, label, params); + } + else if (helpers::isTpetraBlockCrs(R) && helpers::isTpetraBlockCrs(A) && helpers::isTpetraBlockCrs(P)) { + // All matrices are BlockCrs (except maybe Ac) + // FIXME: For the moment we're just going to clobber the innards of Ac, so no reuse. Once we have a reuse kernel, + // we'll need to think about refactoring BlockCrs so we can do something smarter here. + + if(!A.getRowMap()->getComm()->getRank()) + std::cout<<"WARNING: Using inefficient BlockCrs Multiply Placeholder"< & tpR = Xpetra::Helpers::Op2TpetraBlockCrs(R); + const Tpetra::BlockCrsMatrix & tpA = Xpetra::Helpers::Op2TpetraBlockCrs(A); + const Tpetra::BlockCrsMatrix & tpP = Xpetra::Helpers::Op2TpetraBlockCrs(P); + // Tpetra::BlockCrsMatrix & tpAc = Xpetra::Helpers::Op2NonConstTpetraBlockCrs(Ac); + + using CRS=Tpetra::CrsMatrix; + RCP Rcrs = Tpetra::convertToCrsMatrix(tpR); + RCP Acrs = Tpetra::convertToCrsMatrix(tpA); + RCP Pcrs = Tpetra::convertToCrsMatrix(tpP); + // RCP Accrs = Tpetra::convertToCrsMatrix(tpAc); + + // FIXME: The lines below only works because we're assuming Ac is Point + RCP Accrs = Teuchos::rcp(new CRS(Rcrs->getRowMap(),0)); + const bool do_fill_complete=true; + Tpetra::TripleMatrixMultiply::MultiplyRAP(*Rcrs, transposeR, *Acrs, transposeA, *Pcrs, transposeP, *Accrs, do_fill_complete, label, params); + + // Temporary output matrix + RCP > Ac_t = Tpetra::convertToBlockCrsMatrix(*Accrs,A.GetStorageBlockSize()); + RCP > Ac_x = Teuchos::rcp(new Xpetra::TpetraBlockCrsMatrix(Ac_t)); + RCP > Ac_p = Ac_x; + + // We can now cheat and replace the innards of Ac + RCP > Ac_w = Teuchos::rcp_dynamic_cast >(Teuchos::rcpFromRef(Ac)); + Ac_w->replaceCrsMatrix(Ac_p); + } + else { + // Mix and match + TEUCHOS_TEST_FOR_EXCEPTION(1, Exceptions::RuntimeError, "Mix-and-match Crs/BlockCrs Multiply not currently supported"); + } #else throw(Xpetra::Exceptions::RuntimeError("Xpetra must be compiled with Tpetra.")); #endif @@ -215,14 +259,55 @@ namespace Xpetra { (!defined(EPETRA_HAVE_OMP) && (!defined(HAVE_TPETRA_INST_SERIAL) || !defined(HAVE_TPETRA_INST_INT_INT)))) throw(Xpetra::Exceptions::RuntimeError("Xpetra must be compiled with Tpetra ETI enabled.")); # else - const Tpetra::CrsMatrix & tpR = Xpetra::Helpers::Op2TpetraCrs(R); - const Tpetra::CrsMatrix & tpA = Xpetra::Helpers::Op2TpetraCrs(A); - const Tpetra::CrsMatrix & tpP = Xpetra::Helpers::Op2TpetraCrs(P); - Tpetra::CrsMatrix & tpAc = Xpetra::Helpers::Op2NonConstTpetraCrs(Ac); - - // 18Feb2013 JJH I'm reenabling the code that allows the matrix matrix multiply to do the fillComplete. - // Previously, Tpetra's matrix matrix multiply did not support fillComplete. - Tpetra::TripleMatrixMultiply::MultiplyRAP(tpR, transposeR, tpA, transposeA, tpP, transposeP, tpAc, haveMultiplyDoFillComplete, label, params); + using helpers = Xpetra::Helpers; + if(helpers::isTpetraCrs(R) && helpers::isTpetraCrs(A) && helpers::isTpetraCrs(P) && helpers::isTpetraCrs(Ac)) { + // All matrices are Crs + const Tpetra::CrsMatrix & tpR = Xpetra::Helpers::Op2TpetraCrs(R); + const Tpetra::CrsMatrix & tpA = Xpetra::Helpers::Op2TpetraCrs(A); + const Tpetra::CrsMatrix & tpP = Xpetra::Helpers::Op2TpetraCrs(P); + Tpetra::CrsMatrix & tpAc = Xpetra::Helpers::Op2NonConstTpetraCrs(Ac); + + // 18Feb2013 JJH I'm reenabling the code that allows the matrix matrix multiply to do the fillComplete. + // Previously, Tpetra's matrix matrix multiply did not support fillComplete. + Tpetra::TripleMatrixMultiply::MultiplyRAP(tpR, transposeR, tpA, transposeA, tpP, transposeP, tpAc, haveMultiplyDoFillComplete, label, params); + } + else if (helpers::isTpetraBlockCrs(R) && helpers::isTpetraBlockCrs(A) && helpers::isTpetraBlockCrs(P)) { + // All matrices are BlockCrs (except maybe Ac) + // FIXME: For the moment we're just going to clobber the innards of AC, so no reuse. Once we have a reuse kernel, + // we'll need to think about refactoring BlockCrs so we can do something smarter here. + if(!A.getRowMap()->getComm()->getRank()) + std::cout<<"WARNING: Using inefficient BlockCrs Multiply Placeholder"< & tpR = Xpetra::Helpers::Op2TpetraBlockCrs(R); + const Tpetra::BlockCrsMatrix & tpA = Xpetra::Helpers::Op2TpetraBlockCrs(A); + const Tpetra::BlockCrsMatrix & tpP = Xpetra::Helpers::Op2TpetraBlockCrs(P); + // Tpetra::BlockCrsMatrix & tpAc = Xpetra::Helpers::Op2NonConstTpetraBlockCrs(Ac); + + using CRS=Tpetra::CrsMatrix; + RCP Rcrs = Tpetra::convertToCrsMatrix(tpR); + RCP Acrs = Tpetra::convertToCrsMatrix(tpA); + RCP Pcrs = Tpetra::convertToCrsMatrix(tpP); + // RCP Accrs = Tpetra::convertToCrsMatrix(tpAc); + + // FIXME: The lines below only works because we're assuming Ac is Point + RCP Accrs = Teuchos::rcp(new CRS(Rcrs->getRowMap(),0)); + const bool do_fill_complete=true; + Tpetra::TripleMatrixMultiply::MultiplyRAP(*Rcrs, transposeR, *Acrs, transposeA, *Pcrs, transposeP, *Accrs, do_fill_complete, label, params); + + // Temporary output matrix + RCP > Ac_t = Tpetra::convertToBlockCrsMatrix(*Accrs,A.GetStorageBlockSize()); + RCP > Ac_x = Teuchos::rcp(new Xpetra::TpetraBlockCrsMatrix(Ac_t)); + RCP > Ac_p = Ac_x; + + // We can now cheat and replace the innards of Ac + RCP > Ac_w = Teuchos::rcp_dynamic_cast >(Teuchos::rcpFromRef(Ac)); + Ac_w->replaceCrsMatrix(Ac_p); + + } + else { + // Mix and match (not supported) + TEUCHOS_TEST_FOR_EXCEPTION(1, Exceptions::RuntimeError, "Mix-and-match Crs/BlockCrs Multiply not currently supported"); + } # endif #else throw(Xpetra::Exceptions::RuntimeError("Xpetra must be compiled with Tpetra.")); @@ -303,14 +388,55 @@ namespace Xpetra { (!defined(EPETRA_HAVE_OMP) && (!defined(HAVE_TPETRA_INST_SERIAL) || !defined(HAVE_TPETRA_INST_INT_LONG_LONG)))) throw(Xpetra::Exceptions::RuntimeError("Xpetra must be compiled with Tpetra ETI enabled.")); # else - const Tpetra::CrsMatrix & tpR = Xpetra::Helpers::Op2TpetraCrs(R); - const Tpetra::CrsMatrix & tpA = Xpetra::Helpers::Op2TpetraCrs(A); - const Tpetra::CrsMatrix & tpP = Xpetra::Helpers::Op2TpetraCrs(P); - Tpetra::CrsMatrix & tpAc = Xpetra::Helpers::Op2NonConstTpetraCrs(Ac); - - // 18Feb2013 JJH I'm reenabling the code that allows the matrix matrix multiply to do the fillComplete. - // Previously, Tpetra's matrix matrix multiply did not support fillComplete. - Tpetra::TripleMatrixMultiply::MultiplyRAP(tpR, transposeR, tpA, transposeA, tpP, transposeP, tpAc, haveMultiplyDoFillComplete, label, params); + using helpers = Xpetra::Helpers; + if(helpers::isTpetraCrs(R) && helpers::isTpetraCrs(A) && helpers::isTpetraCrs(P)) { + // All matrices are Crs + const Tpetra::CrsMatrix & tpR = Xpetra::Helpers::Op2TpetraCrs(R); + const Tpetra::CrsMatrix & tpA = Xpetra::Helpers::Op2TpetraCrs(A); + const Tpetra::CrsMatrix & tpP = Xpetra::Helpers::Op2TpetraCrs(P); + Tpetra::CrsMatrix & tpAc = Xpetra::Helpers::Op2NonConstTpetraCrs(Ac); + + // 18Feb2013 JJH I'm reenabling the code that allows the matrix matrix multiply to do the fillComplete. + // Previously, Tpetra's matrix matrix multiply did not support fillComplete. + Tpetra::TripleMatrixMultiply::MultiplyRAP(tpR, transposeR, tpA, transposeA, tpP, transposeP, tpAc, haveMultiplyDoFillComplete, label, params); + } + else if (helpers::isTpetraBlockCrs(R) && helpers::isTpetraBlockCrs(A) && helpers::isTpetraBlockCrs(P)) { + // All matrices are BlockCrs (except maybe Ac) + // FIXME: For the moment we're just going to clobber the innards of AC, so no reuse. Once we have a reuse kernel, + // we'll need to think about refactoring BlockCrs so we can do something smarter here. + if(!A.getRowMap()->getComm()->getRank()) + std::cout<<"WARNING: Using inefficient BlockCrs Multiply Placeholder"< & tpR = Xpetra::Helpers::Op2TpetraBlockCrs(R); + const Tpetra::BlockCrsMatrix & tpA = Xpetra::Helpers::Op2TpetraBlockCrs(A); + const Tpetra::BlockCrsMatrix & tpP = Xpetra::Helpers::Op2TpetraBlockCrs(P); + // Tpetra::BlockCrsMatrix & tpAc = Xpetra::Helpers::Op2NonConstTpetraBlockCrs(Ac); + + using CRS=Tpetra::CrsMatrix; + RCP Rcrs = Tpetra::convertToCrsMatrix(tpR); + RCP Acrs = Tpetra::convertToCrsMatrix(tpA); + RCP Pcrs = Tpetra::convertToCrsMatrix(tpP); + // RCP Accrs = Tpetra::convertToCrsMatrix(tpAc); + + // FIXME: The lines below only works because we're assuming Ac is Point + RCP Accrs = Teuchos::rcp(new CRS(Rcrs->getRowMap(),0)); + const bool do_fill_complete=true; + Tpetra::TripleMatrixMultiply::MultiplyRAP(*Rcrs, transposeR, *Acrs, transposeA, *Pcrs, transposeP, *Accrs, do_fill_complete, label, params); + + // Temporary output matrix + RCP > Ac_t = Tpetra::convertToBlockCrsMatrix(*Accrs,A.GetStorageBlockSize()); + RCP > Ac_x = Teuchos::rcp(new Xpetra::TpetraBlockCrsMatrix(Ac_t)); + RCP > Ac_p = Ac_x; + + // We can now cheat and replace the innards of Ac + RCP > Ac_w = Teuchos::rcp_dynamic_cast >(Teuchos::rcpFromRef(Ac)); + Ac_w->replaceCrsMatrix(Ac_p); + } + else { + // Mix and match + TEUCHOS_TEST_FOR_EXCEPTION(1, Exceptions::RuntimeError, "Mix-and-match Crs/BlockCrs Multiply not currently supported"); + } + # endif #else throw(Xpetra::Exceptions::RuntimeError("Xpetra must be compiled with Tpetra."));