Skip to content

Commit

Permalink
platform change
Browse files Browse the repository at this point in the history
  • Loading branch information
fajin-corp committed Nov 8, 2024
1 parent adce284 commit 8050f0a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 29 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,6 @@ Return Value:
this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon;
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon;
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon;
this->QNBitGemmDispatch = &MlasQNBitGemmDispatchNeon;

//
// Check if the processor supports ASIMD dot product instructions.
Expand Down Expand Up @@ -561,6 +560,7 @@ Return Value:
this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchSdot;
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchDot;
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot;
this->QNBitGemmDispatch = &MlasQNBitGemmDispatchNeon;
}

#if defined(__linux__)
Expand Down
56 changes: 28 additions & 28 deletions onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ index e46105324a..414c46a1ce 100644
--- a/onnxruntime/core/mlas/inc/mlas.h
+++ b/onnxruntime/core/mlas/inc/mlas.h
@@ -82,6 +82,9 @@ Abstract:

#if (!defined(_MSC_VER)) || (_MSC_VER >= 1930)
#if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC)
+#if !defined(USE_VSINPU)
Expand All @@ -26,16 +26,16 @@ index e46105324a..414c46a1ce 100644
// Had to temporary disable fp16 under APPLE ARM64, as compiling
// the source files require a hardware specific compilation flag.
@@ -90,6 +93,7 @@ Abstract:

#define MLAS_F16VEC_INTRINSICS_SUPPORTED

+#endif //
#endif //
#endif // ARM64
#endif // Visual Studio 16 or earlier does not support fp16 intrinsic
@@ -1635,6 +1639,7 @@ MlasHalfGemmConvertPackB(
);

#if defined(__aarch64__) && defined(__linux__)
+#if !defined(USE_VSINPU)
/**
Expand All @@ -46,7 +46,7 @@ index e46105324a..414c46a1ce 100644
MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB);
#endif
+#endif

/**
* @brief Indirect Depthwise convolution for fp16
diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h
Expand All @@ -55,7 +55,7 @@ index 4239e2ecae..3df7e5573d 100644
+++ b/onnxruntime/core/mlas/lib/mlasi.h
@@ -361,6 +361,7 @@ size_t
#else

#if defined(__aarch64__) && defined(__linux__)
+#if !defined(USE_VSINPU)
typedef size_t(MLASCALL MLAS_SBGEMM_FLOAT_KERNEL)(
Expand All @@ -66,7 +66,7 @@ index 4239e2ecae..3df7e5573d 100644
);
#endif
+#endif

typedef
size_t
@@ -763,8 +765,10 @@ extern "C" {
Expand All @@ -82,13 +82,13 @@ index 4239e2ecae..3df7e5573d 100644
MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelAdd;
@@ -899,8 +903,10 @@ extern "C" {
#define MLAS_QGEMM_THREAD_COMPLEXITY 65536

#if defined(__aarch64__) && defined(__linux__)
+#if !defined(USE_VSINPU)
#define MLAS_SBGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024))
#endif
+#endif

//
// Single-threaded single precision matrix/matrix multiply operation.
@@ -2570,4 +2576,3 @@ MlasPackInt4Elements(uint8_t* Output, UnpackedType ValueLow, UnpackedType ValueH
Expand All @@ -103,16 +103,16 @@ index ed437f20f7..8c9d0a75fd 100644
@@ -20,7 +20,7 @@ Abstract:
#include <thread>
#include <mutex>
-#if defined(MLAS_TARGET_POWER)

-#if defined(MLAS_TARGET_POWER)
+#if defined(MLAS_TARGET_POWER)
#if defined(__linux__)
#include <sys/auxv.h>
#elif defined(_AIX)
@@ -536,7 +536,7 @@ Return Value:
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
this->QNBitGemmDispatch = &MlasQNBitGemmDispatchNeon;
}

-#if defined(__linux__)
+#if defined(__linux__) && !defined(USE_VSINPU)
//
Expand All @@ -124,12 +124,12 @@ index de7fd72fad..4f75dbd6fa 100644
+++ b/onnxruntime/core/mlas/lib/sbgemm.h
@@ -31,6 +31,7 @@ Abstract:
--*/

#if defined(__aarch64__) && defined(__linux__)
+#if !defined(USE_VSINPU)

#pragma once

@@ -396,4 +397,5 @@ MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t Bat
}
);
Expand All @@ -141,7 +141,7 @@ index 6a71283f9d..d8bd348854 100644
--- a/onnxruntime/core/providers/cpu/math/matmul.cc
+++ b/onnxruntime/core/providers/cpu/math/matmul.cc
@@ -132,7 +132,7 @@ Status MatMul<T>::Compute(OpKernelContext* ctx) const {

return Status::OK();
}
-#if defined(__aarch64__) && defined(__linux__)
Expand Down Expand Up @@ -187,18 +187,18 @@ index b9bbe36583..2f570502d2 100644
+++ b/onnxruntime/core/providers/cpu/math/matmul.h
@@ -31,8 +31,10 @@ class MatMul<float> final : public OpKernel {
trans_batch_b_ = trans_batch_b_attr != 0;

#if defined(__aarch64__) && defined(__linux__)
+#if !defined(USE_VSINPU)
auto config_ops = info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16);
use_fastmath_mode_ = (config_ops == "1") && MlasBf16AccelerationSupported();
+#endif
#endif
}

@@ -57,12 +59,14 @@ class MatMul<float> final : public OpKernel {
bool trans_batch_b_;

#if defined(__aarch64__) && defined(__linux__)
+#if !defined(USE_VSINPU)
// fastmath mode state
Expand All @@ -209,20 +209,20 @@ index b9bbe36583..2f570502d2 100644
#endif
+#endif
};

} // namespace onnxruntime
diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.cpp b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp
index f85fe97776..6039b7fa9e 100644
--- a/onnxruntime/test/mlas/unittest/test_sbgemm.cpp
+++ b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp
@@ -16,6 +16,7 @@ Abstract:
--*/

#if defined(__aarch64__) && defined(__linux__)
+#if !defined(USE_VSINPU)

#include "test_sbgemm.h"

@@ -138,4 +139,5 @@ static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_exe
}
return SBGemmRegistLongExecute() > 0;
Expand All @@ -235,15 +235,15 @@ index 13701e2e3d..7e432f53c2 100644
+++ b/onnxruntime/test/mlas/unittest/test_sbgemm.h
@@ -16,6 +16,7 @@ Abstract:
--*/

#if defined(__aarch64__) && defined(__linux__)
+#if !defined(USE_VSINPU)

#pragma once

@@ -278,4 +279,5 @@ class MlasSBGemmTest : public MlasTestBase {
}
};

+#endif
#endif // defined(__aarch64__) && defined(__linux__)

0 comments on commit 8050f0a

Please sign in to comment.