Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace most of preprocessor usage in nvcomp adapter with constexpr #11980

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 123 additions & 135 deletions cpp/src/io/comp/nvcomp_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,46 +31,30 @@
#include NVCOMP_ZSTD_HEADER
#endif

#if NVCOMP_MAJOR_VERSION > 2 or (NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 3)
#define NVCOMP_HAS_ZSTD_DECOMP 1
#else
#define NVCOMP_HAS_ZSTD_DECOMP 0
#endif
constexpr bool NVCOMP_HAS_ZSTD_DECOMP = NVCOMP_MAJOR_VERSION > 2 or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 3);

#if NVCOMP_MAJOR_VERSION > 2 or (NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 4)
#define NVCOMP_HAS_ZSTD_COMP 1
#else
#define NVCOMP_HAS_ZSTD_COMP 0
#endif
constexpr bool NVCOMP_HAS_ZSTD_COMP = NVCOMP_MAJOR_VERSION > 2 or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 4);

#if NVCOMP_MAJOR_VERSION > 2 or (NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 3)
#define NVCOMP_HAS_DEFLATE 1
#else
#define NVCOMP_HAS_DEFLATE 0
#endif
constexpr bool NVCOMP_HAS_DEFLATE = NVCOMP_MAJOR_VERSION > 2 or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 3);

#if NVCOMP_MAJOR_VERSION > 2 or (NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION > 3) or \
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 3 and NVCOMP_PATCH_VERSION >= 1)
#define NVCOMP_HAS_TEMPSIZE_EX 1
#else
#define NVCOMP_HAS_TEMPSIZE_EX 0
#endif
constexpr bool NVCOMP_HAS_TEMPSIZE_EX = NVCOMP_MAJOR_VERSION > 2 or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION > 3) or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 3 and
NVCOMP_PATCH_VERSION >= 1);

// ZSTD is stable for nvcomp 2.3.2 or newer
#if NVCOMP_MAJOR_VERSION > 2 or (NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION > 3) or \
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 3 and NVCOMP_PATCH_VERSION >= 2)
#define NVCOMP_ZSTD_IS_STABLE 1
#else
#define NVCOMP_ZSTD_IS_STABLE 0
#endif
constexpr bool NVCOMP_ZSTD_IS_STABLE = NVCOMP_MAJOR_VERSION > 2 or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION > 3) or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 3 and
NVCOMP_PATCH_VERSION >= 2);

// Issue https://github.com/NVIDIA/spark-rapids/issues/6614 impacts nvCOMP 2.4.0 ZSTD decompression
// on compute 6.x
#if NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 4 and NVCOMP_PATCH_VERSION == 0
#define NVCOMP_ZSTD_IS_DISABLED_ON_PASCAL 1
#else
#define NVCOMP_ZSTD_IS_DISABLED_ON_PASCAL 0
#endif
constexpr bool NVCOMP_ZSTD_IS_DISABLED_ON_PASCAL =
NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 4 and NVCOMP_PATCH_VERSION == 0;

namespace cudf::io::nvcomp {

Expand All @@ -79,20 +63,20 @@ template <typename... Args>
std::optional<nvcompStatus_t> batched_decompress_get_temp_size_ex(compression_type compression,
Args&&... args)
{
#if NVCOMP_HAS_TEMPSIZE_EX
switch (compression) {
case compression_type::SNAPPY:
return nvcompBatchedSnappyDecompressGetTempSizeEx(std::forward<Args>(args)...);
case compression_type::ZSTD:
#if NVCOMP_HAS_ZSTD_DECOMP
return nvcompBatchedZstdDecompressGetTempSizeEx(std::forward<Args>(args)...);
#else
return std::nullopt;
#endif
case compression_type::DEFLATE: [[fallthrough]];
default: return std::nullopt;
if constexpr (NVCOMP_HAS_TEMPSIZE_EX) {
switch (compression) {
case compression_type::SNAPPY:
return nvcompBatchedSnappyDecompressGetTempSizeEx(std::forward<Args>(args)...);
case compression_type::ZSTD:
if constexpr (NVCOMP_HAS_ZSTD_DECOMP) {
return nvcompBatchedZstdDecompressGetTempSizeEx(std::forward<Args>(args)...);
} else {
return std::nullopt;
}
case compression_type::DEFLATE: [[fallthrough]];
default: return std::nullopt;
}
}
#endif
return std::nullopt;
}

Expand All @@ -104,17 +88,17 @@ auto batched_decompress_get_temp_size(compression_type compression, Args&&... ar
case compression_type::SNAPPY:
return nvcompBatchedSnappyDecompressGetTempSize(std::forward<Args>(args)...);
case compression_type::ZSTD:
#if NVCOMP_HAS_ZSTD_DECOMP
return nvcompBatchedZstdDecompressGetTempSize(std::forward<Args>(args)...);
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_ZSTD_DECOMP) {
return nvcompBatchedZstdDecompressGetTempSize(std::forward<Args>(args)...);
} else {
CUDF_FAIL("Unsupported compression type");
}
case compression_type::DEFLATE:
#if NVCOMP_HAS_DEFLATE
return nvcompBatchedDeflateDecompressGetTempSize(std::forward<Args>(args)...);
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_DEFLATE) {
return nvcompBatchedDeflateDecompressGetTempSize(std::forward<Args>(args)...);
} else {
CUDF_FAIL("Unsupported compression type");
}
default: CUDF_FAIL("Unsupported compression type");
}
}
Expand All @@ -127,17 +111,18 @@ auto batched_decompress_async(compression_type compression, Args&&... args)
case compression_type::SNAPPY:
return nvcompBatchedSnappyDecompressAsync(std::forward<Args>(args)...);
case compression_type::ZSTD:
#if NVCOMP_HAS_ZSTD_DECOMP
return nvcompBatchedZstdDecompressAsync(std::forward<Args>(args)...);
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_ZSTD_DECOMP) {
return nvcompBatchedZstdDecompressAsync(std::forward<Args>(args)...);
} else {
CUDF_FAIL("Unsupported compression type");
}

case compression_type::DEFLATE:
#if NVCOMP_HAS_DEFLATE
return nvcompBatchedDeflateDecompressAsync(std::forward<Args>(args)...);
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_DEFLATE) {
return nvcompBatchedDeflateDecompressAsync(std::forward<Args>(args)...);
} else {
CUDF_FAIL("Unsupported compression type");
}
default: CUDF_FAIL("Unsupported compression type");
}
}
Expand Down Expand Up @@ -170,13 +155,13 @@ void check_is_zstd_enabled()
"Zstandard compression is experimental, you can enable it through "
"`LIBCUDF_NVCOMP_POLICY` environment variable.");

#if NVCOMP_ZSTD_IS_DISABLED_ON_PASCAL
int device;
int cc_major;
CUDF_CUDA_TRY(cudaGetDevice(&device));
CUDF_CUDA_TRY(cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, device));
CUDF_EXPECTS(cc_major != 6, "Zstandard decompression is disabled on Pascal GPUs");
#endif
if constexpr (NVCOMP_ZSTD_IS_DISABLED_ON_PASCAL) {
int device;
int cc_major;
CUDF_CUDA_TRY(cudaGetDevice(&device));
CUDF_CUDA_TRY(cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, device));
CUDF_EXPECTS(cc_major != 6, "Zstandard decompression is disabled on Pascal GPUs");
}
}

void batched_decompress(compression_type compression,
Expand Down Expand Up @@ -228,21 +213,22 @@ auto batched_compress_temp_size(compression_type compression,
batch_size, max_uncompressed_chunk_bytes, nvcompBatchedSnappyDefaultOpts, &temp_size);
break;
case compression_type::DEFLATE:
#if NVCOMP_HAS_DEFLATE
nvcomp_status = nvcompBatchedDeflateCompressGetTempSize(
batch_size, max_uncompressed_chunk_bytes, nvcompBatchedDeflateDefaultOpts, &temp_size);
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_DEFLATE) {
nvcomp_status = nvcompBatchedDeflateCompressGetTempSize(
batch_size, max_uncompressed_chunk_bytes, nvcompBatchedDeflateDefaultOpts, &temp_size);
break;
} else {
CUDF_FAIL("Unsupported compression type");
}
case compression_type::ZSTD:
#if NVCOMP_HAS_ZSTD_COMP
nvcomp_status = nvcompBatchedZstdCompressGetTempSize(
batch_size, max_uncompressed_chunk_bytes, nvcompBatchedZstdDefaultOpts, &temp_size);
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_ZSTD_COMP) {
nvcomp_status = nvcompBatchedZstdCompressGetTempSize(
batch_size, max_uncompressed_chunk_bytes, nvcompBatchedZstdDefaultOpts, &temp_size);
break;
} else {
CUDF_FAIL("Unsupported compression type");
}

default: CUDF_FAIL("Unsupported compression type");
}

Expand All @@ -266,21 +252,21 @@ size_t compress_max_output_chunk_size(compression_type compression,
capped_uncomp_bytes, nvcompBatchedSnappyDefaultOpts, &max_comp_chunk_size);
break;
case compression_type::DEFLATE:
#if NVCOMP_HAS_DEFLATE
status = nvcompBatchedDeflateCompressGetMaxOutputChunkSize(
capped_uncomp_bytes, nvcompBatchedDeflateDefaultOpts, &max_comp_chunk_size);
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_DEFLATE) {
status = nvcompBatchedDeflateCompressGetMaxOutputChunkSize(
capped_uncomp_bytes, nvcompBatchedDeflateDefaultOpts, &max_comp_chunk_size);
break;
} else {
CUDF_FAIL("Unsupported compression type");
}
case compression_type::ZSTD:
#if NVCOMP_HAS_ZSTD_COMP
status = nvcompBatchedZstdCompressGetMaxOutputChunkSize(
capped_uncomp_bytes, nvcompBatchedZstdDefaultOpts, &max_comp_chunk_size);
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_ZSTD_COMP) {
status = nvcompBatchedZstdCompressGetMaxOutputChunkSize(
capped_uncomp_bytes, nvcompBatchedZstdDefaultOpts, &max_comp_chunk_size);
break;
} else {
CUDF_FAIL("Unsupported compression type");
}
default: CUDF_FAIL("Unsupported compression type");
}

Expand Down Expand Up @@ -316,37 +302,39 @@ static void batched_compress_async(compression_type compression,
stream.value());
break;
case compression_type::DEFLATE:
#if NVCOMP_HAS_DEFLATE
nvcomp_status = nvcompBatchedDeflateCompressAsync(device_uncompressed_ptrs,
device_uncompressed_bytes,
max_uncompressed_chunk_bytes,
batch_size,
device_temp_ptr,
temp_bytes,
device_compressed_ptrs,
device_compressed_bytes,
nvcompBatchedDeflateDefaultOpts,
stream.value());
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_DEFLATE) {
nvcomp_status = nvcompBatchedDeflateCompressAsync(device_uncompressed_ptrs,
device_uncompressed_bytes,
max_uncompressed_chunk_bytes,
batch_size,
device_temp_ptr,
temp_bytes,
device_compressed_ptrs,
device_compressed_bytes,
nvcompBatchedDeflateDefaultOpts,
stream.value());
break;
} else {
CUDF_FAIL("Unsupported compression type");
}

case compression_type::ZSTD:
#if NVCOMP_HAS_ZSTD_COMP
nvcomp_status = nvcompBatchedZstdCompressAsync(device_uncompressed_ptrs,
device_uncompressed_bytes,
max_uncompressed_chunk_bytes,
batch_size,
device_temp_ptr,
temp_bytes,
device_compressed_ptrs,
device_compressed_bytes,
nvcompBatchedZstdDefaultOpts,
stream.value());
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_ZSTD_COMP) {
nvcomp_status = nvcompBatchedZstdCompressAsync(device_uncompressed_ptrs,
device_uncompressed_bytes,
max_uncompressed_chunk_bytes,
batch_size,
device_temp_ptr,
temp_bytes,
device_compressed_ptrs,
device_compressed_bytes,
nvcompBatchedZstdDefaultOpts,
stream.value());
break;
} else {
CUDF_FAIL("Unsupported compression type");
}

default: CUDF_FAIL("Unsupported compression type");
}
CUDF_EXPECTS(nvcomp_status == nvcompStatus_t::nvcompSuccess, "Error in compression");
Expand Down Expand Up @@ -430,11 +418,11 @@ std::optional<size_t> compress_max_allowed_chunk_size(compression_type compressi
case compression_type::DEFLATE: return 64 * 1024;
case compression_type::SNAPPY: return std::nullopt;
case compression_type::ZSTD:
#if NVCOMP_HAS_ZSTD_COMP
return nvcompZstdCompressionMaxAllowedChunkSize;
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_ZSTD_COMP) {
return nvcompZstdCompressionMaxAllowedChunkSize;
} else {
CUDF_FAIL("Unsupported compression type");
}
default: return std::nullopt;
}
}
Expand Down