diff --git a/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts b/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts index c772e8bd5..d57c60021 100644 --- a/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts +++ b/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts @@ -49,13 +49,11 @@ import { SdsAbstractCall, SdsAbstractResult, SdsAssignee, - type SdsBlockLambda, SdsCall, SdsCallableType, SdsClass, SdsDeclaration, SdsExpression, - type SdsExpressionLambda, SdsFunction, SdsIndexedAccess, SdsInfixOperation, @@ -127,6 +125,12 @@ export class SafeDsTypeComputer { private readonly partialEvaluator: SafeDsPartialEvaluator; private readonly typeChecker: SafeDsTypeChecker; + /** + * Contains all lambda parameters that are currently being computed. When computing the types of lambda parameters, + * they must only access the type of the containing lambda, if they are not contained in this set themselves. + * Otherwise, this would cause endless recursion. + */ + private readonly incompleteLambdaParameters = new Set(); private readonly nodeTypeCache: WorkspaceCache; constructor(services: SafeDsServices) { @@ -154,10 +158,18 @@ export class SafeDsTypeComputer { return UnknownType; } - // Ignore type parameter substitutions for caching - const unsubstitutedType = this.nodeTypeCache.get(this.getNodeId(node), () => - this.doComputeType(node).simplify(), - ); + const id = this.getNodeId(node); + + // Only cache fully substituted types + let unsubstitutedType: Type | undefined = this.nodeTypeCache.get(id); + if (!unsubstitutedType) { + unsubstitutedType = this.doComputeType(node).simplify(); + + if (unsubstitutedType.isFullySubstituted) { + this.nodeTypeCache.set(id, unsubstitutedType); + } + } + if (isEmpty(substitutions)) { return unsubstitutedType; } @@ -292,6 +304,12 @@ export class SafeDsTypeComputer { // Lambda passed as argument if (isSdsArgument(containerOfLambda)) { + // Lookup parameter type in lambda unless the lambda is being computed. These contain the correct + // substitutions for type parameters. + if (!this.incompleteLambdaParameters.has(node)) { + return this.computeType(containingCallable); + } + const parameter = this.nodeMapper.argumentToParameter(containerOfLambda); if (!parameter) { return UnknownType; @@ -347,12 +365,8 @@ export class SafeDsTypeComputer { // Recursive cases else if (isSdsArgument(node)) { return this.computeType(node.value); - } else if (isSdsBlockLambda(node)) { - return this.computeTypeOfBlockLambda(node); } else if (isSdsCall(node)) { return this.computeTypeOfCall(node); - } else if (isSdsExpressionLambda(node)) { - return this.computeTypeOfExpressionLambda(node); } else if (isSdsIndexedAccess(node)) { return this.computeTypeOfIndexedAccess(node); } else if (isSdsInfixOperation(node)) { @@ -392,6 +406,8 @@ export class SafeDsTypeComputer { default: return UnknownType; } + } else if (isSdsLambda(node)) { + return this.computeTypeOfLambda(node); } else if (isSdsList(node)) { const elementType = this.lowestCommonSupertype(node.elements.map((it) => this.computeType(it))); return this.coreTypes.List(elementType); @@ -431,24 +447,6 @@ export class SafeDsTypeComputer { } /* c8 ignore stop */ } - private computeTypeOfBlockLambda(node: SdsBlockLambda): Type { - const parameterEntries = getParameters(node).map( - (it) => new NamedTupleEntry(it, it.name, this.computeType(it)), - ); - const resultEntries = streamBlockLambdaResults(node) - .map((it) => new NamedTupleEntry(it, it.name, this.computeType(it))) - .toArray(); - - const unsubstitutedType = this.factory.createCallableType( - node, - undefined, - this.factory.createNamedTupleType(...parameterEntries), - this.factory.createNamedTupleType(...resultEntries), - ); - const substitutions = this.computeSubstitutionsForLambda(node, unsubstitutedType); - return unsubstitutedType.substituteTypeParameters(substitutions); - } - private computeTypeOfCall(node: SdsCall): Type { const receiverType = this.computeType(node.receiver); const nonNullableReceiverType = this.computeNonNullableType(receiverType); @@ -486,24 +484,6 @@ export class SafeDsTypeComputer { return result.withExplicitNullability(receiverType.isExplicitlyNullable && node.isNullSafe); } - private computeTypeOfExpressionLambda(node: SdsExpressionLambda): Type { - const parameterEntries = getParameters(node).map( - (it) => new NamedTupleEntry(it, it.name, this.computeType(it)), - ); - const resultEntries = [ - new NamedTupleEntry(undefined, 'result', this.computeType(node.result)), - ]; - - const unsubstitutedType = this.factory.createCallableType( - node, - undefined, - this.factory.createNamedTupleType(...parameterEntries), - this.factory.createNamedTupleType(...resultEntries), - ); - const substitutions = this.computeSubstitutionsForLambda(node, unsubstitutedType); - return unsubstitutedType.substituteTypeParameters(substitutions); - } - private computeTypeOfIndexedAccess(node: SdsIndexedAccess): Type { const receiverType = this.computeType(node.receiver); if (!(receiverType instanceof ClassType) && !(receiverType instanceof TypeVariable)) { @@ -553,6 +533,44 @@ export class SafeDsTypeComputer { } } + private computeTypeOfLambda(node: SdsLambda): Type { + // Remember lambda parameters + const parameters = getParameters(node); + parameters.forEach((it) => { + this.incompleteLambdaParameters.add(it); + }); + + const parameterEntries = parameters.map((it) => new NamedTupleEntry(it, it.name, this.computeType(it))); + const resultEntries = this.buildLambdaResultEntries(node); + + const unsubstitutedType = this.factory.createCallableType( + node, + undefined, + this.factory.createNamedTupleType(...parameterEntries), + this.factory.createNamedTupleType(...resultEntries), + ); + const substitutions = this.computeSubstitutionsForLambda(node, unsubstitutedType); + + // Forget lambda parameters + parameters.forEach((it) => { + this.incompleteLambdaParameters.delete(it); + }); + + return unsubstitutedType.substituteTypeParameters(substitutions); + } + + private buildLambdaResultEntries(node: SdsLambda): NamedTupleEntry[] { + if (isSdsExpressionLambda(node)) { + return [new NamedTupleEntry(undefined, 'result', this.computeType(node.result))]; + } else if (isSdsBlockLambda(node)) { + return streamBlockLambdaResults(node) + .map((it) => new NamedTupleEntry(it, it.name, this.computeType(it))) + .toArray(); + } /* c8 ignore start */ else { + return []; + } /* c8 ignore stop */ + } + private computeTypeOfMemberAccess(node: SdsMemberAccess) { const memberType = this.computeType(node.member);