From 5cdf04d732e7fa2669aef745ca33a4e8b6f83dde Mon Sep 17 00:00:00 2001 From: Behrang Shafei <50267830+bertiqwerty@users.noreply.github.com> Date: Thu, 15 Feb 2024 11:41:47 +0100 Subject: [PATCH] allow returning scalar values as dataframes --- rormula/rormula/{_rormula.pyi => rormula.pyi} | 0 rormula/src/lib.rs | 4 +--- rormula/test/test_arithmetic.py | 5 +++++ 3 files changed, 6 insertions(+), 3 deletions(-) rename rormula/rormula/{_rormula.pyi => rormula.pyi} (100%) diff --git a/rormula/rormula/_rormula.pyi b/rormula/rormula/rormula.pyi similarity index 100% rename from rormula/rormula/_rormula.pyi rename to rormula/rormula/rormula.pyi diff --git a/rormula/src/lib.rs b/rormula/src/lib.rs index 901caef..23df80a 100644 --- a/rormula/src/lib.rs +++ b/rormula/src/lib.rs @@ -74,10 +74,8 @@ fn eval_arithmetic<'py>( let res = pya.into_pyarray(py); Ok(res) } + Value::Scalar(s) => Ok(Array2::::from_elem((1, 1), s).into_pyarray(py)), Value::Cats(_) => Err(PyValueError::new_err("result cannot be cat".to_string())), - Value::Scalar(s) => Err(PyValueError::new_err(format!( - "result cannot be skalar but got {s}" - ))), Value::Error(e) => Err(PyValueError::new_err(format!("computation failed, {e:?}"))), } } diff --git a/rormula/test/test_arithmetic.py b/rormula/test/test_arithmetic.py index c72b82a..9c0a993 100644 --- a/rormula/test/test_arithmetic.py +++ b/rormula/test/test_arithmetic.py @@ -72,6 +72,11 @@ def test_scalar_scalar(): res = rormula.eval_asdf(df) ref = df.eval(s) np.allclose(res[name].to_numpy(), ref.values) + s = "5/3" + rormula = Arithmetic(s, name) + res = rormula.eval_asdf(df) + ref = df.eval(s) + np.allclose(res[name].to_numpy(), ref) if __name__ == "__main__":