Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement CiphertextMatrix/extractDenseRow #66

Merged
merged 1 commit into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 155 additions & 6 deletions Sources/PrivateNearestNeighborsSearch/CiphertextMatrix.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,14 @@ import HomomorphicEncryption

/// Stores a matrix of scalars as ciphertexts.
struct CiphertextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable, Sendable {
typealias Packing = PlaintextMatrixPacking
typealias Dimensions = MatrixDimensions

/// Dimensions of the scalars.
@usableFromInline let dimensions: Dimensions
@usableFromInline let dimensions: MatrixDimensions

/// Dimensions of the scalar matrix in a SIMD-encoded plaintext.
@usableFromInline let simdDimensions: SimdEncodingDimensions

/// Plaintext packing with which the data is stored.
@usableFromInline let packing: Packing
@usableFromInline let packing: MatrixPacking

/// Encrypted data.
@usableFromInline let ciphertexts: [Ciphertext<Scheme, Format>]
Expand Down Expand Up @@ -59,7 +56,7 @@ struct CiphertextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable, Sendab
/// - ciphertexts: ciphertexts encrypting the data; must not be empty.
/// - Throws: Error upon failure to initialize the ciphertext matrix.
@inlinable
init(dimensions: Dimensions, packing: Packing, ciphertexts: [Ciphertext<Scheme, Format>]) throws {
init(dimensions: MatrixDimensions, packing: MatrixPacking, ciphertexts: [Ciphertext<Scheme, Format>]) throws {
guard let context = ciphertexts.first?.context else {
throw PnnsError.emptyCiphertextArray
}
Expand Down Expand Up @@ -128,3 +125,155 @@ extension CiphertextMatrix {
ciphertexts: coeffCiphertexts)
}
}

extension CiphertextMatrix {
/// Computes the evaluation key configuration for calling `extractDenseRow`.
/// - Parameters:
/// - encryptionParams: Encryption parameters; must support `.simd` encoding.
/// - dimensions: Dimensions of the matrix to call `extractDenseRow` on.
/// - Returns: The evaluation key configuration.
/// - Throws: Error upon failure to generate the evaluation key configuration.
@inlinable
static func extractDenseRowConfig(for encryptionParams: EncryptionParameters<Scheme>,
dimensions: MatrixDimensions) throws -> EvaluationKeyConfiguration
{
if dimensions.rowCount == 1 {
// extractDenseRow is a No-op, so no evaluation key required
return EvaluationKeyConfiguration()
}
guard let simdDimensions = encryptionParams.simdDimensions else {
throw PnnsError.simdEncodingNotSupported(for: encryptionParams)
}
let degree = encryptionParams.polyDegree
var galoisElements = [GaloisElement.swappingRows(degree: degree)]
let columnCountPowerOfTwo = dimensions.columnCount.nextPowerOfTwo
if columnCountPowerOfTwo != simdDimensions.columnCount {
try galoisElements.append(GaloisElement.rotatingColumns(by: columnCountPowerOfTwo, degree: degree))
}
return EvaluationKeyConfiguration(galoisElements: galoisElements)
}

/// Extracts a ciphertext matrix with a single row and `.denseRow` packing.
/// - Parameters:
/// - rowIndex: Row index to extract
/// - evaluationKey: Evaluation key; must have `CiphertextMatrix/extractDenseRow` configuration
/// - Returns: A ciphertext matrix in `.denseRow` format with 1 row
/// - Throws: Error upon failure to extract the row.
@inlinable
func extractDenseRow(rowIndex: Int, evaluationKey: EvaluationKey<Scheme>) throws -> Self
where Format == Scheme.CanonicalCiphertextFormat
{
precondition((0..<dimensions.rowCount).contains(rowIndex))
guard packing == .denseRow else {
throw PnnsError.wrongMatrixPacking(got: packing, expected: .denseRow)
}
precondition(simdDimensions.rowCount == 2, "SIMD row count must be 2")

let columnCountPowerOfTwo = dimensions.columnCount.nextPowerOfTwo
let degree = context.degree.nextPowerOfTwo
let rowsPerSimdRow = simdDimensions.columnCount / columnCount
let rowsPerCiphertext = rowsPerSimdRow * simdDimensions.rowCount
let ciphertextIndex = rowIndex / rowsPerCiphertext
if rowCount == 1 {
return self
}

// Suppose, e.g., N=16, columnCount = 2, and the ciphertext data encrypts 2 rows: [1, 2] and [3, 4].
// These rows are packed in the ciphertext SIMD simd rows as
// [[1, 2, 3, 4, 1, 2, 3, 4],
// [1, 2, 3, 4, 1, 2, 3, 4]].
// Suppose ciphertextRowIndex == 1, i.e., we want to return an encryption of
// [[3, 4, 3, 4, 3, 4, 3, 4], [3, 4, 3, 4, 3, 4, 3, 4]]

// Returns the SIMD slot indices for the `rowIndex`'th row of the ciphertext matrix.
func simdSlotIndices(rowIndex: Int) -> Range<Int> {
precondition((0..<dimensions.rowCount).contains(rowIndex))
let ciphertextRowIndex = rowIndex % rowsPerCiphertext
let batchStart = ciphertextRowIndex * columnCountPowerOfTwo
var batchIndices = (batchStart..<batchStart + columnCountPowerOfTwo)
// Ensure no repeated values span multiple SIMD rows
let overflowsSimdRow = batchIndices.contains(simdDimensions.columnCount)
if overflowsSimdRow {
batchIndices = (simdDimensions.columnCount..<simdDimensions.columnCount + columnCountPowerOfTwo)
} else if batchIndices.upperBound > simdDimensions.columnCount {
let padding = simdColumnCount % columnCountPowerOfTwo
batchIndices = (batchIndices.startIndex + padding..<batchIndices.endIndex + padding)
}
// The last ciphertext pads until the end of the ciphertext.
if ciphertextIndex == ciphertexts.indices.last {
let upperBound = batchIndices.endIndex.nextMultiple(of: simdDimensions.columnCount, variableTime: true)
batchIndices = batchIndices.startIndex..<upperBound
}
return batchIndices
}
let batchIndices = simdSlotIndices(rowIndex: rowIndex)

// The number of rows in covered by `batchIndices`
let rowCountInBatch = {
var lastRowIndexInBatch = rowIndex + 1
while lastRowIndexInBatch < dimensions.rowCount,
simdSlotIndices(rowIndex: lastRowIndexInBatch).upperBound == batchIndices.upperBound
{
lastRowIndexInBatch += 1
}
var firstRowIndexInBatch = rowIndex > 0 ? rowIndex - 1 : 0
while firstRowIndexInBatch > 0,
simdSlotIndices(rowIndex: firstRowIndexInBatch).upperBound == batchIndices.upperBound
{
firstRowIndexInBatch -= 1
}
return lastRowIndexInBatch - firstRowIndexInBatch
}()

// First, we mask out just the ciphertext data row vector e.g.,
// plaintextMask = [[0, 0, 1, 1, 0, 0, 1, 1],
// [0, 0, 0, 0, 0, 0, 0, 0]]
let (plaintextMask, copiesInMask) = try {
var repeatMask = Array(repeating: Scheme.Scalar(1), count: columnCountPowerOfTwo)
repeatMask += Array(repeating: 0, count: columnCountPowerOfTwo * (rowCountInBatch - 1))
// pad to next power of two
repeatMask += Array(repeating: 0, count: repeatMask.count.nextPowerOfTwo - repeatMask.count)

var mask = Array(repeating: Scheme.Scalar(0), count: batchIndices.lowerBound)
var repeatCountInMask = 0
while mask.count < batchIndices.upperBound {
mask += repeatMask
repeatCountInMask += 1
}
mask = Array(mask.prefix(degree))
let plaintext: Plaintext<Scheme, Eval> = try context.encode(values: mask, format: .simd)
return (plaintext, repeatCountInMask)
}()

var ciphertextEval = try ciphertexts[ciphertextIndex].convertToEvalFormat()
try ciphertextEval *= plaintextMask
var ciphertext = try ciphertextEval.convertToCanonicalFormat()
// e.g., `ciphertext` now encrypts
// [[0, 0, 3, 4, 0, 0, 3, 4],
// [0, 0, 0, 0, 0, 0, 0, 0]]

// Replicate the values across one SIMD row by rotating and adding.
let rotateCount = simdColumnCount / (copiesInMask * columnCountPowerOfTwo) - 1
var ciphertextCopyRight = ciphertext
for _ in 0..<rotateCount {
try ciphertextCopyRight.rotateColumns(by: columnCountPowerOfTwo, using: evaluationKey)
try ciphertext += ciphertextCopyRight
}
// e.g., `ciphertext` now encrypts
// [[3, 4, 3, 4, 3, 4, 3, 4],
// [0, 0, 0, 0, 0, 0, 0, 0]]

// Duplicate values to both SIMD rows
var ciphertextCopy = ciphertext
try ciphertextCopy.swapRows(using: evaluationKey)
try ciphertext += ciphertextCopy
// e.g., `ciphertext` now encrypts
// [[3, 4, 3, 4, 3, 4, 3, 4],
// [3, 4, 3, 4, 3, 4, 3, 4]]

return try CiphertextMatrix(
dimensions: MatrixDimensions(rowCount: 1, columnCount: columnCount),
packing: packing,
ciphertexts: [ciphertext])
}
}
8 changes: 3 additions & 5 deletions Sources/PrivateNearestNeighborsSearch/Error.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@ enum PnnsError: Error, Equatable {
case wrongCiphertextCount(got: Int, expected: Int)
case wrongContext(gotDescription: String, expectedDescription: String)
case wrongEncodingValuesCount(got: Int, expected: Int)
case wrongMatrixPacking(got: MatrixPacking, expected: MatrixPacking)
case wrongPlaintextCount(got: Int, expected: Int)
case wrongPlaintextMatrixPacking(
got: PlaintextMatrixPacking,
expected: PlaintextMatrixPacking)
}

extension PnnsError {
Expand Down Expand Up @@ -59,10 +57,10 @@ extension PnnsError: LocalizedError {
"Wrong context: got \(gotDescription), expected \(expectedDescription)"
case let .wrongEncodingValuesCount(got, expected):
"Wrong encoding values count \(got), expected \(expected)"
case let .wrongMatrixPacking(got: got, expected: expected):
"Wrong matrix packing \(got), expected \(expected)"
case let .wrongPlaintextCount(got, expected):
"Wrong plaintext count \(got), expected \(expected)"
case let .wrongPlaintextMatrixPacking(got: got, expected: expected):
"Wrong plaintext matrix packing \(got), expected \(expected)"
}
}
}
42 changes: 21 additions & 21 deletions Sources/PrivateNearestNeighborsSearch/PlaintextMatrix.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import Algorithms
import HomomorphicEncryption

/// Different algorithms for packing a matrix of scalar values into plaintexts.
enum PlaintextMatrixPacking: Codable, Equatable, Hashable, Sendable {
/// Different algorithms for packing a matrix of scalar values into plaintexts / ciphertexts.
enum MatrixPacking: Codable, Equatable, Hashable, Sendable {
/// As many columns of data are packed sequentially into each plaintext SIMD row as possible, such that no SIMD row
/// contains data from multiple columns.
case denseColumn
Expand Down Expand Up @@ -64,17 +64,14 @@ struct MatrixDimensions: Equatable, Sendable {

/// Stores a matrix of scalars as plaintexts.
struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable, Sendable {
typealias Packing = PlaintextMatrixPacking
typealias Dimensions = MatrixDimensions

/// Dimensions of the scalars.
@usableFromInline let dimensions: Dimensions
@usableFromInline let dimensions: MatrixDimensions

/// Dimensions of the scalar matrix in a SIMD-encoded plaintext.
let simdDimensions: SimdEncodingDimensions

/// Plaintext packing with which the data is stored.
@usableFromInline let packing: Packing
@usableFromInline let packing: MatrixPacking

/// Plaintexts encoding the scalars.
let plaintexts: [Plaintext<Scheme, Format>]
Expand Down Expand Up @@ -107,7 +104,7 @@ struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable, Sendabl
/// - plaintexts: Plaintexts encoding the data; must not be empty.
/// - Throws: Error upon failure to initialize the plaintext matrix.
@inlinable
init(dimensions: Dimensions, packing: Packing, plaintexts: [Plaintext<Scheme, Format>]) throws {
init(dimensions: MatrixDimensions, packing: MatrixPacking, plaintexts: [Plaintext<Scheme, Format>]) throws {
guard !plaintexts.isEmpty else {
throw PnnsError.emptyPlaintextArray
}
Expand Down Expand Up @@ -143,7 +140,11 @@ struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable, Sendabl
/// - values: The data values to store in the plaintext matrix; stored in row-major format.
/// - Throws: Error upon failure to create the plaitnext matrix.
@inlinable
init(context: Context<Scheme>, dimensions: Dimensions, packing: Packing, values: [some ScalarType]) throws
init(
context: Context<Scheme>,
dimensions: MatrixDimensions,
packing: MatrixPacking,
values: [some ScalarType]) throws
where Format == Coeff
{
guard values.count == dimensions.count, !values.isEmpty else {
Expand Down Expand Up @@ -182,8 +183,8 @@ struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable, Sendabl
@inlinable
static func plaintextCount(
encryptionParameters: EncryptionParameters<Scheme>,
dimensions: PlaintextMatrix.Dimensions,
packing: PlaintextMatrix.Packing) throws -> Int
dimensions: MatrixDimensions,
packing: MatrixPacking) throws -> Int
{
guard let simdDimensions = encryptionParameters.simdDimensions else {
throw PnnsError.simdEncodingNotSupported(for: encryptionParameters)
Expand Down Expand Up @@ -219,7 +220,7 @@ struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable, Sendabl
/// - Returns: The plaintexts for `denseColumn` packing.
/// - Throws: Error upon plaintext to compute the plaintexts.
@inlinable
static func denseColumnPlaintexts<V: ScalarType>(context: Context<Scheme>, dimensions: Dimensions,
static func denseColumnPlaintexts<V: ScalarType>(context: Context<Scheme>, dimensions: MatrixDimensions,
values: [V]) throws -> [Scheme.CoeffPlaintext]
{
let degree = context.degree
Expand Down Expand Up @@ -277,14 +278,14 @@ struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable, Sendabl
@inlinable
static func denseRowPlaintexts<V: ScalarType>(
context: Context<Scheme>,
dimensions: Dimensions,
dimensions: MatrixDimensions,
values: [V]) throws -> [Plaintext<Scheme, Coeff>]
{
let encryptionParameters = context.encryptionParameters
guard let simdDimensions = context.simdDimensions else {
throw PnnsError.simdEncodingNotSupported(for: encryptionParameters)
}
precondition(simdDimensions.rowCount == 2, "simdRowCount must be 2")
precondition(simdDimensions.rowCount == 2, "SIMD row count must be 2")
let simdColumnCount = simdDimensions.columnCount
guard dimensions.columnCount <= simdColumnCount else {
throw PnnsError.invalidMatrixDimensions(dimensions)
Expand Down Expand Up @@ -351,8 +352,8 @@ struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable, Sendabl
@inlinable
static func diagonalPlaintexts<V: ScalarType>(
context: Context<Scheme>,
dimensions: Dimensions,
packing: PlaintextMatrixPacking,
dimensions: MatrixDimensions,
packing: MatrixPacking,
values: [V]) throws -> [Scheme.CoeffPlaintext]
{
let encryptionParameters = context.encryptionParameters
Expand All @@ -367,8 +368,7 @@ struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable, Sendabl
}
guard case let .diagonal(bsgs) = packing else {
let expectedBsgs = BabyStepGiantStep(vectorDimension: dimensions.columnCount)
throw PnnsError
.wrongPlaintextMatrixPacking(got: packing, expected: .diagonal(babyStepGiantStep: expectedBsgs))
throw PnnsError.wrongMatrixPacking(got: packing, expected: .diagonal(babyStepGiantStep: expectedBsgs))
}

let data = Array2d(data: values, rowCount: dimensions.rowCount, columnCount: dimensions.columnCount)
Expand Down Expand Up @@ -439,11 +439,11 @@ struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable, Sendabl
@inlinable
func unpackDenseColumn<V: ScalarType>() throws -> [V] where Format == Coeff {
guard case packing = .denseColumn else {
throw PnnsError.wrongPlaintextMatrixPacking(got: packing, expected: .denseColumn)
throw PnnsError.wrongMatrixPacking(got: packing, expected: .denseColumn)
}
let simdColumnCount = simdDimensions.columnCount
let simdRowCount = simdDimensions.rowCount
precondition(simdRowCount == 2, "simdRowCount must be 2")
precondition(simdRowCount == 2, "SIMD row count must be 2")
let columnsPerPlaintextCount = simdRowCount * (simdColumnCount / rowCount)

var valuesColumnMajor: [V] = []
Expand Down Expand Up @@ -483,7 +483,7 @@ struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable, Sendabl
@inlinable
func unpackDenseRow<V: ScalarType>() throws -> [V] where Format == Coeff {
guard case packing = .denseRow else {
throw PnnsError.wrongPlaintextMatrixPacking(got: packing, expected: Packing.denseRow)
throw PnnsError.wrongMatrixPacking(got: packing, expected: MatrixPacking.denseRow)
}
let simdColumnCount = simdDimensions.columnCount

Expand Down
Loading
Loading