diff --git a/integration/websockets/e2e/ws-gateway.spec.ts b/integration/websockets/e2e/ws-gateway.spec.ts index e96b05150af..4130ab200e7 100644 --- a/integration/websockets/e2e/ws-gateway.spec.ts +++ b/integration/websockets/e2e/ws-gateway.spec.ts @@ -5,7 +5,10 @@ import { expect } from 'chai'; import * as WebSocket from 'ws'; import { ApplicationGateway } from '../src/app.gateway'; import { CoreGateway } from '../src/core.gateway'; +import { ExamplePathGateway } from '../src/example-path.gateway'; import { ServerGateway } from '../src/server.gateway'; +import { WsPathGateway } from '../src/ws-path.gateway'; +import { WsPathGateway2 } from '../src/ws-path2.gateway'; async function createNestApp(...gateways): Promise { const testingModule = await Test.createTestingModule({ @@ -65,7 +68,81 @@ describe('WebSocketGateway (WsAdapter)', () => { ); }); - it(`should support 2 different gateways`, async function () { + it(`should handle message on a different path`, async () => { + app = await createNestApp(WsPathGateway); + await app.listenAsync(3000); + try { + ws = new WebSocket('ws://localhost:3000/ws-path'); + await new Promise((resolve, reject) => { + ws.on('open', resolve); + ws.on('error', reject); + }); + + ws.send( + JSON.stringify({ + event: 'push', + data: { + test: 'test', + }, + }), + ); + await new Promise(resolve => + ws.on('message', data => { + expect(JSON.parse(data).data.test).to.be.eql('test'); + resolve(); + }), + ); + } catch (err) { + console.log(err); + } + }); + + it(`should support 2 different gateways running on different paths`, async function () { + this.retries(10); + + app = await createNestApp(ExamplePathGateway, WsPathGateway2); + await app.listenAsync(3000); + + // open websockets delay + await new Promise(resolve => setTimeout(resolve, 1000)); + + ws = new WebSocket('ws://localhost:8082/example'); + ws2 = new WebSocket('ws://localhost:8082/ws-path'); + + await new Promise(resolve => + ws.on('open', () => { + ws.on('message', data => { + expect(JSON.parse(data).data.test).to.be.eql('test'); + resolve(); + }); + ws.send( + JSON.stringify({ + event: 'push', + data: { + test: 'test', + }, + }), + ); + }), + ); + + await new Promise(resolve => { + ws2.on('message', data => { + expect(JSON.parse(data).data.test).to.be.eql('test'); + resolve(); + }); + ws2.send( + JSON.stringify({ + event: 'push', + data: { + test: 'test', + }, + }), + ); + }); + }); + + it(`should support 2 different gateways running on the same path (but different ports)`, async function () { this.retries(10); app = await createNestApp(ApplicationGateway, CoreGateway); diff --git a/integration/websockets/src/example-path.gateway.ts b/integration/websockets/src/example-path.gateway.ts new file mode 100644 index 00000000000..31c4b2c63b0 --- /dev/null +++ b/integration/websockets/src/example-path.gateway.ts @@ -0,0 +1,14 @@ +import { SubscribeMessage, WebSocketGateway } from '@nestjs/websockets'; + +@WebSocketGateway(8082, { + path: '/example', +}) +export class ExamplePathGateway { + @SubscribeMessage('push') + onPush(client, data) { + return { + event: 'pop', + data, + }; + } +} diff --git a/integration/websockets/src/ws-path.gateway.ts b/integration/websockets/src/ws-path.gateway.ts new file mode 100644 index 00000000000..d40d098480e --- /dev/null +++ b/integration/websockets/src/ws-path.gateway.ts @@ -0,0 +1,14 @@ +import { SubscribeMessage, WebSocketGateway } from '@nestjs/websockets'; + +@WebSocketGateway({ + path: '/ws-path', +}) +export class WsPathGateway { + @SubscribeMessage('push') + onPush(client, data) { + return { + event: 'pop', + data, + }; + } +} diff --git a/integration/websockets/src/ws-path2.gateway.ts b/integration/websockets/src/ws-path2.gateway.ts new file mode 100644 index 00000000000..0c4a9d69d21 --- /dev/null +++ b/integration/websockets/src/ws-path2.gateway.ts @@ -0,0 +1,14 @@ +import { SubscribeMessage, WebSocketGateway } from '@nestjs/websockets'; + +@WebSocketGateway(8082, { + path: '/ws-path', +}) +export class WsPathGateway2 { + @SubscribeMessage('push') + onPush(client, data) { + return { + event: 'pop', + data, + }; + } +} diff --git a/package.json b/package.json index 52f4520d6b9..faa61043801 100644 --- a/package.json +++ b/package.json @@ -25,7 +25,7 @@ "format": "prettier \"**/*.ts\" --ignore-path ./.prettierignore --write && git status", "postinstall": "opencollective", "test": "nyc --require ts-node/register mocha packages/**/*.spec.ts --reporter spec --retries 3 --require 'node_modules/reflect-metadata/Reflect.js' --exit", - "test:integration": "mocha \"integration/*/{,!(node_modules)/**/}/*.spec.ts\" --reporter spec --require ts-node/register --require 'node_modules/reflect-metadata/Reflect.js' --exit", + "test:integration": "mocha \"integration/**/{,!(node_modules)/**/}/*.spec.ts\" --reporter spec --require ts-node/register --require 'node_modules/reflect-metadata/Reflect.js' --exit", "test:docker:up": "docker-compose -f integration/docker-compose.yml up -d", "test:docker:down": "docker-compose -f integration/docker-compose.yml down", "lint": "concurrently 'npm run lint:packages' 'npm run lint:integration' 'npm run lint:spec'", diff --git a/packages/platform-ws/adapters/ws-adapter.ts b/packages/platform-ws/adapters/ws-adapter.ts index 1c123ab564f..b0047dc8926 100644 --- a/packages/platform-ws/adapters/ws-adapter.ts +++ b/packages/platform-ws/adapters/ws-adapter.ts @@ -7,6 +7,7 @@ import { ERROR_EVENT, } from '@nestjs/websockets/constants'; import { MessageMappingProperties } from '@nestjs/websockets/gateway-metadata-explorer'; +import * as http from 'http'; import { EMPTY as empty, fromEvent, Observable } from 'rxjs'; import { filter, first, mergeMap, share, takeUntil } from 'rxjs/operators'; @@ -19,8 +20,23 @@ enum READY_STATE { CLOSED_STATE = 3, } +type HttpServerRegistryKey = number; +type HttpServerRegistryEntry = any; +type WsServerRegistryKey = number; +type WsServerRegistryEntry = any[]; + +const UNDERLYING_HTTP_SERVER_PORT = 0; + export class WsAdapter extends AbstractWsAdapter { protected readonly logger = new Logger(WsAdapter.name); + protected readonly httpServersRegistry = new Map< + HttpServerRegistryKey, + HttpServerRegistryEntry + >(); + protected readonly wsServersRegistry = new Map< + WsServerRegistryKey, + WsServerRegistryEntry + >(); constructor(appOrHttpServer?: INestApplicationContext | any) { super(appOrHttpServer); @@ -39,7 +55,7 @@ export class WsAdapter extends AbstractWsAdapter { this.logger.error(error); throw error; } - if (port === 0 && this.httpServer) { + if (port === UNDERLYING_HTTP_SERVER_PORT && this.httpServer) { return this.bindErrorHandler( new wsPackage.Server({ server: this.httpServer, @@ -47,14 +63,33 @@ export class WsAdapter extends AbstractWsAdapter { }), ); } - return server - ? server - : this.bindErrorHandler( - new wsPackage.Server({ - port, - ...wsOptions, - }), - ); + + if (server) { + // When server exists already + return server; + } + if (options.path && port !== UNDERLYING_HTTP_SERVER_PORT) { + // Multiple servers with different paths + // sharing a single HTTP/S server running on different port + // than a regular HTTP application + this.ensureHttpServerExists(port); + + const wsServer = this.bindErrorHandler( + new wsPackage.Server({ + noServer: true, + ...wsOptions, + }), + ); + this.addWsServerToRegistry(wsServer, port, options.path); + return wsServer; + } + const wsServer = this.bindErrorHandler( + new wsPackage.Server({ + port, + ...wsOptions, + }), + ); + return wsServer; } public bindMessageHandlers( @@ -98,7 +133,7 @@ export class WsAdapter extends AbstractWsAdapter { } public bindErrorHandler(server: any) { - server.on(CONNECTION_EVENT, ws => + server.on(CONNECTION_EVENT, (ws: any) => ws.on(ERROR_EVENT, (err: any) => this.logger.error(err)), ); server.on(ERROR_EVENT, (err: any) => this.logger.error(err)); @@ -108,4 +143,55 @@ export class WsAdapter extends AbstractWsAdapter { public bindClientDisconnect(client: any, callback: Function) { client.on(CLOSE_EVENT, callback); } + + public async dispose() { + const closeEvents = Array.from(this.httpServersRegistry).map( + ([_, server]) => new Promise(resolve => server.close(resolve)), + ); + await Promise.all(closeEvents); + this.httpServersRegistry.clear(); + this.wsServersRegistry.clear(); + } + + protected ensureHttpServerExists(port: number) { + if (this.httpServersRegistry.has(port)) { + return; + } + const httpServer = http.createServer(); + this.httpServersRegistry.set(port, httpServer); + + httpServer.on('upgrade', (request, socket, head) => { + const baseUrl = 'ws://' + request.headers.host + '/'; + const pathname = new URL(request.url, baseUrl).pathname; + const wsServersCollection = this.wsServersRegistry.get(port); + + let isRequestDelegated = false; + for (const wsServer of wsServersCollection) { + if (pathname === wsServer.path) { + wsServer.handleUpgrade(request, socket, head, (ws: unknown) => { + wsServer.emit('connection', ws, request); + }); + isRequestDelegated = true; + break; + } + } + if (!isRequestDelegated) { + socket.destroy(); + } + }); + + httpServer.listen(port); + } + + protected addWsServerToRegistry = any>( + wsServer: T, + port: number, + path: string, + ) { + const entries = this.wsServersRegistry.get(port) ?? []; + entries.push(wsServer); + + wsServer.path = path; + this.wsServersRegistry.set(port, entries); + } } diff --git a/packages/websockets/adapters/ws-adapter.ts b/packages/websockets/adapters/ws-adapter.ts index 6d3eab7747d..be06709e8e3 100644 --- a/packages/websockets/adapters/ws-adapter.ts +++ b/packages/websockets/adapters/ws-adapter.ts @@ -33,11 +33,13 @@ export abstract class AbstractWsAdapter< client.on(DISCONNECT_EVENT, callback); } - public close(server: TServer) { + public async close(server: TServer) { const isCallable = server && isFunction(server.close); - isCallable && server.close(); + isCallable && (await new Promise(resolve => server.close(resolve))); } + public async dispose() {} + public abstract create(port: number, options?: TOptions): TServer; public abstract bindMessageHandlers( client: TClient, diff --git a/packages/websockets/socket-module.ts b/packages/websockets/socket-module.ts index 532c5180e63..17a6ebe0f8d 100644 --- a/packages/websockets/socket-module.ts +++ b/packages/websockets/socket-module.ts @@ -10,6 +10,7 @@ import { InterceptorsContextCreator } from '@nestjs/core/interceptors/intercepto import { PipesConsumer } from '@nestjs/core/pipes/pipes-consumer'; import { PipesContextCreator } from '@nestjs/core/pipes/pipes-context-creator'; import { iterate } from 'iterare'; +import { AbstractWsAdapter } from './adapters'; import { GATEWAY_METADATA } from './constants'; import { ExceptionFiltersContext } from './context/exception-filters-context'; import { WsContextCreator } from './context/ws-context-creator'; @@ -92,6 +93,8 @@ export class SocketModule { .filter(({ server }) => server) .map(async ({ server }) => adapter.close(server)), ); + await (adapter as AbstractWsAdapter)?.dispose(); + this.socketsContainer.clear(); } diff --git a/packages/websockets/socket-server-provider.ts b/packages/websockets/socket-server-provider.ts index f5e1c785e93..eab5a11d513 100644 --- a/packages/websockets/socket-server-provider.ts +++ b/packages/websockets/socket-server-provider.ts @@ -20,7 +20,11 @@ export class SocketServerProvider { path: options.path, }); if (serverAndStreamsHost && options.namespace) { - return this.decorateWithNamespace(options, port, serverAndStreamsHost); + return this.decorateWithNamespace( + options, + port, + serverAndStreamsHost.server, + ); } return serverAndStreamsHost ? serverAndStreamsHost @@ -65,7 +69,7 @@ export class SocketServerProvider { namespaceServer, ); this.socketsContainer.addOne( - { port, path: options.path }, + { port, path: options.path, namespace: options.namespace }, serverAndEventStreamsHost, ); return serverAndEventStreamsHost;