diff --git a/src/nupic/algorithms/SDRClassifier.cpp b/src/nupic/algorithms/SDRClassifier.cpp index c037158914..94893e17df 100644 --- a/src/nupic/algorithms/SDRClassifier.cpp +++ b/src/nupic/algorithms/SDRClassifier.cpp @@ -84,17 +84,28 @@ namespace nupic } void SDRClassifier::compute( - UInt recordNum, const vector& patternNZ, UInt bucketIdx, - Real64 actValue, bool category, bool learn, bool infer, + UInt recordNum, const vector& patternNZ, const vector& bucketIdxList, + const vector& actValueList, bool category, bool learn, bool infer, ClassifierResult* result) { - // update pattern history - patternNZHistory_.emplace_back(patternNZ.begin(), patternNZ.end()); - recordNumHistory_.push_back(recordNum); - if (patternNZHistory_.size() > maxSteps_) + // ensures that recordNum increases monotonically + UInt lastRecordNum = -1; + if (recordNumHistory_.size() > 0) { - patternNZHistory_.pop_front(); - recordNumHistory_.pop_front(); + lastRecordNum = recordNumHistory_[recordNumHistory_.size()-1]; + if (recordNum < lastRecordNum) + NTA_THROW << "the record number has to increase monotonically"; + } + // update pattern history if this is a new record + if (recordNumHistory_.size() == 0 || recordNum > lastRecordNum) + { + patternNZHistory_.emplace_back(patternNZ.begin(), patternNZ.end()); + recordNumHistory_.push_back(recordNum); + if (patternNZHistory_.size() > maxSteps_) + { + patternNZHistory_.pop_front(); + recordNumHistory_.pop_front(); + } } // if input pattern has greater index than previously seen, update @@ -116,38 +127,43 @@ namespace nupic // if in inference mode, compute likelihood and update return value if (infer) { - infer_(patternNZ, bucketIdx, actValue, result); + infer_(patternNZ, actValueList, result); } // update weights if in learning mode if (learn) { - // if bucket is greater, update maxBucketIdx_ and augment weight - // matrix with zero-padding - if (bucketIdx > maxBucketIdx_) + for(size_t categoryI=0; categoryI < bucketIdxList.size(); categoryI++) { - maxBucketIdx_ = bucketIdx; - for (const auto& step : steps_) + UInt bucketIdx = bucketIdxList[categoryI]; + Real64 actValue = actValueList[categoryI]; + // if bucket is greater, update maxBucketIdx_ and augment weight + // matrix with zero-padding + if (bucketIdx > maxBucketIdx_) { - Matrix& weights = weightMatrix_.at(step); - weights.resize(maxInputIdx_ + 1, maxBucketIdx_ + 1); + maxBucketIdx_ = bucketIdx; + for (const auto& step : steps_) + { + Matrix& weights = weightMatrix_.at(step); + weights.resize(maxInputIdx_ + 1, maxBucketIdx_ + 1); + } } - } - // update rolling averages of bucket values - while (actualValues_.size() <= maxBucketIdx_) - { - actualValues_.push_back(0.0); - actualValuesSet_.push_back(false); - } - if (!actualValuesSet_[bucketIdx] || category) - { - actualValues_[bucketIdx] = actValue; - actualValuesSet_[bucketIdx] = true; - } else { - actualValues_[bucketIdx] = - ((1.0 - actValueAlpha_) * actualValues_[bucketIdx]) + - (actValueAlpha_ * actValue); + // update rolling averages of bucket values + while (actualValues_.size() <= maxBucketIdx_) + { + actualValues_.push_back(0.0); + actualValuesSet_.push_back(false); + } + if (!actualValuesSet_[bucketIdx] || category) + { + actualValues_[bucketIdx] = actValue; + actualValuesSet_[bucketIdx] = true; + } else { + actualValues_[bucketIdx] = + ((1.0 - actValueAlpha_) * actualValues_[bucketIdx]) + + (actValueAlpha_ * actValue); + } } // compute errors and update weights @@ -162,7 +178,7 @@ namespace nupic // update weights if (binary_search(steps_.begin(), steps_.end(), nSteps)) { - vector error = calculateError_(bucketIdx, + vector error = calculateError_(bucketIdxList, learnPatternNZ, nSteps); Matrix& weights = weightMatrix_.at(nSteps); for (auto& bit : learnPatternNZ) @@ -184,8 +200,8 @@ namespace nupic return s.str().size(); } - void SDRClassifier::infer_(const vector& patternNZ, UInt bucketIdx, - Real64 actValue, ClassifierResult* result) + void SDRClassifier::infer_(const vector& patternNZ, + const vector& actValue, ClassifierResult* result) { // add the actual values to the return value. For buckets that haven't // been seen yet, the actual value doesn't matter since it will have @@ -204,7 +220,7 @@ namespace nupic { (*actValueVector)[i] = 0; } else { - (*actValueVector)[i] = actValue; + (*actValueVector)[i] = actValue[0]; } } } @@ -227,7 +243,7 @@ namespace nupic } } - vector SDRClassifier::calculateError_(UInt bucketIdx, + vector SDRClassifier::calculateError_(const vector& bucketIdxList, const vector patternNZ, UInt step) { // compute predicted likelihoods @@ -244,7 +260,9 @@ namespace nupic // compute target likelihoods vector targetDistribution (maxBucketIdx_ + 1, 0.0); - targetDistribution[bucketIdx] = 1.0; + Real64 numCategories = (Real64)bucketIdxList.size(); + for(size_t i=0; i& patternNZ, UInt bucketIdx, - Real64 actValue, bool category, bool learn, bool infer, + UInt recordNum, const vector& patternNZ, const vector& bucketIdxList, + const vector& actValueList, bool category, bool learn, bool infer, ClassifierResult* result); /** @@ -161,12 +161,12 @@ namespace nupic private: // Helper function for inference mode - void infer_(const vector& patternNZ, UInt bucketIdx, - Real64 actValue, ClassifierResult* result); + void infer_(const vector& patternNZ, + const vector& actValue, ClassifierResult* result); // Helper function to compute the error signal in learning mode - vector calculateError_(UInt bucketIdx, const vector, - UInt step); + vector calculateError_(const vector& bucketIdxList, + const vector patternNZ, UInt step); // The list of prediction steps to learn and infer. vector steps_; diff --git a/src/nupic/bindings/algorithms.i b/src/nupic/bindings/algorithms.i index f09f7c307a..df47f76511 100644 --- a/src/nupic/bindings/algorithms.i +++ b/src/nupic/bindings/algorithms.i @@ -1446,25 +1446,31 @@ void forceRetentionOfImageSensorLiteLibrary(void) { noneSentinel = 3.14159 if type(classification["actValue"]) in (int, float): - actValue = classification["actValue"] + actValueList = [classification["actValue"]] + bucketIdxList = [classification["bucketIdx"]] category = False elif classification["actValue"] is None: # Use the sentinel value so we know if it gets used in actualValues # returned. - actValue = noneSentinel + actValueList = [noneSentinel] # Turn learning off this step. learn = False category = False # This does not get used when learning is disabled anyway. - classification["bucketIdx"] = 0 + bucketIdxList = [0] isNone = True + elif type(classification["actValue"]) is list: + actValueList = classification["actValue"] + bucketIdxList = classification["bucketIdx"] + category = False else: - actValue = int(classification["bucketIdx"]) + actValueList = [int(classification["bucketIdx"])] + bucketIdxList = [classification["bucketIdx"]] category = True result = self.convertedCompute( - recordNum, patternNZ, int(classification["bucketIdx"]), - actValue, category, learn, infer) + recordNum, patternNZ, bucketIdxList, + actValueList, category, learn, infer) if isNone: for i, v in enumerate(result["actualValues"]): @@ -1549,11 +1555,12 @@ void forceRetentionOfImageSensorLiteLibrary(void) { } PyObject* convertedCompute(UInt recordNum, const vector& patternNZ, - UInt bucketIdx, Real64 actValue, bool category, + const vector& bucketIdxList, + const vector& actValueList, bool category, bool learn, bool infer) { ClassifierResult result; - self->compute(recordNum, patternNZ, bucketIdx, actValue, category, + self->compute(recordNum, patternNZ, bucketIdxList, actValueList, category, learn, infer, &result); PyObject* d = PyDict_New(); for (map*>::const_iterator it = result.begin(); diff --git a/src/test/unit/algorithms/SDRClassifierTest.cpp b/src/test/unit/algorithms/SDRClassifierTest.cpp index 9465005276..134cd9e3f2 100644 --- a/src/test/unit/algorithms/SDRClassifierTest.cpp +++ b/src/test/unit/algorithms/SDRClassifierTest.cpp @@ -54,16 +54,24 @@ namespace input1.push_back(1); input1.push_back(5); input1.push_back(9); + vector bucketIdxList1; + bucketIdxList1.push_back(4); + vector actValueList1; + actValueList1.push_back(34.7); ClassifierResult result1; - c.compute(0, input1, 4, 34.7, false, true, true, &result1); + c.compute(0, input1, bucketIdxList1, actValueList1, false, true, true, &result1); // Create a vector of input bit indices vector input2; input2.push_back(1); input2.push_back(5); input2.push_back(9); + vector bucketIdxList2; + bucketIdxList2.push_back(4); + vector actValueList2; + actValueList2.push_back(34.7); ClassifierResult result2; - c.compute(1, input2, 4, 34.7, false, true, true, &result2); + c.compute(1, input2, bucketIdxList2, actValueList2, false, true, true, &result2); { bool foundMinus1 = false; @@ -119,11 +127,15 @@ namespace input1.push_back(1); input1.push_back(5); input1.push_back(9); + vector bucketIdxList; + bucketIdxList.push_back(4); + vector actValueList; + actValueList.push_back(34.7); ClassifierResult result1; for (UInt i = 0; i < 10; ++i) { ClassifierResult result1; - c.compute(i, input1, 4, 34.7, false, true, true, &result1); + c.compute(i, input1, bucketIdxList, actValueList, false, true, true, &result1); } { @@ -151,38 +163,60 @@ namespace steps.push_back(1); SDRClassifier c = SDRClassifier(steps, 1.0, 0.1, 0); - // Create a input vectors + // Create a input vector vector input1; input1.push_back(1); input1.push_back(5); input1.push_back(9); + vector bucketIdxList1; + bucketIdxList1.push_back(4); + vector actValueList1; + actValueList1.push_back(34.7); - // Create a input vectors + // Create a input vector vector input2; input2.push_back(0); input2.push_back(6); input2.push_back(9); input2.push_back(11); + vector bucketIdxList2; + bucketIdxList2.push_back(5); + vector actValueList2; + actValueList2.push_back(41.7); - // Create a input vectors + // Create input vectors vector input3; input3.push_back(6); input3.push_back(9); + vector bucketIdxList3; + bucketIdxList3.push_back(5); + vector actValueList3; + actValueList3.push_back(44.9); + + vector bucketIdxList4; + bucketIdxList4.push_back(4); + vector actValueList4; + actValueList4.push_back(42.9); + + vector bucketIdxList5; + bucketIdxList5.push_back(4); + vector actValueList5; + actValueList5.push_back(34.7); ClassifierResult result1; - c.compute(0, input1, 4, 34.7, false, true, true, &result1); + c.compute(0, input1, bucketIdxList1, actValueList1, false, true, true, &result1); ClassifierResult result2; - c.compute(1, input2, 5, 41.7, false, true, true, &result2); + c.compute(1, input2, bucketIdxList2, actValueList2, false, true, true, &result2); ClassifierResult result3; - c.compute(2, input3, 5, 44.9, false, true, true, &result3); + c.compute(2, input3, bucketIdxList3, actValueList3, false, true, true, &result3); ClassifierResult result4; - c.compute(3, input1, 4, 42.9, false, true, true, &result4); + c.compute(3, input1, bucketIdxList4, actValueList4, false, true, true, &result4); ClassifierResult result5; - c.compute(4, input1, 4, 34.7, false, true, true, &result5); + c.compute(4, input1, bucketIdxList5, actValueList5, false, true, true, &result5); { bool foundMinus1 = false; @@ -232,6 +266,78 @@ namespace } + TEST(SDRClassifierTest, MultipleCategory) + { + // Test multiple category classification with single compute calls + // This test is ported from the Python unit test + vector steps; + steps.push_back(0); + SDRClassifier c = SDRClassifier(steps, 1.0, 0.1, 0); + + // Create a input vectors + vector input1; + input1.push_back(1); + input1.push_back(3); + input1.push_back(5); + vector bucketIdxList1; + bucketIdxList1.push_back(0); + bucketIdxList1.push_back(1); + vector actValueList1; + actValueList1.push_back(0); + actValueList1.push_back(1); + + // Create a input vectors + vector input2; + input2.push_back(2); + input2.push_back(4); + input2.push_back(6); + vector bucketIdxList2; + bucketIdxList2.push_back(2); + bucketIdxList2.push_back(3); + vector actValueList2; + actValueList2.push_back(2); + actValueList2.push_back(3); + + int recordNum=0; + for(int i=0; i<1000; i++) + { + ClassifierResult result1; + ClassifierResult result2; + c.compute(recordNum, input1, bucketIdxList1, actValueList1, false, true, true, &result1); + recordNum += 1; + c.compute(recordNum, input2, bucketIdxList2, actValueList2, false, true, true, &result2); + recordNum += 1; + } + + ClassifierResult result1; + ClassifierResult result2; + c.compute(recordNum, input1, bucketIdxList1, actValueList1, false, true, true, &result1); + recordNum += 1; + c.compute(recordNum, input2, bucketIdxList2, actValueList2, false, true, true, &result2); + recordNum += 1; + + for (auto it = result1.begin(); it != result1.end(); ++it) + { + if (it->first == 0) { + ASSERT_LT(fabs(it->second->at(0) - 0.5), 0.1) + << "Incorrect prediction for bucket 0 (expected=0.5)"; + ASSERT_LT(fabs(it->second->at(1) - 0.5), 0.1) + << "Incorrect prediction for bucket 1 (expected=0.5)"; + } + } + + for (auto it = result2.begin(); it != result2.end(); ++it) + { + if (it->first == 0) { + ASSERT_LT(fabs(it->second->at(2) - 0.5), 0.1) + << "Incorrect prediction for bucket 2 (expected=0.5)"; + ASSERT_LT(fabs(it->second->at(3) - 0.5), 0.1) + << "Incorrect prediction for bucket 3 (expected=0.5)"; + } + } + + } + TEST(SDRClassifierTest, SaveLoad) { vector steps; @@ -244,8 +350,12 @@ namespace input1.push_back(1); input1.push_back(5); input1.push_back(9); + vector bucketIdxList1; + bucketIdxList1.push_back(4); + vector actValueList1; + actValueList1.push_back(34.7); ClassifierResult result; - c1.compute(0, input1, 4, 34.7, false, true, true, &result); + c1.compute(0, input1, bucketIdxList1, actValueList1, false, true, true, &result); { stringstream ss; @@ -256,8 +366,8 @@ namespace ASSERT_TRUE(c1 == c2); ClassifierResult result1, result2; - c1.compute(1, input1, 4, 35.7, false, true, true, &result1); - c2.compute(1, input1, 4, 35.7, false, true, true, &result2); + c1.compute(1, input1, bucketIdxList1, actValueList1, false, true, true, &result1); + c2.compute(1, input1, bucketIdxList1, actValueList1, false, true, true, &result2); ASSERT_TRUE(result1 == result2); } @@ -275,16 +385,24 @@ namespace input1.push_back(1); input1.push_back(5); input1.push_back(9); + vector bucketIdxList1; + bucketIdxList1.push_back(4); + vector actValueList1; + actValueList1.push_back(34.7); ClassifierResult trainResult1; - c1.compute(0, input1, 4, 34.7, false, true, true, &trainResult1); + c1.compute(0, input1, bucketIdxList1, actValueList1, false, true, true, &trainResult1); // Create a vector of input bit indices vector input2; input2.push_back(0); input2.push_back(8); input2.push_back(9); + vector bucketIdxList2; + bucketIdxList2.push_back(2); + vector actValueList2; + actValueList2.push_back(24.7); ClassifierResult trainResult2; - c1.compute(1, input2, 2, 24.7, false, true, true, &trainResult2); + c1.compute(1, input2, bucketIdxList2, actValueList2, false, true, true, &trainResult2); { stringstream ss; @@ -295,8 +413,8 @@ namespace ASSERT_TRUE(c1 == c2); ClassifierResult result1, result2; - c1.compute(2, input1, 4, 35.7, false, true, true, &result1); - c2.compute(2, input1, 4, 35.7, false, true, true, &result2); + c1.compute(2, input1, bucketIdxList1, actValueList1, false, true, true, &result1); + c2.compute(2, input1, bucketIdxList1, actValueList1, false, true, true, &result2); ASSERT_TRUE(result1 == result2); }