Skip to content

Commit

Permalink
Merge pull request ESMCI#2 from mvertens/feature/update_cime
Browse files Browse the repository at this point in the history
updated to cime6.0.83
  • Loading branch information
mvertens authored Jan 3, 2023
2 parents a74992c + 2adace0 commit 2e7963f
Show file tree
Hide file tree
Showing 22 changed files with 742 additions and 74 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ jobs:
contents: write # for peaceiris/actions-gh-pages to push
pull-requests: write # to comment on pull requests
needs: check-changes
if: needs.check-changes.outputs.any_changed == 'true'
if: |
needs.check-changes.outputs.any_changed == 'true' &&
github.event.pull_request.head.repo.full_name == github.repository
name: Build and deploy documentation
runs-on: ubuntu-latest
steps:
Expand Down
3 changes: 1 addition & 2 deletions CIME/Tools/xmlchange
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Examples:
Several xml variables that have settings for each component have somewhat special treatment.
The variables that this currently applies to are:
NTASKS, NTHRDS, ROOTPE, PIO_TYPENAME, PIO_STRIDE, PIO_NUMTASKS
NTASKS, NTHRDS, ROOTPE, PIO_TYPENAME, PIO_STRIDE, PIO_NUMTASKS, PIO_ASYNC_INTERFACE
For example, to set the number of tasks for all components to 16, use:
./xmlchange NTASKS=16
To set just the number of tasks for the atm component, use:
Expand Down Expand Up @@ -303,7 +303,6 @@ def xmlchange(
% (pair),
)
(xmlid, xmlval) = pair

xmlchange_single_value(
case, xmlid, xmlval, subgroup, append, force, dryrun, env_test
)
Expand Down
41 changes: 26 additions & 15 deletions CIME/XML/env_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,14 +654,24 @@ def _get_argument(self, case, arg):

def _resolve_argument(self, case, flag, name, job):
submitargs = ""

if name.startswith("$"):
name = name[1:]
logger.debug("name is {}".format(name))
# if name.startswith("$"):
# name = name[1:]

if "$" in name:
# We have a complex expression and must rely on get_resolved_value.
# Hopefully, none of the values require subgroup
val = case.get_resolved_value(name)
parts = name.split("$")
logger.debug("parts are {}".format(parts))
val = ""
for part in parts:
if part != "":
logger.debug("part is {}".format(part))
resolved = case.get_value(part, subgroup=job)
if resolved:
val += resolved
else:
val += part
logger.debug("val is {}".format(name))
val = case.get_resolved_value(val)
else:
val = case.get_value(name, subgroup=job)

Expand All @@ -675,12 +685,9 @@ def _resolve_argument(self, case, flag, name, job):
else:
rval = val

if flag != "-P":
# We don't want floating-point data
try:
rval = int(round(float(rval)))
except ValueError:
pass
# We don't want floating-point data (ignore anything else)
if str(rval).replace(".", "", 1).isdigit():
rval = int(round(float(rval)))

# need a correction for tasks per node
if flag == "-n" and rval <= 0:
Expand Down Expand Up @@ -1110,9 +1117,13 @@ def set_batch_system_type(self, batchtype):

def get_job_id(self, output):
jobid_pattern = self.get_value("jobid_pattern", subgroup=None)
expect(
jobid_pattern is not None, "Could not find jobid_pattern in env_batch.xml"
)
if self._batchtype and self._batchtype != "none":
expect(
jobid_pattern is not None,
"Could not find jobid_pattern in env_batch.xml",
)
else:
return output
search_match = re.search(jobid_pattern, output)
expect(
search_match is not None,
Expand Down
34 changes: 31 additions & 3 deletions CIME/XML/env_mach_pes.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ def set_value(self, vid, value, subgroup=None, ignore_type=False):
comp, ninst, ninst_max
),
)
if "NTASKS" in vid or "NTHRDS" in vid:
expect(value != 0, "Cannot set NTASKS or NTHRDS to 0")

if ("NTASKS" in vid or "NTHRDS" in vid) and vid != "PIO_ASYNCIO_NTASKS":
expect(value != 0, f"Cannot set NTASKS or NTHRDS to 0 {vid}")

return EnvBase.set_value(
self, vid, value, subgroup=subgroup, ignore_type=ignore_type
Expand All @@ -104,18 +105,41 @@ def get_max_thread_count(self, comp_classes):
max_threads = threads
return max_threads

def get_total_tasks(self, comp_classes):
def get_total_tasks(self, comp_classes, async_interface=False):
total_tasks = 0
maxinst = self.get_value("NINST")
asyncio_ntasks = 0
asyncio_rootpe = 0
asyncio_stride = 0
asyncio_tasks = []
if maxinst:
comp_interface = "nuopc"
if async_interface:
asyncio_ntasks = self.get_value("PIO_ASYNCIO_NTASKS")
asyncio_rootpe = self.get_value("PIO_ASYNCIO_ROOTPE")
asyncio_stride = self.get_value("PIO_ASYNCIO_STRIDE")
logger.debug(
"asyncio ntasks {} rootpe {} stride {}".format(
asyncio_ntasks, asyncio_rootpe, asyncio_stride
)
)
if asyncio_ntasks and asyncio_stride:
for i in range(
asyncio_rootpe,
asyncio_rootpe + (asyncio_ntasks * asyncio_stride),
asyncio_stride,
):
asyncio_tasks.append(i)
else:
comp_interface = "unknown"
maxinst = 1
tt = 0
maxrootpe = 0
for comp in comp_classes:
ntasks = self.get_value("NTASKS", attribute={"compclass": comp})
rootpe = self.get_value("ROOTPE", attribute={"compclass": comp})
pstrid = self.get_value("PSTRID", attribute={"compclass": comp})

esmf_aware_threading = self.get_value("ESMF_AWARE_THREADING")
# mct is unaware of threads and they should not be counted here
# if esmf is thread aware they are included
Expand All @@ -128,9 +152,13 @@ def get_total_tasks(self, comp_classes):
ninst = self.get_value("NINST", attribute={"compclass": comp})
maxinst = max(maxinst, ninst)
tt = rootpe + nthrds * ((ntasks - 1) * pstrid + 1)
maxrootpe = max(maxrootpe, rootpe)
total_tasks = max(tt, total_tasks)
if asyncio_tasks:
total_tasks = total_tasks + len(asyncio_tasks)
if self.get_value("MULTI_DRIVER"):
total_tasks *= maxinst
logger.debug("asyncio_tasks {}".format(asyncio_tasks))
return total_tasks

def get_tasks_per_node(self, total_tasks, max_thread_count):
Expand Down
78 changes: 71 additions & 7 deletions CIME/XML/env_mach_specific.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,16 +581,12 @@ def get_module_system_cmd_path(self, lang):
else None
)

def get_mpirun(self, case, attribs, job, exe_only=False, overrides=None):
"""
Find best match, return (executable, {arg_name : text})
"""
def _find_best_mpirun_match(self, attribs):
mpirun_nodes = self.get_children("mpirun")
best_match = None
best_num_matched = -1
default_match = None
best_num_matched_default = -1
args = []
for mpirun_node in mpirun_nodes:
xml_attribs = self.attrib(mpirun_node)
all_match = True
Expand Down Expand Up @@ -638,14 +634,82 @@ def get_mpirun(self, case, attribs, job, exe_only=False, overrides=None):
and attribs["mpilib"] == "mpi-serial"
and best_match is None
):
return "", [], None, None
raise ValueError()

expect(
best_match is not None or default_match is not None,
"Could not find a matching MPI for attributes: {}".format(attribs),
)

the_match = best_match if best_match is not None else default_match
return best_match if best_match is not None else default_match

def get_aprun_mode(self, attribs):
default_mode = "default"
valid_modes = ("ignore", "default", "override")

try:
the_match = self._find_best_mpirun_match(attribs)
except ValueError:
return default_mode

mode_node = self.get_children("aprun_mode", root=the_match)

if len(mode_node) == 0:
return default_mode

expect(len(mode_node) == 1, 'Found multiple "aprun_mode" elements.')

# should have only one element to select from
mode = self.text(mode_node[0])

expect(
mode in valid_modes,
f"Value {mode!r} for \"aprun_mode\" is not valid, options are {', '.join(valid_modes)!r}",
)

return mode

def get_aprun_args(self, case, attribs, job, overrides=None):
args = {}

try:
the_match = self._find_best_mpirun_match(attribs)
except ValueError:
return None

arg_node = self.get_optional_child("arguments", root=the_match)

if arg_node:
arg_nodes = self.get_children("arg", root=arg_node)

for arg_node in arg_nodes:
position = self.get(arg_node, "position")

if position is None:
position = "per"

arg_value = transform_vars(
self.text(arg_node),
case=case,
subgroup=job,
overrides=overrides,
default=self.get(arg_node, "default"),
)

args[arg_value] = dict(position=position)

return args

def get_mpirun(self, case, attribs, job, exe_only=False, overrides=None):
"""
Find best match, return (executable, {arg_name : text})
"""
args = []

try:
the_match = self._find_best_mpirun_match(attribs)
except ValueError:
return "", [], None, None

# Now that we know the best match, compute the arguments
if not exe_only:
Expand Down
12 changes: 8 additions & 4 deletions CIME/XML/env_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def set_value(self, vid, value, subgroup=None, ignore_type=False):
Returns the value or None if not found
subgroup is ignored in the general routine and applied in specific methods
"""
comp = None
if any(self._pio_async_interface.values()):
vid, comp, iscompvar = self.check_if_comp_var(vid, None)
if vid.startswith("PIO") and iscompvar:
Expand All @@ -58,9 +59,12 @@ def set_value(self, vid, value, subgroup=None, ignore_type=False):
subgroup = "CPL"

if vid == "PIO_ASYNC_INTERFACE":
if type(value) == type(True):
self._pio_async_interface = value
else:
self._pio_async_interface = convert_to_type(value, "logical", vid)
if comp:
if type(value) == type(True):
self._pio_async_interface[comp] = value
else:
self._pio_async_interface[comp] = convert_to_type(
value, "logical", vid
)

return EnvBase.set_value(self, vid, value, subgroup, ignore_type)
34 changes: 25 additions & 9 deletions CIME/aprun.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def _get_aprun_cmd_for_case_impl(
compiler,
machine,
run_exe,
extra_args,
):
###############################################################################
"""
Expand All @@ -38,19 +39,22 @@ def _get_aprun_cmd_for_case_impl(
>>> compiler = "pgi"
>>> machine = "titan"
>>> run_exe = "e3sm.exe"
>>> _get_aprun_cmd_for_case_impl(ntasks, nthreads, rootpes, pstrids, max_tasks_per_node, max_mpitasks_per_node, pio_numtasks, pio_async_interface, compiler, machine, run_exe)
(' -S 4 -n 680 -N 8 -d 2 e3sm.exe : -S 2 -n 128 -N 4 -d 4 e3sm.exe ', 117, 808, 4, 4)
>>> _get_aprun_cmd_for_case_impl(ntasks, nthreads, rootpes, pstrids, max_tasks_per_node, max_mpitasks_per_node, pio_numtasks, pio_async_interface, compiler, machine, run_exe, None)
(' -S 4 -n 680 -N 8 -d 2 e3sm.exe : -S 2 -n 128 -N 4 -d 4 e3sm.exe ', 117, 808, 4, 4)
>>> compiler = "intel"
>>> _get_aprun_cmd_for_case_impl(ntasks, nthreads, rootpes, pstrids, max_tasks_per_node, max_mpitasks_per_node, pio_numtasks, pio_async_interface, compiler, machine, run_exe)
(' -S 4 -cc numa_node -n 680 -N 8 -d 2 e3sm.exe : -S 2 -cc numa_node -n 128 -N 4 -d 4 e3sm.exe ', 117, 808, 4, 4)
>>> _get_aprun_cmd_for_case_impl(ntasks, nthreads, rootpes, pstrids, max_tasks_per_node, max_mpitasks_per_node, pio_numtasks, pio_async_interface, compiler, machine, run_exe, None)
(' -S 4 -cc numa_node -n 680 -N 8 -d 2 e3sm.exe : -S 2 -cc numa_node -n 128 -N 4 -d 4 e3sm.exe ', 117, 808, 4, 4)
>>> ntasks = [64, 64, 64, 64, 64, 64, 64, 64, 1]
>>> nthreads = [1, 1, 1, 1, 1, 1, 1, 1, 1]
>>> rootpes = [0, 0, 0, 0, 0, 0, 0, 0, 0]
>>> pstrids = [1, 1, 1, 1, 1, 1, 1, 1, 1]
>>> _get_aprun_cmd_for_case_impl(ntasks, nthreads, rootpes, pstrids, max_tasks_per_node, max_mpitasks_per_node, pio_numtasks, pio_async_interface, compiler, machine, run_exe)
(' -S 8 -cc numa_node -n 64 -N 16 -d 1 e3sm.exe ', 4, 64, 16, 1)
>>> _get_aprun_cmd_for_case_impl(ntasks, nthreads, rootpes, pstrids, max_tasks_per_node, max_mpitasks_per_node, pio_numtasks, pio_async_interface, compiler, machine, run_exe, None)
(' -S 8 -cc numa_node -n 64 -N 16 -d 1 e3sm.exe ', 4, 64, 16, 1)
"""
if extra_args is None:
extra_args = {}

max_tasks_per_node = 1 if max_tasks_per_node < 1 else max_tasks_per_node

total_tasks = 0
Expand Down Expand Up @@ -78,6 +82,12 @@ def _get_aprun_cmd_for_case_impl(
if maxt[c1] < 1:
maxt[c1] = 1

global_flags = " ".join(
[x for x, y in extra_args.items() if y["position"] == "global"]
)

per_flags = " ".join([x for x, y in extra_args.items() if y["position"] == "per"])

# Compute task and thread settings for batch commands
(
tasks_per_node,
Expand All @@ -88,7 +98,7 @@ def _get_aprun_cmd_for_case_impl(
total_node_count,
total_task_count,
aprun_args,
) = (0, max_mpitasks_per_node, 1, maxt[0], maxt[0], 0, 0, "")
) = (0, max_mpitasks_per_node, 1, maxt[0], maxt[0], 0, 0, f" {global_flags}")
c1list = list(range(1, total_tasks))
c1list.append(None)
for c1 in c1list:
Expand All @@ -107,10 +117,11 @@ def _get_aprun_cmd_for_case_impl(
if compiler == "intel":
aprun_args += " -cc numa_node"

aprun_args += " -n {:d} -N {:d} -d {:d} {} {}".format(
aprun_args += " -n {:d} -N {:d} -d {:d} {} {} {}".format(
task_count,
tasks_per_node,
thread_count,
per_flags,
run_exe,
"" if c1 is None else ":",
)
Expand Down Expand Up @@ -140,7 +151,7 @@ def _get_aprun_cmd_for_case_impl(


###############################################################################
def get_aprun_cmd_for_case(case, run_exe, overrides=None):
def get_aprun_cmd_for_case(case, run_exe, overrides=None, extra_args=None):
###############################################################################
"""
Given a case, construct and return the aprun command and optimized node count
Expand All @@ -156,6 +167,10 @@ def get_aprun_cmd_for_case(case, run_exe, overrides=None):
the_list.append(case.get_value("_".join([item_name, model])))
max_tasks_per_node = case.get_value("MAX_TASKS_PER_NODE")
if overrides:
overrides = {
x: y if isinstance(y, int) or y is None else int(y)
for x, y in overrides.items()
}
if "max_tasks_per_node" in overrides:
max_tasks_per_node = overrides["max_tasks_per_node"]
if "total_tasks" in overrides:
Expand All @@ -175,4 +190,5 @@ def get_aprun_cmd_for_case(case, run_exe, overrides=None):
case.get_value("COMPILER"),
case.get_value("MACH"),
run_exe,
extra_args,
)
Loading

0 comments on commit 2e7963f

Please sign in to comment.