diff --git a/CMakeLists.txt b/CMakeLists.txt index 85bb5cf..8922c70 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,7 +61,7 @@ if(MSVC) NAME Boost URL - "https://github.com/boostorg/boost/releases/download/boost-1.83.0/boost-1.83.0.zip" + "https://github.com/boostorg/boost/releases/download/boost-1.85.0/boost-1.85.0.zip" OPTIONS "BOOST_INCLUDE_LIBRARIES iostreams\\\\;pfr" "BOOST_IOSTREAMS_ENABLE_ZLIB") @@ -73,7 +73,7 @@ else() GITHUB_REPOSITORY boostorg/pfr GIT_TAG - 2.1.0) + 2.2.0) find_package(Boost REQUIRED COMPONENTS iostreams) endif() diff --git a/include/data_structures/common.hpp b/include/data_structures/common.hpp index 4933206..77ba9f1 100644 --- a/include/data_structures/common.hpp +++ b/include/data_structures/common.hpp @@ -52,7 +52,7 @@ namespace detail { requires std::is_arithmetic_v auto vectorize_helper(T d, It it, bool onehotEnum) -> It { - *it++ = static_cast(d); + *it++ = static_cast>(d); return it; } @@ -69,7 +69,7 @@ namespace detail { requires std::is_arithmetic_v> auto vectorize_helper(const T &d, It it, bool onehotEnum) -> It { - return std::ranges::transform(d, it, [](auto e) { return static_cast(e); }).out; + return std::ranges::transform(d, it, [](auto e) { return static_cast>(e); }).out; } /** @@ -85,7 +85,7 @@ namespace detail { requires std::is_enum_v auto vectorize_helper(T d, It it, bool onehotEnum) -> It { - using value_type = It::container_type::value_type; + using value_type = std::iter_value_t; if (onehotEnum) { it = std::ranges::copy(enumToOneHot(d), it).out; } else { @@ -111,6 +111,38 @@ namespace detail { d, [&it, onehotEnum](const auto &field) { it = detail::vectorize_helper(field, it, onehotEnum); }); return it; } + + /** + * @brief Does the main lifting for determining the size of a vectorized struct + * + * @tparam T Type to understand vectorization size + * @tparam oneHotEnum Flag to expand enums to onehot + * @return std::size_t The size of the vectorized type + */ + template + requires std::is_aggregate_v + consteval auto vectorizedSizeHelper() -> std::size_t + { + T d{};// Make plane prototype for pfr::for_each_field + std::size_t sum{ 0 }; + boost::pfr::for_each_field(d, [&sum](auto field) { + using field_type = decltype(field); + if constexpr (std::is_arithmetic_v) { + sum += 1; + } else if constexpr (std::is_enum_v && oneHotEnum) { + sum += numEnumValues(); + } else if constexpr (std::is_enum_v && !oneHotEnum) { + sum += 1; + } else if constexpr (std::is_aggregate_v) { + sum += vectorizedSizeHelper(); + } else { + static_assert(always_false_v, "Failed to match type"); + } + }); + + return sum; + } + }// namespace detail /** @@ -123,7 +155,7 @@ namespace detail { * @return Incremented output iterator */ template - requires std::is_aggregate_v && std::is_arithmetic_v + requires std::is_aggregate_v && std::is_arithmetic_v> [[maybe_unused]] auto vectorize(S s, It it, bool onehotEnum = false) -> It { boost::pfr::for_each_field( @@ -131,6 +163,23 @@ template return it; } +/** + * @brief Get the size of the struct if its vectorized with vectorize + * @tparam T Type to query + * @param onehotEnum Flag if oneHotEncodings are expanded or not + * @return Number of elements of the vectorized type + */ +template + requires std::is_aggregate_v +[[nodiscard]] constexpr auto getVectorizedSize(bool onehotEnum) -> std::size_t +{ + if (onehotEnum) { + return detail::vectorizedSizeHelper(); + } else { + return detail::vectorizedSizeHelper(); + } +} + /** * @brief Vectorize Struct of data to vector * @tparam T Output arithmetic type of vector @@ -141,10 +190,15 @@ template */ template requires std::is_aggregate_v && std::is_arithmetic_v -auto vectorize(S s, bool onehotEnum = false) -> std::vector +[[nodiscard]] auto vectorize(S s, bool onehotEnum = false) -> std::vector { - std::vector out; - vectorize(s, std::back_inserter(out), onehotEnum); + std::vector out(getVectorizedSize(onehotEnum)); + const auto end = vectorize(s, out.begin(), onehotEnum); + const auto writtenSize = std::distance(out.begin(), end); + if (writtenSize != out.size()) { + throw std::out_of_range("Expected vectorization with size " + std::to_string(out.size()) + " but got " + + std::to_string(writtenSize)); + } return out; } @@ -467,7 +521,7 @@ struct Action template auto enumToOneHot(Action::TargetType e) noexcept -> std::vector { using E = Action::TargetType; - constexpr std::array vals = { E::Self, E::OtherUnit, E::Position }; + constexpr std::array vals{ E::Self, E::OtherUnit, E::Position }; return detail::enumToOneHot_helper(e, vals); } diff --git a/include/data_structures/enums.hpp b/include/data_structures/enums.hpp index b8965f8..6d784b8 100644 --- a/include/data_structures/enums.hpp +++ b/include/data_structures/enums.hpp @@ -18,18 +18,23 @@ namespace cvt { -/** - * @brief Converts an enum value to a one-hot encoding - * @tparam E enum type to convert - * @tparam T element type of output vector - * @param e enum to convert - * @return one-hot encoding of enum - */ -template - requires std::is_enum_v -auto enumToOneHot(E e) noexcept -> std::vector; namespace detail { + /** + * @brief Helper type + * + * @tparam T + */ + template struct always_false : std::false_type + { + }; + + /** + * @brief Always false type to help printing type info at a compile time error + * + * @tparam T type to print + */ + template constexpr bool always_false_v = always_false::value; /** @@ -52,97 +57,76 @@ namespace detail { enum class Alliance : char { Self = 1, Ally = 2, Neutral = 3, Enemy = 4 }; -/** - * @brief Convert alliance value to one-hot encoding - * @tparam T output value type - * @param e enum to convert - * @return one-hot encoding - */ -template auto enumToOneHot(Alliance e) noexcept -> std::vector -{ - constexpr std::array vals = std::array{ Alliance::Self, Alliance::Ally, Alliance::Neutral, Alliance::Enemy }; - static_assert(std::is_sorted(vals.begin(), vals.end())); - return detail::enumToOneHot_helper(e, vals); -} - enum class CloakState : char { Unknown = 0, Cloaked = 1, Detected = 2, UnCloaked = 3, Allied = 4 }; -/** - * @brief Convert cloak state value to one-hot encoding - * @tparam T output value type - * @param e enum to convert - * @return one-hot encoding - */ -template auto enumToOneHot(CloakState e) noexcept -> std::vector -{ - constexpr std::array vals = { - CloakState::Unknown, CloakState::Cloaked, CloakState::Detected, CloakState::UnCloaked, CloakState::Allied - }; - static_assert(std::is_sorted(vals.begin(), vals.end())); - return detail::enumToOneHot_helper(e, vals); -} - enum class Visibility : char { Visible = 1, Snapshot = 2, Hidden = 3 }; -/** - * @brief Convert visibility value to one-hot encoding - * @tparam T output value type - * @param e enum to convert - * @return one-hot encoding - */ -template auto enumToOneHot(Visibility e) noexcept -> std::vector -{ - constexpr std::array vals = { Visibility::Visible, Visibility::Snapshot, Visibility::Hidden }; - static_assert(std::is_sorted(vals.begin(), vals.end())); - return detail::enumToOneHot_helper(e, vals); -} - enum class AddOn : char { None = 0, Reactor = 1, TechLab = 2 }; +enum class Race : char { Terran, Zerg, Protoss, Random }; + +enum class Result : char { Win, Loss, Tie, Undecided }; + /** - * @brief Convert addon value to one-hot encoding - * @tparam T output value type - * @param e enum to convert - * @return one-hot encoding + * @brief Get all the possible values for a particular enum + * + * @tparam E enum type + * @return Array of all possible enum values */ -template auto enumToOneHot(AddOn e) noexcept -> std::vector +template + requires std::is_enum_v +[[nodiscard]] consteval auto getEnumValues() { - constexpr std::array vals = { AddOn::None, AddOn::Reactor, AddOn::TechLab }; - static_assert(std::is_sorted(vals.begin(), vals.end())); - return detail::enumToOneHot_helper(e, vals); + if constexpr (std::same_as) { + return std::array{ Alliance::Self, Alliance::Ally, Alliance::Neutral, Alliance::Enemy }; + } else if constexpr (std::same_as) { + return std::array{ + CloakState::Unknown, CloakState::Cloaked, CloakState::Detected, CloakState::UnCloaked, CloakState::Allied + }; + } else if constexpr (std::same_as) { + return std::array{ Visibility::Visible, Visibility::Snapshot, Visibility::Hidden }; + } else if constexpr (std::same_as) { + return std::array{ AddOn::None, AddOn::Reactor, AddOn::TechLab }; + } else if constexpr (std::same_as) { + return std::array{ Race::Terran, Race::Zerg, Race::Protoss, Race::Random }; + } else if constexpr (std::same_as) { + return std::array{ Result::Win, Result::Loss, Result::Tie, Result::Undecided }; + } else { + static_assert(detail::always_false_v, "Failed to match type"); + } } -enum class Race : char { Terran, Zerg, Protoss, Random }; - /** - * @brief Convert race value to one-hot encoding - * @tparam T output value type - * @param e enum to convert - * @return one-hot encoding + * @brief The number of possible values of an enum + * @tparam E type of enum + * @return Number of possible enum values */ -template auto enumToOneHot(Race e) noexcept -> std::vector +template + requires std::is_enum_v +[[nodiscard]] constexpr auto numEnumValues() -> std::size_t { - constexpr std::array vals = { Race::Terran, Race::Zerg, Race::Protoss, Race::Random }; - static_assert(std::is_sorted(vals.begin(), vals.end())); - return detail::enumToOneHot_helper(e, vals); + return getEnumValues().size(); } -enum class Result : char { Win, Loss, Tie, Undecided }; - /** - * @brief Convert result value to one-hot encoding - * @tparam T output value type - * @param e enum to convert - * @return one-hot encoding + * @brief Converts an enum value to a one-hot encoding + * @tparam E enum type to convert + * @tparam T element type of output vector + * @param e enum to one-hot encode + * @return one-hot encoding of enum */ -template auto enumToOneHot(Result e) noexcept -> std::vector +template + requires std::is_enum_v +auto enumToOneHot(E e) noexcept -> std::vector { - constexpr std::array vals = { Result::Win, Result::Loss, Result::Tie, Result::Undecided }; - static_assert(std::is_sorted(vals.begin(), vals.end())); - return detail::enumToOneHot_helper(e, vals); + constexpr auto enumValues = getEnumValues(); + static_assert(std::is_sorted(enumValues.begin(), enumValues.end())); + auto it = std::ranges::find(enumValues, e); + std::vector ret(enumValues.size()); + ret[std::distance(enumValues.begin(), it)] = static_cast(1); + return ret; } - }// namespace cvt diff --git a/include/replay_parsing.hpp b/include/replay_parsing.hpp index 16a97b1..e85bfcb 100644 --- a/include/replay_parsing.hpp +++ b/include/replay_parsing.hpp @@ -142,22 +142,14 @@ template { // Return empty array if no units if (units.empty()) { return py::array_t(); } - // Lambda wrapper around vectorize to set onehot to true - auto vecFn = [](const UnitT &unit) { return vectorize(unit, true); }; // Create numpy array based on unit feature size - const auto firstUnitFeats = vecFn(units.front()); - py::array_t featureArray({ units.size(), firstUnitFeats.size() }); - - // Interpret numpy array as contiguous span to copy transformed data - std::span rawData(featureArray.mutable_data(), units.size() * firstUnitFeats.size()); - auto rawDataIt = rawData.begin(); - - // Start off by copying the already transformed data, then loop over the rest - rawDataIt = std::copy(firstUnitFeats.begin(), firstUnitFeats.end(), rawDataIt); - for (const auto &unitFeats : units | std::views::drop(1) | std::views::transform(vecFn)) { - rawDataIt = std::copy(unitFeats.begin(), unitFeats.end(), rawDataIt); - } + const bool expandOneHot{ true }; + constexpr std::size_t unitDim = getVectorizedSize(expandOneHot); + py::array_t featureArray({ units.size(), unitDim }); + auto rawDataIt = featureArray.mutable_data(); + // cppcheck-suppress useStlAlgorithm + for (const auto &unit : units) { rawDataIt = vectorize(unit, rawDataIt, expandOneHot); } return featureArray; } @@ -185,25 +177,22 @@ template return pyDict; } - // Lambda wrapper around vectorize to set onehot to true - auto vecFn = [](const Unit &unit) { return vectorize(unit, true); }; - std::unordered_map> groupedUnitFeatures = { { cvt::Alliance::Self, {} }, { cvt::Alliance::Ally, {} }, { cvt::Alliance::Enemy, {} }, { cvt::Alliance::Neutral, {} } }; + const bool expandOneHot{ true }; for (auto &&unit : units) { auto &group = groupedUnitFeatures.at(unit.alliance); - const auto unitFeat = vecFn(unit); - std::ranges::copy(unitFeat, std::back_inserter(group)); + std::ranges::copy(vectorize(unit, expandOneHot), std::back_inserter(group)); } - const std::size_t featureSize = vecFn(units.front()).size(); py::dict pyReturn; + constexpr std::size_t unitDim = getVectorizedSize(expandOneHot); for (auto &&[group, features] : groupedUnitFeatures) { - const std::size_t nUnits = features.size() / featureSize; - py::array_t pyArray({ nUnits, featureSize }); + const std::size_t nUnits = features.size() / unitDim; + py::array_t pyArray({ nUnits, unitDim }); std::ranges::copy(features, pyArray.mutable_data()); pyReturn[py::cast(enum2str.at(group))] = pyArray; } diff --git a/src/sc2_serializer/unit_features.py b/src/sc2_serializer/unit_features.py index 026fcc4..b57a844 100644 --- a/src/sc2_serializer/unit_features.py +++ b/src/sc2_serializer/unit_features.py @@ -44,10 +44,26 @@ class Unit(IntEnum): is_burrowed = auto() is_powered = auto() in_cargo = auto() - order0 = auto() - order1 = auto() - order2 = auto() - order3 = auto() + order0_ability_id = auto() + order0_progress = auto() + order0_target_id = auto() + order0_target_x = auto() + order0_target_y = auto() + order1_ability_id = auto() + order1_progress = auto() + order1_target_id = auto() + order1_target_x = auto() + order1_target_y = auto() + order2_ability_id = auto() + order2_progress = auto() + order2_target_id = auto() + order2_target_x = auto() + order2_target_y = auto() + order3_ability_id = auto() + order3_progress = auto() + order3_target_id = auto() + order3_target_x = auto() + order3_target_y = auto() class UnitOH(IntEnum): @@ -95,10 +111,26 @@ class UnitOH(IntEnum): is_burrowed = auto() is_powered = auto() in_cargo = auto() - order0 = auto() - order1 = auto() - order2 = auto() - order3 = auto() + order0_ability_id = auto() + order0_progress = auto() + order0_target_id = auto() + order0_target_x = auto() + order0_target_y = auto() + order1_ability_id = auto() + order1_progress = auto() + order1_target_id = auto() + order1_target_x = auto() + order1_target_y = auto() + order2_ability_id = auto() + order2_progress = auto() + order2_target_id = auto() + order2_target_x = auto() + order2_target_y = auto() + order3_ability_id = auto() + order3_progress = auto() + order3_target_id = auto() + order3_target_x = auto() + order3_target_y = auto() class NeutralUnit(IntEnum):