Skip to content

Commit

Permalink
Revert "Replace most of preprocessor usage in nvcomp adapter with `co…
Browse files Browse the repository at this point in the history
…nstexpr` (#11980)"

This reverts commit 2ee41d0.
  • Loading branch information
vuule authored Oct 25, 2022
1 parent 5bfc9a4 commit 14477bb
Showing 1 changed file with 135 additions and 123 deletions.
258 changes: 135 additions & 123 deletions cpp/src/io/comp/nvcomp_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,46 @@
#include NVCOMP_ZSTD_HEADER
#endif

constexpr bool NVCOMP_HAS_ZSTD_DECOMP = NVCOMP_MAJOR_VERSION > 2 or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 3);
#if NVCOMP_MAJOR_VERSION > 2 or (NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 3)
#define NVCOMP_HAS_ZSTD_DECOMP 1
#else
#define NVCOMP_HAS_ZSTD_DECOMP 0
#endif

constexpr bool NVCOMP_HAS_ZSTD_COMP = NVCOMP_MAJOR_VERSION > 2 or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 4);
#if NVCOMP_MAJOR_VERSION > 2 or (NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 4)
#define NVCOMP_HAS_ZSTD_COMP 1
#else
#define NVCOMP_HAS_ZSTD_COMP 0
#endif

constexpr bool NVCOMP_HAS_DEFLATE = NVCOMP_MAJOR_VERSION > 2 or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 3);
#if NVCOMP_MAJOR_VERSION > 2 or (NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION >= 3)
#define NVCOMP_HAS_DEFLATE 1
#else
#define NVCOMP_HAS_DEFLATE 0
#endif

constexpr bool NVCOMP_HAS_TEMPSIZE_EX = NVCOMP_MAJOR_VERSION > 2 or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION > 3) or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 3 and
NVCOMP_PATCH_VERSION >= 1);
#if NVCOMP_MAJOR_VERSION > 2 or (NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION > 3) or \
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 3 and NVCOMP_PATCH_VERSION >= 1)
#define NVCOMP_HAS_TEMPSIZE_EX 1
#else
#define NVCOMP_HAS_TEMPSIZE_EX 0
#endif

// ZSTD is stable for nvcomp 2.3.2 or newer
constexpr bool NVCOMP_ZSTD_IS_STABLE = NVCOMP_MAJOR_VERSION > 2 or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION > 3) or
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 3 and
NVCOMP_PATCH_VERSION >= 2);
#if NVCOMP_MAJOR_VERSION > 2 or (NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION > 3) or \
(NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 3 and NVCOMP_PATCH_VERSION >= 2)
#define NVCOMP_ZSTD_IS_STABLE 1
#else
#define NVCOMP_ZSTD_IS_STABLE 0
#endif

// Issue https://github.com/NVIDIA/spark-rapids/issues/6614 impacts nvCOMP 2.4.0 ZSTD decompression
// on compute 6.x
constexpr bool NVCOMP_ZSTD_IS_DISABLED_ON_PASCAL =
NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 4 and NVCOMP_PATCH_VERSION == 0;
#if NVCOMP_MAJOR_VERSION == 2 and NVCOMP_MINOR_VERSION == 4 and NVCOMP_PATCH_VERSION == 0
#define NVCOMP_ZSTD_IS_DISABLED_ON_PASCAL 1
#else
#define NVCOMP_ZSTD_IS_DISABLED_ON_PASCAL 0
#endif

namespace cudf::io::nvcomp {

Expand All @@ -63,20 +79,20 @@ template <typename... Args>
std::optional<nvcompStatus_t> batched_decompress_get_temp_size_ex(compression_type compression,
Args&&... args)
{
if constexpr (NVCOMP_HAS_TEMPSIZE_EX) {
switch (compression) {
case compression_type::SNAPPY:
return nvcompBatchedSnappyDecompressGetTempSizeEx(std::forward<Args>(args)...);
case compression_type::ZSTD:
if constexpr (NVCOMP_HAS_ZSTD_DECOMP) {
return nvcompBatchedZstdDecompressGetTempSizeEx(std::forward<Args>(args)...);
} else {
return std::nullopt;
}
case compression_type::DEFLATE: [[fallthrough]];
default: return std::nullopt;
}
#if NVCOMP_HAS_TEMPSIZE_EX
switch (compression) {
case compression_type::SNAPPY:
return nvcompBatchedSnappyDecompressGetTempSizeEx(std::forward<Args>(args)...);
case compression_type::ZSTD:
#if NVCOMP_HAS_ZSTD_DECOMP
return nvcompBatchedZstdDecompressGetTempSizeEx(std::forward<Args>(args)...);
#else
return std::nullopt;
#endif
case compression_type::DEFLATE: [[fallthrough]];
default: return std::nullopt;
}
#endif
return std::nullopt;
}

Expand All @@ -88,17 +104,17 @@ auto batched_decompress_get_temp_size(compression_type compression, Args&&... ar
case compression_type::SNAPPY:
return nvcompBatchedSnappyDecompressGetTempSize(std::forward<Args>(args)...);
case compression_type::ZSTD:
if constexpr (NVCOMP_HAS_ZSTD_DECOMP) {
return nvcompBatchedZstdDecompressGetTempSize(std::forward<Args>(args)...);
} else {
CUDF_FAIL("Unsupported compression type");
}
#if NVCOMP_HAS_ZSTD_DECOMP
return nvcompBatchedZstdDecompressGetTempSize(std::forward<Args>(args)...);
#else
CUDF_FAIL("Unsupported compression type");
#endif
case compression_type::DEFLATE:
if constexpr (NVCOMP_HAS_DEFLATE) {
return nvcompBatchedDeflateDecompressGetTempSize(std::forward<Args>(args)...);
} else {
CUDF_FAIL("Unsupported compression type");
}
#if NVCOMP_HAS_DEFLATE
return nvcompBatchedDeflateDecompressGetTempSize(std::forward<Args>(args)...);
#else
CUDF_FAIL("Unsupported compression type");
#endif
default: CUDF_FAIL("Unsupported compression type");
}
}
Expand All @@ -111,18 +127,17 @@ auto batched_decompress_async(compression_type compression, Args&&... args)
case compression_type::SNAPPY:
return nvcompBatchedSnappyDecompressAsync(std::forward<Args>(args)...);
case compression_type::ZSTD:
if constexpr (NVCOMP_HAS_ZSTD_DECOMP) {
return nvcompBatchedZstdDecompressAsync(std::forward<Args>(args)...);
} else {
CUDF_FAIL("Unsupported compression type");
}

#if NVCOMP_HAS_ZSTD_DECOMP
return nvcompBatchedZstdDecompressAsync(std::forward<Args>(args)...);
#else
CUDF_FAIL("Unsupported compression type");
#endif
case compression_type::DEFLATE:
if constexpr (NVCOMP_HAS_DEFLATE) {
return nvcompBatchedDeflateDecompressAsync(std::forward<Args>(args)...);
} else {
CUDF_FAIL("Unsupported compression type");
}
#if NVCOMP_HAS_DEFLATE
return nvcompBatchedDeflateDecompressAsync(std::forward<Args>(args)...);
#else
CUDF_FAIL("Unsupported compression type");
#endif
default: CUDF_FAIL("Unsupported compression type");
}
}
Expand Down Expand Up @@ -155,13 +170,13 @@ void check_is_zstd_enabled()
"Zstandard compression is experimental, you can enable it through "
"`LIBCUDF_NVCOMP_POLICY` environment variable.");

if constexpr (NVCOMP_ZSTD_IS_DISABLED_ON_PASCAL) {
int device;
int cc_major;
CUDF_CUDA_TRY(cudaGetDevice(&device));
CUDF_CUDA_TRY(cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, device));
CUDF_EXPECTS(cc_major != 6, "Zstandard decompression is disabled on Pascal GPUs");
}
#if NVCOMP_ZSTD_IS_DISABLED_ON_PASCAL
int device;
int cc_major;
CUDF_CUDA_TRY(cudaGetDevice(&device));
CUDF_CUDA_TRY(cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, device));
CUDF_EXPECTS(cc_major != 6, "Zstandard decompression is disabled on Pascal GPUs");
#endif
}

void batched_decompress(compression_type compression,
Expand Down Expand Up @@ -213,22 +228,21 @@ auto batched_compress_temp_size(compression_type compression,
batch_size, max_uncompressed_chunk_bytes, nvcompBatchedSnappyDefaultOpts, &temp_size);
break;
case compression_type::DEFLATE:
if constexpr (NVCOMP_HAS_DEFLATE) {
nvcomp_status = nvcompBatchedDeflateCompressGetTempSize(
batch_size, max_uncompressed_chunk_bytes, nvcompBatchedDeflateDefaultOpts, &temp_size);
break;
} else {
CUDF_FAIL("Unsupported compression type");
}
#if NVCOMP_HAS_DEFLATE
nvcomp_status = nvcompBatchedDeflateCompressGetTempSize(
batch_size, max_uncompressed_chunk_bytes, nvcompBatchedDeflateDefaultOpts, &temp_size);
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
case compression_type::ZSTD:
if constexpr (NVCOMP_HAS_ZSTD_COMP) {
nvcomp_status = nvcompBatchedZstdCompressGetTempSize(
batch_size, max_uncompressed_chunk_bytes, nvcompBatchedZstdDefaultOpts, &temp_size);
break;
} else {
CUDF_FAIL("Unsupported compression type");
}

#if NVCOMP_HAS_ZSTD_COMP
nvcomp_status = nvcompBatchedZstdCompressGetTempSize(
batch_size, max_uncompressed_chunk_bytes, nvcompBatchedZstdDefaultOpts, &temp_size);
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
default: CUDF_FAIL("Unsupported compression type");
}

Expand All @@ -252,21 +266,21 @@ size_t compress_max_output_chunk_size(compression_type compression,
capped_uncomp_bytes, nvcompBatchedSnappyDefaultOpts, &max_comp_chunk_size);
break;
case compression_type::DEFLATE:
if constexpr (NVCOMP_HAS_DEFLATE) {
status = nvcompBatchedDeflateCompressGetMaxOutputChunkSize(
capped_uncomp_bytes, nvcompBatchedDeflateDefaultOpts, &max_comp_chunk_size);
break;
} else {
CUDF_FAIL("Unsupported compression type");
}
#if NVCOMP_HAS_DEFLATE
status = nvcompBatchedDeflateCompressGetMaxOutputChunkSize(
capped_uncomp_bytes, nvcompBatchedDeflateDefaultOpts, &max_comp_chunk_size);
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
case compression_type::ZSTD:
if constexpr (NVCOMP_HAS_ZSTD_COMP) {
status = nvcompBatchedZstdCompressGetMaxOutputChunkSize(
capped_uncomp_bytes, nvcompBatchedZstdDefaultOpts, &max_comp_chunk_size);
break;
} else {
CUDF_FAIL("Unsupported compression type");
}
#if NVCOMP_HAS_ZSTD_COMP
status = nvcompBatchedZstdCompressGetMaxOutputChunkSize(
capped_uncomp_bytes, nvcompBatchedZstdDefaultOpts, &max_comp_chunk_size);
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
default: CUDF_FAIL("Unsupported compression type");
}

Expand Down Expand Up @@ -302,39 +316,37 @@ static void batched_compress_async(compression_type compression,
stream.value());
break;
case compression_type::DEFLATE:
if constexpr (NVCOMP_HAS_DEFLATE) {
nvcomp_status = nvcompBatchedDeflateCompressAsync(device_uncompressed_ptrs,
device_uncompressed_bytes,
max_uncompressed_chunk_bytes,
batch_size,
device_temp_ptr,
temp_bytes,
device_compressed_ptrs,
device_compressed_bytes,
nvcompBatchedDeflateDefaultOpts,
stream.value());
break;
} else {
CUDF_FAIL("Unsupported compression type");
}

#if NVCOMP_HAS_DEFLATE
nvcomp_status = nvcompBatchedDeflateCompressAsync(device_uncompressed_ptrs,
device_uncompressed_bytes,
max_uncompressed_chunk_bytes,
batch_size,
device_temp_ptr,
temp_bytes,
device_compressed_ptrs,
device_compressed_bytes,
nvcompBatchedDeflateDefaultOpts,
stream.value());
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
case compression_type::ZSTD:
if constexpr (NVCOMP_HAS_ZSTD_COMP) {
nvcomp_status = nvcompBatchedZstdCompressAsync(device_uncompressed_ptrs,
device_uncompressed_bytes,
max_uncompressed_chunk_bytes,
batch_size,
device_temp_ptr,
temp_bytes,
device_compressed_ptrs,
device_compressed_bytes,
nvcompBatchedZstdDefaultOpts,
stream.value());
break;
} else {
CUDF_FAIL("Unsupported compression type");
}

#if NVCOMP_HAS_ZSTD_COMP
nvcomp_status = nvcompBatchedZstdCompressAsync(device_uncompressed_ptrs,
device_uncompressed_bytes,
max_uncompressed_chunk_bytes,
batch_size,
device_temp_ptr,
temp_bytes,
device_compressed_ptrs,
device_compressed_bytes,
nvcompBatchedZstdDefaultOpts,
stream.value());
break;
#else
CUDF_FAIL("Unsupported compression type");
#endif
default: CUDF_FAIL("Unsupported compression type");
}
CUDF_EXPECTS(nvcomp_status == nvcompStatus_t::nvcompSuccess, "Error in compression");
Expand Down Expand Up @@ -418,11 +430,11 @@ std::optional<size_t> compress_max_allowed_chunk_size(compression_type compressi
case compression_type::DEFLATE: return 64 * 1024;
case compression_type::SNAPPY: return std::nullopt;
case compression_type::ZSTD:
if constexpr (NVCOMP_HAS_ZSTD_COMP) {
return nvcompZstdCompressionMaxAllowedChunkSize;
} else {
CUDF_FAIL("Unsupported compression type");
}
#if NVCOMP_HAS_ZSTD_COMP
return nvcompZstdCompressionMaxAllowedChunkSize;
#else
CUDF_FAIL("Unsupported compression type");
#endif
default: return std::nullopt;
}
}
Expand Down

0 comments on commit 14477bb

Please sign in to comment.