Skip to content

Commit

Permalink
Optimizing the preprocessing hooks to reduce SA overhead
Browse files Browse the repository at this point in the history
  • Loading branch information
IshikaDawda committed Nov 20, 2024
1 parent 8c421eb commit 8be8ec3
Show file tree
Hide file tree
Showing 184 changed files with 520 additions and 1,440 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,23 @@

package akka.http.scaladsl

import akka.Done
import akka.http.scaladsl.model.{HttpEntity, HttpResponse}
import akka.http.scaladsl.server.AkkaCoreUtils
import akka.stream.Materializer
import akka.stream.javadsl.Source
import akka.stream.scaladsl.Sink
import akka.util.ByteString
import com.newrelic.api.agent.NewRelic
import com.newrelic.api.agent.security.NewRelicSecurity
import com.newrelic.api.agent.security.instrumentation.helpers.GenericHelper
import com.newrelic.api.agent.security.schema.exceptions.NewRelicSecurityException
import com.newrelic.api.agent.security.utils.logging.LogLevel
import com.newrelic.api.agent.{NewRelic, Token}

import java.lang
import scala.concurrent.{ExecutionContext, Future}
import scala.runtime.AbstractFunction1

class AkkaResponseHelper extends AbstractFunction1[HttpResponse, HttpResponse] {

override def apply(httpResponse: HttpResponse): HttpResponse = {
try {
val stringResponse = new lang.StringBuilder()
val isLockAquired = AkkaCoreUtils.acquireServletLockIfPossible()
val isLockAquired = GenericHelper.acquireLockIfPossible(AkkaCoreUtils.NR_SEC_CUSTOM_ATTRIB_NAME);
stringResponse.append(httpResponse.entity.asInstanceOf[HttpEntity.Strict].getData().decodeString("utf-8"))
AkkaCoreUtils.postProcessHttpRequest(isLockAquired, stringResponse, httpResponse.entity.contentType.toString(), this.getClass.getName, "apply", NewRelic.getAgent.getTransaction.getToken())
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import com.newrelic.api.agent.security.schema.AgentMetaData;
import com.newrelic.api.agent.security.schema.SecurityMetaData;
import com.newrelic.api.agent.security.schema.StringUtils;
import com.newrelic.api.agent.security.schema.VulnerabilityCaseType;
import com.newrelic.api.agent.security.schema.exceptions.NewRelicSecurityException;
import com.newrelic.api.agent.security.schema.operation.RXSSOperation;
import com.newrelic.api.agent.security.schema.policy.AgentPolicy;
Expand All @@ -26,37 +27,10 @@ public class AkkaCoreUtils {
public static final String AKKA_HTTP_10_0_0 = "AKKA_HTTP_10.0.0";
private static final String X_FORWARDED_FOR = "x-forwarded-for";
private static final String EMPTY = "";
public static final String QUESTION_MARK = "?";

public static boolean isServletLockAcquired() {
try {
return NewRelicSecurity.isHookProcessingActive() &&
Boolean.TRUE.equals(NewRelicSecurity.getAgent().getSecurityMetaData().getCustomAttribute(getNrSecCustomAttribName(), Boolean.class));
} catch (Throwable ignored) {}
return false;
}

public static void releaseServletLock() {
try {
if(NewRelicSecurity.isHookProcessingActive()) {
NewRelicSecurity.getAgent().getSecurityMetaData().addCustomAttribute(getNrSecCustomAttribName(), null);
}
} catch (Throwable ignored){}
}

private static String getNrSecCustomAttribName() {
return NR_SEC_CUSTOM_ATTRIB_NAME;
}
private static final String QUESTION_MARK = "?";

public static boolean acquireServletLockIfPossible() {
try {
if (NewRelicSecurity.isHookProcessingActive() &&
!isServletLockAcquired()) {
NewRelicSecurity.getAgent().getSecurityMetaData().addCustomAttribute(getNrSecCustomAttribName(), true);
return true;
}
} catch (Throwable ignored){}
return false;
return GenericHelper.acquireLockIfPossible(VulnerabilityCaseType.REFLECTED_XSS, NR_SEC_CUSTOM_ATTRIB_NAME);
}

public static void postProcessHttpRequest(Boolean isServletLockAcquired, StringBuilder responseBody, String contentType, String className, String methodName, Token token) {
Expand Down Expand Up @@ -87,7 +61,7 @@ public static void postProcessHttpRequest(Boolean isServletLockAcquired, StringB
}
} finally {
if(isServletLockAcquired){
releaseServletLock();
GenericHelper.releaseLock(NR_SEC_CUSTOM_ATTRIB_NAME);
}
}
}
Expand All @@ -105,16 +79,12 @@ public static void preProcessHttpRequest (Boolean isServletLockAcquired, HttpReq
SecurityMetaData securityMetaData = NewRelicSecurity.getAgent().getSecurityMetaData();

com.newrelic.api.agent.security.schema.HttpRequest securityRequest = securityMetaData.getRequest();
if (securityRequest.isRequestParsed()) {
return;
}

AgentMetaData securityAgentMetaData = securityMetaData.getMetaData();

securityRequest.setMethod(httpRequest.method().value());
//TODO Client IP and PORT extraction is pending

// securityRequest.setClientIP();
securityRequest.setServerPort(httpRequest.getUri().port());

processHttpRequestHeader(httpRequest, securityRequest);
Expand Down Expand Up @@ -144,8 +114,8 @@ public static void preProcessHttpRequest (Boolean isServletLockAcquired, HttpReq
NewRelicSecurity.getAgent().log(LogLevel.WARNING, String.format(GenericHelper.ERROR_GENERATING_HTTP_REQUEST, AKKA_HTTP_10_0_0, ignored.getMessage()), ignored, AkkaCoreUtils.class.getName());
}
finally {
if(isServletLockAcquired()){
releaseServletLock();
if(GenericHelper.isLockAcquired(AkkaCoreUtils.NR_SEC_CUSTOM_ATTRIB_NAME)){
GenericHelper.releaseLock(NR_SEC_CUSTOM_ATTRIB_NAME);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,19 @@ import akka.http.scaladsl.model.HttpEntity
import akka.stream.javadsl.Source
import akka.stream.scaladsl.Sink
import akka.util.ByteString

import java.util.concurrent.atomic.AtomicBoolean
import java.util.logging.Level
import com.newrelic.api.agent.{NewRelic, Trace}
import com.newrelic.api.agent.security.NewRelicSecurity
import com.newrelic.api.agent.security.utils.logging.LogLevel
import com.newrelic.api.agent.{NewRelic, Trace}

import java.lang
import scala.collection.mutable
import java.util.concurrent.atomic.AtomicBoolean
import java.util.logging.Level
import scala.concurrent.Future
import scala.runtime.AbstractFunction1

object CsecAkkaHttpContextFunction {

final val retransformed = new AtomicBoolean(false)
private final val retransformed = new AtomicBoolean(false)

def contextWrapper(original: Function1[RequestContext, Future[RouteResult]]): Function1[RequestContext, Future[RouteResult]] = {
if (retransformed.compareAndSet(false, true)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import com.newrelic.api.agent.security.schema.AgentMetaData;
import com.newrelic.api.agent.security.schema.SecurityMetaData;
import com.newrelic.api.agent.security.schema.StringUtils;
import com.newrelic.api.agent.security.schema.VulnerabilityCaseType;
import com.newrelic.api.agent.security.schema.exceptions.NewRelicSecurityException;
import com.newrelic.api.agent.security.schema.operation.RXSSOperation;
import com.newrelic.api.agent.security.schema.policy.AgentPolicy;
Expand All @@ -28,37 +29,10 @@ public class AkkaCoreUtils {

private static final String X_FORWARDED_FOR = "x-forwarded-for";
private static final String EMPTY = "";
public static final String QUESTION_MARK = "?";

public static boolean isServletLockAcquired() {
try {
return NewRelicSecurity.isHookProcessingActive() &&
Boolean.TRUE.equals(NewRelicSecurity.getAgent().getSecurityMetaData().getCustomAttribute(getNrSecCustomAttribName(), Boolean.class));
} catch (Throwable ignored) {}
return false;
}

public static void releaseServletLock() {
try {
if(NewRelicSecurity.isHookProcessingActive()) {
NewRelicSecurity.getAgent().getSecurityMetaData().addCustomAttribute(getNrSecCustomAttribName(), null);
}
} catch (Throwable ignored){}
}

private static String getNrSecCustomAttribName() {
return NR_SEC_CUSTOM_ATTRIB_NAME;
}
private static final String QUESTION_MARK = "?";

public static boolean acquireServletLockIfPossible() {
try {
if (NewRelicSecurity.isHookProcessingActive() &&
!isServletLockAcquired()) {
NewRelicSecurity.getAgent().getSecurityMetaData().addCustomAttribute(getNrSecCustomAttribName(), true);
return true;
}
} catch (Throwable ignored){}
return false;
return GenericHelper.acquireLockIfPossible(VulnerabilityCaseType.REFLECTED_XSS, NR_SEC_CUSTOM_ATTRIB_NAME);
}

public static void postProcessHttpRequest(Boolean isServletLockAcquired, StringBuilder response, String contentType, int responseCode, String className, String methodName, Token token) {
Expand Down Expand Up @@ -94,7 +68,7 @@ public static void postProcessHttpRequest(Boolean isServletLockAcquired, StringB
NewRelicSecurity.getAgent().reportIncident(LogLevel.SEVERE, String.format(GenericHelper.REGISTER_OPERATION_EXCEPTION_MESSAGE, AKKA_HTTP_CORE_10_0, e.getMessage()), e, AkkaCoreUtils.class.getName());
} finally {
if(isServletLockAcquired){
releaseServletLock();
GenericHelper.releaseLock(NR_SEC_CUSTOM_ATTRIB_NAME);
}
}
}
Expand All @@ -121,7 +95,6 @@ public static void preProcessHttpRequest (Boolean isServletLockAcquired, HttpReq
securityRequest.setMethod(httpRequest.method().value());
//TODO Client IP and PORT extraction is pending

// securityRequest.setClientIP();
securityRequest.setServerPort(httpRequest.getUri().port());

processHttpRequestHeader(httpRequest, securityRequest);
Expand Down Expand Up @@ -150,13 +123,13 @@ public static void preProcessHttpRequest (Boolean isServletLockAcquired, HttpReq
NewRelicSecurity.getAgent().log(LogLevel.WARNING, String.format(GenericHelper.ERROR_GENERATING_HTTP_REQUEST, AKKA_HTTP_CORE_10_0, ignored.getMessage()), ignored, AkkaCoreUtils.class.getName());
}
finally {
if(isServletLockAcquired()){
releaseServletLock();
if(GenericHelper.isLockAcquired(NR_SEC_CUSTOM_ATTRIB_NAME)){
GenericHelper.releaseLock(NR_SEC_CUSTOM_ATTRIB_NAME);
}
}
}

public static String getTraceHeader(Map<String, String> headers) {
private static String getTraceHeader(Map<String, String> headers) {
String data = EMPTY;
if (headers.containsKey(ServletHelper.CSEC_DISTRIBUTED_TRACING_HEADER) || headers.containsKey(ServletHelper.CSEC_DISTRIBUTED_TRACING_HEADER.toLowerCase())) {
data = headers.get(ServletHelper.CSEC_DISTRIBUTED_TRACING_HEADER);
Expand All @@ -167,7 +140,7 @@ public static String getTraceHeader(Map<String, String> headers) {
return data;
}

public static void processHttpRequestHeader(HttpRequest request, com.newrelic.api.agent.security.schema.HttpRequest securityRequest){
private static void processHttpRequestHeader(HttpRequest request, com.newrelic.api.agent.security.schema.HttpRequest securityRequest){
Iterator<HttpHeader> headers = request.getHeaders().iterator();
while (headers.hasNext()) {
boolean takeNextValue = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,6 @@ private static void registerExitOperation(boolean isProcessingAllowed, AbstractO

private AbstractOperation preprocessSecurityHook(HttpRequest httpRequest, String methodName) {
try {
SecurityMetaData securityMetaData = NewRelicSecurity.getAgent().getSecurityMetaData();
if (!NewRelicSecurity.isHookProcessingActive() || securityMetaData.getRequest().isEmpty()) {
return null;
}

// Generate required URL
URI methodURI = null;
String uri = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ public InetSocketAddress localAddress() {

@WeaveAllConstructors
public ServerBinding() {
// AgentBridge.getAgent().getLogger().log(Level.FINE, "Setting akka-http port to: {0,number,#}", localAddress().getPort());
// AgentBridge.publicApi.setAppServerPort(localAddress().getPort());
// AgentBridge.publicApi.setServerInfo("Akka HTTP", ManifestUtils.getVersionFromManifest(getClass(), "akka-http-core", "10.2.0"));

NewRelicSecurity.getAgent().setApplicationConnectionConfig(localAddress().getPort(), "http");
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ object ResponseFutureHelper {
try {
val stringResponse: lang.StringBuilder = new lang.StringBuilder();
val dataBytes: Source[ByteString, _] = response.entity.getDataBytes()
val isLockAquired = AkkaCoreUtils.acquireServletLockIfPossible();
val isLockAquired = GenericHelper.acquireLockIfPossible(AkkaCoreUtils.NR_SEC_CUSTOM_ATTRIB_NAME);
val sink: Sink[ByteString, Future[Done]] = Sink.foreach[ByteString] { byteString =>
val chunk = byteString.utf8String
stringResponse.append(chunk)
Expand Down Expand Up @@ -61,7 +61,7 @@ object ResponseFutureHelper {
try {
val stringResponse: lang.StringBuilder = new lang.StringBuilder();
val dataBytes: Source[ByteString, _] = httpResponse.entity.getDataBytes()
val isLockAquired = AkkaCoreUtils.acquireServletLockIfPossible();
val isLockAquired = GenericHelper.acquireLockIfPossible(AkkaCoreUtils.NR_SEC_CUSTOM_ATTRIB_NAME);
val sink: Sink[ByteString, Future[Done]] = Sink.foreach[ByteString] { byteString =>
val chunk = byteString.utf8String
stringResponse.append(chunk)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import com.newrelic.api.agent.security.schema.AgentMetaData;
import com.newrelic.api.agent.security.schema.SecurityMetaData;
import com.newrelic.api.agent.security.schema.StringUtils;
import com.newrelic.api.agent.security.schema.VulnerabilityCaseType;
import com.newrelic.api.agent.security.schema.exceptions.NewRelicSecurityException;
import com.newrelic.api.agent.security.schema.operation.RXSSOperation;
import com.newrelic.api.agent.security.schema.policy.AgentPolicy;
Expand All @@ -28,37 +29,10 @@ public class AkkaCoreUtils {

private static final String X_FORWARDED_FOR = "x-forwarded-for";
private static final String EMPTY = "";
public static final String QUESTION_MARK = "?";

public static boolean isServletLockAcquired() {
try {
return NewRelicSecurity.isHookProcessingActive() &&
Boolean.TRUE.equals(NewRelicSecurity.getAgent().getSecurityMetaData().getCustomAttribute(getNrSecCustomAttribName(), Boolean.class));
} catch (Throwable ignored) {}
return false;
}

public static void releaseServletLock() {
try {
if(NewRelicSecurity.isHookProcessingActive()) {
NewRelicSecurity.getAgent().getSecurityMetaData().addCustomAttribute(getNrSecCustomAttribName(), null);
}
} catch (Throwable ignored){}
}

private static String getNrSecCustomAttribName() {
return NR_SEC_CUSTOM_ATTRIB_NAME;
}
private static final String QUESTION_MARK = "?";

public static boolean acquireServletLockIfPossible() {
try {
if (NewRelicSecurity.isHookProcessingActive() &&
!isServletLockAcquired()) {
NewRelicSecurity.getAgent().getSecurityMetaData().addCustomAttribute(getNrSecCustomAttribName(), true);
return true;
}
} catch (Throwable ignored){}
return false;
return GenericHelper.acquireLockIfPossible(VulnerabilityCaseType.REFLECTED_XSS, NR_SEC_CUSTOM_ATTRIB_NAME);
}

public static void postProcessHttpRequest(Boolean isServletLockAcquired, StringBuilder responseBody, String contentType, int responseCode, String className, String methodName, Token token) {
Expand Down Expand Up @@ -94,7 +68,7 @@ public static void postProcessHttpRequest(Boolean isServletLockAcquired, StringB
NewRelicSecurity.getAgent().reportIncident(LogLevel.SEVERE, String.format(GenericHelper.REGISTER_OPERATION_EXCEPTION_MESSAGE, AKKA_HTTP_CORE_10_0_11, e.getMessage()), e, AkkaCoreUtils.class.getName());
} finally {
if(isServletLockAcquired){
releaseServletLock();
GenericHelper.releaseLock(NR_SEC_CUSTOM_ATTRIB_NAME);
}
}
}
Expand All @@ -121,7 +95,6 @@ public static void preProcessHttpRequest (Boolean isServletLockAcquired, HttpReq
securityRequest.setMethod(httpRequest.method().value());
//TODO Client IP and PORT extraction is pending

// securityRequest.setClientIP();
securityRequest.setServerPort(httpRequest.getUri().getPort());

processHttpRequestHeader(httpRequest, securityRequest);
Expand Down Expand Up @@ -150,13 +123,13 @@ public static void preProcessHttpRequest (Boolean isServletLockAcquired, HttpReq
NewRelicSecurity.getAgent().log(LogLevel.WARNING, String.format(GenericHelper.ERROR_GENERATING_HTTP_REQUEST, AKKA_HTTP_CORE_10_0_11, ignored.getMessage()), ignored, AkkaCoreUtils.class.getName());
}
finally {
if(isServletLockAcquired()){
releaseServletLock();
if(GenericHelper.isLockAcquired(NR_SEC_CUSTOM_ATTRIB_NAME)){
GenericHelper.releaseLock(NR_SEC_CUSTOM_ATTRIB_NAME);
}
}
}

public static String getTraceHeader(Map<String, String> headers) {
private static String getTraceHeader(Map<String, String> headers) {
String data = EMPTY;
if (headers.containsKey(ServletHelper.CSEC_DISTRIBUTED_TRACING_HEADER) || headers.containsKey(ServletHelper.CSEC_DISTRIBUTED_TRACING_HEADER.toLowerCase())) {
data = headers.get(ServletHelper.CSEC_DISTRIBUTED_TRACING_HEADER);
Expand All @@ -167,7 +140,7 @@ public static String getTraceHeader(Map<String, String> headers) {
return data;
}

public static void processHttpRequestHeader(HttpRequest request, com.newrelic.api.agent.security.schema.HttpRequest securityRequest){
private static void processHttpRequestHeader(HttpRequest request, com.newrelic.api.agent.security.schema.HttpRequest securityRequest){
Iterator<HttpHeader> headers = request.getHeaders().iterator();
while (headers.hasNext()) {
boolean takeNextValue = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,6 @@ private static void registerExitOperation(boolean isProcessingAllowed, AbstractO

private AbstractOperation preprocessSecurityHook(HttpRequest httpRequest, String methodName) {
try {
SecurityMetaData securityMetaData = NewRelicSecurity.getAgent().getSecurityMetaData();
if (!NewRelicSecurity.isHookProcessingActive() || securityMetaData.getRequest().isEmpty()) {
return null;
}

// Generate required URL
URI methodURI = null;
String uri = null;
Expand Down
Loading

0 comments on commit 8be8ec3

Please sign in to comment.