From 4668df9da4291fe44fef66bd701891d9115c0a23 Mon Sep 17 00:00:00 2001 From: Julian Popescu Date: Thu, 21 Mar 2024 11:20:15 +0100 Subject: [PATCH] fix: just check if has_ref instead of is_ignored --- config/src/lib.rs | 59 ++++++++++-------------------------------- runner/src/pipeline.rs | 14 +++++----- src/commands/upload.rs | 6 ++--- 3 files changed, 24 insertions(+), 55 deletions(-) diff --git a/config/src/lib.rs b/config/src/lib.rs index 194435b..a5afd4c 100644 --- a/config/src/lib.rs +++ b/config/src/lib.rs @@ -136,6 +136,14 @@ pub enum TestConfigError { PathStrReplaceError(#[from] PathStrReplaceError), } +#[derive(Error, Debug)] +pub enum UseCaseConfigValidationError { + #[error("Generator contains a reference")] + GeneratorContainsRef, + #[error("Aggregator contains a reference")] + AggregatorContainsRef, +} + impl AqoraUseCaseConfig { pub fn replace_refs(&mut self, refs: &RefMap) -> Result<(), PathStrReplaceError> { self.generator = self.generator.replace_refs(refs)?; @@ -198,30 +206,12 @@ impl AqoraUseCaseConfig { Ok(out) } - pub fn ignore_refs(&mut self) -> Result<(), PathStrReplaceError> { - self.generator = self.generator.replace_refs(&HashMap::new())?; - self.aggregator = self.aggregator.replace_refs(&HashMap::new())?; - for layer in self.layers.iter_mut() { - if let Some(transform) = layer.transform.as_mut() { - if transform.path.has_ref() { - layer.transform = Some(FunctionDef::ignored()); - } - } - if let Some(context) = layer.context.as_mut() { - if context.path.has_ref() { - layer.context = Some(FunctionDef::ignored()); - } - } - if let Some(metric) = layer.metric.as_mut() { - if metric.path.has_ref() { - layer.metric = Some(FunctionDef::ignored()); - } - } - if let Some(branch) = layer.branch.as_mut() { - if branch.path.has_ref() { - layer.branch = Some(FunctionDef::ignored()); - } - } + pub fn validate(&self) -> Result<(), UseCaseConfigValidationError> { + if self.generator.has_ref() { + return Err(UseCaseConfigValidationError::GeneratorContainsRef); + } + if self.aggregator.has_ref() { + return Err(UseCaseConfigValidationError::AggregatorContainsRef); } Ok(()) } @@ -240,14 +230,6 @@ pub struct FunctionDef { pub path: PathStr<'static>, } -impl FunctionDef { - pub fn ignored() -> Self { - Self { - path: PathStr::ignored(), - } - } -} - impl<'de> Deserialize<'de> for FunctionDef { fn deserialize(deserializer: D) -> Result where @@ -301,22 +283,12 @@ impl<'de> Deserialize<'de> for FunctionDef { #[derive(Clone)] pub struct PathStr<'a>(Cow<'a, [String]>); -const IGNORED_PATH: &[&str] = &["$$"]; - #[derive(Error, Debug)] pub enum PathStrReplaceError { #[error("Ref not found: {0}")] RefNotFound(String), } -impl PathStr<'static> { - pub fn ignored() -> Self { - Self(Cow::Owned( - IGNORED_PATH.iter().map(|s| s.to_string()).collect(), - )) - } -} - impl<'a> PathStr<'a> { pub fn module<'b: 'a>(&'b self) -> PathStr<'b> { Self(Cow::Borrowed(&self.0[..self.0.len() - 1])) @@ -345,9 +317,6 @@ impl<'a> PathStr<'a> { pub fn has_ref(&self) -> bool { self.0.iter().any(|part| part.starts_with('$')) } - pub fn is_ignored(&self) -> bool { - self.0 == IGNORED_PATH - } } impl fmt::Display for PathStr<'_> { diff --git a/runner/src/pipeline.rs b/runner/src/pipeline.rs index 622bfae..1a551f6 100644 --- a/runner/src/pipeline.rs +++ b/runner/src/pipeline.rs @@ -106,7 +106,7 @@ impl LayerFunction { pub enum LayerFunctionDef { None, Some(LayerFunction), - Ignored, + UseDefault, } #[derive(Debug, Clone)] @@ -196,7 +196,7 @@ impl Layer { ) -> PyResult { let context = match &self.context { LayerFunctionDef::Some(func) => func.call(input, original_input, context).await?, - LayerFunctionDef::Ignored => { + LayerFunctionDef::UseDefault => { if let Some(default) = default { default.context.clone() } else { @@ -209,7 +209,7 @@ impl Layer { }; let transform = match &self.transform { LayerFunctionDef::Some(func) => func.call(input, original_input, &context).await?, - LayerFunctionDef::Ignored => { + LayerFunctionDef::UseDefault => { if let Some(default) = default { default.transform.clone() } else { @@ -224,7 +224,7 @@ impl Layer { LayerFunctionDef::Some(func) => { Some(func.call(&transform, original_input, &context).await?) } - LayerFunctionDef::Ignored => { + LayerFunctionDef::UseDefault => { if let Some(metric) = default.as_ref().and_then(|default| default.metric.as_ref()) { Some(metric.clone()) } else { @@ -240,7 +240,7 @@ impl Layer { LayerFunctionDef::Some(func) => { Some(func.call(&transform, original_input, &context).await?) } - LayerFunctionDef::Ignored => { + LayerFunctionDef::UseDefault => { if let Some(branch) = default.as_ref().and_then(|default| default.branch.as_ref()) { Some(branch.clone()) } else { @@ -405,8 +405,8 @@ impl Pipeline { ) -> PyResult { Ok(match def { Some(FunctionDef { path }) => { - if path.is_ignored() { - LayerFunctionDef::Ignored + if path.has_ref() { + LayerFunctionDef::UseDefault } else { LayerFunctionDef::Some(LayerFunction::new(py, env.import_path(py, path)?)?) } diff --git a/src/commands/upload.rs b/src/commands/upload.rs index 949654b..e4dbddd 100644 --- a/src/commands/upload.rs +++ b/src/commands/upload.rs @@ -232,10 +232,10 @@ pub async fn upload_use_case(args: Upload, global: GlobalArgs, project: PyProjec "Please specify a competition in either the pyproject.toml or the command line", ) })?; - if config.generator.has_ref() || config.aggregator.has_ref() { + if let Err(err) = config.validate() { return Err(error::user( - "Generator and aggregator cannot include references to the submission", - "Please remove any `$` from the generator and aggregator paths in your pyproject.toml", + &format!("Invalid use case: {err}"), + "Please make sure the use case is valid", )); }