From d32d2fbfb7e81e158e06258c22dedd624aaa7ea0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 14 May 2019 22:04:18 +0200 Subject: [PATCH] Fix svm runtime (replaces #91, #33) (#101) * Fix svm runtime --- onnxruntime/core/providers/cpu/ml/ml_common.h | 81 ++++--- .../core/providers/cpu/ml/svmclassifier.cc | 210 +++++++++--------- .../core/providers/cpu/ml/svmclassifier.h | 27 ++- .../providers/cpu/ml/svmclassifier_test.cc | 106 ++++++++- 4 files changed, 262 insertions(+), 162 deletions(-) diff --git a/onnxruntime/core/providers/cpu/ml/ml_common.h b/onnxruntime/core/providers/cpu/ml/ml_common.h index 3e8c32f979931..b7a2f71d2cc7d 100644 --- a/onnxruntime/core/providers/cpu/ml/ml_common.h +++ b/onnxruntime/core/providers/cpu/ml/ml_common.h @@ -287,39 +287,56 @@ static inline void ComputeSoftmaxZero(std::vector& values) { template void write_scores(std::vector& scores, POST_EVAL_TRANSFORM post_transform, int64_t write_index, Tensor* Z, int add_second_class) { - if (post_transform == POST_EVAL_TRANSFORM::PROBIT && scores.size() == 1) { - scores[0] = ComputeProbit(scores[0]); - } else if (scores.size() >= 2) { //multiclass - if (post_transform == POST_EVAL_TRANSFORM::LOGISTIC) { - for (float& score : scores) { - score = ComputeLogistic(score); - } - } else if (post_transform == POST_EVAL_TRANSFORM::SOFTMAX) { - ComputeSoftmax(scores); - } else if (post_transform == POST_EVAL_TRANSFORM::SOFTMAX_ZERO) { - ComputeSoftmaxZero(scores); + if (scores.size() >= 2) { + switch (post_transform) { + case POST_EVAL_TRANSFORM::PROBIT: + for (float& score : scores) + score = ComputeProbit(score); + break; + case POST_EVAL_TRANSFORM::LOGISTIC: + for (float& score : scores) + score = ComputeLogistic(score); + break; + case POST_EVAL_TRANSFORM::SOFTMAX: + ComputeSoftmax(scores); + break; + case POST_EVAL_TRANSFORM::SOFTMAX_ZERO: + ComputeSoftmaxZero(scores); + break; + default: + case POST_EVAL_TRANSFORM::NONE: + break; } - } else { //binary case - if (add_second_class == 0 && scores.size() == 1) { //0=all positive weights, winning class is positive - scores.push_back(scores[0]); - scores[0] = 1.f - scores[0]; //put opposite score in positive slot - } else if (add_second_class == 1 && scores.size() == 1) { //1 = all positive weights, winning class is negative - scores.push_back(scores[0]); - scores[0] = 1.f - scores[0]; //put opposite score in positive slot - } else if (add_second_class == 2 && scores.size() == 1) { //2 = mixed weights, winning class is positive - if (post_transform == POST_EVAL_TRANSFORM::LOGISTIC) { - scores.push_back(ComputeLogistic(scores[0])); - scores[0] = ComputeLogistic(-scores[0]); - } else { - scores.push_back(scores[0]); - scores[0] = -scores[0]; - } - } else if (add_second_class == 3 && scores.size() == 1) { //3 = mixed weights, winning class is negative - if (post_transform == POST_EVAL_TRANSFORM::LOGISTIC) { - scores.push_back(ComputeLogistic(scores[0])); - scores[0] = ComputeLogistic(-scores[0]); - } else { - scores.push_back(-scores[0]); + } else if (scores.size() == 1) { //binary case + if (post_transform == POST_EVAL_TRANSFORM::PROBIT) { + scores[0] = ComputeProbit(scores[0]); + } else { + switch (add_second_class) { + case 0: //0=all positive weights, winning class is positive + scores.push_back(scores[0]); + scores[0] = 1.f - scores[0]; //put opposite score in positive slot + break; + case 1: //1 = all positive weights, winning class is negative + scores.push_back(scores[0]); + scores[0] = 1.f - scores[0]; //put opposite score in positive slot + break; + case 2: //2 = mixed weights, winning class is positive + if (post_transform == POST_EVAL_TRANSFORM::LOGISTIC) { + scores.push_back(ComputeLogistic(scores[0])); //ml_logit(scores[k]); + scores[0] = ComputeLogistic(-scores[0]); + } else { + scores.push_back(scores[0]); + scores[0] = -scores[0]; + } + break; + case 3: //3 = mixed weights, winning class is negative + if (post_transform == POST_EVAL_TRANSFORM::LOGISTIC) { + scores.push_back(ComputeLogistic(scores[0])); //ml_logit(scores[k]); + scores[0] = ComputeLogistic(-scores[0]); + } else { + scores.push_back(-scores[0]); + } + break; } } } diff --git a/onnxruntime/core/providers/cpu/ml/svmclassifier.cc b/onnxruntime/core/providers/cpu/ml/svmclassifier.cc index 7624d5c52cede..3bc845fa2044e 100644 --- a/onnxruntime/core/providers/cpu/ml/svmclassifier.cc +++ b/onnxruntime/core/providers/cpu/ml/svmclassifier.cc @@ -75,6 +75,32 @@ SVMClassifier::SVMClassifier(const OpKernelInfo& info) } } +template +int _set_score_svm(Tensor* Y, float max_weight, const int64_t maxclass, const int64_t n, + POST_EVAL_TRANSFORM post_transform_, const std::vector& proba_, bool weights_are_all_positive_, + const std::vector& classlabels, LabelType posclass, LabelType negclass) { + int write_additional_scores = -1; + auto output_data = Y->template MutableData(); + if (classlabels.size() == 2) { + write_additional_scores = post_transform_ == POST_EVAL_TRANSFORM::NONE ? 2 : 0; + if (proba_.size() == 0) { + if (weights_are_all_positive_ && max_weight >= 0.5) + output_data[n] = classlabels[1]; + else if (max_weight > 0 && !weights_are_all_positive_) + output_data[n] = classlabels[1]; + else + output_data[n] = classlabels[maxclass]; + } else { + output_data[n] = classlabels[maxclass]; + } + } else if (max_weight > 0) { + output_data[n] = posclass; + } else { + output_data[n] = negclass; + } + return write_additional_scores; +} + template Status SVMClassifier::Compute(OpKernelContext* ctx) const { const Tensor* X = ctx->Input(0); @@ -83,40 +109,51 @@ Status SVMClassifier::Compute(OpKernelContext* ctx) const { int64_t N = X->Shape().NumDimensions() == 1 ? 1 : X->Shape()[0]; Tensor* Y = ctx->Output(0, TensorShape({N})); - Tensor* Z; - std::vector dims; - if (mode_ == SVM_TYPE::SVM_SVC && proba_.size() == 0) - dims = {static_cast(N), static_cast(class_count_ * (class_count_ - 1) / 2)}; - else - dims = {static_cast(N), static_cast(class_count_)}; - Z = ctx->Output(1, TensorShape(dims)); + int64_t nb_columns = class_count_; + if (proba_.size() == 0 && vector_count_ > 0) { + if (class_count_ > 2) + nb_columns = class_count_ * (class_count_ - 1) / 2; + else + nb_columns = 2; + } + + std::vector dims{N, nb_columns}; + Tensor* Z = ctx->Output(1, TensorShape(dims)); - const auto* x_data = X->template Data(); + const T* x_data = X->template Data(); int64_t zindex = 0; for (int64_t n = 0; n < N; n++) //for each example { int64_t current_weight_0 = n * stride; int64_t maxclass = -1; - double maxweight = 0.f; std::vector decisions; std::vector scores; std::vector kernels; std::vector votes; - if (mode_ == SVM_TYPE::SVM_SVC) { + if (vector_count_ == 0 && mode_ == SVM_TYPE::SVM_LINEAR) { + for (int64_t j = 0; j < class_count_; j++) { //for each class + auto val = kernel_dot(x_data, current_weight_0, coefficients_, feature_count_ * j, + feature_count_, get_kernel_type()); + val += rho_[0]; + scores.push_back(val); + } + } else { + if (vector_count_ == 0) + return Status(common::ONNXRUNTIME, common::FAIL, "No support vectors."); + int evals = 0; + for (int64_t j = 0; j < vector_count_; j++) { - float val = kernel_dot(x_data, current_weight_0, support_vectors_, feature_count_ * j, feature_count_, get_kernel_type()); + auto val = kernel_dot(x_data, current_weight_0, support_vectors_, feature_count_ * j, + feature_count_, get_kernel_type()); kernels.push_back(val); } - for (int64_t j = 0; j < class_count_; j++) { - votes.push_back(0); - } - int evals = 0; - for (int64_t i = 0; i < class_count_; i++) { //for each class - for (int64_t j = i + 1; j < class_count_; j++) { //for each class - float sum = 0; + votes.resize(class_count_, 0); + for (int64_t i = 0; i < class_count_; i++) { // for each class + for (int64_t j = i + 1; j < class_count_; j++) { // for each class + double sum = 0; int64_t start_index_i = starting_vector_[i]; // *feature_count_; int64_t start_index_j = starting_vector_[j]; // *feature_count_; @@ -125,120 +162,71 @@ Status SVMClassifier::Compute(OpKernelContext* ctx) const { int64_t pos1 = (vector_count_) * (j - 1); int64_t pos2 = (vector_count_) * (i); - for (int64_t m = 0; m < class_i_support_count; m++) { - float val1 = coefficients_[pos1 + start_index_i + m]; - float val2 = kernels[start_index_i + m]; - sum += val1 * val2; - } - for (int64_t m = 0; m < class_j_support_count; m++) { - float val1 = coefficients_[pos2 + start_index_j + m]; - float val2 = kernels[start_index_j + m]; - sum += val1 * val2; - } + const float* val1 = &(coefficients_[pos1 + start_index_i]); + const float* val2 = &(kernels[start_index_i]); + for (int64_t m = 0; m < class_i_support_count; ++m, ++val1, ++val2) + sum += *val1 * *val2; + + val1 = &(coefficients_[pos2 + start_index_j]); + val2 = &(kernels[start_index_j]); + for (int64_t m = 0; m < class_j_support_count; ++m, ++val1, ++val2) + sum += *val1 * *val2; sum += rho_[evals]; - scores.push_back(sum); - if (sum > 0) { - votes[i]++; - } else { - votes[j]++; - } - evals++; //index into rho + scores.push_back((float)sum); + ++(votes[sum > 0 ? i : j]); + ++evals; //index into rho } } - } else if (mode_ == SVM_TYPE::SVM_LINEAR) { //liblinear - for (int64_t j = 0; j < class_count_; j++) { //for each class - float val = kernel_dot(x_data, current_weight_0, coefficients_, feature_count_ * j, feature_count_, get_kernel_type()); - val += rho_[0]; - scores.push_back(val); - } } + if (proba_.size() > 0 && mode_ == SVM_TYPE::SVM_SVC) { //compute probabilities from the scores - std::vector estimates; - std::vector probsp2; int64_t num = class_count_ * class_count_; - for (int64_t m = 0; m < num; m++) { - probsp2.push_back(0.f); //min prob - } - for (int64_t m = 0; m < class_count_; m++) { - estimates.push_back(0.f); //min prob - } + std::vector probsp2(num, 0.f); + std::vector estimates(class_count_, 0.f); int64_t index = 0; - for (int64_t i = 0; i < class_count_; i++) { - for (int64_t j = i + 1; j < class_count_; j++) { + for (int64_t i = 0; i < class_count_; ++i) { + int64_t p1 = i * class_count_ + i + 1; + int64_t p2 = (i + 1) * class_count_ + i; + for (int64_t j = i + 1; j < class_count_; ++j, ++index) { float val1 = sigmoid_probability(scores[index], proba_[index], probb_[index]); float val2 = std::max(val1, 1.0e-7f); - probsp2[i * class_count_ + j] = std::min(val2, 1 - 1.0e-7f); - probsp2[j * class_count_ + i] = 1 - probsp2[i * class_count_ + j]; - index++; + val2 = std::min(val2, 1 - 1.0e-7f); + probsp2[p1] = val2; + probsp2[p2] = 1 - val2; + ++p1; + p2 += class_count_; } } multiclass_probability(class_count_, probsp2, estimates); - //copy probabilities back into scores + // copy probabilities back into scores scores.resize(estimates.size()); - for (int64_t k = 0; k < static_cast(estimates.size()); k++) { - scores[k] = estimates[k]; - } + std::copy(estimates.begin(), estimates.end(), scores.begin()); } - int64_t maxvotes = 0; + + float max_weight = 0; if (votes.size() > 0) { - for (int64_t k = 0; k < static_cast(votes.size()); k++) { - if (votes[k] > maxvotes) { - maxvotes = votes[k]; - maxclass = k; - } - } + auto it_maxvotes = std::max_element(votes.begin(), votes.end()); + maxclass = std::distance(votes.begin(), it_maxvotes); } else { - for (int64_t k = 0; k < static_cast(scores.size()); k++) { - if (scores[k] > maxweight) { - maxclass = k; - maxweight = scores[k]; - } - } + auto it_max_weight = std::max_element(scores.begin(), scores.end()); + maxclass = std::distance(scores.begin(), it_max_weight); + max_weight = *it_max_weight; } - //write top class + + // write top class + // onnx specs expects one column per class. int write_additional_scores = -1; - if (rho_.size() == 1) //binary - { + if (rho_.size() == 1) { if (using_strings_) { - if (classlabels_strings_.size() == 2 && weights_are_all_positive_ && maxweight >= 0.5 && proba_.size() == 0) { - Y->template MutableData()[n] = classlabels_strings_[1]; //positive label - write_additional_scores = 0; - } else if (classlabels_strings_.size() == 2 && maxweight > 0 && !weights_are_all_positive_ && proba_.size() == 0) { - Y->template MutableData()[n] = classlabels_strings_[1]; //positive label - write_additional_scores = 0; - } else if (classlabels_strings_.size() == 2 && proba_.size() > 0) { //this case all classes are in their rightful spot - Y->template MutableData()[n] = classlabels_strings_[maxclass]; //whichever label - write_additional_scores = -1; - } else if (classlabels_strings_.size() == 2) { - Y->template MutableData()[n] = classlabels_strings_[0]; //negative label - write_additional_scores = 1; - } else if (maxweight > 0) { - Y->template MutableData()[n] = "1"; //positive label - } else { - Y->template MutableData()[n] = "0"; //negative label - } - } else //no strings - { - if (classlabels_ints_.size() == 2 && weights_are_all_positive_ && maxweight >= 0.5 && proba_.size() == 0) { - Y->template MutableData()[n] = classlabels_ints_[1]; //positive label - write_additional_scores = 0; - } else if (classlabels_ints_.size() == 2 && maxweight > 0 && !weights_are_all_positive_ && proba_.size() == 0) { - Y->template MutableData()[n] = classlabels_ints_[0]; //pos label - write_additional_scores = 0; - } else if (classlabels_ints_.size() == 2 && proba_.size() > 0) //this case all classes are in their rightful spot - { - Y->template MutableData()[n] = classlabels_ints_[maxclass]; //whichever label - write_additional_scores = -1; - } else if (classlabels_ints_.size() == 2) { - Y->template MutableData()[n] = classlabels_ints_[0]; //negative label - write_additional_scores = 1; - } else if (maxweight > 0) { - Y->template MutableData()[n] = 1; //positive label - } else { - Y->template MutableData()[n] = 0; //negative label - } + write_additional_scores = _set_score_svm( + Y, max_weight, maxclass, n, post_transform_, proba_, + weights_are_all_positive_, classlabels_strings_, "1", "0"); + } else { + write_additional_scores = _set_score_svm( + Y, max_weight, maxclass, n, post_transform_, proba_, + weights_are_all_positive_, classlabels_ints_, 1, 0); } } else { //multiclass if (using_strings_) { diff --git a/onnxruntime/core/providers/cpu/ml/svmclassifier.h b/onnxruntime/core/providers/cpu/ml/svmclassifier.h index 5490a44fbb333..6ef0d0f3acf12 100644 --- a/onnxruntime/core/providers/cpu/ml/svmclassifier.h +++ b/onnxruntime/core/providers/cpu/ml/svmclassifier.h @@ -30,31 +30,30 @@ class SVMCommon { KERNEL get_kernel_type() const { return kernel_type_; } float kernel_dot(const T* A, int64_t a, const std::vector& B, int64_t b, int64_t len, KERNEL k) const { - float sum = 0.f; + double sum = 0; + const T* pA = A + a; + const float* pB = B.data() + b; if (k == KERNEL::POLY) { - for (int64_t i = 0; i < len; i++) { - sum += B[b + i] * static_cast(A[a + i]); - } + for (int64_t i = len; i > 0; --i, ++pA, ++pB) + sum += *pA * *pB; sum = gamma_ * sum + coef0_; sum = std::pow(sum, degree_); } else if (k == KERNEL::SIGMOID) { - for (int64_t i = 0; i < len; i++) { - sum += B[b + i] * static_cast(A[a + i]); - } + for (int64_t i = len; i > 0; --i, ++pA, ++pB) + sum += *pA * *pB; sum = gamma_ * sum + coef0_; sum = std::tanh(sum); } else if (k == KERNEL::RBF) { - for (int64_t i = 0; i < len; i++) { - float val = static_cast(A[a + i]) - B[b + i]; - sum += (val * val); + for (int64_t i = len; i > 0; --i, ++pA, ++pB) { + double val = *pA - *pB; + sum += val * val; } sum = std::exp(-gamma_ * sum); } else if (k == KERNEL::LINEAR) { - for (int64_t i = 0; i < len; i++) { - sum += B[b + i] * static_cast(A[a + i]); - } + for (int64_t i = len; i > 0; --i, ++pA, ++pB) + sum += *pA * *pB; } - return sum; + return (float)sum; } private: diff --git a/onnxruntime/test/providers/cpu/ml/svmclassifier_test.cc b/onnxruntime/test/providers/cpu/ml/svmclassifier_test.cc index 9b75ebe5616b2..2928b39e6f5f0 100644 --- a/onnxruntime/test/providers/cpu/ml/svmclassifier_test.cc +++ b/onnxruntime/test/providers/cpu/ml/svmclassifier_test.cc @@ -10,14 +10,20 @@ namespace test { TEST(MLOpTest, SVMClassifierMulticlassSVC) { OpTester test("SVMClassifier", 1, onnxruntime::kMLDomain); - std::vector dual_coefficients = {1.14360327f, 1.95968249f, -1.175683f, -1.92760275f, -1.32575698f, -1.32575698f, 0.66332785f, 0.66242913f, 0.53120854f, 0.53510444f, -1.06631298f, -1.06631298f, 0.66332785f, 0.66242913f, 0.53120854f, 0.53510444f, 1.f, -1.f}; - std::vector support_vectors = {0.f, 0.5f, 32.f, 2.f, 2.9f, -32.f, 1.f, 1.5f, 1.f, 3.f, 13.3f, -11.f, 12.f, 12.9f, -312.f, 43.f, 413.3f, -114.f}; + std::vector dual_coefficients = {1.14360327f, 1.95968249f, -1.175683f, -1.92760275f, -1.32575698f, + -1.32575698f, 0.66332785f, 0.66242913f, 0.53120854f, 0.53510444f, + -1.06631298f, -1.06631298f, 0.66332785f, 0.66242913f, 0.53120854f, + 0.53510444f, 1.f, -1.f}; + std::vector support_vectors = {0.f, 0.5f, 32.f, 2.f, 2.9f, -32.f, 1.f, 1.5f, 1.f, 3.f, + 13.3f, -11.f, 12.f, 12.9f, -312.f, 43.f, 413.3f, -114.f}; std::vector classes = {0, 1, 2, 3}; std::vector vectors_per_class = {2, 2, 1, 1}; std::vector rho = {0.5279583f, 0.32605162f, 0.32605162f, 0.06663721f, 0.06663721f, 0.f}; std::vector kernel_params = {0.001f, 0.f, 3.f}; //gamma, coef0, degree - std::vector X = {1.f, 0.0f, 0.4f, 3.0f, 44.0f, -3.f, 12.0f, 12.9f, -312.f, 23.0f, 11.3f, -222.f, 23.0f, 11.3f, -222.f, 23.0f, 3311.3f, -222.f, 23.0f, 11.3f, -222.f, 43.0f, 413.3f, -114.f}; + std::vector X = {1.f, 0.0f, 0.4f, 3.0f, 44.0f, -3.f, 12.0f, 12.9f, -312.f, 23.0f, + 11.3f, -222.f, 23.0f, 11.3f, -222.f, 23.0f, 3311.3f, -222.f, 23.0f, + 11.3f, -222.f, 43.0f, 413.3f, -114.f}; std::vector predictions = {1, 1, 2, 0, 0, 0, 0, 3}; std::vector scores = { -0.956958294f, 0.799815655f, 0.799815655f, 0.988598406f, 0.988598406f, 0, @@ -47,12 +53,18 @@ TEST(MLOpTest, SVMClassifierMulticlassSVC) { TEST(MLOpTest, SVMClassifierMulticlassLinearSVC) { OpTester test("SVMClassifier", 1, onnxruntime::kMLDomain); - std::vector dual_coefficients = {-1.55181212e-01f, 2.42698956e-01f, 7.01893432e-03f, 4.07614474e-01f, -3.24927823e-02f, 2.79897536e-04f, -1.95771302e-01f, -3.52437368e-01f, -2.15973096e-02f, -4.38190277e-01f, 4.56869105e-02f, -1.29375499e-02f}; + std::vector dual_coefficients = {-1.55181212e-01f, 2.42698956e-01f, 7.01893432e-03f, + 4.07614474e-01f, -3.24927823e-02f, 2.79897536e-04f, + -1.95771302e-01f, -3.52437368e-01f, -2.15973096e-02f, + -4.38190277e-01f, 4.56869105e-02f, -1.29375499e-02f}; std::vector classes = {0, 1, 2, 3}; std::vector rho = {-0.07489691f, -0.1764396f, -0.21167431f, -0.51619097f}; std::vector kernel_params = {0.001f, 0.f, 3.f}; //gamma, coef0, degree - std::vector X = {1.f, 0.0f, 0.4f, 3.0f, 44.0f, -3.f, 12.0f, 12.9f, -312.f, 23.0f, 11.3f, -222.f, 23.0f, 11.3f, -222.f, 23.0f, 3311.3f, -222.f, 23.0f, 11.3f, -222.f, 43.0f, 413.3f, -114.f}; + std::vector X = {1.f, 0.0f, 0.4f, 3.0f, 44.0f, -3.f, + 12.0f, 12.9f, -312.f, 23.0f, 11.3f, -222.f, + 23.0f, 11.3f, -222.f, 23.0f, 3311.3f, -222.f, + 23.0f, 11.3f, -222.f, 43.0f, 413.3f, -114.f}; std::vector predictions = {1, 0, 1, 1, 1, 0, 1, 0}; std::vector scores = { -0.227270544f, 0.332829535f, -0.279307127f, -0.518262208f, @@ -115,5 +127,89 @@ TEST(MLOpTest, SVMClassifierSVCProbabilities) { test.Run(); } +TEST(MLOpTest, SVMClassifierSVC) { + OpTester test("SVMClassifier", 1, onnxruntime::kMLDomain); + + std::vector coefficients = {1.14360327f, 1.95968249f, -1.175683f, -1.92760275f, -1.32575698f, -1.32575698f, + 0.66332785f, 0.66242913f, 0.53120854f, 0.53510444f, -1.06631298f, -1.06631298f, + 0.66332785f, 0.66242913f, 0.53120854f, 0.53510444f, 1.f, -1.f}; + std::vector support_vectors = {0.f, 0.5f, 32.f, 2.f, 2.9f, -32.f, + 1.f, 1.5f, 1.f, 3.f, 13.3f, -11.f, + 12.f, 12.9f, -312.f, 43.f, 413.3f, -114.f}; + std::vector rho = {0.5279583f}; + std::vector kernel_params = {0.001f, 0.f, 3.f}; //gamma, coef0, degree + std::vector classes = {0, 1}; + std::vector vectors_per_class = {3, 3}; + + std::vector X = {1.f, 0.0f, 0.4f, + 3.0f, 44.0f, -3.f, + 12.0f, 12.9f, -312.f, + 23.0f, 11.3f, -222.f, + 23.0f, 11.3f, -222.f}; + std::vector scores_predictions = { + 0.95695829391479492f, -0.95695829391479492f, + 0.1597825288772583f, -0.1597825288772583f, + 0.797798752784729f, -0.797798752784729f, + -0.52760261297225952f, 0.52760261297225952f, + -0.52760261297225952f, 0.52760261297225952f}; + std::vector class_predictions = {1, 1, 1, 0, 0}; + + test.AddAttribute("kernel_type", std::string("RBF")); + test.AddAttribute("coefficients", coefficients); + test.AddAttribute("support_vectors", support_vectors); + test.AddAttribute("vectors_per_class", vectors_per_class); + test.AddAttribute("rho", rho); + test.AddAttribute("kernel_params", kernel_params); + test.AddAttribute("classlabels_ints", classes); + + test.AddInput("X", {5, 3}, X); + test.AddOutput("Y", {5}, class_predictions); + test.AddOutput("Z", {5, 2}, scores_predictions); + + test.Run(); +} + +TEST(MLOpTest, SVMClassifierSVCDouble) { + OpTester test("SVMClassifier", 1, onnxruntime::kMLDomain); + + std::vector coefficients = {1.14360327f, 1.95968249f, -1.175683f, -1.92760275f, -1.32575698f, -1.32575698f, + 0.66332785f, 0.66242913f, 0.53120854f, 0.53510444f, -1.06631298f, -1.06631298f, + 0.66332785f, 0.66242913f, 0.53120854f, 0.53510444f, 1.f, -1.f}; + std::vector support_vectors = {0.f, 0.5f, 32.f, 2.f, 2.9f, -32.f, + 1.f, 1.5f, 1.f, 3.f, 13.3f, -11.f, + 12.f, 12.9f, -312.f, 43.f, 413.3f, -114.f}; + std::vector rho = {0.5279583f}; + std::vector kernel_params = {0.001f, 0.f, 3.f}; //gamma, coef0, degree + std::vector classes = {0, 1}; + std::vector vectors_per_class = {3, 3}; + + std::vector X = {1.f, 0.0f, 0.4f, + 3.0f, 44.0f, -3.f, + 12.0f, 12.9f, -312.f, + 23.0f, 11.3f, -222.f, + 23.0f, 11.3f, -222.f}; + std::vector scores_predictions = { + 0.95695829391479492f, -0.95695829391479492f, + 0.1597825288772583f, -0.1597825288772583f, + 0.797798752784729f, -0.797798752784729f, + -0.52760261297225952f, 0.52760261297225952f, + -0.52760261297225952f, 0.52760261297225952f}; + std::vector class_predictions = {1, 1, 1, 0, 0}; + + test.AddAttribute("kernel_type", std::string("RBF")); + test.AddAttribute("coefficients", coefficients); + test.AddAttribute("support_vectors", support_vectors); + test.AddAttribute("vectors_per_class", vectors_per_class); + test.AddAttribute("rho", rho); + test.AddAttribute("kernel_params", kernel_params); + test.AddAttribute("classlabels_ints", classes); + + test.AddInput("X", {5, 3}, X); + test.AddOutput("Y", {5}, class_predictions); + test.AddOutput("Z", {5, 2}, scores_predictions); + + test.Run(); +} + } // namespace test } // namespace onnxruntime