Skip to content

Commit

Permalink
fix: just check if has_ref instead of is_ignored
Browse files Browse the repository at this point in the history
  • Loading branch information
jpopesculian committed Mar 21, 2024
1 parent 1de239d commit 4668df9
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 55 deletions.
59 changes: 14 additions & 45 deletions config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ pub enum TestConfigError {
PathStrReplaceError(#[from] PathStrReplaceError),
}

#[derive(Error, Debug)]
pub enum UseCaseConfigValidationError {
#[error("Generator contains a reference")]
GeneratorContainsRef,
#[error("Aggregator contains a reference")]
AggregatorContainsRef,
}

impl AqoraUseCaseConfig {
pub fn replace_refs(&mut self, refs: &RefMap) -> Result<(), PathStrReplaceError> {
self.generator = self.generator.replace_refs(refs)?;
Expand Down Expand Up @@ -198,30 +206,12 @@ impl AqoraUseCaseConfig {
Ok(out)
}

pub fn ignore_refs(&mut self) -> Result<(), PathStrReplaceError> {
self.generator = self.generator.replace_refs(&HashMap::new())?;
self.aggregator = self.aggregator.replace_refs(&HashMap::new())?;
for layer in self.layers.iter_mut() {
if let Some(transform) = layer.transform.as_mut() {
if transform.path.has_ref() {
layer.transform = Some(FunctionDef::ignored());
}
}
if let Some(context) = layer.context.as_mut() {
if context.path.has_ref() {
layer.context = Some(FunctionDef::ignored());
}
}
if let Some(metric) = layer.metric.as_mut() {
if metric.path.has_ref() {
layer.metric = Some(FunctionDef::ignored());
}
}
if let Some(branch) = layer.branch.as_mut() {
if branch.path.has_ref() {
layer.branch = Some(FunctionDef::ignored());
}
}
pub fn validate(&self) -> Result<(), UseCaseConfigValidationError> {
if self.generator.has_ref() {
return Err(UseCaseConfigValidationError::GeneratorContainsRef);
}
if self.aggregator.has_ref() {
return Err(UseCaseConfigValidationError::AggregatorContainsRef);
}
Ok(())
}
Expand All @@ -240,14 +230,6 @@ pub struct FunctionDef {
pub path: PathStr<'static>,
}

impl FunctionDef {
pub fn ignored() -> Self {
Self {
path: PathStr::ignored(),
}
}
}

impl<'de> Deserialize<'de> for FunctionDef {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
Expand Down Expand Up @@ -301,22 +283,12 @@ impl<'de> Deserialize<'de> for FunctionDef {
#[derive(Clone)]
pub struct PathStr<'a>(Cow<'a, [String]>);

const IGNORED_PATH: &[&str] = &["$$"];

#[derive(Error, Debug)]
pub enum PathStrReplaceError {
#[error("Ref not found: {0}")]
RefNotFound(String),
}

impl PathStr<'static> {
pub fn ignored() -> Self {
Self(Cow::Owned(
IGNORED_PATH.iter().map(|s| s.to_string()).collect(),
))
}
}

impl<'a> PathStr<'a> {
pub fn module<'b: 'a>(&'b self) -> PathStr<'b> {
Self(Cow::Borrowed(&self.0[..self.0.len() - 1]))
Expand Down Expand Up @@ -345,9 +317,6 @@ impl<'a> PathStr<'a> {
pub fn has_ref(&self) -> bool {
self.0.iter().any(|part| part.starts_with('$'))
}
pub fn is_ignored(&self) -> bool {
self.0 == IGNORED_PATH
}
}

impl fmt::Display for PathStr<'_> {
Expand Down
14 changes: 7 additions & 7 deletions runner/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ impl LayerFunction {
pub enum LayerFunctionDef {
None,
Some(LayerFunction),
Ignored,
UseDefault,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -196,7 +196,7 @@ impl Layer {
) -> PyResult<LayerEvaluation> {
let context = match &self.context {
LayerFunctionDef::Some(func) => func.call(input, original_input, context).await?,
LayerFunctionDef::Ignored => {
LayerFunctionDef::UseDefault => {
if let Some(default) = default {
default.context.clone()
} else {
Expand All @@ -209,7 +209,7 @@ impl Layer {
};
let transform = match &self.transform {
LayerFunctionDef::Some(func) => func.call(input, original_input, &context).await?,
LayerFunctionDef::Ignored => {
LayerFunctionDef::UseDefault => {
if let Some(default) = default {
default.transform.clone()
} else {
Expand All @@ -224,7 +224,7 @@ impl Layer {
LayerFunctionDef::Some(func) => {
Some(func.call(&transform, original_input, &context).await?)
}
LayerFunctionDef::Ignored => {
LayerFunctionDef::UseDefault => {
if let Some(metric) = default.as_ref().and_then(|default| default.metric.as_ref()) {
Some(metric.clone())
} else {
Expand All @@ -240,7 +240,7 @@ impl Layer {
LayerFunctionDef::Some(func) => {
Some(func.call(&transform, original_input, &context).await?)
}
LayerFunctionDef::Ignored => {
LayerFunctionDef::UseDefault => {
if let Some(branch) = default.as_ref().and_then(|default| default.branch.as_ref()) {
Some(branch.clone())
} else {
Expand Down Expand Up @@ -405,8 +405,8 @@ impl Pipeline {
) -> PyResult<LayerFunctionDef> {
Ok(match def {
Some(FunctionDef { path }) => {
if path.is_ignored() {
LayerFunctionDef::Ignored
if path.has_ref() {
LayerFunctionDef::UseDefault
} else {
LayerFunctionDef::Some(LayerFunction::new(py, env.import_path(py, path)?)?)
}
Expand Down
6 changes: 3 additions & 3 deletions src/commands/upload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,10 @@ pub async fn upload_use_case(args: Upload, global: GlobalArgs, project: PyProjec
"Please specify a competition in either the pyproject.toml or the command line",
)
})?;
if config.generator.has_ref() || config.aggregator.has_ref() {
if let Err(err) = config.validate() {
return Err(error::user(
"Generator and aggregator cannot include references to the submission",
"Please remove any `$` from the generator and aggregator paths in your pyproject.toml",
&format!("Invalid use case: {err}"),
"Please make sure the use case is valid",
));
}

Expand Down

0 comments on commit 4668df9

Please sign in to comment.