Skip to content

Commit

Permalink
Merge pull request #3377 from kratman/feat/useTempDirectories
Browse files Browse the repository at this point in the history
tests: Use temporary directories in unit tests
  • Loading branch information
Saransh-cpp authored Oct 8, 2023
2 parents d02a78b + 708f721 commit 64d712e
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 189 deletions.
23 changes: 11 additions & 12 deletions tests/unit/test_batch_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import pybamm
import unittest
from tempfile import TemporaryDirectory

spm = pybamm.lithium_ion.SPM()
spm_uniform = pybamm.lithium_ion.SPM({"particle": "uniform profile"})
Expand Down Expand Up @@ -90,21 +91,19 @@ def test_solve(self):
self.assertIn(output_experiment, experiments_list)

def test_create_gif(self):
bs = pybamm.BatchStudy({"spm": pybamm.lithium_ion.SPM()})
bs.solve([0, 10])
with TemporaryDirectory() as dir_name:
bs = pybamm.BatchStudy({"spm": pybamm.lithium_ion.SPM()})
bs.solve([0, 10])

# Create a temporary file name
test_stub = "batch_study_test"
test_file = f"{test_stub}.gif"
# Create a temporary file name
test_file = os.path.join(dir_name, "batch_study_test.gif")

# create a GIF before calling the plot method
bs.create_gif(number_of_images=3, duration=1, output_filename=test_file)
# create a GIF before calling the plot method
bs.create_gif(number_of_images=3, duration=1, output_filename=test_file)

# create a GIF after calling the plot method
bs.plot(testing=True)
bs.create_gif(number_of_images=3, duration=1, output_filename=test_file)

os.remove(test_file)
# create a GIF after calling the plot method
bs.plot(testing=True)
bs.create_gif(number_of_images=3, duration=1, output_filename=test_file)


if __name__ == "__main__":
Expand Down
18 changes: 11 additions & 7 deletions tests/unit/test_expression_tree/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tests import TestCase
import os
import unittest
from tempfile import TemporaryDirectory

import numpy as np
from scipy.sparse import csr_matrix, coo_matrix
Expand Down Expand Up @@ -386,13 +387,16 @@ def test_symbol_repr(self):
)

def test_symbol_visualise(self):
c = pybamm.Variable("c", "negative electrode")
d = pybamm.Variable("d", "negative electrode")
sym = pybamm.div(c * pybamm.grad(c)) + (c / d + c - d) ** 5
sym.visualise("test_visualize.png")
self.assertTrue(os.path.exists("test_visualize.png"))
with self.assertRaises(ValueError):
sym.visualise("test_visualize")
with TemporaryDirectory() as dir_name:
test_stub = os.path.join(dir_name, "test_visualize")
test_name = f"{test_stub}.png"
c = pybamm.Variable("c", "negative electrode")
d = pybamm.Variable("d", "negative electrode")
sym = pybamm.div(c * pybamm.grad(c)) + (c / d + c - d) ** 5
sym.visualise(test_name)
self.assertTrue(os.path.exists(test_name))
with self.assertRaises(ValueError):
sym.visualise(test_stub)

def test_has_spatial_derivatives(self):
var = pybamm.Variable("var", domain="test")
Expand Down
12 changes: 7 additions & 5 deletions tests/unit/test_parameters/test_lead_acid_parameters.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#
# Test for the standard lead acid parameters
#
import os
from tests import TestCase
import pybamm
from tests import get_discretisation_for_testing

from tempfile import TemporaryDirectory
import unittest


Expand All @@ -15,10 +16,11 @@ def test_scipy_constants(self):
self.assertAlmostEqual(constants.F.evaluate(), 96485, places=0)

def test_print_parameters(self):
parameters = pybamm.LeadAcidParameters()
parameter_values = pybamm.lead_acid.BaseModel().default_parameter_values
output_file = "lead_acid_parameters.txt"
parameter_values.print_parameters(parameters, output_file)
with TemporaryDirectory() as dir_name:
parameters = pybamm.LeadAcidParameters()
parameter_values = pybamm.lead_acid.BaseModel().default_parameter_values
output_file = os.path.join(dir_name, "lead_acid_parameters.txt")
parameter_values.print_parameters(parameters, output_file)

def test_parameters_defaults_lead_acid(self):
# Load parameters to be tested
Expand Down
14 changes: 8 additions & 6 deletions tests/unit/test_parameters/test_lithium_ion_parameters.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
#
# Tests lithium ion parameters load and give expected values
# Tests lithium-ion parameters load and give expected values
#
import os
from tests import TestCase
import pybamm

from tempfile import TemporaryDirectory
import unittest
import numpy as np


class TestLithiumIonParameterValues(TestCase):
def test_print_parameters(self):
parameters = pybamm.LithiumIonParameters()
parameter_values = pybamm.lithium_ion.BaseModel().default_parameter_values
output_file = "lithium_ion_parameters.txt"
parameter_values.print_parameters(parameters, output_file)
with TemporaryDirectory() as dir_name:
parameters = pybamm.LithiumIonParameters()
parameter_values = pybamm.lithium_ion.BaseModel().default_parameter_values
output_file = os.path.join(dir_name, "lithium_ion_parameters.txt")
parameter_values.print_parameters(parameters, output_file)

def test_lithium_ion(self):
"""This test checks that all the parameters are being calculated
Expand Down
14 changes: 8 additions & 6 deletions tests/unit/test_plotting/test_quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import unittest
from tests import TestCase
import numpy as np
from tempfile import TemporaryDirectory


class TestQuickPlot(TestCase):
Expand Down Expand Up @@ -290,12 +291,13 @@ def test_spm_simulation(self):
quick_plot.plot(0)

# test creating a GIF
test_stub = "spm_sim_test"
test_file = f"{test_stub}.gif"
quick_plot.create_gif(number_of_images=3, duration=3, output_filename=test_file)
assert not os.path.exists(f"{test_stub}*.png")
assert os.path.exists(test_file)
os.remove(test_file)
with TemporaryDirectory() as dir_name:
test_stub = os.path.join(dir_name, "spm_sim_test")
test_file = f"{test_stub}.gif"
quick_plot.create_gif(number_of_images=3, duration=3,
output_filename=test_file)
assert not os.path.exists(f"{test_stub}*.png")
assert os.path.exists(test_file)
pybamm.close_plots()

def test_loqs_spme(self):
Expand Down
135 changes: 71 additions & 64 deletions tests/unit/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
import unittest
import uuid
from tempfile import TemporaryDirectory


class TestSimulation(TestCase):
Expand Down Expand Up @@ -248,32 +249,37 @@ def test_step_with_inputs(self):
)

def test_save_load(self):
model = pybamm.lead_acid.LOQS()
model.use_jacobian = True
sim = pybamm.Simulation(model)

sim.save("test.pickle")
sim_load = pybamm.load_sim("test.pickle")
self.assertEqual(sim.model.name, sim_load.model.name)

# save after solving
sim.solve([0, 600])
sim.save("test.pickle")
sim_load = pybamm.load_sim("test.pickle")
self.assertEqual(sim.model.name, sim_load.model.name)
with TemporaryDirectory() as dir_name:
test_name = os.path.join(dir_name, "tests.pickle")

model = pybamm.lead_acid.LOQS()
model.use_jacobian = True
sim = pybamm.Simulation(model)

sim.save(test_name)
sim_load = pybamm.load_sim(test_name)
self.assertEqual(sim.model.name, sim_load.model.name)

# save after solving
sim.solve([0, 600])
sim.save(test_name)
sim_load = pybamm.load_sim(test_name)
self.assertEqual(sim.model.name, sim_load.model.name)

# with python formats
model.convert_to_format = None
sim = pybamm.Simulation(model)
sim.solve([0, 600])
sim.save(test_name)
model.convert_to_format = "python"
sim = pybamm.Simulation(model)
sim.solve([0, 600])
with self.assertRaisesRegex(
NotImplementedError,
"Cannot save simulation if model format is python"
):
sim.save(test_name)

# with python formats
model.convert_to_format = None
sim = pybamm.Simulation(model)
sim.solve([0, 600])
sim.save("test.pickle")
model.convert_to_format = "python"
sim = pybamm.Simulation(model)
sim.solve([0, 600])
with self.assertRaisesRegex(
NotImplementedError, "Cannot save simulation if model format is python"
):
sim.save("test.pickle")

def test_load_param(self):
# Test load_sim for parameters imports
Expand All @@ -299,33 +305,36 @@ def test_load_param(self):
os.remove(filename)

def test_save_load_dae(self):
model = pybamm.lead_acid.LOQS({"surface form": "algebraic"})
model.use_jacobian = True
sim = pybamm.Simulation(model)

# save after solving
sim.solve([0, 600])
sim.save("test.pickle")
sim_load = pybamm.load_sim("test.pickle")
self.assertEqual(sim.model.name, sim_load.model.name)

# with python format
model.convert_to_format = None
sim = pybamm.Simulation(model)
sim.solve([0, 600])
sim.save("test.pickle")

# with Casadi solver & experiment
model.convert_to_format = "casadi"
sim = pybamm.Simulation(
model,
experiment="Discharge at 1C for 20 minutes",
solver=pybamm.CasadiSolver(),
)
sim.solve([0, 600])
sim.save("test.pickle")
sim_load = pybamm.load_sim("test.pickle")
self.assertEqual(sim.model.name, sim_load.model.name)
with TemporaryDirectory() as dir_name:
test_name = os.path.join(dir_name, "test.pickle")

model = pybamm.lead_acid.LOQS({"surface form": "algebraic"})
model.use_jacobian = True
sim = pybamm.Simulation(model)

# save after solving
sim.solve([0, 600])
sim.save(test_name)
sim_load = pybamm.load_sim(test_name)
self.assertEqual(sim.model.name, sim_load.model.name)

# with python format
model.convert_to_format = None
sim = pybamm.Simulation(model)
sim.solve([0, 600])
sim.save(test_name)

# with Casadi solver & experiment
model.convert_to_format = "casadi"
sim = pybamm.Simulation(
model,
experiment="Discharge at 1C for 20 minutes",
solver=pybamm.CasadiSolver(),
)
sim.solve([0, 600])
sim.save(test_name)
sim_load = pybamm.load_sim(test_name)
self.assertEqual(sim.model.name, sim_load.model.name)

def test_plot(self):
sim = pybamm.Simulation(pybamm.lithium_ion.SPM())
Expand All @@ -340,21 +349,19 @@ def test_plot(self):
sim.plot(testing=True)

def test_create_gif(self):
sim = pybamm.Simulation(pybamm.lithium_ion.SPM())
sim.solve(t_eval=[0, 10])
with TemporaryDirectory() as dir_name:
sim = pybamm.Simulation(pybamm.lithium_ion.SPM())
sim.solve(t_eval=[0, 10])

# Create a temporary file name
test_stub = "test_sim"
test_file = f"{test_stub}.gif"
# Create a temporary file name
test_file = os.path.join(dir_name, "test_sim.gif")

# create a GIF without calling the plot method
sim.create_gif(number_of_images=3, duration=1, output_filename=test_file)

# call the plot method before creating the GIF
sim.plot(testing=True)
sim.create_gif(number_of_images=3, duration=1, output_filename=test_file)
# create a GIF without calling the plot method
sim.create_gif(number_of_images=3, duration=1, output_filename=test_file)

os.remove(test_file)
# call the plot method before creating the GIF
sim.plot(testing=True)
sim.create_gif(number_of_images=3, duration=1, output_filename=test_file)

def test_drive_cycle_interpolant(self):
model = pybamm.lithium_ion.SPM()
Expand Down
Loading

0 comments on commit 64d712e

Please sign in to comment.