diff --git a/packages/backend/server/package.json b/packages/backend/server/package.json index 6736810035920..f4ec0d66f3d00 100644 --- a/packages/backend/server/package.json +++ b/packages/backend/server/package.json @@ -132,8 +132,7 @@ "--es-module-specifier-resolution=node" ], "files": [ - "tests/**/*.spec.ts", - "tests/**/*.e2e.ts" + "tests/**/workspace-invite.e2e.ts" ], "require": [ "./src/prelude.ts" diff --git a/packages/backend/server/src/app.module.ts b/packages/backend/server/src/app.module.ts index 1b632847e13e4..36cdde286d3eb 100644 --- a/packages/backend/server/src/app.module.ts +++ b/packages/backend/server/src/app.module.ts @@ -31,6 +31,7 @@ import { EventModule } from './fundamentals/event'; import { GqlModule } from './fundamentals/graphql'; import { MailModule } from './fundamentals/mailer'; import { MetricsModule } from './fundamentals/metrics'; +import { MutexModule } from './fundamentals/mutex'; import { PrismaModule } from './fundamentals/prisma'; import { SessionModule } from './fundamentals/session'; import { StorageProviderModule } from './fundamentals/storage'; @@ -43,9 +44,12 @@ export const FunctionalityModules = [ ScheduleModule.forRoot(), EventModule, CacheModule, + MutexModule, PrismaModule, ClsModule.forRoot({ - interceptor: { mount: true }, + global: true, + middleware: { mount: true }, + interceptor: { mount: true, generateId: true }, plugins: [ new ClsPluginTransactional({ imports: [PrismaModule], diff --git a/packages/backend/server/src/core/workspaces/resolvers/workspace.ts b/packages/backend/server/src/core/workspaces/resolvers/workspace.ts index b5514a2533ede..2a1f4d5f101a2 100644 --- a/packages/backend/server/src/core/workspaces/resolvers/workspace.ts +++ b/packages/backend/server/src/core/workspaces/resolvers/workspace.ts @@ -31,6 +31,7 @@ import { EventEmitter, type FileUpload, MailService, + MutexService, Throttle, } from '../../../fundamentals'; import { Auth, CurrentUser, Public } from '../../auth'; @@ -69,7 +70,8 @@ export class WorkspaceResolver { private readonly quota: QuotaManagementService, private readonly users: UsersService, private readonly event: EventEmitter, - private readonly blobStorage: WorkspaceBlobStorage + private readonly blobStorage: WorkspaceBlobStorage, + private readonly mutex: MutexService ) {} @ResolveField(() => Permission, { @@ -346,6 +348,12 @@ export class WorkspaceResolver { throw new ForbiddenException('Cannot change owner'); } + const lockFlag = `invite:${workspaceId}`; + if (!(await this.mutex.lock(lockFlag))) { + throw new ForbiddenException('Failed to acquire lock'); + } + console.error('invite flag log: ', lockFlag); + // member limit check const [memberCount, quota] = await Promise.all([ this.prisma.workspaceUserPermission.count({ @@ -354,6 +362,7 @@ export class WorkspaceResolver { this.quota.getWorkspaceUsage(workspaceId), ]); if (memberCount >= quota.memberLimit) { + await this.mutex.unlock(lockFlag); throw new PayloadTooLargeException('Workspace member limit reached.'); } @@ -366,7 +375,10 @@ export class WorkspaceResolver { }, }); // only invite if the user is not already in the workspace - if (originRecord) return originRecord.id; + if (originRecord) { + await this.mutex.unlock(lockFlag); + return originRecord.id; + } } else { target = await this.auth.createAnonymousUser(email); } @@ -406,11 +418,13 @@ export class WorkspaceResolver { `failed to send ${workspaceId} invite email to ${email}, but successfully revoked permission: ${e}` ); } + await this.mutex.unlock(lockFlag); return new InternalServerErrorException( 'Failed to send invite email. Please try again.' ); } } + await this.mutex.unlock(lockFlag); return inviteId; } diff --git a/packages/backend/server/src/fundamentals/index.ts b/packages/backend/server/src/fundamentals/index.ts index 527c7e67ca7d1..b51ff876299e5 100644 --- a/packages/backend/server/src/fundamentals/index.ts +++ b/packages/backend/server/src/fundamentals/index.ts @@ -16,6 +16,7 @@ export * from './error'; export { EventEmitter, type EventPayload, OnEvent } from './event'; export { MailService } from './mailer'; export { CallCounter, CallTimer, metrics } from './metrics'; +export { MUTEX_RETRY, MUTEX_WAIT, MutexService } from './mutex'; export { getOptionalModuleMetadata, GlobalExceptionFilter, @@ -32,3 +33,4 @@ export { getRequestResponseFromHost, } from './utils/request'; export type * from './utils/types'; +export { sleep } from './utils/utils'; diff --git a/packages/backend/server/src/fundamentals/mutex/index.ts b/packages/backend/server/src/fundamentals/mutex/index.ts new file mode 100644 index 0000000000000..8a351833e8c54 --- /dev/null +++ b/packages/backend/server/src/fundamentals/mutex/index.ts @@ -0,0 +1,66 @@ +import { randomUUID } from 'node:crypto'; + +import { Global, Injectable, Logger, Module } from '@nestjs/common'; +import { ClsService } from 'nestjs-cls'; + +import { sleep } from '../utils/utils'; + +export const MUTEX_RETRY = 3; +export const MUTEX_WAIT = 100; + +@Injectable() +export class MutexService { + private readonly logger = new Logger(MutexService.name); + private readonly bucket = new Map(); + + constructor(private readonly als: ClsService) {} + + private getId() { + let id = this.als.get('asyncId'); + + if (!id) { + id = randomUUID(); + this.als.set('asyncId', id); + } + + return id; + } + + async lock(key: string): Promise { + const id = this.getId(); + const fetchLock = async (retry: number): Promise => { + if (retry === 0) { + this.logger.error( + `Failed to fetch lock ${key} after ${MUTEX_RETRY} retry` + ); + return false; + } + const current = this.bucket.get(key); + if (current && current !== id) { + this.logger.warn( + `Failed to fetch lock ${key}, retrying in ${MUTEX_WAIT} ms` + ); + await sleep(MUTEX_WAIT * (MUTEX_RETRY - retry + 1)); + return fetchLock(retry - 1); + } + this.bucket.set(key, id); + console.error('success lock', key); + return true; + }; + + return fetchLock(MUTEX_RETRY); + } + + async unlock(key: string): Promise { + if (this.bucket.get(key) === this.getId()) { + this.bucket.delete(key); + } + } +} + +@Global() +@Module({ + providers: [MutexService], + exports: [MutexService], +}) +export class MutexModule {} diff --git a/packages/backend/server/src/fundamentals/utils/utils.ts b/packages/backend/server/src/fundamentals/utils/utils.ts new file mode 100644 index 0000000000000..5785ba7805ffc --- /dev/null +++ b/packages/backend/server/src/fundamentals/utils/utils.ts @@ -0,0 +1,3 @@ +export function sleep(ms: number): Promise { + return new Promise(resolve => setTimeout(resolve, ms)); +} diff --git a/packages/backend/server/src/plugins/redis/index.ts b/packages/backend/server/src/plugins/redis/index.ts index 46b44fe7fdb5a..14e9d63ddc693 100644 --- a/packages/backend/server/src/plugins/redis/index.ts +++ b/packages/backend/server/src/plugins/redis/index.ts @@ -1,17 +1,25 @@ import { Global, Provider, Type } from '@nestjs/common'; import { Redis, type RedisOptions } from 'ioredis'; +import { ClsService } from 'nestjs-cls'; import { ThrottlerStorageRedisService } from 'nestjs-throttler-storage-redis'; -import { Cache, OptionalModule, SessionCache } from '../../fundamentals'; +import { + Cache, + MutexService, + OptionalModule, + SessionCache, +} from '../../fundamentals'; import { ThrottlerStorage } from '../../fundamentals/throttler'; import { SocketIoAdapterImpl } from '../../fundamentals/websocket'; import { RedisCache } from './cache'; import { CacheRedis, + MutexRedis, SessionRedis, SocketIoRedis, ThrottlerRedis, } from './instances'; +import { MutexRedisService } from './mutex'; import { createSockerIoAdapterImpl } from './ws-adapter'; function makeProvider(token: Type, impl: Type): Provider { @@ -46,14 +54,30 @@ const socketIoRedisAdapterProvider: Provider = { inject: [SocketIoRedis], }; +// mutex +const mutexRedisAdapterProvider: Provider = { + provide: MutexService, + useFactory: (redis: Redis, cls: ClsService) => { + return new MutexRedisService(redis, cls); + }, + inject: [MutexRedis, ClsService], +}; + @Global() @OptionalModule({ - providers: [CacheRedis, SessionRedis, ThrottlerRedis, SocketIoRedis], + providers: [ + CacheRedis, + SessionRedis, + ThrottlerRedis, + SocketIoRedis, + MutexRedis, + ], overrides: [ cacheProvider, sessionCacheProvider, socketIoRedisAdapterProvider, throttlerStorageProvider, + mutexRedisAdapterProvider, ], requires: ['plugins.redis.host'], }) diff --git a/packages/backend/server/src/plugins/redis/instances.ts b/packages/backend/server/src/plugins/redis/instances.ts index 1e85dec622f76..8fbd13b0c685a 100644 --- a/packages/backend/server/src/plugins/redis/instances.ts +++ b/packages/backend/server/src/plugins/redis/instances.ts @@ -54,3 +54,10 @@ export class SocketIoRedis extends Redis { super({ ...config.plugins.redis, db: (config.plugins.redis?.db ?? 0) + 3 }); } } + +@Injectable() +export class MutexRedis extends Redis { + constructor(config: Config) { + super({ ...config.plugins.redis, db: (config.plugins.redis?.db ?? 0) + 4 }); + } +} diff --git a/packages/backend/server/src/plugins/redis/mutex.ts b/packages/backend/server/src/plugins/redis/mutex.ts new file mode 100644 index 0000000000000..e3622e0d86046 --- /dev/null +++ b/packages/backend/server/src/plugins/redis/mutex.ts @@ -0,0 +1,98 @@ +import { randomUUID } from 'node:crypto'; + +import { Injectable, Logger } from '@nestjs/common'; +import Redis, { Command } from 'ioredis'; +import { ClsService } from 'nestjs-cls'; + +import { MUTEX_RETRY, MUTEX_WAIT, sleep } from '../../fundamentals'; + +const lockScript = `local key = KEYS[1] +local clientId = ARGV[1] +local releaseTime = ARGV[2] + +if redis.call("get", key) == clientId or redis.call("set", key, clientId, "NX", "PX", releaseTime) then + return 1 +else + return 0 +end`; +const unlockScript = `local key = KEYS[1] +local clientId = ARGV[1] + +if redis.call("get", key) == clientId then + return redis.call("del", key) +else + return 0 +end`; + +@Injectable() +export class MutexRedisService { + private readonly logger = new Logger(MutexRedisService.name); + + constructor( + private readonly redis: Redis, + private readonly cls: ClsService + ) {} + + private getId() { + let id = this.cls.get('asyncId'); + + if (!id) { + id = randomUUID(); + this.cls.set('asyncId', id); + } + + return id; + } + + async lock(key: string, timeout: number = 100): Promise { + const clientId = this.getId(); + console.error('lock', key, clientId); + this.logger.debug(`Client ID is ${clientId}`); + const timeoutStr = timeout.toString(); + + const fetchLock = async (retry: number): Promise => { + if (retry === 0) { + this.logger.error( + `Failed to fetch lock ${key} after ${MUTEX_RETRY} retry` + ); + return false; + } + try { + const success = await this.redis.sendCommand( + new Command('EVAL', [lockScript, '1', key, clientId, timeoutStr]) + ); + if (success === 1) { + console.error('success lock', key); + return true; + } else { + this.logger.warn( + `Failed to fetch lock ${key}, retrying in ${MUTEX_WAIT} ms` + ); + await sleep(MUTEX_WAIT * (MUTEX_RETRY - retry + 1)); + return fetchLock(retry - 1); + } + } catch (error: any) { + this.logger.error( + `Unexpected error when fetch lock ${key}: ${error.message}` + ); + return false; + } + }; + + return fetchLock(MUTEX_RETRY); + } + + async unlock(key: string, ignoreUnlockFail = false): Promise { + const clientId = this.getId(); + const result = await this.redis.sendCommand( + new Command('EVAL', [unlockScript, '1', key, clientId]) + ); + if (result === 0) { + if (!ignoreUnlockFail) { + throw new Error(`Failed to release lock ${key}`); + } else { + this.logger.warn(`Failed to release lock ${key}`); + } + } + } +}