Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: fix where which use boolean params #863

Merged
merged 1 commit into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/concrete/ml/pytest/torch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,6 +1494,27 @@ def predict(x, weights, bias):
return outputs.squeeze()


class WhereNet(torch.nn.Module):
"""Simple network with a where operation for testing."""

def __init__(self, n_hidden):
super().__init__()
self.n_hidden = n_hidden
self.fc_tot = torch.rand(1, n_hidden) > 0.5

def forward(self, x):
"""Forward pass.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The output tensor after applying the where operation.
"""
y = torch.where(self.fc_tot, x, 0.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the condition here is not done on the encrypted variable. is this the case we should test ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand. self.fc_tot is the condition, a constant, not encrypted and x is the input encrypted tensor that will be 0 where condition matches?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should also test the case when the condition is on an encrypted variable

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this the case? x is encrypted and self.fc_tot isn't. Do you mean "is not on an encrypted variable?"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok sorry you are right the conditions are in clear we are just selecting a part of the encrypted tensor.

return y


class AddNet(nn.Module):
"""Torch model that performs a simple addition between two inputs."""

Expand Down
10 changes: 9 additions & 1 deletion src/concrete/ml/quantization/post_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,15 @@ def _process_initializer(
return values.view(RawOpOutput)
if not isinstance(values, (numpy.ndarray, Tracer)):
values = numpy.array(values)
is_signed = is_symmetric = self._check_distribution_is_symmetric_around_zero(values)
if not numpy.issubdtype(values.dtype, numpy.bool_):
is_signed = is_symmetric = self._check_distribution_is_symmetric_around_zero(values)
# Boolean parameters are quantized to 1 bit
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4593
# We should not quantize boolean parameters in the future
else:
is_signed = is_symmetric = False
n_bits = 1
values = values.astype(numpy.float64)

return QuantizedArray(
n_bits,
Expand Down
32 changes: 32 additions & 0 deletions tests/torch/test_compile_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
TorchDivide,
TorchMultiply,
UnivariateModule,
WhereNet,
)
from concrete.ml.quantization import QuantizedModule

Expand Down Expand Up @@ -764,6 +765,37 @@ def test_dump_torch_network(
)


def test_compile_where_net(default_configuration, check_is_good_execution_for_cml_vs_circuit):
"""Test compilation and execution of PTQSimpleNet."""
n_feat = 32
n_examples = 100

torch_model = WhereNet(n_feat)

# Create random input
inputset = numpy.random.uniform(-100, 100, size=(n_examples, n_feat))

# Compile the model
quantized_module = compile_torch_model(
torch_model,
inputset,
n_bits=16,
configuration=default_configuration,
)

# Test execution
x_test = inputset[:10] # Use first 10 samples for testing

# Check if FHE simulation and quantized module forward give the same output
check_is_good_execution_for_cml_vs_circuit(x_test, model=quantized_module, simulate=True)

# Compare with PyTorch model
torch_output = torch_model(torch.from_numpy(x_test).float()).detach().numpy()
quantized_output = quantized_module.forward(x_test, fhe="disable")

numpy.testing.assert_allclose(torch_output, quantized_output, rtol=1e-2, atol=1e-2)


@pytest.mark.parametrize("verbose", [True, False], ids=["with_verbose", "without_verbose"])
# pylint: disable-next=too-many-locals
def test_pretrained_mnist_qat(
Expand Down
Loading