You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I just started to look into parallel programming in JAX, with very little knowledge about how parallel programming works on GPUs. Basically, my goal is to write a program that runs in parallel on two GPUs. The main features of this program are: (1) each process executes a different calculation; (2) these two processes are executed in synchrony; and (3) these two processes occasionally communicate with each other. For example, a simple pseudo-code with these features is the following:
# On GPU 0:
sum_0 = 0
for i in range(1000):
sum_0 += 1
# On GPU 1:
sum_1 = 0
for j in range(1000):
sum_1 += 2
# Need to make sure i == j at any given moment
# Communicate:
while (program is running):
if i%10 == 0:
# exchange sum_0 and sum_1
sum_tmp = sum_1
sum_1 = sum_0 # on GPU 1
sum_0 = sum_tmp # on GPU 0
I'm wondering if there is a way to realize this program (or more generally, programs with these features) in JAX?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi all,
I just started to look into parallel programming in JAX, with very little knowledge about how parallel programming works on GPUs. Basically, my goal is to write a program that runs in parallel on two GPUs. The main features of this program are: (1) each process executes a different calculation; (2) these two processes are executed in synchrony; and (3) these two processes occasionally communicate with each other. For example, a simple pseudo-code with these features is the following:
I'm wondering if there is a way to realize this program (or more generally, programs with these features) in JAX?
Beta Was this translation helpful? Give feedback.
All reactions