From 4233305a4d5c738966b73badde664f33a3d42aef Mon Sep 17 00:00:00 2001 From: "Stephen A. Crowley" Date: Fri, 19 Jan 2024 21:47:37 -0600 Subject: [PATCH] generated code checks for null variables and throws an AssertionError istead of carrying along thinking its ok to pass 0 to C code --- P.java | 38 +++++++++++-------- src/main/java/arb/expressions/Compiler.java | 28 ++++++++++++++ src/main/java/arb/expressions/Expression.java | 28 ++++++++++++-- 3 files changed, 76 insertions(+), 18 deletions(-) diff --git a/P.java b/P.java index 2bda2e3b9..a43fba93d 100644 --- a/P.java +++ b/P.java @@ -4,9 +4,9 @@ import arb.functions.Function; public class P implements Function { - public Integer _c0 = new Integer("0"); - public Integer _c1 = new Integer("1"); - public Integer _c2 = new Integer("2"); + public Integer const1 = new Integer("0"); + public Integer const2 = new Integer("1"); + public Integer const3 = new Integer("2"); public Real α; public Real β; public Real r1 = new Real(); @@ -27,33 +27,33 @@ public class P implements Function { public Real r5 = new Real(); public Real r6 = new Real(); public Function P; - public Function A = new A(); - public Function B = new B(); - public Function C = new C(); - public Function E = new E(); + public Function A; + public Function B; + public Function C; + public Function E; public RealPolynomial evaluate(Integer in, int order, int bits, RealPolynomial result) { return switch(in.getSignedValue()) { - case 0 -> result.set(this._c1); - case 1 -> ((Real)this.C.evaluate(this.r1.set(this._c1), order, bits, this.r2)) + case 0 -> result.set(this.const2); + case 1 -> ((Real)this.C.evaluate(this.r1.set(this.const2), order, bits, this.r2)) .mul(result.identity(), bits, this.rp1) .sub(this.β, bits, this.rp2) .add(this.α, bits, this.rp3) - .div(this._c2, bits, result); + .div(this.const3, bits, result); default -> { RealPolynomial var5 = (RealPolynomial)this.A.evaluate(in, order, bits, this.rp4); if (this.P == null) { this.P = new P(this); } - var5 = var5.mul((RealPolynomial)this.P.evaluate(in.sub(this._c1, bits, this.i1), order, bits, this.rp5), bits, this.rp6); + var5 = var5.mul((RealPolynomial)this.P.evaluate(in.sub(this.const2, bits, this.i1), order, bits, this.rp5), bits, this.rp6); Real var10001 = (Real)this.B.evaluate(this.r3.set(in), order, bits, this.r4); if (this.P == null) { this.P = new P(this); } yield var5.sub( - var10001.mul((RealPolynomial)this.P.evaluate(in.sub(this._c2, bits, this.i2), order, bits, this.rp7), bits, this.rp8), bits, this.rp9 + var10001.mul((RealPolynomial)this.P.evaluate(in.sub(this.const3, bits, this.i2), order, bits, this.rp7), bits, this.rp8), bits, this.rp9 ) .div((Real)this.E.evaluate(this.r5.set(in), order, bits, this.r6), bits, result); } @@ -61,6 +61,14 @@ public RealPolynomial evaluate(Integer in, int order, int bits, RealPolynomial r } public P() { + this.initializeContextualFunctions(); + } + + public void initializeContextualFunctions() { + this.A = new A(); + this.B = new B(); + this.C = new C(); + this.E = new E(); } public P(P var1) { @@ -70,9 +78,9 @@ public P(P var1) { } public void close() { - this._c0.close(); - this._c1.close(); - this._c2.close(); + this.const1.close(); + this.const2.close(); + this.const3.close(); this.r1.close(); this.r2.close(); this.rp1.close(); diff --git a/src/main/java/arb/expressions/Compiler.java b/src/main/java/arb/expressions/Compiler.java index 014a9649d..1cb39777e 100644 --- a/src/main/java/arb/expressions/Compiler.java +++ b/src/main/java/arb/expressions/Compiler.java @@ -53,6 +53,34 @@ public class Compiler { public static final String objectDesc = Type.getInternalName(Object.class); + public static void addNullCheckForField(MethodVisitor mv, String className, String fieldName, String fieldDesc) + { + Label notNullLabel = new Label(); + + // Load 'this' onto the stack + mv.visitVarInsn(Opcodes.ALOAD, 0); + + // Get the field value + mv.visitFieldInsn(Opcodes.GETFIELD, className, fieldName, fieldDesc); + + // Check if the field value is null + mv.visitJumpInsn(Opcodes.IFNONNULL, notNullLabel); + + // If null, throw AssertionError + mv.visitTypeInsn(Opcodes.NEW, Type.getInternalName(AssertionError.class)); + mv.visitInsn(Opcodes.DUP); + mv.visitLdcInsn(fieldName + " is null"); + mv.visitMethodInsn(Opcodes.INVOKESPECIAL, + Type.getInternalName(AssertionError.class), + "", + "(Ljava/lang/Object;)V", + false); + mv.visitInsn(Opcodes.ATHROW); + + // Label for not null case + mv.visitLabel(notNullLabel); + } + public static > Expression compile(String expression, Context context, Class domainClass, diff --git a/src/main/java/arb/expressions/Expression.java b/src/main/java/arb/expressions/Expression.java index 1ef7fd6b6..257b69cbb 100644 --- a/src/main/java/arb/expressions/Expression.java +++ b/src/main/java/arb/expressions/Expression.java @@ -76,7 +76,7 @@ public class Expression> implements { private static final String contextualFunctionInitializationMethod = "initializeContextualFunctions"; - public static final String evaluationMethodDescriptor = "(Ljava/lang/Object;IILjava/lang/Object;)Ljava/lang/Object;"; + public static final String evaluationMethodDescriptor = "(Ljava/lang/Object;IILjava/lang/Object;)Ljava/lang/Object;"; public static > F instantiate(String expression, Context context, @@ -429,6 +429,7 @@ public ClassVisitor generateEvaluationMethod(ClassVisitor classVisitor) throws E methodVisitor.visitCode(); methodVisitor.visitLabel(startLabel); + Node rootNode = parseRootNode(); if (position < expression.length()) @@ -440,6 +441,8 @@ public ClassVisitor generateEvaluationMethod(ClassVisitor classVisitor) throws E expression.length())); } + addChecksForNullVariableReferences(methodVisitor); + rootNode.generate(methodVisitor, rangeType); methodVisitor.visitInsn(Opcodes.ARETURN); @@ -454,6 +457,17 @@ public ClassVisitor generateEvaluationMethod(ClassVisitor classVisitor) throws E return classVisitor; } + public void addChecksForNullVariableReferences(MethodVisitor methodVisitor) + { + for (var variable : referencedVariables.keySet()) + { + addNullCheckForField(methodVisitor, + className, + variable, + context.variables.map.get(variable).getClass().descriptorString()); + } + } + public MethodVisitor declareLocalVariables(MethodVisitor methodVisitor, Label startLabel, Label endLabel) { // String objectClassDescriptor = Object.class.descriptorString(); @@ -1191,7 +1205,11 @@ public ClassVisitor generateDefaultConstructor(ClassVisitor classVisitor) initializeIntermediateVariables(methodVisitor); - loadThisOntoStack(methodVisitor).visitMethodInsn(Opcodes.INVOKEVIRTUAL, className, contextualFunctionInitializationMethod, "()V", false); + loadThisOntoStack(methodVisitor).visitMethodInsn(Opcodes.INVOKEVIRTUAL, + className, + contextualFunctionInitializationMethod, + "()V", + false); methodVisitor.visitInsn(RETURN); methodVisitor.visitMaxs(0, 0); @@ -1201,7 +1219,11 @@ public ClassVisitor generateDefaultConstructor(ClassVisitor classVisitor) public ClassVisitor generateInitializationMethod(ClassVisitor classVisitor) { - MethodVisitor methodVisitor = classVisitor.visitMethod(Opcodes.ACC_PUBLIC, contextualFunctionInitializationMethod, "()V", null, null); + MethodVisitor methodVisitor = classVisitor.visitMethod(Opcodes.ACC_PUBLIC, + contextualFunctionInitializationMethod, + "()V", + null, + null); try { methodVisitor.visitCode();