Skip to content

Commit

Permalink
fix: simplification of union types (#897)
Browse files Browse the repository at this point in the history
### Summary of Changes

Consider the following program:

```
class C<T> {
    attr a: union<Int, T>
}

segment s(p: C<String>) {
    val v = p.a;
}
```

Previously, the type of the placeholder `v` was inferred to be `String`,
because we incorrectly simplified union types that contained entries
that were not fully substituted. Here, we replaced the entry `Int` with
`T`, since `Int` is a subtype of `T`. This can change, however, once we
substitute `T`, as shown here.

Now, we never consider types that are not fully substituted for
replacement and compute the correct type `union<Int, String>`.
  • Loading branch information
lars-reimann authored Feb 19, 2024
1 parent b81bef9 commit 4c577a3
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 250 deletions.
199 changes: 88 additions & 111 deletions packages/safe-ds-lang/src/language/typing/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ export abstract class Type {
*/
abstract isExplicitlyNullable: boolean;

/**
* Whether the type does not contain type parameter types anymore.
*/
abstract isFullySubstituted: boolean;

/**
* Returns whether the type is equal to another type.
*/
Expand Down Expand Up @@ -60,6 +65,7 @@ export abstract class Type {

export class CallableType extends Type {
private readonly factory: SafeDsTypeFactory;

override isExplicitlyNullable: boolean = false;

constructor(
Expand All @@ -74,6 +80,10 @@ export class CallableType extends Type {
this.factory = services.types.TypeFactory;
}

override get isFullySubstituted(): boolean {
return this.inputType.isFullySubstituted && this.outputType.isFullySubstituted;
}

/**
* Returns the type of the parameter at the given index. If the index is out of bounds, returns `undefined`.
*/
Expand Down Expand Up @@ -114,7 +124,7 @@ export class CallableType extends Type {
}

override substituteTypeParameters(substitutions: TypeParameterSubstitutions): CallableType {
if (isEmpty(substitutions)) {
if (isEmpty(substitutions) || this.isFullySubstituted) {
return this;
}

Expand All @@ -135,102 +145,21 @@ export class CallableType extends Type {
}
}

export class IntersectionType extends Type {
private readonly coreTypes: SafeDsCoreTypes;
private readonly factory: SafeDsTypeFactory;

readonly types: Type[];
private _isExplicitlyNullable: boolean | undefined;

constructor(services: SafeDsServices, types: Type[]) {
super();

this.coreTypes = services.types.CoreTypes;
this.factory = services.types.TypeFactory;

this.types = types;
}

override get isExplicitlyNullable(): boolean {
if (this._isExplicitlyNullable === undefined) {
this._isExplicitlyNullable = this.types.every((it) => it.isExplicitlyNullable);
}

return this._isExplicitlyNullable;
}

override equals(other: unknown): boolean {
if (other === this) {
return true;
} else if (!(other instanceof IntersectionType)) {
return false;
}

return this.types.length === other.types.length && this.types.every((type, i) => type.equals(other.types[i]));
}

override toString(): string {
return `$intersection<${this.types.join(', ')}>`;
}

override simplify(): Type {
// Flatten nested intersections
const newTypes = this.types.flatMap((type) => {
const unwrappedType = type.simplify();
if (unwrappedType instanceof IntersectionType) {
return unwrappedType.types;
} else {
return unwrappedType;
}
});

// Remove the outer intersection if there's only one type left
if (newTypes.length === 1) {
return newTypes[0]!;
}

return this.factory.createIntersectionType(...newTypes);
}

override substituteTypeParameters(substitutions: TypeParameterSubstitutions): IntersectionType {
if (isEmpty(substitutions)) {
return this;
}

return this.factory.createIntersectionType(
...this.types.map((it) => it.substituteTypeParameters(substitutions)),
);
}

override withExplicitNullability(isExplicitlyNullable: boolean): Type {
if (isEmpty(this.types)) {
return this.coreTypes.Any.withExplicitNullability(isExplicitlyNullable);
}

if (this.isExplicitlyNullable && !isExplicitlyNullable) {
return this.factory.createIntersectionType(...this.types.map((it) => it.withExplicitNullability(false)));
} else if (!this.isExplicitlyNullable && isExplicitlyNullable) {
return this.factory.createIntersectionType(...this.types.map((it) => it.withExplicitNullability(true)));
} else {
return this;
}
}
}

export class LiteralType extends Type {
private readonly coreTypes: SafeDsCoreTypes;
private readonly factory: SafeDsTypeFactory;

readonly constants: Constant[];
private _isExplicitlyNullable: boolean | undefined;
override readonly isFullySubstituted = true;

constructor(services: SafeDsServices, constants: Constant[]) {
constructor(
services: SafeDsServices,
readonly constants: Constant[],
) {
super();

this.coreTypes = services.types.CoreTypes;
this.factory = services.types.TypeFactory;

this.constants = constants;
}

override get isExplicitlyNullable(): boolean {
Expand Down Expand Up @@ -304,8 +233,10 @@ export class LiteralType extends Type {

export class NamedTupleType<T extends SdsDeclaration> extends Type {
private readonly factory: SafeDsTypeFactory;

readonly entries: NamedTupleEntry<T>[];
override readonly isExplicitlyNullable = false;
private _isFullySubstituted: boolean | undefined;

constructor(services: SafeDsServices, entries: NamedTupleEntry<T>[]) {
super();
Expand All @@ -314,6 +245,14 @@ export class NamedTupleType<T extends SdsDeclaration> extends Type {
this.entries = entries;
}

override get isFullySubstituted(): boolean {
if (this._isFullySubstituted === undefined) {
this._isFullySubstituted = this.entries.every((it) => it.type.isFullySubstituted);
}

return this._isFullySubstituted;
}

/**
* The length of this tuple.
*/
Expand Down Expand Up @@ -357,7 +296,7 @@ export class NamedTupleType<T extends SdsDeclaration> extends Type {
}

override substituteTypeParameters(substitutions: TypeParameterSubstitutions): NamedTupleType<T> {
if (isEmpty(substitutions)) {
if (isEmpty(substitutions) || this.isFullySubstituted) {
return this;
}

Expand Down Expand Up @@ -399,7 +338,7 @@ export class NamedTupleEntry<T extends SdsDeclaration> {
}

substituteTypeParameters(substitutions: TypeParameterSubstitutions): NamedTupleEntry<T> {
if (isEmpty(substitutions)) {
if (isEmpty(substitutions) || this.type.isFullySubstituted) {
/* c8 ignore next 2 */
return this;
}
Expand Down Expand Up @@ -433,6 +372,8 @@ export abstract class NamedType<T extends SdsDeclaration> extends Type {
}

export class ClassType extends NamedType<SdsClass> {
private _isFullySubstituted: boolean | undefined;

constructor(
declaration: SdsClass,
readonly substitutions: TypeParameterSubstitutions,
Expand All @@ -441,6 +382,14 @@ export class ClassType extends NamedType<SdsClass> {
super(declaration);
}

override get isFullySubstituted(): boolean {
if (this._isFullySubstituted === undefined) {
this._isFullySubstituted = stream(this.substitutions.values()).every((it) => it.isFullySubstituted);
}

return this._isFullySubstituted;
}

getTypeParameterTypeByIndex(index: number): Type {
const typeParameter = getTypeParameters(this.declaration)[index];
if (!typeParameter) {
Expand Down Expand Up @@ -486,7 +435,7 @@ export class ClassType extends NamedType<SdsClass> {
}

override substituteTypeParameters(substitutions: TypeParameterSubstitutions): ClassType {
if (isEmpty(substitutions)) {
if (isEmpty(substitutions) || this.isFullySubstituted) {
return this;
}

Expand All @@ -507,6 +456,8 @@ export class ClassType extends NamedType<SdsClass> {
}

export class EnumType extends NamedType<SdsEnum> {
override readonly isFullySubstituted = true;

constructor(
declaration: SdsEnum,
override readonly isExplicitlyNullable: boolean,
Expand Down Expand Up @@ -538,6 +489,8 @@ export class EnumType extends NamedType<SdsEnum> {
}

export class EnumVariantType extends NamedType<SdsEnumVariant> {
override readonly isFullySubstituted = true;

constructor(
declaration: SdsEnumVariant,
override readonly isExplicitlyNullable: boolean,
Expand Down Expand Up @@ -569,6 +522,8 @@ export class EnumVariantType extends NamedType<SdsEnumVariant> {
}

export class TypeParameterType extends NamedType<SdsTypeParameter> {
override readonly isFullySubstituted = false;

constructor(
declaration: SdsTypeParameter,
override readonly isExplicitlyNullable: boolean,
Expand Down Expand Up @@ -608,12 +563,13 @@ export class TypeParameterType extends NamedType<SdsTypeParameter> {
}

/**
* A type that represents an actual class, enum, or enum variant instead of an instance of it.
* A type that represents an actual named type declaration instead of an instance of it.
*/
export class StaticType extends Type {
private readonly factory: SafeDsTypeFactory;

override readonly isExplicitlyNullable = false;
override readonly isFullySubstituted = true;

constructor(
services: SafeDsServices,
Expand Down Expand Up @@ -664,6 +620,7 @@ export class UnionType extends Type {

readonly types: Type[];
private _isExplicitlyNullable: boolean | undefined;
private _isFullySubstituted: boolean | undefined;

constructor(services: SafeDsServices, types: Type[]) {
super();
Expand All @@ -683,6 +640,14 @@ export class UnionType extends Type {
return this._isExplicitlyNullable;
}

override get isFullySubstituted(): boolean {
if (this._isFullySubstituted === undefined) {
this._isFullySubstituted = this.types.every((it) => it.isFullySubstituted);
}

return this._isFullySubstituted;
}

override equals(other: unknown): boolean {
if (other === this) {
return true;
Expand Down Expand Up @@ -717,6 +682,8 @@ export class UnionType extends Type {
// occurrence of duplicate types. It's also makes splicing easier.
for (let i = newTypes.length - 1; i >= 0; i--) {
const currentType = newTypes[i]!;
const currentTypeIsNothing = currentType.equals(this.coreTypes.Nothing);
const currentTypeIsNothingOrNull = currentType.equals(this.coreTypes.NothingOrNull);

for (let j = newTypes.length - 1; j >= 0; j--) {
if (i === j) {
Expand All @@ -725,10 +692,28 @@ export class UnionType extends Type {

const otherType = newTypes[j]!;

// Remove identical types
if (currentType.equals(otherType)) {
// Remove the current type
newTypes.splice(i, 1);
break;
}

// We can always attempt to replace `Nothing` or `Nothing?` with other types, since they are the bottom
// types. But otherwise, we cannot use a type that is not fully substituted as a replacement. After
// substitution, we might lose information about the original type:
//
// Consider the type `union<C, T>`, where `C` is a class and `T` is a type parameter without an upper
// bound. While `C` is a subtype of `T`, we cannot replace the union type with `T`, since we might later
// substitute `T` with a type that is not a supertype of `C`.
if (!currentTypeIsNothing && !currentTypeIsNothingOrNull && !otherType.isFullySubstituted) {
continue;
}

// Don't merge `Nothing?` into callable types, named tuple types or static types, since that would
// create another union type.
if (
currentType.equals(this.coreTypes.NothingOrNull) &&
currentTypeIsNothingOrNull &&
(otherType instanceof CallableType ||
otherType instanceof NamedTupleType ||
otherType instanceof StaticType)
Expand All @@ -741,31 +726,22 @@ export class UnionType extends Type {
// Other type always occurs before current type
const newConstants = [...otherType.constants, ...currentType.constants];
const newLiteralType = this.factory.createLiteralType(...newConstants).simplify();

// Replace the other type with the new literal type
newTypes.splice(j, 1, newLiteralType);
// Remove the current type
newTypes.splice(i, 1);
break;
}

// Remove subtypes of other types. Type parameter types are a special case, since there might be a
// subtype relation between `currentType` and `otherType` in both directions. We always keep the type
// parameter type in this case for consistency, since it can be narrower if it has a lower bound.
if (currentType instanceof TypeParameterType) {
const candidateType = currentType.withExplicitNullability(
currentType.isExplicitlyNullable || otherType.isExplicitlyNullable,
);

if (this.typeChecker.isSubtypeOf(otherType, candidateType)) {
newTypes.splice(j, 1, candidateType);
newTypes.splice(i, 1);
break;
}
}

// Remove subtypes of other types
const candidateType = otherType.withExplicitNullability(
currentType.isExplicitlyNullable || otherType.isExplicitlyNullable,
);
if (this.typeChecker.isSubtypeOf(currentType, candidateType)) {
newTypes.splice(j, 1, candidateType); // Update nullability
if (this.typeChecker.isSupertypeOf(candidateType, currentType)) {
// Replace the other type with the candidate type (updated nullability)
newTypes.splice(j, 1, candidateType);
// Remove the current type
newTypes.splice(i, 1);
break;
}
Expand All @@ -780,7 +756,7 @@ export class UnionType extends Type {
}

override substituteTypeParameters(substitutions: TypeParameterSubstitutions): UnionType {
if (isEmpty(substitutions)) {
if (isEmpty(substitutions) || this.isFullySubstituted) {
return this;
}

Expand All @@ -803,7 +779,8 @@ export class UnionType extends Type {
}

class UnknownTypeClass extends Type {
readonly isExplicitlyNullable = false;
override readonly isExplicitlyNullable = false;
override readonly isFullySubstituted = true;

override equals(other: unknown): boolean {
return other instanceof UnknownTypeClass;
Expand Down
Loading

0 comments on commit 4c577a3

Please sign in to comment.