From fe6fb165fc5e9ca155e78f6ff7b453b832129fab Mon Sep 17 00:00:00 2001 From: jonmatthis Date: Wed, 27 Nov 2024 13:00:41 -0500 Subject: [PATCH] use context to generate image --- src/core/ai/openai/dto/text-generation.dto.ts | 19 ++++++++ src/core/ai/openai/openai-text.service.ts | 45 +++++++++++++++++++ src/core/ai/openai/openai.module.ts | 3 ++ .../discord/commands/discord-image.command.ts | 35 ++++++++++----- 4 files changed, 90 insertions(+), 12 deletions(-) create mode 100644 src/core/ai/openai/dto/text-generation.dto.ts create mode 100644 src/core/ai/openai/openai-text.service.ts diff --git a/src/core/ai/openai/dto/text-generation.dto.ts b/src/core/ai/openai/dto/text-generation.dto.ts new file mode 100644 index 0000000..ba31477 --- /dev/null +++ b/src/core/ai/openai/dto/text-generation.dto.ts @@ -0,0 +1,19 @@ +import { IsNumber, IsString } from 'class-validator'; + +export class TextGenerationDto { + @IsString() + prompt: string; + + @IsString() + model: string = 'gpt-4o'; + + @IsNumber() + temperature: number = 0.7; + + @IsNumber() + max_tokens: number = 300; + + constructor(partial: Partial) { + Object.assign(this, partial); + } +} diff --git a/src/core/ai/openai/openai-text.service.ts b/src/core/ai/openai/openai-text.service.ts new file mode 100644 index 0000000..68acd77 --- /dev/null +++ b/src/core/ai/openai/openai-text.service.ts @@ -0,0 +1,45 @@ +import { Injectable, Logger, OnModuleInit } from '@nestjs/common'; +import { OpenAI } from 'openai'; +import { TextGenerationDto } from './dto/text-generation.dto'; +import { OpenaiSecretsService } from './openai-secrets.service'; +import { ChatCompletion } from 'openai/src/resources/chat/completions'; + +@Injectable() +export class OpenaiTextGenerationService implements OnModuleInit { + private openai: OpenAI; + + constructor( + private readonly _openAiSecrets: OpenaiSecretsService, + private readonly _logger: Logger, + ) {} + + async onModuleInit() { + try { + const apiKey = await this._openAiSecrets.getOpenaiApiKey(); + this.openai = new OpenAI({ apiKey: apiKey }); + } catch (error) { + this._logger.error('Failed to initialize OpenAI service.', error); + throw error; + } + } + + public async generateText(dto: TextGenerationDto): Promise { + try { + const { prompt, temperature, max_tokens, model } = dto; + this._logger.log(`Generating text...`); + const chatCompletionResponse: ChatCompletion = + await this.openai.chat.completions.create({ + model, + temperature, + max_tokens, + messages: [{ role: 'system', content: prompt }], + }); + + this._logger.log(`Text generation complete.`); + return chatCompletionResponse.choices[0].message.content; + } catch (error) { + this._logger.error('Failed to generate image.', error); + throw error; + } + } +} diff --git a/src/core/ai/openai/openai.module.ts b/src/core/ai/openai/openai.module.ts index 4479990..036646b 100644 --- a/src/core/ai/openai/openai.module.ts +++ b/src/core/ai/openai/openai.module.ts @@ -4,6 +4,7 @@ import { GcpModule } from '../../gcp/gcp.module'; import { OpenaiChatService } from './openai-chat.service'; import { OpenaiAudioService } from './openai-audio.service'; import { OpenaiImageService } from './openai-image.service'; +import { OpenaiTextGenerationService } from './openai-text.service'; @Module({ imports: [GcpModule], @@ -12,6 +13,7 @@ import { OpenaiImageService } from './openai-image.service'; OpenaiChatService, OpenaiAudioService, OpenaiImageService, + OpenaiTextGenerationService, Logger, ], exports: [ @@ -19,6 +21,7 @@ import { OpenaiImageService } from './openai-image.service'; OpenaiChatService, OpenaiAudioService, OpenaiImageService, + OpenaiTextGenerationService, ], }) export class OpenaiModule {} diff --git a/src/interfaces/discord/commands/discord-image.command.ts b/src/interfaces/discord/commands/discord-image.command.ts index e624ef2..26d40c7 100644 --- a/src/interfaces/discord/commands/discord-image.command.ts +++ b/src/interfaces/discord/commands/discord-image.command.ts @@ -11,6 +11,7 @@ import { OpenaiImageService } from '../../../core/ai/openai/openai-image.service import OpenAI from 'openai'; import { AttachmentBuilder } from 'discord.js'; import ImagesResponse = OpenAI.ImagesResponse; +import { OpenaiTextGenerationService } from '../../../core/ai/openai/openai-text.service'; export class ImagePromptDto { @StringOption({ @@ -20,18 +21,21 @@ export class ImagePromptDto { }) prompt: string = 'Generate a new image'; - // @BooleanOption({ - // name: 'use_context', - // description: - // 'Whether to include text from this Thread/Channel in the image generation prompt', - // required: false, - // }) - // useContext: boolean; + @BooleanOption({ + name: 'use_context', + description: + 'Whether to include text from this Thread/Channel in the image generation prompt', + required: false, + }) + useContext: boolean; } @Injectable() export class DiscordImageCommand { - constructor(private readonly _openaiImageService: OpenaiImageService) {} + constructor( + private readonly _openaiImageService: OpenaiImageService, + private readonly _openaiTextService: OpenaiTextGenerationService, + ) {} @SlashCommand({ name: 'image', @@ -42,22 +46,29 @@ export class DiscordImageCommand { @Context() [interaction]: SlashCommandContext, @Options({ required: false }) imagePromptDto?: ImagePromptDto, ) { + await interaction.deferReply(); let promptText = ''; if (!imagePromptDto || !imagePromptDto.prompt) { promptText = 'Generate a new image'; } else { promptText = imagePromptDto.prompt; } - if (imagePromptDto.useContext) { + if (imagePromptDto && imagePromptDto.useContext) { const context = interaction.channel; const messages = await context.messages.fetch(); const contextText = messages .map((message) => message.content) .join(' \n '); - promptText = `${promptText} \n\n ${contextText}`; + + promptText = await this._openaiTextService.generateText({ + prompt: `Condense the provided INPUT TEXT into a 200 word (or less) prompt that will be used to generate an image. Do not generate any text other than the image generation prompt.\n\n--------BEGIN INPUT TEXT\n\n ${contextText} \n\n ---------------END OF INPUT TEXT\n\nREMEMBER! Condense the provided INPUT TEXT into a 200 word (or less) prompt that will be used to generate an image. Do not generate any text other than the image generation prompt.`, + model: 'gpt-4o', + temperature: 0.5, + max_tokens: 300, + }); } - await interaction.reply({ - content: `Generating image from prompt:\n > ${promptText} \n Generating image...`, + await interaction.editReply({ + content: `Generating image from prompt:\n > ${promptText} \n Please wait...`, }); // generate image const response: ImagesResponse =