Skip to content

Commit

Permalink
Moved some functions behind compilation barrier
Browse files Browse the repository at this point in the history
  • Loading branch information
azucca committed Sep 27, 2024
1 parent 0337e08 commit 4c484f3
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 67 deletions.
11 changes: 0 additions & 11 deletions dwave/optimization/include/dwave-optimization/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,17 +738,6 @@ std::vector<ssize_t> broadcast_shape(const std::span<const ssize_t> lhs,
std::vector<ssize_t> broadcast_shape(std::initializer_list<ssize_t> lhs,
std::initializer_list<ssize_t> rhs);

/// For "partial reduction", get the shape of the resulting shape of the array when preforming a
/// reduction over an axis
std::vector<ssize_t> partial_reduce_shape(const std::span<const ssize_t> array_shape,
const ssize_t axis);
std::vector<ssize_t> partial_reduce_shape(std::initializer_list<ssize_t> input_shape,
const ssize_t axis);

/// Gets the strides of a n-dimensional array assuming contiguous memory
std::vector<ssize_t> as_contiguous_strides(const std::span<const ssize_t> shape);
std::vector<ssize_t> as_contiguous_strides(std::initializer_list<ssize_t> shape);

/// Convert a flat index to multi-index
std::vector<ssize_t> unravel_index(const std::span<const ssize_t> strides, ssize_t index);

Expand Down
29 changes: 0 additions & 29 deletions dwave/optimization/src/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,35 +246,6 @@ std::vector<ssize_t> broadcast_shape(std::initializer_list<ssize_t> lhs,
return broadcast_shape(std::span(lhs), std::span(rhs));
}

std::vector<ssize_t> partial_reduce_shape(const std::span<const ssize_t> input_shape,
const ssize_t axis) {
std::vector<ssize_t> shape;
shape.assign(input_shape.begin(), input_shape.end());
shape.erase(shape.begin() + axis);
return shape;
}
std::vector<ssize_t> partial_reduce_shape(std::initializer_list<ssize_t> input_shape,
const ssize_t axis) {
return partial_reduce_shape(std::span(input_shape), axis);
}

std::vector<ssize_t> as_contiguous_strides(const std::span<const ssize_t> shape) {
ssize_t ndim = static_cast<ssize_t>(shape.size());

assert(ndim >= 0);
std::vector<ssize_t> strides(ndim);
// otherwise strides are a function of the shape
strides[ndim - 1] = sizeof(double);
for (auto i = ndim - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * shape[i + 1];
}
return strides;
}

std::vector<ssize_t> as_contiguous_strides(std::initializer_list<ssize_t> shape) {
return as_contiguous_strides(std::span(shape));
}

std::vector<ssize_t> unravel_index(const std::span<const ssize_t> strides, ssize_t index) {
ssize_t ndim = static_cast<ssize_t>(strides.size());
std::vector<ssize_t> indices;
Expand Down
21 changes: 21 additions & 0 deletions dwave/optimization/src/nodes/mathematical.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,27 @@ template class NaryOpNode<std::multiplies<double>>;
template class NaryOpNode<std::plus<double>>;

// PartialReduceNode *****************************************************************
std::vector<ssize_t> partial_reduce_shape(const std::span<const ssize_t> input_shape,
const ssize_t axis) {
std::vector<ssize_t> shape;
shape.assign(input_shape.begin(), input_shape.end());
shape.erase(shape.begin() + axis);
return shape;
}

std::vector<ssize_t> as_contiguous_strides(const std::span<const ssize_t> shape) {
ssize_t ndim = static_cast<ssize_t>(shape.size());

assert(ndim >= 0);
std::vector<ssize_t> strides(ndim);
// otherwise strides are a function of the shape
strides[ndim - 1] = sizeof(double);
for (auto i = ndim - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * shape[i + 1];
}
return strides;
}

/// TODO: support multiple axes
template <class BinaryOp>
PartialReduceNode<BinaryOp>::PartialReduceNode(ArrayNode* node_ptr, std::span<const ssize_t> axes,
Expand Down
27 changes: 0 additions & 27 deletions tests/cpp/tests/test_array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,35 +526,9 @@ TEST_CASE("Test resulting_shape()") {
CHECK_THROWS_WITH(broadcast_shape({2, 1}, {8, 4, 3}),
"operands could not be broadcast together with shapes (2,1) (8,4,3)");
}

SECTION("Reduce (2, 3, 4), axis=0") {
CHECK(std::ranges::equal(partial_reduce_shape({2, 3, 4}, 0), std::vector{3, 4}));
}

SECTION("Reduce (2, 3, 4), axis=1") {
CHECK(std::ranges::equal(partial_reduce_shape({2, 3, 4}, 1), std::vector{2, 4}));
}

SECTION("Reduce (2, 3, 4), axis=2") {
CHECK(std::ranges::equal(partial_reduce_shape({2, 3, 4}, 2), std::vector{2, 3}));
}
}

TEST_CASE("Ravelling-unravelling indices") {
SECTION("Shape (10, 3, 6)") {
auto strides = as_contiguous_strides({10, 3, 6});
ssize_t index = 15;
ssize_t last_element_flat = 10 * 3 * 6 - 1;
std::vector<ssize_t> last_element_multi{9, 2, 5};

CHECK(std::ranges::equal(strides, std::vector{144, 48, 8}));
CHECK(ravel_multi_index(strides, unravel_index(strides, index)) == index);
CHECK(ravel_multi_index(strides, last_element_multi) == last_element_flat);

// Check one element within the range
CHECK(std::ranges::equal(unravel_index(strides, ravel_multi_index(strides, {3, 1, 2})), std::vector{3, 1, 2}));
}

SECTION("On constant array of shape (3, 4, 5)") {
auto state = State();
class Array3d : public ArrayOutputMixin<Array> {
Expand All @@ -578,7 +552,6 @@ TEST_CASE("Ravelling-unravelling indices") {
std::vector<ssize_t> shape_ = {3, 4, 5};
};
auto arr = Array3d();
CHECK(std::ranges::equal(arr.strides(), as_contiguous_strides(arr.shape())));
auto last_element_flat = arr.size() - 1;
CHECK(ravel_multi_index(arr.strides(), {2, 3, 4}) == last_element_flat);
CHECK(ravel_multi_index(arr.strides(), unravel_index(arr.strides(), last_element_flat)) == last_element_flat);
Expand Down

0 comments on commit 4c484f3

Please sign in to comment.