Skip to content

Commit

Permalink
Use ASM for copying to output and for Pow
Browse files Browse the repository at this point in the history
- Always expand Pow
  • Loading branch information
benruijl committed Aug 10, 2024
1 parent 41e7c9c commit c92c835
Showing 1 changed file with 104 additions and 45 deletions.
149 changes: 104 additions & 45 deletions src/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1192,11 +1192,14 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {

res += &format!(
"static const std::complex<double> CONSTANTS_complex[{}] = {{{}}};\n\n",
self.reserved_indices - self.param_count,
(self.param_count..self.reserved_indices)
.map(|i| format!("std::complex<double>({})", self.stack[i]))
.collect::<Vec<_>>()
.join(",")
self.reserved_indices - self.param_count + 1,
{
let mut nums = (self.param_count..self.reserved_indices)
.map(|i| format!("std::complex<double>({})", self.stack[i]))
.collect::<Vec<_>>();
nums.push("std::complex<double>(0, -0.)".to_string()); // used for inversion
nums.join(",")
}
);

res += &format!("extern \"C\" void {}_complex(const std::complex<double> *params, std::complex<double> *out)\n{{\n", function_name);
Expand All @@ -1206,29 +1209,18 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {

self.export_asm_complex_impl(&self.instructions, &mut res);

for (i, r) in &mut self.result_indices.iter().enumerate() {
if *r < self.param_count {
res += &format!("\tout[{}] = params[{}];\n", i, r);
} else if *r < self.reserved_indices {
res += &format!(
"\tout[{}] = CONSTANTS_complex[{}];\n",
i,
r - self.param_count
);
} else {
res += &format!("\tout[{}] = Z[{}];\n", i, r);
}
}

res += "\treturn;\n}\n\n";

res += &format!(
"static const double CONSTANTS_double[{}] = {{{}}};\n\n",
self.reserved_indices - self.param_count,
(self.param_count..self.reserved_indices)
.map(|i| format!("double({})", self.stack[i]))
.collect::<Vec<_>>()
.join(",")
self.reserved_indices - self.param_count + 1,
{
let mut nums = (self.param_count..self.reserved_indices)
.map(|i| format!("double({})", self.stack[i]))
.collect::<Vec<_>>();
nums.push("1".to_string()); // used for inversion
nums.join(",")
}
);

res += &format!(
Expand All @@ -1240,20 +1232,6 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {

self.export_asm_double_impl(&self.instructions, &mut res);

for (i, r) in &mut self.result_indices.iter().enumerate() {
if *r < self.param_count {
res += &format!("\tout[{}] = params[{}];\n", i, r);
} else if *r < self.reserved_indices {
res += &format!(
"\tout[{}] = CONSTANTS_double[{}];\n",
i,
r - self.param_count
);
} else {
res += &format!("\tout[{}] = Z[{}];\n", i, r);
}
}

res += "\treturn;\n}\n";

res
Expand Down Expand Up @@ -1326,10 +1304,27 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
*out += &format!("\t\t\"movsd QWORD {}, xmm0\\n\\t\"\n", format_addr!(*o));
}
Instr::Pow(o, b, e) => {
end_asm_block!(in_asm_block);
if *e == -1 {
if !in_asm_block {
*out += "\t__asm__(\n";
in_asm_block = true;
}

let base = get_input!(*b);
*out += format!("\tZ[{}] = pow({}, {});\n", o, base, e).as_str();
*out += &format!(
"\t\t\"movsd xmm0, QWORD PTR [%1+{}]\\n\\t\"
\t\t\"divsd xmm0, QWORD {}\\n\\t\"
\t\t\"movapd xmm2, xmm0\\n\\t\"
\t\t\"movsd QWORD {}, xmm0\\n\\t\"\n",
(self.reserved_indices - self.param_count) * 8,
format_addr!(*b),
format_addr!(*o)
);
} else {
end_asm_block!(in_asm_block);

let base = get_input!(*b);
*out += format!("\tZ[{}] = pow({}, {});\n", o, base, e).as_str();
}
}
Instr::Powf(o, b, e) => {
end_asm_block!(in_asm_block);
Expand Down Expand Up @@ -1366,6 +1361,24 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
}

end_asm_block!(in_asm_block);

*out += "\t__asm__(\n";
for (i, r) in &mut self.result_indices.iter().enumerate() {
if *r < self.param_count {
*out += &format!("\t\t\"movsd xmm0, QWORD PTR[%3+{}]\\n\\t\"\n", r * 8);
} else if *r < self.reserved_indices {
*out += &format!(
"\t\t\"movsd xmm0, QWORD PTR[%2+{}]\\n\\t\"\n",
(r - self.param_count) * 8
);
} else {
*out += &format!("\t\t\"movsd xmm0, QWORD PTR[%1+{}]\\n\\t\"\n", r * 8);
}

*out += &format!("\t\t\"movsd QWORD PTR[%0+{}], xmm0\\n\\t\"\n", i * 8);
}

*out += "\t\t:\n\t\t: \"r\"(out), \"r\"(Z), \"r\"(CONSTANTS_double), \"r\"(params)\n\t\t: \"memory\", \"xmm0\");\n";
in_asm_block
}

Expand Down Expand Up @@ -1469,10 +1482,31 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
}
}
Instr::Pow(o, b, e) => {
end_asm_block!(in_asm_block);
if *e == -1 {
if !in_asm_block {
*out += "\t__asm__(\n";
in_asm_block = true;
}

let base = get_input!(*b);
*out += format!("\tZ[{}] = pow({}, {});\n", o, base, e).as_str();
*out += &format!(
"\t\t\"movupd xmm0, XMMWORD {}\\n\\t\"
\t\t\"movupd xmm1, XMMWORD PTR [%1+{}]\\n\\t\"
\t\t\"movapd xmm2, xmm0\\n\\t\"
\t\t\"xorpd xmm0, xmm1\\n\\t\"
\t\t\"mulpd xmm2, xmm2\\n\\t\"
\t\t\"haddpd xmm2, xmm2\\n\\t\"
\t\t\"divpd xmm0, xmm2\\n\\t\"
\t\t\"movupd XMMWORD {}, xmm0\\n\\t\"",
format_addr!(*b),
(self.reserved_indices - self.param_count) * 16,
format_addr!(*o)
);
} else {
end_asm_block!(in_asm_block);

let base = get_input!(*b);
*out += format!("\tZ[{}] = pow({}, {});\n", o, base, e).as_str();
}
}
Instr::Powf(o, b, e) => {
end_asm_block!(in_asm_block);
Expand Down Expand Up @@ -1508,6 +1542,24 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
}

end_asm_block!(in_asm_block);

*out += "\t__asm__(\n";
for (i, r) in &mut self.result_indices.iter().enumerate() {
if *r < self.param_count {
*out += &format!("\t\t\"movupd xmm0, XMMWORD PTR[%3+{}]\\n\\t\"\n", r * 16);
} else if *r < self.reserved_indices {
*out += &format!(
"\t\t\"movupd xmm0, XMMWORD PTR[%2+{}]\\n\\t\"\n",
(r - self.param_count) * 16
);
} else {
*out += &format!("\t\t\"movupd xmm0, XMMWORD PTR[%1+{}]\\n\\t\"\n", r * 16);
}

*out += &format!("\t\t\"movupd XMMWORD PTR[%0+{}], xmm0\\n\\t\"\n", i * 16);
}

*out += "\t\t:\n\t\t: \"r\"(out), \"r\"(Z), \"r\"(CONSTANTS_complex), \"r\"(params)\n\t\t: \"memory\", \"xmm0\");\n";
in_asm_block
}
}
Expand Down Expand Up @@ -3316,8 +3368,15 @@ impl<'a> AtomView<'a> {
if den == 1 {
if num > 1 {
return Ok(Expression::Mul(vec![b_eval.clone(); num as usize]));
} else {
return Ok(Expression::Pow(Box::new((
Expression::Mul(vec![
b_eval.clone();
num.unsigned_abs() as usize
]),
-1,
))));
}
return Ok(Expression::Pow(Box::new((b_eval, num))));
}
}
}
Expand Down

0 comments on commit c92c835

Please sign in to comment.