Skip to content

Commit

Permalink
Store whole IR.Module in .bindings cache
Browse files Browse the repository at this point in the history
  • Loading branch information
JaroslavTulach committed Feb 1, 2024
1 parent 41c7b5a commit 6803429
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.enso.compiler.core.ir.module.scope.Export
import org.enso.compiler.core.ir.module.scope.Import
import org.enso.compiler.core.ir.module.scope.imports;
import org.enso.compiler.core.EnsoParser
import org.enso.compiler.data.{BindingsMap, CompilerConfig}
import org.enso.compiler.data.CompilerConfig
import org.enso.compiler.exception.CompilationAbortedException
import org.enso.compiler.pass.PassManager
import org.enso.compiler.pass.analyse._
Expand Down Expand Up @@ -487,19 +487,12 @@ class Compiler(
// TODO: consider generating IR for synthetic modules, if possible.
importExportBindings(module) match {
case Some(bindings) =>
val converted = bindings
.toConcrete(packageRepository.getModuleMap)
.map { concreteBindings =>
concreteBindings
context.updateModule(
module,
u => {
u.ir(bindings)
}
ensureParsed(module)
val currentLocal = module.getBindingsMap()
currentLocal.resolvedImports =
converted.map(_.resolvedImports).getOrElse(Nil)
currentLocal.resolvedExports =
converted.map(_.resolvedExports).getOrElse(Nil)
currentLocal.exportedSymbols =
converted.map(_.exportedSymbols).getOrElse(Map.empty)
)
case _ =>
}
}
Expand Down Expand Up @@ -563,7 +556,7 @@ class Compiler(
* @param module module which is conssidered
* @return module's bindings, if available in libraries' bindings cache
*/
def importExportBindings(module: Module): Option[BindingsMap] = {
def importExportBindings(module: Module): Option[IRModule] = {
if (irCachingEnabled && !context.isInteractive(module)) {
val libraryName = Option(module.getPackage).map(_.libraryName)
libraryName.flatMap(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package org.enso.compiler
import com.oracle.truffle.api.TruffleFile
import org.enso.editions.LibraryName
import org.enso.compiler.context.CompilerContext
import org.enso.compiler.data.BindingsMap
import org.enso.compiler.core.ir.{Module => IRModule}
import org.enso.pkg.{ComponentGroups, Package, QualifiedName}

import scala.collection.immutable.ListSet
Expand Down Expand Up @@ -113,7 +113,7 @@ trait PackageRepository {
libraryName: LibraryName,
moduleName: QualifiedName,
context: CompilerContext
): Option[BindingsMap]
): Option[IRModule]

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,18 +123,14 @@ class ImportResolver(compiler: Compiler) {
// - no - ensure they are parsed (load them from cache) and add them to the import/export resolution
compiler.importExportBindings(current) match {
case Some(bindings) =>
val converted = bindings
.toConcrete(compiler.packageRepository.getModuleMap)
.map { concreteBindings =>
compiler.context.updateModule(
current,
{ u =>
u.bindingsMap(concreteBindings)
u.loadedFromCache(true)
}
)
concreteBindings
compiler.context.updateModule(
current,
u => {
u.ir(bindings)
u.loadedFromCache(true)
}
)
val converted = Option(current.getBindingsMap())
(
converted
.map(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package org.enso.interpreter.caches;

import java.io.IOException;
import java.util.UUID;
import java.util.function.Function;

import org.enso.compiler.context.CompilerContext;
import org.enso.compiler.core.ir.ProcessingPass;

final class CacheUtils {
private CacheUtils() {
}

static Function<Object, Object> writeReplace(CompilerContext context) {
return (obj) -> switch (obj) {
case ProcessingPass.Metadata metadata -> metadata.prepareForSerialization(context);
case UUID _ -> null;
case null -> null;
default -> obj;
};
}

static Function<Object, Object> readResolve(CompilerContext context) {
return (obj) -> switch (obj) {
case ProcessingPass.Metadata metadata -> {
var option = metadata.restoreFromSerialization(context);
if (option.nonEmpty()) {
yield option.get();
} else {
throw raise(
RuntimeException.class, new IOException("Cannot convert " + metadata));
}
}
case null -> null;
default -> obj;
};

}

@SuppressWarnings("unchecked")
static <T extends Exception> T raise(Class<T> cls, Exception e) throws T {
throw (T) e;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
import org.enso.pkg.QualifiedName;
import org.enso.pkg.SourceFile;
import org.openide.util.lookup.ServiceProvider;
import scala.Option;
import scala.Tuple2;
import scala.collection.immutable.Map;

@Persistable(clazz = QualifiedName.class, id = 30300)
public final class ImportExportCache
Expand All @@ -46,11 +43,19 @@ protected byte[] metadata(String sourceDigest, String blobDigest, CachedBindings
return new Metadata(sourceDigest, blobDigest).toBytes();
}

@Override
protected byte[] serialize(EnsoContext context, CachedBindings entry) throws IOException {
var arr =
Persistance.write(
entry.bindings(), CacheUtils.writeReplace(context.getCompiler().context()));
return arr;
}

@Override
protected CachedBindings deserialize(
EnsoContext context, byte[] data, Metadata meta, TruffleLogger logger)
throws ClassNotFoundException, IOException, ClassNotFoundException {
var ref = Persistance.read(data, null);
var ref = Persistance.read(data, CacheUtils.readResolve(context.getCompiler().context()));
var bindings = ref.get(MapToBindings.class);
return new CachedBindings(libraryName, bindings, Optional.empty());
}
Expand Down Expand Up @@ -103,57 +108,34 @@ protected Optional<Cache.Roots> getCacheRoots(EnsoContext context) {
});
}

@Override
protected byte[] serialize(EnsoContext context, CachedBindings entry) throws IOException {
var arr = Persistance.write(entry.bindings(), null);
return arr;
}

public static final class MapToBindings {
private final Map<QualifiedName, Persistance.Reference<BindingsMap>> entries;
private final java.util.Map<QualifiedName, org.enso.compiler.core.ir.Module> entries;

public MapToBindings(Map<QualifiedName, Persistance.Reference<BindingsMap>> entries) {
public MapToBindings(java.util.Map<QualifiedName, org.enso.compiler.core.ir.Module> entries) {
this.entries = entries;
}

public Option<BindingsMap> findForModule(QualifiedName moduleName) {
var ref = entries.get(moduleName);
if (ref.isEmpty()) {
return Option.empty();
}
return Option.apply(ref.get().get(BindingsMap.class));
public org.enso.compiler.core.ir.Module findForModule(QualifiedName moduleName) {
return entries.get(moduleName);
}
}

@ServiceProvider(service = Persistance.class)
public static final class PersistMapToBindings extends Persistance<MapToBindings> {
public PersistMapToBindings() {
super(MapToBindings.class, false, 364);
super(MapToBindings.class, false, 3642);
}

@Override
protected void writeObject(MapToBindings obj, Output out) throws IOException {
out.writeInt(obj.entries.size());
var it = obj.entries.iterator();
while (it.hasNext()) {
var e = it.next();
out.writeInline(QualifiedName.class, e._1());
out.writeObject(e._2().get(BindingsMap.class));
}
out.writeInline(java.util.Map.class, obj.entries);
}

@Override
@SuppressWarnings("unchecked")
protected MapToBindings readObject(Input in) throws IOException, ClassNotFoundException {
var size = in.readInt();
var b = Map.newBuilder();
b.sizeHint(size);
while (size-- > 0) {
var name = in.readInline(QualifiedName.class);
var value = in.readReference(BindingsMap.class);
b.addOne(Tuple2.apply(name, value));
}
return new MapToBindings((Map) b.result());
var map = in.readInline(java.util.Map.class);
return new MapToBindings(map);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Optional;
import java.util.UUID;
import java.util.logging.Level;
import org.apache.commons.lang3.StringUtils;
import org.enso.compiler.core.ir.Module;
import org.enso.compiler.core.ir.ProcessingPass;
import org.enso.interpreter.runtime.EnsoContext;
import org.enso.interpreter.runtime.builtin.Builtins;
import org.enso.persist.Persistance;
Expand All @@ -40,27 +38,19 @@ protected byte[] metadata(String sourceDigest, String blobDigest, CachedModule e
return new Metadata(sourceDigest, blobDigest, entry.compilationStage().toString()).toBytes();
}

@Override
protected byte[] serialize(EnsoContext context, CachedModule entry) throws IOException {
var arr =
Persistance.write(
entry.moduleIR(), CacheUtils.writeReplace(context.getCompiler().context()));
return arr;
}

@Override
protected CachedModule deserialize(
EnsoContext context, byte[] data, Metadata meta, TruffleLogger logger)
throws ClassNotFoundException, IOException, ClassNotFoundException {
var ref =
Persistance.read(
data,
(obj) ->
switch (obj) {
case ProcessingPass.Metadata metadata -> {
var option = metadata.restoreFromSerialization(context.getCompiler().context());
if (option.nonEmpty()) {
yield option.get();
} else {
throw raise(
RuntimeException.class, new IOException("Cannot convert " + metadata));
}
}
case null -> null;
default -> obj;
});
var ref = Persistance.read(data, CacheUtils.readResolve(context.getCompiler().context()));
var mod = ref.get(Module.class);
return new CachedModule(
mod, CompilationStage.valueOf(meta.compilationStage()), module.getSource());
Expand Down Expand Up @@ -146,22 +136,6 @@ protected Optional<Cache.Roots> getCacheRoots(EnsoContext context) {
}
}

@Override
protected byte[] serialize(EnsoContext context, CachedModule entry) throws IOException {
var arr =
Persistance.write(
entry.moduleIR(),
(obj) ->
switch (obj) {
case ProcessingPass.Metadata metadata -> metadata.prepareForSerialization(
context.getCompiler().context());
case UUID uuid -> null;
case null -> null;
default -> obj;
});
return arr;
}

public record CachedModule(Module moduleIR, CompilationStage compilationStage, Source source) {}

public record Metadata(String sourceHash, String blobHash, String compilationStage)
Expand All @@ -187,9 +161,4 @@ static Metadata read(byte[] arr) throws IOException {
private static final String irCacheDataExtension = ".ir";

private static final String irCacheMetadataExtension = ".meta";

@SuppressWarnings("unchecked")
private static <T extends Exception> T raise(Class<T> cls, Exception e) throws T {
throw (T) e;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import com.oracle.truffle.api.source.Source;
import java.io.IOException;
import java.io.PrintStream;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.Future;
Expand All @@ -21,6 +23,7 @@
import org.enso.compiler.pass.analyse.BindingAnalysis$;
import org.enso.editions.LibraryName;
import org.enso.interpreter.caches.Cache;
import org.enso.interpreter.caches.ImportExportCache.MapToBindings;
import org.enso.interpreter.caches.ModuleCache;
import org.enso.interpreter.runtime.type.Types;
import org.enso.pkg.Package;
Expand Down Expand Up @@ -250,9 +253,41 @@ public Future<Boolean> serializeModule(
return (Future<Boolean>) res;
}

private final Map<LibraryName, MapToBindings> known = new HashMap<>();

@Override
public boolean deserializeModule(Compiler compiler, CompilerContext.Module module) {
var library = module.getPackage().libraryName();
var bindings = known.get(library);
if (bindings == null) {
var cached = serializationManager.deserializeLibraryBindings(library);
if (cached.isDefined()) {
bindings = cached.get().bindings();
known.put(library, bindings);
}
}
if (bindings != null) {
var ir = bindings.findForModule(module.getName());
loggerSerializationManager.log(
Level.FINE,
"Deserializing module " + module.getName() + " from library: " + (ir != null));
if (ir != null) {
compiler
.context()
.updateModule(
module,
(u) -> {
u.ir(ir);
u.compilationStage(CompilationStage.AFTER_STATIC_PASSES);
u.loadedFromCache(true);
});
return true;
}
}
var level = "Standard".equals(library.namespace()) ? Level.WARNING : Level.FINE;
var result = serializationManager.deserialize(compiler, module);
loggerSerializationManager.log(
level, "Deserializing module " + module.getName() + " from IR file: " + result.nonEmpty());
return result.nonEmpty();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package org.enso.interpreter.runtime

import org.enso.compiler.PackageRepository
import org.enso.compiler.context.CompilerContext
import org.enso.compiler.data.BindingsMap
import org.enso.compiler.core.ir.{Module => IRModule}
import com.oracle.truffle.api.TruffleFile
import com.typesafe.scalalogging.Logger
import org.apache.commons.lang3.StringUtils
Expand Down Expand Up @@ -574,7 +574,7 @@ private class DefaultPackageRepository(
libraryName: LibraryName,
moduleName: QualifiedName,
context: CompilerContext
): Option[BindingsMap] = {
): Option[IRModule] = {
val cache = ensurePackageIsLoaded(libraryName).toOption.flatMap { _ =>
if (!loadedLibraryBindings.contains(libraryName)) {
loadedPackages.get(libraryName).flatten.foreach(loadDependencies(_))
Expand All @@ -586,7 +586,7 @@ private class DefaultPackageRepository(
}
loadedLibraryBindings.get(libraryName)
}
cache.flatMap(_.flatMap(_.bindings.findForModule(moduleName)))
cache.flatMap(_.map(_.bindings.findForModule(moduleName)))
}

private def loadDependencies(pkg: Package[TruffleFile]): Unit = {
Expand Down
Loading

0 comments on commit 6803429

Please sign in to comment.