diff --git a/include/swift/AST/Decl.h b/include/swift/AST/Decl.h index ba44e4e70a197..a0159d8169d63 100644 --- a/include/swift/AST/Decl.h +++ b/include/swift/AST/Decl.h @@ -1218,7 +1218,9 @@ class RequirementRepr { void print(raw_ostream &OS) const; void print(ASTPrinter &Printer) const; }; - + +using GenericParamSource = PointerUnion; + /// GenericParamList - A list of generic parameters that is part of a generic /// function or type, along with extra requirements placed on those generic /// parameters and types derived from them. diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def index 9215622fdfc41..d9630eaf6d9e0 100644 --- a/include/swift/AST/DiagnosticsParse.def +++ b/include/swift/AST/DiagnosticsParse.def @@ -1650,10 +1650,9 @@ ERROR(redundant_class_requirement,none, "redundant 'class' requirement", ()) ERROR(late_class_requirement,none, "'class' must come first in the requirement list", ()) -ERROR(where_without_generic_params,none, - "'where' clause cannot be attached to " - "%select{a non-generic|a protocol|an associated type}0 " - "declaration", (unsigned)) +ERROR(where_toplevel_nongeneric,none, + "'where' clause cannot be attached to non-generic " + "top-level declaration", ()) ERROR(where_inside_brackets,none, "'where' clause next to generic parameters is obsolete, " "must be written following the declaration's type", ()) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index cd83a44cfc04c..d382d0d942ef3 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2751,6 +2751,10 @@ ERROR(dynamic_self_stored_property_init,none, ERROR(dynamic_self_default_arg,none, "covariant 'Self' type cannot be referenced from a default argument expression", ()) +ERROR(where_nongeneric_ctx,none, + "'where' clause on non-generic member declaration requires a " + "generic context", ()) + //------------------------------------------------------------------------------ // MARK: Type Check Attributes //------------------------------------------------------------------------------ diff --git a/include/swift/AST/TypeCheckRequests.h b/include/swift/AST/TypeCheckRequests.h index ed67e20294ba9..0caa439d73f6a 100644 --- a/include/swift/AST/TypeCheckRequests.h +++ b/include/swift/AST/TypeCheckRequests.h @@ -1108,7 +1108,7 @@ class InferredGenericSignatureRequest : public SimpleRequest, SmallVector, bool), @@ -1124,7 +1124,7 @@ class InferredGenericSignatureRequest : evaluate(Evaluator &evaluator, ModuleDecl *module, GenericSignatureImpl *baseSignature, - GenericParamList *gpl, + GenericParamSource paramSource, SmallVector addedRequirements, SmallVector inferenceSources, bool allowConcreteGenericParams) const; diff --git a/include/swift/AST/TypeCheckerTypeIDZone.def b/include/swift/AST/TypeCheckerTypeIDZone.def index 2c3e517bb5a92..c4d0c823dcd35 100644 --- a/include/swift/AST/TypeCheckerTypeIDZone.def +++ b/include/swift/AST/TypeCheckerTypeIDZone.def @@ -83,7 +83,7 @@ SWIFT_REQUEST(TypeChecker, HasDynamicMemberLookupAttributeRequest, bool(CanType), Cached, NoLocationInfo) SWIFT_REQUEST(TypeChecker, InferredGenericSignatureRequest, GenericSignature (ModuleDecl *, GenericSignatureImpl *, - GenericParamList *, + GenericParamSource, SmallVector, SmallVector, bool), Cached, NoLocationInfo) diff --git a/include/swift/Basic/SimpleDisplay.h b/include/swift/Basic/SimpleDisplay.h index d54b90cd8c1ec..3a553359091aa 100644 --- a/include/swift/Basic/SimpleDisplay.h +++ b/include/swift/Basic/SimpleDisplay.h @@ -135,6 +135,15 @@ namespace swift { } out << "}"; } + + template + void simple_display(llvm::raw_ostream &out, + const llvm::PointerUnion &ptrUnion) { + if (const auto t = ptrUnion.template dyn_cast()) + simple_display(out, t); + else + simple_display(out, ptrUnion.template get()); + } } #endif // SWIFT_BASIC_SIMPLE_DISPLAY_H diff --git a/include/swift/Parse/Parser.h b/include/swift/Parse/Parser.h index be2a0aa3797a6..486855aa94339 100644 --- a/include/swift/Parse/Parser.h +++ b/include/swift/Parse/Parser.h @@ -1066,7 +1066,7 @@ class Parser { bool allowClassRequirement, bool allowAnyObject); ParserStatus parseDeclItem(bool &PreviousHadSemi, - Parser::ParseDeclOptions Options, + ParseDeclOptions Options, llvm::function_ref handler); std::pair, Optional> parseDeclList(SourceLoc LBLoc, SourceLoc &RBLoc, Diag<> ErrorDiag, @@ -1634,14 +1634,10 @@ class Parser { void diagnoseWhereClauseInGenericParamList(const GenericParamList *GenericParams); - enum class WhereClauseKind : unsigned { - Declaration, - Protocol, - AssociatedType - }; ParserStatus - parseFreestandingGenericWhereClause(GenericParamList *GPList, - WhereClauseKind kind=WhereClauseKind::Declaration); + parseFreestandingGenericWhereClause(GenericContext *genCtx, + GenericParamList *&GPList, + ParseDeclOptions flags); ParserStatus parseGenericWhereClause( SourceLoc &WhereLoc, SmallVectorImpl &Requirements, diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index 3ce2b4ebd3ee0..a0e3ce76366ca 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -4470,19 +4470,22 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base, assert(isa(base) || isa(base)); assert(isa(derived) || isa(derived)); - auto baseClass = base->getDeclContext()->getSelfClassDecl(); - auto derivedClass = derived->getDeclContext()->getSelfClassDecl(); + const auto baseClass = base->getDeclContext()->getSelfClassDecl(); + const auto derivedClass = derived->getDeclContext()->getSelfClassDecl(); assert(baseClass != nullptr); assert(derivedClass != nullptr); - auto baseGenericSig = base->getAsGenericContext()->getGenericSignature(); - auto derivedGenericSig = derived->getAsGenericContext()->getGenericSignature(); + const auto baseGenericSig = + base->getAsGenericContext()->getGenericSignature(); + const auto derivedGenericSig = + derived->getAsGenericContext()->getGenericSignature(); if (base == derived) return derivedGenericSig; - if (derivedClass->getSuperclass().isNull()) + const auto derivedSuperclass = derivedClass->getSuperclass(); + if (derivedSuperclass.isNull()) return nullptr; if (derivedGenericSig.isNull()) @@ -4491,12 +4494,6 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base, if (baseGenericSig.isNull()) return derivedGenericSig; - auto baseClassSig = baseClass->getGenericSignature(); - auto subMap = derivedClass->getSuperclass()->getContextSubstitutionMap( - derivedClass->getModuleContext(), baseClass); - - unsigned derivedDepth = 0; - auto key = OverrideSignatureKey(baseGenericSig, derivedGenericSig, derivedClass); @@ -4506,22 +4503,25 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base, return getImpl().overrideSigCache.lookup(key); } - if (auto derivedSig = derivedClass->getGenericSignature()) - derivedDepth = derivedSig->getGenericParams().back()->getDepth() + 1; + const auto derivedClassSig = derivedClass->getGenericSignature(); + + unsigned derivedDepth = 0; + unsigned baseDepth = 0; + if (derivedClassSig) + derivedDepth = derivedClassSig->getGenericParams().back()->getDepth() + 1; + if (const auto baseClassSig = baseClass->getGenericSignature()) + baseDepth = baseClassSig->getGenericParams().back()->getDepth() + 1; SmallVector addedGenericParams; - if (auto *gpList = derived->getAsGenericContext()->getGenericParams()) { + if (const auto *gpList = derived->getAsGenericContext()->getGenericParams()) { for (auto gp : *gpList) { addedGenericParams.push_back( gp->getDeclaredInterfaceType()->castTo()); } } - unsigned baseDepth = 0; - - if (baseClassSig) { - baseDepth = baseClassSig->getGenericParams().back()->getDepth() + 1; - } + const auto subMap = derivedSuperclass->getContextSubstitutionMap( + derivedClass->getModuleContext(), baseClass); auto substFn = [&](SubstitutableType *type) -> Type { auto *gp = cast(type); @@ -4553,7 +4553,7 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base, auto genericSig = evaluateOrDefault( evaluator, AbstractGenericSignatureRequest{ - derivedClass->getGenericSignature().getPointer(), + derivedClassSig.getPointer(), std::move(addedGenericParams), std::move(addedRequirements)}, GenericSignature()); diff --git a/lib/AST/ASTScopeLookup.cpp b/lib/AST/ASTScopeLookup.cpp index 8f165166e89a5..26f946fb0164e 100644 --- a/lib/AST/ASTScopeLookup.cpp +++ b/lib/AST/ASTScopeLookup.cpp @@ -194,8 +194,8 @@ bool ASTScopeImpl::doesContextMatchStartingContext( // For a SubscriptDecl with generic parameters, the call tries to do lookups // with startingContext equal to either the get or set subscript // AbstractFunctionDecls. Since the generic parameters are in the -// SubScriptDeclScope, and not the AbstractFunctionDecl scopes (after all how -// could one parameter be in two scopes?), GenericParamScoped intercepts the +// SubscriptDeclScope, and not the AbstractFunctionDecl scopes (after all how +// could one parameter be in two scopes?), GenericParamScope intercepts the // match query here and tests against the accessor DeclContexts. bool GenericParamScope::doesContextMatchStartingContext( const DeclContext *context) const { diff --git a/lib/AST/ASTWalker.cpp b/lib/AST/ASTWalker.cpp index e44d2dc513b10..b20291653cdbf 100644 --- a/lib/AST/ASTWalker.cpp +++ b/lib/AST/ASTWalker.cpp @@ -126,6 +126,31 @@ class Traversal : public ASTVisitorgetGenericParams()) { + visitGenericParamList(params); + } + return true; + } + return false; + } + + bool visitTrailingRequirements(GenericContext *GC) { + if (const auto Where = GC->getTrailingWhereClause()) { + for (auto &Req: Where->getRequirements()) + if (doIt(Req)) + return true; + } else if (!isa(GC)) { + if (const auto GP = GC->getGenericParams()) + for (auto Req: GP->getTrailingRequirements()) + if (doIt(Req)) + return true; + } + return false; + } + bool visitImportDecl(ImportDecl *ID) { return false; } @@ -138,12 +163,9 @@ class Traversal : public ASTVisitorgetTrailingWhereClause()) { - for(auto &Req: Where->getRequirements()) { - if (doIt(Req)) - return true; - } - } + if (visitTrailingRequirements(ED)) + return true; + for (Decl *M : ED->getMembers()) { if (doIt(M)) return true; @@ -223,15 +245,13 @@ class Traversal : public ASTVisitorgetGenericParams()) { - if (visitGenericParamList(TAD->getGenericParams())) - return true; - } + bool WalkGenerics = visitGenericParamListIfNeeded(TAD); if (auto typerepr = TAD->getUnderlyingTypeRepr()) if (doIt(typerepr)) return true; - return false; + + return WalkGenerics && visitTrailingRequirements(TAD); } bool visitOpaqueTypeDecl(OpaqueTypeDecl *OTD) { @@ -269,20 +289,9 @@ class Traversal : public ASTVisitor Reqs = None; - if (auto *Protocol = dyn_cast(NTD)) { - if (auto *WhereClause = Protocol->getTrailingWhereClause()) - Reqs = WhereClause->getRequirements(); - } else { - Reqs = NTD->getGenericParams()->getTrailingRequirements(); - } - for (auto Req: Reqs) { - if (doIt(Req)) - return true; - } - } - + if (WalkGenerics && visitTrailingRequirements(NTD)) + return true; + for (Decl *Member : NTD->getMembers()) { if (doIt(Member)) return true; @@ -325,13 +334,9 @@ class Traversal : public ASTVisitorgetElementTypeLoc())) return true; - if (WalkGenerics) { - // Visit generic requirements - for (auto Req : SD->getGenericParams()->getTrailingRequirements()) { - if (doIt(Req)) - return true; - } - } + // Visit trailing requirements + if (WalkGenerics && visitTrailingRequirements(SD)) + return true; if (!Walker.shouldWalkAccessorsTheOldWay()) { for (auto *AD : SD->getAllAccessors()) @@ -364,13 +369,9 @@ class Traversal : public ASTVisitorgetBodyResultTypeLoc())) return true; - if (WalkGenerics) { - // Visit trailing requirments - for (auto Req : AFD->getGenericParams()->getTrailingRequirements()) { - if (doIt(Req)) - return true; - } - } + // Visit trailing requirements + if (WalkGenerics && visitTrailingRequirements(AFD)) + return true; if (AFD->getBody(/*canSynthesize=*/false)) { AbstractFunctionDecl::BodyKind PreservedKind = AFD->getBodyKind(); @@ -1323,17 +1324,6 @@ class Traversal : public ASTVisitorgetGenericParams()) { - visitGenericParamList(params); - return true; - } - } - return false; - } }; } // end anonymous namespace diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index d40e73beaed4a..34db43f73aa5e 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -1077,9 +1077,12 @@ void GenericContext::setGenericSignature(GenericSignature genericSig) { } SourceRange GenericContext::getGenericTrailingWhereClauseSourceRange() const { - if (!isGeneric()) - return SourceRange(); - return getGenericParams()->getTrailingWhereClauseSourceRange(); + if (isGeneric()) + return getGenericParams()->getTrailingWhereClauseSourceRange(); + else if (const auto *where = getTrailingWhereClause()) + return where->getSourceRange(); + + return SourceRange(); } ImportDecl *ImportDecl::create(ASTContext &Ctx, DeclContext *DC, diff --git a/lib/AST/GenericSignatureBuilder.cpp b/lib/AST/GenericSignatureBuilder.cpp index d48ec86f034f3..33f7159a4fe4c 100644 --- a/lib/AST/GenericSignatureBuilder.cpp +++ b/lib/AST/GenericSignatureBuilder.cpp @@ -7459,7 +7459,7 @@ llvm::Expected InferredGenericSignatureRequest::evaluate( Evaluator &evaluator, ModuleDecl *parentModule, GenericSignatureImpl *parentSig, - GenericParamList *gpl, + GenericParamSource paramSource, SmallVector addedRequirements, SmallVector inferenceSources, bool allowConcreteGenericParams) const { @@ -7470,78 +7470,99 @@ InferredGenericSignatureRequest::evaluate( // from that context. builder.addGenericSignature(parentSig); - // Type check the generic parameters, treating all generic type - // parameters as dependent, unresolved. - SmallVector gpLists; - if (gpl->getOuterParameters() && !parentSig) { - for (auto *outerParams = gpl; + DeclContext *lookupDC = nullptr; + + const auto visitRequirement = [&](const Requirement &req, + RequirementRepr *reqRepr) { + const auto source = FloatingRequirementSource::forExplicit(reqRepr); + + // If we're extending a protocol and adding a redundant requirement, + // for example, `extension Foo where Self: Foo`, then emit a + // diagnostic. + + if (auto decl = lookupDC->getAsDecl()) { + if (auto extDecl = dyn_cast(decl)) { + auto extType = extDecl->getDeclaredInterfaceType(); + auto extSelfType = extDecl->getSelfInterfaceType(); + auto reqLHSType = req.getFirstType(); + auto reqRHSType = req.getSecondType(); + + if (extType->isExistentialType() && + reqLHSType->isEqual(extSelfType) && + reqRHSType->isEqual(extType)) { + + auto &ctx = extDecl->getASTContext(); + ctx.Diags.diagnose(extDecl->getLoc(), + diag::protocol_extension_redundant_requirement, + extType->getString(), + extSelfType->getString(), + reqRHSType->getString()); + } + } + } + + builder.addRequirement(req, reqRepr, source, nullptr, + lookupDC->getParentModule()); + return false; + }; + + GenericParamList *genericParams = nullptr; + if (auto params = paramSource.dyn_cast()) + genericParams = params; + else + genericParams = paramSource.get()->getGenericParams(); + + if (genericParams) { + // Extensions never have a parent signature. + if (genericParams->getOuterParameters()) + assert(parentSig == nullptr); + + // Type check the generic parameters, treating all generic type + // parameters as dependent, unresolved. + SmallVector gpLists; + for (auto *outerParams = genericParams; outerParams != nullptr; outerParams = outerParams->getOuterParameters()) { gpLists.push_back(outerParams); } - } else { - gpLists.push_back(gpl); - } - // The generic parameter lists MUST appear from innermost to outermost. - // We walk them backwards to order outer requirements before - // inner requirements. - for (auto &genericParams : llvm::reverse(gpLists)) { - assert(genericParams->size() > 0 && - "Parsed an empty generic parameter list?"); + // The generic parameter lists MUST appear from innermost to outermost. + // We walk them backwards to order outer requirements before + // inner requirements. + for (auto &genericParams : llvm::reverse(gpLists)) { + assert(genericParams->size() > 0 && + "Parsed an empty generic parameter list?"); - // Determine where and how to perform name lookup. - DeclContext *lookupDC = genericParams->begin()[0]->getDeclContext(); + // First, add the generic parameters to the generic signature builder. + // Do this before checking the inheritance clause, since it may + // itself be dependent on one of these parameters. + for (const auto param : *genericParams) + builder.addGenericParameter(param); - // First, add the generic parameters to the generic signature builder. - // Do this before checking the inheritance clause, since it may - // itself be dependent on one of these parameters. - for (auto param : *genericParams) - builder.addGenericParameter(param); + // Add the requirements for each of the generic parameters to the builder. + // Now, check the inheritance clauses of each parameter. + for (const auto param : *genericParams) + builder.addGenericParameterRequirements(param); - // Add the requirements for each of the generic parameters to the builder. - // Now, check the inheritance clauses of each parameter. - for (auto param : *genericParams) - builder.addGenericParameterRequirements(param); + // Determine where and how to perform name lookup. + lookupDC = genericParams->begin()[0]->getDeclContext(); - // Add the requirements clause to the builder. + // Add the requirements clause to the builder. + WhereClauseOwner(lookupDC, genericParams) + .visitRequirements(TypeResolutionStage::Structural, + visitRequirement); + } + } else { + // The declaration has a where clause, but no generic parameters of its own. + const auto ctx = paramSource.get(); - using FloatingRequirementSource = - GenericSignatureBuilder::FloatingRequirementSource; - WhereClauseOwner(lookupDC, genericParams).visitRequirements( - TypeResolutionStage::Structural, - [&](const Requirement &req, RequirementRepr *reqRepr) { - auto source = FloatingRequirementSource::forExplicit(reqRepr); - - // If we're extending a protocol and adding a redundant requirement, - // for example, `extension Foo where Self: Foo`, then emit a - // diagnostic. - - if (auto decl = lookupDC->getAsDecl()) { - if (auto extDecl = dyn_cast(decl)) { - auto extType = extDecl->getDeclaredInterfaceType(); - auto extSelfType = extDecl->getSelfInterfaceType(); - auto reqLHSType = req.getFirstType(); - auto reqRHSType = req.getSecondType(); - - if (extType->isExistentialType() && - reqLHSType->isEqual(extSelfType) && - reqRHSType->isEqual(extType)) { - - auto &ctx = extDecl->getASTContext(); - ctx.Diags.diagnose(extDecl->getLoc(), - diag::protocol_extension_redundant_requirement, - extType->getString(), - extSelfType->getString(), - reqRHSType->getString()); - } - } - } - - builder.addRequirement(req, reqRepr, source, nullptr, - lookupDC->getParentModule()); - return false; - }); + assert(ctx->getTrailingWhereClause() && "No params or where clause"); + + // Determine where and how to perform name lookup. + lookupDC = ctx; + + WhereClauseOwner(ctx).visitRequirements( + TypeResolutionStage::Structural, visitRequirement); } /// Perform any remaining requirement inference. diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp index 9b6c2341de92a..0f68884435807 100644 --- a/lib/Parse/ParseDecl.cpp +++ b/lib/Parse/ParseDecl.cpp @@ -4424,7 +4424,7 @@ void Parser::diagnoseConsecutiveIDs(StringRef First, SourceLoc FirstLoc, /// Parse a Decl item in decl list. ParserStatus Parser::parseDeclItem(bool &PreviousHadSemi, - Parser::ParseDeclOptions Options, + ParseDeclOptions Options, llvm::function_ref handler) { if (Tok.is(tok::semi)) { // Consume ';' without preceding decl. @@ -5008,7 +5008,7 @@ parseDeclTypeAlias(Parser::ParseDeclOptions Flags, DeclAttributes &Attributes) { // Parse a 'where' clause if present, adding it to our GenericParamList. if (Tok.is(tok::kw_where)) { ContextChange CC(*this, TAD); - Status |= parseFreestandingGenericWhereClause(genericParams); + Status |= parseFreestandingGenericWhereClause(TAD, genericParams, Flags); } if (UnderlyingTy.isNull()) { @@ -6348,7 +6348,7 @@ ParserResult Parser::parseDeclFunc(SourceLoc StaticLoc, if (Tok.is(tok::kw_where)) { ContextChange CC(*this, FD); - Status |= parseFreestandingGenericWhereClause(GenericParams); + Status |= parseFreestandingGenericWhereClause(FD, GenericParams, Flags); if (Status.hasCodeCompletion() && !CodeCompletion) { // Trigger delayed parsing, no need to continue. return Status; @@ -6603,12 +6603,13 @@ ParserResult Parser::parseDeclEnum(ParseDeclOptions Flags, // Parse a 'where' clause if present, adding it to our GenericParamList. if (Tok.is(tok::kw_where)) { - auto whereStatus = parseFreestandingGenericWhereClause(GenericParams); - Status |= whereStatus; + auto whereStatus = + parseFreestandingGenericWhereClause(ED, GenericParams, Flags); if (whereStatus.hasCodeCompletion() && !CodeCompletion) { // Trigger delayed parsing, no need to continue. return whereStatus; } + Status |= whereStatus; } SyntaxParsingContext BlockContext(SyntaxContext, SyntaxKind::MemberDeclBlock); @@ -6889,12 +6890,13 @@ ParserResult Parser::parseDeclStruct(ParseDeclOptions Flags, // Parse a 'where' clause if present, adding it to our GenericParamList. if (Tok.is(tok::kw_where)) { - auto whereStatus = parseFreestandingGenericWhereClause(GenericParams); - Status |= whereStatus; + auto whereStatus = + parseFreestandingGenericWhereClause(SD, GenericParams, Flags); if (whereStatus.hasCodeCompletion() && !CodeCompletion) { // Trigger delayed parsing, no need to continue. return whereStatus; } + Status |= whereStatus; } // Make the entities of the struct as a code block. @@ -7005,12 +7007,13 @@ ParserResult Parser::parseDeclClass(ParseDeclOptions Flags, // Parse a 'where' clause if present, adding it to our GenericParamList. if (Tok.is(tok::kw_where)) { - auto whereStatus = parseFreestandingGenericWhereClause(GenericParams); - Status |= whereStatus; + auto whereStatus = + parseFreestandingGenericWhereClause(CD, GenericParams, Flags); if (whereStatus.hasCodeCompletion() && !CodeCompletion) { // Trigger delayed parsing, no need to continue. return whereStatus; } + Status |= whereStatus; } SyntaxParsingContext BlockContext(SyntaxContext, SyntaxKind::MemberDeclBlock); @@ -7257,7 +7260,8 @@ Parser::parseDeclSubscript(SourceLoc StaticLoc, if (Tok.is(tok::kw_where)) { ContextChange CC(*this, Subscript); - Status |= parseFreestandingGenericWhereClause(GenericParams); + Status |= parseFreestandingGenericWhereClause(Subscript, GenericParams, + Flags); if (Status.hasCodeCompletion() && !CodeCompletion) { // Trigger delayed parsing, no need to continue. return Status; @@ -7399,7 +7403,7 @@ Parser::parseDeclInit(ParseDeclOptions Flags, DeclAttributes &Attributes) { if (Tok.is(tok::kw_where)) { ContextChange(*this, CD); - Status |= parseFreestandingGenericWhereClause(GenericParams); + Status |= parseFreestandingGenericWhereClause(CD, GenericParams, Flags); if (Status.hasCodeCompletion() && !CodeCompletion) { // Trigger delayed parsing, no need to continue. return Status; diff --git a/lib/Parse/ParseGeneric.cpp b/lib/Parse/ParseGeneric.cpp index 0eb76c2611c40..d2e7099d41c24 100644 --- a/lib/Parse/ParseGeneric.cpp +++ b/lib/Parse/ParseGeneric.cpp @@ -392,21 +392,15 @@ ParserStatus Parser::parseGenericWhereClause( } -/// Parse a free-standing where clause attached to a declaration, adding it to -/// a generic parameter list that may (or may not) already exist. +/// Parse a free-standing where clause attached to a declaration, +/// adding it to a generic parameter list, if any, or to the given +/// generic context representing the declaration. ParserStatus Parser:: -parseFreestandingGenericWhereClause(GenericParamList *genericParams, - WhereClauseKind kind) { +parseFreestandingGenericWhereClause(GenericContext *genCtx, + GenericParamList *&genericParams, + ParseDeclOptions flags) { assert(Tok.is(tok::kw_where) && "Shouldn't call this without a where"); - - // Push the generic arguments back into a local scope so that references will - // find them. - Scope S(this, ScopeKind::Generics); - - if (genericParams) - for (auto pd : genericParams->getParams()) - addToScope(pd); - + SmallVector Requirements; SourceLoc WhereLoc; bool FirstTypeInComplete; @@ -415,10 +409,23 @@ parseFreestandingGenericWhereClause(GenericParamList *genericParams, if (result.shouldStopParsing() || Requirements.empty()) return result; - if (!genericParams) - diagnose(WhereLoc, diag::where_without_generic_params, unsigned(kind)); - else + if (genericParams) { + // Push the generic arguments back into a local scope so that references will + // find them. + Scope S(this, ScopeKind::Generics); + for (auto pd : genericParams->getParams()) + addToScope(pd); + genericParams->addTrailingWhereClause(Context, WhereLoc, Requirements); + + // A where clause that references only outer generic parameters? + } else if (flags.contains(PD_HasContainerType)) { + genCtx->setTrailingWhereClause( + TrailingWhereClause::create(Context, WhereLoc, Requirements)); + } else { + diagnose(WhereLoc, diag::where_toplevel_nongeneric); + } + return ParserStatus(); } diff --git a/lib/Sema/TypeCheckGeneric.cpp b/lib/Sema/TypeCheckGeneric.cpp index 70ffd4360b4bb..39e03030533d1 100644 --- a/lib/Sema/TypeCheckGeneric.cpp +++ b/lib/Sema/TypeCheckGeneric.cpp @@ -450,16 +450,17 @@ void TypeChecker::checkReferencedGenericParams(GenericContext *dc) { /// GenericSignature TypeChecker::checkGenericSignature( - GenericParamList *genericParamList, + GenericParamSource paramSource, DeclContext *dc, GenericSignature parentSig, bool allowConcreteGenericParams, SmallVector additionalRequirements, SmallVector inferenceSources) { - assert(genericParamList && "Missing generic parameters?"); + if (auto genericParamList = paramSource.dyn_cast()) + assert(genericParamList && "Missing generic parameters?"); auto request = InferredGenericSignatureRequest{ - dc->getParentModule(), parentSig.getPointer(), genericParamList, + dc->getParentModule(), parentSig.getPointer(), paramSource, additionalRequirements, inferenceSources, allowConcreteGenericParams}; auto sig = evaluateOrDefault(dc->getASTContext().evaluator, @@ -489,7 +490,7 @@ GenericSignature TypeChecker::checkGenericSignature( /// extension's list of generic parameters. static Type formExtensionInterfaceType( ExtensionDecl *ext, Type type, - GenericParamList *genericParams, + const GenericParamList *genericParams, SmallVectorImpl &sameTypeReqs, bool &mustInferRequirements) { if (type->is()) @@ -602,28 +603,45 @@ GenericSignatureRequest::evaluate(Evaluator &evaluator, return sig; } - // We can fast-path computing the generic signature of non-generic - // declarations by re-using the parent context's signature. - auto *gp = GC->getGenericParams(); - if (!gp) { - return GC->getParent()->getGenericSignatureOfContext(); - } + bool allowConcreteGenericParams = false; + const auto *genericParams = GC->getGenericParams(); + if (genericParams) { + // Setup the depth of the generic parameters. + const_cast(genericParams) + ->setDepth(GC->getGenericContextDepth()); + + // Accessors can always use the generic context of their storage + // declarations. This is a compile-time optimization since it lets us + // avoid the requirements-gathering phase, but it also simplifies that + // work for accessors which don't mention the value type in their formal + // signatures (like the read and modify coroutines, since yield types + // aren't tracked in the AST type yet). + if (auto accessor = dyn_cast(GC->getAsDecl())) { + return cast(accessor->getStorage())->getGenericSignature(); + } + + // ...or we may have a where clause dependent on outer generic parameters. + } else if (const auto *where = GC->getTrailingWhereClause()) { + // If there is no generic context for the where clause to + // rely on, diagnose that now and bail out. + if (!GC->isGenericContext()) { + GC->getASTContext().Diags.diagnose(where->getWhereLoc(), + diag::where_nongeneric_ctx); + return nullptr; + } - // Setup the depth of the generic parameters. - gp->setDepth(GC->getGenericContextDepth()); - - // Accessors can always use the generic context of their storage - // declarations. This is a compile-time optimization since it lets us - // avoid the requirements-gathering phase, but it also simplifies that - // work for accessors which don't mention the value type in their formal - // signatures (like the read and modify coroutines, since yield types - // aren't tracked in the AST type yet). - if (auto accessor = dyn_cast(GC->getAsDecl())) { - return cast(accessor->getStorage())->getGenericSignature(); + allowConcreteGenericParams = true; + } else { + // We can fast-path computing the generic signature of non-generic + // declarations by re-using the parent context's signature. + if (auto accessor = dyn_cast(GC->getAsDecl())) + if (auto subscript = dyn_cast(accessor->getStorage())) + return subscript->getGenericSignature(); + + return GC->getParent()->getGenericSignatureOfContext(); } auto parentSig = GC->getParent()->getGenericSignatureOfContext(); - bool allowConcreteGenericParams = false; SmallVector inferenceSources; SmallVector sameTypeReqs; if (auto VD = dyn_cast_or_null(GC->getAsDecl())) { @@ -685,11 +703,11 @@ GenericSignatureRequest::evaluate(Evaluator &evaluator, bool mustInferRequirements = false; Type extInterfaceType = formExtensionInterfaceType(ext, ext->getExtendedType(), - gp, sameTypeReqs, + genericParams, sameTypeReqs, mustInferRequirements); auto cannotReuseNominalSignature = [&]() -> bool { - const auto finalDepth = gp->getParams().back()->getDepth(); + const auto finalDepth = genericParams->getParams().back()->getDepth(); return mustInferRequirements || !sameTypeReqs.empty() || ext->getTrailingWhereClause() @@ -717,7 +735,7 @@ GenericSignatureRequest::evaluate(Evaluator &evaluator, } return TypeChecker::checkGenericSignature( - gp, GC, parentSig, + GC, GC, parentSig, allowConcreteGenericParams, sameTypeReqs, inferenceSources); } diff --git a/lib/Sema/TypeCheckType.cpp b/lib/Sema/TypeCheckType.cpp index e0ffe53d8327c..f315fae69d82a 100644 --- a/lib/Sema/TypeCheckType.cpp +++ b/lib/Sema/TypeCheckType.cpp @@ -629,9 +629,9 @@ static bool isPointerToVoid(ASTContext &Ctx, Type Ty, bool &IsMutable) { return BGT->getGenericArgs().front()->isVoid(); } -static Type checkConstrainedExtensionRequirements(Type type, - SourceLoc loc, - DeclContext *dc) { +static Type checkContextualRequirements(Type type, + SourceLoc loc, + DeclContext *dc) { // Even if the type is not generic, it might be inside of a generic // context, so we need to check requirements. GenericTypeDecl *decl; @@ -646,25 +646,34 @@ static Type checkConstrainedExtensionRequirements(Type type, return type; } - // FIXME: Some day the type might also have its own 'where' clause, even - // if its not generic. - - auto *ext = dyn_cast(decl->getDeclContext()); - if (!ext || !ext->isConstrainedExtension()) - return type; - - if (parentTy->hasUnboundGenericType() || + if (!parentTy || parentTy->hasUnboundGenericType() || parentTy->hasTypeVariable()) { return type; } - auto subMap = parentTy->getContextSubstitutions(ext); + // We are interested in either a contextual where clause or + // a constrained extension context. + TypeSubstitutionMap subMap; + GenericSignature genericSig; + SourceLoc noteLoc; + if (decl->getTrailingWhereClause()) { + subMap = parentTy->getContextSubstitutions(decl->getDeclContext()); + genericSig = decl->getGenericSignature(); + noteLoc = decl->getLoc(); + } else { + const auto ext = dyn_cast(decl->getDeclContext()); + if (ext && ext->isConstrainedExtension()) { + subMap = parentTy->getContextSubstitutions(ext); + genericSig = ext->getGenericSignature(); + noteLoc = ext->getLoc(); + } else { + return type; + } + } - SourceLoc noteLoc = ext->getLoc(); if (noteLoc.isInvalid()) noteLoc = loc; - auto genericSig = ext->getGenericSignature(); auto result = TypeChecker::checkGenericArguments( dc, loc, noteLoc, type, @@ -722,7 +731,7 @@ static Type applyGenericArguments(Type type, if (resolution.getStage() == TypeResolutionStage::Structural) return type; - return checkConstrainedExtensionRequirements(type, loc, dc); + return checkContextualRequirements(type, loc, dc); } if (type->hasError()) { diff --git a/lib/Sema/TypeChecker.h b/lib/Sema/TypeChecker.h index 2034687f1f3e2..ff12303bfd1de 100644 --- a/lib/Sema/TypeChecker.h +++ b/lib/Sema/TypeChecker.h @@ -696,7 +696,9 @@ class TypeChecker final { /// Construct a new generic environment for the given declaration context. /// - /// \param genericParams The generic parameters to validate. + /// \param paramSource The source of generic info: either a generic parameter + /// list or a generic context with a \c where clause dependent on outer + /// generic parameters. /// /// \param dc The declaration context in which to perform the validation. /// @@ -714,7 +716,7 @@ class TypeChecker final { /// /// \returns the resulting generic signature. static GenericSignature checkGenericSignature( - GenericParamList *genericParams, + GenericParamSource paramSource, DeclContext *dc, GenericSignature outerSignature, bool allowConcreteGenericParams, diff --git a/test/Generics/invalid.swift b/test/Generics/invalid.swift index b15394b1843cb..2d92cae026b70 100644 --- a/test/Generics/invalid.swift +++ b/test/Generics/invalid.swift @@ -1,11 +1,16 @@ // RUN: %target-typecheck-verify-swift -func bet() where A : B {} // expected-error {{'where' clause cannot be attached to a non-generic declaration}} +func bet() where A : B {} // expected-error {{'where' clause cannot be attached to non-generic top-level declaration}} -typealias gimel where A : B // expected-error {{'where' clause cannot be attached to a non-generic declaration}} -// expected-error@-1 {{expected '=' in type alias declaration}} +typealias gimel = Int where A : B // expected-error {{'where' clause cannot be attached to non-generic top-level declaration}} -class dalet where A : B {} // expected-error {{'where' clause cannot be attached to a non-generic declaration}} +class dalet where A : B {} // expected-error {{'where' clause cannot be attached to non-generic top-level declaration}} + +struct Where { + func bet() where A == B {} // expected-error {{'where' clause on non-generic member declaration requires a generic context}} + typealias gimel = Int where A : B // expected-error {{'where' clause on non-generic member declaration requires a generic context}} + class dalet where A : B {} // expected-error {{'where' clause on non-generic member declaration requires a generic context}} +} protocol he where A : B { // expected-error {{use of undeclared type 'A'}} // expected-error@-1 {{use of undeclared type 'B'}} diff --git a/test/Generics/where_clause_contextually_generic_decls.swift b/test/Generics/where_clause_contextually_generic_decls.swift new file mode 100644 index 0000000000000..68c41b6c6d095 --- /dev/null +++ b/test/Generics/where_clause_contextually_generic_decls.swift @@ -0,0 +1,144 @@ +// RUN: %target-typecheck-verify-swift -typecheck %s -verify -swift-version 4 + +// Make sure Self: ... is correctly diagnosed in classes + +class SelfInGenericClass { + // expected-error@+1 {{type 'Self' in conformance requirement does not refer to a generic parameter or associated type}} + func foo() where Self: Equatable { } + // expected-error@+1 {{generic signature requires types 'Self' and 'Bool' to be the same}} + func bar() where Self == Bool { } +} + +protocol Whereable { + associatedtype Assoc + associatedtype Bssoc + + // expected-error@+1 {{instance method requirement 'requirement1()' cannot add constraint 'Self.Assoc: Sequence' on 'Self'}} + func requirement1() where Assoc: Sequence + // expected-error@+1 {{instance method requirement 'requirement2()' cannot add constraint 'Self.Bssoc == Never' on 'Self'}} + func requirement2() where Bssoc == Never +} + +extension Whereable { + // expected-note@+1 {{where 'Self' = 'T1'}} + static func staticExtensionFunc(arg: Self.Element) -> Self.Element + where Self: Sequence { + return arg + } + + // expected-note@+1 {{where 'Self.Assoc' = 'T1.Assoc', 'Self.Bssoc' = 'T1.Bssoc'}} + func extensionFunc() where Assoc == Bssoc { } + + + // expected-note@+1 {{where 'Self.Assoc' = 'T1.Assoc'}} + subscript() -> Assoc where Assoc: Whereable { + fatalError() + } +} + +func testProtocolExtensions(t1: T1, t2: T2, t3: T3, t4: T4) + where T1: Whereable, + T2: Whereable & Sequence, + T3: Whereable, T3.Assoc == T3.Bssoc, + T4: Whereable, T4.Assoc: Whereable { + _ = T1.staticExtensionFunc // expected-error {{static method 'staticExtensionFunc(arg:)' requires that 'T1' conform to 'Sequence'}} + _ = T2.staticExtensionFunc + + t1.extensionFunc() // expected-error {{instance method 'extensionFunc()' requires the types 'T1.Assoc' and 'T1.Bssoc' be equivalent}} + t3.extensionFunc() + + _ = t1[] // expected-error {{subscript 'subscript()' requires that 'T1.Assoc' conform to 'Whereable'}} + _ = t4[] +} + +class Class { + // expected-note@+1 {{where 'T' = 'T}} // expected-note@+1 {{where 'T.Assoc' = 'T.Assoc'}} + static func staticFunc() where T: Whereable, T.Assoc == Int { } + + // expected-note@+1 {{candidate requires that the types 'T' and 'Bool' be equivalent}} + func func1() where T == Bool { } + // FIXME: The rhs type at the end of the error message is not persistent across compilations. + // expected-note@+1 {{candidate requires that the types 'T' and 'Int' be equivalent (requirement specified as 'T' == }} + func func1() where T == Int { } + + func func2() where T == Int { } // expected-note {{where 'T' = 'T'}} + + subscript() -> T.Element where T: Sequence { // expected-note {{where 'T' = 'T'}} + fatalError() + } +} + +extension Class { + static func staticExtensionFunc() where T: Class { } // expected-note {{where 'T' = 'T'}} + + subscript(arg: T.Element) -> T.Element where T == Array { + fatalError() + } +} + +extension Class where T: Equatable { + func extensionFunc() where T: Comparable { } // expected-note {{where 'T' = 'T'}} + + // expected-error@+1 {{same-type constraint type 'Class' does not conform to required protocol 'Equatable'}} + func badRequirement1() where T == Class { } +} + +extension Class where T == Bool { + // expected-error@+1 {{generic parameter 'T' cannot be equal to both 'Int' and 'Bool'}} + func badRequirement2() where T == Int { } +} + +func testMemberDeclarations(arg1: Class, arg2: Class) { + // expected-error@+2 {{static method 'staticFunc()' requires the types 'T.Assoc' and 'Int' be equivalent}} + // expected-error@+1 {{static method 'staticFunc()' requires that 'T' conform to 'Whereable'}} + Class.staticFunc() + Class.staticExtensionFunc() // expected-error {{static method 'staticExtensionFunc()' requires that 'T' inherit from 'Class'}} + Class>.staticExtensionFunc() + + arg1.func1() // expected-error {{no exact matches in call to instance method 'func1'}} + arg1.func2() // expected-error {{instance method 'func2()' requires the types 'T' and 'Int' be equivalent}} + arg1.extensionFunc() // expected-error {{instance method 'extensionFunc()' requires that 'T' conform to 'Comparable'}} + arg2.extensionFunc() + Class().func1() + Class().func2() + + arg1[] // expected-error {{subscript 'subscript()' requires that 'T' conform to 'Sequence'}} + _ = Class>()[Int.zero] +} + +// Test nested types and requirements. + +struct Container { + typealias NestedAlias = Bool where T == Int + // expected-note@-1 {{'NestedAlias' previously declared here}} + typealias NestedAlias = Bool where T == Bool + // expected-error@-1 {{invalid redeclaration of 'NestedAlias}} + typealias NestedAlias2 = T.Magnitude where T: FixedWidthInteger + + class NestedClass where T: Equatable {} +} + +extension Container where T: Sequence { + struct NestedStruct {} + + struct NestedStruct2 where T.Element: Comparable { + enum NestedEnum where T.Element == Double {} // expected-note {{requirement specified as 'T.Element' == 'Double' [with T = String]}} + } + + struct NestedStruct3 {} +} + +extension Container.NestedStruct3 { + func foo(arg: U) where U.Assoc == T {} +} + +_ = Container.NestedAlias2.self // expected-error {{type 'String' does not conform to protocol 'FixedWidthInteger'}} +_ = Container>.NestedClass.self // expected-error {{type 'Container' does not conform to protocol 'Equatable'}} +_ = Container.NestedStruct.self // expected-error {{type 'Void' does not conform to protocol 'Sequence'}} +_ = Container>.NestedStruct2.self // expected-error {{type 'Void' does not conform to protocol 'Comparable'}} +_ = Container.NestedStruct2.NestedEnum.self // expected-error {{'Container.NestedStruct2.NestedEnum' requires the types 'String.Element' (aka 'Character') and 'Double' be equivalent}} +_ = Container.NestedAlias2.self +_ = Container.NestedClass.self +_ = Container.NestedStruct.self +_ = Container>.NestedStruct2.self +_ = Container>.NestedStruct2.NestedEnum.self diff --git a/test/IDE/coloring.swift b/test/IDE/coloring.swift index baa895b4efb6b..aeb8b3f848cfa 100644 --- a/test/IDE/coloring.swift +++ b/test/IDE/coloring.swift @@ -522,3 +522,22 @@ enum E { // CHECK: var _ = 10 @available(iOS 99, *) var _ = 10 + +// CHECK: Array<T> where T: Equatable +typealias GenericAlias = Array where T: Equatable + +// Where clauses on contextually generic declarations +// +struct FreeWhere { + // CHECK: func foo() where T == Int + func foo() where T == Int {} + + // CHECK: subscript() -> Int where T: Sequence + subscript() -> Int where T: Sequence {} + + // CHECK: enum Enum where T == Int + enum Enum where T == Int {} + + // CHECK: typealias Alias = Int where T == Int + typealias Alias = Int where T == Int +} diff --git a/test/Interpreter/where_clause_contextually_generic_decl.swift b/test/Interpreter/where_clause_contextually_generic_decl.swift new file mode 100644 index 0000000000000..812aea9220d37 --- /dev/null +++ b/test/Interpreter/where_clause_contextually_generic_decl.swift @@ -0,0 +1,101 @@ +// RUN: %target-run-simple-swift | %FileCheck %s +// REQUIRES: executable_test + +protocol Protocol {} +extension Protocol { + func foo() { + print("I survived in \(Self.self).\(#function)") + } +} + +struct Foo { + struct Inner1 where T: Protocol { + let bar: T + } + + struct Inner2 where T == String { + func getString() -> T { + return "I survived in \(#function), T := \(T.self)" + } + } + + struct Inner3 { + func isLessThan(lhs: T, rhs: T) -> U where T: Comparable, U == Bool { + return lhs < rhs + } + + init() { + print("This is the unconstrained \(#function)") + } + init() where U == T { + print("I survived in \(#function), T := \(T.self), U := \(U.self)") + } + } +} + +struct ProtocolAdopter: Protocol, Equatable {} + +// CHECK: I survived in ProtocolAdopter.foo() +Foo.Inner1(bar: ProtocolAdopter()).bar.foo() + +// CHECK: I survived in getString(), T := String +print(Foo.Inner2().getString()) + +// CHECK: This is the unconstrained init() +// CHECK: false +print(Foo.Inner3().isLessThan(lhs: .zero, rhs: .zero)) + +// CHECK: I survived in init(), T := Bool, U := Bool +_ = Foo.Inner3() + +protocol RefinedProtocol: Protocol { + associatedtype Assoc = Self + associatedtype Bssoc: RefinedProtocol + + func overload() +} + +extension RefinedProtocol { + func callOverload() { + overload() + print("Assoc := \(Assoc.self), Bssoc := \(Bssoc.self)") + } + + func overload() where Assoc == Bssoc { + print("I survived in \(Self.self).\(#function) (1)") + } + func overload() where Assoc == Self { + print("I survived in \(Self.self).\(#function) (2)") + } + func overload() where Assoc: Sequence, Bssoc == Assoc.Element { + print("I survived in \(Self.self).\(#function) (3)") + } +} + +struct RefinedProtocolAdopter1: RefinedProtocol { + typealias Assoc = RefinedProtocolAdopter2 + typealias Bssoc = RefinedProtocolAdopter2 +} +struct RefinedProtocolAdopter2: RefinedProtocol { + typealias Bssoc = RefinedProtocolAdopter1 +} +struct RefinedProtocolAdopter3: RefinedProtocol { + typealias Assoc = Array + typealias Bssoc = RefinedProtocolAdopter3 +} + +@inline(never) +func callThroughToOverload(arg: T) { + arg.callOverload() +} + +// CHECK: I survived in RefinedProtocolAdopter1.overload() (1) +// CHECK: Assoc := RefinedProtocolAdopter2, Bssoc := RefinedProtocolAdopter2 +callThroughToOverload(arg: RefinedProtocolAdopter1()) +// CHECK: I survived in RefinedProtocolAdopter2.overload() (2) +// CHECK: Assoc := RefinedProtocolAdopter2, Bssoc := RefinedProtocolAdopter1 +callThroughToOverload(arg: RefinedProtocolAdopter2()) +// CHECK: I survived in RefinedProtocolAdopter3.overload() (3) +// CHECK: Assoc := Array, Bssoc := RefinedProtocolAdopter3 +callThroughToOverload(arg: RefinedProtocolAdopter3()) + diff --git a/test/Runtime/demangleToMetadata.swift b/test/Runtime/demangleToMetadata.swift index 2b36ff5090f46..47c66867bfb68 100644 --- a/test/Runtime/demangleToMetadata.swift +++ b/test/Runtime/demangleToMetadata.swift @@ -245,11 +245,18 @@ struct ConformsToP1: P1 { } struct ConformsToP2: P2 { } struct ConformsToP3: P3 { } +struct ContextualWhere1 { + class Nested1 where T: P1 { } + struct Nested2 where T == Int { } +} + DemangleToMetadataTests.test("protocol conformance requirements") { expectEqual(CG4.self, _typeByName("4main3CG4CyAA12ConformsToP1VAA12ConformsToP2VG")!) expectEqual(CG4.InnerGeneric.self, _typeByName("4main3CG4C12InnerGenericVyAA12ConformsToP1VAA12ConformsToP2V_AA12ConformsToP3VG")!) + expectEqual(ContextualWhere1.Nested1.self, + _typeByName("4main16ContextualWhere1V7Nested1CyAA12ConformsToP1V_G")!) // Failure cases: failed conformance requirements. expectNil(_typeByName("4main3CG4CyAA12ConformsToP1VAA12ConformsToP1VG")) @@ -274,9 +281,16 @@ struct ConformsToP4c : P4 { typealias Assoc2 = ConformsToP2 } +struct ContextualWhere2 { + struct Nested1 where U.Assoc1: P1, U.Assoc2: P2 { } + enum Nested2 where U.Assoc1 == U.Assoc2 { } +} + DemangleToMetadataTests.test("associated type conformance requirements") { expectEqual(SG5.self, _typeByName("4main3SG5VyAA13ConformsToP4aVG")!) + expectEqual(ContextualWhere2.Nested1.self, + _typeByName("4main16ContextualWhere2V7Nested1VyAA13ConformsToP4aV_G")!) // Failure cases: failed conformance requirements. expectNil(_typeByName("4main3SG5VyAA13ConformsToP4bVG")) @@ -297,12 +311,16 @@ DemangleToMetadataTests.test("same-type requirements") { // Concrete type. expectEqual(SG7.self, _typeByName("4main3SG7VyAA1SVG")!) + expectEqual(ContextualWhere1.Nested2.self, + _typeByName("4main16ContextualWhere1V7Nested2VySi_G")!) // Other associated type. expectEqual(SG6.self, _typeByName("4main3SG6VyAA13ConformsToP4bVG")!) expectEqual(SG6.self, _typeByName("4main3SG6VyAA13ConformsToP4cVG")!) + expectEqual(ContextualWhere2.Nested2.self, + _typeByName("4main16ContextualWhere2V7Nested2OyAA13ConformsToP4bV_G")!) // Structural type. expectEqual(SG8.self, diff --git a/test/SILGen/generic_signatures.swift b/test/SILGen/generic_signatures.swift index d73c76796467a..a5e24ab02b265 100644 --- a/test/SILGen/generic_signatures.swift +++ b/test/SILGen/generic_signatures.swift @@ -1,4 +1,4 @@ -// RUN: %target-swift-emit-silgen -parse-stdlib %s +// RUN: %target-swift-emit-silgen %s | %FileCheck %s protocol P { associatedtype Assoc @@ -81,3 +81,62 @@ func concreteJungle(_: T, f: @escaping (T.Foo) -> C) -> T.Foo where T : Fooab let ff: (C) -> T.Foo = f return ff(C()) } + +protocol Whereable { + associatedtype Assoc + associatedtype Bssoc: Whereable +} +extension Whereable { + // CHECK-LABEL sil hidden [ossa] @$s18generic_signatures9WhereablePAAE19staticExtensionFunc3arg7ElementSTQz8IteratorSTQz_tSTRzrlFZ : $@convention(method) (@in_guaranteed Self.Iterator, @thick Self.Type) -> @out Self.Element + static func staticExtensionFunc(arg: Self.Iterator) -> Self.Element + where Self: Sequence { + fatalError() + } + + // CHECK-LABEL sil hidden [ossa] @$s18generic_signatures9WhereablePAAE13extensionFuncyy5BssocQz5AssocRtzrlF : $@convention(method) (@in_guaranteed Self) -> () + func extensionFunc() where Assoc == Bssoc { } + + // CHECK-LABEL sil hidden [ossa] @$s18generic_signatures9WhereablePAAE5AssocQzSgycAabERQAD_5BssocQZAGRtzrluig : $@convention(method) (@in_guaranteed Self) -> @out Optional + subscript() -> Assoc + where Assoc: Whereable, Bssoc == Assoc.Bssoc { + fatalError() + } + + // CHECK-LABEL sil hidden [ossa] @$s18generic_signatures9WhereablePAAE5AssocQzSgycAabERQ5Bssoc_ADQZAERSrluig : $@convention(method) (@in_guaranteed Self) -> @out Optional + subscript() -> Assoc + where Assoc: Whereable, Assoc == Bssoc.Assoc { + fatalError() + } +} + +struct W1 {} +struct W2 {} + +class Class { + // CHECK-LABEL: sil hidden [ossa] @$s18generic_signatures5ClassC9classFuncyyAA9WhereableRz5AssocQzRszlFZ : $@convention(method) (@thick Class.Type) -> () + class func classFunc() where T: Whereable, T.Assoc == T { } + + // CHECK-LABEL: sil hidden [ossa] @$s18generic_signatures5ClassC5func1yyAA7FooableRzlF : $@convention(method) (@guaranteed Class) -> () + func func1() where T: Fooable { } + // CHECK-LABEL: sil hidden [ossa] @$s18generic_signatures5ClassC5func2yyAA2W1VRszlF : $@convention(method) (@guaranteed Class) -> () + func func2() where T == W1 { } + // CHECK-LABEL: sil hidden [ossa] @$s18generic_signatures5ClassC5func2yyAA2W2VRszlF : $@convention(method) (@guaranteed Class) -> () + func func2() where T == W2 { } + + // CHECK-LABEL: sil hidden [ossa] @$s18generic_signatures5ClassC5AssocQzycAA9WhereableRzluig : $@convention(method) (@guaranteed Class) -> @out T.Assoc + subscript() -> T.Assoc where T: Whereable { + fatalError() + } + + // CHECK-LABEL: sil hidden [ossa] @$s18generic_signatures5ClassC06NestedC0CAEyx_Gycfc : $@convention(method) (@owned Class.NestedClass) -> @owned Class.NestedClass + class NestedClass where T: Fooable { } +} + +extension Class where T: Whereable { + // CHECK-LABEL: sil hidden [ossa] @$s18generic_signatures5ClassCA2A9WhereableRzlE13extensionFuncyyAA7FooableRzrlF : $@convention(method) (@guaranteed Class) -> () + func extensionFunc() where T: Fooable { } +} +extension Class.NestedClass { + // CHECK-LABEL: sil hidden [ossa] @$s18generic_signatures5ClassC06NestedC0C3foo3argyx_tAA9WhereableRz3FooAA7FooablePQz5AssocAaHPRtzrlF : $@convention(method) (@in_guaranteed T, @guaranteed Class.NestedClass) -> () + func foo(arg: T) where T: Whereable, T.Foo == T.Assoc { } +} diff --git a/test/attr/attr_override.swift b/test/attr/attr_override.swift index 86d501ff75cea..af33027432976 100644 --- a/test/attr/attr_override.swift +++ b/test/attr/attr_override.swift @@ -579,6 +579,60 @@ class SR_4206_DerivedGeneric_6: SR_4206_BaseConcrete_6 { override func foo(arg: T) {} // expected-error {{overridden method 'foo' has generic signature which is incompatible with base method's generic signature ; expected generic signature to be }} } +// Where clauses on contextually generic declarations + +class SR_4206_Base_7 { + func foo1() where T: SR_4206_Protocol_1 {} // expected-note {{overridden declaration is here}} + func foo2() where T: SR_4206_Protocol_1 {} +} + +class SR_4206_Derived_7: SR_4206_Base_7 { + override func foo1() where T: SR_4206_Protocol_2 {} // expected-error {{overridden method 'foo1' has generic signature which is incompatible with base method's generic signature ; expected generic signature to be }} + + override func foo2() {} // OK +} + +// Subclass with new constraint on inherited generic param + +class SR_4206_Base_8 { + func foo() where T: SR_4206_Protocol_1 {} +} +class SR_4206_Derived_8: SR_4206_Base_8 { + override func foo() where T: SR_4206_Protocol_1 {} // OK +} + +// Same-type to conformance visibility reabstraction + +class SR_4206_Base_9 { + func foo() where T == Int {} +} +class SR_4206_Derived_9: SR_4206_Base_9 { + override func foo() where T: FixedWidthInteger {} // OK +} + +// Override with constraint on a non-inherited generic param + +class SR_4206_Base_10 { + func foo() where T: SR_4206_Protocol_1 {} // expected-note {{overridden declaration is here}} +} +class SR_4206_Derived_10: SR_4206_Base_10 { + override func foo() where U: SR_4206_Protocol_1 {} // expected-error {{overridden method 'foo' has generic signature which is incompatible with base method's generic signature ; expected generic signature to be }} +} + +// Override with return type specialization + +class SR_4206_Base_11 { + // The fact that the return type matches the substitution + // for T must hold across overrides. + func foo() -> T where T: FixedWidthInteger { fatalError() } // expected-note {{potential overridden instance method 'foo()' here}} +} +class SR_4206_Derived_11: SR_4206_Base_11 { + override func foo() -> Int { return .zero } // OK +} +class SR_4206_Derived2_11: SR_4206_Base_11 { + override func foo() -> Int { return .zero } // expected-error {{method does not override any method from its superclass}} +} + // Misc // protocol SR_4206_Key {}