diff --git a/.github/workflows/DCO.yml b/.github/workflows/DCO.yml new file mode 100644 index 00000000..b546a2be --- /dev/null +++ b/.github/workflows/DCO.yml @@ -0,0 +1,15 @@ +name: DCO Check + +on: + pull_request: + branches: [ "main" ] + +jobs: + dco-check: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Validate DCO + uses: tisonkun/actions-dco@v1.1 diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml new file mode 100644 index 00000000..a12d74a0 --- /dev/null +++ b/.github/workflows/build_test.yml @@ -0,0 +1,39 @@ +name: build & test + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + clang-format-check: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Clang format + run: sudo apt-get install clang-format + - name: Run Clang format check + run: make fmt + + build: + runs-on: ubuntu-latest + container: + image: vsaglib/vsag:ubuntu + steps: + - uses: actions/checkout@v4 + - name: Load Cache + uses: actions/cache@v4.1.2 + with: + path: ./build/ + key: build-${{ hashFiles('./CMakeLists.txt') }}-${{ hashFiles('./.circleci/fresh_ci_cache.commit') }} + - name: Make Asan + run: make asan + - name: Save Cache + uses: actions/cache@v4.1.2 + with: + path: ./build/ + key: build-${{ hashFiles('./CMakeLists.txt') }}-${{ hashFiles('./.circleci/fresh_ci_cache.commit') }} + - name: Test + run: make test_asan \ No newline at end of file diff --git a/src/simd/CMakeLists.txt b/src/simd/CMakeLists.txt index 4501d97a..68501f06 100644 --- a/src/simd/CMakeLists.txt +++ b/src/simd/CMakeLists.txt @@ -6,6 +6,9 @@ set (SIMD_SRCS sq4_simd.cpp sq4_uniform_simd.cpp sq8_uniform_simd.cpp + sse.cpp + avx.cpp + avx512.cpp ) if (DIST_CONTAINS_SSE) set (SIMD_SRCS ${SIMD_SRCS} sse.cpp) @@ -18,11 +21,13 @@ if (DIST_CONTAINS_AVX) set (SIMD_SRCS ${SIMD_SRCS} avx.cpp) set_source_files_properties (avx.cpp PROPERTIES COMPILE_FLAGS "-mavx") endif () -if (DIST_CONTAINS_AVX2) - set_source_files_properties (avx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") -endif () +# FIXME(LHT): cause illegal instruction on platform which has avx only +#if (DIST_CONTAINS_AVX2) +# set_source_files_properties (avx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") +#endif () if (DIST_CONTAINS_AVX512) - set (SIMD_SRCS ${SIMD_SRCS} avx512.cpp) + set (SIMD_SRCS ${SIMD_SRCS} avx512.cpp + normalize.cpp) set_source_files_properties ( avx512.cpp PROPERTIES @@ -50,7 +55,7 @@ endmacro () simd_add_definitions (DIST_CONTAINS_SSE -DENABLE_SSE=1) simd_add_definitions (DIST_CONTAINS_AVX -DENABLE_AVX=1) -simd_add_definitions (DIST_CONTAINS_AVX2 -DENABLE_AVX2=1) +#simd_add_definitions (DIST_CONTAINS_AVX2 -DENABLE_AVX2=1) simd_add_definitions (DIST_CONTAINS_AVX512 -DENABLE_AVX512=1) target_link_libraries (simd PRIVATE cpuinfo) diff --git a/src/simd/avx.cpp b/src/simd/avx.cpp index ed7b3fca..486d0e1a 100644 --- a/src/simd/avx.cpp +++ b/src/simd/avx.cpp @@ -209,7 +209,7 @@ FP32ComputeIP(const float* query, const float* codes, uint64_t dim) { ip += sse::FP32ComputeIP(query + n * 8, codes + n * 8, dim - n * 8); return ip; #else - return vsag::Generic::FP32ComputeIP(query, codes, dim); + return vsag::generic::FP32ComputeIP(query, codes, dim); #endif } @@ -235,7 +235,7 @@ FP32ComputeL2Sqr(const float* query, const float* codes, uint64_t dim) { l2 += sse::FP32ComputeL2Sqr(query + n * 8, codes + n * 8, dim - n * 8); return l2; #else - return vsag::Generic::FP32ComputeL2Sqr(query, codes, dim); + return vsag::generic::FP32ComputeL2Sqr(query, codes, dim); #endif } @@ -275,7 +275,7 @@ SQ8ComputeIP(const float* query, finalResult += sse::SQ8ComputeIP(query + i, codes + i, lowerBound + i, diff + i, dim - i); return finalResult; #else - return Generic::SQ8ComputeIP(query, codes, lowerBound, diff, dim); + return generic::SQ8ComputeIP(query, codes, lowerBound, diff, dim); #endif } @@ -320,7 +320,7 @@ SQ8ComputeL2Sqr(const float* query, result += sse::SQ8ComputeL2Sqr(query + i, codes + i, lowerBound + i, diff + i, dim - i); return result; #else - return vsag::Generic::SQ8ComputeL2Sqr(query, codes, lowerBound, diff, dim); // TODO + return vsag::generic::SQ8ComputeL2Sqr(query, codes, lowerBound, diff, dim); // TODO #endif } @@ -364,7 +364,7 @@ SQ8ComputeCodesIP(const uint8_t* codes1, result += sse::SQ8ComputeCodesIP(codes1 + i, codes2 + i, lowerBound + i, diff + i, dim - i); return result; #else - return Generic::SQ8ComputeCodesIP(codes1, codes2, lowerBound, diff, dim); + return generic::SQ8ComputeCodesIP(codes1, codes2, lowerBound, diff, dim); #endif } @@ -407,7 +407,7 @@ SQ8ComputeCodesL2Sqr(const uint8_t* codes1, result += sse::SQ8ComputeCodesL2Sqr(codes1 + i, codes2 + i, lowerBound + i, diff + i, dim - i); return result; #else - return Generic::SQ8ComputeCodesIP(codes1, codes2, lowerBound, diff, dim); + return generic::SQ8ComputeCodesL2Sqr(codes1, codes2, lowerBound, diff, dim); #endif } @@ -511,7 +511,7 @@ SQ8UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t result += static_cast(sse::SQ8UniformComputeCodesIP(codes1 + d, codes2 + d, dim - d)); return static_cast(result); #else - return sse::S8UniformComputeCodesIP(codes1, codes2, dim); + return sse::SQ8UniformComputeCodesIP(codes1, codes2, dim); #endif } diff --git a/src/simd/avx512_test.cpp b/src/simd/avx512_test.cpp index 155ea12b..39f27559 100644 --- a/src/simd/avx512_test.cpp +++ b/src/simd/avx512_test.cpp @@ -22,10 +22,11 @@ #include "cpuinfo.h" #include "fixtures.h" #include "simd.h" +#include "simd_status.h" TEST_CASE("avx512 int8", "[ut][simd][avx]") { #if defined(ENABLE_AVX512) - if (cpuinfo_has_x86_sse()) { + if (vsag::SimdStatus::SupportAVX512()) { auto common_dims = fixtures::get_common_used_dims(); for (size_t dim : common_dims) { auto vectors = fixtures::generate_vectors(2, dim); diff --git a/src/simd/avx_test.cpp b/src/simd/avx_test.cpp index ea9ee7ee..dc787821 100644 --- a/src/simd/avx_test.cpp +++ b/src/simd/avx_test.cpp @@ -14,16 +14,15 @@ // limitations under the License. #include -#include #include "./simd.h" #include "catch2/catch_approx.hpp" -#include "cpuinfo.h" #include "fixtures.h" +#include "simd_status.h" TEST_CASE("avx l2 simd16", "[ut][simd][avx]") { -#if defined(ENABLE_AVX) - if (cpuinfo_has_x86_sse()) { +#if defined(ENABLE_AVX2) + if (vsag::SimdStatus::SupportAVX2()) { size_t dim = 16; auto vectors = fixtures::generate_vectors(2, dim); @@ -37,8 +36,8 @@ TEST_CASE("avx l2 simd16", "[ut][simd][avx]") { } TEST_CASE("avx ip simd16", "[ut][simd][avx]") { -#if defined(ENABLE_AVX) - if (cpuinfo_has_x86_sse()) { +#if defined(ENABLE_AVX2) + if (vsag::SimdStatus::SupportAVX2()) { size_t dim = 16; auto vectors = fixtures::generate_vectors(2, dim); @@ -52,8 +51,8 @@ TEST_CASE("avx ip simd16", "[ut][simd][avx]") { } TEST_CASE("avx pq calculation", "[ut][simd][avx]") { -#if defined(ENABLE_AVX) - if (cpuinfo_has_x86_avx2()) { +#if defined(ENABLE_AVX2) + if (vsag::SimdStatus::SupportAVX2()) { size_t dim = 256; float single_dim_value = 0.571; float results_expected[256]{0.0f}; diff --git a/src/simd/fp32_simd.h b/src/simd/fp32_simd.h index 024cfce4..47027f3d 100644 --- a/src/simd/fp32_simd.h +++ b/src/simd/fp32_simd.h @@ -27,32 +27,26 @@ float FP32ComputeL2Sqr(const float* query, const float* codes, uint64_t dim); } // namespace generic -#if defined(ENABLE_SSE) namespace sse { float FP32ComputeIP(const float* query, const float* codes, uint64_t dim); float FP32ComputeL2Sqr(const float* query, const float* codes, uint64_t dim); } // namespace sse -#endif -#if defined(ENABLE_AVX2) namespace avx2 { float FP32ComputeIP(const float* query, const float* codes, uint64_t dim); float FP32ComputeL2Sqr(const float* query, const float* codes, uint64_t dim); } // namespace avx2 -#endif -#if defined(ENABLE_AVX512) namespace avx512 { float FP32ComputeIP(const float* query, const float* codes, uint64_t dim); float FP32ComputeL2Sqr(const float* query, const float* codes, uint64_t dim); } // namespace avx512 -#endif using FP32ComputeType = float (*)(const float* query, const float* codes, uint64_t dim); extern FP32ComputeType FP32ComputeIP; diff --git a/src/simd/fp32_simd_test.cpp b/src/simd/fp32_simd_test.cpp index a20e744f..36a41da4 100644 --- a/src/simd/fp32_simd_test.cpp +++ b/src/simd/fp32_simd_test.cpp @@ -18,6 +18,7 @@ #include "catch2/benchmark/catch_benchmark.hpp" #include "catch2/catch_test_macros.hpp" #include "fixtures.h" +#include "simd_status.h" using namespace vsag; @@ -33,15 +34,22 @@ namespace avx2 = sse; namespace avx512 = avx2; #endif -#define TEST_ACCURACY(Func) \ - { \ - auto gt = generic::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \ - auto sse = sse::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \ - auto avx2 = avx2::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \ - auto avx512 = avx512::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); \ +#define TEST_ACCURACY(Func) \ + { \ + float gt, sse, avx2, avx512; \ + gt = generic::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \ + if (SimdStatus::SupportSSE()) { \ + sse = sse::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); \ + } \ + if (SimdStatus::SupportAVX2()) { \ + avx2 = avx2::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); \ + } \ + if (SimdStatus::SupportAVX512()) { \ + avx512 = avx512::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); \ + } \ }; TEST_CASE("FP32 SIMD Compute", "[FP32SIMD]") { diff --git a/src/simd/normalize.cpp b/src/simd/normalize.cpp new file mode 100644 index 00000000..74ccca03 --- /dev/null +++ b/src/simd/normalize.cpp @@ -0,0 +1,60 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "normalize.h" + +#include "simd_status.h" + +namespace vsag { + +static NormalizeType +SetNormalize() { + if (SimdStatus::SupportAVX512()) { +#if defined(ENABLE_AVX512) + return avx512::Normalize; +#endif + } else if (SimdStatus::SupportAVX2()) { +#if defined(ENABLE_AVX2) + return avx2::Normalize; +#endif + } else if (SimdStatus::SupportSSE()) { +#if defined(ENABLE_SSE) + return sse::Normalize; +#endif + } + return generic::Normalize; +} +NormalizeType Normalize = SetNormalize(); + +static DivScalarType +SetDivScalar() { + if (SimdStatus::SupportAVX512()) { +#if defined(ENABLE_AVX512) + return avx512::DivScalar; +#endif + } else if (SimdStatus::SupportAVX2()) { +#if defined(ENABLE_AVX2) + return avx2::DivScalar; +#endif + } else if (SimdStatus::SupportSSE()) { +#if defined(ENABLE_SSE) + return sse::DivScalar; +#endif + } + return generic::DivScalar; +} +DivScalarType DivScalar = SetDivScalar(); + +} // namespace vsag \ No newline at end of file diff --git a/src/simd/normalize.h b/src/simd/normalize.h index ab59c934..30268ba1 100644 --- a/src/simd/normalize.h +++ b/src/simd/normalize.h @@ -26,7 +26,6 @@ float Normalize(const float* from, float* to, uint64_t dim); } // namespace generic -#if defined(ENABLE_SSE) namespace sse { void DivScalar(const float* from, float* to, uint64_t dim, float scalar); @@ -34,9 +33,7 @@ DivScalar(const float* from, float* to, uint64_t dim, float scalar); float Normalize(const float* from, float* to, uint64_t dim); } // namespace sse -#endif -#if defined(ENABLE_AVX2) namespace avx2 { void DivScalar(const float* from, float* to, uint64_t dim, float scalar); @@ -44,9 +41,7 @@ DivScalar(const float* from, float* to, uint64_t dim, float scalar); float Normalize(const float* from, float* to, uint64_t dim); } // namespace avx2 -#endif -#if defined(ENABLE_AVX512) namespace avx512 { void DivScalar(const float* from, float* to, uint64_t dim, float scalar); @@ -54,34 +49,10 @@ DivScalar(const float* from, float* to, uint64_t dim, float scalar); float Normalize(const float* from, float* to, uint64_t dim); } // namespace avx512 -#endif -inline void -DivScalar(const float* from, float* to, uint64_t dim, float scalar) { -#if defined(ENABLE_AVX512) - avx512::DivScalar(from, to, dim, scalar); -#endif -#if defined(ENABLE_AVX2) - avx2::DivScalar(from, to, dim, scalar); -#endif -#if defined(ENABLE_SSE) - sse::DivScalar(from, to, dim, scalar); -#endif - generic::DivScalar(from, to, dim, scalar); -} - -inline float -Normalize(const float* from, float* to, uint64_t dim) { -#if defined(ENABLE_AVX512) - return avx512::Normalize(from, to, dim); -#endif -#if defined(ENABLE_AVX2) - return avx2::Normalize(from, to, dim); -#endif -#if defined(ENABLE_SSE) - return sse::Normalize(from, to, dim); -#endif - return generic::Normalize(from, to, dim); -} +using NormalizeType = float (*)(const float* from, float* to, uint64_t dim); +extern NormalizeType Normalize; +using DivScalarType = void (*)(const float* from, float* to, uint64_t dim, float scalar); +extern DivScalarType DivScalar; } // namespace vsag diff --git a/src/simd/normalize_test.cpp b/src/simd/normalize_test.cpp index 4f3c6ccd..6b174016 100644 --- a/src/simd/normalize_test.cpp +++ b/src/simd/normalize_test.cpp @@ -18,6 +18,7 @@ #include "catch2/benchmark/catch_benchmark.hpp" #include "catch2/catch_test_macros.hpp" #include "fixtures.h" +#include "simd_status.h" using namespace vsag; @@ -41,21 +42,36 @@ TEST_CASE("Normalize SIMD Compute", "[simd]") { std::vector tmp_value(dim * 4); for (uint64_t i = 0; i < count; ++i) { auto gt = generic::Normalize(vec1.data() + i * dim, tmp_value.data(), dim); - auto sse = sse::Normalize(vec1.data() + i * dim, tmp_value.data() + dim, dim); - auto avx2 = avx2::Normalize(vec1.data() + i * dim, tmp_value.data() + dim * 2, dim); - auto avx512 = avx512::Normalize(vec1.data() + i * dim, tmp_value.data() + dim * 3, dim); - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); - for (int j = 0; j < dim; ++j) { - REQUIRE(fixtures::dist_t(tmp_value[j]) == fixtures::dist_t(tmp_value[j + dim])); - REQUIRE(fixtures::dist_t(tmp_value[j]) == fixtures::dist_t(tmp_value[j + dim * 2])); - REQUIRE(fixtures::dist_t(tmp_value[j]) == fixtures::dist_t(tmp_value[j + dim * 3])); + if (SimdStatus::SupportSSE()) { + auto sse = sse::Normalize(vec1.data() + i * dim, tmp_value.data() + dim, dim); + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); + for (int j = 0; j < dim; ++j) { + REQUIRE(fixtures::dist_t(tmp_value[j]) == + fixtures::dist_t(tmp_value[j + dim * 1])); + } + } + if (SimdStatus::SupportAVX2()) { + auto avx2 = avx2::Normalize(vec1.data() + i * dim, tmp_value.data() + dim * 2, dim); + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); + for (int j = 0; j < dim; ++j) { + REQUIRE(fixtures::dist_t(tmp_value[j]) == + fixtures::dist_t(tmp_value[j + dim * 2])); + } + } + if (SimdStatus::SupportAVX512()) { + auto avx512 = + avx512::Normalize(vec1.data() + i * dim, tmp_value.data() + dim * 3, dim); + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); + for (int j = 0; j < dim; ++j) { + REQUIRE(fixtures::dist_t(tmp_value[j]) == + fixtures::dist_t(tmp_value[j + dim * 3])); + } } } } } + #define BENCHMARK_SIMD_COMPUTE(Simd, Comp) \ BENCHMARK_ADVANCED(#Simd #Comp) { \ for (int i = 0; i < count; ++i) { \ diff --git a/src/simd/simd.cpp b/src/simd/simd.cpp index 5c5f6f07..ac3cfae8 100644 --- a/src/simd/simd.cpp +++ b/src/simd/simd.cpp @@ -126,13 +126,15 @@ GetInnerProductDistanceFunc(size_t dim) { DistanceFunc GetINT8InnerProductDistanceFunc(size_t dim) { + if (SimdStatus::SupportAVX512()) { #ifdef ENABLE_AVX512 - if (dim > 32) { - return vsag::INT8InnerProduct512ResidualsAVX512Distance; - } else if (dim > 16) { - return vsag::INT8InnerProduct256ResidualsAVX512Distance; - } + if (dim > 32) { + return vsag::INT8InnerProduct512ResidualsAVX512Distance; + } else if (dim > 16) { + return vsag::INT8InnerProduct256ResidualsAVX512Distance; + } #endif + } return vsag::INT8InnerProductDistance; } diff --git a/src/simd/sq4_simd.h b/src/simd/sq4_simd.h index 9d390f90..9e0715c0 100644 --- a/src/simd/sq4_simd.h +++ b/src/simd/sq4_simd.h @@ -45,7 +45,6 @@ SQ4ComputeCodesL2Sqr(const uint8_t* codes1, uint64_t dim); } // namespace generic -#if defined(ENABLE_SSE) namespace sse { float SQ4ComputeIP(const float* query, @@ -72,9 +71,7 @@ SQ4ComputeCodesL2Sqr(const uint8_t* codes1, const float* diff, uint64_t dim); } // namespace sse -#endif -#if defined(ENABLE_AVX2) namespace avx2 { float SQ4ComputeIP(const float* query, @@ -101,9 +98,7 @@ SQ4ComputeCodesL2Sqr(const uint8_t* codes1, const float* diff, uint64_t dim); } // namespace avx2 -#endif -#if defined(ENABLE_AVX512) namespace avx512 { float SQ4ComputeIP(const float* query, @@ -130,7 +125,6 @@ SQ4ComputeCodesL2Sqr(const uint8_t* codes1, const float* diff, uint64_t dim); } // namespace avx512 -#endif using SQ4ComputeType = float (*)(const float* query, const uint8_t* codes, diff --git a/src/simd/sq4_uniform_simd.h b/src/simd/sq4_uniform_simd.h index c240776b..0a6f352c 100644 --- a/src/simd/sq4_uniform_simd.h +++ b/src/simd/sq4_uniform_simd.h @@ -23,26 +23,20 @@ float SQ4UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t dim); } // namespace generic -#if defined(ENABLE_SSE) namespace sse { float SQ4UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t dim); } // namespace sse -#endif -#if defined(ENABLE_AVX2) namespace avx2 { float SQ4UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t dim); } // namespace avx2 -#endif -#if defined(ENABLE_AVX512) namespace avx512 { float SQ4UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t dim); } // namespace avx512 -#endif using SQ4UniformComputeCodesType = float (*)(const uint8_t* codes1, const uint8_t* codes2, diff --git a/src/simd/sq4_uniform_simd_test.cpp b/src/simd/sq4_uniform_simd_test.cpp index b074c6a3..73e37660 100644 --- a/src/simd/sq4_uniform_simd_test.cpp +++ b/src/simd/sq4_uniform_simd_test.cpp @@ -17,9 +17,9 @@ #include -#include "../logger.h" #include "catch2/benchmark/catch_benchmark.hpp" #include "fixtures.h" +#include "simd_status.h" using namespace vsag; @@ -35,17 +35,25 @@ namespace avx2 = sse; namespace avx512 = avx2; #endif -#define TEST_ACCURACY(Func) \ - { \ - auto gt = \ - generic::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ - auto sse = sse::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ - auto avx2 = avx2::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ - auto avx512 = \ - avx512::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); \ +#define TEST_ACCURACY(Func) \ + { \ + auto gt = \ + generic::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ + if (SimdStatus::SupportSSE()) { \ + auto sse = \ + sse::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); \ + } \ + if (SimdStatus::SupportAVX2()) { \ + auto avx2 = \ + avx2::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); \ + } \ + if (SimdStatus::SupportAVX512()) { \ + auto avx512 = \ + avx512::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); \ + } \ } TEST_CASE("SQ4 Uniform SIMD Compute Codes", "[SQ4 Uniform SIMD]") { diff --git a/src/simd/sq8_simd.h b/src/simd/sq8_simd.h index 1bba09fc..e6c58858 100644 --- a/src/simd/sq8_simd.h +++ b/src/simd/sq8_simd.h @@ -74,7 +74,6 @@ SQ8ComputeCodesL2Sqr(const uint8_t* codes1, } // namespace sse #endif -#if defined(ENABLE_AVX2) namespace avx2 { float SQ8ComputeIP(const float* query, @@ -101,9 +100,7 @@ SQ8ComputeCodesL2Sqr(const uint8_t* codes1, const float* diff, uint64_t dim); } // namespace avx2 -#endif -#if defined(ENABLE_AVX512) namespace avx512 { float SQ8ComputeIP(const float* query, @@ -130,7 +127,6 @@ SQ8ComputeCodesL2Sqr(const uint8_t* codes1, const float* diff, uint64_t dim); } // namespace avx512 -#endif using SQ8ComputeType = float (*)(const float* query, const uint8_t* codes, diff --git a/src/simd/sq8_simd_test.cpp b/src/simd/sq8_simd_test.cpp index 77cdb9f2..e399714c 100644 --- a/src/simd/sq8_simd_test.cpp +++ b/src/simd/sq8_simd_test.cpp @@ -18,6 +18,7 @@ #include "catch2/benchmark/catch_benchmark.hpp" #include "catch2/catch_test_macros.hpp" #include "fixtures.h" +#include "simd_status.h" using namespace vsag; @@ -33,19 +34,25 @@ namespace avx2 = sse; namespace avx512 = avx2; #endif -#define TEST_ACCURACY(Func) \ - { \ - auto gt = generic::Func( \ - vec1.data() + i * dim, vec2.data() + i * dim, lb.data(), diff.data(), dim); \ - auto sse = \ - sse::Func(vec1.data() + i * dim, vec2.data() + i * dim, lb.data(), diff.data(), dim); \ - auto avx2 = \ - avx2::Func(vec1.data() + i * dim, vec2.data() + i * dim, lb.data(), diff.data(), dim); \ - auto avx512 = avx512::Func( \ - vec1.data() + i * dim, vec2.data() + i * dim, lb.data(), diff.data(), dim); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); \ +#define TEST_ACCURACY(Func) \ + { \ + auto gt = generic::Func( \ + vec1.data() + i * dim, vec2.data() + i * dim, lb.data(), diff.data(), dim); \ + if (SimdStatus::SupportSSE()) { \ + auto sse = sse::Func( \ + vec1.data() + i * dim, vec2.data() + i * dim, lb.data(), diff.data(), dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); \ + } \ + if (SimdStatus::SupportAVX2()) { \ + auto avx2 = avx2::Func( \ + vec1.data() + i * dim, vec2.data() + i * dim, lb.data(), diff.data(), dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); \ + } \ + if (SimdStatus::SupportAVX512()) { \ + auto avx512 = avx512::Func( \ + vec1.data() + i * dim, vec2.data() + i * dim, lb.data(), diff.data(), dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); \ + } \ } TEST_CASE("SQ8 SIMD Compute Codes", "[SQ8 SIMD]") { diff --git a/src/simd/sq8_uniform_simd.h b/src/simd/sq8_uniform_simd.h index 4f6dca78..ae8050c3 100644 --- a/src/simd/sq8_uniform_simd.h +++ b/src/simd/sq8_uniform_simd.h @@ -23,26 +23,20 @@ float SQ8UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t dim); } // namespace generic -#if defined(ENABLE_SSE) namespace sse { float SQ8UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t dim); } // namespace sse -#endif -#if defined(ENABLE_AVX2) namespace avx2 { float SQ8UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t dim); } // namespace avx2 -#endif -#if defined(ENABLE_AVX512) namespace avx512 { float SQ8UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t dim); } // namespace avx512 -#endif using SQ8UniformComputeCodesType = float (*)(const uint8_t* codes1, const uint8_t* codes2, diff --git a/src/simd/sq8_uniform_simd_test.cpp b/src/simd/sq8_uniform_simd_test.cpp index ff06275a..13db3c19 100644 --- a/src/simd/sq8_uniform_simd_test.cpp +++ b/src/simd/sq8_uniform_simd_test.cpp @@ -19,6 +19,7 @@ #include "catch2/benchmark/catch_benchmark.hpp" #include "fixtures.h" +#include "simd_status.h" using namespace vsag; @@ -34,17 +35,25 @@ namespace avx2 = sse; namespace avx512 = avx2; #endif -#define TEST_ACCURACY(Func) \ - { \ - auto gt = \ - generic::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ - auto sse = sse::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ - auto avx2 = avx2::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ - auto avx512 = \ - avx512::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); \ - REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); \ +#define TEST_ACCURACY(Func) \ + { \ + auto gt = \ + generic::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ + if (SimdStatus::SupportSSE()) { \ + auto sse = \ + sse::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); \ + } \ + if (SimdStatus::SupportAVX2()) { \ + auto avx2 = \ + avx2::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); \ + } \ + if (SimdStatus::SupportAVX512()) { \ + auto avx512 = \ + avx512::Func(codes1.data() + i * code_size, codes2.data() + i * code_size, dim); \ + REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); \ + } \ } TEST_CASE("SQ8 Uniform SIMD Compute Codes", "[SQ8 Uniform SIMD]") {