From 9a91b7bcc61ec4b72cdfb25aadafe08fcbe93ae8 Mon Sep 17 00:00:00 2001 From: Jaroslav Tulach Date: Fri, 9 Feb 2024 04:51:45 +0100 Subject: [PATCH] Store whole IR.Module in .bindings cache (#8924) --- .../compiler/context/CompilerContext.java | 5 +- .../scala/org/enso/compiler/Compiler.scala | 53 +- .../org/enso/compiler/PackageRepository.scala | 4 +- .../enso/compiler/phase/ImportResolver.scala | 22 +- .../instrument/job/SerializeModuleJob.java | 9 +- .../DeserializeLibrarySuggestionsJob.scala | 8 +- .../org/enso/interpreter/caches/Cache.java | 316 ++++---- .../enso/interpreter/caches/CacheUtils.java | 99 +++ .../interpreter/caches/ImportExportCache.java | 106 +-- .../enso/interpreter/caches/ModuleCache.java | 106 ++- .../interpreter/caches/SuggestionsCache.java | 75 +- .../org/enso/interpreter/runtime/Module.java | 13 +- .../runtime/SerializationPool.java | 206 ++++++ .../interpreter/runtime/ThreadExecutors.java | 22 +- .../runtime/TruffleCompilerContext.java | 392 +++++++++- .../runtime/DefaultPackageRepository.scala | 7 +- .../runtime/SerializationManager.scala | 694 ------------------ .../org/enso/compiler/SerdeCompilerTest.java | 5 +- .../compiler/SerializationManagerTest.java | 15 +- .../org/enso/compiler/SerializerTest.java | 11 +- .../interpreter/caches/ModuleCacheTest.java | 12 +- .../java/org/enso/persist/PerInputImpl.java | 5 +- .../java/org/enso/persist/Persistance.java | 18 +- 23 files changed, 1113 insertions(+), 1090 deletions(-) create mode 100644 engine/runtime/src/main/java/org/enso/interpreter/caches/CacheUtils.java create mode 100644 engine/runtime/src/main/java/org/enso/interpreter/runtime/SerializationPool.java delete mode 100644 engine/runtime/src/main/scala/org/enso/interpreter/runtime/SerializationManager.scala diff --git a/engine/runtime-compiler/src/main/java/org/enso/compiler/context/CompilerContext.java b/engine/runtime-compiler/src/main/java/org/enso/compiler/context/CompilerContext.java index f6526c008299..effa33360432 100644 --- a/engine/runtime-compiler/src/main/java/org/enso/compiler/context/CompilerContext.java +++ b/engine/runtime-compiler/src/main/java/org/enso/compiler/context/CompilerContext.java @@ -104,10 +104,13 @@ Future serializeLibrary( Compiler compiler, LibraryName libraryName, boolean useGlobalCacheLocations); Future serializeModule( - Compiler compiler, Module module, boolean useGlobalCacheLocations); + Compiler compiler, Module module, boolean useGlobalCacheLocations, boolean usePool); boolean deserializeModule(Compiler compiler, Module module); + scala.Option> deserializeSuggestions(LibraryName libraryName) + throws InterruptedException; + void shutdown(boolean waitForPendingJobCompletion); public static interface Updater { diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/Compiler.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/Compiler.scala index 565a58292b32..67e51a9d6fb7 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/Compiler.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/Compiler.scala @@ -24,7 +24,7 @@ import org.enso.compiler.core.ir.module.scope.Export import org.enso.compiler.core.ir.module.scope.Import import org.enso.compiler.core.ir.module.scope.imports import org.enso.compiler.core.EnsoParser -import org.enso.compiler.data.{BindingsMap, CompilerConfig} +import org.enso.compiler.data.CompilerConfig import org.enso.compiler.exception.CompilationAbortedException import org.enso.compiler.pass.PassManager import org.enso.compiler.pass.analyse._ @@ -200,7 +200,11 @@ class Compiler( */ def generateDocs(module: Module): Module = { initialize() - parseModule(module, isGenDocs = true) + parseModule( + module, + irCachingEnabled && !context.isInteractive(module), + isGenDocs = true + ) module } @@ -244,7 +248,7 @@ class Compiler( initialize() modules.foreach(m => try { - parseModule(m) + parseModule(m, irCachingEnabled && !context.isInteractive(m)) } catch { case e: Throwable => context.log( @@ -286,12 +290,12 @@ class Compiler( ) ) context.updateModule(module, _.invalidateCache()) - parseModule(module) + parseModule(module, irCachingEnabled && !context.isInteractive(module)) importedModules .filter(isLoadedFromSource) .map(m => { if (m.getBindingsMap() == null) { - parseModule(m) + parseModule(m, irCachingEnabled && !context.isInteractive(module)) } }) runImportsAndExportsResolution(module, generateCode) @@ -302,7 +306,7 @@ class Compiler( if (irCachingEnabled) { requiredModules.foreach { module => - ensureParsed(module) + ensureParsed(module, !context.isInteractive(module)) } } requiredModules.foreach { module => @@ -404,6 +408,7 @@ class Compiler( if (shouldCompileDependencies || isModuleInRootPackage(module)) { val shouldStoreCache = + generateCode && irCachingEnabled && !context.wasLoadedFromCache(module) if ( shouldStoreCache && !hasErrors(module) && @@ -415,7 +420,8 @@ class Compiler( context.serializeModule( this, module, - useGlobalCacheLocations + useGlobalCacheLocations, + true ) } } @@ -481,7 +487,7 @@ class Compiler( private def ensureParsedAndAnalyzed(module: Module): Unit = { if (module.getBindingsMap() == null) { - ensureParsed(module) + ensureParsed(module, irCachingEnabled && !context.isInteractive(module)) } if (context.isSynthetic(module)) { // Synthetic modules need to be import-analyzed @@ -490,19 +496,10 @@ class Compiler( // TODO: consider generating IR for synthetic modules, if possible. importExportBindings(module) match { case Some(bindings) => - val converted = bindings - .toConcrete(packageRepository.getModuleMap) - .map { concreteBindings => - concreteBindings - } - ensureParsed(module) - val currentLocal = module.getBindingsMap() - currentLocal.resolvedImports = - converted.map(_.resolvedImports).getOrElse(Nil) - currentLocal.resolvedExports = - converted.map(_.resolvedExports).getOrElse(Nil) - currentLocal.exportedSymbols = - converted.map(_.exportedSymbols).getOrElse(Map.empty) + context.updateModule( + module, + _.ir(bindings) + ) case _ => } } @@ -520,7 +517,7 @@ class Compiler( * @param module - the scope from which docs are generated. */ def gatherImportStatements(module: Module): Array[String] = { - ensureParsed(module) + ensureParsed(module, irCachingEnabled && !context.isInteractive(module)) val importedModules = context.getIr(module).imports.flatMap { case imp: Import.Module => imp.name.parts.take(2).map(_.name) match { @@ -543,6 +540,7 @@ class Compiler( private def parseModule( module: Module, + useCaches: Boolean, isGenDocs: Boolean = false ): Unit = { context.log( @@ -552,7 +550,7 @@ class Compiler( ) context.updateModule(module, _.resetScope()) - if (irCachingEnabled && !context.isInteractive(module)) { + if (useCaches) { if (context.deserializeModule(this, module)) { return } @@ -566,7 +564,7 @@ class Compiler( * @param module module which is conssidered * @return module's bindings, if available in libraries' bindings cache */ - def importExportBindings(module: Module): Option[BindingsMap] = { + def importExportBindings(module: Module): Option[IRModule] = { if (irCachingEnabled && !context.isInteractive(module)) { val libraryName = Option(module.getPackage).map(_.libraryName) libraryName.flatMap( @@ -646,6 +644,11 @@ class Compiler( * @param module the module to ensure is parsed. */ def ensureParsed(module: Module): Unit = { + val useCaches = irCachingEnabled && !context.isInteractive(module) + ensureParsed(module, useCaches) + } + + def ensureParsed(module: Module, useCaches: Boolean): Unit = { if ( !context .getCompilationStage(module) @@ -653,7 +656,7 @@ class Compiler( CompilationStage.AFTER_PARSING ) ) { - parseModule(module) + parseModule(module, useCaches) } } diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/PackageRepository.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/PackageRepository.scala index 0ab999e5faa7..57fa6f27d8f1 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/PackageRepository.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/PackageRepository.scala @@ -3,7 +3,7 @@ package org.enso.compiler import com.oracle.truffle.api.TruffleFile import org.enso.editions.LibraryName import org.enso.compiler.context.CompilerContext -import org.enso.compiler.data.BindingsMap +import org.enso.compiler.core.ir.{Module => IRModule} import org.enso.pkg.{ComponentGroups, Package, QualifiedName} import scala.collection.immutable.ListSet @@ -113,7 +113,7 @@ trait PackageRepository { libraryName: LibraryName, moduleName: QualifiedName, context: CompilerContext - ): Option[BindingsMap] + ): Option[IRModule] } diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/phase/ImportResolver.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/phase/ImportResolver.scala index 9ca7d62a3dda..e7c7372396ff 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/phase/ImportResolver.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/phase/ImportResolver.scala @@ -66,7 +66,7 @@ class ImportResolver(compiler: Compiler) { u.invalidateCache() } ) - compiler.ensureParsed(current) + compiler.ensureParsed(current, false) return analyzeModule(current) } // put the list of resolved imports in the module metadata @@ -123,18 +123,14 @@ class ImportResolver(compiler: Compiler) { // - no - ensure they are parsed (load them from cache) and add them to the import/export resolution compiler.importExportBindings(current) match { case Some(bindings) => - val converted = bindings - .toConcrete(compiler.packageRepository.getModuleMap) - .map { concreteBindings => - compiler.context.updateModule( - current, - { u => - u.bindingsMap(concreteBindings) - u.loadedFromCache(true) - } - ) - concreteBindings + compiler.context.updateModule( + current, + u => { + u.ir(bindings) + u.loadedFromCache(true) } + ) + val converted = Option(current.getBindingsMap()) ( converted .map( @@ -183,7 +179,7 @@ class ImportResolver(compiler: Compiler) { u.compilationStage(CompilationStage.INITIAL) } ) - compiler.ensureParsed(mod) + compiler.ensureParsed(mod, false) b = mod.getBindingsMap() } diff --git a/engine/runtime-instrument-common/src/main/java/org/enso/interpreter/instrument/job/SerializeModuleJob.java b/engine/runtime-instrument-common/src/main/java/org/enso/interpreter/instrument/job/SerializeModuleJob.java index 408226bd64b2..6d9ce4ebc0a8 100644 --- a/engine/runtime-instrument-common/src/main/java/org/enso/interpreter/instrument/job/SerializeModuleJob.java +++ b/engine/runtime-instrument-common/src/main/java/org/enso/interpreter/instrument/job/SerializeModuleJob.java @@ -2,7 +2,6 @@ import java.util.logging.Level; import org.enso.interpreter.instrument.execution.RuntimeContext; -import org.enso.interpreter.runtime.SerializationManager; import org.enso.pkg.QualifiedName; import org.enso.polyglot.CompilationStage; @@ -22,7 +21,6 @@ public SerializeModuleJob(QualifiedName moduleName) { public Void run(RuntimeContext ctx) { var ensoContext = ctx.executionService().getContext(); var compiler = ensoContext.getCompiler(); - SerializationManager serializationManager = SerializationManager.apply(compiler.context()); boolean useGlobalCacheLocations = ensoContext.isUseGlobalCache(); var writeLockTimestamp = ctx.locking().acquireWriteCompilationLock(); try { @@ -40,9 +38,10 @@ public Void run(RuntimeContext ctx) { new Object[] {module.getName(), module.getCompilationStage()}); return; } - - serializationManager.serializeModule( - compiler, module.asCompilerModule(), useGlobalCacheLocations, false); + compiler + .context() + .serializeModule( + compiler, module.asCompilerModule(), useGlobalCacheLocations, false); }); } finally { ctx.locking().releaseWriteCompilationLock(); diff --git a/engine/runtime-instrument-common/src/main/scala/org/enso/interpreter/instrument/job/DeserializeLibrarySuggestionsJob.scala b/engine/runtime-instrument-common/src/main/scala/org/enso/interpreter/instrument/job/DeserializeLibrarySuggestionsJob.scala index 77e2950162a9..7f9277f19bdb 100644 --- a/engine/runtime-instrument-common/src/main/scala/org/enso/interpreter/instrument/job/DeserializeLibrarySuggestionsJob.scala +++ b/engine/runtime-instrument-common/src/main/scala/org/enso/interpreter/instrument/job/DeserializeLibrarySuggestionsJob.scala @@ -1,7 +1,6 @@ package org.enso.interpreter.instrument.job import org.enso.editions.LibraryName -import org.enso.interpreter.runtime.SerializationManager import org.enso.interpreter.instrument.execution.RuntimeContext import org.enso.polyglot.runtime.Runtime.Api @@ -33,17 +32,14 @@ final class DeserializeLibrarySuggestionsJob( "Deserializing suggestions for library [{}].", libraryName ) - val serializationManager = SerializationManager( - ctx.executionService.getContext.getCompiler.context - ) - serializationManager + ctx.executionService.getContext.getCompiler.context .deserializeSuggestions(libraryName) .foreach { cachedSuggestions => ctx.endpoint.sendToClient( Api.Response( Api.SuggestionsDatabaseSuggestionsLoadedNotification( libraryName, - cachedSuggestions.getSuggestions.asScala.toVector + cachedSuggestions.asScala.toVector ) ) ) diff --git a/engine/runtime/src/main/java/org/enso/interpreter/caches/Cache.java b/engine/runtime/src/main/java/org/enso/interpreter/caches/Cache.java index b66e0e394b34..3aaeef35abc4 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/caches/Cache.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/caches/Cache.java @@ -2,21 +2,19 @@ import com.oracle.truffle.api.TruffleFile; import com.oracle.truffle.api.TruffleLogger; +import java.io.File; import java.io.IOException; import java.io.OutputStream; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; import java.nio.file.NoSuchFileException; import java.nio.file.Path; import java.nio.file.StandardOpenOption; -import java.security.MessageDigest; -import java.security.NoSuchAlgorithmException; -import java.util.Comparator; -import java.util.List; import java.util.Optional; import java.util.logging.Level; import org.enso.interpreter.runtime.EnsoContext; import org.enso.logger.masking.MaskedPath; -import org.enso.pkg.SourceFile; -import org.enso.text.Hex; /** * Cache encapsulates a common functionality needed to serialize and de-serialize objects, while @@ -25,11 +23,14 @@ * @param type of the cached data * @param type of the metadata associated with the data */ -public abstract class Cache { +public final class Cache { private final Object LOCK = new Object(); + /** implementation of the serialize/deserialize operations */ + private final Spi spi; + /** Returns a default level of logging for this Cache. */ - protected final Level logLevel; + private final Level logLevel; /** Log name to use in log messages */ private final String logName; @@ -57,17 +58,41 @@ public abstract class Cache { * @param needsDataDigestVerification Flag indicating if the de-serialization process should * compute the hash of the stored cache and compare it with the stored metadata entry. */ - protected Cache( + private Cache( + Cache.Spi spi, Level logLevel, String logName, boolean needsSourceDigestVerification, boolean needsDataDigestVerification) { + this.spi = spi; this.logLevel = logLevel; this.logName = logName; this.needsDataDigestVerification = needsDataDigestVerification; this.needsSourceDigestVerification = needsSourceDigestVerification; } + /** + * Factory method to create new cache instance. + * + * @param spi the implementation logic of the cache + * @param logLevel logging level + * @param logName name to use in logs + * @param needsSourceDigestVerification Flag indicating if the de-serialization process should + * compute the hash of the sources from which the cache was created and compare it with the + * stored metadata entry. + * @param needsDataDigestVerification Flag indicating if the de-serialization process should + * compute the hash of the stored cache and compare it with the stored metadata entry. + */ + static Cache create( + Cache.Spi spi, + Level logLevel, + String logName, + boolean needsSourceDigestVerification, + boolean needsDataDigestVerification) { + return new Cache<>( + spi, logLevel, logName, needsSourceDigestVerification, needsDataDigestVerification); + } + /** * Saves data to a cache file. * @@ -80,7 +105,7 @@ protected Cache( public final Optional save( T entry, EnsoContext context, boolean useGlobalCacheLocations) { TruffleLogger logger = context.getLogger(this.getClass()); - return getCacheRoots(context) + return spi.getCacheRoots(context) .flatMap( roots -> { try { @@ -123,14 +148,14 @@ private boolean saveCacheTo( EnsoContext context, TruffleFile cacheRoot, T entry, TruffleLogger logger) throws IOException, ClassNotFoundException { if (ensureRoot(cacheRoot)) { - byte[] bytesToWrite = serialize(context, entry); + byte[] bytesToWrite = spi.serialize(context, entry); - String blobDigest = computeDigestFromBytes(bytesToWrite); - String sourceDigest = computeDigest(entry, logger).get(); + String blobDigest = CacheUtils.computeDigestFromBytes(ByteBuffer.wrap(bytesToWrite)); + String sourceDigest = spi.computeDigest(entry, logger).get(); if (sourceDigest == null) { throw new ClassNotFoundException("unable to compute digest"); } - byte[] metadataBytes = metadata(sourceDigest, blobDigest, entry); + byte[] metadataBytes = spi.metadata(sourceDigest, blobDigest, entry); TruffleFile cacheDataFile = getCacheDataPath(cacheRoot); TruffleFile metadataFile = getCacheMetadataPath(cacheRoot); @@ -166,18 +191,6 @@ private boolean ensureRoot(TruffleFile cacheRoot) { } } - /** - * Return serialized representation of data's metadata. - * - * @param sourceDigest digest of data's source - * @param blobDigest digest of serialized data - * @param entry data to serialize - * @return raw bytes representing serialized metadata - * @throws java.io.IOException in case of I/O error - */ - protected abstract byte[] metadata(String sourceDigest, String blobDigest, T entry) - throws IOException; - /** * Loads cache for this data, if possible. * @@ -187,7 +200,7 @@ protected abstract byte[] metadata(String sourceDigest, String blobDigest, T ent public final Optional load(EnsoContext context) { synchronized (LOCK) { TruffleLogger logger = context.getLogger(this.getClass()); - return getCacheRoots(context) + return spi.getCacheRoots(context) .flatMap( roots -> { try { @@ -251,24 +264,34 @@ private Optional loadCacheFrom( M meta = optMeta.get(); boolean sourceDigestValid = !needsSourceDigestVerification - || computeDigestFromSource(context, logger) - .map(digest -> digest.equals(meta.sourceHash())) + || spi.computeDigestFromSource(context, logger) + .map(digest -> digest.equals(spi.sourceHash(meta))) .orElseGet(() -> false); - byte[] blobBytes = dataPath.readAllBytes(); + var file = new File(dataPath.toUri()); + ByteBuffer blobBytes; + var threeMbs = 3 * 1024 * 1024; + if (file.exists() && file.length() > threeMbs) { + logger.log(Level.FINE, "Cache file " + file + " mmapped with " + file.length() + " size"); + var raf = new RandomAccessFile(file, "r"); + blobBytes = raf.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, file.length()); + } else { + blobBytes = ByteBuffer.wrap(dataPath.readAllBytes()); + } boolean blobDigestValid = - !needsDataDigestVerification || computeDigestFromBytes(blobBytes).equals(meta.blobHash()); + !needsDataDigestVerification + || CacheUtils.computeDigestFromBytes(blobBytes).equals(spi.blobHash(meta)); if (sourceDigestValid && blobDigestValid) { T cachedObject = null; try { long now = System.currentTimeMillis(); - cachedObject = deserialize(context, blobBytes, meta, logger); + cachedObject = spi.deserialize(context, blobBytes, meta, logger); long took = System.currentTimeMillis() - now; if (cachedObject != null) { logger.log( Level.FINEST, "Loaded cache for {0} with {1} bytes in {2} ms", - new Object[] {logName, blobBytes.length, took}); + new Object[] {logName, blobBytes.limit(), took}); return Optional.of(cachedObject); } else { logger.log(logLevel, "`{0}` was corrupt on disk.", logName); @@ -300,21 +323,6 @@ private Optional loadCacheFrom( } } - /** - * Deserializes and validates data by returning the expected cached entry, or {@code null}. - * - * @param context the context - * @param data data to deserialize object from - * @param meta metadata corresponding to the `obj` - * @param logger Truffle's logger - * @return {@code data} transformed to a cached entry or {@code null} - * @throws ClassNotFoundException exception thrown on unexpected deserialized data - * @throws IOException when I/O goes wrong - * @throws ClassNotFoundException on problems with deserializaiton of Java classes - */ - protected abstract T deserialize(EnsoContext context, byte[] data, M meta, TruffleLogger logger) - throws IOException, ClassNotFoundException, ClassNotFoundException; - /** * Read metadata representation from the provided location * @@ -323,115 +331,12 @@ protected abstract T deserialize(EnsoContext context, byte[] data, M meta, Truff */ private Optional loadCacheMetadata(TruffleFile path, TruffleLogger logger) throws IOException { if (path.isReadable()) { - return metadataFromBytes(path.readAllBytes(), logger); + return spi.metadataFromBytes(path.readAllBytes(), logger); } else { return Optional.empty(); } } - /** - * De-serializes raw bytes to data's metadata. - * - * @param bytes raw bytes representing metadata - * @param logger logger to use - * @return non-empty metadata, if de-serialization was successful - * @throws IOException in case of I/O error - */ - protected abstract Optional metadataFromBytes(byte[] bytes, TruffleLogger logger) - throws IOException; - - /** - * Compute digest of cache's data - * - * @param entry data for which digest should be computed - * @param logger Truffle's logger - * @return non-empty digest, if successful - */ - protected abstract Optional computeDigest(T entry, TruffleLogger logger); - - /** - * Compute digest of data's source - * - * @param context the language context in which loading is taking place - * @param logger Truffle's logger - * @return non-empty digest, if successful - */ - protected abstract Optional computeDigestFromSource( - EnsoContext context, TruffleLogger logger); - - /** - * Computes digest from an array of bytes using a default hashing algorithm. - * - * @param bytes bytes for which hash will be computed - * @return string representation of bytes' hash - */ - protected final String computeDigestFromBytes(byte[] bytes) { - return Hex.toHexString(messageDigest().digest(bytes)); - } - - /** - * Computes digest from package sources using a default hashing algorithm. - * - * @param pkgSources the list of package sources - * @param logger the truffle logger - * @return string representation of bytes' hash - */ - protected final String computeDigestOfLibrarySources( - List> pkgSources, TruffleLogger logger) { - pkgSources.sort(Comparator.comparing(o -> o.qualifiedName().toString())); - - var digest = messageDigest(); - pkgSources.forEach( - source -> { - try { - digest.update(source.file().readAllBytes()); - } catch (IOException e) { - logger.log( - logLevel, "failed to compute digest for " + source.qualifiedName().toString(), e); - } - }); - return Hex.toHexString(digest.digest()); - } - - /** - * Returns a default hashing algorithm used for Enso caches. - * - * @return digest used for computing hashes - */ - protected MessageDigest messageDigest() { - try { - return MessageDigest.getInstance("SHA-1"); - } catch (NoSuchAlgorithmException ex) { - throw new IllegalStateException("Unreachable", ex); - } - } - - /** - * Returns locations where caches can be located - * - * @param context the language context in which loading is taking place - * @return non-empty if the locations have been inferred successfully, empty otherwise - */ - protected abstract Optional getCacheRoots(EnsoContext context); - - /** - * Returns the exact data to be serialized. Override in subclasses to turn an {@code entry} into - * an array of bytes to persist - * - * @param context context we operate in - * @param entry entry to persist - * @return array of bytes - * @throws java.io.IOException if something goes wrong - */ - protected abstract byte[] serialize(EnsoContext context, T entry) throws IOException; - - protected String entryName; - - /** Suffix to be used */ - protected String dataSuffix; - - protected String metadataSuffix; - /** * Gets the path to the cache data within the `cacheRoot`. * @@ -439,11 +344,11 @@ protected MessageDigest messageDigest() { * @return the name of the data file for this entry's cache */ private TruffleFile getCacheDataPath(TruffleFile cacheRoot) { - return cacheRoot.resolve(cacheFileName(dataSuffix)); + return cacheRoot.resolve(cacheFileName(spi.dataSuffix())); } private TruffleFile getCacheMetadataPath(TruffleFile cacheRoot) { - return cacheRoot.resolve(cacheFileName(metadataSuffix)); + return cacheRoot.resolve(cacheFileName(spi.metadataSuffix())); } /** @@ -453,7 +358,7 @@ private TruffleFile getCacheMetadataPath(TruffleFile cacheRoot) { * @return the cache file name with the provided `ext` */ private String cacheFileName(String suffix) { - return entryName + suffix; + return spi.entryName() + suffix; } /** @@ -503,7 +408,7 @@ private void doDeleteAt(TruffleFile cacheRoot, TruffleFile file, TruffleLogger l public final void invalidate(EnsoContext context) { synchronized (LOCK) { TruffleLogger logger = context.getLogger(this.getClass()); - getCacheRoots(context) + spi.getCacheRoots(context) .ifPresent( roots -> { invalidateCache(roots.globalCacheRoot, logger); @@ -512,6 +417,10 @@ public final void invalidate(EnsoContext context) { } } + final T asSpi(Class type) { + return type.cast(spi); + } + /** * Roots encapsulates two possible locations where caches can be stored. * @@ -539,9 +448,96 @@ private static MaskedPath toMaskedPath(TruffleFile truffleFile) { return new MaskedPath(Path.of(truffleFile.getPath())); } - interface Metadata { - String sourceHash(); - - String blobHash(); + /** + * Set of methods to be implemented by those who want to cache something. + * + * @param + */ + public static interface Spi { + /** + * Deserializes and validates data by returning the expected cached entry, or {@code null}. + * + * @param context the context + * @param data data to deserialize object from + * @param meta metadata corresponding to the `obj` + * @param logger Truffle's logger + * @return {@code data} transformed to a cached entry or {@code null} + * @throws ClassNotFoundException exception thrown on unexpected deserialized data + * @throws IOException when I/O goes wrong + */ + public abstract T deserialize( + EnsoContext context, ByteBuffer data, M meta, TruffleLogger logger) + throws IOException, ClassNotFoundException; + + /** + * Returns the exact data to be serialized. Override in subclasses to turn an {@code entry} into + * an array of bytes to persist + * + * @param context context we operate in + * @param entry entry to persist + * @return array of bytes + * @throws java.io.IOException if something goes wrong + */ + public abstract byte[] serialize(EnsoContext context, T entry) throws IOException; + + /** + * Return serialized representation of data's metadata. + * + * @param sourceDigest digest of data's source + * @param blobDigest digest of serialized data + * @param entry data to serialize + * @return raw bytes representing serialized metadata + * @throws java.io.IOException in case of I/O error + */ + public abstract byte[] metadata(String sourceDigest, String blobDigest, T entry) + throws IOException; + + /** + * De-serializes raw bytes to data's metadata. + * + * @param bytes raw bytes representing metadata + * @param logger logger to use + * @return non-empty metadata, if de-serialization was successful + * @throws IOException in case of I/O error + */ + public abstract Optional metadataFromBytes(byte[] bytes, TruffleLogger logger) + throws IOException; + + /** + * Compute digest of cache's data + * + * @param entry data for which digest should be computed + * @param logger Truffle's logger + * @return non-empty digest, if successful + */ + public abstract Optional computeDigest(T entry, TruffleLogger logger); + + /** + * Compute digest of data's source + * + * @param context the language context in which loading is taking place + * @param logger Truffle's logger + * @return non-empty digest, if successful + */ + public abstract Optional computeDigestFromSource( + EnsoContext context, TruffleLogger logger); + + /** + * Returns locations where caches can be located + * + * @param context the language context in which loading is taking place + * @return non-empty if the locations have been inferred successfully, empty otherwise + */ + public abstract Optional getCacheRoots(EnsoContext context); + + public abstract String entryName(); + + public abstract String dataSuffix(); + + public abstract String metadataSuffix(); + + public abstract String sourceHash(M meta); + + public abstract String blobHash(M meta); } } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/caches/CacheUtils.java b/engine/runtime/src/main/java/org/enso/interpreter/caches/CacheUtils.java new file mode 100644 index 000000000000..e14818113729 --- /dev/null +++ b/engine/runtime/src/main/java/org/enso/interpreter/caches/CacheUtils.java @@ -0,0 +1,99 @@ +package org.enso.interpreter.caches; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Comparator; +import java.util.List; +import java.util.UUID; +import java.util.function.Function; + +import org.enso.compiler.context.CompilerContext; +import org.enso.compiler.core.ir.ProcessingPass; +import org.enso.pkg.SourceFile; +import org.enso.text.Hex; + +import com.oracle.truffle.api.TruffleFile; + +final class CacheUtils { + private CacheUtils() { + } + + static Function writeReplace(CompilerContext context) { + return (obj) -> switch (obj) { + case ProcessingPass.Metadata metadata -> metadata.prepareForSerialization(context); + case UUID _ -> null; + case null -> null; + default -> obj; + }; + } + + static Function readResolve(CompilerContext context) { + return (obj) -> switch (obj) { + case ProcessingPass.Metadata metadata -> { + var option = metadata.restoreFromSerialization(context); + if (option.nonEmpty()) { + yield option.get(); + } else { + throw raise(RuntimeException.class, new IOException("Cannot convert " + metadata)); + } + } + case null -> null; + default -> obj; + }; + } + + /** + * Returns a default hashing algorithm used for Enso caches. + * + * @return digest used for computing hashes + */ + private static MessageDigest messageDigest() { + try { + return MessageDigest.getInstance("SHA-1"); + } catch (NoSuchAlgorithmException ex) { + throw raise(RuntimeException.class, ex); + } + } + + /** + * Computes digest from an array of bytes using a default hashing algorithm. + * + * @param bytes bytes for which hash will be computed + * @return string representation of bytes' hash + */ + static String computeDigestFromBytes(ByteBuffer bytes) { + var sha = messageDigest(); + sha.update(bytes); + return Hex.toHexString(sha.digest()); + } + + /** + * Computes digest from package sources using a default hashing algorithm. + * + * @param pkgSources the list of package sources + * @return string representation of bytes' hash + */ + static final String computeDigestOfLibrarySources( + List> pkgSources + ) { + pkgSources.sort(Comparator.comparing(o -> o.qualifiedName().toString())); + + try { + var digest = messageDigest(); + for (var source : pkgSources) { + digest.update(source.file().readAllBytes()); + } + return Hex.toHexString(digest.digest()); + } catch (IOException ex) { + throw raise(RuntimeException.class, ex); + } + } + + @SuppressWarnings("unchecked") + static T raise(Class cls, Exception e) throws T { + throw (T) e; + } + +} diff --git a/engine/runtime/src/main/java/org/enso/interpreter/caches/ImportExportCache.java b/engine/runtime/src/main/java/org/enso/interpreter/caches/ImportExportCache.java index 2299e24ad122..526714353a78 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/caches/ImportExportCache.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/caches/ImportExportCache.java @@ -8,6 +8,7 @@ import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.List; import java.util.Optional; import java.util.logging.Level; @@ -22,62 +23,84 @@ import org.enso.pkg.QualifiedName; import org.enso.pkg.SourceFile; import org.openide.util.lookup.ServiceProvider; -import scala.Option; -import scala.Tuple2; -import scala.collection.immutable.Map; @Persistable(clazz = QualifiedName.class, id = 30300) public final class ImportExportCache - extends Cache { + implements Cache.Spi { private final LibraryName libraryName; - public ImportExportCache(LibraryName libraryName) { - super(Level.FINEST, libraryName.toString(), true, false); + private ImportExportCache(LibraryName libraryName) { this.libraryName = libraryName; - this.entryName = libraryName.name(); - this.dataSuffix = bindingsCacheDataExtension; - this.metadataSuffix = bindingsCacheMetadataExtension; + } + + public static Cache create( + LibraryName libraryName) { + var impl = new ImportExportCache(libraryName); + return Cache.create(impl, Level.FINEST, libraryName.toString(), true, false); + } + + @Override + public String metadataSuffix() { + return bindingsCacheMetadataExtension; + } + + @Override + public String dataSuffix() { + return bindingsCacheDataExtension; + } + + @Override + public String entryName() { + return libraryName.name(); } @Override - protected byte[] metadata(String sourceDigest, String blobDigest, CachedBindings entry) + public byte[] metadata(String sourceDigest, String blobDigest, CachedBindings entry) throws IOException { return new Metadata(sourceDigest, blobDigest).toBytes(); } @Override - protected CachedBindings deserialize( - EnsoContext context, byte[] data, Metadata meta, TruffleLogger logger) + public byte[] serialize(EnsoContext context, CachedBindings entry) throws IOException { + var arr = + Persistance.write( + entry.bindings(), CacheUtils.writeReplace(context.getCompiler().context())); + return arr; + } + + @Override + public CachedBindings deserialize( + EnsoContext context, ByteBuffer data, Metadata meta, TruffleLogger logger) throws ClassNotFoundException, IOException, ClassNotFoundException { - var ref = Persistance.read(data, null); + var ref = Persistance.read(data, CacheUtils.readResolve(context.getCompiler().context())); var bindings = ref.get(MapToBindings.class); return new CachedBindings(libraryName, bindings, Optional.empty()); } @Override - protected Optional metadataFromBytes(byte[] bytes, TruffleLogger logger) + public Optional metadataFromBytes(byte[] bytes, TruffleLogger logger) throws IOException { return Optional.of(Metadata.read(bytes)); } @Override - protected Optional computeDigest(CachedBindings entry, TruffleLogger logger) { - return entry.sources().map(sources -> computeDigestOfLibrarySources(sources, logger)); + public Optional computeDigest(CachedBindings entry, TruffleLogger logger) { + return entry.sources().map(sources -> CacheUtils.computeDigestOfLibrarySources(sources)); } @Override @SuppressWarnings("unchecked") - protected Optional computeDigestFromSource(EnsoContext context, TruffleLogger logger) { + public Optional computeDigestFromSource(EnsoContext context, TruffleLogger logger) { return context .getPackageRepository() .getPackageForLibraryJava(libraryName) - .map(pkg -> computeDigestOfLibrarySources(pkg.listSourcesJava(), logger)); + .map(pkg -> CacheUtils.computeDigestOfLibrarySources(pkg.listSourcesJava())); } @Override @SuppressWarnings("unchecked") - protected Optional getCacheRoots(EnsoContext context) { + public Optional getCacheRoots(EnsoContext context) { return context .getPackageRepository() .getPackageForLibraryJava(libraryName) @@ -104,56 +127,43 @@ protected Optional getCacheRoots(EnsoContext context) { } @Override - protected byte[] serialize(EnsoContext context, CachedBindings entry) throws IOException { - var arr = Persistance.write(entry.bindings(), null); - return arr; + public String sourceHash(Metadata meta) { + return meta.sourceHash(); + } + + @Override + public String blobHash(Metadata meta) { + return meta.blobHash(); } public static final class MapToBindings { - private final Map> entries; + private final java.util.Map entries; - public MapToBindings(Map> entries) { + public MapToBindings(java.util.Map entries) { this.entries = entries; } - public Option findForModule(QualifiedName moduleName) { - var ref = entries.get(moduleName); - if (ref.isEmpty()) { - return Option.empty(); - } - return Option.apply(ref.get().get(BindingsMap.class)); + public org.enso.compiler.core.ir.Module findForModule(QualifiedName moduleName) { + return entries.get(moduleName); } } @ServiceProvider(service = Persistance.class) public static final class PersistMapToBindings extends Persistance { public PersistMapToBindings() { - super(MapToBindings.class, false, 364); + super(MapToBindings.class, false, 3642); } @Override protected void writeObject(MapToBindings obj, Output out) throws IOException { - out.writeInt(obj.entries.size()); - var it = obj.entries.iterator(); - while (it.hasNext()) { - var e = it.next(); - out.writeInline(QualifiedName.class, e._1()); - out.writeObject(e._2().get(BindingsMap.class)); - } + out.writeInline(java.util.Map.class, obj.entries); } @Override @SuppressWarnings("unchecked") protected MapToBindings readObject(Input in) throws IOException, ClassNotFoundException { - var size = in.readInt(); - var b = Map.newBuilder(); - b.sizeHint(size); - while (size-- > 0) { - var name = in.readInline(QualifiedName.class); - var value = in.readReference(BindingsMap.class); - b.addOne(Tuple2.apply(name, value)); - } - return new MapToBindings((Map) b.result()); + var map = in.readInline(java.util.Map.class); + return new MapToBindings(map); } } @@ -162,7 +172,7 @@ public static record CachedBindings( MapToBindings bindings, Optional>> sources) {} - public record Metadata(String sourceHash, String blobHash) implements Cache.Metadata { + public record Metadata(String sourceHash, String blobHash) { byte[] toBytes() throws IOException { try (var os = new ByteArrayOutputStream(); var dos = new DataOutputStream(os)) { diff --git a/engine/runtime/src/main/java/org/enso/interpreter/caches/ModuleCache.java b/engine/runtime/src/main/java/org/enso/interpreter/caches/ModuleCache.java index 896733e8a321..e65bb3b24476 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/caches/ModuleCache.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/caches/ModuleCache.java @@ -8,66 +8,74 @@ import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Optional; -import java.util.UUID; import java.util.logging.Level; import org.apache.commons.lang3.StringUtils; import org.enso.compiler.core.ir.Module; -import org.enso.compiler.core.ir.ProcessingPass; import org.enso.interpreter.runtime.EnsoContext; import org.enso.interpreter.runtime.builtin.Builtins; import org.enso.persist.Persistance; import org.enso.polyglot.CompilationStage; -public final class ModuleCache extends Cache { - +public final class ModuleCache + implements Cache.Spi { private final org.enso.interpreter.runtime.Module module; - public ModuleCache(org.enso.interpreter.runtime.Module module) { - super(Level.FINEST, module.getName().toString(), true, false); + private ModuleCache(org.enso.interpreter.runtime.Module module) { this.module = module; - this.entryName = module.getName().item(); - this.dataSuffix = irCacheDataExtension; - this.metadataSuffix = irCacheMetadataExtension; + } + + public static Cache create( + org.enso.interpreter.runtime.Module module) { + var mc = new ModuleCache(module); + return Cache.create(mc, Level.FINEST, module.getName().toString(), true, false); } @Override - protected byte[] metadata(String sourceDigest, String blobDigest, CachedModule entry) + public String metadataSuffix() { + return irCacheMetadataExtension; + } + + @Override + public String dataSuffix() { + return irCacheDataExtension; + } + + @Override + public String entryName() { + return module.getName().item(); + } + + @Override + public byte[] metadata(String sourceDigest, String blobDigest, CachedModule entry) throws IOException { return new Metadata(sourceDigest, blobDigest, entry.compilationStage().toString()).toBytes(); } @Override - protected CachedModule deserialize( - EnsoContext context, byte[] data, Metadata meta, TruffleLogger logger) + public byte[] serialize(EnsoContext context, CachedModule entry) throws IOException { + var arr = + Persistance.write( + entry.moduleIR(), CacheUtils.writeReplace(context.getCompiler().context())); + return arr; + } + + @Override + public CachedModule deserialize( + EnsoContext context, ByteBuffer data, Metadata meta, TruffleLogger logger) throws ClassNotFoundException, IOException, ClassNotFoundException { - var ref = - Persistance.read( - data, - (obj) -> - switch (obj) { - case ProcessingPass.Metadata metadata -> { - var option = metadata.restoreFromSerialization(context.getCompiler().context()); - if (option.nonEmpty()) { - yield option.get(); - } else { - throw raise( - RuntimeException.class, new IOException("Cannot convert " + metadata)); - } - } - case null -> null; - default -> obj; - }); + var ref = Persistance.read(data, CacheUtils.readResolve(context.getCompiler().context())); var mod = ref.get(Module.class); return new CachedModule( mod, CompilationStage.valueOf(meta.compilationStage()), module.getSource()); } @Override - protected Optional metadataFromBytes(byte[] bytes, TruffleLogger logger) + public Optional metadataFromBytes(byte[] bytes, TruffleLogger logger) throws IOException { return Optional.of(Metadata.read(bytes)); } @@ -80,29 +88,29 @@ private Optional computeDigestOfModuleSources(Source source) { } else { sourceBytes = source.getCharacters().toString().getBytes(StandardCharsets.UTF_8); } - return Optional.of(computeDigestFromBytes(sourceBytes)); + return Optional.of(CacheUtils.computeDigestFromBytes(ByteBuffer.wrap(sourceBytes))); } else { return Optional.empty(); } } @Override - protected Optional computeDigest(CachedModule entry, TruffleLogger logger) { + public Optional computeDigest(CachedModule entry, TruffleLogger logger) { return computeDigestOfModuleSources(entry.source()); } @Override - protected Optional computeDigestFromSource(EnsoContext context, TruffleLogger logger) { + public Optional computeDigestFromSource(EnsoContext context, TruffleLogger logger) { try { return computeDigestOfModuleSources(module.getSource()); } catch (IOException e) { - logger.log(logLevel, "failed to retrieve the source of " + module.getName(), e); + logger.log(Level.FINEST, "failed to retrieve the source of " + module.getName(), e); return Optional.empty(); } } @Override - protected Optional getCacheRoots(EnsoContext context) { + public Optional getCacheRoots(EnsoContext context) { if (module != context.getBuiltins().getModule()) { return context .getPackageOf(module.getSourceFile()) @@ -147,25 +155,18 @@ protected Optional getCacheRoots(EnsoContext context) { } @Override - protected byte[] serialize(EnsoContext context, CachedModule entry) throws IOException { - var arr = - Persistance.write( - entry.moduleIR(), - (obj) -> - switch (obj) { - case ProcessingPass.Metadata metadata -> metadata.prepareForSerialization( - context.getCompiler().context()); - case UUID uuid -> null; - case null -> null; - default -> obj; - }); - return arr; + public String sourceHash(Metadata meta) { + return meta.sourceHash(); + } + + @Override + public String blobHash(Metadata meta) { + return meta.blobHash(); } public record CachedModule(Module moduleIR, CompilationStage compilationStage, Source source) {} - public record Metadata(String sourceHash, String blobHash, String compilationStage) - implements Cache.Metadata { + public record Metadata(String sourceHash, String blobHash, String compilationStage) { byte[] toBytes() throws IOException { try (var os = new ByteArrayOutputStream(); var dos = new DataOutputStream(os)) { @@ -187,9 +188,4 @@ static Metadata read(byte[] arr) throws IOException { private static final String irCacheDataExtension = ".ir"; private static final String irCacheMetadataExtension = ".meta"; - - @SuppressWarnings("unchecked") - private static T raise(Class cls, Exception e) throws T { - throw (T) e; - } } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/caches/SuggestionsCache.java b/engine/runtime/src/main/java/org/enso/interpreter/caches/SuggestionsCache.java index 92088ee2d052..ea31635a0155 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/caches/SuggestionsCache.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/caches/SuggestionsCache.java @@ -8,9 +8,11 @@ import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; +import java.io.InputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.Serializable; +import java.nio.ByteBuffer; import java.util.List; import java.util.Optional; import java.util.logging.Level; @@ -21,32 +23,55 @@ import org.enso.polyglot.Suggestion; public final class SuggestionsCache - extends Cache { - + implements Cache.Spi { private static final String SUGGESTIONS_CACHE_DATA_EXTENSION = ".suggestions"; private static final String SUGGESTIONS_CACHE_METADATA_EXTENSION = ".suggestions.meta"; final LibraryName libraryName; - public SuggestionsCache(LibraryName libraryName) { - super(Level.FINEST, libraryName.toString(), true, false); + private SuggestionsCache(LibraryName libraryName) { this.libraryName = libraryName; - this.entryName = libraryName.name(); - this.dataSuffix = SUGGESTIONS_CACHE_DATA_EXTENSION; - this.metadataSuffix = SUGGESTIONS_CACHE_METADATA_EXTENSION; + } + + public static Cache create( + LibraryName libraryName) { + var impl = new SuggestionsCache(libraryName); + return Cache.create(impl, Level.FINEST, libraryName.toString(), true, false); + } + + @Override + public String metadataSuffix() { + return SUGGESTIONS_CACHE_METADATA_EXTENSION; } @Override - protected byte[] metadata(String sourceDigest, String blobDigest, CachedSuggestions entry) + public String dataSuffix() { + return SUGGESTIONS_CACHE_DATA_EXTENSION; + } + + @Override + public String entryName() { + return libraryName.name(); + } + + @Override + public byte[] metadata(String sourceDigest, String blobDigest, CachedSuggestions entry) throws IOException { return new Metadata(sourceDigest, blobDigest).toBytes(); } @Override - protected CachedSuggestions deserialize( - EnsoContext context, byte[] data, Metadata meta, TruffleLogger logger) - throws ClassNotFoundException, ClassNotFoundException, IOException { - try (var stream = new ObjectInputStream(new ByteArrayInputStream(data))) { + public CachedSuggestions deserialize( + EnsoContext context, ByteBuffer data, Metadata meta, TruffleLogger logger) + throws ClassNotFoundException, IOException { + class BufferInputStream extends InputStream { + @Override + public int read() throws IOException { + return data.get() & 0xff; + } + } + + try (var stream = new ObjectInputStream(new BufferInputStream())) { if (stream.readObject() instanceof Suggestions suggestions) { return new CachedSuggestions(libraryName, suggestions, Optional.empty()); } else { @@ -57,26 +82,26 @@ protected CachedSuggestions deserialize( } @Override - protected Optional metadataFromBytes(byte[] bytes, TruffleLogger logger) + public Optional metadataFromBytes(byte[] bytes, TruffleLogger logger) throws IOException { return Optional.of(Metadata.read(bytes)); } @Override - protected Optional computeDigest(CachedSuggestions entry, TruffleLogger logger) { - return entry.getSources().map(sources -> computeDigestOfLibrarySources(sources, logger)); + public Optional computeDigest(CachedSuggestions entry, TruffleLogger logger) { + return entry.getSources().map(sources -> CacheUtils.computeDigestOfLibrarySources(sources)); } @Override - protected Optional computeDigestFromSource(EnsoContext context, TruffleLogger logger) { + public Optional computeDigestFromSource(EnsoContext context, TruffleLogger logger) { return context .getPackageRepository() .getPackageForLibraryJava(libraryName) - .map(pkg -> computeDigestOfLibrarySources(pkg.listSourcesJava(), logger)); + .map(pkg -> CacheUtils.computeDigestOfLibrarySources(pkg.listSourcesJava())); } @Override - protected Optional getCacheRoots(EnsoContext context) { + public Optional getCacheRoots(EnsoContext context) { return context .getPackageRepository() .getPackageForLibraryJava(libraryName) @@ -103,7 +128,7 @@ protected Optional getCacheRoots(EnsoContext context) { } @Override - protected byte[] serialize(EnsoContext context, CachedSuggestions entry) throws IOException { + public byte[] serialize(EnsoContext context, CachedSuggestions entry) throws IOException { var byteStream = new ByteArrayOutputStream(); try (ObjectOutputStream stream = new ObjectOutputStream(byteStream)) { stream.writeObject(entry.getSuggestionsObjectToSerialize()); @@ -111,6 +136,16 @@ protected byte[] serialize(EnsoContext context, CachedSuggestions entry) throws return byteStream.toByteArray(); } + @Override + public String sourceHash(Metadata meta) { + return meta.sourceHash(); + } + + @Override + public String blobHash(Metadata meta) { + return meta.blobHash(); + } + // Suggestions class is not a record because of a Frgaal bug leading to invalid compilation error. public static final class Suggestions implements Serializable { @@ -160,7 +195,7 @@ public List getSuggestions() { } } - record Metadata(String sourceHash, String blobHash) implements Cache.Metadata { + record Metadata(String sourceHash, String blobHash) { byte[] toBytes() throws IOException { try (var os = new ByteArrayOutputStream(); var dos = new DataOutputStream(os)) { diff --git a/engine/runtime/src/main/java/org/enso/interpreter/runtime/Module.java b/engine/runtime/src/main/java/org/enso/interpreter/runtime/Module.java index d05122db51e1..cfde6757d1a6 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/runtime/Module.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/runtime/Module.java @@ -27,6 +27,7 @@ import org.enso.compiler.context.SimpleUpdate; import org.enso.compiler.core.IR; import org.enso.compiler.core.ir.Expression; +import org.enso.interpreter.caches.Cache; import org.enso.interpreter.caches.ModuleCache; import org.enso.interpreter.node.callable.dispatch.CallOptimiserNode; import org.enso.interpreter.node.callable.dispatch.LoopingCallOptimiserNode; @@ -60,7 +61,7 @@ public final class Module implements EnsoObject { private org.enso.compiler.core.ir.Module ir; private Map uuidsMap; private QualifiedName name; - private final ModuleCache cache; + private final Cache cache; private boolean wasLoadedFromCache; private final boolean synthetic; @@ -86,7 +87,7 @@ public Module(QualifiedName name, Package pkg, TruffleFile sourceFi this.sources = ModuleSources.NONE.newWith(sourceFile); this.pkg = pkg; this.name = name; - this.cache = new ModuleCache(this); + this.cache = ModuleCache.create(this); this.wasLoadedFromCache = false; this.synthetic = false; } @@ -103,7 +104,7 @@ public Module(QualifiedName name, Package pkg, String literalSource this.sources = ModuleSources.NONE.newWith(Rope.apply(literalSource)); this.pkg = pkg; this.name = name; - this.cache = new ModuleCache(this); + this.cache = ModuleCache.create(this); this.wasLoadedFromCache = false; this.patchedValues = new PatchedModuleValues(this); this.synthetic = false; @@ -121,7 +122,7 @@ public Module(QualifiedName name, Package pkg, Rope literalSource) this.sources = ModuleSources.NONE.newWith(literalSource); this.pkg = pkg; this.name = name; - this.cache = new ModuleCache(this); + this.cache = ModuleCache.create(this); this.wasLoadedFromCache = false; this.patchedValues = new PatchedModuleValues(this); this.synthetic = false; @@ -142,7 +143,7 @@ private Module( this.scope = new ModuleScope(this); this.pkg = pkg; this.compilationStage = synthetic ? CompilationStage.INITIAL : CompilationStage.AFTER_CODEGEN; - this.cache = new ModuleCache(this); + this.cache = ModuleCache.create(this); this.wasLoadedFromCache = false; this.synthetic = synthetic; } @@ -522,7 +523,7 @@ public boolean isInteractive() { /** * @return the cache for this module */ - public ModuleCache getCache() { + public Cache getCache() { return cache; } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/runtime/SerializationPool.java b/engine/runtime/src/main/java/org/enso/interpreter/runtime/SerializationPool.java new file mode 100644 index 000000000000..d8f607033654 --- /dev/null +++ b/engine/runtime/src/main/java/org/enso/interpreter/runtime/SerializationPool.java @@ -0,0 +1,206 @@ +package org.enso.interpreter.runtime; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import org.enso.pkg.QualifiedName; + +/** + * Manages threading aspects of serialization. The goal of {@code SerializationPool} is to + * encapsulate working with threads: + * + *
    + *
  • serialization is done asychronously in a single background thread + *
  • deserialization is done synchronously and tries to wait for possible background work to + * finish + *
+ * + * It is good to keep in mind, that serialization isn't the primary goal while Enso program is + * running. When a program is running as much of the CPU time should be dedicated to compilation and + * execution. Only when the Enso program execution is over, flushing the pending caches becomes a + * priority. Future rewrites of this class may optimize towards such direction. + */ +final class SerializationPool { + private final TruffleCompilerContext context; + + /** + * A set of the modules that are currently being serialized. + * + *

This set is accessed concurrently. This is safe as it is backed by a [[ConcurrentHashMap]] + * and is wrapped with the scala [[mutable.Set]] interface. + */ + private final Map isSerializing = new ConcurrentHashMap<>(); + + /** + * A map of the modules awaiting serialization to their associated tasks + * + *

This map is accessed concurrently. + */ + private final Map> isWaitingForSerialization = new ConcurrentHashMap<>(); + + /** The thread pool that handles serialization. */ + private final ExecutorService pool; + + /** all associated threads */ + private final Set threads = Collections.synchronizedSet(new HashSet<>()); + + SerializationPool(TruffleCompilerContext context) { + this.context = context; + this.pool = + Executors.newSingleThreadExecutor( + (r) -> { + var t = context.createSystemThread(r); + t.setName("SerializationPool background thread"); + threads.add(t); + return t; + }); + } + + /** + * @return `true` if there are remaining serialization jobs, `false` otherwise + */ + private boolean hasJobsRemaining() { + synchronized (isWaitingForSerialization) { + return !isWaitingForSerialization.isEmpty() || !isSerializing.isEmpty(); + } + } + + /** + * Performs shutdown actions for the serialization manager. + * + * @param waitForPendingJobCompletion whether or not shutdown should wait for pending + * serialization jobs + */ + void shutdown(boolean waitForPendingJobCompletion) throws InterruptedException { + if (!pool.isShutdown()) { + if (waitForPendingJobCompletion && this.hasJobsRemaining()) { + int waitingCount; + int jobCount; + synchronized (isWaitingForSerialization) { + waitingCount = isWaitingForSerialization.size(); + jobCount = waitingCount + isSerializing.size(); + } + context.logSerializationManager( + Level.FINE, "Waiting for #{0} serialization jobs to complete.", jobCount); + + // Bound the waiting loop + int maxCount = 60; + int counter = 0; + while (this.hasJobsRemaining() && counter < maxCount) { + counter += 1; + synchronized (isWaitingForSerialization) { + isWaitingForSerialization.wait(1000); + } + } + } + + pool.shutdown(); + + // Bound the waiting loop + int maxCount = 10; + int counter = 0; + while (!pool.isTerminated() && counter < maxCount) { + pool.awaitTermination(500, TimeUnit.MILLISECONDS); + counter += 1; + } + + pool.shutdownNow(); + context.logSerializationManager(Level.FINE, "Serialization manager shutdownNow."); + + for (var t : threads.toArray(new Thread[0])) { + context.logSerializationManager(Level.FINEST, "Serialization manager has been shut down."); + t.join(); + } + context.logSerializationManager(Level.FINE, "Serialization manager has been shut down."); + } + } + + boolean isWaitingForSerialization(QualifiedName key) { + synchronized (isWaitingForSerialization) { + return isWaitingForSerialization.containsKey(key); + } + } + + /** + * Checks if the provided key is waiting for serialization. + * + * @param key the library to check + * @return {@code true} if there is a pending serialization for given key, {@code false} otherwise + */ + boolean abort(QualifiedName key) { + synchronized (isWaitingForSerialization) { + if (isWaitingForSerialization(key)) { + var prev = isWaitingForSerialization.remove(key); + isWaitingForSerialization.notifyAll(); + if (prev != null) { + return prev.cancel(false); + } else { + return false; + } + } else { + return false; + } + } + } + + void startSerializing(QualifiedName name) { + synchronized (isWaitingForSerialization) { + isWaitingForSerialization.remove(name); + isSerializing.put(name, true); + isWaitingForSerialization.notifyAll(); + } + } + + /** + * Sets the {@code key} as finished with serialization. + * + * @param name the key to set as having finished serialization + */ + void finishSerializing(QualifiedName name) { + synchronized (isWaitingForSerialization) { + isSerializing.remove(name); + isWaitingForSerialization.notifyAll(); + } + } + + Future submitTask(Callable task, boolean useThreadPool, QualifiedName key) { + if (useThreadPool) { + synchronized (isWaitingForSerialization) { + var future = pool.submit(task); + isWaitingForSerialization.put(key, future); + return future; + } + } else { + try { + return CompletableFuture.completedFuture(task.call()); + } catch (Throwable e) { + context.logSerializationManager( + Level.WARNING, "Serialization task failed for [" + key + "].", e); + return CompletableFuture.failedFuture(e); + } + } + } + + /** + * Waits for a given key to finish serialization, if there is one pending. + * + * @param name the key + * @throws InterruptedException if the wait is interrupted + */ + void waitWhileSerializing(QualifiedName name) throws InterruptedException { + synchronized (isWaitingForSerialization) { + while (isSerializing.containsKey(name)) { + isWaitingForSerialization.wait(100); + } + } + } +} diff --git a/engine/runtime/src/main/java/org/enso/interpreter/runtime/ThreadExecutors.java b/engine/runtime/src/main/java/org/enso/interpreter/runtime/ThreadExecutors.java index a5454e6767c4..8153776e9deb 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/runtime/ThreadExecutors.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/runtime/ThreadExecutors.java @@ -1,5 +1,6 @@ package org.enso.interpreter.runtime; +import java.util.Collections; import java.util.Map; import java.util.WeakHashMap; import java.util.concurrent.ExecutorService; @@ -11,7 +12,9 @@ final class ThreadExecutors { private final EnsoContext context; - private final Map pools = new WeakHashMap<>(); + private final Map pools = + Collections.synchronizedMap(new WeakHashMap<>()); + private final Map threads = Collections.synchronizedMap(new WeakHashMap<>()); ThreadExecutors(EnsoContext context) { this.context = context; @@ -30,6 +33,22 @@ final ExecutorService newFixedThreadPool(int cnt, String name, boolean systemThr } public void shutdown() { + synchronized (pools) { + shutdownPools(); + } + synchronized (threads) { + for (var t : threads.keySet()) { + try { + t.join(); + } catch (InterruptedException ex) { + context.getLogger().log(Level.WARNING, "Cannot shutdown {0} thread", t.getName()); + } + } + } + } + + private void shutdownPools() { + assert Thread.holdsLock(pools); var it = pools.entrySet().iterator(); while (it.hasNext()) { var next = it.next(); @@ -61,6 +80,7 @@ private final class Factory implements ThreadFactory { public Thread newThread(Runnable r) { var thread = context.createThread(system, r); thread.setName(prefix + "-" + counter.incrementAndGet()); + threads.put(thread, thread.getName()); return thread; } } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/runtime/TruffleCompilerContext.java b/engine/runtime/src/main/java/org/enso/interpreter/runtime/TruffleCompilerContext.java index 4e79d6c81833..e5d7500b8d24 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/runtime/TruffleCompilerContext.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/runtime/TruffleCompilerContext.java @@ -1,13 +1,20 @@ package org.enso.interpreter.runtime; +import static org.enso.interpreter.util.ScalaConversions.cons; +import static org.enso.interpreter.util.ScalaConversions.nil; + import com.oracle.truffle.api.TruffleFile; import com.oracle.truffle.api.TruffleLogger; import com.oracle.truffle.api.source.Source; import java.io.IOException; import java.io.PrintStream; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; import java.util.function.Consumer; import java.util.logging.Level; @@ -15,7 +22,10 @@ import org.enso.compiler.PackageRepository; import org.enso.compiler.Passes; import org.enso.compiler.context.CompilerContext; +import org.enso.compiler.context.ExportsBuilder; +import org.enso.compiler.context.ExportsMap; import org.enso.compiler.context.FreshNameSupply; +import org.enso.compiler.context.SuggestionBuilder; import org.enso.compiler.core.ir.Diagnostic; import org.enso.compiler.core.ir.IdentifiedLocation; import org.enso.compiler.data.BindingsMap; @@ -23,15 +33,18 @@ import org.enso.compiler.pass.analyse.BindingAnalysis$; import org.enso.editions.LibraryName; import org.enso.interpreter.caches.Cache; +import org.enso.interpreter.caches.ImportExportCache; +import org.enso.interpreter.caches.ImportExportCache.MapToBindings; import org.enso.interpreter.caches.ModuleCache; +import org.enso.interpreter.caches.SuggestionsCache; import org.enso.interpreter.runtime.type.Types; import org.enso.interpreter.runtime.util.DiagnosticFormatter; import org.enso.pkg.Package; import org.enso.pkg.QualifiedName; import org.enso.polyglot.CompilationStage; import org.enso.polyglot.LanguageInfo; +import org.enso.polyglot.Suggestion; import org.enso.polyglot.data.TypeGraph; -import scala.Option; final class TruffleCompilerContext implements CompilerContext { @@ -39,13 +52,13 @@ final class TruffleCompilerContext implements CompilerContext { private final TruffleLogger loggerCompiler; private final TruffleLogger loggerSerializationManager; private final RuntimeStubsGenerator stubsGenerator; - private final SerializationManager serializationManager; + private final SerializationPool serializationPool; TruffleCompilerContext(EnsoContext context) { this.context = context; this.loggerCompiler = context.getLogger(Compiler.class); - this.loggerSerializationManager = context.getLogger(SerializationManager.class); - this.serializationManager = new SerializationManager(this); + this.loggerSerializationManager = context.getLogger(SerializationPool.class); + this.serializationPool = new SerializationPool(this); this.stubsGenerator = new RuntimeStubsGenerator(context.getBuiltins()); } @@ -74,10 +87,6 @@ public PackageRepository getPackageRepository() { return context.getPackageRepository(); } - final SerializationManager getSerializationManager() { - return serializationManager; - } - @Override public PrintStream getErr() { return context.getErr(); @@ -205,10 +214,7 @@ public void initializeBuiltinsIr( builtins.initializeBuiltinsSource(); if (irCachingEnabled) { - if (serializationManager.deserialize(compiler, builtinsModule) instanceof Option op - && op.isDefined() - && op.get() instanceof Boolean b - && b) { + if (deserializeModule(compiler, builtinsModule)) { // Ensure that builtins doesn't try and have codegen run on it. updateModule(builtinsModule, u -> u.compilationStage(CompilationStage.AFTER_CODEGEN)); } else { @@ -219,7 +225,7 @@ public void initializeBuiltinsIr( } if (irCachingEnabled && !wasLoadedFromCache(builtinsModule)) { - serializationManager.serializeModule(compiler, builtinsModule, true, true); + serializeModule(compiler, builtinsModule, true, true); } } } @@ -278,29 +284,362 @@ public String formatDiagnostic( @Override public Future serializeLibrary( Compiler compiler, LibraryName libraryName, boolean useGlobalCacheLocations) { - Object res = - serializationManager.serializeLibrary(compiler, libraryName, useGlobalCacheLocations); - return (Future) res; + logSerializationManager(Level.INFO, "Requesting serialization for library [{0}].", libraryName); + + var task = doSerializeLibrary(compiler, libraryName, useGlobalCacheLocations); + + return serializationPool.submitTask( + task, isCreateThreadAllowed(), toQualifiedName(libraryName)); } + /** + * Requests that `module` be serialized. + * + *

This method will attempt to schedule the provided module and IR for serialization regardless + * of whether or not it is appropriate to do so. If there are preconditions needed for + * serialization, these should be checked before calling this method. + * + *

In addition, this method handles breaking links between modules contained in the IR to + * ensure safe serialization. + * + *

It is responsible for taking a "snapshot" of the relevant module state at the point at which + * serialization is requested. This is due to the fact that serialization happens in a separate + * thread and the module may be mutated beneath it. + * + * @param module the module to serialize + * @param useGlobalCacheLocations if true, will use global caches location, local one otherwise + * @param useThreadPool if true, will perform serialization asynchronously + * @return Future referencing the serialization task. On completion Future will return `true` if + * `module` has been successfully serialized, `false` otherwise + */ @SuppressWarnings("unchecked") @Override public Future serializeModule( - Compiler compiler, CompilerContext.Module module, boolean useGlobalCacheLocations) { - Object res = - serializationManager.serializeModule(compiler, module, useGlobalCacheLocations, true); - return (Future) res; + Compiler compiler, + CompilerContext.Module module, + boolean useGlobalCacheLocations, + boolean useThreadPool) { + if (module.isSynthetic()) { + throw new IllegalStateException( + "Cannot serialize synthetic module [" + module.getName() + "]"); + } + logSerializationManager( + Level.FINE, "Requesting serialization for module [{0}].", module.getName()); + var ir = module.getIr(); + var dupl = + ir.duplicate( + ir.duplicate$default$1(), ir.duplicate$default$2(), ir.duplicate$default$3(), true); + var duplicatedIr = compiler.updateMetadata(ir, dupl); + Source src; + try { + src = module.getSource(); + } catch (IOException ex) { + logSerializationManager(Level.WARNING, "Cannot get source for " + module.getName(), ex); + return CompletableFuture.failedFuture(ex); + } + var task = + doSerializeModule( + ((Module) module).getCache(), + duplicatedIr, + module.getCompilationStage(), + module.getName(), + src, + useGlobalCacheLocations); + return serializationPool.submitTask(task, useThreadPool, module.getName()); + } + + /** + * Create the task that serializes the provided module IR when it is run. + * + * @param cache the cache manager for the module being serialized + * @param ir the IR for the module being serialized + * @param stage the compilation stage of the module + * @param name the name of the module being serialized + * @param source the source of the module being serialized + * @param useGlobalCacheLocations if true, will use global caches location, local one otherwise + * @return the task that serialies the provided `ir` + */ + private Callable doSerializeModule( + Cache cache, + org.enso.compiler.core.ir.Module ir, + CompilationStage stage, + QualifiedName name, + Source source, + boolean useGlobalCacheLocations) { + return () -> { + var pool = serializationPool; + pool.waitWhileSerializing(name); + + logSerializationManager(Level.FINE, "Running serialization for module [{0}].", name); + pool.startSerializing(name); + try { + var fixedStage = + stage.isAtLeast(CompilationStage.AFTER_STATIC_PASSES) + ? CompilationStage.AFTER_STATIC_PASSES + : stage; + var optionallySaved = + saveCache( + cache, + new ModuleCache.CachedModule(ir, fixedStage, source), + useGlobalCacheLocations); + return optionallySaved.isPresent(); + } catch (Throwable e) { + logSerializationManager( + Level.SEVERE, + "Serialization of module `" + name + "` failed: " + e.getMessage() + "`", + e); + throw e; + } finally { + pool.finishSerializing(name); + } + }; } + private final Map known = new HashMap<>(); + @Override public boolean deserializeModule(Compiler compiler, CompilerContext.Module module) { - var result = serializationManager.deserialize(compiler, module); - return result.nonEmpty(); + var level = Level.FINE; + if (module.getPackage() != null) { + var library = module.getPackage().libraryName(); + var bindings = known.get(library); + if (bindings == null) { + try { + var cached = deserializeLibraryBindings(library); + if (cached.isDefined()) { + bindings = cached.get().bindings(); + known.put(library, bindings); + } + } catch (InterruptedException ex) { + // proceed + } + } + if (bindings != null) { + var ir = bindings.findForModule(module.getName()); + loggerSerializationManager.log( + Level.FINE, + "Deserializing module " + module.getName() + " from library: " + (ir != null)); + if (ir != null) { + compiler + .context() + .updateModule( + module, + (u) -> { + u.ir(ir); + u.compilationStage(CompilationStage.AFTER_STATIC_PASSES); + u.loadedFromCache(true); + }); + return true; + } + } + level = "Standard".equals(library.namespace()) ? Level.WARNING : Level.FINE; + } + try { + var result = deserializeModuleDirect(module); + loggerSerializationManager.log( + result ? level : Level.FINE, + "Deserializing module " + module.getName() + " from IR file: " + result); + return result; + } catch (InterruptedException e) { + loggerSerializationManager.log( + Level.WARNING, "Deserializing module " + module.getName() + " from IR file", e); + return false; + } + } + + /** + * Deserializes the requested module from the cache if possible. + * + *

If the requested module is currently being serialized it will wait for completion before + * loading. If the module is queued for serialization it will evict it and not load from the cache + * (this is usually indicative of a programming bug). + * + * @param module the module to deserialize from the cache. + * @return {@code true} if the deserialization succeeded + */ + private boolean deserializeModuleDirect(CompilerContext.Module module) + throws InterruptedException { + var pool = serializationPool; + if (pool.isWaitingForSerialization(module.getName())) { + pool.abort(module.getName()); + return false; + } else { + pool.waitWhileSerializing(module.getName()); + + var loaded = loadCache(((Module) module).getCache()); + if (loaded.isPresent()) { + updateModule( + module, + (u) -> { + u.ir(loaded.get().moduleIR()); + u.compilationStage(loaded.get().compilationStage()); + u.loadedFromCache(true); + }); + logSerializationManager( + Level.FINE, + "Restored IR from cache for module [{0}] at stage [{1}].", + module.getName(), + loaded.get().compilationStage()); + return true; + } else { + logSerializationManager( + Level.FINE, "Unable to load a cache for module [{0}].", module.getName()); + return false; + } + } + } + + @SuppressWarnings("unchecked") + Callable doSerializeLibrary( + Compiler compiler, LibraryName libraryName, boolean useGlobalCacheLocations) { + return () -> { + var pool = serializationPool; + pool.waitWhileSerializing(toQualifiedName(libraryName)); + + logSerializationManager(Level.FINE, "Running serialization for bindings [{0}].", libraryName); + pool.startSerializing(toQualifiedName(libraryName)); + var map = new HashMap(); + var it = context.getPackageRepository().getModulesForLibrary(libraryName); + while (it.nonEmpty()) { + var module = it.head(); + map.put(module.getName(), module.getIr()); + it = + (scala.collection.immutable.List) + it.tail(); + } + var snd = + context + .getPackageRepository() + .getPackageForLibraryJava(libraryName) + .map(x -> x.listSourcesJava()); + + var bindingsCache = + new ImportExportCache.CachedBindings( + libraryName, new ImportExportCache.MapToBindings(map), snd); + try { + boolean result = + doSerializeLibrarySuggestions(compiler, libraryName, useGlobalCacheLocations); + try { + var cache = ImportExportCache.create(libraryName); + var file = saveCache(cache, bindingsCache, useGlobalCacheLocations); + result &= file.isPresent(); + } catch (Throwable e) { + logSerializationManager( + Level.SEVERE, + "Serialization of bindings `" + libraryName + "` failed: " + e.getMessage() + "`", + e); + throw e; + } + return result; + } finally { + pool.finishSerializing(toQualifiedName(libraryName)); + } + }; + } + + private boolean doSerializeLibrarySuggestions( + Compiler compiler, LibraryName libraryName, boolean useGlobalCacheLocations) { + var exportsBuilder = new ExportsBuilder(); + var exportsMap = new ExportsMap(); + var suggestions = new java.util.ArrayList(); + + try { + var libraryModules = context.getPackageRepository().getModulesForLibrary(libraryName); + libraryModules + .flatMap( + module -> { + var sug = + SuggestionBuilder.apply(module, compiler) + .build(module.getName(), module.getIr()) + .toVector() + .filter(Suggestion::isGlobal); + var exports = exportsBuilder.build(module.getName(), module.getIr()); + exportsMap.addAll(module.getName(), exports); + return sug; + }) + .map( + suggestion -> { + var reexport = exportsMap.get(suggestion).map(s -> s.toString()); + return suggestion.withReexport(reexport); + }) + .foreach(suggestions::add); + + var cachedSuggestions = + new SuggestionsCache.CachedSuggestions( + libraryName, + new SuggestionsCache.Suggestions(suggestions), + context + .getPackageRepository() + .getPackageForLibraryJava(libraryName) + .map(p -> p.listSourcesJava())); + var cache = SuggestionsCache.create(libraryName); + var file = saveCache(cache, cachedSuggestions, useGlobalCacheLocations); + return file.isPresent(); + } catch (Throwable e) { + logSerializationManager( + Level.SEVERE, + "Serialization of suggestions `" + libraryName + "` failed: " + e.getMessage() + "`", + e); + e.printStackTrace(); + throw e; + } + } + + public scala.Option> deserializeSuggestions( + LibraryName libraryName) throws InterruptedException { + var option = deserializeSuggestionsImpl(libraryName); + return option.map(s -> s.getSuggestions()); + } + + private scala.Option deserializeSuggestionsImpl( + LibraryName libraryName) throws InterruptedException { + var pool = serializationPool; + if (pool.isWaitingForSerialization(toQualifiedName(libraryName))) { + pool.abort(toQualifiedName(libraryName)); + return scala.Option.empty(); + } else { + pool.waitWhileSerializing(toQualifiedName(libraryName)); + var cache = SuggestionsCache.create(libraryName); + var loaded = loadCache(cache); + if (loaded.isPresent()) { + logSerializationManager(Level.FINE, "Restored suggestions for library [{0}].", libraryName); + return scala.Option.apply(loaded.get()); + } else { + logSerializationManager( + Level.FINE, "Unable to load suggestions for library [{0}].", libraryName); + return scala.Option.empty(); + } + } + } + + scala.Option deserializeLibraryBindings(LibraryName libraryName) + throws InterruptedException { + var pool = serializationPool; + if (pool.isWaitingForSerialization(toQualifiedName(libraryName))) { + pool.abort(toQualifiedName(libraryName)); + return scala.Option.empty(); + } else { + pool.waitWhileSerializing(toQualifiedName(libraryName)); + var cache = ImportExportCache.create(libraryName); + var loaded = loadCache(cache); + if (loaded.isPresent()) { + logSerializationManager(Level.FINE, "Restored bindings for library [{0}].", libraryName); + return scala.Option.apply(loaded.get()); + } else { + logSerializationManager( + Level.FINEST, "Unable to load bindings for library [{0}].", libraryName); + return scala.Option.empty(); + } + } } @Override public void shutdown(boolean waitForPendingJobCompletion) { - serializationManager.shutdown(waitForPendingJobCompletion); + try { + serializationPool.shutdown(waitForPendingJobCompletion); + } catch (InterruptedException ex) { + logSerializationManager(Level.WARNING, ex.getMessage(), ex); + } } private final class ModuleUpdater implements Updater, AutoCloseable { @@ -435,7 +774,7 @@ public List getDirectModulesRefs() { return module.getDirectModulesRefs(); } - public ModuleCache getCache() { + public Cache getCache() { return module.getCache(); } @@ -492,4 +831,9 @@ public String toString() { } private static void emitIOException() throws IOException {} + + private static QualifiedName toQualifiedName(LibraryName libraryName) { + var namespace = cons(libraryName.namespace(), nil()); + return new QualifiedName(namespace, libraryName.name()); + } } diff --git a/engine/runtime/src/main/scala/org/enso/interpreter/runtime/DefaultPackageRepository.scala b/engine/runtime/src/main/scala/org/enso/interpreter/runtime/DefaultPackageRepository.scala index 932e24f9ecd3..6254e79a0545 100644 --- a/engine/runtime/src/main/scala/org/enso/interpreter/runtime/DefaultPackageRepository.scala +++ b/engine/runtime/src/main/scala/org/enso/interpreter/runtime/DefaultPackageRepository.scala @@ -2,7 +2,7 @@ package org.enso.interpreter.runtime import org.enso.compiler.PackageRepository import org.enso.compiler.context.CompilerContext -import org.enso.compiler.data.BindingsMap +import org.enso.compiler.core.ir.{Module => IRModule} import com.oracle.truffle.api.TruffleFile import com.typesafe.scalalogging.Logger import org.apache.commons.lang3.StringUtils @@ -574,19 +574,18 @@ private class DefaultPackageRepository( libraryName: LibraryName, moduleName: QualifiedName, context: CompilerContext - ): Option[BindingsMap] = { + ): Option[IRModule] = { val cache = ensurePackageIsLoaded(libraryName).toOption.flatMap { _ => if (!loadedLibraryBindings.contains(libraryName)) { loadedPackages.get(libraryName).flatten.foreach(loadDependencies(_)) val cachedBindingOption = context .asInstanceOf[TruffleCompilerContext] - .getSerializationManager() .deserializeLibraryBindings(libraryName) loadedLibraryBindings.addOne((libraryName, cachedBindingOption)) } loadedLibraryBindings.get(libraryName) } - cache.flatMap(_.flatMap(_.bindings.findForModule(moduleName))) + cache.flatMap(_.map(_.bindings.findForModule(moduleName))) } private def loadDependencies(pkg: Package[TruffleFile]): Unit = { diff --git a/engine/runtime/src/main/scala/org/enso/interpreter/runtime/SerializationManager.scala b/engine/runtime/src/main/scala/org/enso/interpreter/runtime/SerializationManager.scala deleted file mode 100644 index 4567eabf76d8..000000000000 --- a/engine/runtime/src/main/scala/org/enso/interpreter/runtime/SerializationManager.scala +++ /dev/null @@ -1,694 +0,0 @@ -package org.enso.interpreter.runtime - -import com.oracle.truffle.api.source.Source -import org.enso.compiler.Compiler -import org.enso.compiler.core.Implicits.AsMetadata -import org.enso.compiler.core.ir.{Module => IRModule} -import org.enso.compiler.context.{ExportsBuilder, ExportsMap, SuggestionBuilder} -import org.enso.compiler.context.CompilerContext -import org.enso.compiler.context.CompilerContext.Module -import org.enso.compiler.pass.analyse.BindingAnalysis -import org.enso.editions.LibraryName -import org.enso.pkg.QualifiedName -import org.enso.polyglot.Suggestion -import org.enso.polyglot.CompilationStage -import org.enso.interpreter.caches.ImportExportCache -import org.enso.interpreter.caches.ModuleCache -import org.enso.interpreter.caches.SuggestionsCache - -import java.io.NotSerializableException -import java.util -import java.util.concurrent.{ - Callable, - CompletableFuture, - ConcurrentHashMap, - Future, - LinkedBlockingDeque, - ThreadPoolExecutor, - TimeUnit -} -import java.util.logging.Level - -import scala.collection.mutable -import scala.jdk.OptionConverters.RichOptional - -final class SerializationManager(private val context: TruffleCompilerContext) { - - def this(compiler: Compiler) = { - this(compiler.context.asInstanceOf[TruffleCompilerContext]) - } - - import SerializationManager._ - - /** The debug logging level. */ - private val debugLogLevel = Level.FINE - - /** A set of the modules that are currently being serialized. - * - * This set is accessed concurrently. This is safe as it is backed by a - * [[ConcurrentHashMap]] and is wrapped with the scala [[mutable.Set]] - * interface. - */ - private val isSerializing: mutable.Set[QualifiedName] = buildConcurrentHashSet - - /** A map of the modules awaiting serialization to their associated tasks - * - * This map is accessed concurrently. - */ - private val isWaitingForSerialization = - collection.concurrent.TrieMap[QualifiedName, Future[Boolean]]() - - /** The thread pool that handles serialization. */ - private val pool: ThreadPoolExecutor = new ThreadPoolExecutor( - SerializationManager.startingThreadCount, - SerializationManager.maximumThreadCount, - SerializationManager.threadKeepalive, - TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable](), - (runnable: Runnable) => { - context.createSystemThread(runnable) - } - ) - - // Make sure it is started to avoid races with language shutdown with low job - // count. - if (context.isCreateThreadAllowed) { - pool.prestartAllCoreThreads() - } - - // === Interface ============================================================ - - /** Requests that `module` be serialized. - * - * This method will attempt to schedule the provided module and IR for - * serialization regardless of whether or not it is appropriate to do so. If - * there are preconditions needed for serialization, these should be checked - * before calling this method. - * - * In addition, this method handles breaking links between modules contained - * in the IR to ensure safe serialization. - * - * It is responsible for taking a "snapshot" of the relevant module state at - * the point at which serialization is requested. This is due to the fact - * that serialization happens in a separate thread and the module may be - * mutated beneath it. - * - * @param module the module to serialize - * @param useGlobalCacheLocations if true, will use global caches location, local one otherwise - * @param useThreadPool if true, will perform serialization asynchronously - * @return Future referencing the serialization task. On completion Future will return - * `true` if `module` has been successfully serialized, `false` otherwise - */ - def serializeModule( - compiler: Compiler, - module: Module, - useGlobalCacheLocations: Boolean, - useThreadPool: Boolean = true - ): Future[Boolean] = { - if (module.isSynthetic) { - throw new IllegalStateException( - "Cannot serialize synthetic module [" + module.getName + "]" - ); - } - context.logSerializationManager( - debugLogLevel, - "Requesting serialization for module [{0}].", - module.getName - ) - val duplicatedIr = compiler.updateMetadata( - module.getIr, - module.getIr.duplicate(keepIdentifiers = true) - ) - val task = doSerializeModule( - getCache(module), - duplicatedIr, - module.getCompilationStage, - module.getName, - module.getSource, - useGlobalCacheLocations - ) - if (useThreadPool) { - isWaitingForSerialization.synchronized { - val future = pool.submit(task) - isWaitingForSerialization.put(module.getName, future) - future - } - } else { - try { - CompletableFuture.completedFuture(task.call()) - } catch { - case e: Throwable => - context.logSerializationManager( - debugLogLevel, - s"Serialization task failed in module [${module.getName}].", - e - ) - CompletableFuture.completedFuture(false) - } - } - } - - def serializeLibrary( - compiler: Compiler, - libraryName: LibraryName, - useGlobalCacheLocations: Boolean - ): Future[Boolean] = { - context.logSerializationManager( - Level.INFO, - "Requesting serialization for library [{0}].", - libraryName - ) - - val task: Callable[Boolean] = - doSerializeLibrary(compiler, libraryName, useGlobalCacheLocations) - if (context.isCreateThreadAllowed) { - isWaitingForSerialization.synchronized { - val future = pool.submit(task) - isWaitingForSerialization.put(libraryName.toQualifiedName, future) - future - } - } else { - try { - CompletableFuture.completedFuture(task.call()) - } catch { - case e: Throwable => - context.logSerializationManager( - debugLogLevel, - s"Serialization task failed for library [$libraryName].", - e - ) - CompletableFuture.completedFuture(false) - } - } - } - - private def doSerializeLibrary( - compiler: Compiler, - libraryName: LibraryName, - useGlobalCacheLocations: Boolean - ): Callable[Boolean] = () => { - while (isSerializingLibrary(libraryName)) { - Thread.sleep(100) - } - - context.logSerializationManager( - debugLogLevel, - "Running serialization for bindings [{0}].", - libraryName - ) - startSerializing(libraryName.toQualifiedName) - val bindingsCache = new ImportExportCache.CachedBindings( - libraryName, - new ImportExportCache.MapToBindings( - context - .getPackageRepository() - .getModulesForLibrary(libraryName) - .map { module => - val ir = module.getIr - val bindings = ir.unsafeGetMetadata( - BindingAnalysis, - "Non-parsed module used in ImportResolver" - ) - val abstractBindings = - bindings.prepareForSerialization(compiler.context) - ( - module.getName, - org.enso.persist.Persistance.Reference.of(abstractBindings) - ) - } - .toMap - ), - context - .getPackageRepository() - .getPackageForLibraryJava(libraryName) - .map(_.listSourcesJava()) - ) - try { - val result = - try { - val cache = new ImportExportCache(libraryName) - val file = context.saveCache( - cache, - bindingsCache, - useGlobalCacheLocations - ) - file.isPresent - } catch { - case e: NotSerializableException => - context.logSerializationManager( - Level.SEVERE, - s"Could not serialize bindings [$libraryName].", - e - ) - throw e - case e: Throwable => - context.logSerializationManager( - Level.SEVERE, - s"Serialization of bindings `$libraryName` failed: ${e.getMessage}`", - e - ) - throw e - } - - doSerializeLibrarySuggestions( - compiler, - libraryName, - useGlobalCacheLocations - ) - - result - } finally { - finishSerializing(libraryName.toQualifiedName) - } - } - - private def doSerializeLibrarySuggestions( - compiler: Compiler, - libraryName: LibraryName, - useGlobalCacheLocations: Boolean - ): Boolean = { - val exportsBuilder = new ExportsBuilder - val exportsMap = new ExportsMap - val suggestions = new util.ArrayList[Suggestion]() - - try { - val libraryModules = - context.getPackageRepository().getModulesForLibrary(libraryName) - libraryModules - .flatMap { module => - val suggestions = SuggestionBuilder(module, compiler) - .build(module.getName, module.getIr) - .toVector - .filter(Suggestion.isGlobal) - val exports = exportsBuilder.build(module.getName, module.getIr) - exportsMap.addAll(module.getName, exports) - suggestions - } - .map { suggestion => - val reexport = exportsMap.get(suggestion).map(_.toString) - suggestion.withReexport(reexport) - } - .foreach(suggestions.add) - val cachedSuggestions = - new SuggestionsCache.CachedSuggestions( - libraryName, - new SuggestionsCache.Suggestions(suggestions), - context - .getPackageRepository() - .getPackageForLibraryJava(libraryName) - .map(_.listSourcesJava()) - ) - val cache = new SuggestionsCache(libraryName) - val file = context.saveCache( - cache, - cachedSuggestions, - useGlobalCacheLocations - ) - file.isPresent - } catch { - case e: NotSerializableException => - context.logSerializationManager( - Level.SEVERE, - s"Could not serialize suggestions [$libraryName].", - e - ) - throw e - case e: Throwable => - context.logSerializationManager( - Level.SEVERE, - s"Serialization of suggestions `$libraryName` failed: ${e.getMessage}`", - e - ) - throw e - } - } - - def deserializeSuggestions( - libraryName: LibraryName - ): Option[SuggestionsCache.CachedSuggestions] = { - if (isWaitingForSerialization(libraryName)) { - abort(libraryName) - None - } else { - while (isSerializingLibrary(libraryName)) { - Thread.sleep(100) - } - val cache = new SuggestionsCache(libraryName) - context.loadCache(cache).toScala match { - case result @ Some(_: SuggestionsCache.CachedSuggestions) => - context.logSerializationManager( - Level.FINE, - "Restored suggestions for library [{0}].", - libraryName - ) - result - case None => - context.logSerializationManager( - Level.FINE, - "Unable to load suggestions for library [{0}].", - libraryName - ) - None - } - } - } - - def deserializeLibraryBindings( - libraryName: LibraryName - ): Option[ImportExportCache.CachedBindings] = { - if (isWaitingForSerialization(libraryName)) { - abort(libraryName) - None - } else { - while (isSerializingLibrary(libraryName)) { - Thread.sleep(100) - } - val cache = new ImportExportCache(libraryName) - context.loadCache(cache).toScala match { - case result @ Some(_: ImportExportCache.CachedBindings) => - context.logSerializationManager( - Level.FINE, - "Restored bindings for library [{0}].", - libraryName - ) - result - case _ => - context.logSerializationManager( - Level.FINEST, - "Unable to load bindings for library [{0}].", - libraryName - ) - None - } - - } - } - - /** Deserializes the requested module from the cache if possible. - * - * If the requested module is currently being serialized it will wait for - * completion before loading. If the module is queued for serialization it - * will evict it and not load from the cache (this is usually indicative of a - * programming bug). - * - * @param module the module to deserialize from the cache. - * @return [[Some]] when deserialization was successful, with `true` for - * relinking being successful and `false` otherwise. [[None]] if the - * cache could not be deserialized. - */ - def deserialize( - compiler: Compiler, - module: Module - ): Option[Boolean] = { - compiler.getClass() - if (isWaitingForSerialization(module)) { - abort(module) - None - } else { - while (isSerializingModule(module.getName)) { - Thread.sleep(100) - } - - context.loadCache(getCache(module)).toScala match { - case Some(loadedCache) => - context.updateModule( - module, - { u => - u.ir(loadedCache.moduleIR) - u.compilationStage(loadedCache.compilationStage) - u.loadedFromCache(true) - } - ) - context.logSerializationManager( - debugLogLevel, - "Restored IR from cache for module [{0}] at stage [{1}].", - module.getName, - loadedCache.compilationStage - ) - Some(true) - case None => - context.logSerializationManager( - debugLogLevel, - "Unable to load a cache for module [{0}].", - module.getName - ) - None - } - } - } - - /** Checks if the provided module is in the process of being serialized. - * - * @param module the module to check - * @return `true` if `module` is currently being serialized, `false` - * otherwise - */ - private def isSerializingModule(module: QualifiedName): Boolean = { - isSerializing.contains(module) - } - - private def isSerializingLibrary(library: LibraryName): Boolean = { - isSerializing.contains(library.toQualifiedName) - } - - private def isWaitingForSerialization(name: QualifiedName): Boolean = { - isWaitingForSerialization.synchronized { - isWaitingForSerialization.contains(name) - } - } - - /** Checks if the provided module is waiting for serialization. - * - * @param module the module to check - * @return `true` if `module` is waiting for serialization, `false` otherwise - */ - private def isWaitingForSerialization( - module: Module - ): Boolean = { - isWaitingForSerialization(module.getName) - } - - /** Checks if the provided library's bindings are waiting for serialization. - * - * @param library the library to check - * @return `true` if `library` is waiting for serialization, `false` otherwise - */ - private def isWaitingForSerialization(library: LibraryName): Boolean = { - isWaitingForSerialization(library.toQualifiedName) - } - - private def abort(name: QualifiedName): Boolean = { - isWaitingForSerialization.synchronized { - if (isWaitingForSerialization(name)) { - isWaitingForSerialization - .remove(name) - .map(_.cancel(false)) - .getOrElse(false) - } else false - } - } - - /** Requests that serialization of `module` be aborted. - * - * If the module is already in the process of serialization it will not be - * aborted. - * - * @param module the module for which to abort serialization - * @return `true` if serialization for `module` was aborted, `false` - * otherwise - */ - private def abort(module: Module): Boolean = { - abort(module.getName) - } - - /** Requests that serialization of library's bindings be aborted. - * - * If the library is already in the process of serialization it will not be - * aborted. - * - * @param library the library for which to abort serialization - * @return `true` if serialization for `library` was aborted, `false` - * otherwise - */ - private def abort(library: LibraryName): Boolean = { - abort(library.toQualifiedName) - } - - /** Performs shutdown actions for the serialization manager. - * - * @param waitForPendingJobCompletion whether or not shutdown should wait for - * pending serialization jobs - */ - def shutdown(waitForPendingJobCompletion: Boolean = false): Unit = { - if (!pool.isShutdown) { - if (waitForPendingJobCompletion && this.hasJobsRemaining) { - val waitingCount = isWaitingForSerialization.synchronized { - isWaitingForSerialization.size - } - val jobCount = waitingCount + isSerializing.size - context.logSerializationManager( - debugLogLevel, - "Waiting for #{0} serialization jobs to complete.", - jobCount - ) - - // Bound the waiting loop - val maxCount = 60 - var counter = 0 - while (this.hasJobsRemaining && counter < maxCount) { - counter += 1 - Thread.sleep(1 * 1000) - } - } - - pool.shutdown() - - // Bound the waiting loop - val maxCount = 10 - var counter = 0 - while (!pool.isTerminated && counter < maxCount) { - pool.awaitTermination(500, TimeUnit.MILLISECONDS) - counter += 1 - } - - pool.shutdownNow() - Thread.sleep(100) - context.logSerializationManager( - debugLogLevel, - "Serialization manager has been shut down." - ) - } - } - - // === Internals ============================================================ - - /** @return `true` if there are remaining serialization jobs, `false` - * otherwise - */ - private def hasJobsRemaining: Boolean = { - isWaitingForSerialization.synchronized { - isWaitingForSerialization.nonEmpty || isSerializing.nonEmpty - } - } - - /** Create the task that serializes the provided module IR when it is run. - * - * @param cache the cache manager for the module being serialized - * @param ir the IR for the module being serialized - * @param stage the compilation stage of the module - * @param name the name of the module being serialized - * @param source the source of the module being serialized - * @param useGlobalCacheLocations if true, will use global caches location, local one otherwise - * @return the task that serialies the provided `ir` - */ - private def doSerializeModule( - cache: ModuleCache, - ir: IRModule, - stage: CompilationStage, - name: QualifiedName, - source: Source, - useGlobalCacheLocations: Boolean - ): Callable[Boolean] = { () => - while (isSerializingModule(name)) { - Thread.sleep(100) - } - - context.logSerializationManager( - debugLogLevel, - "Running serialization for module [{0}].", - name - ) - startSerializing(name) - try { - val fixedStage = - if (stage.isAtLeast(CompilationStage.AFTER_STATIC_PASSES)) { - CompilationStage.AFTER_STATIC_PASSES - } else stage - context - .saveCache( - cache, - new ModuleCache.CachedModule(ir, fixedStage, source), - useGlobalCacheLocations - ) - .map(_ => true) - .orElse(false) - } catch { - case e: NotSerializableException => - context.logSerializationManager( - Level.SEVERE, - s"Could not serialize module [$name].", - e - ) - throw e - case e: Throwable => - context.logSerializationManager( - Level.SEVERE, - s"Serialization of module `$name` failed: ${e.getMessage}`", - e - ) - throw e - } finally { - finishSerializing(name) - } - } - - /** Sets the module described by `name` as serializing. - * - * @param name the name of the module to set as serializing - */ - private def startSerializing(name: QualifiedName): Unit = { - isWaitingForSerialization.synchronized { - isWaitingForSerialization.remove(name) - } - isSerializing.add(name) - } - - /** Sets the module described by `name` as finished with serialization. - * - * @param name the name of the module to set as having finished serialization - */ - private def finishSerializing(name: QualifiedName): Unit = { - isSerializing.remove(name) - } - - /** Builds a [[mutable.Set]] that is backed by a [[ConcurrentHashMap]] and is - * hence safe for concurrent access. - * - * @tparam T the type of the set elements - * @return a concurrent [[mutable.Set]] - */ - private def buildConcurrentHashSet[T]: mutable.Set[T] = { - import scala.jdk.CollectionConverters._ - java.util.Collections - .newSetFromMap( - new ConcurrentHashMap[T, java.lang.Boolean]() - ) - .asScala - } - - private def getCache(module: Module): ModuleCache = { - module.asInstanceOf[TruffleCompilerContext.Module].getCache - } -} - -object SerializationManager { - - /** The maximum number of serialization threads allowed. */ - val maximumThreadCount: Integer = 2 - - /** The number of threads at compiler start. */ - val startingThreadCount: Integer = maximumThreadCount - - /** The thread keep-alive time in seconds. */ - val threadKeepalive: Long = 3 - - implicit private class LibraryOps(val libraryName: LibraryName) - extends AnyVal { - def toQualifiedName: QualifiedName = - QualifiedName(List(libraryName.namespace), libraryName.name) - } - - def apply(context: CompilerContext): SerializationManager = { - context.asInstanceOf[TruffleCompilerContext].getSerializationManager() - } -} diff --git a/engine/runtime/src/test/java/org/enso/compiler/SerdeCompilerTest.java b/engine/runtime/src/test/java/org/enso/compiler/SerdeCompilerTest.java index 862a9c42920d..fe513ec4b4f1 100644 --- a/engine/runtime/src/test/java/org/enso/compiler/SerdeCompilerTest.java +++ b/engine/runtime/src/test/java/org/enso/compiler/SerdeCompilerTest.java @@ -19,7 +19,6 @@ import java.util.logging.SimpleFormatter; import org.enso.compiler.core.ir.Module; import org.enso.interpreter.runtime.EnsoContext; -import org.enso.interpreter.runtime.SerializationManager$; import org.enso.pkg.PackageManager; import org.enso.polyglot.LanguageInfo; import org.enso.polyglot.MethodNames; @@ -60,15 +59,13 @@ private void parseSerializedModule(String projectName, String forbiddenMessage) .filter((m) -> !m.getPackage().libraryName().namespace().equals("Standard")); assertEquals("Two non-standard library modules are compiled", nonStandard.size(), 2); assertEquals(result.compiledModules().exists(m -> m == module), true); - var serializationManager = - SerializationManager$.MODULE$.apply(ensoContext.getCompiler().context()); var futures = new ArrayList>(); result .compiledModules() .filter((m) -> !m.isSynthetic()) .foreach( (m) -> { - var future = serializationManager.serializeModule(compiler, m, true, true); + var future = compiler.context().serializeModule(compiler, m, true, true); futures.add(future); return null; }); diff --git a/engine/runtime/src/test/java/org/enso/compiler/SerializationManagerTest.java b/engine/runtime/src/test/java/org/enso/compiler/SerializationManagerTest.java index 3e714de3a4bc..6639c46ed8bb 100644 --- a/engine/runtime/src/test/java/org/enso/compiler/SerializationManagerTest.java +++ b/engine/runtime/src/test/java/org/enso/compiler/SerializationManagerTest.java @@ -10,9 +10,7 @@ import java.util.stream.Stream; import org.apache.commons.io.FileUtils; import org.enso.editions.LibraryName; -import org.enso.interpreter.caches.SuggestionsCache; import org.enso.interpreter.runtime.EnsoContext; -import org.enso.interpreter.runtime.SerializationManager; import org.enso.interpreter.runtime.util.TruffleFileSystem; import org.enso.interpreter.test.InterpreterContext; import org.enso.pkg.Package; @@ -20,7 +18,10 @@ import org.enso.polyglot.LanguageInfo; import org.enso.polyglot.MethodNames; import org.enso.polyglot.Suggestion; -import org.junit.*; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; public class SerializationManagerTest { @@ -78,7 +79,6 @@ private void clearLibraryCache(LibraryName libraryName) { @Test public void serializeLibrarySuggestions() throws ExecutionException, InterruptedException, TimeoutException { - SerializationManager serializationManager = new SerializationManager(ensoContext.getCompiler()); LibraryName standardBaseLibrary = new LibraryName("Standard", "Base"); Package standardBasePackage = getLibraryPackage(standardBaseLibrary); ensoContext @@ -92,13 +92,12 @@ public void serializeLibrarySuggestions() .get(COMPILE_TIMEOUT_SECONDS, TimeUnit.SECONDS); Assert.assertEquals(Boolean.TRUE, result); - SuggestionsCache.CachedSuggestions cachedSuggestions = - serializationManager.deserializeSuggestions(standardBaseLibrary).get(); - Assert.assertEquals(standardBaseLibrary, cachedSuggestions.getLibraryName()); + var cachedSuggestions = + ensoContext.getCompiler().context().deserializeSuggestions(standardBaseLibrary).get(); Supplier> cachedConstructorSuggestions = () -> - cachedSuggestions.getSuggestions().stream() + cachedSuggestions.stream() .flatMap( suggestion -> { if (suggestion instanceof Suggestion.Constructor constructor) { diff --git a/engine/runtime/src/test/java/org/enso/compiler/SerializerTest.java b/engine/runtime/src/test/java/org/enso/compiler/SerializerTest.java index 1aa6cdd34ec9..a156ec863f41 100644 --- a/engine/runtime/src/test/java/org/enso/compiler/SerializerTest.java +++ b/engine/runtime/src/test/java/org/enso/compiler/SerializerTest.java @@ -2,6 +2,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; import java.io.File; import java.io.IOException; @@ -9,7 +10,6 @@ import java.util.concurrent.TimeUnit; import java.util.logging.Level; import org.enso.interpreter.runtime.EnsoContext; -import org.enso.interpreter.runtime.SerializationManager; import org.enso.pkg.PackageManager; import org.enso.polyglot.LanguageInfo; import org.enso.polyglot.MethodNames; @@ -57,14 +57,13 @@ public void testSerializationOfFQNs() throws Exception { ctx.enter(); var result = compiler.run(module); assertEquals(result.compiledModules().exists(m -> m == module), true); - var serializationManager = new SerializationManager(ensoContext.getCompiler()); var useThreadPool = compiler.context().isCreateThreadAllowed(); - var future = serializationManager.serializeModule(compiler, module, true, useThreadPool); + var future = compiler.context().serializeModule(compiler, module, true, useThreadPool); var serialized = future.get(5, TimeUnit.SECONDS); assertEquals(serialized, true); - var deserialized = serializationManager.deserialize(compiler, module); - assertEquals(deserialized.isDefined() && (Boolean) deserialized.get(), true); - serializationManager.shutdown(true); + var deserialized = compiler.context().deserializeModule(compiler, module); + assertTrue("Deserialized", deserialized); + compiler.context().shutdown(true); ctx.leave(); ctx.close(); } diff --git a/engine/runtime/src/test/java/org/enso/interpreter/caches/ModuleCacheTest.java b/engine/runtime/src/test/java/org/enso/interpreter/caches/ModuleCacheTest.java index 4bd7a1185fe9..eb4fb9187d94 100644 --- a/engine/runtime/src/test/java/org/enso/interpreter/caches/ModuleCacheTest.java +++ b/engine/runtime/src/test/java/org/enso/interpreter/caches/ModuleCacheTest.java @@ -4,6 +4,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import java.nio.ByteBuffer; import org.enso.compiler.CompilerTest; import org.enso.interpreter.runtime.EnsoContext; import org.enso.interpreter.test.TestBase; @@ -49,10 +50,12 @@ public void testCompareList() throws Exception { var module = option.get(); var ir = module.getIr().duplicate(true, true, true, true); var cm = new ModuleCache.CachedModule(ir, CompilationStage.AFTER_CODEGEN, module.getSource()); - byte[] arr = module.getCache().serialize(ensoCtx, cm); + + var mc = module.getCache().asSpi(ModuleCache.class); + byte[] arr = mc.serialize(ensoCtx, cm); var meta = new ModuleCache.Metadata("hash", "code", CompilationStage.AFTER_CODEGEN.toString()); - var cachedIr = module.getCache().deserialize(ensoCtx, arr, meta, null); + var cachedIr = mc.deserialize(ensoCtx, ByteBuffer.wrap(arr), meta, null); assertNotNull("IR read", cachedIr); CompilerTest.assertIR(name, ir, cachedIr.moduleIR()); } @@ -78,10 +81,11 @@ public void testCompareWithWarning() throws Exception { var module = option.get(); var ir = module.getIr().duplicate(true, true, true, true); var cm = new ModuleCache.CachedModule(ir, CompilationStage.AFTER_CODEGEN, module.getSource()); - byte[] arr = module.getCache().serialize(ensoCtx, cm); + var mc = module.getCache().asSpi(ModuleCache.class); + byte[] arr = mc.serialize(ensoCtx, cm); var meta = new ModuleCache.Metadata("hash", "code", CompilationStage.AFTER_CODEGEN.toString()); - var cachedIr = module.getCache().deserialize(ensoCtx, arr, meta, null); + var cachedIr = mc.deserialize(ensoCtx, ByteBuffer.wrap(arr), meta, null); assertNotNull("IR read", cachedIr); CompilerTest.assertIR(name, ir, cachedIr.moduleIR()); } diff --git a/lib/java/persistance/src/main/java/org/enso/persist/PerInputImpl.java b/lib/java/persistance/src/main/java/org/enso/persist/PerInputImpl.java index cf48f24ef482..37d16d2fb611 100644 --- a/lib/java/persistance/src/main/java/org/enso/persist/PerInputImpl.java +++ b/lib/java/persistance/src/main/java/org/enso/persist/PerInputImpl.java @@ -23,14 +23,13 @@ final class PerInputImpl implements Input { this.at = at; } - static Reference readObject(byte[] arr, Function readResolve) + static Reference readObject(ByteBuffer buf, Function readResolve) throws IOException { for (var i = 0; i < PerGenerator.HEADER.length; i++) { - if (arr[i] != PerGenerator.HEADER[i]) { + if (buf.get(i) != PerGenerator.HEADER[i]) { throw new IOException("Wrong header"); } } - var buf = ByteBuffer.wrap(arr); var version = buf.getInt(4); var cache = new InputCache(buf, readResolve); if (version != cache.map().versionStamp) { diff --git a/lib/java/persistance/src/main/java/org/enso/persist/Persistance.java b/lib/java/persistance/src/main/java/org/enso/persist/Persistance.java index e736ffd50c9a..f9b6db9856f4 100644 --- a/lib/java/persistance/src/main/java/org/enso/persist/Persistance.java +++ b/lib/java/persistance/src/main/java/org/enso/persist/Persistance.java @@ -5,6 +5,7 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.function.Function; /** @@ -169,7 +170,22 @@ final T readWith(Input in) { */ public static Reference read(byte[] arr, Function readResolve) throws IOException { - return PerInputImpl.readObject(arr, readResolve); + return read(ByteBuffer.wrap(arr), readResolve); + } + + /** + * Read object written down by {@link #write} from a byte buffer. + * + * @param expected type of object + * @param buf the stored bytes + * @param readResolve either {@code null} or function to call for each object being stored to + * provide a replacement + * @return the read object + * @throws java.io.IOException when an I/O problem happens + */ + public static Reference read(ByteBuffer buf, Function readResolve) + throws IOException { + return PerInputImpl.readObject(buf, readResolve); } /**