Skip to content

Commit

Permalink
Improve rendering performance (#545)
Browse files Browse the repository at this point in the history
- Cache token costs during conversation shrinking
- Lift hot closures to functions
- Limit `debug` spew
  • Loading branch information
petersalas authored May 10, 2024
1 parent 07a6072 commit e7b3e2e
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 95 deletions.
2 changes: 1 addition & 1 deletion packages/ai-jsx/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"repository": "fixie-ai/ai-jsx",
"bugs": "https://github.com/fixie-ai/ai-jsx/issues",
"homepage": "https://ai-jsx.com",
"version": "0.31.0",
"version": "0.32.0",
"volta": {
"extends": "../../package.json"
},
Expand Down
25 changes: 17 additions & 8 deletions packages/ai-jsx/src/core/conversation.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ function toConversationMessages(partialRendering: AI.PartiallyRendered[]): Conve
async function loggableMessage(
message: ConversationMessage,
render: AI.RenderContext['render'],
cost?: (message: ConversationMessage, render: AI.ComponentContext['render']) => Promise<number>
cost?: (message: ConversationMessage) => Promise<number>
) {
let textPromise: PromiseLike<string> | undefined = undefined;
switch (message.type) {
Expand All @@ -261,7 +261,7 @@ async function loggableMessage(
}
}

const costPromise = cost?.(message, render);
const costPromise = cost?.(message);

const { children, ...propsWithoutChildren } = {
children: undefined,
Expand All @@ -288,9 +288,18 @@ export async function renderToConversation(
cost?: (message: ConversationMessage, render: AI.ComponentContext['render']) => Promise<number>,
budget?: number
) {
const cachedCosts = new WeakMap<AI.Element<any>, Promise<number>>();
function cachedCost(message: ConversationMessage): Promise<number> {
if (!cachedCosts.has(message.element)) {
cachedCosts.set(message.element, cost!(message, render));
}

return cachedCosts.get(message.element)!;
}

const conversationToUse =
cost && budget ? (
<ShrinkConversation cost={cost} budget={budget}>
<ShrinkConversation cost={cachedCost} budget={budget}>
{conversation}
</ShrinkConversation>
) : (
Expand All @@ -299,7 +308,7 @@ export async function renderToConversation(
const messages = toConversationMessages(await render(conversationToUse, { stop: isConversationalComponent }));

if (logger && logType) {
const loggableMessages = await Promise.all(messages.map((m) => loggableMessage(m, render, cost)));
const loggableMessages = await Promise.all(messages.map((m) => loggableMessage(m, render, cost && cachedCost)));
logger.setAttribute(`ai.jsx.${logType}`, JSON.stringify(loggableMessages));
logger.info({ [logType]: { messages: loggableMessages } }, `Got ${logType} conversation`);
}
Expand Down Expand Up @@ -465,7 +474,7 @@ export async function ShrinkConversation(
budget,
children,
}: {
cost: (message: ConversationMessage, render: AI.RenderContext['render']) => Promise<number>;
cost: (message: ConversationMessage) => Promise<number>;
budget: number;
children: Node;
},
Expand Down Expand Up @@ -508,7 +517,7 @@ export async function ShrinkConversation(
return {
type: 'immutable',
element: value,
cost: await costFn(toConversationMessages([value])[0], render),
cost: await costFn(toConversationMessages([value])[0]),
};
})
);
Expand Down Expand Up @@ -597,9 +606,9 @@ export async function ShrinkConversation(

logger.debug(
{
node: debug(nodeToReplace.element.props.children, true),
node: debug(nodeToReplace.element.props.children, false),
importance: nodeToReplace.element.props.importance,
replacement: debug(nodeToReplace.element.props.replacement, true),
replacement: debug(nodeToReplace.element.props.replacement, false),
nodeCost: nodeToReplace.cost,
totalCost: aggregateCost(roots),
budget,
Expand Down
49 changes: 34 additions & 15 deletions packages/ai-jsx/src/core/debug.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import * as AI from '../index.js';
import { Element, ElementPredicate, Node, RenderContext } from '../index.js';
import { memoizedIdSymbol } from './memoize.js';

const maxStringLength = 1000;

const debugRepresentationSymbol = Symbol('AI.JSX debug representation');

/**
Expand All @@ -32,14 +30,28 @@ export function debugRepresentation(fn: (element: Element<any>) => unknown) {
};
}

function isEmptyJSXValue(x: unknown) {
return x === undefined || x === null || typeof x === 'boolean' || (Array.isArray(x) && x.length == 0);
}

/**
* Used by {@link DebugTree} to render a tree of {@link Node}s.
* @hidden
*/
export function debug(value: unknown, expandJSXChildren: boolean = true): string {
export function debug(value: unknown, expandJSXChildren: boolean = true, maxStringLength = 2048): string {
const previouslyMemoizedElements = new Set<Element<any>>();
let remainingLength = maxStringLength;

function debugRec(value: unknown, indent: string, context: 'code' | 'children' | 'props'): string {
if (remainingLength <= 0) {
return '{...}';
}
const result = debugRecHelper(value, indent, context);
remainingLength -= result.length;
return result;
}

function debugRecHelper(value: unknown, indent: string, context: 'code' | 'children' | 'props'): string {
if (AI.isIndirectNode(value)) {
return debugRec(AI.getReferencedNode(value), indent, context);
}
Expand Down Expand Up @@ -98,16 +110,16 @@ export function debug(value: unknown, expandJSXChildren: boolean = true): string

if (value.props) {
for (const key of Object.keys(value.props)) {
if (remainingLength <= 0) {
results.push(' {...}');
break;
}

const propValue = value.props[key];
if (key === 'children' || propValue === undefined) {
continue;
} else {
const valueStr = debugRec(propValue, indent, 'props');
if (valueStr.length > maxStringLength) {
results.push(` ${key}=<omitted large object>`);
} else {
results.push(` ${key}=${valueStr}`);
}
results.push(` ${key}=${debugRec(propValue, indent, 'props')}`);
}
}
}
Expand Down Expand Up @@ -135,12 +147,19 @@ export function debug(value: unknown, expandJSXChildren: boolean = true): string
return `{${child}}`;
}
} else if (Array.isArray(value)) {
const filter =
context === 'children'
? (x: unknown) =>
x !== undefined && x !== null && typeof x !== 'boolean' && !(Array.isArray(x) && x.length == 0)
: () => true;
const values = value.filter(filter).map((v) => debugRec(v, indent, context === 'children' ? 'children' : 'code'));
const values: string[] = [];

for (const item of value) {
if (remainingLength <= 0) {
values.push('{...}');
break;
}
if (context === 'children' && isEmptyJSXValue(item)) {
continue;
}
values.push(debugRec(item, indent, context === 'children' ? 'children' : 'code'));
}

switch (context) {
case 'children':
return values.join(`\n${indent}`);
Expand Down
143 changes: 75 additions & 68 deletions packages/ai-jsx/src/core/render.ts
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,77 @@ export function createRenderContext(opts?: { logger?: LogImplementation; enableO
);
}

async function* createRenderGenerator<TIntermediate, TFinal>(
context: RenderContext,
renderStream: StreamRenderer,
renderable: Renderable,
opts: RenderOpts<TIntermediate, TFinal> | undefined,
onComplete: (value: TFinal) => void
): AsyncGenerator<TIntermediate, TFinal> {
// eslint-disable-next-line @typescript-eslint/prefer-nullish-coalescing
const shouldStop: ElementPredicate = opts?.stop || (() => false);
const generatorToWrap = renderStream(context, renderable, shouldStop, Boolean(opts?.appendOnly));

let nextPromise = generatorToWrap.next();
while (true) {
let next = await nextPromise;

if (!next.done && opts?.batchFrames) {
// We use `setImmediate` or `setTimeout` to ensure that all (recursively) queued microtasks
// are completed. (Promise.then handlers are queued as microtasks.)
// See https://developer.mozilla.org/en-US/docs/Web/API/HTML_DOM_API/Microtask_guide
const nullPromise = new Promise<null>((resolve) => {
if ('setImmediate' in globalThis) {
setImmediate(() => resolve(null));
} else {
setTimeout(() => resolve(null), 0);
}
});

while (!next.done) {
nextPromise = generatorToWrap.next();

// Consume from the generator until the null promise resolves.
const nextOrNull = await Promise.race([nextPromise, nullPromise]);
if (nextOrNull === null) {
break;
}
next = nextOrNull;
}
}

const value = opts?.stop ? (next.value as TFinal) : (next.value.join('') as TFinal);
if (next.done) {
onComplete(value);
return value;
}

if (opts?.map) {
// If there's a mapper provided, use it.
yield opts.map(value);
} else if (opts?.stop) {
// If we're doing partial rendering, exclude any elements we stopped on (to avoid accidentally leaking elements up).
yield (value as PartiallyRendered[]).filter((e) => !isElement(e)).join('') as unknown as TIntermediate;
} else {
// Otherwise yield the (string) value as-is.
yield value as unknown as TIntermediate;
}

if (!opts?.batchFrames) {
nextPromise = generatorToWrap.next();
}
}
}

async function flushGenerator<T>(generator: AsyncGenerator<unknown, T>): Promise<T> {
while (true) {
const next = await generator.next();
if (next.done) {
return next.value;
}
}
}

function createRenderContextInternal(
renderStream: StreamRenderer,
userContext: Record<symbol, any>,
Expand All @@ -425,64 +496,9 @@ function createRenderContextInternal(
let promiseResult = null as Promise<any> | null;
let hasReturnedGenerator = false;

// Construct the generator that handles the provided options
const generator = (async function* () {
// eslint-disable-next-line @typescript-eslint/prefer-nullish-coalescing
const shouldStop = (opts?.stop || (() => false)) as ElementPredicate;
const generatorToWrap = renderStream(context, renderable, shouldStop, Boolean(opts?.appendOnly));

let nextPromise = generatorToWrap.next();
while (true) {
let next = await nextPromise;

if (!next.done && opts?.batchFrames) {
// We use `setImmediate` or `setTimeout` to ensure that all (recursively) queued microtasks
// are completed. (Promise.then handlers are queued as microtasks.)
// See https://developer.mozilla.org/en-US/docs/Web/API/HTML_DOM_API/Microtask_guide
const nullPromise = new Promise<null>((resolve) => {
if ('setImmediate' in globalThis) {
setImmediate(() => resolve(null));
} else {
setTimeout(() => resolve(null), 0);
}
});

while (!next.done) {
nextPromise = generatorToWrap.next();

// Consume from the generator until the null promise resolves.
const nextOrNull = await Promise.race([nextPromise, nullPromise]);
if (nextOrNull === null) {
break;
}
next = nextOrNull;
}
}

const value = opts?.stop ? (next.value as TFinal) : (next.value.join('') as TFinal);
if (next.done) {
if (promiseResult === null) {
promiseResult = Promise.resolve(value);
}
return value;
}

if (opts?.map) {
// If there's a mapper provided, use it.
yield opts.map(value);
} else if (opts?.stop) {
// If we're doing partial rendering, exclude any elements we stopped on (to avoid accidentally leaking elements up).
yield (value as PartiallyRendered[]).filter((e) => !isElement(e)).join('');
} else {
// Otherwise yield the (string) value as-is.
yield value;
}

if (!opts?.batchFrames) {
nextPromise = generatorToWrap.next();
}
}
})() as AsyncGenerator<TIntermediate, TFinal>;
const generator = createRenderGenerator(context, renderStream, renderable, opts, (value) => {
promiseResult ||= Promise.resolve(value);
});

return {
then: (onFulfilled?, onRejected?) => {
Expand All @@ -495,16 +511,7 @@ function createRenderContextInternal(
);
}

const flush = async () => {
while (true) {
const next = await generator.next();
if (next.done) {
return next.value;
}
}
};

promiseResult = flush();
promiseResult = flushGenerator(generator);
}

return promiseResult.then(onFulfilled, onRejected);
Expand Down
8 changes: 7 additions & 1 deletion packages/docs/docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# Changelog

## 0.31.0
## 0.32.0

- Improve rendering performance:
- Change `<ShrinkConversation>` to cache token costs when remeasuring the same elements
- Reduce performance impact of debug logging

## [0.31.0](https://github.com/fixie-ai/ai-jsx/tree/07a6072ef77ffb8403786cf02538d2afc11f61f5)

- Fix incorrect unwrapping of JSX array `children`.

Expand Down
4 changes: 2 additions & 2 deletions packages/examples/test/core/completion.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ describe('OpenTelemetry', () => {
{"hello"}
</UserMessage>]",
"ai.jsx.tag": "ShrinkConversation",
"ai.jsx.tree": "<ShrinkConversation cost={tokenCountForConversationMessage} budget={16381}>
"ai.jsx.tree": "<ShrinkConversation cost={cachedCost} budget={16381}>
<UserMessage>
{"hello"}
</UserMessage>
Expand Down Expand Up @@ -171,7 +171,7 @@ describe('OpenTelemetry', () => {
{"hello"}
</UserMessage>]",
"ai.jsx.tag": "ShrinkConversation",
"ai.jsx.tree": "<ShrinkConversation cost={tokenCountForConversationMessage} budget={16381}>
"ai.jsx.tree": "<ShrinkConversation cost={cachedCost} budget={16381}>
<UserMessage>
{"hello"}
</UserMessage>
Expand Down

0 comments on commit e7b3e2e

Please sign in to comment.