Skip to content

Commit

Permalink
[Lgbm] support multi classification (#3234)
Browse files Browse the repository at this point in the history
  • Loading branch information
ewan0x79 authored Jun 4, 2024
1 parent ee9fd35 commit c17307b
Showing 1 changed file with 9 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,15 @@ public static Pair<Integer, ByteBuffer> inferenceMat(
SWIGTYPE_p_p_void model, int iterations, LgbmNDArray a) {
SWIGTYPE_p_long_long outLength = lightgbmlib.new_int64_tp();
SWIGTYPE_p_double outBuffer = null;
SWIGTYPE_p_int numClasses = lightgbmlib.new_intp();
try {
outBuffer = lightgbmlib.new_doubleArray(2L * a.getRows());
int outFlag =
lightgbmlib.LGBM_BoosterGetNumClasses(
lightgbmlib.voidpp_value(model), numClasses);
checkCall(outFlag);
int classes = lightgbmlib.intp_value(numClasses);

outBuffer = lightgbmlib.new_doubleArray((long) classes * a.getRows());
int result =
lightgbmlib.LGBM_BoosterPredictForMat(
lightgbmlib.voidpp_value(model),
Expand Down Expand Up @@ -130,6 +137,7 @@ public static Pair<Integer, ByteBuffer> inferenceMat(
if (outBuffer != null) {
lightgbmlib.delete_doubleArray(outBuffer);
}
lightgbmlib.delete_intp(numClasses);
}
}

Expand Down

0 comments on commit c17307b

Please sign in to comment.