diff --git a/frontend/include/chpl/types/DomainType.h b/frontend/include/chpl/types/DomainType.h index aeda1ae87492..890faae90d44 100644 --- a/frontend/include/chpl/types/DomainType.h +++ b/frontend/include/chpl/types/DomainType.h @@ -98,8 +98,9 @@ class DomainType final : public CompositeType { /** Return an associative domain type */ static const DomainType* getAssociativeType(Context* context, - const QualifiedType& idxType, - const QualifiedType& parSafe); + const QualifiedType& instance, + const QualifiedType& idxType, + const QualifiedType& parSafe); /** Get the default distribution type */ static const QualifiedType& getDefaultDistType(Context* context); diff --git a/frontend/include/chpl/types/Type.h b/frontend/include/chpl/types/Type.h index dd80157d0a59..ce5022dd7613 100644 --- a/frontend/include/chpl/types/Type.h +++ b/frontend/include/chpl/types/Type.h @@ -243,6 +243,8 @@ class Type { */ bool isUserRecordType() const; + bool isRecordLike() const; + /** Returns true if the this type has the pragma 'p' attached to it. */ bool hasPragma(Context* context, uast::pragmatags::PragmaTag p) const; @@ -316,6 +318,8 @@ class Type { */ static bool isPod(Context* context, const Type* t); + static bool needsInitDeinitCall(const Type* t); + /// \cond DO_NOT_DOCUMENT DECLARE_DUMP; /// \endcond DO_NOT_DOCUMENT diff --git a/frontend/lib/resolution/InitResolver.cpp b/frontend/lib/resolution/InitResolver.cpp index 0c657eec12b3..d348ffe234dc 100644 --- a/frontend/lib/resolution/InitResolver.cpp +++ b/frontend/lib/resolution/InitResolver.cpp @@ -295,9 +295,9 @@ static const DomainType* domainTypeFromSubsHelper( if (auto instanceBct = instanceCt->basicClassType()) { // Get BaseRectangularDom parent subs for rectangular domain info if (auto baseDom = instanceBct->parentClassType()) { - auto& rf = fieldsForTypeDecl(context, baseDom, - DefaultsPolicy::IGNORE_DEFAULTS); if (baseDom->id().symbolPath() == "ChapelDistribution.BaseRectangularDom") { + auto& rf = fieldsForTypeDecl(context, baseDom, + DefaultsPolicy::IGNORE_DEFAULTS); CHPL_ASSERT(rf.numFields() == 3); QualifiedType rank; QualifiedType idxType; @@ -315,7 +315,24 @@ static const DomainType* domainTypeFromSubsHelper( return DomainType::getRectangularType(context, instanceQt, rank, idxType, strides); } else if (baseDom->id().symbolPath() == "ChapelDistribution.BaseAssociativeDom") { - // TODO: support associative domains + // Currently the relevant associative domain fields are defined + // on all the children of BaseAssociativeDom, so get information + // from there. + auto& rf = fieldsForTypeDecl(context, instanceBct, + DefaultsPolicy::IGNORE_DEFAULTS); + CHPL_ASSERT(rf.numFields() >= 2); + QualifiedType idxType; + QualifiedType parSafe; + for (int i = 0; i < rf.numFields(); i++) { + if (rf.fieldName(i) == "idxType") { + idxType = rf.fieldType(i); + } else if (rf.fieldName(i) == "parSafe") { + parSafe = rf.fieldType(i); + } + } + + return DomainType::getAssociativeType(context, instanceQt, idxType, + parSafe); } else if (baseDom->id().symbolPath() == "ChapelDistribution.BaseSparseDom") { // TODO: support sparse domains } else { diff --git a/frontend/lib/resolution/call-init-deinit.cpp b/frontend/lib/resolution/call-init-deinit.cpp index e508cbd96643..2a160de63ed8 100644 --- a/frontend/lib/resolution/call-init-deinit.cpp +++ b/frontend/lib/resolution/call-init-deinit.cpp @@ -368,46 +368,6 @@ void CallInitDeinit::processDeinitsAndPropagate(VarFrame* frame, } } -static bool isRecordLike(const Type* t) { - if (auto ct = t->toClassType()) { - auto decorator = ct->decorator(); - // no action needed for 'borrowed' or 'unmanaged' - // (these should just default initialized to 'nil', - // so nothing else needs to be resolved) - if (! (decorator.isBorrowed() || decorator.isUnmanaged() || - decorator.isUnknownManagement())) { - return true; - } - } else if (t->isRecordType() || t->isUnionType()) { - return true; - } - // TODO: tuples? - - return false; -} - -static bool typeNeedsInitDeinitCall(const Type* t) { - if (t == nullptr || t->isUnknownType() || t->isErroneousType()) { - // can't do anything with these - return false; - } else if (t->isPrimitiveType() || t->isBuiltinType() || t->isCStringType() || - t->isNilType() || t->isNothingType() || t->isVoidType()) { - // OK, we can always default initialize primitive numeric types, - // and as well we assume that for the non-generic builtin types - // e.g. TaskIdType. - // No need to resolve anything additional now. - return false; - } else if (t->isEnumType()) { - // OK, can default-initialize enums to first element - return false; - } else if (t->isFunctionType()) { - return false; - } - - return isRecordLike(t); -} - - void CallInitDeinit::resolveDefaultInit(const VarLikeDecl* ast, RV& rv) { // Type variables do not need default init. if (ast->storageKind() == Qualifier::TYPE) return; @@ -446,7 +406,7 @@ void CallInitDeinit::resolveDefaultInit(const VarLikeDecl* ast, RV& rv) { } } - if (!typeNeedsInitDeinitCall(vt)) { + if (!Type::needsInitDeinitCall(vt)) { // nothing to do here } else if (vt->isTupleType()) { // TODO: probably need to do something here, at least in some cases @@ -563,7 +523,7 @@ void CallInitDeinit::resolveCopyInit(const AstNode* ast, const QualifiedType& rhsType, bool forMoveInit, RV& rv) { - if (!typeNeedsInitDeinitCall(lhsType.type())) { + if (!Type::needsInitDeinitCall(lhsType.type())) { // TODO: we could resolve it anyway return; } @@ -672,7 +632,7 @@ void CallInitDeinit::processInit(VarFrame* frame, const QualifiedType& rhsType, RV& rv) { if (lhsType.type() == rhsType.type() && - !typeNeedsInitDeinitCall(lhsType.type())) { + !Type::needsInitDeinitCall(lhsType.type())) { // TODO: we should resolve it anyway return; } @@ -722,7 +682,7 @@ void CallInitDeinit::processInit(VarFrame* frame, } else { // it is copy initialization, so use init= for records // and assign for other stuff - if (lhsType.type() != nullptr && isRecordLike(lhsType.type())) { + if (lhsType.type() != nullptr && lhsType.type()->isRecordLike()) { resolveCopyInit(ast, rhsAst, lhsType, rhsType, /* forMoveInit */ false, @@ -741,7 +701,7 @@ void CallInitDeinit::resolveDeinit(const AstNode* ast, RV& rv) { // nothing to do for 'int' etc - if (!typeNeedsInitDeinitCall(type.type())) { + if (!Type::needsInitDeinitCall(type.type())) { return; } else if (type.type()->isTupleType()) { // TODO: probably need to do something here, at least in some cases @@ -968,7 +928,7 @@ void CallInitDeinit::handleInFormal(const FnCall* ast, const AstNode* actual, // is the copy for 'in' elided? if (elidedCopyFromIds.count(actual->id()) > 0 && isValue(actualType.kind()) && - typeNeedsInitDeinitCall(actualType.type())) { + Type::needsInitDeinitCall(actualType.type())) { // it is move initialization resolveMoveInit(actual, actual, formalType, actualType, rv); diff --git a/frontend/lib/resolution/prims.cpp b/frontend/lib/resolution/prims.cpp index 06fc14916779..02eed2edf33f 100644 --- a/frontend/lib/resolution/prims.cpp +++ b/frontend/lib/resolution/prims.cpp @@ -293,6 +293,7 @@ static QualifiedType computeDomainType(Context* context, const CallInfo& ci) { return QualifiedType(QualifiedType::TYPE, type); } else if (ci.numActuals() == 2) { auto type = DomainType::getAssociativeType(context, + QualifiedType(), ci.actual(0).type(), ci.actual(1).type()); return QualifiedType(QualifiedType::TYPE, type); @@ -389,6 +390,18 @@ static QualifiedType primTypeof(Context* context, PrimitiveTag prim, const CallI return QualifiedType(QualifiedType::TYPE, typePtr); } +static QualifiedType primPromotionType(Context* context, const CallInfo& ci) { + if (ci.numActuals() != 1) return QualifiedType(); + auto actualQt = ci.actual(0).type(); + + auto promoTy = getPromotionType(context, actualQt).type(); + + // We want a type result, even if the prim was passed a value. + auto promoQt = QualifiedType(QualifiedType::TYPE, promoTy); + + return promoQt; +} + static QualifiedType primGetSvecMember(Context* context, PrimitiveTag prim, const CallInfo& ci) { CHPL_ASSERT(prim == PRIM_GET_SVEC_MEMBER || @@ -985,6 +998,12 @@ static QualifiedType primIsPod(Context* context, const CallInfo& ci) { }); } +static QualifiedType primNeedsAutoDestroy(Context* context, const CallInfo& ci) { + return actualTypeHasProperty(context, ci, [=](auto t) { + return Type::needsInitDeinitCall(t) && !Type::isPod(context, t); + }); +} + static QualifiedType primIsCoercible(Context* context, const CallInfo& ci) { if (ci.numActuals() < 2) return QualifiedType(); @@ -1243,10 +1262,13 @@ CallResolutionResult resolvePrimCall(ResolutionContext* rc, break; case PRIM_HAS_DEFAULT_VALUE: - case PRIM_NEEDS_AUTO_DESTROY: CHPL_UNIMPL("various primitives"); break; + case PRIM_NEEDS_AUTO_DESTROY: + type = primNeedsAutoDestroy(context, ci); + break; + case PRIM_CALL_RESOLVES: case PRIM_CALL_AND_FN_RESOLVES: case PRIM_METHOD_CALL_AND_FN_RESOLVES: @@ -1687,7 +1709,7 @@ CallResolutionResult resolvePrimCall(ResolutionContext* rc, break; case PRIM_SCALAR_PROMOTION_TYPE: - CHPL_UNIMPL("misc primitives"); + type = primPromotionType(context, ci); break; case PRIM_STATIC_FIELD_TYPE: type = staticFieldType(context, ci); @@ -1744,7 +1766,6 @@ CallResolutionResult resolvePrimCall(ResolutionContext* rc, case PRIM_BLOCK_UNLOCAL: case PRIM_LOGICAL_FOLDER: case PRIM_WIDE_MAKE: - case PRIM_WIDE_GET_LOCALE: case PRIM_REGISTER_GLOBAL_VAR: case PRIM_BROADCAST_GLOBAL_VARS: case PRIM_PRIVATE_BROADCAST: @@ -1769,6 +1790,11 @@ CallResolutionResult resolvePrimCall(ResolutionContext* rc, CHPL_UNIMPL("misc primitives"); break; + case PRIM_WIDE_GET_LOCALE: + type = QualifiedType(QualifiedType::CONST_VAR, + CompositeType::getLocaleIDType(context)); + break; + case PRIM_GATHER_TESTS: type = primGatherTests(context, ci); break; diff --git a/frontend/lib/resolution/resolution-queries.cpp b/frontend/lib/resolution/resolution-queries.cpp index 3e82c365041e..791fb10c4f69 100644 --- a/frontend/lib/resolution/resolution-queries.cpp +++ b/frontend/lib/resolution/resolution-queries.cpp @@ -2735,9 +2735,11 @@ helpResolveFunction(ResolutionContext* rc, const TypedFnSignature* sig, // same function twice when working with inferred 'out' formals) sig = sig->inferredFrom(); - if (!sig->isInitializer() && sig->needsInstantiation()) { - CHPL_ASSERT(false && "Should only be called on concrete or fully " - "instantiated functions"); + if (!sig->isInitializer() && !sig->untyped()->isTypeConstructor() && + sig->needsInstantiation()) { + CHPL_ASSERT(false && + "Should only be called on concrete or fully " + "instantiated functions"); return nullptr; } @@ -3787,13 +3789,9 @@ static bool resolveFnCallSpecial(Context* context, } if ((ci.name() == USTR("==") || ci.name() == USTR("!="))) { - if (ci.numActuals() == 2 || ci.hasQuestionArg()) { + if (ci.numActuals() == 2) { auto lhs = ci.actual(0).type(); - - // support comparisons with '?' - auto rhs = ci.hasQuestionArg() ? - QualifiedType(QualifiedType::TYPE, AnyType::get(context)) : - ci.actual(1).type(); + auto rhs = ci.actual(1).type(); bool bothType = lhs.kind() == QualifiedType::TYPE && rhs.kind() == QualifiedType::TYPE; @@ -3806,6 +3804,27 @@ static bool resolveFnCallSpecial(Context* context, BoolParam::get(context, result)); return true; } + } else if (ci.numActuals() == 1 && ci.hasQuestionArg()) { + // support type and param comparisons with '?' + // TODO: will likely need adjustment once we are able to compare a + // partially-instantiated type's fields with '?' + auto arg = ci.actual(0).type(); + bool result = false; + bool haveResult = true; + if (arg.isType()) { + result = arg.type()->isAnyType(); + } else if (arg.isParam()) { + result = arg.param() == nullptr; + } else { + haveResult = false; + } + result = ci.name() == USTR("==") ? result : !result; + if (haveResult) { + exprTypeOut = + QualifiedType(QualifiedType::PARAM, BoolType::get(context), + BoolParam::get(context, result)); + return true; + } } } diff --git a/frontend/lib/types/CompositeType.cpp b/frontend/lib/types/CompositeType.cpp index 550fb9d3ee81..7b790c8bdd2f 100644 --- a/frontend/lib/types/CompositeType.cpp +++ b/frontend/lib/types/CompositeType.cpp @@ -188,8 +188,8 @@ const RecordType* CompositeType::getLocaleType(Context* context) { } const RecordType* CompositeType::getLocaleIDType(Context* context) { - auto id = ID(); - auto name = UniqueString::get(context, "chpl_localeID_t"); + auto [id, name] = parsing::getSymbolFromTopLevelModule( + context, "LocaleModelHelpRuntime", "chpl_localeID_t"); return RecordType::get(context, id, name, /* instantiatedFrom */ nullptr, SubstitutionsMap()); diff --git a/frontend/lib/types/DomainType.cpp b/frontend/lib/types/DomainType.cpp index 74469533b81b..2e15ce1bb3b1 100644 --- a/frontend/lib/types/DomainType.cpp +++ b/frontend/lib/types/DomainType.cpp @@ -118,17 +118,35 @@ DomainType::getRectangularType(Context* context, const DomainType* DomainType::getAssociativeType(Context* context, + const QualifiedType& instance, const QualifiedType& idxType, const QualifiedType& parSafe) { + auto genericDomain = getGenericDomainType(context); + SubstitutionsMap subs; - // TODO: assert validity of sub types subs.emplace(ID(UniqueString(), 0, 0), idxType); + CHPL_ASSERT(idxType.isType()); subs.emplace(ID(UniqueString(), 1, 0), parSafe); + CHPL_ASSERT(parSafe.isParam() && parSafe.param() && + parSafe.param()->isBoolParam()); + + // Add substitution for _instance field + auto& rf = fieldsForTypeDecl(context, genericDomain, + resolution::DefaultsPolicy::IGNORE_DEFAULTS, + /* syntaxOnly */ true); + ID instanceFieldId; + for (int i = 0; i < rf.numFields(); i++) { + if (rf.fieldName(i) == USTR("_instance")) { + instanceFieldId = rf.fieldDeclId(i); + break; + } + } + subs.emplace(instanceFieldId, instance); + auto name = UniqueString::get(context, "_domain"); auto id = getDomainID(context); - auto instantiatedFrom = getGenericDomainType(context); - return getDomainType(context, id, name, instantiatedFrom, subs, - DomainType::Kind::Associative).get(); + return getDomainType(context, id, name, /* instantiatedFrom */ genericDomain, + subs, DomainType::Kind::Associative).get(); } const QualifiedType& DomainType::getDefaultDistType(Context* context) { diff --git a/frontend/lib/types/Type.cpp b/frontend/lib/types/Type.cpp index 4e0149d9a5c3..eba4daeac301 100644 --- a/frontend/lib/types/Type.cpp +++ b/frontend/lib/types/Type.cpp @@ -228,6 +228,24 @@ bool Type::hasPragma(Context* context, uast::pragmatags::PragmaTag p) const { return false; } +bool Type::isRecordLike() const { + if (auto ct = this->toClassType()) { + auto decorator = ct->decorator(); + // no action needed for 'borrowed' or 'unmanaged' + // (these should just default initialized to 'nil', + // so nothing else needs to be resolved) + if (!(decorator.isBorrowed() || decorator.isUnmanaged() || + decorator.isUnknownManagement())) { + return true; + } + } else if (this->isRecordType() || this->isUnionType()) { + return true; + } + // TODO: tuples? + + return false; +} + const CompositeType* Type::getCompositeType() const { if (auto at = toCompositeType()) return at; @@ -306,5 +324,26 @@ bool Type::isPod(Context* context, const Type* t) { return true; } +bool Type::needsInitDeinitCall(const Type* t) { + if (t == nullptr || t->isUnknownType() || t->isErroneousType()) { + // can't do anything with these + return false; + } else if (t->isPrimitiveType() || t->isBuiltinType() || t->isCStringType() || + t->isNilType() || t->isNothingType() || t->isVoidType()) { + // OK, we can always default initialize primitive numeric types, + // and as well we assume that for the non-generic builtin types + // e.g. TaskIdType. + // No need to resolve anything additional now. + return false; + } else if (t->isEnumType()) { + // OK, can default-initialize enums to first element + return false; + } else if (t->isFunctionType()) { + return false; + } + + return t->isRecordLike(); +} + } // end namespace types } // end namespace chpl diff --git a/frontend/test/resolution/testDomains.cpp b/frontend/test/resolution/testDomains.cpp index 9ae87f455dbf..83c76a7c2c8d 100644 --- a/frontend/test/resolution/testDomains.cpp +++ b/frontend/test/resolution/testDomains.cpp @@ -105,7 +105,6 @@ module M { assert(aa.action() == AssociatedAction::RUNTIME_TYPE); QualifiedType fullIndexType = findVarType(m, rr, "fullIndex"); - (void)fullIndexType; auto rankVarTy = findVarType(m, rr, "r"); assert(rankVarTy == dType->rank()); @@ -158,101 +157,147 @@ module M { printf("Success: %s\n", domainType.c_str()); } -// static void testAssociative(Context* context, -// std::string domainType, -// std::string idxType, -// bool parSafe) { -// context->advanceToNextRevision(false); -// setupModuleSearchPaths(context, false, false, {}, {}); -// ErrorGuard guard(context); +static void testDomainLiteral(Context* context, + std::string domainLiteral, + DomainType::Kind domainKind) { + context->advanceToNextRevision(false); + setupModuleSearchPaths(context, false, false, {}, {}); + ErrorGuard guard(context); -// std::string program = -// R"""( -// module M { -// var d : )""" + domainType + R"""(; -// type ig = )""" + idxType + R"""(; + std::string program = +R"""( +module M { + var d = )""" + domainLiteral + R"""(; -// type i = d.idxType; -// param s = d.parSafe; -// param rk = d.isRectangular(); -// param ak = d.isAssociative(); + type i = d.idxType; + param rk = d.isRectangular(); + param ak = d.isAssociative(); +} +)"""; -// var p = d.pid(); + auto path = UniqueString::get(context, "input.chpl"); + setFileText(context, path, std::move(program)); -// for loopI in d { -// var z = loopI; -// } + const ModuleVec& vec = parseToplevel(context, path); + const Module* m = vec[0]; -// proc generic(arg: domain) { -// type GT = arg.type; -// return 42; -// } + const ResolutionResultByPostorderID& rr = resolveModule(context, m->id()); -// proc concrete(arg: )""" + domainType + R"""() { -// type CT = arg.type; -// return 42; -// } + const Variable* d = m->stmt(0)->toVariable(); + assert(d); + assert(d->name() == "d"); -// var g_ret = generic(d); -// var c_ret = concrete(d); -// } -// )"""; -// // TODO: generic checks + QualifiedType dQt = rr.byAst(d).type(); + assert(dQt.type()); + auto dType = dQt.type()->toDomainType(); + assert(dType); -// auto path = UniqueString::get(context, "input.chpl"); -// setFileText(context, path, std::move(program)); + assert(findVarType(m, rr, "i") == dType->idxType()); -// const ModuleVec& vec = parseToplevel(context, path); -// const Module* m = vec[1]; + assert(dType->kind() == domainKind); + bool isRectangular = domainKind == DomainType::Kind::Rectangular; + assert(findVarType(m, rr, "rk").param()->toBoolParam()->value() == isRectangular); + assert(findVarType(m, rr, "ak").param()->toBoolParam()->value() == !isRectangular); -// const ResolutionResultByPostorderID& rr = resolveModule(context, m->id()); + assert(guard.realizeErrors() == 0); -// QualifiedType dType = findVarType(m, rr, "d"); -// assert(dType.type()->isDomainType()); + printf("Success: %s\n", domainLiteral.c_str()); +} -// auto fullIndexType = findVarType(m, rr, "i"); -// assert(findVarType(m, rr, "ig") == fullIndexType); +static void testAssociative(Context* context, + std::string domainType, + std::string idxType, + bool parSafe) { + context->advanceToNextRevision(false); + setupModuleSearchPaths(context, false, false, {}, {}); + ErrorGuard guard(context); -// assert(findVarType(m, rr, "s").param()->toBoolParam()->value() == parSafe); + std::string program = +R"""( +module M { + var d : )""" + domainType + R"""(; + type ig = )""" + idxType + R"""(; -// assert(findVarType(m, rr, "rk").param()->toBoolParam()->value() == false); + type i = d.idxType; + param s = d.parSafe; + param rk = d.isRectangular(); + param ak = d.isAssociative(); -// assert(findVarType(m, rr, "ak").param()->toBoolParam()->value() == true); + var p = d.pid; -// assert(findVarType(m, rr, "p").type() == IntType::get(context, 0)); + for loopI in d { + var z = loopI; + } -// assert(findVarType(m, rr, "z").type() == fullIndexType.type()); + proc generic(arg: domain) { + type GT = arg.type; + return 42; + } -// { -// const Variable* g_ret = findOnlyNamed(m, "g_ret")->toVariable(); -// auto res = rr.byAst(g_ret); -// assert(res.type().type()->isIntType()); + proc concrete(arg: )""" + domainType + R"""() { + type CT = arg.type; + return 42; + } + + var g_ret = generic(d); + var c_ret = concrete(d); +} +)"""; + + auto path = UniqueString::get(context, "input.chpl"); + setFileText(context, path, std::move(program)); + + const ModuleVec& vec = parseToplevel(context, path); + const Module* m = vec[0]; + + const ResolutionResultByPostorderID& rr = resolveModule(context, m->id()); -// auto call = resolveOnlyCandidate(context, rr.byAst(g_ret->initExpression())); -// // Generic function, should have been instantiated -// assert(call->signature()->instantiatedFrom() != nullptr); + QualifiedType dType = findVarType(m, rr, "d"); + assert(dType.type()->isDomainType()); -// const Variable* GT = findOnlyNamed(m, "GT")->toVariable(); -// assert(call->byAst(GT).type().type() == dType.type()); -// } + auto fullIndexType = findVarType(m, rr, "i"); + assert(findVarType(m, rr, "ig") == fullIndexType); -// { -// const Variable* c_ret = findOnlyNamed(m, "c_ret")->toVariable(); -// auto res = rr.byAst(c_ret); -// assert(res.type().type()->isIntType()); + assert(findVarType(m, rr, "s").param()->toBoolParam()->value() == parSafe); -// auto call = resolveOnlyCandidate(context, rr.byAst(c_ret->initExpression())); -// // Concrete function, should not be instantiated -// assert(call->signature()->instantiatedFrom() == nullptr); + assert(findVarType(m, rr, "rk").param()->toBoolParam()->value() == false); -// const Variable* CT = findOnlyNamed(m, "CT")->toVariable(); -// assert(call->byAst(CT).type().type() == dType.type()); -// } + assert(findVarType(m, rr, "ak").param()->toBoolParam()->value() == true); -// assert(guard.errors().size() == 0); + assert(findVarType(m, rr, "p").type() == IntType::get(context, 0)); + + assert(findVarType(m, rr, "z").type() == fullIndexType.type()); -// printf("Success: %s\n", domainType.c_str()); -// } + { + const Variable* g_ret = findOnlyNamed(m, "g_ret")->toVariable(); + auto res = rr.byAst(g_ret); + assert(res.type().type()->isIntType()); + + auto call = resolveOnlyCandidate(context, rr.byAst(g_ret->initExpression())); + // Generic function, should have been instantiated + assert(call->signature()->instantiatedFrom() != nullptr); + + const Variable* GT = findOnlyNamed(m, "GT")->toVariable(); + assert(call->byAst(GT).type().type() == dType.type()); + } + + { + const Variable* c_ret = findOnlyNamed(m, "c_ret")->toVariable(); + auto res = rr.byAst(c_ret); + assert(res.type().type()->isIntType()); + + auto call = resolveOnlyCandidate(context, rr.byAst(c_ret->initExpression())); + // Concrete function, should not be instantiated + assert(call->signature()->instantiatedFrom() == nullptr); + + const Variable* CT = findOnlyNamed(m, "CT")->toVariable(); + assert(call->byAst(CT).type().type() == dType.type()); + } + + assert(guard.errors().size() == 0); + + printf("Success: %s\n", domainType.c_str()); +} static void testBadPass(Context* context, std::string argType, std::string actualType) { @@ -389,31 +434,30 @@ int main() { testRectangular(context, "domain(2, int(8))", 2, "int(8)", "one"); testRectangular(context, "domain(3, int(16), strideKind.negOne)", 3, "int(16)", "negOne"); testRectangular(context, "domain(strides=strideKind.negative, idxType=int, rank=1)", 1, "int", "negative"); - context->collectGarbage(); - // TODO: re-enable associative - // testAssociative(context, "domain(int)", "int", true); - // testAssociative(context, "domain(int, false)", "int", false); - // testAssociative(context, "domain(string)", "string", true); - // context->collectGarbage(); + testDomainLiteral(context, "{1..10}", DomainType::Kind::Rectangular); + testDomainLiteral(context, "{1..10, 1..10}", DomainType::Kind::Rectangular); + + testAssociative(context, "domain(int)", "int", true); + testAssociative(context, "domain(int, false)", "int", false); + testAssociative(context, "domain(string)", "string", true); + + testDomainLiteral(context, "{1, 2, 3}", DomainType::Kind::Associative); + testDomainLiteral(context, "{\"apple\", \"banana\"}", DomainType::Kind::Associative); testBadPass(context, "domain(1)", "domain(2)"); testBadPass(context, "domain(1, int(16))", "domain(1, int(8))"); testBadPass(context, "domain(1, int(8))", "domain(1, int(16))"); testBadPass(context, "domain(1, strides=strideKind.negOne)", "domain(1, strides=strideKind.one)"); - // TODO: re-enable associative badPass - // testBadPass(context, "domain(int)", "domain(string)"); - // testBadPass(context, "domain(1)", "domain(int)"); - context->collectGarbage(); + testBadPass(context, "domain(int)", "domain(string)"); + testBadPass(context, "domain(1)", "domain(int)"); testIndex(context, "domain(1)", "int"); testIndex(context, "domain(2)", "2*int"); testIndex(context, "domain(1, bool)", "bool"); testIndex(context, "domain(2, bool)", "2*bool"); - // TODO: re-enable associative indexes - // testIndex(context, "domain(int)", "int"); - // testIndex(context, "domain(string)", "string"); - context->collectGarbage(); + testIndex(context, "domain(int)", "int"); + testIndex(context, "domain(string)", "string"); testBadDomain(context, "domain()"); testBadDomain(context, "domain(1, 2, 3, 4)"); diff --git a/frontend/test/resolution/testResolve.cpp b/frontend/test/resolution/testResolve.cpp index 7ae6419ae3af..0a9b21eb2aea 100644 --- a/frontend/test/resolution/testResolve.cpp +++ b/frontend/test/resolution/testResolve.cpp @@ -1739,6 +1739,50 @@ static void testInfiniteCycleBug() { std::ignore = resolveQualifiedTypeOfX(context, program1); } +// Test use of the 'scalar promotion type' primitive. +// Implementation of getting promotion types is tested more thoroughly +// elsewhere, so this is just a very basic test the prims works as expected. +static void testPromotionPrim() { + Context* context = buildStdContext(); + ErrorGuard guard(context); + + std::string prog = + R"""( + var d : domain(1, real); + type t = __primitive("scalar promotion type", d); + param x = (t == real); + )"""; + + auto x = resolveTypeOfXInit(context, prog); + ensureParamBool(x, true); + + assert(guard.realizeErrors() == 0); +} + +// Test the '_wide_get_locale' primitive. +static void testGetLocalePrim() { + Context* context = buildStdContext(); + // TODO: we get a query system infinite recursion without this for some reason + context->collectGarbage(); + ErrorGuard guard(context); + + auto variables = resolveTypesOfVariables(context, + R"""( + var x : real; + var locId = __primitive("_wide_get_locale", x); + var sublocId = chpl_sublocFromLocaleID(locId); + )""", { "locId", "sublocId" }); + + auto locId = variables.at("locId"); + assert(locId.type()); + assert(locId.type() == CompositeType::getLocaleIDType(context)); + auto sublocId = variables.at("sublocId"); + assert(sublocId.type()); + assert(sublocId.type()->isIntType()); + + assert(guard.realizeErrors() == 0); +} + int main() { test1(); test2(); @@ -1770,5 +1814,8 @@ int main() { testInfiniteCycleBug(); + testPromotionPrim(); + testGetLocalePrim(); + return 0; }