Skip to content

Commit

Permalink
Allow sharing of tensors whose size is unknown to some parties (#184)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #184

When a party wants to secret-share a tensor, the CrypTen API currently assumes that the other parties know what the size of that tensor will be. This may be problematic, for example, if a party wants to share a data tensor whose size is unknown a priori. To make matters worse, if a party makes an incorrect assumption about the size of the tensor they will receive, the PRZS initialization does not work correctly and the parties may end up with secret shares of a tensor that do not match in size (which, somewhat surprisingly, does not always appear to be caught by Gloo during computations).

The diff adds an `broadcast_size` option in the construction of secret-shared tensors. If set to `True`, the `src` party will broadcast the size of the tensor to all other parties so that they can execute the PRZS initialization correctly.

We could also add an additional check on the size of the secret shares for all initializations, but I did not do that because it would add a communication round (somewhat defeating the purpose of doing PRZS initialization). This means that the failure case described above still exists; I added documentation describing it.

Reviewed By: knottb

Differential Revision: D24381193

fbshipit-source-id: 0b2ae19df973dc777f546944adc1e7676d3ef621
  • Loading branch information
lvdmaaten authored and facebook-github-bot committed Oct 21, 2020
1 parent 1cd3894 commit 0169987
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 27 deletions.
4 changes: 2 additions & 2 deletions crypten/communicator/distributed_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def send_obj(self, obj, dst, group=None):

buf = pickle.dumps(obj)
size = torch.tensor(len(buf), dtype=torch.int32)
arr = torch.from_numpy(numpy.frombuffer(buf, dtype=numpy.int8))
arr = torch.from_numpy(numpy.copy(numpy.frombuffer(buf, dtype=numpy.int8)))

r0 = dist.isend(size, dst=dst, group=group)
r1 = dist.isend(arr, dst=dst, group=group)
Expand Down Expand Up @@ -296,7 +296,7 @@ def broadcast_obj(self, obj, src, group=None):
assert obj is not None, "src party must provide obj for broadcast"
buf = pickle.dumps(obj)
size = torch.tensor(len(buf), dtype=torch.int32)
arr = torch.from_numpy(numpy.frombuffer(buf, dtype=numpy.int8))
arr = torch.from_numpy(numpy.copy(numpy.frombuffer(buf, dtype=numpy.int8)))

dist.broadcast(size, src, group=group)
dist.broadcast(arr, src, group=group)
Expand Down
22 changes: 21 additions & 1 deletion crypten/mpc/mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,25 @@ def __init__(self, *args):

class MPCTensor(CrypTensor):
def __init__(self, tensor, ptype=Ptype.arithmetic, device=None, *args, **kwargs):
"""
Creates the shared tensor from the input `tensor` provided by party `src`.
The `ptype` defines the type of sharing used (default: arithmetic).
The other parties can specify a `tensor` or `size` to determine the size
of the shared tensor object to create. In this case, all parties must
specify the same (tensor) size to prevent the party's shares from varying
in size, which leads to undefined behavior.
Alternatively, the parties can set `broadcast_size` to `True` to have the
`src` party broadcast the correct size. The parties who do not know the
tensor size beforehand can provide an empty tensor as input. This is
guaranteed to produce correct behavior but requires an additional
communication round.
The parties can also set the `precision` and `device` for their share of
the tensor. If `device` is unspecified, it is set to `tensor.device`.
"""

# take required_grad from kwargs, input tensor, or set to False:
default = tensor.requires_grad if torch.is_tensor(tensor) else False
requires_grad = kwargs.pop("requires_grad", default)
Expand All @@ -139,12 +158,13 @@ def __init__(self, tensor, ptype=Ptype.arithmetic, device=None, *args, **kwargs)
if tensor is None:
return # TODO: Can we remove this and do staticmethods differently?

# if device is unspecified, try and get it from tensor:
if device is None and hasattr(tensor, "device"):
device = tensor.device

# create the MPCTensor:
tensor_type = ptype.to_tensor()
self._tensor = tensor_type(tensor, device=device, *args, **kwargs)
self._tensor = tensor_type(tensor=tensor, device=device, *args, **kwargs)
self.ptype = ptype

@staticmethod
Expand Down
60 changes: 50 additions & 10 deletions crypten/mpc/primitives/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,57 @@ class ArithmeticSharedTensor(object):
"""

# constructors:
def __init__(self, tensor=None, size=None, precision=None, src=0, device=None):
def __init__(
self,
tensor=None,
size=None,
broadcast_size=False,
precision=None,
src=0,
device=None,
):
"""
Creates the shared tensor from the input `tensor` provided by party `src`.
The other parties can specify a `tensor` or `size` to determine the size
of the shared tensor object to create. In this case, all parties must
specify the same (tensor) size to prevent the party's shares from varying
in size, which leads to undefined behavior.
Alternatively, the parties can set `broadcast_size` to `True` to have the
`src` party broadcast the correct size. The parties who do not know the
tensor size beforehand can provide an empty tensor as input. This is
guaranteed to produce correct behavior but requires an additional
communication round.
The parties can also set the `precision` and `device` for their share of
the tensor. If `device` is unspecified, it is set to `tensor.device`.
"""

# do nothing if source is sentinel:
if src == SENTINEL:
return

# assertions on inputs:
assert (
isinstance(src, int) and src >= 0 and src < comm.get().get_world_size()
), "invalid tensor source"
), "specified source party does not exist"
if self.rank == src:
assert tensor is not None, "source must provide a data tensor"
if hasattr(tensor, "src"):
assert (
tensor.src == src
), "source of data tensor must match source of encryption"
if not broadcast_size:
assert (
tensor is not None or size is not None
), "must specify tensor or size, or set broadcast_size"

if device is None and hasattr(tensor, "device"):
# if device is unspecified, try and get it from tensor:
if device is None and tensor is not None and hasattr(tensor, "device"):
device = tensor.device

# encode the input tensor:
self.encoder = FixedPointEncoder(precision_bits=precision)
if tensor is not None:
if is_int_tensor(tensor) and precision != 0:
Expand All @@ -54,14 +95,13 @@ def __init__(self, tensor=None, size=None, precision=None, src=0, device=None):
tensor = tensor.to(device=device)
size = tensor.size()

# Generate psuedo-random sharing of zero (PRZS) and add source's tensor
# if other parties do not know tensor's size, broadcast the size:
if broadcast_size:
size = comm.get().broadcast_obj(size, src)

# generate pseudo-random zero sharing (PRZS) and add source's tensor:
self.share = ArithmeticSharedTensor.PRZS(size, device=device).share
if self.rank == src:
assert tensor is not None, "Source must provide a data tensor"
if hasattr(tensor, "src"):
assert (
tensor.src == src
), "Source of data tensor must match source of encryption"
self.share += tensor

@property
Expand Down Expand Up @@ -246,7 +286,7 @@ def get_plain_text(self, dst=None):
def _arithmetic_function_(self, y, op, *args, **kwargs):
return self._arithmetic_function(y, op, inplace=True, *args, **kwargs)

def _arithmetic_function(self, y, op, inplace=False, *args, **kwargs):
def _arithmetic_function(self, y, op, inplace=False, *args, **kwargs): # noqa:C901
assert op in [
"add",
"sub",
Expand Down
55 changes: 44 additions & 11 deletions crypten/mpc/primitives/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,64 @@ class BinarySharedTensor(object):
where n is the number of parties present in the protocol (world_size).
"""

def __init__(self, tensor=None, size=None, src=0, device=None):
def __init__(
self, tensor=None, size=None, broadcast_size=False, src=0, device=None
):
"""
Creates the shared tensor from the input `tensor` provided by party `src`.
The other parties can specify a `tensor` or `size` to determine the size
of the shared tensor object to create. In this case, all parties must
specify the same (tensor) size to prevent the party's shares from varying
in size, which leads to undefined behavior.
Alternatively, the parties can set `broadcast_size` to `True` to have the
`src` party broadcast the correct size. The parties who do not know the
tensor size beforehand can provide an empty tensor as input. This is
guaranteed to produce correct behavior but requires an additional
communication round.
The parties can also set the `precision` and `device` for their share of
the tensor. If `device` is unspecified, it is set to `tensor.device`.
"""

# do nothing if source is sentinel:
if src == SENTINEL:
return

# assertions on inputs:
assert (
isinstance(src, int) and src >= 0 and src < comm.get().get_world_size()
), "invalid tensor source"

if device is None and hasattr(tensor, "device"):
), "specified source party does not exist"
if self.rank == src:
assert tensor is not None, "source must provide a data tensor"
if hasattr(tensor, "src"):
assert (
tensor.src == src
), "source of data tensor must match source of encryption"
if not broadcast_size:
assert (
tensor is not None or size is not None
), "must specify tensor or size, or set broadcast_size"

# if device is unspecified, try and get it from tensor:
if device is None and tensor is not None and hasattr(tensor, "device"):
device = tensor.device

# Assume 0 bits of precision unless encoder is set outside of init
# assume zero bits of precision unless encoder is set outside of init:
self.encoder = FixedPointEncoder(precision_bits=0)
if tensor is not None:
tensor = self.encoder.encode(tensor)
tensor = tensor.to(device=device)
size = tensor.size()

# Generate Psuedo-random Sharing of Zero and add source's tensor
# if other parties do not know tensor's size, broadcast the size:
if broadcast_size:
size = comm.get().broadcast_obj(size, src)

# generate pseudo-random zero sharing (PRZS) and add source's tensor:
self.share = BinarySharedTensor.PRZS(size, device=device).share
if self.rank == src:
assert tensor is not None, "Source must provide a data tensor"
if hasattr(tensor, "src"):
assert (
tensor.src == src
), "Source of data tensor must match source of encryption"
self.share ^= tensor

@staticmethod
Expand Down
17 changes: 16 additions & 1 deletion test/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,30 @@ def test_encrypt_decrypt(self):
(5, 3, 32, 32),
]
for size in sizes:

# encryption and decryption without source:
reference = get_random_test_tensor(size=size, is_float=True)
encrypted_tensor = ArithmeticSharedTensor(reference)
self._check(encrypted_tensor, reference, "en/decryption failed")

for dst in range(self.world_size):
self._check(
encrypted_tensor, reference, "en/decryption failed", dst=dst
)

# encryption and decryption with source:
for src in range(self.world_size):
input_tensor = reference if src == self.rank else []
encrypted_tensor = ArithmeticSharedTensor(
input_tensor, src=src, broadcast_size=True
)
for dst in range(self.world_size):
self._check(
encrypted_tensor,
reference,
"en/decryption with broadcast_size failed",
dst=dst,
)

def test_arithmetic(self):
"""Tests arithmetic functions on encrypted tensor."""
arithmetic_functions = ["add", "add_", "sub", "sub_", "mul", "mul_"]
Expand Down
17 changes: 16 additions & 1 deletion test/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,30 @@ def test_encrypt_decrypt(self):
(5, 3, 32, 32),
]
for size in sizes:

# encryption and decryption without source:
reference = get_random_test_tensor(size=size, is_float=False)
encrypted_tensor = BinarySharedTensor(reference)
self._check(encrypted_tensor, reference, "en/decryption failed")

for dst in range(self.world_size):
self._check(
encrypted_tensor, reference, "en/decryption failed", dst=dst
)

# encryption and decryption with source:
for src in range(self.world_size):
input_tensor = reference if src == self.rank else []
encrypted_tensor = BinarySharedTensor(
input_tensor, src=src, broadcast_size=True
)
for dst in range(self.world_size):
self._check(
encrypted_tensor,
reference,
"en/decryption with broadcast_size failed",
dst=dst,
)

def test_transpose(self):
sizes = [
(1,),
Expand Down
16 changes: 15 additions & 1 deletion test/test_mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def test_encrypt_decrypt(self):
(5, 3, 32, 32),
]
for size in sizes:

# encryption and decryption without source:
reference = self._get_random_test_tensor(size=size, is_float=True)
encrypted_tensor = MPCTensor(reference)
self._check(encrypted_tensor, reference, "en/decryption failed")
Expand All @@ -169,7 +171,19 @@ def test_encrypt_decrypt(self):
encrypted_tensor, reference, "en/decryption failed", dst=dst
)

# Test new()
# encryption and decryption with source:
for src in range(self.world_size):
input_tensor = reference if src == self.rank else []
encrypted_tensor = MPCTensor(input_tensor, src=src, broadcast_size=True)
for dst in range(self.world_size):
self._check(
encrypted_tensor,
reference,
"en/decryption with broadcast_size failed",
dst=dst,
)

# test creation via new() function:
encrypted_tensor2 = encrypted_tensor.new(reference)
self.assertIsInstance(
encrypted_tensor2, MPCTensor, "new() returns incorrect type"
Expand Down

0 comments on commit 0169987

Please sign in to comment.