Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow sharing of tensors whose size is unknown to some parties #184

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crypten/communicator/distributed_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def send_obj(self, obj, dst, group=None):

buf = pickle.dumps(obj)
size = torch.tensor(len(buf), dtype=torch.int32)
arr = torch.from_numpy(numpy.frombuffer(buf, dtype=numpy.int8))
arr = torch.from_numpy(numpy.copy(numpy.frombuffer(buf, dtype=numpy.int8)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I have read from here[1], the arr should share the same memory with the numpy array in case numpy.copy is not used.

Q: It is needed to create a copy of the data? (by using this I think we are playing it safe, but if buf becomes bigger it might introduce an overhead because of the copy)

[1] https://pytorch.org/docs/stable/generated/torch.from_numpy.html

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this copy, I am seeing this warning. I tried altering the writeable flag of the numpy array directly but that does not appear to be allowed here. The copy makes the warning go away.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, you are right!

I have read here that pytorch does not support for the moment read-only tensors!


r0 = dist.isend(size, dst=dst, group=group)
r1 = dist.isend(arr, dst=dst, group=group)
Expand Down Expand Up @@ -296,7 +296,7 @@ def broadcast_obj(self, obj, src, group=None):
assert obj is not None, "src party must provide obj for broadcast"
buf = pickle.dumps(obj)
size = torch.tensor(len(buf), dtype=torch.int32)
arr = torch.from_numpy(numpy.frombuffer(buf, dtype=numpy.int8))
arr = torch.from_numpy(numpy.copy(numpy.frombuffer(buf, dtype=numpy.int8)))

dist.broadcast(size, src, group=group)
dist.broadcast(arr, src, group=group)
Expand Down
22 changes: 21 additions & 1 deletion crypten/mpc/mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,25 @@ def __init__(self, *args):

class MPCTensor(CrypTensor):
def __init__(self, tensor, ptype=Ptype.arithmetic, device=None, *args, **kwargs):
"""
Creates the shared tensor from the input `tensor` provided by party `src`.
The `ptype` defines the type of sharing used (default: arithmetic).

The other parties can specify a `tensor` or `size` to determine the size
of the shared tensor object to create. In this case, all parties must
specify the same (tensor) size to prevent the party's shares from varying
in size, which leads to undefined behavior.

Alternatively, the parties can set `broadcast_size` to `True` to have the
`src` party broadcast the correct size. The parties who do not know the
tensor size beforehand can provide an empty tensor as input. This is
guaranteed to produce correct behavior but requires an additional
communication round.

The parties can also set the `precision` and `device` for their share of
the tensor. If `device` is unspecified, it is set to `tensor.device`.
"""

# take required_grad from kwargs, input tensor, or set to False:
default = tensor.requires_grad if torch.is_tensor(tensor) else False
requires_grad = kwargs.pop("requires_grad", default)
Expand All @@ -139,12 +158,13 @@ def __init__(self, tensor, ptype=Ptype.arithmetic, device=None, *args, **kwargs)
if tensor is None:
return # TODO: Can we remove this and do staticmethods differently?

# if device is unspecified, try and get it from tensor:
if device is None and hasattr(tensor, "device"):
device = tensor.device

# create the MPCTensor:
tensor_type = ptype.to_tensor()
self._tensor = tensor_type(tensor, device=device, *args, **kwargs)
self._tensor = tensor_type(tensor=tensor, device=device, *args, **kwargs)
self.ptype = ptype

@staticmethod
Expand Down
60 changes: 50 additions & 10 deletions crypten/mpc/primitives/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,57 @@ class ArithmeticSharedTensor(object):
"""

# constructors:
def __init__(self, tensor=None, size=None, precision=None, src=0, device=None):
def __init__(
self,
tensor=None,
size=None,
broadcast_size=False,
precision=None,
src=0,
device=None,
):
"""
Creates the shared tensor from the input `tensor` provided by party `src`.

The other parties can specify a `tensor` or `size` to determine the size
of the shared tensor object to create. In this case, all parties must
specify the same (tensor) size to prevent the party's shares from varying
in size, which leads to undefined behavior.

Alternatively, the parties can set `broadcast_size` to `True` to have the
`src` party broadcast the correct size. The parties who do not know the
tensor size beforehand can provide an empty tensor as input. This is
guaranteed to produce correct behavior but requires an additional
communication round.

The parties can also set the `precision` and `device` for their share of
the tensor. If `device` is unspecified, it is set to `tensor.device`.
"""

# do nothing if source is sentinel:
if src == SENTINEL:
return

# assertions on inputs:
assert (
isinstance(src, int) and src >= 0 and src < comm.get().get_world_size()
), "invalid tensor source"
), "specified source party does not exist"
if self.rank == src:
assert tensor is not None, "source must provide a data tensor"
if hasattr(tensor, "src"):
assert (
tensor.src == src
), "source of data tensor must match source of encryption"
if not broadcast_size:
assert (
tensor is not None or size is not None
), "must specify tensor or size, or set broadcast_size"

if device is None and hasattr(tensor, "device"):
# if device is unspecified, try and get it from tensor:
if device is None and tensor is not None and hasattr(tensor, "device"):
device = tensor.device

# encode the input tensor:
self.encoder = FixedPointEncoder(precision_bits=precision)
if tensor is not None:
if is_int_tensor(tensor) and precision != 0:
Expand All @@ -54,14 +95,13 @@ def __init__(self, tensor=None, size=None, precision=None, src=0, device=None):
tensor = tensor.to(device=device)
size = tensor.size()

# Generate psuedo-random sharing of zero (PRZS) and add source's tensor
# if other parties do not know tensor's size, broadcast the size:
if broadcast_size:
size = comm.get().broadcast_obj(size, src)

# generate pseudo-random zero sharing (PRZS) and add source's tensor:
self.share = ArithmeticSharedTensor.PRZS(size, device=device).share
if self.rank == src:
assert tensor is not None, "Source must provide a data tensor"
if hasattr(tensor, "src"):
assert (
tensor.src == src
), "Source of data tensor must match source of encryption"
self.share += tensor

@property
Expand Down Expand Up @@ -246,7 +286,7 @@ def get_plain_text(self, dst=None):
def _arithmetic_function_(self, y, op, *args, **kwargs):
return self._arithmetic_function(y, op, inplace=True, *args, **kwargs)

def _arithmetic_function(self, y, op, inplace=False, *args, **kwargs):
def _arithmetic_function(self, y, op, inplace=False, *args, **kwargs): # noqa:C901
assert op in [
"add",
"sub",
Expand Down
55 changes: 44 additions & 11 deletions crypten/mpc/primitives/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,64 @@ class BinarySharedTensor(object):
where n is the number of parties present in the protocol (world_size).
"""

def __init__(self, tensor=None, size=None, src=0, device=None):
def __init__(
self, tensor=None, size=None, broadcast_size=False, src=0, device=None
):
"""
Creates the shared tensor from the input `tensor` provided by party `src`.

The other parties can specify a `tensor` or `size` to determine the size
of the shared tensor object to create. In this case, all parties must
specify the same (tensor) size to prevent the party's shares from varying
in size, which leads to undefined behavior.

Alternatively, the parties can set `broadcast_size` to `True` to have the
`src` party broadcast the correct size. The parties who do not know the
tensor size beforehand can provide an empty tensor as input. This is
guaranteed to produce correct behavior but requires an additional
communication round.

The parties can also set the `precision` and `device` for their share of
the tensor. If `device` is unspecified, it is set to `tensor.device`.
"""

# do nothing if source is sentinel:
if src == SENTINEL:
return

# assertions on inputs:
assert (
isinstance(src, int) and src >= 0 and src < comm.get().get_world_size()
), "invalid tensor source"

if device is None and hasattr(tensor, "device"):
), "specified source party does not exist"
if self.rank == src:
assert tensor is not None, "source must provide a data tensor"
if hasattr(tensor, "src"):
assert (
tensor.src == src
), "source of data tensor must match source of encryption"
if not broadcast_size:
assert (
tensor is not None or size is not None
), "must specify tensor or size, or set broadcast_size"

# if device is unspecified, try and get it from tensor:
if device is None and tensor is not None and hasattr(tensor, "device"):
device = tensor.device

# Assume 0 bits of precision unless encoder is set outside of init
# assume zero bits of precision unless encoder is set outside of init:
self.encoder = FixedPointEncoder(precision_bits=0)
if tensor is not None:
tensor = self.encoder.encode(tensor)
tensor = tensor.to(device=device)
size = tensor.size()

# Generate Psuedo-random Sharing of Zero and add source's tensor
# if other parties do not know tensor's size, broadcast the size:
if broadcast_size:
size = comm.get().broadcast_obj(size, src)

# generate pseudo-random zero sharing (PRZS) and add source's tensor:
self.share = BinarySharedTensor.PRZS(size, device=device).share
if self.rank == src:
assert tensor is not None, "Source must provide a data tensor"
if hasattr(tensor, "src"):
assert (
tensor.src == src
), "Source of data tensor must match source of encryption"
self.share ^= tensor

@staticmethod
Expand Down
17 changes: 16 additions & 1 deletion test/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,30 @@ def test_encrypt_decrypt(self):
(5, 3, 32, 32),
]
for size in sizes:

# encryption and decryption without source:
reference = get_random_test_tensor(size=size, is_float=True)
encrypted_tensor = ArithmeticSharedTensor(reference)
self._check(encrypted_tensor, reference, "en/decryption failed")

for dst in range(self.world_size):
self._check(
encrypted_tensor, reference, "en/decryption failed", dst=dst
)

# encryption and decryption with source:
for src in range(self.world_size):
input_tensor = reference if src == self.rank else []
encrypted_tensor = ArithmeticSharedTensor(
input_tensor, src=src, broadcast_size=True
)
for dst in range(self.world_size):
self._check(
encrypted_tensor,
reference,
"en/decryption with broadcast_size failed",
dst=dst,
)

def test_arithmetic(self):
"""Tests arithmetic functions on encrypted tensor."""
arithmetic_functions = ["add", "add_", "sub", "sub_", "mul", "mul_"]
Expand Down
17 changes: 16 additions & 1 deletion test/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,30 @@ def test_encrypt_decrypt(self):
(5, 3, 32, 32),
]
for size in sizes:

# encryption and decryption without source:
reference = get_random_test_tensor(size=size, is_float=False)
encrypted_tensor = BinarySharedTensor(reference)
self._check(encrypted_tensor, reference, "en/decryption failed")

for dst in range(self.world_size):
self._check(
encrypted_tensor, reference, "en/decryption failed", dst=dst
)

# encryption and decryption with source:
for src in range(self.world_size):
input_tensor = reference if src == self.rank else []
encrypted_tensor = BinarySharedTensor(
input_tensor, src=src, broadcast_size=True
)
for dst in range(self.world_size):
self._check(
encrypted_tensor,
reference,
"en/decryption with broadcast_size failed",
dst=dst,
)

def test_transpose(self):
sizes = [
(1,),
Expand Down
16 changes: 15 additions & 1 deletion test/test_mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def test_encrypt_decrypt(self):
(5, 3, 32, 32),
]
for size in sizes:

# encryption and decryption without source:
reference = self._get_random_test_tensor(size=size, is_float=True)
encrypted_tensor = MPCTensor(reference)
self._check(encrypted_tensor, reference, "en/decryption failed")
Expand All @@ -169,7 +171,19 @@ def test_encrypt_decrypt(self):
encrypted_tensor, reference, "en/decryption failed", dst=dst
)

# Test new()
# encryption and decryption with source:
for src in range(self.world_size):
input_tensor = reference if src == self.rank else []
encrypted_tensor = MPCTensor(input_tensor, src=src, broadcast_size=True)
for dst in range(self.world_size):
self._check(
encrypted_tensor,
reference,
"en/decryption with broadcast_size failed",
dst=dst,
)

# test creation via new() function:
encrypted_tensor2 = encrypted_tensor.new(reference)
self.assertIsInstance(
encrypted_tensor2, MPCTensor, "new() returns incorrect type"
Expand Down