Skip to content

Commit

Permalink
fix illegal instruction on platform has avx only
Browse files Browse the repository at this point in the history
Signed-off-by: LHT129 <[email protected]>
  • Loading branch information
LHT129 committed Dec 2, 2024
1 parent 0a98c28 commit c7a00b3
Show file tree
Hide file tree
Showing 19 changed files with 254 additions and 142 deletions.
15 changes: 15 additions & 0 deletions .github/workflows/DCO.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: DCO Check

on:
pull_request:
branches: [ "main" ]

jobs:
dco-check:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Validate DCO
uses: tisonkun/[email protected]
39 changes: 39 additions & 0 deletions .github/workflows/build_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: build & test

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

jobs:
clang-format-check:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Clang format
run: sudo apt-get install clang-format
- name: Run Clang format check
run: make fmt

build:
runs-on: ubuntu-latest
container:
image: vsaglib/vsag:ubuntu
steps:
- uses: actions/checkout@v4
- name: load cache
uses: actions/[email protected]
with:
path: ./build/
key: build-${{ hashFiles('./CMakeLists.txt') }}-${{ hashFiles('./.circleci/fresh_ci_cache.commit') }}
- name: make asan
run: make asan
- name: save cache
uses: actions/[email protected]
with:
path: ./build/
key: build-${{ hashFiles('./CMakeLists.txt') }}-${{ hashFiles('./.circleci/fresh_ci_cache.commit') }}
- name: make test
run: make test_asan_parallel
15 changes: 10 additions & 5 deletions src/simd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ set (SIMD_SRCS
sq4_simd.cpp
sq4_uniform_simd.cpp
sq8_uniform_simd.cpp
sse.cpp
avx.cpp
avx512.cpp
)
if (DIST_CONTAINS_SSE)
set (SIMD_SRCS ${SIMD_SRCS} sse.cpp)
Expand All @@ -18,11 +21,13 @@ if (DIST_CONTAINS_AVX)
set (SIMD_SRCS ${SIMD_SRCS} avx.cpp)
set_source_files_properties (avx.cpp PROPERTIES COMPILE_FLAGS "-mavx")
endif ()
if (DIST_CONTAINS_AVX2)
set_source_files_properties (avx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mfma")
endif ()
# FIXME(LHT): cause illegal instruction on platform which has avx only
#if (DIST_CONTAINS_AVX2)
# set_source_files_properties (avx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mfma")
#endif ()
if (DIST_CONTAINS_AVX512)
set (SIMD_SRCS ${SIMD_SRCS} avx512.cpp)
set (SIMD_SRCS ${SIMD_SRCS} avx512.cpp
normalize.cpp)
set_source_files_properties (
avx512.cpp
PROPERTIES
Expand Down Expand Up @@ -50,7 +55,7 @@ endmacro ()

simd_add_definitions (DIST_CONTAINS_SSE -DENABLE_SSE=1)
simd_add_definitions (DIST_CONTAINS_AVX -DENABLE_AVX=1)
simd_add_definitions (DIST_CONTAINS_AVX2 -DENABLE_AVX2=1)
#simd_add_definitions (DIST_CONTAINS_AVX2 -DENABLE_AVX2=1)
simd_add_definitions (DIST_CONTAINS_AVX512 -DENABLE_AVX512=1)

target_link_libraries (simd PRIVATE cpuinfo)
Expand Down
14 changes: 7 additions & 7 deletions src/simd/avx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ FP32ComputeIP(const float* query, const float* codes, uint64_t dim) {
ip += sse::FP32ComputeIP(query + n * 8, codes + n * 8, dim - n * 8);
return ip;
#else
return vsag::Generic::FP32ComputeIP(query, codes, dim);
return vsag::generic::FP32ComputeIP(query, codes, dim);
#endif
}

Expand All @@ -235,7 +235,7 @@ FP32ComputeL2Sqr(const float* query, const float* codes, uint64_t dim) {
l2 += sse::FP32ComputeL2Sqr(query + n * 8, codes + n * 8, dim - n * 8);
return l2;
#else
return vsag::Generic::FP32ComputeL2Sqr(query, codes, dim);
return vsag::generic::FP32ComputeL2Sqr(query, codes, dim);
#endif
}

Expand Down Expand Up @@ -275,7 +275,7 @@ SQ8ComputeIP(const float* query,
finalResult += sse::SQ8ComputeIP(query + i, codes + i, lowerBound + i, diff + i, dim - i);
return finalResult;
#else
return Generic::SQ8ComputeIP(query, codes, lowerBound, diff, dim);
return generic::SQ8ComputeIP(query, codes, lowerBound, diff, dim);
#endif
}

Expand Down Expand Up @@ -320,7 +320,7 @@ SQ8ComputeL2Sqr(const float* query,
result += sse::SQ8ComputeL2Sqr(query + i, codes + i, lowerBound + i, diff + i, dim - i);
return result;
#else
return vsag::Generic::SQ8ComputeL2Sqr(query, codes, lowerBound, diff, dim); // TODO
return vsag::generic::SQ8ComputeL2Sqr(query, codes, lowerBound, diff, dim); // TODO
#endif
}

Expand Down Expand Up @@ -364,7 +364,7 @@ SQ8ComputeCodesIP(const uint8_t* codes1,
result += sse::SQ8ComputeCodesIP(codes1 + i, codes2 + i, lowerBound + i, diff + i, dim - i);
return result;
#else
return Generic::SQ8ComputeCodesIP(codes1, codes2, lowerBound, diff, dim);
return generic::SQ8ComputeCodesIP(codes1, codes2, lowerBound, diff, dim);
#endif
}

Expand Down Expand Up @@ -407,7 +407,7 @@ SQ8ComputeCodesL2Sqr(const uint8_t* codes1,
result += sse::SQ8ComputeCodesL2Sqr(codes1 + i, codes2 + i, lowerBound + i, diff + i, dim - i);
return result;
#else
return Generic::SQ8ComputeCodesIP(codes1, codes2, lowerBound, diff, dim);
return generic::SQ8ComputeCodesL2Sqr(codes1, codes2, lowerBound, diff, dim);
#endif
}

Expand Down Expand Up @@ -511,7 +511,7 @@ SQ8UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t
result += static_cast<int32_t>(sse::SQ8UniformComputeCodesIP(codes1 + d, codes2 + d, dim - d));
return static_cast<float>(result);
#else
return sse::S8UniformComputeCodesIP(codes1, codes2, dim);
return sse::SQ8UniformComputeCodesIP(codes1, codes2, dim);
#endif
}

Expand Down
3 changes: 2 additions & 1 deletion src/simd/avx512_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
#include "cpuinfo.h"
#include "fixtures.h"
#include "simd.h"
#include "simd_status.h"

TEST_CASE("avx512 int8", "[ut][simd][avx]") {
#if defined(ENABLE_AVX512)
if (cpuinfo_has_x86_sse()) {
if (vsag::SimdStatus::SupportAVX512()) {
auto common_dims = fixtures::get_common_used_dims();
for (size_t dim : common_dims) {
auto vectors = fixtures::generate_vectors(2, dim);
Expand Down
15 changes: 7 additions & 8 deletions src/simd/avx_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@
// limitations under the License.

#include <catch2/catch_test_macros.hpp>
#include <cstdint>

#include "./simd.h"
#include "catch2/catch_approx.hpp"
#include "cpuinfo.h"
#include "fixtures.h"
#include "simd_status.h"

TEST_CASE("avx l2 simd16", "[ut][simd][avx]") {
#if defined(ENABLE_AVX)
if (cpuinfo_has_x86_sse()) {
#if defined(ENABLE_AVX2)
if (vsag::SimdStatus::SupportAVX2()) {
size_t dim = 16;
auto vectors = fixtures::generate_vectors(2, dim);

Expand All @@ -37,8 +36,8 @@ TEST_CASE("avx l2 simd16", "[ut][simd][avx]") {
}

TEST_CASE("avx ip simd16", "[ut][simd][avx]") {
#if defined(ENABLE_AVX)
if (cpuinfo_has_x86_sse()) {
#if defined(ENABLE_AVX2)
if (vsag::SimdStatus::SupportAVX2()) {
size_t dim = 16;
auto vectors = fixtures::generate_vectors(2, dim);

Expand All @@ -52,8 +51,8 @@ TEST_CASE("avx ip simd16", "[ut][simd][avx]") {
}

TEST_CASE("avx pq calculation", "[ut][simd][avx]") {
#if defined(ENABLE_AVX)
if (cpuinfo_has_x86_avx2()) {
#if defined(ENABLE_AVX2)
if (vsag::SimdStatus::SupportAVX2()) {
size_t dim = 256;
float single_dim_value = 0.571;
float results_expected[256]{0.0f};
Expand Down
6 changes: 0 additions & 6 deletions src/simd/fp32_simd.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,32 +27,26 @@ float
FP32ComputeL2Sqr(const float* query, const float* codes, uint64_t dim);
} // namespace generic

#if defined(ENABLE_SSE)
namespace sse {
float
FP32ComputeIP(const float* query, const float* codes, uint64_t dim);
float
FP32ComputeL2Sqr(const float* query, const float* codes, uint64_t dim);
} // namespace sse
#endif

#if defined(ENABLE_AVX2)
namespace avx2 {
float
FP32ComputeIP(const float* query, const float* codes, uint64_t dim);
float
FP32ComputeL2Sqr(const float* query, const float* codes, uint64_t dim);
} // namespace avx2
#endif

#if defined(ENABLE_AVX512)
namespace avx512 {
float
FP32ComputeIP(const float* query, const float* codes, uint64_t dim);
float
FP32ComputeL2Sqr(const float* query, const float* codes, uint64_t dim);
} // namespace avx512
#endif

using FP32ComputeType = float (*)(const float* query, const float* codes, uint64_t dim);
extern FP32ComputeType FP32ComputeIP;
Expand Down
26 changes: 17 additions & 9 deletions src/simd/fp32_simd_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "catch2/benchmark/catch_benchmark.hpp"
#include "catch2/catch_test_macros.hpp"
#include "fixtures.h"
#include "simd_status.h"

using namespace vsag;

Expand All @@ -33,15 +34,22 @@ namespace avx2 = sse;
namespace avx512 = avx2;
#endif

#define TEST_ACCURACY(Func) \
{ \
auto gt = generic::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \
auto sse = sse::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \
auto avx2 = avx2::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \
auto avx512 = avx512::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \
REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); \
REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); \
REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); \
#define TEST_ACCURACY(Func) \
{ \
float gt, sse, avx2, avx512; \
gt = generic::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \
if (SimdStatus::SupportSSE()) { \
sse = sse::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \
REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(sse)); \
} \
if (SimdStatus::SupportAVX2()) { \
avx2 = avx2::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \
REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx2)); \
} \
if (SimdStatus::SupportAVX512()) { \
avx512 = avx512::Func(vec1.data() + i * dim, vec2.data() + i * dim, dim); \
REQUIRE(fixtures::dist_t(gt) == fixtures::dist_t(avx512)); \
} \
};

TEST_CASE("FP32 SIMD Compute", "[FP32SIMD]") {
Expand Down
60 changes: 60 additions & 0 deletions src/simd/normalize.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "normalize.h"

#include "simd_status.h"

namespace vsag {

static NormalizeType
SetNormalize() {
if (SimdStatus::SupportAVX512()) {
#if defined(ENABLE_AVX512)
return avx512::Normalize;
#endif
} else if (SimdStatus::SupportAVX2()) {
#if defined(ENABLE_AVX2)
return avx2::Normalize;
#endif
} else if (SimdStatus::SupportSSE()) {
#if defined(ENABLE_SSE)
return sse::Normalize;
#endif
}
return generic::Normalize;
}
NormalizeType Normalize = SetNormalize();

static DivScalarType
SetDivScalar() {
if (SimdStatus::SupportAVX512()) {
#if defined(ENABLE_AVX512)
return avx512::DivScalar;
#endif
} else if (SimdStatus::SupportAVX2()) {
#if defined(ENABLE_AVX2)
return avx2::DivScalar;
#endif
} else if (SimdStatus::SupportSSE()) {
#if defined(ENABLE_SSE)
return sse::DivScalar;
#endif
}
return generic::DivScalar;
}
DivScalarType DivScalar = SetDivScalar();

} // namespace vsag
Loading

0 comments on commit c7a00b3

Please sign in to comment.