Skip to content

Commit

Permalink
use context to generate image
Browse files Browse the repository at this point in the history
  • Loading branch information
jonmatthis committed Nov 27, 2024
1 parent bec5878 commit fe6fb16
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 12 deletions.
19 changes: 19 additions & 0 deletions src/core/ai/openai/dto/text-generation.dto.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { IsNumber, IsString } from 'class-validator';

export class TextGenerationDto {
@IsString()
prompt: string;

@IsString()
model: string = 'gpt-4o';

@IsNumber()
temperature: number = 0.7;

@IsNumber()
max_tokens: number = 300;

constructor(partial: Partial<TextGenerationDto>) {
Object.assign(this, partial);
}
}
45 changes: 45 additions & 0 deletions src/core/ai/openai/openai-text.service.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import { Injectable, Logger, OnModuleInit } from '@nestjs/common';
import { OpenAI } from 'openai';
import { TextGenerationDto } from './dto/text-generation.dto';
import { OpenaiSecretsService } from './openai-secrets.service';
import { ChatCompletion } from 'openai/src/resources/chat/completions';

@Injectable()
export class OpenaiTextGenerationService implements OnModuleInit {
private openai: OpenAI;

constructor(
private readonly _openAiSecrets: OpenaiSecretsService,
private readonly _logger: Logger,
) {}

async onModuleInit() {
try {
const apiKey = await this._openAiSecrets.getOpenaiApiKey();
this.openai = new OpenAI({ apiKey: apiKey });
} catch (error) {
this._logger.error('Failed to initialize OpenAI service.', error);
throw error;
}
}

public async generateText(dto: TextGenerationDto): Promise<string> {
try {
const { prompt, temperature, max_tokens, model } = dto;
this._logger.log(`Generating text...`);
const chatCompletionResponse: ChatCompletion =
await this.openai.chat.completions.create({
model,
temperature,
max_tokens,
messages: [{ role: 'system', content: prompt }],
});

this._logger.log(`Text generation complete.`);
return chatCompletionResponse.choices[0].message.content;
} catch (error) {
this._logger.error('Failed to generate image.', error);
throw error;
}
}
}
3 changes: 3 additions & 0 deletions src/core/ai/openai/openai.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { GcpModule } from '../../gcp/gcp.module';
import { OpenaiChatService } from './openai-chat.service';
import { OpenaiAudioService } from './openai-audio.service';
import { OpenaiImageService } from './openai-image.service';
import { OpenaiTextGenerationService } from './openai-text.service';

@Module({
imports: [GcpModule],
Expand All @@ -12,13 +13,15 @@ import { OpenaiImageService } from './openai-image.service';
OpenaiChatService,
OpenaiAudioService,
OpenaiImageService,
OpenaiTextGenerationService,
Logger,
],
exports: [
OpenaiSecretsService,
OpenaiChatService,
OpenaiAudioService,
OpenaiImageService,
OpenaiTextGenerationService,
],
})
export class OpenaiModule {}
35 changes: 23 additions & 12 deletions src/interfaces/discord/commands/discord-image.command.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { OpenaiImageService } from '../../../core/ai/openai/openai-image.service
import OpenAI from 'openai';
import { AttachmentBuilder } from 'discord.js';
import ImagesResponse = OpenAI.ImagesResponse;
import { OpenaiTextGenerationService } from '../../../core/ai/openai/openai-text.service';

export class ImagePromptDto {
@StringOption({
Expand All @@ -20,18 +21,21 @@ export class ImagePromptDto {
})
prompt: string = 'Generate a new image';

// @BooleanOption({
// name: 'use_context',
// description:
// 'Whether to include text from this Thread/Channel in the image generation prompt',
// required: false,
// })
// useContext: boolean;
@BooleanOption({
name: 'use_context',
description:
'Whether to include text from this Thread/Channel in the image generation prompt',
required: false,
})
useContext: boolean;
}

@Injectable()
export class DiscordImageCommand {
constructor(private readonly _openaiImageService: OpenaiImageService) {}
constructor(
private readonly _openaiImageService: OpenaiImageService,
private readonly _openaiTextService: OpenaiTextGenerationService,
) {}

@SlashCommand({
name: 'image',
Expand All @@ -42,22 +46,29 @@ export class DiscordImageCommand {
@Context() [interaction]: SlashCommandContext,
@Options({ required: false }) imagePromptDto?: ImagePromptDto,
) {
await interaction.deferReply();
let promptText = '';
if (!imagePromptDto || !imagePromptDto.prompt) {
promptText = 'Generate a new image';
} else {
promptText = imagePromptDto.prompt;
}
if (imagePromptDto.useContext) {
if (imagePromptDto && imagePromptDto.useContext) {
const context = interaction.channel;
const messages = await context.messages.fetch();
const contextText = messages
.map((message) => message.content)
.join(' \n ');
promptText = `${promptText} \n\n ${contextText}`;

promptText = await this._openaiTextService.generateText({
prompt: `Condense the provided INPUT TEXT into a 200 word (or less) prompt that will be used to generate an image. Do not generate any text other than the image generation prompt.\n\n--------BEGIN INPUT TEXT\n\n ${contextText} \n\n ---------------END OF INPUT TEXT\n\nREMEMBER! Condense the provided INPUT TEXT into a 200 word (or less) prompt that will be used to generate an image. Do not generate any text other than the image generation prompt.`,
model: 'gpt-4o',
temperature: 0.5,
max_tokens: 300,
});
}
await interaction.reply({
content: `Generating image from prompt:\n > ${promptText} \n Generating image...`,
await interaction.editReply({
content: `Generating image from prompt:\n > ${promptText} \n Please wait...`,
});
// generate image
const response: ImagesResponse =
Expand Down

0 comments on commit fe6fb16

Please sign in to comment.