diff --git a/cub/cmake/CubHeaderTesting.cmake b/cub/cmake/CubHeaderTesting.cmake index f0ca17186ce..fdf9be3be48 100644 --- a/cub/cmake/CubHeaderTesting.cmake +++ b/cub/cmake/CubHeaderTesting.cmake @@ -42,12 +42,14 @@ set(header_definitions "CUB_WRAPPED_NAMESPACE=wrapped_cub") cub_add_header_test(base "${header_definitions}") +# Check that BF16 support can be disabled set(header_definitions "THRUST_WRAPPED_NAMESPACE=wrapped_thrust" "CUB_WRAPPED_NAMESPACE=wrapped_cub" "CCCL_DISABLE_BF16_SUPPORT") cub_add_header_test(bf16 "${header_definitions}") +# Check that half support can be disabled set(header_definitions "THRUST_WRAPPED_NAMESPACE=wrapped_thrust" "CUB_WRAPPED_NAMESPACE=wrapped_cub" diff --git a/thrust/cmake/ThrustHeaderTesting.cmake b/thrust/cmake/ThrustHeaderTesting.cmake index ad438b0f879..4c1d07f744b 100644 --- a/thrust/cmake/ThrustHeaderTesting.cmake +++ b/thrust/cmake/ThrustHeaderTesting.cmake @@ -7,7 +7,7 @@ # Meta target for all configs' header builds: add_custom_target(thrust.all.headers) -foreach(thrust_target IN LISTS THRUST_TARGETS) +function(thrust_add_header_test thrust_target label definitions) thrust_get_target_property(config_host ${thrust_target} HOST) thrust_get_target_property(config_device ${thrust_target} DEVICE) thrust_get_target_property(config_prefix ${thrust_target} PREFIX) @@ -115,14 +115,10 @@ foreach(thrust_target IN LISTS THRUST_TARGETS) list(APPEND headertest_srcs "${headertest_src}") endforeach() - set(headertest_target ${config_prefix}.headers) + set(headertest_target ${config_prefix}.headers.${label}) add_library(${headertest_target} OBJECT ${headertest_srcs}) target_link_libraries(${headertest_target} PUBLIC ${thrust_target}) - # Wrap Thrust/CUB in a custom namespace to check proper use of ns macros: - target_compile_definitions(${headertest_target} PRIVATE - "THRUST_WRAPPED_NAMESPACE=wrapped_thrust" - "CUB_WRAPPED_NAMESPACE=wrapped_cub" - ) + target_compile_definitions(${headertest_target} PRIVATE ${header_definitions}) thrust_clone_target_properties(${headertest_target} ${thrust_target}) if ("CUDA" STREQUAL "${config_device}") @@ -141,4 +137,29 @@ foreach(thrust_target IN LISTS THRUST_TARGETS) add_dependencies(thrust.all.headers ${headertest_target}) add_dependencies(${config_prefix}.all ${headertest_target}) -endforeach() +endfunction() + +foreach(thrust_target IN LISTS THRUST_TARGETS) + # Wrap Thrust/CUB in a custom namespace to check proper use of ns macros: + set(header_definitions + "THRUST_WRAPPED_NAMESPACE=wrapped_thrust" + "CUB_WRAPPED_NAMESPACE=wrapped_cub") + thrust_add_header_test(${thrust_target} base "${header_definitions}") + + thrust_get_target_property(config_device ${thrust_target} DEVICE) + if ("CUDA" STREQUAL "${config_device}") + # Check that BF16 support can be disabled + set(header_definitions + "THRUST_WRAPPED_NAMESPACE=wrapped_thrust" + "CUB_WRAPPED_NAMESPACE=wrapped_cub" + "CCCL_DISABLE_BF16_SUPPORT") + thrust_add_header_test(${thrust_target} bf16 "${header_definitions}") + + # Check that half support can be disabled + set(header_definitions + "THRUST_WRAPPED_NAMESPACE=wrapped_thrust" + "CUB_WRAPPED_NAMESPACE=wrapped_cub" + "CCCL_DISABLE_FP16_SUPPORT") + thrust_add_header_test(${thrust_target} half "${header_definitions}") + endif() +endforeach () diff --git a/thrust/cmake/header_test.in b/thrust/cmake/header_test.in index 59e44e03c15..236cb9bde4e 100644 --- a/thrust/cmake/header_test.in +++ b/thrust/cmake/header_test.in @@ -64,3 +64,18 @@ #endif // THRUST_IGNORE_MACRO_CHECKS #include + +#if defined(CCCL_DISABLE_BF16_SUPPORT) +#if defined(__CUDA_BF16_TYPES_EXIST__) +#error Thrust should not include cuda_bf16.h when BF16 support is disabled +#endif // __CUDA_BF16_TYPES_EXIST__ +#endif // CCCL_DISABLE_BF16_SUPPORT + +#if defined(CCCL_DISABLE_FP16_SUPPORT) +#if defined(__CUDA_FP16_TYPES_EXIST__) +#error Thrust should not include cuda_fp16.h when half support is disabled +#endif // __CUDA_FP16_TYPES_EXIST__ +#if defined(__CUDA_BF16_TYPES_EXIST__) +#error Thrust should not include cuda_bf16.h when half support is disabled +#endif // __CUDA_BF16_TYPES_EXIST__ +#endif // CCCL_DISABLE_FP16_SUPPORT