Skip to content

Commit

Permalink
Add function to determine the size of a vectorized type. Add function…
Browse files Browse the repository at this point in the history
… to get the possible values of an enum and another to get the size. Change vectorize to preallocate rather than back_inserter. Improve efficiency of transformUnits by writing directly to output array. Bump boost/pfr versions. Bugfix python enum unit orders not being expanded.

Signed-off-by: Bryce Ferenczi <[email protected]>
  • Loading branch information
5had3z committed Jul 18, 2024
1 parent 5a9174c commit 5b027d5
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 120 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ if(MSVC)
NAME
Boost
URL
"https://github.com/boostorg/boost/releases/download/boost-1.83.0/boost-1.83.0.zip"
"https://github.com/boostorg/boost/releases/download/boost-1.85.0/boost-1.85.0.zip"
OPTIONS
"BOOST_INCLUDE_LIBRARIES iostreams\\\\;pfr"
"BOOST_IOSTREAMS_ENABLE_ZLIB")
Expand All @@ -73,7 +73,7 @@ else()
GITHUB_REPOSITORY
boostorg/pfr
GIT_TAG
2.1.0)
2.2.0)

find_package(Boost REQUIRED COMPONENTS iostreams)
endif()
Expand Down
70 changes: 62 additions & 8 deletions include/data_structures/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ namespace detail {
requires std::is_arithmetic_v<T>
auto vectorize_helper(T d, It it, bool onehotEnum) -> It
{
*it++ = static_cast<It::container_type::value_type>(d);
*it++ = static_cast<std::iter_value_t<It>>(d);
return it;
}

Expand All @@ -69,7 +69,7 @@ namespace detail {
requires std::is_arithmetic_v<std::ranges::range_value_t<T>>
auto vectorize_helper(const T &d, It it, bool onehotEnum) -> It
{
return std::ranges::transform(d, it, [](auto e) { return static_cast<It::container_type::value_type>(e); }).out;
return std::ranges::transform(d, it, [](auto e) { return static_cast<std::iter_value_t<It>>(e); }).out;
}

/**
Expand All @@ -85,7 +85,7 @@ namespace detail {
requires std::is_enum_v<T>
auto vectorize_helper(T d, It it, bool onehotEnum) -> It
{
using value_type = It::container_type::value_type;
using value_type = std::iter_value_t<It>;
if (onehotEnum) {
it = std::ranges::copy(enumToOneHot<value_type>(d), it).out;
} else {
Expand All @@ -111,6 +111,38 @@ namespace detail {
d, [&it, onehotEnum](const auto &field) { it = detail::vectorize_helper(field, it, onehotEnum); });
return it;
}

/**
* @brief Does the main lifting for determining the size of a vectorized struct
*
* @tparam T Type to understand vectorization size
* @tparam oneHotEnum Flag to expand enums to onehot
* @return std::size_t The size of the vectorized type
*/
template<typename T, bool oneHotEnum>
requires std::is_aggregate_v<T>
consteval auto vectorizedSizeHelper() -> std::size_t
{
T d{};// Make plane prototype for pfr::for_each_field
std::size_t sum{ 0 };
boost::pfr::for_each_field(d, [&sum](auto field) {
using field_type = decltype(field);
if constexpr (std::is_arithmetic_v<field_type>) {
sum += 1;
} else if constexpr (std::is_enum_v<field_type> && oneHotEnum) {
sum += numEnumValues<field_type>();
} else if constexpr (std::is_enum_v<field_type> && !oneHotEnum) {
sum += 1;
} else if constexpr (std::is_aggregate_v<field_type>) {
sum += vectorizedSizeHelper<field_type, oneHotEnum>();
} else {
static_assert(always_false_v<field_type>, "Failed to match type");
}
});

return sum;
}

}// namespace detail

/**
Expand All @@ -123,14 +155,31 @@ namespace detail {
* @return Incremented output iterator
*/
template<typename S, typename It>
requires std::is_aggregate_v<S> && std::is_arithmetic_v<typename It::container_type::value_type>
requires std::is_aggregate_v<S> && std::is_arithmetic_v<std::iter_value_t<It>>
[[maybe_unused]] auto vectorize(S s, It it, bool onehotEnum = false) -> It
{
boost::pfr::for_each_field(
s, [&it, onehotEnum](const auto &field) { it = detail::vectorize_helper(field, it, onehotEnum); });
return it;
}

/**
* @brief Get the size of the struct if its vectorized with vectorize
* @tparam T Type to query
* @param onehotEnum Flag if oneHotEncodings are expanded or not
* @return Number of elements of the vectorized type
*/
template<typename T>
requires std::is_aggregate_v<T>
[[nodiscard]] constexpr auto getVectorizedSize(bool onehotEnum) -> std::size_t
{
if (onehotEnum) {
return detail::vectorizedSizeHelper<T, true>();
} else {
return detail::vectorizedSizeHelper<T, false>();
}
}

/**
* @brief Vectorize Struct of data to vector
* @tparam T Output arithmetic type of vector
Expand All @@ -141,10 +190,15 @@ template<typename S, typename It>
*/
template<typename T, typename S>
requires std::is_aggregate_v<S> && std::is_arithmetic_v<T>
auto vectorize(S s, bool onehotEnum = false) -> std::vector<T>
[[nodiscard]] auto vectorize(S s, bool onehotEnum = false) -> std::vector<T>
{
std::vector<T> out;
vectorize(s, std::back_inserter(out), onehotEnum);
std::vector<T> out(getVectorizedSize<S>(onehotEnum));
const auto end = vectorize(s, out.begin(), onehotEnum);
const auto writtenSize = std::distance(out.begin(), end);
if (writtenSize != out.size()) {
throw std::out_of_range("Expected vectorization with size " + std::to_string(out.size()) + " but got "
+ std::to_string(writtenSize));
}
return out;
}

Expand Down Expand Up @@ -467,7 +521,7 @@ struct Action
template<typename T> auto enumToOneHot(Action::TargetType e) noexcept -> std::vector<T>
{
using E = Action::TargetType;
constexpr std::array vals = { E::Self, E::OtherUnit, E::Position };
constexpr std::array vals{ E::Self, E::OtherUnit, E::Position };
return detail::enumToOneHot_helper<T>(e, vals);
}

Expand Down
144 changes: 64 additions & 80 deletions include/data_structures/enums.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,23 @@

namespace cvt {

/**
* @brief Converts an enum value to a one-hot encoding
* @tparam E enum type to convert
* @tparam T element type of output vector
* @param e enum to convert
* @return one-hot encoding of enum
*/
template<typename E, typename T>
requires std::is_enum_v<E>
auto enumToOneHot(E e) noexcept -> std::vector<T>;

namespace detail {
/**
* @brief Helper type
*
* @tparam T
*/
template<typename T> struct always_false : std::false_type
{
};

/**
* @brief Always false type to help printing type info at a compile time error
*
* @tparam T type to print
*/
template<typename T> constexpr bool always_false_v = always_false<T>::value;


/**
Expand All @@ -52,97 +57,76 @@ namespace detail {

enum class Alliance : char { Self = 1, Ally = 2, Neutral = 3, Enemy = 4 };

/**
* @brief Convert alliance value to one-hot encoding
* @tparam T output value type
* @param e enum to convert
* @return one-hot encoding
*/
template<typename T> auto enumToOneHot(Alliance e) noexcept -> std::vector<T>
{
constexpr std::array vals = std::array{ Alliance::Self, Alliance::Ally, Alliance::Neutral, Alliance::Enemy };
static_assert(std::is_sorted(vals.begin(), vals.end()));
return detail::enumToOneHot_helper<T>(e, vals);
}

enum class CloakState : char { Unknown = 0, Cloaked = 1, Detected = 2, UnCloaked = 3, Allied = 4 };

/**
* @brief Convert cloak state value to one-hot encoding
* @tparam T output value type
* @param e enum to convert
* @return one-hot encoding
*/
template<typename T> auto enumToOneHot(CloakState e) noexcept -> std::vector<T>
{
constexpr std::array vals = {
CloakState::Unknown, CloakState::Cloaked, CloakState::Detected, CloakState::UnCloaked, CloakState::Allied
};
static_assert(std::is_sorted(vals.begin(), vals.end()));
return detail::enumToOneHot_helper<T>(e, vals);
}

enum class Visibility : char { Visible = 1, Snapshot = 2, Hidden = 3 };

/**
* @brief Convert visibility value to one-hot encoding
* @tparam T output value type
* @param e enum to convert
* @return one-hot encoding
*/
template<typename T> auto enumToOneHot(Visibility e) noexcept -> std::vector<T>
{
constexpr std::array vals = { Visibility::Visible, Visibility::Snapshot, Visibility::Hidden };
static_assert(std::is_sorted(vals.begin(), vals.end()));
return detail::enumToOneHot_helper<T>(e, vals);
}

enum class AddOn : char { None = 0, Reactor = 1, TechLab = 2 };

enum class Race : char { Terran, Zerg, Protoss, Random };

enum class Result : char { Win, Loss, Tie, Undecided };

/**
* @brief Convert addon value to one-hot encoding
* @tparam T output value type
* @param e enum to convert
* @return one-hot encoding
* @brief Get all the possible values for a particular enum
*
* @tparam E enum type
* @return Array of all possible enum values
*/
template<typename T> auto enumToOneHot(AddOn e) noexcept -> std::vector<T>
template<typename E>
requires std::is_enum_v<E>
[[nodiscard]] consteval auto getEnumValues()
{
constexpr std::array vals = { AddOn::None, AddOn::Reactor, AddOn::TechLab };
static_assert(std::is_sorted(vals.begin(), vals.end()));
return detail::enumToOneHot_helper<T>(e, vals);
if constexpr (std::same_as<E, Alliance>) {
return std::array{ Alliance::Self, Alliance::Ally, Alliance::Neutral, Alliance::Enemy };
} else if constexpr (std::same_as<E, CloakState>) {
return std::array{
CloakState::Unknown, CloakState::Cloaked, CloakState::Detected, CloakState::UnCloaked, CloakState::Allied
};
} else if constexpr (std::same_as<E, Visibility>) {
return std::array{ Visibility::Visible, Visibility::Snapshot, Visibility::Hidden };
} else if constexpr (std::same_as<E, AddOn>) {
return std::array{ AddOn::None, AddOn::Reactor, AddOn::TechLab };
} else if constexpr (std::same_as<E, Race>) {
return std::array{ Race::Terran, Race::Zerg, Race::Protoss, Race::Random };
} else if constexpr (std::same_as<E, Result>) {
return std::array{ Result::Win, Result::Loss, Result::Tie, Result::Undecided };
} else {
static_assert(detail::always_false_v<E>, "Failed to match type");
}
}

enum class Race : char { Terran, Zerg, Protoss, Random };

/**
* @brief Convert race value to one-hot encoding
* @tparam T output value type
* @param e enum to convert
* @return one-hot encoding
* @brief The number of possible values of an enum
* @tparam E type of enum
* @return Number of possible enum values
*/
template<typename T> auto enumToOneHot(Race e) noexcept -> std::vector<T>
template<typename E>
requires std::is_enum_v<E>
[[nodiscard]] constexpr auto numEnumValues() -> std::size_t
{
constexpr std::array vals = { Race::Terran, Race::Zerg, Race::Protoss, Race::Random };
static_assert(std::is_sorted(vals.begin(), vals.end()));
return detail::enumToOneHot_helper<T>(e, vals);
return getEnumValues<E>().size();
}

enum class Result : char { Win, Loss, Tie, Undecided };

/**
* @brief Convert result value to one-hot encoding
* @tparam T output value type
* @param e enum to convert
* @return one-hot encoding
* @brief Converts an enum value to a one-hot encoding
* @tparam E enum type to convert
* @tparam T element type of output vector
* @param e enum to one-hot encode
* @return one-hot encoding of enum
*/
template<typename T> auto enumToOneHot(Result e) noexcept -> std::vector<T>
template<typename T, typename E>
requires std::is_enum_v<E>
auto enumToOneHot(E e) noexcept -> std::vector<T>
{
constexpr std::array vals = { Result::Win, Result::Loss, Result::Tie, Result::Undecided };
static_assert(std::is_sorted(vals.begin(), vals.end()));
return detail::enumToOneHot_helper<T>(e, vals);
constexpr auto enumValues = getEnumValues<E>();
static_assert(std::is_sorted(enumValues.begin(), enumValues.end()));
auto it = std::ranges::find(enumValues, e);
std::vector<T> ret(enumValues.size());
ret[std::distance(enumValues.begin(), it)] = static_cast<T>(1);
return ret;
}


}// namespace cvt


Expand Down
33 changes: 11 additions & 22 deletions include/replay_parsing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,22 +142,14 @@ template<typename T, typename UnitT>
{
// Return empty array if no units
if (units.empty()) { return py::array_t<T>(); }
// Lambda wrapper around vectorize to set onehot to true
auto vecFn = [](const UnitT &unit) { return vectorize<T>(unit, true); };

// Create numpy array based on unit feature size
const auto firstUnitFeats = vecFn(units.front());
py::array_t<T> featureArray({ units.size(), firstUnitFeats.size() });

// Interpret numpy array as contiguous span to copy transformed data
std::span<T> rawData(featureArray.mutable_data(), units.size() * firstUnitFeats.size());
auto rawDataIt = rawData.begin();

// Start off by copying the already transformed data, then loop over the rest
rawDataIt = std::copy(firstUnitFeats.begin(), firstUnitFeats.end(), rawDataIt);
for (const auto &unitFeats : units | std::views::drop(1) | std::views::transform(vecFn)) {
rawDataIt = std::copy(unitFeats.begin(), unitFeats.end(), rawDataIt);
}
const bool expandOneHot{ true };
constexpr std::size_t unitDim = getVectorizedSize<UnitT>(expandOneHot);
py::array_t<T> featureArray({ units.size(), unitDim });
auto rawDataIt = featureArray.mutable_data();
// cppcheck-suppress useStlAlgorithm
for (const auto &unit : units) { rawDataIt = vectorize(unit, rawDataIt, expandOneHot); }
return featureArray;
}

Expand Down Expand Up @@ -185,25 +177,22 @@ template<typename T>
return pyDict;
}

// Lambda wrapper around vectorize to set onehot to true
auto vecFn = [](const Unit &unit) { return vectorize<T>(unit, true); };

std::unordered_map<cvt::Alliance, std::vector<T>> groupedUnitFeatures = { { cvt::Alliance::Self, {} },
{ cvt::Alliance::Ally, {} },
{ cvt::Alliance::Enemy, {} },
{ cvt::Alliance::Neutral, {} } };

const bool expandOneHot{ true };
for (auto &&unit : units) {
auto &group = groupedUnitFeatures.at(unit.alliance);
const auto unitFeat = vecFn(unit);
std::ranges::copy(unitFeat, std::back_inserter(group));
std::ranges::copy(vectorize<T>(unit, expandOneHot), std::back_inserter(group));
}

const std::size_t featureSize = vecFn(units.front()).size();
py::dict pyReturn;
constexpr std::size_t unitDim = getVectorizedSize<Unit>(expandOneHot);
for (auto &&[group, features] : groupedUnitFeatures) {
const std::size_t nUnits = features.size() / featureSize;
py::array_t<T> pyArray({ nUnits, featureSize });
const std::size_t nUnits = features.size() / unitDim;
py::array_t<T> pyArray({ nUnits, unitDim });
std::ranges::copy(features, pyArray.mutable_data());
pyReturn[py::cast(enum2str.at(group))] = pyArray;
}
Expand Down
Loading

0 comments on commit 5b027d5

Please sign in to comment.