Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lazy load llama #220

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 19 additions & 36 deletions core/src/services/llama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import fs from "fs";
import https from "https";
import si from "systeminformation";
import { wordsToPunish } from "./wordsToPunish.ts";
import { prettyConsole } from "../index.ts";

const __dirname = path.dirname(fileURLToPath(import.meta.url));

Expand Down Expand Up @@ -67,28 +68,25 @@ class LlamaService {
private modelInitialized: boolean = false;

private constructor() {
console.log("Constructing");
this.llama = undefined;
this.model = undefined;
this.modelUrl =
"https://huggingface.co/NousResearch/Hermes-3-Llama-3.1-8B-GGUF/resolve/main/Hermes-3-Llama-3.1-8B.Q8_0.gguf?download=true";
const modelName = "model.gguf";
console.log("modelName", modelName);
this.modelPath = path.join(__dirname, modelName);
try {
this.initializeModel();
} catch (error) {
console.error("Error initializing model", error);

}
private async ensureInitialized() {
if (!this.modelInitialized) {
await this.initializeModel();
}
}

public static getInstance(): LlamaService {
if (!LlamaService.instance) {
LlamaService.instance = new LlamaService();
}
return LlamaService.instance;
}

async initializeModel() {
try {
await this.checkModel();
Expand All @@ -99,30 +97,26 @@ class LlamaService {
);

if (hasCUDA) {
console.log("**** CUDA detected");
console.log("**** LlamaService: CUDA detected");
} else {
console.log(
"**** No CUDA detected - local response will be slow"
console.warn(
"**** LlamaService: No CUDA detected - local response will be slow"
);
}

this.llama = await getLlama({
gpu: "cuda",
});
console.log("Creating grammar");
const grammar = new LlamaJsonSchemaGrammar(
this.llama,
jsonSchemaGrammar as GbnfJsonSchema
);
this.grammar = grammar;
console.log("Loading model");
console.log("this.modelPath", this.modelPath);

this.model = await this.llama.loadModel({
modelPath: this.modelPath,
});
console.log("Model GPU support", this.llama.getGpuDeviceNames());
console.log("Creating context");

this.ctx = await this.model.createContext({ contextSize: 8192 });
this.sequence = this.ctx.getSequence();

Expand All @@ -139,11 +133,7 @@ class LlamaService {
}

async checkModel() {
console.log("Checking model");
if (!fs.existsSync(this.modelPath)) {
console.log("this.modelPath", this.modelPath);
console.log("Model not found. Downloading...");

await new Promise<void>((resolve, reject) => {
const file = fs.createWriteStream(this.modelPath);
let downloadedSize = 0;
Expand All @@ -157,14 +147,9 @@ class LlamaService {
if (isRedirect) {
const redirectUrl = response.headers.location;
if (redirectUrl) {
console.log(
"Following redirect to:",
redirectUrl
);
downloadModel(redirectUrl);
return;
} else {
console.error("Redirect URL not found");
reject(new Error("Redirect URL not found"));
return;
}
Expand All @@ -191,7 +176,6 @@ class LlamaService {

response.on("end", () => {
file.end();
console.log("\nModel downloaded successfully.");
resolve();
});
})
Expand All @@ -211,14 +195,13 @@ class LlamaService {
});
});
} else {
console.log("Model already exists.");
prettyConsole.warn("Model already exists.");
}
}

async deleteModel() {
if (fs.existsSync(this.modelPath)) {
fs.unlinkSync(this.modelPath);
console.log("Model deleted.");
}
}

Expand All @@ -230,7 +213,7 @@ class LlamaService {
presence_penalty: number,
max_tokens: number
): Promise<any> {
console.log("Queueing message generateText");
await this.ensureInitialized();
return new Promise((resolve, reject) => {
this.messageQueue.push({
context,
Expand All @@ -255,13 +238,15 @@ class LlamaService {
presence_penalty: number,
max_tokens: number
): Promise<string> {
await this.ensureInitialized();

return new Promise((resolve, reject) => {
this.messageQueue.push({
context,
temperature,
stop,
frequency_penalty,
presence_penalty,
frequency_penalty: frequency_penalty ?? 1.0,
presence_penalty: presence_penalty ?? 1.0,
max_tokens,
useGrammar: false,
resolve,
Expand All @@ -286,7 +271,6 @@ class LlamaService {
const message = this.messageQueue.shift();
if (message) {
try {
console.log("Processing message");
const response = await this.getCompletionResponse(
message.context,
message.temperature,
Expand Down Expand Up @@ -334,7 +318,7 @@ class LlamaService {
};

const responseTokens: Token[] = [];
console.log("Evaluating tokens");

for await (const token of this.sequence.evaluate(tokens, {
temperature: Number(temperature),
repeatPenalty: repeatPenalty,
Expand Down Expand Up @@ -374,7 +358,6 @@ class LlamaService {
// try parsing response as JSON
try {
jsonString = JSON.stringify(JSON.parse(response));
console.log("parsedResponse", jsonString);
} catch {
throw new Error("JSON string not found");
}
Expand All @@ -384,20 +367,19 @@ class LlamaService {
if (!parsedResponse) {
throw new Error("Parsed response is undefined");
}
console.log("AI: " + parsedResponse.content);
await this.sequence.clearHistory();
return parsedResponse;
} catch (error) {
console.error("Error parsing JSON:", error);
}
} else {
console.log("AI: " + response);
await this.sequence.clearHistory();
return response;
}
}

async getEmbeddingResponse(input: string): Promise<number[] | undefined> {
await this.ensureInitialized();
if (!this.model) {
throw new Error("Model not initialized. Call initialize() first.");
}
Expand All @@ -409,3 +391,4 @@ class LlamaService {
}

export default LlamaService;