Skip to content

Commit

Permalink
Fix issue with call sites in super calls to constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
manuel-alvarez-alvarez committed Nov 21, 2024
1 parent 9bd1251 commit 99bb102
Show file tree
Hide file tree
Showing 11 changed files with 218 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ private static void writeStackOperations(final AdviceSpecification advice, final
String mode = "COPY";
if (allArgsSpec != null) {
if (advice instanceof AfterSpecification) {
mode = advice.isConstructor() ? "PREPEND_ARRAY_CTOR" : "PREPEND_ARRAY";
mode = "PREPEND_ARRAY";
} else {
mode = "APPEND_ARRAY";
}
Expand Down Expand Up @@ -344,12 +344,10 @@ private static void writeAdviceMethodCall(
final MethodCallExpr invokeStatic =
new MethodCallExpr()
.setScope(new NameExpr("handler"))
.setName("method")
.addArgument(opCode("INVOKESTATIC"))
.setName("advice")
.addArgument(new StringLiteralExpr(method.getOwner().getInternalName()))
.addArgument(new StringLiteralExpr(method.getMethodName()))
.addArgument(new StringLiteralExpr(method.getMethodType().getDescriptor()))
.addArgument(new BooleanLiteralExpr(false));
.addArgument(new StringLiteralExpr(method.getMethodType().getDescriptor()));
body.addStatement(invokeStatic);
}
if (requiresCast(advice)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ final class AdviceGeneratorTest extends BaseCsiPluginTest {
pointcut('java/security/MessageDigest', 'getInstance', '(Ljava/lang/String;)Ljava/security/MessageDigest;')
statements(
'handler.dupParameters(descriptor, StackDupMode.COPY);',
'handler.method(Opcodes.INVOKESTATIC, "datadog/trace/plugin/csi/impl/AdviceGeneratorTest$BeforeAdvice", "before", "(Ljava/lang/String;)V", false);',
'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$BeforeAdvice", "before", "(Ljava/lang/String;)V");',
'handler.method(opcode, owner, name, descriptor, isInterface);'
)
}
Expand Down Expand Up @@ -78,7 +78,7 @@ final class AdviceGeneratorTest extends BaseCsiPluginTest {
advices(0) {
pointcut('java/lang/String', 'replaceAll', '(Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;')
statements(
'handler.method(Opcodes.INVOKESTATIC, "datadog/trace/plugin/csi/impl/AdviceGeneratorTest$AroundAdvice", "around", "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;", false);'
'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$AroundAdvice", "around", "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;");'
)
}
}
Expand Down Expand Up @@ -110,7 +110,7 @@ final class AdviceGeneratorTest extends BaseCsiPluginTest {
statements(
'handler.dupInvoke(owner, descriptor, StackDupMode.COPY);',
'handler.method(opcode, owner, name, descriptor, isInterface);',
'handler.method(Opcodes.INVOKESTATIC, "datadog/trace/plugin/csi/impl/AdviceGeneratorTest$AfterAdvice", "after", "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;", false);',
'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$AfterAdvice", "after", "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;");',
)
}
}
Expand Down Expand Up @@ -140,9 +140,9 @@ final class AdviceGeneratorTest extends BaseCsiPluginTest {
advices(0) {
pointcut('java/net/URL', '<init>', '(Ljava/lang/String;)V')
statements(
'handler.dupParameters(descriptor, StackDupMode.PREPEND_ARRAY_CTOR);',
'handler.dupParameters(descriptor, StackDupMode.PREPEND_ARRAY);',
'handler.method(opcode, owner, name, descriptor, isInterface);',
'handler.method(Opcodes.INVOKESTATIC, "datadog/trace/plugin/csi/impl/AdviceGeneratorTest$AfterAdviceCtor", "after", "([Ljava/lang/Object;Ljava/net/URL;)Ljava/net/URL;", false);',
'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$AfterAdviceCtor", "after", "([Ljava/lang/Object;Ljava/net/URL;)Ljava/net/URL;");',
)
}
}
Expand Down Expand Up @@ -208,7 +208,7 @@ final class AdviceGeneratorTest extends BaseCsiPluginTest {
statements(
'handler.dupParameters(descriptor, StackDupMode.PREPEND_ARRAY);',
'handler.invokeDynamic(name, descriptor, bootstrapMethodHandle, bootstrapMethodArguments);',
'handler.method(Opcodes.INVOKESTATIC, "datadog/trace/plugin/csi/impl/AdviceGeneratorTest$InvokeDynamicAfterAdvice", "after", "([Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/String;", false);'
'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$InvokeDynamicAfterAdvice", "after", "([Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/String;");'
)
}
}
Expand Down Expand Up @@ -297,7 +297,7 @@ final class AdviceGeneratorTest extends BaseCsiPluginTest {
'handler.dupParameters(descriptor, StackDupMode.PREPEND_ARRAY);',
'handler.invokeDynamic(name, descriptor, bootstrapMethodHandle, bootstrapMethodArguments);',
'handler.loadConstantArray(bootstrapMethodArguments);',
'handler.method(Opcodes.INVOKESTATIC, "datadog/trace/plugin/csi/impl/AdviceGeneratorTest$InvokeDynamicWithConstantsAdvice", "after", "([Ljava/lang/Object;Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/String;", false);'
'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$InvokeDynamicWithConstantsAdvice", "after", "([Ljava/lang/Object;Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/String;");'
)
}
}
Expand Down Expand Up @@ -393,7 +393,7 @@ final class AdviceGeneratorTest extends BaseCsiPluginTest {
statements(
'int[] parameterIndices = new int[] { 0 };',
'handler.dupParameters(descriptor, parameterIndices, owner);',
'handler.method(Opcodes.INVOKESTATIC, "datadog/trace/plugin/csi/impl/AdviceGeneratorTest$PartialArgumentsBeforeAdvice", "before", "(Ljava/lang/String;)V", false);',
'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$PartialArgumentsBeforeAdvice", "before", "(Ljava/lang/String;)V");',
'handler.method(opcode, owner, name, descriptor, isInterface);',
)
}
Expand All @@ -402,7 +402,7 @@ final class AdviceGeneratorTest extends BaseCsiPluginTest {
statements(
'int[] parameterIndices = new int[] { 1 };',
'handler.dupParameters(descriptor, parameterIndices, null);',
'handler.method(Opcodes.INVOKESTATIC, "datadog/trace/plugin/csi/impl/AdviceGeneratorTest$PartialArgumentsBeforeAdvice", "before", "([Ljava/lang/Object;)V", false);',
'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$PartialArgumentsBeforeAdvice", "before", "([Ljava/lang/Object;)V");',
'handler.method(opcode, owner, name, descriptor, isInterface);',
)
}
Expand All @@ -411,7 +411,7 @@ final class AdviceGeneratorTest extends BaseCsiPluginTest {
statements(
'int[] parameterIndices = new int[] { 0 };',
'handler.dupInvoke(owner, descriptor, parameterIndices);',
'handler.method(Opcodes.INVOKESTATIC, "datadog/trace/plugin/csi/impl/AdviceGeneratorTest$PartialArgumentsBeforeAdvice", "before", "(Ljava/lang/String;I)V", false);',
'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$PartialArgumentsBeforeAdvice", "before", "(Ljava/lang/String;I)V");',
'handler.method(opcode, owner, name, descriptor, isInterface);',
)
}
Expand Down Expand Up @@ -441,9 +441,9 @@ final class AdviceGeneratorTest extends BaseCsiPluginTest {
advices(0) {
pointcut('java/lang/StringBuilder', '<init>', '(Ljava/lang/String;)V')
statements(
'handler.dupParameters(descriptor, StackDupMode.PREPEND_ARRAY_CTOR);',
'handler.dupParameters(descriptor, StackDupMode.PREPEND_ARRAY);',
'handler.method(opcode, owner, name, descriptor, isInterface);',
'handler.method(Opcodes.INVOKESTATIC, "datadog/trace/plugin/csi/impl/AdviceGeneratorTest$SuperTypeReturnAdvice", "after", "([Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;", false);',
'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$SuperTypeReturnAdvice", "after", "([Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");',
'handler.instruction(Opcodes.CHECKCAST, "java/lang/StringBuilder");'
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import datadog.trace.agent.tooling.csi.InvokeAdvice;
import datadog.trace.agent.tooling.csi.InvokeDynamicAdvice;
import java.security.ProtectionDomain;
import java.util.Deque;
import java.util.LinkedList;
import javax.annotation.Nonnull;
import net.bytebuddy.asm.AsmVisitorWrapper;
import net.bytebuddy.description.field.FieldDescription;
Expand Down Expand Up @@ -118,25 +120,48 @@ public MethodVisitor visitMethod(
private static class CallSiteMethodVisitor extends MethodVisitor
implements CallSiteAdvice.MethodHandler {
private final Advices advices;
private final Deque<String> newInvocations = new LinkedList<>();
private StackDupMode ctorDupMode = null;

private CallSiteMethodVisitor(
@Nonnull final Advices advices, @Nonnull final MethodVisitor delegated) {
super(ASM_API, delegated);
this.advices = advices;
}

@Override
public void visitTypeInsn(final int opcode, final String type) {
if (opcode == Opcodes.NEW) {
newInvocations.addLast(type);
}
super.visitTypeInsn(opcode, type);
}

@Override
public void visitMethodInsn(
final int opcode,
final String owner,
final String name,
final String descriptor,
final boolean isInterface) {
CallSiteAdvice advice = advices.findAdvice(owner, name, descriptor);
if (advice instanceof InvokeAdvice) {
((InvokeAdvice) advice).apply(this, opcode, owner, name, descriptor, isInterface);
} else {
mv.visitMethodInsn(opcode, owner, name, descriptor, isInterface);
try {
if (opcode == Opcodes.INVOKESPECIAL && "<init>".equals(name)) {
if (owner.equals(newInvocations.peekLast())) {
newInvocations.removeLast();
ctorDupMode = StackDupMode.PREPEND_ARRAY_ON_NEW_CTOR;
} else {
ctorDupMode = StackDupMode.PREPEND_ARRAY_ON_SUPER_CTOR;
}
}
CallSiteAdvice advice = advices.findAdvice(owner, name, descriptor);
// we cannot instrument calls to super in ctor
if (advice instanceof InvokeAdvice) {
((InvokeAdvice) advice).apply(this, opcode, owner, name, descriptor, isInterface);
} else {
mv.visitMethodInsn(opcode, owner, name, descriptor, isInterface);
}
} finally {
ctorDupMode = null;
}
}

Expand Down Expand Up @@ -197,6 +222,15 @@ public void method(
mv.visitMethodInsn(opcode, owner, name, descriptor, isInterface);
}

@Override
public void advice(String owner, String name, String descriptor) {
if (ctorDupMode == StackDupMode.PREPEND_ARRAY_ON_SUPER_CTOR) {
// append this to the stack after super call
mv.visitIntInsn(Opcodes.ALOAD, 0);
}
mv.visitMethodInsn(Opcodes.INVOKESTATIC, owner, name, descriptor, false);
}

@Override
public void invokeDynamic(
final String name,
Expand All @@ -206,6 +240,11 @@ public void invokeDynamic(
mv.visitInvokeDynamicInsn(name, descriptor, bootstrapMethodHandle, bootstrapMethodArguments);
}

@Override
public void dupCtorParameters(final String methodDescriptor) {
dupParameters(methodDescriptor, ctorDupMode);
}

@Override
public void dupParameters(final String methodDescriptor, final StackDupMode mode) {
final Type method = Type.getMethodType(methodDescriptor);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ public static void dup(final MethodVisitor mv, final Type[] parameters, final St
dup(mv, parameters);
break;
case PREPEND_ARRAY:
case PREPEND_ARRAY_CTOR:
case PREPEND_ARRAY_ON_NEW_CTOR:
case PREPEND_ARRAY_ON_SUPER_CTOR:
case APPEND_ARRAY:
dupN(mv, parameters, mode);
break;
Expand Down Expand Up @@ -279,12 +280,22 @@ private static void dupN(
loadArray(mv, arraySize, parameters);
mv.visitInsn(POP);
break;
case PREPEND_ARRAY_CTOR:
// move the array before the NEW and DUP opcodes
case PREPEND_ARRAY_ON_NEW_CTOR:
// move the array before the uninitialized entry created by NEW and DUP
// stack start = [uninitialized, uninitialized, arg_0, ..., arg_n]
// stack end = [array, uninitialized, uninitialized, arg_0, ..., arg_n]
mv.visitInsn(DUP_X2);
loadArray(mv, arraySize, parameters);
mv.visitInsn(POP);
break;
case PREPEND_ARRAY_ON_SUPER_CTOR:
// move the array before the uninitialized entry
// stack start = [uninitialized, arg_0, ..., arg_n]
// stack end = [array, uninitialized, arg_0, ..., arg_n]
mv.visitInsn(DUP_X1);
loadArray(mv, arraySize, parameters);
mv.visitInsn(POP);
break;
case APPEND_ARRAY:
loadArray(mv, arraySize, parameters);
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,19 @@ interface MethodHandler {
/** Performs a method invocation (static, special, virtual, interface...) */
void method(int opcode, String owner, String name, String descriptor, boolean isInterface);

/** Performs an advice invocation (always static) */
void advice(String owner, String name, String descriptor);

/** Performs a dynamic method invocation */
void invokeDynamic(
String name,
String descriptor,
Handle bootstrapMethodHandle,
Object... bootstrapMethodArguments);

/** Duplicates all the ctor parameters in the stack just before the method is invoked. */
void dupCtorParameters(String methodDescriptor);

/** Duplicates all the method parameters in the stack just before the method is invoked. */
void dupParameters(String methodDescriptor, StackDupMode mode);

Expand Down Expand Up @@ -62,11 +68,10 @@ enum StackDupMode {
COPY,
/** Copies the parameters in an array and prepends it */
PREPEND_ARRAY,
/**
* Copies the parameters in an array, prepends it and swaps the array with the uninitialized
* instance in a ctor
*/
PREPEND_ARRAY_CTOR,
/** Copies the parameters in an array and adds it between NEW and DUP opcodes */
PREPEND_ARRAY_ON_NEW_CTOR,
/** Copies the parameters in an array and adds it before the uninitialized instance in a ctor */
PREPEND_ARRAY_ON_SUPER_CTOR,
/** Copies the parameters in an array and appends it */
APPEND_ARRAY
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import net.bytebuddy.jar.asm.Type
import net.bytebuddy.matcher.ElementMatcher
import net.bytebuddy.utility.JavaModule
import net.bytebuddy.utility.nullability.MaybeNull

import java.lang.reflect.Constructor
import java.security.MessageDigest


Expand Down Expand Up @@ -81,6 +83,10 @@ class BaseCallSiteTest extends DDSpecification {
return buildPointcut(String.getDeclaredMethod('concat', String))
}

protected static Pointcut stringReaderPointcut() {
return buildPointcut(StringReader.getDeclaredConstructor(String))
}

protected static Pointcut messageDigestGetInstancePointcut() {
return buildPointcut(MessageDigest.getDeclaredMethod('getInstance', String))
}
Expand All @@ -100,6 +106,10 @@ class BaseCallSiteTest extends DDSpecification {
return buildPointcut(Type.getType(executable.getDeclaringClass()).internalName, executable.name, Type.getType(executable).descriptor)
}

protected static Pointcut buildPointcut(final Constructor<?> executable) {
return buildPointcut(Type.getType(executable.getDeclaringClass()).internalName, "<init>", Type.getType(executable).descriptor)
}

protected static Pointcut buildPointcut(final String type, final String method, final String descriptor) {
return new Pointcut(type: type, method: method, descriptor: descriptor)
}
Expand Down Expand Up @@ -157,6 +167,13 @@ class BaseCallSiteTest extends DDSpecification {
return clazz.getConstructor().newInstance()
}

protected static Class<?> loadClass(final Type type,
final byte[] data,
final ClassLoader loader = Thread.currentThread().contextClassLoader) {
final classLoader = new ByteArrayClassLoader(loader, [(type.className): data])
return classLoader.loadClass(type.className)
}

protected static byte[] transformType(final Type source,
final Type target,
final CallSiteTransformer transformer,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package datadog.trace.agent.tooling.csi

import datadog.trace.agent.tooling.bytebuddy.csi.CallSiteTransformer
import net.bytebuddy.asm.AsmVisitorWrapper
import net.bytebuddy.description.type.TypeDescription
import net.bytebuddy.dynamic.DynamicType
import net.bytebuddy.jar.asm.Type

import java.util.concurrent.atomic.AtomicInteger

class CallSiteInstrumentationTest extends BaseCallSiteTest {

Expand Down Expand Up @@ -60,6 +64,36 @@ class CallSiteInstrumentationTest extends BaseCallSiteTest {
0 * builder.visit(_ as AsmVisitorWrapper) >> builder
}

void 'test call site transformer with super call in ctor'() {
setup:
SuperInCtorExampleAdvice.CALLS.set(0)
final source = Type.getType(SuperInCtorExample)
final target = renameType(source, 'Test')
final pointcut = stringReaderPointcut()
final InvokeAdvice advice = new InvokeAdvice() {
@Override
void apply(CallSiteAdvice.MethodHandler handler, int opcode, String owner, String name, String descriptor, boolean isInterface) {
handler.dupCtorParameters(descriptor)
handler.method(opcode, owner, name, descriptor, isInterface)
handler.advice(
Type.getType(SuperInCtorExampleAdvice).internalName,
'onInvoke',
Type.getMethodType(Type.getType(StringReader), Type.getType(Object[]), Type.getType(StringReader)).getDescriptor(),
)
}
}
final callSiteTransformer = new CallSiteTransformer(mockAdvices([mockCallSites(advice, pointcut)]))

when:
final transformedClass = transformType(source, target, callSiteTransformer)
final transformed = loadClass(target, transformedClass)
final reader = transformed.newInstance("test")

then:
reader != null
SuperInCtorExampleAdvice.CALLS.get() == 1
}

static class StringCallSites implements CallSites, TestCallSites {

@Override
Expand All @@ -82,4 +116,14 @@ class CallSiteInstrumentationTest extends BaseCallSiteTest {
handler.method(opcode, owner, name, descriptor, isInterface)
}
}

static class SuperInCtorExampleAdvice {

private static final AtomicInteger CALLS = new AtomicInteger(0)

static StringReader onInvoke(Object[] args, StringReader result) {
CALLS.incrementAndGet()
return result
}
}
}
Loading

0 comments on commit 99bb102

Please sign in to comment.