From e868be8ecbf9ae1bd7cf13744156454c573bbb69 Mon Sep 17 00:00:00 2001 From: Alexander Condello Date: Thu, 14 Nov 2024 09:50:28 -0800 Subject: [PATCH] temp --- .../include/dwave-optimization/utils.hpp | 11 ++++++ dwave/optimization/src/array.cpp | 1 + dwave/optimization/src/nodes/indexing.cpp | 38 ++++++++++++++++--- dwave/optimization/src/nodes/testing.cpp | 14 +++++++ 4 files changed, 58 insertions(+), 6 deletions(-) diff --git a/dwave/optimization/include/dwave-optimization/utils.hpp b/dwave/optimization/include/dwave-optimization/utils.hpp index c0caa4a5..055b8cf2 100644 --- a/dwave/optimization/include/dwave-optimization/utils.hpp +++ b/dwave/optimization/include/dwave-optimization/utils.hpp @@ -127,6 +127,17 @@ class fraction { return lhs; } + constexpr friend fraction operator+(const fraction& lhs, const fraction& rhs) { + return fraction(lhs.numerator_ * rhs.denominator_ + rhs.numerator_ * lhs.denominator_, + lhs.denominator_ * rhs.denominator_); + } + constexpr friend fraction operator+(const fraction& lhs, const std::integral auto& rhs) { + return lhs + fraction(rhs); + } + constexpr friend fraction operator+(const std::integral auto& lhs, fraction& rhs) { + return fraction(lhs) + rhs; + } + /// Fractions can be printed friend std::ostream& operator<<(std::ostream& os, const fraction& rhs) { os << "fraction(" << rhs.numerator(); diff --git a/dwave/optimization/src/array.cpp b/dwave/optimization/src/array.cpp index d4fb85a7..3cf0fa91 100644 --- a/dwave/optimization/src/array.cpp +++ b/dwave/optimization/src/array.cpp @@ -44,6 +44,7 @@ bool SizeInfo::operator==(const SizeInfo& other) const { SizeInfo SizeInfo::substitute(ssize_t max_depth) const { if (max_depth <= 0) return *this; + if (this->array_ptr == nullptr) return *this; SizeInfo sizeinfo = this->array_ptr->sizeinfo(); diff --git a/dwave/optimization/src/nodes/indexing.cpp b/dwave/optimization/src/nodes/indexing.cpp index 36f46158..57352b32 100644 --- a/dwave/optimization/src/nodes/indexing.cpp +++ b/dwave/optimization/src/nodes/indexing.cpp @@ -641,12 +641,38 @@ ssize_t AdvancedIndexingNode::size(const State& state) const { } SizeInfo AdvancedIndexingNode::sizeinfo() const { + // easy case, fixed size if (!dynamic()) return SizeInfo(size()); - // when we get around to supporting broadcasting this will need to change - assert(predecessors().size() >= 2); - assert(!dynamic_cast(predecessors()[0])->dynamic() && - "sizeinfo for dynamic base arrays not supported"); - return SizeInfo(dynamic_cast(predecessors()[1])); + + return SizeInfo(this); + + // // if the base array is dynamic AND the first indexer is a slice then + // // our size is derived from the base array via alchemy + // assert(indices_.size() >= 1); // should always be true + // if (array_ptr_->dynamic() && std::holds_alternative(indices_[0])) { + // // there are additional subcases we can investigate here to try to be + // // more specific, e.g. whether or not any of the indexing arrays are + // // dynamic etc... But for now let's just handwave a bit, say our size + // // is derived from ourselves, and so the best we can with the upper bound. + // std::optional max; + // for (const Node* ptr : predecessors()) { + // // 100 is a magic number... we really need a better way to do this. + // auto sizeinfo = dynamic_cast(ptr)->sizeinfo().substitute(100); + // // assert(false); + // if (max && sizeinfo.max) { + // if (*sizeinfo.max < *max) max = sizeinfo.max; + // } else if (sizeinfo.max) { + // max = sizeinfo.max; + // } + // } + // return SizeInfo(this, std::nullopt, max); + // } + + // // If the first indexer is not a slice, then whether or not we are dynamic + // // we derive our size from indexing arrays. + // // If we eventually add broadcasting then this will need to change. + // assert(predecessors().size() >= 2); // should always be true + // return SizeInfo(dynamic_cast(predecessors()[1])); } std::span AdvancedIndexingNode::shape(const State& state) const { @@ -1515,7 +1541,7 @@ ssize_t BasicIndexingNode::size(const State& state) const { } SizeInfo BasicIndexingNode::sizeinfo() const { - if (size_ >= 0) return SizeInfo(size_); + if (!dynamic()) return SizeInfo(size_); auto sizeinfo = SizeInfo(array_ptr_); diff --git a/dwave/optimization/src/nodes/testing.cpp b/dwave/optimization/src/nodes/testing.cpp index 6d20970b..aa1d7b32 100644 --- a/dwave/optimization/src/nodes/testing.cpp +++ b/dwave/optimization/src/nodes/testing.cpp @@ -46,6 +46,7 @@ void check_shape(const std::span& dynamic_shape, ArrayValidationNode::ArrayValidationNode(ArrayNode* node_ptr) : array_ptr(node_ptr) { assert(array_ptr->ndim() == static_cast(array_ptr->shape().size())); assert(array_ptr->dynamic() == (array_ptr->size() == -1)); + node_ptr->sizeinfo(); // smoke check add_predecessor(node_ptr); } @@ -183,6 +184,19 @@ void ArrayValidationNode::propagate(State& state) const { assert(std::ranges::max(array_ptr->view(state)) <= array_ptr->max()); assert(!array_ptr->integral() || std::ranges::all_of(array_ptr->view(state), is_integer)); } + + // check that whatever sizeinfo the array reports is accurate + auto sizeinfo = array_ptr->sizeinfo(); + if (sizeinfo.array_ptr != nullptr) { + // the size is at least theoretically derived from another array, so let's check that the + // reported multiplier/offset are correct + assert(array_ptr->size(state) == + sizeinfo.multiplier * sizeinfo.array_ptr->size(state) + sizeinfo.offset); + } else { + assert(array_ptr->size(state) == sizeinfo.offset); + } + assert(!sizeinfo.max || array_ptr->size(state) <= *sizeinfo.max); + assert(!sizeinfo.min || array_ptr->size(state) >= *sizeinfo.min); } void ArrayValidationNode::revert(State& state) const {