Skip to content

Commit

Permalink
Rename dependent_variable_type to prediction_field_type as that's how…
Browse files Browse the repository at this point in the history
… the field is really used in C++ code
  • Loading branch information
przemekwitek committed Dec 5, 2019
1 parent 8acd4c3 commit 2d085dc
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
2 changes: 1 addition & 1 deletion include/api/CDataFrameTrainBoostedTreeClassifierRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final

private:
std::size_t m_NumTopClasses;
std::string m_DependentVariableType;
std::string m_PredictionFieldType;
};

//! \brief Makes a core::CDataFrame boosted tree classification runner.
Expand Down
20 changes: 10 additions & 10 deletions lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ using TSizeVec = std::vector<std::size_t>;

// Configuration
const std::string NUM_TOP_CLASSES{"num_top_classes"};
const std::string DEPENDENT_VARIABLE_TYPE{"dependent_variable_type"};
const std::string DEPENDENT_VARIABLE_TYPE_STRING{"string"};
const std::string DEPENDENT_VARIABLE_TYPE_INT{"int"};
const std::string DEPENDENT_VARIABLE_TYPE_BOOL{"bool"};
const std::string PREDICTION_FIELD_TYPE{"prediction_field_type"};
const std::string PREDICTION_FIELD_TYPE_STRING{"string"};
const std::string PREDICTION_FIELD_TYPE_INT{"int"};
const std::string PREDICTION_FIELD_TYPE_BOOL{"bool"};
const std::string BALANCED_CLASS_LOSS{"balanced_class_loss"};

// Output
Expand All @@ -51,7 +51,7 @@ CDataFrameTrainBoostedTreeClassifierRunner::parameterReader() {
static const CDataFrameAnalysisConfigReader PARAMETER_READER{[] {
auto theReader = CDataFrameTrainBoostedTreeRunner::parameterReader();
theReader.addParameter(NUM_TOP_CLASSES, CDataFrameAnalysisConfigReader::E_OptionalParameter);
theReader.addParameter(DEPENDENT_VARIABLE_TYPE,
theReader.addParameter(PREDICTION_FIELD_TYPE,
CDataFrameAnalysisConfigReader::E_OptionalParameter);
theReader.addParameter(BALANCED_CLASS_LOSS,
CDataFrameAnalysisConfigReader::E_OptionalParameter);
Expand All @@ -66,8 +66,8 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier
: CDataFrameTrainBoostedTreeRunner{spec, parameters} {

m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::size_t{0});
m_DependentVariableType =
parameters[DEPENDENT_VARIABLE_TYPE].fallback(DEPENDENT_VARIABLE_TYPE_STRING);
m_PredictionFieldType =
parameters[PREDICTION_FIELD_TYPE].fallback(PREDICTION_FIELD_TYPE_STRING);
this->boostedTreeFactory().balanceClassTrainingLoss(
parameters[BALANCED_CLASS_LOSS].fallback(true));

Expand Down Expand Up @@ -170,16 +170,16 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writePredictedCategoryValue(
const std::string& categoryValue,
core::CRapidJsonConcurrentLineWriter& writer) const {

if (m_DependentVariableType == DEPENDENT_VARIABLE_TYPE_INT) {
if (m_PredictionFieldType == PREDICTION_FIELD_TYPE_INT) {
double doubleValue;
if (core::CStringUtils::stringToType(categoryValue, doubleValue)) {
writer.Int64(static_cast<std::int64_t>(doubleValue));
return;
}
} else if (m_DependentVariableType == DEPENDENT_VARIABLE_TYPE_BOOL) {
} else if (m_PredictionFieldType == PREDICTION_FIELD_TYPE_BOOL) {
double doubleValue;
if (core::CStringUtils::stringToType(categoryValue, doubleValue)) {
writer.Bool(static_cast<std::int64_t>(doubleValue) == 1);
writer.Bool(static_cast<std::int64_t>(doubleValue) == 1.0);
return;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ BOOST_AUTO_TEST_CASE(testPredictionFieldNameClash) {

template<typename T>
void testWriteOneRow(const std::string& dependentVariableField,
const std::string& dependentVariableType,
const std::string& predictionFieldType,
T (rapidjson::Value::*extract)() const,
const std::vector<T>& expectedPredictions) {
// Prepare input data frame
Expand Down Expand Up @@ -77,7 +77,7 @@ void testWriteOneRow(const std::string& dependentVariableField,
rapidjson::Document jsonParameters;
jsonParameters.Parse("{"
" \"dependent_variable\": \"" + dependentVariableField + "\","
" \"dependent_variable_type\": \"" + dependentVariableType + "\""
" \"prediction_field_type\": \"" + predictionFieldType + "\""
"}");
const auto parameters{
api::CDataFrameTrainBoostedTreeClassifierRunner::parameterReader().read(jsonParameters)};
Expand Down Expand Up @@ -123,21 +123,21 @@ void testWriteOneRow(const std::string& dependentVariableField,
}
}

BOOST_AUTO_TEST_CASE(testWriteOneRowDependentVariableIsInt) {
BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsInt) {
testWriteOneRow("x3", "int", &rapidjson::Value::GetInt, {1, 1, 1, 5, 5});
}

BOOST_AUTO_TEST_CASE(testWriteOneRowDependentVariableIsBool) {
BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsBool) {
testWriteOneRow("x4", "bool", &rapidjson::Value::GetBool,
{true, true, true, false, false});
}

BOOST_AUTO_TEST_CASE(testWriteOneRowDependentVariableIsString) {
BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsString) {
testWriteOneRow("x5", "string", &rapidjson::Value::GetString,
{"cat", "cat", "cat", "dog", "dog"});
}

BOOST_AUTO_TEST_CASE(testWriteOneRowDependentVariableTypeIsMissing) {
BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsMissing) {
testWriteOneRow("x5", "", &rapidjson::Value::GetString,
{"cat", "cat", "cat", "dog", "dog"});
}
Expand Down

0 comments on commit 2d085dc

Please sign in to comment.