Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP SP: add equals operator== #1384

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions src/nupic/algorithms/SpatialPooler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,119 @@ void SpatialPooler::stripUnlearnedColumns(UInt activeArray[]) const
}
}

bool SpatialPooler::operator==(const SpatialPooler &sp) const
{
const UInt numColumns = this->getNumColumns();
const UInt numInputs = this->getNumInputs();

ASSERT_TRUE(this->getNumColumns() == sp.getNumColumns());
ASSERT_TRUE(this->getNumInputs() == sp.getNumInputs());
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RFC: I'm looking for a way how to approach implementation of these comparisons.

The method in SPTest class used man ASSERT_TRUE() calls, the advantage is that when there is an inequality you'll know what exactly caused it.

Now for equals== implementation we need to turn the asserts into a conjunction of conditions (a1==b1 && a2==b2 &&...).

I was thinking of turning this

ASSERT_TRUE(this->getNumColumns() == sp.getNumColumns());
ASSERT_TRUE(this->getNumInputs() == sp.getNumInputs());

Into

  • A) (a1==b1) || return false ; but this result in losing the information on which specific check the comparison failed. Is that a problem?
  • B) If the above is a problem, we could return a huge conjunction
    return (this->getNumColumns() == sp.getNumColumns()) && (this->getNumInputs()==sp.getNumInputs()) && ..

I think the b) would inform us in gtest at which part the conjunction failed, keeping the desired behavior.
The only problem with this is the readability and that some checks have pre-processing or even are inside for-loops, but we could workaround that.

  • C) keep the asserts (EXPECT_EQ) for information and just add && for equals.

EDIT:
add option C)

ASSERT_TRUE(this->getPotentialRadius() ==
sp.getPotentialRadius());
ASSERT_TRUE(this->getPotentialPct() == sp.getPotentialPct());
ASSERT_TRUE(this->getGlobalInhibition() ==
sp.getGlobalInhibition());
ASSERT_TRUE(this->getNumActiveColumnsPerInhArea() ==
sp.getNumActiveColumnsPerInhArea());
ASSERT_TRUE(almost_eq(this->getLocalAreaDensity(),
sp.getLocalAreaDensity()));
ASSERT_TRUE(this->getStimulusThreshold() ==
sp.getStimulusThreshold());
ASSERT_TRUE(this->getDutyCyclePeriod() == sp.getDutyCyclePeriod());
ASSERT_TRUE(almost_eq(this->getBoostStrength(), sp.getBoostStrength()));
ASSERT_TRUE(this->getIterationNum() == sp.getIterationNum());
ASSERT_TRUE(this->getIterationLearnNum() ==
sp.getIterationLearnNum());
ASSERT_TRUE(this->getSpVerbosity() == sp.getSpVerbosity());
ASSERT_TRUE(this->getWrapAround() == sp.getWrapAround());
ASSERT_TRUE(this->getUpdatePeriod() == sp.getUpdatePeriod());
ASSERT_TRUE(almost_eq(this->getSynPermTrimThreshold(),
sp.getSynPermTrimThreshold()));
cout << "check: " << this->getSynPermActiveInc() << " " <<
sp.getSynPermActiveInc() << endl;
ASSERT_TRUE(almost_eq(this->getSynPermActiveInc(),
sp.getSynPermActiveInc()));
ASSERT_TRUE(almost_eq(this->getSynPermInactiveDec(),
sp.getSynPermInactiveDec()));
ASSERT_TRUE(almost_eq(this->getSynPermBelowStimulusInc(),
sp.getSynPermBelowStimulusInc()));
ASSERT_TRUE(almost_eq(this->getSynPermConnected(),
sp.getSynPermConnected()));
ASSERT_TRUE(almost_eq(this->getMinPctOverlapDutyCycles(),
sp.getMinPctOverlapDutyCycles()));


auto boostFactors1 = new Real[numColumns];
auto boostFactors2 = new Real[numColumns];
this->getBoostFactors(boostFactors1);
sp.getBoostFactors(boostFactors2);
ASSERT_TRUE(check_vector_eq(boostFactors1, boostFactors2, numColumns));
delete[] boostFactors1;
delete[] boostFactors2;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like some optimization. Q: keep the delete[] (saves memory), or remove it (saves cycles, as the local memory is unallocated anyway at the end of the function/scope)


auto overlapDutyCycles1 = new Real[numColumns];
auto overlapDutyCycles2 = new Real[numColumns];
this->getOverlapDutyCycles(overlapDutyCycles1);
sp.getOverlapDutyCycles(overlapDutyCycles2);
ASSERT_TRUE(check_vector_eq(overlapDutyCycles1, overlapDutyCycles2, numColumns));
delete[] overlapDutyCycles1;
delete[] overlapDutyCycles2;

auto activeDutyCycles1 = new Real[numColumns];
auto activeDutyCycles2 = new Real[numColumns];
this->getActiveDutyCycles(activeDutyCycles1);
sp.getActiveDutyCycles(activeDutyCycles2);
ASSERT_TRUE(check_vector_eq(activeDutyCycles1, activeDutyCycles2, numColumns));
delete[] activeDutyCycles1;
delete[] activeDutyCycles2;

auto minOverlapDutyCycles1 = new Real[numColumns];
auto minOverlapDutyCycles2 = new Real[numColumns];
this->getMinOverlapDutyCycles(minOverlapDutyCycles1);
sp.getMinOverlapDutyCycles(minOverlapDutyCycles2);
ASSERT_TRUE(check_vector_eq(minOverlapDutyCycles1, minOverlapDutyCycles2, numColumns));
delete[] minOverlapDutyCycles1;
delete[] minOverlapDutyCycles2;

for (UInt i = 0; i < numColumns; i++) {
auto potential1 = new UInt[numInputs];
auto potential2 = new UInt[numInputs];
this->getPotential(i, potential1);
sp.getPotential(i, potential2);
ASSERT_TRUE(check_vector_eq(potential1, potential2, numInputs));
delete[] potential1;
delete[] potential2;
}

for (UInt i = 0; i < numColumns; i++) {
auto perm1 = new Real[numInputs];
auto perm2 = new Real[numInputs];
this->getPermanence(i, perm1);
sp.getPermanence(i, perm2);
ASSERT_TRUE(check_vector_eq(perm1, perm2, numInputs));
delete[] perm1;
delete[] perm2;
}

for (UInt i = 0; i < numColumns; i++) {
auto con1 = new UInt[numInputs];
auto con2 = new UInt[numInputs];
this->getConnectedSynapses(i, con1);
sp.getConnectedSynapses(i, con2);
ASSERT_TRUE(check_vector_eq(con1, con2, numInputs));
delete[] con1;
delete[] con2;
}
auto conCounts1 = new UInt[numColumns];
auto conCounts2 = new UInt[numColumns];
this->getConnectedCounts(conCounts1);
sp.getConnectedCounts(conCounts2);
ASSERT_TRUE(check_vector_eq(conCounts1, conCounts2, numColumns));
delete[] conCounts1;
delete[] conCounts2;

return true;
}

void SpatialPooler::toDense_(vector<UInt>& sparse,
UInt dense[],
Expand Down
6 changes: 6 additions & 0 deletions src/nupic/algorithms/SpatialPooler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,12 @@ namespace nupic
const vector<Real>& getBoostedOverlaps() const;


/**
Equals operator
*/
bool operator==(const SpatialPooler &sp) const;


///////////////////////////////////////////////////////////
//
// Implementation methods. all methods below this line are
Expand Down
112 changes: 0 additions & 112 deletions src/test/unit/algorithms/SpatialPoolerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,118 +126,6 @@ namespace {
return true;
}

void check_spatial_eq(SpatialPooler sp1, SpatialPooler sp2)
{
UInt numColumns = sp1.getNumColumns();
UInt numInputs = sp2.getNumInputs();

ASSERT_TRUE(sp1.getNumColumns() == sp2.getNumColumns());
ASSERT_TRUE(sp1.getNumInputs() == sp2.getNumInputs());
ASSERT_TRUE(sp1.getPotentialRadius() ==
sp2.getPotentialRadius());
ASSERT_TRUE(sp1.getPotentialPct() == sp2.getPotentialPct());
ASSERT_TRUE(sp1.getGlobalInhibition() ==
sp2.getGlobalInhibition());
ASSERT_TRUE(sp1.getNumActiveColumnsPerInhArea() ==
sp2.getNumActiveColumnsPerInhArea());
ASSERT_TRUE(almost_eq(sp1.getLocalAreaDensity(),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: many uses of almost_eq() can be replaced with EXCEPT_NEAR(a, b, diff); and the function can be removed from SPTest class.
https://github.com/google/googletest/blob/master/googletest/docs/AdvancedGuide.md

Although, for the comparisons we need a boolean, not an assert; so we'd have to keep the function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll just implement the small almost_eq method ; see "twice" in the example
#1384 (comment)

sp2.getLocalAreaDensity()));
ASSERT_TRUE(sp1.getStimulusThreshold() ==
sp2.getStimulusThreshold());
ASSERT_TRUE(sp1.getDutyCyclePeriod() == sp2.getDutyCyclePeriod());
ASSERT_TRUE(almost_eq(sp1.getBoostStrength(), sp2.getBoostStrength()));
ASSERT_TRUE(sp1.getIterationNum() == sp2.getIterationNum());
ASSERT_TRUE(sp1.getIterationLearnNum() ==
sp2.getIterationLearnNum());
ASSERT_TRUE(sp1.getSpVerbosity() == sp2.getSpVerbosity());
ASSERT_TRUE(sp1.getWrapAround() == sp2.getWrapAround());
ASSERT_TRUE(sp1.getUpdatePeriod() == sp2.getUpdatePeriod());
ASSERT_TRUE(almost_eq(sp1.getSynPermTrimThreshold(),
sp2.getSynPermTrimThreshold()));
cout << "check: " << sp1.getSynPermActiveInc() << " " <<
sp2.getSynPermActiveInc() << endl;
ASSERT_TRUE(almost_eq(sp1.getSynPermActiveInc(),
sp2.getSynPermActiveInc()));
ASSERT_TRUE(almost_eq(sp1.getSynPermInactiveDec(),
sp2.getSynPermInactiveDec()));
ASSERT_TRUE(almost_eq(sp1.getSynPermBelowStimulusInc(),
sp2.getSynPermBelowStimulusInc()));
ASSERT_TRUE(almost_eq(sp1.getSynPermConnected(),
sp2.getSynPermConnected()));
ASSERT_TRUE(almost_eq(sp1.getMinPctOverlapDutyCycles(),
sp2.getMinPctOverlapDutyCycles()));


auto boostFactors1 = new Real[numColumns];
auto boostFactors2 = new Real[numColumns];
sp1.getBoostFactors(boostFactors1);
sp2.getBoostFactors(boostFactors2);
ASSERT_TRUE(check_vector_eq(boostFactors1, boostFactors2, numColumns));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another function is check_vector_eq(); used in two scenarios:

  • comparing array[] with vector<>. this could be avoided ei by vector.data()
  • for 2 vectors of Real, it calls almost_eq element-wise; I'm not sure if that can be done implicitly without the need for such function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_vector_eq should be implemented as std::equal() with the "twice" example:
http://www.tenouk.com/cpluscodesnippet/cplusstlvectoralgorithmequal.html

delete[] boostFactors1;
delete[] boostFactors2;

auto overlapDutyCycles1 = new Real[numColumns];
auto overlapDutyCycles2 = new Real[numColumns];
sp1.getOverlapDutyCycles(overlapDutyCycles1);
sp2.getOverlapDutyCycles(overlapDutyCycles2);
ASSERT_TRUE(check_vector_eq(overlapDutyCycles1, overlapDutyCycles2, numColumns));
delete[] overlapDutyCycles1;
delete[] overlapDutyCycles2;

auto activeDutyCycles1 = new Real[numColumns];
auto activeDutyCycles2 = new Real[numColumns];
sp1.getActiveDutyCycles(activeDutyCycles1);
sp2.getActiveDutyCycles(activeDutyCycles2);
ASSERT_TRUE(check_vector_eq(activeDutyCycles1, activeDutyCycles2, numColumns));
delete[] activeDutyCycles1;
delete[] activeDutyCycles2;

auto minOverlapDutyCycles1 = new Real[numColumns];
auto minOverlapDutyCycles2 = new Real[numColumns];
sp1.getMinOverlapDutyCycles(minOverlapDutyCycles1);
sp2.getMinOverlapDutyCycles(minOverlapDutyCycles2);
ASSERT_TRUE(check_vector_eq(minOverlapDutyCycles1, minOverlapDutyCycles2, numColumns));
delete[] minOverlapDutyCycles1;
delete[] minOverlapDutyCycles2;

for (UInt i = 0; i < numColumns; i++) {
auto potential1 = new UInt[numInputs];
auto potential2 = new UInt[numInputs];
sp1.getPotential(i, potential1);
sp2.getPotential(i, potential2);
ASSERT_TRUE(check_vector_eq(potential1, potential2, numInputs));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the case where check is within a for loop, so we'll have to create a variable that covers the whole loop:

bool allPotentialsEq = true;
for (..)
  allPotentialsEq == allPotEq && check_vector_eq(..)

delete[] potential1;
delete[] potential2;
}

for (UInt i = 0; i < numColumns; i++) {
auto perm1 = new Real[numInputs];
auto perm2 = new Real[numInputs];
sp1.getPermanence(i, perm1);
sp2.getPermanence(i, perm2);
ASSERT_TRUE(check_vector_eq(perm1, perm2, numInputs));
delete[] perm1;
delete[] perm2;
}

for (UInt i = 0; i < numColumns; i++) {
auto con1 = new UInt[numInputs];
auto con2 = new UInt[numInputs];
sp1.getConnectedSynapses(i, con1);
sp2.getConnectedSynapses(i, con2);
ASSERT_TRUE(check_vector_eq(con1, con2, numInputs));
delete[] con1;
delete[] con2;
}

auto conCounts1 = new UInt[numColumns];
auto conCounts2 = new UInt[numColumns];
sp1.getConnectedCounts(conCounts1);
sp2.getConnectedCounts(conCounts2);
ASSERT_TRUE(check_vector_eq(conCounts1, conCounts2, numColumns));
delete[] conCounts1;
delete[] conCounts2;
}

void setup(SpatialPooler& sp, UInt numInputs,
UInt numColumns)
Expand Down