Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SE] Allow where clauses on non-generic declarations in generic contexts #23489

Merged
Merged
4 changes: 3 additions & 1 deletion include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1218,7 +1218,9 @@ class RequirementRepr {
void print(raw_ostream &OS) const;
void print(ASTPrinter &Printer) const;
};


using GenericParamSource = PointerUnion<GenericContext *, GenericParamList *>;

/// GenericParamList - A list of generic parameters that is part of a generic
/// function or type, along with extra requirements placed on those generic
/// parameters and types derived from them.
Expand Down
7 changes: 3 additions & 4 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1650,10 +1650,9 @@ ERROR(redundant_class_requirement,none,
"redundant 'class' requirement", ())
ERROR(late_class_requirement,none,
"'class' must come first in the requirement list", ())
ERROR(where_without_generic_params,none,
"'where' clause cannot be attached to "
"%select{a non-generic|a protocol|an associated type}0 "
"declaration", (unsigned))
ERROR(where_toplevel_nongeneric,none,
"'where' clause cannot be attached to non-generic "
"top-level declaration", ())
AnthonyLatsis marked this conversation as resolved.
Show resolved Hide resolved
ERROR(where_inside_brackets,none,
"'where' clause next to generic parameters is obsolete, "
"must be written following the declaration's type", ())
Expand Down
4 changes: 4 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2751,6 +2751,10 @@ ERROR(dynamic_self_stored_property_init,none,
ERROR(dynamic_self_default_arg,none,
"covariant 'Self' type cannot be referenced from a default argument expression", ())

ERROR(where_nongeneric_ctx,none,
"'where' clause on non-generic member declaration requires a "
"generic context", ())

//------------------------------------------------------------------------------
// MARK: Type Check Attributes
//------------------------------------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ class InferredGenericSignatureRequest :
public SimpleRequest<InferredGenericSignatureRequest,
GenericSignature (ModuleDecl *,
GenericSignatureImpl *,
GenericParamList *,
GenericParamSource,
SmallVector<Requirement, 2>,
SmallVector<TypeLoc, 2>,
bool),
Expand All @@ -1124,7 +1124,7 @@ class InferredGenericSignatureRequest :
evaluate(Evaluator &evaluator,
ModuleDecl *module,
GenericSignatureImpl *baseSignature,
GenericParamList *gpl,
GenericParamSource paramSource,
SmallVector<Requirement, 2> addedRequirements,
SmallVector<TypeLoc, 2> inferenceSources,
bool allowConcreteGenericParams) const;
Expand Down
2 changes: 1 addition & 1 deletion include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ SWIFT_REQUEST(TypeChecker, HasDynamicMemberLookupAttributeRequest,
bool(CanType), Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, InferredGenericSignatureRequest,
GenericSignature (ModuleDecl *, GenericSignatureImpl *,
GenericParamList *,
GenericParamSource,
SmallVector<Requirement, 2>,
SmallVector<TypeLoc, 2>, bool),
Cached, NoLocationInfo)
Expand Down
9 changes: 9 additions & 0 deletions include/swift/Basic/SimpleDisplay.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,15 @@ namespace swift {
}
out << "}";
}

template<typename T, typename U>
void simple_display(llvm::raw_ostream &out,
const llvm::PointerUnion<T, U> &ptrUnion) {
if (const auto t = ptrUnion.template dyn_cast<T>())
simple_display(out, t);
else
simple_display(out, ptrUnion.template get<U>());
}
}

#endif // SWIFT_BASIC_SIMPLE_DISPLAY_H
12 changes: 4 additions & 8 deletions include/swift/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,7 @@ class Parser {
bool allowClassRequirement,
bool allowAnyObject);
ParserStatus parseDeclItem(bool &PreviousHadSemi,
Parser::ParseDeclOptions Options,
ParseDeclOptions Options,
llvm::function_ref<void(Decl*)> handler);
std::pair<std::vector<Decl *>, Optional<std::string>>
parseDeclList(SourceLoc LBLoc, SourceLoc &RBLoc, Diag<> ErrorDiag,
Expand Down Expand Up @@ -1634,14 +1634,10 @@ class Parser {
void
diagnoseWhereClauseInGenericParamList(const GenericParamList *GenericParams);

enum class WhereClauseKind : unsigned {
Declaration,
Protocol,
AssociatedType
};
ParserStatus
parseFreestandingGenericWhereClause(GenericParamList *GPList,
WhereClauseKind kind=WhereClauseKind::Declaration);
parseFreestandingGenericWhereClause(GenericContext *genCtx,
GenericParamList *&GPList,
ParseDeclOptions flags);

ParserStatus parseGenericWhereClause(
SourceLoc &WhereLoc, SmallVectorImpl<RequirementRepr> &Requirements,
Expand Down
40 changes: 20 additions & 20 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4470,19 +4470,22 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
assert(isa<AbstractFunctionDecl>(base) || isa<SubscriptDecl>(base));
assert(isa<AbstractFunctionDecl>(derived) || isa<SubscriptDecl>(derived));

auto baseClass = base->getDeclContext()->getSelfClassDecl();
auto derivedClass = derived->getDeclContext()->getSelfClassDecl();
const auto baseClass = base->getDeclContext()->getSelfClassDecl();
const auto derivedClass = derived->getDeclContext()->getSelfClassDecl();

assert(baseClass != nullptr);
assert(derivedClass != nullptr);

auto baseGenericSig = base->getAsGenericContext()->getGenericSignature();
auto derivedGenericSig = derived->getAsGenericContext()->getGenericSignature();
const auto baseGenericSig =
base->getAsGenericContext()->getGenericSignature();
const auto derivedGenericSig =
derived->getAsGenericContext()->getGenericSignature();

if (base == derived)
return derivedGenericSig;

if (derivedClass->getSuperclass().isNull())
const auto derivedSuperclass = derivedClass->getSuperclass();
if (derivedSuperclass.isNull())
return nullptr;

if (derivedGenericSig.isNull())
Expand All @@ -4491,12 +4494,6 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
if (baseGenericSig.isNull())
return derivedGenericSig;

auto baseClassSig = baseClass->getGenericSignature();
auto subMap = derivedClass->getSuperclass()->getContextSubstitutionMap(
derivedClass->getModuleContext(), baseClass);

unsigned derivedDepth = 0;

auto key = OverrideSignatureKey(baseGenericSig,
derivedGenericSig,
derivedClass);
Expand All @@ -4506,22 +4503,25 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
return getImpl().overrideSigCache.lookup(key);
}

if (auto derivedSig = derivedClass->getGenericSignature())
derivedDepth = derivedSig->getGenericParams().back()->getDepth() + 1;
const auto derivedClassSig = derivedClass->getGenericSignature();

unsigned derivedDepth = 0;
unsigned baseDepth = 0;
if (derivedClassSig)
derivedDepth = derivedClassSig->getGenericParams().back()->getDepth() + 1;
if (const auto baseClassSig = baseClass->getGenericSignature())
baseDepth = baseClassSig->getGenericParams().back()->getDepth() + 1;

SmallVector<GenericTypeParamType *, 2> addedGenericParams;
if (auto *gpList = derived->getAsGenericContext()->getGenericParams()) {
if (const auto *gpList = derived->getAsGenericContext()->getGenericParams()) {
for (auto gp : *gpList) {
addedGenericParams.push_back(
gp->getDeclaredInterfaceType()->castTo<GenericTypeParamType>());
}
}

unsigned baseDepth = 0;

if (baseClassSig) {
baseDepth = baseClassSig->getGenericParams().back()->getDepth() + 1;
}
const auto subMap = derivedSuperclass->getContextSubstitutionMap(
derivedClass->getModuleContext(), baseClass);

auto substFn = [&](SubstitutableType *type) -> Type {
auto *gp = cast<GenericTypeParamType>(type);
Expand Down Expand Up @@ -4553,7 +4553,7 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
auto genericSig = evaluateOrDefault(
evaluator,
AbstractGenericSignatureRequest{
derivedClass->getGenericSignature().getPointer(),
derivedClassSig.getPointer(),
std::move(addedGenericParams),
std::move(addedRequirements)},
GenericSignature());
Expand Down
4 changes: 2 additions & 2 deletions lib/AST/ASTScopeLookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ bool ASTScopeImpl::doesContextMatchStartingContext(
// For a SubscriptDecl with generic parameters, the call tries to do lookups
// with startingContext equal to either the get or set subscript
// AbstractFunctionDecls. Since the generic parameters are in the
// SubScriptDeclScope, and not the AbstractFunctionDecl scopes (after all how
// could one parameter be in two scopes?), GenericParamScoped intercepts the
// SubscriptDeclScope, and not the AbstractFunctionDecl scopes (after all how
// could one parameter be in two scopes?), GenericParamScope intercepts the
// match query here and tests against the accessor DeclContexts.
bool GenericParamScope::doesContextMatchStartingContext(
const DeclContext *context) const {
Expand Down
90 changes: 40 additions & 50 deletions lib/AST/ASTWalker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,31 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
// Decls
//===--------------------------------------------------------------------===//

bool visitGenericParamListIfNeeded(GenericContext *GC) {
// Must check this first in case extensions have not been bound yet
if (Walker.shouldWalkIntoGenericParams()) {
if (auto *params = GC->getGenericParams()) {
visitGenericParamList(params);
}
return true;
}
return false;
}

bool visitTrailingRequirements(GenericContext *GC) {
if (const auto Where = GC->getTrailingWhereClause()) {
for (auto &Req: Where->getRequirements())
if (doIt(Req))
return true;
} else if (!isa<ExtensionDecl>(GC)) {
if (const auto GP = GC->getGenericParams())
for (auto Req: GP->getTrailingRequirements())
if (doIt(Req))
return true;
}
return false;
}

bool visitImportDecl(ImportDecl *ID) {
return false;
}
Expand All @@ -138,12 +163,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
if (doIt(Inherit))
return true;
}
if (auto *Where = ED->getTrailingWhereClause()) {
for(auto &Req: Where->getRequirements()) {
if (doIt(Req))
return true;
}
}
if (visitTrailingRequirements(ED))
return true;

for (Decl *M : ED->getMembers()) {
if (doIt(M))
return true;
Expand Down Expand Up @@ -223,15 +245,13 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
}

bool visitTypeAliasDecl(TypeAliasDecl *TAD) {
if (Walker.shouldWalkIntoGenericParams() && TAD->getGenericParams()) {
if (visitGenericParamList(TAD->getGenericParams()))
return true;
}
bool WalkGenerics = visitGenericParamListIfNeeded(TAD);

if (auto typerepr = TAD->getUnderlyingTypeRepr())
if (doIt(typerepr))
return true;
return false;

return WalkGenerics && visitTrailingRequirements(TAD);
}

bool visitOpaqueTypeDecl(OpaqueTypeDecl *OTD) {
Expand Down Expand Up @@ -269,20 +289,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
}

// Visit requirements
if (WalkGenerics) {
ArrayRef<swift::RequirementRepr> Reqs = None;
if (auto *Protocol = dyn_cast<ProtocolDecl>(NTD)) {
if (auto *WhereClause = Protocol->getTrailingWhereClause())
Reqs = WhereClause->getRequirements();
} else {
Reqs = NTD->getGenericParams()->getTrailingRequirements();
}
for (auto Req: Reqs) {
if (doIt(Req))
return true;
}
}

if (WalkGenerics && visitTrailingRequirements(NTD))
return true;

for (Decl *Member : NTD->getMembers()) {
if (doIt(Member))
return true;
Expand Down Expand Up @@ -325,13 +334,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
if (doIt(SD->getElementTypeLoc()))
return true;

if (WalkGenerics) {
// Visit generic requirements
for (auto Req : SD->getGenericParams()->getTrailingRequirements()) {
if (doIt(Req))
return true;
}
}
// Visit trailing requirements
if (WalkGenerics && visitTrailingRequirements(SD))
return true;

if (!Walker.shouldWalkAccessorsTheOldWay()) {
for (auto *AD : SD->getAllAccessors())
Expand Down Expand Up @@ -364,13 +369,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
if (doIt(FD->getBodyResultTypeLoc()))
return true;

if (WalkGenerics) {
// Visit trailing requirments
for (auto Req : AFD->getGenericParams()->getTrailingRequirements()) {
if (doIt(Req))
return true;
}
}
// Visit trailing requirements
if (WalkGenerics && visitTrailingRequirements(AFD))
return true;

if (AFD->getBody(/*canSynthesize=*/false)) {
AbstractFunctionDecl::BodyKind PreservedKind = AFD->getBodyKind();
Expand Down Expand Up @@ -1323,17 +1324,6 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
}
return false;
}

private:
bool visitGenericParamListIfNeeded(GenericContext *gc) {
if (Walker.shouldWalkIntoGenericParams()) {
if (auto *params = gc->getGenericParams()) {
visitGenericParamList(params);
return true;
}
}
return false;
}
};

} // end anonymous namespace
Expand Down
9 changes: 6 additions & 3 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1077,9 +1077,12 @@ void GenericContext::setGenericSignature(GenericSignature genericSig) {
}

SourceRange GenericContext::getGenericTrailingWhereClauseSourceRange() const {
if (!isGeneric())
return SourceRange();
return getGenericParams()->getTrailingWhereClauseSourceRange();
if (isGeneric())
return getGenericParams()->getTrailingWhereClauseSourceRange();
else if (const auto *where = getTrailingWhereClause())
return where->getSourceRange();

return SourceRange();
}

ImportDecl *ImportDecl::create(ASTContext &Ctx, DeclContext *DC,
Expand Down
Loading