diff --git a/doc/ws.md b/doc/ws.md index 20facffbb..d728ace82 100644 --- a/doc/ws.md +++ b/doc/ws.md @@ -270,6 +270,10 @@ This class represents a WebSocket. It extends the `EventEmitter`. - `options` {Object} - `followRedirects` {Boolean} Whether or not to follow redirects. Defaults to `false`. + - `generateMask` {Function} The function used to generate the masking key. It + takes a `Buffer` that must be filled synchronously and is called before a + message is sent, for each message. By default the buffer is filled with + cryptographically strong random bytes. - `handshakeTimeout` {Number} Timeout in milliseconds for the handshake request. This is reset after every redirection. - `maxPayload` {Number} The maximum allowed message size in bytes. diff --git a/lib/sender.js b/lib/sender.js index 4490a623b..82cc662d6 100644 --- a/lib/sender.js +++ b/lib/sender.js @@ -11,7 +11,7 @@ const { EMPTY_BUFFER } = require('./constants'); const { isValidStatusCode } = require('./validation'); const { mask: applyMask, toBuffer } = require('./buffer-util'); -const mask = Buffer.alloc(4); +const maskBuffer = Buffer.alloc(4); /** * HyBi Sender implementation. @@ -22,9 +22,17 @@ class Sender { * * @param {(net.Socket|tls.Socket)} socket The connection socket * @param {Object} [extensions] An object containing the negotiated extensions + * @param {Function} [generateMask] The function used to generate the masking + * key */ - constructor(socket, extensions) { + constructor(socket, extensions, generateMask) { this._extensions = extensions || {}; + + if (generateMask) { + this._generateMask = generateMask; + this._maskBuffer = Buffer.alloc(4); + } + this._socket = socket; this._firstFragment = true; @@ -42,8 +50,12 @@ class Sender { * @param {Object} options Options object * @param {Boolean} [options.fin=false] Specifies whether or not to set the * FIN bit + * @param {Function} [options.generateMask] The function used to generate the + * masking key * @param {Boolean} [options.mask=false] Specifies whether or not to mask * `data` + * @param {Buffer} [options.maskBuffer] The buffer used to store the masking + * key * @param {Number} options.opcode The opcode * @param {Boolean} [options.readOnly=false] Specifies whether `data` can be * modified @@ -81,7 +93,13 @@ class Sender { if (!options.mask) return [target, data]; - randomFillSync(mask, 0, 4); + const mask = options.maskBuffer ? options.maskBuffer : maskBuffer; + + if (options.generateMask) { + options.generateMask(mask); + } else { + randomFillSync(mask, 0, 4); + } target[1] |= 0x80; target[offset - 4] = mask[0]; @@ -156,6 +174,8 @@ class Sender { rsv1: false, opcode: 0x08, mask, + maskBuffer: this._maskBuffer, + generateMask: this._generateMask, readOnly: false }), cb @@ -200,6 +220,8 @@ class Sender { rsv1: false, opcode: 0x09, mask, + maskBuffer: this._maskBuffer, + generateMask: this._generateMask, readOnly }), cb @@ -244,6 +266,8 @@ class Sender { rsv1: false, opcode: 0x0a, mask, + maskBuffer: this._maskBuffer, + generateMask: this._generateMask, readOnly }), cb @@ -299,6 +323,8 @@ class Sender { rsv1, opcode, mask: options.mask, + maskBuffer: this._maskBuffer, + generateMask: this._generateMask, readOnly: toBuffer.readOnly }; @@ -314,6 +340,8 @@ class Sender { rsv1: false, opcode, mask: options.mask, + maskBuffer: this._maskBuffer, + generateMask: this._generateMask, readOnly: toBuffer.readOnly }), cb @@ -331,8 +359,12 @@ class Sender { * @param {Number} options.opcode The opcode * @param {Boolean} [options.fin=false] Specifies whether or not to set the * FIN bit + * @param {Function} [options.generateMask] The function used to generate the + * masking key * @param {Boolean} [options.mask=false] Specifies whether or not to mask * `data` + * @param {Buffer} [options.maskBuffer] The buffer used to store the masking + * key * @param {Boolean} [options.readOnly=false] Specifies whether `data` can be * modified * @param {Boolean} [options.rsv1=false] Specifies whether or not to set the diff --git a/lib/websocket.js b/lib/websocket.js index 130b3dc58..57710f4e1 100644 --- a/lib/websocket.js +++ b/lib/websocket.js @@ -192,6 +192,8 @@ class WebSocket extends EventEmitter { * server and client * @param {Buffer} head The first packet of the upgraded stream * @param {Object} options Options object + * @param {Function} [options.generateMask] The function used to generate the + * masking key * @param {Number} [options.maxPayload=0] The maximum allowed message size * @param {Boolean} [options.skipUTF8Validation=false] Specifies whether or * not to skip UTF-8 validation for text and close messages @@ -206,7 +208,7 @@ class WebSocket extends EventEmitter { skipUTF8Validation: options.skipUTF8Validation }); - this._sender = new Sender(socket, this._extensions); + this._sender = new Sender(socket, this._extensions, options.generateMask); this._receiver = receiver; this._socket = socket; @@ -613,6 +615,8 @@ module.exports = WebSocket; * @param {Object} [options] Connection options * @param {Boolean} [options.followRedirects=false] Whether or not to follow * redirects + * @param {Function} [options.generateMask] The function used to generate the + * masking key * @param {Number} [options.handshakeTimeout] Timeout in milliseconds for the * handshake request * @param {Number} [options.maxPayload=104857600] The maximum allowed message @@ -899,6 +903,7 @@ function initAsClient(websocket, address, protocols, options) { } websocket.setSocket(socket, head, { + generateMask: opts.generateMask, maxPayload: opts.maxPayload, skipUTF8Validation: opts.skipUTF8Validation }); diff --git a/test/websocket.test.js b/test/websocket.test.js index 0d48887de..5f9392d84 100644 --- a/test/websocket.test.js +++ b/test/websocket.test.js @@ -126,6 +126,41 @@ describe('WebSocket', () => { /^RangeError: Unsupported protocol version: 1000 \(supported versions: 8, 13\)$/ ); }); + + it('honors the `generateMask` option', (done) => { + const wss = new WebSocket.Server({ port: 0 }, () => { + const ws = new WebSocket(`ws://localhost:${wss.address().port}`, { + generateMask() {} + }); + + ws.on('open', () => { + ws.send('foo'); + }); + + ws.on('close', (code, reason) => { + assert.strictEqual(code, 1005); + assert.deepStrictEqual(reason, EMPTY_BUFFER); + + wss.close(done); + }); + }); + + wss.on('connection', (ws) => { + const chunks = []; + + ws._socket.prependListener('data', (chunk) => { + chunks.push(chunk); + }); + + ws.on('message', () => { + assert.ok( + Buffer.concat(chunks).slice(2, 6).equals(Buffer.alloc(4)) + ); + + ws.close(); + }); + }); + }); }); });