Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add process set support for some remaining functions [TensorFlow] #3054

Merged
merged 1 commit into from
Jul 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions horovod/tensorflow/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,41 +24,47 @@
from horovod.tensorflow.mpi_ops import allgather, broadcast
from horovod.tensorflow.mpi_ops import rank, size
from horovod.tensorflow.util import _cache, _executing_eagerly, _make_subgraph
from horovod.common.process_sets import ProcessSet, global_process_set


@_cache
def _make_broadcast_group_fn():
if _executing_eagerly():
# Eager mode will parallelize independent control flow
def broadcast_group(variables, root_rank):
def broadcast_group(variables, root_rank, process_set: ProcessSet):
for var in variables:
var.assign(broadcast(var, root_rank))
var.assign(broadcast(var, root_rank, process_set=process_set))

return _make_subgraph(broadcast_group)
else:
# Graph mode requires an Op
def broadcast_group(variables, root_rank):
return tf.group(*[var.assign(broadcast(var, root_rank))
def broadcast_group(variables, root_rank, process_set: ProcessSet):
return tf.group(*[var.assign(broadcast(var, root_rank, process_set=process_set))
for var in variables])

return broadcast_group


def broadcast_variables(variables, root_rank):
"""Broadcasts variables from root rank to all other processes.
def broadcast_variables(variables, root_rank, process_set=global_process_set):
"""
Broadcasts variables from root rank to all other processes
in a process set (defaults to all Horovod processes).

Arguments:
variables: variables for broadcast
root_rank: rank of the process from which global variables will be broadcasted
to all other processes.
process_set: Process set object to limit this operation to a subset of
Horovod processes. Default is the global process set.
"""
broadcast_group = _make_broadcast_group_fn()
return broadcast_group(variables, root_rank)
return broadcast_group(variables, root_rank, process_set)


def broadcast_object(obj, root_rank=0, session=None, name=None):
def broadcast_object(obj, root_rank=0, session=None, name=None, process_set=global_process_set):
"""
Serializes and broadcasts an object from root rank to all other processes.
Serializes and broadcasts an object from root rank to all other processes
in a process set (defaults to all Horovod processes).

Arguments:
obj: An object capable of being serialized without losing any context.
Expand All @@ -67,6 +73,8 @@ def broadcast_object(obj, root_rank=0, session=None, name=None):
session: Session for TensorFlow v1 compatibility.
name: Optional name to use during broadcast, will default to the class
type.
process_set: Process set object to limit this operation to a subset of
Horovod processes. Default is the global process set.
Returns:
The object that was broadcast from the `root_rank`.
"""
Expand All @@ -85,13 +93,13 @@ def to_numpy(v):
cloudpickle.dump(obj, b)
t = tf.convert_to_tensor(bytearray(b.getvalue()), dtype=tf.uint8)
sz = tf.convert_to_tensor([t.shape[0]], dtype=tf.int32)
to_numpy(broadcast(sz, root_rank, name + '.sz'))
to_numpy(broadcast(sz, root_rank, name + '.sz', process_set=process_set))
else:
sz = tf.convert_to_tensor([0], dtype=tf.int32)
sz = to_numpy(broadcast(sz, root_rank, name + '.sz'))
sz = to_numpy(broadcast(sz, root_rank, name + '.sz', process_set=process_set))
t = tf.zeros(sz.tolist()[0], dtype=tf.uint8)

t = to_numpy(broadcast(t, root_rank, name + '.t'))
t = to_numpy(broadcast(t, root_rank, name + '.t', process_set=process_set))

if rank() != root_rank:
buf = io.BytesIO(t.tobytes())
Expand All @@ -100,14 +108,14 @@ def to_numpy(v):
return obj


def broadcast_object_fn(root_rank=0, session=None, name=None):
def broadcast_object_fn(root_rank=0, session=None, name=None, process_set=global_process_set):
name = name or 'broadcast_object_fn'

sz = tf.placeholder(tf.int32, [1], name='bcast_object_size')
bcast_size = broadcast(sz, root_rank, name + '.sz')
bcast_size = broadcast(sz, root_rank, name + '.sz', process_set=process_set)

t = tf.placeholder(tf.uint8, [None], name='bcast_object_data')
bcast_data = broadcast(t, root_rank, name + '.t')
bcast_data = broadcast(t, root_rank, name + '.t', process_set=process_set)

session = session or ops.get_default_session()

Expand All @@ -133,7 +141,7 @@ def _bcast(obj):
return _bcast


def allgather_object(obj, session=None, name=None):
def allgather_object(obj, session=None, name=None, process_set=global_process_set):
"""
Serializes and allgathers an object from all other processes.

Expand All @@ -142,6 +150,8 @@ def allgather_object(obj, session=None, name=None):
session: Session for TensorFlow v1 compatibility.
name: Optional name to use during allgather, will default to the class
type.
process_set: Process set object to limit this operation to a subset of
Horovod processes. Default is the global process set.

Returns:
The list of objects that were allgathered across all ranks.
Expand All @@ -166,12 +176,12 @@ def to_numpy(v):
t = tf.convert_to_tensor(bytearray(b.getvalue()), dtype=tf.uint8)
sz = tf.convert_to_tensor([t.shape[0]], dtype=tf.int32)

sizes = to_numpy(allgather(sz, name=name + '.sz'))
gathered = to_numpy(allgather(t, name=name + '.t'))
sizes = to_numpy(allgather(sz, name=name + '.sz', process_set=process_set))
gathered = to_numpy(allgather(t, name=name + '.t', process_set=process_set))

def select(i):
start = sum(sizes[:i])
end = start + sizes[i]
return gathered[start:end]

return [load(select(i)) for i in range(size())]
return [load(select(i)) for i in range(process_set.size())]
174 changes: 174 additions & 0 deletions test/parallel/test_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2335,6 +2335,44 @@ def test_horovod_broadcast_gpu_process_sets(self):
hvd.remove_process_set(odd_set)
hvd.remove_process_set(even_set)

def test_broadcast_variables_process_sets(self):
hvd.init()
rank = hvd.rank()
size = hvd.size()

if hvd.ccl_built():
self.skipTest("Multiple process sets currently do not support CCL.")

# This test does not apply if there is only one worker.
if size == 1:
self.skipTest("Only one worker available")

even_ranks = [rk for rk in range(0, size) if rk % 2 == 0]
odd_ranks = [rk for rk in range(0, size) if rk % 2 == 1]

even_set = hvd.add_process_set(even_ranks)
odd_set = hvd.add_process_set(odd_ranks)

if rank in even_ranks:
set_ranks = even_ranks
this_set = even_set
elif rank in odd_ranks:
set_ranks = odd_ranks
this_set = odd_set
root_rank = set_ranks[0]

with tf.device("/cpu:0"):
var = tf.Variable(initial_value=[rank], dtype=tf.int32)
if not hvd._executing_eagerly():
init = tf.compat.v1.global_variables_initializer()
self.evaluate(init)
self.evaluate(
hvd.broadcast_variables([var], root_rank=root_rank, process_set=this_set))
value = self.evaluate(var)
self.assertListEqual(list(value), [root_rank])

hvd.remove_process_set(odd_set)
hvd.remove_process_set(even_set)

def test_horovod_broadcast_error(self):
"""Test that the broadcast returns an error if any dimension besides
Expand Down Expand Up @@ -3569,6 +3607,49 @@ def test_broadcast_object(self):
obj = hvd.broadcast_object(obj, root_rank=0)
self.assertDictEqual(obj, expected_obj)

def test_broadcast_object_process_sets(self):
""" This should best be tested with more than two Horovod processes """
hvd.init()
rank = hvd.rank()
size = hvd.size()

if hvd.ccl_built():
self.skipTest("Multiple process sets currently do not support CCL.")

# This test does not apply if there is only one worker.
if size == 1:
self.skipTest("Only one worker available")

even_ranks = [rk for rk in range(0, size) if rk % 2 == 0]
odd_ranks = [rk for rk in range(0, size) if rk % 2 == 1]
even_set = hvd.add_process_set(even_ranks)
odd_set = hvd.add_process_set(odd_ranks)
if rank in even_ranks:
set_ranks = even_ranks
this_set = even_set
elif rank in odd_ranks:
set_ranks = odd_ranks
this_set = odd_set
root_rank = set_ranks[0]

with tf.device("/cpu:0"):
expected_even_obj = {
'even': 123,
0: [1, 2]
}
expected_odd_obj = {
'odd': 456,
1: [1, 2, 3, 4]
}
expected_obj = expected_even_obj if this_set == even_set else expected_odd_obj
obj = expected_obj if hvd.rank() == root_rank else {}

obj = hvd.broadcast_object(obj, root_rank=root_rank, process_set=this_set)
self.assertDictEqual(obj, expected_obj)

hvd.remove_process_set(odd_set)
hvd.remove_process_set(even_set)

def test_broadcast_object_fn(self):
if hvd._executing_eagerly() or _IS_TF2:
# Only for TF 1.0 in graph mode
Expand All @@ -3587,6 +3668,55 @@ def test_broadcast_object_fn(self):
obj = bcast(obj)
self.assertDictEqual(obj, expected_obj)

def test_broadcast_object_fn_process_sets(self):
""" This should best be tested with more than two Horovod processes """
if hvd._executing_eagerly() or _IS_TF2:
# Only for TF 1.0 in graph mode
return

hvd.init()
rank = hvd.rank()
size = hvd.size()

if hvd.ccl_built():
self.skipTest("Multiple process sets currently do not support CCL.")

# This test does not apply if there is only one worker.
if size == 1:
self.skipTest("Only one worker available")

even_ranks = [rk for rk in range(0, size) if rk % 2 == 0]
odd_ranks = [rk for rk in range(0, size) if rk % 2 == 1]
even_set = hvd.add_process_set(even_ranks)
odd_set = hvd.add_process_set(odd_ranks)
if rank in even_ranks:
set_ranks = even_ranks
this_set = even_set
elif rank in odd_ranks:
set_ranks = odd_ranks
this_set = odd_set
root_rank = set_ranks[0]

with tf.device("/cpu:0"):
expected_even_obj = {
'even': 123,
0: [1, 2]
}
expected_odd_obj = {
'odd': 456,
1: [1, 2, 3, 4]
}
expected_obj = expected_even_obj if this_set == even_set else expected_odd_obj
obj = expected_obj if hvd.rank() == root_rank else {}

bcast = hvd.broadcast_object_fn(root_rank=root_rank, process_set=this_set)
obj = bcast(obj)
self.assertDictEqual(obj, expected_obj)

hvd.remove_process_set(odd_set)
hvd.remove_process_set(even_set)


def test_allgather_object(self):
hvd.init()

Expand All @@ -3604,6 +3734,50 @@ def test_allgather_object(self):
self.assertEqual(len(results), hvd.size())
self.assertListEqual(results, expected)


def test_allgather_object_process_sets(self):
""" This should best be tested with more than two Horovod processes """
hvd.init()

rank = hvd.rank()
size = hvd.size()

if hvd.ccl_built():
self.skipTest("Multiple process sets currently do not support CCL.")

# This test does not apply if there is only one worker.
if size == 1:
self.skipTest("Only one worker available")

even_ranks = [rk for rk in range(0, size) if rk % 2 == 0]
odd_ranks = [rk for rk in range(0, size) if rk % 2 == 1]
even_set = hvd.add_process_set(even_ranks)
odd_set = hvd.add_process_set(odd_ranks)
if rank in even_ranks:
set_ranks = even_ranks
this_set = even_set
elif rank in odd_ranks:
set_ranks = odd_ranks
this_set = odd_set

with tf.device("/cpu:0"):
d = {'metric_val_1': hvd.rank()}
if this_set.rank() == 1:
d['metric_val_2'] = 42 if this_set == even_set else 23

results = hvd.allgather_object(d, process_set=this_set)

expected = [{'metric_val_1': i} for i in set_ranks]
if this_set.size() > 1:
expected[1] = {'metric_val_1': set_ranks[1],
'metric_val_2': 42 if this_set == even_set else 23}

self.assertEqual(len(results), this_set.size())
self.assertListEqual(results, expected)

hvd.remove_process_set(odd_set)
hvd.remove_process_set(even_set)

def test_elastic_state(self):
if not hvd._executing_eagerly() and _IS_TF2:
# Only support TF 2.0 in eager mode
Expand Down