diff --git a/src/encoder.ts b/src/encoder.ts index 60eb1c6..fd10fdd 100644 --- a/src/encoder.ts +++ b/src/encoder.ts @@ -25,6 +25,10 @@ export function encode1(message: MessageBuffer): bytes { return Buffer.concat([message.ne, message.ns, message.ciphertext]); } +export function encode2(message: MessageBuffer): bytes { + return Buffer.concat([message.ns, message.ciphertext]); +} + export function decode0(input: bytes): MessageBuffer { if (input.length < 32) { throw new Error("Cannot decode stage 0 MessageBuffer: length less than 32 bytes."); @@ -39,7 +43,7 @@ export function decode0(input: bytes): MessageBuffer { export function decode1(input: bytes): MessageBuffer { if (input.length < 80) { - throw new Error("Cannot decode stage 0 MessageBuffer: length less than 96 bytes."); + throw new Error("Cannot decode stage 1 MessageBuffer: length less than 80 bytes."); } return { @@ -48,3 +52,15 @@ export function decode1(input: bytes): MessageBuffer { ciphertext: input.slice(80, input.length), } } + +export function decode2(input: bytes): MessageBuffer { + if (input.length < 48) { + throw new Error("Cannot decode stage 2 MessageBuffer: length less than 48 bytes."); + } + + return { + ne: Buffer.alloc(0), + ns: input.slice(0, 48), + ciphertext: input.slice(48, input.length), + } +} diff --git a/src/handshake-xx.ts b/src/handshake-xx.ts index 2db646b..41360e5 100644 --- a/src/handshake-xx.ts +++ b/src/handshake-xx.ts @@ -11,7 +11,7 @@ import { verifySignedPayload, } from "./utils"; import { logger } from "./logger"; -import { decode0, decode1, encode0, encode1 } from "./encoder"; +import {decode0, decode1, decode2, encode0, encode1, encode2} from "./encoder"; import { WrappedConnection } from "./noise"; import PeerId from "peer-id"; @@ -99,11 +99,11 @@ export class XXHandshake implements IHandshake { if (this.isInitiator) { logger('Stage 2 - Initiator sending third handshake message.'); const messageBuffer = this.xx.sendMessage(this.session, this.payload); - this.connection.writeLP(encode1(messageBuffer)); + this.connection.writeLP(encode2(messageBuffer)); logger('Stage 2 - Initiator sent message with signed payload.'); } else { logger('Stage 2 - Responder waiting for third handshake message...'); - const receivedMessageBuffer = decode1((await this.connection.readLP()).slice()); + const receivedMessageBuffer = decode2((await this.connection.readLP()).slice()); const {plaintext, valid} = this.xx.recvMessage(this.session, receivedMessageBuffer); if(!valid) { throw new Error("xx handshake stage 2 validation fail"); diff --git a/test/noise.test.ts b/test/noise.test.ts index 1fceb50..a22dd0b 100644 --- a/test/noise.test.ts +++ b/test/noise.test.ts @@ -1,18 +1,13 @@ -import { expect, assert } from "chai"; +import {assert, expect} from "chai"; import DuplexPair from 'it-pair/duplex'; -import { Noise } from "../src"; +import {Noise} from "../src"; import {createPeerIdsFromFixtures} from "./fixtures/peer"; import Wrap from "it-pb-rpc"; -import { random } from "bcrypto"; +import {random} from "bcrypto"; import sinon from "sinon"; import {XXHandshake} from "../src/handshake-xx"; -import { - createHandshakePayload, - generateKeypair, - getHandshakePayload, getPayload, - signPayload -} from "../src/utils"; -import {decode0, decode1, encode1, uint16BEDecode, uint16BEEncode} from "../src/encoder"; +import {createHandshakePayload, generateKeypair, getHandshakePayload, getPayload, signPayload} from "../src/utils"; +import {decode0, decode2, encode1, uint16BEDecode, uint16BEEncode} from "../src/encoder"; import {XX} from "../src/handshakes/xx"; import {Buffer} from "buffer"; import {getKeyPairFromPeerId} from "./utils"; @@ -89,7 +84,7 @@ describe("Noise", () => { wrapped.writeLP(encode1(messageBuffer)); // Stage 2 - finish handshake - receivedMessageBuffer = decode1((await wrapped.readLP()).slice()); + receivedMessageBuffer = decode2((await wrapped.readLP()).slice()); xx.recvMessage(handshake.session, receivedMessageBuffer); return {wrapped, handshake}; })(),