Skip to content

Commit

Permalink
Gemini: choose a content filtering threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
enricoros committed Dec 20, 2023
1 parent 6b62a67 commit fdb66da
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 21 deletions.
23 changes: 14 additions & 9 deletions src/modules/llms/server/gemini/gemini.router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { listModelsOutputSchema, ModelDescriptionSchema } from '../llm.server.ty

import { fixupHost, openAIChatGenerateOutputSchema, OpenAIHistorySchema, openAIHistorySchema, OpenAIModelSchema, openAIModelSchema } from '../openai/openai.router';

import { GeminiContentSchema, GeminiGenerateContentRequest, geminiGeneratedContentResponseSchema, geminiModelsGenerateContentPath, geminiModelsListOutputSchema, geminiModelsListPath } from './gemini.wiretypes';
import { GeminiBlockSafetyLevel, geminiBlockSafetyLevelSchema, GeminiContentSchema, GeminiGenerateContentRequest, geminiGeneratedContentResponseSchema, geminiModelsGenerateContentPath, geminiModelsListOutputSchema, geminiModelsListPath } from './gemini.wiretypes';


// Default hosts
Expand Down Expand Up @@ -49,7 +49,7 @@ export function geminiAccess(access: GeminiAccessSchema, modelRefId: string | nu
* - System messages = [User, Model'Ok']
* - User and Assistant messages are coalesced into a single message (e.g. [User, User, Assistant, Assistant, User] -> [User[2], Assistant[2], User[1]])
*/
export const geminiGenerateContentTextPayload = (model: OpenAIModelSchema, history: OpenAIHistorySchema, n: number): GeminiGenerateContentRequest => {
export const geminiGenerateContentTextPayload = (model: OpenAIModelSchema, history: OpenAIHistorySchema, safety: GeminiBlockSafetyLevel, n: number): GeminiGenerateContentRequest => {

// convert the history to a Gemini format
const contents: GeminiContentSchema[] = [];
Expand Down Expand Up @@ -82,12 +82,12 @@ export const geminiGenerateContentTextPayload = (model: OpenAIModelSchema, histo
...(model.maxTokens && { maxOutputTokens: model.maxTokens }),
temperature: model.temperature,
},
// safetySettings: [
// { category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', threshold: 'BLOCK_NONE' },
// { category: 'HARM_CATEGORY_HATE_SPEECH', threshold: 'BLOCK_NONE' },
// { category: 'HARM_CATEGORY_HARASSMENT', threshold: 'BLOCK_NONE' },
// { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_NONE' },
// ],
safetySettings: safety !== 'HARM_BLOCK_THRESHOLD_UNSPECIFIED' ? [
{ category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', threshold: safety },
{ category: 'HARM_CATEGORY_HATE_SPEECH', threshold: safety },
{ category: 'HARM_CATEGORY_HARASSMENT', threshold: safety },
{ category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: safety },
] : undefined,
};
};

Expand All @@ -108,6 +108,7 @@ async function geminiPOST<TOut extends object, TPostBody extends object>(access:
export const geminiAccessSchema = z.object({
dialect: z.enum(['gemini']),
geminiKey: z.string(),
minSafetyLevel: geminiBlockSafetyLevelSchema,
});
export type GeminiAccessSchema = z.infer<typeof geminiAccessSchema>;

Expand All @@ -123,6 +124,10 @@ const chatGenerateInputSchema = z.object({
});


/**
* See https://github.com/google/generative-ai-js/tree/main/packages/main/src for
* the official Google implementation.
*/
export const llmGeminiRouter = createTRPCRouter({

/* [Gemini] models.list = /v1beta/models */
Expand Down Expand Up @@ -184,7 +189,7 @@ export const llmGeminiRouter = createTRPCRouter({
.mutation(async ({ input: { access, history, model } }) => {

// generate the content
const wireGeneration = await geminiPOST(access, model.id, geminiGenerateContentTextPayload(model, history, 1), geminiModelsGenerateContentPath);
const wireGeneration = await geminiPOST(access, model.id, geminiGenerateContentTextPayload(model, history, access.minSafetyLevel, 1), geminiModelsGenerateContentPath);
const generation = geminiGeneratedContentResponseSchema.parse(wireGeneration);

// only use the first result (and there should be only one)
Expand Down
21 changes: 12 additions & 9 deletions src/modules/llms/server/gemini/gemini.wiretypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,19 @@ const geminiHarmCategorySchema = z.enum([
'HARM_CATEGORY_DANGEROUS_CONTENT',
]);

export const geminiBlockSafetyLevelSchema = z.enum([
'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
'BLOCK_LOW_AND_ABOVE',
'BLOCK_MEDIUM_AND_ABOVE',
'BLOCK_ONLY_HIGH',
'BLOCK_NONE',
]);

export type GeminiBlockSafetyLevel = z.infer<typeof geminiBlockSafetyLevelSchema>;

const geminiSafetySettingSchema = z.object({
category: geminiHarmCategorySchema,
threshold: z.enum([
'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
'BLOCK_LOW_AND_ABOVE',
'BLOCK_MEDIUM_AND_ABOVE',
'BLOCK_ONLY_HIGH',
'BLOCK_NONE',
]),
threshold: geminiBlockSafetyLevelSchema,
});

const geminiGenerationConfigSchema = z.object({
Expand Down Expand Up @@ -176,10 +179,10 @@ export const geminiGeneratedContentResponseSchema = z.object({
}).optional(),
tokenCount: z.number().optional(),
// groundingAttributions: z.array(GroundingAttribution).optional(), // This field is populated for GenerateAnswer calls.
})),
})).optional(),
// NOTE: promptFeedback is only send in the first chunk in a streaming response
promptFeedback: z.object({
blockReason: z.enum(['BLOCK_REASON_UNSPECIFIED', 'SAFETY', 'OTHER']).optional(),
safetyRatings: z.array(geminiSafetyRatingSchema),
safetyRatings: z.array(geminiSafetyRatingSchema).optional(),
}).optional(),
});
2 changes: 1 addition & 1 deletion src/modules/llms/server/llm.server.streaming.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ export async function llmStreamingRelayHandler(req: NextRequest): Promise<Respon

case 'gemini':
requestAccess = geminiAccess(access, model.id, geminiModelsStreamGenerateContentPath);
body = geminiGenerateContentTextPayload(model, history, 1);
body = geminiGenerateContentTextPayload(model, history, access.minSafetyLevel, 1);
vendorStreamParser = createStreamParserGemini(model.id.replace('models/', ''));
break;

Expand Down
45 changes: 43 additions & 2 deletions src/modules/llms/vendors/gemini/GeminiSourceSetup.tsx
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import * as React from 'react';

import { FormControl, FormHelperText, Option, Select } from '@mui/joy';
import HealthAndSafetyIcon from '@mui/icons-material/HealthAndSafety';

import { FormInputKey } from '~/common/components/forms/FormInputKey';
import { FormLabelStart } from '~/common/components/forms/FormLabelStart';
import { InlineError } from '~/common/components/InlineError';
import { Link } from '~/common/components/Link';
import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton';

import { DModelSourceId } from '../../store-llms';
import type { DModelSourceId } from '../../store-llms';
import type { GeminiBlockSafetyLevel } from '../../server/gemini/gemini.wiretypes';
import { useLlmUpdateModels } from '../useLlmUpdateModels';
import { useSourceSetup } from '../useSourceSetup';

Expand All @@ -14,6 +19,14 @@ import { ModelVendorGemini } from './gemini.vendor';

const GEMINI_API_KEY_LINK = 'https://makersuite.google.com/app/apikey';

const SAFETY_OPTIONS: { value: GeminiBlockSafetyLevel, label: string }[] = [
{ value: 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', label: 'Default' },
{ value: 'BLOCK_LOW_AND_ABOVE', label: 'Low and above' },
{ value: 'BLOCK_MEDIUM_AND_ABOVE', label: 'Medium and above' },
{ value: 'BLOCK_ONLY_HIGH', label: 'Only high' },
{ value: 'BLOCK_NONE', label: 'None' },
];


export function GeminiSourceSetup(props: { sourceId: DModelSourceId }) {

Expand All @@ -22,7 +35,7 @@ export function GeminiSourceSetup(props: { sourceId: DModelSourceId }) {
useSourceSetup(props.sourceId, ModelVendorGemini);

// derived state
const { geminiKey } = access;
const { geminiKey, minSafetyLevel } = access;

const needsUserKey = !ModelVendorGemini.hasBackendCap?.();
const shallFetchSucceed = !needsUserKey || (!!geminiKey && sourceSetupValid);
Expand All @@ -45,6 +58,34 @@ export function GeminiSourceSetup(props: { sourceId: DModelSourceId }) {
placeholder='...'
/>

<FormControl orientation='horizontal' sx={{ justifyContent: 'space-between', alignItems: 'center' }}>
<FormLabelStart title='Safety Settings'
description='Threshold' />
<Select
variant='outlined'
value={minSafetyLevel} onChange={(_event, value) => value && updateSetup({ minSafetyLevel: value })}
startDecorator={<HealthAndSafetyIcon sx={{ display: { xs: 'none', sm: 'inherit' } }} />}
// indicator={<KeyboardArrowDownIcon />}
slotProps={{
root: { sx: { width: '100%' } },
indicator: { sx: { opacity: 0.5 } },
button: { sx: { whiteSpace: 'inherit' } },
}}
>
{SAFETY_OPTIONS.map(option => (
<Option key={'gemini-safety-' + option.value} value={option.value}>{option.label}</Option>
))}
</Select>
</FormControl>

<FormHelperText sx={{ display: 'block' }}>
Gemini has <Link href='https://ai.google.dev/docs/safety_setting_gemini' target='_blank' noLinkStyle>
adjustable safety settings</Link> on four categories: Harassment, Hate speech,
Sexually explicit, and Dangerous content, in addition to non-adjustable built-in filters.
By default, the model will block content with <em>medium and above</em> probability
of being unsafe.
</FormHelperText>

<SetupFormRefetchButton
refetch={refetch} disabled={!shallFetchSucceed || isFetching} error={isError}
/>
Expand Down
5 changes: 5 additions & 0 deletions src/modules/llms/vendors/gemini/gemini.vendor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { backendCaps } from '~/modules/backend/state-backend';
import { apiAsync, apiQuery } from '~/common/util/trpc.client';

import type { GeminiAccessSchema } from '../../server/gemini/gemini.router';
import type { GeminiBlockSafetyLevel } from '../../server/gemini/gemini.wiretypes';
import type { IModelVendor } from '../IModelVendor';
import type { VChatMessageOut } from '../../llm.client';
import { unifiedStreamingClient } from '../unifiedStreamingClient';
Expand All @@ -16,6 +17,7 @@ import { GeminiSourceSetup } from './GeminiSourceSetup';

export interface SourceSetupGemini {
geminiKey: string;
minSafetyLevel: GeminiBlockSafetyLevel;
}

export interface LLMOptionsGemini {
Expand Down Expand Up @@ -45,13 +47,15 @@ export const ModelVendorGemini: IModelVendor<SourceSetupGemini, GeminiAccessSche
// functions
initializeSetup: () => ({
geminiKey: '',
minSafetyLevel: 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
}),
validateSetup: (setup) => {
return setup.geminiKey?.length > 0;
},
getTransportAccess: (partialSetup): GeminiAccessSchema => ({
dialect: 'gemini',
geminiKey: partialSetup?.geminiKey || '',
minSafetyLevel: partialSetup?.minSafetyLevel || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
}),

// List Models
Expand Down Expand Up @@ -89,4 +93,5 @@ export const ModelVendorGemini: IModelVendor<SourceSetupGemini, GeminiAccessSche

// Chat Generate (streaming) with Functions
streamingChatGenerateOrThrow: unifiedStreamingClient,

};

0 comments on commit fdb66da

Please sign in to comment.