diff --git a/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts b/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts index acde31016..d57c60021 100644 --- a/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts +++ b/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts @@ -49,16 +49,15 @@ import { SdsAbstractCall, SdsAbstractResult, SdsAssignee, - type SdsBlockLambda, SdsCall, SdsCallableType, SdsClass, SdsDeclaration, SdsExpression, - type SdsExpressionLambda, SdsFunction, SdsIndexedAccess, SdsInfixOperation, + SdsLambda, SdsLiteralType, SdsMemberAccess, SdsNamedType, @@ -126,6 +125,12 @@ export class SafeDsTypeComputer { private readonly partialEvaluator: SafeDsPartialEvaluator; private readonly typeChecker: SafeDsTypeChecker; + /** + * Contains all lambda parameters that are currently being computed. When computing the types of lambda parameters, + * they must only access the type of the containing lambda, if they are not contained in this set themselves. + * Otherwise, this would cause endless recursion. + */ + private readonly incompleteLambdaParameters = new Set(); private readonly nodeTypeCache: WorkspaceCache; constructor(services: SafeDsServices) { @@ -149,27 +154,22 @@ export class SafeDsTypeComputer { * simplified as much as possible. */ computeType(node: AstNode | undefined, substitutions: TypeParameterSubstitutions = NO_SUBSTITUTIONS): Type { - return this.computeTypeWithRecursionCheck(node, {}, substitutions); - } - - computeTypeWithRecursionCheck( - node: AstNode | undefined, - state: ComputeTypeState, - substitutions: TypeParameterSubstitutions = NO_SUBSTITUTIONS, - ): Type { if (!node) { return UnknownType; } - // Don't cache the result if we are inferring type of lambda parameter from context - if (state?.computingSubstitutionsForCall) { - return this.doComputeTypeWithRecursionCheck(node, state).simplify(); + const id = this.getNodeId(node); + + // Only cache fully substituted types + let unsubstitutedType: Type | undefined = this.nodeTypeCache.get(id); + if (!unsubstitutedType) { + unsubstitutedType = this.doComputeType(node).simplify(); + + if (unsubstitutedType.isFullySubstituted) { + this.nodeTypeCache.set(id, unsubstitutedType); + } } - // Ignore type parameter substitutions for caching - const unsubstitutedType = this.nodeTypeCache.get(this.getNodeId(node), () => - this.doComputeTypeWithRecursionCheck(node, {}).simplify(), - ); if (isEmpty(substitutions)) { return unsubstitutedType; } @@ -187,23 +187,23 @@ export class SafeDsTypeComputer { return `${documentUri}~${nodePath}`; } - private doComputeTypeWithRecursionCheck(node: AstNode | undefined, state: ComputeTypeState): Type { + private doComputeType(node: AstNode | undefined): Type { if (isSdsAssignee(node)) { - return this.computeTypeOfAssignee(node, state); + return this.computeTypeOfAssignee(node); } else if (isSdsDeclaration(node)) { - return this.computeTypeOfDeclaration(node, state); + return this.computeTypeOfDeclaration(node); } else if (isSdsExpression(node)) { - return this.computeTypeOfExpression(node, state); + return this.computeTypeOfExpression(node); } else if (isSdsType(node)) { - return this.computeTypeOfType(node, state); + return this.computeTypeOfType(node); } else if (isSdsTypeArgument(node)) { - return this.computeTypeOfType(node.value, state); + return this.computeTypeOfType(node.value); } /* c8 ignore start */ else { return UnknownType; } /* c8 ignore stop */ } - private computeTypeOfAssignee(node: SdsAssignee, state: ComputeTypeState): Type { + private computeTypeOfAssignee(node: SdsAssignee): Type { const containingAssignment = AstUtils.getContainerOfType(node, isSdsAssignment); if (!containingAssignment) { /* c8 ignore next 2 */ @@ -211,7 +211,7 @@ export class SafeDsTypeComputer { } const assigneePosition = node.$containerIndex ?? -1; - const expressionType = this.computeTypeWithRecursionCheck(containingAssignment?.expression, state); + const expressionType = this.computeType(containingAssignment?.expression); if (expressionType instanceof NamedTupleType) { return expressionType.getTypeOfEntryByIndex(assigneePosition); } else if (assigneePosition === 0) { @@ -221,10 +221,10 @@ export class SafeDsTypeComputer { return UnknownType; } - private computeTypeOfDeclaration(node: SdsDeclaration, state: ComputeTypeState): Type { + private computeTypeOfDeclaration(node: SdsDeclaration): Type { if (isSdsAnnotation(node)) { const parameterEntries = getParameters(node).map( - (it) => new NamedTupleEntry(it, it.name, this.computeTypeWithRecursionCheck(it.type, state)), + (it) => new NamedTupleEntry(it, it.name, this.computeType(it.type)), ); return this.factory.createCallableType( @@ -234,7 +234,7 @@ export class SafeDsTypeComputer { this.factory.createNamedTupleType(), ); } else if (isSdsAttribute(node)) { - return this.computeTypeWithRecursionCheck(node.type, state); + return this.computeType(node.type); } else if (isSdsClass(node)) { return this.factory.createClassType(node, NO_SUBSTITUTIONS, false); } else if (isSdsEnum(node)) { @@ -242,17 +242,17 @@ export class SafeDsTypeComputer { } else if (isSdsEnumVariant(node)) { return this.factory.createEnumVariantType(node, false); } else if (isSdsFunction(node)) { - return this.computeTypeOfCallableWithManifestTypes(node, state); + return this.computeTypeOfCallableWithManifestTypes(node); } else if (isSdsParameter(node)) { - return this.computeTypeOfParameter(node, state); + return this.computeTypeOfParameter(node); } else if (isSdsPipeline(node)) { return UnknownType; } else if (isSdsResult(node)) { - return this.computeTypeWithRecursionCheck(node.type, state); + return this.computeType(node.type); } else if (isSdsSchema(node)) { return UnknownType; } else if (isSdsSegment(node)) { - return this.computeTypeOfCallableWithManifestTypes(node, state); + return this.computeTypeOfCallableWithManifestTypes(node); } else if (isSdsTypeParameter(node)) { return this.factory.createTypeVariable(node, false); } /* c8 ignore start */ else { @@ -260,15 +260,12 @@ export class SafeDsTypeComputer { } /* c8 ignore stop */ } - private computeTypeOfCallableWithManifestTypes( - node: SdsFunction | SdsSegment | SdsCallableType, - state: ComputeTypeState, - ): Type { + private computeTypeOfCallableWithManifestTypes(node: SdsFunction | SdsSegment | SdsCallableType): Type { const parameterEntries = getParameters(node).map( - (it) => new NamedTupleEntry(it, it.name, this.computeTypeWithRecursionCheck(it.type, state)), + (it) => new NamedTupleEntry(it, it.name, this.computeType(it.type)), ); const resultEntries = getResults(node.resultList).map( - (it) => new NamedTupleEntry(it, it.name, this.computeTypeWithRecursionCheck(it.type, state)), + (it) => new NamedTupleEntry(it, it.name, this.computeType(it.type)), ); return this.factory.createCallableType( @@ -279,15 +276,15 @@ export class SafeDsTypeComputer { ); } - private computeTypeOfParameter(node: SdsParameter, state: ComputeTypeState): Type { + private computeTypeOfParameter(node: SdsParameter): Type { // Manifest type if (node.type) { - const type = this.computeTypeWithRecursionCheck(node.type, state); + const type = this.computeType(node.type); return this.rememberParameterInCallableType(node, type); } // Infer type from context - const contextType = this.computeTypeOfParameterContext(node, state); + const contextType = this.computeTypeOfParameterContext(node); if (!(contextType instanceof CallableType)) { return UnknownType; } @@ -297,7 +294,7 @@ export class SafeDsTypeComputer { return this.rememberParameterInCallableType(node, type); } - private computeTypeOfParameterContext(node: SdsParameter, state: ComputeTypeState): Type { + private computeTypeOfParameterContext(node: SdsParameter): Type { const containingCallable = AstUtils.getContainerOfType(node, isSdsCallable); if (!isSdsLambda(containingCallable)) { return UnknownType; @@ -307,32 +304,23 @@ export class SafeDsTypeComputer { // Lambda passed as argument if (isSdsArgument(containerOfLambda)) { + // Lookup parameter type in lambda unless the lambda is being computed. These contain the correct + // substitutions for type parameters. + if (!this.incompleteLambdaParameters.has(node)) { + return this.computeType(containingCallable); + } + const parameter = this.nodeMapper.argumentToParameter(containerOfLambda); if (!parameter) { return UnknownType; } - // Don't continue if we are already computing substitutions for a call. Otherwise, we would end up in an - // infinite loop. - const parameterType = this.computeTypeWithRecursionCheck(parameter, state); - if (state?.computingSubstitutionsForCall) { - return parameterType; - } - - // Compute substitutions for containing call - const containingCall = AstUtils.getContainerOfType(containerOfLambda, isSdsCall); - if (!containingCall) { - /* c8 ignore next 2 */ - return parameterType; - } - - const substitutions = this.computeSubstitutionsForCall(containingCall); - return parameterType.substituteTypeParameters(substitutions); + return this.computeType(parameter); } // Lambda passed as default value if (isSdsParameter(containerOfLambda)) { - return this.computeTypeWithRecursionCheck(containerOfLambda, state); + return this.computeType(containerOfLambda); } // Yielded lambda @@ -341,7 +329,7 @@ export class SafeDsTypeComputer { if (!isSdsYield(firstAssignee)) { return UnknownType; } - return this.computeTypeWithRecursionCheck(firstAssignee.result?.ref, state); + return this.computeType(firstAssignee.result?.ref); } return UnknownType; @@ -355,10 +343,10 @@ export class SafeDsTypeComputer { } } - private computeTypeOfExpression(node: SdsExpression, state: ComputeTypeState): Type { + private computeTypeOfExpression(node: SdsExpression): Type { // Type cast if (isSdsTypeCast(node)) { - return this.computeTypeWithRecursionCheck(node.type, state); + return this.computeType(node.type); } // Partial evaluation (definitely handles SdsBoolean, SdsFloat, SdsInt, SdsNull, and SdsString) @@ -376,15 +364,11 @@ export class SafeDsTypeComputer { // Recursive cases else if (isSdsArgument(node)) { - return this.computeTypeWithRecursionCheck(node.value, state); - } else if (isSdsBlockLambda(node)) { - return this.computeTypeOfBlockLambda(node, state); + return this.computeType(node.value); } else if (isSdsCall(node)) { - return this.computeTypeOfCall(node, state); - } else if (isSdsExpressionLambda(node)) { - return this.computeTypeOfExpressionLambda(node, state); + return this.computeTypeOfCall(node); } else if (isSdsIndexedAccess(node)) { - return this.computeTypeOfIndexedAccess(node, state); + return this.computeTypeOfIndexedAccess(node); } else if (isSdsInfixOperation(node)) { switch (node.operator) { // Boolean operators @@ -411,26 +395,24 @@ export class SafeDsTypeComputer { case '-': case '*': case '/': - return this.computeTypeOfArithmeticInfixOperation(node, state); + return this.computeTypeOfArithmeticInfixOperation(node); // Elvis operator case '?:': - return this.computeTypeOfElvisOperation(node, state); + return this.computeTypeOfElvisOperation(node); // Unknown operator /* c8 ignore next 2 */ default: return UnknownType; } + } else if (isSdsLambda(node)) { + return this.computeTypeOfLambda(node); } else if (isSdsList(node)) { - const elementType = this.lowestCommonSupertype( - node.elements.map((it) => this.computeTypeWithRecursionCheck(it, state)), - ); + const elementType = this.lowestCommonSupertype(node.elements.map((it) => this.computeType(it))); return this.coreTypes.List(elementType); } else if (isSdsMap(node)) { - let keyType = this.lowestCommonSupertype( - node.entries.map((it) => this.computeTypeWithRecursionCheck(it.key, state)), - ); + let keyType = this.lowestCommonSupertype(node.entries.map((it) => this.computeType(it.key))); // Keeping literal types for keys is too strict: We would otherwise infer the key type of `{"a": 1, "b": 2}` // as `Literal<"a", "b">`. But then we would be unable to pass an unknown `String` as the key in an indexed @@ -438,20 +420,18 @@ export class SafeDsTypeComputer { // evaluator. keyType = this.computeClassTypeForLiteralType(keyType); - const valueType = this.lowestCommonSupertype( - node.entries.map((it) => this.computeTypeWithRecursionCheck(it.value, state)), - ); + const valueType = this.lowestCommonSupertype(node.entries.map((it) => this.computeType(it.value))); return this.coreTypes.Map(keyType, valueType); } else if (isSdsMemberAccess(node)) { - return this.computeTypeOfMemberAccess(node, state); + return this.computeTypeOfMemberAccess(node); } else if (isSdsParenthesizedExpression(node)) { - return this.computeTypeWithRecursionCheck(node.expression, state); + return this.computeType(node.expression); } else if (isSdsPrefixOperation(node)) { switch (node.operator) { case 'not': return this.coreTypes.Boolean; case '-': - return this.computeTypeOfArithmeticPrefixOperation(node, state); + return this.computeTypeOfArithmeticPrefixOperation(node); // Unknown operator /* c8 ignore next 2 */ @@ -459,32 +439,16 @@ export class SafeDsTypeComputer { return UnknownType; } } else if (isSdsReference(node)) { - return this.computeTypeOfReference(node, state); + return this.computeTypeOfReference(node); } else if (isSdsThis(node)) { - return this.computeTypeOfThis(node, state); + return this.computeTypeOfThis(node); } /* c8 ignore start */ else { return UnknownType; } /* c8 ignore stop */ } - private computeTypeOfBlockLambda(node: SdsBlockLambda, state: ComputeTypeState): Type { - const parameterEntries = getParameters(node).map( - (it) => new NamedTupleEntry(it, it.name, this.computeTypeWithRecursionCheck(it, state)), - ); - const resultEntries = streamBlockLambdaResults(node) - .map((it) => new NamedTupleEntry(it, it.name, this.computeTypeWithRecursionCheck(it, state))) - .toArray(); - - return this.factory.createCallableType( - node, - undefined, - this.factory.createNamedTupleType(...parameterEntries), - this.factory.createNamedTupleType(...resultEntries), - ); - } - - private computeTypeOfCall(node: SdsCall, state: ComputeTypeState): Type { - const receiverType = this.computeTypeWithRecursionCheck(node.receiver, state); + private computeTypeOfCall(node: SdsCall): Type { + const receiverType = this.computeType(node.receiver); const nonNullableReceiverType = this.computeNonNullableType(receiverType); let result: Type = UnknownType; @@ -520,28 +484,8 @@ export class SafeDsTypeComputer { return result.withExplicitNullability(receiverType.isExplicitlyNullable && node.isNullSafe); } - private computeTypeOfExpressionLambda(node: SdsExpressionLambda, state: ComputeTypeState): Type { - const parameterEntries = getParameters(node).map( - (it) => new NamedTupleEntry(it, it.name, this.computeTypeWithRecursionCheck(it, state)), - ); - const resultEntries = [ - new NamedTupleEntry( - undefined, - 'result', - this.computeTypeWithRecursionCheck(node.result, state), - ), - ]; - - return this.factory.createCallableType( - node, - undefined, - this.factory.createNamedTupleType(...parameterEntries), - this.factory.createNamedTupleType(...resultEntries), - ); - } - - private computeTypeOfIndexedAccess(node: SdsIndexedAccess, state: ComputeTypeState): Type { - const receiverType = this.computeTypeWithRecursionCheck(node.receiver, state); + private computeTypeOfIndexedAccess(node: SdsIndexedAccess): Type { + const receiverType = this.computeType(node.receiver); if (!(receiverType instanceof ClassType) && !(receiverType instanceof TypeVariable)) { return UnknownType; } @@ -565,9 +509,9 @@ export class SafeDsTypeComputer { return UnknownType; } - private computeTypeOfArithmeticInfixOperation(node: SdsInfixOperation, state: ComputeTypeState): Type { - const leftOperandType = this.computeTypeWithRecursionCheck(node.leftOperand, state); - const rightOperandType = this.computeTypeWithRecursionCheck(node.rightOperand, state); + private computeTypeOfArithmeticInfixOperation(node: SdsInfixOperation): Type { + const leftOperandType = this.computeType(node.leftOperand); + const rightOperandType = this.computeType(node.rightOperand); if ( this.typeChecker.isSubtypeOf(leftOperandType, this.coreTypes.Int) && @@ -579,18 +523,56 @@ export class SafeDsTypeComputer { } } - private computeTypeOfElvisOperation(node: SdsInfixOperation, state: ComputeTypeState): Type { - const leftOperandType = this.computeTypeWithRecursionCheck(node.leftOperand, state); + private computeTypeOfElvisOperation(node: SdsInfixOperation): Type { + const leftOperandType = this.computeType(node.leftOperand); if (leftOperandType.isExplicitlyNullable) { - const rightOperandType = this.computeTypeWithRecursionCheck(node.rightOperand, state); + const rightOperandType = this.computeType(node.rightOperand); return this.lowestCommonSupertype([leftOperandType.withExplicitNullability(false), rightOperandType]); } else { return leftOperandType; } } - private computeTypeOfMemberAccess(node: SdsMemberAccess, state: ComputeTypeState) { - const memberType = this.computeTypeWithRecursionCheck(node.member, state); + private computeTypeOfLambda(node: SdsLambda): Type { + // Remember lambda parameters + const parameters = getParameters(node); + parameters.forEach((it) => { + this.incompleteLambdaParameters.add(it); + }); + + const parameterEntries = parameters.map((it) => new NamedTupleEntry(it, it.name, this.computeType(it))); + const resultEntries = this.buildLambdaResultEntries(node); + + const unsubstitutedType = this.factory.createCallableType( + node, + undefined, + this.factory.createNamedTupleType(...parameterEntries), + this.factory.createNamedTupleType(...resultEntries), + ); + const substitutions = this.computeSubstitutionsForLambda(node, unsubstitutedType); + + // Forget lambda parameters + parameters.forEach((it) => { + this.incompleteLambdaParameters.delete(it); + }); + + return unsubstitutedType.substituteTypeParameters(substitutions); + } + + private buildLambdaResultEntries(node: SdsLambda): NamedTupleEntry[] { + if (isSdsExpressionLambda(node)) { + return [new NamedTupleEntry(undefined, 'result', this.computeType(node.result))]; + } else if (isSdsBlockLambda(node)) { + return streamBlockLambdaResults(node) + .map((it) => new NamedTupleEntry(it, it.name, this.computeType(it))) + .toArray(); + } /* c8 ignore start */ else { + return []; + } /* c8 ignore stop */ + } + + private computeTypeOfMemberAccess(node: SdsMemberAccess) { + const memberType = this.computeType(node.member); // A member access of an enum variant without parameters always yields an instance, even if it is not in a call if (memberType instanceof StaticType && !isSdsCall(node.$container)) { @@ -601,7 +583,7 @@ export class SafeDsTypeComputer { } } - const receiverType = this.computeTypeWithRecursionCheck(node.receiver, state); + const receiverType = this.computeType(node.receiver); let result: Type = memberType; // Substitute type parameters @@ -614,8 +596,8 @@ export class SafeDsTypeComputer { ); } - private computeTypeOfArithmeticPrefixOperation(node: SdsPrefixOperation, state: ComputeTypeState): Type { - const operandType = this.computeTypeWithRecursionCheck(node.operand, state); + private computeTypeOfArithmeticPrefixOperation(node: SdsPrefixOperation): Type { + const operandType = this.computeType(node.operand); if (this.typeChecker.isSubtypeOf(operandType, this.coreTypes.Int)) { return this.coreTypes.Int; @@ -624,9 +606,9 @@ export class SafeDsTypeComputer { } } - private computeTypeOfReference(node: SdsReference, state: ComputeTypeState): Type { + private computeTypeOfReference(node: SdsReference): Type { const target = node.target.ref; - const instanceType = this.computeTypeWithRecursionCheck(target, state); + const instanceType = this.computeType(target); if (isSdsNamedTypeDeclaration(target) && instanceType instanceof NamedType) { return this.factory.createStaticType(instanceType.withExplicitNullability(false)); @@ -635,11 +617,11 @@ export class SafeDsTypeComputer { } } - private computeTypeOfThis(node: SdsThis, state: ComputeTypeState): Type { + private computeTypeOfThis(node: SdsThis): Type { // If closest callable is a class, return the class type const containingCallable = AstUtils.getContainerOfType(node, isSdsCallable); if (isSdsClass(containingCallable)) { - return this.computeTypeWithRecursionCheck(containingCallable, state); + return this.computeType(containingCallable); } // Invalid if the callable is not a class member or static @@ -649,22 +631,22 @@ export class SafeDsTypeComputer { // Otherwise, return the type of the containing class or unknown if not in a class const containingClass = AstUtils.getContainerOfType(containingCallable, isSdsClass); - return this.computeTypeWithRecursionCheck(containingClass, state); + return this.computeType(containingClass); } - private computeTypeOfType(node: SdsType, state: ComputeTypeState): Type { + private computeTypeOfType(node: SdsType): Type { if (isSdsCallableType(node)) { - return this.computeTypeOfCallableWithManifestTypes(node, state); + return this.computeTypeOfCallableWithManifestTypes(node); } else if (isSdsLiteralType(node)) { return this.computeTypeOfLiteralType(node); } else if (isSdsMemberType(node)) { - return this.computeTypeWithRecursionCheck(node.member, state); + return this.computeType(node.member); } else if (isSdsNamedType(node)) { - return this.computeTypeOfNamedType(node, state); + return this.computeTypeOfNamedType(node); } else if (isSdsUnionType(node)) { const typeArguments = getTypeArguments(node.typeArgumentList); return this.factory.createUnionType( - ...typeArguments.map((typeArgument) => this.computeTypeWithRecursionCheck(typeArgument.value, state)), + ...typeArguments.map((typeArgument) => this.computeType(typeArgument.value)), ); } else if (isSdsUnknownType(node)) { return UnknownType; @@ -682,11 +664,8 @@ export class SafeDsTypeComputer { } /* c8 ignore stop */ } - private computeTypeOfNamedType(node: SdsNamedType, state: ComputeTypeState) { - const unparameterizedType = this.computeTypeWithRecursionCheck( - node.declaration?.ref, - state, - ).withExplicitNullability(node.isNullable); + private computeTypeOfNamedType(node: SdsNamedType) { + const unparameterizedType = this.computeType(node.declaration?.ref).withExplicitNullability(node.isNullable); if (!(unparameterizedType instanceof ClassType)) { return unparameterizedType; } @@ -819,6 +798,13 @@ export class SafeDsTypeComputer { * @returns The computed substitutions for the type parameters of the callable. */ computeSubstitutionsForCall(node: SdsAbstractCall): TypeParameterSubstitutions { + return this.doComputeSubstitutionsForCall(node); + } + + private doComputeSubstitutionsForCall( + node: SdsAbstractCall, + precomputedArgumentTypes?: Map, + ): TypeParameterSubstitutions { // Compute substitutions for member access const substitutionsFromReceiver = isSdsCall(node) && isSdsMemberAccess(node.receiver) @@ -840,11 +826,12 @@ export class SafeDsTypeComputer { const argument = parametersToArguments.get(parameter); return [ this.computeType(parameter.type), - this.computeTypeWithRecursionCheck(argument?.value ?? parameter.defaultValue, { - computingSubstitutionsForCall: true, - }), + // Use precomputed argument types (lambdas) if available. This prevents infinite recursion. + precomputedArgumentTypes?.get(argument?.value) ?? + this.computeType(argument?.value ?? parameter.defaultValue), ]; }); + const substitutionsFromArguments = this.computeSubstitutionsForArguments( typeParameters, parameterTypesToArgumentTypes, @@ -882,6 +869,22 @@ export class SafeDsTypeComputer { return this.computeSubstitutionsForArguments(ownTypeParameters, ownTypesToOverriddenTypes); } + private computeSubstitutionsForLambda(node: SdsLambda, unsubstitutedType: Type): TypeParameterSubstitutions { + const containerOfLambda = node.$container; + if (!isSdsArgument(containerOfLambda)) { + return NO_SUBSTITUTIONS; + } + + const containingCall = AstUtils.getContainerOfType(containerOfLambda, isSdsCall); + if (!containingCall) { + /* c8 ignore next 2 */ + return NO_SUBSTITUTIONS; + } + + const precomputedArgumentTypes = new Map([[node, unsubstitutedType]]); + return this.doComputeSubstitutionsForCall(containingCall, precomputedArgumentTypes); + } + private computeSubstitutionsForMemberAccess(node: SdsMemberAccess): TypeParameterSubstitutions { const receiverType = this.computeType(node.receiver); if (receiverType instanceof ClassType) { @@ -1705,18 +1708,6 @@ interface ComputeUpperBoundOptions { stopAtTypeVariable?: boolean; } -interface ComputeTypeState { - /** - * Indicates that we are currently computing substitutions for a call. This is used to avoid infinite recursion: - * - * 1. The type of the lambda parameter gets inferred from the context. If the lambda is passed as an argument, the - * result might include type parameters. - * 2. Substitutions are computed for the type parameters from the context (i.e. the call). This involves - * computing the type of the lambda parameter. - */ - computingSubstitutionsForCall?: boolean; -} - interface ComputeSubstitutionsForParametersState { substitutions: TypeParameterSubstitutions; remainingVariances: Map; diff --git a/packages/safe-ds-lang/tests/resources/typing/declarations/parameters/of block lambdas/that are passed as arguments/with type parameters.sdsdev b/packages/safe-ds-lang/tests/resources/typing/declarations/parameters/of block lambdas/that are passed as arguments/with type parameters.sdsdev index 5501690de..337e11215 100644 --- a/packages/safe-ds-lang/tests/resources/typing/declarations/parameters/of block lambdas/that are passed as arguments/with type parameters.sdsdev +++ b/packages/safe-ds-lang/tests/resources/typing/declarations/parameters/of block lambdas/that are passed as arguments/with type parameters.sdsdev @@ -10,6 +10,8 @@ class MyClass(param: T) sub MySuperclass { @Pure fun myFunction(p: T, callback: (p: T) -> ()) +@Pure fun myFunction2(callback: (p: T) -> (r: T)) + segment mySegment() { // $TEST$ serialization literal<1> MyClass(1).myMethod((»p«) {}); @@ -19,4 +21,7 @@ segment mySegment() { // $TEST$ serialization literal<1> myFunction(1, (»p«) {}); + + // $TEST$ serialization literal<""> + myFunction2((»p«) -> ""); } diff --git a/packages/safe-ds-lang/tests/resources/typing/declarations/parameters/of expression lambdas/that are passed as arguments/with type parameters.sdsdev b/packages/safe-ds-lang/tests/resources/typing/declarations/parameters/of expression lambdas/that are passed as arguments/with type parameters.sdsdev index e63797446..e1cd35145 100644 --- a/packages/safe-ds-lang/tests/resources/typing/declarations/parameters/of expression lambdas/that are passed as arguments/with type parameters.sdsdev +++ b/packages/safe-ds-lang/tests/resources/typing/declarations/parameters/of expression lambdas/that are passed as arguments/with type parameters.sdsdev @@ -10,6 +10,8 @@ class MyClass(param: T) sub MySuperclass { @Pure fun myFunction(p: T, callback: (p: T) -> ()) +@Pure fun myFunction2(callback: (p: T) -> (r: T)) + segment mySegment() { // $TEST$ serialization literal<1> MyClass(1).myMethod((»p«) -> ""); @@ -19,4 +21,7 @@ segment mySegment() { // $TEST$ serialization literal<1> myFunction(1, (»p«) -> ""); + + // $TEST$ serialization literal<""> + myFunction2((»p«) -> ""); } diff --git a/packages/safe-ds-lang/tests/resources/typing/declarations/parameters/scope provider should not cause infinite recursion/main.sdsdev b/packages/safe-ds-lang/tests/resources/typing/declarations/parameters/scope provider should not cause infinite recursion/main.sdsdev new file mode 100644 index 000000000..1dd0d0c3b --- /dev/null +++ b/packages/safe-ds-lang/tests/resources/typing/declarations/parameters/scope provider should not cause infinite recursion/main.sdsdev @@ -0,0 +1,30 @@ +package tests.typing.declarations.parameters.scopeProviderShouldNotCauseInfiniteRecursion + +/* + * This test is related to the first attempt to fix infinite recursion when computing the type of a lambda parameter. + * There, we passed a flag around as a parameter inside the type computer to indicate that we are currently computing + * type arguments for a call. + * + * However, this approach was not sufficient to fix the issue, since the type computer (implicitly) calls the scope + * provider, which in turn calls the type computer again. The flag was lost along the way, again opening the door for + * infinite recursion. + */ + +class MyCell { + @Pure + fun ^not() -> result: MyCell +} + +class MyColumn(values: List) { + @Pure + fun transform( + transformer: (cell: MyCell) -> (transformedCell: MyCell) + ) -> result: MyColumn +} + +pipeline myPipelines { + val column = MyColumn([1, 2, 3]); + + // $TEST$ serialization MyColumn + val »transformedColumn« = column.transform((cell) -> cell.^not()); +}