Skip to content

Commit

Permalink
Fix #866, add bool support to rules params (#874)
Browse files Browse the repository at this point in the history
* Refactor Skirmish AI rules param functions

No logic or interface changes, just some readability
and preparation for further commits.

* Refactor the rules param selection filter

No logic or interface change, just readability (no longer
crams everything into a single statement due to macro) and
preparation for further commits.

* Fix #866, rules params can now be boolean

AI and selection filters read them as numbers, 0/1.

---------

Co-authored-by: lhog <[email protected]>
  • Loading branch information
sprunk and lhog authored Jul 20, 2023
1 parent 5335611 commit dd320bc
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 46 deletions.
2 changes: 2 additions & 0 deletions doc/changelog.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ Lua:
- add `Spring.GetUnitIsBeingBuilt(unitID) -> bool beingBuilt, number buildProgress`. Note that this
doesn't bring new capability because `buildProgress` was already available from `GetUnitHealth`,
and `beingBuilt` from `GetUnitIsStunned`, but as you can see it wasn't terribly convenient/intuitive.
- rules params now support the boolean type. Skirmish AI and the rules param selection filter can
read them via existing numerical interface by using 0 and 1.
- add `Script.DelayByFrames(frameDelay, function, args...)`. Runs `function(args...)` after a delay
of the specified number of frames (at least 1). Multiple functions can be queued onto the same frame
and run in the order they were added, just before that frame's GameFrame call-in.
Expand Down
30 changes: 18 additions & 12 deletions rts/ExternalAI/SSkirmishAICallbackImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,21 @@ static float getRulesParamFloatValueByName(
const char* rulesParamName,
float defaultValue
) {
float value = defaultValue;
const std::string key(rulesParamName);
const auto it = params.find(key);

if (it == params.end())
return value;
return defaultValue;

if (modParamIsVisible(it->second, losMask))
value = it->second.valueInt;
const LuaRulesParams::Param& param = it->second;
if (!modParamIsVisible(param, losMask))
return defaultValue;

return value;
if (std::holds_alternative <std::string> (param.value))
return defaultValue;
else if (std::holds_alternative <bool> (param.value))
return std::get <bool> (param.value) ? 1.0f : 0.0f;
else
return std::get <float> (param.value);
}

static const char* getRulesParamStringValueByName(
Expand All @@ -225,17 +229,19 @@ static const char* getRulesParamStringValueByName(
const char* rulesParamName,
const char* defaultValue
) {
const char* value = defaultValue;
const std::string key(rulesParamName);
const auto it = params.find(key);

if (it == params.end())
return value;
return defaultValue;

const LuaRulesParams::Param& param = it->second;
if (!modParamIsVisible(it->second, losMask))
return defaultValue;

if (modParamIsVisible(it->second, losMask))
value = it->second.valueString.c_str();
if (!std::holds_alternative <std::string> (param.value))
return defaultValue;

return value;
return std::get <std::string> (param.value).c_str();
}


Expand Down
45 changes: 33 additions & 12 deletions rts/Game/UI/SelectionKeyHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,33 +199,54 @@ namespace {
groupNum = -1;
)

DECLARE_FILTER_EX(RulesParamEquals, 2, unit->modParams.find(param) != unit->modParams.end() &&
((wantedValueStr.empty()) ? unit->modParams.find(param)->second.valueInt == wantedValue
: unit->modParams.find(param)->second.valueString == wantedValueStr),
std::string param;
struct RulesParamEquals_Filter : public Filter {
std::string paramName;
std::string wantedValueStr;

float wantedValue;
float wantedValueNum;

RulesParamEquals_Filter()
: Filter("RulesParamEquals", 2)
, wantedValueNum (0.0f)
{ }

bool ShouldIncludeUnit(const CUnit* unit) const override {
const auto it = unit->modParams.find(paramName);
if (it == unit->modParams.end())
return false;

const auto& param = it->second;
if (!wantedValueStr.empty()) {
if (std::holds_alternative <std::string> (param.value))
return std::get <std::string> (param.value) == wantedValueStr;
else
return false;
} else {
if (std::holds_alternative <float> (param.value))
return std::get <float> (param.value) == wantedValueNum;
else if (std::holds_alternative <bool> (param.value))
return (std::get <bool> (param.value) ? 1.0f : 0.0f) == wantedValueNum;
else
return false;
}
}

void SetParam(int index, const std::string& value) override {
switch (index) {
case 0: {
param = value;
paramName = value;
} break;
case 1: {
const char* cstr = value.c_str();
char* endNumPos = nullptr;
wantedValue = strtof(cstr, &endNumPos);
wantedValueNum = strtof(cstr, &endNumPos);
if (endNumPos == cstr) wantedValueStr = value;
} break;
}
},
wantedValue = 0.0f;
)
}
} RulesParamEquals_filter_instance;

#undef DECLARE_FILTER_EX
#undef DECLARE_FILTER
#undef STRTOF
}


Expand Down
4 changes: 2 additions & 2 deletions rts/Lua/LuaRulesParams.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
/* This file is part of the Spring engine (GPL v2 or later), see LICENSE.html */

#include "LuaRulesParams.h"
#include "System/creg/STL_Variant.h"

using namespace LuaRulesParams;

CR_BIND(Param,)
CR_REG_METADATA(Param, (
CR_MEMBER(los),
CR_MEMBER(valueInt),
CR_MEMBER(valueString)
CR_MEMBER(value)
))
4 changes: 2 additions & 2 deletions rts/Lua/LuaRulesParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#define LUA_RULESPARAMS_H

#include <string>
#include <variant>

#include "System/UnorderedMap.hpp"
#include "System/creg/creg_cond.h"
Expand Down Expand Up @@ -31,8 +32,7 @@ namespace LuaRulesParams
CR_DECLARE_STRUCT(Param)

int los = RULESPARAMLOS_PRIVATE;
float valueInt = 0.0f;
std::string valueString;
std::variant <bool, float, std::string> value;
};

typedef spring::unordered_map<std::string, Param> Params;
Expand Down
7 changes: 4 additions & 3 deletions rts/Lua/LuaSyncedCtrl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1268,10 +1268,11 @@ void SetRulesParam(lua_State* L, const char* caller, int offset,

// set the value of the parameter
if (lua_israwnumber(L, valIndex)) {
param.valueInt = lua_tofloat(L, valIndex);
param.valueString.resize(0);
param.value.emplace <float> (lua_tofloat(L, valIndex));
} else if (lua_israwboolean(L, valIndex)) {
param.value.emplace <bool> (lua_toboolean(L, valIndex));
} else if (lua_isstring(L, valIndex)) {
param.valueString = lua_tostring(L, valIndex);
param.value.emplace <std::string> (lua_tostring(L, valIndex));
} else if (lua_isnoneornil(L, valIndex)) {
params.erase(key);
return; //no need to set los if param was erased
Expand Down
38 changes: 23 additions & 15 deletions rts/Lua/LuaSyncedRead.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
#include "System/StringUtil.h"

#include <cctype>
#include <type_traits>


using std::min;
Expand Down Expand Up @@ -675,17 +676,21 @@ static int PushRulesParams(lua_State* L, const char* caller,
{
lua_createtable(L, 0, params.size());

for (auto& it: params) {
for (const auto& it: params) {
const std::string& name = it.first;
const LuaRulesParams::Param& param = it.second;
if (!(param.los & losStatus))
continue;

if (!param.valueString.empty()) {
LuaPushNamedString(L, name, param.valueString);
} else {
LuaPushNamedNumber(L, name, param.valueInt);
}
std::visit ([L, &name](auto&& value) {
using T = std::decay_t <decltype(value)>;
if constexpr (std::is_same_v <T, float>)
LuaPushNamedNumber(L, name, value);
else if constexpr (std::is_same_v <T, bool>)
LuaPushNamedBool(L, name, value);
else if constexpr (std::is_same_v <T, std::string>)
LuaPushNamedString(L, name, value);
}, param.value);
}

return 1;
Expand All @@ -702,17 +707,20 @@ static int GetRulesParam(lua_State* L, const char* caller, int index,
return 0;

const LuaRulesParams::Param& param = it->second;
if (!(param.los & losStatus))
return 0;

if (param.los & losStatus) {
if (!param.valueString.empty()) {
lua_pushsstring(L, param.valueString);
} else {
lua_pushnumber(L, param.valueInt);
}
return 1;
}
std::visit ([L](auto&& value) {
using T = std::decay_t <decltype(value)>;
if constexpr (std::is_same_v <T, float>)
lua_pushnumber(L, value);
else if constexpr (std::is_same_v <T, bool>)
lua_pushboolean(L, value);
else if constexpr (std::is_same_v <T, std::string>)
lua_pushsstring(L, value);
}, param.value);

return 0;
return 1;
}


Expand Down
71 changes: 71 additions & 0 deletions rts/System/creg/STL_Variant.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#ifndef CR_STL_VARIANT_H
#define CR_STL_VARIANT_H

#include "creg_cond.h"
#include <variant>

#ifdef USING_CREG

namespace creg
{
template <typename T0, typename T1, typename T2>
class Variant3Type : public IType
{
public:
using VT = std::variant<T0,T1,T2>;
Variant3Type() : IType(sizeof(VT)) { }
~Variant3Type() { }

void Serialize(ISerializer* s, void* instance)
{
VT& p = *(VT*)instance;
if (s->IsWriting()) {
int index = p.index();
s->SerializeInt(&index, sizeof(int));
switch (index) {
case 0: DeduceType<T0>::Get()->Serialize(s, (void*) &(std::get<0>(p))); break;
case 1: DeduceType<T1>::Get()->Serialize(s, (void*) &(std::get<1>(p))); break;
case 2: DeduceType<T2>::Get()->Serialize(s, (void*) &(std::get<2>(p))); break;
}
} else {
int index;
s->SerializeInt(&index, sizeof(int));
switch (index) {
case 0: {
T0 x;
DeduceType<T0>::Get()->Serialize(s, (void*) &x);
p = std::move(x); // neither `p.emplace <T0>` nor `<0>` worked; still, should be safe
} break;
case 1: {
T1 x;
DeduceType<T1>::Get()->Serialize(s, (void*) &x);
p = std::move(x);
} break;
case 2: {
T2 x;
DeduceType<T2>::Get()->Serialize(s, (void*) &x);
p = std::move(x);
} break;
}
}
}
std::string GetName() const { return "variant<"
+ DeduceType<T0>::Get()->GetName() + ","
+ DeduceType<T1>::Get()->GetName() + ","
+ DeduceType<T2>::Get()->GetName() + ">";
}
};

/* FIXME: ideally this would support arbitrary variants, but that involves some
* recursive variadic template bullshit. Three just happened to be the first use case. */
template<typename T0, typename T1, typename T2>
struct DeduceType<std::variant<T0,T1,T2> > {
static std::unique_ptr<IType> Get() {
return std::unique_ptr<IType>(new Variant3Type<T0,T1,T2>());
}
};
}

#endif // USING_CREG

#endif // CR_STL_VARIANT_H

0 comments on commit dd320bc

Please sign in to comment.