Skip to content

Commit

Permalink
Merge pull request #510 from HappyTomatoo/feat-binary-class
Browse files Browse the repository at this point in the history
Feat binary class in tree_ensemble_classifier operator
  • Loading branch information
raphaelDkhn authored Jan 6, 2024
2 parents 8c387a9 + 1dbf5a6 commit 538eaa0
Show file tree
Hide file tree
Showing 2 changed files with 484 additions and 1 deletion.
151 changes: 150 additions & 1 deletion src/operators/ml/tree_ensemble/tree_ensemble_classifier.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ impl TreeEnsembleClassifierImpl<
fn predict(ref self: TreeEnsembleClassifier<T>, X: Tensor<T>) -> (Span<usize>, MutMatrix::<T>) {
let leaves_index = self.ensemble.leave_index_tree(X);
let n_classes = self.classlabels.len();
assert(n_classes > 1, 'binary class not supported yet');
let mut res: MutMatrix<T> = MutMatrixImpl::new(*leaves_index.shape.at(0), n_classes);

// Set base values
Expand Down Expand Up @@ -402,6 +401,156 @@ impl TreeEnsembleClassifierImpl<
i += 1;
};

// Binary class
let mut binary = false;
let mut i: usize = 0;
let mut class_ids = self.class_ids;
let mut class_id: usize = 0;
// Get first class_id in class_ids
match class_ids.pop_front() {
Option::Some(c_id) => {
let mut class_id = *c_id;
},
Option::None(_) => {
let mut class_id: usize = 0;
}
};
loop {
if i == self.class_ids.len() {
break;
}
match class_ids.pop_front() {
Option::Some(c_id) => {
if *c_id == class_id {
binary = true;
continue;
}else{
binary = false;
break;
}

},
Option::None(_) => { break; }
};

};

// Clone res
if binary{
let mut new_res: MutMatrix<T> = MutMatrixImpl::new(res.rows, res.cols);
let mut i: usize = 0;
loop {
if i == res.rows {
break;
}
// Exchange
let res_ele_1 = match res.get(i, 0) {
Option::Some(res_0) => {
new_res.set(i, 1, res_0);
},
Option::None(_) => {
new_res.set(i, 1, NumberTrait::zero());
},
};
i+=1;
};
match self.post_transform {
POST_TRANSFORM::NONE => {
let mut i: usize = 0;
loop {
if i == res.rows {
break;
}
// Exchange
let res_ele_0 = match new_res.get(i, 1) {
Option::Some(res_1) => {
let value = NumberTrait::sub(NumberTrait::one(), res_1);
new_res.set(i, 0, value);
},
Option::None(_) => {
new_res.set(i, 0, NumberTrait::zero());
},
};
i+=1;
};
},
POST_TRANSFORM::SOFTMAX => {
let mut i: usize = 0;
loop {
if i == res.rows {
break;
}
// Exchange
let res_ele_0 = match new_res.get(i, 1) {
Option::Some(res_1) => {
new_res.set(i, 0, res_1.neg());
},
Option::None(_) => {
new_res.set(i, 0, NumberTrait::zero());
},
};
i+=1;
};
},
POST_TRANSFORM::LOGISTIC => {
let mut i: usize = 0;
loop {
if i == res.rows {
break;
}
// Exchange
let res_ele_0 = match new_res.get(i, 1) {
Option::Some(res_1) => {
new_res.set(i, 0, res_1.neg());
},
Option::None(_) => {
new_res.set(i, 0, NumberTrait::zero());
},
};
i+=1;
};
},
POST_TRANSFORM::SOFTMAXZERO => {
let mut i: usize = 0;
loop {
if i == res.rows {
break;
}
// Exchange
let res_ele_0 = match new_res.get(i, 1) {
Option::Some(res_1) => {
new_res.set(i, 0, res_1.neg());
},
Option::None(_) => {
new_res.set(i, 0, NumberTrait::zero());
},
};
i+=1;
};
},
POST_TRANSFORM::PROBIT => {
let mut i: usize = 0;
loop {
if i == res.rows {
break;
}
// Exchange
let res_ele_0 = match new_res.get(i, 1) {
Option::Some(res_1) => {
let value = NumberTrait::sub(NumberTrait::one(), res_1);
new_res.set(i, 0, value);
},
Option::None(_) => {
new_res.set(i, 0, NumberTrait::zero());
},
};
i+=1;
};
},
};
res = new_res;
}

// Post Transform
let mut new_scores = match self.post_transform {
POST_TRANSFORM::NONE => res, // No action required
Expand Down
Loading

0 comments on commit 538eaa0

Please sign in to comment.