Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/server should return better error codes when generating open ai images #582

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
# CasualOS Changelog

## V3.3.16

#### Date: TBD

### :rocket: Features

### :bug: Bug Fixes

- Improved error handling for `ai.generateImage` requests with unacceptable parameters.
- The server now returns an `invalid_request` error code when the parameters provided are not accepted by the selected model (e.g., OpenAI, Google).
- This ensures that users receive clear and actionable feedback when their requests fail due to invalid parameters.

## V3.3.15

#### Date: 12/19/2024
Expand Down
6 changes: 6 additions & 0 deletions src/aux-records/AIController.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2765,6 +2765,7 @@ describe('AIController', () => {
it('should return the result from the generateImage interface', async () => {
generateImageInterface.generateImage.mockReturnValueOnce(
Promise.resolve({
success: true,
images: [
{
base64: 'base64',
Expand Down Expand Up @@ -2860,6 +2861,7 @@ describe('AIController', () => {

otherInterface.generateImage.mockReturnValueOnce(
Promise.resolve({
success: true,
images: [
{
base64: 'base64',
Expand Down Expand Up @@ -2979,6 +2981,7 @@ describe('AIController', () => {
it('should work when the controller is configured to allow all subscription tiers and the user does not have a subscription', async () => {
generateImageInterface.generateImage.mockReturnValueOnce(
Promise.resolve({
success: true,
images: [
{
base64: 'base64',
Expand Down Expand Up @@ -3090,6 +3093,7 @@ describe('AIController', () => {

generateImageInterface.generateImage.mockReturnValueOnce(
Promise.resolve({
success: true,
images: [
{
base64: 'base64',
Expand Down Expand Up @@ -3120,6 +3124,7 @@ describe('AIController', () => {
it('should reject the request if it would exceed the subscription request limits', async () => {
generateImageInterface.generateImage.mockReturnValueOnce(
Promise.resolve({
success: true,
images: [
{
base64: 'base64',
Expand Down Expand Up @@ -3150,6 +3155,7 @@ describe('AIController', () => {
it('should reject the request if it would exceed the subscription period limits', async () => {
generateImageInterface.generateImage.mockReturnValueOnce(
Promise.resolve({
success: true,
images: [
{
base64: 'base64',
Expand Down
8 changes: 6 additions & 2 deletions src/aux-records/AIController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import {
AIChatInterfaceResponse,
AIChatInterfaceStreamResponse,
AIChatMessage,
AIChatStreamMessage,
} from './AIChatInterface';
import {
AIGenerateSkyboxInterface,
Expand Down Expand Up @@ -1074,7 +1073,7 @@ export class AIController {
}

const result = await provider.generateImage({
model,
model: model,
prompt: request.prompt,
negativePrompt: request.negativePrompt,
width: width,
Expand All @@ -1092,6 +1091,10 @@ export class AIController {
userId: request.userId,
});

if (!result.success) {
return result;
}

await this._metrics.recordImageMetrics({
userId: request.userId,
createdAtMs: Date.now(),
Expand Down Expand Up @@ -1702,6 +1705,7 @@ export interface AIGenerateImageFailure {
| NotSupportedError
| SubscriptionLimitReached
| NotAuthorizedError
| 'invalid_request'
| 'invalid_model';
errorMessage: string;

Expand Down
35 changes: 31 additions & 4 deletions src/aux-records/AIImageInterface.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
import {
InvalidSubscriptionTierError,
NotAuthorizedError,
NotLoggedInError,
NotSubscribedError,
NotSupportedError,
ServerError,
SubscriptionLimitReached,
} from '@casual-simulation/aux-common/Errors';

/**
* Defines an interface that is able to generate images from text prompts.
*/
Expand Down Expand Up @@ -79,13 +89,30 @@ export interface AIGenerateImageInterfaceRequest {
userId?: string;
}

export interface AIGenerateImageInterfaceResponse {
/**
* The list of images that were generated.
*/
export type AIGenerateImageInterfaceResponse =
| AIGenerateImageInterfaceSuccess
| AIGenerateImageInterfaceFailure;

export interface AIGenerateImageInterfaceSuccess {
success: true;
images: AIGeneratedImage[];
}

export interface AIGenerateImageInterfaceFailure {
success: false;
errorCode:
| ServerError
| NotLoggedInError
| NotSubscribedError
| InvalidSubscriptionTierError
| NotSupportedError
| SubscriptionLimitReached
| NotAuthorizedError
| 'invalid_request'
| 'invalid_model';
errorMessage: string;
}

/**
* Defines an image that was generated by the AI.
*/
Expand Down
24 changes: 23 additions & 1 deletion src/aux-records/OpenAIImageInterface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ import {
} from './AIImageInterface';
import { handleAxiosErrors } from './Utils';
import { traced } from './tracing/TracingDecorators';
import { SpanKind, SpanOptions } from '@opentelemetry/api';
import {
SpanKind,
SpanOptions,
SpanStatusCode,
trace,
} from '@opentelemetry/api';

const TRACE_NAME = 'OpenAIImageInterface';
const SPAN_OPTIONS: SpanOptions = {
Expand Down Expand Up @@ -93,9 +98,26 @@ export class OpenAIImageInterface implements AIImageInterface {
);

return {
success: true,
images,
};
} catch (err) {
if (axios.isAxiosError(err)) {
if (err.response.status === 400) {
const span = trace.getActiveSpan();
span?.recordException(err);
span?.setStatus({ code: SpanStatusCode.ERROR });

console.error(
`[OpenAIChatInterface] [${request.userId}] [generateImage]: Bad request: ${err.response.data.error.message}`
);
return {
success: false,
errorCode: 'invalid_request',
errorMessage: err.response.data.error.message,
};
}
}
handleAxiosErrors(err);
}
}
Expand Down
1 change: 1 addition & 0 deletions src/aux-records/RecordsServer.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16239,6 +16239,7 @@ describe('RecordsServer', () => {
});

imageInterface.generateImage.mockResolvedValueOnce({
success: true,
images: [
{
base64: 'base64',
Expand Down
21 changes: 21 additions & 0 deletions src/aux-records/StabilityAIImageInterface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import axios from 'axios';
import { handleAxiosErrors } from './Utils';
import { traced } from './tracing/TracingDecorators';
import { z } from 'zod';
import { SpanStatusCode, trace } from '@opentelemetry/api';

const TRACE_NAME = 'StabilityAIImageInterface';

Expand Down Expand Up @@ -115,9 +116,26 @@ export class StabilityAIImageInterface implements AIImageInterface {
);

return {
success: true,
images,
};
} catch (err) {
if (axios.isAxiosError(err)) {
if (err.response.status === 400) {
const span = trace.getActiveSpan();
span?.recordException(err);
span?.setStatus({ code: SpanStatusCode.ERROR });

console.error(
`[StabilityAIChatIngerface] [${request.userId}] [generateImage]: Bad request: ${err.response.data.error.message}`
);
return {
success: false,
errorCode: 'invalid_request',
errorMessage: err.response.data.error.message,
};
}
}
handleAxiosErrors(err);
}
}
Expand Down Expand Up @@ -157,6 +175,7 @@ export class StabilityAIImageInterface implements AIImageInterface {
const data = schema.parse(result.data);

return {
success: true,
images: [
{
base64: data.image,
Expand Down Expand Up @@ -203,6 +222,7 @@ export class StabilityAIImageInterface implements AIImageInterface {
const data = schema.parse(result.data);

return {
success: true,
images: [
{
base64: data.image,
Expand Down Expand Up @@ -252,6 +272,7 @@ export class StabilityAIImageInterface implements AIImageInterface {
const data = schema.parse(result.data);

return {
success: true,
images: [
{
base64: data.image,
Expand Down
Loading