Skip to content

Commit

Permalink
Julesghub/mpi update (#192)
Browse files Browse the repository at this point in the history
* importing underworld3.mpi earlier on to avoid circular imports

* Only get the communicator via underworld3.mpi.comm
  • Loading branch information
julesghub authored May 3, 2024
1 parent 15f0c52 commit 4fe6bf4
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/underworld3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def view():
from .utilities._nb_tools import *

# Needed everywhere
import underworld3.mpi
from underworld3.utilities import _api_tools

import underworld3.adaptivity
Expand All @@ -110,7 +111,6 @@ def view():
import underworld3.maths
import underworld3.utilities
import underworld3.kdtree
import underworld3.mpi
import underworld3.cython
import underworld3.scaling
import underworld3.visualisation
Expand Down
2 changes: 1 addition & 1 deletion src/underworld3/cython/petsc_generic_snes_solvers.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2517,7 +2517,7 @@ class SNES_Stokes_SaddlePt(SolverBaseClass):

from mpi4py import MPI

comm = MPI.COMM_WORLD
comm = uw.mpi.comm
max_magvel_glob = comm.allreduce(max_magvel, op=MPI.MAX)

min_dx = self.mesh.get_min_radius()
Expand Down
10 changes: 4 additions & 6 deletions src/underworld3/swarm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Optional, Tuple, Union
import contextlib

import numpy as np
import sympy
import petsc4py.PETSc as PETSc
from mpi4py import MPI
import petsc4py.PETSc as PETSc

import underworld3 as uw
from underworld3.utilities._api_tools import Stateful
Expand All @@ -16,7 +15,7 @@
import os
import warnings

comm = MPI.COMM_WORLD
comm = uw.mpi.comm

from enum import Enum

Expand Down Expand Up @@ -451,7 +450,7 @@ def save(

if h5py.h5.get_config().mpi == True and not force_sequential:
with h5py.File(
f"{filename[:-3]}.h5", "w", driver="mpio", comm=MPI.COMM_WORLD
f"{filename[:-3]}.h5", "w", driver="mpio", comm=comm
) as h5f:
with self.swarm.access(self):
if compression == True:
Expand Down Expand Up @@ -1144,7 +1143,7 @@ def save(
data_copy = self.data[:].copy()

with h5py.File(
f"{filename[:-3]}.h5", "w", driver="mpio", comm=MPI.COMM_WORLD
f"{filename[:-3]}.h5", "w", driver="mpio", comm=comm
) as h5f:
if compression == True:
h5f.create_dataset(
Expand Down Expand Up @@ -1872,7 +1871,6 @@ def estimate_dt(self, V_fn):

from mpi4py import MPI

comm = MPI.COMM_WORLD
max_magvel_glob = comm.allreduce(max_magvel, op=MPI.MAX)

min_dx = self.mesh.get_min_radius()
Expand Down
8 changes: 4 additions & 4 deletions src/underworld3/systems/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,7 +1257,7 @@ def uw_scalar_function(self, user_uw_function):
# from mpi4py import MPI

# ## get global max dif value
# comm = MPI.COMM_WORLD
# comm = uw.mpi.comm
# diffusivity_glob = comm.allreduce(max_diffusivity, op=MPI.MAX)

# ### get the velocity values
Expand Down Expand Up @@ -1552,7 +1552,7 @@ def estimate_dt(self):
from mpi4py import MPI

## get global max dif value
comm = MPI.COMM_WORLD
comm = uw.mpi.comm
diffusivity_glob = comm.allreduce(max_diffusivity, op=MPI.MAX)

### get the velocity values
Expand Down Expand Up @@ -1864,7 +1864,7 @@ def solve(
# # max_diffusivity = self.k.data[:, 0].max()

# ## get global max dif value
# comm = MPI.COMM_WORLD
# comm = uw.mpi.comm
# diffusivity_glob = comm.allreduce(max_diffusivity, op=MPI.MAX)

# ### get the velocity values
Expand Down Expand Up @@ -2623,7 +2623,7 @@ def estimate_dt(self):
from mpi4py import MPI

## get global max dif value
comm = MPI.COMM_WORLD
comm = uw.mpi.comm
diffusivity_glob = comm.allreduce(max_diffusivity, op=MPI.MAX)

### get the velocity values
Expand Down

0 comments on commit 4fe6bf4

Please sign in to comment.