From 819112e856b38ea1ba2057da9697d3d92e87b68d Mon Sep 17 00:00:00 2001 From: Laurens van der Maaten Date: Sat, 17 Oct 2020 10:04:32 -0700 Subject: [PATCH] Allow sharing of tensors whose size is unknown to some parties Summary: When a party wants to secret-share a tensor, the CrypTen API currently assumes that the other parties know what the size of that tensor will be. This may be problematic, for example, if a party wants to share a data tensor whose size is unknown a priori. To make matters worse, if a party makes an incorrect assumption about the size of the tensor they will receive, the PRZS initialization does not work correctly and the parties may end up with secret shares of a tensor that do not match in size (which, somewhat surprisingly, does not always appear to be caught by Gloo during computations). The diff adds an `broadcast_size` option in the construction of secret-shared tensors. If set to `True`, the `src` party will broadcast the size of the tensor to all other parties so that they can execute the PRZS initialization correctly. We could also add an additional check on the size of the secret shares for all initializations, but I did not do that because it would add a communication round (somewhat defeating the purpose of doing PRZS initialization). This means that the failure case described above still exists; I added documentation describing it. Differential Revision: D24381193 fbshipit-source-id: 3baf2373d40fc880f7d5741178e182c0661b46b2 --- .../communicator/distributed_communicator.py | 4 +- crypten/mpc/mpc.py | 22 ++++++- crypten/mpc/primitives/arithmetic.py | 60 +++++++++++++++---- crypten/mpc/primitives/binary.py | 55 +++++++++++++---- test/test_arithmetic.py | 17 +++++- test/test_binary.py | 17 +++++- test/test_mpc.py | 16 ++++- 7 files changed, 164 insertions(+), 27 deletions(-) diff --git a/crypten/communicator/distributed_communicator.py b/crypten/communicator/distributed_communicator.py index f9e38db2..33e09feb 100644 --- a/crypten/communicator/distributed_communicator.py +++ b/crypten/communicator/distributed_communicator.py @@ -264,7 +264,7 @@ def send_obj(self, obj, dst, group=None): buf = pickle.dumps(obj) size = torch.tensor(len(buf), dtype=torch.int32) - arr = torch.from_numpy(numpy.frombuffer(buf, dtype=numpy.int8)) + arr = torch.from_numpy(numpy.copy(numpy.frombuffer(buf, dtype=numpy.int8))) r0 = dist.isend(size, dst=dst, group=group) r1 = dist.isend(arr, dst=dst, group=group) @@ -296,7 +296,7 @@ def broadcast_obj(self, obj, src, group=None): assert obj is not None, "src party must provide obj for broadcast" buf = pickle.dumps(obj) size = torch.tensor(len(buf), dtype=torch.int32) - arr = torch.from_numpy(numpy.frombuffer(buf, dtype=numpy.int8)) + arr = torch.from_numpy(numpy.copy(numpy.frombuffer(buf, dtype=numpy.int8))) dist.broadcast(size, src, group=group) dist.broadcast(arr, src, group=group) diff --git a/crypten/mpc/mpc.py b/crypten/mpc/mpc.py index 70663700..207616e5 100644 --- a/crypten/mpc/mpc.py +++ b/crypten/mpc/mpc.py @@ -130,6 +130,25 @@ def __init__(self, *args): class MPCTensor(CrypTensor): def __init__(self, tensor, ptype=Ptype.arithmetic, device=None, *args, **kwargs): + """ + Creates the shared tensor from the input `tensor` provided by party `src`. + The `ptype` defines the type of sharing used (default: arithmetic). + + The other parties can specify a `tensor` or `size` to determine the size + of the shared tensor object to create. In this case, all parties must + specify the same (tensor) size to prevent the party's shares from varying + in size, which leads to undefined behavior. + + Alternatively, the parties can set `broadcast_size` to `True` to have the + `src` party broadcast the correct size. The parties who do not know the + tensor size beforehand can provide an empty tensor as input. This is + guaranteed to produce correct behavior but requires an additional + communication round. + + The parties can also set the `precision` and `device` for their share of + the tensor. If `device` is unspecified, it is set to `tensor.device`. + """ + # take required_grad from kwargs, input tensor, or set to False: default = tensor.requires_grad if torch.is_tensor(tensor) else False requires_grad = kwargs.pop("requires_grad", default) @@ -139,12 +158,13 @@ def __init__(self, tensor, ptype=Ptype.arithmetic, device=None, *args, **kwargs) if tensor is None: return # TODO: Can we remove this and do staticmethods differently? + # if device is unspecified, try and get it from tensor: if device is None and hasattr(tensor, "device"): device = tensor.device # create the MPCTensor: tensor_type = ptype.to_tensor() - self._tensor = tensor_type(tensor, device=device, *args, **kwargs) + self._tensor = tensor_type(tensor=tensor, device=device, *args, **kwargs) self.ptype = ptype @staticmethod diff --git a/crypten/mpc/primitives/arithmetic.py b/crypten/mpc/primitives/arithmetic.py index 0c53057d..7b12470a 100644 --- a/crypten/mpc/primitives/arithmetic.py +++ b/crypten/mpc/primitives/arithmetic.py @@ -36,16 +36,57 @@ class ArithmeticSharedTensor(object): """ # constructors: - def __init__(self, tensor=None, size=None, precision=None, src=0, device=None): + def __init__( + self, + tensor=None, + size=None, + broadcast_size=False, + precision=None, + src=0, + device=None, + ): + """ + Creates the shared tensor from the input `tensor` provided by party `src`. + + The other parties can specify a `tensor` or `size` to determine the size + of the shared tensor object to create. In this case, all parties must + specify the same (tensor) size to prevent the party's shares from varying + in size, which leads to undefined behavior. + + Alternatively, the parties can set `broadcast_size` to `True` to have the + `src` party broadcast the correct size. The parties who do not know the + tensor size beforehand can provide an empty tensor as input. This is + guaranteed to produce correct behavior but requires an additional + communication round. + + The parties can also set the `precision` and `device` for their share of + the tensor. If `device` is unspecified, it is set to `tensor.device`. + """ + + # do nothing if source is sentinel: if src == SENTINEL: return + + # assertions on inputs: assert ( isinstance(src, int) and src >= 0 and src < comm.get().get_world_size() - ), "invalid tensor source" + ), "specified source party does not exist" + if self.rank == src: + assert tensor is not None, "source must provide a data tensor" + if hasattr(tensor, "src"): + assert ( + tensor.src == src + ), "source of data tensor must match source of encryption" + if not broadcast_size: + assert ( + tensor is not None or size is not None + ), "must specify tensor or size, or set broadcast_size" - if device is None and hasattr(tensor, "device"): + # if device is unspecified, try and get it from tensor: + if device is None and tensor is not None and hasattr(tensor, "device"): device = tensor.device + # encode the input tensor: self.encoder = FixedPointEncoder(precision_bits=precision) if tensor is not None: if is_int_tensor(tensor) and precision != 0: @@ -54,14 +95,13 @@ def __init__(self, tensor=None, size=None, precision=None, src=0, device=None): tensor = tensor.to(device=device) size = tensor.size() - # Generate psuedo-random sharing of zero (PRZS) and add source's tensor + # if other parties do not know tensor's size, broadcast the size: + if broadcast_size: + size = comm.get().broadcast_obj(size, src) + + # generate pseudo-random zero sharing (PRZS) and add source's tensor: self.share = ArithmeticSharedTensor.PRZS(size, device=device).share if self.rank == src: - assert tensor is not None, "Source must provide a data tensor" - if hasattr(tensor, "src"): - assert ( - tensor.src == src - ), "Source of data tensor must match source of encryption" self.share += tensor @property @@ -246,7 +286,7 @@ def get_plain_text(self, dst=None): def _arithmetic_function_(self, y, op, *args, **kwargs): return self._arithmetic_function(y, op, inplace=True, *args, **kwargs) - def _arithmetic_function(self, y, op, inplace=False, *args, **kwargs): + def _arithmetic_function(self, y, op, inplace=False, *args, **kwargs): # noqa:C901 assert op in [ "add", "sub", diff --git a/crypten/mpc/primitives/binary.py b/crypten/mpc/primitives/binary.py index b747ed32..9078841c 100644 --- a/crypten/mpc/primitives/binary.py +++ b/crypten/mpc/primitives/binary.py @@ -31,31 +31,64 @@ class BinarySharedTensor(object): where n is the number of parties present in the protocol (world_size). """ - def __init__(self, tensor=None, size=None, src=0, device=None): + def __init__( + self, tensor=None, size=None, broadcast_size=False, src=0, device=None + ): + """ + Creates the shared tensor from the input `tensor` provided by party `src`. + + The other parties can specify a `tensor` or `size` to determine the size + of the shared tensor object to create. In this case, all parties must + specify the same (tensor) size to prevent the party's shares from varying + in size, which leads to undefined behavior. + + Alternatively, the parties can set `broadcast_size` to `True` to have the + `src` party broadcast the correct size. The parties who do not know the + tensor size beforehand can provide an empty tensor as input. This is + guaranteed to produce correct behavior but requires an additional + communication round. + + The parties can also set the `precision` and `device` for their share of + the tensor. If `device` is unspecified, it is set to `tensor.device`. + """ + + # do nothing if source is sentinel: if src == SENTINEL: return + + # assertions on inputs: assert ( isinstance(src, int) and src >= 0 and src < comm.get().get_world_size() - ), "invalid tensor source" - - if device is None and hasattr(tensor, "device"): + ), "specified source party does not exist" + if self.rank == src: + assert tensor is not None, "source must provide a data tensor" + if hasattr(tensor, "src"): + assert ( + tensor.src == src + ), "source of data tensor must match source of encryption" + if not broadcast_size: + assert ( + tensor is not None or size is not None + ), "must specify tensor or size, or set broadcast_size" + + # if device is unspecified, try and get it from tensor: + if device is None and tensor is not None and hasattr(tensor, "device"): device = tensor.device - # Assume 0 bits of precision unless encoder is set outside of init + # assume zero bits of precision unless encoder is set outside of init: self.encoder = FixedPointEncoder(precision_bits=0) if tensor is not None: tensor = self.encoder.encode(tensor) tensor = tensor.to(device=device) size = tensor.size() - # Generate Psuedo-random Sharing of Zero and add source's tensor + # if other parties do not know tensor's size, broadcast the size: + if broadcast_size: + size = comm.get().broadcast_obj(size, src) + + # generate pseudo-random zero sharing (PRZS) and add source's tensor: self.share = BinarySharedTensor.PRZS(size, device=device).share if self.rank == src: - assert tensor is not None, "Source must provide a data tensor" - if hasattr(tensor, "src"): - assert ( - tensor.src == src - ), "Source of data tensor must match source of encryption" self.share ^= tensor @staticmethod diff --git a/test/test_arithmetic.py b/test/test_arithmetic.py index da365d1a..f6ca337d 100644 --- a/test/test_arithmetic.py +++ b/test/test_arithmetic.py @@ -91,15 +91,30 @@ def test_encrypt_decrypt(self): (5, 3, 32, 32), ] for size in sizes: + + # encryption and decryption without source: reference = get_random_test_tensor(size=size, is_float=True) encrypted_tensor = ArithmeticSharedTensor(reference) self._check(encrypted_tensor, reference, "en/decryption failed") - for dst in range(self.world_size): self._check( encrypted_tensor, reference, "en/decryption failed", dst=dst ) + # encryption and decryption with source: + for src in range(self.world_size): + input_tensor = reference if src == self.rank else [] + encrypted_tensor = ArithmeticSharedTensor( + input_tensor, src=src, broadcast_size=True + ) + for dst in range(self.world_size): + self._check( + encrypted_tensor, + reference, + "en/decryption with broadcast_size failed", + dst=dst, + ) + def test_arithmetic(self): """Tests arithmetic functions on encrypted tensor.""" arithmetic_functions = ["add", "add_", "sub", "sub_", "mul", "mul_"] diff --git a/test/test_binary.py b/test/test_binary.py index ac0e9e8d..63c83749 100644 --- a/test/test_binary.py +++ b/test/test_binary.py @@ -67,15 +67,30 @@ def test_encrypt_decrypt(self): (5, 3, 32, 32), ] for size in sizes: + + # encryption and decryption without source: reference = get_random_test_tensor(size=size, is_float=False) encrypted_tensor = BinarySharedTensor(reference) self._check(encrypted_tensor, reference, "en/decryption failed") - for dst in range(self.world_size): self._check( encrypted_tensor, reference, "en/decryption failed", dst=dst ) + # encryption and decryption with source: + for src in range(self.world_size): + input_tensor = reference if src == self.rank else [] + encrypted_tensor = BinarySharedTensor( + input_tensor, src=src, broadcast_size=True + ) + for dst in range(self.world_size): + self._check( + encrypted_tensor, + reference, + "en/decryption with broadcast_size failed", + dst=dst, + ) + def test_transpose(self): sizes = [ (1,), diff --git a/test/test_mpc.py b/test/test_mpc.py index 01f812cd..7d5d91bd 100644 --- a/test/test_mpc.py +++ b/test/test_mpc.py @@ -161,6 +161,8 @@ def test_encrypt_decrypt(self): (5, 3, 32, 32), ] for size in sizes: + + # encryption and decryption without source: reference = self._get_random_test_tensor(size=size, is_float=True) encrypted_tensor = MPCTensor(reference) self._check(encrypted_tensor, reference, "en/decryption failed") @@ -169,7 +171,19 @@ def test_encrypt_decrypt(self): encrypted_tensor, reference, "en/decryption failed", dst=dst ) - # Test new() + # encryption and decryption with source: + for src in range(self.world_size): + input_tensor = reference if src == self.rank else [] + encrypted_tensor = MPCTensor(input_tensor, src=src, broadcast_size=True) + for dst in range(self.world_size): + self._check( + encrypted_tensor, + reference, + "en/decryption with broadcast_size failed", + dst=dst, + ) + + # test creation via new() function: encrypted_tensor2 = encrypted_tensor.new(reference) self.assertIsInstance( encrypted_tensor2, MPCTensor, "new() returns incorrect type"