Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Replace most of preprocessor usage in nvcomp adapter with constexpr" #11999

Merged
merged 1 commit into from
Oct 26, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 135 additions & 123 deletions cpp/src/io/comp/nvcomp_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,46 @@
#include NVCOMP_ZSTD_HEADER
#endif

constexpr bool NVCOMP_HAS_ZSTD_DECOMP = NVCOMP_MAJOR_VERSION > 2 or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 3);
#if NVCOMP_MAJOR_VERSION > 2 or (NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 3)
#define NVCOMP_HAS_ZSTD_DECOMP 1
#else
#define NVCOMP_HAS_ZSTD_DECOMP 0
#endif

constexpr bool NVCOMP_HAS_ZSTD_COMP = NVCOMP_MAJOR_VERSION > 2 or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 4);
#if NVCOMP_MAJOR_VERSION > 2 or (NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 4)
#define NVCOMP_HAS_ZSTD_COMP 1
#else
#define NVCOMP_HAS_ZSTD_COMP 0
#endif

constexpr bool NVCOMP_HAS_DEFLATE = NVCOMP_MAJOR_VERSION > 2 or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 3);
#if NVCOMP_MAJOR_VERSION > 2 or (NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 3)
#define NVCOMP_HAS_DEFLATE 1
#else
#define NVCOMP_HAS_DEFLATE 0
#endif

constexpr bool NVCOMP_HAS_TEMPSIZE_EX = NVCOMP_MAJOR_VERSION > 2 or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION > 3) or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 3 and
NVCOMP_PATCH_VERSION >= 1);
#if NVCOMP_MAJOR_VERSION > 2 or (NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION > 3) or \
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 3 and NVCOMP_PATCH_VERSION >= 1)
#define NVCOMP_HAS_TEMPSIZE_EX 1
#else
#define NVCOMP_HAS_TEMPSIZE_EX 0
#endif

// ZSTD is stable for nvcomp 2.3.2 or newer
constexpr bool NVCOMP_ZSTD_IS_STABLE = NVCOMP_MAJOR_VERSION > 2 or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION > 3) or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 3 and
NVCOMP_PATCH_VERSION >= 2);
#if NVCOMP_MAJOR_VERSION > 2 or (NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION > 3) or \
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 3 and NVCOMP_PATCH_VERSION >= 2)
#define NVCOMP_ZSTD_IS_STABLE 1
#else
#define NVCOMP_ZSTD_IS_STABLE 0
#endif

// Issue https://github.com/NVIDIA/spark-rapids/issues/6614 impacts nvCOMP 2.4.0 ZSTD decompression
// on compute 6.x
constexpr bool NVCOMP_ZSTD_IS_DISABLED_ON_PASCAL =
NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 4 and NVCOMP_PATCH_VERSION == 0;
#if NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 4 and NVCOMP_PATCH_VERSION == 0
#define NVCOMP_ZSTD_IS_DISABLED_ON_PASCAL 1
#else
#define NVCOMP_ZSTD_IS_DISABLED_ON_PASCAL 0
#endif

namespace cudf::io::nvcomp {

Expand All @@ -63,20 +79,20 @@ template <typename... Args>
std::optional<nvcompStatus_t> batched_decompress_get_temp_size_ex(compression_type compression,
Args&&... args)
{
if constexpr (NVCOMP_HAS_TEMPSIZE_EX) {
switch (compression) {
case compression_type::SNAPPY:
return nvcompBatchedSnappyDecompressGetTempSizeEx(std::forward<Args>(args)...);
case compression_type::ZSTD:
if constexpr (NVCOMP_HAS_ZSTD_DECOMP) {
return nvcompBatchedZstdDecompressGetTempSizeEx(std::forward<Args>(args)...);
} else {
return std::nullopt;
}
case compression_type::DEFLATE: [[fallthrough]];
default: return std::nullopt;
}
#if NVCOMP_HAS_TEMPSIZE_EX
switch (compression) {
case compression_type::SNAPPY:
return nvcompBatchedSnappyDecompressGetTempSizeEx(std::forward<Args>(args)...);
case compression_type::ZSTD:
#if NVCOMP_HAS_ZSTD_DECOMP
return nvcompBatchedZstdDecompressGetTempSizeEx(std::forward<Args>(args)...);
#else
return std::nullopt;
#endif
case compression_type::DEFLATE: [[fallthrough]];
default: return std::nullopt;
}
#endif
return std::nullopt;
}

Expand All @@ -88,17 +104,17 @@ auto batched_decompress_get_temp_size(compression_type compression, Args&&... ar
case compression_type::SNAPPY:
return nvcompBatchedSnappyDecompressGetTempSize(std::forward<Args>(args)...);
case compression_type::ZSTD:
if constexpr (NVCOMP_HAS_ZSTD_DECOMP) {
return nvcompBatchedZstdDecompressGetTempSize(std::forward<Args>(args)...);
} else {
CUDF_FAIL("Unsupported compression type");
}
#if NVCOMP_HAS_ZSTD_DECOMP
return nvcompBatchedZstdDecompressGetTempSize(std::forward<Args>(args)...);
#else
CUDF_FAIL("Unsupported compression type");
#endif
case compression_type::DEFLATE:
if constexpr (NVCOMP_HAS_DEFLATE) {
return nvcompBatchedDeflateDecompressGetTempSize(std::forward<Args>(args)...);
} else {
CUDF_FAIL("Unsupported compression type");
}
#if NVCOMP_HAS_DEFLATE
return nvcompBatchedDeflateDecompressGetTempSize(std::forward<Args>(args)...);
#else
CUDF_FAIL("Unsupported compression type");
#endif
default: CUDF_FAIL("Unsupported compression type");
}
}
Expand All @@ -111,18 +127,17 @@ auto batched_decompress_async(compression_type compression, Args&&... args)
case compression_type::SNAPPY:
return nvcompBatchedSnappyDecompressAsync(std::forward<Args>(args)...);
case compression_type::ZSTD:
if constexpr (NVCOMP_HAS_ZSTD_DECOMP) {
return nvcompBatchedZstdDecompressAsync(std::forward<Args>(args)...);
} else {
CUDF_FAIL("Unsupported compression type");
}

#if NVCOMP_HAS_ZSTD_DECOMP
return nvcompBatchedZstdDecompressAsync(std::forward<Args>(args)...);
#else
CUDF_FAIL("Unsupported compression type");
#endif
case compression_type::DEFLATE:
if constexpr (NVCOMP_HAS_DEFLATE) {
return nvcompBatchedDeflateDecompressAsync(std::forward<Args>(args)...);
} else {
CUDF_FAIL("Unsupported compression type");
}
#if NVCOMP_HAS_DEFLATE
return nvcompBatchedDeflateDecompressAsync(std::forward<Args>(args)...);
#else
CUDF_FAIL("Unsupported compression type");
#endif
default: CUDF_FAIL("Unsupported compression type");
}
}
Expand Down Expand Up @@ -155,13 +170,13 @@ void check_is_zstd_enabled()
"Zstandard compression is experimental, you can enable it through "
"`LIBCUDF_NVCOMP_POLICY` environment variable.");

if constexpr (NVCOMP_ZSTD_IS_DISABLED_ON_PASCAL) {
int device;
int cc_major;
CUDF_CUDA_TRY(cudaGetDevice(&device));
CUDF_CUDA_TRY(cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, device));
CUDF_EXPECTS(cc_major != 6, "Zstandard decompression is disabled on Pascal GPUs");
}
#if NVCOMP_ZSTD_IS_DISABLED_ON_PASCAL
int device;
int cc_major;
CUDF_CUDA_TRY(cudaGetDevice(&device));
CUDF_CUDA_TRY(cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, device));
CUDF_EXPECTS(cc_major != 6, "Zstandard decompression is disabled on Pascal GPUs");
#endif
}

void batched_decompress(compression_type compression,
Expand Down Expand Up @@ -213,22 +228,21 @@ auto batched_compress_temp_size(compression_type compression,
batch_size, max_uncompressed_chunk_bytes, nvcompBatchedSnappyDefaultOpts, &temp_size);
break;
case compression_type::DEFLATE:
if constexpr (NVCOMP_HAS_DEFLATE) {
nvcomp_status = nvcompBatchedDeflateCompressGetTempSize(
batch_size, max_uncompressed_chunk_bytes, nvcompBatchedDeflateDefaultOpts, &temp_size);
break;
} else {
CUDF_FAIL("Unsupported compression type");
}
#if NVCOMP_HAS_DEFLATE
nvcomp_status = nvcompBatchedDeflateCompressGetTempSize(
batch_size, max_uncompressed_chunk_bytes, nvcompBatchedDeflateDefaultOpts, &temp_size);
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
case compression_type::ZSTD:
if constexpr (NVCOMP_HAS_ZSTD_COMP) {
nvcomp_status = nvcompBatchedZstdCompressGetTempSize(
batch_size, max_uncompressed_chunk_bytes, nvcompBatchedZstdDefaultOpts, &temp_size);
break;
} else {
CUDF_FAIL("Unsupported compression type");
}

#if NVCOMP_HAS_ZSTD_COMP
nvcomp_status = nvcompBatchedZstdCompressGetTempSize(
batch_size, max_uncompressed_chunk_bytes, nvcompBatchedZstdDefaultOpts, &temp_size);
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
default: CUDF_FAIL("Unsupported compression type");
}

Expand All @@ -252,21 +266,21 @@ size_t compress_max_output_chunk_size(compression_type compression,
capped_uncomp_bytes, nvcompBatchedSnappyDefaultOpts, &max_comp_chunk_size);
break;
case compression_type::DEFLATE:
if constexpr (NVCOMP_HAS_DEFLATE) {
status = nvcompBatchedDeflateCompressGetMaxOutputChunkSize(
capped_uncomp_bytes, nvcompBatchedDeflateDefaultOpts, &max_comp_chunk_size);
break;
} else {
CUDF_FAIL("Unsupported compression type");
}
#if NVCOMP_HAS_DEFLATE
status = nvcompBatchedDeflateCompressGetMaxOutputChunkSize(
capped_uncomp_bytes, nvcompBatchedDeflateDefaultOpts, &max_comp_chunk_size);
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
case compression_type::ZSTD:
if constexpr (NVCOMP_HAS_ZSTD_COMP) {
status = nvcompBatchedZstdCompressGetMaxOutputChunkSize(
capped_uncomp_bytes, nvcompBatchedZstdDefaultOpts, &max_comp_chunk_size);
break;
} else {
CUDF_FAIL("Unsupported compression type");
}
#if NVCOMP_HAS_ZSTD_COMP
status = nvcompBatchedZstdCompressGetMaxOutputChunkSize(
capped_uncomp_bytes, nvcompBatchedZstdDefaultOpts, &max_comp_chunk_size);
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
default: CUDF_FAIL("Unsupported compression type");
}

Expand Down Expand Up @@ -302,39 +316,37 @@ static void batched_compress_async(compression_type compression,
stream.value());
break;
case compression_type::DEFLATE:
if constexpr (NVCOMP_HAS_DEFLATE) {
nvcomp_status = nvcompBatchedDeflateCompressAsync(device_uncompressed_ptrs,
device_uncompressed_bytes,
max_uncompressed_chunk_bytes,
batch_size,
device_temp_ptr,
temp_bytes,
device_compressed_ptrs,
device_compressed_bytes,
nvcompBatchedDeflateDefaultOpts,
stream.value());
break;
} else {
CUDF_FAIL("Unsupported compression type");
}

#if NVCOMP_HAS_DEFLATE
nvcomp_status = nvcompBatchedDeflateCompressAsync(device_uncompressed_ptrs,
device_uncompressed_bytes,
max_uncompressed_chunk_bytes,
batch_size,
device_temp_ptr,
temp_bytes,
device_compressed_ptrs,
device_compressed_bytes,
nvcompBatchedDeflateDefaultOpts,
stream.value());
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
case compression_type::ZSTD:
if constexpr (NVCOMP_HAS_ZSTD_COMP) {
nvcomp_status = nvcompBatchedZstdCompressAsync(device_uncompressed_ptrs,
device_uncompressed_bytes,
max_uncompressed_chunk_bytes,
batch_size,
device_temp_ptr,
temp_bytes,
device_compressed_ptrs,
device_compressed_bytes,
nvcompBatchedZstdDefaultOpts,
stream.value());
break;
} else {
CUDF_FAIL("Unsupported compression type");
}

#if NVCOMP_HAS_ZSTD_COMP
nvcomp_status = nvcompBatchedZstdCompressAsync(device_uncompressed_ptrs,
device_uncompressed_bytes,
max_uncompressed_chunk_bytes,
batch_size,
device_temp_ptr,
temp_bytes,
device_compressed_ptrs,
device_compressed_bytes,
nvcompBatchedZstdDefaultOpts,
stream.value());
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
default: CUDF_FAIL("Unsupported compression type");
}
CUDF_EXPECTS(nvcomp_status == nvcompStatus_t::nvcompSuccess, "Error in compression");
Expand Down Expand Up @@ -418,11 +430,11 @@ std::optional<size_t> compress_max_allowed_chunk_size(compression_type compressi
case compression_type::DEFLATE: return 64 * 1024;
case compression_type::SNAPPY: return std::nullopt;
case compression_type::ZSTD:
if constexpr (NVCOMP_HAS_ZSTD_COMP) {
return nvcompZstdCompressionMaxAllowedChunkSize;
} else {
CUDF_FAIL("Unsupported compression type");
}
#if NVCOMP_HAS_ZSTD_COMP
return nvcompZstdCompressionMaxAllowedChunkSize;
#else
CUDF_FAIL("Unsupported compression type");
#endif
default: return std::nullopt;
}
}
Expand Down