diff --git a/.github/workflows/publish_pypi.yml b/.github/workflows/publish_pypi.yml index e12f0694..9f462588 100644 --- a/.github/workflows/publish_pypi.yml +++ b/.github/workflows/publish_pypi.yml @@ -11,7 +11,7 @@ jobs: runs-on: ${{ matrix.runner[0] }} strategy: matrix: - runner: [[macos-13, x86_64], [macos-latest, x86_64], [macos-14, aarch64]] + runner: [[macos-latest, x86_64], [macos-14, aarch64]] steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 diff --git a/Readme.md b/Readme.md index 7c5ffcb7..9107d6a3 100644 --- a/Readme.md +++ b/Readme.md @@ -12,6 +12,7 @@ Symbolica website Zulip Chat Symbolica website + Codecov

# Symbolica diff --git a/src/api/cpp.rs b/src/api/cpp.rs index d14d3183..f7c577eb 100644 --- a/src/api/cpp.rs +++ b/src/api/cpp.rs @@ -417,3 +417,76 @@ unsafe extern "C" fn simplify_factorized( unsafe extern "C" fn drop(symbolica: *mut Symbolica) { let _ = Box::from_raw(symbolica); } + +#[cfg(test)] +mod test { + use std::ffi::{c_char, CStr}; + + use super::{drop, init}; + + #[test] + fn simplify() { + let symbolica = unsafe { init() }; + + unsafe { super::set_vars(symbolica, b"d,y\0".as_ptr() as *const c_char) }; + + let input = "-(4096-4096*y^2)/(-3072+1024*d)*(1536-512*d)-(-8192+8192*y^2)/(2)*((-6+d)/2)-(-8192+8192*y^2)/(-2)*((-13+3*d)/2)-(-8192+8192*y^2)/(-4)*(-8+2*d)\0"; + let result = unsafe { super::simplify(symbolica, input.as_ptr() as *const i8, 0, true) }; + let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned(); + + assert_eq!(result, "[32768-32768*y^2-8192*d+8192*d*y^2]"); + + let result = unsafe { super::simplify(symbolica, input.as_ptr() as *const i8, 0, false) }; + let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned(); + + assert_eq!(result, "32768-32768*y^2-8192*d+8192*d*y^2"); + + let result = + unsafe { super::simplify_factorized(symbolica, input.as_ptr() as *const i8, 0, true) }; + let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned(); + + assert_eq!(result, "[8192]*[4-4*y^2-d+d*y^2]"); + + let result = + unsafe { super::simplify_factorized(symbolica, input.as_ptr() as *const i8, 0, false) }; + let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned(); + + unsafe { drop(symbolica) }; + assert_eq!(result, "8192*(4-4*y^2-d+d*y^2)"); + } + + #[test] + fn simplify_ff() { + let symbolica = unsafe { init() }; + + unsafe { super::set_vars(symbolica, b"d,y\0".as_ptr() as *const c_char) }; + + let input = "-(4096-4096*y^2)/(-3072+1024*d)*(1536-512*d)-(-8192+8192*y^2)/(2)*((-6+d)/2)-(-8192+8192*y^2)/(-2)*((-13+3*d)/2)-(-8192+8192*y^2)/(-4)*(-8+2*d)\0"; + let result = + unsafe { super::simplify(symbolica, input.as_ptr() as *const i8, 4293491017, true) }; + let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned(); + + assert_eq!(result, "[32768+4293458249*y^2+4293482825*d+8192*d*y^2]"); + + let result = + unsafe { super::simplify(symbolica, input.as_ptr() as *const i8, 4293491017, false) }; + let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned(); + + assert_eq!(result, "32768+4293458249*y^2+4293482825*d+8192*d*y^2"); + + let result = unsafe { + super::simplify_factorized(symbolica, input.as_ptr() as *const i8, 4293491017, true) + }; + let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned(); + + assert_eq!(result, "[32768+4293458249*y^2+4293482825*d+8192*d*y^2]"); + + let result = unsafe { + super::simplify_factorized(symbolica, input.as_ptr() as *const i8, 4293491017, false) + }; + let result = unsafe { CStr::from_ptr(result).to_str().unwrap() }.to_owned(); + + unsafe { drop(symbolica) }; + assert_eq!(result, "32768+4293458249*y^2+4293482825*d+8192*d*y^2"); + } +} diff --git a/src/domains/float.rs b/src/domains/float.rs index b8706e43..337c26ba 100644 --- a/src/domains/float.rs +++ b/src/domains/float.rs @@ -892,3 +892,40 @@ impl<'a, T: Real + From<&'a Rational>> From<&'a Rational> for Complex { Complex::new(value.into(), T::zero()) } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn double() { + let a = 5.; + let b = 7.; + + let r = a.sqrt() + b.log() + b.sin() - a.cos() + b.tan() - 0.3.asin() + 0.5.acos() + - a.atan2(b) + + b.sinh() + - a.cosh() + + b.tanh() + - 0.7.asinh() + + b.acosh() / 0.4.atanh() + + b.powf(a); + assert_eq!(r, 17293.219725825093); + } + + #[test] + fn complex() { + let a = Complex::new(1., 2.); + let b: Complex = Complex::new(3., 4.); + + let r = a.sqrt() + b.log() - a.exp() + b.sin() - a.cos() + b.tan() - a.asin() + b.acos() + - a.atan2(&b) + + b.sinh() + - a.cosh() + + b.tanh() + - a.asinh() + + b.acosh() / a.atanh() + + b.powf(a); + assert_eq!(r, Complex::new(0.1924131450685842, -39.83285329561913)); + } +} diff --git a/src/evaluate.rs b/src/evaluate.rs index 8d234e9c..f7c05038 100644 --- a/src/evaluate.rs +++ b/src/evaluate.rs @@ -146,3 +146,47 @@ impl<'a> AtomView<'a> { } } } + +#[cfg(test)] +mod test { + use ahash::HashMap; + + use crate::{evaluate::EvaluationFn, representations::Atom, state::State}; + + #[test] + fn evaluate() { + let x = State::get_symbol("v1"); + let f = State::get_symbol("f1"); + let g = State::get_symbol("f2"); + let p0 = Atom::parse("v2(0)").unwrap(); + let a = Atom::parse("v1*cos(v1) + f1(v1, 1)^2 + f2(f2(v1)) + v2(0)").unwrap(); + + let mut const_map = HashMap::default(); + let mut fn_map: HashMap<_, EvaluationFn<_>> = HashMap::default(); + let mut cache = HashMap::default(); + + // x = 6 and p(0) = 7 + let v = Atom::new_var(x); + const_map.insert(v.as_view(), 6.); + const_map.insert(p0.as_view(), 7.); + + // f(x, y) = x^2 + y + fn_map.insert( + f, + EvaluationFn::new(Box::new(|args: &[f64], _, _, _| { + args[0] * args[0] + args[1] + })), + ); + + // g(x) = f(x, 3) + fn_map.insert( + g, + EvaluationFn::new(Box::new(move |args: &[f64], var_map, fn_map, cache| { + fn_map.get(&f).unwrap().get()(&[args[0], 3.], var_map, fn_map, cache) + })), + ); + + let r = a.evaluate::(&const_map, &fn_map, &mut cache); + assert_eq!(r, 2905.761021719902); + } +} diff --git a/src/parser.rs b/src/parser.rs index d85ced28..a3dcb223 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -117,7 +117,7 @@ impl Operator { } #[inline] - pub fn right_associative(&self) -> bool { + pub fn left_associative(&self) -> bool { match self { Operator::Mul => true, Operator::Add => true, @@ -127,6 +127,18 @@ impl Operator { Operator::Inv => true, } } + + #[inline] + pub fn right_associative(&self) -> bool { + match self { + Operator::Mul => true, + Operator::Add => true, + Operator::Pow => true, + Operator::Argument => true, + Operator::Neg => true, + Operator::Inv => true, + } + } } pub struct Position { @@ -244,7 +256,7 @@ impl Token { if let Token::Op(ml, mr, o2, mut args2) = other { debug_assert!(!ml && !mr); - if *o1 == o2 { + if *o1 == o2 && o2.left_associative() { // add from the left by swapping and then extending from the right std::mem::swap(args, &mut args2); args.append(&mut args2); @@ -1148,3 +1160,58 @@ impl Token { } } } + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use crate::{domains::integer::Z, parser::Token, representations::Atom, state::State}; + + #[test] + fn pow() { + let input = Atom::parse("v1^v2^v3^3").unwrap(); + assert_eq!(format!("{}", input), "v1^v2^v3^3"); + + let input = Atom::parse("(v1^v2)^3").unwrap(); + assert_eq!(format!("{}", input), "(v1^v2)^3"); + } + + #[test] + fn unary() { + let input = Atom::parse("-x^z").unwrap(); + assert_eq!(format!("{}", input), "-x^z"); + + let input = Atom::parse("(-x)^z").unwrap(); + assert_eq!(format!("{}", input), "(-x)^z"); + } + + #[test] + fn liberal() { + let input = Atom::parse( + "89233_21837281 x + ^2 / y + 5", + ) + .unwrap(); + let res = Atom::parse("8923321837281*x^2*y^-1+5").unwrap(); + assert_eq!(input, res); + } + + #[test] + fn poly() { + let var_names = ["v1".into(), "v2".into()]; + let var_map = Arc::new(vec![ + State::get_symbol("v1").into(), + State::get_symbol("v2").into(), + ]); + let (rest, input) = + Token::parse_polynomial::<_, u8>("#ABC*v1^2*v2+5".as_bytes(), &var_map, &var_names, &Z); + + assert!(rest.is_empty()); + assert_eq!( + input, + Atom::parse("5+2748*v1^2*v2") + .unwrap() + .to_polynomial(&Z, var_map.clone().into()) + ); + } +} diff --git a/src/poly/evaluate.rs b/src/poly/evaluate.rs index 29bc42f9..03982604 100644 --- a/src/poly/evaluate.rs +++ b/src/poly/evaluate.rs @@ -1889,3 +1889,138 @@ auto 𝑖 = 1i;\n", f.write_str("}") } } + +#[cfg(test)] +mod test { + use crate::{ + domains::{float::Complex, rational::Q}, + poly::{ + evaluate::{BorrowedHornerScheme, InstructionSetPrinter}, + polynomial::MultivariatePolynomial, + }, + representations::Atom, + }; + + use wide::f64x4; + + const RES_53: &str = "-a5^3*b0^5+a4*a5^2*b0^4*b1-a4^2*a5*b0^4*b2+a4^3*b0^4*b3-a3*a5^2* +b0^3*b1^2+2*a3*a5^2*b0^4*b2+a3*a4*a5*b0^3*b1*b2-3*a3*a4*a5*b0^4* +b3-a3*a4^2*b0^3*b1*b3-a3^2*a5*b0^3*b2^2+2*a3^2*a5*b0^3*b1*b3+a3^2 +*a4*b0^3*b2*b3-a3^3*b0^3*b3^2+a2*a5^2*b0^2*b1^3-3*a2*a5^2*b0^3*b1 +*b2+3*a2*a5^2*b0^4*b3-a2*a4*a5*b0^2*b1^2*b2+2*a2*a4*a5*b0^3*b2^2+ +a2*a4*a5*b0^3*b1*b3+a2*a4^2*b0^2*b1^2*b3-2*a2*a4^2*b0^3*b2*b3+a2* +a3*a5*b0^2*b1*b2^2-2*a2*a3*a5*b0^2*b1^2*b3-a2*a3*a5*b0^3*b2*b3-a2 +*a3*a4*b0^2*b1*b2*b3+3*a2*a3*a4*b0^3*b3^2+a2*a3^2*b0^2*b1*b3^2- +a2^2*a5*b0^2*b2^3+3*a2^2*a5*b0^2*b1*b2*b3-3*a2^2*a5*b0^3*b3^2+ +a2^2*a4*b0^2*b2^2*b3-2*a2^2*a4*b0^2*b1*b3^2-a2^2*a3*b0^2*b2*b3^2+ +a2^3*b0^2*b3^3-a1*a5^2*b0*b1^4+4*a1*a5^2*b0^2*b1^2*b2-2*a1*a5^2* +b0^3*b2^2-4*a1*a5^2*b0^3*b1*b3+a1*a4*a5*b0*b1^3*b2-3*a1*a4*a5* +b0^2*b1*b2^2-a1*a4*a5*b0^2*b1^2*b3+5*a1*a4*a5*b0^3*b2*b3-a1*a4^2* +b0*b1^3*b3+3*a1*a4^2*b0^2*b1*b2*b3-3*a1*a4^2*b0^3*b3^2-a1*a3*a5* +b0*b1^2*b2^2+2*a1*a3*a5*b0*b1^3*b3+2*a1*a3*a5*b0^2*b2^3-4*a1*a3* +a5*b0^2*b1*b2*b3+3*a1*a3*a5*b0^3*b3^2+a1*a3*a4*b0*b1^2*b2*b3-2*a1 +*a3*a4*b0^2*b2^2*b3-a1*a3*a4*b0^2*b1*b3^2-a1*a3^2*b0*b1^2*b3^2+2* +a1*a3^2*b0^2*b2*b3^2+a1*a2*a5*b0*b1*b2^3-3*a1*a2*a5*b0*b1^2*b2*b3 +-a1*a2*a5*b0^2*b2^2*b3+5*a1*a2*a5*b0^2*b1*b3^2-a1*a2*a4*b0*b1* +b2^2*b3+2*a1*a2*a4*b0*b1^2*b3^2+a1*a2*a4*b0^2*b2*b3^2+a1*a2*a3*b0 +*b1*b2*b3^2-3*a1*a2*a3*b0^2*b3^3-a1*a2^2*b0*b1*b3^3-a1^2*a5*b0* +b2^4+4*a1^2*a5*b0*b1*b2^2*b3-2*a1^2*a5*b0*b1^2*b3^2-4*a1^2*a5* +b0^2*b2*b3^2+a1^2*a4*b0*b2^3*b3-3*a1^2*a4*b0*b1*b2*b3^2+3*a1^2*a4 +*b0^2*b3^3-a1^2*a3*b0*b2^2*b3^2+2*a1^2*a3*b0*b1*b3^3+a1^2*a2*b0* +b2*b3^3-a1^3*b0*b3^4+a0*a5^2*b1^5-5*a0*a5^2*b0*b1^3*b2+5*a0*a5^2* +b0^2*b1*b2^2+5*a0*a5^2*b0^2*b1^2*b3-5*a0*a5^2*b0^3*b2*b3-a0*a4*a5 +*b1^4*b2+4*a0*a4*a5*b0*b1^2*b2^2+a0*a4*a5*b0*b1^3*b3-2*a0*a4*a5* +b0^2*b2^3-7*a0*a4*a5*b0^2*b1*b2*b3+3*a0*a4*a5*b0^3*b3^2+a0*a4^2* +b1^4*b3-4*a0*a4^2*b0*b1^2*b2*b3+2*a0*a4^2*b0^2*b2^2*b3+4*a0*a4^2* +b0^2*b1*b3^2+a0*a3*a5*b1^3*b2^2-2*a0*a3*a5*b1^4*b3-3*a0*a3*a5*b0* +b1*b2^3+6*a0*a3*a5*b0*b1^2*b2*b3+3*a0*a3*a5*b0^2*b2^2*b3-7*a0*a3* +a5*b0^2*b1*b3^2-a0*a3*a4*b1^3*b2*b3+3*a0*a3*a4*b0*b1*b2^2*b3+a0* +a3*a4*b0*b1^2*b3^2-5*a0*a3*a4*b0^2*b2*b3^2+a0*a3^2*b1^3*b3^2-3*a0 +*a3^2*b0*b1*b2*b3^2+3*a0*a3^2*b0^2*b3^3-a0*a2*a5*b1^2*b2^3+3*a0* +a2*a5*b1^3*b2*b3+2*a0*a2*a5*b0*b2^4-6*a0*a2*a5*b0*b1*b2^2*b3-3*a0 +*a2*a5*b0*b1^2*b3^2+7*a0*a2*a5*b0^2*b2*b3^2+a0*a2*a4*b1^2*b2^2*b3 +-2*a0*a2*a4*b1^3*b3^2-2*a0*a2*a4*b0*b2^3*b3+4*a0*a2*a4*b0*b1*b2* +b3^2-3*a0*a2*a4*b0^2*b3^3-a0*a2*a3*b1^2*b2*b3^2+2*a0*a2*a3*b0* +b2^2*b3^2+a0*a2*a3*b0*b1*b3^3+a0*a2^2*b1^2*b3^3-2*a0*a2^2*b0*b2* +b3^3+a0*a1*a5*b1*b2^4-4*a0*a1*a5*b1^2*b2^2*b3+2*a0*a1*a5*b1^3* +b3^2-a0*a1*a5*b0*b2^3*b3+7*a0*a1*a5*b0*b1*b2*b3^2-3*a0*a1*a5*b0^2 +*b3^3-a0*a1*a4*b1*b2^3*b3+3*a0*a1*a4*b1^2*b2*b3^2+a0*a1*a4*b0* +b2^2*b3^2-5*a0*a1*a4*b0*b1*b3^3+a0*a1*a3*b1*b2^2*b3^2-2*a0*a1*a3* +b1^2*b3^3-a0*a1*a3*b0*b2*b3^3-a0*a1*a2*b1*b2*b3^3+3*a0*a1*a2*b0* +b3^4+a0*a1^2*b1*b3^4-a0^2*a5*b2^5+5*a0^2*a5*b1*b2^3*b3-5*a0^2*a5* +b1^2*b2*b3^2-5*a0^2*a5*b0*b2^2*b3^2+5*a0^2*a5*b0*b1*b3^3+a0^2*a4* +b2^4*b3-4*a0^2*a4*b1*b2^2*b3^2+2*a0^2*a4*b1^2*b3^3+4*a0^2*a4*b0* +b2*b3^3-a0^2*a3*b2^3*b3^2+3*a0^2*a3*b1*b2*b3^3-3*a0^2*a3*b0*b3^4+ +a0^2*a2*b2^2*b3^3-2*a0^2*a2*b1*b3^4-a0^2*a1*b2*b3^4+a0^3*b3^5"; + + #[test] + fn res_53() { + let poly: MultivariatePolynomial<_, u8> = + Atom::parse(RES_53).unwrap().to_polynomial(&Q, None); + + let (h, _ops, scheme) = poly.optimize_horner_scheme(1000); + let mut i = h.to_instr(poly.nvars()); + + println!( + "Number of operations={}, with scheme={:?}", + BorrowedHornerScheme::from(&h).op_count_cse(), + scheme, + ); + + i.fuse_operations(); + + for _ in 0..100_000 { + if !i.common_pair_elimination() { + break; + } + i.fuse_operations(); + } + + let o = i.to_output(poly.variables.as_ref().to_vec(), true); + let o_f64 = o.convert::(); + + let _ = format!( + "{}", + InstructionSetPrinter { + name: "sigma".to_string(), + instr: &o, + mode: crate::poly::evaluate::InstructionSetMode::CPP( + crate::poly::evaluate::InstructionSetModeCPPSettings { + write_header_and_test: true, + always_pass_output_array: false, + } + ) + } + ); + + let mut evaluator = o_f64.evaluator(); + + let res = evaluator + .evaluate_with_input(&(0..poly.nvars()).map(|x| x as f64 + 1.).collect::>())[0]; + + assert_eq!(res, 280944.); + + // evaluate with simd + let o_f64x4 = o.convert::(); + let mut evaluator = o_f64x4.evaluator(); + + let res = evaluator.evaluate_with_input( + &(0..poly.nvars()) + .map(|x| f64x4::new([x as f64 + 1., x as f64 + 2., x as f64 + 3., x as f64 + 4.])) + .collect::>(), + )[0]; + + assert_eq!(res, f64x4::new([280944.0, 645000.0, 1774950.0, 4985154.0])); + + // evaluate with complex numbers + let mut complex_evaluator = o.convert::>().evaluator(); + let res = complex_evaluator.evaluate_with_input( + &(0..poly.nvars()) + .map(|x| Complex::new(x as f64 + 0.1, x as f64 + 2.)) + .collect::>(), + )[0]; + assert!( + (res.re - 3230756.634848104).abs() < 1e-6 && (res.im - 2522437.0904901037).abs() < 1e-6 + ); + } +} diff --git a/src/poly/factor.rs b/src/poly/factor.rs index 9ca18f70..4b5ede98 100644 --- a/src/poly/factor.rs +++ b/src/poly/factor.rs @@ -348,8 +348,8 @@ where FiniteField: Field + PolynomialGCD + FiniteFieldCore, { fn square_free_factorization(&self) -> Vec<(Self, usize)> { - let c = self.content(); - let stripped = self.clone().div_coeff(&c); + let c = self.lcoeff(); + let stripped = self.clone().make_monic(); let mut factors = vec![]; let fs = stripped.factor_separable(); @@ -383,7 +383,10 @@ where } match var_count { - 0 | 1 => { + 0 => { + factors.push((f, p)); + } + 1 => { for (d2, f2) in f.distinct_degree_factorization() { debug!("DDF {} {}", f2, d2); for f3 in f2.equal_degree_factorization(d2) { @@ -3130,7 +3133,7 @@ mod test { #[test] fn factor_ff_bivariate() { - let field = Zp::new(17); + let field = Zp::new(997); let poly = Atom::parse("((v2+1)*v1^2+v1*v2+1)*((v2^2+2)*v1^2+v2+1)") .unwrap() .to_polynomial::<_, u8>(&field, None); diff --git a/src/poly/resultant.rs b/src/poly/resultant.rs index 865c57c2..2f0bf93f 100644 --- a/src/poly/resultant.rs +++ b/src/poly/resultant.rs @@ -97,3 +97,63 @@ impl UnivariatePolynomial { res } } + +#[cfg(test)] +mod test { + use crate::domains::integer::Z; + use crate::domains::rational::Q; + use crate::poly::polynomial::MultivariatePolynomial; + use crate::representations::Atom; + + #[test] + fn resultant() { + let a = Atom::parse("9v1^6-27v1^4-27v1^3+72v1^2+18v1-451") + .unwrap() + .to_polynomial::<_, u8>(&Q, None) + .to_univariate_from_univariate(0); + let b = Atom::parse("3v1^4-4v1^2-9v1+21") + .unwrap() + .to_polynomial::<_, u8>(&Q, None) + .to_univariate_from_univariate(0); + let r = a.resultant(&b); + assert_eq!(r, 11149673028381u64.into()); + } + + #[test] + fn resultant_prs_large() { + let system = [ + "-272853213601 + 114339252960*v2 - 4841413740*v2^2 + 296664007680*v4 - 25123011840*v2*v4 - + 32592015360*v4^2 - 4907531205*v5 + 6155208630*v5^2 - 3860046090*v5^3 + 1210353435*v5^4 - + 151807041*v5^5 + 312814245280*v6 - 97612876080*v2*v6 + 1518070410*v2^2*v6 - + 253265840640*v4*v6 + 7877554560*v2*v4*v6 + 10219530240*v4^2*v6 - 146051082720*v6^2 + + 29048482440*v2*v6^2 + 75369035520*v4*v6^2 + 35852138640*v6^3 - 3036140820*v2*v6^3 - + 7877554560*v4*v6^3 - 4841413740*v6^4 + 303614082*v6^5", + "-121828703201 - 1128406464*v1 + 303614082*v1^2 + 24547177584*v2 - 303614082*v2^2 - + 2927757312*v3 + 1575510912*v1*v3 + 2043906048*v3^2 + 123022775808*v4 - 6600113280*v2*v4 - + 15080712192*v4^2 + 1480577211*v5 + 146055798*v5^2 - 1347744906*v5^3 + 816475707*v5^4 - + 151807041*v5^5 + 171636450272*v6 - 32717479104*v2*v6 + 303614082*v2^2*v6 - + 135541762560*v4*v6 + 3151021824*v2*v4*v6 + 6131718144*v4^2*v6 - 95005523376*v6^2 + + 13441077468*v2*v6^2 + 49947954048*v4*v6^2 + 26959925088*v6^3 - 1821684492*v2*v6^3 - + 6302043648*v4*v6^3 - 4176745074*v6^4 + 303614082*v6^5", + ]; + + let mut system = system + .iter() + .map(|s| Atom::parse(s).unwrap().to_polynomial::<_, u16>(&Z, None)) + .collect::>(); + MultivariatePolynomial::unify_variables_list(&mut system); + + let var = 0; + let a = system[0].to_univariate(var); + let b = system[1].to_univariate(var); + + let r = a.resultant_prs(&b); + + let res = "-351386377558921617913117495604303232443676-13790107017999999428952788718610765086720*v3+9827971852963339087984510765471845089280*v3^2-280524240668642743539521434896511795200*v3^3+97918838723960202933606538595952230400*v3^4-5314937079854166446575553985297899043840*v1+7575728303325907213654727048384547256320*v1*v3-324356153273118172217571659099091763200*v1*v3^2+150958209699438646189310080335426355200*v1*v3^3+1459905975120096702631379691615772127520*v1^2-125012267407347628875522410277774950400*v1^2*v3+87272714982487967328194890193918361600*v1^2*v3^2-16060603798860632876369198542630809600*v1^3+22424239266333713827383409285937356800*v1^3*v3+2160668887641529717742672248905422400*v1^4+1354294280612676780452085325751267257668192*v6+39734613593311178569710537839540105625600*v6*v3-27865256722167384570188475202759258521600*v6*v3^2+175921981436267483236649035443575193600*v6*v3^3-61406729369263178110905795390681907200*v6*v3^4+15314382322422016740409269792322749043200*v6*v1-21479468723337358939520282968793595110400*v6*v1*v3+203409791035684277492375447231633817600*v6*v1*v3^2-94668707777614066254313101227301273600*v6*v1*v3^3-4139272618559803545636721197111265724400*v6*v1^2+78397523628336648616853036953858867200*v6*v1^2*v3-54730346683933132053274761647033548800*v6*v1^2*v3^2+10071904077251583329248480441988812800*v6*v1^3-14062658522955040874799765145418342400*v6*v1^3*v3-1354995743097230500957269037449163200*v6*v1^4-2367896305717791221956898715505617501445056*v6^2-46292687892477418837577147366397047930880*v6^2*v3+32337290781902786933956629706880358481920*v6^2*v3^2-27580988615008037626084806404289331200*v6^2*v3^3+9627326214672616907218281480742502400*v6^2*v3^4-17841973458559005176982858880798862223360*v6^2*v1+24926661644383398261591568732386942996480*v6^2*v1*v3-31890518086103043505160557404959539200*v6^2*v1*v3^2+14842127914286951065294850616144691200*v6^2*v1*v3^3+4803575421053050706660875224470400473280*v6^2*v1^2-12291137179018881350947298166494822400*v6^2*v1^2*v3+8580605200447143584623585512458649600*v6^2*v1^2*v3^2-1579069707026731284670312611667737600*v6^2*v1^3+2204738836226002171049115721951180800*v6^2*v1^3*v3+212435773282192917522961671125504400*v6^2*v1^4+2494961992653296325028324212640453124889600*v6^3+28849190833763405479699812009600299827200*v6^3*v3-20140001148098981183941378195381341388800*v6^3*v3^2+11118958967179645861967635878700115558400*v6^3*v1-15524584218326297995954812358939783987200*v6^3*v1*v3-2991716750406630342970458631670687539200*v6^3*v1^2-1772535279912235533967029650358928182132480*v6^4-10505535408790438577935793180820583219200*v6^4*v3+7334053021231060894030648069629463756800*v6^4*v3^2-4049008438804648201912753621774599782400*v6^4*v1+5653332537198942772481957887006044979200*v6^4*v1*v3+1089444291022712930113710634475123251200*v6^4*v1^2+898690122182833295589205662503794559796864*v6^5+2251068354128239337582288282696747581440*v6^5*v3-1571500549108393499821597480373201141760*v6^5*v3^2+867599261486925578026506942289371463680*v6^5*v1-1211365006604386656112481391121009213440*v6^5*v1*v3-233440131481053678521676101413944483840*v6^5*v1^2-334972523815164547408286541369481963062528*v6^6-263982621934964082348273502963276185600*v6^6*v3+184289754935729642394077728483796582400*v6^6*v3^2-101743302204100740071730412600429363200*v6^6*v1+142056686096291599345434915706259865600*v6^6*v1*v3+27375507216472860290526520214227161600*v6^6*v1^2+92887008215136370072837287898141019689344*v6^7+13110546324286806774343784710927810560*v6^7*v3-9152645547143619823598491213289226240*v6^7*v3^2+5053023062485540110945000357336760320*v6^7*v1-7055164275923206947357170310243778560*v6^7*v1*v3-1359588949006034672146954695203228160*v6^7*v1^2-19087257764871901513773677515006244966400*v6^8+2841567690972631378229632354823913000960*v6^9-291195946735118945124677882620928308224*v6^10+18431616670849378149970607435511871488*v6^11-543835579602413868858781878081291264*v6^12+15597477966285265118914372191523231041000*v5-245447710139298529858941866813755392000*v5*v3+171350288210453690656242435322810368000*v5*v3^2-94599638282854641716467177834468224000*v5*v1+132082513828891386547520210561332992000*v5*v1*v3+25453401102442610949261707243590212000*v5*v1^2-45411984028105044019699380521763730985280*v5*v6+182919978541731072291801887104756531200*v5*v6*v3-127698852944227352354654147601433804800*v5*v6*v3^2+70500408396292184112465310654958246400*v5*v6*v1-98434532477841917440045905442771891200*v5*v6*v1*v3-18969154696250786173342179694700833200*v5*v6*v1^2+55003612820966057503024025353699032766080*v5*v6^2-33223980977628423890341541470219468800*v5*v6^2*v3+23194099927778333659295038384870195200*v5*v6^2*v3^2-12805076001794288374402469108313753600*v5*v6^2*v1+17878785360995798862373258755004108800*v5*v6^2*v1*v3+3445390928941898739103180072578916800*v5*v6^2*v1^2-36212231181655357463687198453157452313600*v5*v6^3+14069264595301729477604555666412809299200*v5*v6^4-3237479989051682809580844453154845667200*v5*v6^5+409563107761162882925994444729920136960*v5*v6^6-22013254475726077349297074842098808960*v5*v6^7-10053352936060966782937950767819809690860*v5^2+32932755432170760569169811376981606400*v5^2*v3-22990791528119210208665717376383385600*v5^2*v3^2+12692832822815813969367531468211660800*v5^2*v1-17722068469591891202513157144295526400*v5^2*v1*v3-3415190277994270700484306324681950400*v5^2*v1^2+29707896766397290357555785282575595878880*v5^2*v6-57019534921064997757748321171863142400*v5^2*v6*v3+39806090416592545604465809119979929600*v5^2*v6*v3^2-21976279084160467885798832118322252800*v5^2*v6*v1+30683861362790087236775727863317862400*v5^2*v6*v1*v3+5913035783454339727920322556993546400*v5^2*v6*v1^2-37302062378389232891674984124380580041920*v5^2*v6^2+14641076363022695273709831834334003200*v5^2*v6^2*v3-10221128781732825002401203356044492800*v5^2*v6^2*v3^2+5642914848248330470075664352816230400*v5^2*v6^2*v1-7878786769252385939350927586950963200*v5^2*v6^2*v1*v3-1518307866991345207062418337068675200*v5^2*v6^2*v1^2+25681302707666851532860725046888387123200*v5^2*v6^3-10458741720579739914368041771228643980800*v5^2*v6^4+2520479886656951321446241477235442417920*v5^2*v6^5-333083286426653512111002147322570191360*v5^2*v6^6+18640148890139899619012458968935427840*v5^2*v6^7+414911263816261790113586737614444084960*v5^3+151752777978897395607235689418029465600*v5^3*v3-105940618589041578065428688839001702400*v5^3*v3^2+58488049846033371223622088629865523200*v5^3*v1-81662560162386216425434614313397145600*v5^3*v1*v3-15737055864626510456984795466644241600*v5^3*v1^2-1585672703184495673894588405081928144640*v5^3*v6-72360704332631397795065899642766131200*v5^3*v6*v3+50515963402025692800329024278912204800*v5^3*v6*v3^2-27889021461535017900181648820649446400*v5^3*v6*v1+38939388455728138200253622881661491200*v5^3*v6*v1*v3+7503944650322609965673875242820183200*v5^3*v6*v1^2+3527447631833347104006026820522109140480*v5^3*v6^2+7769123963168496709778228887319347200*v5^3*v6^2*v3-5423728049759139212486688091147468800*v5^3*v6^2*v3^2+2994349860804524773560359050320998400*v5^3*v6^2*v1-4180790371689336476291822070259507200*v5^3*v6^2*v1*v3-805673144544299216785403211456259200*v5^3*v6^2*v1^2-3727514692990837419389364030442284902400*v5^3*v6^3+2051485173833671116475035035491521907200*v5^3*v6^4-614821883114260826429743802606140235520*v5^3*v6^5+95621327057944630975369889554673978880*v5^3*v6^6-6064521488024361377257398718961660160*v5^3*v6^7+1627462659741295232635314702377463845640*v5^4-101642857058676788342485563311434137600*v5^4*v3+70958220965491342805131430991001190400*v5^4*v3^2-39174851158031678840332977526281907200*v5^4*v1+54696961994232910078955478055563417600*v5^4*v1*v3+10540560384305300379798711916957533600*v5^4*v1^2-4785395055990532927423173887678038866240*v5^4*v6+56591005231715981715543689963314790400*v5^4*v6*v3-39506928180631911763681443936653721600*v5^4*v6*v3^2+21811116599723867952865797173360908800*v5^4*v6*v1-30453257139237098651171113034503910400*v5^4*v6*v1*v3-5868596427873815885902766574357524400*v5^4*v6*v1^2+5114113403851989582240098526003298972800*v5^4*v6^2-7751167590372267518615413258149888000*v5^4*v6^2*v3+5411192468750450909222080953802752000*v5^4*v6^2*v3^2-2987429175455978106133023859911936000*v5^4*v6^2*v1+4171127527995139242525354068556288000*v5^4*v6^2*v1*v3+803811034040729958194990106961368000*v5^4*v6^2*v1^2-2712349074575972935281868605613212595200*v5^4*v6^3+770070834029855766768147977571966240000*v5^4*v6^4-110058118810184370995065087641038839680*v5^4*v6^5+5531213039802125716963085591624820480*v5^4*v6^6+137796177264125135690569732621948800*v5^4*v6^7-373744237967747318097655035394806229392*v5^5+19528752507418994002306184430398177280*v5^5*v3-13633280052349109020477902338202501120*v5^5*v3^2+7526706695567737271722175249215964160*v5^5*v1-10508986707019104869951716385697761280*v5^5*v1*v3-2025169313331806667646945345160506080*v5^5*v1^2+1215536820450847752474367554142920011904*v5^5*v6-11349923971616534580829712270862336000*v5^5*v6*v3+7923531829241731688503761396639744000*v5^5*v6*v3^2-4374449864060539369694784937728192000*v5^5*v6*v1+6107722451707168176554982743243136000*v5^5*v6*v1*v3+1177009014131068867356949799479146000*v5^5*v6*v1^2-1413634176130610954389489058318893777152*v5^5*v6^2+1638818290535850846792973088865976320*v5^5*v6^2*v3-1144080693392952477949811401661153280*v5^5*v6^2*v3^2+631627882810692513868125044667095040*v5^5*v6^2*v1-881895534490400868419646288780472320*v5^5*v6^2*v1*v3-169948618625754334018369336900403520*v5^5*v6^2*v1^2+826980745012980525654275093992819891200*v5^5*v6^3-273850667214219945816540968220232769280*v5^5*v6^4+51787378668025078969009698679958550144*v5^5*v6^5-5148065182587715069399685210756007168*v5^5*v6^6+203938342350905200822043204280484224*v5^5*v6^7+8396010529199487178908025111075470600*v5^6+12821686359815483795642849669657085120*v5^6*v6-3486741933219625604409869636436712320*v5^6*v6^2-36785147870484641738753696782202499360*v5^7+12905301212905230238157985715741313280*v5^7*v6-916923342881865840823092794954641920*v5^7*v6^2+20230288344237735053451621915739901460*v5^8-9430829758183594873828448918573310240*v5^8*v6+1082631046775166944466178953329743680*v5^8*v6^2-4939771095356636533922616686506369560*v5^9+2523408013736821291155145004510228160*v5^9*v6-321524413616291983277996042784547200*v5^9*v6^2+474541792488342833563034227065152484*v5^10-254004286756870666789616873799792288*v5^10*v6+33989723725150866803673867380080704*v5^10*v6^2+340054807036742119278407339941916972482560*v4-1168465430599048731112810012610311926251520*v4*v6+1803251539160284741968818546040707238297600*v4*v6^2-1645584197198932009018040684997750258401280*v4*v6^3+983225719045982886492157957057230189711360*v4*v6^4-401872324263573241960647370747903148544000*v4*v6^5+113789782501275194186732116722172129443840*v4*v6^6-22040058529475774799832268879481077760000*v4*v6^7+2794927291759340041335283859047296860160*v4*v6^8-209557447006475795544473788404267909120*v4*v6^9+7055164275923206947357170310243778560*v4*v6^10-4148924309292177805612699194701222707200*v4*v5+7776797270776577512297113724354201190400*v4*v5*v6-5792702378959601123657932251494387712000*v4*v5*v6^2+2140476011255620382325631426462128537600*v4*v5*v6^3-392197080479229244438225955932453171200*v4*v5*v6^4+28509414494560868456216818014736281600*v4*v5*v6^5+5203735574366460298565080345896448819200*v4*v5^2-9753949119279097218813329078003574374400*v4*v5^2*v6+7265423322762889544926898078145503232000*v4*v5^2*v6^2-2684664827676540818510113992511822233600*v4*v5^2*v6^3+491908202634965493024215605745788723200*v4*v5^2*v6^4-35757570721991597724746517510008217600*v4*v5^2*v6^5-3263359597484051373676406318613027225600*v4*v5^3+6116883345988586391459206370951394099200*v4*v5^3*v6-4556282422749608697666020828667518976000*v4*v5^3*v6^2+1683603366509017123472444368185380044800*v4*v5^3*v6^3-308484805042266495625355549366003097600*v4*v5^3*v6^4+22424239266333713827383409285937356800*v4*v5^3*v6^5+1023256822939914413779890116853237350400*v4*v5^4-1918005794928624546474496912925437132800*v4*v5^4*v6+1428664827472334930624091276785577984000*v4*v5^4*v6^2-527909530176556216682037640871686963200*v4*v5^4*v6^3+96728286326812375746933519716458598400*v4*v5^4*v6^4-7031329261477520437399882572709171200*v4*v5^4*v6^5-128340686267040112914765879062948413440*v4*v5^5+240563438685963078710360629756749742080*v4*v5^5*v6-179188469886360652315563990647682662400*v4*v5^5*v6^2+66212381750957898363509805804245483520*v4*v5^5*v6^3-12132022352854433568259458405115146240*v4*v5^5*v6^4+881895534490400868419646288780472320*v4*v5^5*v6^5-49164550414950667847480579268710887587840*v4^2+123028102530480819459927412746971678638080*v4^2*v6-134877064351018867955458522266590727045120*v4^2*v6^2+84554433522522008212768621789075249397760*v4^2*v6^3-33138726988704698123377646193993999974400*v4^2*v6^4+8312765414115754783491341757137477959680*v4^2*v6^5-1303272543738486978767710360229180866560*v4^2*v6^6+116758072925724015046986158180338237440*v4^2*v6^7-4576322773571809911799245606644613120*v4^2*v6^8+599844478451881128522317955860417740800*v4^2*v5-564260483967447502254044856783952281600*v4^2*v5*v6+176929134803352182910166607635646054400*v4^2*v5*v6^2-18492593185661103863491990063612723200*v4^2*v5*v6^3-752347311956596669672059809045269708800*v4^2*v5^2+707716539213408731640666430542584217600*v4^2*v5^2*v6-221911118227933246361903880763352678400*v4^2*v5^2*v6^2+23194099927778333659295038384870195200*v4^2*v5^2*v6^3+471811026142272487760444287028389478400*v4^2*v5^3-443822236455866492723807761526705356800*v4^2*v5^3*v6+139164599566670001955770230309221171200*v4^2*v5^3*v6^2-14545452497081327888032481698986393600*v4^2*v5^3*v6^3-147940745485288830907935920508901785600*v4^2*v5^4+139164599566670001955770230309221171200*v4^2*v5^4*v6-43636357491243983664097445096959180800*v4^2*v5^4*v6^2+4560862223661094337772896803919462400*v4^2*v5^4*v6^3+18555279942222666927436030707896156160*v4^2*v5^5-17454542996497593465638978038783672320*v4^2*v5^5*v6+5473034668393313205327476164703354880*v4^2*v5^5*v6^2-572040346696476238974905700830576640*v4^2*v5^5*v6^3"; + let res = Atom::parse(res) + .unwrap() + .to_polynomial::<_, u16>(&Z, system[0].variables.clone().into()); + + assert_eq!(r, res); + } +} diff --git a/src/printer.rs b/src/printer.rs index a71dd110..65e7ebeb 100644 --- a/src/printer.rs +++ b/src/printer.rs @@ -105,6 +105,7 @@ impl Default for PrintOptions { #[derive(Debug, Copy, Clone)] pub struct PrintState { pub level: usize, + pub top_level_add_child: bool, pub explicit_sign: bool, pub superscript: bool, } @@ -160,6 +161,7 @@ impl<'a> fmt::Display for AtomPrinter<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let print_state = PrintState { level: 0, + top_level_add_child: false, explicit_sign: false, superscript: false, }; @@ -216,7 +218,7 @@ impl<'a> FormattedPrintVar for VarView<'a> { print_state: PrintState, ) -> fmt::Result { if print_state.explicit_sign { - if print_state.level == 1 && opts.color_top_level_sum { + if print_state.top_level_add_child && opts.color_top_level_sum { f.write_fmt(format_args!("{}", "+".yellow()))?; } else { f.write_char('+')?; @@ -301,7 +303,7 @@ impl<'a> FormattedPrintNum for NumView<'a> { }; if is_negative { - if print_state.level == 1 && opts.color_top_level_sum { + if print_state.top_level_add_child && opts.color_top_level_sum { f.write_fmt(format_args!("{}", "-".yellow()))?; } else if print_state.superscript { f.write_char('⁻')?; @@ -309,7 +311,7 @@ impl<'a> FormattedPrintNum for NumView<'a> { f.write_char('-')?; } } else if print_state.explicit_sign { - if print_state.level == 1 && opts.color_top_level_sum { + if print_state.top_level_add_child && opts.color_top_level_sum { f.write_fmt(format_args!("{}", "+".yellow()))?; } else { f.write_char('+')?; @@ -395,7 +397,7 @@ impl<'a> FormattedPrintMul for MulView<'a> { if let Some(AtomView::Num(n)) = self.iter().last() { // write -1*x as -x if n.get_coeff_view() == CoefficientView::Natural(-1, 1) { - if print_state.level == 1 && opts.color_top_level_sum { + if print_state.top_level_add_child && opts.color_top_level_sum { f.write_fmt(format_args!("{}", "-".yellow()))?; } else { f.write_char('-')?; @@ -409,13 +411,14 @@ impl<'a> FormattedPrintMul for MulView<'a> { skip_num = true; } else if print_state.explicit_sign { - if print_state.level == 1 && opts.color_top_level_sum { + if print_state.top_level_add_child && opts.color_top_level_sum { f.write_fmt(format_args!("{}", "+".yellow()))?; } else { f.write_char('+')?; } } + print_state.top_level_add_child = false; print_state.level += 1; print_state.explicit_sign = false; for x in self.iter().take(if skip_num { @@ -460,7 +463,7 @@ impl<'a> FormattedPrintFn for FunView<'a> { mut print_state: PrintState, ) -> fmt::Result { if print_state.explicit_sign { - if print_state.level == 1 && opts.color_top_level_sum { + if print_state.top_level_add_child && opts.color_top_level_sum { f.write_fmt(format_args!("{}", "+".yellow()))?; } else { f.write_char('+')?; @@ -492,6 +495,7 @@ impl<'a> FormattedPrintFn for FunView<'a> { } } + print_state.top_level_add_child = false; print_state.level += 1; print_state.explicit_sign = false; let mut first = true; @@ -526,7 +530,7 @@ impl<'a> FormattedPrintPow for PowView<'a> { mut print_state: PrintState, ) -> fmt::Result { if print_state.explicit_sign { - if print_state.level == 1 && opts.color_top_level_sum { + if print_state.top_level_add_child && opts.color_top_level_sum { f.write_fmt(format_args!("{}", "+".yellow()))?; } else { f.write_char('+')?; @@ -536,6 +540,7 @@ impl<'a> FormattedPrintPow for PowView<'a> { let b = self.get_base(); let e = self.get_exp(); + print_state.top_level_add_child = false; print_state.level += 1; print_state.explicit_sign = false; @@ -619,10 +624,11 @@ impl<'a> FormattedPrintAdd for AddView<'a> { mut print_state: PrintState, ) -> fmt::Result { let mut first = true; + print_state.top_level_add_child = print_state.level == 0; print_state.level += 1; for x in self.iter() { - if !first && print_state.level == 1 && opts.terms_on_new_line { + if !first && print_state.top_level_add_child && opts.terms_on_new_line { f.write_char('\n')?; f.write_char('\t')?; } @@ -1291,3 +1297,120 @@ impl<'a, F: Ring + Display> Display for MatrixPrinter<'a, F> { } } } + +#[cfg(test)] +mod test { + use colored::control::ShouldColorize; + + use crate::{ + domains::{finite_field::Zp, integer::Z}, + printer::{AtomPrinter, PolynomialPrinter, PrintOptions}, + representations::Atom, + }; + + #[test] + fn atoms() { + let a = Atom::parse("f(x,y^2)^(x+z)/5+3").unwrap(); + + if ShouldColorize::from_env().should_colorize() { + assert_eq!(format!("{}", a), "1/5*f(x,y^2)^(x+z)\u{1b}[33m+\u{1b}[0m3"); + } else { + assert_eq!(format!("{}", a), "1/5*f(x,y^2)^(x+z)+3"); + } + + assert_eq!( + format!( + "{}", + AtomPrinter::new_with_options(a.as_view(), PrintOptions::latex()) + ), + "\\frac{1}{5} f\\!\\left(x,y^{2}\\right)^{x+z}+3" + ); + + assert_eq!( + format!( + "{}", + AtomPrinter::new_with_options(a.as_view(), PrintOptions::mathematica()) + ), + "1/5 f[x,y^2]^(x+z)+3" + ); + + let a = Atom::parse("8127389217 x^2").unwrap(); + assert_eq!( + format!( + "{}", + AtomPrinter::new_with_options( + a.as_view(), + PrintOptions { + terms_on_new_line: true, + color_top_level_sum: false, + color_builtin_symbols: false, + print_finite_field: true, + symmetric_representation_for_finite_field: false, + explicit_rational_polynomial: false, + number_thousands_separator: Some('_'), + multiplication_operator: ' ', + square_brackets_for_function: false, + num_exp_as_superscript: true, + latex: false + } + ) + ), + "812_738_921_7 x²" + ); + } + + #[test] + fn polynomials() { + let a = Atom::parse("15 x^2") + .unwrap() + .to_polynomial::<_, u8>(&Zp::new(17), None); + assert_eq!( + format!( + "{}", + PolynomialPrinter::new_with_options( + &a, + PrintOptions { + terms_on_new_line: true, + color_top_level_sum: false, + color_builtin_symbols: false, + print_finite_field: true, + symmetric_representation_for_finite_field: true, + explicit_rational_polynomial: false, + number_thousands_separator: Some('_'), + multiplication_operator: ' ', + square_brackets_for_function: false, + num_exp_as_superscript: false, + latex: false + } + ) + ), + "-2*x^2 % 17" + ); + } + + #[test] + fn rational_polynomials() { + let a = Atom::parse("15 x^2 / (1+x)") + .unwrap() + .to_rational_polynomial::<_, _, u8>(&Z, &Z, None); + assert_eq!(format!("{}", a), "15*x^2/(1+x)"); + + let a = Atom::parse("(15 x^2 + 6) / (1+x)") + .unwrap() + .to_rational_polynomial::<_, _, u8>(&Z, &Z, None); + assert_eq!(format!("{}", a), "(6+15*x^2)/(1+x)"); + } + + #[test] + fn factorized_rational_polynomials() { + let a = Atom::parse("15 x^2 / ((1+x)(x+2))") + .unwrap() + .to_factorized_rational_polynomial::<_, _, u8>(&Z, &Z, None); + assert_eq!(format!("{}", a), "15*x^2/((1+x)(2+x))"); + + let a = Atom::parse("(15 x^2 + 6) / ((1+x)(x+2))") + .unwrap() + .to_factorized_rational_polynomial::<_, _, u8>(&Z, &Z, None); + assert_eq!(format!("{}", a), "3*(2+5*x^2)/((1+x)(2+x))"); + } +} diff --git a/src/representations.rs b/src/representations.rs index ce62dc95..492445dc 100644 --- a/src/representations.rs +++ b/src/representations.rs @@ -1037,3 +1037,46 @@ impl> std::ops::Div for Atom { self } } + +#[cfg(test)] +mod test { + use crate::{ + fun, + representations::{Atom, FunctionBuilder}, + state::State, + }; + + #[test] + fn debug() { + let x = Atom::parse("v1+f1(v2)").unwrap(); + assert_eq!( + format!("{:?}", x), + "AddView { data: [5, 15, 0, 0, 0, 1, 2, 2, 1, 12, 3, 5, 0, 0, 0, 1, 42, 2, 1, 13] }" + ); + assert_eq!( + x.get_all_symbols(true), + [ + State::get_symbol("v1"), + State::get_symbol("v2"), + State::get_symbol("f1") + ] + .into_iter() + .collect(), + ); + assert_eq!(x.as_view().get_byte_size(), 20); + } + + #[test] + fn composition() { + let v1 = Atom::parse("v1").unwrap(); + let v2 = Atom::parse("v2").unwrap(); + let f1_id = State::get_symbol("f1"); + + let f1 = fun!(f1_id, v1, v2, Atom::new_num(2)); + + let r = (-(&v2 + &v1 + 2) * &v2 * 6).npow(5) / &v2.pow(&v1) * &f1 / 4; + + let res = Atom::parse("1/4*(v2^v1)^-1*(-6*v2*(v1+v2+2))^5*f1(v1,v2,2)").unwrap(); + assert_eq!(res, r); + } +} diff --git a/src/solve.rs b/src/solve.rs index 48bc3195..b201fd4a 100644 --- a/src/solve.rs +++ b/src/solve.rs @@ -87,3 +87,111 @@ impl<'a> AtomView<'a> { Ok(result) } } + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use crate::{ + domains::{ + integer::Z, + rational::Q, + rational_polynomial::{RationalPolynomial, RationalPolynomialField}, + }, + poly::Variable, + representations::{Atom, AtomView}, + state::State, + tensors::matrix::Matrix, + }; + + #[test] + fn solve() { + let x = State::get_symbol("v1"); + let y = State::get_symbol("v2"); + let z = State::get_symbol("v3"); + let eqs = [ + "v4*v1 + f1(v4)*v2 + v3 - 1", + "v1 + v4*v2 + v3/v4 - 2", + "(v4-1)v1 + v4*v3", + ]; + + let atoms: Vec<_> = eqs.iter().map(|e| Atom::parse(e).unwrap()).collect(); + let system: Vec<_> = atoms.iter().map(|x| x.as_view()).collect(); + + let sol = AtomView::solve_linear_system::(&system, &[x, y, z]).unwrap(); + + let res = [ + "(v4^3-2*v4^2*f1(v4))*(v4^2-v4^3+v4^4-f1(v4)+v4*f1(v4)-v4^2*f1(v4))^-1", + "(v4^2-f1(v4))^-1*(2*v4-1)", + "(v4^2-v4^3-2*v4*f1(v4)+2*v4^2*f1(v4))*(v4^2-v4^3+v4^4-f1(v4)+v4*f1(v4)-v4^2*f1(v4))^-1", + ]; + let res = res + .iter() + .map(|x| Atom::parse(x).unwrap()) + .collect::>(); + + assert_eq!(sol, res); + } + + #[test] + fn solve_from_matrix() { + let system = [ + ["v4", "v4+1", "v4^2+5"], + ["1", "v4", "v4+1"], + ["v4-1", "-1", "v4"], + ]; + let rhs = ["1", "2", "-1"]; + + let var_map = Arc::new(vec![Variable::Symbol(State::get_symbol("v4"))]); + + let system_rat: Vec> = system + .iter() + .flatten() + .map(|s| { + Atom::parse(s) + .unwrap() + .to_rational_polynomial(&Q, &Z, Some(var_map.clone())) + }) + .collect(); + + let rhs_rat: Vec> = rhs + .iter() + .map(|s| { + Atom::parse(s) + .unwrap() + .to_rational_polynomial(&Q, &Z, Some(var_map.clone())) + }) + .collect(); + + let field = RationalPolynomialField::new_from_poly(&rhs_rat[0].numerator); + let m = Matrix::from_linear( + system_rat, + system.len() as u32, + system.len() as u32, + field.clone(), + ) + .unwrap(); + let b = Matrix::new_vec(rhs_rat, field); + + let sol = m.solve(&b).unwrap(); + + let res = [ + "(10-2*v4+4*v4^2-v4^3)/(6-4*v4+5*v4^2-3*v4^3+v4^4)", + "(-4+10*v4-5*v4^2+2*v4^3)/(6-4*v4+5*v4^2-3*v4^3+v4^4)", + "(2-4*v4)/(6-4*v4+5*v4^2-3*v4^3+v4^4)", + ]; + + let res = res + .iter() + .map(|x| { + Atom::parse(x).unwrap().to_rational_polynomial( + &Z, + &Z, + m.data[0].get_variables().clone().into(), + ) + }) + .collect::>(); + + assert_eq!(sol.data, res); + } +} diff --git a/src/streaming.rs b/src/streaming.rs index 38f26fb5..ce9751e2 100644 --- a/src/streaming.rs +++ b/src/streaming.rs @@ -89,12 +89,17 @@ impl TermOutputStream { } else if self.mem_buf.len() == 1 { self.mem_buf.pop().unwrap() } else { - let mut out = Atom::default(); - let add = out.to_add(); - for x in self.mem_buf.drain(..) { - add.extend(x.as_view()); - } - out + Workspace::get_local().with(|ws| { + let mut a = ws.new_atom(); + let add = a.to_add(); + for x in self.mem_buf.drain(..) { + add.extend(x.as_view()); + } + + let mut out = Atom::new(); + a.as_view().normalize(ws, &mut out); + out + }) } } } @@ -173,3 +178,28 @@ where self.exp_out.to_expression() } } + +#[cfg(test)] +mod test { + use crate::{id::Pattern, representations::Atom, streaming::TermStreamer}; + + #[test] + fn main() { + let input = Atom::parse("v1 + f1(v1) + 2*f1(v2) + 7*f1(v3)").unwrap(); + let pattern = Pattern::parse("f(x_)").unwrap(); + let rhs = Pattern::parse("f1(v1) + v1").unwrap(); + + let mut stream = TermStreamer::new_from(input); + + // map every term in the expression + stream = stream.map(|workspace, x| { + let mut out1 = workspace.new_atom(); + pattern.replace_all_into(x.as_view(), &rhs, None, None, &mut out1); + out1.expand() + }); + + let r = stream.to_expression(); + let res = Atom::parse("v1+f1(v1)+2*f1(v2)+7*f1(v3)").unwrap(); + assert_eq!(r, res); + } +} diff --git a/src/tensors/matrix.rs b/src/tensors/matrix.rs index 96424048..cbaeb38e 100644 --- a/src/tensors/matrix.rs +++ b/src/tensors/matrix.rs @@ -108,7 +108,7 @@ impl Matrix { } Ok(Matrix { - nrows: data.len() as u32, + nrows: (data.len() / cols) as u32, ncols: cols as u32, data, field, @@ -755,3 +755,120 @@ impl Matrix { Ok(result) } } + +#[cfg(test)] +mod test { + use crate::{ + domains::{integer::Z, rational::Q}, + tensors::matrix::Matrix, + }; + + #[test] + fn basics() { + let a = Matrix::from_linear( + vec![ + 1u64.into(), + 2u64.into(), + 3u64.into(), + 4u64.into(), + 5u64.into(), + 6u64.into(), + ], + 2, + 3, + Z, + ) + .unwrap(); + + assert_eq!( + a.transpose().data, + vec![1.into(), 4.into(), 2.into(), 5.into(), 3.into(), 6.into()] + ); + + assert_eq!( + a.clone().into_transposed().data, + vec![1.into(), 4.into(), 2.into(), 5.into(), 3.into(), 6.into()] + ); + + assert_eq!( + (-a.clone()).data, + vec![ + (-1).into(), + (-2).into(), + (-3).into(), + (-4).into(), + (-5).into(), + (-6).into() + ] + ); + + assert_eq!( + (&a - &a).data, + vec![0.into(), 0.into(), 0.into(), 0.into(), 0.into(), 0.into()] + ); + + let b = Matrix::from_nested_vec( + vec![ + vec![7u64.into(), 8u64.into()], + vec![9u64.into(), 10u64.into()], + vec![11u64.into(), 12u64.into()], + ], + Z, + ) + .unwrap(); + + let c = &a * &b; + + assert_eq!(c.data, vec![58.into(), 64.into(), 139.into(), 154.into()]); + assert_eq!(&c[1], &[139.into(), 154.into()]); + assert_eq!(c[(0, 1)], 64.into()); + + let c_m = c.map(|x| x * &2u64.into(), Z); + assert_eq!( + c_m.data, + vec![116.into(), 128.into(), 278.into(), 308.into()] + ); + } + + #[test] + fn solve() { + let a = Matrix::from_linear( + vec![ + 1u64.into(), + 2u64.into(), + 3u64.into(), + 4u64.into(), + 5u64.into(), + 16u64.into(), + 7u64.into(), + 8u64.into(), + 9u64.into(), + ], + 3, + 3, + Q, + ) + .unwrap(); + + assert_eq!( + a.inv().unwrap().data, + vec![ + (-83, 60).into(), + (1, 10).into(), + (17, 60).into(), + (19, 15).into(), + (-1, 5).into(), + (-1, 15).into(), + (-1, 20).into(), + (1, 10).into(), + (-1, 20).into() + ] + ); + assert_eq!(a.det().unwrap(), 60.into()); + + let b = Matrix::from_linear(vec![1u64.into(), 2u64.into(), 3u64.into()], 3, 1, Q).unwrap(); + + let r = a.solve(&b).unwrap(); + assert_eq!(r.data, vec![(-1, 3).into(), (2, 3).into(), 0.into()]); + } +}