From 2d085dc066fef059956c92c6a7ad6f3b360bc0a9 Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Thu, 5 Dec 2019 12:11:12 +0100 Subject: [PATCH] Rename dependent_variable_type to prediction_field_type as that's how the field is really used in C++ code --- ...ataFrameTrainBoostedTreeClassifierRunner.h | 2 +- ...taFrameTrainBoostedTreeClassifierRunner.cc | 20 +++++++++---------- ...ameTrainBoostedTreeClassifierRunnerTest.cc | 12 +++++------ 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h b/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h index 3f0964af98..ccf0ccccf1 100644 --- a/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h +++ b/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h @@ -59,7 +59,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final private: std::size_t m_NumTopClasses; - std::string m_DependentVariableType; + std::string m_PredictionFieldType; }; //! \brief Makes a core::CDataFrame boosted tree classification runner. diff --git a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc index 3aaab55e32..e60bcf5814 100644 --- a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc +++ b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc @@ -32,10 +32,10 @@ using TSizeVec = std::vector; // Configuration const std::string NUM_TOP_CLASSES{"num_top_classes"}; -const std::string DEPENDENT_VARIABLE_TYPE{"dependent_variable_type"}; -const std::string DEPENDENT_VARIABLE_TYPE_STRING{"string"}; -const std::string DEPENDENT_VARIABLE_TYPE_INT{"int"}; -const std::string DEPENDENT_VARIABLE_TYPE_BOOL{"bool"}; +const std::string PREDICTION_FIELD_TYPE{"prediction_field_type"}; +const std::string PREDICTION_FIELD_TYPE_STRING{"string"}; +const std::string PREDICTION_FIELD_TYPE_INT{"int"}; +const std::string PREDICTION_FIELD_TYPE_BOOL{"bool"}; const std::string BALANCED_CLASS_LOSS{"balanced_class_loss"}; // Output @@ -51,7 +51,7 @@ CDataFrameTrainBoostedTreeClassifierRunner::parameterReader() { static const CDataFrameAnalysisConfigReader PARAMETER_READER{[] { auto theReader = CDataFrameTrainBoostedTreeRunner::parameterReader(); theReader.addParameter(NUM_TOP_CLASSES, CDataFrameAnalysisConfigReader::E_OptionalParameter); - theReader.addParameter(DEPENDENT_VARIABLE_TYPE, + theReader.addParameter(PREDICTION_FIELD_TYPE, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(BALANCED_CLASS_LOSS, CDataFrameAnalysisConfigReader::E_OptionalParameter); @@ -66,8 +66,8 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier : CDataFrameTrainBoostedTreeRunner{spec, parameters} { m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::size_t{0}); - m_DependentVariableType = - parameters[DEPENDENT_VARIABLE_TYPE].fallback(DEPENDENT_VARIABLE_TYPE_STRING); + m_PredictionFieldType = + parameters[PREDICTION_FIELD_TYPE].fallback(PREDICTION_FIELD_TYPE_STRING); this->boostedTreeFactory().balanceClassTrainingLoss( parameters[BALANCED_CLASS_LOSS].fallback(true)); @@ -170,16 +170,16 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writePredictedCategoryValue( const std::string& categoryValue, core::CRapidJsonConcurrentLineWriter& writer) const { - if (m_DependentVariableType == DEPENDENT_VARIABLE_TYPE_INT) { + if (m_PredictionFieldType == PREDICTION_FIELD_TYPE_INT) { double doubleValue; if (core::CStringUtils::stringToType(categoryValue, doubleValue)) { writer.Int64(static_cast(doubleValue)); return; } - } else if (m_DependentVariableType == DEPENDENT_VARIABLE_TYPE_BOOL) { + } else if (m_PredictionFieldType == PREDICTION_FIELD_TYPE_BOOL) { double doubleValue; if (core::CStringUtils::stringToType(categoryValue, doubleValue)) { - writer.Bool(static_cast(doubleValue) == 1); + writer.Bool(static_cast(doubleValue) == 1.0); return; } } diff --git a/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc b/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc index aade937d2b..1cb338d097 100644 --- a/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc +++ b/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc @@ -47,7 +47,7 @@ BOOST_AUTO_TEST_CASE(testPredictionFieldNameClash) { template void testWriteOneRow(const std::string& dependentVariableField, - const std::string& dependentVariableType, + const std::string& predictionFieldType, T (rapidjson::Value::*extract)() const, const std::vector& expectedPredictions) { // Prepare input data frame @@ -77,7 +77,7 @@ void testWriteOneRow(const std::string& dependentVariableField, rapidjson::Document jsonParameters; jsonParameters.Parse("{" " \"dependent_variable\": \"" + dependentVariableField + "\"," - " \"dependent_variable_type\": \"" + dependentVariableType + "\"" + " \"prediction_field_type\": \"" + predictionFieldType + "\"" "}"); const auto parameters{ api::CDataFrameTrainBoostedTreeClassifierRunner::parameterReader().read(jsonParameters)}; @@ -123,21 +123,21 @@ void testWriteOneRow(const std::string& dependentVariableField, } } -BOOST_AUTO_TEST_CASE(testWriteOneRowDependentVariableIsInt) { +BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsInt) { testWriteOneRow("x3", "int", &rapidjson::Value::GetInt, {1, 1, 1, 5, 5}); } -BOOST_AUTO_TEST_CASE(testWriteOneRowDependentVariableIsBool) { +BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsBool) { testWriteOneRow("x4", "bool", &rapidjson::Value::GetBool, {true, true, true, false, false}); } -BOOST_AUTO_TEST_CASE(testWriteOneRowDependentVariableIsString) { +BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsString) { testWriteOneRow("x5", "string", &rapidjson::Value::GetString, {"cat", "cat", "cat", "dog", "dog"}); } -BOOST_AUTO_TEST_CASE(testWriteOneRowDependentVariableTypeIsMissing) { +BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsMissing) { testWriteOneRow("x5", "", &rapidjson::Value::GetString, {"cat", "cat", "cat", "dog", "dog"}); }