Skip to content

Commit

Permalink
feat: parser improvs (#75)
Browse files Browse the repository at this point in the history
* add body extractor circuit

* add tests

* add failure test

* support request data extraction

* feat(parser): change `parsing_start, parsing_header` to counter

* add parser updates

* add header field name match and value extraction

* fix: tests
  • Loading branch information
lonerapier authored Sep 3, 2024
1 parent a0ee612 commit 1d814b3
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 25 deletions.
113 changes: 104 additions & 9 deletions circuits/http/extractor.circom
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
pragma circom 2.1.9;

include "../utils/bytes.circom";
include "interpreter.circom";
include "parser/machine.circom";
include "../utils/bytes.circom";
include "../utils/search.circom";
include "circomlib/circuits/mux1.circom";
include "circomlib/circuits/gates.circom";
include "@zk-email/circuits/utils/array.circom";

// TODO:
Expand All @@ -24,6 +28,8 @@ template ExtractResponse(DATA_BYTES, maxContentLength) {
State[0].byte <== data[0];
State[0].parsing_start <== 1;
State[0].parsing_header <== 0;
State[0].parsing_field_name <== 0;
State[0].parsing_field_value <== 0;
State[0].parsing_body <== 0;
State[0].line_status <== 0;

Expand All @@ -35,25 +41,31 @@ template ExtractResponse(DATA_BYTES, maxContentLength) {
State[data_idx].byte <== data[data_idx];
State[data_idx].parsing_start <== State[data_idx - 1].next_parsing_start;
State[data_idx].parsing_header <== State[data_idx - 1].next_parsing_header;
State[data_idx].parsing_field_name <== State[data_idx-1].next_parsing_field_name;
State[data_idx].parsing_field_value <== State[data_idx-1].next_parsing_field_value;
State[data_idx].parsing_body <== State[data_idx - 1].next_parsing_body;
State[data_idx].line_status <== State[data_idx - 1].next_line_status;

// apply body mask to data
dataMask[data_idx] <== data[data_idx] * State[data_idx].next_parsing_body;

// Debugging
log("State[", data_idx, "].parsing_start ", "= ", State[data_idx].parsing_start);
log("State[", data_idx, "].parsing_header", "= ", State[data_idx].parsing_header);
log("State[", data_idx, "].parsing_body ", "= ", State[data_idx].parsing_body);
log("State[", data_idx, "].line_status ", "= ", State[data_idx].line_status);
log("State[", data_idx, "].parsing_start ", "= ", State[data_idx].parsing_start);
log("State[", data_idx, "].parsing_header ", "= ", State[data_idx].parsing_header);
log("State[", data_idx, "].parsing_field_name ", "= ", State[data_idx].parsing_field_name);
log("State[", data_idx, "].parsing_field_value", "= ", State[data_idx].parsing_field_value);
log("State[", data_idx, "].parsing_body ", "= ", State[data_idx].parsing_body);
log("State[", data_idx, "].line_status ", "= ", State[data_idx].line_status);
log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx");
}

// Debugging
log("State[", DATA_BYTES, "].parsing_start ", "= ", State[DATA_BYTES-1].next_parsing_start);
log("State[", DATA_BYTES, "].parsing_header", "= ", State[DATA_BYTES-1].next_parsing_header);
log("State[", DATA_BYTES, "].parsing_body ", "= ", State[DATA_BYTES-1].next_parsing_body);
log("State[", DATA_BYTES, "].line_status ", "= ", State[DATA_BYTES-1].next_line_status);
log("State[", DATA_BYTES, "].parsing_start ", "= ", State[DATA_BYTES-1].next_parsing_start);
log("State[", DATA_BYTES, "].parsing_header ", "= ", State[DATA_BYTES-1].next_parsing_header);
log("State[", DATA_BYTES, "].parsing_field_name ", "= ", State[DATA_BYTES-1].parsing_field_name);
log("State[", DATA_BYTES, "].parsing_field_value", "= ", State[DATA_BYTES-1].parsing_field_value);
log("State[", DATA_BYTES, "].parsing_body ", "= ", State[DATA_BYTES-1].next_parsing_body);
log("State[", DATA_BYTES, "].line_status ", "= ", State[DATA_BYTES-1].next_line_status);
log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx");

signal valueStartingIndex[DATA_BYTES];
Expand All @@ -68,4 +80,87 @@ template ExtractResponse(DATA_BYTES, maxContentLength) {
}

response <== SelectSubArray(DATA_BYTES, maxContentLength)(dataMask, valueStartingIndex[DATA_BYTES-1]+1, DATA_BYTES - valueStartingIndex[DATA_BYTES-1]);
}

template ExtractHeaderValue(DATA_BYTES, headerNameLength, maxValueLength) {
signal input data[DATA_BYTES];
signal input header[headerNameLength];

signal output value[maxValueLength];

//--------------------------------------------------------------------------------------------//
//-CONSTRAINTS--------------------------------------------------------------------------------//
//--------------------------------------------------------------------------------------------//
component dataASCII = ASCII(DATA_BYTES);
dataASCII.in <== data;
//--------------------------------------------------------------------------------------------//

// Initialze the parser
component State[DATA_BYTES];
State[0] = StateUpdate();
State[0].byte <== data[0];
State[0].parsing_start <== 1;
State[0].parsing_header <== 0;
State[0].parsing_field_name <== 0;
State[0].parsing_field_value <== 0;
State[0].parsing_body <== 0;
State[0].line_status <== 0;

signal headerMatch[DATA_BYTES];
headerMatch[0] <== 0;
signal isHeaderNameMatch[DATA_BYTES];
isHeaderNameMatch[0] <== 0;
signal readCRLF[DATA_BYTES];
readCRLF[0] <== 0;
signal valueMask[DATA_BYTES];
valueMask[0] <== 0;

for(var data_idx = 1; data_idx < DATA_BYTES; data_idx++) {
State[data_idx] = StateUpdate();
State[data_idx].byte <== data[data_idx];
State[data_idx].parsing_start <== State[data_idx - 1].next_parsing_start;
State[data_idx].parsing_header <== State[data_idx - 1].next_parsing_header;
State[data_idx].parsing_field_name <== State[data_idx-1].next_parsing_field_name;
State[data_idx].parsing_field_value <== State[data_idx-1].next_parsing_field_value;
State[data_idx].parsing_body <== State[data_idx - 1].next_parsing_body;
State[data_idx].line_status <== State[data_idx - 1].next_line_status;

// apply value mask to data
// TODO: change r
headerMatch[data_idx] <== HeaderFieldNameMatch(DATA_BYTES, headerNameLength)(data, header, 100, data_idx);
readCRLF[data_idx] <== IsEqual()([State[data_idx].line_status, 2]);
isHeaderNameMatch[data_idx] <== Mux1()([isHeaderNameMatch[data_idx-1] * (1-readCRLF[data_idx]), 1], headerMatch[data_idx]);
valueMask[data_idx] <== MultiAND(3)([data[data_idx], isHeaderNameMatch[data_idx], State[data_idx].parsing_field_value]);

// Debugging
log("State[", data_idx, "].parsing_start ", "= ", State[data_idx].parsing_start);
log("State[", data_idx, "].parsing_header ", "= ", State[data_idx].parsing_header);
log("State[", data_idx, "].parsing_field_name ", "= ", State[data_idx].parsing_field_name);
log("State[", data_idx, "].parsing_field_value", "= ", State[data_idx].parsing_field_value);
log("State[", data_idx, "].parsing_body ", "= ", State[data_idx].parsing_body);
log("State[", data_idx, "].line_status ", "= ", State[data_idx].line_status);
log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx");
}

// Debugging
log("State[", DATA_BYTES, "].parsing_start ", "= ", State[DATA_BYTES-1].next_parsing_start);
log("State[", DATA_BYTES, "].parsing_header ", "= ", State[DATA_BYTES-1].next_parsing_header);
log("State[", DATA_BYTES, "].parsing_field_name ", "= ", State[DATA_BYTES-1].parsing_field_name);
log("State[", DATA_BYTES, "].parsing_field_value", "= ", State[DATA_BYTES-1].parsing_field_value);
log("State[", DATA_BYTES, "].parsing_body ", "= ", State[DATA_BYTES-1].next_parsing_body);
log("State[", DATA_BYTES, "].line_status ", "= ", State[DATA_BYTES-1].next_line_status);
log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx");

signal valueStartingIndex[DATA_BYTES];
signal isZeroMask[DATA_BYTES];
signal isPrevStartingIndex[DATA_BYTES];
valueStartingIndex[0] <== 0;
isZeroMask[0] <== IsZero()(valueMask[0]);
for (var i=1 ; i<DATA_BYTES ; i++) {
isZeroMask[i] <== IsZero()(valueMask[i]);
isPrevStartingIndex[i] <== IsZero()(valueStartingIndex[i-1]);
valueStartingIndex[i] <== valueStartingIndex[i-1] + i * (1-isZeroMask[i]) * isPrevStartingIndex[i];
}

value <== SelectSubArray(DATA_BYTES, maxValueLength)(valueMask, valueStartingIndex[DATA_BYTES-1]+1, maxValueLength);
}
54 changes: 52 additions & 2 deletions circuits/http/interpreter.circom
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
pragma circom 2.1.9;

include "parser/language.circom";
include "../utils/search.circom";
include "../utils/array.circom";

/* TODO:
/* TODO:
Notes --
- This is a pretty efficient way to simply check what the method used in a request is by checking
the first `DATA_LENGTH` number of bytes.
the first `DATA_LENGTH` number of bytes.
- Could probably change this to a template that checks if it is one of the given methods
so we don't check them all in one
*/
Expand All @@ -32,4 +33,53 @@ template YieldMethod(DATA_LENGTH) {
signal TagPost <== IsPost.out * RequestMethodTag.POST;

MethodTag <== TagGet + TagPost;
}

// https://www.rfc-editor.org/rfc/rfc9112.html#name-field-syntax
template HeaderFieldNameValueMatch(dataLen, nameLen, valueLen) {
signal input data[dataLen];
signal input headerName[nameLen];
signal input headerValue[valueLen];
signal input r;
signal input index;

component syntax = Syntax();

signal output value[valueLen];

// is name matches
signal headerNameMatch <== SubstringMatchWithIndex(dataLen, nameLen)(data, headerName, r, index);

// next byte to name should be COLON
signal endOfHeaderName <== IndexSelector(dataLen)(data, index + nameLen);
signal isNextByteColon <== IsEqual()([endOfHeaderName, syntax.COLON]);

signal headerNameMatchAndNextByteColon <== headerNameMatch * isNextByteColon;

// field-name: SP field-value
signal headerValueMatch <== SubstringMatchWithIndex(dataLen, valueLen)(data, headerValue, r, index + nameLen + 2);

// header name matches + header value matches
signal output out <== headerNameMatchAndNextByteColon * headerValueMatch;
}

// https://www.rfc-editor.org/rfc/rfc9112.html#name-field-syntax
template HeaderFieldNameMatch(dataLen, nameLen) {
signal input data[dataLen];
signal input headerName[nameLen];
signal input r;
signal input index;

component syntax = Syntax();

// is name matches
signal headerNameMatch <== SubstringMatchWithIndex(dataLen, nameLen)(data, headerName, r, index);

// next byte to name should be COLON
signal endOfHeaderName <== IndexSelector(dataLen)(data, index + nameLen);
signal isNextByteColon <== IsEqual()([endOfHeaderName, syntax.COLON]);

// header name matches
signal output out;
out <== headerNameMatch * isNextByteColon;
}
64 changes: 53 additions & 11 deletions circuits/http/parser/machine.circom
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,37 @@ include "language.circom";
include "../../utils/array.circom";

template StateUpdate() {
signal input parsing_start; // Bool flag for if we are in the start line
signal input parsing_start; // flag that counts up to 3 for each value in the start line
signal input parsing_header; // Flag + Counter for what header line we are in
signal input parsing_body;
signal input parsing_field_name; // flag that tells if parsing header field name
signal input parsing_field_value; // flag that tells if parsing header field value
signal input parsing_body; // Flag when we are inside body
signal input line_status; // Flag that counts up to 4 to read a double CLRF
signal input byte;

signal output next_parsing_start;
signal output next_parsing_header;
signal output next_parsing_field_name;
signal output next_parsing_field_value;
signal output next_parsing_body;
signal output next_line_status;

component Syntax = Syntax();

//---------------------------------------------------------------------------------//
//---------------------------------------------------------------------------------//
// check if we read space or colon
component readSP = IsEqual();
readSP.in <== [byte, Syntax.SPACE];
component readColon = IsEqual();
readColon.in <== [byte, Syntax.COLON];

// Check if what we just read is a CR / LF
component readCR = IsEqual();
readCR.in <== [byte, Syntax.CR];
component readLF = IsEqual();
readLF.in <== [byte, Syntax.LF];

signal notCRAndLF <== (1 - readCR.out) * (1 - readLF.out);
signal notCRAndLF <== (1 - readCR.out) * (1 - readLF.out);
//---------------------------------------------------------------------------------//

//---------------------------------------------------------------------------------//
Expand All @@ -42,32 +52,64 @@ template StateUpdate() {

//---------------------------------------------------------------------------------//
// Take current state and CRLF info to update state
signal state[3] <== [parsing_start, parsing_header, parsing_body];
signal state[5] <== [parsing_start, parsing_header, parsing_field_name, parsing_field_value, parsing_body];
component stateChange = StateChange();
stateChange.readCRLF <== readCRLF;
stateChange.readCRLFCRLF <== readCRLFCRLF;
stateChange.readSP <== readSP.out;
stateChange.readColon <== readColon.out;
stateChange.state <== state;

component nextState = ArrayAdd(3);
component nextState = ArrayAdd(5);
nextState.lhs <== state;
nextState.rhs <== stateChange.out;
//---------------------------------------------------------------------------------//

next_parsing_start <== nextState.out[0];
next_parsing_header <== nextState.out[1];
next_parsing_body <== nextState.out[2];
next_parsing_field_name <== nextState.out[2];
next_parsing_field_value <== nextState.out[3];
next_parsing_body <== nextState.out[4];
next_line_status <== line_status + readCR.out + readCRLF + readCRLFCRLF - line_status * notCRAndLF;

}

// TODO:
// - multiple space between start line values
// - handle incrementParsingHeader being incremented for header -> body CRLF
// - header value parsing doesn't handle SPACE between colon and actual value
template StateChange() {
signal input readCRLF;
signal input readCRLFCRLF;
signal input state[3];
signal output out[3];
signal input readSP;
signal input readColon;
signal input state[5];
signal output out[5];

// GreaterEqThan(2) because start line can have at most 3 values for request or response
signal isParsingStart <== GreaterEqThan(2)([state[0], 1]);
// increment parsing start counter on reading SP
signal incrementParsingStart <== readSP * isParsingStart;
// disable parsing start on reading CRLF
signal disableParsingStart <== readCRLF * state[0];

// enable parsing header on reading CRLF
signal enableParsingHeader <== readCRLF * isParsingStart;
// check if we are parsing header
signal isParsingHeader <== GreaterEqThan(10)([state[1], 1]);
// increment parsing header counter on CRLF and parsing header
signal incrementParsingHeader <== readCRLF * isParsingHeader;
// disable parsing header on reading CRLF-CRLF
signal disableParsingHeader <== readCRLFCRLF * state[1];
// parsing field value when parsing header and read Colon `:`
signal isParsingFieldValue <== isParsingHeader * readColon;

// parsing body when reading CRLF-CRLF and parsing header
signal enableParsingBody <== readCRLFCRLF * isParsingHeader;

out <== [-disableParsingStart, disableParsingStart - disableParsingHeader, disableParsingHeader];
// parsing_start = out[0] = enable header (default 1) + increment start - disable start
// parsing_header = out[1] = enable header + increment header - disable header
// parsing_field_name = out[2] = enable header + increment header - parsing field value - parsing body
// parsing_field_value = out[3] = parsing field value - increment parsing header (zeroed every time new header starts)
// parsing_body = out[4] = enable body
out <== [incrementParsingStart - disableParsingStart, enableParsingHeader + incrementParsingHeader - disableParsingHeader, enableParsingHeader + incrementParsingHeader - isParsingFieldValue - enableParsingBody, isParsingFieldValue - incrementParsingHeader, enableParsingBody];
}
2 changes: 1 addition & 1 deletion circuits/test/common/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ export function readJSONInputFile(filename: string, key: any[]): [number[], numb
return [input, keyUnicode, output];
}

function toByte(data: string): number[] {
export function toByte(data: string): number[] {
const byteArray = [];
for (let i = 0; i < data.length; i++) {
byteArray.push(data.charCodeAt(i));
Expand Down
39 changes: 37 additions & 2 deletions circuits/test/http/extractor.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { circomkit, WitnessTester, generateDescription, readHTTPInputFile } from "../common";
import { circomkit, WitnessTester, generateDescription, readHTTPInputFile, toByte } from "../common";

describe("HTTP :: Extractor", async () => {
describe("HTTP :: body Extractor", async () => {
let circuit: WitnessTester<["data"], ["response"]>;


Expand Down Expand Up @@ -50,4 +50,39 @@ describe("HTTP :: Extractor", async () => {
output3.pop();
generatePassCase(parsedHttp.input, output3, "output length less than actual length");
});
});

describe("HTTP :: header Extractor", async () => {
let circuit: WitnessTester<["data", "header"], ["value"]>;

function generatePassCase(input: number[], headerName: number[], headerValue: number[], desc: string) {
const description = generateDescription(input);

it(`(valid) witness: ${description} ${desc}`, async () => {
circuit = await circomkit.WitnessTester(`ExtractHeaderValue`, {
file: "circuits/http/extractor",
template: "ExtractHeaderValue",
params: [input.length, headerName.length, headerValue.length],
});
console.log("#constraints:", await circuit.getConstraintCount());

await circuit.expectPass({ data: input, header: headerName }, { value: headerValue });
});
}

describe("response", async () => {

let parsedHttp = readHTTPInputFile("get_response.http");

generatePassCase(parsedHttp.input, toByte("Content-Length"), toByte(parsedHttp.headers["Content-Length"]), "");

// let output2 = parsedHttp.bodyBytes.slice(0);
// output2.push(0, 0, 0, 0);
// generatePassCase(parsedHttp.input, output2, "output length more than actual length");

// let output3 = parsedHttp.bodyBytes.slice(0);
// output3.pop();
// // output3.pop(); // TODO: fails due to shift subarray bug
// generatePassCase(parsedHttp.input, output3, "output length less than actual length");
});
});

0 comments on commit 1d814b3

Please sign in to comment.