Skip to content

Commit

Permalink
chore: fix where which use boolean params (#863)
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery authored Oct 28, 2024
1 parent 664cbd4 commit b307ca5
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 1 deletion.
21 changes: 21 additions & 0 deletions src/concrete/ml/pytest/torch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,6 +1496,27 @@ def predict(x, weights, bias):
return outputs.squeeze()


class WhereNet(torch.nn.Module):
"""Simple network with a where operation for testing."""

def __init__(self, n_hidden):
super().__init__()
self.n_hidden = n_hidden
self.fc_tot = torch.rand(1, n_hidden) > 0.5

def forward(self, x):
"""Forward pass.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying the where operation.
"""
y = torch.where(self.fc_tot, x, 0.0)
return y


class AddNet(nn.Module):
"""Torch model that performs a simple addition between two inputs."""

Expand Down
10 changes: 9 additions & 1 deletion src/concrete/ml/quantization/post_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,15 @@ def _process_initializer(
return values.view(RawOpOutput)
if not isinstance(values, (numpy.ndarray, Tracer)):
values = numpy.array(values)
is_signed = is_symmetric = self._check_distribution_is_symmetric_around_zero(values)
if not numpy.issubdtype(values.dtype, numpy.bool_):
is_signed = is_symmetric = self._check_distribution_is_symmetric_around_zero(values)
# Boolean parameters are quantized to 1 bit
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4593
# We should not quantize boolean parameters in the future
else:
is_signed = is_symmetric = False
n_bits = 1
values = values.astype(numpy.float64)

return QuantizedArray(
n_bits,
Expand Down
32 changes: 32 additions & 0 deletions tests/torch/test_compile_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
TorchDivide,
TorchMultiply,
UnivariateModule,
WhereNet,
)
from concrete.ml.quantization import QuantizedModule

Expand Down Expand Up @@ -775,6 +776,37 @@ def test_dump_torch_network(
)


def test_compile_where_net(default_configuration, check_is_good_execution_for_cml_vs_circuit):
"""Test compilation and execution of PTQSimpleNet."""
n_feat = 32
n_examples = 100

torch_model = WhereNet(n_feat)

# Create random input
inputset = numpy.random.uniform(-100, 100, size=(n_examples, n_feat))

# Compile the model
quantized_module = compile_torch_model(
torch_model,
inputset,
n_bits=16,
configuration=default_configuration,
)

# Test execution
x_test = inputset[:10] # Use first 10 samples for testing

# Check if FHE simulation and quantized module forward give the same output
check_is_good_execution_for_cml_vs_circuit(x_test, model=quantized_module, simulate=True)

# Compare with PyTorch model
torch_output = torch_model(torch.from_numpy(x_test).float()).detach().numpy()
quantized_output = quantized_module.forward(x_test, fhe="disable")

numpy.testing.assert_allclose(torch_output, quantized_output, rtol=1e-2, atol=1e-2)


@pytest.mark.parametrize("verbose", [True, False], ids=["with_verbose", "without_verbose"])
# pylint: disable-next=too-many-locals
def test_pretrained_mnist_qat(
Expand Down

0 comments on commit b307ca5

Please sign in to comment.