diff --git a/metrix-integration/pom.xml b/metrix-integration/pom.xml index 61f24e19..c8f379bc 100644 --- a/metrix-integration/pom.xml +++ b/metrix-integration/pom.xml @@ -76,6 +76,12 @@ powsybl-commons-test test + + com.powsybl + powsybl-computation + test-jar + test + com.powsybl powsybl-config-test diff --git a/metrix-integration/src/main/groovy/com/powsybl/metrix/integration/MetrixDslDataLoader.groovy b/metrix-integration/src/main/groovy/com/powsybl/metrix/integration/MetrixDslDataLoader.groovy index 05a17ab4..6975b8ee 100755 --- a/metrix-integration/src/main/groovy/com/powsybl/metrix/integration/MetrixDslDataLoader.groovy +++ b/metrix-integration/src/main/groovy/com/powsybl/metrix/integration/MetrixDslDataLoader.groovy @@ -15,7 +15,9 @@ import com.powsybl.metrix.mapping.TimeSeriesMappingConfigLoader import com.powsybl.timeseries.ReadOnlyTimeSeriesStore import com.powsybl.timeseries.TimeSeriesFilter import com.powsybl.timeseries.dsl.CalculatedTimeSeriesGroovyDslLoader +import groovy.transform.ThreadInterrupt import org.codehaus.groovy.control.CompilerConfiguration +import org.codehaus.groovy.control.customizers.ASTTransformationCustomizer import org.codehaus.groovy.control.customizers.ImportCustomizer import java.nio.charset.StandardCharsets @@ -104,11 +106,17 @@ class MetrixDslDataLoader { imports.addStaticStars("com.powsybl.metrix.integration.MetrixGeneratorsBinding.ReferenceVariable") def config = CalculatedTimeSeriesGroovyDslLoader.createCompilerConfig() config.addCompilationCustomizers(imports) + + // Add a check on thread interruption in every loop (for, while) in the script + config.addCompilationCustomizers(new ASTTransformationCustomizer(ThreadInterrupt.class)) } static void evaluate(GroovyCodeSource dslSrc, Binding binding) { - def config = createCompilerConfig() - def shell = new GroovyShell(binding, config) + def shell = new GroovyShell(binding, createCompilerConfig()) + + // Check for thread interruption right before beginning the evaluation + if (Thread.currentThread().isInterrupted()) throw new InterruptedException("Execution Interrupted") + shell.evaluate(dslSrc) } diff --git a/metrix-integration/src/test/java/com/powsybl/metrix/integration/MetrixDslDataLoaderInterruptionTest.java b/metrix-integration/src/test/java/com/powsybl/metrix/integration/MetrixDslDataLoaderInterruptionTest.java new file mode 100644 index 00000000..8d495c25 --- /dev/null +++ b/metrix-integration/src/test/java/com/powsybl/metrix/integration/MetrixDslDataLoaderInterruptionTest.java @@ -0,0 +1,125 @@ +package com.powsybl.metrix.integration; + +import com.google.common.jimfs.Configuration; +import com.google.common.jimfs.Jimfs; +import com.powsybl.computation.AbstractTaskInterruptionTest; +import com.powsybl.iidm.network.Network; +import com.powsybl.iidm.serde.NetworkSerDe; +import com.powsybl.metrix.mapping.DataTableStore; +import com.powsybl.metrix.mapping.MappingParameters; +import com.powsybl.metrix.mapping.TimeSeriesDslLoader; +import com.powsybl.metrix.mapping.TimeSeriesMappingConfig; +import com.powsybl.timeseries.ReadOnlyTimeSeriesStore; +import com.powsybl.timeseries.ReadOnlyTimeSeriesStoreCache; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import java.io.IOException; +import java.io.Writer; +import java.nio.charset.StandardCharsets; +import java.nio.file.FileSystem; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Objects; + +/** + * @author Nicolas Rol {@literal } + */ +class MetrixDslDataLoaderInterruptionTest extends AbstractTaskInterruptionTest { + + private FileSystem fileSystem; + private Path dslFile; + private Path mappingFile; + private Network network; + private final MetrixParameters parameters = new MetrixParameters(); + private final MappingParameters mappingParameters = MappingParameters.load(); + + @BeforeEach + public void setUp() throws Exception { + fileSystem = Jimfs.newFileSystem(Configuration.unix()); + dslFile = fileSystem.getPath("/test.dsl"); + mappingFile = fileSystem.getPath("/mapping.dsl"); + network = NetworkSerDe.read(Objects.requireNonNull(getClass().getResourceAsStream("/simpleNetwork.xml"))); + + // Create mapping file for use in all tests + try (Writer writer = Files.newBufferedWriter(mappingFile, StandardCharsets.UTF_8)) { + writer.write(String.join(System.lineSeparator(), + "timeSeries['tsN'] = 1000", + "timeSeries['tsN_1'] = 2000", + "timeSeries['tsITAM'] = 3000", + "timeSeries['ts1'] = 100", + "timeSeries['ts2'] = 200", + "timeSeries['ts3'] = 300", + "timeSeries['ts4'] = 400", + "timeSeries['ts5'] = 500" + )); + } + } + + @AfterEach + public void tearDown() throws Exception { + fileSystem.close(); + } + + @ParameterizedTest + @Timeout(10) + @ValueSource(booleans = {false, true}) + void testCancelMetrixDslDataLoaderShort(boolean isDelayed) throws Exception { + + try (Writer writer = Files.newBufferedWriter(dslFile, StandardCharsets.UTF_8)) { + writer.write(String.join(System.lineSeparator(), + "load('FVALDI11_L') {", + " preventiveSheddingPercentage 20", + "}", + "load('FVALDI11_L2') {", + " preventiveSheddingPercentage 30", + " preventiveSheddingCost 12000", + "}", + "load('FVERGE11_L') {", + " preventiveSheddingPercentage 0", + " preventiveSheddingCost 10000", + "}", + "load('FSSV.O11_L') {", + " curativeSheddingPercentage 40", + "}")); + } + ReadOnlyTimeSeriesStore store = new ReadOnlyTimeSeriesStoreCache(); + TimeSeriesMappingConfig tsConfig = new TimeSeriesDslLoader(mappingFile).load(network, mappingParameters, store, new DataTableStore(), null); + testCancelShortTask(isDelayed, () -> MetrixDslDataLoader.load(dslFile, network, parameters, store, tsConfig)); + } + + @ParameterizedTest + @Timeout(10) + @ValueSource(booleans = {false, true}) + void testCancelMetrixDslDataLoaderLong(boolean isDelayed) throws IOException, InterruptedException { + + try (Writer writer = Files.newBufferedWriter(dslFile, StandardCharsets.UTF_8)) { + writer.write(""" + for (int i = 0; i < 10; i++) { + sleep(500) + print(i) + } + """ + String.join(System.lineSeparator(), + "load('FVALDI11_L') {", + " preventiveSheddingPercentage 20", + "}", + "load('FVALDI11_L2') {", + " preventiveSheddingPercentage 30", + " preventiveSheddingCost 12000", + "}", + "load('FVERGE11_L') {", + " preventiveSheddingPercentage 0", + " preventiveSheddingCost 10000", + "}", + "load('FSSV.O11_L') {", + " curativeSheddingPercentage 40", + "}")); + } + ReadOnlyTimeSeriesStore store = new ReadOnlyTimeSeriesStoreCache(); + TimeSeriesMappingConfig tsConfig = new TimeSeriesDslLoader(mappingFile).load(network, mappingParameters, store, new DataTableStore(), null); + testCancelLongTask(isDelayed, () -> MetrixDslDataLoader.load(dslFile, network, parameters, store, tsConfig)); + } +} diff --git a/metrix-mapping/pom.xml b/metrix-mapping/pom.xml index ca0f7ee9..53e698ac 100644 --- a/metrix-mapping/pom.xml +++ b/metrix-mapping/pom.xml @@ -74,6 +74,12 @@ powsybl-commons-test test + + com.powsybl + powsybl-computation + test-jar + test + com.powsybl powsybl-iidm-impl diff --git a/metrix-mapping/src/main/groovy/com/powsybl/metrix/mapping/TimeSeriesDslLoader.groovy b/metrix-mapping/src/main/groovy/com/powsybl/metrix/mapping/TimeSeriesDslLoader.groovy index 60dac5b6..036ef47b 100644 --- a/metrix-mapping/src/main/groovy/com/powsybl/metrix/mapping/TimeSeriesDslLoader.groovy +++ b/metrix-mapping/src/main/groovy/com/powsybl/metrix/mapping/TimeSeriesDslLoader.groovy @@ -15,8 +15,10 @@ import com.powsybl.timeseries.ReadOnlyTimeSeriesStore import com.powsybl.timeseries.TimeSeriesFilter import com.powsybl.timeseries.ast.NodeCalc import com.powsybl.timeseries.dsl.CalculatedTimeSeriesGroovyDslLoader +import groovy.transform.ThreadInterrupt import org.apache.commons.lang3.StringUtils import org.codehaus.groovy.control.CompilerConfiguration +import org.codehaus.groovy.control.customizers.ASTTransformationCustomizer import org.codehaus.groovy.control.customizers.ImportCustomizer import org.slf4j.Logger import org.slf4j.LoggerFactory @@ -93,6 +95,9 @@ class TimeSeriesDslLoader { getStaticStars().forEach(staticStars -> imports.addStaticStars(staticStars)) def config = CalculatedTimeSeriesGroovyDslLoader.createCompilerConfig() config.addCompilationCustomizers(imports) + + // Add a check on thread interruption in every loop (for, while) in the script + config.addCompilationCustomizers(new ASTTransformationCustomizer(ThreadInterrupt.class)) } static void bind(Binding binding, Network network, ReadOnlyTimeSeriesStore store, DataTableStore dataTableStore, MappingParameters parameters, TimeSeriesMappingConfig config, TimeSeriesMappingConfigLoader loader, LogDslLoader logDslLoader, ComputationRange computationRange) { @@ -282,6 +287,10 @@ class TimeSeriesDslLoader { } def shell = new GroovyShell(binding, createCompilerConfig()) + + // Check for thread interruption right before beginning the evaluation + if (Thread.currentThread().isInterrupted()) throw new InterruptedException("Execution Interrupted") + shell.evaluate(dslSrc) TimeSeriesMappingConfigChecker configChecker = new TimeSeriesMappingConfigChecker(config) diff --git a/metrix-mapping/src/test/java/com/powsybl/metrix/mapping/TimeSeriesDslLoaderInterruptionTest.java b/metrix-mapping/src/test/java/com/powsybl/metrix/mapping/TimeSeriesDslLoaderInterruptionTest.java new file mode 100644 index 00000000..00cf10a1 --- /dev/null +++ b/metrix-mapping/src/test/java/com/powsybl/metrix/mapping/TimeSeriesDslLoaderInterruptionTest.java @@ -0,0 +1,59 @@ +package com.powsybl.metrix.mapping; + +import com.powsybl.computation.AbstractTaskInterruptionTest; +import com.powsybl.iidm.network.Network; +import com.powsybl.timeseries.ReadOnlyTimeSeriesStore; +import com.powsybl.timeseries.ReadOnlyTimeSeriesStoreCache; +import org.junit.jupiter.api.*; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +/** + * @author Nicolas Rol {@literal } + */ +@TestMethodOrder(MethodOrderer.OrderAnnotation.class) +class TimeSeriesDslLoaderInterruptionTest extends AbstractTaskInterruptionTest { + + private final MappingParameters parameters = MappingParameters.load(); + private final Network network = MappingTestNetwork.create(); + private final ReadOnlyTimeSeriesStore store = new ReadOnlyTimeSeriesStoreCache(); + + @ParameterizedTest + @Timeout(10) + @Order(1) + @ValueSource(booleans = {false, true}) + void testCancelTaskJava(boolean isDelayed) throws Exception { + testCancelLongTask(isDelayed, () -> { + try { + Thread.sleep(5000L); + return 0; + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + } + + @ParameterizedTest + @Timeout(10) + @Order(2) + @ValueSource(booleans = {false, true}) + void testCancelTaskGroovyLong(boolean isDelayed) throws Exception { + String script = """ + for (int i = 0; i < 10; i++) { + sleep(500) + } + """; + TimeSeriesDslLoader dsl = new TimeSeriesDslLoader(script); + testCancelLongTask(isDelayed, () -> dsl.load(network, parameters, store, new DataTableStore(), null)); + } + + @ParameterizedTest + @Timeout(10) + @Order(3) + @ValueSource(booleans = {false, true}) + void testCancelTaskGroovyShort(boolean isDelayed) throws Exception { + String script = "writeLog(\"LOG_TYPE\", \"LOG_SECTION\", \"LOG_MESSAGE\")"; + TimeSeriesDslLoader dsl = new TimeSeriesDslLoader(script); + testCancelShortTask(isDelayed, () -> dsl.load(network, parameters, store, new DataTableStore(), null)); + } +} diff --git a/pom.xml b/pom.xml index d70fae2b..21dd30c0 100644 --- a/pom.xml +++ b/pom.xml @@ -248,6 +248,13 @@ ${powsyblcore.version} test + + com.powsybl + powsybl-computation + test-jar + test + ${powsyblcore.version} +