Skip to content

Commit

Permalink
Merge pull request #1339 from ywcui1990/SDRClassifier
Browse files Browse the repository at this point in the history
Allow SDR classifier to handle multiple category
  • Loading branch information
scottpurdy authored Jun 7, 2017
2 parents 08900cf + e9a06bf commit 1b6520f
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 69 deletions.
92 changes: 55 additions & 37 deletions src/nupic/algorithms/SDRClassifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,28 @@ namespace nupic
}

void SDRClassifier::compute(
UInt recordNum, const vector<UInt>& patternNZ, UInt bucketIdx,
Real64 actValue, bool category, bool learn, bool infer,
UInt recordNum, const vector<UInt>& patternNZ, const vector<UInt>& bucketIdxList,
const vector<Real64>& actValueList, bool category, bool learn, bool infer,
ClassifierResult* result)
{
// update pattern history
patternNZHistory_.emplace_back(patternNZ.begin(), patternNZ.end());
recordNumHistory_.push_back(recordNum);
if (patternNZHistory_.size() > maxSteps_)
// ensures that recordNum increases monotonically
UInt lastRecordNum = -1;
if (recordNumHistory_.size() > 0)
{
patternNZHistory_.pop_front();
recordNumHistory_.pop_front();
lastRecordNum = recordNumHistory_[recordNumHistory_.size()-1];
if (recordNum < lastRecordNum)
NTA_THROW << "the record number has to increase monotonically";
}
// update pattern history if this is a new record
if (recordNumHistory_.size() == 0 || recordNum > lastRecordNum)
{
patternNZHistory_.emplace_back(patternNZ.begin(), patternNZ.end());
recordNumHistory_.push_back(recordNum);
if (patternNZHistory_.size() > maxSteps_)
{
patternNZHistory_.pop_front();
recordNumHistory_.pop_front();
}
}

// if input pattern has greater index than previously seen, update
Expand All @@ -116,38 +127,43 @@ namespace nupic
// if in inference mode, compute likelihood and update return value
if (infer)
{
infer_(patternNZ, bucketIdx, actValue, result);
infer_(patternNZ, actValueList, result);
}

// update weights if in learning mode
if (learn)
{
// if bucket is greater, update maxBucketIdx_ and augment weight
// matrix with zero-padding
if (bucketIdx > maxBucketIdx_)
for(size_t categoryI=0; categoryI < bucketIdxList.size(); categoryI++)
{
maxBucketIdx_ = bucketIdx;
for (const auto& step : steps_)
UInt bucketIdx = bucketIdxList[categoryI];
Real64 actValue = actValueList[categoryI];
// if bucket is greater, update maxBucketIdx_ and augment weight
// matrix with zero-padding
if (bucketIdx > maxBucketIdx_)
{
Matrix& weights = weightMatrix_.at(step);
weights.resize(maxInputIdx_ + 1, maxBucketIdx_ + 1);
maxBucketIdx_ = bucketIdx;
for (const auto& step : steps_)
{
Matrix& weights = weightMatrix_.at(step);
weights.resize(maxInputIdx_ + 1, maxBucketIdx_ + 1);
}
}
}

// update rolling averages of bucket values
while (actualValues_.size() <= maxBucketIdx_)
{
actualValues_.push_back(0.0);
actualValuesSet_.push_back(false);
}
if (!actualValuesSet_[bucketIdx] || category)
{
actualValues_[bucketIdx] = actValue;
actualValuesSet_[bucketIdx] = true;
} else {
actualValues_[bucketIdx] =
((1.0 - actValueAlpha_) * actualValues_[bucketIdx]) +
(actValueAlpha_ * actValue);
// update rolling averages of bucket values
while (actualValues_.size() <= maxBucketIdx_)
{
actualValues_.push_back(0.0);
actualValuesSet_.push_back(false);
}
if (!actualValuesSet_[bucketIdx] || category)
{
actualValues_[bucketIdx] = actValue;
actualValuesSet_[bucketIdx] = true;
} else {
actualValues_[bucketIdx] =
((1.0 - actValueAlpha_) * actualValues_[bucketIdx]) +
(actValueAlpha_ * actValue);
}
}

// compute errors and update weights
Expand All @@ -162,7 +178,7 @@ namespace nupic
// update weights
if (binary_search(steps_.begin(), steps_.end(), nSteps))
{
vector<Real64> error = calculateError_(bucketIdx,
vector<Real64> error = calculateError_(bucketIdxList,
learnPatternNZ, nSteps);
Matrix& weights = weightMatrix_.at(nSteps);
for (auto& bit : learnPatternNZ)
Expand All @@ -184,8 +200,8 @@ namespace nupic
return s.str().size();
}

void SDRClassifier::infer_(const vector<UInt>& patternNZ, UInt bucketIdx,
Real64 actValue, ClassifierResult* result)
void SDRClassifier::infer_(const vector<UInt>& patternNZ,
const vector<Real64>& actValue, ClassifierResult* result)
{
// add the actual values to the return value. For buckets that haven't
// been seen yet, the actual value doesn't matter since it will have
Expand All @@ -204,7 +220,7 @@ namespace nupic
{
(*actValueVector)[i] = 0;
} else {
(*actValueVector)[i] = actValue;
(*actValueVector)[i] = actValue[0];
}
}
}
Expand All @@ -227,7 +243,7 @@ namespace nupic
}
}

vector<Real64> SDRClassifier::calculateError_(UInt bucketIdx,
vector<Real64> SDRClassifier::calculateError_(const vector<UInt>& bucketIdxList,
const vector<UInt> patternNZ, UInt step)
{
// compute predicted likelihoods
Expand All @@ -244,7 +260,9 @@ namespace nupic

// compute target likelihoods
vector<Real64> targetDistribution (maxBucketIdx_ + 1, 0.0);
targetDistribution[bucketIdx] = 1.0;
Real64 numCategories = (Real64)bucketIdxList.size();
for(size_t i=0; i<bucketIdxList.size(); i++)
targetDistribution[bucketIdxList[i]] = 1.0 / numCategories;

axby(-1.0, likelihoods, 1.0, targetDistribution);
return likelihoods;
Expand Down
12 changes: 6 additions & 6 deletions src/nupic/algorithms/SDRClassifier.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ namespace nupic
* used when predicting each bucket.
*/
virtual void compute(
UInt recordNum, const vector<UInt>& patternNZ, UInt bucketIdx,
Real64 actValue, bool category, bool learn, bool infer,
UInt recordNum, const vector<UInt>& patternNZ, const vector<UInt>& bucketIdxList,
const vector<Real64>& actValueList, bool category, bool learn, bool infer,
ClassifierResult* result);

/**
Expand Down Expand Up @@ -161,12 +161,12 @@ namespace nupic

private:
// Helper function for inference mode
void infer_(const vector<UInt>& patternNZ, UInt bucketIdx,
Real64 actValue, ClassifierResult* result);
void infer_(const vector<UInt>& patternNZ,
const vector<Real64>& actValue, ClassifierResult* result);

// Helper function to compute the error signal in learning mode
vector<Real64> calculateError_(UInt bucketIdx, const vector<UInt>,
UInt step);
vector<Real64> calculateError_(const vector<UInt>& bucketIdxList,
const vector<UInt> patternNZ, UInt step);

// The list of prediction steps to learn and infer.
vector<UInt> steps_;
Expand Down
23 changes: 15 additions & 8 deletions src/nupic/bindings/algorithms.i
Original file line number Diff line number Diff line change
Expand Up @@ -1446,25 +1446,31 @@ void forceRetentionOfImageSensorLiteLibrary(void) {
noneSentinel = 3.14159

if type(classification["actValue"]) in (int, float):
actValue = classification["actValue"]
actValueList = [classification["actValue"]]
bucketIdxList = [classification["bucketIdx"]]
category = False
elif classification["actValue"] is None:
# Use the sentinel value so we know if it gets used in actualValues
# returned.
actValue = noneSentinel
actValueList = [noneSentinel]
# Turn learning off this step.
learn = False
category = False
# This does not get used when learning is disabled anyway.
classification["bucketIdx"] = 0
bucketIdxList = [0]
isNone = True
elif type(classification["actValue"]) is list:
actValueList = classification["actValue"]
bucketIdxList = classification["bucketIdx"]
category = False
else:
actValue = int(classification["bucketIdx"])
actValueList = [int(classification["bucketIdx"])]
bucketIdxList = [classification["bucketIdx"]]
category = True

result = self.convertedCompute(
recordNum, patternNZ, int(classification["bucketIdx"]),
actValue, category, learn, infer)
recordNum, patternNZ, bucketIdxList,
actValueList, category, learn, infer)

if isNone:
for i, v in enumerate(result["actualValues"]):
Expand Down Expand Up @@ -1549,11 +1555,12 @@ void forceRetentionOfImageSensorLiteLibrary(void) {
}

PyObject* convertedCompute(UInt recordNum, const vector<UInt>& patternNZ,
UInt bucketIdx, Real64 actValue, bool category,
const vector<UInt>& bucketIdxList,
const vector<Real64>& actValueList, bool category,
bool learn, bool infer)
{
ClassifierResult result;
self->compute(recordNum, patternNZ, bucketIdx, actValue, category,
self->compute(recordNum, patternNZ, bucketIdxList, actValueList, category,
learn, infer, &result);
PyObject* d = PyDict_New();
for (map<Int, vector<Real64>*>::const_iterator it = result.begin();
Expand Down
Loading

0 comments on commit 1b6520f

Please sign in to comment.