Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add model settings and prompt template from hf #588

Merged
merged 1 commit into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cortex-js/src/command.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { ModelRemoveCommand } from './infrastructure/commanders/models/model-rem
import { RunCommand } from './infrastructure/commanders/shortcuts/run.command';
import { InitCudaQuestions } from './infrastructure/commanders/questions/cuda.questions';
import { CliUsecasesModule } from './infrastructure/commanders/usecases/cli.usecases.module';
import { ModelUpdateCommand } from './infrastructure/commanders/models/model-update.command';

@Module({
imports: [
Expand Down Expand Up @@ -55,6 +56,7 @@ import { CliUsecasesModule } from './infrastructure/commanders/usecases/cli.usec
ModelGetCommand,
ModelRemoveCommand,
ModelPullCommand,
ModelUpdateCommand,

// Shortcuts
RunCommand,
Expand Down
2 changes: 1 addition & 1 deletion cortex-js/src/infrastructure/commanders/chat.command.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export class ChatCommand extends CommandRunner {
}

@Option({
flags: '--model <model_id>',
flags: '-m, --model <model_id>',
description: 'Model Id to start chat with',
})
parseModelId(value: string) {
Expand Down
2 changes: 2 additions & 0 deletions cortex-js/src/infrastructure/commanders/models.command.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { ModelListCommand } from './models/model-list.command';
import { ModelStopCommand } from './models/model-stop.command';
import { ModelPullCommand } from './models/model-pull.command';
import { ModelRemoveCommand } from './models/model-remove.command';
import { ModelUpdateCommand } from './models/model-update.command';

@SubCommand({
name: 'models',
Expand All @@ -15,6 +16,7 @@ import { ModelRemoveCommand } from './models/model-remove.command';
ModelListCommand,
ModelGetCommand,
ModelRemoveCommand,
ModelUpdateCommand,
],
description: 'Subcommands for managing models',
})
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import { CommandRunner, SubCommand, Option } from 'nest-commander';
import { ModelsCliUsecases } from '../usecases/models.cli.usecases';
import { exit } from 'node:process';
import { ModelParameterParser } from '../utils/model-parameter.parser';
import {
ModelRuntimeParams,
ModelSettingParams,
} from '@/domain/models/model.interface';

type UpdateOptions = {
model?: string;
options?: string[];
};

@SubCommand({ name: 'update', description: 'Update configuration of a model.' })
export class ModelUpdateCommand extends CommandRunner {
constructor(private readonly modelsCliUsecases: ModelsCliUsecases) {
super();
}

async run(_input: string[], option: UpdateOptions): Promise<void> {
const modelId = option.model;
if (!modelId) {
console.error('Model Id is required');
exit(1);
}

const options = option.options;
if (!options || options.length === 0) {
console.log('Nothing to update');
exit(0);
}

const parser = new ModelParameterParser();
const settingParams: ModelSettingParams = {};
const runtimeParams: ModelRuntimeParams = {};

options.forEach((option) => {
const [key, stringValue] = option.split('=');
if (parser.isModelSettingParam(key)) {
const value = parser.parse(key, stringValue);
// @ts-expect-error did the check so it's safe
settingParams[key] = value;
} else if (parser.isModelRuntimeParam(key)) {
const value = parser.parse(key, stringValue);
// @ts-expect-error did the check so it's safe
runtimeParams[key] = value;
}
});

if (Object.keys(settingParams).length > 0) {
const updatedSettingParams =
await this.modelsCliUsecases.updateModelSettingParams(
modelId,
settingParams,
);
console.log(
'Updated setting params! New setting params:',
updatedSettingParams,
);
}

if (Object.keys(runtimeParams).length > 0) {
await this.modelsCliUsecases.updateModelRuntimeParams(
modelId,
runtimeParams,
);
console.log('Updated runtime params! New runtime params:', runtimeParams);
}
}

@Option({
flags: '-m, --model <model_id>',
required: true,
description: 'Model Id to update',
})
parseModelId(value: string) {
return value;
}

@Option({
flags: '-c, --options <options...>',
description:
'Specify the options to update the model. Syntax: -c option1=value1 option2=value2. For example: cortex models update -c max_tokens=100 temperature=0.5',
})
parseOptions(option: string, optionsAccumulator: string[] = []): string[] {
optionsAccumulator.push(option);
return optionsAccumulator;
}
}
37 changes: 37 additions & 0 deletions cortex-js/src/infrastructure/commanders/prompt-constants.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//// HF Chat template
export const OPEN_CHAT_3_5_JINJA = ``;

export const ZEPHYR_JINJA = `{% for message in messages %}
{% if message['role'] == 'user' %}
{{ '<|user|>
' + message['content'] + eos_token }}
{% elif message['role'] == 'system' %}
{{ '<|system|>
' + message['content'] + eos_token }}
{% elif message['role'] == 'assistant' %}
{{ '<|assistant|>
' + message['content'] + eos_token }}
{% endif %}
{% if loop.last and add_generation_prompt %}
{{ '<|assistant|>' }}
{% endif %}
{% endfor %}`;

//// Corresponding prompt template
export const OPEN_CHAT_3_5 = `GPT4 Correct User: {prompt}<|end_of_turn|>GPT4 Correct Assistant:`;

export const ZEPHYR = `<|system|>
{system_message}</s>
<|user|>
{prompt}</s>
<|assistant|>
`;

export const COMMAND_R = `<BOS_TOKEN><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{system}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>{prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{response}
`;

// getting from https://huggingface.co/TheBloke/Llama-2-70B-Chat-GGUF
export const LLAMA_2 = `[INST] <<SYS>>
{system_message}
<</SYS>>
{prompt}[/INST]`;
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { CommandRunner, SubCommand, Option } from 'nest-commander';
import { exit } from 'node:process';
import { ChatUsecases } from '@/usecases/chat/chat.usecases';
import { ChatCliUsecases } from '../usecases/chat.cli.usecases';
import { defaultCortexCppHost, defaultCortexCppPort } from 'constant';

type RunOptions = {
model?: string;
Expand All @@ -29,7 +30,11 @@ export class RunCommand extends CommandRunner {
exit(1);
}

await this.cortexUsecases.startCortex();
await this.cortexUsecases.startCortex(
defaultCortexCppHost,
defaultCortexCppPort,
false,
);
await this.modelsUsecases.startModel(modelId);
const chatCliUsecases = new ChatCliUsecases(
this.chatUsecases,
Expand All @@ -39,7 +44,7 @@ export class RunCommand extends CommandRunner {
}

@Option({
flags: '--model <model_id>',
flags: '-m, --model <model_id>',
description: 'Model Id to start chat with',
})
parseModelId(value: string) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
import { exit } from 'node:process';
import { ModelsUsecases } from '@/usecases/models/models.usecases';
import { Model, ModelFormat } from '@/domain/models/model.interface';
import {
Model,
ModelFormat,
ModelRuntimeParams,
ModelSettingParams,
} from '@/domain/models/model.interface';
import { CreateModelDto } from '@/infrastructure/dtos/models/create-model.dto';
import { HuggingFaceRepoData } from '@/domain/models/huggingface.interface';
import { gguf } from '@huggingface/gguf';
import { InquirerService } from 'nest-commander';
import { Inject, Injectable } from '@nestjs/common';
import { Presets, SingleBar } from 'cli-progress';
import {
LLAMA_2,
OPEN_CHAT_3_5,
OPEN_CHAT_3_5_JINJA,
ZEPHYR,
ZEPHYR_JINJA,
} from '../prompt-constants';

const AllQuantizations = [
'Q3_K_S',
Expand Down Expand Up @@ -49,6 +61,20 @@ export class ModelsCliUsecases {
await this.modelsUsecases.stopModel(modelId);
}

async updateModelSettingParams(
modelId: string,
settingParams: ModelSettingParams,
): Promise<ModelSettingParams> {
return this.modelsUsecases.updateModelSettingParams(modelId, settingParams);
}

async updateModelRuntimeParams(
modelId: string,
runtimeParams: ModelRuntimeParams,
): Promise<ModelRuntimeParams> {
return this.modelsUsecases.updateModelRuntimeParams(modelId, runtimeParams);
}

private async getModelOrStop(modelId: string): Promise<Model> {
const model = await this.modelsUsecases.findOne(modelId);
if (!model) {
Expand Down Expand Up @@ -103,10 +129,16 @@ export class ModelsCliUsecases {
if (!sibling) throw 'No expected quantization found';

let stopWord = '';
let promptTemplate = LLAMA_2;

try {
const { metadata } = await gguf(sibling.downloadUrl!);
// @ts-expect-error "tokenizer.ggml.eos_token_id"
const index = metadata['tokenizer.ggml.eos_token_id'];
// @ts-expect-error "tokenizer.ggml.eos_token_id"
const hfChatTemplate = metadata['tokenizer.chat_template'];
promptTemplate = this.guessPromptTemplateFromHuggingFace(hfChatTemplate);

// @ts-expect-error "tokenizer.ggml.tokens"
stopWord = metadata['tokenizer.ggml.tokens'][index] ?? '';
} catch (err) {
Expand All @@ -129,7 +161,9 @@ export class ModelsCliUsecases {
version: '',
format: ModelFormat.GGUF,
description: '',
settings: {},
settings: {
prompt_template: promptTemplate,
},
parameters: {
stop: stopWords,
},
Expand All @@ -144,6 +178,37 @@ export class ModelsCliUsecases {
await this.modelsUsecases.create(model);
}

// TODO: move this to somewhere else, should be reused by API as well. Maybe in a separate service / provider?
private guessPromptTemplateFromHuggingFace(jinjaCode?: string): string {
if (!jinjaCode) {
console.log('No jinja code provided. Returning default LLAMA_2');
return LLAMA_2;
}

if (typeof jinjaCode !== 'string') {
console.log(
`Invalid jinja code provided (type is ${typeof jinjaCode}). Returning default LLAMA_2`,
);
return LLAMA_2;
}

switch (jinjaCode) {
case ZEPHYR_JINJA:
return ZEPHYR;

case OPEN_CHAT_3_5_JINJA:
return OPEN_CHAT_3_5;

default:
console.log(
'Unknown jinja code:',
jinjaCode,
'Returning default LLAMA_2',
);
return LLAMA_2;
}
}

private async fetchHuggingFaceRepoData(repoId: string) {
const sanitizedUrl = this.toHuggingFaceUrl(repoId);

Expand Down
Loading