Skip to content

Commit

Permalink
apply black 20.8b1 formatting update
Browse files Browse the repository at this point in the history
Summary:
allow-large-files

black_any_style

Reviewed By: zertosh

Differential Revision: D24325133

fbshipit-source-id: b4afe80d1e8b2bc993f4b8e3822c02964df47462
  • Loading branch information
amyreese authored and facebook-github-bot committed Oct 15, 2020
1 parent 85db37f commit 1cd3894
Show file tree
Hide file tree
Showing 18 changed files with 68 additions and 70 deletions.
20 changes: 10 additions & 10 deletions crypten/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,16 +172,16 @@ def is_encrypted_tensor(obj):

def _setup_przs(device=None):
"""
Generate shared random seeds to generate pseudo-random sharings of
zero. The random seeds are shared such that each process shares
one seed with the previous rank process and one with the next rank.
This allows for the generation of `n` random values, each known to
exactly two of the `n` parties.
For arithmetic sharing, one of these parties will add the number
while the other subtracts it, allowing for the generation of a
pseudo-random sharing of zero. (This can be done for binary
sharing using bitwise-xor rather than addition / subtraction)
Generate shared random seeds to generate pseudo-random sharings of
zero. The random seeds are shared such that each process shares
one seed with the previous rank process and one with the next rank.
This allows for the generation of `n` random values, each known to
exactly two of the `n` parties.
For arithmetic sharing, one of these parties will add the number
while the other subtracts it, allowing for the generation of a
pseudo-random sharing of zero. (This can be done for binary
sharing using bitwise-xor rather than addition / subtraction)
"""
# Initialize RNG Generators
comm.get().g0 = torch.Generator()
Expand Down
8 changes: 4 additions & 4 deletions crypten/communicator/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,11 @@ def _log_communication_time(self, comm_time):

def get_generator(self, idx, device=None):
"""
Get the corresponding RNG generator, as specified by its index and device
Get the corresponding RNG generator, as specified by its index and device
Args:
idx: The index of the generator, can be either 0 or 1
device: The device that the generator lives in.
Args:
idx: The index of the generator, can be either 0 or 1
device: The device that the generator lives in.
"""

if device is None:
Expand Down
9 changes: 4 additions & 5 deletions crypten/cryptensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,24 +787,23 @@ def where(self, condition, y):

def sigmoid(self, reciprocal_method="log"):
"""Computes the sigmoid function on the input value
sigmoid(x) = (1 + exp(-x))^{-1}
sigmoid(x) = (1 + exp(-x))^{-1}
"""
raise NotImplementedError("sigmoid is not implemented")

def tanh(self, reciprocal_method="log"):
"""Computes tanh from the sigmoid function:
tanh(x) = 2 * sigmoid(2 * x) - 1
tanh(x) = 2 * sigmoid(2 * x) - 1
"""
raise NotImplementedError("tanh is not implemented")

def softmax(self, dim, **kwargs):
"""Compute the softmax of a tensor's elements along a given dimension
"""
"""Compute the softmax of a tensor's elements along a given dimension"""
raise NotImplementedError("softmax is not implemented")

def log_softmax(self, dim, **kwargs):
"""Applies a softmax of a tensor's elements along a given dimension,
followed by a logarithm.
followed by a logarithm.
"""
raise NotImplementedError("log_softmax is not implemented")

Expand Down
10 changes: 5 additions & 5 deletions crypten/cuda/cuda_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def decorator(func):

class CUDALongTensor(object):
"""
A wrapper class for `torch.cuda.LongTensor`. When performing operations that are
currently not supported for `torch.cuda.LongTensor` (e.g `matmul`, `conv2d`), it will
convert the underlying LongTensor into DoubleTensor and convert the computed
result back to a LongTensor. The computed result will be the same as the original
expected result.
A wrapper class for `torch.cuda.LongTensor`. When performing operations that are
currently not supported for `torch.cuda.LongTensor` (e.g `matmul`, `conv2d`), it will
convert the underlying LongTensor into DoubleTensor and convert the computed
result back to a LongTensor. The computed result will be the same as the original
expected result.
"""

__BITS = torch.iinfo(torch.long).bits
Expand Down
3 changes: 1 addition & 2 deletions crypten/mpc/mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,8 +982,7 @@ def _truncate_tanh(self):

@mode(Ptype.arithmetic)
def softmax(self, dim, **kwargs):
"""Compute the softmax of a tensor's elements along a given dimension
"""
"""Compute the softmax of a tensor's elements along a given dimension"""
# 0-d case
if self.dim() == 0:
assert dim == 0, "Improper dim argument"
Expand Down
18 changes: 9 additions & 9 deletions crypten/mpc/primitives/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
# MPC tensor where shares additive-sharings.
class ArithmeticSharedTensor(object):
"""
Encrypted tensor object that uses additive sharing to perform computations.
Encrypted tensor object that uses additive sharing to perform computations.
Additive shares are computed by splitting each value of the input tensor
into n separate random values that add to the input tensor, where n is
the number of parties present in the protocol (world_size).
Additive shares are computed by splitting each value of the input tensor
into n separate random values that add to the input tensor, where n is
the number of parties present in the protocol (world_size).
"""

# constructors:
Expand Down Expand Up @@ -165,7 +165,7 @@ def __setitem__(self, index, value):

def pad(self, pad, mode="constant", value=0):
"""
Pads the input tensor with values provided in `value`.
Pads the input tensor with values provided in `value`.
"""
assert mode == "constant", (
"Padding with mode %s is currently unsupported" % mode
Expand Down Expand Up @@ -473,13 +473,13 @@ def conv_transpose2d(self, kernel, **kwargs):

def index_add(self, dim, index, tensor):
"""Perform out-of-place index_add: Accumulate the elements of tensor into the
self tensor by adding to the indices in the order given in index. """
self tensor by adding to the indices in the order given in index."""
result = self.clone()
return result.index_add_(dim, index, tensor)

def index_add_(self, dim, index, tensor):
"""Perform in-place index_add: Accumulate the elements of tensor into the
self tensor by adding to the indices in the order given in index. """
self tensor by adding to the indices in the order given in index."""
public = isinstance(tensor, (int, float)) or is_tensor(tensor)
private = isinstance(tensor, ArithmeticSharedTensor)
if public:
Expand Down Expand Up @@ -543,8 +543,8 @@ def _sum_pool2d(self, kernel_size, stride=None, padding=0):

def take(self, index, dimension=None):
"""Take entries of tensor along a dimension according to the index.
This function is identical to torch.take() when dimension=None,
otherwise, it is identical to ONNX gather() function.
This function is identical to torch.take() when dimension=None,
otherwise, it is identical to ONNX gather() function.
"""
result = self.shallow_copy()
index = index.long()
Expand Down
8 changes: 4 additions & 4 deletions crypten/mpc/primitives/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
# MPC tensor where shares are XOR-sharings.
class BinarySharedTensor(object):
"""
Encrypted tensor object that uses binary sharing to perform computations.
Encrypted tensor object that uses binary sharing to perform computations.
Binary shares are computed by splitting each value of the input tensor
into n separate random values that xor together to the input tensor value,
where n is the number of parties present in the protocol (world_size).
Binary shares are computed by splitting each value of the input tensor
into n separate random values that xor together to the input tensor value,
where n is the number of parties present in the protocol (world_size).
"""

def __init__(self, tensor=None, size=None, src=0, device=None):
Expand Down
6 changes: 3 additions & 3 deletions crypten/mpc/primitives/ot/baseOT.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@

class BaseOT:
"""
hardcoded public parameter
log2(__prime) > 128
hardcoded public parameter
log2(__prime) > 128
__generator is a primitive root of __prime
__generator is a primitive root of __prime
"""

__prime = 631276824160446938136046282957027762913
Expand Down
12 changes: 6 additions & 6 deletions crypten/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def forward_function(*args, **kwargs):

def __getattr__(self, name):
"""Redefine __getattr__ so that any parameters, modules or buffers
inside the Module object can be accessed as attributes
inside the Module object can be accessed as attributes
"""
if "_parameters" in self.__dict__:
parameters = self.__dict__["_parameters"]
Expand All @@ -482,8 +482,8 @@ def __getattr__(self, name):

def __setattr__(self, name, value):
"""Redefine __setattr__ so that any submodules created
inside the Module object are registered with _modules
OrderedDict.
inside the Module object are registered with _modules
OrderedDict.
"""

def remove_from(*dicts):
Expand Down Expand Up @@ -2056,9 +2056,9 @@ def _identify_bool_attributes_with_defaults(
attributes, attr_name, attr_value, default=True
):
"""For boolean attributes that have default values in the ONNX specification
checks to see if they are present in `attributes`, and assigns the
default if not present and appropriate value if present. Note `attr_value`
must be the value `attributes[attr_name]` if the default is to be kept.
checks to see if they are present in `attributes`, and assigns the
default if not present and appropriate value if present. Note `attr_value`
must be the value `attributes[attr_name]` if the default is to be kept.
"""
output = default
if attr_name in attributes and attributes[attr_name] != attr_value:
Expand Down
6 changes: 3 additions & 3 deletions test/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class TestArithmetic(MultiProcessTestCase):
"""
This class tests all functions of the ArithmeticSharedTensor.
This class tests all functions of the ArithmeticSharedTensor.
"""

def setUp(self):
Expand Down Expand Up @@ -72,8 +72,8 @@ def test_share_attr(self):

def test_encrypt_decrypt(self):
"""
Tests tensor encryption and decryption for both positive
and negative values.
Tests tensor encryption and decryption for both positive
and negative values.
"""
sizes = [
(),
Expand Down
6 changes: 3 additions & 3 deletions test/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

class TestBinary(MultiProcessTestCase):
"""
This class tests all functions of BinarySharedTensor.
This class tests all functions of BinarySharedTensor.
"""

def setUp(self):
Expand Down Expand Up @@ -48,8 +48,8 @@ def _check(self, encrypted_tensor, reference, msg, dst=None, tolerance=None):

def test_encrypt_decrypt(self):
"""
Tests tensor encryption and decryption for both positive
and negative values.
Tests tensor encryption and decryption for both positive
and negative values.
"""
sizes = [
(),
Expand Down
2 changes: 1 addition & 1 deletion test/test_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

class TestCommunicator:
"""
This class tests all member functions of crypten package
This class tests all member functions of crypten package
"""

def test_przs_generators(self):
Expand Down
2 changes: 1 addition & 1 deletion test/test_crypten.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

class TestCrypten(MultiProcessTestCase):
"""
This class tests all member functions of crypten package
This class tests all member functions of crypten package
"""

def setUp(self):
Expand Down
16 changes: 8 additions & 8 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def forward(self, x):
@unittest.skipIf(torch.cuda.is_available() is False, "requires CUDA")
class TestCUDA(TestMPC):
"""
This class tests all functions of CUDALongTensor as well as its integration with MPCTensor.
This class tests all functions of CUDALongTensor as well as its integration with MPCTensor.
"""

def _check_int(self, result, reference, msg):
Expand Down Expand Up @@ -476,11 +476,11 @@ def test_torch_gather(self):

@unittest.skip("torch.scatter behaves inconsistently on CUDA")
def test_torch_scatter(self):
""" Test scatter/scatter_add function of CUDALongTensor
"""Test scatter/scatter_add function of CUDALongTensor
This test will be skipped for now since torch.scatter provides
inconsistent result given the same input on CUDA. This is likely
due to a potential bug on pytorch's implementation of scatter
This test will be skipped for now since torch.scatter provides
inconsistent result given the same input on CUDA. This is likely
due to a potential bug on pytorch's implementation of scatter
"""

funcs = ["scatter", "scatter_add"]
Expand Down Expand Up @@ -537,9 +537,9 @@ def test_torch_nonzero(self):

@unittest.skip("torch.scatter behaves inconsistently on CUDA")
def test_scatter(self):
""" This test will be skipped for now since torch.scatter provides
inconsistent result given the same input on CUDA. This is likely
due to a potential bug on pytorch's implementation of scatter
"""This test will be skipped for now since torch.scatter provides
inconsistent result given the same input on CUDA. This is likely
due to a potential bug on pytorch's implementation of scatter
"""
pass

Expand Down
2 changes: 1 addition & 1 deletion test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class TestDistributions(object):
"""
This class tests accuracy of distributions provided by random sampling in crypten.
This class tests accuracy of distributions provided by random sampling in crypten.
"""

def _check_distribution(
Expand Down
6 changes: 3 additions & 3 deletions test/test_mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

class TestMPC(object):
"""
This class tests all functions of MPCTensor.
This class tests all functions of MPCTensor.
"""

def _get_random_test_tensor(self, *args, **kwargs):
Expand Down Expand Up @@ -142,8 +142,8 @@ def test_share_attr(self):

def test_encrypt_decrypt(self):
"""
Tests tensor encryption and decryption for both positive
and negative values.
Tests tensor encryption and decryption for both positive
and negative values.
"""
sizes = [
(),
Expand Down
2 changes: 1 addition & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

class TestNN(object):
"""
This class tests the crypten.nn package.
This class tests the crypten.nn package.
"""

def _check(self, encrypted_tensor, reference, msg, tolerance=None):
Expand Down
2 changes: 1 addition & 1 deletion test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class TestOptim(object):
"""
This class tests the crypten.optim package.
This class tests the crypten.optim package.
"""

def _check(self, encrypted_tensor, reference, msg, tolerance=None):
Expand Down

0 comments on commit 1cd3894

Please sign in to comment.