From 6a7f2cdb87a1a77cb7c9516e0e7c894bdf7b2dcb Mon Sep 17 00:00:00 2001 From: Julian Popescu Date: Tue, 10 Dec 2024 01:16:31 +0100 Subject: [PATCH] feat: add organization prompt for template download --- Cargo.lock | 10 +++ Cargo.toml | 2 +- src/commands/template.rs | 92 +++++++++++++++++++- src/graphql/get_viewer_organizations.graphql | 13 +++ 4 files changed, 115 insertions(+), 2 deletions(-) create mode 100644 src/graphql/get_viewer_organizations.graphql diff --git a/Cargo.lock b/Cargo.lock index adcd7cd..55c3b0f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -763,6 +763,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de" dependencies = [ "console", + "fuzzy-matcher", "shell-words", "tempfile", "thiserror", @@ -1017,6 +1018,15 @@ dependencies = [ "slab", ] +[[package]] +name = "fuzzy-matcher" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54614a3312934d066701a80f20f15fa3b56d67ac7722b39eea5b4c9dd1d66c94" +dependencies = [ + "thread_local", +] + [[package]] name = "generic-array" version = "0.14.7" diff --git a/Cargo.toml b/Cargo.toml index 641d1eb..f5f5248 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,7 @@ base32 = "0.5" base64 = "0.22" chrono = { version = "0.4", features = ["serde"] } clap = { version = "4.4", features = ["derive", "cargo", "color", "env"] } -dialoguer = "0.11.0" +dialoguer = { version = "0.11.0", features = ["fuzzy-select"] } dirs = "5.0" dunce = "1.0" fs4 = { version = "0.8", features = ["tokio"] } diff --git a/src/commands/template.rs b/src/commands/template.rs index c865e55..eafeec6 100644 --- a/src/commands/template.rs +++ b/src/commands/template.rs @@ -1,9 +1,11 @@ use crate::{ + colors::ColorChoiceExt, commands::{ install::{install, Install}, login::check_login, GlobalArgs, }, + dirs::pyproject_path, download::download_archive, error::{self, Result}, git::init_repository, @@ -25,6 +27,14 @@ use url::Url; )] pub struct GetCompetitionTemplate; +#[derive(GraphQLQuery)] +#[graphql( + query_path = "src/graphql/get_viewer_organizations.graphql", + schema_path = "src/graphql/schema.graphql", + response_derives = "Debug" +)] +pub struct GetViewerOrganizations; + #[derive(Args, Debug, Serialize)] #[command(author, version, about)] pub struct Template { @@ -36,7 +46,7 @@ pub struct Template { pub async fn template(args: Template, global: GlobalArgs) -> Result<()> { let m = MultiProgress::new(); - check_login(global.clone(), &m).await?; + let logged_in = check_login(global.clone(), &m).await?; let client = GraphQLClient::new(global.url.parse()?).await?; @@ -101,6 +111,44 @@ pub async fn template(args: Template, global: GlobalArgs) -> Result<()> { })? .download_url; + let organizations = if logged_in { + client + .send::(get_viewer_organizations::Variables {}) + .await? + .viewer + .organizations + .nodes + .into_iter() + .map(|node| node.organization) + .collect::>() + } else { + vec![] + }; + + let organization = if !organizations.is_empty() { + m.suspend(|| -> Result<_> { + let items = organizations + .iter() + .map(|org| format!("@{} ({})", org.username.clone(), org.display_name.clone())) + .collect::>(); + Result::Ok( + dialoguer::FuzzySelect::with_theme(global.color.dialoguer().as_ref()) + .with_prompt("Would you like to submit with a team? (Press ESC to skip)") + .items(&items) + .interact_opt() + .map_err(|err| { + error::system( + &format!("Could not select organization: {err}"), + "Please try again", + ) + })? + .and_then(|index| organizations.into_iter().nth(index)), + ) + })? + } else { + None + }; + pb.set_message("Downloading competition template..."); match download_archive(download_url, &destination, &pb).await { Ok(_) => { @@ -123,6 +171,48 @@ pub async fn template(args: Template, global: GlobalArgs) -> Result<()> { } } + if let Some(organization) = organization { + let toml_path = pyproject_path(&destination); + let mut doc = tokio::fs::read_to_string(&toml_path) + .await + .map_err(|err| { + error::system( + &format!("Failed to read {}: {err}", toml_path.display()), + "Contact the competition organizer", + ) + })? + .parse::() + .map_err(|err| { + error::system( + &format!("Failed to parse {}: {err}", toml_path.display()), + "Contact the competition organizer", + ) + })?; + let aqora_config = doc + .get_mut("tool") + .and_then(|tool| tool.as_table_mut()) + .and_then(|tool| tool.get_mut("aqora")) + .and_then(|aqora| aqora.as_table_mut()) + .ok_or_else(|| { + error::system( + &format!( + "Failed to parse {}: Could not find tool.aqora", + toml_path.display() + ), + "Contact the competition organizer", + ) + })?; + aqora_config["entity"] = toml_edit::value(organization.username.clone()); + tokio::fs::write(&toml_path, doc.to_string()) + .await + .map_err(|err| { + error::system( + &format!("Failed to write {}: {err}", toml_path.display()), + "Check the permissions of the file", + ) + })?; + } + if !args.no_install { let install_global = GlobalArgs { project: destination.clone(), diff --git a/src/graphql/get_viewer_organizations.graphql b/src/graphql/get_viewer_organizations.graphql new file mode 100644 index 0000000..bbefcfe --- /dev/null +++ b/src/graphql/get_viewer_organizations.graphql @@ -0,0 +1,13 @@ +query GetViewerOrganizations { + viewer { + organizations { + nodes { + organization { + id + username + displayName + } + } + } + } +}