Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Sync itrex1.3 #12

Merged
merged 4 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,7 @@ option(NE_AVX512_VBMI "neural_engine: enable AVX512-VBMI"
option(NE_AVX512_VNNI "neural_engine: enable AVX512-VNNI" OFF)
option(NE_FMA "neural_engine: enable FMA" ON)
option(NE_AMX "neural_engine: enable AMX" OFF)

# in MSVC F16C is implied with AVX2/AVX512
if (NOT MSVC)
option(NE_F16C "neural_engine: enable F16C" ON)
endif()
option(NE_F16C "neural_engine: enable F16C" ON)
airMeng marked this conversation as resolved.
Show resolved Hide resolved

# 3rd party libs
option(NE_ONEDNN "neural_engine: use oneDNN" ON)
airMeng marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -93,6 +89,8 @@ if (NE_GELU_VEC)
endif()
option(NE_PYTHON_API "neural_engine: use python api" OFF)
option(NE_SIMD_VEC_DOT_F16 "neural_engine: enable vec_dot_fp16 SIMD optimization" ON)
option(BUILD_SHARED_LIBS "If build as shared libs" ON)

if (NE_SIMD_VEC_DOT_F16)
add_compile_definitions(NE_SIMD_VEC_DOT_F16)
endif()
Expand All @@ -103,7 +101,6 @@ endif()

if (MSVC)
add_compile_definitions(_CRT_SECURE_NO_WARNINGS NOMINMAX)

if (BUILD_SHARED_LIBS)
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
endif()
Expand Down
21 changes: 12 additions & 9 deletions bestla/jblas/jit_blas_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ class SchedulerBase : public Scheduler2D {
mL2Use += static_cast<size_t>(mBlock[1]) * mBlock[2] * mEleSize[1];
mL2Use += static_cast<size_t>(mStep[0]) * mBlock[2] * mEleSize[0];
}
const float DensityThres = 32;
const float DensityThres = 16;
static size_t constexpr ReservedSize = 32ULL * 1024ULL;

virtual float calculate_score() {
Expand Down Expand Up @@ -364,7 +364,7 @@ class SchedulerKBlock : public Scheduler2D {
mL2Use += static_cast<size_t>(mBlock[1]) * mBlock[2] * mEleSize[1];
mL2Use += static_cast<size_t>(mStep[0]) * mBlock[2] * mEleSize[0];
}
const float DensityThres = 32;
const float DensityThres = 16;

float calculate_score() {
int tmpnstep = mThdSize[1] < _GemmCore_T::PREFERRED_N ? mThdSize[1] : _GemmCore_T::PREFERRED_N;
Expand Down Expand Up @@ -489,13 +489,14 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> {
this->mL2Use += static_cast<size_t>(blks) * (this->mBlock[1] + this->mStep[0]) *
(sizeof(float) + sizeof(int8_t) + sizeof(float)); // scale+zp+reduce
assert(this->mL2Use <= this->mL2Size - ReservedSize);
assert(this->mBlock[0]>0);
assert(this->mBlock[1]>0);
assert(this->mBlock[2]>0);
assert(this->mBlock[0] > 0);
assert(this->mBlock[1] > 0);
assert(this->mBlock[2] > 0);
assert(this->mBlock[2] % _GemmCore_T::KTILE == 0);
}

protected:
const float DensityThres = 32;
const float DensityThres = 16;
static size_t constexpr ReservedSize = 32ULL * 1024ULL;

void cache_blocking_compute() override {
Expand Down Expand Up @@ -529,6 +530,11 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> {
(this->mStep[0] * this->mEleSize[0] +
float(CorSize * (this->mStep[0] + this->mBlock[1])) / this->mKBlock +
this->mBlock[1] * this->mEleSize[1]));
if (rawk < this->mKBlock) {
rawk = static_cast<int>((valid_total - this->mBlock[0] * this->mBlock[1] * this->mEleSize[2] -
1 * CorSize * (this->mStep[0] + this->mBlock[1])) /
(this->mStep[0] * this->mEleSize[0] + this->mBlock[1] * this->mEleSize[1]));
}
rawk = std::min(rawk, this->mSizePadded[2]);
this->mBlock[2] = utils::padto_le(rawk, this->mStep[2]);
if (this->mBlock[2] > this->mKBlock) {
Expand Down Expand Up @@ -569,9 +575,6 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> {
this->mBlock[2] = static_cast<int>(getMaxK(this->mBlock[1]));
this->mBlock[2] = utils::padto_le(this->mBlock[2], this->mStep[2]);
this->mBlock[2] = std::min(mKBlock, this->mBlock[2]);
auto tmp = utils::updiv(mKBlock, this->mBlock[2]);
while (mKBlock % tmp != 0) tmp++; // TODO(Yu) optimize
this->mBlock[2] = utils::downdiv(mKBlock, tmp);
}
}

Expand Down
20 changes: 16 additions & 4 deletions bestla/jblas/kernel_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,14 @@ inline JBLAS_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int
for (; j < align_col; j += 8) quant();
for (; j < col; j++) {
auto fp_v = ref::f8_to_fp32(srcptr[i * ld_src + j], src_f8_type);
if constexpr (std::is_same_v<_S_T, utils::f8>) {
dstptr[i * ld_dst + j] = fp_v * std::pow(2, sptr[j / _PACK_ROW].x);
} else if constexpr (std::is_same_v<_S_T, float>) {
dstptr[i * ld_dst + j] = fp_v * sptr[j / _PACK_ROW];
if constexpr (WITH_SCALE) {
if constexpr (std::is_same_v<_S_T, utils::f8>) {
dstptr[i * ld_dst + j] = fp_v * std::pow(2, sptr[j / _PACK_ROW].x);
} else if constexpr (std::is_same_v<_S_T, float>) {
dstptr[i * ld_dst + j] = fp_v * sptr[j / _PACK_ROW];
}
} else {
dstptr[i * ld_dst + j] = fp_v;
}
}
}
Expand Down Expand Up @@ -636,6 +640,14 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow1(
vzps[iv] = _mm256_cvtepi8_epi32(tmp);
}
}
auto rowre = row - irow;
int rowpad4 = utils::padto_le(rowre, UnrollRow) + irow;
for (; irow < rowpad4; irow += UnrollRow) {
for (int iter16 = 0; iter16 < Loop16; iter16++)
pad_bit4_16(tmpbuf + iter16 * 16, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2 + 8 * iter16));
for (int iterr = 0; iterr < UnrollRow; iterr++)
dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * _NCOL, vscales, vzps);
}
for (; irow < row; irow++) {
if constexpr (_NCOL == 24) {
pad_bit4_16(tmpbuf, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2));
Expand Down
51 changes: 33 additions & 18 deletions bestla/jblas/kernel_avx512f.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,13 +321,28 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow1(utils::bit4x2* srcptr,
vzps[iv] = _mm512_cvtepi8_epi32(tmp);
}
}
}
for (; irow < row; irow++) {
pad_bit4(tmpbuf, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2), zmm_mask, LoadMask48);
if constexpr (_IS_SYM) {
dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, nullptr);
} else {
dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps);
auto rowre = row - irow;
int rowpad4 = utils::padto_le(rowre, UnrollRow) + irow;
for (; irow < rowpad4; irow += UnrollRow) {
for (int iter64 = 0; iter64 < Loop64; iter64++) {
pad_bit4(tmpbuf + iter64 * 64, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2 + 32 * iter64), zmm_mask,
LoadMask64);
}
for (int iterr = 0; iterr < UnrollRow; iterr++) {
if constexpr (_IS_SYM) {
dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * ColTile, vscales, nullptr);
} else {
dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * ColTile, vscales, vzps);
}
}
}
for (; irow < row; irow++) {
pad_bit4(tmpbuf, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2), zmm_mask, LoadMask48);
if constexpr (_IS_SYM) {
dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, nullptr);
} else {
dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps);
}
}
}
return JblasSuccess;
Expand Down Expand Up @@ -565,7 +580,7 @@ inline JBLAS_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int
auto quant = [&](__mmask16 mask) {
__m128i f8_src;
auto sign_revert =
_mm512_cvtepi8_epi32(_mm_mask_loadu_epi8(f8_src, mask, reinterpret_cast<__m128i*>(srcptr + i * ld_src + j)));
_mm512_cvtepi8_epi32(_mm_maskz_loadu_epi8(mask, reinterpret_cast<__m128i*>(srcptr + i * ld_src + j)));
auto e_revert = sign_revert;
auto mantissa_revert = sign_revert;
sign_revert = _mm512_slli_epi32(sign_revert, 24);
Expand Down Expand Up @@ -888,10 +903,10 @@ inline void f32_f4_quantize_4x16(const float* srcptr, int8_t* dstptr, int ld_src
zmm2 = _mm512_add_ps(zmm2, zmm_zp);
zmm3 = _mm512_add_ps(zmm3, zmm_zp);
} else {
mask4 = _mm512_cmplt_ps_mask(zmm0, zmm_v0);
mask5 = _mm512_cmplt_ps_mask(zmm1, zmm_v0);
mask6 = _mm512_cmplt_ps_mask(zmm2, zmm_v0);
mask7 = _mm512_cmplt_ps_mask(zmm3, zmm_v0);
mask4 = _mm512_cmp_ps_mask(zmm0, zmm_v0, 1);
mask5 = _mm512_cmp_ps_mask(zmm1, zmm_v0, 1);
mask6 = _mm512_cmp_ps_mask(zmm2, zmm_v0, 1);
mask7 = _mm512_cmp_ps_mask(zmm3, zmm_v0, 1);

zmm0 = _mm512_abs_ps(zmm0);
zmm1 = _mm512_abs_ps(zmm1);
Expand All @@ -908,10 +923,10 @@ inline void f32_f4_quantize_4x16(const float* srcptr, int8_t* dstptr, int ld_src
zmm5 = _mm512_sub_ps(zmm1, sub_v);
zmm6 = _mm512_sub_ps(zmm2, sub_v);
zmm7 = _mm512_sub_ps(zmm3, sub_v);
mask0 = _mm512_cmple_ps_mask(zmm4, zmm_v0);
mask1 = _mm512_cmple_ps_mask(zmm5, zmm_v0);
mask2 = _mm512_cmple_ps_mask(zmm6, zmm_v0);
mask3 = _mm512_cmple_ps_mask(zmm7, zmm_v0);
mask0 = _mm512_cmp_ps_mask(zmm4, zmm_v0, 2);
mask1 = _mm512_cmp_ps_mask(zmm5, zmm_v0, 2);
mask2 = _mm512_cmp_ps_mask(zmm6, zmm_v0, 2);
mask3 = _mm512_cmp_ps_mask(zmm7, zmm_v0, 2);
xmm0 = _mm_mask_blend_epi8(mask0, xmm0, _mm_loadu_si128(reinterpret_cast<const __m128i*>(broadcast_f4_v + i * 16)));
xmm1 = _mm_mask_blend_epi8(mask1, xmm1, _mm_loadu_si128(reinterpret_cast<const __m128i*>(broadcast_f4_v + i * 16)));
xmm2 = _mm_mask_blend_epi8(mask2, xmm2, _mm_loadu_si128(reinterpret_cast<const __m128i*>(broadcast_f4_v + i * 16)));
Expand Down Expand Up @@ -949,7 +964,7 @@ inline void f32_f4_quantize_1x16(const float* srcptr, int8_t* dstptr, int ld_src
auto zp = _mm512_set1_ps(0.8480964004993439f);
zmm0 = _mm512_add_ps(zmm0, zp);
} else {
mask1 = _mm512_cmplt_ps_mask(zmm0, zmm_v0);
mask1 = _mm512_cmp_ps_mask(zmm0, zmm_v0, 1);
zmm0 = _mm512_abs_ps(zmm0);
}
constexpr int loop_num = F4_T == JBLAS_DTYPE::F4_NF4 ? 16 : 8;
Expand All @@ -959,7 +974,7 @@ inline void f32_f4_quantize_1x16(const float* srcptr, int8_t* dstptr, int ld_src
if constexpr (F4_T == JBLAS_DTYPE::F4_BNB) sub_v = _mm512_set1_ps(F4_BNB_quant_sub_helper[i]);
if constexpr (F4_T == JBLAS_DTYPE::F4_E2M1) sub_v = _mm512_set1_ps(F4_E2M1_quant_sub_helper[i]);
zmm1 = _mm512_sub_ps(zmm0, sub_v);
mask0 = _mm512_cmple_ps_mask(zmm1, zmm_v0);
mask0 = _mm512_cmp_ps_mask(zmm1, zmm_v0, 2);
xmm0 = _mm_mask_blend_epi8(mask0, xmm0, _mm_loadu_si128(reinterpret_cast<const __m128i*>(broadcast_f4_v + i * 16)));
zmm0 = _mm512_mask_add_ps(zmm0, mask0, zmm0, avoid_double_cmp);
}
Expand Down
56 changes: 39 additions & 17 deletions bestla/jblas/kernel_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,25 +230,47 @@ inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) {
dstptr[7] = tmp;
}

inline void convert_s4_s8_8_lowbits(int8_t* dstptr, int8_t* srcptr) {
auto src32 = *reinterpret_cast<uint32_t*>(srcptr);
auto tmp = static_cast<int>(src32 & 0xf);
dstptr[0] = static_cast<int8_t>(tmp);
tmp = static_cast<int>(src32 & 0xf0) >> 4;
dstptr[1] = static_cast<int8_t>(tmp);
tmp = static_cast<int>((src32 & 0xf00) >> 8);
dstptr[2] = static_cast<int8_t>(tmp);
tmp = static_cast<int>((src32 & 0xf000) >> 12);
dstptr[3] = static_cast<int8_t>(tmp);
tmp = static_cast<int>((src32 & 0xf0000) >> 16);
dstptr[4] = static_cast<int8_t>(tmp);
tmp = static_cast<int>((src32 & 0xf00000) >> 20);
dstptr[5] = static_cast<int8_t>(tmp);
tmp = static_cast<int>((src32 & 0xf000000) >> 24);
dstptr[6] = static_cast<int8_t>(tmp);
tmp = static_cast<int>((src32 & 0xf0000000) >> 28);
dstptr[7] = static_cast<int8_t>(tmp);
}

template <>
inline void convert_s4_s8_8<JBLAS_DTYPE::S4_FULLRANGE>(int8_t* dstptr, int8_t* srcptr) {
auto src32 = *reinterpret_cast<uint32_t*>(srcptr);
auto tmp = static_cast<int8_t>(src32 & 0xf);
dstptr[0] = tmp - 8;
tmp = static_cast<int8_t>(src32 & 0xf0) >> 4;
dstptr[1] = tmp - 8;
tmp = static_cast<int8_t>((src32 & 0xf00) >> 8);
dstptr[2] = tmp - 8;
tmp = static_cast<int8_t>((src32 & 0xf000) >> 12);
dstptr[3] = tmp - 8;
tmp = static_cast<int8_t>((src32 & 0xf0000) >> 16);
dstptr[4] = tmp - 8;
tmp = static_cast<int8_t>((src32 & 0xf00000) >> 20);
dstptr[5] = tmp - 8;
tmp = static_cast<int8_t>((src32 & 0xf000000) >> 24);
dstptr[6] = tmp - 8;
tmp = static_cast<int8_t>((src32 & 0xf0000000) >> 28);
dstptr[7] = tmp - 8;
convert_s4_s8_8_lowbits(dstptr, srcptr);
for (size_t i = 0; i < 8; i++) {
dstptr[i] -= 8;
}
}

template <>
inline void convert_s4_s8_8<JBLAS_DTYPE::F4_BNB>(int8_t* dstptr, int8_t* srcptr) {
convert_s4_s8_8_lowbits(dstptr, srcptr);
}

template <>
inline void convert_s4_s8_8<JBLAS_DTYPE::F4_NF4>(int8_t* dstptr, int8_t* srcptr) {
convert_s4_s8_8_lowbits(dstptr, srcptr);
}

template <>
inline void convert_s4_s8_8<JBLAS_DTYPE::F4_E2M1>(int8_t* dstptr, int8_t* srcptr) {
convert_s4_s8_8_lowbits(dstptr, srcptr);
}

template <JBLAS_DTYPE S4_T>
Expand Down
20 changes: 18 additions & 2 deletions neural_speed/cmake/Common.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,25 @@ function(add_executable_w_warning TARGET)
warning_check(${TARGET})
endfunction()

function(add_library_w_warning TARGET)
add_library(${TARGET} STATIC ${ARGN})
function(add_library_w_warning_ TARGET)
add_library(${TARGET} ${ARGN})
set_target_properties(${TARGET} PROPERTIES C_STANDARD 11 C_STANDARD_REQUIRED ON C_EXTENSIONS OFF)
set_target_properties(${TARGET} PROPERTIES CXX_STANDARD 11 CXX_STANDARD_REQUIRED ON CXX_EXTENSIONS OFF)
warning_check(${TARGET})
endfunction()

function(add_library_w_warning TARGET)
add_library_w_warning_(${TARGET} STATIC ${ARGN})
endfunction()

function(add_shared_library_w_warning TARGET)
add_library_w_warning_(${TARGET} SHARED ${ARGN})
endfunction()

function(add_shareable_library_w_warning TARGET)
if (BUILD_SHARED_LIBS)
add_library_w_warning_(${TARGET} SHARED ${ARGN})
else()
add_library_w_warning_(${TARGET} STATIC ${ARGN})
endif()
endfunction()
3 changes: 3 additions & 0 deletions neural_speed/cmake/ISA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.

if (MSVC)
if(NE_F16C)
add_compile_definitions(__F16C__)
endif()
if (NE_AVX512)
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX512>)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
Expand Down
2 changes: 1 addition & 1 deletion neural_speed/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ find_package(Threads REQUIRED)
file(GLOB layers_srcs "layers/*.cpp")
set(sources ne_layers.c ${layers_srcs})

add_library_w_warning(ne_layers "${sources}")
add_shareable_library_w_warning(ne_layers "${sources}")

target_include_directories(ne_layers PUBLIC .)
target_compile_features(ne_layers PUBLIC c_std_11) # don't bump
Expand Down
4 changes: 2 additions & 2 deletions neural_speed/scripts/convert_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,8 +855,8 @@ def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus:
return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None)


SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {'F16': DT_F16, 'F32': DT_F32, 'I32': DT_I32, 'BOOL': DT_BOOL}

SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {'F16': DT_F16, 'F32': DT_F32, 'I32': DT_I32, 'BOOL': DT_BOOL,
'BF16': DT_BF16}

def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
header_size, = struct.unpack('<Q', fp.read(8))
Expand Down