Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
pitrou committed Feb 15, 2024
1 parent a03d957 commit 893215e
Show file tree
Hide file tree
Showing 8 changed files with 776 additions and 385 deletions.
143 changes: 101 additions & 42 deletions cpp/src/arrow/util/byte_stream_split_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@

#include "arrow/util/endian.h"
#include "arrow/util/simd.h"
#include "arrow/util/small_vector.h"
#include "arrow/util/ubsan.h"

#include <algorithm>
#include <array>
#include <cassert>
#include <cstdint>
#include <cstring>

#ifdef ARROW_HAVE_SSE4_2
#define ARROW_HAVE_SIMD_SPLIT
Expand All @@ -37,8 +40,9 @@ namespace arrow::util::internal {

#if defined(ARROW_HAVE_SSE4_2)
template <int kNumStreams>
void ByteStreamSplitDecodeSse2(const uint8_t* data, int64_t num_values, int64_t stride,
uint8_t* out) {
void ByteStreamSplitDecodeSse2(const uint8_t* data, int width, int64_t num_values,
int64_t stride, uint8_t* out) {
assert(width == kNumStreams);
static_assert(kNumStreams == 4 || kNumStreams == 8, "Invalid number of streams.");
constexpr int kNumStreamsLog2 = (kNumStreams == 8 ? 3 : 2);
constexpr int64_t kBlockSize = sizeof(__m128i) * kNumStreams;
Expand Down Expand Up @@ -89,8 +93,9 @@ void ByteStreamSplitDecodeSse2(const uint8_t* data, int64_t num_values, int64_t
}

template <int kNumStreams>
void ByteStreamSplitEncodeSse2(const uint8_t* raw_values, const int64_t num_values,
uint8_t* output_buffer_raw) {
void ByteStreamSplitEncodeSse2(const uint8_t* raw_values, int width,
const int64_t num_values, uint8_t* output_buffer_raw) {
assert(width == kNumStreams);
static_assert(kNumStreams == 4 || kNumStreams == 8, "Invalid number of streams.");
constexpr int kBlockSize = sizeof(__m128i) * kNumStreams;

Expand Down Expand Up @@ -176,15 +181,16 @@ void ByteStreamSplitEncodeSse2(const uint8_t* raw_values, const int64_t num_valu

#if defined(ARROW_HAVE_AVX2)
template <int kNumStreams>
void ByteStreamSplitDecodeAvx2(const uint8_t* data, int64_t num_values, int64_t stride,
uint8_t* out) {
void ByteStreamSplitDecodeAvx2(const uint8_t* data, int width, int64_t num_values,
int64_t stride, uint8_t* out) {
assert(width == kNumStreams);
static_assert(kNumStreams == 4 || kNumStreams == 8, "Invalid number of streams.");
constexpr int kNumStreamsLog2 = (kNumStreams == 8 ? 3 : 2);
constexpr int64_t kBlockSize = sizeof(__m256i) * kNumStreams;

const int64_t size = num_values * kNumStreams;
if (size < kBlockSize) // Back to SSE for small size
return ByteStreamSplitDecodeSse2<kNumStreams>(data, num_values, stride, out);
return ByteStreamSplitDecodeSse2<kNumStreams>(data, width, num_values, stride, out);
const int64_t num_blocks = size / kBlockSize;

// First handle suffix.
Expand Down Expand Up @@ -260,18 +266,19 @@ void ByteStreamSplitDecodeAvx2(const uint8_t* data, int64_t num_values, int64_t
}

template <int kNumStreams>
void ByteStreamSplitEncodeAvx2(const uint8_t* raw_values, const int64_t num_values,
uint8_t* output_buffer_raw) {
void ByteStreamSplitEncodeAvx2(const uint8_t* raw_values, int width,
const int64_t num_values, uint8_t* output_buffer_raw) {
assert(width == kNumStreams);
static_assert(kNumStreams == 4 || kNumStreams == 8, "Invalid number of streams.");
constexpr int kBlockSize = sizeof(__m256i) * kNumStreams;

if constexpr (kNumStreams == 8) // Back to SSE, currently no path for double.
return ByteStreamSplitEncodeSse2<kNumStreams>(raw_values, num_values,
return ByteStreamSplitEncodeSse2<kNumStreams>(raw_values, width, num_values,
output_buffer_raw);

const int64_t size = num_values * kNumStreams;
if (size < kBlockSize) // Back to SSE for small size
return ByteStreamSplitEncodeSse2<kNumStreams>(raw_values, num_values,
return ByteStreamSplitEncodeSse2<kNumStreams>(raw_values, width, num_values,
output_buffer_raw);
const int64_t num_blocks = size / kBlockSize;
const __m256i* raw_values_simd = reinterpret_cast<const __m256i*>(raw_values);
Expand Down Expand Up @@ -334,15 +341,16 @@ void ByteStreamSplitEncodeAvx2(const uint8_t* raw_values, const int64_t num_valu

#if defined(ARROW_HAVE_AVX512)
template <int kNumStreams>
void ByteStreamSplitDecodeAvx512(const uint8_t* data, int64_t num_values, int64_t stride,
uint8_t* out) {
void ByteStreamSplitDecodeAvx512(const uint8_t* data, int width, int64_t num_values,
int64_t stride, uint8_t* out) {
assert(width == kNumStreams);
static_assert(kNumStreams == 4 || kNumStreams == 8, "Invalid number of streams.");
constexpr int kNumStreamsLog2 = (kNumStreams == 8 ? 3 : 2);
constexpr int64_t kBlockSize = sizeof(__m512i) * kNumStreams;

const int64_t size = num_values * kNumStreams;
if (size < kBlockSize) // Back to AVX2 for small size
return ByteStreamSplitDecodeAvx2<kNumStreams>(data, num_values, stride, out);
return ByteStreamSplitDecodeAvx2<kNumStreams>(data, width, num_values, stride, out);
const int64_t num_blocks = size / kBlockSize;

// First handle suffix.
Expand Down Expand Up @@ -436,15 +444,16 @@ void ByteStreamSplitDecodeAvx512(const uint8_t* data, int64_t num_values, int64_
}

template <int kNumStreams>
void ByteStreamSplitEncodeAvx512(const uint8_t* raw_values, const int64_t num_values,
uint8_t* output_buffer_raw) {
void ByteStreamSplitEncodeAvx512(const uint8_t* raw_values, int width,
const int64_t num_values, uint8_t* output_buffer_raw) {
assert(width == kNumStreams);
static_assert(kNumStreams == 4 || kNumStreams == 8, "Invalid number of streams.");
constexpr int kBlockSize = sizeof(__m512i) * kNumStreams;

const int64_t size = num_values * kNumStreams;

if (size < kBlockSize) // Back to AVX2 for small size
return ByteStreamSplitEncodeAvx2<kNumStreams>(raw_values, num_values,
return ByteStreamSplitEncodeAvx2<kNumStreams>(raw_values, width, num_values,
output_buffer_raw);

const int64_t num_blocks = size / kBlockSize;
Expand Down Expand Up @@ -547,30 +556,31 @@ void ByteStreamSplitEncodeAvx512(const uint8_t* raw_values, const int64_t num_va

#if defined(ARROW_HAVE_SIMD_SPLIT)
template <int kNumStreams>
void inline ByteStreamSplitDecodeSimd(const uint8_t* data, int64_t num_values,
void inline ByteStreamSplitDecodeSimd(const uint8_t* data, int width, int64_t num_values,
int64_t stride, uint8_t* out) {
#if defined(ARROW_HAVE_AVX512)
return ByteStreamSplitDecodeAvx512<kNumStreams>(data, num_values, stride, out);
return ByteStreamSplitDecodeAvx512<kNumStreams>(data, width, num_values, stride, out);
#elif defined(ARROW_HAVE_AVX2)
return ByteStreamSplitDecodeAvx2<kNumStreams>(data, num_values, stride, out);
return ByteStreamSplitDecodeAvx2<kNumStreams>(data, width, num_values, stride, out);
#elif defined(ARROW_HAVE_SSE4_2)
return ByteStreamSplitDecodeSse2<kNumStreams>(data, num_values, stride, out);
return ByteStreamSplitDecodeSse2<kNumStreams>(data, width, num_values, stride, out);
#else
#error "ByteStreamSplitDecodeSimd not implemented"
#endif
}

template <int kNumStreams>
void inline ByteStreamSplitEncodeSimd(const uint8_t* raw_values, const int64_t num_values,
void inline ByteStreamSplitEncodeSimd(const uint8_t* raw_values, int width,
const int64_t num_values,
uint8_t* output_buffer_raw) {
#if defined(ARROW_HAVE_AVX512)
return ByteStreamSplitEncodeAvx512<kNumStreams>(raw_values, num_values,
return ByteStreamSplitEncodeAvx512<kNumStreams>(raw_values, width, num_values,
output_buffer_raw);
#elif defined(ARROW_HAVE_AVX2)
return ByteStreamSplitEncodeAvx2<kNumStreams>(raw_values, num_values,
return ByteStreamSplitEncodeAvx2<kNumStreams>(raw_values, width, num_values,
output_buffer_raw);
#elif defined(ARROW_HAVE_SSE4_2)
return ByteStreamSplitEncodeSse2<kNumStreams>(raw_values, num_values,
return ByteStreamSplitEncodeSse2<kNumStreams>(raw_values, width, num_values,
output_buffer_raw);
#else
#error "ByteStreamSplitEncodeSimd not implemented"
Expand Down Expand Up @@ -671,45 +681,94 @@ inline void DoMergeStreams(const uint8_t** src_streams, int width, int64_t nvalu
}

template <int kNumStreams>
void ByteStreamSplitEncodeScalar(const uint8_t* raw_values, const int64_t num_values,
uint8_t* output_buffer_raw) {
void ByteStreamSplitEncodeScalar(const uint8_t* raw_values, int width,
const int64_t num_values, uint8_t* out) {
assert(width == kNumStreams);
std::array<uint8_t*, kNumStreams> dest_streams;
for (int stream = 0; stream < kNumStreams; ++stream) {
dest_streams[stream] = &output_buffer_raw[stream * num_values];
dest_streams[stream] = &out[stream * num_values];
}
DoSplitStreams(raw_values, kNumStreams, num_values, dest_streams.data());
}

inline void ByteStreamSplitEncodeScalarDynamic(const uint8_t* raw_values, int width,
const int64_t num_values, uint8_t* out) {
::arrow::internal::SmallVector<uint8_t*, 16> dest_streams;
dest_streams.resize(width);
for (int stream = 0; stream < width; ++stream) {
dest_streams[stream] = &out[stream * num_values];
}
DoSplitStreams(raw_values, width, num_values, dest_streams.data());
}

template <int kNumStreams>
void ByteStreamSplitDecodeScalar(const uint8_t* data, int64_t num_values, int64_t stride,
uint8_t* out) {
void ByteStreamSplitDecodeScalar(const uint8_t* data, int width, int64_t num_values,
int64_t stride, uint8_t* out) {
assert(width == kNumStreams);
std::array<const uint8_t*, kNumStreams> src_streams;
for (int stream = 0; stream < kNumStreams; ++stream) {
src_streams[stream] = &data[stream * stride];
}
DoMergeStreams(src_streams.data(), kNumStreams, num_values, out);
}

template <int kNumStreams>
void inline ByteStreamSplitEncode(const uint8_t* raw_values, const int64_t num_values,
uint8_t* output_buffer_raw) {
inline void ByteStreamSplitDecodeScalarDynamic(const uint8_t* data, int width,
int64_t num_values, int64_t stride,
uint8_t* out) {
::arrow::internal::SmallVector<const uint8_t*, 16> src_streams;
src_streams.resize(width);
for (int stream = 0; stream < width; ++stream) {
src_streams[stream] = &data[stream * stride];
}
DoMergeStreams(src_streams.data(), width, num_values, out);
}

inline void ByteStreamSplitEncode(const uint8_t* raw_values, int width,
const int64_t num_values, uint8_t* out) {
#if defined(ARROW_HAVE_SIMD_SPLIT)
return ByteStreamSplitEncodeSimd<kNumStreams>(raw_values, num_values,
output_buffer_raw);
#define ByteStreamSplitEncodePerhapsSimd ByteStreamSplitEncodeSimd
#else
return ByteStreamSplitEncodeScalar<kNumStreams>(raw_values, num_values,
output_buffer_raw);
#define ByteStreamSplitEncodePerhapsSimd ByteStreamSplitEncodeScalar
#endif
switch (width) {
case 1:
memcpy(out, raw_values, num_values);
return;
case 2:
return ByteStreamSplitEncodeScalar<2>(raw_values, width, num_values, out);
case 4:
return ByteStreamSplitEncodePerhapsSimd<4>(raw_values, width, num_values, out);
case 8:
return ByteStreamSplitEncodePerhapsSimd<8>(raw_values, width, num_values, out);
case 16:
return ByteStreamSplitEncodeScalar<16>(raw_values, width, num_values, out);
}
return ByteStreamSplitEncodeScalarDynamic(raw_values, width, num_values, out);
#undef ByteStreamSplitEncodePerhapsSimd
}

template <int kNumStreams>
void inline ByteStreamSplitDecode(const uint8_t* data, int64_t num_values, int64_t stride,
uint8_t* out) {
inline void ByteStreamSplitDecode(const uint8_t* data, int width, int64_t num_values,
int64_t stride, uint8_t* out) {
#if defined(ARROW_HAVE_SIMD_SPLIT)
return ByteStreamSplitDecodeSimd<kNumStreams>(data, num_values, stride, out);
#define ByteStreamSplitDecodePerhapsSimd ByteStreamSplitDecodeSimd
#else
return ByteStreamSplitDecodeScalar<kNumStreams>(data, num_values, stride, out);
#define ByteStreamSplitDecodePerhapsSimd ByteStreamSplitDecodeScalar
#endif
switch (width) {
case 1:
memcpy(out, data, num_values);
return;
case 2:
return ByteStreamSplitDecodeScalar<2>(data, width, num_values, stride, out);
case 4:
return ByteStreamSplitDecodePerhapsSimd<4>(data, width, num_values, stride, out);
case 8:
return ByteStreamSplitDecodePerhapsSimd<8>(data, width, num_values, stride, out);
case 16:
return ByteStreamSplitDecodeScalar<16>(data, width, num_values, stride, out);
}
return ByteStreamSplitDecodeScalarDynamic(data, width, num_values, stride, out);
#undef ByteStreamSplitDecodePerhapsSimd
}

} // namespace arrow::util::internal
Loading

0 comments on commit 893215e

Please sign in to comment.