Skip to content

Commit

Permalink
Merge pull request #777 from ai16z/shaw/refactor-image-interface
Browse files Browse the repository at this point in the history
Fix: Refactor image interface and update to move llama cloud -> together provider
  • Loading branch information
lalalune authored Dec 2, 2024
2 parents dadef5b + 45d3c8f commit 1ae26c3
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 74 deletions.
1 change: 1 addition & 0 deletions agent/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ export function getTokenForProvider(
settings.ETERNALAI_API_KEY
);
case ModelProviderName.LLAMACLOUD:
case ModelProviderName.TOGETHER:
return (
character.settings?.secrets?.LLAMACLOUD_API_KEY ||
settings.LLAMACLOUD_API_KEY ||
Expand Down
53 changes: 30 additions & 23 deletions packages/core/src/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,25 @@ export async function generateText({

// if runtime.getSetting("LLAMACLOUD_MODEL_LARGE") is true and modelProvider is LLAMACLOUD, then use the large model
if (
runtime.getSetting("LLAMACLOUD_MODEL_LARGE") &&
provider === ModelProviderName.LLAMACLOUD
(runtime.getSetting("LLAMACLOUD_MODEL_LARGE") &&
provider === ModelProviderName.LLAMACLOUD) ||
(runtime.getSetting("TOGETHER_MODEL_LARGE") &&
provider === ModelProviderName.TOGETHER)
) {
model = runtime.getSetting("LLAMACLOUD_MODEL_LARGE");
model =
runtime.getSetting("LLAMACLOUD_MODEL_LARGE") ||
runtime.getSetting("TOGETHER_MODEL_LARGE");
}

if (
runtime.getSetting("LLAMACLOUD_MODEL_SMALL") &&
provider === ModelProviderName.LLAMACLOUD
(runtime.getSetting("LLAMACLOUD_MODEL_SMALL") &&
provider === ModelProviderName.LLAMACLOUD) ||
(runtime.getSetting("TOGETHER_MODEL_SMALL") &&
provider === ModelProviderName.TOGETHER)
) {
model = runtime.getSetting("LLAMACLOUD_MODEL_SMALL");
model =
runtime.getSetting("LLAMACLOUD_MODEL_SMALL") ||
runtime.getSetting("TOGETHER_MODEL_SMALL");
}

elizaLogger.info("Selected model:", model);
Expand Down Expand Up @@ -120,7 +128,8 @@ export async function generateText({
case ModelProviderName.ETERNALAI:
case ModelProviderName.ALI_BAILIAN:
case ModelProviderName.VOLENGINE:
case ModelProviderName.LLAMACLOUD: {
case ModelProviderName.LLAMACLOUD:
case ModelProviderName.TOGETHER: {
elizaLogger.debug("Initializing OpenAI model.");
const openai = createOpenAI({ apiKey, baseURL: endpoint });

Expand Down Expand Up @@ -806,12 +815,6 @@ export const generateImage = async (
data?: string[];
error?: any;
}> => {
const { prompt, width, height } = data;
let { count } = data;
if (!count) {
count = 1;
}

const model = getModel(runtime.imageModelProvider, ModelClass.IMAGE);
const modelSettings = models[runtime.imageModelProvider].imageSettings;

Expand Down Expand Up @@ -866,16 +869,19 @@ export const generateImage = async (
const imageURL = await response.json();
return { success: true, data: [imageURL] };
} else if (
runtime.imageModelProvider === ModelProviderName.TOGETHER ||
// for backwards compat
runtime.imageModelProvider === ModelProviderName.LLAMACLOUD
) {
const together = new Together({ apiKey: apiKey as string });
// Fix: steps 4 is for schnell; 28 is for dev.
const response = await together.images.create({
model: "black-forest-labs/FLUX.1-schnell",
prompt,
width,
height,
prompt: data.prompt,
width: data.width,
height: data.height,
steps: modelSettings?.steps ?? 4,
n: count,
n: data.count,
});
const urls: string[] = [];
for (let i = 0; i < response.data.length; i++) {
Expand All @@ -902,11 +908,11 @@ export const generateImage = async (

// Prepare the input parameters according to their schema
const input = {
prompt: prompt,
prompt: data.prompt,
image_size: "square" as const,
num_inference_steps: modelSettings?.steps ?? 50,
guidance_scale: 3.5,
num_images: count,
guidance_scale: data.guidanceScale || 3.5,
num_images: data.count,
enable_safety_checker: true,
output_format: "png" as const,
seed: data.seed ?? 6252023,
Expand Down Expand Up @@ -945,7 +951,7 @@ export const generateImage = async (
const base64s = await Promise.all(base64Promises);
return { success: true, data: base64s };
} else {
let targetSize = `${width}x${height}`;
let targetSize = `${data.width}x${data.height}`;
if (
targetSize !== "1024x1024" &&
targetSize !== "1792x1024" &&
Expand All @@ -956,9 +962,9 @@ export const generateImage = async (
const openai = new OpenAI({ apiKey: apiKey as string });
const response = await openai.images.generate({
model,
prompt,
prompt: data.prompt,
size: targetSize as "1024x1024" | "1792x1024" | "1024x1792",
n: count,
n: data.count,
response_format: "b64_json",
});
const base64s = response.data.map(
Expand Down Expand Up @@ -1157,6 +1163,7 @@ export async function handleProvider(
case ModelProviderName.ALI_BAILIAN:
case ModelProviderName.VOLENGINE:
case ModelProviderName.LLAMACLOUD:
case ModelProviderName.TOGETHER:
return await handleOpenAI(options);
case ModelProviderName.ANTHROPIC:
return await handleAnthropic(options);
Expand Down
21 changes: 21 additions & 0 deletions packages/core/src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,27 @@ export const models: Models = {
[ModelClass.IMAGE]: "black-forest-labs/FLUX.1-schnell",
},
},
[ModelProviderName.TOGETHER]: {
settings: {
stop: [],
maxInputTokens: 128000,
maxOutputTokens: 8192,
repetition_penalty: 0.4,
temperature: 0.7,
},
imageSettings: {
steps: 4,
},
endpoint: "https://api.together.ai/v1",
model: {
[ModelClass.SMALL]: "meta-llama/Llama-3.2-3B-Instruct-Turbo",
[ModelClass.MEDIUM]: "meta-llama-3.1-8b-instruct",
[ModelClass.LARGE]: "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
[ModelClass.EMBEDDING]:
"togethercomputer/m2-bert-80M-32k-retrieval",
[ModelClass.IMAGE]: "black-forest-labs/FLUX.1-schnell",
},
},
[ModelProviderName.LLAMALOCAL]: {
settings: {
stop: ["<|eot_id|>", "<|eom_id|>"],
Expand Down
102 changes: 56 additions & 46 deletions packages/core/src/tests/generation.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import { describe, expect, it, vi, beforeEach } from "vitest";
import { ModelProviderName, IAgentRuntime } from "../types";
import { models } from "../models";
import { generateText, generateTrueOrFalse, splitChunks, trimTokens } from "../generation";
import {
generateText,
generateTrueOrFalse,
splitChunks,
trimTokens,
} from "../generation";
import type { TiktokenModel } from "js-tiktoken";

// Mock the elizaLogger
Expand Down Expand Up @@ -42,6 +47,8 @@ describe("Generation", () => {
getSetting: vi.fn().mockImplementation((key: string) => {
if (key === "LLAMACLOUD_MODEL_LARGE") return false;
if (key === "LLAMACLOUD_MODEL_SMALL") return false;
if (key === "TOGETHER_MODEL_LARGE") return false;
if (key === "TOGETHER_MODEL_SMALL") return false;
return undefined;
}),
} as unknown as IAgentRuntime;
Expand Down Expand Up @@ -122,53 +129,56 @@ describe("Generation", () => {
});
});

describe("trimTokens", () => {
const model = "gpt-4" as TiktokenModel;

it("should return empty string for empty input", () => {
const result = trimTokens("", 100, model);
expect(result).toBe("");
});

it("should throw error for negative maxTokens", () => {
expect(() => trimTokens("test", -1, model)).toThrow("maxTokens must be positive");
});

it("should return unchanged text if within token limit", () => {
const shortText = "This is a short text";
const result = trimTokens(shortText, 10, model);
expect(result).toBe(shortText);
});

it("should truncate text to specified token limit", () => {
// Using a longer text that we know will exceed the token limit
const longText = "This is a much longer text that will definitely exceed our very small token limit and need to be truncated to fit within the specified constraints."
const result = trimTokens(longText, 5, model);

// The exact result will depend on the tokenizer, but we can verify:
// 1. Result is shorter than original
expect(result.length).toBeLessThan(longText.length);
// 2. Result is not empty
expect(result.length).toBeGreaterThan(0);
// 3. Result is a proper substring of the original text
expect(longText.includes(result)).toBe(true);
});

it("should handle non-ASCII characters", () => {
const unicodeText = "Hello 👋 World 🌍";
const result = trimTokens(unicodeText, 5, model);
expect(result.length).toBeGreaterThan(0);
});

it("should handle multiline text", () => {
const multilineText = `Line 1
describe("trimTokens", () => {
const model = "gpt-4" as TiktokenModel;

it("should return empty string for empty input", () => {
const result = trimTokens("", 100, model);
expect(result).toBe("");
});

it("should throw error for negative maxTokens", () => {
expect(() => trimTokens("test", -1, model)).toThrow(
"maxTokens must be positive"
);
});

it("should return unchanged text if within token limit", () => {
const shortText = "This is a short text";
const result = trimTokens(shortText, 10, model);
expect(result).toBe(shortText);
});

it("should truncate text to specified token limit", () => {
// Using a longer text that we know will exceed the token limit
const longText =
"This is a much longer text that will definitely exceed our very small token limit and need to be truncated to fit within the specified constraints.";
const result = trimTokens(longText, 5, model);

// The exact result will depend on the tokenizer, but we can verify:
// 1. Result is shorter than original
expect(result.length).toBeLessThan(longText.length);
// 2. Result is not empty
expect(result.length).toBeGreaterThan(0);
// 3. Result is a proper substring of the original text
expect(longText.includes(result)).toBe(true);
});

it("should handle non-ASCII characters", () => {
const unicodeText = "Hello 👋 World 🌍";
const result = trimTokens(unicodeText, 5, model);
expect(result.length).toBeGreaterThan(0);
});

it("should handle multiline text", () => {
const multilineText = `Line 1
Line 2
Line 3
Line 4
Line 5`;
const result = trimTokens(multilineText, 5, model);
expect(result.length).toBeGreaterThan(0);
expect(result.length).toBeLessThan(multilineText.length);
});
});
const result = trimTokens(multilineText, 5, model);
expect(result.length).toBeGreaterThan(0);
expect(result.length).toBeLessThan(multilineText.length);
});
});
});
2 changes: 2 additions & 0 deletions packages/core/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ export type Models = {
[ModelProviderName.GROK]: Model;
[ModelProviderName.GROQ]: Model;
[ModelProviderName.LLAMACLOUD]: Model;
[ModelProviderName.TOGETHER]: Model;
[ModelProviderName.LLAMALOCAL]: Model;
[ModelProviderName.GOOGLE]: Model;
[ModelProviderName.CLAUDE_VERTEX]: Model;
Expand All @@ -216,6 +217,7 @@ export enum ModelProviderName {
GROK = "grok",
GROQ = "groq",
LLAMACLOUD = "llama_cloud",
TOGETHER = "together",
LLAMALOCAL = "llama_local",
GOOGLE = "google",
CLAUDE_VERTEX = "claude_vertex",
Expand Down
34 changes: 29 additions & 5 deletions packages/plugin-image-generation/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { generateImage } from "@ai16z/eliza";

import fs from "fs";
import path from "path";
import { validateImageGenConfig } from "./enviroment";
import { validateImageGenConfig } from "./environment";

export function saveBase64Image(base64Data: string, filename: string): string {
// Create generatedImages directory if it doesn't exist
Expand Down Expand Up @@ -97,7 +97,17 @@ const imageGeneration: Action = {
runtime: IAgentRuntime,
message: Memory,
state: State,
options: any,
options: {
width?: number;
height?: number;
count?: number;
negativePrompt?: string;
numIterations?: number;
guidanceScale?: number;
seed?: number;
modelId?: string;
jobId?: string;
},
callback: HandlerCallback
) => {
elizaLogger.log("Composing state for message:", message);
Expand All @@ -116,9 +126,23 @@ const imageGeneration: Action = {
const images = await generateImage(
{
prompt: imagePrompt,
width: 1024,
height: 1024,
count: 1,
width: options.width || 1024,
height: options.height || 1024,
...(options.count != null ? { count: options.count || 1 } : {}),
...(options.negativePrompt != null
? { negativePrompt: options.negativePrompt }
: {}),
...(options.numIterations != null
? { numIterations: options.numIterations }
: {}),
...(options.guidanceScale != null
? { guidanceScale: options.guidanceScale }
: {}),
...(options.seed != null ? { seed: options.seed } : {}),
...(options.modelId != null
? { modelId: options.modelId }
: {}),
...(options.jobId != null ? { jobId: options.jobId } : {}),
},
runtime
);
Expand Down

0 comments on commit 1ae26c3

Please sign in to comment.