Skip to content

Commit

Permalink
feat: add credentialProvider option when creating clients
Browse files Browse the repository at this point in the history
In some instances, credentials for the redis client will be short-lived and need to be
fetched on-demand when connecting to redis. This is the case when connecting in AWS using
IAM authentication or Entra ID in Azure.

This feature allows for a credentialProvider to be provided which is a callable function
returning a Promise that resolves to a username/password object.
  • Loading branch information
dhensby committed Dec 5, 2024
1 parent ffa7d25 commit 1b0ac32
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/client-configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
| scripts | | Script definitions (see [Lua Scripts](../README.md#lua-scripts)) |
| functions | | Function definitions (see [Functions](../README.md#functions)) |
| commandsQueueMaxLength | | Maximum length of the client's internal command queue |
| credentialSupplier | | A callable function that returns a Promise which resolves to an object with username and password properties |
| disableOfflineQueue | `false` | Disables offline queuing, see [FAQ](./FAQ.md#what-happens-when-the-network-goes-down) |
| readonly | `false` | Connect in [`READONLY`](https://redis.io/commands/readonly) mode |
| legacyMode | `false` | Maintain some backwards compatibility (see the [Migration Guide](./v3-to-v4.md)) |
Expand Down
16 changes: 16 additions & 0 deletions packages/client/lib/client/index.spec.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { strict as assert } from 'node:assert';
import { setTimeout } from 'node:timers/promises';
import testUtils, { GLOBAL, waitTillBeenCalled } from '../test-utils';
import RedisClient, { RedisClientType } from '.';
import { AbortError, ClientClosedError, ClientOfflineError, ConnectionTimeoutError, DisconnectsClientError, ErrorReply, MultiErrorReply, SocketClosedUnexpectedlyError, WatchError } from '../errors';
Expand Down Expand Up @@ -103,6 +104,21 @@ describe('Client', () => {
},
minimumDockerVersion: [6, 2]
});

testUtils.testWithClient('should accept a credentialSupplier', async client => {
assert.equal(
await client.ping(),
'PONG'
);
}, {
...GLOBAL.SERVERS.PASSWORD,
clientOptions: {
// simulate a slight pause to fetch the credentials
credentialSupplier: () => setTimeout(50).then(() => Promise.resolve({
...GLOBAL.SERVERS.PASSWORD.clientOptions,
})),
}
});
});

testUtils.testWithClient('should set connection name', async client => {
Expand Down
44 changes: 34 additions & 10 deletions packages/client/lib/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ import { Command, CommandSignature, TypeMapping, CommanderConfig, RedisFunction,
import RedisClientMultiCommand, { RedisClientMultiCommandType } from './multi-command';
import { RedisMultiQueuedCommand } from '../multi-command';
import HELLO, { HelloOptions } from '../commands/HELLO';
import { AuthOptions } from '../commands/AUTH';
import { ScanOptions, ScanCommonOptions } from '../commands/SCAN';
import { RedisLegacyClient, RedisLegacyClientType } from './legacy-mode';
import { RedisPoolOptions, RedisClientPool } from './pool';
import { RedisVariadicArgument, parseArgs, pushVariadicArguments } from '../commands/generic-transformers';
import { BasicCommandParser, CommandParser } from './parser';

export type RedisCredentialSupplier = () => Promise<AuthOptions | undefined>;

export interface RedisClientOptions<
M extends RedisModules = RedisModules,
F extends RedisFunctions = RedisFunctions,
Expand All @@ -34,6 +37,10 @@ export interface RedisClientOptions<
* Socket connection properties
*/
socket?: SocketOptions;
/**
* Credential supplier callback function
*/
credentialSupplier?: RedisCredentialSupplier;
/**
* ACL username ([see ACL guide](https://redis.io/topics/acl))
*/
Expand Down Expand Up @@ -276,6 +283,7 @@ export default class RedisClient<
readonly #options?: RedisClientOptions<M, F, S, RESP, TYPE_MAPPING>;
readonly #socket: RedisSocket;
readonly #queue: RedisCommandsQueue;
#credentialSupplier: RedisCredentialSupplier;
#selectedDB = 0;
#monitorCallback?: MonitorCallback<TYPE_MAPPING>;
private _self = this;
Expand Down Expand Up @@ -313,6 +321,8 @@ export default class RedisClient<
this.#options = this.#initiateOptions(options);
this.#queue = this.#initiateQueue();
this.#socket = this.#initiateSocket();
this.#credentialSupplier = this.#initiateCredentialSupplier();

this.#epoch = 0;
}

Expand Down Expand Up @@ -345,16 +355,16 @@ export default class RedisClient<
);
}

#handshake(selectedDB: number) {
#handshake(selectedDB: number, credential?: AuthOptions) {
const commands = [];

if (this.#options?.RESP) {
const hello: HelloOptions = {};

if (this.#options.password) {
if (credential?.password) {
hello.AUTH = {
username: this.#options.username ?? 'default',
password: this.#options.password
username: credential?.username ?? 'default',
password: credential?.password
};
}

Expand All @@ -366,11 +376,11 @@ export default class RedisClient<
parseArgs(HELLO, this.#options.RESP, hello)
);
} else {
if (this.#options?.username || this.#options?.password) {
if (credential) {
commands.push(
parseArgs(COMMANDS.AUTH, {
username: this.#options.username,
password: this.#options.password ?? ''
username: credential.username,
password: credential.password ?? ''
})
);
}
Expand All @@ -396,7 +406,11 @@ export default class RedisClient<
}

#initiateSocket(): RedisSocket {
const socketInitiator = () => {
const socketInitiator = async () => {
// we have to call the credential fetch before pushing any commands into the queue,
// so fetch the credentials before doing anything else.
const credential: AuthOptions | undefined = await this.#credentialSupplier();

const promises = [],
chainId = Symbol('Socket Initiator');

Expand All @@ -418,7 +432,7 @@ export default class RedisClient<
);
}

const commands = this.#handshake(this.#selectedDB);
const commands = this.#handshake(this.#selectedDB, credential);
for (let i = commands.length - 1; i >= 0; --i) {
promises.push(
this.#queue.addCommand(commands[i], {
Expand Down Expand Up @@ -463,6 +477,15 @@ export default class RedisClient<
.on('end', () => this.emit('end'));
}

#initiateCredentialSupplier(): RedisCredentialSupplier {
// if a credential supplier has been provided, use it, otherwise create a provider from the
// supplier username and password (if provided)
return this.#options?.credentialSupplier ?? (() => Promise.resolve((this.#options?.username || this.#options?.password) ? {
username: this.#options?.username,
password: this.#options?.password ?? '',
} : undefined));
}

#pingTimer?: NodeJS.Timeout;

#setPingTimer(): void {
Expand Down Expand Up @@ -997,10 +1020,11 @@ export default class RedisClient<
* Reset the client to its default state (i.e. stop PubSub, stop monitoring, select default DB, etc.)
*/
async reset() {
const credential: AuthOptions | undefined = await this.#credentialSupplier?.();
const chainId = Symbol('Reset Chain'),
promises = [this._self.#queue.reset(chainId)],
selectedDB = this._self.#options?.database ?? 0;
for (const command of this._self.#handshake(selectedDB)) {
for (const command of this._self.#handshake(selectedDB, credential)) {
promises.push(
this._self.#queue.addCommand(command, {
chainId
Expand Down

0 comments on commit 1b0ac32

Please sign in to comment.