Skip to content

Commit

Permalink
Replace most of preprocessor usage in nvcomp adapter with constexpr (
Browse files Browse the repository at this point in the history
…#11980)

C++17's "constexpr if" provides the same functionality as `#if` directive, as used in the nvcomp adapter.
This PR replaces macros with `constexpr` variables and uses them as conditions in "constexpr if" statements.

Authors:
  - Vukasin Milovanovic (https://github.com/vuule)

Approvers:
  - Nghia Truong (https://github.com/ttnghia)
  - Tobias Ribizel (https://github.com/upsj)

URL: #11980
  • Loading branch information
vuule authored Oct 25, 2022
1 parent 11918ae commit 2ee41d0
Showing 1 changed file with 123 additions and 135 deletions.
258 changes: 123 additions & 135 deletions cpp/src/io/comp/nvcomp_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,46 +31,30 @@
#include NVCOMP_ZSTD_HEADER
#endif

#if NVCOMP_MAJOR_VERSION > 2 or (NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 3)
#define NVCOMP_HAS_ZSTD_DECOMP 1
#else
#define NVCOMP_HAS_ZSTD_DECOMP 0
#endif
constexpr bool NVCOMP_HAS_ZSTD_DECOMP = NVCOMP_MAJOR_VERSION > 2 or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 3);

#if NVCOMP_MAJOR_VERSION > 2 or (NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 4)
#define NVCOMP_HAS_ZSTD_COMP 1
#else
#define NVCOMP_HAS_ZSTD_COMP 0
#endif
constexpr bool NVCOMP_HAS_ZSTD_COMP = NVCOMP_MAJOR_VERSION > 2 or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 4);

#if NVCOMP_MAJOR_VERSION > 2 or (NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 3)
#define NVCOMP_HAS_DEFLATE 1
#else
#define NVCOMP_HAS_DEFLATE 0
#endif
constexpr bool NVCOMP_HAS_DEFLATE = NVCOMP_MAJOR_VERSION > 2 or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 3);

#if NVCOMP_MAJOR_VERSION > 2 or (NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION > 3) or \
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 3 and NVCOMP_PATCH_VERSION >= 1)
#define NVCOMP_HAS_TEMPSIZE_EX 1
#else
#define NVCOMP_HAS_TEMPSIZE_EX 0
#endif
constexpr bool NVCOMP_HAS_TEMPSIZE_EX = NVCOMP_MAJOR_VERSION > 2 or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION > 3) or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 3 and
NVCOMP_PATCH_VERSION >= 1);

// ZSTD is stable for nvcomp 2.3.2 or newer
#if NVCOMP_MAJOR_VERSION > 2 or (NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION > 3) or \
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 3 and NVCOMP_PATCH_VERSION >= 2)
#define NVCOMP_ZSTD_IS_STABLE 1
#else
#define NVCOMP_ZSTD_IS_STABLE 0
#endif
constexpr bool NVCOMP_ZSTD_IS_STABLE = NVCOMP_MAJOR_VERSION > 2 or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION > 3) or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 3 and
NVCOMP_PATCH_VERSION >= 2);

// Issue https://github.com/NVIDIA/spark-rapids/issues/6614 impacts nvCOMP 2.4.0 ZSTD decompression
// on compute 6.x
#if NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 4 and NVCOMP_PATCH_VERSION == 0
#define NVCOMP_ZSTD_IS_DISABLED_ON_PASCAL 1
#else
#define NVCOMP_ZSTD_IS_DISABLED_ON_PASCAL 0
#endif
constexpr bool NVCOMP_ZSTD_IS_DISABLED_ON_PASCAL =
NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 4 and NVCOMP_PATCH_VERSION == 0;

namespace cudf::io::nvcomp {

Expand All @@ -79,20 +63,20 @@ template <typename... Args>
std::optional<nvcompStatus_t> batched_decompress_get_temp_size_ex(compression_type compression,
Args&&... args)
{
#if NVCOMP_HAS_TEMPSIZE_EX
switch (compression) {
case compression_type::SNAPPY:
return nvcompBatchedSnappyDecompressGetTempSizeEx(std::forward<Args>(args)...);
case compression_type::ZSTD:
#if NVCOMP_HAS_ZSTD_DECOMP
return nvcompBatchedZstdDecompressGetTempSizeEx(std::forward<Args>(args)...);
#else
return std::nullopt;
#endif
case compression_type::DEFLATE: [[fallthrough]];
default: return std::nullopt;
if constexpr (NVCOMP_HAS_TEMPSIZE_EX) {
switch (compression) {
case compression_type::SNAPPY:
return nvcompBatchedSnappyDecompressGetTempSizeEx(std::forward<Args>(args)...);
case compression_type::ZSTD:
if constexpr (NVCOMP_HAS_ZSTD_DECOMP) {
return nvcompBatchedZstdDecompressGetTempSizeEx(std::forward<Args>(args)...);
} else {
return std::nullopt;
}
case compression_type::DEFLATE: [[fallthrough]];
default: return std::nullopt;
}
}
#endif
return std::nullopt;
}

Expand All @@ -104,17 +88,17 @@ auto batched_decompress_get_temp_size(compression_type compression, Args&&... ar
case compression_type::SNAPPY:
return nvcompBatchedSnappyDecompressGetTempSize(std::forward<Args>(args)...);
case compression_type::ZSTD:
#if NVCOMP_HAS_ZSTD_DECOMP
return nvcompBatchedZstdDecompressGetTempSize(std::forward<Args>(args)...);
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_ZSTD_DECOMP) {
return nvcompBatchedZstdDecompressGetTempSize(std::forward<Args>(args)...);
} else {
CUDF_FAIL("Unsupported compression type");
}
case compression_type::DEFLATE:
#if NVCOMP_HAS_DEFLATE
return nvcompBatchedDeflateDecompressGetTempSize(std::forward<Args>(args)...);
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_DEFLATE) {
return nvcompBatchedDeflateDecompressGetTempSize(std::forward<Args>(args)...);
} else {
CUDF_FAIL("Unsupported compression type");
}
default: CUDF_FAIL("Unsupported compression type");
}
}
Expand All @@ -127,17 +111,18 @@ auto batched_decompress_async(compression_type compression, Args&&... args)
case compression_type::SNAPPY:
return nvcompBatchedSnappyDecompressAsync(std::forward<Args>(args)...);
case compression_type::ZSTD:
#if NVCOMP_HAS_ZSTD_DECOMP
return nvcompBatchedZstdDecompressAsync(std::forward<Args>(args)...);
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_ZSTD_DECOMP) {
return nvcompBatchedZstdDecompressAsync(std::forward<Args>(args)...);
} else {
CUDF_FAIL("Unsupported compression type");
}

case compression_type::DEFLATE:
#if NVCOMP_HAS_DEFLATE
return nvcompBatchedDeflateDecompressAsync(std::forward<Args>(args)...);
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_DEFLATE) {
return nvcompBatchedDeflateDecompressAsync(std::forward<Args>(args)...);
} else {
CUDF_FAIL("Unsupported compression type");
}
default: CUDF_FAIL("Unsupported compression type");
}
}
Expand Down Expand Up @@ -170,13 +155,13 @@ void check_is_zstd_enabled()
"Zstandard compression is experimental, you can enable it through "
"`LIBCUDF_NVCOMP_POLICY` environment variable.");

#if NVCOMP_ZSTD_IS_DISABLED_ON_PASCAL
int device;
int cc_major;
CUDF_CUDA_TRY(cudaGetDevice(&device));
CUDF_CUDA_TRY(cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, device));
CUDF_EXPECTS(cc_major != 6, "Zstandard decompression is disabled on Pascal GPUs");
#endif
if constexpr (NVCOMP_ZSTD_IS_DISABLED_ON_PASCAL) {
int device;
int cc_major;
CUDF_CUDA_TRY(cudaGetDevice(&device));
CUDF_CUDA_TRY(cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, device));
CUDF_EXPECTS(cc_major != 6, "Zstandard decompression is disabled on Pascal GPUs");
}
}

void batched_decompress(compression_type compression,
Expand Down Expand Up @@ -228,21 +213,22 @@ auto batched_compress_temp_size(compression_type compression,
batch_size, max_uncompressed_chunk_bytes, nvcompBatchedSnappyDefaultOpts, &temp_size);
break;
case compression_type::DEFLATE:
#if NVCOMP_HAS_DEFLATE
nvcomp_status = nvcompBatchedDeflateCompressGetTempSize(
batch_size, max_uncompressed_chunk_bytes, nvcompBatchedDeflateDefaultOpts, &temp_size);
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_DEFLATE) {
nvcomp_status = nvcompBatchedDeflateCompressGetTempSize(
batch_size, max_uncompressed_chunk_bytes, nvcompBatchedDeflateDefaultOpts, &temp_size);
break;
} else {
CUDF_FAIL("Unsupported compression type");
}
case compression_type::ZSTD:
#if NVCOMP_HAS_ZSTD_COMP
nvcomp_status = nvcompBatchedZstdCompressGetTempSize(
batch_size, max_uncompressed_chunk_bytes, nvcompBatchedZstdDefaultOpts, &temp_size);
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_ZSTD_COMP) {
nvcomp_status = nvcompBatchedZstdCompressGetTempSize(
batch_size, max_uncompressed_chunk_bytes, nvcompBatchedZstdDefaultOpts, &temp_size);
break;
} else {
CUDF_FAIL("Unsupported compression type");
}

default: CUDF_FAIL("Unsupported compression type");
}

Expand All @@ -266,21 +252,21 @@ size_t compress_max_output_chunk_size(compression_type compression,
capped_uncomp_bytes, nvcompBatchedSnappyDefaultOpts, &max_comp_chunk_size);
break;
case compression_type::DEFLATE:
#if NVCOMP_HAS_DEFLATE
status = nvcompBatchedDeflateCompressGetMaxOutputChunkSize(
capped_uncomp_bytes, nvcompBatchedDeflateDefaultOpts, &max_comp_chunk_size);
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_DEFLATE) {
status = nvcompBatchedDeflateCompressGetMaxOutputChunkSize(
capped_uncomp_bytes, nvcompBatchedDeflateDefaultOpts, &max_comp_chunk_size);
break;
} else {
CUDF_FAIL("Unsupported compression type");
}
case compression_type::ZSTD:
#if NVCOMP_HAS_ZSTD_COMP
status = nvcompBatchedZstdCompressGetMaxOutputChunkSize(
capped_uncomp_bytes, nvcompBatchedZstdDefaultOpts, &max_comp_chunk_size);
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_ZSTD_COMP) {
status = nvcompBatchedZstdCompressGetMaxOutputChunkSize(
capped_uncomp_bytes, nvcompBatchedZstdDefaultOpts, &max_comp_chunk_size);
break;
} else {
CUDF_FAIL("Unsupported compression type");
}
default: CUDF_FAIL("Unsupported compression type");
}

Expand Down Expand Up @@ -316,37 +302,39 @@ static void batched_compress_async(compression_type compression,
stream.value());
break;
case compression_type::DEFLATE:
#if NVCOMP_HAS_DEFLATE
nvcomp_status = nvcompBatchedDeflateCompressAsync(device_uncompressed_ptrs,
device_uncompressed_bytes,
max_uncompressed_chunk_bytes,
batch_size,
device_temp_ptr,
temp_bytes,
device_compressed_ptrs,
device_compressed_bytes,
nvcompBatchedDeflateDefaultOpts,
stream.value());
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_DEFLATE) {
nvcomp_status = nvcompBatchedDeflateCompressAsync(device_uncompressed_ptrs,
device_uncompressed_bytes,
max_uncompressed_chunk_bytes,
batch_size,
device_temp_ptr,
temp_bytes,
device_compressed_ptrs,
device_compressed_bytes,
nvcompBatchedDeflateDefaultOpts,
stream.value());
break;
} else {
CUDF_FAIL("Unsupported compression type");
}

case compression_type::ZSTD:
#if NVCOMP_HAS_ZSTD_COMP
nvcomp_status = nvcompBatchedZstdCompressAsync(device_uncompressed_ptrs,
device_uncompressed_bytes,
max_uncompressed_chunk_bytes,
batch_size,
device_temp_ptr,
temp_bytes,
device_compressed_ptrs,
device_compressed_bytes,
nvcompBatchedZstdDefaultOpts,
stream.value());
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_ZSTD_COMP) {
nvcomp_status = nvcompBatchedZstdCompressAsync(device_uncompressed_ptrs,
device_uncompressed_bytes,
max_uncompressed_chunk_bytes,
batch_size,
device_temp_ptr,
temp_bytes,
device_compressed_ptrs,
device_compressed_bytes,
nvcompBatchedZstdDefaultOpts,
stream.value());
break;
} else {
CUDF_FAIL("Unsupported compression type");
}

default: CUDF_FAIL("Unsupported compression type");
}
CUDF_EXPECTS(nvcomp_status == nvcompStatus_t::nvcompSuccess, "Error in compression");
Expand Down Expand Up @@ -430,11 +418,11 @@ std::optional<size_t> compress_max_allowed_chunk_size(compression_type compressi
case compression_type::DEFLATE: return 64 * 1024;
case compression_type::SNAPPY: return std::nullopt;
case compression_type::ZSTD:
#if NVCOMP_HAS_ZSTD_COMP
return nvcompZstdCompressionMaxAllowedChunkSize;
#else
CUDF_FAIL("Unsupported compression type");
#endif
if constexpr (NVCOMP_HAS_ZSTD_COMP) {
return nvcompZstdCompressionMaxAllowedChunkSize;
} else {
CUDF_FAIL("Unsupported compression type");
}
default: return std::nullopt;
}
}
Expand Down

0 comments on commit 2ee41d0

Please sign in to comment.