Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: clean up chat stream #593

Merged
merged 1 commit into from
May 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
chore: clean up chat stream
  • Loading branch information
namchuai committed May 21, 2024
commit d341452057f00f080d3b1923729505c40ab97a84
1 change: 0 additions & 1 deletion cortex-js/src/command.module.ts
Original file line number Diff line number Diff line change
@@ -33,7 +33,6 @@ import { ModelUpdateCommand } from './infrastructure/commanders/models/model-upd
DatabaseModule,
ModelsModule,
CortexModule,
ChatModule,
ExtensionModule,
HttpModule,
CliUsecasesModule,
2 changes: 1 addition & 1 deletion cortex-js/src/command.ts
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@ import { CommandFactory } from 'nest-commander';
import { CommandModule } from './command.module';

async function bootstrap() {
await CommandFactory.run(CommandModule);
await CommandFactory.run(CommandModule, ['warn', 'error']);
}

bootstrap();
5 changes: 4 additions & 1 deletion cortex-js/src/domain/abstracts/engine.abstract.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
/* eslint-disable no-unused-vars, @typescript-eslint/no-unused-vars */
import stream from 'stream';
import { Model, ModelSettingParams } from '../models/model.interface';
import { Extension } from './extension.abstract';

export abstract class EngineExtension extends Extension {
abstract provider: string;

abstract inference(completion: any, req: any, stream: any, res?: any): void;
abstract inference(dto: any, headers: Record<string, string>): Promise<any>;

abstract inferenceStream(dto: any, headers: any): Promise<stream.Readable>;

async loadModel(
model: Model,
151 changes: 34 additions & 117 deletions cortex-js/src/domain/abstracts/oai.abstract.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import { HttpService } from '@nestjs/axios';
import { EngineExtension } from './engine.abstract';
import { stdout } from 'process';

export type ChatStreamEvent = {
type: 'data' | 'error' | 'end';
data?: any;
error?: any;
};
import stream from 'stream';

export abstract class OAIEngineExtension extends EngineExtension {
abstract apiUrl: string;
@@ -15,120 +9,43 @@ export abstract class OAIEngineExtension extends EngineExtension {
super();
}

inference(
override async inferenceStream(
createChatDto: any,
headers: Record<string, string>,
writableStream: WritableStream<ChatStreamEvent>,
res?: any,
) {
if (createChatDto.stream === true) {
if (res) {
res.writeHead(200, {
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
Connection: 'keep-alive',
'Access-Control-Allow-Origin': '*',
});
this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
responseType: 'stream',
})
.toPromise()
.then((response) => {
response?.data.pipe(res);
});
} else {
const decoder = new TextDecoder('utf-8');
const defaultWriter = writableStream.getWriter();
defaultWriter.ready.then(() => {
this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
responseType: 'stream',
})
.subscribe({
next: (response) => {
response.data.on('data', (chunk: any) => {
let content = '';
const text = decoder.decode(chunk);
const lines = text.trim().split('\n');
let cachedLines = '';
for (const line of lines) {
try {
const toParse = cachedLines + line;
if (!line.includes('data: [DONE]')) {
const data = JSON.parse(toParse.replace('data: ', ''));
content += data.choices[0]?.delta?.content ?? '';

if (content.startsWith('assistant: ')) {
content = content.replace('assistant: ', '');
}

if (content !== '') {
defaultWriter.write({
type: 'data',
data: content,
});
}
}
} catch {
cachedLines = line;
}
}
});

response.data.on('error', (error: any) => {
defaultWriter.write({
type: 'error',
error,
});
});
): Promise<stream.Readable> {
const response = await this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
responseType: 'stream',
})
.toPromise();

if (!response) {
throw new Error('No response');
}

response.data.on('end', () => {
// stdout.write('Stream end');
defaultWriter.write({
type: 'end',
});
});
},
return response.data;
}

error: (error) => {
stdout.write('Stream error: ' + error);
},
});
});
}
} else {
const defaultWriter = writableStream.getWriter();
defaultWriter.ready.then(() => {
this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
})
.toPromise()
.then((response) => {
defaultWriter.write({
type: 'data',
data: response?.data,
});
})
.catch((error: any) => {
defaultWriter.write({
type: 'error',
error,
});
});
});
override async inference(
createChatDto: any,
headers: Record<string, string>,
): Promise<any> {
const response = await this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
})
.toPromise();
if (!response) {
throw new Error('No response');
}

return response.data;
}
}
13 changes: 2 additions & 11 deletions cortex-js/src/infrastructure/commanders/chat.command.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import { ChatUsecases } from '@/usecases/chat/chat.usecases';
import { CommandRunner, SubCommand, Option } from 'nest-commander';
import { ChatCliUsecases } from './usecases/chat.cli.usecases';
import { CortexUsecases } from '@/usecases/cortex/cortex.usecases';
import { exit } from 'node:process';

type ChatOptions = {
@@ -10,10 +8,7 @@ type ChatOptions = {

@SubCommand({ name: 'chat', description: 'Start a chat with a model' })
export class ChatCommand extends CommandRunner {
constructor(
private readonly chatUsecases: ChatUsecases,
private readonly cortexUsecases: CortexUsecases,
) {
constructor(private readonly chatCliUsecases: ChatCliUsecases) {
super();
}

@@ -24,11 +19,7 @@ export class ChatCommand extends CommandRunner {
exit(1);
}

const chatCliUsecases = new ChatCliUsecases(
this.chatUsecases,
this.cortexUsecases,
);
return chatCliUsecases.chat(modelId);
return this.chatCliUsecases.chat(modelId);
}

@Option({
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//// HF Chat template
export const OPEN_CHAT_3_5_JINJA = ``;
export const OPEN_CHAT_3_5_JINJA = `{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}`;

export const ZEPHYR_JINJA = `{% for message in messages %}
{% if message['role'] == 'user' %}
8 changes: 2 additions & 6 deletions cortex-js/src/infrastructure/commanders/serve.command.ts
Original file line number Diff line number Diff line change
@@ -13,10 +13,6 @@ type ServeOptions = {
description: 'Providing API endpoint for Cortex backend',
})
export class ServeCommand extends CommandRunner {
constructor() {
super();
}

async run(_input: string[], options?: ServeOptions): Promise<void> {
const host = options?.host || defaultCortexJsHost;
const port = options?.port || defaultCortexJsPort;
@@ -34,15 +30,15 @@ export class ServeCommand extends CommandRunner {
}

@Option({
flags: '--host <host>',
flags: '-h, --host <host>',
description: 'Host to serve the application',
})
parseHost(value: string) {
return value;
}

@Option({
flags: '--port <port>',
flags: '-p, --port <port>',
description: 'Port to serve the application',
})
parsePort(value: string) {
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@ import { CortexUsecases } from '@/usecases/cortex/cortex.usecases';
import { ModelsUsecases } from '@/usecases/models/models.usecases';
import { CommandRunner, SubCommand, Option } from 'nest-commander';
import { exit } from 'node:process';
import { ChatUsecases } from '@/usecases/chat/chat.usecases';
import { ChatCliUsecases } from '../usecases/chat.cli.usecases';
import { defaultCortexCppHost, defaultCortexCppPort } from 'constant';

@@ -18,7 +17,7 @@ export class RunCommand extends CommandRunner {
constructor(
private readonly modelsUsecases: ModelsUsecases,
private readonly cortexUsecases: CortexUsecases,
private readonly chatUsecases: ChatUsecases,
private readonly chatCliUsecases: ChatCliUsecases,
) {
super();
}
@@ -36,11 +35,7 @@ export class RunCommand extends CommandRunner {
false,
);
await this.modelsUsecases.startModel(modelId);
const chatCliUsecases = new ChatCliUsecases(
this.chatUsecases,
this.cortexUsecases,
);
await chatCliUsecases.chat(modelId);
await this.chatCliUsecases.chat(modelId);
}

@Option({
Original file line number Diff line number Diff line change
@@ -2,12 +2,12 @@ import { ChatUsecases } from '@/usecases/chat/chat.usecases';
import { ChatCompletionRole } from '@/domain/models/message.interface';
import { exit, stdin, stdout } from 'node:process';
import * as readline from 'node:readline/promises';
import { ChatStreamEvent } from '@/domain/abstracts/oai.abstract';
import { ChatCompletionMessage } from '@/infrastructure/dtos/chat/chat-completion-message.dto';
import { CreateChatCompletionDto } from '@/infrastructure/dtos/chat/create-chat-completion.dto';
import { CortexUsecases } from '@/usecases/cortex/cortex.usecases';
import { Injectable } from '@nestjs/common';

// TODO: make this class injectable
@Injectable()
export class ChatCliUsecases {
private exitClause = 'exit()';
private userIndicator = '>> ';
@@ -59,26 +59,44 @@ export class ChatCliUsecases {
top_p: 0.7,
};

let llmFullResponse = '';
const writableStream = new WritableStream<ChatStreamEvent>({
write(chunk) {
if (chunk.type === 'data') {
stdout.write(chunk.data ?? '');
llmFullResponse += chunk.data ?? '';
} else if (chunk.type === 'error') {
console.log('Error!!');
} else {
messages.push({
content: llmFullResponse,
role: ChatCompletionRole.Assistant,
});
llmFullResponse = '';
console.log('\n');
const decoder = new TextDecoder('utf-8');
this.chatUsecases.inferenceStream(chatDto, {}).then((response) => {
response.on('error', (error) => {
console.error(error);
rl.prompt();
});

response.on('end', () => {
console.log('\n');
rl.prompt();
});

response.on('data', (chunk) => {
let content = '';
const text = decoder.decode(chunk);
const lines = text.trim().split('\n');
let cachedLines = '';
for (const line of lines) {
try {
const toParse = cachedLines + line;
if (!line.includes('data: [DONE]')) {
const data = JSON.parse(toParse.replace('data: ', ''));
content += data.choices[0]?.delta?.content ?? '';

if (content.startsWith('assistant: ')) {
content = content.replace('assistant: ', '');
}

if (content.trim().length > 0) {
stdout.write(content);
}
}
} catch {
cachedLines = line;
}
}
},
});
});

this.chatUsecases.createChatCompletions(chatDto, {}, writableStream);
});
}
}
Original file line number Diff line number Diff line change
@@ -3,11 +3,13 @@ import { InitCliUsecases } from './init.cli.usecases';
import { HttpModule } from '@nestjs/axios';
import { ModelsCliUsecases } from './models.cli.usecases';
import { ModelsModule } from '@/usecases/models/models.module';
import { ChatCliUsecases } from './chat.cli.usecases';
import { ChatModule } from '@/usecases/chat/chat.module';
import { CortexModule } from '@/usecases/cortex/cortex.module';

@Module({
imports: [HttpModule, ModelsModule],
controllers: [],
providers: [InitCliUsecases, ModelsCliUsecases],
exports: [InitCliUsecases, ModelsCliUsecases],
imports: [HttpModule, ModelsModule, ChatModule, CortexModule],
providers: [InitCliUsecases, ModelsCliUsecases, ChatCliUsecases],
exports: [InitCliUsecases, ModelsCliUsecases, ChatCliUsecases],
})
export class CliUsecasesModule {}
Loading