diff --git a/tfhe/benches/integer/zk_pke.rs b/tfhe/benches/integer/zk_pke.rs index 03b13b5681..1df1f61e80 100644 --- a/tfhe/benches/integer/zk_pke.rs +++ b/tfhe/benches/integer/zk_pke.rs @@ -7,9 +7,7 @@ use std::fs::{File, OpenOptions}; use std::io::Write; use std::path::Path; use tfhe::integer::key_switching_key::KeySwitchingKey; -use tfhe::integer::parameters::{ - IntegerCompactCiphertextListCastingMode, IntegerCompactCiphertextListUnpackingMode, -}; +use tfhe::integer::parameters::IntegerCompactCiphertextListExpansionMode; use tfhe::integer::{ClientKey, CompactPrivateKey, CompactPublicKey, ServerKey}; use tfhe::keycache::NamedParam; use tfhe::shortint::parameters::classic::tuniform::p_fail_2_minus_64::ks_pbs::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; @@ -247,8 +245,7 @@ fn pke_zk_verify(c: &mut Criterion, results_file: &Path) { public_params, &pk, &metadata, - IntegerCompactCiphertextListUnpackingMode::UnpackIfNecessary(&sks), - IntegerCompactCiphertextListCastingMode::CastIfNecessary( + IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary( casting_key.as_view(), ), ) diff --git a/tfhe/docs/guides/data_versioning.md b/tfhe/docs/guides/data_versioning.md index f5f344b559..4253cc3b7c 100644 --- a/tfhe/docs/guides/data_versioning.md +++ b/tfhe/docs/guides/data_versioning.md @@ -90,8 +90,8 @@ You will find below a list of breaking changes and how to upgrade them. ```rust use std::io::Cursor; use tfhe::integer::ciphertext::{ - CompactCiphertextList, DataKind, IntegerCompactCiphertextListCastingMode, - IntegerCompactCiphertextListUnpackingMode, SignedRadixCiphertext, + CompactCiphertextList, DataKind, IntegerCompactCiphertextListExpansionMode, + SignedRadixCiphertext, }; use tfhe::integer::{ClientKey, CompactPublicKey}; use tfhe::shortint::parameters::classic::compact_pk::PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_KS_PBS; @@ -130,10 +130,7 @@ pub fn main() { .reinterpret_data(&[DataKind::Signed(num_blocks)]) .unwrap(); let expander = compact_ct - .expand( - IntegerCompactCiphertextListUnpackingMode::NoUnpacking, - IntegerCompactCiphertextListCastingMode::NoCasting, - ) + .expand(IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking) .unwrap(); let expanded = expander.get::(0).unwrap().unwrap(); let decrypted: i8 = client_key.decrypt_signed_radix(&expanded); diff --git a/tfhe/src/high_level_api/compact_list.rs b/tfhe/src/high_level_api/compact_list.rs index 58837601bb..6e52c6e539 100644 --- a/tfhe/src/high_level_api/compact_list.rs +++ b/tfhe/src/high_level_api/compact_list.rs @@ -12,8 +12,7 @@ use crate::high_level_api::traits::Tagged; use crate::integer::ciphertext::{Compactable, DataKind, Expandable}; use crate::integer::encryption::KnowsMessageModulus; use crate::integer::parameters::{ - CompactCiphertextListConformanceParams, IntegerCompactCiphertextListCastingMode, - IntegerCompactCiphertextListUnpackingMode, + CompactCiphertextListConformanceParams, IntegerCompactCiphertextListExpansionMode, }; use crate::named::Named; use crate::prelude::CiphertextList; @@ -111,13 +110,7 @@ impl CompactCiphertextList { sks: &crate::ServerKey, ) -> crate::Result { self.inner - .expand( - IntegerCompactCiphertextListUnpackingMode::UnpackIfNecessary(sks.key.pbs_key()), - sks.cpk_casting_key().map_or( - IntegerCompactCiphertextListCastingMode::NoCasting, - IntegerCompactCiphertextListCastingMode::CastIfNecessary, - ), - ) + .expand(sks.integer_compact_ciphertext_list_expansion_mode()) .map(|inner| CompactCiphertextListExpander { inner, tag: self.tag.clone(), @@ -129,44 +122,22 @@ impl CompactCiphertextList { if !self.inner.is_packed() && !self.inner.needs_casting() { // No ServerKey required, short-circuit to avoid the global state call return Ok(CompactCiphertextListExpander { - inner: self.inner.expand( - IntegerCompactCiphertextListUnpackingMode::NoUnpacking, - IntegerCompactCiphertextListCastingMode::NoCasting, - )?, + inner: self + .inner + .expand(IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking)?, tag: self.tag.clone(), }); } global_state::try_with_internal_keys(|maybe_keys| match maybe_keys { None => Err(crate::high_level_api::errors::UninitializedServerKey.into()), - Some(InternalServerKey::Cpu(cpu_key)) => { - let unpacking_mode = if self.inner.is_packed() { - IntegerCompactCiphertextListUnpackingMode::UnpackIfNecessary(cpu_key.pbs_key()) - } else { - IntegerCompactCiphertextListUnpackingMode::NoUnpacking - }; - - let casting_mode = if self.inner.needs_casting() { - IntegerCompactCiphertextListCastingMode::CastIfNecessary( - cpu_key.cpk_casting_key().ok_or_else(|| { - crate::Error::new( - "No casting key found in ServerKey, \ - required to expand this CompactCiphertextList" - .to_string(), - ) - })?, - ) - } else { - IntegerCompactCiphertextListCastingMode::NoCasting - }; - - self.inner - .expand(unpacking_mode, casting_mode) - .map(|inner| CompactCiphertextListExpander { - inner, - tag: self.tag.clone(), - }) - } + Some(InternalServerKey::Cpu(cpu_key)) => self + .inner + .expand(cpu_key.integer_compact_ciphertext_list_expansion_mode()) + .map(|inner| CompactCiphertextListExpander { + inner, + tag: self.tag.clone(), + }), #[cfg(feature = "gpu")] Some(_) => Err(crate::Error::new("Expected a CPU server key".to_string())), }) @@ -261,8 +232,7 @@ mod zk { public_params, &pk.key.key, metadata, - IntegerCompactCiphertextListUnpackingMode::NoUnpacking, - IntegerCompactCiphertextListCastingMode::NoCasting, + IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking, )?, tag: self.tag.clone(), }); @@ -270,42 +240,18 @@ mod zk { global_state::try_with_internal_keys(|maybe_keys| match maybe_keys { None => Err(crate::high_level_api::errors::UninitializedServerKey.into()), - Some(InternalServerKey::Cpu(cpu_key)) => { - let unpacking_mode = if self.inner.is_packed() { - IntegerCompactCiphertextListUnpackingMode::UnpackIfNecessary( - cpu_key.pbs_key(), - ) - } else { - IntegerCompactCiphertextListUnpackingMode::NoUnpacking - }; - - let casting_mode = if self.inner.needs_casting() { - IntegerCompactCiphertextListCastingMode::CastIfNecessary( - cpu_key.cpk_casting_key().ok_or_else(|| { - crate::Error::new( - "No casting key found in ServerKey, \ - required to expand this CompactCiphertextList" - .to_string(), - ) - })?, - ) - } else { - IntegerCompactCiphertextListCastingMode::NoCasting - }; - - self.inner - .verify_and_expand( - public_params, - &pk.key.key, - metadata, - unpacking_mode, - casting_mode, - ) - .map(|expander| CompactCiphertextListExpander { - inner: expander, - tag: self.tag.clone(), - }) - } + Some(InternalServerKey::Cpu(cpu_key)) => self + .inner + .verify_and_expand( + public_params, + &pk.key.key, + metadata, + cpu_key.integer_compact_ciphertext_list_expansion_mode(), + ) + .map(|expander| CompactCiphertextListExpander { + inner: expander, + tag: self.tag.clone(), + }), #[cfg(feature = "gpu")] Some(_) => Err(crate::Error::new("Expected a CPU server key".to_string())), }) @@ -321,8 +267,7 @@ mod zk { // No ServerKey required, short circuit to avoid the global state call return Ok(CompactCiphertextListExpander { inner: self.inner.expand_without_verification( - IntegerCompactCiphertextListUnpackingMode::NoUnpacking, - IntegerCompactCiphertextListCastingMode::NoCasting, + IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking, )?, tag: self.tag.clone(), }); @@ -330,36 +275,15 @@ mod zk { global_state::try_with_internal_keys(|maybe_keys| match maybe_keys { None => Err(crate::high_level_api::errors::UninitializedServerKey.into()), - Some(InternalServerKey::Cpu(cpu_key)) => { - let unpacking_mode = if self.inner.is_packed() { - IntegerCompactCiphertextListUnpackingMode::UnpackIfNecessary( - cpu_key.pbs_key(), - ) - } else { - IntegerCompactCiphertextListUnpackingMode::NoUnpacking - }; - - let casting_mode = if self.inner.needs_casting() { - IntegerCompactCiphertextListCastingMode::CastIfNecessary( - cpu_key.cpk_casting_key().ok_or_else(|| { - crate::Error::new( - "No casting key found in ServerKey, \ - required to expand this CompactCiphertextList" - .to_string(), - ) - })?, - ) - } else { - IntegerCompactCiphertextListCastingMode::NoCasting - }; - - self.inner - .expand_without_verification(unpacking_mode, casting_mode) - .map(|expander| CompactCiphertextListExpander { - inner: expander, - tag: self.tag.clone(), - }) - } + Some(InternalServerKey::Cpu(cpu_key)) => self + .inner + .expand_without_verification( + cpu_key.integer_compact_ciphertext_list_expansion_mode(), + ) + .map(|expander| CompactCiphertextListExpander { + inner: expander, + tag: self.tag.clone(), + }), #[cfg(feature = "gpu")] Some(_) => Err(crate::Error::new("Expected a CPU server key".to_string())), }) diff --git a/tfhe/src/high_level_api/keys/server.rs b/tfhe/src/high_level_api/keys/server.rs index 7eaaac317c..f5e5467079 100644 --- a/tfhe/src/high_level_api/keys/server.rs +++ b/tfhe/src/high_level_api/keys/server.rs @@ -11,6 +11,7 @@ use crate::high_level_api::keys::{IntegerCompressedServerKey, IntegerServerKey}; use crate::integer::compression_keys::{ CompressedCompressionKey, CompressedDecompressionKey, CompressionKey, DecompressionKey, }; +use crate::integer::parameters::IntegerCompactCiphertextListExpansionMode; use crate::named::Named; use crate::prelude::Tagged; use crate::shortint::MessageModulus; @@ -101,6 +102,19 @@ impl ServerKey { pub(in crate::high_level_api) fn message_modulus(&self) -> MessageModulus { self.key.message_modulus() } + + pub(in crate::high_level_api) fn integer_compact_ciphertext_list_expansion_mode( + &self, + ) -> IntegerCompactCiphertextListExpansionMode { + self.cpk_casting_key().map_or_else( + || { + IntegerCompactCiphertextListExpansionMode::UnpackAndSanitizeIfNecessary( + self.pbs_key(), + ) + }, + IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary, + ) + } } impl Tagged for ServerKey { diff --git a/tfhe/src/integer/ciphertext/compact_list.rs b/tfhe/src/integer/ciphertext/compact_list.rs index dc0ffdfd92..5d076abb19 100644 --- a/tfhe/src/integer/ciphertext/compact_list.rs +++ b/tfhe/src/integer/ciphertext/compact_list.rs @@ -7,42 +7,102 @@ use crate::integer::backward_compatibility::ciphertext::ProvenCompactCiphertextL use crate::integer::block_decomposition::DecomposableInto; use crate::integer::encryption::{create_clear_radix_block_iterator, KnowsMessageModulus}; use crate::integer::parameters::CompactCiphertextListConformanceParams; -pub use crate::integer::parameters::{ - IntegerCompactCiphertextListCastingMode, IntegerCompactCiphertextListUnpackingMode, -}; +pub use crate::integer::parameters::IntegerCompactCiphertextListExpansionMode; use crate::integer::{CompactPublicKey, ServerKey}; #[cfg(feature = "zk-pok")] use crate::shortint::ciphertext::ProvenCompactCiphertextListConformanceParams; -use crate::shortint::parameters::CiphertextConformanceParams; -#[cfg(feature = "zk-pok")] -use crate::shortint::parameters::CompactCiphertextListExpansionKind; +use crate::shortint::parameters::{ + CastingFunctionsOwned, CiphertextConformanceParams, ShortintCompactCiphertextListCastingMode, +}; #[cfg(feature = "zk-pok")] use crate::shortint::parameters::{ - CarryModulus, CiphertextModulus, CompactPublicKeyEncryptionParameters, LweDimension, + CiphertextModulus, CompactCiphertextListExpansionKind, CompactPublicKeyEncryptionParameters, + LweDimension, }; -use crate::shortint::{Ciphertext, MessageModulus}; +use crate::shortint::{CarryModulus, Ciphertext, MessageModulus}; #[cfg(feature = "zk-pok")] -use crate::zk::CompactPkeCrs; +use crate::zk::{CompactPkeCrs, CompactPkePublicParams, ZkComputeLoad, ZkVerificationOutCome}; + use rayon::prelude::*; use serde::{Deserialize, Serialize}; use tfhe_versionable::Versionize; -#[cfg(feature = "zk-pok")] -use crate::zk::{CompactPkePublicParams, ZkComputeLoad, ZkVerificationOutCome}; +/// Unpack message and carries and additionally sanitizes blocks that correspond to boolean values +/// to make sure they encrypt a 0 or a 1. +fn unpack_and_sanitize_message_and_carries( + packed_blocks: Vec, + sks: &ServerKey, + infos: &[DataKind], +) -> Vec { + let IntegerUnpackingToShortintCastingModeHelper { + msg_extract, + carry_extract, + msg_extract_bool, + carry_extract_bool, + } = IntegerUnpackingToShortintCastingModeHelper::new( + sks.message_modulus(), + sks.carry_modulus(), + ); + let msg_extract = sks.key.generate_lookup_table(msg_extract); + let carry_extract = sks.key.generate_lookup_table(carry_extract); + let msg_extract_bool = sks.key.generate_lookup_table(msg_extract_bool); + let carry_extract_bool = sks.key.generate_lookup_table(carry_extract_bool); + + let block_count: usize = infos.iter().map(|x| x.num_blocks()).sum(); + let packed_block_count = block_count.div_ceil(2); + assert_eq!( + packed_block_count, + packed_blocks.len(), + "Internal error, invalid packed blocks count during unpacking of a compact ciphertext list." + ); + let mut functions = vec![[None; 2]; packed_block_count]; + + let mut overall_block_idx = 0; + + for data_kind in infos { + let block_count = data_kind.num_blocks(); + for _ in 0..block_count { + let is_in_msg_part = overall_block_idx % 2 == 0; + + let unpacking_function = if is_in_msg_part { + if matches!(data_kind, DataKind::Boolean) { + &msg_extract_bool + } else { + &msg_extract + } + } else if matches!(data_kind, DataKind::Boolean) { + &carry_extract_bool + } else { + &carry_extract + }; + + let packed_block_idx = overall_block_idx / 2; + let idx_in_packed_block = overall_block_idx % 2; + + functions[packed_block_idx][idx_in_packed_block] = Some(unpacking_function); + overall_block_idx += 1; + } + } -fn extract_message_and_carries(packed_blocks: Vec, sks: &ServerKey) -> Vec { packed_blocks .into_par_iter() - .flat_map(|block| { + .zip(functions.into_par_iter()) + .flat_map(|(block, extract_function)| { let mut low_block = block; let mut high_block = low_block.clone(); + let (msg_lut, carry_lut) = (extract_function[0], extract_function[1]); rayon::join( || { - sks.key.message_extract_assign(&mut low_block); + if let Some(msg_lut) = msg_lut { + sks.key.apply_lookup_table_assign(&mut low_block, msg_lut); + } }, || { - sks.key.carry_extract_assign(&mut high_block); + if let Some(carry_lut) = carry_lut { + sks.key + .apply_lookup_table_assign(&mut high_block, carry_lut); + } }, ); @@ -51,6 +111,54 @@ fn extract_message_and_carries(packed_blocks: Vec, sks: &ServerKey) .collect::>() } +/// This function sanitizes boolean blocks to make sure they encrypt a 0 or a 1 +fn sanitize_boolean_blocks( + packed_blocks: Vec, + sks: &ServerKey, + infos: &[DataKind], +) -> Vec { + let message_modulus = sks.message_modulus().0 as u64; + let msg_extract_bool = sks.key.generate_lookup_table(|x: u64| { + let tmp = x % message_modulus; + if tmp == 0 { + 0u64 + } else { + 1u64 + } + }); + + let block_count: usize = infos.iter().map(|x| x.num_blocks()).sum(); + let mut functions = vec![None; block_count]; + + let mut overall_block_idx = 0; + + for data_kind in infos { + let block_count = data_kind.num_blocks(); + for _ in 0..block_count { + let acc = if matches!(data_kind, DataKind::Boolean) { + Some(&msg_extract_bool) + } else { + None + }; + + functions[overall_block_idx] = acc; + overall_block_idx += 1; + } + } + + packed_blocks + .into_par_iter() + .zip(functions.into_par_iter()) + .map(|(mut block, sanitize_acc)| { + if let Some(sanitize_acc) = sanitize_acc { + sks.key.apply_lookup_table_assign(&mut block, sanitize_acc); + } + + block + }) + .collect::>() +} + pub trait Compactable { fn compact_into( self, @@ -337,9 +445,78 @@ impl ParameterSetConformant for CompactCiphertextList { pub const WRONG_UNPACKING_MODE_ERR_MSG: &str = "Cannot expand a CompactCiphertextList that requires unpacking without \ a server key, please provide a integer::ServerKey passing it with the \ - enum variant IntegerCompactCiphertextListUnpackingMode::UnpackIfNecessary \ + enum variant IntegerCompactCiphertextListExpansionMode::UnpackIfNecessary \ + or IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary \ as unpacking_mode."; +struct IntegerUnpackingToShortintCastingModeHelper { + msg_extract: Box u64 + Sync>, + carry_extract: Box u64 + Sync>, + msg_extract_bool: Box u64 + Sync>, + carry_extract_bool: Box u64 + Sync>, +} + +impl IntegerUnpackingToShortintCastingModeHelper { + pub fn new(message_modulus: MessageModulus, carry_modulus: CarryModulus) -> Self { + let message_modulus = message_modulus.0 as u64; + let carry_modulus = carry_modulus.0 as u64; + let msg_extract = Box::new(move |x: u64| x % message_modulus); + let carry_extract = Box::new(move |x: u64| (x / carry_modulus) % message_modulus); + let msg_extract_bool = Box::new(move |x: u64| { + let tmp = x % message_modulus; + u64::from(tmp != 0) + }); + let carry_extract_bool = Box::new(move |x: u64| { + let tmp = (x / carry_modulus) % message_modulus; + u64::from(tmp != 0) + }); + + Self { + msg_extract, + carry_extract, + msg_extract_bool, + carry_extract_bool, + } + } + + pub fn generate_function(&self, infos: &[DataKind]) -> CastingFunctionsOwned { + let block_count: usize = infos.iter().map(|x| x.num_blocks()).sum(); + let packed_block_count = block_count.div_ceil(2); + let mut functions = vec![Some(Vec::with_capacity(2)); packed_block_count]; + + let mut overall_block_idx = 0; + + for data_kind in infos { + let block_count = data_kind.num_blocks(); + for _ in 0..block_count { + let is_in_msg_part = overall_block_idx % 2 == 0; + + let unpacking_function: &(dyn Fn(u64) -> u64 + Sync) = if is_in_msg_part { + if matches!(data_kind, DataKind::Boolean) { + self.msg_extract_bool.as_ref() + } else { + self.msg_extract.as_ref() + } + } else if matches!(data_kind, DataKind::Boolean) { + self.carry_extract_bool.as_ref() + } else { + self.carry_extract.as_ref() + }; + + let packed_block_idx = overall_block_idx / 2; + + if let Some(block_fns) = functions[packed_block_idx].as_mut() { + block_fns.push(unpacking_function) + } + + overall_block_idx += 1; + } + } + + functions + } +} + impl CompactCiphertextList { pub fn is_packed(&self) -> bool { self.ct_list.degree.get() @@ -424,8 +601,8 @@ impl CompactCiphertextList { /// /// ```rust /// use tfhe::integer::ciphertext::{ - /// CompactCiphertextList, DataKind, IntegerCompactCiphertextListCastingMode, - /// IntegerCompactCiphertextListUnpackingMode, RadixCiphertext, SignedRadixCiphertext, + /// CompactCiphertextList, DataKind, IntegerCompactCiphertextListExpansionMode, + /// RadixCiphertext, SignedRadixCiphertext, /// }; /// use tfhe::integer::{ClientKey, CompactPublicKey}; /// use tfhe::shortint::parameters::classic::compact_pk::PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_KS_PBS; @@ -440,10 +617,7 @@ impl CompactCiphertextList { /// let mut compact_ct = CompactCiphertextList::builder(&pk).push(-1i8).build(); /// /// let sanity_check_expander = compact_ct - /// .expand( - /// IntegerCompactCiphertextListUnpackingMode::NoUnpacking, - /// IntegerCompactCiphertextListCastingMode::NoCasting, - /// ) + /// .expand(IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking) /// .unwrap(); /// let sanity_expanded = sanity_check_expander /// .get::(0) @@ -457,10 +631,7 @@ impl CompactCiphertextList { /// .unwrap(); /// /// let expander = compact_ct - /// .expand( - /// IntegerCompactCiphertextListUnpackingMode::NoUnpacking, - /// IntegerCompactCiphertextListCastingMode::NoCasting, - /// ) + /// .expand(IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking) /// .unwrap(); /// /// let expanded = expander.get::(0).unwrap().unwrap(); @@ -491,15 +662,14 @@ impl CompactCiphertextList { pub fn expand( &self, - unpacking_mode: IntegerCompactCiphertextListUnpackingMode<'_>, - casting_mode: IntegerCompactCiphertextListCastingMode<'_>, + expansion_mode: IntegerCompactCiphertextListExpansionMode<'_>, ) -> crate::Result { let is_packed = self.is_packed(); if is_packed && matches!( - unpacking_mode, - IntegerCompactCiphertextListUnpackingMode::NoUnpacking + expansion_mode, + IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking ) { return Err(crate::Error::new(String::from( @@ -507,11 +677,35 @@ impl CompactCiphertextList { ))); } - let expanded_blocks = self.ct_list.expand(casting_mode.into())?; + let expanded_blocks = match expansion_mode { + IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary( + key_switching_key_view, + ) => { + let function_helper; + let functions; + let functions = if is_packed { + let dest_sks = &key_switching_key_view.key.dest_server_key; + function_helper = IntegerUnpackingToShortintCastingModeHelper::new( + dest_sks.message_modulus, + dest_sks.carry_modulus, + ); + functions = function_helper.generate_function(&self.info); + Some(functions.as_slice()) + } else { + None + }; + self.ct_list + .expand(ShortintCompactCiphertextListCastingMode::CastIfNecessary { + casting_key: key_switching_key_view.key, + functions, + })? + } + IntegerCompactCiphertextListExpansionMode::UnpackAndSanitizeIfNecessary(sks) => { + let expanded_blocks = self + .ct_list + .expand(ShortintCompactCiphertextListCastingMode::NoCasting)?; - let expanded_blocks = if is_packed { - match unpacking_mode { - IntegerCompactCiphertextListUnpackingMode::UnpackIfNecessary(sks) => { + if is_packed { let degree = self.ct_list.degree; let mut conformance_params = sks.key.conformance_params(); conformance_params.degree = degree; @@ -525,12 +719,14 @@ impl CompactCiphertextList { } } - extract_message_and_carries(expanded_blocks, sks) + unpack_and_sanitize_message_and_carries(expanded_blocks, sks, &self.info) + } else { + sanitize_boolean_blocks(expanded_blocks, sks, &self.info) } - IntegerCompactCiphertextListUnpackingMode::NoUnpacking => unreachable!(), } - } else { - expanded_blocks + IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking => self + .ct_list + .expand(ShortintCompactCiphertextListCastingMode::NoCasting)?, }; Ok(CompactCiphertextListExpander::new( @@ -597,15 +793,14 @@ impl ProvenCompactCiphertextList { public_params: &CompactPkePublicParams, public_key: &CompactPublicKey, metadata: &[u8], - unpacking_mode: IntegerCompactCiphertextListUnpackingMode<'_>, - casting_mode: IntegerCompactCiphertextListCastingMode<'_>, + expansion_mode: IntegerCompactCiphertextListExpansionMode<'_>, ) -> crate::Result { let is_packed = self.is_packed(); if is_packed && matches!( - unpacking_mode, - IntegerCompactCiphertextListUnpackingMode::NoUnpacking + expansion_mode, + IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking ) { return Err(crate::Error::new(String::from( @@ -613,16 +808,42 @@ impl ProvenCompactCiphertextList { ))); } - let expanded_blocks = self.ct_list.verify_and_expand( - public_params, - &public_key.key, - metadata, - casting_mode.into(), - )?; - - let expanded_blocks = if is_packed { - match unpacking_mode { - IntegerCompactCiphertextListUnpackingMode::UnpackIfNecessary(sks) => { + let expanded_blocks = match expansion_mode { + IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary( + key_switching_key_view, + ) => { + let function_helper; + let functions; + let functions = if is_packed { + let dest_sks = &key_switching_key_view.key.dest_server_key; + function_helper = IntegerUnpackingToShortintCastingModeHelper::new( + dest_sks.message_modulus, + dest_sks.carry_modulus, + ); + functions = function_helper.generate_function(&self.info); + Some(functions.as_slice()) + } else { + None + }; + self.ct_list.verify_and_expand( + public_params, + &public_key.key, + metadata, + ShortintCompactCiphertextListCastingMode::CastIfNecessary { + casting_key: key_switching_key_view.key, + functions, + }, + )? + } + IntegerCompactCiphertextListExpansionMode::UnpackAndSanitizeIfNecessary(sks) => { + let expanded_blocks = self.ct_list.verify_and_expand( + public_params, + &public_key.key, + metadata, + ShortintCompactCiphertextListCastingMode::NoCasting, + )?; + + if is_packed { let degree = self.ct_list.proved_lists[0].0.degree; let mut conformance_params = sks.key.conformance_params(); conformance_params.degree = degree; @@ -636,12 +857,19 @@ impl ProvenCompactCiphertextList { } } - extract_message_and_carries(expanded_blocks, sks) + unpack_and_sanitize_message_and_carries(expanded_blocks, sks, &self.info) + } else { + sanitize_boolean_blocks(expanded_blocks, sks, &self.info) } - IntegerCompactCiphertextListUnpackingMode::NoUnpacking => unreachable!(), } - } else { - expanded_blocks + IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking => { + self.ct_list.verify_and_expand( + public_params, + &public_key.key, + metadata, + ShortintCompactCiphertextListCastingMode::NoCasting, + )? + } }; Ok(CompactCiphertextListExpander::new( @@ -656,15 +884,14 @@ impl ProvenCompactCiphertextList { /// If you are here you were probably looking for it: use at your own risks. pub fn expand_without_verification( &self, - unpacking_mode: IntegerCompactCiphertextListUnpackingMode<'_>, - casting_mode: IntegerCompactCiphertextListCastingMode<'_>, + expansion_mode: IntegerCompactCiphertextListExpansionMode<'_>, ) -> crate::Result { let is_packed = self.is_packed(); if is_packed && matches!( - unpacking_mode, - IntegerCompactCiphertextListUnpackingMode::NoUnpacking + expansion_mode, + IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking ) { return Err(crate::Error::new(String::from( @@ -672,13 +899,36 @@ impl ProvenCompactCiphertextList { ))); } - let expanded_blocks = self - .ct_list - .expand_without_verification(casting_mode.into())?; + let expanded_blocks = match expansion_mode { + IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary( + key_switching_key_view, + ) => { + let function_helper; + let functions; + let functions = if is_packed { + let dest_sks = &key_switching_key_view.key.dest_server_key; + function_helper = IntegerUnpackingToShortintCastingModeHelper::new( + dest_sks.message_modulus, + dest_sks.carry_modulus, + ); + functions = function_helper.generate_function(&self.info); + Some(functions.as_slice()) + } else { + None + }; + self.ct_list.expand_without_verification( + ShortintCompactCiphertextListCastingMode::CastIfNecessary { + casting_key: key_switching_key_view.key, + functions, + }, + )? + } + IntegerCompactCiphertextListExpansionMode::UnpackAndSanitizeIfNecessary(sks) => { + let expanded_blocks = self.ct_list.expand_without_verification( + ShortintCompactCiphertextListCastingMode::NoCasting, + )?; - let expanded_blocks = if is_packed { - match unpacking_mode { - IntegerCompactCiphertextListUnpackingMode::UnpackIfNecessary(sks) => { + if is_packed { let degree = self.ct_list.proved_lists[0].0.degree; let mut conformance_params = sks.key.conformance_params(); conformance_params.degree = degree; @@ -692,12 +942,14 @@ impl ProvenCompactCiphertextList { } } - extract_message_and_carries(expanded_blocks, sks) + unpack_and_sanitize_message_and_carries(expanded_blocks, sks, &self.info) + } else { + sanitize_boolean_blocks(expanded_blocks, sks, &self.info) } - IntegerCompactCiphertextListUnpackingMode::NoUnpacking => unreachable!(), } - } else { - expanded_blocks + IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking => self + .ct_list + .expand_without_verification(ShortintCompactCiphertextListCastingMode::NoCasting)?, }; Ok(CompactCiphertextListExpander::new( @@ -797,9 +1049,7 @@ impl ParameterSetConformant for ProvenCompactCiphertextList { mod tests { use crate::integer::ciphertext::CompactCiphertextList; use crate::integer::key_switching_key::KeySwitchingKey; - use crate::integer::parameters::{ - IntegerCompactCiphertextListCastingMode, IntegerCompactCiphertextListUnpackingMode, - }; + use crate::integer::parameters::IntegerCompactCiphertextListExpansionMode; use crate::integer::{ ClientKey, CompactPrivateKey, CompactPublicKey, RadixCiphertext, ServerKey, }; @@ -843,8 +1093,64 @@ mod tests { crs.public_params(), &pk, &metadata, - IntegerCompactCiphertextListUnpackingMode::UnpackIfNecessary(&sk), - IntegerCompactCiphertextListCastingMode::CastIfNecessary(ksk.as_view()), + IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary(ksk.as_view()), + ) + .unwrap(); + + for (idx, msg) in msgs.iter().copied().enumerate() { + let expanded = expander.get::(idx).unwrap().unwrap(); + let decrypted = cks.decrypt_radix::(&expanded); + assert_eq!(msg, decrypted); + } + + let unverified_expander = proven_ct + .expand_without_verification( + IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary(ksk.as_view()), + ) + .unwrap(); + + for (idx, msg) in msgs.iter().copied().enumerate() { + let expanded = unverified_expander + .get::(idx) + .unwrap() + .unwrap(); + let decrypted = cks.decrypt_radix::(&expanded); + assert_eq!(msg, decrypted); + } + } + + #[test] + fn test_several_proven_lists() { + let pke_params = PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; + let ksk_params = PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; + let fhe_params = PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; + + let metadata = [b'i', b'n', b't', b'e', b'g', b'e', b'r']; + + let crs_blocks_for_64_bits = + 64 / ((pke_params.message_modulus.0 * pke_params.carry_modulus.0).ilog2() as usize); + let encryption_num_blocks = 64 / (pke_params.message_modulus.0.ilog2() as usize); + + let crs = CompactPkeCrs::from_shortint_params(pke_params, crs_blocks_for_64_bits).unwrap(); + let cks = ClientKey::new(fhe_params); + let sk = ServerKey::new_radix_server_key(&cks); + let compact_private_key = CompactPrivateKey::new(pke_params); + let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), ksk_params); + let pk = CompactPublicKey::new(&compact_private_key); + + let msgs = (0..2).map(|_| random::()).collect::>(); + + let proven_ct = CompactCiphertextList::builder(&pk) + .extend_with_num_blocks(msgs.iter().copied(), encryption_num_blocks) + .build_with_proof_packed(crs.public_params(), &metadata, ZkComputeLoad::Proof) + .unwrap(); + + let expander = proven_ct + .verify_and_expand( + crs.public_params(), + &pk, + &metadata, + IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary(ksk.as_view()), ) .unwrap(); @@ -856,8 +1162,7 @@ mod tests { let unverified_expander = proven_ct .expand_without_verification( - IntegerCompactCiphertextListUnpackingMode::UnpackIfNecessary(&sk), - IntegerCompactCiphertextListCastingMode::CastIfNecessary(ksk.as_view()), + IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary(ksk.as_view()), ) .unwrap(); diff --git a/tfhe/src/integer/ciphertext/utils.rs b/tfhe/src/integer/ciphertext/utils.rs index c11720f5b9..15c9660a6d 100644 --- a/tfhe/src/integer/ciphertext/utils.rs +++ b/tfhe/src/integer/ciphertext/utils.rs @@ -7,7 +7,9 @@ use tfhe_versionable::Versionize; #[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Versionize)] #[versionize(DataKindVersions)] pub enum DataKind { + /// The held value is a number of radix blocks. Unsigned(usize), + /// The held value is a number of radix blocks. Signed(usize), Boolean, } diff --git a/tfhe/src/integer/key_switching_key/test.rs b/tfhe/src/integer/key_switching_key/test.rs index fbc0d00357..fe41c6538b 100644 --- a/tfhe/src/integer/key_switching_key/test.rs +++ b/tfhe/src/integer/key_switching_key/test.rs @@ -1,8 +1,6 @@ use crate::integer::key_switching_key::KeySwitchingKey; use crate::integer::keycache::KEY_CACHE; -use crate::integer::parameters::{ - IntegerCompactCiphertextListCastingMode, IntegerCompactCiphertextListUnpackingMode, -}; +use crate::integer::parameters::IntegerCompactCiphertextListExpansionMode; use crate::integer::{ ClientKey, CompactPrivateKey, CompactPublicKey, CrtClientKey, IntegerCiphertext, IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey, @@ -197,13 +195,15 @@ fn test_case_cpk_encrypt_cast_compute( // Encrypt a value and cast let ct1 = pk.encrypt_radix_compact(input_msg, num_block); let expander = ct1 - .expand( - IntegerCompactCiphertextListUnpackingMode::UnpackIfNecessary(&sks_fhe), - IntegerCompactCiphertextListCastingMode::CastIfNecessary(ksk.as_view()), - ) + .expand(IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary(ksk.as_view())) .unwrap(); let mut ct1_extracted_and_cast = expander.get::(0).unwrap().unwrap(); + assert!(ct1_extracted_and_cast + .blocks() + .iter() + .all(|x| x.degree.get() == sks_fhe.message_modulus().0 - 1)); + let sanity_pbs: u64 = cks_fhe.decrypt_radix(&ct1_extracted_and_cast); assert_eq!(sanity_pbs, input_msg); diff --git a/tfhe/src/integer/parameters/mod.rs b/tfhe/src/integer/parameters/mod.rs index 787523f1ac..84bb683037 100644 --- a/tfhe/src/integer/parameters/mod.rs +++ b/tfhe/src/integer/parameters/mod.rs @@ -2,10 +2,8 @@ use crate::conformance::ListSizeConstraint; use crate::integer::key_switching_key::KeySwitchingKeyView; use crate::integer::server_key::ServerKey; -use crate::shortint::parameters::compact_public_key_only::CompactCiphertextListCastingMode; use crate::shortint::parameters::{ CarryModulus, CiphertextConformanceParams, EncryptionKeyChoice, MessageModulus, - ShortintCompactCiphertextListCastingMode, }; pub use crate::shortint::parameters::{ DecompositionBaseLog, DecompositionLevelCount, DynamicDistribution, GlweDimension, @@ -14,26 +12,13 @@ pub use crate::shortint::parameters::{ use crate::shortint::PBSParameters; pub use crate::shortint::{CiphertextModulus, ClassicPBSParameters, WopbsParameters}; -pub type IntegerCompactCiphertextListCastingMode<'key> = - CompactCiphertextListCastingMode>; - -impl<'key> From> - for ShortintCompactCiphertextListCastingMode<'key> -{ - fn from(value: IntegerCompactCiphertextListCastingMode<'key>) -> Self { - match value { - IntegerCompactCiphertextListCastingMode::CastIfNecessary(integer_key) => { - Self::CastIfNecessary(integer_key.key) - } - IntegerCompactCiphertextListCastingMode::NoCasting => Self::NoCasting, - } - } -} - #[derive(Clone, Copy)] -pub enum IntegerCompactCiphertextListUnpackingMode<'key> { - UnpackIfNecessary(&'key ServerKey), - NoUnpacking, +pub enum IntegerCompactCiphertextListExpansionMode<'key> { + /// The [`KeySwitchingKeyView`] has all the information to both cast and unpack. + CastAndUnpackIfNecessary(KeySwitchingKeyView<'key>), + /// This only allows to unpack. + UnpackAndSanitizeIfNecessary(&'key ServerKey), + NoCastingAndNoUnpacking, } pub const ALL_PARAMETER_VEC_INTEGER_16_BITS: [WopbsParameters; 2] = [ diff --git a/tfhe/src/integer/public_key/tests.rs b/tfhe/src/integer/public_key/tests.rs index 9e80255b9f..36233aa27e 100644 --- a/tfhe/src/integer/public_key/tests.rs +++ b/tfhe/src/integer/public_key/tests.rs @@ -1,7 +1,5 @@ use crate::integer::keycache::KEY_CACHE; -use crate::integer::parameters::{ - IntegerCompactCiphertextListCastingMode, IntegerCompactCiphertextListUnpackingMode, -}; +use crate::integer::parameters::IntegerCompactCiphertextListExpansionMode; use crate::integer::tests::create_parametrized_test; use crate::integer::{gen_keys, CompressedPublicKey, IntegerKeyKind, PublicKey, RadixCiphertext}; use crate::shortint::parameters::classic::compact_pk::*; @@ -129,16 +127,9 @@ fn radix_encrypt_decrypt_compact_128_bits_list(params: ClassicPBSParameters) { assert!(compact_lists[1].is_packed()); for compact_encrypted_list in compact_lists { - let unpacking_mode = if compact_encrypted_list.is_packed() { - IntegerCompactCiphertextListUnpackingMode::UnpackIfNecessary(&sks) - } else { - IntegerCompactCiphertextListUnpackingMode::NoUnpacking - }; - let expander = compact_encrypted_list .expand( - unpacking_mode, - IntegerCompactCiphertextListCastingMode::NoCasting, + IntegerCompactCiphertextListExpansionMode::UnpackAndSanitizeIfNecessary(&sks), ) .unwrap(); diff --git a/tfhe/src/shortint/ciphertext/compact_list.rs b/tfhe/src/shortint/ciphertext/compact_list.rs index d932837b7f..cd05d7613f 100644 --- a/tfhe/src/shortint/ciphertext/compact_list.rs +++ b/tfhe/src/shortint/ciphertext/compact_list.rs @@ -6,7 +6,6 @@ use crate::conformance::ParameterSetConformant; use crate::core_crypto::commons::traits::ContiguousEntityContainer; use crate::core_crypto::entities::*; use crate::shortint::backward_compatibility::ciphertext::CompactCiphertextListVersions; -use crate::shortint::parameters::compact_public_key_only::CompactCiphertextListCastingMode; pub use crate::shortint::parameters::ShortintCompactCiphertextListCastingMode; use crate::shortint::parameters::{ CarryModulus, CompactCiphertextListExpansionKind, MessageModulus, @@ -62,7 +61,8 @@ impl CompactCiphertextList { /// Expand a [`CompactCiphertextList`] to a `Vec` of [`Ciphertext`]. /// /// The function takes a [`ShortintCompactCiphertextListCastingMode`] to indicate whether a - /// keyswitch should be applied during expansion. + /// keyswitch should be applied during expansion, and if it does, functions can be applied as + /// well during casting, which can be more efficient if a refresh is required during casting. /// /// This is useful when using separate parameters for the public key used to encrypt the /// [`CompactCiphertextList`] allowing to keyswitch to the computation params during expansion. @@ -94,7 +94,7 @@ impl CompactCiphertextList { match (self.expansion_kind, casting_mode) { ( CompactCiphertextListExpansionKind::RequiresCasting, - CompactCiphertextListCastingMode::NoCasting, + ShortintCompactCiphertextListCastingMode::NoCasting, ) => Err(crate::Error::new(String::from( "Cannot expand a CompactCiphertextList that requires casting without casting, \ please provide a shortint::KeySwitchingKey passing it with the enum variant \ @@ -102,13 +102,32 @@ impl CompactCiphertextList { ))), ( CompactCiphertextListExpansionKind::RequiresCasting, - CompactCiphertextListCastingMode::CastIfNecessary(casting_key), + ShortintCompactCiphertextListCastingMode::CastIfNecessary { + casting_key, + functions, + }, ) => { + let functions = match functions { + Some(functions) => { + if functions.len() != output_lwe_ciphertext_list.lwe_ciphertext_count().0 { + return Err(crate::Error::new(format!( + "Cannot expand a CompactCiphertextList: got {} functions for casting, \ + expected {}", + functions.len(), + output_lwe_ciphertext_list.lwe_ciphertext_count().0 + ))); + } + functions + } + None => &vec![None; output_lwe_ciphertext_list.lwe_ciphertext_count().0], + }; + let pbs_order = casting_key.dest_server_key.pbs_order; let res = output_lwe_ciphertext_list .par_iter() - .map(|lwe_view| { + .zip(functions.par_iter()) + .flat_map(|(lwe_view, functions)| { let lwe_to_cast = LweCiphertext::from_container( lwe_view.as_ref().to_vec(), self.ct_list.ciphertext_modulus(), @@ -122,7 +141,8 @@ impl CompactCiphertextList { noise_level: self.noise_level, }; - casting_key.cast(&shortint_ct_to_cast) + casting_key + .cast_and_apply_functions(&shortint_ct_to_cast, functions.as_deref()) }) .collect::>(); Ok(res) diff --git a/tfhe/src/shortint/ciphertext/zk.rs b/tfhe/src/shortint/ciphertext/zk.rs index de33d1313f..83be4121ad 100644 --- a/tfhe/src/shortint/ciphertext/zk.rs +++ b/tfhe/src/shortint/ciphertext/zk.rs @@ -104,10 +104,58 @@ impl ProvenCompactCiphertextList { &self, casting_mode: ShortintCompactCiphertextListCastingMode<'_>, ) -> crate::Result> { + let per_list_casting_mode: Vec<_> = match casting_mode { + ShortintCompactCiphertextListCastingMode::CastIfNecessary { + casting_key, + functions, + } => match functions { + Some(functions) => { + // For how many ciphertexts we have functions + let functions_sets_count = functions.len(); + let total_ciphertext_count: usize = self + .proved_lists + .iter() + .map(|list| list.0.ct_list.lwe_ciphertext_count().0) + .sum(); + + if functions_sets_count != total_ciphertext_count { + return Err(crate::Error::new(format!( + "Cannot expand a CompactCiphertextList: got {functions_sets_count} \ + sets of functions for casting, expected {total_ciphertext_count}" + ))); + } + + let mut modes = vec![]; + let mut functions_used_so_far = 0; + for list in self.proved_lists.iter() { + let blocks_in_list = list.0.ct_list.lwe_ciphertext_count().0; + + let functions_to_use = &functions + [functions_used_so_far..functions_used_so_far + blocks_in_list]; + + modes.push(ShortintCompactCiphertextListCastingMode::CastIfNecessary { + casting_key, + functions: Some(functions_to_use), + }); + + functions_used_so_far += blocks_in_list; + } + modes + } + None => vec![ + ShortintCompactCiphertextListCastingMode::NoCasting; + self.proved_lists.len() + ], + }, + ShortintCompactCiphertextListCastingMode::NoCasting => { + vec![ShortintCompactCiphertextListCastingMode::NoCasting; self.proved_lists.len()] + } + }; let expanded = self .proved_lists .iter() - .map(|(ct_list, _proof)| ct_list.expand(casting_mode)) + .zip(per_list_casting_mode.into_iter()) + .map(|((ct_list, _proof), casting_mode)| ct_list.expand(casting_mode)) .collect::>, _>>()? .into_iter() .flatten() diff --git a/tfhe/src/shortint/key_switching_key/mod.rs b/tfhe/src/shortint/key_switching_key/mod.rs index dd9145afa4..3803c63c6f 100644 --- a/tfhe/src/shortint/key_switching_key/mod.rs +++ b/tfhe/src/shortint/key_switching_key/mod.rs @@ -16,6 +16,7 @@ use crate::shortint::parameters::{ use crate::shortint::server_key::apply_programmable_bootstrap; use crate::shortint::{Ciphertext, ClientKey, CompressedServerKey, ServerKey}; use core::cmp::Ordering; +use rayon::prelude::*; use serde::{Deserialize, Serialize}; use tfhe_versionable::Versionize; @@ -485,6 +486,21 @@ impl<'keys> KeySwitchingKeyView<'keys> { /// assert_eq!(ck2.decrypt(&cipher_2), cleartext); /// ``` pub fn cast(&self, input_ct: &Ciphertext) -> Ciphertext { + let res = self.cast_and_apply_functions(input_ct, None); + assert_eq!(res.len(), 1); + res.into_iter().next().unwrap() + } + + /// Cast a ciphertext from the source parameter set to the dest parameter set, + /// returning a new ciphertext. + /// + /// If None is provided then an identity function is used and tighter degrees are used where + /// applicable. + pub fn cast_and_apply_functions( + &self, + input_ct: &Ciphertext, + functions: Option<&[&(dyn Fn(u64) -> u64 + Sync)]>, + ) -> Vec { let output_lwe_size = match self.key_switching_key_material.destination_key { EncryptionKeyChoice::Big => self .dest_server_key @@ -538,100 +554,211 @@ impl<'keys> KeySwitchingKeyView<'keys> { ); keyswitched.degree = pre_processed.degree; + let degree_after_keyswitch = keyswitched.degree; + + enum CastCiphertext { + CorrectKey(Ciphertext), + WrongKeyRequiresPBS(Ciphertext), + } + // Manage the destination key adjustment - let mut res = { + let res = { let destination_pbs_order: PBSOrder = self.key_switching_key_material.destination_key.into(); if destination_pbs_order == self.dest_server_key.pbs_order { - keyswitched + CastCiphertext::CorrectKey(keyswitched) } else { - let wrong_key_ct = keyswitched; - let mut correct_key_ct = self.dest_server_key.create_trivial(0); - correct_key_ct.degree = wrong_key_ct.degree; - correct_key_ct.set_noise_level(wrong_key_ct.noise_level()); - // We are arriving under the wrong key for the dest_server_key match self.key_switching_key_material.destination_key { // Big to Small == keyswitch EncryptionKeyChoice::Big => { + let wrong_key_ct = keyswitched; + let mut correct_key_ct = self.dest_server_key.create_trivial(0); + correct_key_ct.degree = wrong_key_ct.degree; + correct_key_ct.set_noise_level(wrong_key_ct.noise_level()); + keyswitch_lwe_ciphertext( &self.dest_server_key.key_switching_key, &wrong_key_ct.ct, &mut correct_key_ct.ct, ); + + CastCiphertext::CorrectKey(correct_key_ct) } - // Small to Big == PBS - EncryptionKeyChoice::Small => { - ShortintEngine::with_thread_local_mut(|engine| { - let acc = self.dest_server_key.generate_lookup_table(|x| x); - let (_, buffers) = engine.get_buffers(self.dest_server_key); - apply_programmable_bootstrap( - &self.dest_server_key.bootstrapping_key, - &wrong_key_ct.ct, - &mut correct_key_ct.ct, - &acc.acc, - buffers, - ); - }); - // Degree does not need to be updated as we apply an Identity LUT and we - // apply only the bootstrap directly on the underlying ciphertext, we have - // to update the noise however. - correct_key_ct.set_noise_level(NoiseLevel::NOMINAL); - } + // Small to Big == PBS, we handle this in the last part of the function to apply + // the refresh and the user functions in similar ways and keep the code easier + // to maintain + EncryptionKeyChoice::Small => CastCiphertext::WrongKeyRequiresPBS(keyswitched), } - - correct_key_ct } }; - let degree_after_keyswitch = res.degree; + let output_ciphertext_count = functions.map_or_else(|| 1, |x| x.len()); + let mut output_cts = vec![self.dest_server_key.create_trivial(0); output_ciphertext_count]; + let identity_fn_array: &[&(dyn Fn(u64) -> u64 + Sync)] = &[&|x: u64| x]; + let functions_to_use = functions.map_or_else(|| identity_fn_array, |fns| fns); + let using_user_provided_functions = functions.is_some(); + let using_identity_lut = !using_user_provided_functions; + match cast_rshift.cmp(&0) { - // Same bit size: only key switch + // Same bit size Ordering::Equal => { - // Refresh if we haven't applied a PBS yet - if res.noise_level() == NoiseLevel::UNKNOWN { - let acc = self.dest_server_key.generate_lookup_table(|x| x); - self.dest_server_key - .apply_lookup_table_assign(&mut res, &acc); - // We apply an Identity LUT so we know a tighter bound than the worst case LUT - // value - res.degree = degree_after_keyswitch; + // Refresh or apply user functions if provided + match res { + CastCiphertext::CorrectKey(ciphertext) => { + output_cts + .par_iter_mut() + .zip(functions_to_use.par_iter()) + .for_each(|(correct_key_ct, function)| { + let acc = self.dest_server_key.generate_lookup_table(function); + *correct_key_ct = + self.dest_server_key.apply_lookup_table(&ciphertext, &acc); + // If we apply an Identity LUT we know a tighter bound than the + // worst case LUT value + if using_identity_lut { + correct_key_ct.degree = degree_after_keyswitch; + } + }); + } + CastCiphertext::WrongKeyRequiresPBS(wrong_key_ct) => { + output_cts + .par_iter_mut() + .zip(functions_to_use.par_iter()) + .for_each(|(correct_key_ct, function)| { + ShortintEngine::with_thread_local_mut(|engine| { + let (_, buffers) = engine.get_buffers(self.dest_server_key); + let acc = self.dest_server_key.generate_lookup_table(function); + apply_programmable_bootstrap( + &self.dest_server_key.bootstrapping_key, + &wrong_key_ct.ct, + &mut correct_key_ct.ct, + &acc.acc, + buffers, + ); + + // Update degree depending on the LUT used (as this is a PBS and + // not a full apply lookup table) + if using_user_provided_functions { + correct_key_ct.degree = acc.degree; + } else { + correct_key_ct.degree = degree_after_keyswitch; + } + // Update the noise as well + correct_key_ct.set_noise_level(NoiseLevel::NOMINAL); + }); + }); + } } } - // Cast to bigger bit length: keyswitch, then right shift + // Cast to bigger bit length: keyswitch, then right shift, combine this with user + // function for better efficiency Ordering::Greater => { - let acc = self - .dest_server_key - .generate_lookup_table(|n| n >> cast_rshift); - self.dest_server_key - .apply_lookup_table_assign(&mut res, &acc); - // degree and noise are updated by the apply lookup table + match res { + CastCiphertext::CorrectKey(ciphertext) => { + output_cts + .par_iter_mut() + .zip(functions_to_use.par_iter()) + .for_each(|(correct_key_ct, function)| { + let acc = self + .dest_server_key + .generate_lookup_table(|n| function(n >> cast_rshift)); + *correct_key_ct = + self.dest_server_key.apply_lookup_table(&ciphertext, &acc); + // degree and noise are updated by the apply lookup table + }); + } + CastCiphertext::WrongKeyRequiresPBS(wrong_key_ct) => { + output_cts + .par_iter_mut() + .zip(functions_to_use.par_iter()) + .for_each(|(correct_key_ct, function)| { + ShortintEngine::with_thread_local_mut(|engine| { + let (_, buffers) = engine.get_buffers(self.dest_server_key); + let acc = self.dest_server_key.generate_lookup_table(|n| { + // Call the function on the shifted arrival + // value + function(n >> cast_rshift) + }); + apply_programmable_bootstrap( + &self.dest_server_key.bootstrapping_key, + &wrong_key_ct.ct, + &mut correct_key_ct.ct, + &acc.acc, + buffers, + ); + // Update degree and noise as it's a raw PBS + correct_key_ct.degree = acc.degree; + correct_key_ct.set_noise_level(NoiseLevel::NOMINAL); + }); + }); + } + } } - // Cast to smaller bit length: left shift, then keyswitch + // Cast to smaller bit length: left shift, then keyswitch, then refresh or apply user + // function. Ordering::Less => { - // The degree is high in the source plaintext modulus, but smaller in the arriving - // one. - // - // src 4 bits: - // 0 | XX | 11 - // shifted: - // 0 | 11 | 00 -> Applied lut will have max degree 1100 = 12 - // dst 2 bits : - // 0 | 11 -> 11 = 3 - let new_degree = Degree::new(degree_after_keyswitch.get() >> -cast_rshift); - // Refresh if we haven't applied a PBS yet - if res.noise_level() == NoiseLevel::UNKNOWN { - let acc = self.dest_server_key.generate_lookup_table(|x| x); - self.dest_server_key - .apply_lookup_table_assign(&mut res, &acc); + match res { + CastCiphertext::CorrectKey(ciphertext) => { + output_cts + .par_iter_mut() + .zip(functions_to_use.par_iter()) + .for_each(|(correct_key_ct, function)| { + let acc = self.dest_server_key.generate_lookup_table(function); + *correct_key_ct = + self.dest_server_key.apply_lookup_table(&ciphertext, &acc); + + if using_user_provided_functions { + correct_key_ct.degree = acc.degree; + } else { + // Note that this relies on the fact that the left shift degree + // is in degree_after_keyswitch. + // The degree is high in the source plaintext modulus, but + // smaller in the arriving one. + // + // src 4 bits: + // 0 | XX | 11 + // shifted: + // 0 | 11 | 00 -> Applied lut will have max degree 1100 = 12 + // dst 2 bits : + // 0 | 11 -> 11 = 3 + let new_degree = + Degree::new(degree_after_keyswitch.get() >> -cast_rshift); + correct_key_ct.degree = new_degree; + } + }); + } + CastCiphertext::WrongKeyRequiresPBS(wrong_key_ct) => { + output_cts + .par_iter_mut() + .zip(functions_to_use.par_iter()) + .for_each(|(correct_key_ct, function)| { + ShortintEngine::with_thread_local_mut(|engine| { + let (_, buffers) = engine.get_buffers(self.dest_server_key); + let acc = self.dest_server_key.generate_lookup_table(function); + apply_programmable_bootstrap( + &self.dest_server_key.bootstrapping_key, + &wrong_key_ct.ct, + &mut correct_key_ct.ct, + &acc.acc, + buffers, + ); + if using_user_provided_functions { + correct_key_ct.degree = acc.degree; + } else { + let new_degree = Degree::new( + degree_after_keyswitch.get() >> -cast_rshift, + ); + correct_key_ct.degree = new_degree; + } + correct_key_ct.set_noise_level(NoiseLevel::NOMINAL); + }); + }); + } } - // Apply the degree correction, even if we bootstrapped as the Identity LUT would - // not change this correction - res.degree = new_degree; } } - res + output_cts } } diff --git a/tfhe/src/shortint/parameters/compact_public_key_only/mod.rs b/tfhe/src/shortint/parameters/compact_public_key_only/mod.rs index 8cc1f2917d..115bb1b8e3 100644 --- a/tfhe/src/shortint/parameters/compact_public_key_only/mod.rs +++ b/tfhe/src/shortint/parameters/compact_public_key_only/mod.rs @@ -21,15 +21,20 @@ pub enum CompactCiphertextListExpansionKind { NoCasting(PBSOrder), } -#[derive(Clone, Copy, Debug)] -pub enum CompactCiphertextListCastingMode { - CastIfNecessary(K), +pub type CastingFunctionsOwned<'functions> = + Vec u64 + Sync)>>>; +pub type CastingFunctionsView<'functions> = + &'functions [Option u64 + Sync)>>]; + +#[derive(Clone, Copy)] +pub enum ShortintCompactCiphertextListCastingMode<'a> { + CastIfNecessary { + casting_key: KeySwitchingKeyView<'a>, + functions: Option>, + }, NoCasting, } -pub type ShortintCompactCiphertextListCastingMode<'key> = - CompactCiphertextListCastingMode>; - impl From for CompactCiphertextListExpansionKind { fn from(value: PBSOrder) -> Self { Self::NoCasting(value) diff --git a/tfhe/src/shortint/parameters/mod.rs b/tfhe/src/shortint/parameters/mod.rs index 2a72d523c3..0c0fd9ea3d 100644 --- a/tfhe/src/shortint/parameters/mod.rs +++ b/tfhe/src/shortint/parameters/mod.rs @@ -44,8 +44,8 @@ pub use crate::shortint::parameters::classic::tuniform::p_fail_2_minus_64::ks_pb pub use crate::shortint::parameters::classic::tuniform::p_fail_2_minus_64::pbs_ks::*; pub use crate::shortint::parameters::list_compression::CompressionParameters; pub use compact_public_key_only::{ - CompactCiphertextListExpansionKind, CompactPublicKeyEncryptionParameters, - ShortintCompactCiphertextListCastingMode, + CastingFunctionsOwned, CastingFunctionsView, CompactCiphertextListExpansionKind, + CompactPublicKeyEncryptionParameters, ShortintCompactCiphertextListCastingMode, }; #[cfg(tarpaulin)] pub use coverage_parameters::*;