Skip to content

Commit

Permalink
Emit predicted category using an appropriate JSON type. (elastic#877)
Browse files Browse the repository at this point in the history
  • Loading branch information
przemekwitek committed Dec 6, 2019
1 parent dcf7370 commit 5dc608d
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 17 deletions.
2 changes: 2 additions & 0 deletions docs/CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ tree which is trained for both regression and classification. (See {ml-pull}811[
(See {ml-pull}818[#818].)
* Reduce memory usage of {ml} native processes on Windows. (See {ml-pull}844[#844].)
* Reduce runtime of classification and regression. (See {ml-pull}863[#863].)
* Emit `prediction_field_name` in ml results using the type provided as
`prediction_field_type` parameter. (See {ml-pull}877[#877].)

=== Bug Fixes
* Fixes potential memory corruption when determining seasonality. (See {ml-pull}852[#852].)
Expand Down
11 changes: 11 additions & 0 deletions include/api/CDataFrameTrainBoostedTreeClassifierRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ namespace api {
class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
: public CDataFrameTrainBoostedTreeRunner {
public:
enum EPredictionFieldType {
E_PredictionFieldTypeString,
E_PredictionFieldTypeInt,
E_PredictionFieldTypeBool
};

static const CDataFrameAnalysisConfigReader& parameterReader();

//! This is not intended to be called directly: use CDataFrameTrainBoostedTreeClassifierRunnerFactory.
Expand All @@ -44,6 +50,10 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
const TRowRef& row,
core::CRapidJsonConcurrentLineWriter& writer) const;

//! Write the predicted category value as string, int or bool.
void writePredictedCategoryValue(const std::string& categoryValue,
core::CRapidJsonConcurrentLineWriter& writer) const;

//! \return A serialisable definition of the trained classification model.
TInferenceModelDefinitionUPtr
inferenceModelDefinition(const TStrVec& fieldNames,
Expand All @@ -55,6 +65,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final

private:
std::size_t m_NumTopClasses;
EPredictionFieldType m_PredictionFieldType;
};

//! \brief Makes a core::CDataFrame boosted tree classification runner.
Expand Down
41 changes: 39 additions & 2 deletions lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ using TSizeVec = std::vector<std::size_t>;

// Configuration
const std::string NUM_TOP_CLASSES{"num_top_classes"};
const std::string PREDICTION_FIELD_TYPE{"prediction_field_type"};
const std::string BALANCED_CLASS_LOSS{"balanced_class_loss"};

// Output
Expand All @@ -45,8 +46,16 @@ const std::string CLASS_PROBABILITY_FIELD_NAME{"class_probability"};
const CDataFrameAnalysisConfigReader&
CDataFrameTrainBoostedTreeClassifierRunner::parameterReader() {
static const CDataFrameAnalysisConfigReader PARAMETER_READER{[] {
const std::string typeString{"string"};
const std::string typeInt{"int"};
const std::string typeBool{"bool"};
auto theReader = CDataFrameTrainBoostedTreeRunner::parameterReader();
theReader.addParameter(NUM_TOP_CLASSES, CDataFrameAnalysisConfigReader::E_OptionalParameter);
theReader.addParameter(PREDICTION_FIELD_TYPE,
CDataFrameAnalysisConfigReader::E_OptionalParameter,
{{typeString, int{E_PredictionFieldTypeString}},
{typeInt, int{E_PredictionFieldTypeInt}},
{typeBool, int{E_PredictionFieldTypeBool}}});
theReader.addParameter(BALANCED_CLASS_LOSS,
CDataFrameAnalysisConfigReader::E_OptionalParameter);
return theReader;
Expand All @@ -60,6 +69,8 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier
: CDataFrameTrainBoostedTreeRunner{spec, parameters} {

m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::size_t{0});
m_PredictionFieldType =
parameters[PREDICTION_FIELD_TYPE].fallback(E_PredictionFieldTypeString);
this->boostedTreeFactory().balanceClassTrainingLoss(
parameters[BALANCED_CLASS_LOSS].fallback(true));

Expand Down Expand Up @@ -119,7 +130,7 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(

writer.StartObject();
writer.Key(this->predictionFieldName());
writer.String(categoryValues[predictedCategoryId]);
writePredictedCategoryValue(categoryValues[predictedCategoryId], writer);
writer.Key(PREDICTION_PROBABILITY_FIELD_NAME);
writer.Double(probabilityOfCategory[predictedCategoryId]);
writer.Key(IS_TRAINING_FIELD_NAME);
Expand All @@ -135,7 +146,7 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
for (std::size_t i = 0; i < std::min(categoryIds.size(), m_NumTopClasses); ++i) {
writer.StartObject();
writer.Key(CLASS_NAME_FIELD_NAME);
writer.String(categoryValues[categoryIds[i]]);
writePredictedCategoryValue(categoryValues[categoryIds[i]], writer);
writer.Key(CLASS_PROBABILITY_FIELD_NAME);
writer.Double(probabilityOfCategory[i]);
writer.EndObject();
Expand All @@ -158,6 +169,32 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
columnHoldingPrediction, row, writer);
}

void CDataFrameTrainBoostedTreeClassifierRunner::writePredictedCategoryValue(
const std::string& categoryValue,
core::CRapidJsonConcurrentLineWriter& writer) const {

double doubleValue;
switch (m_PredictionFieldType) {
case E_PredictionFieldTypeString:
writer.String(categoryValue);
break;
case E_PredictionFieldTypeInt:
if (core::CStringUtils::stringToType(categoryValue, doubleValue)) {
writer.Int64(static_cast<std::int64_t>(doubleValue));
} else {
writer.String(categoryValue);
}
break;
case E_PredictionFieldTypeBool:
if (core::CStringUtils::stringToType(categoryValue, doubleValue)) {
writer.Bool(doubleValue != 0.0);
} else {
writer.String(categoryValue);
}
break;
}
}

CDataFrameTrainBoostedTreeClassifierRunner::TLossFunctionUPtr
CDataFrameTrainBoostedTreeClassifierRunner::chooseLossFunction(const core::CDataFrame& frame,
std::size_t dependentVariableColumn) const {
Expand Down
65 changes: 50 additions & 15 deletions lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,20 @@ BOOST_AUTO_TEST_CASE(testPredictionFieldNameClash) {
BOOST_TEST_REQUIRE(errors[0] == "Input error: prediction_field_name must not be equal to any of [is_training, prediction_probability, top_classes].");
}

BOOST_AUTO_TEST_CASE(testWriteOneRow) {
template<typename T>
void testWriteOneRow(const std::string& dependentVariableField,
const std::string& predictionFieldType,
T (rapidjson::Value::*extract)() const,
const std::vector<T>& expectedPredictions) {
// Prepare input data frame
const TStrVec columnNames{"x1", "x2", "x3", "x4", "x5", "x5_prediction"};
const TStrVec categoricalColumns{"x1", "x2", "x5"};
const std::string predictionField = dependentVariableField + "_prediction";
const TStrVec columnNames{"x1", "x2", "x3", "x4", "x5", predictionField};
const TStrVec categoricalColumns{"x1", "x2", "x3", "x4", "x5"};
const TStrVecVec rows{{"a", "b", "1.0", "1.0", "cat", "-1.0"},
{"a", "b", "2.0", "2.0", "cat", "-0.5"},
{"a", "b", "5.0", "5.0", "dog", "-0.1"},
{"c", "d", "5.0", "5.0", "dog", "1.0"},
{"e", "f", "5.0", "5.0", "dog", "1.5"}};
{"a", "b", "1.0", "1.0", "cat", "-0.5"},
{"a", "b", "5.0", "0.0", "dog", "-0.1"},
{"c", "d", "5.0", "0.0", "dog", "1.0"},
{"e", "f", "5.0", "0.0", "dog", "1.5"}};
std::unique_ptr<core::CDataFrame> frame =
core::makeMainStorageDataFrame(columnNames.size()).first;
frame->columnNames(columnNames);
Expand All @@ -67,10 +72,21 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {

// Create classification analysis runner object
const auto spec{test::CDataFrameAnalysisSpecificationFactory::predictionSpec(
"classification", "x5", rows.size(), columnNames.size(), 13000000, 0, 0,
categoricalColumns)};
"classification", dependentVariableField, rows.size(),
columnNames.size(), 13000000, 0, 0, categoricalColumns)};
rapidjson::Document jsonParameters;
jsonParameters.Parse("{\"dependent_variable\": \"x5\"}");
if (predictionFieldType.empty()) {
jsonParameters.Parse("{\"dependent_variable\": \"" + dependentVariableField + "\"}");
} else {
jsonParameters.Parse("{"
" \"dependent_variable\": \"" +
dependentVariableField +
"\","
" \"prediction_field_type\": \"" +
predictionFieldType +
"\""
"}");
}
const auto parameters{
api::CDataFrameTrainBoostedTreeClassifierRunner::parameterReader().read(jsonParameters)};
api::CDataFrameTrainBoostedTreeClassifierRunner runner(*spec, parameters);
Expand All @@ -83,10 +99,10 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {

frame->readRows(1, [&](TRowItr beginRows, TRowItr endRows) {
const auto columnHoldingDependentVariable{
std::find(columnNames.begin(), columnNames.end(), "x5") -
std::find(columnNames.begin(), columnNames.end(), dependentVariableField) -
columnNames.begin()};
const auto columnHoldingPrediction{
std::find(columnNames.begin(), columnNames.end(), "x5_prediction") -
std::find(columnNames.begin(), columnNames.end(), predictionField) -
columnNames.begin()};
for (auto row = beginRows; row != endRows; ++row) {
runner.writeOneRow(*frame, columnHoldingDependentVariable,
Expand All @@ -95,17 +111,17 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
});
}
// Verify results
const TStrVec expectedPredictions{"cat", "cat", "cat", "dog", "dog"};
rapidjson::Document arrayDoc;
arrayDoc.Parse<rapidjson::kParseDefaultFlags>(output.str().c_str());
BOOST_TEST_REQUIRE(arrayDoc.IsArray());
BOOST_TEST_REQUIRE(arrayDoc.Size() == rows.size());
BOOST_TEST_REQUIRE(arrayDoc.Size() == expectedPredictions.size());
for (std::size_t i = 0; i < arrayDoc.Size(); ++i) {
BOOST_TEST_CONTEXT("Result for row " << i) {
const rapidjson::Value& object = arrayDoc[rapidjson::SizeType(i)];
BOOST_TEST_REQUIRE(object.IsObject());
BOOST_TEST_REQUIRE(object.HasMember("x5_prediction"));
BOOST_TEST_REQUIRE(object["x5_prediction"].GetString() ==
BOOST_TEST_REQUIRE(object.HasMember(predictionField));
BOOST_TEST_REQUIRE((object[predictionField].*extract)() ==
expectedPredictions[i]);
BOOST_TEST_REQUIRE(object.HasMember("prediction_probability"));
BOOST_TEST_REQUIRE(object["prediction_probability"].GetDouble() > 0.5);
Expand All @@ -115,4 +131,23 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
}
}

BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsInt) {
testWriteOneRow("x3", "int", &rapidjson::Value::GetInt, {1, 1, 1, 5, 5});
}

BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsBool) {
testWriteOneRow("x4", "bool", &rapidjson::Value::GetBool,
{true, true, true, false, false});
}

BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsString) {
testWriteOneRow("x5", "string", &rapidjson::Value::GetString,
{"cat", "cat", "cat", "dog", "dog"});
}

BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsMissing) {
testWriteOneRow("x5", "", &rapidjson::Value::GetString,
{"cat", "cat", "cat", "dog", "dog"});
}

BOOST_AUTO_TEST_SUITE_END()

0 comments on commit 5dc608d

Please sign in to comment.