Skip to content

Commit

Permalink
[Search] [Playground] [Bug] Remove token clipping (#199055)
Browse files Browse the repository at this point in the history
- Remove token pruning functionality as this has a large cost, causing
OOMs on serverless.
- make the default model for openai gpt-4o
- when the context is over the model limit, show a better error to the
user for this
- Update tests

---------

Co-authored-by: Joseph McElroy <[email protected]>
  • Loading branch information
yansavitski and joemcelroy authored Nov 6, 2024
1 parent b585ca6 commit c35934e
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 146 deletions.
16 changes: 8 additions & 8 deletions x-pack/plugins/search_playground/common/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@
import { ModelProvider, LLMs } from './types';

export const MODELS: ModelProvider[] = [
{
name: 'OpenAI GPT-3.5 Turbo',
model: 'gpt-3.5-turbo',
promptTokenLimit: 16385,
provider: LLMs.openai,
},
{
name: 'OpenAI GPT-4o',
model: 'gpt-4o',
Expand All @@ -26,6 +20,12 @@ export const MODELS: ModelProvider[] = [
promptTokenLimit: 128000,
provider: LLMs.openai,
},
{
name: 'OpenAI GPT-3.5 Turbo',
model: 'gpt-3.5-turbo',
promptTokenLimit: 16385,
provider: LLMs.openai,
},
{
name: 'Anthropic Claude 3 Haiku',
model: 'anthropic.claude-3-haiku-20240307-v1:0',
Expand All @@ -40,13 +40,13 @@ export const MODELS: ModelProvider[] = [
},
{
name: 'Google Gemini 1.5 Pro',
model: 'gemini-1.5-pro-001',
model: 'gemini-1.5-pro-002',
promptTokenLimit: 2097152,
provider: LLMs.gemini,
},
{
name: 'Google Gemini 1.5 Flash',
model: 'gemini-1.5-flash-001',
model: 'gemini-1.5-flash-002',
promptTokenLimit: 2097152,
provider: LLMs.gemini,
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,22 @@ describe('useLLMsModels Hook', () => {
connectorType: LLMs.openai,
disabled: false,
icon: expect.any(Function),
id: 'connectorId1OpenAI GPT-3.5 Turbo ',
name: 'OpenAI GPT-3.5 Turbo ',
id: 'connectorId1OpenAI GPT-4o ',
name: 'OpenAI GPT-4o ',
showConnectorName: false,
value: 'gpt-3.5-turbo',
promptTokenLimit: 16385,
value: 'gpt-4o',
promptTokenLimit: 128000,
},
{
connectorId: 'connectorId1',
connectorName: 'OpenAI Connector',
connectorType: LLMs.openai,
disabled: false,
icon: expect.any(Function),
id: 'connectorId1OpenAI GPT-4o ',
name: 'OpenAI GPT-4o ',
id: 'connectorId1OpenAI GPT-4 Turbo ',
name: 'OpenAI GPT-4 Turbo ',
showConnectorName: false,
value: 'gpt-4o',
value: 'gpt-4-turbo',
promptTokenLimit: 128000,
},
{
Expand All @@ -65,11 +65,11 @@ describe('useLLMsModels Hook', () => {
connectorType: LLMs.openai,
disabled: false,
icon: expect.any(Function),
id: 'connectorId1OpenAI GPT-4 Turbo ',
name: 'OpenAI GPT-4 Turbo ',
id: 'connectorId1OpenAI GPT-3.5 Turbo ',
name: 'OpenAI GPT-3.5 Turbo ',
showConnectorName: false,
value: 'gpt-4-turbo',
promptTokenLimit: 128000,
value: 'gpt-3.5-turbo',
promptTokenLimit: 16385,
},
{
connectorId: 'connectorId2',
Expand Down
169 changes: 67 additions & 102 deletions x-pack/plugins/search_playground/server/lib/conversational_chain.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ import type { Client } from '@elastic/elasticsearch';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { ChatPromptTemplate } from '@langchain/core/prompts';
import { FakeListChatModel, FakeStreamingLLM } from '@langchain/core/utils/testing';
import { experimental_StreamData } from 'ai';
import { createAssist as Assist } from '../utils/assist';
import { ConversationalChain, clipContext } from './conversational_chain';
import { ConversationalChain, contextLimitCheck } from './conversational_chain';
import { ChatMessage, MessageRole } from '../types';

describe('conversational chain', () => {
Expand All @@ -30,16 +29,20 @@ describe('conversational chain', () => {
}: {
responses: string[];
chat: ChatMessage[];
expectedFinalAnswer: string;
expectedDocs: any;
expectedTokens: any;
expectedSearchRequest: any;
expectedFinalAnswer?: string;
expectedDocs?: any;
expectedTokens?: any;
expectedSearchRequest?: any;
contentField?: Record<string, string>;
isChatModel?: boolean;
docs?: any;
expectedHasClipped?: boolean;
modelLimit?: number;
}) => {
if (expectedHasClipped) {
expect.assertions(1);
}

const searchMock = jest.fn().mockImplementation(() => {
return {
hits: {
Expand Down Expand Up @@ -101,44 +104,52 @@ describe('conversational chain', () => {
questionRewritePrompt: 'rewrite question {question} using {context}"',
});

const stream = await conversationalChain.stream(aiClient, chat);
try {
const stream = await conversationalChain.stream(aiClient, chat);

const streamToValue: string[] = await new Promise((resolve, reject) => {
const reader = stream.getReader();
const textDecoder = new TextDecoder();
const chunks: string[] = [];
const streamToValue: string[] = await new Promise((resolve, reject) => {
const reader = stream.getReader();
const textDecoder = new TextDecoder();
const chunks: string[] = [];

const read = () => {
reader.read().then(({ done, value }) => {
if (done) {
resolve(chunks);
} else {
chunks.push(textDecoder.decode(value));
read();
}
}, reject);
};
read();
});
const read = () => {
reader.read().then(({ done, value }) => {
if (done) {
resolve(chunks);
} else {
chunks.push(textDecoder.decode(value));
read();
}
}, reject);
};
read();
});

const textValue = streamToValue
.filter((v) => v[0] === '0')
.reduce((acc, v) => acc + v.replace(/0:"(.*)"\n/, '$1'), '');
expect(textValue).toEqual(expectedFinalAnswer);
const textValue = streamToValue
.filter((v) => v[0] === '0')
.reduce((acc, v) => acc + v.replace(/0:"(.*)"\n/, '$1'), '');
expect(textValue).toEqual(expectedFinalAnswer);

const annotations = streamToValue
.filter((v) => v[0] === '8')
.map((entry) => entry.replace(/8:(.*)\n/, '$1'), '')
.map((entry) => JSON.parse(entry))
.reduce((acc, v) => acc.concat(v), []);
const annotations = streamToValue
.filter((v) => v[0] === '8')
.map((entry) => entry.replace(/8:(.*)\n/, '$1'), '')
.map((entry) => JSON.parse(entry))
.reduce((acc, v) => acc.concat(v), []);

const docValues = annotations.filter((v: { type: string }) => v.type === 'retrieved_docs');
const tokens = annotations.filter((v: { type: string }) => v.type.endsWith('_token_count'));
const hasClipped = !!annotations.some((v: { type: string }) => v.type === 'context_clipped');
expect(docValues).toEqual(expectedDocs);
expect(tokens).toEqual(expectedTokens);
expect(hasClipped).toEqual(expectedHasClipped);
expect(searchMock.mock.calls[0]).toEqual(expectedSearchRequest);
const docValues = annotations.filter((v: { type: string }) => v.type === 'retrieved_docs');
const tokens = annotations.filter((v: { type: string }) => v.type.endsWith('_token_count'));
const hasClipped = !!annotations.some((v: { type: string }) => v.type === 'context_clipped');
expect(docValues).toEqual(expectedDocs);
expect(tokens).toEqual(expectedTokens);
expect(hasClipped).toEqual(expectedHasClipped);
expect(searchMock.mock.calls[0]).toEqual(expectedSearchRequest);
} catch (error) {
if (expectedHasClipped) {
expect(error).toMatchInlineSnapshot(`[ContextLimitError: Context exceeds the model limit]`);
} else {
throw error;
}
}
};

it('should be able to create a conversational chain', async () => {
Expand Down Expand Up @@ -470,102 +481,56 @@ describe('conversational chain', () => {
},
],
modelLimit: 100,
expectedFinalAnswer: 'the final answer',
expectedDocs: [
{
documents: [
{
metadata: { _id: '1', _index: 'index' },
pageContent: expect.any(String),
},
{
metadata: { _id: '1', _index: 'website' },
pageContent: expect.any(String),
},
],
type: 'retrieved_docs',
},
],
// Even with body_content of 1000, the token count should be below or equal to model limit of 100
expectedTokens: [
{ type: 'context_token_count', count: 63 },
{ type: 'prompt_token_count', count: 97 },
],
expectedHasClipped: true,
expectedSearchRequest: [
{
method: 'POST',
path: '/index,website/_search',
body: { query: { match: { field: 'rewrite "the" question' } }, size: 3 },
},
],
isChatModel: false,
});
}, 10000);

describe('clipContext', () => {
describe('contextLimitCheck', () => {
const prompt = ChatPromptTemplate.fromTemplate(
'you are a QA bot {question} {chat_history} {context}'
);

afterEach(() => {
jest.clearAllMocks();
});

it('should return the input as is if modelLimit is undefined', async () => {
const input = {
context: 'This is a test context.',
question: 'This is a test question.',
chat_history: 'This is a test chat history.',
};
jest.spyOn(prompt, 'format');
const result = await contextLimitCheck(undefined, prompt)(input);

const data = new experimental_StreamData();
const appendMessageAnnotationSpy = jest.spyOn(data, 'appendMessageAnnotation');

const result = await clipContext(undefined, prompt, data)(input);
expect(result).toEqual(input);
expect(appendMessageAnnotationSpy).not.toHaveBeenCalled();
expect(result).toBe(input);
expect(prompt.format).not.toHaveBeenCalled();
});

it('should not clip context if within modelLimit', async () => {
it('should return the input if within modelLimit', async () => {
const input = {
context: 'This is a test context.',
question: 'This is a test question.',
chat_history: 'This is a test chat history.',
};
const data = new experimental_StreamData();
const appendMessageAnnotationSpy = jest.spyOn(data, 'appendMessageAnnotation');
const result = await clipContext(10000, prompt, data)(input);
jest.spyOn(prompt, 'format');
const result = await contextLimitCheck(10000, prompt)(input);
expect(result).toEqual(input);
expect(appendMessageAnnotationSpy).not.toHaveBeenCalled();
expect(prompt.format).toHaveBeenCalledWith(input);
});

it('should clip context if exceeds modelLimit', async () => {
expect.assertions(1);
const input = {
context: 'This is a test context.\nThis is another line.\nAnd another one.',
question: 'This is a test question.',
chat_history: 'This is a test chat history.',
};
const data = new experimental_StreamData();
const appendMessageAnnotationSpy = jest.spyOn(data, 'appendMessageAnnotation');
const result = await clipContext(33, prompt, data)(input);
expect(result.context).toBe('This is a test context.\nThis is another line.');
expect(appendMessageAnnotationSpy).toHaveBeenCalledWith({
type: 'context_clipped',
count: 4,
});
});

it('exit when context becomes empty', async () => {
const input = {
context: 'This is a test context.\nThis is another line.\nAnd another one.',
question: 'This is a test question.',
chat_history: 'This is a test chat history.',
};
const data = new experimental_StreamData();
const appendMessageAnnotationSpy = jest.spyOn(data, 'appendMessageAnnotation');
const result = await clipContext(1, prompt, data)(input);
expect(result.context).toBe('');
expect(appendMessageAnnotationSpy).toHaveBeenCalledWith({
type: 'context_clipped',
count: 15,
});
await expect(contextLimitCheck(33, prompt)(input)).rejects.toMatchInlineSnapshot(
`[ContextLimitError: Context exceeds the model limit]`
);
});
});
});
38 changes: 14 additions & 24 deletions x-pack/plugins/search_playground/server/lib/conversational_chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import { renderTemplate } from '../utils/render_template';
import { AssistClient } from '../utils/assist';
import { getCitations } from '../utils/get_citations';
import { getTokenEstimate, getTokenEstimateFromMessages } from './token_tracking';
import { ContextLimitError } from './errors';

interface RAGOptions {
index: string;
Expand Down Expand Up @@ -88,37 +89,26 @@ position: ${i + 1}
return serializedDocs.join('\n');
};

export function clipContext(
export function contextLimitCheck(
modelLimit: number | undefined,
prompt: ChatPromptTemplate,
data: experimental_StreamData
prompt: ChatPromptTemplate
): (input: ContextInputs) => Promise<ContextInputs> {
return async (input) => {
if (!modelLimit) return input;
let context = input.context;
const clippedContext = [];

while (
getTokenEstimate(await prompt.format({ ...input, context })) > modelLimit &&
context.length > 0
) {
// remove the last paragraph
const lines = context.split('\n');
clippedContext.push(lines.pop());
context = lines.join('\n');
}
const stringPrompt = await prompt.format(input);
const approxPromptTokens = getTokenEstimate(stringPrompt);
const aboveContextLimit = approxPromptTokens > modelLimit;

if (clippedContext.length > 0) {
data.appendMessageAnnotation({
type: 'context_clipped',
count: getTokenEstimate(clippedContext.join('\n')),
});
if (aboveContextLimit) {
throw new ContextLimitError(
'Context exceeds the model limit',
modelLimit,
approxPromptTokens
);
}

return {
...input,
context,
};
return input;
};
}

Expand Down Expand Up @@ -205,7 +195,7 @@ class ConversationalChainFn {
});
return inputs;
}),
RunnableLambda.from(clipContext(this.options?.rag?.inputTokensLimit, prompt, data)),
RunnableLambda.from(contextLimitCheck(this.options?.rag?.inputTokensLimit, prompt)),
RunnableLambda.from(registerContextTokenCounts(data)),
prompt,
this.options.model.withConfig({ metadata: { type: 'question_answer_qa' } }),
Expand Down
Loading

0 comments on commit c35934e

Please sign in to comment.