From 1b0ac32bd9f6384636637517fa32deae759b620e Mon Sep 17 00:00:00 2001 From: Daniel Hensby Date: Thu, 17 Oct 2024 09:35:31 +0200 Subject: [PATCH] feat: add credentialProvider option when creating clients In some instances, credentials for the redis client will be short-lived and need to be fetched on-demand when connecting to redis. This is the case when connecting in AWS using IAM authentication or Entra ID in Azure. This feature allows for a credentialProvider to be provided which is a callable function returning a Promise that resolves to a username/password object. --- docs/client-configuration.md | 1 + packages/client/lib/client/index.spec.ts | 16 +++++++++ packages/client/lib/client/index.ts | 44 ++++++++++++++++++------ 3 files changed, 51 insertions(+), 10 deletions(-) diff --git a/docs/client-configuration.md b/docs/client-configuration.md index deb68437e1..6a211cb23e 100644 --- a/docs/client-configuration.md +++ b/docs/client-configuration.md @@ -22,6 +22,7 @@ | scripts | | Script definitions (see [Lua Scripts](../README.md#lua-scripts)) | | functions | | Function definitions (see [Functions](../README.md#functions)) | | commandsQueueMaxLength | | Maximum length of the client's internal command queue | +| credentialSupplier | | A callable function that returns a Promise which resolves to an object with username and password properties | | disableOfflineQueue | `false` | Disables offline queuing, see [FAQ](./FAQ.md#what-happens-when-the-network-goes-down) | | readonly | `false` | Connect in [`READONLY`](https://redis.io/commands/readonly) mode | | legacyMode | `false` | Maintain some backwards compatibility (see the [Migration Guide](./v3-to-v4.md)) | diff --git a/packages/client/lib/client/index.spec.ts b/packages/client/lib/client/index.spec.ts index cd2040ec97..b6525c30fb 100644 --- a/packages/client/lib/client/index.spec.ts +++ b/packages/client/lib/client/index.spec.ts @@ -1,4 +1,5 @@ import { strict as assert } from 'node:assert'; +import { setTimeout } from 'node:timers/promises'; import testUtils, { GLOBAL, waitTillBeenCalled } from '../test-utils'; import RedisClient, { RedisClientType } from '.'; import { AbortError, ClientClosedError, ClientOfflineError, ConnectionTimeoutError, DisconnectsClientError, ErrorReply, MultiErrorReply, SocketClosedUnexpectedlyError, WatchError } from '../errors'; @@ -103,6 +104,21 @@ describe('Client', () => { }, minimumDockerVersion: [6, 2] }); + + testUtils.testWithClient('should accept a credentialSupplier', async client => { + assert.equal( + await client.ping(), + 'PONG' + ); + }, { + ...GLOBAL.SERVERS.PASSWORD, + clientOptions: { + // simulate a slight pause to fetch the credentials + credentialSupplier: () => setTimeout(50).then(() => Promise.resolve({ + ...GLOBAL.SERVERS.PASSWORD.clientOptions, + })), + } + }); }); testUtils.testWithClient('should set connection name', async client => { diff --git a/packages/client/lib/client/index.ts b/packages/client/lib/client/index.ts index 55355a133d..465ca28aae 100644 --- a/packages/client/lib/client/index.ts +++ b/packages/client/lib/client/index.ts @@ -11,12 +11,15 @@ import { Command, CommandSignature, TypeMapping, CommanderConfig, RedisFunction, import RedisClientMultiCommand, { RedisClientMultiCommandType } from './multi-command'; import { RedisMultiQueuedCommand } from '../multi-command'; import HELLO, { HelloOptions } from '../commands/HELLO'; +import { AuthOptions } from '../commands/AUTH'; import { ScanOptions, ScanCommonOptions } from '../commands/SCAN'; import { RedisLegacyClient, RedisLegacyClientType } from './legacy-mode'; import { RedisPoolOptions, RedisClientPool } from './pool'; import { RedisVariadicArgument, parseArgs, pushVariadicArguments } from '../commands/generic-transformers'; import { BasicCommandParser, CommandParser } from './parser'; +export type RedisCredentialSupplier = () => Promise; + export interface RedisClientOptions< M extends RedisModules = RedisModules, F extends RedisFunctions = RedisFunctions, @@ -34,6 +37,10 @@ export interface RedisClientOptions< * Socket connection properties */ socket?: SocketOptions; + /** + * Credential supplier callback function + */ + credentialSupplier?: RedisCredentialSupplier; /** * ACL username ([see ACL guide](https://redis.io/topics/acl)) */ @@ -276,6 +283,7 @@ export default class RedisClient< readonly #options?: RedisClientOptions; readonly #socket: RedisSocket; readonly #queue: RedisCommandsQueue; + #credentialSupplier: RedisCredentialSupplier; #selectedDB = 0; #monitorCallback?: MonitorCallback; private _self = this; @@ -313,6 +321,8 @@ export default class RedisClient< this.#options = this.#initiateOptions(options); this.#queue = this.#initiateQueue(); this.#socket = this.#initiateSocket(); + this.#credentialSupplier = this.#initiateCredentialSupplier(); + this.#epoch = 0; } @@ -345,16 +355,16 @@ export default class RedisClient< ); } - #handshake(selectedDB: number) { + #handshake(selectedDB: number, credential?: AuthOptions) { const commands = []; if (this.#options?.RESP) { const hello: HelloOptions = {}; - if (this.#options.password) { + if (credential?.password) { hello.AUTH = { - username: this.#options.username ?? 'default', - password: this.#options.password + username: credential?.username ?? 'default', + password: credential?.password }; } @@ -366,11 +376,11 @@ export default class RedisClient< parseArgs(HELLO, this.#options.RESP, hello) ); } else { - if (this.#options?.username || this.#options?.password) { + if (credential) { commands.push( parseArgs(COMMANDS.AUTH, { - username: this.#options.username, - password: this.#options.password ?? '' + username: credential.username, + password: credential.password ?? '' }) ); } @@ -396,7 +406,11 @@ export default class RedisClient< } #initiateSocket(): RedisSocket { - const socketInitiator = () => { + const socketInitiator = async () => { + // we have to call the credential fetch before pushing any commands into the queue, + // so fetch the credentials before doing anything else. + const credential: AuthOptions | undefined = await this.#credentialSupplier(); + const promises = [], chainId = Symbol('Socket Initiator'); @@ -418,7 +432,7 @@ export default class RedisClient< ); } - const commands = this.#handshake(this.#selectedDB); + const commands = this.#handshake(this.#selectedDB, credential); for (let i = commands.length - 1; i >= 0; --i) { promises.push( this.#queue.addCommand(commands[i], { @@ -463,6 +477,15 @@ export default class RedisClient< .on('end', () => this.emit('end')); } + #initiateCredentialSupplier(): RedisCredentialSupplier { + // if a credential supplier has been provided, use it, otherwise create a provider from the + // supplier username and password (if provided) + return this.#options?.credentialSupplier ?? (() => Promise.resolve((this.#options?.username || this.#options?.password) ? { + username: this.#options?.username, + password: this.#options?.password ?? '', + } : undefined)); + } + #pingTimer?: NodeJS.Timeout; #setPingTimer(): void { @@ -997,10 +1020,11 @@ export default class RedisClient< * Reset the client to its default state (i.e. stop PubSub, stop monitoring, select default DB, etc.) */ async reset() { + const credential: AuthOptions | undefined = await this.#credentialSupplier?.(); const chainId = Symbol('Reset Chain'), promises = [this._self.#queue.reset(chainId)], selectedDB = this._self.#options?.database ?? 0; - for (const command of this._self.#handshake(selectedDB)) { + for (const command of this._self.#handshake(selectedDB, credential)) { promises.push( this._self.#queue.addCommand(command, { chainId