From 11205b650eedd6d23244c6369d904264e30992c1 Mon Sep 17 00:00:00 2001 From: Bernard Knueven Date: Tue, 29 Jun 2021 10:54:42 -0600 Subject: [PATCH] adding config_file interface & tests for it --- prescient/simulator/__init__.py | 1 + prescient/simulator/prescient.py | 12 +++++++++++- tests/simulator_tests/test_sim_rts_mod.py | 22 ++++++++++++++++++---- 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/prescient/simulator/__init__.py b/prescient/simulator/__init__.py index d70c67a7..121f6fd2 100644 --- a/prescient/simulator/__init__.py +++ b/prescient/simulator/__init__.py @@ -14,3 +14,4 @@ from .stats_manager import StatsManager from .time_manager import TimeManager from .options import Options +from .prescient import Prescient diff --git a/prescient/simulator/prescient.py b/prescient/simulator/prescient.py index a364acb7..e5d0ae6f 100644 --- a/prescient/simulator/prescient.py +++ b/prescient/simulator/prescient.py @@ -24,6 +24,7 @@ from .oracle_manager import OracleManager from .stats_manager import StatsManager from .reporting_manager import ReportingManager +from prescient.scripts import runner from prescient.stats.overall_stats import OverallStats from prescient.engine.egret import EgretEngine as Engine @@ -55,7 +56,16 @@ def simulate(self, **options): prescient.plugins.internal.clear_plugins() prescient.simulator.config.clear_prescient_config() - if 'plugin' in options: + if 'config_file' in options: + config_file = options.pop('config_file') + if options: + raise RuntimeError(f"If using a config_file, all options must be specified in the configuration file") + script, config_options = runner.parse_commands(config_file) + if script != 'simulator.py': + raise RuntimeError(f"config_file must be a simulator configuration text file, got {script}") + options = parse_args(args=config_options) + + elif 'plugin' in options: # parse using the Config plugin_options = self.CONFIG({ 'plugin':options['plugin'] }) for plugin in plugin_options.plugin: diff --git a/tests/simulator_tests/test_sim_rts_mod.py b/tests/simulator_tests/test_sim_rts_mod.py index 805b04b8..ee64b58f 100644 --- a/tests/simulator_tests/test_sim_rts_mod.py +++ b/tests/simulator_tests/test_sim_rts_mod.py @@ -18,7 +18,7 @@ from prescient.scripts import runner from tests.simulator_tests import simulator_diff -from prescient.simulator.prescient import Prescient +from prescient.simulator import Prescient this_file_path = os.path.dirname(os.path.realpath(__file__)) @@ -126,10 +126,10 @@ def _assert_column_equality(self, filename, column_name): diff = df_a[column_name].equals(df_b[column_name]) assert diff, f"Column: '{column_name}' of File: '{filename}.csv' diverges." - -class TestSimulatorModRtsGmlcCopperSheet_csv(_SimulatorModRTSGMLC, unittest.TestCase): +# test runner.py with plugin +class TestSimulatorModRtsGmlcCopperSheet(_SimulatorModRTSGMLC, unittest.TestCase): def _set_names(self): - self.simulator_config_filename = 'simulate_deterministic_csv.txt' + self.simulator_config_filename = 'simulate_deterministic.txt' self.results_dir_name = 'deterministic_simulation_csv_output' self.baseline_dir_name = 'deterministic_simulation_output_baseline' @@ -152,6 +152,19 @@ def _set_names(self): 'no_startup_shutdown_curves':True, } +# test csv / text file configuration +class TestSimulatorModRtsGmlcCopperSheet_csv_python_config_file(_SimulatorModRTSGMLC, unittest.TestCase): + def _set_names(self): + self.simulator_config_filename = 'simulate_deterministic_csv.txt' + self.results_dir_name = 'deterministic_simulation_csv_output' + self.baseline_dir_name = 'deterministic_simulation_output_baseline' + + def _run_simulator(self): + os.chdir(self.test_cases_path) + options = {'config_file' : self.simulator_config_filename} + Prescient().simulate(**options) + +# test plugin with Python and *.dat files class TestSimulatorModRtsGmlcCopperSheet_python(_SimulatorModRTSGMLC, unittest.TestCase): def _set_names(self): @@ -167,6 +180,7 @@ def _run_simulator(self): options['print_callback_message'] = True Prescient().simulate(**options) +# test options are correctly re-freshed, Python, and network class TestSimulatorModRtsGmlcNetwork_python(_SimulatorModRTSGMLC, unittest.TestCase): def _set_names(self):