From a7dc2323c587b928660a561035ac8baeb279f0e1 Mon Sep 17 00:00:00 2001 From: Stephen Crowley Date: Wed, 4 Dec 2024 23:33:45 -0600 Subject: [PATCH] https://github.com/crowlogic/arb4j/issues/253 --- README.md | 4 +- expressions/P.yaml | 16 ++ src/main/java/arb/expressions/nodes/Node.java | 36 +++- .../expressions/nodes/unary/FunctionNode.java | 6 +- .../arb/expressions/nodes/unary/WhenNode.java | 199 ++++++++++-------- .../expressions/nodes/DerivativeNodeTest.java | 8 + todoList.txt | 1 + 7 files changed, 165 insertions(+), 105 deletions(-) create mode 100644 expressions/P.yaml diff --git a/README.md b/README.md index 71ec152da..bb89f4995 100644 --- a/README.md +++ b/README.md @@ -48,8 +48,8 @@ try (Real x = new Real("25", 128)) { #### Expression Compiler - The [arb.expressions](https://github.com/crowlogic/arb4j/tree/master/src/main/java/arb/expressions) package in arb4j includes tools for compiling mathematical [expressions](https://github.com/crowlogic/arb4j/blob/master/src/main/java/arb/expressions/Expression.java) directly into Java bytecode, saving milleniums of development time, reducing the need to laborously and tediously write new code for each different formula to be evaluated whilst also ensuring efficiency and correctness; it would be challenging to write code manually that would significantly outperform the generated code -##### ExpressionAnalyzer -The [ExpressionAnalyzer](https://github.com/crowlogic/arb4j/tree/master/src/main/java/arb/viz/ExpressionAnalyzer.java) provides a tree-list view that shows the abstract-syntex-tree that constitutes +##### Expressor +The [Expressor](https://github.com/crowlogic/arb4j/tree/master/src/main/java/arb/expressions/viz/Expressor.java) program provides a tree-list view that shows the abstract-syntex-tree that constitutes a given expression and the intermediate values that combine to produce a given result. ![Screenshot from 2024-08-25 21-42-44](https://github.com/user-attachments/assets/cd1d71de-bcef-4be6-b25a-3c41293de158) diff --git a/expressions/P.yaml b/expressions/P.yaml new file mode 100644 index 000000000..b42459e79 --- /dev/null +++ b/expressions/P.yaml @@ -0,0 +1,16 @@ +--- !!arb.expressions.SerializedExpression +coDomain: arb.RationalFunction +context: + input: + - arb.Integer + - 2 + β: + - arb.Real + - '-0.5' + α: + - arb.Real + - '-0.5' +domain: arb.Integer +expression: n➔when(n=0,1,n=1,(C(1)*x-β+α)/2.0,else,(A(n)*P(n-1)-B(n)*P(n-2))/E(n)) +function: arb.functions.rational.RationalFunctionSequence +... diff --git a/src/main/java/arb/expressions/nodes/Node.java b/src/main/java/arb/expressions/nodes/Node.java index 212479611..9d6ead455 100644 --- a/src/main/java/arb/expressions/nodes/Node.java +++ b/src/main/java/arb/expressions/nodes/Node.java @@ -269,7 +269,12 @@ public Node mul(Node multiplicand) public Node pow(int i) { - return pow(LiteralConstantNode.of(expression, i)); + return pow(nodeOf(i)); + } + + public Node nodeOf(int i) + { + return LiteralConstantNode.of(expression, i); } public Node neg() @@ -280,12 +285,12 @@ public Node neg() public Node sub(int i) { - return sub(LiteralConstantNode.of(expression, i)); + return sub(nodeOf(i)); } public Node div(int i) { - return div(LiteralConstantNode.of(expression, i)); + return div(nodeOf(i)); } public Node pow(String exponent) @@ -296,23 +301,34 @@ public Node pow(String exponent) public Node cos() { - return new FunctionNode<>("cos", + return apply("cos"); + } + + public Node apply(String functionName) + { + return new FunctionNode<>(functionName, this, expression); } public Node sin() { - return new FunctionNode<>("sin", - this, - expression); + return apply("sin"); } public Node sqrt() { - return new FunctionNode<>("sqrt", - this, - expression); + return apply("sqrt"); + } + + public Node tan() + { + return apply("tan"); + } + + public Node tanh() + { + return apply("tanh"); } } diff --git a/src/main/java/arb/expressions/nodes/unary/FunctionNode.java b/src/main/java/arb/expressions/nodes/unary/FunctionNode.java index 9910de6a4..5606d6566 100644 --- a/src/main/java/arb/expressions/nodes/unary/FunctionNode.java +++ b/src/main/java/arb/expressions/nodes/unary/FunctionNode.java @@ -465,7 +465,7 @@ private Node differentiateBuiltinFunction() switch (functionName) { case "arcsin": - var one = expression.newLiteralConstant(1); + var one = nodeOf(1); return one.div(one.sub(arg.pow(2)).sqrt()); case "sin": return arg.cos(); @@ -474,7 +474,9 @@ private Node differentiateBuiltinFunction() case "exp": return this; case "log": - return expression.newLiteralConstant(1).div(arg); + return nodeOf(1).div(arg); + case "tanh": + return nodeOf(1).sub(arg.tanh().pow(2)); default: throw new UnsupportedOperationException("Derivative not implemented for function: " + functionName); } diff --git a/src/main/java/arb/expressions/nodes/unary/WhenNode.java b/src/main/java/arb/expressions/nodes/unary/WhenNode.java index 837a0881f..e22dc1722 100644 --- a/src/main/java/arb/expressions/nodes/unary/WhenNode.java +++ b/src/main/java/arb/expressions/nodes/unary/WhenNode.java @@ -9,6 +9,7 @@ import java.util.TreeMap; import java.util.function.Consumer; import java.util.stream.Collectors; +import java.util.stream.Stream; import org.objectweb.asm.Label; import org.objectweb.asm.MethodVisitor; @@ -46,46 +47,8 @@ public class WhenNode> extend UnaryOperationNode { - private static final String INTEGER_CLASS_INTERNAL_NAME = Type.getInternalName(Integer.class); private static final String INT_METHOD_DESCRIPTOR = Type.getMethodDescriptor(Type.getType(int.class)); - - void evaluateCase(TreeMap> cases, VariableNode variable) - { - if (!variable.reference.equals(expression.independentVariable.reference)) - { - throw new CompilerException("condition of when statement must be the equality of the input variable which is " - + expression.independentVariable - + " not " - + variable); - } - - if (!expression.nextCharacterIs('=')) - { - throw new CompilerException(format("= expected in condition of when function at pos=%d expression=%s but got ch=%c and lastCh=%c", - expression.position, - expression, - expression.character, - expression.previousCharacter)); - } - - var constant = evaluateCondition(); - var value = expression.resolve(); - cases.put(new Integer(constant.value), value); - } - - public LiteralConstantNode evaluateCondition() - { - Node condition = expression.evaluate(); - if (!(condition instanceof LiteralConstantNode)) - { - throw new CompilerException("condition of when statement must be the equality of the input variable to an " - + "Integer LiteralConstant type, but got " - + condition); - } - LiteralConstantNode constant = (LiteralConstantNode) condition; - expression.require(','); - return constant; - } + private static final String INTEGER_CLASS_INTERNAL_NAME = Type.getInternalName(Integer.class); public static > Node @@ -103,9 +66,10 @@ public LiteralConstantNode evaluateCondition() } public TreeMap> cases; + private Label defaultLabel = new Label(); - private Label endSwitch = new Label(); + private Label endSwitch = new Label(); private Label[] labels = null; public WhenNode(Expression expression) @@ -129,6 +93,60 @@ public WhenNode(Expression expression) } } + public > WhenNode(Expression expression, + TreeMap> cases) + { + super(null, + expression); + this.cases = new TreeMap<>(); + + for (var entry : cases.entrySet()) + { + var value = entry.getValue().spliceInto(expression); + this.cases.put(entry.getKey(), value); + } + + } + + @Override + public void accept(Consumer> t) + { + cases.forEach((id, branch) -> branch.accept(t)); + arg.accept(t); + t.accept(this); + } + + @Override + public Node differentiate(VariableNode variable) + { + assert false : "TODO"; + return null; + } + + void evaluateCase(TreeMap> cases, VariableNode variable) + { + if (!variable.reference.equals(expression.independentVariable.reference)) + { + throw new CompilerException("condition of when statement must be the equality of the input variable which is " + + expression.independentVariable + + " not " + + variable); + } + + if (!expression.nextCharacterIs('=')) + { + throw new CompilerException(format("= expected in condition of when function at pos=%d expression=%s but got ch=%c and lastCh=%c", + expression.position, + expression, + expression.character, + expression.previousCharacter)); + } + + var constant = evaluateCondition(); + var value = expression.resolve(); + cases.put(new Integer(constant.value), value); + } + public void evaluateCases() { Node node = expression.evaluate(); @@ -147,6 +165,20 @@ else if (node instanceof VariableNode variable) } } + public LiteralConstantNode evaluateCondition() + { + Node condition = expression.evaluate(); + if (!condition.isLiteralConstant()) + { + throw new CompilerException("condition of when statement must be the equality of the input variable to an " + + "Integer LiteralConstant type, but got " + + condition); + } + var constant = condition.asLiteralConstant(); + expression.require(','); + return constant; + } + @Override public MethodVisitor generate(MethodVisitor mv, Class resultType) { @@ -174,7 +206,6 @@ public MethodVisitor generate(MethodVisitor mv, Class resultType) for (int i = 0; i < labels.length; i++) { mv.visitLabel(labels[i]); - branches.get(i).generate(mv, resultType); mv.visitJumpInsn(GOTO, endSwitch); } @@ -203,28 +234,23 @@ public void generateIndex(MethodVisitor mv) } @Override - public String toString() + public List> getBranches() { - return String.format("When[cases=%s,default=%s]", - cases.entrySet() - .stream() - .map(node -> node.getKey() + "=" + node.getValue().typeset()) - .collect(Collectors.toList()), - arg.typeset()); + // the default branch is stored in this.arg + return Stream.concat(cases.values().stream(), Stream.of(arg)).toList(); } @Override - public Class type() + public String getIntermediateValueFieldName() { - return expression.coDomainType; + return null; } @Override - public String typeset() + public Node integrate(VariableNode variable) { - return cases.entrySet().stream().map(entry -> entry.getValue().typeset()).collect(Collectors.joining(", ")) - + " \text{otherwise} " - + arg.typeset(); + assert false : "TODO: Auto-generated method stub"; + return null; } @Override @@ -234,27 +260,15 @@ public boolean isLeaf() } @Override - public List> getBranches() + public boolean isLiteralConstant() { - assert false : "TODO: Auto-generated method stub"; - return null; + return arg.isLiteralConstant() && cases.values().stream().allMatch(Node::isLiteralConstant); } @Override - public Node integrate(VariableNode variable) - { - assert false : "TODO: Auto-generated method stub"; - return null; - } - - public > Node substitute(String variable, - Node arg) + public boolean isScalar() { - cases.entrySet() - .stream() - .toList() - .forEach(event -> cases.put(event.getKey(), event.getValue().substitute(variable, arg))); - return this; + return arg.isScalar(); } @Override @@ -262,46 +276,49 @@ public Node integrate(VariableNode variable) Node spliceInto(Expression newExpression) { - return new WhenNode(newExpression); - } - - @Override - public void accept(Consumer> t) - { - cases.forEach((id, branch) -> branch.accept(t)); - arg.accept(t); - t.accept(this); + return new WhenNode(newExpression, + cases); } - @Override - public Node differentiate(VariableNode variable) + public > Node substitute(String variable, + Node arg) { - assert false : "TODO"; - return null; + cases.entrySet() + .stream() + .toList() + .forEach(event -> cases.put(event.getKey(), event.getValue().substitute(variable, arg))); + return this; } @Override - public boolean isScalar() + public char symbol() { - return arg.isScalar(); + return '≡'; } @Override - public char symbol() + public String toString() { - return '≡'; + return String.format("When[cases=%s,default=%s]", + cases.entrySet() + .stream() + .map(node -> node.getKey() + "=" + node.getValue()) + .collect(Collectors.toList()), + arg); } @Override - public boolean isLiteralConstant() + public Class type() { - return arg.isLiteralConstant() && cases.values().stream().allMatch(Node::isLiteralConstant); + return expression.coDomainType; } @Override - public String getIntermediateValueFieldName() + public String typeset() { - return null; + return cases.entrySet().stream().map(entry -> entry.getValue().typeset()).collect(Collectors.joining(", ")) + + " \text{otherwise} " + + arg.typeset(); } } diff --git a/src/test/java/arb/expressions/nodes/DerivativeNodeTest.java b/src/test/java/arb/expressions/nodes/DerivativeNodeTest.java index d73fafacb..6f82c1d8a 100644 --- a/src/test/java/arb/expressions/nodes/DerivativeNodeTest.java +++ b/src/test/java/arb/expressions/nodes/DerivativeNodeTest.java @@ -15,6 +15,14 @@ public class DerivativeNodeTest extends TestCase { + + public void testTanhDerivative() + { + var f = RealFunction.parse("∂tanh(x)/∂x"); + var df = RealFunction.parse("1-tanh(x)²"); + assertEquals(df.rootNode.toString(), f.rootNode.toString()); + } + public void testArcSinDerivative() { var f = RealFunction.parse("∂arcsin(x)/∂x"); diff --git a/todoList.txt b/todoList.txt index 1b399c6a8..34090bb78 100644 --- a/todoList.txt +++ b/todoList.txt @@ -6,3 +6,4 @@ expr compiler: fractional derivatices implement ∂/∂x[∏ᵢ₌₁ᵏfᵢ(x)] = ∑ᵢ₌₁ᵏ[(∂/∂x fᵢ(x))∏ⱼ₌₁,ⱼ≠ᵢᵏfⱼ(x)] java.lang.NoSuchMethodError: 'arb.Integer arb.Fraction.floor(int, arb.Integer)'at n-Σk-ℭ(n,2*k)WherekEq0To⌊n⁄2⌋.evaluate(Unknown Source)at arb4j/arb.functions.Function.evaluate(Function.java:242) RationalFunction RationalJacobiPolynomials.evaluate(Integer t, int order, int bits, RationalFunction res)jshell> P.evaluate(0,128);$2 ==> 1jshell> P.evaluate(1,128);$3 ==> (3*x)/2jshell> P.evaluate(2,128);$4 ==> (20*x^2-5)/8jshell> P.evaluate(3,128);| Exception arb.exceptions.ArbException: numeratorAddress=123494138668048denominatorAddress=123494138667984numerator=123494139639664denominator=123494139639696| at RationalFunction.assertPointerConsistency (RationalFunction.java:532)| at RationalFunction.set (RationalFunction.java:598)| at RationalFunction.div (RationalFunction.java:56)| at P.evaluate (Unknown Source)| at P.evaluate (Unknown Source)| at RationalJacobiPolynomials.evaluate (RationalJacobiPolynomials.java:115)| at RationalJacobiPolynomials.evaluate (RationalJacobiPolynomials.java:1)| at Function.evaluate (Function.java:242)| at Function.evaluate (Function.java:222)| at Sequence.evaluate (Sequence.java:50)| at (#5:1) +java.lang.AssertionError: TODO: Auto-generated method stubat arb4j/arb.expressions.nodes.unary.WhenNode.getBranches(WhenNode.java:239)at arb4j/arb.expressions.viz.NodeTreeItem.buildChildren(NodeTreeItem.java:47)at arb4j/arb.expressions.viz.NodeTreeItem.getChildren(NodeTreeItem.java:40)at javafx.controls@23.0.1/javafx.scene.control.TreeItem.updateExpandedDescendentCount(TreeItem.java:918)