Skip to content

Commit

Permalink
Merge pull request #14185 from sapenlei/fix/prevent-socketio-server-c…
Browse files Browse the repository at this point in the history
…lose

fix(websockets): Prevent HTTP server early close in Socket.IO shutdown
  • Loading branch information
kamilmysliwiec authored Nov 25, 2024
2 parents ae0517b + 3b5cb56 commit 95c8547
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 2 deletions.
75 changes: 75 additions & 0 deletions integration/websockets/e2e/gateway.spec.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import { INestApplication } from '@nestjs/common';
import { Test } from '@nestjs/testing';
import { expect } from 'chai';
import * as EventSource from 'eventsource';
import { io } from 'socket.io-client';
import { AppController as LongConnectionController } from '../../nest-application/sse/src/app.controller';
import { ApplicationGateway } from '../src/app.gateway';
import { NamespaceGateway } from '../src/namespace.gateway';
import { ServerGateway } from '../src/server.gateway';
Expand Down Expand Up @@ -98,5 +100,78 @@ describe('WebSocketGateway', () => {
);
});

describe('Shared Server for WebSocket and Long-Running Connections', () => {
afterEach(() => {});
it('should block application shutdown', function (done) {
let eventSource;

(async () => {
this.timeout(30000);

setTimeout(() => {
const instance = testingModule.get(ServerGateway);
expect(instance.onApplicationShutdown.called).to.be.false;
eventSource.close();
done();
}, 25000);

const testingModule = await Test.createTestingModule({
providers: [ServerGateway],
controllers: [LongConnectionController],
}).compile();
app = testingModule.createNestApplication();

await app.listen(3000);

ws = io(`http://localhost:3000`);
eventSource = new EventSource(`http://localhost:3000/sse`);

await new Promise((resolve, reject) => {
ws.on('connect', resolve);
ws.on('error', reject);
});

await new Promise((resolve, reject) => {
eventSource.onmessage = resolve;
eventSource.onerror = reject;
});

app.close();
})();
});

it('should shutdown application immediately when forceCloseConnections is true', async () => {
const testingModule = await Test.createTestingModule({
providers: [ServerGateway],
controllers: [LongConnectionController],
}).compile();

app = testingModule.createNestApplication({
forceCloseConnections: true,
});

await app.listen(3000);

ws = io(`http://localhost:3000`);
const eventSource = new EventSource(`http://localhost:3000/sse`);

await new Promise((resolve, reject) => {
ws.on('connect', resolve);
ws.on('error', reject);
});

await new Promise((resolve, reject) => {
eventSource.onmessage = resolve;
eventSource.onerror = reject;
});

await app.close();

const instance = testingModule.get(ServerGateway);
expect(instance.onApplicationShutdown.called).to.be.true;
eventSource.close();
});
});

afterEach(() => app.close());
});
7 changes: 5 additions & 2 deletions integration/websockets/src/server.gateway.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { UseInterceptors } from '@nestjs/common';
import { OnApplicationShutdown, UseInterceptors } from '@nestjs/common';
import { SubscribeMessage, WebSocketGateway } from '@nestjs/websockets';
import * as Sinon from 'sinon';
import { RequestInterceptor } from './request.interceptor';

@WebSocketGateway()
export class ServerGateway {
export class ServerGateway implements OnApplicationShutdown {
@SubscribeMessage('push')
onPush(client, data) {
return {
Expand All @@ -20,4 +21,6 @@ export class ServerGateway {
data: { ...data, path: client.pattern },
};
}

onApplicationShutdown = Sinon.spy();
}
8 changes: 8 additions & 0 deletions packages/websockets/adapters/ws-adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export abstract class AbstractWsAdapter<
> implements WebSocketAdapter<TServer, TClient, TOptions>
{
protected readonly httpServer: any;
private _forceCloseConnections: boolean;

constructor(appOrHttpServer?: INestApplicationContext | any) {
if (appOrHttpServer && appOrHttpServer instanceof NestApplication) {
Expand All @@ -26,6 +27,10 @@ export abstract class AbstractWsAdapter<
}
}

public set forceCloseConnections(value: boolean) {
this._forceCloseConnections = value;
}

public bindClientConnect(server: TServer, callback: Function) {
server.on(CONNECTION_EVENT, callback);
}
Expand All @@ -35,6 +40,9 @@ export abstract class AbstractWsAdapter<
}

public async close(server: TServer) {
if (this._forceCloseConnections) {
return;
}
const isCallable = server && isFunction(server.close);
isCallable && (await new Promise(resolve => server.close(resolve)));
}
Expand Down
6 changes: 6 additions & 0 deletions packages/websockets/socket-module.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { NestApplicationOptions } from '@nestjs/common';
import { InjectionToken } from '@nestjs/common/interfaces';
import { Injectable } from '@nestjs/common/interfaces/injectable.interface';
import { NestApplicationContextOptions } from '@nestjs/common/interfaces/nest-application-context-options.interface';
Expand Down Expand Up @@ -113,8 +114,12 @@ export class SocketModule<
}

private initializeAdapter() {
const forceCloseConnections = (this.appOptions as NestApplicationOptions)
.forceCloseConnections;
const adapter = this.applicationConfig.getIoAdapter();
if (adapter) {
(adapter as AbstractWsAdapter).forceCloseConnections =
forceCloseConnections;
this.isAdapterInitialized = true;
return;
}
Expand All @@ -124,6 +129,7 @@ export class SocketModule<
() => require('@nestjs/platform-socket.io'),
);
const ioAdapter = new IoAdapter(this.httpServer);
ioAdapter.forceCloseConnections = forceCloseConnections;
this.applicationConfig.setIoAdapter(ioAdapter);

this.isAdapterInitialized = true;
Expand Down

0 comments on commit 95c8547

Please sign in to comment.