Skip to content

Commit

Permalink
Test thrust headers for disabled half/bf16 support (#2219)
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber authored Aug 14, 2024
1 parent 352638b commit dded5f1
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 8 deletions.
2 changes: 2 additions & 0 deletions cub/cmake/CubHeaderTesting.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@ set(header_definitions
"CUB_WRAPPED_NAMESPACE=wrapped_cub")
cub_add_header_test(base "${header_definitions}")

# Check that BF16 support can be disabled
set(header_definitions
"THRUST_WRAPPED_NAMESPACE=wrapped_thrust"
"CUB_WRAPPED_NAMESPACE=wrapped_cub"
"CCCL_DISABLE_BF16_SUPPORT")
cub_add_header_test(bf16 "${header_definitions}")

# Check that half support can be disabled
set(header_definitions
"THRUST_WRAPPED_NAMESPACE=wrapped_thrust"
"CUB_WRAPPED_NAMESPACE=wrapped_cub"
Expand Down
37 changes: 29 additions & 8 deletions thrust/cmake/ThrustHeaderTesting.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Meta target for all configs' header builds:
add_custom_target(thrust.all.headers)

foreach(thrust_target IN LISTS THRUST_TARGETS)
function(thrust_add_header_test thrust_target label definitions)
thrust_get_target_property(config_host ${thrust_target} HOST)
thrust_get_target_property(config_device ${thrust_target} DEVICE)
thrust_get_target_property(config_prefix ${thrust_target} PREFIX)
Expand Down Expand Up @@ -115,14 +115,10 @@ foreach(thrust_target IN LISTS THRUST_TARGETS)
list(APPEND headertest_srcs "${headertest_src}")
endforeach()

set(headertest_target ${config_prefix}.headers)
set(headertest_target ${config_prefix}.headers.${label})
add_library(${headertest_target} OBJECT ${headertest_srcs})
target_link_libraries(${headertest_target} PUBLIC ${thrust_target})
# Wrap Thrust/CUB in a custom namespace to check proper use of ns macros:
target_compile_definitions(${headertest_target} PRIVATE
"THRUST_WRAPPED_NAMESPACE=wrapped_thrust"
"CUB_WRAPPED_NAMESPACE=wrapped_cub"
)
target_compile_definitions(${headertest_target} PRIVATE ${header_definitions})
thrust_clone_target_properties(${headertest_target} ${thrust_target})

if ("CUDA" STREQUAL "${config_device}")
Expand All @@ -141,4 +137,29 @@ foreach(thrust_target IN LISTS THRUST_TARGETS)

add_dependencies(thrust.all.headers ${headertest_target})
add_dependencies(${config_prefix}.all ${headertest_target})
endforeach()
endfunction()

foreach(thrust_target IN LISTS THRUST_TARGETS)
# Wrap Thrust/CUB in a custom namespace to check proper use of ns macros:
set(header_definitions
"THRUST_WRAPPED_NAMESPACE=wrapped_thrust"
"CUB_WRAPPED_NAMESPACE=wrapped_cub")
thrust_add_header_test(${thrust_target} base "${header_definitions}")

thrust_get_target_property(config_device ${thrust_target} DEVICE)
if ("CUDA" STREQUAL "${config_device}")
# Check that BF16 support can be disabled
set(header_definitions
"THRUST_WRAPPED_NAMESPACE=wrapped_thrust"
"CUB_WRAPPED_NAMESPACE=wrapped_cub"
"CCCL_DISABLE_BF16_SUPPORT")
thrust_add_header_test(${thrust_target} bf16 "${header_definitions}")

# Check that half support can be disabled
set(header_definitions
"THRUST_WRAPPED_NAMESPACE=wrapped_thrust"
"CUB_WRAPPED_NAMESPACE=wrapped_cub"
"CCCL_DISABLE_FP16_SUPPORT")
thrust_add_header_test(${thrust_target} half "${header_definitions}")
endif()
endforeach ()
15 changes: 15 additions & 0 deletions thrust/cmake/header_test.in
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,18 @@
#endif // THRUST_IGNORE_MACRO_CHECKS

#include <thrust/${header}>

#if defined(CCCL_DISABLE_BF16_SUPPORT)
#if defined(__CUDA_BF16_TYPES_EXIST__)
#error Thrust should not include cuda_bf16.h when BF16 support is disabled
#endif // __CUDA_BF16_TYPES_EXIST__
#endif // CCCL_DISABLE_BF16_SUPPORT

#if defined(CCCL_DISABLE_FP16_SUPPORT)
#if defined(__CUDA_FP16_TYPES_EXIST__)
#error Thrust should not include cuda_fp16.h when half support is disabled
#endif // __CUDA_FP16_TYPES_EXIST__
#if defined(__CUDA_BF16_TYPES_EXIST__)
#error Thrust should not include cuda_bf16.h when half support is disabled
#endif // __CUDA_BF16_TYPES_EXIST__
#endif // CCCL_DISABLE_FP16_SUPPORT

0 comments on commit dded5f1

Please sign in to comment.