From 1ebd657c1c6ff94ab147b55d62bfec4a2c6c031f Mon Sep 17 00:00:00 2001 From: Alissa Renz Date: Wed, 16 Oct 2024 20:16:37 -0700 Subject: [PATCH] Add support for Agents/Assistants (#2286) --- src/App.ts | 12 + src/Assistant.spec.ts | 492 ++++++++++++++++++++++++ src/Assistant.ts | 412 ++++++++++++++++++++ src/AssistantThreadContextStore.spec.ts | 171 ++++++++ src/AssistantThreadContextStore.ts | 97 +++++ src/errors.ts | 12 + src/index.ts | 8 + 7 files changed, 1204 insertions(+) create mode 100644 src/Assistant.spec.ts create mode 100644 src/Assistant.ts create mode 100644 src/AssistantThreadContextStore.spec.ts create mode 100644 src/AssistantThreadContextStore.ts diff --git a/src/App.ts b/src/App.ts index 056e9e776..9b09975e7 100644 --- a/src/App.ts +++ b/src/App.ts @@ -63,6 +63,7 @@ import { StringIndexed } from './types/helpers'; // eslint-disable-next-line import/order import allSettled = require('promise.allsettled'); // eslint-disable-line @typescript-eslint/no-require-imports import { FunctionCompleteFn, FunctionFailFn, CustomFunction, CustomFunctionMiddleware } from './CustomFunction'; +import { Assistant } from './Assistant'; // eslint-disable-next-line @typescript-eslint/no-require-imports, import/no-commonjs const packageJson = require('../package.json'); // eslint-disable-line @typescript-eslint/no-var-requires @@ -519,6 +520,17 @@ export default class App return this; } + /** + * Register Assistant middleware + * + * @param assistant global assistant middleware function + */ + public assistant(assistant: Assistant): this { + const m = assistant.getMiddleware(); + this.middleware.push(m); + return this; + } + /** * Register WorkflowStep middleware * diff --git a/src/Assistant.spec.ts b/src/Assistant.spec.ts new file mode 100644 index 000000000..bd4cd0b36 --- /dev/null +++ b/src/Assistant.spec.ts @@ -0,0 +1,492 @@ +import 'mocha'; +import { assert } from 'chai'; +import sinon from 'sinon'; +import rewiremock from 'rewiremock'; +import { WebClient } from '@slack/web-api'; +import { + Assistant, + AssistantMiddlewareArgs, + AllAssistantMiddlewareArgs, + AssistantMiddleware, + AssistantConfig, + AssistantThreadStartedMiddlewareArgs, + AssistantThreadContextChangedMiddlewareArgs, + AssistantUserMessageMiddlewareArgs, +} from './Assistant'; +import { Override } from './test-helpers'; +import { AllMiddlewareArgs, AnyMiddlewareArgs, AssistantThreadStartedEvent, Middleware, SlackEventMiddlewareArgs } from './types'; +import { AssistantInitializationError, AssistantMissingPropertyError } from './errors'; +import { AssistantThreadContextStore, AssistantThreadContext } from './AssistantThreadContextStore'; + +async function importAssistant(overrides: Override = {}): Promise { + return rewiremock.module(() => import('./Assistant'), overrides); +} + +const MOCK_FN = async () => { }; + +const MOCK_CONFIG_SINGLE = { + threadStarted: MOCK_FN, + threadContextChanged: MOCK_FN, + userMessage: MOCK_FN, +}; + +const MOCK_CONFIG_MULTIPLE = { + threadStarted: [MOCK_FN, MOCK_FN], + threadContextChanged: [MOCK_FN], + userMessage: [MOCK_FN, MOCK_FN, MOCK_FN], +}; + +describe('Assistant class', () => { + describe('constructor', () => { + it('should accept config as single functions', async () => { + const assistant = new Assistant(MOCK_CONFIG_SINGLE); + assert.isNotNull(assistant); + }); + + it('should accept config as multiple functions', async () => { + const assistant = new Assistant(MOCK_CONFIG_MULTIPLE); + assert.isNotNull(assistant); + }); + + describe('validate', () => { + it('should throw an error if config is not an object', async () => { + const { validate } = await importAssistant(); + + // intentionally casting to AssistantConfig to trigger failure + const badConfig = '' as unknown as AssistantConfig; + + const validationFn = () => validate(badConfig); + const expectedMsg = 'Assistant expects a configuration object as the argument'; + assert.throws(validationFn, AssistantInitializationError, expectedMsg); + }); + + it('should throw an error if required keys are missing', async () => { + const { validate } = await importAssistant(); + + // intentionally casting to AssistantConfig to trigger failure + const badConfig = { + threadStarted: async () => { }, + } as unknown as AssistantConfig; + + const validationFn = () => validate(badConfig); + const expectedMsg = 'Assistant is missing required keys: userMessage'; + assert.throws(validationFn, AssistantInitializationError, expectedMsg); + }); + + it('should throw an error if props are not a single callback or an array of callbacks', async () => { + const { validate } = await importAssistant(); + + // intentionally casting to AssistantConfig to trigger failure + const badConfig = { + threadStarted: async () => { }, + threadContextChanged: {}, + userMessage: async () => { }, + } as unknown as AssistantConfig; + + const validationFn = () => validate(badConfig); + const expectedMsg = 'Assistant threadContextChanged property must be a function or an array of functions'; + assert.throws(validationFn, AssistantInitializationError, expectedMsg); + }); + }); + }); + + describe('getMiddleware', () => { + it('should call next if not an assistant event', async () => { + const assistant = new Assistant(MOCK_CONFIG_SINGLE); + const middleware = assistant.getMiddleware(); + const fakeMessageArgs = createGenericEvent() as unknown as AnyMiddlewareArgs & AllMiddlewareArgs; + fakeMessageArgs.payload.type = 'app_mention'; + + const fakeNext = sinon.spy(); + fakeMessageArgs.next = fakeNext; + + await middleware(fakeMessageArgs); + + assert(fakeNext.called); + }); + + it('should not call next if a assistant event', async () => { + const assistant = new Assistant(MOCK_CONFIG_SINGLE); + const middleware = assistant.getMiddleware(); + const mockThreadStartedArgs = createMockThreadStartedEvent() as + unknown as AnyMiddlewareArgs & AllMiddlewareArgs; + + const fakeNext = sinon.spy(); + mockThreadStartedArgs.next = fakeNext; + + await middleware(mockThreadStartedArgs); + + assert(fakeNext.notCalled); + }); + + describe('isAssistantEvent', () => { + it('should return true if recognized assistant event', async () => { + const mockThreadStartedArgs = createMockThreadStartedEvent() as + unknown as AnyMiddlewareArgs; + const mockThreadContextChangedArgs = createMockThreadContextChangedEvent() as + unknown as AnyMiddlewareArgs; + const mockUserMessageArgs = createMockUserMessageEvent() as + unknown as AnyMiddlewareArgs; + + const { isAssistantEvent } = await importAssistant(); + + const threadStartedIsAssistantEvent = isAssistantEvent(mockThreadStartedArgs); + const threadContextChangedIsAssistantEvent = isAssistantEvent(mockThreadContextChangedArgs); + const userMessageIsAssistantEvent = isAssistantEvent(mockUserMessageArgs); + + assert.isTrue(threadStartedIsAssistantEvent); + assert.isTrue(threadContextChangedIsAssistantEvent); + assert.isTrue(userMessageIsAssistantEvent); + }); + + it('should return false if not a recognized assistant event', async () => { + const fakeEventArgs = createGenericEvent() as unknown as SlackEventMiddlewareArgs; + fakeEventArgs.payload.type = 'function_executed'; + + const { isAssistantEvent } = await importAssistant(); + const messageIsAssistantEvent = isAssistantEvent(fakeEventArgs as AnyMiddlewareArgs); + + assert.isFalse(messageIsAssistantEvent); + }); + }); + + describe('matchesConstraints', () => { + it('should return true if recognized assistant message', async () => { + const mockUserMessageArgs = createMockUserMessageEvent() as unknown as AssistantMiddlewareArgs; + const { matchesConstraints } = await importAssistant(); + const eventMatchesConstraints = matchesConstraints(mockUserMessageArgs); + + assert.isTrue(eventMatchesConstraints); + }); + + it('should return false if not supported message subtype', async () => { + const mockAppMentionArgs = createGenericEvent() as unknown as any; + mockAppMentionArgs.payload.type = 'message'; + mockAppMentionArgs.payload.subtype = 'bot_message'; + + const { matchesConstraints } = await importAssistant(); + const eventMatchesConstraints = matchesConstraints(mockAppMentionArgs); + + assert.isFalse(eventMatchesConstraints); + }); + + it('should return true if not message event', async () => { + const assistantThreadStartedArgs = createGenericEvent() as unknown as any; + assistantThreadStartedArgs.payload.type = 'assistant_thread_started'; + + const { matchesConstraints } = await importAssistant(); + const eventMatchesConstraints = matchesConstraints(assistantThreadStartedArgs); + + assert.isTrue(eventMatchesConstraints); + }); + + describe('isAssistantMessage', () => { + it('should return true if assistant message event', async () => { + const mockUserMessageArgs = createMockUserMessageEvent() as unknown as any; + const { isAssistantMessage } = await importAssistant(); + const userMessageIsAssistantEvent = isAssistantMessage(mockUserMessageArgs.payload); + + assert.isTrue(userMessageIsAssistantEvent); + }); + + it('should return false if not correct subtype', async () => { + const mockAppMentionArgs = createGenericEvent() as unknown as any; + mockAppMentionArgs.payload.type = 'message'; + mockAppMentionArgs.payload.subtype = 'app_mention'; + + const { isAssistantMessage } = await importAssistant(); + const userMessageIsAssistantEvent = isAssistantMessage(mockAppMentionArgs.payload); + + assert.isFalse(userMessageIsAssistantEvent); + }); + + it('should return false if thread_ts is missing', async () => { + const mockUnsupportedMessageArgs = createMockUserMessageEvent() as unknown as any; + delete mockUnsupportedMessageArgs.payload.thread_ts; + + const { isAssistantMessage } = await importAssistant(); + const userMessageIsAssistantEvent = isAssistantMessage(mockUnsupportedMessageArgs.payload); + + assert.isFalse(userMessageIsAssistantEvent); + }); + + it('should return false if channel_type is incorrect', async () => { + const mockUnsupportedMessageArgs = createMockUserMessageEvent() as unknown as any; + mockUnsupportedMessageArgs.payload.channel_type = 'mpim'; + + const { isAssistantMessage } = await importAssistant(); + const userMessageIsAssistantEvent = isAssistantMessage(mockUnsupportedMessageArgs.payload); + + assert.isFalse(userMessageIsAssistantEvent); + }); + }); + }); + }); + + describe('processEvent', () => { + describe('enrichAssistantArgs', () => { + it('should remove next() from all original event args', async () => { + const mockThreadStartedArgs = createMockThreadStartedEvent() as + unknown as AssistantThreadStartedMiddlewareArgs & AllMiddlewareArgs; + const mockThreadContextChangedArgs = createMockThreadContextChangedEvent() as + unknown as AssistantThreadContextChangedMiddlewareArgs & AllMiddlewareArgs; + const mockUserMessageArgs = createMockUserMessageEvent() as + unknown as AssistantUserMessageMiddlewareArgs & AllMiddlewareArgs; + const mockThreadContextStore = createMockThreadContextStore(); + + const { enrichAssistantArgs } = await importAssistant(); + + const threadStartedArgs = enrichAssistantArgs(mockThreadContextStore, mockThreadStartedArgs); + const threadContextChangedArgs = enrichAssistantArgs(mockThreadContextStore, mockThreadContextChangedArgs); + const userMessageArgs = enrichAssistantArgs(mockThreadContextStore, mockUserMessageArgs); + + assert.notExists(threadStartedArgs.next); + assert.notExists(threadContextChangedArgs.next); + assert.notExists(userMessageArgs.next); + }); + + it('should augment assistant_thread_started args with utilities', async () => { + const mockArgs = createMockThreadStartedEvent(); + const mockThreadContextStore = createMockThreadContextStore(); + const { enrichAssistantArgs } = await importAssistant(); + const assistantArgs = enrichAssistantArgs(mockThreadContextStore, mockArgs as any); + + assert.exists(assistantArgs.say); + assert.exists(assistantArgs.setStatus); + assert.exists(assistantArgs.setSuggestedPrompts); + assert.exists(assistantArgs.setTitle); + }); + + it('should augment assistant_thread_context_changed args with utilities', async () => { + const mockArgs = createMockThreadContextChangedEvent(); + const mockThreadContextStore = createMockThreadContextStore(); + const { enrichAssistantArgs } = await importAssistant(); + const assistantArgs = enrichAssistantArgs(mockThreadContextStore, mockArgs as any); + + assert.exists(assistantArgs.say); + assert.exists(assistantArgs.setStatus); + assert.exists(assistantArgs.setSuggestedPrompts); + assert.exists(assistantArgs.setTitle); + }); + + it('should augment message args with utilities', async () => { + const mockArgs = createMockUserMessageEvent(); + const mockThreadContextStore = createMockThreadContextStore(); + const { enrichAssistantArgs } = await importAssistant(); + const assistantArgs = enrichAssistantArgs(mockThreadContextStore, mockArgs as any); + + assert.exists(assistantArgs.say); + assert.exists(assistantArgs.setStatus); + assert.exists(assistantArgs.setSuggestedPrompts); + assert.exists(assistantArgs.setTitle); + }); + + describe('extractThreadInfo', () => { + it('should return expected channelId, threadTs, and context for `assistant_thread_started` event', async () => { + const mockThreadStartedEvent = createMockThreadStartedEvent() as unknown as AssistantThreadStartedMiddlewareArgs; // eslint-disable-line max-len + const { payload } = mockThreadStartedEvent; + const { extractThreadInfo } = await importAssistant(); + const { channelId, threadTs, context } = extractThreadInfo(payload); + + assert.equal(payload.assistant_thread.channel_id, channelId); + assert.equal(payload.assistant_thread.thread_ts, threadTs); + assert.deepEqual(payload.assistant_thread.context, context); + }); + + it('should return expected channelId, threadTs, and context for `assistant_thread_context_changed` event', async () => { + const mockThreadChangedEvent = createMockThreadContextChangedEvent() as unknown as AssistantThreadContextChangedMiddlewareArgs; // eslint-disable-line max-len + const { payload } = mockThreadChangedEvent; + const { extractThreadInfo } = await importAssistant(); + const { channelId, threadTs, context } = extractThreadInfo(payload); + + assert.equal(payload.assistant_thread.channel_id, channelId); + assert.equal(payload.assistant_thread.thread_ts, threadTs); + assert.deepEqual(payload.assistant_thread.context, context); + }); + + it('should return expected channelId and threadTs for `message` event', async () => { + const mockUserMessageEvent = createMockUserMessageEvent(); + const { payload } = mockUserMessageEvent as any; + const { extractThreadInfo } = await importAssistant(); + const { channelId, threadTs, context } = extractThreadInfo(payload); + + assert.equal(payload.channel, channelId); + assert.equal(payload.thread_ts, threadTs); + assert.isEmpty(context); + }); + + it('should throw error if `channel_id` or `thread_ts` are missing', async () => { + const { payload } = createMockThreadStartedEvent() as unknown as AssistantThreadStartedMiddlewareArgs; // eslint-disable-line max-len + payload.assistant_thread.channel_id = ''; + const { extractThreadInfo } = await importAssistant(); + + const extractThreadInfoFn = () => extractThreadInfo(payload); + const expectedMsg = 'Assistant message event is missing required properties: channel_id'; + assert.throws(extractThreadInfoFn, AssistantMissingPropertyError, expectedMsg); + }); + }); + + describe('assistant args/utilities', () => { + it('say should call chat.postMessage', async () => { + const mockThreadStartedArgs = createMockThreadStartedEvent() as unknown as AssistantMiddlewareArgs & AllMiddlewareArgs; // eslint-disable-line max-len + + const fakeClient = { chat: { postMessage: sinon.spy() } }; + mockThreadStartedArgs.client = fakeClient as unknown as WebClient; + const mockThreadContextStore = createMockThreadContextStore(); + + const { enrichAssistantArgs } = await importAssistant(); + const threadStartedArgs = enrichAssistantArgs(mockThreadContextStore, mockThreadStartedArgs); + + await threadStartedArgs.say('Say called!'); + + assert(fakeClient.chat.postMessage.called); + }); + + it('setStatus should call assistant.threads.setStatus', async () => { + const mockThreadStartedArgs = createMockThreadStartedEvent() as unknown as AssistantMiddlewareArgs & AllMiddlewareArgs; // eslint-disable-line max-len + + const fakeClient = { assistant: { threads: { setStatus: sinon.spy() } } }; + mockThreadStartedArgs.client = fakeClient as unknown as WebClient; + const mockThreadContextStore = createMockThreadContextStore(); + + const { enrichAssistantArgs } = await importAssistant(); + const threadStartedArgs = enrichAssistantArgs(mockThreadContextStore, mockThreadStartedArgs); + + await threadStartedArgs.setStatus('Status set!'); + + assert(fakeClient.assistant.threads.setStatus.called); + }); + + it('setSuggestedPrompts should call assistant.threads.setSuggestedPrompts', async () => { + const mockThreadStartedArgs = createMockThreadStartedEvent() as unknown as AssistantMiddlewareArgs & AllMiddlewareArgs; // eslint-disable-line max-len + + const fakeClient = { assistant: { threads: { setSuggestedPrompts: sinon.spy() } } }; + mockThreadStartedArgs.client = fakeClient as unknown as WebClient; + const mockThreadContextStore = createMockThreadContextStore(); + + const { enrichAssistantArgs } = await importAssistant(); + const threadStartedArgs = enrichAssistantArgs(mockThreadContextStore, mockThreadStartedArgs); + + await threadStartedArgs.setSuggestedPrompts({ prompts: [{ title: '', message: '' }] }); + + assert(fakeClient.assistant.threads.setSuggestedPrompts.called); + }); + + it('setTitle should call assistant.threads.setTitle', async () => { + const mockThreadStartedArgs = createMockThreadStartedEvent() as unknown as AssistantMiddlewareArgs & AllMiddlewareArgs; // eslint-disable-line max-len + + const fakeClient = { assistant: { threads: { setTitle: sinon.spy() } } }; + mockThreadStartedArgs.client = fakeClient as unknown as WebClient; + const mockThreadContextStore = createMockThreadContextStore(); + + const { enrichAssistantArgs } = await importAssistant(); + const threadStartedArgs = enrichAssistantArgs(mockThreadContextStore, mockThreadStartedArgs); + + await threadStartedArgs.setTitle('Title set!'); + + assert(fakeClient.assistant.threads.setTitle.called); + }); + }); + }); + + describe('processAssistantMiddleware', () => { + it('should call each callback in user-provided middleware', async () => { + const { ...mockArgs } = createMockThreadContextChangedEvent() as unknown as AllAssistantMiddlewareArgs; + const { processAssistantMiddleware } = await importAssistant(); + + const fn1 = sinon.spy((async ({ next: continuation }) => { + await continuation(); + }) as Middleware); + const fn2 = sinon.spy(async () => { }); + const fakeMiddleware = [fn1, fn2] as AssistantMiddleware; + + await processAssistantMiddleware(mockArgs, fakeMiddleware); + + assert(fn1.called); + assert(fn2.called); + }); + }); + }); +}); + +function createMockThreadStartedEvent() { + return { + payload: { + type: 'assistant_thread_started', + assistant_thread: { + user_id: '', + context: { + channel_id: '', + team_id: '', + enterprise_id: '', + }, + channel_id: 'D01234567AR', + thread_ts: '1234567890.123456', + }, + event_ts: '', + }, + context: {}, + }; +} + +function createMockThreadContextChangedEvent() { + return { + payload: { + type: 'assistant_thread_context_changed', + assistant_thread: { + user_id: '', + context: { + channel_id: '', + team_id: '', + enterprise_id: '', + }, + channel_id: 'D01234567AR', + thread_ts: '1234567890.123456', + }, + event_ts: '', + }, + context: {}, + }; +} + +function createMockUserMessageEvent() { + return { + payload: { + user: '', + type: 'message', + ts: '', + text: 'test', + team: '', + user_team: '', + source_team: '', + user_profile: {}, + thread_ts: '1234567890.123456', + parent_user_id: '', + blocks: [], + channel: 'D01234567AR', + event_ts: '', + channel_type: 'im', + }, + context: {}, + }; +} + +function createGenericEvent() { + return { + payload: { + type: '', + }, + context: {}, + }; +} + +function createMockThreadContextStore(): AssistantThreadContextStore { + return { + async get(_: AllAssistantMiddlewareArgs): Promise { + return {}; + }, + // eslint-disable-next-line @typescript-eslint/no-empty-function + async save(_: AllAssistantMiddlewareArgs): Promise { + }, + }; +} diff --git a/src/Assistant.ts b/src/Assistant.ts new file mode 100644 index 000000000..3ae6d4e75 --- /dev/null +++ b/src/Assistant.ts @@ -0,0 +1,412 @@ +import { + AssistantThreadsSetStatusResponse, + AssistantThreadsSetSuggestedPromptsResponse, + AssistantThreadsSetTitleResponse, + ChatPostMessageArguments, +} from '@slack/web-api'; +import processMiddleware from './middleware/process'; +import { + AllMiddlewareArgs, + AnyMiddlewareArgs, + Middleware, + SayFn, + SlackEventMiddlewareArgs, +} from './types'; +import { AssistantInitializationError, AssistantMissingPropertyError } from './errors'; +import { + AssistantThreadContext, + AssistantThreadContextStore, + DefaultThreadContextStore, + GetThreadContextFn, + SaveThreadContextFn, +} from './AssistantThreadContextStore'; + +/** + * Configuration object used to instantiate the Assistant + */ +export interface AssistantConfig { + threadContextStore?: AssistantThreadContextStore; + threadStarted: AssistantThreadStartedMiddleware | AssistantThreadStartedMiddleware[]; + threadContextChanged?: AssistantThreadContextChangedMiddleware | AssistantThreadContextChangedMiddleware[]; + userMessage: AssistantUserMessageMiddleware | AssistantUserMessageMiddleware[]; +} + +/** + * Callback utilities + */ +interface AssistantUtilityArgs { + getThreadContext: GetThreadContextFn; + saveThreadContext: SaveThreadContextFn; + say: SayFn; + setStatus: SetStatusFn; + setSuggestedPrompts: SetSuggestedPromptsFn; + setTitle: SetTitleFn; +} + +interface SetStatusFn { + (status: string): Promise; +} + +interface SetSuggestedPromptsFn { + (params: SetSuggestedPromptsArguments): Promise; +} + +interface SetSuggestedPromptsArguments { + prompts: [AssistantPrompt, ...AssistantPrompt[]]; +} + +interface AssistantPrompt { + title: string; + message: string; +} + +interface SetTitleFn { + (title: string): Promise; +} + +/** + * Middleware + */ +export type AssistantThreadStartedMiddleware = Middleware; +export type AssistantThreadContextChangedMiddleware = Middleware; +export type AssistantUserMessageMiddleware = Middleware; + +export type AssistantMiddleware = + | AssistantThreadStartedMiddleware[] + | AssistantThreadContextChangedMiddleware[] + | AssistantUserMessageMiddleware[]; + +export type AssistantMiddlewareArgs = + | AssistantThreadStartedMiddlewareArgs + | AssistantThreadContextChangedMiddlewareArgs + | AssistantUserMessageMiddlewareArgs; + +// TODO : revisit Omit of `say`, as it's added on as part of the enrichment step +export interface AssistantThreadStartedMiddlewareArgs extends + Omit, 'say'>, AssistantUtilityArgs {} +export interface AssistantThreadContextChangedMiddlewareArgs extends + Omit, 'say'>, AssistantUtilityArgs {} +export interface AssistantUserMessageMiddlewareArgs extends + Omit, AssistantUtilityArgs {} + +export type AllAssistantMiddlewareArgs = +T & AllMiddlewareArgs; + +/** Constants */ +const ASSISTANT_PAYLOAD_TYPES = new Set(['assistant_thread_started', 'assistant_thread_context_changed', 'message']); + +export class Assistant { + private threadContextStore: AssistantThreadContextStore; + + /** 'assistant_thread_started' */ + private threadStarted: AssistantThreadStartedMiddleware[]; + + /** 'assistant_thread_context_changed' */ + private threadContextChanged: AssistantThreadContextChangedMiddleware[]; + + /** 'message' */ + private userMessage: AssistantUserMessageMiddleware[]; + + public constructor(config: AssistantConfig) { + validate(config); + + const { + threadContextStore = new DefaultThreadContextStore(), + threadStarted, + // When `threadContextChanged` method is not provided, fallback to + // AssistantContextStore's save method. If a custom store has also not + // been provided, the default save context-via-metadata approach is used. + // See DefaultThreadContextStore for details of this implementation. + threadContextChanged = (args) => threadContextStore.save(args), + userMessage, + } = config; + + this.threadContextStore = threadContextStore; + this.threadStarted = Array.isArray(threadStarted) ? threadStarted : [threadStarted]; + this.threadContextChanged = Array.isArray(threadContextChanged) ? threadContextChanged : [threadContextChanged]; + this.userMessage = Array.isArray(userMessage) ? userMessage : [userMessage]; + } + + public getMiddleware(): Middleware { + return async (args): Promise => { + if (isAssistantEvent(args) && matchesConstraints(args)) { + return this.processEvent(args); + } + return args.next(); + }; + } + + private async processEvent(args: AllAssistantMiddlewareArgs): Promise { + const { payload } = args; + const assistantArgs = enrichAssistantArgs(this.threadContextStore, args); + const assistantMiddleware = this.getAssistantMiddleware(payload); + return processAssistantMiddleware(assistantArgs, assistantMiddleware); + } + + /** + * `getAssistantMiddleware()` returns the Assistant instance's middleware + */ + private getAssistantMiddleware(payload: AllAssistantMiddlewareArgs['payload']): AssistantMiddleware { + switch (payload.type) { + case 'assistant_thread_started': + return this.threadStarted; + case 'assistant_thread_context_changed': + return this.threadContextChanged; + case 'message': + return this.userMessage; + default: + return []; + } + } +} + +/** + * `enrichAssistantArgs()` takes the event arguments and: + * 1. Removes the next() passed in from App-level middleware processing, thus preventing + * events from continuing down the global middleware chain to subsequent listeners + * 2. Adds assistant-specific utilities (i.e., helper methods) + * */ +export function enrichAssistantArgs( + threadContextStore: AssistantThreadContextStore, + args: AllAssistantMiddlewareArgs, +): AllAssistantMiddlewareArgs { + const { next: _next, ...assistantArgs } = args; + const preparedArgs = { ...assistantArgs as Exclude, 'next'> }; + + // Do not pass preparedArgs (ie, do not add utilities to get/save) + preparedArgs.getThreadContext = () => threadContextStore.get(args); + preparedArgs.saveThreadContext = () => threadContextStore.save(args); + + preparedArgs.say = createSay(preparedArgs); + preparedArgs.setStatus = createSetStatus(preparedArgs); + preparedArgs.setSuggestedPrompts = createSetSuggestedPrompts(preparedArgs); + preparedArgs.setTitle = createSetTitle(preparedArgs); + return preparedArgs; +} + +/** + * `isAssistantEvent()` determines if incoming event is a supported + * Assistant event type. + */ +export function isAssistantEvent(args: AnyMiddlewareArgs): args is AllAssistantMiddlewareArgs { + return ASSISTANT_PAYLOAD_TYPES.has(args.payload.type); +} + +/** + * `matchesConstraints()` determines if the incoming event payload + * is related to the Assistant. + */ +export function matchesConstraints(args: AssistantMiddlewareArgs): args is AssistantMiddlewareArgs { + return args.payload.type === 'message' ? isAssistantMessage(args.payload) : true; +} + +/** + * `isAssistantMessage()` evaluates if the message payload is associated + * with the Assistant container. + */ +export function isAssistantMessage(payload: AnyMiddlewareArgs['payload']): boolean { + const isThreadMessage = 'channel' in payload && 'thread_ts' in payload; + const inAssistantContainer = ('channel_type' in payload && payload.channel_type === 'im') && + (!('subtype' in payload) || payload.subtype === 'file_share'); + return isThreadMessage && inAssistantContainer; +} + +/** + * `validate()` determines if the provided AssistantConfig is a valid configuration. + */ +export function validate(config: AssistantConfig): void { + // Ensure assistant config object is passed in + if (typeof config !== 'object') { + const errorMsg = 'Assistant expects a configuration object as the argument'; + throw new AssistantInitializationError(errorMsg); + } + + // Check for missing required keys + const requiredKeys: (keyof AssistantConfig)[] = ['threadStarted', 'userMessage']; + const missingKeys: (keyof AssistantConfig)[] = []; + requiredKeys.forEach((key) => { if (config[key] === undefined) missingKeys.push(key); }); + + if (missingKeys.length > 0) { + const errorMsg = `Assistant is missing required keys: ${missingKeys.join(', ')}`; + throw new AssistantInitializationError(errorMsg); + } + + // Ensure a callback or an array of callbacks is present + const requiredFns: (keyof AssistantConfig)[] = ['threadStarted', 'userMessage']; + if ('threadContextChanged' in config) requiredFns.push('threadContextChanged'); + requiredFns.forEach((fn) => { + if (typeof config[fn] !== 'function' && !Array.isArray(config[fn])) { + const errorMsg = `Assistant ${fn} property must be a function or an array of functions`; + throw new AssistantInitializationError(errorMsg); + } + }); + + // Validate threadContextStore + if (config.threadContextStore) { + // Ensure assistant config object is passed in + if (typeof config.threadContextStore !== 'object') { + const errorMsg = 'Assistant expects threadContextStore to be a configuration object'; + throw new AssistantInitializationError(errorMsg); + } + + // Check for missing required keys + const requiredContextKeys: (keyof AssistantThreadContextStore)[] = ['get', 'save']; + const missingContextKeys: (keyof AssistantThreadContextStore)[] = []; + requiredContextKeys.forEach((k) => { + if (config.threadContextStore && config.threadContextStore[k] === undefined) { + missingContextKeys.push(k); + } + }); + + if (missingContextKeys.length > 0) { + const errorMsg = `threadContextStore is missing required keys: ${missingContextKeys.join(', ')}`; + throw new AssistantInitializationError(errorMsg); + } + + // Ensure properties of context store are functions + const requiredStoreFns: (keyof AssistantThreadContextStore)[] = ['get', 'save']; + requiredStoreFns.forEach((fn) => { + if (config.threadContextStore && typeof config.threadContextStore[fn] !== 'function') { + const errorMsg = `threadContextStore ${fn} property must be a function`; + throw new AssistantInitializationError(errorMsg); + } + }); + } +} + +/** + * `processAssistantMiddleware()` invokes each callback for the given event + */ +export async function processAssistantMiddleware( + args: AllAssistantMiddlewareArgs, + middleware: AssistantMiddleware, +): Promise { + const { context, client, logger } = args; + const callbacks = [...middleware] as Middleware[]; + const lastCallback = callbacks.pop(); + + if (lastCallback !== undefined) { + await processMiddleware( + callbacks, args, context, client, logger, + async () => lastCallback({ ...args, context, client, logger }), + ); + } +} + +/** + * Utility functions + */ + +/** + * Creates utility `say()` to easily respond to wherever the message + * was received. Alias for `postMessage()`. + * https://api.slack.com/methods/chat.postMessage + */ +function createSay(args: AllAssistantMiddlewareArgs): SayFn { + const { + client, + payload, + } = args; + const { channelId: channel, threadTs: thread_ts } = extractThreadInfo(payload); + + return (message: Parameters[0]) => { + const postMessageArgument: ChatPostMessageArguments = typeof message === 'string' ? + { text: message, channel, thread_ts } : + { ...message, channel, thread_ts }; + + return client.chat.postMessage(postMessageArgument); + }; +} + +/** + * Creates utility `setStatus()` to set the status and indicate active processing. + * https://api.slack.com/methods/assistant.threads.setStatus + */ +function createSetStatus(args: AllAssistantMiddlewareArgs): SetStatusFn { + const { + client, + payload, + } = args; + const { channelId: channel_id, threadTs: thread_ts } = extractThreadInfo(payload); + + return (status: Parameters[0]) => client.assistant.threads.setStatus({ + channel_id, + thread_ts, + status, + }); +} + +/** + * Creates utility `setSuggestedPrompts()` to provides prompts for the user to select from. + * https://api.slack.com/methods/assistant.threads.setSuggestedPrompts + */ +function createSetSuggestedPrompts(args: AllAssistantMiddlewareArgs): SetSuggestedPromptsFn { + const { + client, + payload, + } = args; + const { channelId: channel_id, threadTs: thread_ts } = extractThreadInfo(payload); + + return (params: Parameters[0]) => { + const { prompts } = params; + return client.assistant.threads.setSuggestedPrompts({ + channel_id, + thread_ts, + prompts, + }); + }; +} + +/** + * Creates utility `setTitle()` to set the title of the Assistant thread + * https://api.slack.com/methods/assistant.threads.setTitle + */ +function createSetTitle(args: AllAssistantMiddlewareArgs): SetTitleFn { + const { + client, + payload, + } = args; + const { channelId: channel_id, threadTs: thread_ts } = extractThreadInfo(payload); + + return (title: Parameters[0]) => client.assistant.threads.setTitle({ + channel_id, + thread_ts, + title, + }); +} + +/** + * `extractThreadInfo()` parses an incoming payload and returns relevant + * details about the thread +*/ +export function extractThreadInfo(payload: AllAssistantMiddlewareArgs['payload']): { channelId: string, threadTs: string, context: AssistantThreadContext } { + let channelId: string = ''; + let threadTs: string = ''; + let context: AssistantThreadContext = {}; + + // assistant_thread_started, asssistant_thread_context_changed + if ('assistant_thread' in payload) { + channelId = payload.assistant_thread.channel_id; + threadTs = payload.assistant_thread.thread_ts; + context = payload.assistant_thread.context; + } + + // user message in thread + if ('channel' in payload && 'thread_ts' in payload && payload.thread_ts !== undefined) { + channelId = payload.channel; + threadTs = payload.thread_ts; + } + + // throw error if `channel` or `thread_ts` are missing + if (!channelId || !threadTs) { + const missingProps: string[] = []; + if (!channelId) missingProps.push('channel_id'); + if (!threadTs) missingProps.push('thread_ts'); + if (missingProps.length > 0) { + const errorMsg = `Assistant message event is missing required properties: ${missingProps.join(', ')}`; + throw new AssistantMissingPropertyError(errorMsg); + } + } + + return { channelId, threadTs, context }; +} diff --git a/src/AssistantThreadContextStore.spec.ts b/src/AssistantThreadContextStore.spec.ts new file mode 100644 index 000000000..158154ac2 --- /dev/null +++ b/src/AssistantThreadContextStore.spec.ts @@ -0,0 +1,171 @@ +import 'mocha'; +import { assert } from 'chai'; +import sinon from 'sinon'; +import { WebClient } from '@slack/web-api'; +import { DefaultThreadContextStore } from './AssistantThreadContextStore'; +import { AllAssistantMiddlewareArgs, extractThreadInfo } from './Assistant'; + +describe('DefaultThreadContextStore class', () => { + describe('get', () => { + it('should retrieve message metadata if context not already saved to instance', async () => { + const mockContextStore = new DefaultThreadContextStore(); + const mockAssistantMiddlewareArgs = createMockAssistantMiddlewareArgs() as unknown as AllAssistantMiddlewareArgs; + const mockThreadContext = { channel_id: '123', thread_ts: '123', enterprise_id: null }; + const fakeClient = { + conversations: { + replies: sinon.fake.returns({ + messages: [{ + user: 'U12345', + ts: '12345', + metadata: { event_payload: mockThreadContext }, + }], + }), + }, + }; + mockAssistantMiddlewareArgs.client = fakeClient as unknown as WebClient; + const threadContext = await mockContextStore.get(mockAssistantMiddlewareArgs); + + assert(fakeClient.conversations.replies.called); + assert.equal(threadContext, mockThreadContext); + }); + + it('should return an empty object if no message history exists', async () => { + const mockContextStore = new DefaultThreadContextStore(); + const mockAssistantMiddlewareArgs = createMockAssistantMiddlewareArgs() as unknown as AllAssistantMiddlewareArgs; + const fakeClient = { conversations: { replies: sinon.fake.returns([]) } }; + mockAssistantMiddlewareArgs.client = fakeClient as unknown as WebClient; + const threadContext = await mockContextStore.get(mockAssistantMiddlewareArgs); + + assert.isEmpty(threadContext); + }); + + it('should return an empty object if no message metadata exists', async () => { + const mockContextStore = new DefaultThreadContextStore(); + const mockAssistantMiddlewareArgs = createMockAssistantMiddlewareArgs() as unknown as AllAssistantMiddlewareArgs; + const fakeClient = { + conversations: { + replies: sinon.fake.returns({ + messages: [{ + user: 'U12345', + ts: '12345', + metadata: {}, + }], + }), + }, + }; + mockAssistantMiddlewareArgs.client = fakeClient as unknown as WebClient; + const threadContext = await mockContextStore.get(mockAssistantMiddlewareArgs); + + assert.isEmpty(threadContext); + }); + + it('should retrieve instance context if it has been saved previously', async () => { + const mockContextStore = new DefaultThreadContextStore(); + const mockAssistantMiddlewareArgs = createMockAssistantMiddlewareArgs() as any; + const fakeClient = { + conversations: { replies: sinon.fake.returns({ messages: [{ user: 'U12345', ts: '12345' }] }) }, + chat: { update: sinon.fake() }, + }; + mockAssistantMiddlewareArgs.client = fakeClient as unknown as WebClient; + + await mockContextStore.save(mockAssistantMiddlewareArgs); + const threadContext = await mockContextStore.get(mockAssistantMiddlewareArgs); + + assert(fakeClient.conversations.replies.calledOnce); + assert.equal(threadContext, mockAssistantMiddlewareArgs.payload.assistant_thread.context); + }); + }); + + describe('save', () => { + it('should update instance context with threadContext', async () => { + const mockContextStore = new DefaultThreadContextStore(); + const mockAssistantMiddlewareArgs = createMockAssistantMiddlewareArgs() as any; + const fakeClient = { + conversations: { replies: sinon.fake.returns({ messages: [] }) }, + chat: { update: sinon.fake() }, + }; + mockAssistantMiddlewareArgs.client = fakeClient as unknown as WebClient; + + await mockContextStore.save(mockAssistantMiddlewareArgs); + const instanceContext = await mockContextStore.get(mockAssistantMiddlewareArgs); + + assert(fakeClient.conversations.replies.calledOnce); + assert.deepEqual(instanceContext, mockAssistantMiddlewareArgs.payload.assistant_thread.context); + }); + + it('should retrieve message history', async () => { + const mockContextStore = new DefaultThreadContextStore(); + const mockAssistantMiddlewareArgs = createMockAssistantMiddlewareArgs() as any; + const fakeClient = { + conversations: { replies: sinon.fake.returns({}) }, + chat: { update: sinon.fake() }, + }; + mockAssistantMiddlewareArgs.client = fakeClient as unknown as WebClient; + + await mockContextStore.save(mockAssistantMiddlewareArgs); + assert(fakeClient.conversations.replies.calledOnce); + }); + + it('should return early if no message history exists', async () => { + const mockContextStore = new DefaultThreadContextStore(); + const mockAssistantMiddlewareArgs = createMockAssistantMiddlewareArgs() as any; + const fakeClient = { + conversations: { replies: sinon.fake.returns({}) }, + chat: { update: sinon.fake() }, + }; + mockAssistantMiddlewareArgs.client = fakeClient as unknown as WebClient; + + await mockContextStore.save(mockAssistantMiddlewareArgs); + assert(fakeClient.chat.update.notCalled); + }); + + it('should update first bot message metadata with threadContext', async () => { + const mockContextStore = new DefaultThreadContextStore(); + const mockAssistantMiddlewareArgs = createMockAssistantMiddlewareArgs() as any; + const fakeClient = { + conversations: { replies: sinon.fake.returns({ messages: [{ user: 'U12345', ts: '12345', text: 'foo' }] }) }, + chat: { update: sinon.fake() }, + }; + mockAssistantMiddlewareArgs.client = fakeClient as unknown as WebClient; + const { channelId, context } = extractThreadInfo(mockAssistantMiddlewareArgs.payload); + const mockParams = { + channel: channelId, + ts: '12345', + text: 'foo', + metadata: { + event_type: 'assistant_thread_context', + event_payload: context, + }, + }; + + await mockContextStore.save(mockAssistantMiddlewareArgs); + assert(fakeClient.chat.update.calledWith(mockParams)); + }); + }); +}); + +function createMockAssistantMiddlewareArgs() { + return { + client: {}, + logger: { + debug: sinon.fake(), + }, + payload: { + type: 'assistant_thread_started', + assistant_thread: { + user_id: '', + context: { + channel_id: 'D01234567AR', + team_id: 'T123', + enterprise_id: 'E12345678', + }, + channel_id: 'D01234567AR', + thread_ts: '1234567890.123456', + }, + event_ts: '', + }, + context: { + botUserId: 'U12345', + }, + }; +} diff --git a/src/AssistantThreadContextStore.ts b/src/AssistantThreadContextStore.ts new file mode 100644 index 000000000..521408d11 --- /dev/null +++ b/src/AssistantThreadContextStore.ts @@ -0,0 +1,97 @@ +import { ChatUpdateArguments } from '@slack/web-api'; +import { Block, KnownBlock, MessageMetadataEventPayloadObject } from '@slack/types'; +import { AllAssistantMiddlewareArgs, extractThreadInfo } from './Assistant'; + +export interface AssistantThreadContextStore { + get: GetThreadContextFn; + save: SaveThreadContextFn; +} + +export interface GetThreadContextFn { + (args: AllAssistantMiddlewareArgs): Promise; +} + +export interface SaveThreadContextFn { + (args: AllAssistantMiddlewareArgs): Promise; +} + +export interface AssistantThreadContext { + channel_id?: string; + team_id?: string; + enterprise_id?: string | null; +} + +export class DefaultThreadContextStore implements AssistantThreadContextStore { + private context: AssistantThreadContext = {}; + + public async get(args: AllAssistantMiddlewareArgs): Promise { + const { context, client, payload, logger } = args; + + logger.debug('DefaultAssistantThreadStore: get method called'); + + if (this.context.channel_id) { + return this.context; + } + + const { channelId: channel, threadTs: thread_ts } = extractThreadInfo(payload); + + // Retrieve the current thread history + const thread = await client.conversations.replies({ + channel, + ts: thread_ts, + oldest: thread_ts, + include_all_metadata: true, + limit: 4, + }); + + if (!thread.messages) return {}; + + // Find the first message in the thread that holds the current context using metadata. + // See createSaveThreadContext below for a description and explanation for this approach. + const initialMsg = thread.messages.find((m) => !('subtype' in m) && m.user === context.botUserId); + const threadContext = initialMsg && initialMsg.metadata ? initialMsg.metadata.event_payload : null; + + return threadContext || {}; + } + + public async save(args: AllAssistantMiddlewareArgs): Promise { + const { context, client, payload, logger } = args; + const { channelId: channel, threadTs: thread_ts, context: threadContext } = extractThreadInfo(payload); + + logger.debug('DefaultAssistantThreadStore: save method called'); + + // Retrieve first several messages from the current Assistant thread + const thread = await client.conversations.replies({ + channel, + ts: thread_ts, + oldest: thread_ts, + include_all_metadata: true, + limit: 4, + }); + + if (!thread.messages) return; + + // Find and update the initial Assistant message with the new context to ensure the + // thread always contains the most recent context that user is sending messages from. + const initialMsg = thread.messages.find((m) => !('subtype' in m) && m.user === context.botUserId); + if (initialMsg && initialMsg.ts) { + const params: ChatUpdateArguments = { + channel, + ts: initialMsg.ts, + text: initialMsg.text, + metadata: { + event_type: 'assistant_thread_context', + event_payload: threadContext as MessageMetadataEventPayloadObject, + }, + }; + + if (initialMsg.blocks) { + params.blocks = initialMsg.blocks as (KnownBlock | Block)[]; + } + + await client.chat.update(params); + } + + this.context = threadContext; + } +} diff --git a/src/errors.ts b/src/errors.ts index 09a5a8b7f..a006b266e 100644 --- a/src/errors.ts +++ b/src/errors.ts @@ -17,6 +17,10 @@ export function isCodedError(err: any): err is CodedError { export enum ErrorCode { AppInitializationError = 'slack_bolt_app_initialization_error', + + AssistantInitializationError = 'slack_bolt_assistant_initialization_error', + AssistantMissingPropertyError = 'slack_bolt_assistant_missing_property_error', + AuthorizationError = 'slack_bolt_authorization_error', ContextMissingPropertyError = 'slack_bolt_context_missing_property_error', @@ -69,6 +73,14 @@ export class AppInitializationError extends Error implements CodedError { public code = ErrorCode.AppInitializationError; } +export class AssistantInitializationError extends Error implements CodedError { + public code = ErrorCode.AssistantInitializationError; +} + +export class AssistantMissingPropertyError extends Error implements CodedError { + public code = ErrorCode.AssistantMissingPropertyError; +} + export class AuthorizationError extends Error implements CodedError { public code = ErrorCode.AuthorizationError; diff --git a/src/index.ts b/src/index.ts index 121fe98ce..b03d566fa 100644 --- a/src/index.ts +++ b/src/index.ts @@ -46,6 +46,14 @@ export { buildReceiverRoutes, } from './receivers/custom-routes'; +export { + Assistant, + AssistantConfig, + AssistantThreadContextChangedMiddleware, + AssistantThreadStartedMiddleware, + AssistantUserMessageMiddleware, +} from './Assistant'; + export { WorkflowStep, WorkflowStepConfig,