Skip to content

Commit

Permalink
add ability to use defaults for evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
jpopesculian committed Mar 11, 2024
1 parent db67132 commit 7cdd732
Show file tree
Hide file tree
Showing 10 changed files with 309 additions and 182 deletions.
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

79 changes: 30 additions & 49 deletions runner/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use pyo3::{
};
use serde::{Deserialize, Serialize};
use split_stream_by::{Either, SplitStreamByMapExt};
use std::{collections::HashMap, path::PathBuf, sync::Arc};
use std::{collections::HashMap, path::PathBuf};
use thiserror::Error;

#[derive(Debug, Clone)]
Expand All @@ -38,6 +38,7 @@ impl LayerFunction {
inspect
.getattr(intern!(py, "signature"))?
.call1((func,))?
.getattr(intern!(py, "parameters"))?
.call_method0(intern!(py, "values"))?,
)?;
for parameter in parameters {
Expand All @@ -63,6 +64,7 @@ impl LayerFunction {
takes_context_kwarg = true;
}
}

Ok(Self {
func: func.to_object(py),
takes_input_arg,
Expand Down Expand Up @@ -182,29 +184,36 @@ impl Layer {
input: &PyObject,
original_input: &PyObject,
context: &PyObject,
default: Option<&LayerEvaluation>,
) -> PyResult<LayerEvaluation> {
let context = if let Some(context_transform) = self.context.as_ref() {
context_transform
.call(input, original_input, context)
.await?
} else if let Some(default) = default.as_ref().map(|default| &default.context) {
default.clone()
} else {
input.clone()
context.clone()
};

let transform = if let Some(transform) = self.transform.as_ref() {
transform.call(input, original_input, &context).await?
} else if let Some(default) = default.as_ref().map(|default| &default.transform) {
default.clone()
} else {
input.clone()
};

let metric = if let Some(metric) = self.metric.as_ref() {
Some(metric.call(&transform, original_input, &context).await?)
} else {
None
default.as_ref().and_then(|default| default.metric.clone())
};

let branch = if let Some(branch) = self.branch.as_ref() {
Some(branch.call(&transform, original_input, &context).await?)
} else {
None
default.as_ref().and_then(|default| default.branch.clone())
};

Ok(LayerEvaluation {
Expand All @@ -214,10 +223,6 @@ impl Layer {
branch,
})
}

pub async fn assert_metric(&self, evaluation: &LayerEvaluation) -> PyResult<()> {
todo!()
}
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -248,7 +253,6 @@ impl PipelineConfig {

#[derive(Clone, Debug)]
pub struct Evaluator {
config: PipelineConfig,
layers: Vec<Layer>,
}

Expand All @@ -265,12 +269,21 @@ pub enum EvaluationError {
),
#[error("Layer not found: {0}")]
LayerNotFound(String),
#[error("{0}")]
Custom(String),
}

impl EvaluationError {
pub fn custom(err: impl ToString) -> Self {
Self::Custom(err.to_string())
}
}

impl Evaluator {
pub async fn evaluate(
&self,
mut input: PyObject,
defaults: Option<&EvaluationResult>,
) -> Result<EvaluationResult, (EvaluationResult, EvaluationError)> {
let mut out = EvaluationResult::new();
macro_rules! try_or_bail {
Expand All @@ -286,7 +299,14 @@ impl Evaluator {
let mut layer_index = 0;
while layer_index < self.layers.len() {
let layer = &self.layers[layer_index];
let result = try_or_bail!(layer.evaluate(&input, &original_input, &context).await);
let default = defaults
.and_then(|defaults| defaults.get(&layer.name))
.and_then(|defaults| defaults.get(out.get(&layer.name).map_or(0, |v| v.len())));
let result = try_or_bail!(
layer
.evaluate(&input, &original_input, &context, default)
.await
);
if let Some(branch) = try_or_bail!(result.branch_str()) {
layer_index = try_or_bail!(self
.layers
Expand All @@ -302,44 +322,6 @@ impl Evaluator {
}
Ok(out)
}

pub async fn assert_metric(&self, results: &EvaluationResult) -> Result<(), EvaluationError> {
let mut result_indexes = HashMap::<String, usize>::new();
let mut layer_index = 0;
while layer_index < self.layers.len() {
let layer = &self.layers[layer_index];
let layer_name = &layer.name;
let result_index = *result_indexes.entry(layer_name.clone()).or_insert(0);
let result = results
.get(layer_name)
.ok_or_else(|| EvaluationError::LayerNotFound(layer_name.clone()))?
.get(result_index)
.ok_or_else(|| EvaluationError::LayerNotFound(layer_name.clone()))?;
layer.assert_metric(result).await?;
result_indexes.insert(layer_name.clone(), result_index + 1);
if let Some(branch) = result.branch_str()? {
layer_index = self
.layers
.iter()
.position(|layer| layer.name == branch)
.ok_or_else(|| EvaluationError::LayerNotFound(branch))?;
} else {
layer_index += 1;
}
}
Ok(())
}

pub fn evaluate_all(
self,
inputs: impl Stream<Item = PyResult<PyObject>>,
) -> impl Stream<Item = Result<EvaluationResult, (EvaluationResult, EvaluationError)>> {
let this = Arc::new(self);
inputs
.map_err(|err| (EvaluationResult::new(), EvaluationError::Python(err)))
.map_ok(move |input| (input, this.clone()))
.and_then(|(input, evaluator)| async move { evaluator.evaluate(input).await })
}
}

pub struct Pipeline {
Expand Down Expand Up @@ -415,7 +397,6 @@ impl Pipeline {
pub fn evaluator(&self) -> Evaluator {
Evaluator {
layers: self.layers.clone(),
config: self.config.clone(),
}
}

Expand Down
25 changes: 7 additions & 18 deletions src/commands/test.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use crate::{
commands::GlobalArgs,
dirs::{
init_venv, project_data_dir, project_last_run_dir, project_use_case_toml_path,
read_pyproject,
init_venv, project_data_dir, project_last_run_dir, project_last_run_result,
project_use_case_toml_path, read_pyproject,
},
error::{self, Result},
python::LastRunResult,
};
use aqora_config::{PyProject, Version};
use aqora_config::PyProject;
use aqora_runner::pipeline::{
EvaluateAllInfo, EvaluateInputInfo, EvaluationError, EvaluationResult, Evaluator, Pipeline,
PipelineConfig,
Expand All @@ -17,7 +18,6 @@ use indicatif::{MultiProgress, ProgressBar};
use owo_colors::{OwoColorize, Stream as OwoStream, Style};
use pyo3::prelude::*;
use pyo3::{exceptions::PyException, Python};
use serde::{Deserialize, Serialize};
use std::{
path::Path,
pin::Pin,
Expand All @@ -31,14 +31,6 @@ pub struct Test {
pub test: Vec<usize>,
}

#[derive(Serialize, Deserialize)]
pub struct LastRunResult {
#[serde(flatten)]
pub info: EvaluateAllInfo,
pub use_case_version: Option<Version>,
pub submission_version: Option<Version>,
}

fn evaluate(
evaluator: Evaluator,
inputs: impl Stream<Item = (usize, PyResult<PyObject>)>,
Expand All @@ -50,7 +42,7 @@ fn evaluate(
.map(move |input| (input, evaluator.clone()))
.then(|((index, result), evaluator)| async move {
match result {
Ok(input) => match evaluator.evaluate(input.clone()).await {
Ok(input) => match evaluator.evaluate(input.clone(), None).await {
Ok(result) => (
index,
EvaluateInputInfo {
Expand Down Expand Up @@ -100,10 +92,7 @@ fn evaluate(
.if_supports_color(OwoStream::Stdout, |text| text.red()),
err = err
));
return Err((
item.result,
EvaluationError::Python(PyException::new_err(err)),
));
return Err((item.result, EvaluationError::custom(err)));
}

let is_ok = item.error.is_none();
Expand Down Expand Up @@ -202,7 +191,7 @@ pub async fn test_submission(args: Test, global: GlobalArgs, project: PyProject)
})?;

let last_run_dir = project_last_run_dir(&global.project);
let last_run_result_file = last_run_dir.join("result.msgpack");
let last_run_result_file = project_last_run_result(&global.project);
if args.test.is_empty() {
if last_run_dir.exists() {
tokio::fs::remove_dir_all(&last_run_dir)
Expand Down
Loading

0 comments on commit 7cdd732

Please sign in to comment.