Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Diagonal Unpacking for PlaintextMatrix #102

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions Sources/PrivateNearestNeighborsSearch/PlaintextMatrix.swift
Original file line number Diff line number Diff line change
Expand Up @@ -481,8 +481,7 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
case .denseRow:
return try unpackDenseRow()
case .diagonal:
// TODO: Implement
preconditionFailure("Unpacking diagonal plaintext matrix not supported")
return try unpackDiagonal()
}
}

Expand Down Expand Up @@ -578,6 +577,59 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
return values
}

/// Unpacks a plaintext matrix with `diagonal` packing.
/// - Returns: The stored data values in row-major format.
/// - Throws: Error upon failure to unpack the matrix.
@inlinable
func unpackDiagonal() throws -> [Scheme.Scalar] where Format == Coeff {
guard case let .diagonal(babyStepGiantStep) = packing else {
let expectedBabyStepGiantStep = BabyStepGiantStep(vectorDimension: columnCount)
throw PnnsError.wrongMatrixPacking(
got: packing,
expected: .diagonal(babyStepGiantStep: expectedBabyStepGiantStep))
}
var packedValues = Array2d<Scheme.Scalar>.zero(rowCount: 0, columnCount: rowCount)
let expectedPlaintextCount = try PlaintextMatrix.plaintextCount(
encryptionParameters: context.encryptionParameters,
dimensions: dimensions,
packing: packing)
let plaintextsPerColumn = expectedPlaintextCount / columnCount.nextPowerOfTwo
let middle = context.degree / 2

for (chunkIndex, babyStepChunk) in plaintexts.chunks(ofCount: babyStepGiantStep.babyStep * plaintextsPerColumn)
.enumerated()
{
let rotationStep = chunkIndex * babyStepGiantStep.babyStep
let rotated: [[Scheme.Scalar]] = try babyStepChunk.map { plaintext in
var decodedValues: [Scheme.Scalar] = try plaintext.decode(format: .simd)
decodedValues[0..<middle].rotate(toStartAt: rotationStep)
decodedValues[middle...].rotate(toStartAt: middle + rotationStep)
return decodedValues
}
let diagonals = rotated.chunks(ofCount: plaintextsPerColumn).map { diagonalChunks in
diagonalChunks.flatMap { $0 }[0..<rowCount]
}
packedValues.append(rows: diagonals.flatMap { $0 })
}
var values = Array2d<Scheme.Scalar>.zero(rowCount: rowCount, columnCount: columnCount)
let columnNextPowerOfTwo = columnCount.nextPowerOfTwo
var valuesCount = 0
for rowIndex in 0..<packedValues.rowCount {
for columnIndex in 0..<packedValues.columnCount {
let valuesRowIndex = columnIndex
let valuesColumnIndex = (rowIndex + columnIndex) % columnNextPowerOfTwo
if valuesColumnIndex < columnCount {
values[valuesRowIndex, valuesColumnIndex] = packedValues[rowIndex, columnIndex]
valuesCount += 1
}
}
}
guard valuesCount == count else {
throw PnnsError.wrongEncodingValuesCount(got: valuesCount, expected: count)
}
return values.data
}

/// Symmetric secret key encryption of the plaintext matrix.
/// - Parameter secretKey: Secret key to encrypt with.
/// - Returns: A ciphertext encrypting the plaintext matrix.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,7 @@ final class PlaintextMatrixTests: XCTestCase {
XCTAssertEqual(plaintextMatrix.packing, packing)
XCTAssertEqual(plaintextMatrix.context, context)
// Test round-trip
switch packing {
case .diagonal: // TODO: test .diagonal once implemented
break
default:
XCTAssertEqual(try plaintextMatrix.unpack(), encodeValues.flatMap { $0 })
}
XCTAssertEqual(try plaintextMatrix.unpack(), encodeValues.flatMap { $0 })

// Test representation
XCTAssertEqual(plaintextMatrix.plaintexts.count, expected.count)
Expand All @@ -155,45 +150,40 @@ final class PlaintextMatrixTests: XCTestCase {
}

// Test signed encoding/decoding
switch packing {
case .diagonal: // TODO: test .diagonal once implemented
break
default:
let signedValues: [Scheme.SignedScalar] = try plaintextMatrix.unpack()
let signedMatrix = try PlaintextMatrix<Scheme, Coeff>(
context: context,
dimensions: dimensions,
packing: packing,
signedValues: signedValues)
let signedRoundtrip: [Scheme.SignedScalar] = try signedMatrix.unpack()
XCTAssertEqual(signedRoundtrip, signedValues)
let signedValues: [Scheme.SignedScalar] = try plaintextMatrix.unpack()
let signedMatrix = try PlaintextMatrix<Scheme, Coeff>(
context: context,
dimensions: dimensions,
packing: packing,
signedValues: signedValues)
let signedRoundtrip: [Scheme.SignedScalar] = try signedMatrix.unpack()
XCTAssertEqual(signedRoundtrip, signedValues)

// Test modular reduction
let largerValues = encodeValues.flatMap { $0 }.map { $0 + t }
let largerSignedValues = signedValues.enumerated().map { index, value in
if index.isMultiple(of: 2) {
value + Scheme.SignedScalar(t)
} else {
value - Scheme.SignedScalar(t)
}
// Test modular reduction
let largerValues = encodeValues.flatMap { $0 }.map { $0 + t }
let largerSignedValues = signedValues.enumerated().map { index, value in
if index.isMultiple(of: 2) {
value + Scheme.SignedScalar(t)
} else {
value - Scheme.SignedScalar(t)
}
}

let largerPlaintextMatrix = try PlaintextMatrix<Scheme, Coeff>(
context: context,
dimensions: dimensions,
packing: packing,
values: largerValues,
reduce: true)
XCTAssertEqual(largerPlaintextMatrix, plaintextMatrix)
let largerPlaintextMatrix = try PlaintextMatrix<Scheme, Coeff>(
context: context,
dimensions: dimensions,
packing: packing,
values: largerValues,
reduce: true)
XCTAssertEqual(largerPlaintextMatrix, plaintextMatrix)

let largerSignedMatrix = try PlaintextMatrix<Scheme, Coeff>(
context: context,
dimensions: dimensions,
packing: packing,
signedValues: largerSignedValues,
reduce: true)
XCTAssertEqual(largerSignedMatrix, signedMatrix)
}
let largerSignedMatrix = try PlaintextMatrix<Scheme, Coeff>(
context: context,
dimensions: dimensions,
packing: packing,
signedValues: largerSignedValues,
reduce: true)
XCTAssertEqual(largerSignedMatrix, signedMatrix)
}

func testPlaintextMatrixDenseColumn() throws {
Expand Down
Loading