Skip to content

Commit

Permalink
Fix svm runtime (replaces #91, #33) (#101)
Browse files Browse the repository at this point in the history
* Fix svm runtime
  • Loading branch information
xadupre authored May 14, 2019
1 parent 406770c commit d32d2fb
Show file tree
Hide file tree
Showing 4 changed files with 262 additions and 162 deletions.
81 changes: 49 additions & 32 deletions onnxruntime/core/providers/cpu/ml/ml_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,39 +287,56 @@ static inline void ComputeSoftmaxZero(std::vector<float>& values) {
template <typename T>
void write_scores(std::vector<T>& scores, POST_EVAL_TRANSFORM post_transform, int64_t write_index, Tensor* Z,
int add_second_class) {
if (post_transform == POST_EVAL_TRANSFORM::PROBIT && scores.size() == 1) {
scores[0] = ComputeProbit(scores[0]);
} else if (scores.size() >= 2) { //multiclass
if (post_transform == POST_EVAL_TRANSFORM::LOGISTIC) {
for (float& score : scores) {
score = ComputeLogistic(score);
}
} else if (post_transform == POST_EVAL_TRANSFORM::SOFTMAX) {
ComputeSoftmax(scores);
} else if (post_transform == POST_EVAL_TRANSFORM::SOFTMAX_ZERO) {
ComputeSoftmaxZero(scores);
if (scores.size() >= 2) {
switch (post_transform) {
case POST_EVAL_TRANSFORM::PROBIT:
for (float& score : scores)
score = ComputeProbit(score);
break;
case POST_EVAL_TRANSFORM::LOGISTIC:
for (float& score : scores)
score = ComputeLogistic(score);
break;
case POST_EVAL_TRANSFORM::SOFTMAX:
ComputeSoftmax(scores);
break;
case POST_EVAL_TRANSFORM::SOFTMAX_ZERO:
ComputeSoftmaxZero(scores);
break;
default:
case POST_EVAL_TRANSFORM::NONE:
break;
}
} else { //binary case
if (add_second_class == 0 && scores.size() == 1) { //0=all positive weights, winning class is positive
scores.push_back(scores[0]);
scores[0] = 1.f - scores[0]; //put opposite score in positive slot
} else if (add_second_class == 1 && scores.size() == 1) { //1 = all positive weights, winning class is negative
scores.push_back(scores[0]);
scores[0] = 1.f - scores[0]; //put opposite score in positive slot
} else if (add_second_class == 2 && scores.size() == 1) { //2 = mixed weights, winning class is positive
if (post_transform == POST_EVAL_TRANSFORM::LOGISTIC) {
scores.push_back(ComputeLogistic(scores[0]));
scores[0] = ComputeLogistic(-scores[0]);
} else {
scores.push_back(scores[0]);
scores[0] = -scores[0];
}
} else if (add_second_class == 3 && scores.size() == 1) { //3 = mixed weights, winning class is negative
if (post_transform == POST_EVAL_TRANSFORM::LOGISTIC) {
scores.push_back(ComputeLogistic(scores[0]));
scores[0] = ComputeLogistic(-scores[0]);
} else {
scores.push_back(-scores[0]);
} else if (scores.size() == 1) { //binary case
if (post_transform == POST_EVAL_TRANSFORM::PROBIT) {
scores[0] = ComputeProbit(scores[0]);
} else {
switch (add_second_class) {
case 0: //0=all positive weights, winning class is positive
scores.push_back(scores[0]);
scores[0] = 1.f - scores[0]; //put opposite score in positive slot
break;
case 1: //1 = all positive weights, winning class is negative
scores.push_back(scores[0]);
scores[0] = 1.f - scores[0]; //put opposite score in positive slot
break;
case 2: //2 = mixed weights, winning class is positive
if (post_transform == POST_EVAL_TRANSFORM::LOGISTIC) {
scores.push_back(ComputeLogistic(scores[0])); //ml_logit(scores[k]);
scores[0] = ComputeLogistic(-scores[0]);
} else {
scores.push_back(scores[0]);
scores[0] = -scores[0];
}
break;
case 3: //3 = mixed weights, winning class is negative
if (post_transform == POST_EVAL_TRANSFORM::LOGISTIC) {
scores.push_back(ComputeLogistic(scores[0])); //ml_logit(scores[k]);
scores[0] = ComputeLogistic(-scores[0]);
} else {
scores.push_back(-scores[0]);
}
break;
}
}
}
Expand Down
210 changes: 99 additions & 111 deletions onnxruntime/core/providers/cpu/ml/svmclassifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,32 @@ SVMClassifier<T>::SVMClassifier(const OpKernelInfo& info)
}
}

template <typename LabelType>
int _set_score_svm(Tensor* Y, float max_weight, const int64_t maxclass, const int64_t n,
POST_EVAL_TRANSFORM post_transform_, const std::vector<float>& proba_, bool weights_are_all_positive_,
const std::vector<LabelType>& classlabels, LabelType posclass, LabelType negclass) {
int write_additional_scores = -1;
auto output_data = Y->template MutableData<LabelType>();
if (classlabels.size() == 2) {
write_additional_scores = post_transform_ == POST_EVAL_TRANSFORM::NONE ? 2 : 0;
if (proba_.size() == 0) {
if (weights_are_all_positive_ && max_weight >= 0.5)
output_data[n] = classlabels[1];
else if (max_weight > 0 && !weights_are_all_positive_)
output_data[n] = classlabels[1];
else
output_data[n] = classlabels[maxclass];
} else {
output_data[n] = classlabels[maxclass];
}
} else if (max_weight > 0) {
output_data[n] = posclass;
} else {
output_data[n] = negclass;
}
return write_additional_scores;
}

template <typename T>
Status SVMClassifier<T>::Compute(OpKernelContext* ctx) const {
const Tensor* X = ctx->Input<Tensor>(0);
Expand All @@ -83,40 +109,51 @@ Status SVMClassifier<T>::Compute(OpKernelContext* ctx) const {
int64_t N = X->Shape().NumDimensions() == 1 ? 1 : X->Shape()[0];

Tensor* Y = ctx->Output(0, TensorShape({N}));
Tensor* Z;

std::vector<int64_t> dims;
if (mode_ == SVM_TYPE::SVM_SVC && proba_.size() == 0)
dims = {static_cast<int64_t>(N), static_cast<int64_t>(class_count_ * (class_count_ - 1) / 2)};
else
dims = {static_cast<int64_t>(N), static_cast<int64_t>(class_count_)};
Z = ctx->Output(1, TensorShape(dims));
int64_t nb_columns = class_count_;
if (proba_.size() == 0 && vector_count_ > 0) {
if (class_count_ > 2)
nb_columns = class_count_ * (class_count_ - 1) / 2;
else
nb_columns = 2;
}

std::vector<int64_t> dims{N, nb_columns};
Tensor* Z = ctx->Output(1, TensorShape(dims));

const auto* x_data = X->template Data<T>();
const T* x_data = X->template Data<T>();
int64_t zindex = 0;

for (int64_t n = 0; n < N; n++) //for each example
{
int64_t current_weight_0 = n * stride;
int64_t maxclass = -1;
double maxweight = 0.f;
std::vector<float> decisions;
std::vector<float> scores;
std::vector<float> kernels;
std::vector<int64_t> votes;

if (mode_ == SVM_TYPE::SVM_SVC) {
if (vector_count_ == 0 && mode_ == SVM_TYPE::SVM_LINEAR) {
for (int64_t j = 0; j < class_count_; j++) { //for each class
auto val = kernel_dot(x_data, current_weight_0, coefficients_, feature_count_ * j,
feature_count_, get_kernel_type());
val += rho_[0];
scores.push_back(val);
}
} else {
if (vector_count_ == 0)
return Status(common::ONNXRUNTIME, common::FAIL, "No support vectors.");
int evals = 0;

for (int64_t j = 0; j < vector_count_; j++) {
float val = kernel_dot(x_data, current_weight_0, support_vectors_, feature_count_ * j, feature_count_, get_kernel_type());
auto val = kernel_dot(x_data, current_weight_0, support_vectors_, feature_count_ * j,
feature_count_, get_kernel_type());
kernels.push_back(val);
}
for (int64_t j = 0; j < class_count_; j++) {
votes.push_back(0);
}
int evals = 0;
for (int64_t i = 0; i < class_count_; i++) { //for each class
for (int64_t j = i + 1; j < class_count_; j++) { //for each class
float sum = 0;
votes.resize(class_count_, 0);
for (int64_t i = 0; i < class_count_; i++) { // for each class
for (int64_t j = i + 1; j < class_count_; j++) { // for each class
double sum = 0;
int64_t start_index_i = starting_vector_[i]; // *feature_count_;
int64_t start_index_j = starting_vector_[j]; // *feature_count_;

Expand All @@ -125,120 +162,71 @@ Status SVMClassifier<T>::Compute(OpKernelContext* ctx) const {

int64_t pos1 = (vector_count_) * (j - 1);
int64_t pos2 = (vector_count_) * (i);
for (int64_t m = 0; m < class_i_support_count; m++) {
float val1 = coefficients_[pos1 + start_index_i + m];
float val2 = kernels[start_index_i + m];
sum += val1 * val2;
}
for (int64_t m = 0; m < class_j_support_count; m++) {
float val1 = coefficients_[pos2 + start_index_j + m];
float val2 = kernels[start_index_j + m];
sum += val1 * val2;
}
const float* val1 = &(coefficients_[pos1 + start_index_i]);
const float* val2 = &(kernels[start_index_i]);
for (int64_t m = 0; m < class_i_support_count; ++m, ++val1, ++val2)
sum += *val1 * *val2;

val1 = &(coefficients_[pos2 + start_index_j]);
val2 = &(kernels[start_index_j]);
for (int64_t m = 0; m < class_j_support_count; ++m, ++val1, ++val2)
sum += *val1 * *val2;

sum += rho_[evals];
scores.push_back(sum);
if (sum > 0) {
votes[i]++;
} else {
votes[j]++;
}
evals++; //index into rho
scores.push_back((float)sum);
++(votes[sum > 0 ? i : j]);
++evals; //index into rho
}
}
} else if (mode_ == SVM_TYPE::SVM_LINEAR) { //liblinear
for (int64_t j = 0; j < class_count_; j++) { //for each class
float val = kernel_dot(x_data, current_weight_0, coefficients_, feature_count_ * j, feature_count_, get_kernel_type());
val += rho_[0];
scores.push_back(val);
}
}

if (proba_.size() > 0 && mode_ == SVM_TYPE::SVM_SVC) {
//compute probabilities from the scores
std::vector<float> estimates;
std::vector<float> probsp2;
int64_t num = class_count_ * class_count_;
for (int64_t m = 0; m < num; m++) {
probsp2.push_back(0.f); //min prob
}
for (int64_t m = 0; m < class_count_; m++) {
estimates.push_back(0.f); //min prob
}
std::vector<float> probsp2(num, 0.f);
std::vector<float> estimates(class_count_, 0.f);
int64_t index = 0;
for (int64_t i = 0; i < class_count_; i++) {
for (int64_t j = i + 1; j < class_count_; j++) {
for (int64_t i = 0; i < class_count_; ++i) {
int64_t p1 = i * class_count_ + i + 1;
int64_t p2 = (i + 1) * class_count_ + i;
for (int64_t j = i + 1; j < class_count_; ++j, ++index) {
float val1 = sigmoid_probability(scores[index], proba_[index], probb_[index]);
float val2 = std::max(val1, 1.0e-7f);
probsp2[i * class_count_ + j] = std::min(val2, 1 - 1.0e-7f);
probsp2[j * class_count_ + i] = 1 - probsp2[i * class_count_ + j];
index++;
val2 = std::min(val2, 1 - 1.0e-7f);
probsp2[p1] = val2;
probsp2[p2] = 1 - val2;
++p1;
p2 += class_count_;
}
}
multiclass_probability(class_count_, probsp2, estimates);
//copy probabilities back into scores
// copy probabilities back into scores
scores.resize(estimates.size());
for (int64_t k = 0; k < static_cast<int64_t>(estimates.size()); k++) {
scores[k] = estimates[k];
}
std::copy(estimates.begin(), estimates.end(), scores.begin());
}
int64_t maxvotes = 0;

float max_weight = 0;
if (votes.size() > 0) {
for (int64_t k = 0; k < static_cast<int64_t>(votes.size()); k++) {
if (votes[k] > maxvotes) {
maxvotes = votes[k];
maxclass = k;
}
}
auto it_maxvotes = std::max_element(votes.begin(), votes.end());
maxclass = std::distance(votes.begin(), it_maxvotes);
} else {
for (int64_t k = 0; k < static_cast<int64_t>(scores.size()); k++) {
if (scores[k] > maxweight) {
maxclass = k;
maxweight = scores[k];
}
}
auto it_max_weight = std::max_element(scores.begin(), scores.end());
maxclass = std::distance(scores.begin(), it_max_weight);
max_weight = *it_max_weight;
}
//write top class

// write top class
// onnx specs expects one column per class.
int write_additional_scores = -1;
if (rho_.size() == 1) //binary
{
if (rho_.size() == 1) {
if (using_strings_) {
if (classlabels_strings_.size() == 2 && weights_are_all_positive_ && maxweight >= 0.5 && proba_.size() == 0) {
Y->template MutableData<std::string>()[n] = classlabels_strings_[1]; //positive label
write_additional_scores = 0;
} else if (classlabels_strings_.size() == 2 && maxweight > 0 && !weights_are_all_positive_ && proba_.size() == 0) {
Y->template MutableData<std::string>()[n] = classlabels_strings_[1]; //positive label
write_additional_scores = 0;
} else if (classlabels_strings_.size() == 2 && proba_.size() > 0) { //this case all classes are in their rightful spot
Y->template MutableData<std::string>()[n] = classlabels_strings_[maxclass]; //whichever label
write_additional_scores = -1;
} else if (classlabels_strings_.size() == 2) {
Y->template MutableData<std::string>()[n] = classlabels_strings_[0]; //negative label
write_additional_scores = 1;
} else if (maxweight > 0) {
Y->template MutableData<std::string>()[n] = "1"; //positive label
} else {
Y->template MutableData<std::string>()[n] = "0"; //negative label
}
} else //no strings
{
if (classlabels_ints_.size() == 2 && weights_are_all_positive_ && maxweight >= 0.5 && proba_.size() == 0) {
Y->template MutableData<int64_t>()[n] = classlabels_ints_[1]; //positive label
write_additional_scores = 0;
} else if (classlabels_ints_.size() == 2 && maxweight > 0 && !weights_are_all_positive_ && proba_.size() == 0) {
Y->template MutableData<int64_t>()[n] = classlabels_ints_[0]; //pos label
write_additional_scores = 0;
} else if (classlabels_ints_.size() == 2 && proba_.size() > 0) //this case all classes are in their rightful spot
{
Y->template MutableData<int64_t>()[n] = classlabels_ints_[maxclass]; //whichever label
write_additional_scores = -1;
} else if (classlabels_ints_.size() == 2) {
Y->template MutableData<int64_t>()[n] = classlabels_ints_[0]; //negative label
write_additional_scores = 1;
} else if (maxweight > 0) {
Y->template MutableData<int64_t>()[n] = 1; //positive label
} else {
Y->template MutableData<int64_t>()[n] = 0; //negative label
}
write_additional_scores = _set_score_svm<std::string>(
Y, max_weight, maxclass, n, post_transform_, proba_,
weights_are_all_positive_, classlabels_strings_, "1", "0");
} else {
write_additional_scores = _set_score_svm<int64_t>(
Y, max_weight, maxclass, n, post_transform_, proba_,
weights_are_all_positive_, classlabels_ints_, 1, 0);
}
} else { //multiclass
if (using_strings_) {
Expand Down
Loading

0 comments on commit d32d2fb

Please sign in to comment.