Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix OPF prediction models #3665

Merged
merged 11 commits into from
Jun 7, 2017
5 changes: 2 additions & 3 deletions docs/examples/network/complete-network-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,8 @@ def runHotgym(numRecords):
numRecords = min(numRecords, dataSource.getDataRowCount())
network = createNetwork(dataSource)

# Set predicted field index. It needs to be the same index as the data source.
predictedIdx = dataSource.getFieldNames().index("consumption")
network.regions["sensor"].setParameter("predictedFieldIdx", predictedIdx)
# Set predicted field
network.regions["sensor"].setParameter("predictedField", "consumption")

# Enable learning for all regions.
network.regions["SP"].setParameter("learningMode", 1)
Expand Down
4 changes: 1 addition & 3 deletions docs/examples/network/example-set-predicted-field.py
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
predictedIdx = dataSource.getFieldNames().index("consumption")

network.regions["sensor"].setParameter("predictedFieldIdx", predictedIdx)
network.regions["sensor"].setParameter("predictedField", "consumption")
5 changes: 1 addition & 4 deletions examples/network/hierarchy_network_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,7 @@ def createRecordSensor(network, name, dataSource):
sensorRegion.encoder = createEncoder()

# Specify which sub-encoder should be used for "actValueOut"
predictedFieldIdx = dataSource.getFieldNames().index("consumption")
network.regions[name].setParameter("predictedFieldIdx",
numpy.array([predictedFieldIdx],
dtype="uint32"))
network.regions[name].setParameter("predictedField", "consumption")

# Specify the dataSource as a file record stream instance
sensorRegion.dataSource = dataSource
Expand Down
11 changes: 11 additions & 0 deletions src/nupic/frameworks/opf/htm_prediction_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,13 @@ def setFieldStatistics(self, fieldStats):
encoder.setFieldStats('',fieldStats)


def enableInference(self, inferenceArgs=None):
super(HTMPredictionModel, self).enableInference(inferenceArgs)
if inferenceArgs is not None and "predictedField" in inferenceArgs:
self._getSensorRegion().setParameter("predictedField",
inferenceArgs["predictedField"])


def enableLearning(self):
super(HTMPredictionModel, self).enableLearning()
self.setEncoderLearning(True)
Expand Down Expand Up @@ -1165,6 +1172,10 @@ def __createCLANetwork(self, sensorParams, spEnable, spParams, tmEnable,
clParams))
n.addRegion("Classifier", "py.%s" % str(clRegionName), json.dumps(clParams))

n.link("sensor", "Classifier", "UniformLink", "", srcOutput="actValueOut",
destInput="actValueIn")
n.link("sensor", "Classifier", "UniformLink", "", srcOutput="bucketIdxOut",
destInput="bucketIdxIn")
n.link("sensor", "Classifier", "UniformLink", "", srcOutput="categoryOut",
destInput="categoryIn")

Expand Down
45 changes: 20 additions & 25 deletions src/nupic/regions/record_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def getSpec(cls):
actValueOut=dict(
description="Actual value of the field to predict. The index of the "
"field to predict must be specified via the parameter "
"predictedFieldIdx. If this parameter is not set, then "
"predictedField. If this parameter is not set, then "
"actValueOut won't be populated.",
dataType="Real32",
count=0,
Expand All @@ -121,7 +121,7 @@ def getSpec(cls):
description="Active index of the encoder bucket for the "
"actual value of the field to predict. The index of the "
"field to predict must be specified via the parameter "
"predictedFieldIdx. If this parameter is not set, then "
"predictedField. If this parameter is not set, then "
"actValueOut won't be populated.",
dataType="UInt64",
count=0,
Expand Down Expand Up @@ -206,15 +206,13 @@ def getSpec(cls):
accessMode="ReadWrite",
count=1,
constraints=""),
predictedFieldIdx=dict(
description="Index of the field to be predicted. Needs to be "
"consistent with the data source indexing. "
"Default value is < 0 which means that no particular "
"field is selected. This will result in the outputs "
"actValueOut and bucketIdxOut not being populated.",
dataType="UInt32",
predictedField=dict(
description="The field to be predicted. This will result in the "
"outputs actValueOut and bucketIdxOut not being "
"populated.",
dataType="Byte",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dataType is confusing because it's a string. Should the description say "The field name to be predicted..."?

accessMode="ReadWrite",
count=1,
count=0,
defaultValue=-1,
constraints=""),
topDownMode=dict(
Expand Down Expand Up @@ -246,10 +244,10 @@ def __init__(self, verbosity=0, numCategories=1):
self.numCategories = numCategories
self._iterNum = 0

# Optional index of the field for which we want to populate bucketIdxOut
# and actValueOut. If predictedFieldIdx < 0, then bucketIdxOut and
# Optional field for which we want to populate bucketIdxOut
# and actValueOut. If predictedField is None, then bucketIdxOut and
# actValueOut won't be populated.
self.predictedFieldIdx = -1
self.predictedField = None

# lastRecord is the last record returned. Used for debugging only
self.lastRecord = None
Expand Down Expand Up @@ -380,14 +378,8 @@ def compute(self, inputs, outputs):
self.encoder.encodeIntoArray(data, outputs["dataOut"])

# If there is a field to predict, set bucketIdxOut and actValueOut.
if self.predictedFieldIdx >= 0:
fields = self.dataSource.getFieldNames()
if self.predictedFieldIdx >= len(fields):
raise ValueError("predictedFieldIdx (%s) must be strictly less than "
"the number of fields (%s). Fields: %s."
% (self.predictedFieldIdx, len(fields), fields))
predictedField = fields[self.predictedFieldIdx]
encoders = [e for e in self.encoder.encoders if e[0] == predictedField]
if self.predictedField is not None:
encoders = [e for e in self.encoder.encoders if e[0] == self.predictedField]
if len(encoders) == 0:
raise ValueError("There is no encoder for set for the predicted "
"field: %s" % predictedField)
Expand All @@ -397,9 +389,12 @@ def compute(self, inputs, outputs):
else:
encoder = encoders[0][1]

actualValue = data[predictedField]
actualValue = data[self.predictedField]
outputs["bucketIdxOut"][:] = encoder.getBucketIndices(actualValue)
outputs["actValueOut"][:] = actualValue
if isinstance(actualValue, str):
outputs["actValueOut"][:] = encoder.getBucketIndices(actualValue)
else:
outputs["actValueOut"][:] = actualValue

# Write out the scalar values obtained from they data source.
outputs["sourceOut"][:] = self.encoder.getScalars(data)
Expand Down Expand Up @@ -594,8 +589,8 @@ def setParameter(self, parameterName, index, parameterValue):
"""
if parameterName == 'topDownMode':
self.topDownMode = parameterValue
elif parameterName == 'predictedFieldIdx':
self.predictedFieldIdx = parameterValue
elif parameterName == 'predictedField':
self.predictedField = parameterValue
else:
raise Exception('Unknown parameter: ' + parameterName)

Expand Down
3 changes: 1 addition & 2 deletions tests/unit/nupic/regions/record_sensor_region_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def _createNetwork():
sensorRegion.dataSource = dataSource

# Get and set what field index we want to predict.
predictedIdx = dataSource.getFieldNames().index('consumption')
network.regions['sensor'].setParameter('predictedFieldIdx', predictedIdx)
network.regions['sensor'].setParameter('predictedField', 'consumption')

return network

Expand Down
3 changes: 1 addition & 2 deletions tests/unit/nupic/regions/sdr_classifier_region_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def _createNetwork():
sensorRegion.dataSource = dataSource

# Get and set what field index we want to predict.
predictedIdx = dataSource.getFieldNames().index('consumption')
network.regions['sensor'].setParameter('predictedFieldIdx', predictedIdx)
network.regions['sensor'].setParameter('predictedField', 'consumption')

return network

Expand Down