Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basesolver init function now accepts a comm #30

Merged
merged 6 commits into from
Jan 7, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions baseclasses/BaseSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,22 @@ class BaseSolver(object):
Abstract Class for a basic Solver Object
"""

def __init__(self, name, category={}, def_options={}, options={}, immutableOptions=set(), deprecatedOptions={}):
def __init__(
self, name, category, defaultOptions={}, options={}, immutableOptions=set(), deprecatedOptions={}, comm=None, informs={}
):
"""
Solver Class Initialization
"""

self.name = name
self.category = category
self.options = CaseInsensitiveDict()
self.defaultOptions = CaseInsensitiveDict(def_options)
self.defaultOptions = CaseInsensitiveDict(defaultOptions)
self.immutableOptions = CaseInsensitiveSet(immutableOptions)
self.deprecatedOptions = CaseInsensitiveDict(deprecatedOptions)
self.comm = comm
self.informs = informs
self.solverCreated = False
self.comm = None

# Initialize Options
for key, (optionType, optionValue) in self.defaultOptions.items():
Expand Down Expand Up @@ -146,7 +149,7 @@ def printModifiedOptions(self):
tmpDict = {}
for key in self.options:
defaultType, defaultValue = self.defaultOptions[key]
if defaultType == list and not isinstance(defaultValue, list):
if defaultType is list and not isinstance(defaultValue, list):
defaultValue = defaultValue[0]
optionValue = self.getOption(key)
if optionValue != defaultValue:
Expand All @@ -156,7 +159,7 @@ def printModifiedOptions(self):
def pp(self, obj):
"""
This method prints ``obj`` (via pprint) on the root proc of ``self.comm`` if it exists.
Otherswise it will just print ``obj``.
Otherwise it will just print ``obj``.

Parameters
----------
Expand Down
2 changes: 1 addition & 1 deletion baseclasses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.2.4"
__version__ = "1.2.5"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think now that we have broken backwards compatibility, we should bump to 1.3.0 at a minimum.

Copy link
Contributor Author

@eirikurj eirikurj Jan 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree, will bump


from .pyAero_problem import AeroProblem
from .pyTransi_problem import TransiProblem
Expand Down
14 changes: 12 additions & 2 deletions baseclasses/pyAero_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,15 @@ class AeroSolver(BaseSolver):
"""

def __init__(
self, name, category={}, def_options={}, informs={}, options={}, immutableOptions=set(), deprecatedOptions={}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think changing the order will break pyFriction so we should change that as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I will submit PR for the remaining repos that use baseclasses

self,
name,
category,
defaultOptions={},
options={},
immutableOptions=set(),
deprecatedOptions={},
comm=None,
informs={},
):

"""
Expand All @@ -37,10 +45,12 @@ def __init__(
super().__init__(
name,
category=category,
def_options=def_options,
defaultOptions=defaultOptions,
options=options,
immutableOptions=immutableOptions,
deprecatedOptions=deprecatedOptions,
comm=comm,
informs=informs,
)
self.families = CaseInsensitiveDict()
self._updateGeomInfo = False
Expand Down
48 changes: 45 additions & 3 deletions tests/test_BaseSolver.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
try:
from mpi4py import MPI
except ImportError:
MPI = None

import unittest
from baseclasses import BaseSolver
from baseclasses.utils import Error


class SOLVER(BaseSolver):
def __init__(self, name, options=None, *args, **kwargs):
def __init__(self, name, options={}, comm=None, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe args and kwargs are used so those can probably be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, will remove


"""Create an artificial class for testing"""

category = "Solver for testing BaseSolver"
def_opts = {
defaultOptions = {
"boolOption": [bool, True],
"floatOption": [float, 10.0],
"intOption": [int, [1, 2, 3]],
Expand All @@ -22,9 +27,22 @@ def __init__(self, name, options=None, *args, **kwargs):
"oldOption": "Use boolOption instead.",
}

informs = {
-1: "Failure -1",
0: "Success",
1: "Failure 1",
}

# Initialize the inherited BaseSolver
super().__init__(
name, category, def_opts, options, immutableOptions=immutableOptions, deprecatedOptions=deprecatedOptions
name,
category,
defaultOptions=defaultOptions,
options=options,
immutableOptions=immutableOptions,
deprecatedOptions=deprecatedOptions,
comm=comm,
informs=informs,
)


Expand Down Expand Up @@ -85,3 +103,27 @@ def test_options(self):
solver.setOption("strOPTION", "str2") # test immutableOptions
with self.assertRaises(Error):
solver.setOption("oldoption", 4) # test deprecatedOptions


class TestComm(unittest.TestCase):

N_PROCS = 2

@unittest.skipIf(MPI is None, "mpi4py not imported")
def test_comm_with_mpi(self):
# initialize solver
solver = SOLVER("testComm", comm=MPI.COMM_WORLD)
self.assertFalse(solver.comm is None)
solver.printCurrentOptions()

def test_comm_without_mpi(self):
# initialize solver
solver = SOLVER("testComm", comm=None)
self.assertTrue(solver.comm is None)
solver.printCurrentOptions()


class TestInforms(unittest.TestCase):
def test_informs(self):
solver = SOLVER("testInforms")
self.assertEqual(solver.informs[0], "Success")