diff --git a/src/nupic/bindings/algorithms.i b/src/nupic/bindings/algorithms.i index 5a1f94003b..aa4483ce54 100644 --- a/src/nupic/bindings/algorithms.i +++ b/src/nupic/bindings/algorithms.i @@ -42,6 +42,8 @@ else: from nupic.proto.SvmProto_capnp import (SvmDenseProto, Svm01Proto) from nupic.proto.TemporalMemoryProto_capnp import TemporalMemoryProto +# Capnp reader traveral limit (see capnp::ReaderOptions) +_TRAVERSAL_LIMIT_IN_WORDS = 1 << 63 _ALGORITHMS = _algorithms %} @@ -338,7 +340,8 @@ void forceRetentionOfImageSensorLiteLibrary(void) { :param: Destination SvmDenseProto message builder """ - reader = SvmDenseProto.from_bytes(self._writeAsCapnpPyBytes()) # copy + reader = SvmDenseProto.from_bytes(self._writeAsCapnpPyBytes(), + traversal_limit_in_words=_TRAVERSAL_LIMIT_IN_WORDS) pyBuilder.from_dict(reader.to_dict()) # copy @classmethod @@ -450,7 +453,8 @@ void forceRetentionOfImageSensorLiteLibrary(void) { :param: Destination Svm01Proto message builder """ - reader = Svm01Proto.from_bytes(self._writeAsCapnpPyBytes()) # copy + reader = Svm01Proto.from_bytes(self._writeAsCapnpPyBytes(), + traversal_limit_in_words=_TRAVERSAL_LIMIT_IN_WORDS) pyBuilder.from_dict(reader.to_dict()) # copy @classmethod @@ -899,7 +903,8 @@ void forceRetentionOfImageSensorLiteLibrary(void) { :param: Destination Cells4Proto message builder """ - reader = Cells4Proto.from_bytes(self._writeAsCapnpPyBytes()) # copy + reader = Cells4Proto.from_bytes(self._writeAsCapnpPyBytes(), + traversal_limit_in_words=_TRAVERSAL_LIMIT_IN_WORDS) pyBuilder.from_dict(reader.to_dict()) # copy %} @@ -1148,7 +1153,8 @@ void forceRetentionOfImageSensorLiteLibrary(void) { :param: Destination SpatialPoolerProto message builder """ - reader = SpatialPoolerProto.from_bytes(self._writeAsCapnpPyBytes()) # copy + reader = SpatialPoolerProto.from_bytes(self._writeAsCapnpPyBytes(), + traversal_limit_in_words=_TRAVERSAL_LIMIT_IN_WORDS) pyBuilder.from_dict(reader.to_dict()) # copy @@ -1465,7 +1471,8 @@ void forceRetentionOfImageSensorLiteLibrary(void) { :param: Destination SdrClassifierProto message builder """ - reader = SdrClassifierProto.from_bytes(self._writeAsCapnpPyBytes()) # copy + reader = SdrClassifierProto.from_bytes(self._writeAsCapnpPyBytes(), + traversal_limit_in_words=_TRAVERSAL_LIMIT_IN_WORDS) pyBuilder.from_dict(reader.to_dict()) # copy @classmethod @@ -1581,7 +1588,8 @@ void forceRetentionOfImageSensorLiteLibrary(void) { :param: Destination ConnectionsProto message builder """ - reader = ConnectionsProto.from_bytes(self._writeAsCapnpPyBytes()) # copy + reader = ConnectionsProto.from_bytes(self._writeAsCapnpPyBytes(), + traversal_limit_in_words=_TRAVERSAL_LIMIT_IN_WORDS) pyBuilder.from_dict(reader.to_dict()) # copy @@ -1782,7 +1790,8 @@ void forceRetentionOfImageSensorLiteLibrary(void) { :param: Destination TemporalMemoryProto message builder """ - reader = TemporalMemoryProto.from_bytes(self._writeAsCapnpPyBytes()) # copy + reader = TemporalMemoryProto.from_bytes(self._writeAsCapnpPyBytes(), + traversal_limit_in_words=_TRAVERSAL_LIMIT_IN_WORDS) pyBuilder.from_dict(reader.to_dict()) # copy diff --git a/src/nupic/bindings/engine_internal.i b/src/nupic/bindings/engine_internal.i index cc8c7f5168..cff72d3256 100644 --- a/src/nupic/bindings/engine_internal.i +++ b/src/nupic/bindings/engine_internal.i @@ -35,6 +35,9 @@ except ImportError: else: from nupic.proto.NetworkProto_capnp import NetworkProto from nupic.proto.PyRegionProto_capnp import PyRegionProto + +# Capnp reader traveral limit (see capnp::ReaderOptions) +_TRAVERSAL_LIMIT_IN_WORDS = 1 << 63 %} %{ @@ -287,7 +290,8 @@ class IterablePair(object): :param: Destination NetworkProto message builder """ - reader = NetworkProto.from_bytes(self._writeAsCapnpPyBytes()) # copy + reader = NetworkProto.from_bytes(self._writeAsCapnpPyBytes(), + traversal_limit_in_words=_TRAVERSAL_LIMIT_IN_WORDS) pyBuilder.from_dict(reader.to_dict()) # copy @@ -364,7 +368,8 @@ class _PyCapnpHelper(object): :returns: The deserialized python region instance. """ - pyRegionProto = PyRegionProto.from_bytes(pyRegionProtoBytes) + pyRegionProto = PyRegionProto.from_bytes(pyRegionProtoBytes, + traversal_limit_in_words=_TRAVERSAL_LIMIT_IN_WORDS) return getattr(regionCls, methodName)(pyRegionProto) diff --git a/src/nupic/bindings/math.i b/src/nupic/bindings/math.i index 1e23235bcb..e6aa47e20e 100644 --- a/src/nupic/bindings/math.i +++ b/src/nupic/bindings/math.i @@ -33,6 +33,9 @@ except ImportError: else: from nupic.proto.RandomProto_capnp import RandomProto +# Capnp reader traveral limit (see capnp::ReaderOptions) +_TRAVERSAL_LIMIT_IN_WORDS = 1 << 63 + _MATH = _math %} @@ -195,7 +198,8 @@ def write(self, pyBuilder): :param: Destination RandomProto message builder """ - reader = RandomProto.from_bytes(self._writeAsCapnpPyBytes()) # copy + reader = RandomProto.from_bytes(self._writeAsCapnpPyBytes(), + traversal_limit_in_words=_TRAVERSAL_LIMIT_IN_WORDS) pyBuilder.from_dict(reader.to_dict()) # copy diff --git a/src/nupic/bindings/sparse_matrix.i b/src/nupic/bindings/sparse_matrix.i index a8b925352f..5fbe97061e 100644 --- a/src/nupic/bindings/sparse_matrix.i +++ b/src/nupic/bindings/sparse_matrix.i @@ -45,6 +45,9 @@ except ImportError: else: from nupic.proto.SparseMatrixProto_capnp import SparseMatrixProto from nupic.proto.SparseBinaryMatrixProto_capnp import SparseBinaryMatrixProto + +# Capnp reader traveral limit (see capnp::ReaderOptions) +_TRAVERSAL_LIMIT_IN_WORDS = 1 << 63 %} @@ -423,7 +426,8 @@ def write(self, pyBuilder): :param: Destination SparseMatrixProto message builder """ - reader = SparseMatrixProto.from_bytes(self._writeAsCapnpPyBytes()) # copy + reader = SparseMatrixProto.from_bytes(self._writeAsCapnpPyBytes(), + traversal_limit_in_words=_TRAVERSAL_LIMIT_IN_WORDS) pyBuilder.from_dict(reader.to_dict()) # copy @classmethod @@ -2973,7 +2977,8 @@ def write(self, pyBuilder): :param: Destination SparseBinaryMatrixProto message builder """ - reader = SparseBinaryMatrixProto.from_bytes(self._writeAsCapnpPyBytes()) # copy + reader = SparseBinaryMatrixProto.from_bytes(self._writeAsCapnpPyBytes(), + traversal_limit_in_words=_TRAVERSAL_LIMIT_IN_WORDS) pyBuilder.from_dict(reader.to_dict()) # copy def read(self, proto): @@ -3484,7 +3489,8 @@ def write(self, pyBuilder): :param: Destination SparseBinaryMatrixProto message builder """ - reader = SparseBinaryMatrixProto.from_bytes(self._writeAsCapnpPyBytes()) # copy + reader = SparseBinaryMatrixProto.from_bytes(self._writeAsCapnpPyBytes(), + traversal_limit_in_words=_TRAVERSAL_LIMIT_IN_WORDS) pyBuilder.from_dict(reader.to_dict()) # copy def read(self, proto): diff --git a/src/nupic/engine/Region.cpp b/src/nupic/engine/Region.cpp index 3694614e57..845df508b2 100644 --- a/src/nupic/engine/Region.cpp +++ b/src/nupic/engine/Region.cpp @@ -436,79 +436,14 @@ const Timer &Region::getExecuteTimer() const { return executeTimer_; } bool Region::operator==(const Region &o) const { - if (name_ != o.name_ || type_ != o.type_ || dims_ != o.dims_ || - phases_ != o.phases_ || dimensionInfo_ != o.dimensionInfo_ || - initialized_ != o.initialized_ || outputs_.size() != o.outputs_.size() || + if (initialized_ != o.initialized_ || outputs_.size() != o.outputs_.size() || inputs_.size() != o.inputs_.size()) { return false; } - if (spec_ != nullptr && o.spec_ != nullptr) { - // Compare specs - if (spec_->singleNodeOnly != o.spec_->singleNodeOnly || - spec_->description != o.spec_->description) { - return false; - } - - // Parameters - for (size_t i = 0; i < spec_->parameters.getCount(); ++i) { - const std::pair &p1 = - spec_->parameters.getByIndex(i); - const std::pair &p2 = - o.spec_->parameters.getByIndex(i); - if (p1.first != p2.first || p1.second.count != p2.second.count || - p1.second.description != p2.second.description || - p1.second.constraints != p2.second.constraints || - p1.second.defaultValue != p2.second.defaultValue || - p1.second.dataType != p2.second.dataType || - p1.second.accessMode != p2.second.accessMode) { - return false; - } - } - // Outputs - for (size_t i = 0; i < spec_->outputs.getCount(); ++i) { - const std::pair &p1 = - spec_->outputs.getByIndex(i); - const std::pair &p2 = - o.spec_->outputs.getByIndex(i); - if (p1.first != p2.first || p1.second.count != p2.second.count || - p1.second.regionLevel != p2.second.regionLevel || - p1.second.isDefaultOutput != p2.second.isDefaultOutput || - p1.second.sparse != p2.second.sparse || - p1.second.description != p2.second.description || - p1.second.dataType != p2.second.dataType) { - return false; - } - } - // Outputs - for (size_t i = 0; i < spec_->inputs.getCount(); ++i) { - const std::pair &p1 = spec_->inputs.getByIndex(i); - const std::pair &p2 = - o.spec_->inputs.getByIndex(i); - if (p1.first != p2.first || p1.second.count != p2.second.count || - p1.second.regionLevel != p2.second.regionLevel || - p1.second.isDefaultInput != p2.second.isDefaultInput || - p1.second.sparse != p2.second.sparse || - p1.second.requireSplitterMap != p2.second.requireSplitterMap || - p1.second.required != p2.second.required || - p1.second.description != p2.second.description || - p1.second.dataType != p2.second.dataType) { - return false; - } - } - // Commands - for (size_t i = 0; i < spec_->commands.getCount(); ++i) { - const std::pair &p1 = - spec_->commands.getByIndex(i); - const std::pair &p2 = - o.spec_->commands.getByIndex(i); - if (p1.first != p2.first || - p1.second.description != p2.second.description) { - return false; - } - } - } else if (spec_ != o.spec_) { - // One of them is not null + if (name_ != o.name_ || type_ != o.type_ || dims_ != o.dims_ || + spec_ != o.spec_ || phases_ != o.phases_ || + dimensionInfo_ != o.dimensionInfo_) { return false; } @@ -552,6 +487,6 @@ bool Region::operator==(const Region &o) const { } return true; -} +} // namespace nupic } // namespace nupic diff --git a/src/nupic/engine/Spec.cpp b/src/nupic/engine/Spec.cpp index 8896a751b4..a09afa02c4 100644 --- a/src/nupic/engine/Spec.cpp +++ b/src/nupic/engine/Spec.cpp @@ -31,7 +31,14 @@ Implementation of Spec API namespace nupic { Spec::Spec() : singleNodeOnly(false), description("") {} - +bool Spec::operator==(const Spec &o) const { + if (singleNodeOnly != o.singleNodeOnly || description != o.description || + parameters != o.parameters || outputs != o.outputs || + inputs != o.inputs || commands != o.commands) { + return false; + } + return true; +} std::string Spec::getDefaultInputName() const { if (inputs.getCount() == 0) return ""; @@ -87,16 +94,29 @@ InputSpec::InputSpec(std::string description, NTA_BasicType dataType, required(required), regionLevel(regionLevel), isDefaultInput(isDefaultInput), requireSplitterMap(requireSplitterMap), sparse(sparse) {} - +bool InputSpec::operator==(const InputSpec &o) const { + return required == o.required && regionLevel == o.regionLevel && + isDefaultInput == o.isDefaultInput && sparse == o.sparse && + requireSplitterMap == o.requireSplitterMap && dataType == o.dataType && + count == o.count && description == o.description; +} OutputSpec::OutputSpec(std::string description, NTA_BasicType dataType, size_t count, bool regionLevel, bool isDefaultOutput, bool sparse) : description(std::move(description)), dataType(dataType), count(count), regionLevel(regionLevel), isDefaultOutput(isDefaultOutput), sparse(sparse) {} +bool OutputSpec::operator==(const OutputSpec &o) const { + return regionLevel == o.regionLevel && isDefaultOutput == o.isDefaultOutput && + sparse == o.sparse && dataType == o.dataType && count == o.count && + description == o.description; +} CommandSpec::CommandSpec(std::string description) : description(std::move(description)) {} +bool CommandSpec::operator==(const CommandSpec &o) const { + return description == o.description; +} ParameterSpec::ParameterSpec(std::string description, NTA_BasicType dataType, size_t count, std::string constraints, @@ -109,6 +129,11 @@ ParameterSpec::ParameterSpec(std::string description, NTA_BasicType dataType, if (dataType == NTA_BasicType_Byte && count > 0) NTA_THROW << "Parameters of type 'byte' are not supported"; } +bool ParameterSpec::operator==(const ParameterSpec &o) const { + return dataType == o.dataType && count == o.count && + description == o.description && constraints == o.constraints && + defaultValue == o.defaultValue && accessMode == o.accessMode; +} std::string Spec::toString() const { // TODO -- minimal information here; fill out with the rest of diff --git a/src/nupic/engine/Spec.hpp b/src/nupic/engine/Spec.hpp index 457cc2757f..e517eaa63b 100644 --- a/src/nupic/engine/Spec.hpp +++ b/src/nupic/engine/Spec.hpp @@ -39,7 +39,10 @@ class InputSpec { InputSpec(std::string description, NTA_BasicType dataType, UInt32 count, bool required, bool regionLevel, bool isDefaultInput, bool requireSplitterMap = true, bool sparse = false); - + bool operator==(const InputSpec &other) const; + inline bool operator!=(const InputSpec &other) const { + return !operator==(other); + } std::string description; NTA_BasicType dataType; // TBD: Omit? isn't it always of unknown size? @@ -59,7 +62,10 @@ class OutputSpec { OutputSpec(std::string description, const NTA_BasicType dataType, size_t count, bool regionLevel, bool isDefaultOutput, bool sparse = false); - + bool operator==(const OutputSpec &other) const; + inline bool operator!=(const OutputSpec &other) const { + return !operator==(other); + } std::string description; NTA_BasicType dataType; // Size, in number of elements. If size is fixed, specify it here. @@ -74,7 +80,10 @@ class CommandSpec { public: CommandSpec() {} CommandSpec(std::string description); - + bool operator==(const CommandSpec &other) const; + inline bool operator!=(const CommandSpec &other) const { + return !operator==(other); + } std::string description; }; @@ -89,7 +98,10 @@ class ParameterSpec { ParameterSpec(std::string description, NTA_BasicType dataType, size_t count, std::string constraints, std::string defaultValue, AccessMode accessMode); - + bool operator==(const ParameterSpec &other) const; + inline bool operator!=(const ParameterSpec &other) const { + return !operator==(other); + } std::string description; // [open: current basic types are bytes/{u}int16/32/64, real32/64, BytePtr. Is @@ -109,7 +121,8 @@ struct Spec { // TODO: should this be in the base API or layered? In the API right // now since we do not build layered libraries. std::string toString() const; - + bool operator==(const Spec &other) const; + inline bool operator!=(const Spec &other) const { return !operator==(other); } // Some RegionImpls support only a single node in a region. // Such regions always have dimension [1] bool singleNodeOnly; diff --git a/src/nupic/ntypes/Collection.cpp b/src/nupic/ntypes/Collection.cpp index 79ec6089f9..582d9bdcac 100644 --- a/src/nupic/ntypes/Collection.cpp +++ b/src/nupic/ntypes/Collection.cpp @@ -42,7 +42,14 @@ namespace nupic { template Collection::Collection() {} template Collection::~Collection() {} - +template +bool Collection::operator==(const Collection &o) const { + const static auto compare = [](std::pair a, + std::pair b) { + return a.first == b.first && a.second == b.second; + }; + return std::equal(vec_.begin(), vec_.end(), o.vec_.begin(), compare); +} template size_t Collection::getCount() const { return vec_.size(); } diff --git a/src/nupic/ntypes/Collection.hpp b/src/nupic/ntypes/Collection.hpp index eefc44c5ca..c613343a2a 100644 --- a/src/nupic/ntypes/Collection.hpp +++ b/src/nupic/ntypes/Collection.hpp @@ -36,7 +36,10 @@ template class Collection { public: Collection(); virtual ~Collection(); - + bool operator==(const Collection &other) const; + inline bool operator!=(const Collection &other) const { + return !operator==(other); + } size_t getCount() const; // This method provides access by index to the contents of the collection diff --git a/src/nupic/py_support/PyCapnp.hpp b/src/nupic/py_support/PyCapnp.hpp index b29897e161..51b9903a0d 100644 --- a/src/nupic/py_support/PyCapnp.hpp +++ b/src/nupic/py_support/PyCapnp.hpp @@ -111,7 +111,9 @@ class PyCapnpHelper { kj::Array array = kj::heapArray(srcNumWords); memcpy(array.asBytes().begin(), srcBytes, srcNumBytes); // copy - capnp::FlatArrayMessageReader reader(array.asPtr()); // copy ? + capnp::ReaderOptions options; + options.traversalLimitInWords = kj::maxValue; // Don't limit. + capnp::FlatArrayMessageReader reader(array.asPtr(), options); // copy ? typename MessageType::Reader proto = reader.getRoot(); obj.read(proto); #else diff --git a/src/nupic/regions/PyRegion.cpp b/src/nupic/regions/PyRegion.cpp index 35e1d583a6..fbd3f5c787 100644 --- a/src/nupic/regions/PyRegion.cpp +++ b/src/nupic/regions/PyRegion.cpp @@ -435,7 +435,9 @@ void PyRegion::write(capnp::AnyPointer::Builder &proto) const { kj::Array array = Helper::serialize(node_); // Initialize PyRegionProto::Reader from serialized python region - capnp::FlatArrayMessageReader reader(array.asPtr()); + capnp::ReaderOptions options; + options.traversalLimitInWords = kj::maxValue; // Don't limit. + capnp::FlatArrayMessageReader reader(array.asPtr(), options); PyRegionProto::Reader pyRegionReader = reader.getRoot(); // Assign python region's serialization output to the builder diff --git a/src/nupic/types/Serializable.hpp b/src/nupic/types/Serializable.hpp index 0598cc2825..a921daa898 100644 --- a/src/nupic/types/Serializable.hpp +++ b/src/nupic/types/Serializable.hpp @@ -52,8 +52,9 @@ template class Serializable { void read(std::istream &stream) { kj::std::StdInputStream in(stream); - - capnp::InputStreamMessageReader message(in); + capnp::ReaderOptions options; + options.traversalLimitInWords = kj::maxValue; // Don't limit. + capnp::InputStreamMessageReader message(in, options); typename ProtoT::Reader proto = message.getRoot(); read(proto); } diff --git a/src/test/unit/ntypes/CollectionTest.cpp b/src/test/unit/ntypes/CollectionTest.cpp index d546117b0f..110f9a5953 100644 --- a/src/test/unit/ntypes/CollectionTest.cpp +++ b/src/test/unit/ntypes/CollectionTest.cpp @@ -41,6 +41,10 @@ struct CollectionTest : public ::testing::Test { Item() : x(-1) {} Item(int x) : x(x) {} + inline bool operator==(const Item &other) const { return x == other.x; }; + inline bool operator!=(const Item &other) const { + return !operator==(other); + } }; };