-
Notifications
You must be signed in to change notification settings - Fork 279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP SP: add equals operator== #1384
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -639,6 +639,119 @@ void SpatialPooler::stripUnlearnedColumns(UInt activeArray[]) const | |
} | ||
} | ||
|
||
bool SpatialPooler::operator==(const SpatialPooler &sp) const | ||
{ | ||
const UInt numColumns = this->getNumColumns(); | ||
const UInt numInputs = this->getNumInputs(); | ||
|
||
ASSERT_TRUE(this->getNumColumns() == sp.getNumColumns()); | ||
ASSERT_TRUE(this->getNumInputs() == sp.getNumInputs()); | ||
ASSERT_TRUE(this->getPotentialRadius() == | ||
sp.getPotentialRadius()); | ||
ASSERT_TRUE(this->getPotentialPct() == sp.getPotentialPct()); | ||
ASSERT_TRUE(this->getGlobalInhibition() == | ||
sp.getGlobalInhibition()); | ||
ASSERT_TRUE(this->getNumActiveColumnsPerInhArea() == | ||
sp.getNumActiveColumnsPerInhArea()); | ||
ASSERT_TRUE(almost_eq(this->getLocalAreaDensity(), | ||
sp.getLocalAreaDensity())); | ||
ASSERT_TRUE(this->getStimulusThreshold() == | ||
sp.getStimulusThreshold()); | ||
ASSERT_TRUE(this->getDutyCyclePeriod() == sp.getDutyCyclePeriod()); | ||
ASSERT_TRUE(almost_eq(this->getBoostStrength(), sp.getBoostStrength())); | ||
ASSERT_TRUE(this->getIterationNum() == sp.getIterationNum()); | ||
ASSERT_TRUE(this->getIterationLearnNum() == | ||
sp.getIterationLearnNum()); | ||
ASSERT_TRUE(this->getSpVerbosity() == sp.getSpVerbosity()); | ||
ASSERT_TRUE(this->getWrapAround() == sp.getWrapAround()); | ||
ASSERT_TRUE(this->getUpdatePeriod() == sp.getUpdatePeriod()); | ||
ASSERT_TRUE(almost_eq(this->getSynPermTrimThreshold(), | ||
sp.getSynPermTrimThreshold())); | ||
cout << "check: " << this->getSynPermActiveInc() << " " << | ||
sp.getSynPermActiveInc() << endl; | ||
ASSERT_TRUE(almost_eq(this->getSynPermActiveInc(), | ||
sp.getSynPermActiveInc())); | ||
ASSERT_TRUE(almost_eq(this->getSynPermInactiveDec(), | ||
sp.getSynPermInactiveDec())); | ||
ASSERT_TRUE(almost_eq(this->getSynPermBelowStimulusInc(), | ||
sp.getSynPermBelowStimulusInc())); | ||
ASSERT_TRUE(almost_eq(this->getSynPermConnected(), | ||
sp.getSynPermConnected())); | ||
ASSERT_TRUE(almost_eq(this->getMinPctOverlapDutyCycles(), | ||
sp.getMinPctOverlapDutyCycles())); | ||
|
||
|
||
auto boostFactors1 = new Real[numColumns]; | ||
auto boostFactors2 = new Real[numColumns]; | ||
this->getBoostFactors(boostFactors1); | ||
sp.getBoostFactors(boostFactors2); | ||
ASSERT_TRUE(check_vector_eq(boostFactors1, boostFactors2, numColumns)); | ||
delete[] boostFactors1; | ||
delete[] boostFactors2; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like some optimization. Q: keep the |
||
|
||
auto overlapDutyCycles1 = new Real[numColumns]; | ||
auto overlapDutyCycles2 = new Real[numColumns]; | ||
this->getOverlapDutyCycles(overlapDutyCycles1); | ||
sp.getOverlapDutyCycles(overlapDutyCycles2); | ||
ASSERT_TRUE(check_vector_eq(overlapDutyCycles1, overlapDutyCycles2, numColumns)); | ||
delete[] overlapDutyCycles1; | ||
delete[] overlapDutyCycles2; | ||
|
||
auto activeDutyCycles1 = new Real[numColumns]; | ||
auto activeDutyCycles2 = new Real[numColumns]; | ||
this->getActiveDutyCycles(activeDutyCycles1); | ||
sp.getActiveDutyCycles(activeDutyCycles2); | ||
ASSERT_TRUE(check_vector_eq(activeDutyCycles1, activeDutyCycles2, numColumns)); | ||
delete[] activeDutyCycles1; | ||
delete[] activeDutyCycles2; | ||
|
||
auto minOverlapDutyCycles1 = new Real[numColumns]; | ||
auto minOverlapDutyCycles2 = new Real[numColumns]; | ||
this->getMinOverlapDutyCycles(minOverlapDutyCycles1); | ||
sp.getMinOverlapDutyCycles(minOverlapDutyCycles2); | ||
ASSERT_TRUE(check_vector_eq(minOverlapDutyCycles1, minOverlapDutyCycles2, numColumns)); | ||
delete[] minOverlapDutyCycles1; | ||
delete[] minOverlapDutyCycles2; | ||
|
||
for (UInt i = 0; i < numColumns; i++) { | ||
auto potential1 = new UInt[numInputs]; | ||
auto potential2 = new UInt[numInputs]; | ||
this->getPotential(i, potential1); | ||
sp.getPotential(i, potential2); | ||
ASSERT_TRUE(check_vector_eq(potential1, potential2, numInputs)); | ||
delete[] potential1; | ||
delete[] potential2; | ||
} | ||
|
||
for (UInt i = 0; i < numColumns; i++) { | ||
auto perm1 = new Real[numInputs]; | ||
auto perm2 = new Real[numInputs]; | ||
this->getPermanence(i, perm1); | ||
sp.getPermanence(i, perm2); | ||
ASSERT_TRUE(check_vector_eq(perm1, perm2, numInputs)); | ||
delete[] perm1; | ||
delete[] perm2; | ||
} | ||
|
||
for (UInt i = 0; i < numColumns; i++) { | ||
auto con1 = new UInt[numInputs]; | ||
auto con2 = new UInt[numInputs]; | ||
this->getConnectedSynapses(i, con1); | ||
sp.getConnectedSynapses(i, con2); | ||
ASSERT_TRUE(check_vector_eq(con1, con2, numInputs)); | ||
delete[] con1; | ||
delete[] con2; | ||
} | ||
auto conCounts1 = new UInt[numColumns]; | ||
auto conCounts2 = new UInt[numColumns]; | ||
this->getConnectedCounts(conCounts1); | ||
sp.getConnectedCounts(conCounts2); | ||
ASSERT_TRUE(check_vector_eq(conCounts1, conCounts2, numColumns)); | ||
delete[] conCounts1; | ||
delete[] conCounts2; | ||
|
||
return true; | ||
} | ||
|
||
void SpatialPooler::toDense_(vector<UInt>& sparse, | ||
UInt dense[], | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -126,118 +126,6 @@ namespace { | |
return true; | ||
} | ||
|
||
void check_spatial_eq(SpatialPooler sp1, SpatialPooler sp2) | ||
{ | ||
UInt numColumns = sp1.getNumColumns(); | ||
UInt numInputs = sp2.getNumInputs(); | ||
|
||
ASSERT_TRUE(sp1.getNumColumns() == sp2.getNumColumns()); | ||
ASSERT_TRUE(sp1.getNumInputs() == sp2.getNumInputs()); | ||
ASSERT_TRUE(sp1.getPotentialRadius() == | ||
sp2.getPotentialRadius()); | ||
ASSERT_TRUE(sp1.getPotentialPct() == sp2.getPotentialPct()); | ||
ASSERT_TRUE(sp1.getGlobalInhibition() == | ||
sp2.getGlobalInhibition()); | ||
ASSERT_TRUE(sp1.getNumActiveColumnsPerInhArea() == | ||
sp2.getNumActiveColumnsPerInhArea()); | ||
ASSERT_TRUE(almost_eq(sp1.getLocalAreaDensity(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO: many uses of Although, for the comparisons we need a boolean, not an assert; so we'd have to keep the function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We'll just implement the small |
||
sp2.getLocalAreaDensity())); | ||
ASSERT_TRUE(sp1.getStimulusThreshold() == | ||
sp2.getStimulusThreshold()); | ||
ASSERT_TRUE(sp1.getDutyCyclePeriod() == sp2.getDutyCyclePeriod()); | ||
ASSERT_TRUE(almost_eq(sp1.getBoostStrength(), sp2.getBoostStrength())); | ||
ASSERT_TRUE(sp1.getIterationNum() == sp2.getIterationNum()); | ||
ASSERT_TRUE(sp1.getIterationLearnNum() == | ||
sp2.getIterationLearnNum()); | ||
ASSERT_TRUE(sp1.getSpVerbosity() == sp2.getSpVerbosity()); | ||
ASSERT_TRUE(sp1.getWrapAround() == sp2.getWrapAround()); | ||
ASSERT_TRUE(sp1.getUpdatePeriod() == sp2.getUpdatePeriod()); | ||
ASSERT_TRUE(almost_eq(sp1.getSynPermTrimThreshold(), | ||
sp2.getSynPermTrimThreshold())); | ||
cout << "check: " << sp1.getSynPermActiveInc() << " " << | ||
sp2.getSynPermActiveInc() << endl; | ||
ASSERT_TRUE(almost_eq(sp1.getSynPermActiveInc(), | ||
sp2.getSynPermActiveInc())); | ||
ASSERT_TRUE(almost_eq(sp1.getSynPermInactiveDec(), | ||
sp2.getSynPermInactiveDec())); | ||
ASSERT_TRUE(almost_eq(sp1.getSynPermBelowStimulusInc(), | ||
sp2.getSynPermBelowStimulusInc())); | ||
ASSERT_TRUE(almost_eq(sp1.getSynPermConnected(), | ||
sp2.getSynPermConnected())); | ||
ASSERT_TRUE(almost_eq(sp1.getMinPctOverlapDutyCycles(), | ||
sp2.getMinPctOverlapDutyCycles())); | ||
|
||
|
||
auto boostFactors1 = new Real[numColumns]; | ||
auto boostFactors2 = new Real[numColumns]; | ||
sp1.getBoostFactors(boostFactors1); | ||
sp2.getBoostFactors(boostFactors2); | ||
ASSERT_TRUE(check_vector_eq(boostFactors1, boostFactors2, numColumns)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another function is
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. check_vector_eq should be implemented as |
||
delete[] boostFactors1; | ||
delete[] boostFactors2; | ||
|
||
auto overlapDutyCycles1 = new Real[numColumns]; | ||
auto overlapDutyCycles2 = new Real[numColumns]; | ||
sp1.getOverlapDutyCycles(overlapDutyCycles1); | ||
sp2.getOverlapDutyCycles(overlapDutyCycles2); | ||
ASSERT_TRUE(check_vector_eq(overlapDutyCycles1, overlapDutyCycles2, numColumns)); | ||
delete[] overlapDutyCycles1; | ||
delete[] overlapDutyCycles2; | ||
|
||
auto activeDutyCycles1 = new Real[numColumns]; | ||
auto activeDutyCycles2 = new Real[numColumns]; | ||
sp1.getActiveDutyCycles(activeDutyCycles1); | ||
sp2.getActiveDutyCycles(activeDutyCycles2); | ||
ASSERT_TRUE(check_vector_eq(activeDutyCycles1, activeDutyCycles2, numColumns)); | ||
delete[] activeDutyCycles1; | ||
delete[] activeDutyCycles2; | ||
|
||
auto minOverlapDutyCycles1 = new Real[numColumns]; | ||
auto minOverlapDutyCycles2 = new Real[numColumns]; | ||
sp1.getMinOverlapDutyCycles(minOverlapDutyCycles1); | ||
sp2.getMinOverlapDutyCycles(minOverlapDutyCycles2); | ||
ASSERT_TRUE(check_vector_eq(minOverlapDutyCycles1, minOverlapDutyCycles2, numColumns)); | ||
delete[] minOverlapDutyCycles1; | ||
delete[] minOverlapDutyCycles2; | ||
|
||
for (UInt i = 0; i < numColumns; i++) { | ||
auto potential1 = new UInt[numInputs]; | ||
auto potential2 = new UInt[numInputs]; | ||
sp1.getPotential(i, potential1); | ||
sp2.getPotential(i, potential2); | ||
ASSERT_TRUE(check_vector_eq(potential1, potential2, numInputs)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the case where check is within a for loop, so we'll have to create a variable that covers the whole loop:
|
||
delete[] potential1; | ||
delete[] potential2; | ||
} | ||
|
||
for (UInt i = 0; i < numColumns; i++) { | ||
auto perm1 = new Real[numInputs]; | ||
auto perm2 = new Real[numInputs]; | ||
sp1.getPermanence(i, perm1); | ||
sp2.getPermanence(i, perm2); | ||
ASSERT_TRUE(check_vector_eq(perm1, perm2, numInputs)); | ||
delete[] perm1; | ||
delete[] perm2; | ||
} | ||
|
||
for (UInt i = 0; i < numColumns; i++) { | ||
auto con1 = new UInt[numInputs]; | ||
auto con2 = new UInt[numInputs]; | ||
sp1.getConnectedSynapses(i, con1); | ||
sp2.getConnectedSynapses(i, con2); | ||
ASSERT_TRUE(check_vector_eq(con1, con2, numInputs)); | ||
delete[] con1; | ||
delete[] con2; | ||
} | ||
|
||
auto conCounts1 = new UInt[numColumns]; | ||
auto conCounts2 = new UInt[numColumns]; | ||
sp1.getConnectedCounts(conCounts1); | ||
sp2.getConnectedCounts(conCounts2); | ||
ASSERT_TRUE(check_vector_eq(conCounts1, conCounts2, numColumns)); | ||
delete[] conCounts1; | ||
delete[] conCounts2; | ||
} | ||
|
||
void setup(SpatialPooler& sp, UInt numInputs, | ||
UInt numColumns) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RFC: I'm looking for a way how to approach implementation of these comparisons.
The method in SPTest class used man
ASSERT_TRUE()
calls, the advantage is that when there is an inequality you'll know what exactly caused it.Now for
equals==
implementation we need to turn the asserts into a conjunction of conditions(a1==b1 && a2==b2 &&...)
.I was thinking of turning this
Into
(a1==b1) || return false
; but this result in losing the information on which specific check the comparison failed. Is that a problem?return (this->getNumColumns() == sp.getNumColumns()) && (this->getNumInputs()==sp.getNumInputs()) && ..
I think the b) would inform us in gtest at which part the conjunction failed, keeping the desired behavior.
The only problem with this is the readability and that some checks have pre-processing or even are inside for-loops, but we could workaround that.
EXPECT_EQ
) for information and just add && for equals.EDIT:
add option C)