Skip to content

Commit

Permalink
chore: monkey patch transformers in the env or not
Browse files Browse the repository at this point in the history
  • Loading branch information
fd0r committed Jun 28, 2024
1 parent eadea72 commit 6482845
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 116 deletions.
1 change: 1 addition & 0 deletions .gitleaksignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ a99389ee01cbb972e46a892d3d0e9c7f8ee23f59:use_case_examples/training/analyze.ipyn
f41de03048a9ed27946b875e81b34138bb4bb17b:use_case_examples/training/analyze.ipynb:aws-access-token:6404
e2904473898ddd325f245f4faca526a0e9520f49:builders/Dockerfile.zamalang-env:generic-api-key:5
7d5e885816f1f1e432dd94da38c5c8267292056a:docs/advanced_examples/XGBRegressor.ipynb:aws-access-token:1026
25c5e7abaa7382520af3fb7a64266e193b1f6a59:poetry.lock:square-access-token:6401
21 changes: 12 additions & 9 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1390,15 +1390,18 @@ def from_sklearn_model(
# Get the expected number of ONNX outputs in the sklearn model.
expected_number_of_outputs = 1 if is_regressor_or_partial_regressor(model) else 2

onnx_model, lsbs_to_remove_for_trees, input_quantizers, output_quantizers = (
onnx_fp32_model_to_quantized_model(
onnx_model,
n_bits,
framework,
expected_number_of_outputs,
n_features,
X,
)
(
onnx_model,
lsbs_to_remove_for_trees,
input_quantizers,
output_quantizers,
) = onnx_fp32_model_to_quantized_model(
onnx_model,
n_bits,
framework,
expected_number_of_outputs,
n_features,
X,
)

model.input_quantizers = input_quantizers
Expand Down
48 changes: 27 additions & 21 deletions tests/deployment/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,30 +633,36 @@ def check_client_server_training(
fhe_server.load()

# Client-server training
q_weights_deployment, q_bias_deployment, weights_deployment, bias_deployment = (
get_fitted_weights(
x_train,
y_train,
weights,
bias,
batch_size=batch_size,
max_iter=max_iter,
fhe_client=fhe_client,
fhe_server=fhe_server,
)
(
q_weights_deployment,
q_bias_deployment,
weights_deployment,
bias_deployment,
) = get_fitted_weights(
x_train,
y_train,
weights,
bias,
batch_size=batch_size,
max_iter=max_iter,
fhe_client=fhe_client,
fhe_server=fhe_server,
)

# Quantized module (development) training
q_weights_development, q_bias_development, weights_development, bias_development = (
get_fitted_weights(
x_train,
y_train,
weights,
bias,
batch_size=batch_size,
max_iter=max_iter,
quantized_module=model.training_quantized_module,
)
(
q_weights_development,
q_bias_development,
weights_development,
bias_development,
) = get_fitted_weights(
x_train,
y_train,
weights,
bias,
batch_size=batch_size,
max_iter=max_iter,
quantized_module=model.training_quantized_module,
)

# Check that both quantized outputs from the quantized module (development) are matching the
Expand Down
152 changes: 86 additions & 66 deletions tests/sklearn/test_fhe_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,34 +439,42 @@ def test_encrypted_fit_coherence(
x, y = get_blob_data(scale_input=True, parameters_range=parameters_range)
y = y + label_offset

weights_disable, bias_disable, y_pred_proba_disable, y_pred_class_disable, _ = (
check_encrypted_fit(
x,
y,
n_bits,
random_state,
parameters_range,
max_iter,
fit_intercept,
simulation_configuration,
check_accuracy=check_accuracy,
fhe="disable",
)
(
weights_disable,
bias_disable,
y_pred_proba_disable,
y_pred_class_disable,
_,
) = check_encrypted_fit(
x,
y,
n_bits,
random_state,
parameters_range,
max_iter,
fit_intercept,
simulation_configuration,
check_accuracy=check_accuracy,
fhe="disable",
)

weights_simulated, bias_simulated, y_pred_proba_simulated, y_pred_class_simulated, _ = (
check_encrypted_fit(
x,
y,
n_bits,
random_state,
parameters_range,
max_iter,
fit_intercept,
simulation_configuration,
check_accuracy=check_accuracy,
fhe="simulate",
)
(
weights_simulated,
bias_simulated,
y_pred_proba_simulated,
y_pred_class_simulated,
_,
) = check_encrypted_fit(
x,
y,
n_bits,
random_state,
parameters_range,
max_iter,
fit_intercept,
simulation_configuration,
check_accuracy=check_accuracy,
fhe="simulate",
)

# Make sure weight, bias and prediction values are identical between clear and
Expand All @@ -476,19 +484,23 @@ def test_encrypted_fit_coherence(
assert array_allclose_and_same_shape(y_pred_proba_simulated, y_pred_proba_disable)
assert array_allclose_and_same_shape(y_pred_class_simulated, y_pred_class_disable)

weights_partial, bias_partial, y_pred_proba_partial, y_pred_class_partial, _ = (
check_encrypted_fit(
x,
y,
n_bits,
random_state,
parameters_range,
max_iter,
fit_intercept,
simulation_configuration,
check_accuracy=check_accuracy,
partial_fit=True,
)
(
weights_partial,
bias_partial,
y_pred_proba_partial,
y_pred_class_partial,
_,
) = check_encrypted_fit(
x,
y,
n_bits,
random_state,
parameters_range,
max_iter,
fit_intercept,
simulation_configuration,
check_accuracy=check_accuracy,
partial_fit=True,
)

# Make sure weight, bias and prediction values are identical between clear and partial fitting
Expand Down Expand Up @@ -547,21 +559,25 @@ def test_encrypted_fit_coherence(

# Fit the model for the remaining iterations starting at the previous weight/bias values. It is
# necessary to provide the RNG object as well in order to keep data shuffle consistent
weights_coef_init, bias_coef_init, y_pred_proba_coef_init, y_pred_class_coef_init, _ = (
check_encrypted_fit(
x,
y,
n_bits,
random_state,
parameters_range,
last_iterations,
fit_intercept,
simulation_configuration,
check_accuracy=check_accuracy,
fhe="simulate",
random_number_generator=rng_coef_init,
fit_kwargs=coef_init_fit_kwargs,
)
(
weights_coef_init,
bias_coef_init,
y_pred_proba_coef_init,
y_pred_class_coef_init,
_,
) = check_encrypted_fit(
x,
y,
n_bits,
random_state,
parameters_range,
last_iterations,
fit_intercept,
simulation_configuration,
check_accuracy=check_accuracy,
fhe="simulate",
random_number_generator=rng_coef_init,
fit_kwargs=coef_init_fit_kwargs,
)

# Make sure weight, bias and prediction values are identical between clear fitting with and
Expand Down Expand Up @@ -607,18 +623,22 @@ def test_encrypted_fit_in_fhe(n_bits, max_iter, parameter_min_max, default_confi
# Avoid checking the accuracy. Since this test is mostly here to make sure that FHE execution
# properly matches the quantized clear one, some parameters (for example, the number of
# features) were set to make it quicker, without considering the model's accuracy
weights_disable, bias_disable, y_pred_proba_disable, y_pred_class_disable, _ = (
check_encrypted_fit(
x,
y,
n_bits,
random_state,
parameters_range,
max_iter,
fit_intercept,
default_configuration,
fhe="disable",
)
(
weights_disable,
bias_disable,
y_pred_proba_disable,
y_pred_class_disable,
_,
) = check_encrypted_fit(
x,
y,
n_bits,
random_state,
parameters_range,
max_iter,
fit_intercept,
default_configuration,
fhe="disable",
)

# Same, avoid checking the accuracy
Expand Down
59 changes: 48 additions & 11 deletions tests/torch/test_hybrid_converter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for the hybrid model converter."""

import sys
import tempfile
from pathlib import Path
from typing import List, Union
Expand Down Expand Up @@ -32,12 +33,16 @@ def test_tuple_serialization(tup):
assert tup == underscore_str_to_tuple(tuple_to_underscore_str(tup))


# pylint: disable=too-many-locals
def run_hybrid_llm_test(
model: torch.nn.Module,
inputs: torch.Tensor,
module_names: Union[str, List],
expected_accuracy,
has_pbs: bool,
has_pbs_reshape: bool,
monkeypatch,
transformers_installed,
):
"""Run the test for any model with its private module names."""

Expand All @@ -47,11 +52,28 @@ def run_hybrid_llm_test(
compress_input_ciphertexts=True,
)

# Create a hybrid model
hybrid_model = HybridFHEModel(model, module_names)
hybrid_model.compile_model(
inputs, p_error=0.1, n_bits=9, rounding_threshold_bits=8, configuration=configuration
)
with monkeypatch.context() as m:
if not transformers_installed:
m.setitem(sys.modules, "transformers", None)
if has_pbs_reshape:
has_pbs = True
# Create a hybrid model
hybrid_model = HybridFHEModel(model, module_names)
try:
hybrid_model.compile_model(
inputs,
p_error=0.1,
n_bits=9,
rounding_threshold_bits=8,
configuration=configuration,
)
except RuntimeError as error:
# When reshaping adds PBSs we sometimes encounter NoParametersFound
# when compiling. In this case we skip the rest since we can't simulate
# without compilation.
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4183
assert "NoParametersFound" in error.args[0]
pytest.skip(error.args[0])

if has_pbs:
# Check for non-zero programmable bootstrapping
Expand Down Expand Up @@ -124,14 +146,22 @@ def run_hybrid_llm_test(
# 'from_pretrained' method
@pytest.mark.filterwarnings("ignore::FutureWarning")
@pytest.mark.parametrize(
"list_or_str_private_modules_names, expected_accuracy, has_pbs",
"list_or_str_private_modules_names, expected_accuracy, has_pbs, has_pbs_reshape",
[
("transformer.h.0.mlp", 0.95, True),
(["transformer.h.0.mlp", "transformer.h.1.mlp"], 0.40, True),
("transformer.h.0.mlp.c_fc", 1.0, False),
("transformer.h.0.mlp", 0.95, True, False),
(["transformer.h.0.mlp", "transformer.h.1.mlp"], 0.40, True, False),
("transformer.h.0.mlp.c_fc", 1.0, False, True),
],
)
def test_gpt2_hybrid_mlp(list_or_str_private_modules_names, expected_accuracy, has_pbs):
@pytest.mark.parametrize("transformers_installed", [True, False])
def test_gpt2_hybrid_mlp(
list_or_str_private_modules_names,
expected_accuracy,
has_pbs,
has_pbs_reshape,
transformers_installed,
monkeypatch,
):
"""Test GPT2 hybrid."""

# Get GPT2 from Hugging Face
Expand All @@ -144,7 +174,14 @@ def test_gpt2_hybrid_mlp(list_or_str_private_modules_names, expected_accuracy, h
# Run the test with using a single module in FHE
assert isinstance(model, torch.nn.Module)
run_hybrid_llm_test(
model, input_ids, list_or_str_private_modules_names, expected_accuracy, has_pbs
model,
input_ids,
list_or_str_private_modules_names,
expected_accuracy,
has_pbs,
has_pbs_reshape,
monkeypatch,
transformers_installed,
)


Expand Down
Loading

0 comments on commit 6482845

Please sign in to comment.