Skip to content

Commit

Permalink
chore: cortex chat enhancement
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-menlo committed May 30, 2024
1 parent dd551e0 commit d9da79d
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 53 deletions.
48 changes: 40 additions & 8 deletions cortex-js/src/infrastructure/commanders/chat.command.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import { CommandRunner, SubCommand, Option } from 'nest-commander';
import {
CommandRunner,
SubCommand,
Option,
InquirerService,
} from 'nest-commander';
import { ChatCliUsecases } from './usecases/chat.cli.usecases';
import { exit } from 'node:process';
import { PSCliUsecases } from './usecases/ps.cli.usecases';
import { ModelsUsecases } from '@/usecases/models/models.usecases';

type ChatOptions = {
threadId?: string;
Expand All @@ -10,22 +17,47 @@ type ChatOptions = {

@SubCommand({ name: 'chat', description: 'Send a chat request to a model' })
export class ChatCommand extends CommandRunner {
constructor(private readonly chatCliUsecases: ChatCliUsecases) {
constructor(
private readonly inquirerService: InquirerService,
private readonly chatCliUsecases: ChatCliUsecases,
private readonly modelsUsecases: ModelsUsecases,
private readonly psCliUsecases: PSCliUsecases,
) {
super();
}

async run(_input: string[], options: ChatOptions): Promise<void> {
const modelId = _input[0];
if (!modelId) {
console.error('Model ID is required');
exit(1);
let modelId = _input[0];
let message = _input[1] ?? options.message;
if (!modelId || !(await this.modelsUsecases.findOne(modelId))) {
message = _input[0] ?? options.message;
// Check for running models
const models = await this.psCliUsecases.getModels();
if (models.length === 1) {
modelId = models[0].modelId;
} else if (models.length > 0) {
const { model } = await this.inquirerService.inquirer.prompt({
type: 'list',
name: 'model',
message: 'Select running model to chat with:',
choices: models.map((e) => ({
name: e.modelId,
value: e.modelId,
})),
});
modelId = model;
} else {
console.error('Model ID is required');
exit(1);
}
}

return this.chatCliUsecases.chat(
modelId,
options.threadId,
options.message,
message, // Accept both message from inputs or arguments
options.attach,
false, // Do not stop cortex session or loaded model
);
}

Expand All @@ -39,8 +71,8 @@ export class ChatCommand extends CommandRunner {

@Option({
flags: '-m, --message <message>',
defaultValue: undefined,
description: 'Message to send to the model',
required: true,
})
parseModelId(value: string) {
return value;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
import { CommandRunner, SubCommand } from 'nest-commander';
import { CommandRunner, SubCommand, Option } from 'nest-commander';
import { ModelsCliUsecases } from '../usecases/models.cli.usecases';

interface ModelListOptions {
format: 'table' | 'json';
}
@SubCommand({ name: 'list', description: 'List all models locally.' })
export class ModelListCommand extends CommandRunner {
constructor(private readonly modelsCliUsecases: ModelsCliUsecases) {
super();
}

async run(): Promise<void> {
async run(_input: string[], option: ModelListOptions): Promise<void> {
const models = await this.modelsCliUsecases.listAllModels();
console.log(models);
option.format === 'table' ? console.table(models) : console.log(models);
}

@Option({
flags: '-f, --format <format>',
defaultValue: 'json',
description: 'Print models list in table or json format',
})
parseModelId(value: string) {
return value;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export class ModelStopCommand extends CommandRunner {

await this.modelsCliUsecases
.stopModel(input[0])
.then(() => this.cortexUsecases.stopCortex())
.then(() => this.modelsCliUsecases.stopModel(input[0]))
.then(console.log);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,47 +36,12 @@ export class ChatCliUsecases {
private readonly messagesUsecases: MessagesUsecases,
) {}

private async getOrCreateNewThread(
modelId: string,
threadId?: string,
): Promise<Thread> {
if (threadId) {
const thread = await this.threadUsecases.findOne(threadId);
if (!thread) throw new Error(`Cannot find thread with id: ${threadId}`);
return thread;
}

const model = await this.modelsUsecases.findOne(modelId);
if (!model) throw new Error(`Cannot find model with id: ${modelId}`);

const assistant = await this.assistantUsecases.findOne('jan');
if (!assistant) throw new Error('No assistant available');

const createThreadModel: CreateThreadModelInfoDto = {
id: modelId,
settings: model.settings,
parameters: model.parameters,
};

const assistantDto: CreateThreadAssistantDto = {
assistant_id: assistant.id,
assistant_name: assistant.name,
model: createThreadModel,
};

const createThreadDto: CreateThreadDto = {
title: 'New Thread',
assistants: [assistantDto],
};

return this.threadUsecases.create(createThreadDto);
}

async chat(
modelId: string,
threadId?: string,
message?: string,
attach: boolean = true,
stopModel: boolean = true,
): Promise<void> {
if (attach) console.log(`Inorder to exit, type '${this.exitClause}'.`);
const thread = await this.getOrCreateNewThread(modelId, threadId);
Expand All @@ -95,11 +60,10 @@ export class ChatCliUsecases {
if (message) sendCompletionMessage.bind(this)(message);
if (attach) rl.prompt();

rl.on('close', () => {
this.cortexUsecases.stopCortex().then(() => {
if (attach) console.log(this.exitMessage);
exit(0);
});
rl.on('close', async () => {
if (stopModel) await this.modelsUsecases.stopModel(modelId);
if (attach) console.log(this.exitMessage);
exit(0);
});

rl.on('line', sendCompletionMessage.bind(this));
Expand Down Expand Up @@ -213,4 +177,40 @@ export class ChatCliUsecases {
});
}
}

private async getOrCreateNewThread(
modelId: string,
threadId?: string,
): Promise<Thread> {
if (threadId) {
const thread = await this.threadUsecases.findOne(threadId);
if (!thread) throw new Error(`Cannot find thread with id: ${threadId}`);
return thread;
}

const model = await this.modelsUsecases.findOne(modelId);
if (!model) throw new Error(`Cannot find model with id: ${modelId}`);

const assistant = await this.assistantUsecases.findOne('jan');
if (!assistant) throw new Error('No assistant available');

const createThreadModel: CreateThreadModelInfoDto = {
id: modelId,
settings: model.settings,
parameters: model.parameters,
};

const assistantDto: CreateThreadAssistantDto = {
assistant_id: assistant.id,
assistant_name: assistant.name,
model: createThreadModel,
};

const createThreadDto: CreateThreadDto = {
title: 'New Thread',
assistants: [assistantDto],
};

return this.threadUsecases.create(createThreadDto);
}
}

0 comments on commit d9da79d

Please sign in to comment.