Skip to content

Commit

Permalink
delete MPIRequest.wait()
Browse files Browse the repository at this point in the history
  • Loading branch information
mtar committed Sep 15, 2020
1 parent 720a878 commit 136a053
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 81 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
- [#664](https://github.com/helmholtz-analytics/heat/pull/664) New feature / enhancement: `random.random_sample`, `random.random`, `random.sample`, `random.ranf`, `random.random_integer`
- [#666](https://github.com/helmholtz-analytics/heat/pull/666) New feature: distributed prepend/append for diff().
- [#667](https://github.com/helmholtz-analytics/heat/pull/667) Enhancement `reshape`: rename axis parameter
- [#672](https://github.com/helmholtz-analytics/heat/pull/672) Bug / Enhancement: Remove `MPIRequest.wait()`, rewrite calls with capital letters

# v0.4.0

Expand Down
4 changes: 2 additions & 2 deletions heat/core/arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def diff(a, n=1, axis=-1, prepend=None, append=None):
ret.lloc[diff_slice] = dif

if rank > 0:
snd.wait() # wait for the send to finish
snd.Wait() # wait for the send to finish
if rank < size - 1:
cr_slice = [slice(None)] * len(a.shape)
# slice of 1 element in the selected axis for the shape creation
Expand All @@ -399,7 +399,7 @@ def diff(a, n=1, axis=-1, prepend=None, append=None):
axis_slice_end = [slice(None)] * len(a.shape)
# select the last elements in the selected axis
axis_slice_end[axis] = slice(-1, None)
rec.wait()
rec.Wait()
# diff logic
ret.lloc[axis_slice_end] = (
recv_data.reshape(ret.lloc[axis_slice_end].shape) - ret.lloc[axis_slice_end]
Expand Down
12 changes: 0 additions & 12 deletions heat/core/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,18 +1103,6 @@ def Wait(self, status=None):
self.recvbuf = self.recvbuf.permute(self.permutation)
self.tensor.copy_(self.recvbuf)

def wait(self, status=None):
self.handle.wait(status)
if (
self.tensor is not None
and isinstance(self.tensor, torch.Tensor)
and self.tensor.is_cuda
and not CUDA_AWARE_MPI
):
if self.permutation is not None:
self.recvbuf = self.recvbuf.permute(self.permutation)
self.tensor.copy_(self.recvbuf)

def __getattr__(self, name):
"""
Default pass-through for the communicator methods.
Expand Down
4 changes: 2 additions & 2 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def get_halo(self, halo_size):
req_list.append(self.comm.Irecv(res_next, source=self.comm.rank - 1))

for req in req_list:
req.wait()
req.Wait()

self.__halo_next = res_prev
self.__halo_prev = res_next
Expand Down Expand Up @@ -2775,7 +2775,7 @@ def resplit_(self, axis=None):
lp_arr = None
for k in lp_keys:
if rcv[k][0] is not None:
rcv[k][0].wait()
rcv[k][0].Wait()
if lp_arr is None:
lp_arr = rcv[k][1]
else:
Expand Down
18 changes: 9 additions & 9 deletions heat/core/linalg/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def matmul(a, b, allow_resplit=False):
if any(lshape_map[:, 0, :][:, 1] == 1):
a_d1_1s_flag = True

index_map_comm.wait()
index_map_comm.Wait()
for pr in range(a.comm.size):
start0 = index_map[pr, 0, 0, 0].item()
stop0 = index_map[pr, 0, 0, 1].item()
Expand All @@ -382,7 +382,7 @@ def matmul(a, b, allow_resplit=False):
a_block_map[pr, dim0, dim1] = torch.tensor(
(dim0 * mB, dim1 * kB), dtype=torch.int, device=a._DNDarray__array.device
)
rem_map_comm.wait()
rem_map_comm.Wait()
if b.split == 0:
# the blocks are shifted in the 2nd dimension of A for as many remainders
# there are between the blocks in the first dim of B
Expand Down Expand Up @@ -440,7 +440,7 @@ def matmul(a, b, allow_resplit=False):
b_block_map[:, cnt:, :, 0] += 1

# work loop: loop over all processes (also will incorporate the remainder calculations)
c_wait.wait()
c_wait.Wait()

if split_0_flag:
# need to send b here and not a
Expand Down Expand Up @@ -484,7 +484,7 @@ def matmul(a, b, allow_resplit=False):

# receive the data from the last loop and do the calculation with that
if pr != 0:
req[pr - 1].wait()
req[pr - 1].Wait()
# after receiving the last loop's bcast
__mm_c_block_setter(
b_proc=pr - 1,
Expand Down Expand Up @@ -518,7 +518,7 @@ def matmul(a, b, allow_resplit=False):

# need to wait if its the last loop, also need to collect the remainders
if pr == b.comm.size - 1:
req[pr].wait()
req[pr].Wait()
__mm_c_block_setter(
b_proc=pr,
a_proc=a.comm.rank,
Expand Down Expand Up @@ -610,7 +610,7 @@ def matmul(a, b, allow_resplit=False):
# receive the data from the last loop and do the calculation with that
if pr != 0:
# after receiving the last loop's bcast
req[pr - 1].wait()
req[pr - 1].Wait()
__mm_c_block_setter(
a_proc=pr - 1,
b_proc=b.comm.rank,
Expand Down Expand Up @@ -645,7 +645,7 @@ def matmul(a, b, allow_resplit=False):

# need to wait if its the last loop, also need to collect the remainders
if pr == b.comm.size - 1:
req[pr].wait()
req[pr].Wait()
__mm_c_block_setter(
a_proc=pr,
b_proc=a.comm.rank,
Expand Down Expand Up @@ -706,7 +706,7 @@ def matmul(a, b, allow_resplit=False):

# receive the data from the last loop and do the calculation with that
if pr != 0:
req[pr - 1].wait()
req[pr - 1].Wait()
# after receiving the last loop's bcast
st0 = index_map[pr - 1, 0, 0, 0].item()
sp0 = index_map[pr - 1, 0, 0, 1].item() + 1
Expand All @@ -717,7 +717,7 @@ def matmul(a, b, allow_resplit=False):
del b_lp_data[pr - 1]

if pr == b.comm.size - 1:
req[pr].wait()
req[pr].Wait()
st0 = index_map[pr, 0, 0, 0].item()
sp0 = index_map[pr, 0, 0, 1].item() + 1
st1 = index_map[pr, 1, 1, 0].item()
Expand Down
8 changes: 4 additions & 4 deletions heat/core/linalg/qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def __split0_q_loop(col, r_tiles, proc_tile_start, active_procs, q0_tiles, q_dic
if col in q_dict_waits.keys():
for key in q_dict_waits[col].keys():
new_key = q_dict_waits[col][key][3] + key + "e"
q_dict_waits[col][key][0][1].wait()
q_dict_waits[col][key][0][1].Wait()
q_dict[col][new_key] = [
q_dict_waits[col][key][0][0],
q_dict_waits[col][key][1].wait(),
Expand Down Expand Up @@ -728,7 +728,7 @@ def __split0_q_loop(col, r_tiles, proc_tile_start, active_procs, q0_tiles, q_dic
for pr in range(diag_process, active_procs[-1] + 1):
if local_merge_q[pr][1] is not None:
# receive q from the other processes
local_merge_q[pr][1].wait()
local_merge_q[pr][1].Wait()
if rank in active_procs:
sum_row = sum(q0_tiles.tile_rows_per_process[:pr])
end_row = q0_tiles.tile_rows_per_process[pr] + sum_row
Expand Down Expand Up @@ -790,7 +790,7 @@ def __split0_q_loop(col, r_tiles, proc_tile_start, active_procs, q0_tiles, q_dic
)
for ind in qi_mult[qi_col]:
if global_merge_dict[ind][1] is not None:
global_merge_dict[ind][1].wait()
global_merge_dict[ind][1].Wait()
lp_q = global_merge_dict[ind][0]
if mult_qi_col.shape[1] < lp_q.shape[1]:
new_mult = torch.zeros(
Expand All @@ -810,7 +810,7 @@ def __split0_q_loop(col, r_tiles, proc_tile_start, active_procs, q0_tiles, q_dic
q0_tiles.arr.lloc[:, write_inds[2] : write_inds[2] + hold.shape[1]] = hold
else:
for ind in merge_dict_keys:
global_merge_dict[ind][1].wait()
global_merge_dict[ind][1].Wait()
if col in q_dict.keys():
del q_dict[col]

Expand Down
12 changes: 6 additions & 6 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ def concatenate(arrays, axis=0):
chunk_map[arr0.comm.rank, i] = chk[i].stop - chk[i].start
chunk_map_comm = arr0.comm.Iallreduce(MPI.IN_PLACE, chunk_map, MPI.SUM)

lshape_map_comm.wait()
chunk_map_comm.wait()
lshape_map_comm.Wait()
chunk_map_comm.Wait()

if s0 is not None:
send_slice = [slice(None)] * arr0.ndim
Expand All @@ -341,7 +341,7 @@ def concatenate(arrays, axis=0):
tag=pr + arr0.comm.size + spr,
)
arr0._DNDarray__array = arr0.lloc[keep_slice].clone()
send.wait()
send.Wait()
for pr in range(spr):
snt = abs((chunk_map[pr, s0] - lshape_map[0, pr, s0]).item())
snt = (
Expand Down Expand Up @@ -388,7 +388,7 @@ def concatenate(arrays, axis=0):
tag=pr + arr1.comm.size + spr,
)
arr1._DNDarray__array = arr1.lloc[keep_slice].clone()
send.wait()
send.Wait()
for pr in range(arr1.comm.size - 1, spr, -1):
snt = abs((chunk_map[pr, axis] - lshape_map[1, pr, axis]).item())
snt = (
Expand Down Expand Up @@ -2010,9 +2010,9 @@ def resplit(arr, axis=None):
buf = torch.zeros_like(new_tiles[key])
rcv_waits[key] = [arr.comm.Irecv(buf=buf, source=spr, tag=spr), buf]
for w in waits:
w.wait()
w.Wait()
for k in rcv_waits.keys():
rcv_waits[k][0].wait()
rcv_waits[k][0].Wait()
new_tiles[k] = rcv_waits[k][1]

return new_arr
Expand Down
Loading

0 comments on commit 136a053

Please sign in to comment.