Skip to content

Commit

Permalink
generated code checks for null variables and throws an AssertionError
Browse files Browse the repository at this point in the history
istead of carrying along thinking its ok to pass 0 to C code
  • Loading branch information
crowlogic committed Jan 20, 2024
1 parent ba6e8ef commit 4233305
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 18 deletions.
38 changes: 23 additions & 15 deletions P.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import arb.functions.Function;

public class P implements Function<Integer, RealPolynomial> {
public Integer _c0 = new Integer("0");
public Integer _c1 = new Integer("1");
public Integer _c2 = new Integer("2");
public Integer const1 = new Integer("0");
public Integer const2 = new Integer("1");
public Integer const3 = new Integer("2");
public Real α;
public Real β;
public Real r1 = new Real();
Expand All @@ -27,40 +27,48 @@ public class P implements Function<Integer, RealPolynomial> {
public Real r5 = new Real();
public Real r6 = new Real();
public Function<Integer, RealPolynomial> P;
public Function<Integer, RealPolynomial> A = new A();
public Function<Real, Real> B = new B();
public Function<Real, Real> C = new C();
public Function<Real, Real> E = new E();
public Function<Integer, RealPolynomial> A;
public Function<Real, Real> B;
public Function<Real, Real> C;
public Function<Real, Real> E;

public RealPolynomial evaluate(Integer in, int order, int bits, RealPolynomial result) {
return switch(in.getSignedValue()) {
case 0 -> result.set(this._c1);
case 1 -> ((Real)this.C.evaluate(this.r1.set(this._c1), order, bits, this.r2))
case 0 -> result.set(this.const2);
case 1 -> ((Real)this.C.evaluate(this.r1.set(this.const2), order, bits, this.r2))
.mul(result.identity(), bits, this.rp1)
.sub(this.β, bits, this.rp2)
.add(this.α, bits, this.rp3)
.div(this._c2, bits, result);
.div(this.const3, bits, result);
default -> {
RealPolynomial var5 = (RealPolynomial)this.A.evaluate(in, order, bits, this.rp4);
if (this.P == null) {
this.P = new P(this);
}

var5 = var5.mul((RealPolynomial)this.P.evaluate(in.sub(this._c1, bits, this.i1), order, bits, this.rp5), bits, this.rp6);
var5 = var5.mul((RealPolynomial)this.P.evaluate(in.sub(this.const2, bits, this.i1), order, bits, this.rp5), bits, this.rp6);
Real var10001 = (Real)this.B.evaluate(this.r3.set(in), order, bits, this.r4);
if (this.P == null) {
this.P = new P(this);
}

yield var5.sub(
var10001.mul((RealPolynomial)this.P.evaluate(in.sub(this._c2, bits, this.i2), order, bits, this.rp7), bits, this.rp8), bits, this.rp9
var10001.mul((RealPolynomial)this.P.evaluate(in.sub(this.const3, bits, this.i2), order, bits, this.rp7), bits, this.rp8), bits, this.rp9
)
.div((Real)this.E.evaluate(this.r5.set(in), order, bits, this.r6), bits, result);
}
};
}

public P() {
this.initializeContextualFunctions();
}

public void initializeContextualFunctions() {
this.A = new A();
this.B = new B();
this.C = new C();
this.E = new E();
}

public P(P var1) {
Expand All @@ -70,9 +78,9 @@ public P(P var1) {
}

public void close() {
this._c0.close();
this._c1.close();
this._c2.close();
this.const1.close();
this.const2.close();
this.const3.close();
this.r1.close();
this.r2.close();
this.rp1.close();
Expand Down
28 changes: 28 additions & 0 deletions src/main/java/arb/expressions/Compiler.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,34 @@ public class Compiler
{
public static final String objectDesc = Type.getInternalName(Object.class);

public static void addNullCheckForField(MethodVisitor mv, String className, String fieldName, String fieldDesc)
{
Label notNullLabel = new Label();

// Load 'this' onto the stack
mv.visitVarInsn(Opcodes.ALOAD, 0);

// Get the field value
mv.visitFieldInsn(Opcodes.GETFIELD, className, fieldName, fieldDesc);

// Check if the field value is null
mv.visitJumpInsn(Opcodes.IFNONNULL, notNullLabel);

// If null, throw AssertionError
mv.visitTypeInsn(Opcodes.NEW, Type.getInternalName(AssertionError.class));
mv.visitInsn(Opcodes.DUP);
mv.visitLdcInsn(fieldName + " is null");
mv.visitMethodInsn(Opcodes.INVOKESPECIAL,
Type.getInternalName(AssertionError.class),
"<init>",
"(Ljava/lang/Object;)V",
false);
mv.visitInsn(Opcodes.ATHROW);

// Label for not null case
mv.visitLabel(notNullLabel);
}

public static <D, R, F extends Function<D, R>> Expression<D, R, F> compile(String expression,
Context context,
Class<? extends D> domainClass,
Expand Down
28 changes: 25 additions & 3 deletions src/main/java/arb/expressions/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public class Expression<D, R, F extends Function<D, R>> implements
{
private static final String contextualFunctionInitializationMethod = "initializeContextualFunctions";

public static final String evaluationMethodDescriptor = "(Ljava/lang/Object;IILjava/lang/Object;)Ljava/lang/Object;";
public static final String evaluationMethodDescriptor = "(Ljava/lang/Object;IILjava/lang/Object;)Ljava/lang/Object;";

public static <D, R, F extends Function<D, R>> F instantiate(String expression,
Context context,
Expand Down Expand Up @@ -429,6 +429,7 @@ public ClassVisitor generateEvaluationMethod(ClassVisitor classVisitor) throws E
methodVisitor.visitCode();
methodVisitor.visitLabel(startLabel);


Node<D, R, F> rootNode = parseRootNode();

if (position < expression.length())
Expand All @@ -440,6 +441,8 @@ public ClassVisitor generateEvaluationMethod(ClassVisitor classVisitor) throws E
expression.length()));
}

addChecksForNullVariableReferences(methodVisitor);

rootNode.generate(methodVisitor, rangeType);

methodVisitor.visitInsn(Opcodes.ARETURN);
Expand All @@ -454,6 +457,17 @@ public ClassVisitor generateEvaluationMethod(ClassVisitor classVisitor) throws E
return classVisitor;
}

public void addChecksForNullVariableReferences(MethodVisitor methodVisitor)
{
for (var variable : referencedVariables.keySet())
{
addNullCheckForField(methodVisitor,
className,
variable,
context.variables.map.get(variable).getClass().descriptorString());
}
}

public MethodVisitor declareLocalVariables(MethodVisitor methodVisitor, Label startLabel, Label endLabel)
{
// String objectClassDescriptor = Object.class.descriptorString();
Expand Down Expand Up @@ -1191,7 +1205,11 @@ public ClassVisitor generateDefaultConstructor(ClassVisitor classVisitor)

initializeIntermediateVariables(methodVisitor);

loadThisOntoStack(methodVisitor).visitMethodInsn(Opcodes.INVOKEVIRTUAL, className, contextualFunctionInitializationMethod, "()V", false);
loadThisOntoStack(methodVisitor).visitMethodInsn(Opcodes.INVOKEVIRTUAL,
className,
contextualFunctionInitializationMethod,
"()V",
false);

methodVisitor.visitInsn(RETURN);
methodVisitor.visitMaxs(0, 0);
Expand All @@ -1201,7 +1219,11 @@ public ClassVisitor generateDefaultConstructor(ClassVisitor classVisitor)

public ClassVisitor generateInitializationMethod(ClassVisitor classVisitor)
{
MethodVisitor methodVisitor = classVisitor.visitMethod(Opcodes.ACC_PUBLIC, contextualFunctionInitializationMethod, "()V", null, null);
MethodVisitor methodVisitor = classVisitor.visitMethod(Opcodes.ACC_PUBLIC,
contextualFunctionInitializationMethod,
"()V",
null,
null);
try
{
methodVisitor.visitCode();
Expand Down

0 comments on commit 4233305

Please sign in to comment.