Skip to content

Commit

Permalink
Merge pull request #23489 from AnthonyLatsis/where-clause-nongeneric-…
Browse files Browse the repository at this point in the history
…decl

[SE] Allow where clauses on non-generic declarations in generic contexts
  • Loading branch information
slavapestov authored Mar 6, 2020
2 parents 0bbd8de + f762644 commit adbf8da
Show file tree
Hide file tree
Showing 24 changed files with 692 additions and 228 deletions.
4 changes: 3 additions & 1 deletion include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1228,7 +1228,9 @@ class RequirementRepr {
void print(raw_ostream &OS) const;
void print(ASTPrinter &Printer) const;
};


using GenericParamSource = PointerUnion<GenericContext *, GenericParamList *>;

/// GenericParamList - A list of generic parameters that is part of a generic
/// function or type, along with extra requirements placed on those generic
/// parameters and types derived from them.
Expand Down
7 changes: 3 additions & 4 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1650,10 +1650,9 @@ ERROR(redundant_class_requirement,none,
"redundant 'class' requirement", ())
ERROR(late_class_requirement,none,
"'class' must come first in the requirement list", ())
ERROR(where_without_generic_params,none,
"'where' clause cannot be attached to "
"%select{a non-generic|a protocol|an associated type}0 "
"declaration", (unsigned))
ERROR(where_toplevel_nongeneric,none,
"'where' clause cannot be attached to non-generic "
"top-level declaration", ())
ERROR(where_inside_brackets,none,
"'where' clause next to generic parameters is obsolete, "
"must be written following the declaration's type", ())
Expand Down
4 changes: 4 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2751,6 +2751,10 @@ ERROR(dynamic_self_stored_property_init,none,
ERROR(dynamic_self_default_arg,none,
"covariant 'Self' type cannot be referenced from a default argument expression", ())

ERROR(where_nongeneric_ctx,none,
"'where' clause on non-generic member declaration requires a "
"generic context", ())

//------------------------------------------------------------------------------
// MARK: Type Check Attributes
//------------------------------------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ class InferredGenericSignatureRequest :
public SimpleRequest<InferredGenericSignatureRequest,
GenericSignature (ModuleDecl *,
GenericSignatureImpl *,
GenericParamList *,
GenericParamSource,
SmallVector<Requirement, 2>,
SmallVector<TypeLoc, 2>,
bool),
Expand All @@ -1124,7 +1124,7 @@ class InferredGenericSignatureRequest :
evaluate(Evaluator &evaluator,
ModuleDecl *module,
GenericSignatureImpl *baseSignature,
GenericParamList *gpl,
GenericParamSource paramSource,
SmallVector<Requirement, 2> addedRequirements,
SmallVector<TypeLoc, 2> inferenceSources,
bool allowConcreteGenericParams) const;
Expand Down
2 changes: 1 addition & 1 deletion include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ SWIFT_REQUEST(TypeChecker, HasDynamicMemberLookupAttributeRequest,
bool(CanType), Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, InferredGenericSignatureRequest,
GenericSignature (ModuleDecl *, GenericSignatureImpl *,
GenericParamList *,
GenericParamSource,
SmallVector<Requirement, 2>,
SmallVector<TypeLoc, 2>, bool),
Cached, NoLocationInfo)
Expand Down
9 changes: 9 additions & 0 deletions include/swift/Basic/SimpleDisplay.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,15 @@ namespace swift {
}
out << "}";
}

template<typename T, typename U>
void simple_display(llvm::raw_ostream &out,
const llvm::PointerUnion<T, U> &ptrUnion) {
if (const auto t = ptrUnion.template dyn_cast<T>())
simple_display(out, t);
else
simple_display(out, ptrUnion.template get<U>());
}
}

#endif // SWIFT_BASIC_SIMPLE_DISPLAY_H
12 changes: 4 additions & 8 deletions include/swift/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ class Parser {
bool allowClassRequirement,
bool allowAnyObject);
ParserStatus parseDeclItem(bool &PreviousHadSemi,
Parser::ParseDeclOptions Options,
ParseDeclOptions Options,
llvm::function_ref<void(Decl*)> handler);
std::pair<std::vector<Decl *>, Optional<std::string>>
parseDeclList(SourceLoc LBLoc, SourceLoc &RBLoc, Diag<> ErrorDiag,
Expand Down Expand Up @@ -1637,14 +1637,10 @@ class Parser {
void
diagnoseWhereClauseInGenericParamList(const GenericParamList *GenericParams);

enum class WhereClauseKind : unsigned {
Declaration,
Protocol,
AssociatedType
};
ParserStatus
parseFreestandingGenericWhereClause(GenericParamList *GPList,
WhereClauseKind kind=WhereClauseKind::Declaration);
parseFreestandingGenericWhereClause(GenericContext *genCtx,
GenericParamList *&GPList,
ParseDeclOptions flags);

ParserStatus parseGenericWhereClause(
SourceLoc &WhereLoc, SmallVectorImpl<RequirementRepr> &Requirements,
Expand Down
40 changes: 20 additions & 20 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4470,19 +4470,22 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
assert(isa<AbstractFunctionDecl>(base) || isa<SubscriptDecl>(base));
assert(isa<AbstractFunctionDecl>(derived) || isa<SubscriptDecl>(derived));

auto baseClass = base->getDeclContext()->getSelfClassDecl();
auto derivedClass = derived->getDeclContext()->getSelfClassDecl();
const auto baseClass = base->getDeclContext()->getSelfClassDecl();
const auto derivedClass = derived->getDeclContext()->getSelfClassDecl();

assert(baseClass != nullptr);
assert(derivedClass != nullptr);

auto baseGenericSig = base->getAsGenericContext()->getGenericSignature();
auto derivedGenericSig = derived->getAsGenericContext()->getGenericSignature();
const auto baseGenericSig =
base->getAsGenericContext()->getGenericSignature();
const auto derivedGenericSig =
derived->getAsGenericContext()->getGenericSignature();

if (base == derived)
return derivedGenericSig;

if (derivedClass->getSuperclass().isNull())
const auto derivedSuperclass = derivedClass->getSuperclass();
if (derivedSuperclass.isNull())
return nullptr;

if (derivedGenericSig.isNull())
Expand All @@ -4491,12 +4494,6 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
if (baseGenericSig.isNull())
return derivedGenericSig;

auto baseClassSig = baseClass->getGenericSignature();
auto subMap = derivedClass->getSuperclass()->getContextSubstitutionMap(
derivedClass->getModuleContext(), baseClass);

unsigned derivedDepth = 0;

auto key = OverrideSignatureKey(baseGenericSig,
derivedGenericSig,
derivedClass);
Expand All @@ -4506,22 +4503,25 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
return getImpl().overrideSigCache.lookup(key);
}

if (auto derivedSig = derivedClass->getGenericSignature())
derivedDepth = derivedSig->getGenericParams().back()->getDepth() + 1;
const auto derivedClassSig = derivedClass->getGenericSignature();

unsigned derivedDepth = 0;
unsigned baseDepth = 0;
if (derivedClassSig)
derivedDepth = derivedClassSig->getGenericParams().back()->getDepth() + 1;
if (const auto baseClassSig = baseClass->getGenericSignature())
baseDepth = baseClassSig->getGenericParams().back()->getDepth() + 1;

SmallVector<GenericTypeParamType *, 2> addedGenericParams;
if (auto *gpList = derived->getAsGenericContext()->getGenericParams()) {
if (const auto *gpList = derived->getAsGenericContext()->getGenericParams()) {
for (auto gp : *gpList) {
addedGenericParams.push_back(
gp->getDeclaredInterfaceType()->castTo<GenericTypeParamType>());
}
}

unsigned baseDepth = 0;

if (baseClassSig) {
baseDepth = baseClassSig->getGenericParams().back()->getDepth() + 1;
}
const auto subMap = derivedSuperclass->getContextSubstitutionMap(
derivedClass->getModuleContext(), baseClass);

auto substFn = [&](SubstitutableType *type) -> Type {
auto *gp = cast<GenericTypeParamType>(type);
Expand Down Expand Up @@ -4553,7 +4553,7 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
auto genericSig = evaluateOrDefault(
evaluator,
AbstractGenericSignatureRequest{
derivedClass->getGenericSignature().getPointer(),
derivedClassSig.getPointer(),
std::move(addedGenericParams),
std::move(addedRequirements)},
GenericSignature());
Expand Down
4 changes: 2 additions & 2 deletions lib/AST/ASTScopeLookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ bool ASTScopeImpl::doesContextMatchStartingContext(
// For a SubscriptDecl with generic parameters, the call tries to do lookups
// with startingContext equal to either the get or set subscript
// AbstractFunctionDecls. Since the generic parameters are in the
// SubScriptDeclScope, and not the AbstractFunctionDecl scopes (after all how
// could one parameter be in two scopes?), GenericParamScoped intercepts the
// SubscriptDeclScope, and not the AbstractFunctionDecl scopes (after all how
// could one parameter be in two scopes?), GenericParamScope intercepts the
// match query here and tests against the accessor DeclContexts.
bool GenericParamScope::doesContextMatchStartingContext(
const DeclContext *context) const {
Expand Down
90 changes: 40 additions & 50 deletions lib/AST/ASTWalker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,31 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
// Decls
//===--------------------------------------------------------------------===//

bool visitGenericParamListIfNeeded(GenericContext *GC) {
// Must check this first in case extensions have not been bound yet
if (Walker.shouldWalkIntoGenericParams()) {
if (auto *params = GC->getGenericParams()) {
visitGenericParamList(params);
}
return true;
}
return false;
}

bool visitTrailingRequirements(GenericContext *GC) {
if (const auto Where = GC->getTrailingWhereClause()) {
for (auto &Req: Where->getRequirements())
if (doIt(Req))
return true;
} else if (!isa<ExtensionDecl>(GC)) {
if (const auto GP = GC->getGenericParams())
for (auto Req: GP->getTrailingRequirements())
if (doIt(Req))
return true;
}
return false;
}

bool visitImportDecl(ImportDecl *ID) {
return false;
}
Expand All @@ -138,12 +163,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
if (doIt(Inherit))
return true;
}
if (auto *Where = ED->getTrailingWhereClause()) {
for(auto &Req: Where->getRequirements()) {
if (doIt(Req))
return true;
}
}
if (visitTrailingRequirements(ED))
return true;

for (Decl *M : ED->getMembers()) {
if (doIt(M))
return true;
Expand Down Expand Up @@ -223,15 +245,13 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
}

bool visitTypeAliasDecl(TypeAliasDecl *TAD) {
if (Walker.shouldWalkIntoGenericParams() && TAD->getGenericParams()) {
if (visitGenericParamList(TAD->getGenericParams()))
return true;
}
bool WalkGenerics = visitGenericParamListIfNeeded(TAD);

if (auto typerepr = TAD->getUnderlyingTypeRepr())
if (doIt(typerepr))
return true;
return false;

return WalkGenerics && visitTrailingRequirements(TAD);
}

bool visitOpaqueTypeDecl(OpaqueTypeDecl *OTD) {
Expand Down Expand Up @@ -269,20 +289,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
}

// Visit requirements
if (WalkGenerics) {
ArrayRef<swift::RequirementRepr> Reqs = None;
if (auto *Protocol = dyn_cast<ProtocolDecl>(NTD)) {
if (auto *WhereClause = Protocol->getTrailingWhereClause())
Reqs = WhereClause->getRequirements();
} else {
Reqs = NTD->getGenericParams()->getTrailingRequirements();
}
for (auto Req: Reqs) {
if (doIt(Req))
return true;
}
}

if (WalkGenerics && visitTrailingRequirements(NTD))
return true;

for (Decl *Member : NTD->getMembers()) {
if (doIt(Member))
return true;
Expand Down Expand Up @@ -325,13 +334,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
if (doIt(SD->getElementTypeLoc()))
return true;

if (WalkGenerics) {
// Visit generic requirements
for (auto Req : SD->getGenericParams()->getTrailingRequirements()) {
if (doIt(Req))
return true;
}
}
// Visit trailing requirements
if (WalkGenerics && visitTrailingRequirements(SD))
return true;

if (!Walker.shouldWalkAccessorsTheOldWay()) {
for (auto *AD : SD->getAllAccessors())
Expand Down Expand Up @@ -364,13 +369,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
if (doIt(FD->getBodyResultTypeLoc()))
return true;

if (WalkGenerics) {
// Visit trailing requirments
for (auto Req : AFD->getGenericParams()->getTrailingRequirements()) {
if (doIt(Req))
return true;
}
}
// Visit trailing requirements
if (WalkGenerics && visitTrailingRequirements(AFD))
return true;

if (AFD->getBody(/*canSynthesize=*/false)) {
AbstractFunctionDecl::BodyKind PreservedKind = AFD->getBodyKind();
Expand Down Expand Up @@ -1323,17 +1324,6 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
}
return false;
}

private:
bool visitGenericParamListIfNeeded(GenericContext *gc) {
if (Walker.shouldWalkIntoGenericParams()) {
if (auto *params = gc->getGenericParams()) {
visitGenericParamList(params);
return true;
}
}
return false;
}
};

} // end anonymous namespace
Expand Down
9 changes: 6 additions & 3 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1077,9 +1077,12 @@ void GenericContext::setGenericSignature(GenericSignature genericSig) {
}

SourceRange GenericContext::getGenericTrailingWhereClauseSourceRange() const {
if (!isGeneric())
return SourceRange();
return getGenericParams()->getTrailingWhereClauseSourceRange();
if (isGeneric())
return getGenericParams()->getTrailingWhereClauseSourceRange();
else if (const auto *where = getTrailingWhereClause())
return where->getSourceRange();

return SourceRange();
}

ImportDecl *ImportDecl::create(ASTContext &Ctx, DeclContext *DC,
Expand Down
Loading

0 comments on commit adbf8da

Please sign in to comment.