Skip to content

Commit

Permalink
Dyno: Implement param enum-to-string casts and fix param enum stringi…
Browse files Browse the repository at this point in the history
…fy (chapel-lang#25837)

Prior to this change, param enums could not be stringified because
EnumParam only stored an ID and the stringify signature lacks a
``Context*`` needed to map the ID to uAST. Without the uAST we cannot
compute the original name of the enum element. Instead, store a pair of
``ID`` and ``string`` inside EnumParam and compute the element's name
ahead of time.

``EnumParam::get`` is supplanted by ``Param::getEnumParam``, which is a
helper that uses the provided context to fetch the enum element's name
as a string, then builds the stored pair.

Using this capability, this PR easily implements param enum to string
casts.

After this PR, we can now print types with param-enum fields more
cleanly. Rather than ``R(mod.colors@0)`` we simply print, e.g.,
``R(green)``.

[reviewed-by @DanilaFe]
  • Loading branch information
benharsh authored Sep 3, 2024
2 parents 3d1c539 + f8a8fc8 commit 5950fce
Show file tree
Hide file tree
Showing 13 changed files with 140 additions and 28 deletions.
1 change: 1 addition & 0 deletions doc/util/nitpick_ignore
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ cpp:identifier ParamTag
cpp:identifier PragmaTag
cpp:identifier TypeTag
cpp:identifier PrimitiveTag
cpp:identifier EnumParam
cpp:identifier chpl::uast::PrimitiveTag
cpp:identifier types::RealParam
cpp:identifier types::IntParam
Expand Down
56 changes: 54 additions & 2 deletions frontend/include/chpl/types/Param.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,23 @@ class Param {
return 0;
}
};
struct EnumValue {
ID id;
std::string str;

EnumValue(ID id, std::string str)
: id(id), str(str)
{ }
bool operator==(const EnumValue& other) const {
return this->id == other.id && this->str == other.str;
}
bool operator!=(const EnumValue& other) const {
return !(*this == other);
}
size_t hash() const {
return chpl::hash(id, str);
}
};

private:
ParamTag tag_;
Expand Down Expand Up @@ -114,8 +131,8 @@ class Param {
static std::string valueToString(NoneValue v) {
return "none";
}
static std::string valueToString(ID id) {
return id.str();
static std::string valueToString(EnumValue v) {
return v.str;
}
static std::string valueToString(bool v) {
return v ? "true" : "false";
Expand Down Expand Up @@ -211,6 +228,8 @@ class Param {
#undef PARAM_NODE
#undef PARAM_TO

static const EnumParam* getEnumParam(Context* context, ID id);

/// \cond DO_NOT_DOCUMENT
DECLARE_DUMP;
/// \endcond DO_NOT_DOCUMENT
Expand Down Expand Up @@ -278,6 +297,19 @@ template<> struct stringify<chpl::types::Param::NoneValue> {
}
};

template<> struct stringify<chpl::types::Param::EnumValue> {
void operator()(std::ostream& streamOut,
chpl::StringifyKind stringKind,
const chpl::types::Param::EnumValue& stringMe) const {
if (stringKind == chpl::StringifyKind::CHPL_SYNTAX) {
streamOut << stringMe.str;
} else {
streamOut << stringMe.str;
streamOut << " (" << stringMe.id.str() << ")";
}
}
};

template<> struct serialize<types::Param::ComplexDouble> {
void operator()(Serializer& ser, types::Param::ComplexDouble val) const {
ser.write(val.re);
Expand Down Expand Up @@ -305,6 +337,21 @@ template<> struct deserialize<types::Param::NoneValue> {
}
};

template<> struct serialize<types::Param::EnumValue> {
void operator()(Serializer& ser, types::Param::EnumValue val) const {
ser.write(val.id);
ser.write(val.str);
}
};

template<> struct deserialize<types::Param::EnumValue> {
types::Param::EnumValue operator()(Deserializer& des) {
auto id = des.read<ID>();
auto str = des.read<std::string>();
return types::Param::EnumValue(id, str);
}
};

/// \endcond DO_NOT_DOCUMENT
} // end namespace chpl

Expand All @@ -321,6 +368,11 @@ namespace std {
return key.hash();
}
};
template<> struct hash<chpl::types::Param::EnumValue> {
size_t operator()(const chpl::types::Param::EnumValue key) const {
return key.hash();
}
};
} // end namespace std

#endif
2 changes: 1 addition & 1 deletion frontend/include/chpl/types/param-classes-list.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

PARAM_NODE(BoolParam, bool)
PARAM_NODE(ComplexParam, ComplexDouble)
PARAM_NODE(EnumParam, ID)
PARAM_NODE(EnumParam, EnumValue)
PARAM_NODE(IntParam, int64_t)
PARAM_NODE(NoneParam, NoneValue)
PARAM_NODE(RealParam, double)
Expand Down
9 changes: 5 additions & 4 deletions frontend/lib/resolution/Resolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3886,7 +3886,7 @@ Resolver::typeForScopeResolvedEnumElement(const EnumType* enumType,
bool ambiguous) {
if (!refersToId.isEmpty()) {
// Found a single enum element, the type can be a param.
auto newParam = EnumParam::get(context, refersToId);
auto newParam = Param::getEnumParam(context, refersToId);
return QualifiedType(QualifiedType::PARAM, enumType, newParam);
} else if (ambiguous) {
// multiple candidates. but the expression most likely has a type given by
Expand Down Expand Up @@ -4433,9 +4433,7 @@ resolveIterTypeWithTag(Resolver& rv,
// The iterand is an unresolved call, or it is a resolved iterator but
// not the one that we need. Regather existing actuals and reuse the
// receiver if it is present.
} else {
auto call = iterand->toCall();
CHPL_ASSERT(call);
} else if (auto call = iterand->toCall()) {

bool raiseErrors = false;
auto tmp = CallInfo::create(context, call, rv.byPostorder, raiseErrors);
Expand All @@ -4445,6 +4443,9 @@ resolveIterTypeWithTag(Resolver& rv,
callIsMethodCall = tmp.isMethodCall();
callIsParenless = tmp.isParenless();
for (auto& a : tmp.actuals()) callActuals.push_back(a);
} else {
CHPL_UNIMPL("unknown iterand");
return error;
}

if (!needSerial) {
Expand Down
11 changes: 11 additions & 0 deletions frontend/lib/resolution/resolution-queries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3590,6 +3590,17 @@ static bool resolveFnCallSpecial(Context* context,
if (isParamTypeCast) {
auto srcEnumType = src.type()->toEnumType();
auto dstEnumType = dst.type()->toEnumType();

if (srcEnumType && dst.type()->isStringType()) {
std::ostringstream oss;
src.param()->stringify(oss, chpl::StringifyKind::CHPL_SYNTAX);
auto ustr = UniqueString::get(context, oss.str());
exprTypeOut = QualifiedType(QualifiedType::PARAM,
RecordType::getStringType(context),
StringParam::get(context, ustr));
return true;
}

if (srcEnumType && srcEnumType->isAbstract()) {
exprTypeOut = CHPL_TYPE_ERROR(context, EnumAbstract, astForErr, "from", srcEnumType, dst.type());
return true;
Expand Down
4 changes: 2 additions & 2 deletions frontend/lib/resolution/return-type-inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ static bool helpComputeOrderToEnumReturnType(Context* context,
uint64_t counter = 0;
for (auto elem : ast->enumElements()) {
if (counter == whichValue) {
param = EnumParam::get(context, elem->id());
param = Param::getEnumParam(context, elem->id());
break;
}
counter++;
Expand Down Expand Up @@ -920,7 +920,7 @@ static bool helpComputeEnumToOrderReturnType(Context* context,
parsing::idToAst(context, et->id())->toEnum();
int counter = 0;
for (auto elem : ast->enumElements()) {
if (elem->id() == inputParam->value()) {
if (elem->id() == inputParam->value().id) {
param = IntParam::get(context, counter);
break;
}
Expand Down
2 changes: 1 addition & 1 deletion frontend/lib/types/EnumType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ getParamConstantsMapQuery(Context* context, const EnumType* et) {
auto ast = parsing::idToAst(context, et->id());
if (auto e = ast->toEnum()) {
for (auto elem : e->enumElements()) {
auto param = EnumParam::get(context, elem->id());
auto param = Param::getEnumParam(context, elem->id());
auto k = UniqueString::get(context, elem->name().str());
QualifiedType v(QualifiedType::PARAM, et, param);
ret.insert({std::move(k), std::move(v)});
Expand Down
12 changes: 9 additions & 3 deletions frontend/lib/types/Param.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,10 @@ optional<Immediate> paramToImmediate(Context* context,

if (ep) {
auto numericValueOpt =
computeNumericValueOfEnumElement(context, ep->value());
computeNumericValueOfEnumElement(context, ep->value().id);

if (!numericValueOpt) {
auto eltAst = parsing::idToAst(context, ep->value())->toEnumElement();
auto eltAst = parsing::idToAst(context, ep->value().id)->toEnumElement();
auto qt = CHPL_TYPE_ERROR(context, EnumValueAbstract, astForErr, et, eltAst);

// In order to be able to compose multiple calls to this function,
Expand Down Expand Up @@ -482,7 +482,7 @@ static QualifiedType enumParamFromNumericValue(Context* context,
if (elemId) {
return QualifiedType(QualifiedType::PARAM,
enumType,
EnumParam::get(context, elemId));
Param::getEnumParam(context, elemId));
} else {
return CHPL_TYPE_ERROR(context, NoMatchingEnumValue,
astForErr, enumType, numericValue);
Expand Down Expand Up @@ -946,6 +946,12 @@ IMPLEMENT_DUMP(Param);
// clear the macros
#undef PARAM_NODE

const EnumParam* Param::getEnumParam(Context* context, ID id) {
auto ast = parsing::idToAst(context, id)->toEnumElement();
CHPL_ASSERT(ast && "expecting EnumElement");
auto value = EnumValue(id, ast->name().str());
return EnumParam::get(context, value);
}

} // end namespace types
} // end namespace chpl
58 changes: 50 additions & 8 deletions frontend/test/resolution/testEnums.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ static void test1() {

auto et = qt.type()->toEnumType();
auto ep = qt.param()->toEnumParam();
assert(et->id().contains(ep->value()));
assert(et->id().contains(ep->value().id));
auto enumAst = parsing::idToAst(context, et->id());
assert(enumAst && enumAst->isEnum());
auto elemAst = parsing::idToAst(context, ep->value());
auto elemAst = parsing::idToAst(context, ep->value().id);
assert(elemAst && elemAst->isEnumElement());
}

Expand Down Expand Up @@ -366,11 +366,11 @@ static void test11() {

auto param0 = vars.at("a").param();
assert(param0 && param0->isEnumParam());
assert(param0->toEnumParam()->value().postOrderId() == 1);
assert(param0->toEnumParam()->value().id.postOrderId() == 1);

auto param1 = vars.at("b").param();
assert(param1 && param1->isEnumParam());
assert(param1->toEnumParam()->value().postOrderId() == 5);
assert(param1->toEnumParam()->value().id.postOrderId() == 5);
}

static void test12() {
Expand Down Expand Up @@ -579,22 +579,22 @@ static void test17() {
assert(vars.at("c").type()->isEnumType());
assert(vars.at("c").param());
assert(vars.at("c").param()->isEnumParam());
assert(vars.at("c").param()->toEnumParam()->value().postOrderId() == 0);
assert(vars.at("c").param()->toEnumParam()->value().id.postOrderId() == 0);
assert(vars.at("d").type());
assert(vars.at("d").type()->isEnumType());
assert(vars.at("d").param());
assert(vars.at("d").param()->isEnumParam());
assert(vars.at("d").param()->toEnumParam()->value().postOrderId() == 0);
assert(vars.at("d").param()->toEnumParam()->value().id.postOrderId() == 0);
assert(vars.at("e").type());
assert(vars.at("e").type()->isEnumType());
assert(vars.at("e").param());
assert(vars.at("e").param()->isEnumParam());
assert(vars.at("e").param()->toEnumParam()->value().postOrderId() == 1);
assert(vars.at("e").param()->toEnumParam()->value().id.postOrderId() == 1);
assert(vars.at("f").type());
assert(vars.at("f").type()->isEnumType());
assert(vars.at("f").param());
assert(vars.at("f").param()->isEnumParam());
assert(vars.at("f").param()->toEnumParam()->value().postOrderId() == 2);
assert(vars.at("f").param()->toEnumParam()->value().id.postOrderId() == 2);

assert(guard.realizeErrors() == 1);
}
Expand Down Expand Up @@ -671,6 +671,46 @@ static void test19() {
assert(vars.at("res").type()->isEnumType());
}

static void test20() {
Context ctx;
auto context = &ctx;
ErrorGuard guard(context);

auto vars = resolveTypesOfVariables(context,
R"""(
enum colors {red, green, blue};
param c = colors.red;
param s = c:string;
param x = colors.red:string;
param y = colors.green:string;
param z = colors.blue:string;
record R {
param p : colors;
}
var r = new R(colors.green);
)""", {"s", "x", "y", "z", "r"});

assert(guard.realizeErrors() == 0);

auto check = [] (QualifiedType qt, std::string text) {
assert(qt.type()->isStringType());
assert(qt.param()->toStringParam()->value() == text);
};

check(vars.at("s"), "red");
check(vars.at("x"), "red");
check(vars.at("y"), "green");
check(vars.at("z"), "blue");

std::ostringstream oss;
vars.at("r").type()->stringify(oss, StringifyKind::CHPL_SYNTAX);
assert(oss.str() == "R(green)");
}

int main() {
test1();
test2();
Expand All @@ -691,5 +731,7 @@ int main() {
test17();
test18();
test19();
test20();

return 0;
}
2 changes: 1 addition & 1 deletion frontend/test/resolution/testParams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ static void test4() {
assert(bestFn->formalType(0).isParam());
assert(bestFn->formalName(0) == UniqueString::get(context, "this"));
assert(bestFn->formalType(0).param()->isEnumParam());
assert(bestFn->formalType(0).param()->toEnumParam()->value() == greenEnum->id());
assert(bestFn->formalType(0).param()->toEnumParam()->value().id == greenEnum->id());
const ResolvedFunction* rfn = scopeResolveFunction(context, isBlueFn->id());
const auto tsi = typedSignatureInitial(context, rfn->signature()->untyped());
assert(tsi->formalType(0).isParam());
Expand Down
4 changes: 2 additions & 2 deletions frontend/test/resolution/testRanges.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ getRangeInfo(Context* context, const RecordType* r) {
assert(bounded.type()->isEnumType());
assert(bounded.param() != nullptr);
auto boundedValue = bounded.param()->toEnumParam();
auto id = boundedValue->value();
auto id = boundedValue->value().id;
auto astNode = idToAst(context, id)->toNamedDecl();
assert(astNode != nullptr);
std::string boundTypeStr = astNode->name().str();
Expand All @@ -46,7 +46,7 @@ getRangeInfo(Context* context, const RecordType* r) {
assert(stridable.type()->isEnumType());
assert(stridable.param() != nullptr);
auto stridableValue = stridable.param()->toEnumParam();
auto idS = stridableValue->value();
auto idS = stridableValue->value().id;
auto astNodeS = idToAst(context, idS)->toNamedDecl();
assert(astNodeS != nullptr);
std::string stridesStr = astNodeS->name().str();
Expand Down
5 changes: 2 additions & 3 deletions frontend/test/resolution/testTypeConstruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1540,15 +1540,14 @@ static void test44() {
assert(pt.type()->isEnumType());
assert(pt.param()->isEnumParam());
auto param = pt.param()->toEnumParam();
// TODO: properly stringify enum params
assert(param->value().str() == "M.coords@0");
assert(param->value().str == "x");


auto parent = xt->basicClassType()->parentClassType();
auto pf = parent->substitutions();
assert(pf.size() == 1);
assert(pf.begin()->second.type()->isEnumType());
assert(pf.begin()->second.param()->toEnumParam()->value().str() == "M.Other.color@0");
assert(pf.begin()->second.param()->toEnumParam()->value().str == "red");
}

int main() {
Expand Down
2 changes: 1 addition & 1 deletion tools/chapel-py/src/method-tables/param-methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ CLASS_END(BoolParam)

CLASS_BEGIN(EnumParam)
PLAIN_GETTER(EnumParam, value, "Get the value of this enum Param",
const chpl::uast::AstNode*, return parsing::idToAst(context, node->value()))
const chpl::uast::AstNode*, return parsing::idToAst(context, node->value().id))
CLASS_END(EnumParam)

CLASS_BEGIN(IntParam)
Expand Down

0 comments on commit 5950fce

Please sign in to comment.