diff --git a/src/concrete/ml/pytest/torch_models.py b/src/concrete/ml/pytest/torch_models.py index fccc67eb1..8c43ed856 100644 --- a/src/concrete/ml/pytest/torch_models.py +++ b/src/concrete/ml/pytest/torch_models.py @@ -1494,6 +1494,27 @@ def predict(x, weights, bias): return outputs.squeeze() +class WhereNet(torch.nn.Module): + """Simple network with a where operation for testing.""" + + def __init__(self, n_hidden): + super().__init__() + self.n_hidden = n_hidden + self.fc_tot = torch.rand(1, n_hidden) > 0.5 + + def forward(self, x): + """Forward pass. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying the where operation. + """ + y = torch.where(self.fc_tot, x, 0.0) + return y + + class AddNet(nn.Module): """Torch model that performs a simple addition between two inputs.""" diff --git a/src/concrete/ml/quantization/post_training.py b/src/concrete/ml/quantization/post_training.py index 2d3f54f9f..77788a643 100644 --- a/src/concrete/ml/quantization/post_training.py +++ b/src/concrete/ml/quantization/post_training.py @@ -899,7 +899,15 @@ def _process_initializer( return values.view(RawOpOutput) if not isinstance(values, (numpy.ndarray, Tracer)): values = numpy.array(values) - is_signed = is_symmetric = self._check_distribution_is_symmetric_around_zero(values) + if not numpy.issubdtype(values.dtype, numpy.bool_): + is_signed = is_symmetric = self._check_distribution_is_symmetric_around_zero(values) + # Boolean parameters are quantized to 1 bit + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4593 + # We should not quantize boolean parameters in the future + else: + is_signed = is_symmetric = False + n_bits = 1 + values = values.astype(numpy.float64) return QuantizedArray( n_bits, diff --git a/tests/torch/test_compile_torch.py b/tests/torch/test_compile_torch.py index 80041144f..4d55ffc16 100644 --- a/tests/torch/test_compile_torch.py +++ b/tests/torch/test_compile_torch.py @@ -54,6 +54,7 @@ TorchDivide, TorchMultiply, UnivariateModule, + WhereNet, ) from concrete.ml.quantization import QuantizedModule @@ -764,6 +765,37 @@ def test_dump_torch_network( ) +def test_compile_where_net(default_configuration, check_is_good_execution_for_cml_vs_circuit): + """Test compilation and execution of PTQSimpleNet.""" + n_feat = 32 + n_examples = 100 + + torch_model = WhereNet(n_feat) + + # Create random input + inputset = numpy.random.uniform(-100, 100, size=(n_examples, n_feat)) + + # Compile the model + quantized_module = compile_torch_model( + torch_model, + inputset, + n_bits=16, + configuration=default_configuration, + ) + + # Test execution + x_test = inputset[:10] # Use first 10 samples for testing + + # Check if FHE simulation and quantized module forward give the same output + check_is_good_execution_for_cml_vs_circuit(x_test, model=quantized_module, simulate=True) + + # Compare with PyTorch model + torch_output = torch_model(torch.from_numpy(x_test).float()).detach().numpy() + quantized_output = quantized_module.forward(x_test, fhe="disable") + + numpy.testing.assert_allclose(torch_output, quantized_output, rtol=1e-2, atol=1e-2) + + @pytest.mark.parametrize("verbose", [True, False], ids=["with_verbose", "without_verbose"]) # pylint: disable-next=too-many-locals def test_pretrained_mnist_qat(