diff --git a/metrix-integration/pom.xml b/metrix-integration/pom.xml
index 61f24e19..c8f379bc 100644
--- a/metrix-integration/pom.xml
+++ b/metrix-integration/pom.xml
@@ -76,6 +76,12 @@
powsybl-commons-test
test
+
+ com.powsybl
+ powsybl-computation
+ test-jar
+ test
+
com.powsybl
powsybl-config-test
diff --git a/metrix-integration/src/main/groovy/com/powsybl/metrix/integration/MetrixDslDataLoader.groovy b/metrix-integration/src/main/groovy/com/powsybl/metrix/integration/MetrixDslDataLoader.groovy
index 05a17ab4..6975b8ee 100755
--- a/metrix-integration/src/main/groovy/com/powsybl/metrix/integration/MetrixDslDataLoader.groovy
+++ b/metrix-integration/src/main/groovy/com/powsybl/metrix/integration/MetrixDslDataLoader.groovy
@@ -15,7 +15,9 @@ import com.powsybl.metrix.mapping.TimeSeriesMappingConfigLoader
import com.powsybl.timeseries.ReadOnlyTimeSeriesStore
import com.powsybl.timeseries.TimeSeriesFilter
import com.powsybl.timeseries.dsl.CalculatedTimeSeriesGroovyDslLoader
+import groovy.transform.ThreadInterrupt
import org.codehaus.groovy.control.CompilerConfiguration
+import org.codehaus.groovy.control.customizers.ASTTransformationCustomizer
import org.codehaus.groovy.control.customizers.ImportCustomizer
import java.nio.charset.StandardCharsets
@@ -104,11 +106,17 @@ class MetrixDslDataLoader {
imports.addStaticStars("com.powsybl.metrix.integration.MetrixGeneratorsBinding.ReferenceVariable")
def config = CalculatedTimeSeriesGroovyDslLoader.createCompilerConfig()
config.addCompilationCustomizers(imports)
+
+ // Add a check on thread interruption in every loop (for, while) in the script
+ config.addCompilationCustomizers(new ASTTransformationCustomizer(ThreadInterrupt.class))
}
static void evaluate(GroovyCodeSource dslSrc, Binding binding) {
- def config = createCompilerConfig()
- def shell = new GroovyShell(binding, config)
+ def shell = new GroovyShell(binding, createCompilerConfig())
+
+ // Check for thread interruption right before beginning the evaluation
+ if (Thread.currentThread().isInterrupted()) throw new InterruptedException("Execution Interrupted")
+
shell.evaluate(dslSrc)
}
diff --git a/metrix-integration/src/test/java/com/powsybl/metrix/integration/MetrixDslDataLoaderInterruptionTest.java b/metrix-integration/src/test/java/com/powsybl/metrix/integration/MetrixDslDataLoaderInterruptionTest.java
new file mode 100644
index 00000000..8d495c25
--- /dev/null
+++ b/metrix-integration/src/test/java/com/powsybl/metrix/integration/MetrixDslDataLoaderInterruptionTest.java
@@ -0,0 +1,125 @@
+package com.powsybl.metrix.integration;
+
+import com.google.common.jimfs.Configuration;
+import com.google.common.jimfs.Jimfs;
+import com.powsybl.computation.AbstractTaskInterruptionTest;
+import com.powsybl.iidm.network.Network;
+import com.powsybl.iidm.serde.NetworkSerDe;
+import com.powsybl.metrix.mapping.DataTableStore;
+import com.powsybl.metrix.mapping.MappingParameters;
+import com.powsybl.metrix.mapping.TimeSeriesDslLoader;
+import com.powsybl.metrix.mapping.TimeSeriesMappingConfig;
+import com.powsybl.timeseries.ReadOnlyTimeSeriesStore;
+import com.powsybl.timeseries.ReadOnlyTimeSeriesStoreCache;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Timeout;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+
+import java.io.IOException;
+import java.io.Writer;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.FileSystem;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.Objects;
+
+/**
+ * @author Nicolas Rol {@literal }
+ */
+class MetrixDslDataLoaderInterruptionTest extends AbstractTaskInterruptionTest {
+
+ private FileSystem fileSystem;
+ private Path dslFile;
+ private Path mappingFile;
+ private Network network;
+ private final MetrixParameters parameters = new MetrixParameters();
+ private final MappingParameters mappingParameters = MappingParameters.load();
+
+ @BeforeEach
+ public void setUp() throws Exception {
+ fileSystem = Jimfs.newFileSystem(Configuration.unix());
+ dslFile = fileSystem.getPath("/test.dsl");
+ mappingFile = fileSystem.getPath("/mapping.dsl");
+ network = NetworkSerDe.read(Objects.requireNonNull(getClass().getResourceAsStream("/simpleNetwork.xml")));
+
+ // Create mapping file for use in all tests
+ try (Writer writer = Files.newBufferedWriter(mappingFile, StandardCharsets.UTF_8)) {
+ writer.write(String.join(System.lineSeparator(),
+ "timeSeries['tsN'] = 1000",
+ "timeSeries['tsN_1'] = 2000",
+ "timeSeries['tsITAM'] = 3000",
+ "timeSeries['ts1'] = 100",
+ "timeSeries['ts2'] = 200",
+ "timeSeries['ts3'] = 300",
+ "timeSeries['ts4'] = 400",
+ "timeSeries['ts5'] = 500"
+ ));
+ }
+ }
+
+ @AfterEach
+ public void tearDown() throws Exception {
+ fileSystem.close();
+ }
+
+ @ParameterizedTest
+ @Timeout(10)
+ @ValueSource(booleans = {false, true})
+ void testCancelMetrixDslDataLoaderShort(boolean isDelayed) throws Exception {
+
+ try (Writer writer = Files.newBufferedWriter(dslFile, StandardCharsets.UTF_8)) {
+ writer.write(String.join(System.lineSeparator(),
+ "load('FVALDI11_L') {",
+ " preventiveSheddingPercentage 20",
+ "}",
+ "load('FVALDI11_L2') {",
+ " preventiveSheddingPercentage 30",
+ " preventiveSheddingCost 12000",
+ "}",
+ "load('FVERGE11_L') {",
+ " preventiveSheddingPercentage 0",
+ " preventiveSheddingCost 10000",
+ "}",
+ "load('FSSV.O11_L') {",
+ " curativeSheddingPercentage 40",
+ "}"));
+ }
+ ReadOnlyTimeSeriesStore store = new ReadOnlyTimeSeriesStoreCache();
+ TimeSeriesMappingConfig tsConfig = new TimeSeriesDslLoader(mappingFile).load(network, mappingParameters, store, new DataTableStore(), null);
+ testCancelShortTask(isDelayed, () -> MetrixDslDataLoader.load(dslFile, network, parameters, store, tsConfig));
+ }
+
+ @ParameterizedTest
+ @Timeout(10)
+ @ValueSource(booleans = {false, true})
+ void testCancelMetrixDslDataLoaderLong(boolean isDelayed) throws IOException, InterruptedException {
+
+ try (Writer writer = Files.newBufferedWriter(dslFile, StandardCharsets.UTF_8)) {
+ writer.write("""
+ for (int i = 0; i < 10; i++) {
+ sleep(500)
+ print(i)
+ }
+ """ + String.join(System.lineSeparator(),
+ "load('FVALDI11_L') {",
+ " preventiveSheddingPercentage 20",
+ "}",
+ "load('FVALDI11_L2') {",
+ " preventiveSheddingPercentage 30",
+ " preventiveSheddingCost 12000",
+ "}",
+ "load('FVERGE11_L') {",
+ " preventiveSheddingPercentage 0",
+ " preventiveSheddingCost 10000",
+ "}",
+ "load('FSSV.O11_L') {",
+ " curativeSheddingPercentage 40",
+ "}"));
+ }
+ ReadOnlyTimeSeriesStore store = new ReadOnlyTimeSeriesStoreCache();
+ TimeSeriesMappingConfig tsConfig = new TimeSeriesDslLoader(mappingFile).load(network, mappingParameters, store, new DataTableStore(), null);
+ testCancelLongTask(isDelayed, () -> MetrixDslDataLoader.load(dslFile, network, parameters, store, tsConfig));
+ }
+}
diff --git a/metrix-mapping/pom.xml b/metrix-mapping/pom.xml
index ca0f7ee9..53e698ac 100644
--- a/metrix-mapping/pom.xml
+++ b/metrix-mapping/pom.xml
@@ -74,6 +74,12 @@
powsybl-commons-test
test
+
+ com.powsybl
+ powsybl-computation
+ test-jar
+ test
+
com.powsybl
powsybl-iidm-impl
diff --git a/metrix-mapping/src/main/groovy/com/powsybl/metrix/mapping/TimeSeriesDslLoader.groovy b/metrix-mapping/src/main/groovy/com/powsybl/metrix/mapping/TimeSeriesDslLoader.groovy
index 60dac5b6..036ef47b 100644
--- a/metrix-mapping/src/main/groovy/com/powsybl/metrix/mapping/TimeSeriesDslLoader.groovy
+++ b/metrix-mapping/src/main/groovy/com/powsybl/metrix/mapping/TimeSeriesDslLoader.groovy
@@ -15,8 +15,10 @@ import com.powsybl.timeseries.ReadOnlyTimeSeriesStore
import com.powsybl.timeseries.TimeSeriesFilter
import com.powsybl.timeseries.ast.NodeCalc
import com.powsybl.timeseries.dsl.CalculatedTimeSeriesGroovyDslLoader
+import groovy.transform.ThreadInterrupt
import org.apache.commons.lang3.StringUtils
import org.codehaus.groovy.control.CompilerConfiguration
+import org.codehaus.groovy.control.customizers.ASTTransformationCustomizer
import org.codehaus.groovy.control.customizers.ImportCustomizer
import org.slf4j.Logger
import org.slf4j.LoggerFactory
@@ -93,6 +95,9 @@ class TimeSeriesDslLoader {
getStaticStars().forEach(staticStars -> imports.addStaticStars(staticStars))
def config = CalculatedTimeSeriesGroovyDslLoader.createCompilerConfig()
config.addCompilationCustomizers(imports)
+
+ // Add a check on thread interruption in every loop (for, while) in the script
+ config.addCompilationCustomizers(new ASTTransformationCustomizer(ThreadInterrupt.class))
}
static void bind(Binding binding, Network network, ReadOnlyTimeSeriesStore store, DataTableStore dataTableStore, MappingParameters parameters, TimeSeriesMappingConfig config, TimeSeriesMappingConfigLoader loader, LogDslLoader logDslLoader, ComputationRange computationRange) {
@@ -282,6 +287,10 @@ class TimeSeriesDslLoader {
}
def shell = new GroovyShell(binding, createCompilerConfig())
+
+ // Check for thread interruption right before beginning the evaluation
+ if (Thread.currentThread().isInterrupted()) throw new InterruptedException("Execution Interrupted")
+
shell.evaluate(dslSrc)
TimeSeriesMappingConfigChecker configChecker = new TimeSeriesMappingConfigChecker(config)
diff --git a/metrix-mapping/src/test/java/com/powsybl/metrix/mapping/TimeSeriesDslLoaderInterruptionTest.java b/metrix-mapping/src/test/java/com/powsybl/metrix/mapping/TimeSeriesDslLoaderInterruptionTest.java
new file mode 100644
index 00000000..00cf10a1
--- /dev/null
+++ b/metrix-mapping/src/test/java/com/powsybl/metrix/mapping/TimeSeriesDslLoaderInterruptionTest.java
@@ -0,0 +1,59 @@
+package com.powsybl.metrix.mapping;
+
+import com.powsybl.computation.AbstractTaskInterruptionTest;
+import com.powsybl.iidm.network.Network;
+import com.powsybl.timeseries.ReadOnlyTimeSeriesStore;
+import com.powsybl.timeseries.ReadOnlyTimeSeriesStoreCache;
+import org.junit.jupiter.api.*;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+
+/**
+ * @author Nicolas Rol {@literal }
+ */
+@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
+class TimeSeriesDslLoaderInterruptionTest extends AbstractTaskInterruptionTest {
+
+ private final MappingParameters parameters = MappingParameters.load();
+ private final Network network = MappingTestNetwork.create();
+ private final ReadOnlyTimeSeriesStore store = new ReadOnlyTimeSeriesStoreCache();
+
+ @ParameterizedTest
+ @Timeout(10)
+ @Order(1)
+ @ValueSource(booleans = {false, true})
+ void testCancelTaskJava(boolean isDelayed) throws Exception {
+ testCancelLongTask(isDelayed, () -> {
+ try {
+ Thread.sleep(5000L);
+ return 0;
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+
+ @ParameterizedTest
+ @Timeout(10)
+ @Order(2)
+ @ValueSource(booleans = {false, true})
+ void testCancelTaskGroovyLong(boolean isDelayed) throws Exception {
+ String script = """
+ for (int i = 0; i < 10; i++) {
+ sleep(500)
+ }
+ """;
+ TimeSeriesDslLoader dsl = new TimeSeriesDslLoader(script);
+ testCancelLongTask(isDelayed, () -> dsl.load(network, parameters, store, new DataTableStore(), null));
+ }
+
+ @ParameterizedTest
+ @Timeout(10)
+ @Order(3)
+ @ValueSource(booleans = {false, true})
+ void testCancelTaskGroovyShort(boolean isDelayed) throws Exception {
+ String script = "writeLog(\"LOG_TYPE\", \"LOG_SECTION\", \"LOG_MESSAGE\")";
+ TimeSeriesDslLoader dsl = new TimeSeriesDslLoader(script);
+ testCancelShortTask(isDelayed, () -> dsl.load(network, parameters, store, new DataTableStore(), null));
+ }
+}
diff --git a/pom.xml b/pom.xml
index d70fae2b..21dd30c0 100644
--- a/pom.xml
+++ b/pom.xml
@@ -248,6 +248,13 @@
${powsyblcore.version}
test
+
+ com.powsybl
+ powsybl-computation
+ test-jar
+ test
+ ${powsyblcore.version}
+