Skip to content

Commit

Permalink
chore: clean up chat stream
Browse files Browse the repository at this point in the history
  • Loading branch information
namchuai committed May 21, 2024
1 parent e65ad47 commit 47e7a9e
Show file tree
Hide file tree
Showing 13 changed files with 125 additions and 197 deletions.
1 change: 0 additions & 1 deletion cortex-js/src/command.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import { ModelUpdateCommand } from './infrastructure/commanders/models/model-upd
DatabaseModule,
ModelsModule,
CortexModule,
ChatModule,
ExtensionModule,
HttpModule,
CliUsecasesModule,
Expand Down
2 changes: 1 addition & 1 deletion cortex-js/src/command.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { CommandFactory } from 'nest-commander';
import { CommandModule } from './command.module';

async function bootstrap() {
await CommandFactory.run(CommandModule);
await CommandFactory.run(CommandModule, ['warn', 'error']);
}

bootstrap();
6 changes: 5 additions & 1 deletion cortex-js/src/domain/abstracts/engine.abstract.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
/* eslint-disable no-unused-vars, @typescript-eslint/no-unused-vars */
import stream from 'stream';
import { Model, ModelSettingParams } from '../models/model.interface';
import { Extension } from './extension.abstract';

export abstract class EngineExtension extends Extension {
abstract provider: string;

abstract inference(completion: any, req: any, stream: any, res?: any): void;
// TODO: NamH check this, it's not working right now
abstract inference(dto: any, headers: Record<string, string>): Promise<any>;

abstract inferenceStream(dto: any, headers: any): Promise<stream.Readable>;

async loadModel(
model: Model,
Expand Down
148 changes: 30 additions & 118 deletions cortex-js/src/domain/abstracts/oai.abstract.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import { HttpService } from '@nestjs/axios';
import { EngineExtension } from './engine.abstract';
import { stdout } from 'process';

export type ChatStreamEvent = {
type: 'data' | 'error' | 'end';
data?: any;
error?: any;
};
import stream from 'stream';

export abstract class OAIEngineExtension extends EngineExtension {
abstract apiUrl: string;
Expand All @@ -15,120 +9,38 @@ export abstract class OAIEngineExtension extends EngineExtension {
super();
}

inference(
override async inferenceStream(
createChatDto: any,
headers: Record<string, string>,
writableStream: WritableStream<ChatStreamEvent>,
res?: any,
) {
if (createChatDto.stream === true) {
if (res) {
res.writeHead(200, {
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
Connection: 'keep-alive',
'Access-Control-Allow-Origin': '*',
});
this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
responseType: 'stream',
})
.toPromise()
.then((response) => {
response?.data.pipe(res);
});
} else {
const decoder = new TextDecoder('utf-8');
const defaultWriter = writableStream.getWriter();
defaultWriter.ready.then(() => {
this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
responseType: 'stream',
})
.subscribe({
next: (response) => {
response.data.on('data', (chunk: any) => {
let content = '';
const text = decoder.decode(chunk);
const lines = text.trim().split('\n');
let cachedLines = '';
for (const line of lines) {
try {
const toParse = cachedLines + line;
if (!line.includes('data: [DONE]')) {
const data = JSON.parse(toParse.replace('data: ', ''));
content += data.choices[0]?.delta?.content ?? '';

if (content.startsWith('assistant: ')) {
content = content.replace('assistant: ', '');
}

if (content !== '') {
defaultWriter.write({
type: 'data',
data: content,
});
}
}
} catch {
cachedLines = line;
}
}
});

response.data.on('error', (error: any) => {
defaultWriter.write({
type: 'error',
error,
});
});
): Promise<stream.Readable> {
const response = await this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
responseType: 'stream',
})
.toPromise();

if (!response) {
throw new Error('No response');
}

response.data.on('end', () => {
// stdout.write('Stream end');
defaultWriter.write({
type: 'end',
});
});
},
return response.data;
}

error: (error) => {
stdout.write('Stream error: ' + error);
},
});
});
}
} else {
const defaultWriter = writableStream.getWriter();
defaultWriter.ready.then(() => {
this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
})
.toPromise()
.then((response) => {
defaultWriter.write({
type: 'data',
data: response?.data,
});
})
.catch((error: any) => {
defaultWriter.write({
type: 'error',
error,
});
});
});
}
override async inference(
createChatDto: any,
headers: Record<string, string>,
) {
return this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
})
.toPromise();
}
}
13 changes: 2 additions & 11 deletions cortex-js/src/infrastructure/commanders/chat.command.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import { ChatUsecases } from '@/usecases/chat/chat.usecases';
import { CommandRunner, SubCommand, Option } from 'nest-commander';
import { ChatCliUsecases } from './usecases/chat.cli.usecases';
import { CortexUsecases } from '@/usecases/cortex/cortex.usecases';
import { exit } from 'node:process';

type ChatOptions = {
Expand All @@ -10,10 +8,7 @@ type ChatOptions = {

@SubCommand({ name: 'chat', description: 'Start a chat with a model' })
export class ChatCommand extends CommandRunner {
constructor(
private readonly chatUsecases: ChatUsecases,
private readonly cortexUsecases: CortexUsecases,
) {
constructor(private readonly chatCliUsecases: ChatCliUsecases) {
super();
}

Expand All @@ -24,11 +19,7 @@ export class ChatCommand extends CommandRunner {
exit(1);
}

const chatCliUsecases = new ChatCliUsecases(
this.chatUsecases,
this.cortexUsecases,
);
return chatCliUsecases.chat(modelId);
return this.chatCliUsecases.chat(modelId);
}

@Option({
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//// HF Chat template
export const OPEN_CHAT_3_5_JINJA = ``;
export const OPEN_CHAT_3_5_JINJA = `{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}`;

export const ZEPHYR_JINJA = `{% for message in messages %}
{% if message['role'] == 'user' %}
Expand Down
8 changes: 2 additions & 6 deletions cortex-js/src/infrastructure/commanders/serve.command.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@ type ServeOptions = {
description: 'Providing API endpoint for Cortex backend',
})
export class ServeCommand extends CommandRunner {
constructor() {
super();
}

async run(_input: string[], options?: ServeOptions): Promise<void> {
const host = options?.host || defaultCortexJsHost;
const port = options?.port || defaultCortexJsPort;
Expand All @@ -34,15 +30,15 @@ export class ServeCommand extends CommandRunner {
}

@Option({
flags: '--host <host>',
flags: '-h, --host <host>',
description: 'Host to serve the application',
})
parseHost(value: string) {
return value;
}

@Option({
flags: '--port <port>',
flags: '-p, --port <port>',
description: 'Port to serve the application',
})
parsePort(value: string) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import { CortexUsecases } from '@/usecases/cortex/cortex.usecases';
import { ModelsUsecases } from '@/usecases/models/models.usecases';
import { CommandRunner, SubCommand, Option } from 'nest-commander';
import { exit } from 'node:process';
import { ChatUsecases } from '@/usecases/chat/chat.usecases';
import { ChatCliUsecases } from '../usecases/chat.cli.usecases';
import { defaultCortexCppHost, defaultCortexCppPort } from 'constant';

Expand All @@ -18,7 +17,7 @@ export class RunCommand extends CommandRunner {
constructor(
private readonly modelsUsecases: ModelsUsecases,
private readonly cortexUsecases: CortexUsecases,
private readonly chatUsecases: ChatUsecases,
private readonly chatCliUsecases: ChatCliUsecases,
) {
super();
}
Expand All @@ -36,11 +35,7 @@ export class RunCommand extends CommandRunner {
false,
);
await this.modelsUsecases.startModel(modelId);
const chatCliUsecases = new ChatCliUsecases(
this.chatUsecases,
this.cortexUsecases,
);
await chatCliUsecases.chat(modelId);
await this.chatCliUsecases.chat(modelId);
}

@Option({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ import { ChatUsecases } from '@/usecases/chat/chat.usecases';
import { ChatCompletionRole } from '@/domain/models/message.interface';
import { exit, stdin, stdout } from 'node:process';
import * as readline from 'node:readline/promises';
import { ChatStreamEvent } from '@/domain/abstracts/oai.abstract';
import { ChatCompletionMessage } from '@/infrastructure/dtos/chat/chat-completion-message.dto';
import { CreateChatCompletionDto } from '@/infrastructure/dtos/chat/create-chat-completion.dto';
import { CortexUsecases } from '@/usecases/cortex/cortex.usecases';
import { Injectable } from '@nestjs/common';

// TODO: make this class injectable
@Injectable()
export class ChatCliUsecases {
private exitClause = 'exit()';
private userIndicator = '>> ';
Expand Down Expand Up @@ -59,26 +59,34 @@ export class ChatCliUsecases {
top_p: 0.7,
};

let llmFullResponse = '';
const writableStream = new WritableStream<ChatStreamEvent>({
write(chunk) {
if (chunk.type === 'data') {
stdout.write(chunk.data ?? '');
llmFullResponse += chunk.data ?? '';
} else if (chunk.type === 'error') {
console.log('Error!!');
} else {
messages.push({
content: llmFullResponse,
role: ChatCompletionRole.Assistant,
});
llmFullResponse = '';
console.log('\n');
const decoder = new TextDecoder('utf-8');
this.chatUsecases.inferenceStream(chatDto, {}).then((response) => {
response.on('data', (chunk) => {
let content = '';
const text = decoder.decode(chunk);
const lines = text.trim().split('\n');
let cachedLines = '';
for (const line of lines) {
try {
const toParse = cachedLines + line;
if (!line.includes('data: [DONE]')) {
const data = JSON.parse(toParse.replace('data: ', ''));
content += data.choices[0]?.delta?.content ?? '';

if (content.startsWith('assistant: ')) {
content = content.replace('assistant: ', '');
}

if (content !== '') {
stdout.write(content);
}
}
} catch {
cachedLines = line;
}
}
},
});
});

this.chatUsecases.createChatCompletions(chatDto, {}, writableStream);
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ import { InitCliUsecases } from './init.cli.usecases';
import { HttpModule } from '@nestjs/axios';
import { ModelsCliUsecases } from './models.cli.usecases';
import { ModelsModule } from '@/usecases/models/models.module';
import { ChatCliUsecases } from './chat.cli.usecases';
import { ChatModule } from '@/usecases/chat/chat.module';
import { CortexModule } from '@/usecases/cortex/cortex.module';

@Module({
imports: [HttpModule, ModelsModule],
controllers: [],
providers: [InitCliUsecases, ModelsCliUsecases],
exports: [InitCliUsecases, ModelsCliUsecases],
imports: [HttpModule, ModelsModule, ChatModule, CortexModule],
providers: [InitCliUsecases, ModelsCliUsecases, ChatCliUsecases],
exports: [InitCliUsecases, ModelsCliUsecases, ChatCliUsecases],
})
export class CliUsecasesModule {}
Loading

0 comments on commit 47e7a9e

Please sign in to comment.