From fc2ba3e9e12231078fcfb64884851db88d52b2a2 Mon Sep 17 00:00:00 2001 From: jel <25802745+jelni@users.noreply.github.com> Date: Tue, 27 Aug 2024 01:52:08 +0200 Subject: [PATCH] add system prompt for Gemini --- src/apis/makersuite.rs | 15 +++++++++++---- src/commands/makersuite.rs | 3 ++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/apis/makersuite.rs b/src/apis/makersuite.rs index a534ba2..b133211 100644 --- a/src/apis/makersuite.rs +++ b/src/apis/makersuite.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::time::Duration; use std::{env, fmt}; @@ -122,18 +123,19 @@ pub async fn upload_file( struct GenerateContentRequest<'a> { contents: &'a [Content<'a>], safety_settings: &'static [SafetySetting], + system_instruction: Content<'a>, generation_config: GenerationConfig, } #[derive(Serialize)] struct Content<'a> { - parts: &'a [Part], + parts: &'a [Part<'a>], } #[derive(Serialize)] #[serde(rename_all = "camelCase")] -pub enum Part { - Text(String), +pub enum Part<'a> { + Text(Cow<'a, str>), FileData(FileData), } @@ -225,7 +227,7 @@ pub async fn stream_generate_content( http_client: reqwest::Client, tx: mpsc::UnboundedSender>, model: &str, - parts: &[Part], + parts: &[Part<'_>], max_output_tokens: u16, ) { let url = format!( @@ -251,6 +253,11 @@ pub async fn stream_generate_content( }, SafetySetting { category: "HARM_CATEGORY_HARASSMENT", threshold: "BLOCK_NONE" }, ], + system_instruction: Content { + parts: &[Part::Text(Cow::Borrowed( + "Be concise and precise. Don't be verbose. Answer in the user's language.", + ))], + }, generation_config: GenerationConfig { max_output_tokens }, }) .send() diff --git a/src/commands/makersuite.rs b/src/commands/makersuite.rs index ac6e870..a8a42af 100644 --- a/src/commands/makersuite.rs +++ b/src/commands/makersuite.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::fmt::Write; use std::time::Duration; @@ -42,7 +43,7 @@ impl CommandTrait for GoogleGemini { let mut parts = Vec::new(); if let Some(prompt) = prompt { - parts.push(Part::Text(prompt.0)); + parts.push(Part::Text(Cow::Owned(prompt.0))); } if let Some(message_image) =