Skip to content

Commit

Permalink
feat(generative): speed up token streaming, avoid useStreamableValue …
Browse files Browse the repository at this point in the history
…if only last value is needed
  • Loading branch information
dqbd committed May 27, 2024
1 parent faf4c48 commit 62da2dc
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 30 deletions.
13 changes: 13 additions & 0 deletions app/generative_ui/ai/message.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"use client";

import { StreamableValue, useStreamableValue } from "ai/rsc";

export function AIMessage(props: { value: StreamableValue<string> }) {
const [data] = useStreamableValue(props.value);

return (
<div className="empty:hidden border border-gray-700 p-3 rounded-lg max-w-[50vw]">
{data}
</div>
);
}
11 changes: 3 additions & 8 deletions app/generative_ui/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import { useState } from "react";
import type { EndpointsContext } from "./agent";
import { useActions } from "./utils/client";
import { LocalContext } from "./shared";
import { readStreamableValue } from "ai/rsc";

export default function GenerativeUIPage() {
const actions = useActions<typeof EndpointsContext>();
Expand Down Expand Up @@ -35,16 +34,12 @@ export default function GenerativeUIPage() {
// consume the value stream to obtain the final string value
// after which we can append to our chat history state
(async () => {
let finalValue: string | null = null;
for await (const value of readStreamableValue(element.value)) {
finalValue = value;
}

if (finalValue != null) {
let lastEvent = await element.lastEvent;
if (typeof lastEvent === "string") {
setHistory((prev) => [
...prev,
["user", input],
["assistant", finalValue as string],
["assistant", lastEvent],
]);
}
})();
Expand Down
57 changes: 35 additions & 22 deletions app/generative_ui/utils/server.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
StreamEvent,
} from "@langchain/core/tracers/log_stream";
import { AIProvider } from "./client";
import { AIMessage } from "../ai/message";

/**
* Executes `streamEvents` method on a runnable
Expand All @@ -28,14 +29,14 @@ export function streamRunnableUI<RunInput, RunOutput>(
inputs: RunInput,
) {
const ui = createStreamableUI();
const value = createStreamableValue();
const [lastEvent, resolve] = withResolvers<string>();

(async () => {
let lastEvent: StreamEvent | null = null;
let lastEventValue: StreamEvent | null = null;

const streamableMap: Record<
const callbacks: Record<
string,
ReturnType<typeof createStreamableUI>
ReturnType<typeof createStreamableUI | typeof createStreamableValue>
> = {};

for await (const streamEvent of runnable.streamEvents(inputs, {
Expand All @@ -48,34 +49,30 @@ export function streamRunnableUI<RunInput, RunOutput>(
if (isValidElement(chunk)) {
ui.append(chunk);
} else if ("text" in chunk && typeof chunk.text === "string") {
if (!streamableMap[streamEvent.run_id]) {
streamableMap[streamEvent.run_id] = createStreamableUI();
const value = streamableMap[streamEvent.run_id].value;

// create an AI message
ui.append(
<div className="empty:hidden border border-gray-700 p-3 rounded-lg max-w-[50vw]">
{value}
</div>,
);
if (!callbacks[streamEvent.run_id]) {
// the createStreamableValue / useStreamableValue is preferred
// as the stream events are updated immediately in the UI
// rather than being batched by React via createStreamableUI
const textStream = createStreamableValue();
ui.append(<AIMessage value={textStream.value} />);
callbacks[streamEvent.run_id] = textStream;
}

streamableMap[streamEvent.run_id].append(chunk.text);
callbacks[streamEvent.run_id].append(chunk.text);
}
}
lastEvent = streamEvent;
lastEventValue = streamEvent;
}

value.done(lastEvent?.data.output);
// resolve the promise, which will be sent
// to the client thanks to RSC
resolve(lastEventValue?.data.output);

for (const ui of Object.values(streamableMap)) ui.done();
Object.values(callbacks).forEach((cb) => cb.done());
ui.done();
})();

return {
ui: ui.value,
value: value.value,
};
return { ui: ui.value, lastEvent };
}

/**
Expand Down Expand Up @@ -147,3 +144,19 @@ export function exposeEndpoints<T extends Record<string, unknown>>(
return <AIProvider actions={actions}>{props.children}</AIProvider>;
};
}

/**
* Polyfill to emulate the upcoming Promise.withResolvers
*/
export function withResolvers<T>() {
let resolve: (value: T) => void;
let reject: (reason?: any) => void;

const innerPromise = new Promise<T>((res, rej) => {
resolve = res;
reject = rej;
});

// @ts-expect-error
return [innerPromise, resolve, reject] as const;
}

0 comments on commit 62da2dc

Please sign in to comment.