Skip to content

Commit

Permalink
Merge pull request #35 from pluto/feat/byte-pack-aes
Browse files Browse the repository at this point in the history
feat: byte pack NIVC
  • Loading branch information
devloper authored Nov 7, 2024
2 parents 7adbc34 + 6bf1a03 commit 8035a0d
Show file tree
Hide file tree
Showing 16 changed files with 274 additions and 149 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ all: build
build:
@for circuit in $(CIRCOM_FILES); do \
echo "Processing $${circuit}..."; \
circom "$${circuit}" --r1cs -o "$$(dirname $${circuit})/artifacts" -l node_modules; \
circom "$${circuit}" --r1cs --wasm -o "$$(dirname $${circuit})/artifacts" -l node_modules; \
build-circuit "$${circuit}" "$$(dirname $${circuit})/artifacts/$$(basename $${circuit} .circom).bin" -l node_modules; \
done

Expand Down
130 changes: 19 additions & 111 deletions circuits/aes-gcm/nivc/aes-gctr-nivc.circom
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ template AESGCTRFOLD(DATA_BYTES) {
// ------------------------------------------------------------------------------------------------------------------ //
// ~~ Set sizes at compile time ~~
assert(DATA_BYTES % 16 == 0);
// Value for accumulating both plaintext and ciphertext as well as counter
var TOTAL_BYTES_ACROSS_NIVC = 2 * DATA_BYTES + 4;
// Value for accumulating both packed plaintext and ciphertext as well as counter
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4;
// ------------------------------------------------------------------------------------------------------------------ //


Expand All @@ -29,7 +29,7 @@ template AESGCTRFOLD(DATA_BYTES) {
// We extract the number from the 4 byte word counter
component last_counter_bits = BytesToBits(4);
for(var i = 0; i < 4; i ++) {
last_counter_bits.in[i] <== step_in[DATA_BYTES * 2 + i];
last_counter_bits.in[i] <== step_in[DATA_BYTES + i];
}
component last_counter_num = Bits2Num(32);
// pass in reverse order
Expand All @@ -46,124 +46,32 @@ template AESGCTRFOLD(DATA_BYTES) {
aes.plainText <== plainText;

for(var i = 0; i < 4; i++) {
aes.lastCounter[i] <== step_in[DATA_BYTES * 2 + i];
aes.lastCounter[i] <== step_in[DATA_BYTES + i];
}


// Write out the plaintext and ciphertext to our accumulation arrays, both at once.
signal prevAccumulatedPlaintext[DATA_BYTES];
for(var i = 0 ; i < DATA_BYTES ; i++) {
prevAccumulatedPlaintext[i] <== step_in[i];
signal textToPack[16][2];
for(var i = 0 ; i < 16 ; i++) {
textToPack[i][0] <== plainText[i];
textToPack[i][1] <== aes.cipherText[i];
}
signal prevAccumulatedCiphertext[DATA_BYTES];
signal nextPackedChunk[16] <== GenericBytePackArray(16,2)(textToPack);

signal prevAccumulatedPackedText[DATA_BYTES];
for(var i = 0 ; i < DATA_BYTES ; i++) {
prevAccumulatedCiphertext[i] <== step_in[DATA_BYTES + i];
prevAccumulatedPackedText[i] <== step_in[i];
}
component nextTexts = WriteToIndexForTwoArrays(DATA_BYTES, 16);
nextTexts.first_array_to_write_to <== prevAccumulatedPlaintext;
nextTexts.second_array_to_write_to <== prevAccumulatedCiphertext;
nextTexts.first_array_to_write_at_index <== plainText;
nextTexts.second_array_to_write_at_index <== aes.cipherText;
nextTexts.index <== index * 16;

component nextAccumulatedPackedText = WriteToIndex(DATA_BYTES, 16);
nextAccumulatedPackedText.array_to_write_to <== prevAccumulatedPackedText;
nextAccumulatedPackedText.array_to_write_at_index <== nextPackedChunk;
nextAccumulatedPackedText.index <== index * 16;

for(var i = 0 ; i < TOTAL_BYTES_ACROSS_NIVC ; i++) {
if(i < DATA_BYTES) {
step_out[i] <== nextTexts.outFirst[i];
} else if(i < 2 * DATA_BYTES) {
step_out[i] <== nextTexts.outSecond[i - DATA_BYTES];
} else if(i < 2 * DATA_BYTES + 4) {
step_out[i] <== aes.counter[i - (2 * DATA_BYTES)];
step_out[i] <== nextAccumulatedPackedText.out[i];
} else {
step_out[i] <== aes.counter[i - DATA_BYTES];
}
}
}



template WriteToIndexForTwoArrays(m, n) {
signal input first_array_to_write_to[m];
signal input second_array_to_write_to[m];
signal input first_array_to_write_at_index[n];
signal input second_array_to_write_at_index[n];
signal input index;

signal output outFirst[m];
signal output outSecond[m];

assert(m >= n);

// Note: this is underconstrained, we need to constrain that index + n <= m
// Need to constrain that index + n <= m -- can't be an assertion, because uses a signal
// ------------------------- //

// Here, we get an array of ALL zeros, except at the `index` AND `index + n`
// beginning-------^^^^^ end---^^^^^^^^^
signal indexMatched[m];
component indexBegining[m];
component indexEnding[m];
for(var i = 0 ; i < m ; i++) {
indexBegining[i] = IsZero();
indexBegining[i].in <== i - index;
indexEnding[i] = IsZero();
indexEnding[i].in <== i - (index + n);
indexMatched[i] <== indexBegining[i].out + indexEnding[i].out;
}

// E.g., index == 31, m == 160, n == 16
// => indexMatch[31] == 1;
// => indexMatch[47] == 1;
// => otherwise, all 0.

signal accum[m];
accum[0] <== indexMatched[0];

component writeAt = IsZero();
writeAt.in <== accum[0] - 1;

component orFirst = OR();
orFirst.a <== (writeAt.out * first_array_to_write_at_index[0]);
orFirst.b <== (1 - writeAt.out) * first_array_to_write_to[0];
outFirst[0] <== orFirst.out;

component orSecond = OR();
orSecond.a <== (writeAt.out * second_array_to_write_at_index[0]);
orSecond.b <== (1 - writeAt.out) * second_array_to_write_to[0];
outSecond[0] <== orSecond.out;
// IF accum == 1 then { array_to_write_at } ELSE IF accum != 1 then { array to write_to }
signal accum_index[m];
accum_index[0] <== accum[0];

component writeSelector[m - 1];
component indexSelectorFirst[m - 1];
component indexSelectorSecond[m - 1];
component orsFirst[m-1];
component orsSecond[m-1];
for(var i = 1 ; i < m ; i++) {
// accum will be 1 at all indices where we want to write the new array
accum[i] <== accum[i-1] + indexMatched[i];
writeSelector[i-1] = IsZero();
writeSelector[i-1].in <== accum[i] - 1;
// IsZero(accum[i] - 1); --> tells us we are in the range where we want to write the new array

indexSelectorFirst[i-1] = IndexSelector(n);
indexSelectorFirst[i-1].index <== accum_index[i-1];
indexSelectorFirst[i-1].in <== first_array_to_write_at_index;

indexSelectorSecond[i-1] = IndexSelector(n);
indexSelectorSecond[i-1].index <== accum_index[i-1];
indexSelectorSecond[i-1].in <== second_array_to_write_at_index;
// When accum is not zero, out is array_to_write_at_index, otherwise it is array_to_write_to

orsFirst[i-1] = OR();
orsFirst[i-1].a <== (writeSelector[i-1].out * indexSelectorFirst[i-1].out);
orsFirst[i-1].b <== (1 - writeSelector[i-1].out) * first_array_to_write_to[i];
outFirst[i] <== orsFirst[i-1].out;

orsSecond[i-1] = OR();
orsSecond[i-1].a <== (writeSelector[i-1].out * indexSelectorSecond[i-1].out);
orsSecond[i-1].b <== (1 - writeSelector[i-1].out) * second_array_to_write_to[i];
outSecond[i] <== orsSecond[i-1].out;

accum_index[i] <== accum_index[i-1] + writeSelector[i-1].out;
}
}
2 changes: 1 addition & 1 deletion circuits/http/nivc/body_mask.circom
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ include "../parser/machine.circom";

template HTTPMaskBodyNIVC(DATA_BYTES) {
// ------------------------------------------------------------------------------------------------------------------ //
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES * 2 + 4; // aes ct/pt + ctr
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4; // aes ct/pt + ctr
// ------------------------------------------------------------------------------------------------------------------ //

// ------------------------------------------------------------------------------------------------------------------ //
Expand Down
2 changes: 1 addition & 1 deletion circuits/http/nivc/lock_header.circom
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ include "circomlib/circuits/comparators.circom";
// TODO: should use a MAX_HEADER_NAME_LENGTH and a MAX_HEADER_VALUE_LENGTH
template LockHeader(DATA_BYTES, MAX_HEADER_NAME_LENGTH, MAX_HEADER_VALUE_LENGTH) {
// ------------------------------------------------------------------------------------------------------------------ //
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES * 2 + 4; // aes pt/ct + ctr
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4; // aes pt/ct + ctr
// ------------------------------------------------------------------------------------------------------------------ //

// ------------------------------------------------------------------------------------------------------------------ //
Expand Down
11 changes: 7 additions & 4 deletions circuits/http/nivc/parse_and_lock_start_line.circom
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@ include "../../utils/bytes.circom";
template ParseAndLockStartLine(DATA_BYTES, MAX_BEGINNING_LENGTH, MAX_MIDDLE_LENGTH, MAX_FINAL_LENGTH) {
// ------------------------------------------------------------------------------------------------------------------ //
// ~~ Set sizes at compile time ~~
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES * 2 + 4; // AES ct/pt + ctr
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4; // AES ct/pt + ctr
// ------------------------------------------------------------------------------------------------------------------ //

// ------------------------------------------------------------------------------------------------------------------ //
signal input step_in[TOTAL_BYTES_ACROSS_NIVC];
signal output step_out[TOTAL_BYTES_ACROSS_NIVC];

// Get the plaintext
signal data[DATA_BYTES];
signal packedData[DATA_BYTES];
for (var i = 0 ; i < DATA_BYTES ; i++) {
data[i] <== step_in[i];
packedData[i] <== step_in[i];
}
component unpackData = UnpackDoubleByteArray(DATA_BYTES);
unpackData.in <== packedData;
signal data[DATA_BYTES] <== unpackData.lower;

signal input beginning[MAX_BEGINNING_LENGTH];
signal input beginning_length;
Expand Down Expand Up @@ -100,7 +103,7 @@ template ParseAndLockStartLine(DATA_BYTES, MAX_BEGINNING_LENGTH, MAX_MIDDLE_LENG
for (var i = 0 ; i < TOTAL_BYTES_ACROSS_NIVC ; i++) {
// add plaintext http input to step_out and ignore the ciphertext
if(i < DATA_BYTES) {
step_out[i] <== step_in[i];
step_out[i] <== data[i]; // PASS OUT JUST THE PLAINTEXT DATA
} else {
step_out[i] <== 0;
}
Expand Down
2 changes: 1 addition & 1 deletion circuits/json/nivc/extractor.circom
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ include "@zk-email/circuits/utils/array.circom";

template MaskExtractFinal(DATA_BYTES, MAX_VALUE_LENGTH) {
// ------------------------------------------------------------------------------------------------------------------ //
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES * 2 + 4; // aes pt/ct + ctr
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4; // aes pt/ct + ctr
// ------------------------------------------------------------------------------------------------------------------ //
signal input step_in[TOTAL_BYTES_ACROSS_NIVC];
signal output step_out[TOTAL_BYTES_ACROSS_NIVC];
Expand Down
8 changes: 4 additions & 4 deletions circuits/json/nivc/masker.circom
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ include "../interpreter.circom";
template JsonMaskObjectNIVC(DATA_BYTES, MAX_STACK_HEIGHT, MAX_KEY_LENGTH) {
// ------------------------------------------------------------------------------------------------------------------ //
assert(MAX_STACK_HEIGHT >= 2); // TODO (autoparallel): idk if we need this now
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES * 2 + 4; // aes ct/pt + ctr
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4; // aes ct/pt + ctr
// ------------------------------------------------------------------------------------------------------------------ //
signal input step_in[TOTAL_BYTES_ACROSS_NIVC];
signal output step_out[TOTAL_BYTES_ACROSS_NIVC];
Expand Down Expand Up @@ -87,15 +87,15 @@ template JsonMaskObjectNIVC(DATA_BYTES, MAX_STACK_HEIGHT, MAX_KEY_LENGTH) {
// mask = currently parsing value and all subsequent keys matched
step_out[data_idx] <== data[data_idx] * or[data_idx - 1];
}
for(var i = DATA_BYTES - MAX_KEY_LENGTH; i < 2 * DATA_BYTES + 4; i ++) {
for(var i = DATA_BYTES - MAX_KEY_LENGTH; i < DATA_BYTES + 4; i ++) {
step_out[i] <== 0;
}
}

template JsonMaskArrayIndexNIVC(DATA_BYTES, MAX_STACK_HEIGHT) {
// ------------------------------------------------------------------------------------------------------------------ //
assert(MAX_STACK_HEIGHT >= 2); // TODO (autoparallel): idk if we need this now
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES * 2 + 4; // aes ct/pt + ctr
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4; // aes ct/pt + ctr
// ------------------------------------------------------------------------------------------------------------------ //
signal input step_in[TOTAL_BYTES_ACROSS_NIVC];
signal output step_out[TOTAL_BYTES_ACROSS_NIVC];
Expand Down Expand Up @@ -136,7 +136,7 @@ template JsonMaskArrayIndexNIVC(DATA_BYTES, MAX_STACK_HEIGHT) {
or[data_idx - 1] <== OR()(parsing_array[data_idx], parsing_array[data_idx - 1]);
step_out[data_idx] <== data[data_idx] * or[data_idx - 1];
}
for(var i = DATA_BYTES ; i < 2 * DATA_BYTES + 4; i++) {
for(var i = DATA_BYTES ; i < TOTAL_BYTES_ACROSS_NIVC; i++) {
step_out[i] <== 0;
}
}
35 changes: 19 additions & 16 deletions circuits/test/aes-gcm/nivc/aes-gctr-nivc.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ describe("aes-gctr-nivc", () => {


const DATA_BYTES_0 = 16;
const TOTAL_BYTES_ACROSS_NIVC_0 = 2 * DATA_BYTES_0 + 4;
const TOTAL_BYTES_ACROSS_NIVC_0 = DATA_BYTES_0 + 4;

it("all correct for self generated single zero pt block case", async () => {
circuit_one_block = await circomkit.WitnessTester("aes-gcm-fold", {
Expand All @@ -25,12 +25,13 @@ describe("aes-gctr-nivc", () => {
const counter = [0x00, 0x00, 0x00, 0x01];
const step_in = new Array(TOTAL_BYTES_ACROSS_NIVC_0).fill(0x00);
counter.forEach((value, index) => {
step_in[2 * DATA_BYTES_0 + index] = value;
step_in[DATA_BYTES_0 + index] = value;
});

let expected = plainText.concat(ct).concat([0x00, 0x00, 0x00, 0x02]);
expected = expected.concat(new Array(TOTAL_BYTES_ACROSS_NIVC_0 - expected.length).fill(0));
const witness = await circuit_one_block.compute({ key: key, iv: iv, plainText: plainText, aad: aad, step_in: step_in }, ["step_out"])

let packed = plainText.map((x, i) => x + (ct[i] * 256));
let expected = [...packed, 0x00, 0x00, 0x00, 0x02];
assert.deepEqual(witness.step_out, expected.map(BigInt));
});

Expand All @@ -50,18 +51,18 @@ describe("aes-gctr-nivc", () => {
const counter = [0x00, 0x00, 0x00, 0x01];
const step_in = new Array(TOTAL_BYTES_ACROSS_NIVC_0).fill(0x00);
counter.forEach((value, index) => {
step_in[2 * DATA_BYTES_0 + index] = value;
step_in[DATA_BYTES_0 + index] = value;
});

let expected = plainText.concat(ct).concat([0x00, 0x00, 0x00, 0x02]);
expected = expected.concat(new Array(TOTAL_BYTES_ACROSS_NIVC_0 - expected.length).fill(0));

const witness = await circuit_one_block.compute({ key: key, iv: iv, plainText: plainText, aad: aad, step_in: step_in }, ["step_out"])

let packed = plainText.map((x, i) => x + (ct[i] * 256));
let expected = [...packed, 0x00, 0x00, 0x00, 0x02];
assert.deepEqual(witness.step_out, expected.map(BigInt));
});

const DATA_BYTES_1 = 32;
const TOTAL_BYTES_ACROSS_NIVC_1 = DATA_BYTES_1 * 2 + 4;
const TOTAL_BYTES_ACROSS_NIVC_1 = DATA_BYTES_1 + 4;


let zero_block = [0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
Expand All @@ -83,12 +84,13 @@ describe("aes-gctr-nivc", () => {
const counter = [0x00, 0x00, 0x00, 0x01];
const step_in = new Array(TOTAL_BYTES_ACROSS_NIVC_1).fill(0x00);
counter.forEach((value, index) => {
step_in[2 * DATA_BYTES_1 + index] = value;
step_in[DATA_BYTES_1 + index] = value;
});
let expected = plainText1.concat(zero_block).concat(ct_part1).concat(zero_block).concat([0x00, 0x00, 0x00, 0x02]);
expected = expected.concat(new Array(TOTAL_BYTES_ACROSS_NIVC_1 - expected.length).fill(0));

const witness = await circuit_one_block.compute({ key: key, iv: iv, plainText: plainText1, aad: aad, step_in: step_in }, ["step_out"])

let packed1 = plainText1.map((x, i) => x + (ct_part1[i] * 256));
let expected = packed1.concat(zero_block).concat([0x00, 0x00, 0x00, 0x02]);
assert.deepEqual(witness.step_out, expected.map(BigInt));
});

Expand All @@ -99,12 +101,13 @@ describe("aes-gctr-nivc", () => {
params: [DATA_BYTES_1], // input len is 32 bytes
});

const counter = [0x00, 0x00, 0x00, 0x02];
let step_in = plainText1.concat(zero_block).concat(ct_part1).concat(zero_block).concat(counter);
let packed1 = plainText1.map((x, i) => x + (ct_part1[i] * 256));
let packed2 = plainText2.map((x, i) => x + (ct_part2[i] * 256));
let step_in = packed1.concat(zero_block).concat([0x00, 0x00, 0x00, 0x02]);
step_in = step_in.concat(new Array(TOTAL_BYTES_ACROSS_NIVC_1 - step_in.length).fill(0));

let expected = plainText1.concat(plainText2).concat(ct_part1).concat(ct_part2).concat([0x00, 0x00, 0x00, 0x03]);
expected = expected.concat(new Array(TOTAL_BYTES_ACROSS_NIVC_1 - expected.length).fill(0));

let expected = packed1.concat(packed2).concat([0x00, 0x00, 0x00, 0x03]);

const witness = await circuit_one_block.compute({ key: key, iv: iv, plainText: plainText2, aad: aad, step_in: step_in }, ["step_out"])
assert.deepEqual(witness.step_out, expected.map(BigInt));
Expand Down
6 changes: 2 additions & 4 deletions circuits/test/full/full.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ describe("NIVC_FULL", async () => {

const DATA_BYTES = 320;
const MAX_STACK_HEIGHT = 5;
const TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES * 2 + 4;
const TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4;

const MAX_HEADER_NAME_LENGTH = 20;
const MAX_HEADER_VALUE_LENGTH = 35;
Expand Down Expand Up @@ -132,7 +132,7 @@ describe("NIVC_FULL", async () => {
const counter = [0x00, 0x00, 0x00, 0x01];
const init_nivc_input = new Array(TOTAL_BYTES_ACROSS_NIVC).fill(0x00);
counter.forEach((value, index) => {
init_nivc_input[2 * DATA_BYTES + index] = value;
init_nivc_input[DATA_BYTES + index] = value;
});
let pt = http_response_plaintext.slice(0, 16);
aes_gcm = await aesCircuit.compute({ key: Array(16).fill(0), iv: Array(12).fill(0), plainText: pt, aad: Array(16).fill(0), step_in: init_nivc_input }, ["step_out"]);
Expand All @@ -154,8 +154,6 @@ describe("NIVC_FULL", async () => {
let maskedInput = extendedJsonInput.fill(0, 0, idx);
maskedInput = maskedInput.fill(0, 320);



let key0 = [100, 97, 116, 97, 0, 0, 0, 0]; // "data"
let key0Len = 4;
let key1 = [105, 116, 101, 109, 115, 0, 0, 0]; // "items"
Expand Down
2 changes: 1 addition & 1 deletion circuits/test/http/nivc/body_mask.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ describe("NIVC_HTTP", async () => {
let bodyMaskCircuit: WitnessTester<["step_in"], ["step_out"]>;

const DATA_BYTES = 320;
const TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES * 2 + 4;
const TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4;

const MAX_HEADER_NAME_LENGTH = 20;
const MAX_HEADER_VALUE_LENGTH = 35;
Expand Down
2 changes: 1 addition & 1 deletion circuits/test/http/nivc/lock_header.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ describe("HTTPLockHeader", async () => {
let lockHeaderCircuit: WitnessTester<["step_in", "header", "headerNameLength", "value", "headerValueLength"], ["step_out"]>;

const DATA_BYTES = 320;
const TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES * 2 + 4;
const TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4;

const MAX_BEGINNING_LENGTH = 10;
const MAX_MIDDLE_LENGTH = 50;
Expand Down
Loading

0 comments on commit 8035a0d

Please sign in to comment.