Skip to content

Commit

Permalink
fix: comparison operators only accept base pointer
Browse files Browse the repository at this point in the history
  • Loading branch information
tearfur authored Feb 16, 2024
1 parent d066984 commit 697a546
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 94 deletions.
185 changes: 91 additions & 94 deletions include/small/detail/iterator/pointer_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#ifndef SMALL_DETAIL_ITERATOR_POINTER_WRAPPER_HPP
#define SMALL_DETAIL_ITERATOR_POINTER_WRAPPER_HPP

#include <small/detail/traits/ptr_to_const.hpp>
#include <algorithm>
#include <iterator>
#include <type_traits>
Expand Down Expand Up @@ -146,186 +147,182 @@ namespace small {

public /* friends */:
/// \brief Get distance between iterators
template <class Iter1, class Iter2>
template <class Pointer1, class Pointer2>
constexpr friend auto
operator-(
const pointer_wrapper<Iter1> &x,
const pointer_wrapper<Iter2> &y) noexcept
const pointer_wrapper<Pointer1> &x,
const pointer_wrapper<Pointer2> &y) noexcept
-> decltype(x.base() - y.base());

/// \brief Sum iterators
template <class Iter1>
constexpr friend pointer_wrapper<Iter1> operator+(
typename pointer_wrapper<Iter1>::difference_type,
pointer_wrapper<Iter1>) noexcept;
template <class Pointer1>
constexpr friend pointer_wrapper<Pointer1> operator+(
typename pointer_wrapper<Pointer1>::difference_type,
pointer_wrapper<Pointer1>) noexcept;

private:
/// Base pointer
iterator_type base_;
};

template <class Iter1, class Iter2>
template <class Pointer1, class Pointer2 = Pointer1>
inline constexpr bool
operator==(
const pointer_wrapper<Iter1> &x,
const pointer_wrapper<Iter2> &y) noexcept {
const pointer_wrapper<Pointer1> &x,
const pointer_wrapper<Pointer2> &y) noexcept {
return x.base() == y.base();
}

template <class Iter1, class Iter2>
template <class Pointer1, class Pointer2 = Pointer1>
inline constexpr bool
operator!=(
const pointer_wrapper<Iter1> &x,
const pointer_wrapper<Iter2> &y) noexcept {
const pointer_wrapper<Pointer1> &x,
const pointer_wrapper<Pointer2> &y) noexcept {
return !(x == y);
}

template <class Iter1, class Iter2>
template <class Pointer1, class Pointer2 = Pointer1>
inline constexpr bool
operator<(
const pointer_wrapper<Iter1> &x,
const pointer_wrapper<Iter2> &y) noexcept {
const pointer_wrapper<Pointer1> &x,
const pointer_wrapper<Pointer2> &y) noexcept {
return x.base() < y.base();
}

template <class Iter1, class Iter2>
template <class Pointer1, class Pointer2 = Pointer1>
inline constexpr bool
operator>(
const pointer_wrapper<Iter1> &x,
const pointer_wrapper<Iter2> &y) noexcept {
const pointer_wrapper<Pointer1> &x,
const pointer_wrapper<Pointer2> &y) noexcept {
return y < x;
}

template <class Iter1, class Iter2>
template <class Pointer1, class Pointer2 = Pointer1>
inline constexpr bool
operator>=(
const pointer_wrapper<Iter1> &x,
const pointer_wrapper<Iter2> &y) noexcept {
return !(x < y);
operator<=(
const pointer_wrapper<Pointer1> &x,
const pointer_wrapper<Pointer2> &y) noexcept {
return !(y < x);
}

template <class Iter1, class Iter2>
template <class Pointer1, class Pointer2 = Pointer1>
inline constexpr bool
operator<=(
const pointer_wrapper<Iter1> &x,
const pointer_wrapper<Iter2> &y) noexcept {
return !(y < x);
operator>=(
const pointer_wrapper<Pointer1> &x,
const pointer_wrapper<Pointer2> &y) noexcept {
return !(x < y);
}

template <class Iter>
template <class Pointer>
inline constexpr bool
operator==(
const pointer_wrapper<Iter> &x,
const pointer_wrapper<Iter> &y) noexcept {
return x.base() == y.base();
const pointer_wrapper<Pointer> &x,
ptr_to_const_t<Pointer> y) noexcept {
return x.base() == y;
}

template <class Iter>
template <class Pointer>
inline constexpr bool
operator!=(
const pointer_wrapper<Iter> &x,
const pointer_wrapper<Iter> &y) noexcept {
const pointer_wrapper<Pointer> &x,
ptr_to_const_t<Pointer> y) noexcept {
return !(x == y);
}

template <class Iter>
template <class Pointer>
inline constexpr bool
operator>(
const pointer_wrapper<Iter> &x,
const pointer_wrapper<Iter> &y) noexcept {
return y < x;
operator<(
const pointer_wrapper<Pointer> &x,
ptr_to_const_t<Pointer> y) noexcept {
return x.base() < y;
}

template <class Iter>
template <class Pointer>
inline constexpr bool
operator>=(
const pointer_wrapper<Iter> &x,
const pointer_wrapper<Iter> &y) noexcept {
return !(x < y);
operator>(
const pointer_wrapper<Pointer> &x,
ptr_to_const_t<Pointer> y) noexcept {
return y < x.base();
}

template <class Iter>
template <class Pointer>
inline constexpr bool
operator<=(
const pointer_wrapper<Iter> &x,
const pointer_wrapper<Iter> &y) noexcept {
const pointer_wrapper<Pointer> &x,
ptr_to_const_t<Pointer> y) noexcept {
return !(y < x);
}

template <class Iter, class BaseIter>
template <class Pointer>
inline constexpr bool
operator==(const pointer_wrapper<Iter> &x, const BaseIter &y) noexcept {
return x.base() == y;
}

template <class Iter, class BaseIter>
inline constexpr bool
operator!=(const pointer_wrapper<Iter> &x, const BaseIter &y) noexcept {
return !(x == y);
}

template <class Iter, class BaseIter>
inline constexpr bool
operator>(const pointer_wrapper<Iter> &x, const BaseIter &y) noexcept {
return y < x;
}

template <class Iter, class BaseIter>
inline constexpr bool
operator>=(const pointer_wrapper<Iter> &x, const BaseIter &y) noexcept {
operator>=(
const pointer_wrapper<Pointer> &x,
ptr_to_const_t<Pointer> y) noexcept {
return !(x < y);
}

template <class Iter, class BaseIter>
template <class Pointer>
inline constexpr bool
operator<=(const pointer_wrapper<Iter> &x, const BaseIter &y) noexcept {
return !(y < x);
}

template <class Iter, class BaseIter>
inline constexpr bool
operator==(const BaseIter &x, const pointer_wrapper<Iter> &y) noexcept {
return x.base() == y.base();
operator==(
ptr_to_const_t<Pointer> x,
const pointer_wrapper<Pointer> &y) noexcept {
return x == y.base();
}

template <class Iter, class BaseIter>
template <class Pointer>
inline constexpr bool
operator!=(const BaseIter &x, const pointer_wrapper<Iter> &y) noexcept {
operator!=(
ptr_to_const_t<Pointer> x,
const pointer_wrapper<Pointer> &y) noexcept {
return !(x == y);
}

template <class Iter, class BaseIter>
template <class Pointer>
inline constexpr bool
operator>(const BaseIter &x, const pointer_wrapper<Iter> &y) noexcept {
operator>(
ptr_to_const_t<Pointer> x,
const pointer_wrapper<Pointer> &y) noexcept {
return y < x;
}

template <class Iter, class BaseIter>
template <class Pointer>
inline constexpr bool
operator>=(const BaseIter &x, const pointer_wrapper<Iter> &y) noexcept {
return !(x < y);
operator<(
ptr_to_const_t<Pointer> x,
const pointer_wrapper<Pointer> &y) noexcept {
return y > x;
}

template <class Iter, class BaseIter>
template <class Pointer>
inline constexpr bool
operator<=(const BaseIter &x, const pointer_wrapper<Iter> &y) noexcept {
operator<=(
ptr_to_const_t<Pointer> x,
const pointer_wrapper<Pointer> &y) noexcept {
return !(y < x);
}

template <class Iter1, class Iter2>
template <class Pointer>
inline constexpr bool
operator>=(
ptr_to_const_t<Pointer> x,
const pointer_wrapper<Pointer> &y) noexcept {
return !(x < y);
}

template <class Pointer1, class Pointer2 = Pointer1>
inline constexpr auto
operator-(
const pointer_wrapper<Iter1> &x,
const pointer_wrapper<Iter2> &y) noexcept
const pointer_wrapper<Pointer1> &x,
const pointer_wrapper<Pointer2> &y) noexcept
-> decltype(x.base() - y.base()) {
return x.base() - y.base();
}

template <class Iter>
inline constexpr pointer_wrapper<Iter>
template <class Pointer>
inline constexpr pointer_wrapper<Pointer>
operator+(
typename pointer_wrapper<Iter>::difference_type n,
pointer_wrapper<Iter> x) noexcept {
typename pointer_wrapper<Pointer>::difference_type n,
pointer_wrapper<Pointer> x) noexcept {
x += n;
return x;
}
Expand Down
25 changes: 25 additions & 0 deletions include/small/detail/traits/ptr_to_const.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//
// Copyright (c) 2024 Yat Ho ([email protected])
//
// Distributed under the Boost Software License, Version 1.0.
// https://www.boost.org/LICENSE_1_0.txt
//

#ifndef SMALL_DETAIL_TRAITS_PTR_TO_CONST_HPP
#define SMALL_DETAIL_TRAITS_PTR_TO_CONST_HPP

#include <type_traits>

namespace small {
namespace detail {
/// Convert a pointer to pointer-to-const
template <typename T, typename = std::enable_if_t<std::is_pointer_v<T>>>
using ptr_to_const = std::add_pointer<
std::add_const_t<std::remove_pointer_t<T>>>;

template <typename T>
using ptr_to_const_t = typename ptr_to_const<T>::type;
} // namespace detail
} // namespace small

#endif // SMALL_DETAIL_TRAITS_PTR_TO_CONST_HPP
41 changes: 41 additions & 0 deletions test/unit/string_small_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,36 @@ TEST_CASE("String Vector") {

SECTION("Iterators") {
vector<std::string, 5> a = { "one", "two", "three" };
auto iter = a.begin();

// LegacyContiguousIterator
REQUIRE(*(a.begin() + 2) == *(std::addressof(*a.begin()) + 2));

// LegacyRandomAccessIterator
REQUIRE((iter += 2) == a.begin() + 2);
REQUIRE(a.begin() + 1 == 1 + a.begin());
REQUIRE((iter -= 2) == a.begin());
REQUIRE(a.end() - 2 == a.begin() + 1);
REQUIRE(a.end() - a.begin() == 3);
REQUIRE(a.begin()[1] == "two");
REQUIRE(a.end()[-1] == "three");

REQUIRE(a.begin() == a.begin());
REQUIRE(a.begin() != a.begin() + 2);
REQUIRE(a.begin() < a.begin() + 2);
REQUIRE(a.begin() + 2 > a.begin());

// LegacyBidirectionalIterator + LegacyForwardIterator
iter = a.begin() + 1;
REQUIRE(--iter == a.begin());
REQUIRE(iter++ == a.begin());
REQUIRE(iter-- == a.begin() + 1);
iter = a.begin() + 1;
REQUIRE(*iter-- == "two");

// LegacyInputIterator/LegacyIterator
iter = a.begin();
REQUIRE(++iter == a.begin() + 1);

REQUIRE(a.begin() == a.data());
REQUIRE(a.end() == a.data() + a.size());
Expand All @@ -224,6 +254,17 @@ TEST_CASE("String Vector") {

REQUIRE(*a.crbegin() == "three");
REQUIRE(*std::prev(a.crend()) == "one");

// Custom comparison operators
REQUIRE(a.begin() == a.data());
REQUIRE(a.begin() != a.data() + 2);
REQUIRE(a.begin() < a.data() + 2);
REQUIRE(a.begin() + 2 > a.data());

REQUIRE(a.data() == a.begin());
REQUIRE(a.data() != a.begin() + 2);
REQUIRE(a.data() < a.begin() + 2);
REQUIRE(a.data() + 2 > a.begin());
}

SECTION("Capacity") {
Expand Down

0 comments on commit 697a546

Please sign in to comment.