Skip to content

Commit

Permalink
use ignored paths instead of optional values
Browse files Browse the repository at this point in the history
  • Loading branch information
jpopesculian committed Mar 20, 2024
1 parent c611c3c commit 1de239d
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 61 deletions.
29 changes: 25 additions & 4 deletions config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,22 +204,22 @@ impl AqoraUseCaseConfig {
for layer in self.layers.iter_mut() {
if let Some(transform) = layer.transform.as_mut() {
if transform.path.has_ref() {
layer.transform = None;
layer.transform = Some(FunctionDef::ignored());
}
}
if let Some(context) = layer.context.as_mut() {
if context.path.has_ref() {
layer.context = None;
layer.context = Some(FunctionDef::ignored());
}
}
if let Some(metric) = layer.metric.as_mut() {
if metric.path.has_ref() {
layer.metric = None;
layer.metric = Some(FunctionDef::ignored());
}
}
if let Some(branch) = layer.branch.as_mut() {
if branch.path.has_ref() {
layer.branch = None;
layer.branch = Some(FunctionDef::ignored());
}
}
}
Expand All @@ -240,6 +240,14 @@ pub struct FunctionDef {
pub path: PathStr<'static>,
}

impl FunctionDef {
pub fn ignored() -> Self {
Self {
path: PathStr::ignored(),
}
}
}

impl<'de> Deserialize<'de> for FunctionDef {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
Expand Down Expand Up @@ -293,12 +301,22 @@ impl<'de> Deserialize<'de> for FunctionDef {
#[derive(Clone)]
pub struct PathStr<'a>(Cow<'a, [String]>);

const IGNORED_PATH: &[&str] = &["$$"];

#[derive(Error, Debug)]
pub enum PathStrReplaceError {
#[error("Ref not found: {0}")]
RefNotFound(String),
}

impl PathStr<'static> {
pub fn ignored() -> Self {
Self(Cow::Owned(
IGNORED_PATH.iter().map(|s| s.to_string()).collect(),
))
}
}

impl<'a> PathStr<'a> {
pub fn module<'b: 'a>(&'b self) -> PathStr<'b> {
Self(Cow::Borrowed(&self.0[..self.0.len() - 1]))
Expand Down Expand Up @@ -327,6 +345,9 @@ impl<'a> PathStr<'a> {
pub fn has_ref(&self) -> bool {
self.0.iter().any(|part| part.starts_with('$'))
}
pub fn is_ignored(&self) -> bool {
self.0 == IGNORED_PATH
}
}

impl fmt::Display for PathStr<'_> {
Expand Down
147 changes: 90 additions & 57 deletions runner/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use crate::python::{
async_generator, async_python_run, deepcopy, format_err, serde_pickle, serde_pickle_opt,
AsyncIterator, PyEnv,
};
use aqora_config::AqoraUseCaseConfig;
use aqora_config::{AqoraUseCaseConfig, FunctionDef};
use futures::prelude::*;
use pyo3::{
exceptions::PyValueError,
intern,
prelude::*,
types::{PyDict, PyIterator, PyNone, PyTuple},
Expand Down Expand Up @@ -101,26 +102,33 @@ impl LayerFunction {
}
}

#[derive(Debug, Clone)]
pub enum LayerFunctionDef {
None,
Some(LayerFunction),
Ignored,
}

#[derive(Debug, Clone)]
pub struct Layer {
name: String,
transform: Option<LayerFunction>,
context: Option<LayerFunction>,
metric: Option<LayerFunction>,
branch: Option<LayerFunction>,
transform: LayerFunctionDef,
context: LayerFunctionDef,
metric: LayerFunctionDef,
branch: LayerFunctionDef,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[pyclass]
pub struct LayerEvaluation {
#[serde(with = "serde_pickle")]
transform: PyObject,
pub transform: PyObject,
#[serde(with = "serde_pickle")]
context: PyObject,
pub context: PyObject,
#[serde(with = "serde_pickle_opt")]
metric: Option<PyObject>,
pub metric: Option<PyObject>,
#[serde(with = "serde_pickle_opt")]
branch: Option<PyObject>,
pub branch: Option<PyObject>,
}

impl LayerEvaluation {
Expand Down Expand Up @@ -186,34 +194,62 @@ impl Layer {
context: &PyObject,
default: Option<&LayerEvaluation>,
) -> PyResult<LayerEvaluation> {
let context = if let Some(context_transform) = self.context.as_ref() {
context_transform
.call(input, original_input, context)
.await?
} else if let Some(default) = default.as_ref().map(|default| &default.context) {
default.clone()
} else {
context.clone()
let context = match &self.context {
LayerFunctionDef::Some(func) => func.call(input, original_input, context).await?,
LayerFunctionDef::Ignored => {
if let Some(default) = default {
default.context.clone()
} else {
return Err(PyErr::new::<PyValueError, _>(
"Context function is ignored but no default is provided",
));
}
}
LayerFunctionDef::None => context.clone(),
};

let transform = if let Some(transform) = self.transform.as_ref() {
transform.call(input, original_input, &context).await?
} else if let Some(default) = default.as_ref().map(|default| &default.transform) {
default.clone()
} else {
input.clone()
let transform = match &self.transform {
LayerFunctionDef::Some(func) => func.call(input, original_input, &context).await?,
LayerFunctionDef::Ignored => {
if let Some(default) = default {
default.transform.clone()
} else {
return Err(PyErr::new::<PyValueError, _>(
"Transform function is ignored but no default is provided",
));
}
}
LayerFunctionDef::None => input.clone(),
};

let metric = if let Some(metric) = self.metric.as_ref() {
Some(metric.call(&transform, original_input, &context).await?)
} else {
default.as_ref().and_then(|default| default.metric.clone())
let metric = match &self.metric {
LayerFunctionDef::Some(func) => {
Some(func.call(&transform, original_input, &context).await?)
}
LayerFunctionDef::Ignored => {
if let Some(metric) = default.as_ref().and_then(|default| default.metric.as_ref()) {
Some(metric.clone())
} else {
return Err(PyErr::new::<PyValueError, _>(
"Metric function is ignored but no default is provided",
));
}
}
LayerFunctionDef::None => None,
};

let branch = if let Some(branch) = self.branch.as_ref() {
Some(branch.call(&transform, original_input, &context).await?)
} else {
default.as_ref().and_then(|default| default.branch.clone())
let branch = match &self.branch {
LayerFunctionDef::Some(func) => {
Some(func.call(&transform, original_input, &context).await?)
}
LayerFunctionDef::Ignored => {
if let Some(branch) = default.as_ref().and_then(|default| default.branch.as_ref()) {
Some(branch.clone())
} else {
return Err(PyErr::new::<PyValueError, _>(
"Branch function is ignored but no default is provided",
));
}
}
LayerFunctionDef::None => None,
};

Ok(LayerEvaluation {
Expand Down Expand Up @@ -344,32 +380,12 @@ impl Pipeline {
.layers
.iter()
.map(|layer| {
let transform = layer
.transform
.as_ref()
.map(|def| LayerFunction::new(py, env.import_path(py, &def.path)?))
.transpose()?;
let context = layer
.context
.as_ref()
.map(|def| LayerFunction::new(py, env.import_path(py, &def.path)?))
.transpose()?;
let metric = layer
.metric
.as_ref()
.map(|def| LayerFunction::new(py, env.import_path(py, &def.path)?))
.transpose()?;
let branch = layer
.branch
.as_ref()
.map(|def| LayerFunction::new(py, env.import_path(py, &def.path)?))
.transpose()?;
Ok(Layer {
name: layer.name.clone(),
transform,
context,
metric,
branch,
transform: Self::import_function_def(py, env, layer.transform.as_ref())?,
context: Self::import_function_def(py, env, layer.context.as_ref())?,
metric: Self::import_function_def(py, env, layer.metric.as_ref())?,
branch: Self::import_function_def(py, env, layer.branch.as_ref())?,
})
})
.collect::<PyResult<Vec<_>>>()?;
Expand All @@ -382,6 +398,23 @@ impl Pipeline {
})
}

fn import_function_def(
py: Python,
env: &PyEnv,
def: Option<&FunctionDef>,
) -> PyResult<LayerFunctionDef> {
Ok(match def {
Some(FunctionDef { path }) => {
if path.is_ignored() {
LayerFunctionDef::Ignored
} else {
LayerFunctionDef::Some(LayerFunction::new(py, env.import_path(py, path)?)?)
}
}
None => LayerFunctionDef::None,
})
}

pub fn generator(&self) -> PyResult<impl Stream<Item = PyResult<PyObject>>> {
let generator = Python::with_gil(|py| {
PyResult::Ok(
Expand Down

0 comments on commit 1de239d

Please sign in to comment.