Skip to content

Commit

Permalink
WIP: remove IR from TypeRepresentation, abstracting ResolvedType for …
Browse files Browse the repository at this point in the history
…serialization and bringing it back
  • Loading branch information
radeusgd committed Mar 2, 2024
1 parent 2c92a40 commit e4c1292
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import scala.collection.immutable.Seq$;
import scala.jdk.javaapi.CollectionConverters;
import scala.jdk.javaapi.CollectionConverters$;
import scala.util.Either;
import scala.util.Right;

import java.util.*;

Expand Down Expand Up @@ -70,10 +72,11 @@ public Module runModule(Module ir, ModuleContext moduleContext) {
Option.empty()
);

BindingsMap bindingsMap = getMetadata(ir, BindingAnalysis$.MODULE$, BindingsMap.class);
var mappedBindings = ir.bindings().map((def) -> switch (def) {
case Method.Explicit b -> {
var mapped = def.mapExpressions(
(expression) -> runExpression(expression, ctx)
(expression) -> analyzeExpression(expression, ctx, LocalBindingsTyping.create(), bindingsMap)
);

var inferredType = getInferredType(b.body());
Expand All @@ -95,10 +98,10 @@ public Module runModule(Module ir, ModuleContext moduleContext) {

@Override
public Expression runExpression(Expression ir, InlineContext inlineContext) {
return analyzeExpression(ir, inlineContext, LocalBindingsTyping.create());
return analyzeExpression(ir, inlineContext, LocalBindingsTyping.create(), inlineContext.bindingsAnalysis());
}

private Expression analyzeExpression(Expression ir, InlineContext inlineContext, LocalBindingsTyping localBindingsTyping) {
private Expression analyzeExpression(Expression ir, InlineContext inlineContext, LocalBindingsTyping localBindingsTyping, BindingsMap bindingsMap) {
// We first run the inner expressions, as most basic inference is propagating types in a bottom-up manner.
var mappedIr = switch (ir) {
case Function.Lambda lambda -> {
Expand All @@ -108,27 +111,27 @@ private Expression analyzeExpression(Expression ir, InlineContext inlineContext,
registerBinding(arg, type, localBindingsTyping);
}
}
var newBody = analyzeExpression(lambda.body(), inlineContext, localBindingsTyping);
var newBody = analyzeExpression(lambda.body(), inlineContext, localBindingsTyping, bindingsMap);
yield lambda.copy(lambda.arguments(), newBody, lambda.location(), lambda.canBeTCO(), lambda.passData(), lambda.diagnostics(), lambda.id());
}
case Case.Expr caseExpr -> {
var newScrutinee = analyzeExpression(caseExpr.scrutinee(), inlineContext, localBindingsTyping);
var newScrutinee = analyzeExpression(caseExpr.scrutinee(), inlineContext, localBindingsTyping, bindingsMap);
List<Case.Branch> newBranches = CollectionConverters$.MODULE$.asJava(caseExpr.branches()).stream().map((branch) -> {
// TODO once we will be implementing type equality constraints*, we will need to copy localBindingsTyping here, to ensure independent typing of branches
// (*) (case x of _ : Integer -> e) ==> x : Integer within e
var myBranchLocalBindingsTyping = localBindingsTyping;
registerPattern(branch.pattern(), myBranchLocalBindingsTyping);
var newExpression = analyzeExpression(branch.expression(), inlineContext, myBranchLocalBindingsTyping);
var newExpression = analyzeExpression(branch.expression(), inlineContext, myBranchLocalBindingsTyping, bindingsMap);
return branch.copy(branch.pattern(), newExpression, branch.terminalBranch(), branch.location(), branch.passData(), branch.diagnostics(), branch.id());
}).toList();
yield caseExpr.copy(newScrutinee, CollectionConverters$.MODULE$.asScala(newBranches).toSeq(), caseExpr.isNested(), caseExpr.location(), caseExpr.passData(), caseExpr.diagnostics(), caseExpr.id());
}
default -> ir.mapExpressions(
(expression) -> analyzeExpression(expression, inlineContext, localBindingsTyping)
(expression) -> analyzeExpression(expression, inlineContext, localBindingsTyping, bindingsMap)
);
};

processTypePropagation(mappedIr, localBindingsTyping);
processTypePropagation(mappedIr, bindingsMap, localBindingsTyping);

// The ascriptions are processed later, because we want them to _overwrite_ any type that was inferred.
processTypeAscription(mappedIr);
Expand Down Expand Up @@ -160,7 +163,7 @@ private void processTypeAscription(Expression ir) {
}
}

private void processTypePropagation(Expression ir, LocalBindingsTyping localBindingsTyping) {
private void processTypePropagation(Expression ir, BindingsMap bindingsMap, LocalBindingsTyping localBindingsTyping) {
switch (ir) {
case Name.Literal l -> processName(l, localBindingsTyping);
case Application.Force f -> {
Expand All @@ -172,7 +175,7 @@ private void processTypePropagation(Expression ir, LocalBindingsTyping localBind
case Application.Prefix p -> {
var functionType = getInferredType(p.function());
if (functionType != null) {
var inferredType = processApplication(functionType.type(), p.arguments(), p);
var inferredType = processApplication(functionType.type(), p.arguments(), p, bindingsMap);
if (inferredType != null) {
setInferredType(p, new InferredType(inferredType));
}
Expand Down Expand Up @@ -310,26 +313,26 @@ private void processLiteral(Literal literal) {
}

@SuppressWarnings("unchecked")
private TypeRepresentation processApplication(TypeRepresentation functionType, scala.collection.immutable.List<CallArgument> arguments, Application.Prefix relatedIR) {
private TypeRepresentation processApplication(TypeRepresentation functionType, scala.collection.immutable.List<CallArgument> arguments, Application.Prefix relatedIR, BindingsMap bindingsMap) {
if (arguments.isEmpty()) {
logger.warn("processApplication: {} - unexpected - no arguments in a function application", relatedIR.showCode());
return functionType;
}

var firstArgument = arguments.head();
var firstResult = processSingleApplication(functionType, firstArgument, relatedIR);
var firstResult = processSingleApplication(functionType, firstArgument, relatedIR, bindingsMap);
if (firstResult == null) {
return null;
}

if (arguments.length() == 1) {
return firstResult;
} else {
return processApplication(firstResult, (scala.collection.immutable.List<CallArgument>) arguments.tail(), relatedIR);
return processApplication(firstResult, (scala.collection.immutable.List<CallArgument>) arguments.tail(), relatedIR, bindingsMap);
}
}

private TypeRepresentation processSingleApplication(TypeRepresentation functionType, CallArgument argument, Application.Prefix relatedIR) {
private TypeRepresentation processSingleApplication(TypeRepresentation functionType, CallArgument argument, Application.Prefix relatedIR, BindingsMap bindingsMap) {
if (argument.name().isDefined()) {
// TODO named arguments are not yet supported
return null;
Expand All @@ -345,7 +348,7 @@ private TypeRepresentation processSingleApplication(TypeRepresentation functionT
}

case TypeRepresentation.UnresolvedSymbol unresolvedSymbol -> {
return processUnresolvedSymbolApplication(unresolvedSymbol, argument.value());
return processUnresolvedSymbolApplication(bindingsMap, unresolvedSymbol, argument.value());
}

case TypeRepresentation.TopType() -> {
Expand All @@ -360,33 +363,28 @@ private TypeRepresentation processSingleApplication(TypeRepresentation functionT
return null;
}

private BindingsMap getBindingsMap(IR ir) {
return getMetadata(ir, BindingAnalysis$.MODULE$, BindingsMap.class);
}

private BindingsMap.Type findTypeDescription(QualifiedName typeName, IR someIr) {
var bindingsMap = getBindingsMap(someIr);
var result = bindingsMap.resolveQualifiedName(typeName.path());
if (result.isLeft()) {
throw new IllegalStateException("Internal error: type signature contained _already resolved_ reference to type " + typeName + ", but now that type cannot be found in the bindings map.");
}
private BindingsMap.Type findResolvedType(BindingsMap bindingsMap, QualifiedName typeName) {
var resolved = switch (bindingsMap.resolveQualifiedName(typeName.fullPath())) {
case Right<BindingsMap.ResolutionError, BindingsMap.ResolvedName> right ->
right.value();
default -> throw new IllegalStateException("Internal error: type signature contained _already resolved_ reference to type " + typeName + ", but now that type cannot be found in the bindings map.");
};

var resolved = result.right().get();
return switch (resolved) {
case BindingsMap.ResolvedType resolvedType -> resolvedType.tp();
default -> throw new IllegalStateException("Internal error: type signature contained _already resolved_ reference to type " + typeName + ", but now that type is not a type, but " + resolved + ".");
};
}

private TypeRepresentation processUnresolvedSymbolApplication(TypeRepresentation.UnresolvedSymbol function, Expression argument) {
private TypeRepresentation processUnresolvedSymbolApplication(BindingsMap bindingsMap, TypeRepresentation.UnresolvedSymbol function, Expression argument) {
var argumentType = getInferredType(argument);
if (argumentType == null) {
return null;
}

switch (argumentType.type()) {
case TypeRepresentation.TypeObject typeObject -> {
var typeDescription = findTypeDescription(typeObject.name(), argument);
var typeDescription = findResolvedType(bindingsMap, typeObject.name());
Option<BindingsMap.Cons> ctorCandidate = typeDescription.members().find((ctor) -> ctor.name().equals(function.name()));
if (ctorCandidate.isDefined()) {
return buildAtomConstructorType(typeObject, ctorCandidate.get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package org.enso.compiler.context

import org.enso.compiler.PackageRepository
import org.enso.compiler.context.LocalScope
import org.enso.compiler.data.CompilerConfig
import org.enso.compiler.data.{BindingsMap, CompilerConfig}
import org.enso.compiler.pass.PassConfiguration

/** A type containing the information about the execution context for an inline
Expand All @@ -26,7 +26,7 @@ case class InlineContext(
passConfiguration: Option[PassConfiguration] = None,
pkgRepo: Option[PackageRepository] = None
) {
def bindingsAnalysis() = module.bindingsAnalysis()
def bindingsAnalysis(): BindingsMap = module.bindingsAnalysis()
}
object InlineContext {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package org.enso.compiler.context

import org.enso.compiler.PackageRepository
import org.enso.compiler.data.CompilerConfig
import org.enso.compiler.data.{BindingsMap, CompilerConfig}
import org.enso.compiler.pass.PassConfiguration
import org.enso.pkg.Package;
import org.enso.pkg.QualifiedName;
import org.enso.pkg.Package
import org.enso.pkg.QualifiedName
import com.oracle.truffle.api.source.Source
import org.enso.compiler.data.BindingsMap.ModuleReference

Expand All @@ -25,8 +25,8 @@ case class ModuleContext(
isGeneratingDocs: Boolean = false,
pkgRepo: Option[PackageRepository] = None
) {
def isSynthetic() = module.isSynthetic()
def bindingsAnalysis() = module.getBindingsMap()
def isSynthetic(): Boolean = module.isSynthetic()
def bindingsAnalysis(): BindingsMap = module.getBindingsMap()
def getName(): QualifiedName = module.getName()
def getPackage(): Package[_] = module.getPackage()
def getSource(): Source = module.getSource()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.enso.compiler.core.ir.expression.errors
import org.enso.compiler.data.BindingsMap.{DefinedEntity, ModuleReference}
import org.enso.compiler.core.CompilerError
import org.enso.compiler.core.ir.Expression
import org.enso.compiler.core.ir.module.scope.Definition
import org.enso.compiler.pass.IRPass
import org.enso.compiler.pass.analyse.BindingAnalysis
import org.enso.compiler.pass.resolve.MethodDefinitions
Expand Down Expand Up @@ -739,13 +740,40 @@ object BindingsMap {
* @param members the member names
* @param builtinType true if constructor is annotated with @Builtin_Type, false otherwise.
*/
case class Type(// TODO if type can hold Cons->Argument->Expression, we need to make it have toAbstract
case class Type(
override val name: String,
params: Seq[String],
members: Seq[Cons],
builtinType: Boolean
) extends DefinedEntity {
override def canExport: Boolean = true

def toAbstract: Type =
this.copy(members = Seq())

def toConcrete(concreteModule: ModuleReference.Concrete): Option[Type] = {
val ir = concreteModule.unsafeAsModule().getIr
ir.bindings.collectFirst {
case typeIr: Definition.Type if typeIr.name.name == name => typeIr
}.map { typeIr => Type.fromIr(typeIr, builtinType) }
}
}

object Type {
def fromIr(ir: Definition.Type, isBuiltinType: Boolean): Type =
BindingsMap.Type(
ir.name.name,
ir.params.map(_.name.name),
ir.members.map(m =>
Cons(
m.name.name,
m.arguments.map(arg =>
BindingsMap.Argument(arg.name.name, arg.defaultValue.isDefined, arg.ascribedType)
)
)
),
isBuiltinType
)
}

/** A representation of an imported polyglot symbol.
Expand Down Expand Up @@ -777,16 +805,16 @@ object BindingsMap {
}

/** @inheritdoc */
override def toAbstract: ResolvedType = {
this.copy(module = module.toAbstract)
}
override def toAbstract: ResolvedType =
this.copy(module = module.toAbstract, tp = tp.toAbstract)

/** @inheritdoc */
override def toConcrete(
moduleMap: ModuleMap
): Option[ResolvedType] = {
module.toConcrete(moduleMap).map(module => this.copy(module = module))
}
): Option[ResolvedType] = for {
concreteModule <- module.toConcrete(moduleMap)
concreteTp <- tp.toConcrete(concreteModule)
} yield ResolvedType(concreteModule, concreteTp)

override def qualifiedName: QualifiedName =
module.getName.createChild(tp.name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import org.enso.compiler.core.ir.module.scope.definition
import org.enso.compiler.core.ir.module.scope.imports
import org.enso.compiler.core.ir.MetadataStorage.MetadataPair
import org.enso.compiler.data.BindingsMap
import org.enso.compiler.data.BindingsMap.Cons
import org.enso.compiler.pass.IRPass
import org.enso.compiler.pass.desugar.{
ComplexType,
Expand Down Expand Up @@ -56,19 +55,7 @@ case object BindingAnalysis extends IRPass {
val isBuiltinType = sumType
.getMetadata(ModuleAnnotations)
.exists(_.annotations.exists(_.name == "@Builtin_Type"))
BindingsMap.Type(
sumType.name.name,
sumType.params.map(_.name.name),
sumType.members.map(m =>
Cons(
m.name.name,
m.arguments.map(arg =>
BindingsMap.Argument(arg.name.name, arg.defaultValue.isDefined, arg.ascribedType)
)
)
),
isBuiltinType
)
BindingsMap.Type.fromIr(sumType, isBuiltinType)
}
val importedPolyglot = ir.imports.collect { case poly: imports.Polyglot =>
BindingsMap.PolyglotSymbol(poly.getVisibleName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import scala.jdk.javaapi.CollectionConverters;

public class TypeInferenceTest extends CompilerTest {
@Ignore
@Test
public void zeroAryCheck() throws Exception {
final URI uri = new URI("memory://zeroAryCheck.enso");
Expand Down
2 changes: 2 additions & 0 deletions lib/scala/pkg/src/main/scala/org/enso/pkg/QualifiedName.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ case class QualifiedName(path: List[String], item: String) {
def pathAsJava(): java.util.List[String] = {
path.asJava
}

def fullPath(): List[String] = path :+ item
}

object QualifiedName {
Expand Down

0 comments on commit e4c1292

Please sign in to comment.