From 98cd68787c0563dc962f0dca1a18fb699377cae6 Mon Sep 17 00:00:00 2001 From: ProphetX10 Date: Mon, 2 Dec 2024 08:03:11 -0100 Subject: [PATCH] Update generation.ts to fix TOGETHER/LLAMACLOUD image generation --- packages/core/src/generation.ts | 60 +++++++++++++++++++++++---------- 1 file changed, 43 insertions(+), 17 deletions(-) diff --git a/packages/core/src/generation.ts b/packages/core/src/generation.ts index a460a728564..52fc5135aa5 100644 --- a/packages/core/src/generation.ts +++ b/packages/core/src/generation.ts @@ -874,7 +874,6 @@ export const generateImage = async ( runtime.imageModelProvider === ModelProviderName.LLAMACLOUD ) { const together = new Together({ apiKey: apiKey as string }); - // Fix: steps 4 is for schnell; 28 is for dev. const response = await together.images.create({ model: "black-forest-labs/FLUX.1-schnell", prompt: data.prompt, @@ -883,23 +882,41 @@ export const generateImage = async ( steps: modelSettings?.steps ?? 4, n: data.count, }); - const urls: string[] = []; - for (let i = 0; i < response.data.length; i++) { - const json = response.data[i].b64_json; - // decode base64 - const base64 = Buffer.from(json, "base64").toString("base64"); - urls.push(base64); + + // Add type assertion to handle the response properly + const togetherResponse = response as unknown as TogetherAIImageResponse; + + if (!togetherResponse.data || !Array.isArray(togetherResponse.data)) { + throw new Error("Invalid response format from Together AI"); } - const base64s = await Promise.all( - urls.map(async (url) => { - const response = await fetch(url); - const blob = await response.blob(); - const buffer = await blob.arrayBuffer(); - let base64 = Buffer.from(buffer).toString("base64"); - base64 = "data:image/jpeg;base64," + base64; - return base64; - }) - ); + + // Rest of the code remains the same... + const base64s = await Promise.all(togetherResponse.data.map(async (image) => { + if (!image.url) { + elizaLogger.error("Missing URL in image data:", image); + throw new Error("Missing URL in Together AI response"); + } + + // Fetch the image from the URL + const imageResponse = await fetch(image.url); + if (!imageResponse.ok) { + throw new Error(`Failed to fetch image: ${imageResponse.statusText}`); + } + + // Convert to blob and then to base64 + const blob = await imageResponse.blob(); + const arrayBuffer = await blob.arrayBuffer(); + const base64 = Buffer.from(arrayBuffer).toString('base64'); + + // Return with proper MIME type + return `data:image/jpeg;base64,${base64}`; + })); + + if (base64s.length === 0) { + throw new Error("No images generated by Together AI"); + } + + elizaLogger.debug(`Generated ${base64s.length} images`); return { success: true, data: base64s }; } else if (runtime.imageModelProvider === ModelProviderName.FAL) { fal.config({ @@ -1406,3 +1423,12 @@ async function handleOllama({ ...modelOptions, }); } + +// Add type definition for Together AI response +interface TogetherAIImageResponse { + data: Array<{ + url: string; + content_type?: string; + image_type?: string; + }>; +}