Skip to content

Commit

Permalink
Add SDF endpoints to update prompts on the server
Browse files Browse the repository at this point in the history
  • Loading branch information
jkeiser committed Nov 26, 2024
1 parent f027046 commit 9c64f82
Show file tree
Hide file tree
Showing 11 changed files with 406 additions and 45 deletions.
33 changes: 30 additions & 3 deletions lib/asset-sprayer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@
while_true
)]

use std::path::PathBuf;
use std::{borrow::Cow, path::PathBuf};

use async_openai::{config::OpenAIConfig, types::CreateChatCompletionRequest};
use config::AssetSprayerConfig;
use prompt::Prompt;
use telemetry::prelude::*;
use thiserror::Error;

pub mod config;
pub mod prompt;
pub use prompt::Prompt;

#[remain::sorted]
#[derive(Debug, Error)]
Expand Down Expand Up @@ -77,7 +77,7 @@ impl AssetSprayer {
}

pub async fn prompt(&self, prompt: &Prompt) -> Result<CreateChatCompletionRequest> {
prompt.prompt(&self.prompts_dir).await
prompt.prompt(self).await
}

pub async fn run(&self, prompt: &Prompt) -> Result<String> {
Expand All @@ -95,6 +95,33 @@ impl AssetSprayer {
.ok_or(AssetSprayerError::EmptyChoice)?;
Ok(text)
}

pub async fn raw_prompt(
&self,
prompt: &(impl RawPromptYamlSource + std::fmt::Display),
) -> Result<CreateChatCompletionRequest> {
Ok(serde_yaml::from_str(&self.raw_prompt_yaml(prompt).await?)?)
}

pub async fn raw_prompt_yaml(
&self,
prompt: &(impl RawPromptYamlSource + std::fmt::Display),
) -> Result<Cow<'static, str>> {
if let Some(ref prompts_dir) = self.prompts_dir {
// Read from disk if prompts_dir is available (faster dev cycle)
let path = prompts_dir.join(prompt.raw_prompt_yaml_relative_path());
info!("Loading prompt for {} from disk at {:?}", prompt, path);
Ok(tokio::fs::read_to_string(path).await?.into())
} else {
info!("Loading embedded prompt for {}", prompt);
Ok(prompt.raw_prompt_yaml_embedded().into())
}
}
}

pub trait RawPromptYamlSource {
fn raw_prompt_yaml_relative_path(&self) -> &'static str;
fn raw_prompt_yaml_embedded(&self) -> &'static str;
}

#[ignore = "You must have OPENAI_API_KEY set to run this test"]
Expand Down
66 changes: 24 additions & 42 deletions lib/asset-sprayer/src/prompt.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::{borrow::Cow, path::PathBuf};

use async_openai::types::{
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage,
Expand All @@ -8,7 +6,7 @@ use async_openai::types::{
use serde::{Deserialize, Serialize};
use telemetry::prelude::*;

use crate::{AssetSprayerError, Result};
use crate::{AssetSprayer, AssetSprayerError, RawPromptYamlSource, Result};

#[derive(
Debug, Clone, PartialEq, Eq, Serialize, Deserialize, strum::Display, strum::EnumDiscriminants,
Expand All @@ -25,8 +23,11 @@ pub enum Prompt {
Eq,
Serialize,
Deserialize,
strum::AsRefStr,
strum::Display,
strum::EnumIter,
strum::EnumString,
strum::IntoStaticStr,
strum::VariantNames,
)]
pub enum AwsCliCommandPromptKind {
Expand All @@ -47,9 +48,9 @@ pub struct AwsCliCommand(pub String, pub String);
impl Prompt {
pub async fn prompt(
&self,
prompts_dir: &Option<PathBuf>,
asset_sprayer: &AssetSprayer,
) -> Result<CreateChatCompletionRequest> {
let raw_prompt = self.raw_prompt(prompts_dir).await?;
let raw_prompt = asset_sprayer.raw_prompt(self).await?;
self.replace_prompt(raw_prompt).await
}

Expand Down Expand Up @@ -131,43 +132,38 @@ impl Prompt {
Ok(response.error_for_status()?.text().await?)
}
}
}

pub async fn raw_prompt(
&self,
prompts_dir: &Option<PathBuf>,
) -> Result<CreateChatCompletionRequest> {
Ok(serde_yaml::from_str(
&self.raw_prompt_yaml(prompts_dir).await?,
)?)
impl AwsCliCommand {
pub fn new(command: impl Into<String>, subcommand: impl Into<String>) -> Self {
Self(command.into(), subcommand.into())
}

async fn raw_prompt_yaml(&self, prompts_dir: &Option<PathBuf>) -> Result<Cow<'static, str>> {
if let Some(ref prompts_dir) = prompts_dir {
// Read from disk if prompts_dir is available (faster dev cycle)
let path = prompts_dir.join(self.raw_prompt_yaml_relative_path());
info!("Loading prompt for {} from disk at {:?}", self, path);
Ok(tokio::fs::read_to_string(path).await?.into())
} else {
info!("Loading embedded prompt for {}", self);
Ok(self.raw_prompt_yaml_embedded().into())
}
pub fn command(&self) -> &str {
&self.0
}

pub fn subcommand(&self) -> &str {
&self.1
}
}

fn raw_prompt_yaml_relative_path(&self) -> &str {
impl RawPromptYamlSource for Prompt {
fn raw_prompt_yaml_relative_path(&self) -> &'static str {
match self {
Self::AwsCliCommandPrompt(kind, _) => kind.yaml_relative_path(),
Self::AwsCliCommandPrompt(kind, _) => kind.raw_prompt_yaml_relative_path(),
}
}

fn raw_prompt_yaml_embedded(&self) -> &'static str {
match self {
Self::AwsCliCommandPrompt(kind, _) => kind.yaml_embedded(),
Self::AwsCliCommandPrompt(kind, _) => kind.raw_prompt_yaml_embedded(),
}
}
}

impl AwsCliCommandPromptKind {
const fn yaml_relative_path(&self) -> &'static str {
impl RawPromptYamlSource for AwsCliCommandPromptKind {
fn raw_prompt_yaml_relative_path(&self) -> &'static str {
match self {
Self::AssetSchema => "aws/asset_schema.yaml",
Self::CreateAction => "aws/create_action.yaml",
Expand All @@ -177,7 +173,7 @@ impl AwsCliCommandPromptKind {
}
}

fn yaml_embedded(&self) -> &'static str {
fn raw_prompt_yaml_embedded(&self) -> &'static str {
match self {
Self::AssetSchema => include_str!("../prompts/aws/asset_schema.yaml"),
Self::CreateAction => include_str!("../prompts/aws/create_action.yaml"),
Expand All @@ -187,17 +183,3 @@ impl AwsCliCommandPromptKind {
}
}
}

impl AwsCliCommand {
pub fn new(command: impl Into<String>, subcommand: impl Into<String>) -> Self {
Self(command.into(), subcommand.into())
}

pub fn command(&self) -> &str {
&self.0
}

pub fn subcommand(&self) -> &str {
&self.1
}
}
1 change: 1 addition & 0 deletions lib/dal/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pub mod label_list;
pub mod layer_db_types;
pub mod management;
pub mod module;
pub mod prompt_override;
pub mod pkg;
pub mod prop;
pub mod property_editor;
Expand Down
5 changes: 5 additions & 0 deletions lib/dal/src/migrations/U3402__prompt_overrides.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CREATE TABLE prompt_overrides
(
kind VARCHAR(255) NOT NULL PRIMARY KEY,
prompt_yaml TEXT NOT NULL
);
127 changes: 127 additions & 0 deletions lib/dal/src/prompt_override.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
use crate::{DalContext, WsEvent, WsEventResult, WsPayload};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use telemetry::prelude::*;
use thiserror::Error;

#[remain::sorted]
#[derive(Error, Debug)]
pub enum PromptOverrideError {
#[error("pg error: {0}")]
Pg(#[from] si_data_pg::PgError),
#[error("transactions error: {0}")]
Transactions(#[from] crate::TransactionsError),
#[error("ws event error: {0}")]
WsEvent(#[from] crate::WsEventError),
}

pub type Result<T> = std::result::Result<T, PromptOverrideError>;

#[derive(Serialize, Deserialize, Debug, PartialEq, Clone, Eq)]
#[serde(rename_all = "camelCase")]
pub struct PromptUpdatedPayload {
pub kind: String,
pub overridden: bool,
}

impl WsEvent {
pub async fn prompt_updated(
ctx: &DalContext,
kind: String,
overridden: bool,
) -> WsEventResult<Self> {
WsEvent::new(
ctx,
WsPayload::PromptUpdated(PromptUpdatedPayload { kind, overridden }),
)
.await
}
}

pub struct PromptOverride;

impl PromptOverride {
pub async fn list(ctx: &DalContext) -> Result<HashSet<String>> {
let rows = ctx
.txns()
.await?
.pg()
.query(
"
SELECT kind FROM prompt_overrides
",
&[],
)
.await?;
let mut result = HashSet::with_capacity(rows.len());
for row in rows {
result.insert(row.try_get(0)?);
}
Ok(result)
}

pub async fn get_opt(ctx: &DalContext, kind: &str) -> Result<Option<String>> {
match ctx
.txns()
.await?
.pg()
.query_opt(
"
SELECT prompt_yaml FROM prompt_overrides WHERE kind = $1
",
&[&kind],
)
.await?
{
Some(row) => Ok(Some(row.try_get(0)?)),
None => Ok(None),
}
}

pub async fn set(ctx: &DalContext, kind: &str, prompt: &str) -> Result<()> {
ctx.txns()
.await?
.pg()
.execute(
"
INSERT INTO prompt_overrides
(kind, prompt_yaml)
VALUES
($1, $2)
ON CONFLICT (kind) DO
UPDATE SET prompt_yaml = $2
",
&[&kind, &prompt],
)
.await?;

WsEvent::prompt_updated(ctx, kind.to_owned(), true)
.await?
.publish_immediately(ctx)
.await?;
Ok(())
}

pub async fn reset(ctx: &DalContext, kind: &str) -> Result<bool> {
let deleted = ctx
.txns()
.await?
.pg()
.execute(
"
DELETE FROM prompt_overrides WHERE kind = $1
",
&[&kind],
)
.await?;
if deleted > 0 {
WsEvent::prompt_updated(ctx, kind.to_owned(), false)
.await?
.publish_immediately(ctx)
.await?;
Ok(true)
} else {
Ok(false)
}
}
}
2 changes: 2 additions & 0 deletions lib/dal/src/ws_event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use crate::management::prototype::ManagementFuncExecutedPayload;
use crate::pkg::{
ImportWorkspaceVotePayload, WorkspaceActorPayload, WorkspaceImportApprovalActorPayload,
};
use crate::prompt_override::PromptUpdatedPayload;
use crate::qualification::QualificationCheckPayload;
use crate::schema::variant::{
SchemaVariantClonedPayload, SchemaVariantDeletedPayload, SchemaVariantReplacedPayload,
Expand Down Expand Up @@ -110,6 +111,7 @@ pub enum WsPayload {
ManagementFuncExecuted(ManagementFuncExecutedPayload),
ModuleImported(Vec<si_frontend_types::SchemaVariant>),
Online(OnlinePayload),
PromptUpdated(PromptUpdatedPayload),
ResourceRefreshed(ComponentUpdatedPayload),
SchemaVariantCloned(SchemaVariantClonedPayload),
SchemaVariantCreated(frontend_types::SchemaVariant),
Expand Down
1 change: 1 addition & 0 deletions lib/dal/tests/integration_test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mod management;
mod module;
mod node_weight;
mod pkg;
mod prompt_overrides;
mod prop;
mod property_editor;
mod qualifications;
Expand Down
Loading

0 comments on commit 9c64f82

Please sign in to comment.