diff --git a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/types/TypeInference.java b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/types/TypeInference.java index 211adc0b89d9a..42f735106c0b8 100644 --- a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/types/TypeInference.java +++ b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/types/TypeInference.java @@ -24,6 +24,8 @@ import scala.collection.immutable.Seq$; import scala.jdk.javaapi.CollectionConverters; import scala.jdk.javaapi.CollectionConverters$; +import scala.util.Either; +import scala.util.Right; import java.util.*; @@ -70,10 +72,11 @@ public Module runModule(Module ir, ModuleContext moduleContext) { Option.empty() ); + BindingsMap bindingsMap = getMetadata(ir, BindingAnalysis$.MODULE$, BindingsMap.class); var mappedBindings = ir.bindings().map((def) -> switch (def) { case Method.Explicit b -> { var mapped = def.mapExpressions( - (expression) -> runExpression(expression, ctx) + (expression) -> analyzeExpression(expression, ctx, LocalBindingsTyping.create(), bindingsMap) ); var inferredType = getInferredType(b.body()); @@ -95,10 +98,10 @@ public Module runModule(Module ir, ModuleContext moduleContext) { @Override public Expression runExpression(Expression ir, InlineContext inlineContext) { - return analyzeExpression(ir, inlineContext, LocalBindingsTyping.create()); + return analyzeExpression(ir, inlineContext, LocalBindingsTyping.create(), inlineContext.bindingsAnalysis()); } - private Expression analyzeExpression(Expression ir, InlineContext inlineContext, LocalBindingsTyping localBindingsTyping) { + private Expression analyzeExpression(Expression ir, InlineContext inlineContext, LocalBindingsTyping localBindingsTyping, BindingsMap bindingsMap) { // We first run the inner expressions, as most basic inference is propagating types in a bottom-up manner. var mappedIr = switch (ir) { case Function.Lambda lambda -> { @@ -108,27 +111,27 @@ private Expression analyzeExpression(Expression ir, InlineContext inlineContext, registerBinding(arg, type, localBindingsTyping); } } - var newBody = analyzeExpression(lambda.body(), inlineContext, localBindingsTyping); + var newBody = analyzeExpression(lambda.body(), inlineContext, localBindingsTyping, bindingsMap); yield lambda.copy(lambda.arguments(), newBody, lambda.location(), lambda.canBeTCO(), lambda.passData(), lambda.diagnostics(), lambda.id()); } case Case.Expr caseExpr -> { - var newScrutinee = analyzeExpression(caseExpr.scrutinee(), inlineContext, localBindingsTyping); + var newScrutinee = analyzeExpression(caseExpr.scrutinee(), inlineContext, localBindingsTyping, bindingsMap); List newBranches = CollectionConverters$.MODULE$.asJava(caseExpr.branches()).stream().map((branch) -> { // TODO once we will be implementing type equality constraints*, we will need to copy localBindingsTyping here, to ensure independent typing of branches // (*) (case x of _ : Integer -> e) ==> x : Integer within e var myBranchLocalBindingsTyping = localBindingsTyping; registerPattern(branch.pattern(), myBranchLocalBindingsTyping); - var newExpression = analyzeExpression(branch.expression(), inlineContext, myBranchLocalBindingsTyping); + var newExpression = analyzeExpression(branch.expression(), inlineContext, myBranchLocalBindingsTyping, bindingsMap); return branch.copy(branch.pattern(), newExpression, branch.terminalBranch(), branch.location(), branch.passData(), branch.diagnostics(), branch.id()); }).toList(); yield caseExpr.copy(newScrutinee, CollectionConverters$.MODULE$.asScala(newBranches).toSeq(), caseExpr.isNested(), caseExpr.location(), caseExpr.passData(), caseExpr.diagnostics(), caseExpr.id()); } default -> ir.mapExpressions( - (expression) -> analyzeExpression(expression, inlineContext, localBindingsTyping) + (expression) -> analyzeExpression(expression, inlineContext, localBindingsTyping, bindingsMap) ); }; - processTypePropagation(mappedIr, localBindingsTyping); + processTypePropagation(mappedIr, bindingsMap, localBindingsTyping); // The ascriptions are processed later, because we want them to _overwrite_ any type that was inferred. processTypeAscription(mappedIr); @@ -160,7 +163,7 @@ private void processTypeAscription(Expression ir) { } } - private void processTypePropagation(Expression ir, LocalBindingsTyping localBindingsTyping) { + private void processTypePropagation(Expression ir, BindingsMap bindingsMap, LocalBindingsTyping localBindingsTyping) { switch (ir) { case Name.Literal l -> processName(l, localBindingsTyping); case Application.Force f -> { @@ -172,7 +175,7 @@ private void processTypePropagation(Expression ir, LocalBindingsTyping localBind case Application.Prefix p -> { var functionType = getInferredType(p.function()); if (functionType != null) { - var inferredType = processApplication(functionType.type(), p.arguments(), p); + var inferredType = processApplication(functionType.type(), p.arguments(), p, bindingsMap); if (inferredType != null) { setInferredType(p, new InferredType(inferredType)); } @@ -310,14 +313,14 @@ private void processLiteral(Literal literal) { } @SuppressWarnings("unchecked") - private TypeRepresentation processApplication(TypeRepresentation functionType, scala.collection.immutable.List arguments, Application.Prefix relatedIR) { + private TypeRepresentation processApplication(TypeRepresentation functionType, scala.collection.immutable.List arguments, Application.Prefix relatedIR, BindingsMap bindingsMap) { if (arguments.isEmpty()) { logger.warn("processApplication: {} - unexpected - no arguments in a function application", relatedIR.showCode()); return functionType; } var firstArgument = arguments.head(); - var firstResult = processSingleApplication(functionType, firstArgument, relatedIR); + var firstResult = processSingleApplication(functionType, firstArgument, relatedIR, bindingsMap); if (firstResult == null) { return null; } @@ -325,11 +328,11 @@ private TypeRepresentation processApplication(TypeRepresentation functionType, s if (arguments.length() == 1) { return firstResult; } else { - return processApplication(firstResult, (scala.collection.immutable.List) arguments.tail(), relatedIR); + return processApplication(firstResult, (scala.collection.immutable.List) arguments.tail(), relatedIR, bindingsMap); } } - private TypeRepresentation processSingleApplication(TypeRepresentation functionType, CallArgument argument, Application.Prefix relatedIR) { + private TypeRepresentation processSingleApplication(TypeRepresentation functionType, CallArgument argument, Application.Prefix relatedIR, BindingsMap bindingsMap) { if (argument.name().isDefined()) { // TODO named arguments are not yet supported return null; @@ -345,7 +348,7 @@ private TypeRepresentation processSingleApplication(TypeRepresentation functionT } case TypeRepresentation.UnresolvedSymbol unresolvedSymbol -> { - return processUnresolvedSymbolApplication(unresolvedSymbol, argument.value()); + return processUnresolvedSymbolApplication(bindingsMap, unresolvedSymbol, argument.value()); } case TypeRepresentation.TopType() -> { @@ -360,25 +363,20 @@ private TypeRepresentation processSingleApplication(TypeRepresentation functionT return null; } - private BindingsMap getBindingsMap(IR ir) { - return getMetadata(ir, BindingAnalysis$.MODULE$, BindingsMap.class); - } - - private BindingsMap.Type findTypeDescription(QualifiedName typeName, IR someIr) { - var bindingsMap = getBindingsMap(someIr); - var result = bindingsMap.resolveQualifiedName(typeName.path()); - if (result.isLeft()) { - throw new IllegalStateException("Internal error: type signature contained _already resolved_ reference to type " + typeName + ", but now that type cannot be found in the bindings map."); - } + private BindingsMap.Type findResolvedType(BindingsMap bindingsMap, QualifiedName typeName) { + var resolved = switch (bindingsMap.resolveQualifiedName(typeName.fullPath())) { + case Right right -> + right.value(); + default -> throw new IllegalStateException("Internal error: type signature contained _already resolved_ reference to type " + typeName + ", but now that type cannot be found in the bindings map."); + }; - var resolved = result.right().get(); return switch (resolved) { case BindingsMap.ResolvedType resolvedType -> resolvedType.tp(); default -> throw new IllegalStateException("Internal error: type signature contained _already resolved_ reference to type " + typeName + ", but now that type is not a type, but " + resolved + "."); }; } - private TypeRepresentation processUnresolvedSymbolApplication(TypeRepresentation.UnresolvedSymbol function, Expression argument) { + private TypeRepresentation processUnresolvedSymbolApplication(BindingsMap bindingsMap, TypeRepresentation.UnresolvedSymbol function, Expression argument) { var argumentType = getInferredType(argument); if (argumentType == null) { return null; @@ -386,7 +384,7 @@ private TypeRepresentation processUnresolvedSymbolApplication(TypeRepresentation switch (argumentType.type()) { case TypeRepresentation.TypeObject typeObject -> { - var typeDescription = findTypeDescription(typeObject.name(), argument); + var typeDescription = findResolvedType(bindingsMap, typeObject.name()); Option ctorCandidate = typeDescription.members().find((ctor) -> ctor.name().equals(function.name())); if (ctorCandidate.isDefined()) { return buildAtomConstructorType(typeObject, ctorCandidate.get()); diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/context/InlineContext.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/context/InlineContext.scala index b09eb92b7e63d..de2b9c0004ba4 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/context/InlineContext.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/context/InlineContext.scala @@ -2,7 +2,7 @@ package org.enso.compiler.context import org.enso.compiler.PackageRepository import org.enso.compiler.context.LocalScope -import org.enso.compiler.data.CompilerConfig +import org.enso.compiler.data.{BindingsMap, CompilerConfig} import org.enso.compiler.pass.PassConfiguration /** A type containing the information about the execution context for an inline @@ -26,7 +26,7 @@ case class InlineContext( passConfiguration: Option[PassConfiguration] = None, pkgRepo: Option[PackageRepository] = None ) { - def bindingsAnalysis() = module.bindingsAnalysis() + def bindingsAnalysis(): BindingsMap = module.bindingsAnalysis() } object InlineContext { diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/context/ModuleContext.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/context/ModuleContext.scala index e12753f9d0797..0e21f0a1116ad 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/context/ModuleContext.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/context/ModuleContext.scala @@ -1,10 +1,10 @@ package org.enso.compiler.context import org.enso.compiler.PackageRepository -import org.enso.compiler.data.CompilerConfig +import org.enso.compiler.data.{BindingsMap, CompilerConfig} import org.enso.compiler.pass.PassConfiguration -import org.enso.pkg.Package; -import org.enso.pkg.QualifiedName; +import org.enso.pkg.Package +import org.enso.pkg.QualifiedName import com.oracle.truffle.api.source.Source import org.enso.compiler.data.BindingsMap.ModuleReference @@ -25,8 +25,8 @@ case class ModuleContext( isGeneratingDocs: Boolean = false, pkgRepo: Option[PackageRepository] = None ) { - def isSynthetic() = module.isSynthetic() - def bindingsAnalysis() = module.getBindingsMap() + def isSynthetic(): Boolean = module.isSynthetic() + def bindingsAnalysis(): BindingsMap = module.getBindingsMap() def getName(): QualifiedName = module.getName() def getPackage(): Package[_] = module.getPackage() def getSource(): Source = module.getSource() diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/data/BindingsMap.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/data/BindingsMap.scala index 85526ee47a743..e88b3a2be73ad 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/data/BindingsMap.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/data/BindingsMap.scala @@ -9,6 +9,7 @@ import org.enso.compiler.core.ir.expression.errors import org.enso.compiler.data.BindingsMap.{DefinedEntity, ModuleReference} import org.enso.compiler.core.CompilerError import org.enso.compiler.core.ir.Expression +import org.enso.compiler.core.ir.module.scope.Definition import org.enso.compiler.pass.IRPass import org.enso.compiler.pass.analyse.BindingAnalysis import org.enso.compiler.pass.resolve.MethodDefinitions @@ -739,13 +740,40 @@ object BindingsMap { * @param members the member names * @param builtinType true if constructor is annotated with @Builtin_Type, false otherwise. */ - case class Type(// TODO if type can hold Cons->Argument->Expression, we need to make it have toAbstract + case class Type( override val name: String, params: Seq[String], members: Seq[Cons], builtinType: Boolean ) extends DefinedEntity { override def canExport: Boolean = true + + def toAbstract: Type = + this.copy(members = Seq()) + + def toConcrete(concreteModule: ModuleReference.Concrete): Option[Type] = { + val ir = concreteModule.unsafeAsModule().getIr + ir.bindings.collectFirst { + case typeIr: Definition.Type if typeIr.name.name == name => typeIr + }.map { typeIr => Type.fromIr(typeIr, builtinType) } + } + } + + object Type { + def fromIr(ir: Definition.Type, isBuiltinType: Boolean): Type = + BindingsMap.Type( + ir.name.name, + ir.params.map(_.name.name), + ir.members.map(m => + Cons( + m.name.name, + m.arguments.map(arg => + BindingsMap.Argument(arg.name.name, arg.defaultValue.isDefined, arg.ascribedType) + ) + ) + ), + isBuiltinType + ) } /** A representation of an imported polyglot symbol. @@ -777,16 +805,16 @@ object BindingsMap { } /** @inheritdoc */ - override def toAbstract: ResolvedType = { - this.copy(module = module.toAbstract) - } + override def toAbstract: ResolvedType = + this.copy(module = module.toAbstract, tp = tp.toAbstract) /** @inheritdoc */ override def toConcrete( moduleMap: ModuleMap - ): Option[ResolvedType] = { - module.toConcrete(moduleMap).map(module => this.copy(module = module)) - } + ): Option[ResolvedType] = for { + concreteModule <- module.toConcrete(moduleMap) + concreteTp <- tp.toConcrete(concreteModule) + } yield ResolvedType(concreteModule, concreteTp) override def qualifiedName: QualifiedName = module.getName.createChild(tp.name) diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/BindingAnalysis.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/BindingAnalysis.scala index acbdf6871b688..8f8899f4cb475 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/BindingAnalysis.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/BindingAnalysis.scala @@ -8,7 +8,6 @@ import org.enso.compiler.core.ir.module.scope.definition import org.enso.compiler.core.ir.module.scope.imports import org.enso.compiler.core.ir.MetadataStorage.MetadataPair import org.enso.compiler.data.BindingsMap -import org.enso.compiler.data.BindingsMap.Cons import org.enso.compiler.pass.IRPass import org.enso.compiler.pass.desugar.{ ComplexType, @@ -56,19 +55,7 @@ case object BindingAnalysis extends IRPass { val isBuiltinType = sumType .getMetadata(ModuleAnnotations) .exists(_.annotations.exists(_.name == "@Builtin_Type")) - BindingsMap.Type( - sumType.name.name, - sumType.params.map(_.name.name), - sumType.members.map(m => - Cons( - m.name.name, - m.arguments.map(arg => - BindingsMap.Argument(arg.name.name, arg.defaultValue.isDefined, arg.ascribedType) - ) - ) - ), - isBuiltinType - ) + BindingsMap.Type.fromIr(sumType, isBuiltinType) } val importedPolyglot = ir.imports.collect { case poly: imports.Polyglot => BindingsMap.PolyglotSymbol(poly.getVisibleName) diff --git a/engine/runtime-integration-tests/src/test/java/org/enso/compiler/TypeInferenceTest.java b/engine/runtime-integration-tests/src/test/java/org/enso/compiler/TypeInferenceTest.java index fe51b83dfdfef..5bd4212b3d3c5 100644 --- a/engine/runtime-integration-tests/src/test/java/org/enso/compiler/TypeInferenceTest.java +++ b/engine/runtime-integration-tests/src/test/java/org/enso/compiler/TypeInferenceTest.java @@ -31,7 +31,6 @@ import scala.jdk.javaapi.CollectionConverters; public class TypeInferenceTest extends CompilerTest { - @Ignore @Test public void zeroAryCheck() throws Exception { final URI uri = new URI("memory://zeroAryCheck.enso"); diff --git a/lib/scala/pkg/src/main/scala/org/enso/pkg/QualifiedName.scala b/lib/scala/pkg/src/main/scala/org/enso/pkg/QualifiedName.scala index 840041a5f92b1..f59ed1f16ff32 100644 --- a/lib/scala/pkg/src/main/scala/org/enso/pkg/QualifiedName.scala +++ b/lib/scala/pkg/src/main/scala/org/enso/pkg/QualifiedName.scala @@ -52,6 +52,8 @@ case class QualifiedName(path: List[String], item: String) { def pathAsJava(): java.util.List[String] = { path.asJava } + + def fullPath(): List[String] = path :+ item } object QualifiedName {