Skip to content

Commit

Permalink
Merge pull request #244 from PrincetonUniversity/direct_mpi_calc_myrange
Browse files Browse the repository at this point in the history
[LIBSTELL] Use a more direct (i.e. no loops) implementation of `mpi_calc_myrange`
  • Loading branch information
lazersos authored Jun 4, 2024
2 parents d94626d + 954a0fe commit 87f189c
Showing 1 changed file with 71 additions and 25 deletions.
96 changes: 71 additions & 25 deletions LIBSTELL/Sources/Modules/mpi_params.f
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ MODULE mpi_params
INTEGER :: MPI_COMM_WORKERS_OK=-1, worker_id_ok=-1 !communicator subgroup, vmec ran ok
INTEGER :: MPI_COMM_SHARMEM = 718, myid_sharmem=-1 !communicator for shared memory
INTEGER :: MPI_COMM_STEL = 327 !communicator which is a copy of MPI_COMM_WORLD (user must set this up)
INTEGER :: MPI_COMM_MYWORLD = 411 !communicator
INTEGER :: MPI_COMM_MYWORLD = 411 !communicator
INTEGER :: MPI_COMM_FIELDLINES = 328 !communicator for FIELDLINES code
INTEGER :: MPI_COMM_TORLINES = 329 !communicator for TORLINES code
INTEGER :: MPI_COMM_BEAMS = 330 !communicator for BEAMS3D code
Expand All @@ -34,12 +34,12 @@ MODULE mpi_params
INTEGER :: MPI_COMM_PARVMEC = 101 !communicator for PARVMEC code

CONTAINS

SUBROUTINE mpi_stel_abort(error)
#if defined(MPI_OPT)
USE MPI
#endif
IMPLICIT NONE
IMPLICIT NONE
INTEGER, INTENT(in) :: error
INTEGER :: length, temp
CHARACTER(LEN=MPI_MAX_ERROR_STRING) :: message
Expand All @@ -48,7 +48,7 @@ SUBROUTINE mpi_stel_abort(error)
WRITE(6,*) '!!!!!!!!!!!!MPI_ERROR DETECTED!!!!!!!!!!!!!!'
WRITE(6,*) ' MESSAGE: ',message(1:length)
WRITE(6,*) '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
CALL FLUSH(6)
CALL FLUSH(6)
#else
WRITE(6,*) '!!!!!!!!!!!!MPI_ERROR DETECTED!!!!!!!!!!!!!!'
WRITE(6,*) ' MPI_STEL_ABORT CALLED BUT NO MPI'
Expand All @@ -57,34 +57,80 @@ SUBROUTINE mpi_stel_abort(error)
!CALL MPI_ABORT(MPI_COMM_STEL,1,temp)
END SUBROUTINE mpi_stel_abort

!> Distribute the workload of operating on over n1:n2 (inclusive)
!> over the compute ranks available in the given communicator
!> and return the local ranges to be worked on in mystart and myend.
!
!> This routine must __always__ run,
!> hence no `STOP` statements or similar are allowed here.
!> If more ranks than work items are available in the communicator,
!> this routine returns `myend` > `mystart`,
!> which implies that loops like `DO i = mystart, myend` are simply skipped
!> in ranks that do not get a share of the workload.
SUBROUTINE MPI_CALC_MYRANGE(comm,n1,n2,mystart,myend)
#if defined(MPI_OPT)
USE mpi
#endif
IMPLICIT NONE
INTEGER, INTENT(inout) :: comm
INTEGER, INTENT(in) :: n1, n2
INTEGER, INTENT(out) :: mystart, myend
INTEGER :: delta, local_size, local_rank, istat, maxend, k, i
mystart = n1; myend = n2
INTEGER, INTENT(inout) :: comm !< communicator to distribute work over
INTEGER, INTENT(in) :: n1 !< lower bound of range to work on
INTEGER, INTENT(in) :: n2 !< upper bound of range to work on (inclusive)
INTEGER, INTENT(out) :: mystart !< lower bound of chunk this rank should work on
INTEGER, INTENT(out) :: myend !< upper bound of chunk this rank should work on (inclusive)

INTEGER :: local_size, local_rank, istat
INTEGER :: total_work, work_per_rank, work_remainder

! Default if not using MPI: just work on full range
mystart = n1
myend = n2

#if defined(MPI_OPT)
CALL MPI_COMM_SIZE( comm, local_size, istat)
CALL MPI_COMM_RANK( comm, local_rank, istat )
delta = CEILING(DBLE(n2-n1+1)/DBLE(local_size))
mystart = n1 + local_rank*delta
myend = mystart + delta - 1
maxend = local_size*delta
IF (maxend>n2) THEN
k = maxend-n2
DO i = (local_size-k), local_size-1
IF (local_rank > i) THEN
mystart = mystart - 1
myend = myend - 1
ELSEIF (local_rank==i) THEN
myend = myend - 1
END IF
END DO

! `local_size` is the number of available ranks.
! We assume it is always > 0, i.e., 1, 2, 3, ...
CALL MPI_COMM_SIZE(comm, local_size, istat)

! `local_rank` is the ID of the rank to compute `mystart` and `myend` for.
! We assume it is always >= 0, i.e., 0, 1, 2, ...
! and only up to `local_size - 1` (inclusive).
CALL MPI_COMM_RANK(comm, local_rank, istat)

! Total number of items to work on.
! NOTE: n2 is the upper range bound, inclusive!
total_work = n2 - n1 + 1

! size of chunks that are present in all ranks
! (Note that we use integer division here intentionally,
! since the remainder is handled explicitly
! via the `work_remainder` variable below.)
work_per_rank = total_work / local_size

! number of work items that remain after distributing
! equal chunks of work to all ranks
work_remainder = MODULO(total_work, local_size)

! ranges corresponding to working on evenly distributed chunks
! `myend` is inclusive, i.e., the indices to work on are
! { mystart, mystart+1, ..., myend-1, myend }.
! Thus, one can use code like `DO i = mystart, myend`.
mystart = n1 + local_rank * work_per_rank
myend = n1 + (local_rank + 1) * work_per_rank - 1

IF (local_rank .lt. work_remainder) THEN
! The first `work_remainder` ranks get one additional item to work on.
! This takes care of the additional `work_remainder` items
! that need to be worked on, on top of the evenly distributed chunks.
mystart = mystart + local_rank
myend = myend + local_rank + 1
ELSE
! All following ranks after the first `work_remainder` ones
! get their ranges just shifted by a constant offset,
! since they don't do any additional work.
mystart = mystart + work_remainder
myend = myend + work_remainder
END IF
#endif
RETURN
END SUBROUTINE MPI_CALC_MYRANGE
Expand Down

0 comments on commit 87f189c

Please sign in to comment.