Skip to content

Commit

Permalink
NUP-2504: Add unit test to softmax overflow
Browse files Browse the repository at this point in the history
  • Loading branch information
lscheinkman committed Apr 13, 2018
1 parent 7fcc68d commit f5bc76b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 10 deletions.
4 changes: 4 additions & 0 deletions src/nupic/algorithms/SDRClassifier.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ const UInt sdrClassifierVersion = 1;
typedef Dense<UInt, Real64> Matrix;

class SDRClassifier : public Serializable<SdrClassifierProto> {
// Make test class friend so it can unit test private members directly
friend class SDRClassifierTest;

public:
/**
* Constructor for use when deserializing.
Expand Down Expand Up @@ -204,6 +207,7 @@ class SDRClassifier : public Serializable<SdrClassifierProto> {
// Version and verbosity.
UInt version_;
UInt verbosity_;

}; // end of SDRClassifier class

} // end of namespace sdr_classifier
Expand Down
46 changes: 36 additions & 10 deletions src/test/unit/algorithms/SDRClassifierTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,43 @@
* Implementation of unit tests for SDRClassifier
*/

#include <cmath> // isnan
#include <iostream>
#include <limits> // numeric_limits
#include <sstream>
#include <stdio.h>
#include <string>
#include <vector>

#include <gtest/gtest.h>
#include <kj/std/iostream.h>

#include <nupic/algorithms/ClassifierResult.hpp>
#include <nupic/algorithms/SDRClassifier.hpp>
#include <nupic/math/StlIo.hpp>
#include <nupic/types/Types.hpp>
#include <nupic/utils/Log.hpp>

namespace nupic {
namespace algorithms {
namespace sdr_classifier {

// SDRClassifier friend class used to access private members
class SDRClassifierTest : public ::testing::Test {
protected:
typedef std::vector<double>::iterator Iterator;
void softmax_(SDRClassifier *self, Iterator begin, Iterator end) {
self->softmax_(begin, end);
};
};
} // namespace sdr_classifier
} // namespace algorithms
} // namespace nupic

using namespace std;
using namespace nupic;
using namespace nupic::algorithms::cla_classifier;
using namespace nupic::algorithms::sdr_classifier;

namespace {

TEST(SDRClassifierTest, Basic) {
TEST_F(SDRClassifierTest, Basic) {
vector<UInt> steps;
steps.push_back(1);
SDRClassifier c = SDRClassifier(steps, 0.1, 0.1, 0);
Expand Down Expand Up @@ -109,7 +127,7 @@ TEST(SDRClassifierTest, Basic) {
}
}

TEST(SDRClassifierTest, SingleValue) {
TEST_F(SDRClassifierTest, SingleValue) {
// Feed the same input 10 times, the corresponding probability should be
// very high
vector<UInt> steps;
Expand Down Expand Up @@ -145,7 +163,7 @@ TEST(SDRClassifierTest, SingleValue) {
}
}

TEST(SDRClassifierTest, ComputeComplex) {
TEST_F(SDRClassifierTest, ComputeComplex) {
// More complex classification
// This test is ported from the Python unit test
vector<UInt> steps;
Expand Down Expand Up @@ -255,7 +273,7 @@ TEST(SDRClassifierTest, ComputeComplex) {
}
}

TEST(SDRClassifierTest, MultipleCategory) {
TEST_F(SDRClassifierTest, MultipleCategory) {
// Test multiple category classification with single compute calls
// This test is ported from the Python unit test
vector<UInt> steps;
Expand Down Expand Up @@ -326,7 +344,7 @@ TEST(SDRClassifierTest, MultipleCategory) {
}
}

TEST(SDRClassifierTest, SaveLoad) {
TEST_F(SDRClassifierTest, SaveLoad) {
vector<UInt> steps;
steps.push_back(1);
SDRClassifier c1 = SDRClassifier(steps, 0.1, 0.1, 0);
Expand Down Expand Up @@ -362,7 +380,7 @@ TEST(SDRClassifierTest, SaveLoad) {
ASSERT_TRUE(result1 == result2);
}

TEST(SDRClassifierTest, WriteRead) {
TEST_F(SDRClassifierTest, WriteRead) {
vector<UInt> steps;
steps.push_back(1);
steps.push_back(2);
Expand Down Expand Up @@ -412,4 +430,12 @@ TEST(SDRClassifierTest, WriteRead) {
ASSERT_TRUE(result1 == result2);
}

TEST_F(SDRClassifierTest, testSoftmaxOverflow) {
SDRClassifier c = SDRClassifier({1}, 0.5, 0.5, 0);
std::vector<double> values = {numeric_limits<double>::max()};
softmax_(&c, values.begin(), values.end());
double result = values[0];
ASSERT_FALSE(std::isnan(result));
}

} // end namespace

0 comments on commit f5bc76b

Please sign in to comment.