Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: clean up chat stream #593

Merged
merged 1 commit into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cortex-js/src/command.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import { ModelUpdateCommand } from './infrastructure/commanders/models/model-upd
DatabaseModule,
ModelsModule,
CortexModule,
ChatModule,
ExtensionModule,
HttpModule,
CliUsecasesModule,
Expand Down
2 changes: 1 addition & 1 deletion cortex-js/src/command.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { CommandFactory } from 'nest-commander';
import { CommandModule } from './command.module';

async function bootstrap() {
await CommandFactory.run(CommandModule);
await CommandFactory.run(CommandModule, ['warn', 'error']);
}

bootstrap();
5 changes: 4 additions & 1 deletion cortex-js/src/domain/abstracts/engine.abstract.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
/* eslint-disable no-unused-vars, @typescript-eslint/no-unused-vars */
import stream from 'stream';
import { Model, ModelSettingParams } from '../models/model.interface';
import { Extension } from './extension.abstract';

export abstract class EngineExtension extends Extension {
abstract provider: string;

abstract inference(completion: any, req: any, stream: any, res?: any): void;
abstract inference(dto: any, headers: Record<string, string>): Promise<any>;

abstract inferenceStream(dto: any, headers: any): Promise<stream.Readable>;

async loadModel(
model: Model,
Expand Down
151 changes: 34 additions & 117 deletions cortex-js/src/domain/abstracts/oai.abstract.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import { HttpService } from '@nestjs/axios';
import { EngineExtension } from './engine.abstract';
import { stdout } from 'process';

export type ChatStreamEvent = {
type: 'data' | 'error' | 'end';
data?: any;
error?: any;
};
import stream from 'stream';

export abstract class OAIEngineExtension extends EngineExtension {
abstract apiUrl: string;
Expand All @@ -15,120 +9,43 @@ export abstract class OAIEngineExtension extends EngineExtension {
super();
}

inference(
override async inferenceStream(
createChatDto: any,
headers: Record<string, string>,
writableStream: WritableStream<ChatStreamEvent>,
res?: any,
) {
if (createChatDto.stream === true) {
if (res) {
res.writeHead(200, {
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
Connection: 'keep-alive',
'Access-Control-Allow-Origin': '*',
});
this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
responseType: 'stream',
})
.toPromise()
.then((response) => {
response?.data.pipe(res);
});
} else {
const decoder = new TextDecoder('utf-8');
const defaultWriter = writableStream.getWriter();
defaultWriter.ready.then(() => {
this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
responseType: 'stream',
})
.subscribe({
next: (response) => {
response.data.on('data', (chunk: any) => {
let content = '';
const text = decoder.decode(chunk);
const lines = text.trim().split('\n');
let cachedLines = '';
for (const line of lines) {
try {
const toParse = cachedLines + line;
if (!line.includes('data: [DONE]')) {
const data = JSON.parse(toParse.replace('data: ', ''));
content += data.choices[0]?.delta?.content ?? '';

if (content.startsWith('assistant: ')) {
content = content.replace('assistant: ', '');
}

if (content !== '') {
defaultWriter.write({
type: 'data',
data: content,
});
}
}
} catch {
cachedLines = line;
}
}
});

response.data.on('error', (error: any) => {
defaultWriter.write({
type: 'error',
error,
});
});
): Promise<stream.Readable> {
const response = await this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
responseType: 'stream',
})
.toPromise();

if (!response) {
throw new Error('No response');
}

response.data.on('end', () => {
// stdout.write('Stream end');
defaultWriter.write({
type: 'end',
});
});
},
return response.data;
}

error: (error) => {
stdout.write('Stream error: ' + error);
},
});
});
}
} else {
const defaultWriter = writableStream.getWriter();
defaultWriter.ready.then(() => {
this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
})
.toPromise()
.then((response) => {
defaultWriter.write({
type: 'data',
data: response?.data,
});
})
.catch((error: any) => {
defaultWriter.write({
type: 'error',
error,
});
});
});
override async inference(
createChatDto: any,
headers: Record<string, string>,
): Promise<any> {
const response = await this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
})
.toPromise();
if (!response) {
throw new Error('No response');
}

return response.data;
}
}
13 changes: 2 additions & 11 deletions cortex-js/src/infrastructure/commanders/chat.command.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import { ChatUsecases } from '@/usecases/chat/chat.usecases';
import { CommandRunner, SubCommand, Option } from 'nest-commander';
import { ChatCliUsecases } from './usecases/chat.cli.usecases';
import { CortexUsecases } from '@/usecases/cortex/cortex.usecases';
import { exit } from 'node:process';

type ChatOptions = {
Expand All @@ -10,10 +8,7 @@ type ChatOptions = {

@SubCommand({ name: 'chat', description: 'Start a chat with a model' })
export class ChatCommand extends CommandRunner {
constructor(
private readonly chatUsecases: ChatUsecases,
private readonly cortexUsecases: CortexUsecases,
) {
constructor(private readonly chatCliUsecases: ChatCliUsecases) {
super();
}

Expand All @@ -24,11 +19,7 @@ export class ChatCommand extends CommandRunner {
exit(1);
}

const chatCliUsecases = new ChatCliUsecases(
this.chatUsecases,
this.cortexUsecases,
);
return chatCliUsecases.chat(modelId);
return this.chatCliUsecases.chat(modelId);
}

@Option({
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//// HF Chat template
export const OPEN_CHAT_3_5_JINJA = ``;
export const OPEN_CHAT_3_5_JINJA = `{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}`;

export const ZEPHYR_JINJA = `{% for message in messages %}
{% if message['role'] == 'user' %}
Expand Down
8 changes: 2 additions & 6 deletions cortex-js/src/infrastructure/commanders/serve.command.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@ type ServeOptions = {
description: 'Providing API endpoint for Cortex backend',
})
export class ServeCommand extends CommandRunner {
constructor() {
super();
}

async run(_input: string[], options?: ServeOptions): Promise<void> {
const host = options?.host || defaultCortexJsHost;
const port = options?.port || defaultCortexJsPort;
Expand All @@ -34,15 +30,15 @@ export class ServeCommand extends CommandRunner {
}

@Option({
flags: '--host <host>',
flags: '-h, --host <host>',
description: 'Host to serve the application',
})
parseHost(value: string) {
return value;
}

@Option({
flags: '--port <port>',
flags: '-p, --port <port>',
description: 'Port to serve the application',
})
parsePort(value: string) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import { CortexUsecases } from '@/usecases/cortex/cortex.usecases';
import { ModelsUsecases } from '@/usecases/models/models.usecases';
import { CommandRunner, SubCommand, Option } from 'nest-commander';
import { exit } from 'node:process';
import { ChatUsecases } from '@/usecases/chat/chat.usecases';
import { ChatCliUsecases } from '../usecases/chat.cli.usecases';
import { defaultCortexCppHost, defaultCortexCppPort } from 'constant';

Expand All @@ -18,7 +17,7 @@ export class RunCommand extends CommandRunner {
constructor(
private readonly modelsUsecases: ModelsUsecases,
private readonly cortexUsecases: CortexUsecases,
private readonly chatUsecases: ChatUsecases,
private readonly chatCliUsecases: ChatCliUsecases,
) {
super();
}
Expand All @@ -36,11 +35,7 @@ export class RunCommand extends CommandRunner {
false,
);
await this.modelsUsecases.startModel(modelId);
const chatCliUsecases = new ChatCliUsecases(
this.chatUsecases,
this.cortexUsecases,
);
await chatCliUsecases.chat(modelId);
await this.chatCliUsecases.chat(modelId);
}

@Option({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ import { ChatUsecases } from '@/usecases/chat/chat.usecases';
import { ChatCompletionRole } from '@/domain/models/message.interface';
import { exit, stdin, stdout } from 'node:process';
import * as readline from 'node:readline/promises';
import { ChatStreamEvent } from '@/domain/abstracts/oai.abstract';
import { ChatCompletionMessage } from '@/infrastructure/dtos/chat/chat-completion-message.dto';
import { CreateChatCompletionDto } from '@/infrastructure/dtos/chat/create-chat-completion.dto';
import { CortexUsecases } from '@/usecases/cortex/cortex.usecases';
import { Injectable } from '@nestjs/common';

// TODO: make this class injectable
@Injectable()
export class ChatCliUsecases {
private exitClause = 'exit()';
private userIndicator = '>> ';
Expand Down Expand Up @@ -59,26 +59,44 @@ export class ChatCliUsecases {
top_p: 0.7,
};

let llmFullResponse = '';
const writableStream = new WritableStream<ChatStreamEvent>({
write(chunk) {
if (chunk.type === 'data') {
stdout.write(chunk.data ?? '');
llmFullResponse += chunk.data ?? '';
} else if (chunk.type === 'error') {
console.log('Error!!');
} else {
messages.push({
content: llmFullResponse,
role: ChatCompletionRole.Assistant,
});
llmFullResponse = '';
console.log('\n');
const decoder = new TextDecoder('utf-8');
this.chatUsecases.inferenceStream(chatDto, {}).then((response) => {
response.on('error', (error) => {
console.error(error);
rl.prompt();
});

response.on('end', () => {
console.log('\n');
rl.prompt();
});

response.on('data', (chunk) => {
let content = '';
const text = decoder.decode(chunk);
const lines = text.trim().split('\n');
let cachedLines = '';
for (const line of lines) {
try {
const toParse = cachedLines + line;
if (!line.includes('data: [DONE]')) {
const data = JSON.parse(toParse.replace('data: ', ''));
content += data.choices[0]?.delta?.content ?? '';

if (content.startsWith('assistant: ')) {
content = content.replace('assistant: ', '');
}

if (content.trim().length > 0) {
stdout.write(content);
}
}
} catch {
cachedLines = line;
}
}
},
});
});

this.chatUsecases.createChatCompletions(chatDto, {}, writableStream);
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ import { InitCliUsecases } from './init.cli.usecases';
import { HttpModule } from '@nestjs/axios';
import { ModelsCliUsecases } from './models.cli.usecases';
import { ModelsModule } from '@/usecases/models/models.module';
import { ChatCliUsecases } from './chat.cli.usecases';
import { ChatModule } from '@/usecases/chat/chat.module';
import { CortexModule } from '@/usecases/cortex/cortex.module';

@Module({
imports: [HttpModule, ModelsModule],
controllers: [],
providers: [InitCliUsecases, ModelsCliUsecases],
exports: [InitCliUsecases, ModelsCliUsecases],
imports: [HttpModule, ModelsModule, ChatModule, CortexModule],
providers: [InitCliUsecases, ModelsCliUsecases, ChatCliUsecases],
exports: [InitCliUsecases, ModelsCliUsecases, ChatCliUsecases],
})
export class CliUsecasesModule {}
Loading