From d9da79dce301ceeefa30a7a2cab5e192d6d968c9 Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 30 May 2024 11:41:15 +0700 Subject: [PATCH] chore: cortex chat enhancement --- .../infrastructure/commanders/chat.command.ts | 48 +++++++++-- .../commanders/models/model-list.command.ts | 18 +++- .../commanders/models/model-stop.command.ts | 2 +- .../commanders/usecases/chat.cli.usecases.ts | 82 +++++++++---------- 4 files changed, 97 insertions(+), 53 deletions(-) diff --git a/cortex-js/src/infrastructure/commanders/chat.command.ts b/cortex-js/src/infrastructure/commanders/chat.command.ts index 8f3e5c4e7..99bf3bade 100644 --- a/cortex-js/src/infrastructure/commanders/chat.command.ts +++ b/cortex-js/src/infrastructure/commanders/chat.command.ts @@ -1,6 +1,13 @@ -import { CommandRunner, SubCommand, Option } from 'nest-commander'; +import { + CommandRunner, + SubCommand, + Option, + InquirerService, +} from 'nest-commander'; import { ChatCliUsecases } from './usecases/chat.cli.usecases'; import { exit } from 'node:process'; +import { PSCliUsecases } from './usecases/ps.cli.usecases'; +import { ModelsUsecases } from '@/usecases/models/models.usecases'; type ChatOptions = { threadId?: string; @@ -10,22 +17,47 @@ type ChatOptions = { @SubCommand({ name: 'chat', description: 'Send a chat request to a model' }) export class ChatCommand extends CommandRunner { - constructor(private readonly chatCliUsecases: ChatCliUsecases) { + constructor( + private readonly inquirerService: InquirerService, + private readonly chatCliUsecases: ChatCliUsecases, + private readonly modelsUsecases: ModelsUsecases, + private readonly psCliUsecases: PSCliUsecases, + ) { super(); } async run(_input: string[], options: ChatOptions): Promise { - const modelId = _input[0]; - if (!modelId) { - console.error('Model ID is required'); - exit(1); + let modelId = _input[0]; + let message = _input[1] ?? options.message; + if (!modelId || !(await this.modelsUsecases.findOne(modelId))) { + message = _input[0] ?? options.message; + // Check for running models + const models = await this.psCliUsecases.getModels(); + if (models.length === 1) { + modelId = models[0].modelId; + } else if (models.length > 0) { + const { model } = await this.inquirerService.inquirer.prompt({ + type: 'list', + name: 'model', + message: 'Select running model to chat with:', + choices: models.map((e) => ({ + name: e.modelId, + value: e.modelId, + })), + }); + modelId = model; + } else { + console.error('Model ID is required'); + exit(1); + } } return this.chatCliUsecases.chat( modelId, options.threadId, - options.message, + message, // Accept both message from inputs or arguments options.attach, + false, // Do not stop cortex session or loaded model ); } @@ -39,8 +71,8 @@ export class ChatCommand extends CommandRunner { @Option({ flags: '-m, --message ', + defaultValue: undefined, description: 'Message to send to the model', - required: true, }) parseModelId(value: string) { return value; diff --git a/cortex-js/src/infrastructure/commanders/models/model-list.command.ts b/cortex-js/src/infrastructure/commanders/models/model-list.command.ts index 6e491fc8d..052e8feec 100644 --- a/cortex-js/src/infrastructure/commanders/models/model-list.command.ts +++ b/cortex-js/src/infrastructure/commanders/models/model-list.command.ts @@ -1,14 +1,26 @@ -import { CommandRunner, SubCommand } from 'nest-commander'; +import { CommandRunner, SubCommand, Option } from 'nest-commander'; import { ModelsCliUsecases } from '../usecases/models.cli.usecases'; +interface ModelListOptions { + format: 'table' | 'json'; +} @SubCommand({ name: 'list', description: 'List all models locally.' }) export class ModelListCommand extends CommandRunner { constructor(private readonly modelsCliUsecases: ModelsCliUsecases) { super(); } - async run(): Promise { + async run(_input: string[], option: ModelListOptions): Promise { const models = await this.modelsCliUsecases.listAllModels(); - console.log(models); + option.format === 'table' ? console.table(models) : console.log(models); + } + + @Option({ + flags: '-f, --format ', + defaultValue: 'json', + description: 'Print models list in table or json format', + }) + parseModelId(value: string) { + return value; } } diff --git a/cortex-js/src/infrastructure/commanders/models/model-stop.command.ts b/cortex-js/src/infrastructure/commanders/models/model-stop.command.ts index f13f2021c..3168e9b7a 100644 --- a/cortex-js/src/infrastructure/commanders/models/model-stop.command.ts +++ b/cortex-js/src/infrastructure/commanders/models/model-stop.command.ts @@ -20,7 +20,7 @@ export class ModelStopCommand extends CommandRunner { await this.modelsCliUsecases .stopModel(input[0]) - .then(() => this.cortexUsecases.stopCortex()) + .then(() => this.modelsCliUsecases.stopModel(input[0])) .then(console.log); } } diff --git a/cortex-js/src/infrastructure/commanders/usecases/chat.cli.usecases.ts b/cortex-js/src/infrastructure/commanders/usecases/chat.cli.usecases.ts index 5147b2e1a..2ef890d49 100644 --- a/cortex-js/src/infrastructure/commanders/usecases/chat.cli.usecases.ts +++ b/cortex-js/src/infrastructure/commanders/usecases/chat.cli.usecases.ts @@ -36,47 +36,12 @@ export class ChatCliUsecases { private readonly messagesUsecases: MessagesUsecases, ) {} - private async getOrCreateNewThread( - modelId: string, - threadId?: string, - ): Promise { - if (threadId) { - const thread = await this.threadUsecases.findOne(threadId); - if (!thread) throw new Error(`Cannot find thread with id: ${threadId}`); - return thread; - } - - const model = await this.modelsUsecases.findOne(modelId); - if (!model) throw new Error(`Cannot find model with id: ${modelId}`); - - const assistant = await this.assistantUsecases.findOne('jan'); - if (!assistant) throw new Error('No assistant available'); - - const createThreadModel: CreateThreadModelInfoDto = { - id: modelId, - settings: model.settings, - parameters: model.parameters, - }; - - const assistantDto: CreateThreadAssistantDto = { - assistant_id: assistant.id, - assistant_name: assistant.name, - model: createThreadModel, - }; - - const createThreadDto: CreateThreadDto = { - title: 'New Thread', - assistants: [assistantDto], - }; - - return this.threadUsecases.create(createThreadDto); - } - async chat( modelId: string, threadId?: string, message?: string, attach: boolean = true, + stopModel: boolean = true, ): Promise { if (attach) console.log(`Inorder to exit, type '${this.exitClause}'.`); const thread = await this.getOrCreateNewThread(modelId, threadId); @@ -95,11 +60,10 @@ export class ChatCliUsecases { if (message) sendCompletionMessage.bind(this)(message); if (attach) rl.prompt(); - rl.on('close', () => { - this.cortexUsecases.stopCortex().then(() => { - if (attach) console.log(this.exitMessage); - exit(0); - }); + rl.on('close', async () => { + if (stopModel) await this.modelsUsecases.stopModel(modelId); + if (attach) console.log(this.exitMessage); + exit(0); }); rl.on('line', sendCompletionMessage.bind(this)); @@ -213,4 +177,40 @@ export class ChatCliUsecases { }); } } + + private async getOrCreateNewThread( + modelId: string, + threadId?: string, + ): Promise { + if (threadId) { + const thread = await this.threadUsecases.findOne(threadId); + if (!thread) throw new Error(`Cannot find thread with id: ${threadId}`); + return thread; + } + + const model = await this.modelsUsecases.findOne(modelId); + if (!model) throw new Error(`Cannot find model with id: ${modelId}`); + + const assistant = await this.assistantUsecases.findOne('jan'); + if (!assistant) throw new Error('No assistant available'); + + const createThreadModel: CreateThreadModelInfoDto = { + id: modelId, + settings: model.settings, + parameters: model.parameters, + }; + + const assistantDto: CreateThreadAssistantDto = { + assistant_id: assistant.id, + assistant_name: assistant.name, + model: createThreadModel, + }; + + const createThreadDto: CreateThreadDto = { + title: 'New Thread', + assistants: [assistantDto], + }; + + return this.threadUsecases.create(createThreadDto); + } }