Skip to content

Commit

Permalink
C code tests & avx512f f16 implement (#183)
Browse files Browse the repository at this point in the history
* test: add tests for c code

Signed-off-by: usamoi <[email protected]>

* fix: relax EPSILON for tests

Signed-off-by: usamoi <[email protected]>

---------

Signed-off-by: usamoi <[email protected]>
  • Loading branch information
usamoi authored Dec 15, 2023
1 parent 2869fbd commit c50912e
Show file tree
Hide file tree
Showing 13 changed files with 276 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ jobs:
cargo build --no-default-features --features "pg${{ matrix.version }} pg_test" --target aarch64-unknown-linux-gnu
- name: Test
run: |
cargo test --all --no-fail-fast --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu
cargo test --all --no-fail-fast --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu -- --nocapture
- name: Install release
run: ./scripts/ci_install.sh
- name: Sqllogictest
Expand Down
25 changes: 23 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions crates/c/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ name = "c"
version.workspace = true
edition.workspace = true

[dependencies]
half = { version = "~2.3", features = ["use-intrinsics"] }
[dev-dependencies]
half = { version = "~2.3", features = ["use-intrinsics", "rand_distr"] }
detect = { path = "../detect" }
rand = "0.8.5"

[build-dependencies]
cc = "1.0"
78 changes: 76 additions & 2 deletions crates/c/src/c.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@ v_f16_cosine_avx512fp16(_Float16 *a, _Float16 *b, size_t n) {
xx = _mm512_fmadd_ph(x, x, xx);
yy = _mm512_fmadd_ph(y, y, yy);
}
return (float)(_mm512_reduce_add_ph(xy) /
sqrt(_mm512_reduce_add_ph(xx) * _mm512_reduce_add_ph(yy)));
{
float rxy = _mm512_reduce_add_ph(xy);
float rxx = _mm512_reduce_add_ph(xx);
float ryy = _mm512_reduce_add_ph(yy);
return rxy / sqrt(rxx * ryy);
}
}

__attribute__((target("arch=x86-64-v4,avx512fp16"))) extern float
Expand Down Expand Up @@ -74,6 +78,76 @@ v_f16_sl2_avx512fp16(_Float16 *a, _Float16 *b, size_t n) {
return (float)_mm512_reduce_add_ph(dd);
}

__attribute__((target("arch=x86-64-v4"))) extern float
v_f16_cosine_v4(_Float16 *a, _Float16 *b, size_t n) {
__m512 xy = _mm512_set1_ps(0);
__m512 xx = _mm512_set1_ps(0);
__m512 yy = _mm512_set1_ps(0);

while (n >= 16) {
__m512 x = _mm512_cvtph_ps(_mm256_loadu_epi16(a));
__m512 y = _mm512_cvtph_ps(_mm256_loadu_epi16(b));
a += 16, b += 16, n -= 16;
xy = _mm512_fmadd_ps(x, y, xy);
xx = _mm512_fmadd_ps(x, x, xx);
yy = _mm512_fmadd_ps(y, y, yy);
}
if (n > 0) {
__mmask16 mask = _bzhi_u32(0xFFFF, n);
__m512 x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
__m512 y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
xy = _mm512_fmadd_ps(x, y, xy);
xx = _mm512_fmadd_ps(x, x, xx);
yy = _mm512_fmadd_ps(y, y, yy);
}
{
float rxy = _mm512_reduce_add_ps(xy);
float rxx = _mm512_reduce_add_ps(xx);
float ryy = _mm512_reduce_add_ps(yy);
return rxy / sqrt(rxx * ryy);
}
}

__attribute__((target("arch=x86-64-v4"))) extern float
v_f16_dot_v4(_Float16 *a, _Float16 *b, size_t n) {
__m512 xy = _mm512_set1_ps(0);

while (n >= 16) {
__m512 x = _mm512_cvtph_ps(_mm256_loadu_epi16(a));
__m512 y = _mm512_cvtph_ps(_mm256_loadu_epi16(b));
a += 16, b += 16, n -= 16;
xy = _mm512_fmadd_ps(x, y, xy);
}
if (n > 0) {
__mmask16 mask = _bzhi_u32(0xFFFF, n);
__m512 x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
__m512 y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
xy = _mm512_fmadd_ps(x, y, xy);
}
return _mm512_reduce_add_ps(xy);
}

__attribute__((target("arch=x86-64-v4"))) extern float
v_f16_sl2_v4(_Float16 *a, _Float16 *b, size_t n) {
__m512 dd = _mm512_set1_ps(0);

while (n >= 16) {
__m512 x = _mm512_cvtph_ps(_mm256_loadu_epi16(a));
__m512 y = _mm512_cvtph_ps(_mm256_loadu_epi16(b));
a += 16, b += 16, n -= 16;
__m512 d = _mm512_sub_ps(x, y);
dd = _mm512_fmadd_ps(d, d, dd);
}
if (n > 0) {
__mmask16 mask = _bzhi_u32(0xFFFF, n);
__m512 x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
__m512 y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
__m512 d = _mm512_sub_ps(x, y);
dd = _mm512_fmadd_ps(d, d, dd);
}
return _mm512_reduce_add_ps(dd);
}

__attribute__((target("arch=x86-64-v3"))) extern float
v_f16_cosine_v3(_Float16 *a, _Float16 *b, size_t n) {
float xy = 0;
Expand Down
3 changes: 3 additions & 0 deletions crates/c/src/c.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
extern float v_f16_cosine_avx512fp16(_Float16 *, _Float16 *, size_t n);
extern float v_f16_dot_avx512fp16(_Float16 *, _Float16 *, size_t n);
extern float v_f16_sl2_avx512fp16(_Float16 *, _Float16 *, size_t n);
extern float v_f16_cosine_v4(_Float16 *, _Float16 *, size_t n);
extern float v_f16_dot_v4(_Float16 *, _Float16 *, size_t n);
extern float v_f16_sl2_v4(_Float16 *, _Float16 *, size_t n);
extern float v_f16_cosine_v3(_Float16 *, _Float16 *, size_t n);
extern float v_f16_dot_v3(_Float16 *, _Float16 *, size_t n);
extern float v_f16_sl2_v3(_Float16 *, _Float16 *, size_t n);
Expand Down
17 changes: 3 additions & 14 deletions crates/c/src/c.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,10 @@ extern "C" {
pub fn v_f16_cosine_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_dot_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_sl2_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_cosine_v4(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_dot_v4(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_sl2_v4(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_cosine_v3(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_dot_v3(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_sl2_v3(a: *const u16, b: *const u16, n: usize) -> f32;
}

// `compiler_builtin` defines `__extendhfsf2` with integer calling convention.
// However C compilers links `__extendhfsf2` with floating calling convention.
// The code should be removed once Rust offically supports `f16`.

#[cfg(target_arch = "x86_64")]
#[no_mangle]
#[linkage = "external"]
extern "C" fn __extendhfsf2(f: f64) -> f32 {
unsafe {
let f: half::f16 = std::mem::transmute_copy(&f);
f.to_f32()
}
}
126 changes: 126 additions & 0 deletions crates/c/tests/x86_64.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#![cfg(target_arch = "x86_64")]

#[test]
fn test_v_f16_cosine() {
const EPSILON: f32 = f16::EPSILON.to_f32_const();
use half::f16;
unsafe fn v_f16_cosine(a: *const u16, b: *const u16, n: usize) -> f32 {
let mut xy = 0.0f32;
let mut xx = 0.0f32;
let mut yy = 0.0f32;
for i in 0..n {
let x = a.add(i).cast::<f16>().read().to_f32();
let y = b.add(i).cast::<f16>().read().to_f32();
xy += x * y;
xx += x * x;
yy += y * y;
}
xy / (xx * yy).sqrt()
}
let n = 4000;
let a = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let b = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let r = unsafe { v_f16_cosine(a.as_ptr().cast(), b.as_ptr().cast(), n) };
if detect::x86_64::detect_avx512fp16() {
println!("detected avx512fp16");
let c = unsafe { c::v_f16_cosine_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no avx512fp16, skipped");
}
if detect::x86_64::detect_v4() {
println!("detected v4");
let c = unsafe { c::v_f16_cosine_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v4, skipped");
}
if detect::x86_64::detect_v3() {
println!("detected v3");
let c = unsafe { c::v_f16_cosine_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v3, skipped");
}
}

#[test]
fn test_v_f16_dot() {
const EPSILON: f32 = 1.0f32;
use half::f16;
unsafe fn v_f16_dot(a: *const u16, b: *const u16, n: usize) -> f32 {
let mut xy = 0.0f32;
for i in 0..n {
let x = a.add(i).cast::<f16>().read().to_f32();
let y = b.add(i).cast::<f16>().read().to_f32();
xy += x * y;
}
xy
}
let n = 4000;
let a = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let b = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let r = unsafe { v_f16_dot(a.as_ptr().cast(), b.as_ptr().cast(), n) };
if detect::x86_64::detect_avx512fp16() {
println!("detected avx512fp16");
let c = unsafe { c::v_f16_dot_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no avx512fp16, skipped");
}
if detect::x86_64::detect_v4() {
println!("detected v4");
let c = unsafe { c::v_f16_dot_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v4, skipped");
}
if detect::x86_64::detect_v3() {
println!("detected v3");
let c = unsafe { c::v_f16_dot_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v3, skipped");
}
}

#[test]
fn test_v_f16_sl2() {
const EPSILON: f32 = 1.0f32;
use half::f16;
unsafe fn v_f16_sl2(a: *const u16, b: *const u16, n: usize) -> f32 {
let mut dd = 0.0f32;
for i in 0..n {
let x = a.add(i).cast::<f16>().read().to_f32();
let y = b.add(i).cast::<f16>().read().to_f32();
let d = x - y;
dd += d * d;
}
dd
}
let n = 4000;
let a = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let b = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let r = unsafe { v_f16_sl2(a.as_ptr().cast(), b.as_ptr().cast(), n) };
if detect::x86_64::detect_avx512fp16() {
println!("detected avx512fp16");
let c = unsafe { c::v_f16_sl2_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no avx512fp16, skipped");
}
if detect::x86_64::detect_v4() {
println!("detected v4");
let c = unsafe { c::v_f16_sl2_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v4, skipped");
}
if detect::x86_64::detect_v3() {
println!("detected v3");
let c = unsafe { c::v_f16_sl2_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v3, skipped");
}
}
8 changes: 8 additions & 0 deletions crates/detect/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[package]
name = "detect"
version.workspace = true
edition.workspace = true

[dependencies]
std_detect = { git = "https://github.com/tensorchord/stdarch.git", branch = "avx512fp16" }
ctor = "0.2.6"
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ fn ctor_v4() {
ATOMIC_V4.store(test_v4(), Ordering::Relaxed);
}

pub fn _detect_v4() -> bool {
pub fn detect_v4() -> bool {
ATOMIC_V4.load(Ordering::Relaxed)
}

Expand Down
3 changes: 1 addition & 2 deletions crates/service/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ bincode.workspace = true
half.workspace = true
num-traits.workspace = true
c = { path = "../c" }
std_detect = { git = "https://github.com/tensorchord/stdarch.git", branch = "avx512fp16" }
detect = { path = "../detect" }
rand = "0.8.5"
crc32fast = "1.3.2"
crossbeam = "0.8.2"
Expand All @@ -32,7 +32,6 @@ arc-swap = "1.6.0"
bytemuck = { version = "1.14.0", features = ["extern_crate_alloc"] }
serde_with = "3.4.0"
multiversion = "0.7.3"
ctor = "0.2.6"

[target.'cfg(target_os = "macos")'.dependencies]
ulock-sys = "0.1.0"
Expand Down
Loading

0 comments on commit c50912e

Please sign in to comment.