From 7cdd732ae86e729bdaafc31c9e07cd67702617ef Mon Sep 17 00:00:00 2001 From: Julian Popescu Date: Mon, 11 Mar 2024 17:40:23 +0100 Subject: [PATCH] add ability to use defaults for evaluate --- Cargo.lock | 6 +- runner/src/pipeline.rs | 79 ++--- src/commands/test.rs | 25 +- src/commands/upload.rs | 328 ++++++++++++------ src/dirs.rs | 4 + src/graphql/get_entity_id_by_username.graphql | 6 - src/graphql/get_viewer_id.graphql | 5 - src/graphql/schema.graphql | 6 + src/graphql/submission_upload_info.graphql | 17 + src/python.rs | 15 +- 10 files changed, 309 insertions(+), 182 deletions(-) delete mode 100644 src/graphql/get_entity_id_by_username.graphql delete mode 100644 src/graphql/get_viewer_id.graphql create mode 100644 src/graphql/submission_upload_info.graphql diff --git a/Cargo.lock b/Cargo.lock index 88e3105..e9825dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -91,7 +91,7 @@ dependencies = [ [[package]] name = "aqora" -version = "0.1.1" +version = "0.1.2" dependencies = [ "aqora-config", "aqora-runner", @@ -135,7 +135,7 @@ dependencies = [ [[package]] name = "aqora-config" -version = "0.1.0" +version = "0.1.1" dependencies = [ "pep440_rs", "pyproject-toml", @@ -146,7 +146,7 @@ dependencies = [ [[package]] name = "aqora-runner" -version = "0.1.0" +version = "0.1.1" dependencies = [ "aqora-config", "futures", diff --git a/runner/src/pipeline.rs b/runner/src/pipeline.rs index 29fc3bd..5a9d430 100644 --- a/runner/src/pipeline.rs +++ b/runner/src/pipeline.rs @@ -11,7 +11,7 @@ use pyo3::{ }; use serde::{Deserialize, Serialize}; use split_stream_by::{Either, SplitStreamByMapExt}; -use std::{collections::HashMap, path::PathBuf, sync::Arc}; +use std::{collections::HashMap, path::PathBuf}; use thiserror::Error; #[derive(Debug, Clone)] @@ -38,6 +38,7 @@ impl LayerFunction { inspect .getattr(intern!(py, "signature"))? .call1((func,))? + .getattr(intern!(py, "parameters"))? .call_method0(intern!(py, "values"))?, )?; for parameter in parameters { @@ -63,6 +64,7 @@ impl LayerFunction { takes_context_kwarg = true; } } + Ok(Self { func: func.to_object(py), takes_input_arg, @@ -182,29 +184,36 @@ impl Layer { input: &PyObject, original_input: &PyObject, context: &PyObject, + default: Option<&LayerEvaluation>, ) -> PyResult { let context = if let Some(context_transform) = self.context.as_ref() { context_transform .call(input, original_input, context) .await? + } else if let Some(default) = default.as_ref().map(|default| &default.context) { + default.clone() } else { - input.clone() + context.clone() }; + let transform = if let Some(transform) = self.transform.as_ref() { transform.call(input, original_input, &context).await? + } else if let Some(default) = default.as_ref().map(|default| &default.transform) { + default.clone() } else { input.clone() }; + let metric = if let Some(metric) = self.metric.as_ref() { Some(metric.call(&transform, original_input, &context).await?) } else { - None + default.as_ref().and_then(|default| default.metric.clone()) }; let branch = if let Some(branch) = self.branch.as_ref() { Some(branch.call(&transform, original_input, &context).await?) } else { - None + default.as_ref().and_then(|default| default.branch.clone()) }; Ok(LayerEvaluation { @@ -214,10 +223,6 @@ impl Layer { branch, }) } - - pub async fn assert_metric(&self, evaluation: &LayerEvaluation) -> PyResult<()> { - todo!() - } } #[derive(Clone, Debug)] @@ -248,7 +253,6 @@ impl PipelineConfig { #[derive(Clone, Debug)] pub struct Evaluator { - config: PipelineConfig, layers: Vec, } @@ -265,12 +269,21 @@ pub enum EvaluationError { ), #[error("Layer not found: {0}")] LayerNotFound(String), + #[error("{0}")] + Custom(String), +} + +impl EvaluationError { + pub fn custom(err: impl ToString) -> Self { + Self::Custom(err.to_string()) + } } impl Evaluator { pub async fn evaluate( &self, mut input: PyObject, + defaults: Option<&EvaluationResult>, ) -> Result { let mut out = EvaluationResult::new(); macro_rules! try_or_bail { @@ -286,7 +299,14 @@ impl Evaluator { let mut layer_index = 0; while layer_index < self.layers.len() { let layer = &self.layers[layer_index]; - let result = try_or_bail!(layer.evaluate(&input, &original_input, &context).await); + let default = defaults + .and_then(|defaults| defaults.get(&layer.name)) + .and_then(|defaults| defaults.get(out.get(&layer.name).map_or(0, |v| v.len()))); + let result = try_or_bail!( + layer + .evaluate(&input, &original_input, &context, default) + .await + ); if let Some(branch) = try_or_bail!(result.branch_str()) { layer_index = try_or_bail!(self .layers @@ -302,44 +322,6 @@ impl Evaluator { } Ok(out) } - - pub async fn assert_metric(&self, results: &EvaluationResult) -> Result<(), EvaluationError> { - let mut result_indexes = HashMap::::new(); - let mut layer_index = 0; - while layer_index < self.layers.len() { - let layer = &self.layers[layer_index]; - let layer_name = &layer.name; - let result_index = *result_indexes.entry(layer_name.clone()).or_insert(0); - let result = results - .get(layer_name) - .ok_or_else(|| EvaluationError::LayerNotFound(layer_name.clone()))? - .get(result_index) - .ok_or_else(|| EvaluationError::LayerNotFound(layer_name.clone()))?; - layer.assert_metric(result).await?; - result_indexes.insert(layer_name.clone(), result_index + 1); - if let Some(branch) = result.branch_str()? { - layer_index = self - .layers - .iter() - .position(|layer| layer.name == branch) - .ok_or_else(|| EvaluationError::LayerNotFound(branch))?; - } else { - layer_index += 1; - } - } - Ok(()) - } - - pub fn evaluate_all( - self, - inputs: impl Stream>, - ) -> impl Stream> { - let this = Arc::new(self); - inputs - .map_err(|err| (EvaluationResult::new(), EvaluationError::Python(err))) - .map_ok(move |input| (input, this.clone())) - .and_then(|(input, evaluator)| async move { evaluator.evaluate(input).await }) - } } pub struct Pipeline { @@ -415,7 +397,6 @@ impl Pipeline { pub fn evaluator(&self) -> Evaluator { Evaluator { layers: self.layers.clone(), - config: self.config.clone(), } } diff --git a/src/commands/test.rs b/src/commands/test.rs index 7704170..2d71eb6 100644 --- a/src/commands/test.rs +++ b/src/commands/test.rs @@ -1,12 +1,13 @@ use crate::{ commands::GlobalArgs, dirs::{ - init_venv, project_data_dir, project_last_run_dir, project_use_case_toml_path, - read_pyproject, + init_venv, project_data_dir, project_last_run_dir, project_last_run_result, + project_use_case_toml_path, read_pyproject, }, error::{self, Result}, + python::LastRunResult, }; -use aqora_config::{PyProject, Version}; +use aqora_config::PyProject; use aqora_runner::pipeline::{ EvaluateAllInfo, EvaluateInputInfo, EvaluationError, EvaluationResult, Evaluator, Pipeline, PipelineConfig, @@ -17,7 +18,6 @@ use indicatif::{MultiProgress, ProgressBar}; use owo_colors::{OwoColorize, Stream as OwoStream, Style}; use pyo3::prelude::*; use pyo3::{exceptions::PyException, Python}; -use serde::{Deserialize, Serialize}; use std::{ path::Path, pin::Pin, @@ -31,14 +31,6 @@ pub struct Test { pub test: Vec, } -#[derive(Serialize, Deserialize)] -pub struct LastRunResult { - #[serde(flatten)] - pub info: EvaluateAllInfo, - pub use_case_version: Option, - pub submission_version: Option, -} - fn evaluate( evaluator: Evaluator, inputs: impl Stream)>, @@ -50,7 +42,7 @@ fn evaluate( .map(move |input| (input, evaluator.clone())) .then(|((index, result), evaluator)| async move { match result { - Ok(input) => match evaluator.evaluate(input.clone()).await { + Ok(input) => match evaluator.evaluate(input.clone(), None).await { Ok(result) => ( index, EvaluateInputInfo { @@ -100,10 +92,7 @@ fn evaluate( .if_supports_color(OwoStream::Stdout, |text| text.red()), err = err )); - return Err(( - item.result, - EvaluationError::Python(PyException::new_err(err)), - )); + return Err((item.result, EvaluationError::custom(err))); } let is_ok = item.error.is_none(); @@ -202,7 +191,7 @@ pub async fn test_submission(args: Test, global: GlobalArgs, project: PyProject) })?; let last_run_dir = project_last_run_dir(&global.project); - let last_run_result_file = last_run_dir.join("result.msgpack"); + let last_run_result_file = project_last_run_result(&global.project); if args.test.is_empty() { if last_run_dir.exists() { tokio::fs::remove_dir_all(&last_run_dir) diff --git a/src/commands/upload.rs b/src/commands/upload.rs index ca33948..27d7d34 100644 --- a/src/commands/upload.rs +++ b/src/commands/upload.rs @@ -1,15 +1,17 @@ use crate::{ commands::GlobalArgs, compress::compress, - dirs::{init_venv, pyproject_path, read_pyproject}, + dirs::{ + init_venv, project_last_run_dir, project_last_run_result, pyproject_path, read_pyproject, + }, error::{self, Result}, - graphql_client::GraphQLClient, + graphql_client::{custom_scalars::*, GraphQLClient}, id::Id, - python::build_package, + python::{build_package, LastRunResult}, readme::read_readme, revert_file::RevertFile, }; -use aqora_config::PyProject; +use aqora_config::{PyProject, Version}; use clap::Args; use futures::prelude::*; use graphql_client::GraphQLQuery; @@ -58,55 +60,84 @@ pub async fn get_competition_id_by_slug( #[derive(GraphQLQuery)] #[graphql( - query_path = "src/graphql/get_viewer_id.graphql", + query_path = "src/graphql/submission_upload_info.graphql", schema_path = "src/graphql/schema.graphql", response_derives = "Debug" )] -pub struct GetViewerId; +pub struct SubmissionUploadInfo; -pub async fn get_viewer_id(client: &GraphQLClient) -> Result { - let viewer = client - .send::(get_viewer_id::Variables {}) - .await? - .viewer; - Id::parse_node_id(viewer.id).map_err(|err| { - error::system( - &format!("Could not parse viewer ID: {}", err), - "This is a bug, please report it", - ) - }) +pub struct SubmissionUploadInfoResponse { + competition_id: Id, + use_case_version: Version, + entity_id: Id, } -#[derive(GraphQLQuery)] -#[graphql( - query_path = "src/graphql/get_entity_id_by_username.graphql", - schema_path = "src/graphql/schema.graphql", - response_derives = "Debug" -)] -pub struct GetEntityIdByUsername; - -pub async fn get_entity_id_by_username( +pub async fn get_submission_upload_info( client: &GraphQLClient, - username: impl Into, -) -> Result { - let username = username.into(); - let entity = client - .send::(get_entity_id_by_username::Variables { - username: username.clone(), + slug: impl Into, + username: Option>, +) -> Result { + let slug = slug.into(); + let username = username.map(|u| u.into()); + let response = client + .send::(submission_upload_info::Variables { + slug: slug.clone(), + username: username.clone().unwrap_or_default(), + use_username: username.is_some(), }) - .await? - .entity_by_username + .await?; + let competition = response.competition_by_slug.ok_or_else(|| { + error::user( + &format!("Competition '{}' not found", slug), + "Please make sure the competition is correct", + ) + })?; + let competition_id = Id::parse_node_id(competition.id).map_err(|err| { + error::system( + &format!("Could not parse competition ID: {}", err), + "This is a bug, please report it", + ) + })?; + let use_case_version = competition + .use_case + .latest .ok_or_else(|| { - error::user( - &format!("User '{}' not found", username), - "Please make sure the username is correct", + error::system( + "No use case version found", + "Please contact the competition organizer", + ) + })? + .version + .parse() + .map_err(|err| { + error::system( + &format!("Invalid use case version found: {err}"), + "Please contact the competition organizer", ) })?; - Id::parse_node_id(entity.id).map_err(|err| { + let entity_id = if let Some(username) = username { + response + .entity_by_username + .ok_or_else(|| { + error::user( + &format!("User '{}' not found", username), + "Please make sure the username is correct", + ) + })? + .id + } else { + response.viewer.id + }; + let entity_id = Id::parse_node_id(entity_id).map_err(|err| { error::system( &format!("Could not parse entity ID: {}", err), "This is a bug, please report it", ) + })?; + Ok(SubmissionUploadInfoResponse { + competition_id, + use_case_version, + entity_id, }) } @@ -458,6 +489,50 @@ pub async fn upload_submission(args: Upload, global: GlobalArgs, project: PyProj ) })?; + let readme = read_readme( + &global.project, + project.project.as_ref().and_then(|p| p.readme.as_ref()), + ) + .await + .map_err(|err| { + error::user( + &format!("Could not read readme: {}", err), + "Please make sure the readme is valid", + ) + })?; + + let version = project.version().ok_or_else(|| { + error::user( + "Could not get project version", + "Please make sure the project is valid", + ) + })?; + + let evaluation_path = project_last_run_dir(&global.project); + if !evaluation_path.exists() { + return Err(error::user( + "No last run result found", + "Please make sure you have run `aqora test`", + )); + } + let last_run_result: LastRunResult = + std::fs::File::open(project_last_run_result(&global.project)) + .map_err(rmp_serde::decode::Error::InvalidDataRead) + .and_then(rmp_serde::from_read) + .map_err(|err| { + error::user( + &format!("Could not read last run result: {}", err), + "Please make sure your last call to `aqora test` was successful", + ) + })?; + + if last_run_result.submission_version.as_ref() != Some(&version) { + return Err(error::user( + "Submission version does not match last run result", + "Please re-run `aqora test`", + )); + } + let slug = args .competition .as_ref() @@ -468,36 +543,25 @@ pub async fn upload_submission(args: Upload, global: GlobalArgs, project: PyProj "Please specify a competition in either the pyproject.toml or the command line", ) })?; - let competition_id = get_competition_id_by_slug(&client, slug).await?; - let entity_id = if let Some(username) = &config.entity { - get_entity_id_by_username(&client, username).await? - } else { - get_viewer_id(&client).await? - }; - let version = project.version().ok_or_else(|| { - error::user( - "Could not get project version", - "Please make sure the project is valid", - ) - })?; + let SubmissionUploadInfoResponse { + entity_id, + competition_id, + use_case_version, + } = get_submission_upload_info(&client, slug, config.entity.as_ref()).await?; let package_name = format!( "submission-{}-{}", competition_id.to_package_id(), entity_id.to_package_id() ); - let readme = read_readme( - &global.project, - project.project.as_ref().and_then(|p| p.readme.as_ref()), - ) - .await - .map_err(|err| { - error::user( - &format!("Could not read readme: {}", err), - "Please make sure the readme is valid", - ) - })?; + if last_run_result.use_case_version.as_ref() != Some(&use_case_version) { + return Err(error::user( + "Use case version does not match last run result", + "Please install the latest version with `aqora install --upgrade` and re-run `aqora test`", + )); + } + let project_version = client .send::(update_submission_mutation::Variables { competition_id: competition_id.to_node_id(), @@ -512,52 +576,116 @@ pub async fn upload_submission(args: Upload, global: GlobalArgs, project: PyProj let s3_client = reqwest::Client::new(); - let upload_url = if let Some(url) = project_version - .files - .iter() - .find(|f| { - matches!( - f.kind, - update_submission_mutation::ProjectVersionFileKind::PACKAGE - ) + let evaluation_fut = { + let upload_url = if let Some(url) = project_version + .files + .iter() + .find(|f| { + matches!( + f.kind, + update_submission_mutation::ProjectVersionFileKind::SUBMISSION_EVALUATION + ) + }) + .and_then(|f| f.upload_url.as_ref()) + { + url + } else { + return Err(error::system( + "No upload URL found", + "This is a bug, please report it", + )); + }; + let evaluation_tar_file = tempdir + .path() + .join(format!("{package_name}-{version}.evaluation.tar.gz")); + let mut evaluation_pb = ProgressBar::new_spinner().with_message("Compressing evaluation"); + evaluation_pb.enable_steady_tick(std::time::Duration::from_millis(100)); + evaluation_pb = m.add(evaluation_pb); + + let evaluation_pb_cloned = evaluation_pb.clone(); + let client = s3_client.clone(); + async move { + compress(evaluation_path, &evaluation_tar_file) + .await + .map_err(|err| { + error::system( + &format!("Could not compress evaluation: {}", err), + "Please make sure the evaluation directory is valid", + ) + })?; + evaluation_pb_cloned.set_message("Uploading evaluation"); + upload_file(&client, evaluation_tar_file, upload_url, "application/gzip").await + } + .map(move |res| { + if res.is_ok() { + evaluation_pb.finish_with_message("Evaluation uploaded"); + } else { + evaluation_pb.finish_with_message("An error occurred while processing evaluation"); + } + res }) - .and_then(|f| f.upload_url.as_ref()) - { - url - } else { - return Err(error::system( - "No upload URL found", - "This is a bug, please report it", - )); + .boxed() }; - let package_tar_file = tempdir - .path() - .join(format!("{package_name}-{version}.tar.gz")); - let mut package_pb = ProgressBar::new_spinner().with_message("Initializing Python environment"); - package_pb.enable_steady_tick(std::time::Duration::from_millis(100)); - package_pb = m.add(package_pb); - - let env = init_venv( - &global.project, - global.uv.as_ref(), - &package_pb, - global.color, - ) - .await?; - let project_file = RevertFile::save(pyproject_path(&global.project))?; - let mut new_project = project.clone(); - new_project.set_name(package_name); - std::fs::write(&project_file, new_project.toml()?)?; + let package_fut = { + let upload_url = if let Some(url) = project_version + .files + .iter() + .find(|f| { + matches!( + f.kind, + update_submission_mutation::ProjectVersionFileKind::PACKAGE + ) + }) + .and_then(|f| f.upload_url.as_ref()) + { + url + } else { + return Err(error::system( + "No upload URL found", + "This is a bug, please report it", + )); + }; + let package_tar_file = tempdir + .path() + .join(format!("{package_name}-{version}.tar.gz")); + let mut package_pb = ProgressBar::new_spinner().with_message("Building package"); + package_pb.enable_steady_tick(std::time::Duration::from_millis(100)); + package_pb = m.add(package_pb); - build_package(&env, &global.project, tempdir.path(), &package_pb).await?; - project_file.revert()?; + let package_pb_cloned = package_pb.clone(); + let client = s3_client.clone(); + async move { + let env = init_venv( + &global.project, + global.uv.as_ref(), + &package_pb_cloned, + global.color, + ) + .await?; - package_pb.set_message("Uploading package"); + let project_file = RevertFile::save(pyproject_path(&global.project))?; + let mut new_project = project.clone(); + new_project.set_name(package_name); + std::fs::write(&project_file, new_project.toml()?)?; + build_package(&env, &global.project, tempdir.path(), &package_pb_cloned).await?; + project_file.revert()?; - upload_file(&s3_client, package_tar_file, upload_url, "application/gzip").await?; + package_pb_cloned.set_message("Uploading package"); + upload_file(&client, package_tar_file, upload_url, "application/gzip").await + } + .map(move |res| { + if res.is_ok() { + package_pb.finish_with_message("Package uploaded"); + } else { + package_pb.finish_with_message("An error occurred while processing package"); + } + res + }) + .boxed() + }; - package_pb.finish_with_message("Package uploaded"); + futures::future::try_join_all([evaluation_fut, package_fut]).await?; let mut validate_pb = ProgressBar::new_spinner().with_message("Validating submission"); validate_pb.enable_steady_tick(std::time::Duration::from_millis(100)); diff --git a/src/dirs.rs b/src/dirs.rs index cfb1443..491795f 100644 --- a/src/dirs.rs +++ b/src/dirs.rs @@ -50,6 +50,10 @@ pub fn project_last_run_dir(project_dir: impl AsRef) -> PathBuf { project_config_dir(project_dir).join(LAST_RUN_DIRNAME) } +pub fn project_last_run_result(project_dir: impl AsRef) -> PathBuf { + project_last_run_dir(project_dir).join("result.msgpack") +} + pub fn project_data_dir(project_dir: impl AsRef, kind: impl ToString) -> PathBuf { project_config_dir(project_dir) .join(DATA_DIRNAME) diff --git a/src/graphql/get_entity_id_by_username.graphql b/src/graphql/get_entity_id_by_username.graphql deleted file mode 100644 index f32e9e7..0000000 --- a/src/graphql/get_entity_id_by_username.graphql +++ /dev/null @@ -1,6 +0,0 @@ -query GetEntityIdByUsername($username: String!) { - entityByUsername(username: $username) { - id - __typename - } -} diff --git a/src/graphql/get_viewer_id.graphql b/src/graphql/get_viewer_id.graphql deleted file mode 100644 index 0214ec9..0000000 --- a/src/graphql/get_viewer_id.graphql +++ /dev/null @@ -1,5 +0,0 @@ -query GetViewerId { - viewer { - id - } -} diff --git a/src/graphql/schema.graphql b/src/graphql/schema.graphql index 9d7114e..6ffa64a 100644 --- a/src/graphql/schema.graphql +++ b/src/graphql/schema.graphql @@ -164,6 +164,7 @@ interface Entity { bio: String createdAt: DateTime! viewerCan(action: Action!, asOrganization: ID): Boolean! + submissions(after: String, before: String, first: Int, last: Int, competitionId: ID): SubmissionConnection! } type EntityConnection { @@ -294,6 +295,7 @@ type Organization implements Entity & Node { image: Url imageThumbnail: Url users(after: String, before: String, first: Int, last: Int): OrganizationMembershipConnection! + submissions(after: String, before: String, first: Int, last: Int, competitionId: ID): SubmissionConnection! viewerCan(action: Action!, asOrganization: ID): Boolean! } @@ -418,6 +420,7 @@ type ProjectVersionEvaluation implements Node { score: Float error: String latest: Boolean! + max: Boolean! finalizedAt: DateTime createdAt: DateTime! id: ID! @@ -468,6 +471,7 @@ enum ProjectVersionFileKind { DATA PACKAGE TEMPLATE + SUBMISSION_EVALUATION } type Query { @@ -667,6 +671,8 @@ type User implements Entity & Node { can(action: Action!, on: ID, asOrganization: ID): Boolean! organizations(after: String, before: String, first: Int, last: Int): OrganizationMembershipConnection! submissions(after: String, before: String, first: Int, last: Int, competitionId: ID): SubmissionConnection! + topics(after: String, before: String, first: Int, last: Int): TopicConnection! + comments(after: String, before: String, first: Int, last: Int): CommentConnection! viewerCan(action: Action!, asOrganization: ID): Boolean! } diff --git a/src/graphql/submission_upload_info.graphql b/src/graphql/submission_upload_info.graphql new file mode 100644 index 0000000..2bec12c --- /dev/null +++ b/src/graphql/submission_upload_info.graphql @@ -0,0 +1,17 @@ +query SubmissionUploadInfo($slug: String!, $username: String!, $use_username: Boolean!) { + competitionBySlug(slug: $slug) { + id + useCase { + latest { + version + } + } + } + viewer @skip(if: $use_username) { + id + } + entityByUsername(username: $username) @include(if: $use_username) { + id + __typename + } +} diff --git a/src/python.rs b/src/python.rs index c851ff1..e8247d2 100644 --- a/src/python.rs +++ b/src/python.rs @@ -2,10 +2,23 @@ use crate::{ error::{self, Result}, process::run_command, }; -use aqora_runner::python::{PipOptions, PipPackage, PyEnv}; +use aqora_config::Version; +use aqora_runner::{ + pipeline::EvaluateAllInfo, + python::{PipOptions, PipPackage, PyEnv}, +}; use indicatif::ProgressBar; +use serde::{Deserialize, Serialize}; use std::path::Path; +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct LastRunResult { + #[serde(flatten)] + pub info: EvaluateAllInfo, + pub use_case_version: Option, + pub submission_version: Option, +} + pub async fn build_package( env: &PyEnv, input: impl AsRef,