Skip to content

Commit

Permalink
Merge pull request #217 from newrelic/enhancments/revamp-234880
Browse files Browse the repository at this point in the history
Revamp user class detection technique, use server level endpoints
  • Loading branch information
lovesh-ap authored May 14, 2024
2 parents 3de6420 + 729734c commit a8927a2
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ public class APIEndpointTest {
public void setupEndpoints() {
expectedMappings.put("/servlet/*", HttpServletServer.class.getName()+"$1");
expectedMappings.put("/index.jsp", null);
expectedMappings.put("/*.jsp", "org.apache.jasper.servlet.JspServlet");
expectedMappings.put("/", DefaultServlet.class.getName());
expectedMappings.put("/*.jspx", "org.apache.jasper.servlet.JspServlet");
}

@Test
Expand All @@ -43,7 +40,7 @@ public void testAPIEndpoint() throws Exception {

Set<ApplicationURLMapping> mappings = URLMappingsHelper.getApplicationURLMappings();
Assert.assertNotNull(mappings);
Assert.assertEquals(5, mappings.size());
Assert.assertEquals(2, mappings.size());
for (ApplicationURLMapping mapping : mappings) {
assertMappings(mapping);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ public class APIEndpointTest {
public void setupEndpoints() {
expectedMappings.put("/servlet/*", HttpServletServer.class.getName()+"$1");
expectedMappings.put("/index.jsp", null);
expectedMappings.put("/*.jsp", "org.apache.jasper.servlet.JspServlet");
expectedMappings.put("/", DefaultServlet.class.getName());
expectedMappings.put("/*.jspx", "org.apache.jasper.servlet.JspServlet");
}

@Test
Expand All @@ -43,7 +40,7 @@ public void testAPIEndpoint() throws Exception {

Set<ApplicationURLMapping> mappings = URLMappingsHelper.getApplicationURLMappings();
Assert.assertNotNull(mappings);
Assert.assertEquals(5, mappings.size());
Assert.assertEquals(2, mappings.size());
for (ApplicationURLMapping mapping : mappings) {
assertMappings(mapping);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ public class ServerTest {
@BeforeClass
public static void addMappings() {
actualMappings.put("/servlet/*", MyServlet.class.getName());
actualMappings.put("/", ServletHandler.Default404Servlet.class.getName());
}
@After
public void teardown() throws Exception {
Expand Down Expand Up @@ -201,7 +200,7 @@ public void testAPIEndpoint () throws Exception {
start();

Set<ApplicationURLMapping> mappings = URLMappingsHelper.getApplicationURLMappings();
Assert.assertEquals(2, mappings.size());
Assert.assertEquals(1, mappings.size());
for (ApplicationURLMapping mapping : mappings) {
Assert.assertNotNull(mapping);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public class ApiEndpointTest {

@Test
public void testURLMappings() {
String handler = DefaultServlet.class.getName();
String handler = MyServlet.class.getName();
String method = "*";
Iterator<ApplicationURLMapping> mappings = URLMappingsHelper.getApplicationURLMappings().iterator();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.File;
import java.io.IOException;
import java.net.ServerSocket;
Expand Down Expand Up @@ -64,7 +67,7 @@ public void onStartup(Set<Class<?>> c, ServletContext ctx){
}
}, Collections.emptySet());

Tomcat.addServlet(context, "servlet", new DefaultServlet());
Tomcat.addServlet(context, "servlet", new MyServlet());
context.addServletMappingDecoded("/*","servlet");
context.addServletMappingDecoded("/test","servlet");

Expand All @@ -88,4 +91,10 @@ private void stop() {
}
}
}
}
class MyServlet extends HttpServlet {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
super.doGet(req, resp);
}
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package com.nr.agent.security.instrumentation.memcached.test;

import com.github.mwarc.embeddedmemcached.JMemcachedServer;
import com.newrelic.agent.security.instrumentation.spy.memcached.MemcachedHelper;
import com.newrelic.agent.security.introspec.InstrumentationTestConfig;
import com.newrelic.agent.security.introspec.SecurityInstrumentationTestRunner;
import com.newrelic.agent.security.introspec.SecurityIntrospector;
import com.newrelic.api.agent.security.schema.AbstractOperation;
import com.newrelic.api.agent.security.schema.VulnerabilityCaseType;
import com.newrelic.api.agent.security.schema.operation.MemcachedOperation;
import net.spy.memcached.MemcachedClient;
import net.spy.memcached.ops.StoreType;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
Expand Down Expand Up @@ -68,7 +70,7 @@ public void testSet() {
Assert.assertEquals("No operations detected.", 1, operations.size());
MemcachedOperation operation = (MemcachedOperation) operations.get(0);

verifier(operation, Arrays.asList(key, value), WRITE, "asyncStore");
verifier(operation, Arrays.asList(key, value), WRITE, "asyncStore", "set");
}

@Test
Expand All @@ -80,7 +82,7 @@ public void testAdd() {
Assert.assertEquals("No operations detected.", 1, operations.size());
MemcachedOperation operation = (MemcachedOperation) operations.get(0);

verifier(operation, Arrays.asList(key, value), WRITE, "asyncStore");
verifier(operation, Arrays.asList(key, value), WRITE, "asyncStore", "add");
}
@Test
public void testReplace() {
Expand All @@ -91,7 +93,7 @@ public void testReplace() {
Assert.assertEquals("No operations detected.", 1, operations.size());
MemcachedOperation operation = (MemcachedOperation) operations.get(0);

verifier(operation, Arrays.asList(key, value), WRITE, "asyncStore");
verifier(operation, Arrays.asList(key, value), WRITE, "asyncStore", "replace");
}
@Test
public void testAppend() {
Expand All @@ -102,7 +104,7 @@ public void testAppend() {
Assert.assertEquals("No operations detected.", 1, operations.size());
MemcachedOperation operation = (MemcachedOperation) operations.get(0);

verifier(operation, Arrays.asList(key, value), UPDATE, "asyncCat");
verifier(operation, Arrays.asList(key, value), UPDATE, "asyncCat", "append");
}
@Test
public void testPrepend() {
Expand All @@ -113,7 +115,7 @@ public void testPrepend() {
Assert.assertEquals("No operations detected.", 1, operations.size());
MemcachedOperation operation = (MemcachedOperation) operations.get(0);

verifier(operation, Arrays.asList(key, value), UPDATE, "asyncCat");
verifier(operation, Arrays.asList(key, value), UPDATE, "asyncCat", "prepend");
}
@Test
public void testPrepend1() {
Expand All @@ -124,7 +126,7 @@ public void testPrepend1() {
Assert.assertEquals("No operations detected.", 1, operations.size());
MemcachedOperation operation = (MemcachedOperation) operations.get(0);

verifier(operation, Arrays.asList(key, value), UPDATE, "asyncCat");
verifier(operation, Arrays.asList(key, value), UPDATE, "asyncCat", "prepend");
}
@Test
public void testCas() {
Expand All @@ -135,7 +137,7 @@ public void testCas() {
Assert.assertEquals("No operations detected.", 1, operations.size());
MemcachedOperation operation = (MemcachedOperation) operations.get(0);

verifier(operation, Arrays.asList(key, value), WRITE, "asyncCAS");
verifier(operation, Arrays.asList(key, value), WRITE, "asyncCAS", "set");
}
@Test
public void testCas1() {
Expand All @@ -146,7 +148,7 @@ public void testCas1() {
Assert.assertEquals("No operations detected.", 1, operations.size());
MemcachedOperation operation = (MemcachedOperation) operations.get(0);

verifier(operation, Arrays.asList(key, value), WRITE, "asyncCAS");
verifier(operation, Arrays.asList(key, value), WRITE, "asyncCAS", "set");
}

@Test
Expand All @@ -158,7 +160,7 @@ public void testAsyncCAS() {
Assert.assertEquals("No operations detected.", 1, operations.size());
MemcachedOperation operation = (MemcachedOperation) operations.get(0);

verifier(operation, Arrays.asList(key, value), WRITE, "asyncCAS");
verifier(operation, Arrays.asList(key, value), WRITE, "asyncCAS", "set");
}
@Test
public void testAsyncCAS1() {
Expand All @@ -169,14 +171,15 @@ public void testAsyncCAS1() {
Assert.assertEquals("No operations detected.", 1, operations.size());
MemcachedOperation operation = (MemcachedOperation) operations.get(0);

verifier(operation, Arrays.asList(key, value), WRITE, "asyncCAS");
verifier(operation, Arrays.asList(key, value), WRITE, "asyncCAS", "set");
}

private void verifier(MemcachedOperation operation, List<?> args, String type, String method) {
private void verifier(MemcachedOperation operation, List<?> args, String type, String method, String command) {
Assert.assertEquals("Incorrect executed parameters.", args, operation.getArguments());
Assert.assertEquals("Incorrect event case type.", VulnerabilityCaseType.CACHING_DATA_STORE, operation.getCaseType());
Assert.assertEquals("Incorrect event category.", MemcachedOperation.MEMCACHED, operation.getCategory());
Assert.assertEquals("Incorrect event category.", type, operation.getType());
Assert.assertEquals("Incorrect event category.", type, operation.getCommand());
Assert.assertEquals("Incorrect event category.", command, operation.getType());
Assert.assertEquals("Incorrect executed class-name.", memcachedClient.getClass().getName(), operation.getClassName());
Assert.assertEquals("Incorrect executed method-name.", method, operation.getMethodName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,12 +391,23 @@ private UserClassEntity setUserClassEntity(AbstractOperation operation, Security

for (int i = 0; i < operation.getStackTrace().length; i++) {
StackTraceElement stackTraceElement = operation.getStackTrace()[i];

// Section for user class identification using API handlers
if( !securityMetaData.getMetaData().isFoundAnnotedUserLevelServiceMethod() && URLMappingsHelper.getHandlersHash().contains(stackTraceElement.getClassName().hashCode())){
//Found -> assign user class and return
userClassEntity.setUserClassElement(stackTraceElement);
securityMetaData.getMetaData().setUserLevelServiceMethodEncountered(true);
userClassEntity.setCalledByUserCode(true);
return userClassEntity;
}

//Fallback to old mechanism
if(userStackTraceElement != null){
if(StringUtils.equals(stackTraceElement.getClassName(), userStackTraceElement.getClassName())
&& StringUtils.equals(stackTraceElement.getMethodName(), userStackTraceElement.getMethodName())){
userClassEntity.setUserClassElement(stackTraceElement);
userClassEntity.setCalledByUserCode(securityMetaData.getMetaData().isUserLevelServiceMethodEncountered());
return userClassEntity;
userStackTraceElement = stackTraceElement;
}
}
// TODO: the `if` should be `else if` please check crypto case BenchmarkTest01978. service trace is being registered from doSomething()
Expand Down Expand Up @@ -725,4 +736,5 @@ public String decryptAndVerify(String encryptedData, String hashVerifier) {
return null;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,18 @@ public static Set<ApplicationURLMapping> getApplicationURLMappings() {
return mappings;
}

private static Set<Integer> handlers = ConcurrentHashMap.newKeySet();

public static Set<Integer> getHandlersHash() {
return handlers;
}

public static void addApplicationURLMapping(ApplicationURLMapping mapping) {
if (mapping.getHandler() == null || (mapping.getHandler() != null && !defaultHandlers.contains(mapping.getHandler()))) {
mappings.add(mapping);
}
if (mapping.getHandler() != null){
handlers.add(mapping.getHandler().hashCode());
}
}
}

0 comments on commit a8927a2

Please sign in to comment.