diff --git a/frontend/include/chpl/framework/all-global-strings.h b/frontend/include/chpl/framework/all-global-strings.h index 315c5355ba0d..08ec0ef01d75 100644 --- a/frontend/include/chpl/framework/all-global-strings.h +++ b/frontend/include/chpl/framework/all-global-strings.h @@ -43,6 +43,7 @@ X(c_ptr , "c_ptr") X(c_ptrConst , "c_ptrConst") X(c_char , "c_char") X(class_ , "class") +X(defaultDist , "defaultDist") X(deinit , "deinit") X(deserialize , "deserialize") X(dmapped , "dmapped") @@ -66,6 +67,7 @@ X(index , "index") X(init , "init") X(initequals , "init=") X(int_ , "int") +X(instance_ , "_instance") X(isCoercible , "isCoercible") X(leader , "leader") X(locale , "locale") diff --git a/frontend/include/chpl/parsing/parsing-queries.h b/frontend/include/chpl/parsing/parsing-queries.h index 75be77d154b5..fc86c312b086 100644 --- a/frontend/include/chpl/parsing/parsing-queries.h +++ b/frontend/include/chpl/parsing/parsing-queries.h @@ -454,15 +454,29 @@ std::string getExistingFileInModuleSearchPath(Context* context, */ const uast::Module* getToplevelModule(Context* context, UniqueString name); +struct IdAndName { + ID id; + UniqueString name; +}; + + +/** + Given a particular (presumably standard) module, return the ID of a + symbol with the given name in that module. Beyond creating the ID, this also + ensures that the standard module is parsed, and thus, that 'idToAst' on the + returned ID will return a non-null value. + */ +ID getSymbolIdFromTopLevelModule(Context* context, + const char* modName, + const char* symName); + /** - Given a particular (presumably standard) module, return the ID of a symbol - with the given name in that module. Beyond creating the ID, this also ensures - that the standard module is parsed, and thus, that 'idToAst' on the returned - ID will return a non-null value. + Like getSymbolId..., but return also contains the name of the given symbol for + convenience. */ -ID getSymbolFromTopLevelModule(Context* context, - const char* modName, - const char* symName); +IdAndName getSymbolFromTopLevelModule(Context* context, + const char* modName, + const char* symName); /** This query parses a submodule for 'include submodule'. diff --git a/frontend/include/chpl/resolution/resolution-error-classes-list.h b/frontend/include/chpl/resolution/resolution-error-classes-list.h index 6cf7b7d2d3b8..224636d7b04e 100644 --- a/frontend/include/chpl/resolution/resolution-error-classes-list.h +++ b/frontend/include/chpl/resolution/resolution-error-classes-list.h @@ -54,6 +54,7 @@ ERROR_CLASS(IncompatibleKinds, types::QualifiedType::Kind, const uast::AstNode*, ERROR_CLASS(IncompatibleRangeBounds, const uast::Range*, types::QualifiedType, types::QualifiedType) ERROR_CLASS(IncompatibleTypeAndInit, const uast::AstNode*, const uast::AstNode*, const uast::AstNode*, const types::Type*, const types::Type*) ERROR_CLASS(InvalidClassCast, const uast::PrimCall*, types::QualifiedType) +ERROR_CLASS(InvalidDomainCall, const uast::FnCall*, std::vector) ERROR_CLASS(InvalidIndexCall, const uast::FnCall*, types::QualifiedType) ERROR_CLASS(InvalidNewTarget, const uast::New*, types::QualifiedType) ERROR_CLASS(InvalidParamCast, const uast::AstNode*, types::QualifiedType, types::QualifiedType) diff --git a/frontend/include/chpl/resolution/resolution-types.h b/frontend/include/chpl/resolution/resolution-types.h index 0218151badc3..7c16ab17a38f 100644 --- a/frontend/include/chpl/resolution/resolution-types.h +++ b/frontend/include/chpl/resolution/resolution-types.h @@ -1992,6 +1992,7 @@ class AssociatedAction { REDUCE_SCAN, // resolution of "generate" for a reduce/scan operation. INFER_TYPE, COMPARE, // == , e.g., for select-statements + RUNTIME_TYPE, // create runtime type }; private: diff --git a/frontend/include/chpl/types/CompositeType.h b/frontend/include/chpl/types/CompositeType.h index 0ef8fedec5f0..62c96a590c69 100644 --- a/frontend/include/chpl/types/CompositeType.h +++ b/frontend/include/chpl/types/CompositeType.h @@ -225,6 +225,9 @@ class CompositeType : public Type { /** Get the chpl_localeID_t type */ static const RecordType* getLocaleIDType(Context* context); + /** Get the _distribution type */ + static const RecordType* getDistributionType(Context* context); + /** Get the record _owned implementing owned */ static const RecordType* getOwnedRecordType(Context* context, const BasicClassType* bct); diff --git a/frontend/include/chpl/types/DomainType.h b/frontend/include/chpl/types/DomainType.h index 4cffbee6175f..aeda1ae87492 100644 --- a/frontend/include/chpl/types/DomainType.h +++ b/frontend/include/chpl/types/DomainType.h @@ -56,7 +56,7 @@ class DomainType final : public CompositeType { // TODO: distributions Kind kind_; - // Will compute idxType, rank, and stridable from 'subs' + // Will compute idxType, rank, and strides from 'subs' DomainType(ID id, UniqueString name, const DomainType* instantiatedFrom, SubstitutionsMap subs, @@ -91,15 +91,19 @@ class DomainType final : public CompositeType { /** Return a rectangular domain type */ static const DomainType* getRectangularType(Context* context, + const QualifiedType& instance, const QualifiedType& rank, const QualifiedType& idxType, - const QualifiedType& stridable); + const QualifiedType& strides); /** Return an associative domain type */ static const DomainType* getAssociativeType(Context* context, const QualifiedType& idxType, const QualifiedType& parSafe); + /** Get the default distribution type */ + static const QualifiedType& getDefaultDistType(Context* context); + Kind kind() const { return kind_; } @@ -123,7 +127,7 @@ class DomainType final : public CompositeType { } } - const QualifiedType& stridable() const { + const QualifiedType& strides() const { CHPL_ASSERT(kind_ == Kind::Rectangular); return subs_.at(ID(UniqueString(), 2, 0)); } diff --git a/frontend/lib/parsing/parsing-queries.cpp b/frontend/lib/parsing/parsing-queries.cpp index 1d9de57fb700..2c1e665a7e14 100644 --- a/frontend/lib/parsing/parsing-queries.cpp +++ b/frontend/lib/parsing/parsing-queries.cpp @@ -1109,9 +1109,9 @@ const Module* getToplevelModule(Context* context, UniqueString name) { return getToplevelModuleQuery(context, name); } -ID getSymbolFromTopLevelModule(Context* context, - const char* modName, - const char* symName) { +ID getSymbolIdFromTopLevelModule(Context* context, + const char* modName, + const char* symName) { std::ignore = getToplevelModule(context, UniqueString::get(context, modName)); // Performance: this has to concatenate the two strings at runtime. @@ -1127,6 +1127,13 @@ ID getSymbolFromTopLevelModule(Context* context, return ID(UniqueString::get(context, fullPath)); } +IdAndName getSymbolFromTopLevelModule(Context* context, + const char* modName, + const char* symName) { + return {getSymbolIdFromTopLevelModule(context, modName, symName), + UniqueString::get(context, symName)}; +} + static const Module* const& getIncludedSubmoduleQuery(Context* context, ID includeModuleId) { QUERY_BEGIN(getIncludedSubmoduleQuery, context, includeModuleId); diff --git a/frontend/lib/resolution/InitResolver.cpp b/frontend/lib/resolution/InitResolver.cpp index 15c0c37ad409..8261f3f78827 100644 --- a/frontend/lib/resolution/InitResolver.cpp +++ b/frontend/lib/resolution/InitResolver.cpp @@ -280,6 +280,57 @@ bool InitResolver::isFinalReceiverStateValid(void) { return ret; } +// Extract domain type information from _instance substitution +static const DomainType* domainTypeFromSubsHelper( + Context* context, const CompositeType::SubstitutionsMap& subs) { + auto genericDomain = DomainType::getGenericDomainType(context); + + // Expect one substitution for _instance + if (subs.size() != 1) return genericDomain; + + const QualifiedType instanceQt = subs.begin()->second; + + if (auto instance = instanceQt.type()) { + if (auto instanceCt = instance->toClassType()) { + if (auto instanceBct = instanceCt->basicClassType()) { + // Get BaseRectangularDom parent subs for rectangular domain info + if (auto baseDom = instanceBct->parentClassType()) { + auto& rf = fieldsForTypeDecl(context, baseDom, + DefaultsPolicy::IGNORE_DEFAULTS); + if (baseDom->id().symbolPath() == "ChapelDistribution.BaseRectangularDom") { + CHPL_ASSERT(rf.numFields() == 3); + QualifiedType rank; + QualifiedType idxType; + QualifiedType strides; + for (int i = 0; i < rf.numFields(); i++) { + if (rf.fieldName(i) == "rank") { + rank = rf.fieldType(i); + } else if (rf.fieldName(i) == "idxType") { + idxType = rf.fieldType(i); + } else if (rf.fieldName(i) == "strides") { + strides = rf.fieldType(i); + } + } + + return DomainType::getRectangularType(context, instanceQt, rank, + idxType, strides); + } else if (baseDom->id().symbolPath() == "ChapelDistribution.BaseAssociativeDom") { + // TODO: support associative domains + } else if (baseDom->id().symbolPath() == "ChapelDistribution.BaseSparseDom") { + // TODO: support sparse domains + } else { + // not a recognized domain type + return genericDomain; + } + } + } + } + } + + // If we reach here, we weren't able to resolve the domain type + return genericDomain; +} + static const Type* ctFromSubs(Context* context, const Type* receiverType, const BasicClassType* superType, @@ -313,6 +364,8 @@ static const Type* ctFromSubs(Context* context, auto manager = AnyOwnedType::get(context); auto dec = ClassTypeDecorator(ClassTypeDecorator::BORROWED_NONNIL); ret = ClassType::get(context, basic, manager, dec); + } else if (receiverType->isDomainType()) { + ret = domainTypeFromSubsHelper(context, subs); } else { CHPL_ASSERT(false && "Not handled!"); } @@ -400,9 +453,12 @@ const Type* InitResolver::computeReceiverTypeConsideringState(void) { QualifiedType::Kind InitResolver::determineReceiverIntent(void) { if (initialRecvType_->isClassType()) { return QualifiedType::CONST_IN; - } else { - CHPL_ASSERT(initialRecvType_->isRecordType()); + } else if (initialRecvType_->isRecordType() || + initialRecvType_->isDomainType()) { return QualifiedType::REF; + } else { + CHPL_ASSERT(false && "Not handled"); + return QualifiedType::UNKNOWN; } } diff --git a/frontend/lib/resolution/Resolver.cpp b/frontend/lib/resolution/Resolver.cpp index 9863b729b7e9..bd9c7f9d8d26 100644 --- a/frontend/lib/resolution/Resolver.cpp +++ b/frontend/lib/resolution/Resolver.cpp @@ -2155,19 +2155,23 @@ bool Resolver::resolveSpecialNewCall(const Call* call) { return true; } - // Remove nilability from e.g., 'new C?()' for the init call (or else it - // will not resolve because the receiver formal is 'nonnil borrowed'). const Type* initReceiverType = qtNewExpr.type(); + if (auto clsType = qtNewExpr.type()->toClassType()) { + // Remove nilability from e.g., 'new C?()' for the init call (or else it + // will not resolve because the receiver formal is 'nonnil borrowed'). + // always set the receiver to be borrowed non-nil b/c we don't want to // call initializers for '_owned' when the receiver is 'owned(MyClass)' auto newDecor = ClassTypeDecorator(ClassTypeDecorator::BORROWED_NONNIL); initReceiverType = clsType->withDecorator(context, newDecor); - CHPL_ASSERT(initReceiverType); + } else if (auto recordType = qtNewExpr.type()->toRecordType()) { + // Rewrite 'new dmap' to 'new _distribution' + if (recordType->id().symbolPath() == "ChapelArray.dmap") { + initReceiverType = CompositeType::getDistributionType(context); + } } - - // The 'new' will produce an 'init' call as a side effect. - UniqueString name = USTR("init"); + CHPL_ASSERT(initReceiverType); /* auto cls = qtNewExpr.type()->toClassType(); @@ -2186,12 +2190,11 @@ bool Resolver::resolveSpecialNewCall(const Call* call) { actuals.push_back(std::move(receiverInfo)); // Remaining actuals. - if (call->numActuals()) { - prepareCallInfoActuals(call, actuals, questionArg); - CHPL_ASSERT(!questionArg); - } + prepareCallInfoActuals(call, actuals, questionArg); + CHPL_ASSERT(!questionArg); - auto ci = CallInfo(name, calledType, isMethodCall, + // The 'new' will produce an 'init' call as a side effect. + auto ci = CallInfo(USTR("init"), calledType, isMethodCall, /* hasQuestionArg */ questionArg != nullptr, /* isParenless */ false, std::move(actuals)); @@ -2306,6 +2309,8 @@ bool Resolver::resolveSpecialKeywordCall(const Call* call) { auto fnCall = call->toFnCall(); if (!fnCall->calledExpression()->isIdentifier()) return false; + auto& r = byPostorder.byAst(call); + auto fnName = fnCall->calledExpression()->toIdentifier()->name(); if (fnName == "index") { auto runResult = context->runAndTrackErrors([&](Context* ctx) { @@ -2318,7 +2323,6 @@ bool Resolver::resolveSpecialKeywordCall(const Call* call) { auto inScopes = CallScopeInfo::forNormalCall(scope, poiScope); auto result = resolveGeneratedCall(context, call, ci, inScopes); - auto& r = byPostorder.byAst(call); handleResolvedCall(r, call, ci, result); return result; }); @@ -2331,25 +2335,81 @@ bool Resolver::resolveSpecialKeywordCall(const Call* call) { CHPL_REPORT(context, InvalidIndexCall, fnCall, firstActual); } return true; - } + } else if (fnName == "domain") { + auto& rCalledExp = byPostorder.byAst(fnCall->calledExpression()); + CHPL_ASSERT(rCalledExp.type().hasTypePtr()); + // Try resolving 'domain(?)' as a special case. + if (call->numActuals() == 1 && call->actual(0)->isIdentifier() && + call->actual(0)->toIdentifier()->name() == "?") { + // 'domain(?)' is equivalent to just 'domain', the generic domain + // type. + // Copy the result of resolving 'domain' as the called identifier. + r.setType(rCalledExp.type()); + } else { + // Get type by resolving the type of corresponding '_domain' init call + // TODO: prohibit associative domain with idxType 'domain' + const AstNode* questionArg = nullptr; + std::vector actuals; + // Set up receiver + auto receiverType = + QualifiedType(QualifiedType::INIT_RECEIVER, rCalledExp.type().type()); + auto receiverArg = CallInfoActual(receiverType, USTR("this")); + actuals.push_back(std::move(receiverArg)); + // Set up distribution arg + auto defaultDistArg = CallInfoActual( + DomainType::getDefaultDistType(context), UniqueString()); + actuals.push_back(std::move(defaultDistArg)); + // Remaining given args from domain() call as written + prepareCallInfoActuals(call, actuals, questionArg); + CHPL_ASSERT(!questionArg); + + auto ci = + CallInfo(USTR("init"), + /* calledType */ receiverType, + /* isMethodCall */ true, + /* hasQuestionArg */ false, + /* isParenless */ false, + actuals); - return false; -} + auto scope = scopeStack.back(); + auto inScopes = CallScopeInfo::forNormalCall(scope, poiScope); + auto runResult = context->runAndTrackErrors([&](Context* ctx) { + return resolveGeneratedCall(context, call, ci, inScopes); + }); -bool Resolver::resolveSpecialCall(const Call* call) { - if (resolveSpecialOpCall(call)) { - return true; - } else if (resolveSpecialPrimitiveCall(call)) { - return true; - } else if (resolveSpecialNewCall(call)) { - return true; - } else if (resolveSpecialKeywordCall(call)) { + // Use the init call's receiver type as the resulting TYPE + QualifiedType receiverTy; + if (runResult.ranWithoutErrors()) { + auto result = runResult.result(); + if (auto initMsc = result.mostSpecific().only()) { + handleResolvedCall(r, call, ci, result, + {{AssociatedAction::RUNTIME_TYPE, fnCall->id()}}); + receiverTy = initMsc.fn()->formalType(0); + } + } + if (!receiverTy.type()) { + std::vector actualTypesForErr; + for (auto it = actuals.begin() + 2; it != actuals.end(); ++it) { + actualTypesForErr.push_back(it->type()); + } + receiverTy = CHPL_TYPE_ERROR(context, InvalidDomainCall, fnCall, + actualTypesForErr); + } + r.setType(QualifiedType(QualifiedType::TYPE, receiverTy.type())); + } return true; } return false; } +bool Resolver::resolveSpecialCall(const Call* call) { + return resolveSpecialOpCall(call) || + resolveSpecialPrimitiveCall(call) || + resolveSpecialNewCall(call) || + resolveSpecialKeywordCall(call); +} + static QualifiedType lookupFieldType(Resolver& rv, const CompositeType* ct, const ID& idField) { if (!ct || !idField) return {}; @@ -3746,10 +3806,55 @@ bool Resolver::enter(const uast::Domain* decl) { } void Resolver::exit(const uast::Domain* decl) { + if (scopeResolveOnly) { + return; + } + + const DomainType* genericDomainType = DomainType::getGenericDomainType(context); + if (CompositeType::isMissingBundledRecordType(context, genericDomainType->id())) { + // If we don't have the standard library code backing the Domain type, leave + // it unresolved. + return; + } + if (decl->numExprs() == 0) { + // Generic domain, as in a generic-domain array auto& re = byPostorder.byAst(decl); - auto dt = QualifiedType(QualifiedType::CONST_VAR, DomainType::getGenericDomainType(context)); + auto dt = QualifiedType(QualifiedType::CONST_VAR, genericDomainType); re.setType(dt); + } else { + // Call appropriate domain builder proc. Use ensureDomainExpr when the + // domain is declared without curly braces (within an array type). + const char* domainBuilderProc = decl->usedCurlyBraces() + ? "chpl__buildDomainExpr" + : "chpl__ensureDomainExpr"; + + // Add key or range actuals + std::vector actuals; + for (auto expr : decl->exprs()) { + actuals.emplace_back(byPostorder.byAst(expr).type(), UniqueString()); + } + + // Add definedConst actual if appropriate + if (decl->usedCurlyBraces()) { + actuals.emplace_back( + QualifiedType(QualifiedType::PARAM, BoolType::get(context), + BoolParam::get(context, true)), + UniqueString()); + } + + auto ci = CallInfo(/* name */ UniqueString::get(context, domainBuilderProc), + /* calledType */ QualifiedType(), + /* isMethodCall */ false, + /* hasQuestionArg */ false, + /* isParenless */ false, + actuals); + auto scope = scopeStack.back(); + auto inScopes = CallScopeInfo::forNormalCall(scope, poiScope); + auto c = resolveGeneratedCall(context, decl, ci, inScopes); + + ResolvedExpression& r = byPostorder.byAst(decl); + handleResolvedCall(r, decl, ci, c); } } @@ -3864,6 +3969,8 @@ static const Type* getGenericType(Context* context, const Type* recv) { auto m = getGenericType(context, cur->manageableType()); gen = ClassType::get(context, m->toManageableType(), cur->manager(), cur->decorator()); + } else if (recv->isDomainType()) { + gen = DomainType::getGenericDomainType(context); } return gen; } @@ -4464,26 +4571,18 @@ static void resolveNewForClass(Resolver& rv, const New* node, re.setType(qt); } -static void resolveNewForRecord(Resolver& rv, const New* node, - const RecordType* recordType) { - ResolvedExpression& re = rv.byPostorder.byAst(node); +static void resolveNewForRecordLike(Resolver& rv, const New* node, + const CompositeType* recordLikeType) { + CHPL_ASSERT(recordLikeType->isRecordType() || + recordLikeType->isDomainType() || + recordLikeType->isUnionType()); - if (node->management() != New::DEFAULT_MANAGEMENT) { - CHPL_REPORT(rv.context, MemManagementNonClass, node, recordType); - } else { - auto qt = QualifiedType(QualifiedType::INIT_RECEIVER, recordType); - re.setType(qt); - } -} - -static void resolveNewForUnion(Resolver& rv, const New* node, - const UnionType* unionType) { ResolvedExpression& re = rv.byPostorder.byAst(node); if (node->management() != New::DEFAULT_MANAGEMENT) { - CHPL_REPORT(rv.context, MemManagementNonClass, node, unionType); + CHPL_REPORT(rv.context, MemManagementNonClass, node, recordLikeType); } else { - auto qt = QualifiedType(QualifiedType::INIT_RECEIVER, unionType); + auto qt = QualifiedType(QualifiedType::INIT_RECEIVER, recordLikeType); re.setType(qt); } } @@ -4514,18 +4613,14 @@ void Resolver::exit(const New* node) { return; } - if (qtTypeExpr.type()->isBasicClassType()) { + auto type = qtTypeExpr.type(); + if (type->isBasicClassType()) { CHPL_ASSERT(false && "Expected fully decorated class type"); - - } else if (auto classType = qtTypeExpr.type()->toClassType()) { + } else if (auto classType = type->toClassType()) { resolveNewForClass(*this, node, classType); - - } else if (auto recordType = qtTypeExpr.type()->toRecordType()) { - resolveNewForRecord(*this, node, recordType); - - } else if (auto unionType = qtTypeExpr.type()->toUnionType()) { - resolveNewForUnion(*this, node, unionType); - + } else if (type->isRecordType() || type->isDomainType() || + type->isUnionType()) { + resolveNewForRecordLike(*this, node, type->toCompositeType()); } else { if (node->management() != New::DEFAULT_MANAGEMENT) { CHPL_REPORT(context, MemManagementNonClass, node, qtTypeExpr.type()); diff --git a/frontend/lib/resolution/Resolver.h b/frontend/lib/resolution/Resolver.h index 67a2295e1ac0..6825419aee6d 100644 --- a/frontend/lib/resolution/Resolver.h +++ b/frontend/lib/resolution/Resolver.h @@ -512,7 +512,7 @@ struct Resolver { // own logic for traversing actuals etc. bool resolveSpecialPrimitiveCall(const uast::Call* call); - // resolve a keyword call like index(D) + // resolve a keyword call like index(D) or domain(1) bool resolveSpecialKeywordCall(const uast::Call* call); // Resolve a || or && operation. diff --git a/frontend/lib/resolution/default-functions.cpp b/frontend/lib/resolution/default-functions.cpp index ecf41a27207e..96f188ede659 100644 --- a/frontend/lib/resolution/default-functions.cpp +++ b/frontend/lib/resolution/default-functions.cpp @@ -142,25 +142,7 @@ needCompilerGeneratedMethod(Context* context, const Type* type, } } - // Some basic getter methods for domain properties - // - // TODO: We can eventually replace these for calls on a domain *value* by - // looking at the property from the _instance implementation. But that won't - // work if we want to support these methods on a domain type-expression. - // - // TODO: calling these within a method doesn't work - if (type->isDomainType()) { - if (parenless) { - if (name == "idxType" || name == "rank" || name == "stridable" || - name == "parSafe") { - return true; - } - } else { - if (name == "isRectangular" || name == "isAssociative") { - return true; - } - } - } else if (type->isArrayType()) { + if (type->isArrayType()) { if (name == "domain" || name == "eltType") { return true; } @@ -219,6 +201,8 @@ generateInitParts(Context* context, CHPL_ASSERT(receiverType); qtReceiver = QualifiedType(QualifiedType::CONST_IN, receiverType); + } else if (CompositeType::isMissingBundledType(context, compType->id())) { + // ignore } else { CHPL_ASSERT(false && "Not possible!"); } @@ -706,47 +690,6 @@ generateDeSerialize(Context* context, const CompositeType* compType, return ret; } -static const TypedFnSignature* -generateDomainMethod(Context* context, - const DomainType* dt, - UniqueString name) { - // Build a basic function signature for methods querying some aspect of - // a domain's type. - // TODO: we should really have a way to just set the return type here - const TypedFnSignature* result = nullptr; - std::vector formals; - std::vector formalTypes; - - formals.push_back( - UntypedFnSignature::FormalDetail(USTR("this"), - UntypedFnSignature::DK_NO_DEFAULT, - nullptr)); - formalTypes.push_back(QualifiedType(QualifiedType::CONST_REF, dt)); - - auto ufs = UntypedFnSignature::get(context, - /*id*/ dt->id(), - /*name*/ name, - /*isMethod*/ true, - /*isTypeConstructor*/ false, - /*isCompilerGenerated*/ true, - /*throws*/ false, - /*idTag*/ parsing::idToTag(context, dt->id()), - /*kind*/ uast::Function::Kind::PROC, - /*formals*/ std::move(formals), - /*whereClause*/ nullptr); - - // now build the other pieces of the typed signature - result = TypedFnSignature::get(context, ufs, std::move(formalTypes), - TypedFnSignature::WHERE_NONE, - /* needsInstantiation */ false, - /* instantiatedFrom */ nullptr, - /* parentFn */ nullptr, - /* formalsInstantiated */ Bitmap(), - /* outerVariables */ {}); - - return result; -} - static const TypedFnSignature* generateArrayMethod(Context* context, const ArrayType* at, @@ -1137,8 +1080,6 @@ getCompilerGeneratedMethodQuery(Context* context, QualifiedType receiverType, result = generateDeSerialize(context, compType, name, "writer", "serializer"); } else if (name == USTR("deserialize")) { result = generateDeSerialize(context, compType, name, "reader", "deserializer"); - } else if (auto domainType = type->toDomainType()) { - result = generateDomainMethod(context, domainType, name); } else if (auto arrayType = type->toArrayType()) { result = generateArrayMethod(context, arrayType, name); } else if (auto tupleType = type->toTupleType()) { diff --git a/frontend/lib/resolution/prims.cpp b/frontend/lib/resolution/prims.cpp index 06079acc8fcf..9d7d08f7bed8 100644 --- a/frontend/lib/resolution/prims.cpp +++ b/frontend/lib/resolution/prims.cpp @@ -286,6 +286,7 @@ static QualifiedType primCallResolves(ResolutionContext* rc, static QualifiedType computeDomainType(Context* context, const CallInfo& ci) { if (ci.numActuals() == 3) { auto type = DomainType::getRectangularType(context, + QualifiedType(), ci.actual(0).type(), ci.actual(1).type(), ci.actual(2).type()); diff --git a/frontend/lib/resolution/resolution-error-classes-list.cpp b/frontend/lib/resolution/resolution-error-classes-list.cpp index 636687101093..13f37b53297a 100644 --- a/frontend/lib/resolution/resolution-error-classes-list.cpp +++ b/frontend/lib/resolution/resolution-error-classes-list.cpp @@ -25,6 +25,7 @@ #include "chpl/uast/VisibilityClause.h" #include "chpl/uast/AstTag.h" #include "chpl/types/all-types.h" +#include #include #include @@ -615,6 +616,55 @@ void ErrorInvalidClassCast::write(ErrorWriterBase& wr) const { } } +void ErrorInvalidDomainCall::write(ErrorWriterBase& wr) const { + auto fnCall = std::get(info_); + auto actualTypes = std::get>(info_); + + wr.heading(kind_, type_, fnCall, "invalid use of the 'domain' keyword."); + wr.codeForLocation(fnCall); + wr.message( + "The 'domain' keyword should be used with a valid domain type " + "expression."); + + if (fnCall->numActuals() == 0) { + wr.message("However, 'domain' here did not have any actuals."); + } else if (fnCall->numActuals() == 1) { + // Could be rectangular or associative. Error if we have an actual type + // that's wrong for both. + auto qt = actualTypes[0]; + if (!qt.isType() && !(qt.type() && qt.type()->isIntType())) { + wr.message("However, the first actual was ", decayToValue(qt), + " rather than an 'int' (for rectangular) or a type (for " + "associative)."); + wr.code(fnCall, {fnCall->actual(0)}); + } + } else if (fnCall->numActuals() <= 3) { + // Should be rectangular, must have one or more actual type(s) wrong after + // the first. + wr.message( + "This 'domain' call is structured like a rectangular domain type."); + bool erroredForIdxType = false; + if (fnCall->numActuals() >= 2) { + auto idxTypeQt = actualTypes[1]; + if (!idxTypeQt.isType()) { + erroredForIdxType = true; + wr.message("However, the second actual ('idxType') was ", + decayToValue(idxTypeQt), " rather than a type as required."); + wr.code(fnCall, {fnCall->actual(1)}); + } + } + if (fnCall->numActuals() == 3) { + auto stridesQt = actualTypes[2]; + wr.message((erroredForIdxType ? "Additionally" : "However"), + ", the third actual ('strides') was ", decayToValue(stridesQt), + " rather than a 'strideKind' as required."); + wr.code(fnCall, {fnCall->actual(2)}); + } + } else { + wr.message("However, 'domain' here had too many actuals."); + } +} + void ErrorInvalidIndexCall::write(ErrorWriterBase& wr) const { auto fnCall = std::get(info_); auto& type = std::get(info_); diff --git a/frontend/lib/resolution/resolution-queries.cpp b/frontend/lib/resolution/resolution-queries.cpp index 9c9d9797b085..0bee98a575a0 100644 --- a/frontend/lib/resolution/resolution-queries.cpp +++ b/frontend/lib/resolution/resolution-queries.cpp @@ -1006,7 +1006,8 @@ const ResolvedFields& resolveForwardingExprs(Context* context, static bool typeUsesForwarding(Context* context, const Type* receiverType) { if (auto ct = receiverType->getCompositeType()) { - if (ct->isBasicClassType() || ct->isRecordType() || ct->isUnionType()) { + if (ct->isBasicClassType() || ct->isRecordType() || ct->isDomainType() || + ct->isUnionType()) { ID ctId = ct->id(); if (!ctId.isEmpty()) { return parsing::aggregateUsesForwarding(context, ctId); @@ -1638,6 +1639,8 @@ typeConstructorInitialQuery(Context* context, const Type* t) idTag = uast::asttags::Class; } else if (t->isRecordType()) { idTag = uast::asttags::Record; + } else if (t->isDomainType()) { + idTag = uast::asttags::Record; } else if (t->isUnionType()) { idTag = uast::asttags::Union; } @@ -3976,24 +3979,7 @@ static bool resolveFnCallSpecialType(Context* context, // the type. // // TODO: sync, single - if (ci.name() == "domain") { - // TODO: a compiler-generated type constructor would be simpler, but we - // don't support default values on compiler-generated methods because the - // default values require existing AST. - - // Note: 'dmapped' is treated like a binary operator at the moment, so - // we don't need to worry about distribution type for 'domain(...)' exprs. - - // Transform domain type expressions like `domain(arg1, ...)` into: - // _domain.static_type(arg1, ...) - auto genericDom = DomainType::getGenericDomainType(context); - auto recv = QualifiedType(QualifiedType::TYPE, genericDom); - auto typeCtorName = UniqueString::get(context, "static_type"); - auto ctorCall = CallInfo::createWithReceiver(ci, recv, typeCtorName); - - result = resolveCall(rc, call, ctorCall, inScopes); - return true; - } else if (ci.name() == "atomic") { + if (ci.name() == "atomic") { auto newName = UniqueString::get(context, "chpl__atomicType"); auto ctorCall = CallInfo::copyAndRename(ci, newName); result = resolveCall(rc, call, ctorCall, inScopes); diff --git a/frontend/lib/resolution/resolution-types.cpp b/frontend/lib/resolution/resolution-types.cpp index e136683c58f2..b0abde2db5af 100644 --- a/frontend/lib/resolution/resolution-types.cpp +++ b/frontend/lib/resolution/resolution-types.cpp @@ -1175,6 +1175,8 @@ const char* AssociatedAction::kindToString(Action a) { return "infer-type"; case COMPARE: return "compare"; + case RUNTIME_TYPE: + return "runtime-type"; // no default to get a warning if new Actions are added } diff --git a/frontend/lib/resolution/return-type-inference.cpp b/frontend/lib/resolution/return-type-inference.cpp index b9f38d680b12..2781c698b4d9 100644 --- a/frontend/lib/resolution/return-type-inference.cpp +++ b/frontend/lib/resolution/return-type-inference.cpp @@ -161,12 +161,11 @@ const CompositeType* helpGetTypeForDecl(Context* context, insnFromBct, std::move(filteredSubs)); } else if (auto r = ad->toRecord()) { - if (r->id().symbolPath() == "ChapelDomain._domain") { + if (r->id() == DomainType::getGenericDomainType(context)->id()) { ret = DomainType::getGenericDomainType(context); - // TODO: update this to call a method on ArrayType to get the id or path - } else if (r->id().symbolPath() == "ChapelArray._array") { + } else if (r->id() == ArrayType::getGenericArrayType(context)->id()) { ret = ArrayType::getGenericArrayType(context); - } else if (r->id().symbolPath() == "ChapelLocale._locale") { + } else if (r->id() == CompositeType::getLocaleType(context)->id()) { ret = CompositeType::getLocaleType(context); } else { const RecordType* insnFromRec = nullptr; @@ -975,34 +974,6 @@ static bool helpComputeCompilerGeneratedReturnType(Context* context, result = QualifiedType(QualifiedType::REF, ft.type()); } return true; - } else if (untyped->isMethod() && sig->formalType(0).type()->isDomainType()) { - auto dt = sig->formalType(0).type()->toDomainType(); - - if (untyped->name() == "idxType") { - result = dt->idxType(); - } else if (untyped->name() == "rank") { - // Can't use `RankType::rank` because `D.rank` is defined for associative - // domains, even though they don't have a matching substitution. - result = QualifiedType(QualifiedType::PARAM, - IntType::get(context, 64), - IntParam::get(context, dt->rankInt())); - } else if (untyped->name() == "stridable") { - result = dt->stridable(); - } else if (untyped->name() == "parSafe") { - result = dt->parSafe(); - } else if (untyped->name() == "isRectangular") { - auto val = BoolParam::get(context, dt->kind() == DomainType::Kind::Rectangular); - auto type = BoolType::get(context); - result = QualifiedType(QualifiedType::PARAM, type, val); - } else if (untyped->name() == "isAssociative") { - auto val = BoolParam::get(context, dt->kind() == DomainType::Kind::Associative); - auto type = BoolType::get(context); - result = QualifiedType(QualifiedType::PARAM, type, val); - } else { - CHPL_ASSERT(false && "unhandled compiler-generated domain method"); - return true; - } - return true; } else if (untyped->isMethod() && sig->formalType(0).type()->isArrayType()) { auto at = sig->formalType(0).type()->toArrayType(); diff --git a/frontend/lib/types/ArrayType.cpp b/frontend/lib/types/ArrayType.cpp index 24d836a60da9..9b66b0d96000 100644 --- a/frontend/lib/types/ArrayType.cpp +++ b/frontend/lib/types/ArrayType.cpp @@ -46,7 +46,8 @@ void ArrayType::stringify(std::ostream& ss, } static ID getArrayID(Context* context) { - return parsing::getSymbolFromTopLevelModule(context, "ChapelArray", "_array"); + return parsing::getSymbolIdFromTopLevelModule(context, "ChapelArray", + "_array"); } const owned& @@ -61,8 +62,8 @@ ArrayType::getArrayTypeQuery(Context* context, ID id, UniqueString name, const ArrayType* ArrayType::getGenericArrayType(Context* context) { - auto name = UniqueString::get(context, "_array"); auto id = getArrayID(context); + auto name = id.symbolName(context); SubstitutionsMap subs; const ArrayType* instantiatedFrom = nullptr; return getArrayTypeQuery(context, id, name, instantiatedFrom, subs).get(); @@ -75,8 +76,8 @@ ArrayType::getArrayType(Context* context, SubstitutionsMap subs; subs.emplace(ArrayType::domainId, domainType); subs.emplace(ArrayType::eltTypeId, eltType); - auto name = UniqueString::get(context, "_array"); auto id = getArrayID(context); + auto name = id.symbolName(context); auto instantiatedFrom = getGenericArrayType(context); return getArrayTypeQuery(context, id, name, instantiatedFrom, subs).get(); } diff --git a/frontend/lib/types/BasicClassType.cpp b/frontend/lib/types/BasicClassType.cpp index 0abbe356fd79..0c5d9ce0fe96 100644 --- a/frontend/lib/types/BasicClassType.cpp +++ b/frontend/lib/types/BasicClassType.cpp @@ -44,7 +44,7 @@ BasicClassType::get(Context* context, ID id, UniqueString name, const BasicClassType* parentType, const BasicClassType* instantiatedFrom, SubstitutionsMap subs) { - // getObjectType should be used to construct object + // getRootClassType should be used to construct RootClass // everything else should have a parent type. CHPL_ASSERT(parentType != nullptr); return getBasicClassType(context, id, name, @@ -65,8 +65,8 @@ BasicClassType::getRootClassType(Context* context) { const BasicClassType* BasicClassType::getReduceScanOpType(Context* context) { - auto name = UniqueString::get(context, "ReduceScanOp"); - auto id = parsing::getSymbolFromTopLevelModule(context, "ChapelReduce", "ReduceScanOp"); + auto [id, name] = parsing::getSymbolFromTopLevelModule( + context, "ChapelReduce", "ReduceScanOp"); auto objectType = getRootClassType(context); return getBasicClassType(context, id, name, diff --git a/frontend/lib/types/CPtrType.cpp b/frontend/lib/types/CPtrType.cpp index d1e6116a029b..f510086be5d0 100644 --- a/frontend/lib/types/CPtrType.cpp +++ b/frontend/lib/types/CPtrType.cpp @@ -82,7 +82,8 @@ const CPtrType* CPtrType::getCVoidPtrType(Context* context) { const ID& CPtrType::getId(Context* context) { QUERY_BEGIN(getId, context); - ID result = parsing::getSymbolFromTopLevelModule(context, "CTypes", "c_ptr"); + ID result = + parsing::getSymbolIdFromTopLevelModule(context, "CTypes", "c_ptr"); return QUERY_END(result); } @@ -97,7 +98,8 @@ const CPtrType* CPtrType::withoutConst(Context* context) const { const ID& CPtrType::getConstId(Context* context) { QUERY_BEGIN(getConstId, context); - ID result = parsing::getSymbolFromTopLevelModule(context, "CTypes", "c_ptrConst"); + ID result = + parsing::getSymbolIdFromTopLevelModule(context, "CTypes", "c_ptrConst"); return QUERY_END(result); } diff --git a/frontend/lib/types/CompositeType.cpp b/frontend/lib/types/CompositeType.cpp index f8af87b92e62..87220933d4eb 100644 --- a/frontend/lib/types/CompositeType.cpp +++ b/frontend/lib/types/CompositeType.cpp @@ -26,7 +26,10 @@ #include "chpl/types/BasicClassType.h" #include "chpl/types/ClassType.h" #include "chpl/types/ClassTypeDecorator.h" +#include "chpl/types/CPtrType.h" +#include "chpl/types/DomainType.h" #include "chpl/types/RecordType.h" +#include "chpl/types/TupleType.h" #include "chpl/uast/Decl.h" #include "chpl/uast/NamedDecl.h" @@ -156,55 +159,67 @@ void CompositeType::stringify(std::ostream& ss, } const RecordType* CompositeType::getStringType(Context* context) { - auto name = UniqueString::get(context, "_string"); - auto id = parsing::getSymbolFromTopLevelModule(context, "String", "_string"); + auto [id, name] = + parsing::getSymbolFromTopLevelModule(context, "String", "_string"); return RecordType::get(context, id, name, - /* instantiatedFrom */ nullptr, - SubstitutionsMap()); + /* instantiatedFrom */ nullptr, SubstitutionsMap()); } const RecordType* CompositeType::getRangeType(Context* context) { - auto name = UniqueString::get(context, "_range"); - auto id = parsing::getSymbolFromTopLevelModule(context, "ChapelRange", "_range"); + auto [id, name] = + parsing::getSymbolFromTopLevelModule(context, "ChapelRange", "_range"); return RecordType::get(context, id, name, - /* instantiatedFrom */ nullptr, - SubstitutionsMap()); + /* instantiatedFrom */ nullptr, SubstitutionsMap()); } const RecordType* CompositeType::getBytesType(Context* context) { - auto name = UniqueString::get(context, "_bytes"); - auto id = parsing::getSymbolFromTopLevelModule(context, "Bytes", "_bytes"); + auto [id, name] = + parsing::getSymbolFromTopLevelModule(context, "Bytes", "_bytes"); return RecordType::get(context, id, name, - /* instantiatedFrom */ nullptr, - SubstitutionsMap()); + /* instantiatedFrom */ nullptr, SubstitutionsMap()); } const RecordType* CompositeType::getLocaleType(Context* context) { - auto name = UniqueString::get(context, "_locale"); - auto id = parsing::getSymbolFromTopLevelModule(context, "ChapelLocale", "_locale"); + auto [id, name] = + parsing::getSymbolFromTopLevelModule(context, "ChapelLocale", "_locale"); return RecordType::get(context, id, name, /* instantiatedFrom */ nullptr, SubstitutionsMap()); } const RecordType* CompositeType::getLocaleIDType(Context* context) { - auto name = UniqueString::get(context, "chpl_localeID_t"); auto id = ID(); + auto name = UniqueString::get(context, "chpl_localeID_t"); return RecordType::get(context, id, name, /* instantiatedFrom */ nullptr, SubstitutionsMap()); } +const RecordType* CompositeType::getDistributionType(Context* context) { + auto [id, name] = parsing::getSymbolFromTopLevelModule( + context, "ChapelDistribution", "_distribution"); + return RecordType::get(context, id, name, + /* instantiatedFrom */ nullptr, SubstitutionsMap()); +} + +static const ID getOwnedRecordId(Context* context) { + return parsing::getSymbolIdFromTopLevelModule(context, "OwnedObject", + "_owned"); +} + +static const ID getSharedRecordId(Context* context) { + return parsing::getSymbolIdFromTopLevelModule(context, "SharedObject", + "_shared"); +} + static const RecordType* tryCreateManagerRecord(Context* context, - const char* moduleName, - const char* recordName, + const ID& recordId, const BasicClassType* bct) { const RecordType* instantiatedFrom = nullptr; SubstitutionsMap subs; if (bct != nullptr) { instantiatedFrom = tryCreateManagerRecord(context, - moduleName, - recordName, + recordId, /*bct*/ nullptr); auto fields = fieldsForTypeDecl(context, @@ -225,21 +240,20 @@ static const RecordType* tryCreateManagerRecord(Context* context, } } - auto name = UniqueString::get(context, recordName); - auto id = parsing::getSymbolFromTopLevelModule(context, moduleName, recordName); - return RecordType::get(context, id, name, + auto name = recordId.symbolName(context); + return RecordType::get(context, recordId, name, instantiatedFrom, std::move(subs)); } const RecordType* CompositeType::getOwnedRecordType(Context* context, const BasicClassType* bct) { - return tryCreateManagerRecord(context, "OwnedObject", "_owned", bct); + return tryCreateManagerRecord(context, getOwnedRecordId(context), bct); } const RecordType* CompositeType::getSharedRecordType(Context* context, const BasicClassType* bct) { - return tryCreateManagerRecord(context, "SharedObject", "_shared", bct); + return tryCreateManagerRecord(context, getSharedRecordId(context), bct); } bool CompositeType::isMissingBundledType(Context* context, ID id) { @@ -250,13 +264,14 @@ bool CompositeType::isMissingBundledType(Context* context, ID id) { bool CompositeType::isMissingBundledRecordType(Context* context, ID id) { bool noLibrary = parsing::bundledModulePath(context).isEmpty(); if (noLibrary) { - auto path = id.symbolPath(); - return path == "String._string" || - path == "ChapelRange._range" || - path == "ChapelTuple._tuple" || - path == "Bytes._bytes" || - path == "OwnedObject._owned" || - path == "SharedObject._shared"; + return id == CompositeType::getStringType(context)->id() || + id == CompositeType::getRangeType(context)->id() || + id == TupleType::getGenericTupleType(context)->id() || + id == CompositeType::getBytesType(context)->id() || + id == CompositeType::getDistributionType(context)->id() || + id == DomainType::getGenericDomainType(context)->id() || + id == getOwnedRecordId(context) || + id == getSharedRecordId(context); } return false; @@ -265,19 +280,18 @@ bool CompositeType::isMissingBundledRecordType(Context* context, ID id) { bool CompositeType::isMissingBundledClassType(Context* context, ID id) { bool noLibrary = parsing::bundledModulePath(context).isEmpty(); if (noLibrary) { - auto path = id.symbolPath(); - return path == "ChapelReduce.ReduceScanOp" || - path == "Errors.Error" || - path == "CTypes.c_ptr" || - path == "CTypes.c_ptrConst"; + return id == BasicClassType::getReduceScanOpType(context)->id() || + id == CompositeType::getErrorType(context)->basicClassType()->id() || + id == CPtrType::getId(context) || + id == CPtrType::getConstId(context); } return false; } const ClassType* CompositeType::getErrorType(Context* context) { - auto name = UniqueString::get(context, "Error"); - auto id = parsing::getSymbolFromTopLevelModule(context, "Errors", "Error"); + auto [id, name] = + parsing::getSymbolFromTopLevelModule(context, "Errors", "Error"); auto dec = ClassTypeDecorator(ClassTypeDecorator::GENERIC_NONNIL); auto bct = BasicClassType::get(context, id, name, diff --git a/frontend/lib/types/DomainType.cpp b/frontend/lib/types/DomainType.cpp index e536f0f9f831..f51adb8e5e7f 100644 --- a/frontend/lib/types/DomainType.cpp +++ b/frontend/lib/types/DomainType.cpp @@ -21,6 +21,7 @@ #include "chpl/framework/query-impl.h" #include "chpl/parsing/parsing-queries.h" +#include "chpl/resolution/resolution-queries.h" #include "chpl/resolution/intents.h" #include "chpl/types/Param.h" #include "chpl/types/TupleType.h" @@ -37,7 +38,7 @@ void DomainType::stringify(std::ostream& ss, ss << ","; idxType().type()->stringify(ss, stringKind); ss << ","; - stridable().param()->stringify(ss, stringKind); + strides().param()->stringify(ss, stringKind); ss << ")"; } else if (kind_ == Kind::Associative) { ss << "domain("; @@ -53,7 +54,8 @@ void DomainType::stringify(std::ostream& ss, } static ID getDomainID(Context* context) { - return parsing::getSymbolFromTopLevelModule(context, "ChapelDomain", "_domain"); + return parsing::getSymbolIdFromTopLevelModule(context, "ChapelDomain", + "_domain"); } const owned& @@ -69,8 +71,8 @@ DomainType::getDomainType(Context* context, ID id, UniqueString name, const DomainType* DomainType::getGenericDomainType(Context* context) { - auto name = UniqueString::get(context, "_domain"); auto id = getDomainID(context); + auto name = id.symbolName(context); SubstitutionsMap subs; const DomainType* instantiatedFrom = nullptr; return getDomainType(context, id, name, instantiatedFrom, subs).get(); @@ -78,18 +80,44 @@ DomainType::getGenericDomainType(Context* context) { const DomainType* DomainType::getRectangularType(Context* context, + const QualifiedType& instance, const QualifiedType& rank, const QualifiedType& idxType, - const QualifiedType& stridable) { + const QualifiedType& strides) { + auto genericDomain = getGenericDomainType(context); + SubstitutionsMap subs; + CHPL_ASSERT(rank.isParam() && rank.param()->isIntParam()); subs.emplace(ID(UniqueString(), 0, 0), rank); + CHPL_ASSERT(idxType.isType()); subs.emplace(ID(UniqueString(), 1, 0), idxType); - subs.emplace(ID(UniqueString(), 2, 0), stridable); + CHPL_ASSERT(strides.isParam() && strides.param()->isEnumParam() && + strides.param()->toEnumParam()->value().id.symbolPath() == + "ChapelRange.strideKind"); + subs.emplace(ID(UniqueString(), 2, 0), strides); + + + // Add substitution for _instance field + auto& rf = fieldsForTypeDecl(context, genericDomain, + resolution::DefaultsPolicy::IGNORE_DEFAULTS, + /* syntaxOnly */ true); + ID instanceFieldId; + for (int i = 0; i < rf.numFields(); i++) { + if (rf.fieldName(i) == USTR("_instance")) { + instanceFieldId = rf.fieldDeclId(i); + break; + } + } + if (instanceFieldId.isEmpty()) { + CHPL_ASSERT(isMissingBundledRecordType(context, genericDomain->id())); + instanceFieldId = ID(USTR("_instance"), 0, 0); + } + subs.emplace(instanceFieldId, instance); + auto name = UniqueString::get(context, "_domain"); auto id = getDomainID(context); - auto instantiatedFrom = getGenericDomainType(context); - return getDomainType(context, id, name, instantiatedFrom, subs, - DomainType::Kind::Rectangular).get(); + return getDomainType(context, id, name, /* instantiatedFrom*/ genericDomain, + subs, DomainType::Kind::Rectangular).get(); } const DomainType* @@ -97,6 +125,7 @@ DomainType::getAssociativeType(Context* context, const QualifiedType& idxType, const QualifiedType& parSafe) { SubstitutionsMap subs; + // TODO: assert validity of sub types subs.emplace(ID(UniqueString(), 0, 0), idxType); subs.emplace(ID(UniqueString(), 1, 0), parSafe); auto name = UniqueString::get(context, "_domain"); @@ -106,6 +135,25 @@ DomainType::getAssociativeType(Context* context, DomainType::Kind::Associative).get(); } +const QualifiedType& DomainType::getDefaultDistType(Context* context) { + QUERY_BEGIN(getDefaultDistType, context); + + QualifiedType result; + + if (auto mod = parsing::getToplevelModule( + context, UniqueString::get(context, "DefaultRectangular"))) { + for (auto stmt : mod->children()) { + auto decl = stmt->toNamedDecl(); + if (decl && decl->name() == USTR("defaultDist")) { + auto res = resolution::resolveModuleStmt(context, stmt->id()); + result = res.byId(stmt->id()).type(); + } + } + } + + return QUERY_END(result); +} + int DomainType::rankInt() const { if (kind_ == Kind::Rectangular) { return rank().param()->toIntParam()->value(); diff --git a/frontend/lib/types/EnumType.cpp b/frontend/lib/types/EnumType.cpp index 52cd457f43c6..7a4eb12e98b0 100644 --- a/frontend/lib/types/EnumType.cpp +++ b/frontend/lib/types/EnumType.cpp @@ -61,14 +61,12 @@ const EnumType* EnumType::get(Context* context, ID id, UniqueString name) { } const EnumType* EnumType::getBoundKindType(Context* context) { - auto name = UniqueString::get(context, "boundKind"); - auto id = parsing::getSymbolFromTopLevelModule(context, "ChapelRange", "boundKind"); + auto [id, name] = parsing::getSymbolFromTopLevelModule(context, "ChapelRange", "boundKind"); return EnumType::get(context, id, name); } const EnumType* EnumType::getIterKindType(Context* context) { - auto name = UniqueString::get(context, "iterKind"); - auto id = parsing::getSymbolFromTopLevelModule(context, "ChapelBase", "iterKind"); + auto [id, name] = parsing::getSymbolFromTopLevelModule(context, "ChapelBase", "iterKind"); return EnumType::get(context, id, name); } diff --git a/frontend/lib/types/TupleType.cpp b/frontend/lib/types/TupleType.cpp index 73ed247e3716..5e3fa7f865cf 100644 --- a/frontend/lib/types/TupleType.cpp +++ b/frontend/lib/types/TupleType.cpp @@ -100,8 +100,8 @@ TupleType::getTupleType(Context* context, const TupleType* instantiatedFrom, QUERY_BEGIN(getTupleType, context, instantiatedFrom, subs, isVarArgTuple); - auto name = UniqueString::get(context, "_tuple"); - auto id = parsing::getSymbolFromTopLevelModule(context, "ChapelTuple", "_tuple"); + auto [id, name] = + parsing::getSymbolFromTopLevelModule(context, "ChapelTuple", "_tuple"); auto result = toOwned(new TupleType(id, name, instantiatedFrom, std::move(subs), isVarArgTuple)); diff --git a/frontend/test/resolution/testArrays.cpp b/frontend/test/resolution/testArrays.cpp index c1b81d02ceee..32630f8e5e90 100644 --- a/frontend/test/resolution/testArrays.cpp +++ b/frontend/test/resolution/testArrays.cpp @@ -43,8 +43,11 @@ static QualifiedType findVarType(const Module* m, static void testArray(std::string domainType, std::string eltType) { - Context ctx; + Context::Configuration config; + config.chplHome = getenv("CHPL_HOME"); + Context ctx(config); Context* context = &ctx; + setupModuleSearchPaths(context, false, false, {}, {}); ErrorGuard guard(context); // a different element type from the one we were given @@ -53,7 +56,7 @@ static void testArray(std::string domainType, altElt = "string"; } - std::string program = DomainModule + ArrayModule + + std::string program = ArrayModule + R"""( module M { use ChapelArray; @@ -94,7 +97,7 @@ module M { setFileText(context, path, std::move(program)); const ModuleVec& vec = parseToplevel(context, path); - const Module* m = vec[2]; + const Module* m = vec[1]; const ResolutionResultByPostorderID& rr = resolveModule(context, m->id()); @@ -137,7 +140,7 @@ module M { assert(call->byAst(ETGood).type().type() == AType.type()); } - assert(guard.errors().size() == 0); + assert(guard.realizeErrors() == 0); std::string arrayText; arrayText += "[" + domainType + "] " + eltType; @@ -148,7 +151,9 @@ int main() { testArray("domain(1)", "int"); testArray("domain(1)", "string"); testArray("domain(2)", "int"); - testArray("domain(int)", "int"); + + // TODO: re-enable once associative domains are working + // testArray("domain(int)", "int"); return 0; } diff --git a/frontend/test/resolution/testDomains.cpp b/frontend/test/resolution/testDomains.cpp index 14dd1fc1a6a9..28c9b5b17fce 100644 --- a/frontend/test/resolution/testDomains.cpp +++ b/frontend/test/resolution/testDomains.cpp @@ -18,7 +18,6 @@ */ #include "test-resolution.h" -#include "test-minimal-modules.h" #include "chpl/parsing/parsing-queries.h" #include "chpl/resolution/resolution-queries.h" @@ -37,19 +36,18 @@ static QualifiedType findVarType(const Module* m, return rr.byAst(var).type(); } -static void testRectangular(std::string domainType, +static void testRectangular(Context* context, + std::string domainType, int rank, std::string idxType, - bool stridable) { - Context ctx; - Context* context = &ctx; + std::string strides) { + context->advanceToNextRevision(false); + setupModuleSearchPaths(context, false, false, {}, {}); ErrorGuard guard(context); - std::string program = DomainModule + + std::string program = R"""( module M { - use ChapelDomain; - var d : )""" + domainType + R"""(; param rg = )""" + std::to_string(rank) + R"""(; type ig = )""" + idxType + R"""(; @@ -57,11 +55,11 @@ module M { param r = d.rank; type i = d.idxType; - param s = d.stridable; + param s = d.strides; param rk = d.isRectangular(); param ak = d.isAssociative(); - var p = d.pid(); + var p = d.pid; for loopI in d { var z = loopI; @@ -86,19 +84,40 @@ module M { setFileText(context, path, std::move(program)); const ModuleVec& vec = parseToplevel(context, path); - const Module* m = vec[1]; + const Module* m = vec[0]; const ResolutionResultByPostorderID& rr = resolveModule(context, m->id()); - QualifiedType dType = findVarType(m, rr, "d"); + const Variable* d = m->stmt(0)->toVariable(); + assert(d); + assert(d->name() == "d"); + + QualifiedType dQt = rr.byAst(d).type(); + assert(dQt.type()); + auto dType = dQt.type()->toDomainType(); + assert(dType); + + auto dTypeExpr = d->typeExpression(); + assert(dTypeExpr); + auto typeRe = rr.byAst(dTypeExpr); + auto& aa = typeRe.associatedActions()[0]; + assert(!aa.id().isEmpty()); + assert(aa.action() == AssociatedAction::RUNTIME_TYPE); QualifiedType fullIndexType = findVarType(m, rr, "fullIndex"); + (void)fullIndexType; - assert(findVarType(m, rr, "r").param()->toIntParam()->value() == rank); + auto rankVarTy = findVarType(m, rr, "r"); + assert(rankVarTy == dType->rank()); + assert(rankVarTy.param()->toIntParam()->value() == rank); - assert(findVarType(m, rr, "ig") == findVarType(m, rr, "i")); + auto idxTypeVarTy = findVarType(m, rr, "i"); + assert(idxTypeVarTy == dType->idxType()); + assert(findVarType(m, rr, "ig") == idxTypeVarTy); - assert(findVarType(m, rr, "s").param()->toBoolParam()->value() == stridable); + auto stridesVarTy = findVarType(m, rr, "s"); + assert(stridesVarTy == dType->strides()); + assert(stridesVarTy.param()->toEnumParam()->value().str == strides); assert(findVarType(m, rr, "rk").param()->toBoolParam()->value() == true); @@ -118,7 +137,7 @@ module M { assert(call->signature()->instantiatedFrom() != nullptr); const Variable* GT = findOnlyNamed(m, "GT")->toVariable(); - assert(call->byAst(GT).type().type() == dType.type()); + assert(call->byAst(GT).type().type() == dType); } { @@ -131,121 +150,120 @@ module M { assert(call->signature()->instantiatedFrom() == nullptr); const Variable* CT = findOnlyNamed(m, "CT")->toVariable(); - assert(call->byAst(CT).type().type() == dType.type()); + assert(call->byAst(CT).type().type() == dType); } - assert(guard.errors().size() == 0); + assert(guard.realizeErrors() == 0); printf("Success: %s\n", domainType.c_str()); } -static void testAssociative(std::string domainType, - std::string idxType, - bool parSafe) { - Context ctx; - Context* context = &ctx; - ErrorGuard guard(context); +// static void testAssociative(Context* context, +// std::string domainType, +// std::string idxType, +// bool parSafe) { +// context->advanceToNextRevision(false); +// setupModuleSearchPaths(context, false, false, {}, {}); +// ErrorGuard guard(context); - std::string program = DomainModule + -R"""( -module M { - use ChapelDomain; - - var d : )""" + domainType + R"""(; - type ig = )""" + idxType + R"""(; +// std::string program = +// R"""( +// module M { +// var d : )""" + domainType + R"""(; +// type ig = )""" + idxType + R"""(; - type i = d.idxType; - param s = d.parSafe; - param rk = d.isRectangular(); - param ak = d.isAssociative(); +// type i = d.idxType; +// param s = d.parSafe; +// param rk = d.isRectangular(); +// param ak = d.isAssociative(); - var p = d.pid(); +// var p = d.pid(); - for loopI in d { - var z = loopI; - } +// for loopI in d { +// var z = loopI; +// } - proc generic(arg: domain) { - type GT = arg.type; - return 42; - } +// proc generic(arg: domain) { +// type GT = arg.type; +// return 42; +// } - proc concrete(arg: )""" + domainType + R"""() { - type CT = arg.type; - return 42; - } +// proc concrete(arg: )""" + domainType + R"""() { +// type CT = arg.type; +// return 42; +// } - var g_ret = generic(d); - var c_ret = concrete(d); -} -)"""; - // TODO: generic checks +// var g_ret = generic(d); +// var c_ret = concrete(d); +// } +// )"""; +// // TODO: generic checks - auto path = UniqueString::get(context, "input.chpl"); - setFileText(context, path, std::move(program)); +// auto path = UniqueString::get(context, "input.chpl"); +// setFileText(context, path, std::move(program)); - const ModuleVec& vec = parseToplevel(context, path); - const Module* m = vec[1]; +// const ModuleVec& vec = parseToplevel(context, path); +// const Module* m = vec[1]; - const ResolutionResultByPostorderID& rr = resolveModule(context, m->id()); +// const ResolutionResultByPostorderID& rr = resolveModule(context, m->id()); - QualifiedType dType = findVarType(m, rr, "d"); +// QualifiedType dType = findVarType(m, rr, "d"); +// assert(dType.type()->isDomainType()); - auto fullIndexType = findVarType(m, rr, "i"); - assert(findVarType(m, rr, "ig") == fullIndexType); +// auto fullIndexType = findVarType(m, rr, "i"); +// assert(findVarType(m, rr, "ig") == fullIndexType); - assert(findVarType(m, rr, "s").param()->toBoolParam()->value() == parSafe); +// assert(findVarType(m, rr, "s").param()->toBoolParam()->value() == parSafe); - assert(findVarType(m, rr, "rk").param()->toBoolParam()->value() == false); +// assert(findVarType(m, rr, "rk").param()->toBoolParam()->value() == false); - assert(findVarType(m, rr, "ak").param()->toBoolParam()->value() == true); +// assert(findVarType(m, rr, "ak").param()->toBoolParam()->value() == true); - assert(findVarType(m, rr, "p").type() == IntType::get(context, 0)); +// assert(findVarType(m, rr, "p").type() == IntType::get(context, 0)); - assert(findVarType(m, rr, "z").type() == fullIndexType.type()); +// assert(findVarType(m, rr, "z").type() == fullIndexType.type()); - { - const Variable* g_ret = findOnlyNamed(m, "g_ret")->toVariable(); - auto res = rr.byAst(g_ret); - assert(res.type().type()->isIntType()); +// { +// const Variable* g_ret = findOnlyNamed(m, "g_ret")->toVariable(); +// auto res = rr.byAst(g_ret); +// assert(res.type().type()->isIntType()); - auto call = resolveOnlyCandidate(context, rr.byAst(g_ret->initExpression())); - // Generic function, should have been instantiated - assert(call->signature()->instantiatedFrom() != nullptr); +// auto call = resolveOnlyCandidate(context, rr.byAst(g_ret->initExpression())); +// // Generic function, should have been instantiated +// assert(call->signature()->instantiatedFrom() != nullptr); - const Variable* GT = findOnlyNamed(m, "GT")->toVariable(); - assert(call->byAst(GT).type().type() == dType.type()); - } +// const Variable* GT = findOnlyNamed(m, "GT")->toVariable(); +// assert(call->byAst(GT).type().type() == dType.type()); +// } - { - const Variable* c_ret = findOnlyNamed(m, "c_ret")->toVariable(); - auto res = rr.byAst(c_ret); - assert(res.type().type()->isIntType()); +// { +// const Variable* c_ret = findOnlyNamed(m, "c_ret")->toVariable(); +// auto res = rr.byAst(c_ret); +// assert(res.type().type()->isIntType()); - auto call = resolveOnlyCandidate(context, rr.byAst(c_ret->initExpression())); - // Concrete function, should not be instantiated - assert(call->signature()->instantiatedFrom() == nullptr); +// auto call = resolveOnlyCandidate(context, rr.byAst(c_ret->initExpression())); +// // Concrete function, should not be instantiated +// assert(call->signature()->instantiatedFrom() == nullptr); - const Variable* CT = findOnlyNamed(m, "CT")->toVariable(); - assert(call->byAst(CT).type().type() == dType.type()); - } +// const Variable* CT = findOnlyNamed(m, "CT")->toVariable(); +// assert(call->byAst(CT).type().type() == dType.type()); +// } - assert(guard.errors().size() == 0); +// assert(guard.errors().size() == 0); - printf("Success: %s\n", domainType.c_str()); -} +// printf("Success: %s\n", domainType.c_str()); +// } -static void testBadPass(std::string argType, std::string actualType) { +static void testBadPass(Context* context, std::string argType, + std::string actualType) { // Ensure that we can't, e.g., pass a domain(1) to a domain(2) - Context ctx; - Context* context = &ctx; + context->advanceToNextRevision(false); + setupModuleSearchPaths(context, false, false, {}, {}); ErrorGuard guard(context); - std::string program = DomainModule + + std::string program = R"""( module M { - use ChapelDomain; - proc foo(arg: )""" + argType + R"""() { return 42; } @@ -260,7 +278,7 @@ module M { setFileText(context, path, std::move(program)); const ModuleVec& vec = parseToplevel(context, path); - const Module* m = vec[1]; + const Module* m = vec[0]; const ResolutionResultByPostorderID& rr = resolveModule(context, m->id()); @@ -276,18 +294,16 @@ module M { guard.clearErrors(); } -static void testIndex(std::string domainType, +static void testIndex(Context* context, + std::string domainType, std::string expectedType) { - Context ctx; - Context* context = &ctx; + context->advanceToNextRevision(false); + setupModuleSearchPaths(context, false, false, {}, {}); ErrorGuard guard(context); - std::string program = DomainModule + ArrayModule + + std::string program = R"""( module M { - use ChapelDomain; - use ChapelArray; - var d : )""" + domainType + R"""(; type t = )""" + expectedType + R"""(; type i = index(d); @@ -300,38 +316,121 @@ module M { setFileText(context, path, std::move(program)); const ModuleVec& vec = parseToplevel(context, path); - const Module* m = vec[2]; + const Module* m = vec[0]; const ResolutionResultByPostorderID& rr = resolveModule(context, m->id()); - findVarType(m, rr, "d").dump(); - findVarType(m, rr, "t").dump(); - findVarType(m, rr, "i").dump(); + + assert(!findVarType(m, rr, "d").isUnknownOrErroneous()); + assert(!findVarType(m, rr, "t").isUnknownOrErroneous()); + assert(!findVarType(m, rr, "i").isUnknownOrErroneous()); assert(findVarType(m, rr, "equal").isParamTrue()); + + // assert(guard.realizeErrors() == 0); + + printf("Success: index(%s) == %s\n", domainType.c_str(), + expectedType.c_str()); +} + +static void testBadDomainHelper(std::string domainType, Context* context, + ErrorGuard& guard) { + std::string program = +R"""( +module M { + var d : )""" + domainType + R"""(; +} +)"""; + + auto path = UniqueString::get(context, "input.chpl"); + setFileText(context, path, std::move(program)); + + const ModuleVec& vec = parseToplevel(context, path); + const Module* m = vec[0]; + + const ResolutionResultByPostorderID& rr = resolveModule(context, m->id()); + + const Variable* d = m->stmt(0)->toVariable(); + assert(d); + assert(d->name() == "d"); + QualifiedType dType = rr.byAst(d).type(); + assert(dType.type()->isErroneousType()); + + assert(guard.errors().size() == 1); + auto& e = guard.errors()[0]; + assert(e->type() == chpl::InvalidDomainCall); + + guard.clearErrors(); +} + +// Ensure we gracefully error for bad domain type expressions, with or without +// the standard modules available. +static void testBadDomain(Context* contextWithStd, std::string domainType) { + // Without standard modules + { + Context ctx; + Context* context = &ctx; + ErrorGuard guard(context); + + testBadDomainHelper(domainType, context, guard); + } + + // With standard modules + { + contextWithStd->advanceToNextRevision(false); + setupModuleSearchPaths(contextWithStd, false, false, {}, {}); + ErrorGuard guard(contextWithStd); + + testBadDomainHelper(domainType, contextWithStd, guard); + } + + printf("Success: cannot resolve %s\n", + domainType.c_str()); } int main() { - testRectangular("domain(1)", 1, "int", false); - testRectangular("domain(2)", 2, "int", false); - testRectangular("domain(1, stridable=true)", 1, "int", true); - testRectangular("domain(2, int(8))", 2, "int(8)", false); - testRectangular("domain(3, int(16), true)", 3, "int(16)", true); - testRectangular("domain(stridable=false, idxType=int, rank=1)", 1, "int", false); - - testAssociative("domain(int)", "int", true); - testAssociative("domain(int, false)", "int", false); - testAssociative("domain(string)", "string", true); - - testBadPass("domain(1)", "domain(2)"); - testBadPass("domain(int)", "domain(string)"); - testBadPass("domain(1)", "domain(int)"); - - testIndex("domain(1)", "int"); - testIndex("domain(2)", "2*int"); - testIndex("domain(1, bool)", "bool"); - testIndex("domain(2, bool)", "2*bool"); - testIndex("domain(int)", "int"); - testIndex("domain(string)", "string"); + // Set up context with standard modules, re-used between tests for + // performance. + auto ctx = buildStdContext(); + auto context = ctx.get(); + + testRectangular(context, "domain(1)", 1, "int", "one"); + testRectangular(context, "domain(2)", 2, "int", "one"); + testRectangular(context, "domain(1, strides=strideKind.one)", 1, "int", "one"); + testRectangular(context, "domain(2, int(8))", 2, "int(8)", "one"); + testRectangular(context, "domain(3, int(16), strideKind.negOne)", 3, "int(16)", "negOne"); + testRectangular(context, "domain(strides=strideKind.negative, idxType=int, rank=1)", 1, "int", "negative"); + context->collectGarbage(); + + // TODO: re-enable associative + // testAssociative(context, "domain(int)", "int", true); + // testAssociative(context, "domain(int, false)", "int", false); + // testAssociative(context, "domain(string)", "string", true); + // context->collectGarbage(); + + testBadPass(context, "domain(1)", "domain(2)"); + testBadPass(context, "domain(1, int(16))", "domain(1, int(8))"); + testBadPass(context, "domain(1, int(8))", "domain(1, int(16))"); + testBadPass(context, "domain(1, strides=strideKind.negOne)", "domain(1, strides=strideKind.one)"); + // TODO: re-enable associative badPass + // testBadPass(context, "domain(int)", "domain(string)"); + // testBadPass(context, "domain(1)", "domain(int)"); + context->collectGarbage(); + + testIndex(context, "domain(1)", "int"); + testIndex(context, "domain(2)", "2*int"); + testIndex(context, "domain(1, bool)", "bool"); + testIndex(context, "domain(2, bool)", "2*bool"); + // TODO: re-enable associative indexes + // testIndex(context, "domain(int)", "int"); + // testIndex(context, "domain(string)", "string"); + context->collectGarbage(); + + testBadDomain(context, "domain()"); + testBadDomain(context, "domain(1, 2, 3, 4)"); + testBadDomain(context, "domain(\"asdf\")"); + testBadDomain(context, "domain(\"asdf\", \"asdf2\")"); + testBadDomain(context, "domain(1, \"asdf\")"); + testBadDomain(context, "domain(1, int, \"asdf\")"); return 0; } diff --git a/frontend/test/test-minimal-modules.h b/frontend/test/test-minimal-modules.h index 1518222c9ea5..98c9c0a59e94 100644 --- a/frontend/test/test-minimal-modules.h +++ b/frontend/test/test-minimal-modules.h @@ -22,45 +22,6 @@ #include -static std::string DomainModule = -R"""( -module ChapelDomain { - record _domain { - var _pid: int; - var _instance; - var _unowned:bool; - } - - proc type _domain.static_type(param rank : int, type idxType=int, param stridable: bool = false) type { - return __primitive("static domain type", rank, idxType, stridable); - } - - proc type _domain.static_type(type idxType, param parSafe: bool = true) type { - return __primitive("static domain type", idxType, parSafe); - } - - proc computeIndexType(arg: domain) type { - if arg.isRectangular() { - if arg.rank == 1 then return arg.idxType; - else return arg.rank*arg.idxType; - } else { - return arg.idxType; - } - } - - iter _domain.these() { - var ret : computeIndexType(this); - yield ret; - } - - // Prove that fields and methods on '_domain' work - proc _domain.pid() { - return _pid; - } -} -)"""; - - static std::string ArrayModule = R"""( module ChapelArray { @@ -98,6 +59,12 @@ module ChapelArray { proc chpl__buildIndexType(d: domain) type do return chpl__buildIndexType(d.rank, d.idxType); + + param nullPid = -1; + + proc _isPrivatized(value) param do return false; + + record dmap { } } )""";