Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster bytesToCoefficientInplace -> faster deserialization #96

Merged
merged 1 commit into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 70 additions & 40 deletions Sources/HomomorphicEncryption/CoefficientPacking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@
public enum CoefficientPacking {}

extension CoefficientPacking {
/// Checks `bitsPerCoeff` and `skipLSBs` are valid.
/// - Parameters:
/// - bitsPerCoeff: Number of bits in each coefficient.
/// - skipLSBs: How many least-significant bits from each coefficient are omitted from serialization to bytes.
/// - Throws: Error upon invalid `bitsPerCoeff` or `skipLSBs` arguments.
@inlinable
static func validate(bitsPerCoeff: Int, skipLSBs: Int) throws {
guard bitsPerCoeff > 0, bitsPerCoeff > skipLSBs, skipLSBs >= 0 else {
throw HeError.invalidCoefficientPacking(bitsPerCoeff: bitsPerCoeff, skipLSBs: skipLSBs)
}
}

@inlinable
static func bytesToCoefficientsCoeffCount(byteCount: Int, bitsPerCoeff: Int, decode: Bool,
skipLSBs: Int = 0) -> Int
Expand All @@ -39,10 +51,11 @@ extension CoefficientPacking {
/// - skipLSBs: How many least-significant bits from each coefficient are assumed to be 0, and not present in
/// `bytes`.
/// - Returns: The deserialized coefficients.
/// - Throws: Error upon failure to convert bytes to coefficients.
/// - seealso: ``CoefficientPacking/coefficientsToBytes(coeffs:bitsPerCoeff:skipLSBs:)``
@inlinable
public static func bytesToCoefficients<T: ScalarType>(bytes: [UInt8], bitsPerCoeff: Int, decode: Bool,
skipLSBs: Int = 0) -> [T]
skipLSBs: Int = 0) throws -> [T]
{
var coeffs: [T] = .init(
repeating: 0,
Expand All @@ -51,56 +64,76 @@ extension CoefficientPacking {
bitsPerCoeff: bitsPerCoeff,
decode: decode,
skipLSBs: skipLSBs))
bytesToCoefficientsInplace(bytes: bytes, coeffs: &coeffs, bitsPerCoeff: bitsPerCoeff, skipLSBs: skipLSBs)
try bytesToCoefficientsInplace(bytes: bytes, coeffs: &coeffs, bitsPerCoeff: bitsPerCoeff, skipLSBs: skipLSBs)
return coeffs
}

/// Converts an sequence of bytes into coefficients, unused bits in the last coefficient will be set to zero.
@inlinable
static func bytesToCoefficientsInplace<T, C>(
bytes: some Sequence<UInt8>,
coeffs: inout C,
static func bytesToCoefficientsInplace<C, T, B>(
bytes: B,
coeffs coeffsCollection: inout C,
bitsPerCoeff: Int,
skipLSBs: Int = 0)
skipLSBs: Int = 0) throws
where T: ScalarType,
C: MutableCollection,
C.Element == T,
C.Index == Int
B: Collection, B.Element == UInt8, B.Index == Int,
C: MutableCollection, C.Element == T, C.Index == Int
{
precondition(bitsPerCoeff > 0)
precondition(bitsPerCoeff > skipLSBs)

typealias BufferType = UInt64
precondition(T.bitWidth <= BufferType.bitWidth)
try validate(bitsPerCoeff: bitsPerCoeff, skipLSBs: skipLSBs)
let serializedBitCount = bitsPerCoeff - skipLSBs
var coeffIndex = coeffs.startIndex
var coeff: T = 0
var remainingCoeffBits = serializedBitCount

// consume bytes and populate coefficients
for byte in bytes {
var remainingBits = UInt8.bitWidth
var byte = byte
repeat {
let shift = min(remainingBits, remainingCoeffBits)
coeff &<<= shift
coeff |= T(byte &>> (UInt8.bitWidth - shift))
byte = byte &<< shift
remainingCoeffBits &-= shift
remainingBits &-= shift
let foundContiguousBuffer: ()? = coeffsCollection.withContiguousMutableStorageIfAvailable { coeffs in
var coeffIndex = 0
var unusedBitCount = 0
var unusedBits: BufferType = 0
// Read bytes into BufferType.
// Bits from a coefficient will be in at most two buffers
let bytesPerBuffer = BufferType.bitWidth / UInt8.bitWidth
for byteChunkIndex in stride(from: bytes.startIndex, to: bytes.endIndex, by: bytesPerBuffer) {
let coeffStartIndex = byteChunkIndex
let endIndex = min(coeffStartIndex &+ bytesPerBuffer, bytes.endIndex)
let buffer = BufferType(bigEndianBytes: bytes[coeffStartIndex..<endIndex])
var newBitsCount = 0

if remainingCoeffBits == 0 {
remainingCoeffBits = serializedBitCount
coeffs[coeffIndex] = coeff &<< skipLSBs
// Deal with unused bits from previous round
if unusedBitCount != 0 {
newBitsCount = serializedBitCount &- unusedBitCount
var coeff = buffer &>> (BufferType.bitWidth &- newBitsCount)
coeff |= (unusedBits &<< newBitsCount)
coeffs[coeffIndex] = T(coeff &<< skipLSBs)
coeffIndex &+= 1
coeff = 0
}
} while remainingBits > 0

// Parse as many complete coefficients from the buffer as possible
let remainingCoeffs = coeffs.count &- coeffIndex
let coeffsPerFullCoeff = (BufferType.bitWidth &- newBitsCount) / serializedBitCount
for i in 0..<min(remainingCoeffs, coeffsPerFullCoeff) {
let msbsToClear = newBitsCount &+ i &* serializedBitCount
var coeff = buffer &<< msbsToClear
let lsbsToClear = BufferType.bitWidth &- serializedBitCount
coeff &>>= lsbsToClear
coeffs[coeffIndex] = T(coeff &<< skipLSBs)
coeffIndex &+= 1
}
unusedBitCount &+= (BufferType.bitWidth % serializedBitCount)
if unusedBitCount >= serializedBitCount {
unusedBitCount &-= serializedBitCount
}
unusedBits = buffer & ((1 &<< unusedBitCount) &- 1)
}
// Process unused bits from the last coefficient
if coeffIndex < coeffs.count {
let coeff = unusedBits &<< (serializedBitCount &- unusedBitCount &+ skipLSBs)
coeffs[coeffIndex] = T(coeff)
coeffIndex &+= 1
}
precondition(coeffIndex == coeffs.count)
}
if coeffIndex < coeffs.endIndex {
coeff &<<= (remainingCoeffBits &+ skipLSBs)
coeffs[coeffIndex] = coeff
coeffIndex &+= 1
guard foundContiguousBuffer != nil else {
throw HeError.serializationBufferNotContiguous
}
precondition(coeffIndex == coeffs.endIndex)
}

@inlinable
Expand Down Expand Up @@ -142,9 +175,7 @@ extension CoefficientPacking {
C.Element == UInt8,
C.Index == Int
{
precondition(bitsPerCoeff > 0)
precondition(bitsPerCoeff > skipLSBs)

try validate(bitsPerCoeff: bitsPerCoeff, skipLSBs: skipLSBs)
var byteIndex = 0
let bytesCount = bytes.count
let serializedBitCount = bitsPerCoeff - skipLSBs
Expand Down Expand Up @@ -174,7 +205,6 @@ extension CoefficientPacking {
remainingBits &-= shift
} while remainingCoeffBits > 0
}

if byteIndex < bytesCount {
byte &<<= remainingBits
bytesPtr[byteIndex] = byte
Expand Down
3 changes: 3 additions & 0 deletions Sources/HomomorphicEncryption/Error.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public enum HeError: Error, Equatable {
case insecureEncryptionParameters(_ description: String)
case invalidCiphertext(_ description: String)
case invalidCoefficientIndex(index: Int, degree: Int)
case invalidCoefficientPacking(bitsPerCoeff: Int, skipLSBs: Int)
case invalidContext(_ description: String)
case invalidCorrectionFactor(_ description: String)
case invalidDegree(_ degree: Int)
Expand Down Expand Up @@ -203,6 +204,8 @@ extension HeError: LocalizedError {
"Insecure encryption parameters \(description)"
case let .invalidCoefficientIndex(index, degree):
"Invalid coefficient index \(index) for degree \(degree)"
case let .invalidCoefficientPacking(bitsPerCoeff, skipLSBs):
"Invalid coefficint packing: bitsPerCoeff \(bitsPerCoeff), skipLSBs \(skipLSBs)"
case let .invalidCiphertext(description):
"\(description)"
case let .invalidContext(description):
Expand Down
2 changes: 1 addition & 1 deletion Sources/HomomorphicEncryption/HeScheme.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,7 @@ extension HeScheme {
}

let galoisElements = Array(galoisKey.keys.keys)
let steps = try GaloisElement.stepsFor(elements: galoisElements, degree: degree).values.compactMap { $0 }
let steps = GaloisElement.stepsFor(elements: galoisElements, degree: degree).values.compactMap { $0 }

let positiveStep = if step < 0 {
step + degree / 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ extension PolyRq {
}

let bytes = buffer[offset..<(offset &+ byteCount)]
CoefficientPacking.bytesToCoefficientsInplace(
try CoefficientPacking.bytesToCoefficientsInplace(
bytes: bytes,
coeffs: &data.data[polyIndices(rnsIndex: rnsIndex)],
bitsPerCoeff: bitsPerCoeff,
Expand Down
4 changes: 2 additions & 2 deletions Sources/PrivateInformationRetrieval/MulPir.swift
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ extension MulPirServer {
return nil
}
let bytes = Array(entry[startIndex..<endIndex])
let coefficients: [Scheme.Scalar] = CoefficientPacking.bytesToCoefficients(
let coefficients: [Scheme.Scalar] = try CoefficientPacking.bytesToCoefficients(
bytes: bytes,
bitsPerCoeff: context.plaintextModulus.log2,
decode: false)
Expand Down Expand Up @@ -502,7 +502,7 @@ extension MulPirServer {
.map { startIndex in
let endIndex = min(startIndex + bytesPerPlaintext, flatDatabase.count)
let values = Array(flatDatabase[startIndex..<endIndex])
let plaintextCoefficients: [Scheme.Scalar] = CoefficientPacking.bytesToCoefficients(
let plaintextCoefficients: [Scheme.Scalar] = try CoefficientPacking.bytesToCoefficients(
bytes: values,
bitsPerCoeff: context.plaintextModulus.log2,
decode: false)
Expand Down
10 changes: 5 additions & 5 deletions Sources/PrivateInformationRetrieval/PirUtil.swift
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,11 @@ enum PirUtil<Scheme: HeScheme> {
return try plaintexts.map { plaintext in try plaintext.encrypt(using: secretKey) }
}

static func encodeDatabase<Scalar: ScalarType>(database: [[UInt8]], plaintextModulus: Scalar) -> [[Scalar]] {
database.map { entry in
CoefficientPacking.bytesToCoefficients(bytes: entry,
bitsPerCoeff: plaintextModulus.log2,
decode: false)
static func encodeDatabase<Scalar: ScalarType>(database: [[UInt8]], plaintextModulus: Scalar) throws -> [[Scalar]] {
try database.map { entry in
try CoefficientPacking.bytesToCoefficients(bytes: entry,
bitsPerCoeff: plaintextModulus.log2,
decode: false)
}
}
}
14 changes: 7 additions & 7 deletions Tests/HomomorphicEncryptionTests/CoefficientPackingTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class CoefficientPackingTests: XCTestCase {
var rng = NistAes128Ctr()
rng.fill(&bytes)

let coeffs: [T] = CoefficientPacking.bytesToCoefficients(
let coeffs: [T] = try CoefficientPacking.bytesToCoefficients(
bytes: bytes,
bitsPerCoeff: log2t,
decode: false)
Expand Down Expand Up @@ -64,7 +64,7 @@ class CoefficientPackingTests: XCTestCase {
}

let bytes = try CoefficientPacking.coefficientsToBytes(coeffs: coeffs, bitsPerCoeff: log2t + 1)
let decodedCoeffs: [T] = CoefficientPacking.bytesToCoefficients(
let decodedCoeffs: [T] = try CoefficientPacking.bytesToCoefficients(
bytes: bytes,
bitsPerCoeff: log2t + 1,
decode: true)
Expand All @@ -76,7 +76,7 @@ class CoefficientPackingTests: XCTestCase {
try runTest(UInt64.self)
}

func testBytesToCoeffKAT() {
func testBytesToCoeffKAT() throws {
struct BytesToCoeffKAT<T: ScalarType> {
let bytes: [UInt8]
let bitsPerCoeff: Int
Expand All @@ -85,7 +85,7 @@ class CoefficientPackingTests: XCTestCase {
let expectedCoefficients: [T]
}

func runTest<T: ScalarType>(_: T.Type) {
func runTest<T: ScalarType>(_: T.Type) throws {
let kats: [BytesToCoeffKAT<T>] = [
BytesToCoeffKAT(
bytes: [3, 24, 95, 141, 179, 34, 113],
Expand Down Expand Up @@ -132,7 +132,7 @@ class CoefficientPackingTests: XCTestCase {
]

for kat in kats {
let coeffs: [T] = CoefficientPacking.bytesToCoefficients(
let coeffs: [T] = try CoefficientPacking.bytesToCoefficients(
bytes: kat.bytes,
bitsPerCoeff: kat.bitsPerCoeff,
decode: kat.decode,
Expand All @@ -141,8 +141,8 @@ class CoefficientPackingTests: XCTestCase {
}
}

runTest(UInt32.self)
runTest(UInt64.self)
try runTest(UInt32.self)
try runTest(UInt64.self)
}

func testCoeffsToBytesKAT() throws {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ final class GaloisTests: XCTestCase {
func testGaloisElementsToSteps() throws {
let galoisElements = [2, 3, 9, 11]
let degree = 8
let result = try GaloisElement.stepsFor(elements: galoisElements, degree: degree)
let result = GaloisElement.stepsFor(elements: galoisElements, degree: degree)
let expected = [2: nil, 3: 3, 9: 2, 11: 1]
XCTAssertEqual(result, expected)

Expand All @@ -119,7 +119,7 @@ final class GaloisTests: XCTestCase {
try elementToStep[GaloisElement.rotatingColumns(by: step, degree: degree)] = step
}

let result = try GaloisElement.stepsFor(elements: Array(elementToStep.keys), degree: degree)
let result = GaloisElement.stepsFor(elements: Array(elementToStep.keys), degree: degree)
XCTAssertEqual(result, elementToStep)
}
}
Expand Down
Loading