Skip to content

Commit

Permalink
send message reply chains in /gemini and /llama3
Browse files Browse the repository at this point in the history
  • Loading branch information
jelni committed Dec 29, 2024
1 parent b5a1be6 commit 843f93e
Show file tree
Hide file tree
Showing 14 changed files with 736 additions and 379 deletions.
709 changes: 446 additions & 263 deletions Cargo.lock

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ edition = "2021"
[lints.clippy]
pedantic = "warn"
nursery = "warn"
allow_attributes = "warn"

[dependencies]
async-signal = "0.2"
async-trait = "0.1"
base64 = "0.22"
bytes = "1.8"
bytes = "1.9"
charname = "1.15"
colored = "2.0"
colored = "2.2"
counter = "0.6"
dotenvy = "0.15"
futures-util = "0.3"
Expand All @@ -28,7 +29,7 @@ rmp-serde = "1.1"
serde = "1.0"
serde_json = "1.0"
tdlib = { git = "https://github.com/jelni/tdlib-rs-latest" }
tempfile = "3.13"
tempfile = "3.14"
time = { version = "0.3", features = ["macros", "serde", "serde-well-known"] }
tokio = { version = "1.41", features = ["macros", "rt-multi-thread", "signal", "time"] }
tokio = { version = "1.42", features = ["macros", "rt-multi-thread", "signal", "time"] }
url = "2.5"
2 changes: 1 addition & 1 deletion src/apis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ pub mod craiyon;
pub mod different_dimension_me;
pub mod fal;
pub mod google;
pub mod google_aistudio;
pub mod kiwifarms;
pub mod makersuite;
pub mod mathjs;
pub mod microlink;
pub mod moveit;
Expand Down
32 changes: 18 additions & 14 deletions src/apis/makersuite.rs → src/apis/google_aistudio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ pub struct Error {
}

pub async fn upload_file(
http_client: reqwest::Client,
http_client: &reqwest::Client,
file: tokio::fs::File,
size: u64,
mime_type: &str,
Expand Down Expand Up @@ -121,25 +121,26 @@ pub async fn upload_file(
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct GenerateContentRequest<'a> {
contents: &'a [Content<'a>],
contents: Cow<'a, [Content<'a>]>,
safety_settings: &'static [SafetySetting],
system_instruction: Option<Content<'a>>,
generation_config: GenerationConfig,
}

#[derive(Serialize)]
struct Content<'a> {
parts: &'a [Part<'a>],
#[derive(Clone, Serialize)]
pub struct Content<'a> {
pub parts: Cow<'a, [Part<'a>]>,
pub role: Option<&'static str>,
}

#[derive(Serialize)]
#[derive(Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub enum Part<'a> {
Text(Cow<'a, str>),
FileData(FileData),
}

#[derive(Serialize)]
#[derive(Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct FileData {
pub file_uri: String,
Expand Down Expand Up @@ -217,12 +218,12 @@ impl fmt::Display for Error {
}
}

pub async fn stream_generate_content(
pub async fn stream_generate_content<'a>(
http_client: reqwest::Client,
tx: mpsc::UnboundedSender<Result<GenerateContentResponse, GenerationError>>,
model: &str,
parts: &[Part<'_>],
system_instruction: Option<&[Part<'_>]>,
contents: Cow<'a, [Content<'a>]>,
system_instruction: Option<Content<'a>>,
max_output_tokens: u16,
) {
let url = format!(
Expand All @@ -235,8 +236,9 @@ pub async fn stream_generate_content(
.unwrap(),
)
.json(&GenerateContentRequest {
contents: &[Content { parts }],
contents,
safety_settings: &[
SafetySetting { category: "HARM_CATEGORY_HARASSMENT", threshold: "BLOCK_NONE" },
SafetySetting { category: "HARM_CATEGORY_HATE_SPEECH", threshold: "BLOCK_NONE" },
SafetySetting {
category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
Expand All @@ -246,10 +248,12 @@ pub async fn stream_generate_content(
category: "HARM_CATEGORY_DANGEROUS_CONTENT",
threshold: "BLOCK_NONE",
},
SafetySetting { category: "HARM_CATEGORY_HARASSMENT", threshold: "BLOCK_NONE" },
SafetySetting {
category: "HARM_CATEGORY_CIVIC_INTEGRITY",
threshold: "BLOCK_NONE",
},
],
system_instruction: system_instruction
.map(|system_instruction| Content { parts: system_instruction }),
system_instruction,
generation_config: GenerationConfig { max_output_tokens },
})
.send()
Expand Down
4 changes: 3 additions & 1 deletion src/apis/openai.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::borrow::Cow;

use reqwest::StatusCode;
use serde::{Deserialize, Serialize};

Expand All @@ -14,7 +16,7 @@ struct Request<'a> {
#[derive(Serialize)]
pub struct Message<'a> {
pub role: &'static str,
pub content: &'a str,
pub content: Cow<'a, str>,
}

#[derive(Deserialize)]
Expand Down
7 changes: 6 additions & 1 deletion src/apis/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,12 @@ pub async fn multiple(

let translations = match source_language {
Some(_) => response.json::<Vec<String>>().await?,
None => response.json::<Vec<(String, String)>>().await?.into_iter().map(|t| t.0).collect(),
None => response
.json::<Vec<(String, String)>>()
.await?
.into_iter()
.map(|translation| translation.0)
.collect(),
};

Ok(translations)
Expand Down
4 changes: 2 additions & 2 deletions src/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use async_trait::async_trait;
use reqwest::StatusCode;
use tdlib::types::FormattedText;

use crate::apis::makersuite::GenerationError;
use crate::apis::google_aistudio::GenerationError;
use crate::bot::TdError;
use crate::utilities::api_utils::ServerError;
use crate::utilities::command_context::CommandContext;
Expand All @@ -22,10 +22,10 @@ pub mod delete;
pub mod dice_reply;
pub mod different_dimension_me;
pub mod fal;
pub mod gemini;
pub mod groq;
pub mod kebab;
pub mod kiwifarms;
pub mod makersuite;
pub mod markov_chain;
pub mod mevo;
pub mod moveit_joke;
Expand Down
119 changes: 73 additions & 46 deletions src/commands/makersuite.rs → src/commands/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,26 @@ use tokio::sync::mpsc;
use tokio::time::Instant;

use super::{CommandError, CommandResult, CommandTrait};
use crate::apis::makersuite::{
self, Candidate, CitationSource, FileData, GenerateContentResponse, Part, PartResponse,
use crate::apis::google_aistudio::{
self, Candidate, CitationSource, Content, FileData, GenerateContentResponse, Part, PartResponse,
};
use crate::utilities::command_context::CommandContext;
use crate::utilities::convert_argument::{ConvertArgument, StringGreedyOrReply};
use crate::utilities::convert_argument::{ConvertArgument, ReplyChain};
use crate::utilities::file_download::MEBIBYTE;
use crate::utilities::rate_limit::RateLimiter;
use crate::utilities::telegram_utils;

const SYSTEM_INSTRUCTION: &str =
pub const SYSTEM_INSTRUCTION: &str =
"Be concise and precise. Don't be verbose. Answer in the language the user wrote in.";

pub struct GoogleGemini {
pub struct Gemini {
command_names: &'static [&'static str],
description: &'static str,
model: &'static str,
}

impl GoogleGemini {
impl Gemini {
#[expect(clippy::self_named_constructors)]
pub const fn gemini() -> Self {
Self {
command_names: &["gemini", "g"],
Expand All @@ -47,7 +48,7 @@ impl GoogleGemini {
}

#[async_trait]
impl CommandTrait for GoogleGemini {
impl CommandTrait for Gemini {
fn command_names(&self) -> &[&str] {
self.command_names
}
Expand All @@ -57,81 +58,107 @@ impl CommandTrait for GoogleGemini {
}

fn rate_limit(&self) -> RateLimiter<i64> {
RateLimiter::new(3, 45)
RateLimiter::new(6, 60)
}

#[expect(clippy::too_many_lines)]
async fn execute(&self, ctx: &CommandContext, arguments: String) -> CommandResult {
let prompt = Option::<StringGreedyOrReply>::convert(ctx, &arguments).await?.0;
let ReplyChain(messages) = ConvertArgument::convert(ctx, &arguments).await?.0;
ctx.send_typing().await?;

let (model, system_instruction, parts) = if let Some(message_image) =
telegram_utils::get_message_or_reply_attachment(&ctx.message, true, ctx.client_id)
.await?
{
let file = message_image.file()?;
let mut contents = Vec::new();

for message in messages {
let mut parts = Vec::new();

if file.size > 64 * MEBIBYTE {
return Err(CommandError::Custom("the file cannot be larger than 64 MiB.".into()));
if let Some(text) = message.text {
parts.push(Part::Text(Cow::Owned(text)));
}

let File::File(file) =
functions::download_file(file.id, 1, 0, 0, true, ctx.client_id).await?;
if let Some(message_image) = telegram_utils::get_message_attachment(
Cow::Owned(message.content),
true,
ctx.client_id,
)
.await?
{
let file = message_image.file()?;

let open_file = tokio::fs::File::open(file.local.path).await.unwrap();
if file.size > 64 * MEBIBYTE {
return Err(CommandError::Custom("files cannot be larger than 64 MiB.".into()));
}

let file = makersuite::upload_file(
ctx.bot_state.http_client.clone(),
open_file,
file.size.try_into().unwrap(),
message_image.mime_type()?,
)
.await?;
let File::File(file) =
functions::download_file(file.id, 1, 0, 0, true, ctx.client_id).await?;

let mut parts = if let Some(prompt) = prompt {
vec![Part::Text(Cow::Owned(prompt.0))]
} else {
vec![Part::Text(Cow::Borrowed("Comment briefly on what you see."))]
};
let open_file = tokio::fs::File::open(file.local.path).await.unwrap();

parts.push(Part::FileData(FileData { file_uri: file.uri }));
let file = google_aistudio::upload_file(
&ctx.bot_state.http_client,
open_file,
file.size.try_into().unwrap(),
message_image.mime_type()?,
)
.await?;

(self.model, Some([Part::Text(Cow::Borrowed(SYSTEM_INSTRUCTION))].as_slice()), parts)
} else {
let mut parts = vec![Part::Text(Cow::Borrowed(SYSTEM_INSTRUCTION))];
parts.push(Part::FileData(FileData { file_uri: file.uri }));
}

if let Some(prompt) = prompt {
parts.push(Part::Text(Cow::Owned(prompt.0)));
} else {
return Err(CommandError::Custom("no prompt or file provided.".into()));
if !parts.is_empty() {
contents.push(Content {
parts: Cow::Owned(parts),
role: Some(if message.my { "model" } else { "user" }),
});
}
}

if contents.is_empty() {
return Err(CommandError::Custom("no prompt or file provided.".into()));
}

(self.model, None, parts)
let system_instruction = if contents
.iter()
.any(|content| content.parts.iter().any(|part| matches!(part, Part::Text(..))))
{
Some(Content {
parts: Cow::Borrowed([Part::Text(Cow::Borrowed(SYSTEM_INSTRUCTION))].as_slice()),
role: None,
})
} else {
contents.push(Content {
parts: Cow::Borrowed(
[Part::Text(Cow::Borrowed("Comment briefly on what you see."))].as_slice(),
),
role: Some("user"),
});

None
};

let http_client = ctx.bot_state.http_client.clone();
let (tx, mut rx) = mpsc::unbounded_channel();
let model = self.model;

tokio::spawn(async move {
makersuite::stream_generate_content(
google_aistudio::stream_generate_content(
http_client,
tx,
model,
&parts,
Cow::Owned(contents),
system_instruction,
512,
)
.await;
});

let mut next_update = Instant::now() + Duration::from_secs(5);
let mut last_update = Instant::now();
let mut changed_after_last_update = false;
let mut progress = Option::<GenerationProgress>::None;
let mut message = Option::<Message>::None;

loop {
let (update_message, finished) = if let Ok(response) =
tokio::time::timeout_at(next_update, rx.recv()).await
tokio::time::timeout_at(last_update + Duration::from_secs(5), rx.recv()).await
{
match response {
Some(response) => {
Expand All @@ -155,7 +182,7 @@ impl CommandTrait for GoogleGemini {
None => (true, true),
}
} else {
next_update = Instant::now() + Duration::from_secs(5);
last_update = Instant::now();
(true, false)
};

Expand Down Expand Up @@ -183,7 +210,7 @@ impl CommandTrait for GoogleGemini {
);
}

next_update = Instant::now() + Duration::from_secs(5);
last_update = Instant::now();
changed_after_last_update = false;
}

Expand Down
Loading

0 comments on commit 843f93e

Please sign in to comment.