diff --git a/config/src/lib.rs b/config/src/lib.rs index 7cd61f3..194435b 100644 --- a/config/src/lib.rs +++ b/config/src/lib.rs @@ -204,22 +204,22 @@ impl AqoraUseCaseConfig { for layer in self.layers.iter_mut() { if let Some(transform) = layer.transform.as_mut() { if transform.path.has_ref() { - layer.transform = None; + layer.transform = Some(FunctionDef::ignored()); } } if let Some(context) = layer.context.as_mut() { if context.path.has_ref() { - layer.context = None; + layer.context = Some(FunctionDef::ignored()); } } if let Some(metric) = layer.metric.as_mut() { if metric.path.has_ref() { - layer.metric = None; + layer.metric = Some(FunctionDef::ignored()); } } if let Some(branch) = layer.branch.as_mut() { if branch.path.has_ref() { - layer.branch = None; + layer.branch = Some(FunctionDef::ignored()); } } } @@ -240,6 +240,14 @@ pub struct FunctionDef { pub path: PathStr<'static>, } +impl FunctionDef { + pub fn ignored() -> Self { + Self { + path: PathStr::ignored(), + } + } +} + impl<'de> Deserialize<'de> for FunctionDef { fn deserialize(deserializer: D) -> Result where @@ -293,12 +301,22 @@ impl<'de> Deserialize<'de> for FunctionDef { #[derive(Clone)] pub struct PathStr<'a>(Cow<'a, [String]>); +const IGNORED_PATH: &[&str] = &["$$"]; + #[derive(Error, Debug)] pub enum PathStrReplaceError { #[error("Ref not found: {0}")] RefNotFound(String), } +impl PathStr<'static> { + pub fn ignored() -> Self { + Self(Cow::Owned( + IGNORED_PATH.iter().map(|s| s.to_string()).collect(), + )) + } +} + impl<'a> PathStr<'a> { pub fn module<'b: 'a>(&'b self) -> PathStr<'b> { Self(Cow::Borrowed(&self.0[..self.0.len() - 1])) @@ -327,6 +345,9 @@ impl<'a> PathStr<'a> { pub fn has_ref(&self) -> bool { self.0.iter().any(|part| part.starts_with('$')) } + pub fn is_ignored(&self) -> bool { + self.0 == IGNORED_PATH + } } impl fmt::Display for PathStr<'_> { diff --git a/runner/src/pipeline.rs b/runner/src/pipeline.rs index 5a9d430..622bfae 100644 --- a/runner/src/pipeline.rs +++ b/runner/src/pipeline.rs @@ -2,9 +2,10 @@ use crate::python::{ async_generator, async_python_run, deepcopy, format_err, serde_pickle, serde_pickle_opt, AsyncIterator, PyEnv, }; -use aqora_config::AqoraUseCaseConfig; +use aqora_config::{AqoraUseCaseConfig, FunctionDef}; use futures::prelude::*; use pyo3::{ + exceptions::PyValueError, intern, prelude::*, types::{PyDict, PyIterator, PyNone, PyTuple}, @@ -101,26 +102,33 @@ impl LayerFunction { } } +#[derive(Debug, Clone)] +pub enum LayerFunctionDef { + None, + Some(LayerFunction), + Ignored, +} + #[derive(Debug, Clone)] pub struct Layer { name: String, - transform: Option, - context: Option, - metric: Option, - branch: Option, + transform: LayerFunctionDef, + context: LayerFunctionDef, + metric: LayerFunctionDef, + branch: LayerFunctionDef, } #[derive(Debug, Clone, Serialize, Deserialize)] #[pyclass] pub struct LayerEvaluation { #[serde(with = "serde_pickle")] - transform: PyObject, + pub transform: PyObject, #[serde(with = "serde_pickle")] - context: PyObject, + pub context: PyObject, #[serde(with = "serde_pickle_opt")] - metric: Option, + pub metric: Option, #[serde(with = "serde_pickle_opt")] - branch: Option, + pub branch: Option, } impl LayerEvaluation { @@ -186,34 +194,62 @@ impl Layer { context: &PyObject, default: Option<&LayerEvaluation>, ) -> PyResult { - let context = if let Some(context_transform) = self.context.as_ref() { - context_transform - .call(input, original_input, context) - .await? - } else if let Some(default) = default.as_ref().map(|default| &default.context) { - default.clone() - } else { - context.clone() + let context = match &self.context { + LayerFunctionDef::Some(func) => func.call(input, original_input, context).await?, + LayerFunctionDef::Ignored => { + if let Some(default) = default { + default.context.clone() + } else { + return Err(PyErr::new::( + "Context function is ignored but no default is provided", + )); + } + } + LayerFunctionDef::None => context.clone(), }; - - let transform = if let Some(transform) = self.transform.as_ref() { - transform.call(input, original_input, &context).await? - } else if let Some(default) = default.as_ref().map(|default| &default.transform) { - default.clone() - } else { - input.clone() + let transform = match &self.transform { + LayerFunctionDef::Some(func) => func.call(input, original_input, &context).await?, + LayerFunctionDef::Ignored => { + if let Some(default) = default { + default.transform.clone() + } else { + return Err(PyErr::new::( + "Transform function is ignored but no default is provided", + )); + } + } + LayerFunctionDef::None => input.clone(), }; - - let metric = if let Some(metric) = self.metric.as_ref() { - Some(metric.call(&transform, original_input, &context).await?) - } else { - default.as_ref().and_then(|default| default.metric.clone()) + let metric = match &self.metric { + LayerFunctionDef::Some(func) => { + Some(func.call(&transform, original_input, &context).await?) + } + LayerFunctionDef::Ignored => { + if let Some(metric) = default.as_ref().and_then(|default| default.metric.as_ref()) { + Some(metric.clone()) + } else { + return Err(PyErr::new::( + "Metric function is ignored but no default is provided", + )); + } + } + LayerFunctionDef::None => None, }; - let branch = if let Some(branch) = self.branch.as_ref() { - Some(branch.call(&transform, original_input, &context).await?) - } else { - default.as_ref().and_then(|default| default.branch.clone()) + let branch = match &self.branch { + LayerFunctionDef::Some(func) => { + Some(func.call(&transform, original_input, &context).await?) + } + LayerFunctionDef::Ignored => { + if let Some(branch) = default.as_ref().and_then(|default| default.branch.as_ref()) { + Some(branch.clone()) + } else { + return Err(PyErr::new::( + "Branch function is ignored but no default is provided", + )); + } + } + LayerFunctionDef::None => None, }; Ok(LayerEvaluation { @@ -344,32 +380,12 @@ impl Pipeline { .layers .iter() .map(|layer| { - let transform = layer - .transform - .as_ref() - .map(|def| LayerFunction::new(py, env.import_path(py, &def.path)?)) - .transpose()?; - let context = layer - .context - .as_ref() - .map(|def| LayerFunction::new(py, env.import_path(py, &def.path)?)) - .transpose()?; - let metric = layer - .metric - .as_ref() - .map(|def| LayerFunction::new(py, env.import_path(py, &def.path)?)) - .transpose()?; - let branch = layer - .branch - .as_ref() - .map(|def| LayerFunction::new(py, env.import_path(py, &def.path)?)) - .transpose()?; Ok(Layer { name: layer.name.clone(), - transform, - context, - metric, - branch, + transform: Self::import_function_def(py, env, layer.transform.as_ref())?, + context: Self::import_function_def(py, env, layer.context.as_ref())?, + metric: Self::import_function_def(py, env, layer.metric.as_ref())?, + branch: Self::import_function_def(py, env, layer.branch.as_ref())?, }) }) .collect::>>()?; @@ -382,6 +398,23 @@ impl Pipeline { }) } + fn import_function_def( + py: Python, + env: &PyEnv, + def: Option<&FunctionDef>, + ) -> PyResult { + Ok(match def { + Some(FunctionDef { path }) => { + if path.is_ignored() { + LayerFunctionDef::Ignored + } else { + LayerFunctionDef::Some(LayerFunction::new(py, env.import_path(py, path)?)?) + } + } + None => LayerFunctionDef::None, + }) + } + pub fn generator(&self) -> PyResult>> { let generator = Python::with_gil(|py| { PyResult::Ok(