Skip to content

Commit

Permalink
notifyPendingCacheItems in ProgramExecutionSupport before executionSe…
Browse files Browse the repository at this point in the history
…rvice.execute
  • Loading branch information
JaroslavTulach committed Sep 26, 2022
1 parent ee488ae commit 2c7652d
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ object ContextRegistryProtocol {
/** An information about computed expression. */
case object Value extends Payload

case class Pending(message: Option[String], progress: Option[Double]) extends Payload;
case class Pending(message: Option[String], progress: Option[Double])
extends Payload;

/** Indicates that the expression was computed to an error.
*
Expand Down Expand Up @@ -217,7 +218,7 @@ object ContextRegistryProtocol {

case m: Payload.Pending =>
Encoder[Payload.Pending]
.apply(m)
.apply(m)
.deepMerge(
Json.obj(CodecField.Type -> PayloadType.Pending.asJson)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,9 @@ object CacheInvalidation {
*/
def runAll(
stack: Iterable[InstrumentFrame],
instructions: Iterable[CacheInvalidation],
invalidatedKeys: java.util.Set[UUID]
instructions: Iterable[CacheInvalidation]
): Unit =
instructions.foreach(run(stack, _, invalidatedKeys))
instructions.foreach(run(stack, _))

/** Run a sequence of invalidation instructions on all visualisations.
*
Expand Down Expand Up @@ -158,14 +157,13 @@ object CacheInvalidation {
*/
def run(
stack: Iterable[InstrumentFrame],
instruction: CacheInvalidation,
invalidatedKeys: java.util.Set[UUID]
instruction: CacheInvalidation
): Unit = {
val frames = instruction.elements match {
case StackSelector.All => stack
case StackSelector.Top => stack.headOption.toSeq
}
run(frames, instruction.command, instruction.indexes, invalidatedKeys)
run(frames, instruction.command, instruction.indexes)
}

/** Run cache invalidation of a multiple instrument frames.
Expand All @@ -177,12 +175,9 @@ object CacheInvalidation {
private def run(
frames: Iterable[InstrumentFrame],
command: Command,
indexes: Set[IndexSelector],
invalidatedKeys: java.util.Set[UUID]
indexes: Set[IndexSelector]
): Unit = {
frames.foreach(frame =>
run(frame.cache, frame.syncState, command, indexes, invalidatedKeys)
)
frames.foreach(frame => run(frame.cache, frame.syncState, command, indexes))
}

/** Run cache invalidation of a single instrument frame.
Expand Down Expand Up @@ -226,25 +221,21 @@ object CacheInvalidation {
cache: RuntimeCache,
syncState: UpdatesSynchronizationState,
command: Command,
indexes: Set[IndexSelector],
invalidatedKeys: java.util.Set[UUID]
indexes: Set[IndexSelector]
): Unit =
command match {
case Command.InvalidateAll =>
invalidatedKeys.addAll(cache.getKeys)
cache.clear()
indexes.foreach(clearIndex(_, cache))
case Command.InvalidateKeys(keys) =>
keys.foreach { key =>
cache.remove(key)
invalidatedKeys.add(key)
indexes.foreach(clearIndexKey(key, _, cache))
}
case Command.InvalidateStale(scope) =>
val staleKeys = cache.getKeys.asScala.diff(scope.toSet)
staleKeys.foreach { key =>
cache.remove(key)
invalidatedKeys.add(key)
indexes.foreach(clearIndexKey(key, _, cache))
syncState.invalidate(key)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ class RecomputeContextCmd(
.map(CacheInvalidation(CacheInvalidation.StackSelector.Top, _))
CacheInvalidation.runAll(
stack,
cacheInvalidationCommands,
new java.util.HashSet[java.util.UUID]()
cacheInvalidationCommands
)
reply(Api.RecomputeContextResponse(request.contextId))
true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import java.util.logging.Level

import scala.jdk.CollectionConverters._
import scala.jdk.OptionConverters._
import java.util.UUID

/** A job that ensures that specified files are compiled.
*
Expand All @@ -46,7 +45,7 @@ final class EnsureCompiledJob(protected val files: Iterable[File])

try {
val compilationResult = ensureCompiledFiles(files)
setCacheWeights(new java.util.HashSet[java.util.UUID]())
setCacheWeights()
compilationResult
} finally {
ctx.locking.releaseWriteCompilationLock()
Expand Down Expand Up @@ -306,54 +305,17 @@ final class EnsureCompiledJob(protected val files: Iterable[File])
changeset,
module.getSource.getCharacters
)
val invalidatedKeys = new java.util.HashSet[java.util.UUID]();
ctx.contextManager.getAllContexts.values
.foreach { stack =>
if (stack.nonEmpty && isStackInModule(module.getName, stack)) {
CacheInvalidation.runAll(stack, invalidationCommands, invalidatedKeys)
CacheInvalidation.runAll(stack, invalidationCommands)
}
}
CacheInvalidation.runAllVisualisations(
ctx.contextManager.getVisualisations(module.getName),
invalidationCommands
)

if (!invalidatedKeys.isEmpty()) {
System.err.println("Invalidated: " + invalidatedKeys)
val invalidatedKeysScala = invalidatedKeys.asScala.toSet
ctx.contextManager.getAllContexts.foreachEntry((contextId, stack) => {
val knownKeys = stack.top.cache.getWeights.entrySet
val cachedKeys = stack.top.cache.getKeys
val pendingKeys = new java.util.HashSet[UUID]()

knownKeys.forEach(e => {
if (e.getValue > 0) {
if (!cachedKeys.contains(e.getKey)) {
pendingKeys.add(e.getKey)
System.out.println(" found key with " + e.getValue + " key: " + e.getKey)
}
}
});
val ids = invalidatedKeysScala.map { key =>
// pendingKeys.asScala.toSet.map { key =>
Api.ExpressionUpdate(
key,
None,
None,
Vector.empty,
true,
Api.ExpressionUpdate.Payload.Pending(None, None)
)
}

System.err.println(" ignore pendingKeys: " + pendingKeys)
val msg = Api.Response(
Api.ExpressionUpdates(contextId, ids)
)
ctx.endpoint.sendToClient(msg)
})
}

val invalidatedVisualisations =
ctx.contextManager.getInvalidatedVisualisations(
module.getName,
Expand Down Expand Up @@ -409,18 +371,15 @@ final class EnsureCompiledJob(protected val files: Iterable[File])
else
CompilationStatus.Success

private def setCacheWeights(
invalidatedKeys: java.util.Set[java.util.UUID]
)(implicit ctx: RuntimeContext): Unit = {
private def setCacheWeights()(implicit ctx: RuntimeContext): Unit = {
ctx.contextManager.getAllContexts.values.foreach { stack =>
getCacheMetadata(stack).foreach { metadata =>
CacheInvalidation.run(
stack,
CacheInvalidation(
CacheInvalidation.StackSelector.Top,
CacheInvalidation.Command.SetMetadata(metadata)
),
invalidatedKeys
)
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import org.enso.polyglot.LanguageInfo
import org.enso.polyglot.runtime.Runtime.Api
import org.enso.polyglot.runtime.Runtime.Api.ContextId

import scala.jdk.CollectionConverters._
import scala.jdk.OptionConverters._

/** Provides support for executing Enso code. Adds convenient methods to
Expand Down Expand Up @@ -95,12 +96,45 @@ object ProgramExecutionSupport {
enterables += fun.getExpressionId -> fun.getCall
}

def notifyPendingCacheItems(cache: RuntimeCache): Unit = {
val knownKeys = cache.getWeights.entrySet
val cachedKeys = cache.getKeys
val pendingKeys = new java.util.HashSet[UUID]()

knownKeys.forEach(e => {
if (e.getValue > 0) {
if (!cachedKeys.contains(e.getKey)) {
pendingKeys.add(e.getKey)
System.out.println(
" found key with " + e.getValue + " key: " + e.getKey
)
}
}
});
val ids = pendingKeys.asScala.toSet.map { key =>
Api.ExpressionUpdate(
key,
None,
None,
Vector.empty,
true,
Api.ExpressionUpdate.Payload.Pending(None, None)
)
}

val msg = Api.Response(
Api.ExpressionUpdates(contextId, ids)
)
ctx.endpoint.sendToClient(msg)
}

executionFrame match {
case ExecutionFrame(
ExecutionItem.Method(module, cons, function),
cache,
syncState
) =>
notifyPendingCacheItems(cache)
ctx.executionService.execute(
module.toString,
cons.item,
Expand All @@ -125,6 +159,7 @@ object ProgramExecutionSupport {
.orElseThrow(() =>
new ModuleNotFoundForExpressionIdException(expressionId)
)
notifyPendingCacheItems(cache)
ctx.executionService.execute(
module,
callData,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,14 @@ class InstrumentTestContext {
n: Int,
timeoutSeconds: Long = 60
): List[Api.Response] = {
receiveNWithFilter(n, {
receiveNWithFilter(
n,
{
case Some(Api.Response(None, Api.ExpressionUpdates(_, _))) => false
case _ => true
}, timeoutSeconds)
case _ => true
},
timeoutSeconds
)
}

def receiveNIgnoreStdLib(
Expand All @@ -45,7 +49,6 @@ class InstrumentTestContext {
receiveNWithFilter(n, (_ => true), timeoutSeconds)
}


private def receiveNWithFilter(
n: Int,
f: (Any => Boolean),
Expand Down

0 comments on commit 2c7652d

Please sign in to comment.