Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes grpc tracing #293

Merged
merged 1 commit into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

import io.dongtai.iast.core.bytecode.enhance.IastContext;
import io.dongtai.iast.core.bytecode.enhance.plugin.AbstractClassVisitor;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Type;
import org.objectweb.asm.*;

public class AbstractStubAdapter extends AbstractClassVisitor {
public AbstractStubAdapter(ClassVisitor classVisitor, IastContext context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ public ClassVisitor dispatch(ClassVisitor classVisitor, IastContext context) {
case classOfByteString:
classVisitor = new ByteStringAdapter(classVisitor, null);
break;
default:
break;
}
return classVisitor;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ public ServerStreamListenerImplAdapter(ClassVisitor classVisitor, IastContext co
@Override
public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
MethodVisitor mv = super.visitMethod(access, name, descriptor, signature, exceptions);
if (name.equals("closed")) {
mv = new ServerStreamListenerImplAdviceAdapter(mv, access, name, descriptor);
if ("messagesAvailable".equals(name)) {
mv = new ServerStreamListenerImplStartAdviceAdapter(mv, access, name, descriptor);
setTransformed();
} else if ("closed".equals(name)) {
mv = new ServerStreamListenerImplClosedAdviceAdapter(mv, access, name, descriptor);
setTransformed();
}
return mv;
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.dongtai.iast.core.bytecode.enhance.plugin.framework.grpc;

import io.dongtai.iast.core.bytecode.enhance.asm.AsmMethods;
import io.dongtai.iast.core.bytecode.enhance.asm.AsmTypes;
import io.dongtai.iast.core.utils.AsmUtils;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.commons.AdviceAdapter;

public class ServerStreamListenerImplStartAdviceAdapter extends AdviceAdapter implements AsmTypes, AsmMethods {
protected ServerStreamListenerImplStartAdviceAdapter(MethodVisitor methodVisitor, int access, String name, String descriptor) {
super(AsmUtils.api, methodVisitor, access, name, descriptor);
}

@Override
protected void onMethodEnter() {
invokeStatic(ASM_TYPE_SPY_HANDLER, SPY_HANDLER$getDispatcher);
invokeInterface(ASM_TYPE_SPY_DISPATCHER, SPY$startGrpcCall);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.dongtai.iast.core.handler.context;

import io.dongtai.iast.core.EngineManager;

/**
* @author owefsad
*/
Expand Down Expand Up @@ -30,7 +32,10 @@ public static String getSpanId(String traceId, int agentId) {

public static String getSegmentId() {
TracingContext context = CONTEXT.get();
return context.createSegmentId();
if (context != null) {
return context.createSegmentId();
}
return getOrCreateGlobalTraceId(null, EngineManager.getAgentId());
}

public static String getHeaderKey() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package io.dongtai.iast.core.handler.context;

import java.lang.dongtai.TraceIdHandler;

public class TraceManager implements TraceIdHandler {
@Override
public String getTraceKey() {
return ContextManager.getHeaderKey();
}

@Override
public String getTraceId() {
return ContextManager.getSegmentId();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,15 @@
import io.dongtai.iast.core.handler.hookpoint.graphy.GraphBuilder;
import io.dongtai.iast.core.handler.hookpoint.models.MethodEvent;
import io.dongtai.iast.core.service.ErrorLogReport;
import io.dongtai.iast.core.utils.HttpClientUtils;
import io.dongtai.iast.core.utils.StackUtils;
import io.dongtai.iast.core.utils.TaintPoolUtils;
import io.dongtai.iast.core.utils.*;
import io.dongtai.log.DongTaiLog;

import java.io.File;
import java.lang.dongtai.TraceIdHandler;
import java.lang.reflect.Method;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.*;

public class GrpcHandler {
private static IastClassLoader gRpcClassLoader;
Expand Down Expand Up @@ -56,7 +52,7 @@ private static void createClassLoader(Object channel) {

Class<?> classOfGrpcProxy = gRpcClassLoader.loadClass("io.dongtai.plugin.GrpcProxy");
methodOfInterceptChannel = classOfGrpcProxy
.getDeclaredMethod("interceptChannel", Object.class, String.class, String.class);
.getDeclaredMethod("interceptChannel", Object.class, TraceIdHandler.class);
methodOfInterceptService = classOfGrpcProxy
.getDeclaredMethod("interceptService", Object.class);
methodOfGetRequestMetadata = classOfGrpcProxy.getDeclaredMethod("getServerMeta");
Expand All @@ -66,6 +62,10 @@ private static void createClassLoader(Object channel) {
}
}

public static void setSharedTraceId(String traceId) {
sharedTraceId.set(traceId);
}

/**
* 拦截 Grpc client 的 channel,后续client调用Server端服务会经过该拦截器
*
Expand All @@ -77,10 +77,7 @@ public static Object interceptChannel(Object channel) {
createClassLoader(channel);
}
try {
// todo: 考虑测试并发场景
String traceId = ContextManager.getSegmentId();
sharedTraceId.set(traceId);
return methodOfInterceptChannel.invoke(null, channel, ContextManager.getHeaderKey(), traceId);
return methodOfInterceptChannel.invoke(null, channel, new GrpcTraceManager());
} catch (Exception e) {
DongTaiLog.error(e);
}
Expand Down Expand Up @@ -213,11 +210,11 @@ public static void sendMessage(Object message) {
MethodEvent event = new MethodEvent(
0,
0,
"io.grpc.stub.ClientCalls",
"io.grpc.stub.ClientCalls",
"blockingUnaryCall",
"io.grpc.stub.ClientCalls.blockingUnaryCall(io.grpc.Channel, io.grpc.MethodDescriptor<ReqT,RespT>, io.grpc.CallOptions, ReqT)",
"io.grpc.stub.ClientCalls.blockingUnaryCall(io.grpc.Channel, io.grpc.MethodDescriptor<ReqT,RespT>, io.grpc.CallOptions, ReqT)",
"io.grpc.internal.ServerCallImpl",
"io.grpc.internal.ServerCallImpl",
"sendMessage",
"io.grpc.internal.ServerCallImpl.sendMessage(RespT)",
"io.grpc.internal.ServerCallImpl.sendMessage(RespT)",
null,
new Object[]{message},
null,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package io.dongtai.iast.core.handler.hookpoint.framework.grpc;

import io.dongtai.iast.core.handler.context.TraceManager;

public class GrpcTraceManager extends TraceManager {
@Override
public String getTraceId() {
String traceId = super.getTraceId();
GrpcHandler.setSharedTraceId(traceId);
return traceId;
}
}
4 changes: 4 additions & 0 deletions dongtai-plugins/dongtai-grpc/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
<version>${grpc-all.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>io.dongtai.iast</groupId>
<artifactId>dongtai-spy</artifactId>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package io.dongtai.plugin;

import io.grpc.ClientCall;
import io.grpc.ForwardingClientCall;
import io.grpc.Metadata;
import io.grpc.*;

import java.lang.dongtai.TraceIdHandler;

public class DongTaiClientCall<REQUEST, RESPONSE> extends ForwardingClientCall.SimpleForwardingClientCall<REQUEST, RESPONSE> {
String serviceName;
Expand All @@ -12,14 +12,14 @@ public class DongTaiClientCall<REQUEST, RESPONSE> extends ForwardingClientCall.S
String traceKey;
String traceId;

protected DongTaiClientCall(ClientCall<REQUEST, RESPONSE> delegate, String serviceName, String serviceType, String targetService, String traceKey, String traceId) {
protected DongTaiClientCall(ClientCall<REQUEST, RESPONSE> delegate, String serviceName, String serviceType, String targetService, TraceIdHandler traceIdHandler) {
super(delegate);
this.serviceName = serviceName;
this.serviceType = serviceType;
this.targetService = targetService;
this.pluginName = "GRPC";
this.traceKey = traceKey;
this.traceId = traceId;
this.traceKey = traceIdHandler.getTraceKey();
this.traceId = traceIdHandler.getTraceId();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@

import io.grpc.*;

import java.lang.dongtai.TraceIdHandler;

public class DongTaiClientInterceptor implements ClientInterceptor {
private String traceId;
private String traceKey;
private TraceIdHandler traceIdHandler;

public DongTaiClientInterceptor(String traceKey, String traceId) {
this.traceId = traceId;
this.traceKey = traceKey;
public DongTaiClientInterceptor(TraceIdHandler traceIdHandler) {
this.traceIdHandler = traceIdHandler;
}

@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel channel) {
String methodName = method.getFullMethodName();
String methodType = method.getType().toString();
String target = channel.toString();
return new DongTaiClientCall<ReqT, RespT>(channel.newCall(method, callOptions), methodName, methodType, target, traceKey, traceId);
return new DongTaiClientCall<ReqT, RespT>(channel.newCall(method, callOptions), methodName, methodType, target, traceIdHandler);
}
}
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
package io.dongtai.plugin;

import io.grpc.Channel;
import io.grpc.ClientInterceptors;
import io.grpc.ServerInterceptors;
import io.grpc.ServerServiceDefinition;
import io.grpc.*;

import java.lang.dongtai.TraceIdHandler;
import java.util.HashMap;
import java.util.Map;

public class GrpcProxy {
private static Map<String, Object> metadata;

public static Object interceptChannel(Object channel, String traceKey, String traceId) {
public static Object interceptChannel(Object channel, TraceIdHandler traceIdHandler) {
try {
Channel interceptedChannel = (Channel) channel;
return ClientInterceptors.intercept(interceptedChannel, new DongTaiClientInterceptor(traceKey, traceId));
return ClientInterceptors.intercept(interceptedChannel, new DongTaiClientInterceptor(traceIdHandler));
} catch (Exception e) {
// fixme: remove throw exception
e.printStackTrace();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package java.lang.dongtai;

public interface TraceIdHandler {
String getTraceKey();
String getTraceId();
}