diff --git a/packages/safe-ds-lang/src/language/typing/safe-ds-type-checker.ts b/packages/safe-ds-lang/src/language/typing/safe-ds-type-checker.ts index 70b4f0f6a..9660e7cd1 100644 --- a/packages/safe-ds-lang/src/language/typing/safe-ds-type-checker.ts +++ b/packages/safe-ds-lang/src/language/typing/safe-ds-type-checker.ts @@ -375,6 +375,21 @@ export class SafeDsTypeChecker { } }; + /** + * Returns whether {@link type} can be `null`. Compared to {@link Type.isNullable}, this method also considers the + * upper bound of type parameter types. + */ + canBeNull = (type: Type): boolean => { + if (type.isNullable) { + return true; + } else if (type instanceof TypeParameterType) { + const upperBound = this.typeComputer().computeUpperBound(type); + return upperBound.isNullable; + } else { + return false; + } + }; + /** * Checks whether {@link type} is allowed as the type of a constant parameter. */ diff --git a/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts b/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts index 424cb8d1c..73a625dcf 100644 --- a/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts +++ b/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts @@ -791,7 +791,7 @@ export class SafeDsTypeComputer { * If invalid lower bounds are specified (e.g. because of an unresolved reference or a cycle), `$unknown` is * returned. The result is simplified as much as possible. */ - computeLowerBound(nodeOrType: SdsTypeParameter | TypeParameterType): Type { + computeLowerBound(nodeOrType: SdsTypeParameter | TypeParameterType, options: ComputeBoundOptions = {}): Type { let type: TypeParameterType; if (nodeOrType instanceof TypeParameterType) { type = nodeOrType; @@ -799,10 +799,14 @@ export class SafeDsTypeComputer { type = this.computeType(nodeOrType) as TypeParameterType; } - return this.doComputeLowerBound(type, new Set()); + return this.doComputeLowerBound(type, new Set(), options); } - private doComputeLowerBound(type: TypeParameterType, visited: Set): Type { + private doComputeLowerBound( + type: TypeParameterType, + visited: Set, + options: ComputeBoundOptions, + ): Type { // Check for cycles if (visited.has(type.declaration)) { return UnknownType; @@ -817,10 +821,10 @@ export class SafeDsTypeComputer { const boundType = this.computeLowerBoundType(lowerBounds[0]!); if (!(boundType instanceof NamedType)) { return UnknownType; - } else if (!(boundType instanceof TypeParameterType)) { + } else if (options.stopAtTypeParameterType || !(boundType instanceof TypeParameterType)) { return boundType; } else { - return this.doComputeLowerBound(boundType, visited); + return this.doComputeLowerBound(boundType, visited, options); } } @@ -837,7 +841,7 @@ export class SafeDsTypeComputer { * invalid upper bounds are specified, but are invalid (e.g. because of an unresolved reference or a cycle), * `$unknown` is returned. The result is simplified as much as possible. */ - computeUpperBound(nodeOrType: SdsTypeParameter | TypeParameterType): Type { + computeUpperBound(nodeOrType: SdsTypeParameter | TypeParameterType, options: ComputeBoundOptions = {}): Type { let type: TypeParameterType; if (nodeOrType instanceof TypeParameterType) { type = nodeOrType; @@ -845,11 +849,15 @@ export class SafeDsTypeComputer { type = this.computeType(nodeOrType) as TypeParameterType; } - const result = this.doComputeUpperBound(type, new Set()); + const result = this.doComputeUpperBound(type, new Set(), options); return result.updateNullability(result.isNullable || type.isNullable); } - private doComputeUpperBound(type: TypeParameterType, visited: Set): Type { + private doComputeUpperBound( + type: TypeParameterType, + visited: Set, + options: ComputeBoundOptions, + ): Type { // Check for cycles if (visited.has(type.declaration)) { return UnknownType; @@ -864,10 +872,10 @@ export class SafeDsTypeComputer { const boundType = this.computeUpperBoundType(upperBounds[0]!); if (!(boundType instanceof NamedType)) { return UnknownType; - } else if (!(boundType instanceof TypeParameterType)) { + } else if (options.stopAtTypeParameterType || !(boundType instanceof TypeParameterType)) { return boundType; } else { - return this.doComputeUpperBound(boundType, visited); + return this.doComputeUpperBound(boundType, visited, options); } } @@ -1140,6 +1148,17 @@ export class SafeDsTypeComputer { } } +/** + * Options for {@link computeLowerBound} and {@link computeUpperBound}. + */ +interface ComputeBoundOptions { + /** + * If `true`, the computation stops at type parameter types and returns them as is. Otherwise, it finds the bounds + * for the type parameter types recursively. + */ + stopAtTypeParameterType?: boolean; +} + interface GroupTypesResult { classTypes: ClassType[]; constants: Constant[]; diff --git a/packages/safe-ds-lang/src/language/validation/other/expressions/chainedExpressions.ts b/packages/safe-ds-lang/src/language/validation/other/expressions/chainedExpressions.ts index b991422b3..e8880338f 100644 --- a/packages/safe-ds-lang/src/language/validation/other/expressions/chainedExpressions.ts +++ b/packages/safe-ds-lang/src/language/validation/other/expressions/chainedExpressions.ts @@ -15,25 +15,21 @@ export const chainedExpressionsMustBeNullSafeIfReceiverIsNullable = (services: S } const receiverType = typeComputer.computeType(node.receiver); - if (receiverType === UnknownType) { + if (receiverType === UnknownType || !typeChecker.canBeNull(receiverType)) { return; } - if (isSdsCall(node) && receiverType.isNullable && typeChecker.canBeCalled(receiverType)) { + if (isSdsCall(node) && typeChecker.canBeCalled(receiverType)) { accept('error', 'The receiver can be null so a null-safe call must be used.', { node, code: CODE_CHAINED_EXPRESSION_MISSING_NULL_SAFETY, }); - } else if ( - isSdsIndexedAccess(node) && - receiverType.isNullable && - typeChecker.canBeAccessedByIndex(receiverType) - ) { + } else if (isSdsIndexedAccess(node) && typeChecker.canBeAccessedByIndex(receiverType)) { accept('error', 'The receiver can be null so a null-safe indexed access must be used.', { node, code: CODE_CHAINED_EXPRESSION_MISSING_NULL_SAFETY, }); - } else if (isSdsMemberAccess(node) && receiverType.isNullable) { + } else if (isSdsMemberAccess(node)) { accept('error', 'The receiver can be null so a null-safe member access must be used.', { node, code: CODE_CHAINED_EXPRESSION_MISSING_NULL_SAFETY, diff --git a/packages/safe-ds-lang/src/language/validation/style.ts b/packages/safe-ds-lang/src/language/validation/style.ts index e4f1f20aa..d0f9eb0f7 100644 --- a/packages/safe-ds-lang/src/language/validation/style.ts +++ b/packages/safe-ds-lang/src/language/validation/style.ts @@ -222,6 +222,7 @@ export const constraintListShouldNotBeEmpty = (services: SafeDsServices) => { export const elvisOperatorShouldBeNeeded = (services: SafeDsServices) => { const partialEvaluator = services.evaluation.PartialEvaluator; const settingsProvider = services.workspace.SettingsProvider; + const typeChecker = services.types.TypeChecker; const typeComputer = services.types.TypeComputer; return async (node: SdsInfixOperation, accept: ValidationAcceptor) => { @@ -235,7 +236,7 @@ export const elvisOperatorShouldBeNeeded = (services: SafeDsServices) => { } const leftType = typeComputer.computeType(node.leftOperand); - if (!leftType.isNullable) { + if (!typeChecker.canBeNull(leftType)) { accept( 'info', 'The left operand is never null, so the elvis operator is unnecessary (keep the left operand).', @@ -322,14 +323,14 @@ export const chainedExpressionNullSafetyShouldBeNeeded = (services: SafeDsServic } const receiverType = typeComputer.computeType(node.receiver); - if (receiverType === UnknownType) { + if (receiverType === UnknownType || typeChecker.canBeNull(receiverType)) { return; } if ( - (isSdsCall(node) && !receiverType.isNullable && typeChecker.canBeCalled(receiverType)) || - (isSdsIndexedAccess(node) && !receiverType.isNullable && typeChecker.canBeAccessedByIndex(receiverType)) || - (isSdsMemberAccess(node) && !receiverType.isNullable) + (isSdsCall(node) && typeChecker.canBeCalled(receiverType)) || + (isSdsIndexedAccess(node) && typeChecker.canBeAccessedByIndex(receiverType)) || + isSdsMemberAccess(node) ) { accept('info', 'The receiver is never null, so null-safety is unnecessary.', { node, diff --git a/packages/safe-ds-lang/tests/resources/validation/other/expressions/chained expression/missing null safety/main.sdstest b/packages/safe-ds-lang/tests/resources/validation/other/expressions/chained expression/missing null safety/main.sdstest index fe504d268..8047a2697 100644 --- a/packages/safe-ds-lang/tests/resources/validation/other/expressions/chained expression/missing null safety/main.sdstest +++ b/packages/safe-ds-lang/tests/resources/validation/other/expressions/chained expression/missing null safety/main.sdstest @@ -93,6 +93,24 @@ segment indexedAccess( »unresolved?[0]«; } +class IndexedAccess( + nullable: Nullable, + nonNullable: NonNullable, + + // $TEST$ error "The receiver can be null so a null-safe indexed access must be used." + p1: Any? = »nullable[0]«, + // $TEST$ no error "The receiver can be null so a null-safe indexed access must be used." + p2: Any? = »nonNullable[0]«, + + // $TEST$ no error "The receiver can be null so a null-safe indexed access must be used." + p3: Any? = »nullable?[0]«, + // $TEST$ no error "The receiver can be null so a null-safe indexed access must be used." + p4: Any? = »nonNullable?[0]«, +) where { + Nullable sub List?, + NonNullable sub List +} + segment memberAccess( myClass: MyClass, myClassOrNull: MyClass?, @@ -128,3 +146,21 @@ segment memberAccess( // $TEST$ no error "The receiver can be null so a null-safe member access must be used." »unresolved?.a«; } + +class MemberAccess( + nullable: Nullable, + nonNullable: NonNullable, + + // $TEST$ error "The receiver can be null so a null-safe member access must be used." + p1: Any? = »nullable.a«, + // $TEST$ no error "The receiver can be null so a null-safe member access must be used." + p2: Any? = »nonNullable.a«, + + // $TEST$ no error "The receiver can be null so a null-safe member access must be used." + p3: Any? = »nullable?.a«, + // $TEST$ no error "The receiver can be null so a null-safe member access must be used." + p4: Any? = »nonNullable?.a«, +) where { + Nullable sub MyClass?, + NonNullable sub MyClass +} diff --git a/packages/safe-ds-lang/tests/resources/validation/style/unnecessary elvis operator/main.sdstest b/packages/safe-ds-lang/tests/resources/validation/style/unnecessary elvis operator/main.sdstest index 3ea67fc90..7129938c7 100644 --- a/packages/safe-ds-lang/tests/resources/validation/style/unnecessary elvis operator/main.sdstest +++ b/packages/safe-ds-lang/tests/resources/validation/style/unnecessary elvis operator/main.sdstest @@ -1,6 +1,6 @@ package validation.style.unnecessaryElvisOperator -fun f() -> result: Any? +@Pure fun f() -> result: Any? pipeline test { @@ -40,3 +40,19 @@ pipeline test { // $TEST$ info "Both operands are always null, so the elvis operator is unnecessary (replace it with null)." »null ?: null«; } + +class TestsForTypeParameters( + nullable: Nullable, + nonNullable: NonNullable, + + // $TEST$ no info "The left operand is never null, so the elvis operator is unnecessary (keep the left operand)." + p1: Any? = »nullable ?: 2«, + // $TEST$ no info "The left operand is never null, so the elvis operator is unnecessary (keep the left operand)." + p2: Any? = »nullable ?: null«, + // $TEST$ info "The left operand is never null, so the elvis operator is unnecessary (keep the left operand)." + p3: Any? = »nonNullable ?: 2«, + // $TEST$ info "The left operand is never null, so the elvis operator is unnecessary (keep the left operand)." + p4: Any? = »nonNullable ?: null«, +) where { + NonNullable sub Any +} diff --git a/packages/safe-ds-lang/tests/resources/validation/style/unnecessary null safety/main.sdstest b/packages/safe-ds-lang/tests/resources/validation/style/unnecessary null safety/main.sdstest index f11834edc..07f9d8b90 100644 --- a/packages/safe-ds-lang/tests/resources/validation/style/unnecessary null safety/main.sdstest +++ b/packages/safe-ds-lang/tests/resources/validation/style/unnecessary null safety/main.sdstest @@ -93,6 +93,24 @@ segment indexedAccess( »unresolved?[0]«; } +class IndexedAccess( + nullable: Nullable, + nonNullable: NonNullable, + + // $TEST$ no info "The receiver is never null, so null-safety is unnecessary." + p1: Any? = »nullable[0]«, + // $TEST$ no info "The receiver is never null, so null-safety is unnecessary." + p2: Any? = »nonNullable[0]«, + + // $TEST$ no info "The receiver is never null, so null-safety is unnecessary." + p3: Any? = »nullable?[0]«, + // $TEST$ info "The receiver is never null, so null-safety is unnecessary." + p4: Any? = »nonNullable?[0]«, +) where { + Nullable sub List?, + NonNullable sub List +} + segment memberAccess( myClass: MyClass, myClassOrNull: MyClass?, @@ -128,3 +146,21 @@ segment memberAccess( // $TEST$ no info "The receiver is never null, so null-safety is unnecessary." »unresolved?.a«; } + +class MemberAccess( + nullable: Nullable, + nonNullable: NonNullable, + + // $TEST$ no info "The receiver is never null, so null-safety is unnecessary." + p1: Any? = »nullable.a«, + // $TEST$ no info "The receiver is never null, so null-safety is unnecessary." + p2: Any? = »nonNullable.a«, + + // $TEST$ no info "The receiver is never null, so null-safety is unnecessary." + p3: Any? = »nullable?.a«, + // $TEST$ info "The receiver is never null, so null-safety is unnecessary." + p4: Any? = »nonNullable?.a«, +) where { + Nullable sub MyClass?, + NonNullable sub MyClass +}