diff --git a/config/src/lib.rs b/config/src/lib.rs index a5afd4c..894ae7b 100644 --- a/config/src/lib.rs +++ b/config/src/lib.rs @@ -373,3 +373,35 @@ impl<'de> de::Deserialize<'de> for PathStr<'static> { deserializer.deserialize_str(PathStrVisitor) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_path_str_no_ref() { + let path_str: PathStr = "foo.bar.baz".parse().unwrap(); + assert_eq!(path_str.module().to_string(), "foo.bar"); + assert_eq!(path_str.name(), "baz"); + assert_eq!(path_str.to_string(), "foo.bar.baz"); + assert!(!path_str.has_ref()); + } + + #[test] + fn test_path_str_with_ref() { + let path_str: PathStr = "foo.$bar.baz".parse().unwrap(); + assert_eq!(path_str.module().to_string(), "foo.$bar"); + assert_eq!(path_str.name(), "baz"); + assert_eq!(path_str.to_string(), "foo.$bar.baz"); + assert!(path_str.has_ref()); + + let refs: RefMap = vec![("bar".to_string(), "qux.quux".parse().unwrap())] + .into_iter() + .collect(); + let replaced = path_str.replace_refs(&refs).unwrap(); + assert_eq!(replaced.module().to_string(), "foo.qux.quux"); + assert_eq!(replaced.name(), "baz"); + assert_eq!(replaced.to_string(), "foo.qux.quux.baz"); + assert!(!replaced.has_ref()); + } +} diff --git a/runner/src/pipeline.rs b/runner/src/pipeline.rs index 1a551f6..202eb2d 100644 --- a/runner/src/pipeline.rs +++ b/runner/src/pipeline.rs @@ -403,16 +403,18 @@ impl Pipeline { env: &PyEnv, def: Option<&FunctionDef>, ) -> PyResult { - Ok(match def { - Some(FunctionDef { path }) => { - if path.has_ref() { - LayerFunctionDef::UseDefault - } else { - LayerFunctionDef::Some(LayerFunction::new(py, env.import_path(py, path)?)?) - } + if let Some(FunctionDef { path }) = def { + if path.has_ref() { + Ok(LayerFunctionDef::UseDefault) + } else { + Ok(LayerFunctionDef::Some(LayerFunction::new( + py, + env.import_path(py, path)?, + )?)) } - None => LayerFunctionDef::None, - }) + } else { + Ok(LayerFunctionDef::None) + } } pub fn generator(&self) -> PyResult>> {