Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy committed Nov 1, 2024
1 parent 6df1676 commit 4233379
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 22 deletions.
34 changes: 18 additions & 16 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -713,24 +713,26 @@ void
int8_t ZeroPoint);

typedef
void (MLASCALL MLAS_QUANTIZE_SOFTMAX_I8_KERNEL)(
size_t D,
const int8_t* Xdata,
int8_t* Ydata,
const float* LookupTable,
float Yscale,
int8_t YZeroPoint,
float* Buff);
void
(MLASCALL MLAS_QUANTIZE_SOFTMAX_I8_KERNEL)(
size_t D,
const int8_t* Xdata,
int8_t* Ydata,
const float* LookupTable,
float Yscale,
int8_t YZeroPoint,
float* Buff);

typedef
void (MLASCALL MLAS_QUANTIZE_SOFTMAX_U8_KERNEL)(
size_t D,
const uint8_t* Xdata,
uint8_t* Ydata,
const float* LookupTable,
float Yscale,
uint8_t YZeroPoint,
float* Buff);
void
(MLASCALL MLAS_QUANTIZE_SOFTMAX_U8_KERNEL)(
size_t D,
const uint8_t* Xdata,
uint8_t* Ydata,
const float* LookupTable,
float Yscale,
uint8_t YZeroPoint,
float* Buff);

template<typename InputType, typename FilterType>
struct MLAS_QUANT_KERNEL
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/mlas/lib/qsoftmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ struct MLAS_QSOFTMAX_WORK_BLOCK {
};

static void BuildLookupTable(gsl::span<float> table,
const float x_scale,
size_t reduce_len, bool is_signed) {
const float x_scale,
size_t reduce_len, bool is_signed) {
// make sure sum(exp(x)) < max<T>()
double bit_shift = log(std::numeric_limits<float>::max() / reduce_len);
double reserve_bit = std::is_same_v<float, float> ? 5 : 3;
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/mlas/lib/qsoftmax_kernel_naive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ Module Name:

#include "mlasi.h"

void MlasQuantizeSoftmaxU8KernelNaive(size_t D, const uint8_t* x_data, uint8_t* y_data, const float* lookup_table,
float y_scale, uint8_t yzp, float*) {
void MLASCALL MlasQuantizeSoftmaxU8KernelNaive(size_t D, const uint8_t* x_data, uint8_t* y_data,
const float* lookup_table, float y_scale, uint8_t yzp, float*) {
constexpr size_t N = 1;
const auto c_y_scale = y_scale;
const auto c_y_zp = yzp;
Expand Down Expand Up @@ -66,8 +66,8 @@ void MlasQuantizeSoftmaxU8KernelNaive(size_t D, const uint8_t* x_data, uint8_t*
}
}

void MlasQuantizeSoftmaxI8KernelNaive(size_t D, const int8_t* x_data, int8_t* y_data, const float* lookup_table,
float y_scale, int8_t yzp, float*) {
void MLASCALL MlasQuantizeSoftmaxI8KernelNaive(size_t D, const int8_t* x_data, int8_t* y_data,
const float* lookup_table, float y_scale, int8_t yzp, float*) {
constexpr size_t N = 1;
const auto c_y_scale = y_scale;
const auto c_y_zp = yzp;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/test/mlas/bench/bench_qsoftmax.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "bench_util.h"
#include "core/mlas/lib/mlasi.h"
#include "core/util/thread_utils.h"

static const std::vector<std::string> qsoftmax_bench_arg_names = {"N", "D", "is_signed"};
//(const void* Input, void* Output, size_t N, size_t D, const float* LoopupTable,float Scale, int ZeroPoint, bool is_signed, MLAS_THREADPOOL* ThreadPool);
Expand Down

0 comments on commit 4233379

Please sign in to comment.