Skip to content

Commit

Permalink
Add Allreduce and all MPI OP
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianDeconinck committed Dec 22, 2024
1 parent fd2fa97 commit cc620c6
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 14 deletions.
12 changes: 9 additions & 3 deletions ndsl/comm/caching_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np

from ndsl.comm.comm_abc import Comm, Request
from ndsl.comm.comm_abc import Comm, ReductionOperator, Request


T = TypeVar("T")
Expand Down Expand Up @@ -147,9 +147,12 @@ def Split(self, color, key) -> "CachingCommReader":
new_data = self._data.get_split()
return CachingCommReader(data=new_data)

def allreduce(self, sendobj, op=None) -> Any:
def allreduce(self, sendobj, op: Optional[ReductionOperator] = None) -> Any:
return self._data.get_generic_obj()

def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any:
raise NotImplementedError("CachingCommReader.Allreduce")

@classmethod
def load(cls, file: BinaryIO) -> "CachingCommReader":
data = CachingCommData.load(file)
Expand Down Expand Up @@ -229,7 +232,10 @@ def Split(self, color, key) -> "CachingCommWriter":
def dump(self, file: BinaryIO):
self._data.dump(file)

def allreduce(self, sendobj, op=None) -> Any:
def allreduce(self, sendobj, op: Optional[ReductionOperator] = None) -> Any:
result = self._comm.allreduce(sendobj, op)
self._data.generic_obj_buffers.append(copy.deepcopy(result))
return result

def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any:
raise NotImplementedError("CachingCommWriter.Allreduce")
26 changes: 25 additions & 1 deletion ndsl/comm/comm_abc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,30 @@
import abc
import enum
from typing import List, Optional, TypeVar


T = TypeVar("T")


@enum.unique
class ReductionOperator(enum.Enum):
OP_NULL = enum.auto()
MAX = enum.auto()
MIN = enum.auto()
SUM = enum.auto()
PROD = enum.auto()
LAND = enum.auto()
BAND = enum.auto()
LOR = enum.auto()
BOR = enum.auto()
LXOR = enum.auto()
BXOR = enum.auto()
MAXLOC = enum.auto()
MINLOC = enum.auto()
REPLACE = enum.auto()
NO_OP = enum.auto()


class Request(abc.ABC):
@abc.abstractmethod
def wait(self):
Expand Down Expand Up @@ -69,5 +89,9 @@ def Split(self, color, key) -> "Comm":
...

@abc.abstractmethod
def allreduce(self, sendobj: T, op=None) -> T:
def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T:
...

@abc.abstractmethod
def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T:
...
19 changes: 15 additions & 4 deletions ndsl/comm/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from typing import List, Mapping, Optional, Sequence, Tuple, Union, cast

import numpy as np
from mpi4py import MPI

import ndsl.constants as constants
from ndsl.buffer import array_buffer, device_synchronize, recv_buffer, send_buffer
from ndsl.comm.boundary import Boundary
from ndsl.comm.comm_abc import ReductionOperator
from ndsl.comm.partitioner import CubedSpherePartitioner, Partitioner, TilePartitioner
from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater
from ndsl.performance.timer import NullTimer, Timer
Expand Down Expand Up @@ -109,10 +109,13 @@ def _create_all_reduce_quantity(
)
return all_reduce_quantity

def all_reduce_sum(
self, input_quantity: Quantity, output_quantity: Quantity = None
def all_reduce(
self,
input_quantity: Quantity,
op: ReductionOperator,
output_quantity: Quantity = None,
):
reduced_quantity_data = self.comm.allreduce(input_quantity.data, MPI.SUM)
reduced_quantity_data = self.comm.allreduce(input_quantity.data, op)
if output_quantity is None:
all_reduce_quantity = self._create_all_reduce_quantity(
input_quantity.metadata, reduced_quantity_data
Expand All @@ -126,6 +129,14 @@ def all_reduce_sum(

output_quantity.data = reduced_quantity_data

def all_reduce_per_element(
self,
input_quantity: Quantity,
output_quantity: Quantity,
op: ReductionOperator,
):
self.comm.Allreduce(input_quantity.data, output_quantity.data, op)

def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs):
with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer(
numpy_module.zeros, recvbuf
Expand Down
31 changes: 28 additions & 3 deletions ndsl/comm/mpi.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,36 @@
try:
import mpi4py
from mpi4py import MPI
except ImportError:
MPI = None
from typing import List, Optional, TypeVar, cast
from typing import Dict, List, Optional, TypeVar, cast

from ndsl.comm.comm_abc import Comm, Request
from ndsl.comm.comm_abc import Comm, ReductionOperator, Request
from ndsl.logging import ndsl_log


T = TypeVar("T")


class MPIComm(Comm):
_op_mapping: Dict[ReductionOperator, mpi4py.MPI.Op] = {
ReductionOperator.OP_NULL: mpi4py.MPI.OP_NULL,
ReductionOperator.MAX: mpi4py.MPI.MAX,
ReductionOperator.MIN: mpi4py.MPI.MIN,
ReductionOperator.SUM: mpi4py.MPI.SUM,
ReductionOperator.PROD: mpi4py.MPI.PROD,
ReductionOperator.LAND: mpi4py.MPI.LAND,
ReductionOperator.BAND: mpi4py.MPI.BAND,
ReductionOperator.LOR: mpi4py.MPI.LOR,
ReductionOperator.BOR: mpi4py.MPI.BOR,
ReductionOperator.LXOR: mpi4py.MPI.LXOR,
ReductionOperator.BXOR: mpi4py.MPI.BXOR,
ReductionOperator.MAXLOC: mpi4py.MPI.MAXLOC,
ReductionOperator.MINLOC: mpi4py.MPI.MINLOC,
ReductionOperator.REPLACE: mpi4py.MPI.REPLACE,
ReductionOperator.NO_OP: mpi4py.MPI.NO_OP,
}

def __init__(self):
if MPI is None:
raise RuntimeError("MPI not available")
Expand Down Expand Up @@ -72,8 +91,14 @@ def Split(self, color, key) -> "Comm":
)
return self._comm.Split(color, key)

def allreduce(self, sendobj: T, op=None) -> T:
def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T:
ndsl_log.debug(
"allreduce on rank %s with operator %s", self._comm.Get_rank(), op
)
return self._comm.allreduce(sendobj, op)

def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T:
ndsl_log.debug(
"allreduce on rank %s with operator %s", self._comm.Get_rank(), op
)
return self._comm.Allreduce(sendobj, recvobj, self._op_mapping[op])
10 changes: 7 additions & 3 deletions ndsl/comm/null_comm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
from typing import Any, Mapping
from typing import Any, Mapping, Optional

from ndsl.comm.comm_abc import Comm, Request
from ndsl.comm.comm_abc import Comm, ReductionOperator, Request


class NullAsyncResult(Request):
Expand Down Expand Up @@ -91,5 +91,9 @@ def Split(self, color, key):
self._split_comms[color].append(new_comm)
return new_comm

def allreduce(self, sendobj, op=None) -> Any:
def allreduce(self, sendobj, op: Optional[ReductionOperator] = None) -> Any:
return self._fill_value

def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any:
recvobj = sendobj
return recvobj

0 comments on commit cc620c6

Please sign in to comment.