Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revamp user class detection technique, use server level endpoints #217

Merged
merged 9 commits into from
May 14, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ public class APIEndpointTest {
public void setupEndpoints() {
expectedMappings.put("/servlet/*", HttpServletServer.class.getName()+"$1");
expectedMappings.put("/index.jsp", null);
expectedMappings.put("/*.jsp", "org.apache.jasper.servlet.JspServlet");
expectedMappings.put("/", DefaultServlet.class.getName());
expectedMappings.put("/*.jspx", "org.apache.jasper.servlet.JspServlet");
}

@Test
Expand All @@ -43,7 +40,7 @@ public void testAPIEndpoint() throws Exception {

Set<ApplicationURLMapping> mappings = URLMappingsHelper.getApplicationURLMappings();
Assert.assertNotNull(mappings);
Assert.assertEquals(5, mappings.size());
Assert.assertEquals(2, mappings.size());
for (ApplicationURLMapping mapping : mappings) {
assertMappings(mapping);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ public class APIEndpointTest {
public void setupEndpoints() {
expectedMappings.put("/servlet/*", HttpServletServer.class.getName()+"$1");
expectedMappings.put("/index.jsp", null);
expectedMappings.put("/*.jsp", "org.apache.jasper.servlet.JspServlet");
expectedMappings.put("/", DefaultServlet.class.getName());
expectedMappings.put("/*.jspx", "org.apache.jasper.servlet.JspServlet");
}

@Test
Expand All @@ -43,7 +40,7 @@ public void testAPIEndpoint() throws Exception {

Set<ApplicationURLMapping> mappings = URLMappingsHelper.getApplicationURLMappings();
Assert.assertNotNull(mappings);
Assert.assertEquals(5, mappings.size());
Assert.assertEquals(2, mappings.size());
for (ApplicationURLMapping mapping : mappings) {
assertMappings(mapping);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ public class ServerTest {
@BeforeClass
public static void addMappings() {
actualMappings.put("/servlet/*", MyServlet.class.getName());
actualMappings.put("/", ServletHandler.Default404Servlet.class.getName());
}
@After
public void teardown() throws Exception {
Expand Down Expand Up @@ -201,7 +200,7 @@ public void testAPIEndpoint () throws Exception {
start();

Set<ApplicationURLMapping> mappings = URLMappingsHelper.getApplicationURLMappings();
Assert.assertEquals(2, mappings.size());
Assert.assertEquals(1, mappings.size());
for (ApplicationURLMapping mapping : mappings) {
Assert.assertNotNull(mapping);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public class ApiEndpointTest {

@Test
public void testURLMappings() {
String handler = DefaultServlet.class.getName();
String handler = MyServlet.class.getName();
String method = "*";
Iterator<ApplicationURLMapping> mappings = URLMappingsHelper.getApplicationURLMappings().iterator();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.File;
import java.io.IOException;
import java.net.ServerSocket;
Expand Down Expand Up @@ -64,7 +67,7 @@ public void onStartup(Set<Class<?>> c, ServletContext ctx){
}
}, Collections.emptySet());

Tomcat.addServlet(context, "servlet", new DefaultServlet());
Tomcat.addServlet(context, "servlet", new MyServlet());
context.addServletMappingDecoded("/*","servlet");
context.addServletMappingDecoded("/test","servlet");

Expand All @@ -88,4 +91,10 @@ private void stop() {
}
}
}
}
class MyServlet extends HttpServlet {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
super.doGet(req, resp);
}
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package com.nr.agent.security.instrumentation.memcached.test;

import com.github.mwarc.embeddedmemcached.JMemcachedServer;
import com.newrelic.agent.security.instrumentation.spy.memcached.MemcachedHelper;
import com.newrelic.agent.security.introspec.InstrumentationTestConfig;
import com.newrelic.agent.security.introspec.SecurityInstrumentationTestRunner;
import com.newrelic.agent.security.introspec.SecurityIntrospector;
import com.newrelic.api.agent.security.schema.AbstractOperation;
import com.newrelic.api.agent.security.schema.VulnerabilityCaseType;
import com.newrelic.api.agent.security.schema.operation.MemcachedOperation;
import net.spy.memcached.MemcachedClient;
import net.spy.memcached.ops.StoreType;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
Expand Down Expand Up @@ -68,7 +70,7 @@ public void testSet() {
Assert.assertEquals("No operations detected.", 1, operations.size());
MemcachedOperation operation = (MemcachedOperation) operations.get(0);

verifier(operation, Arrays.asList(key, value), WRITE, "asyncStore");
verifier(operation, Arrays.asList(key, value), WRITE, "asyncStore", "set");
}

@Test
Expand All @@ -80,7 +82,7 @@ public void testAdd() {
Assert.assertEquals("No operations detected.", 1, operations.size());
MemcachedOperation operation = (MemcachedOperation) operations.get(0);

verifier(operation, Arrays.asList(key, value), WRITE, "asyncStore");
verifier(operation, Arrays.asList(key, value), WRITE, "asyncStore", "add");
}
@Test
public void testReplace() {
Expand All @@ -91,7 +93,7 @@ public void testReplace() {
Assert.assertEquals("No operations detected.", 1, operations.size());
MemcachedOperation operation = (MemcachedOperation) operations.get(0);

verifier(operation, Arrays.asList(key, value), WRITE, "asyncStore");
verifier(operation, Arrays.asList(key, value), WRITE, "asyncStore", "replace");
}
@Test
public void testAppend() {
Expand All @@ -102,7 +104,7 @@ public void testAppend() {
Assert.assertEquals("No operations detected.", 1, operations.size());
MemcachedOperation operation = (MemcachedOperation) operations.get(0);

verifier(operation, Arrays.asList(key, value), UPDATE, "asyncCat");
verifier(operation, Arrays.asList(key, value), UPDATE, "asyncCat", "append");
}
@Test
public void testPrepend() {
Expand All @@ -113,7 +115,7 @@ public void testPrepend() {
Assert.assertEquals("No operations detected.", 1, operations.size());
MemcachedOperation operation = (MemcachedOperation) operations.get(0);

verifier(operation, Arrays.asList(key, value), UPDATE, "asyncCat");
verifier(operation, Arrays.asList(key, value), UPDATE, "asyncCat", "prepend");
}
@Test
public void testPrepend1() {
Expand All @@ -124,7 +126,7 @@ public void testPrepend1() {
Assert.assertEquals("No operations detected.", 1, operations.size());
MemcachedOperation operation = (MemcachedOperation) operations.get(0);

verifier(operation, Arrays.asList(key, value), UPDATE, "asyncCat");
verifier(operation, Arrays.asList(key, value), UPDATE, "asyncCat", "prepend");
}
@Test
public void testCas() {
Expand All @@ -135,7 +137,7 @@ public void testCas() {
Assert.assertEquals("No operations detected.", 1, operations.size());
MemcachedOperation operation = (MemcachedOperation) operations.get(0);

verifier(operation, Arrays.asList(key, value), WRITE, "asyncCAS");
verifier(operation, Arrays.asList(key, value), WRITE, "asyncCAS", "set");
}
@Test
public void testCas1() {
Expand All @@ -146,7 +148,7 @@ public void testCas1() {
Assert.assertEquals("No operations detected.", 1, operations.size());
MemcachedOperation operation = (MemcachedOperation) operations.get(0);

verifier(operation, Arrays.asList(key, value), WRITE, "asyncCAS");
verifier(operation, Arrays.asList(key, value), WRITE, "asyncCAS", "set");
}

@Test
Expand All @@ -158,7 +160,7 @@ public void testAsyncCAS() {
Assert.assertEquals("No operations detected.", 1, operations.size());
MemcachedOperation operation = (MemcachedOperation) operations.get(0);

verifier(operation, Arrays.asList(key, value), WRITE, "asyncCAS");
verifier(operation, Arrays.asList(key, value), WRITE, "asyncCAS", "set");
}
@Test
public void testAsyncCAS1() {
Expand All @@ -169,14 +171,15 @@ public void testAsyncCAS1() {
Assert.assertEquals("No operations detected.", 1, operations.size());
MemcachedOperation operation = (MemcachedOperation) operations.get(0);

verifier(operation, Arrays.asList(key, value), WRITE, "asyncCAS");
verifier(operation, Arrays.asList(key, value), WRITE, "asyncCAS", "set");
}

private void verifier(MemcachedOperation operation, List<?> args, String type, String method) {
private void verifier(MemcachedOperation operation, List<?> args, String type, String method, String command) {
Assert.assertEquals("Incorrect executed parameters.", args, operation.getArguments());
Assert.assertEquals("Incorrect event case type.", VulnerabilityCaseType.CACHING_DATA_STORE, operation.getCaseType());
Assert.assertEquals("Incorrect event category.", MemcachedOperation.MEMCACHED, operation.getCategory());
Assert.assertEquals("Incorrect event category.", type, operation.getType());
Assert.assertEquals("Incorrect event category.", type, operation.getCommand());
Assert.assertEquals("Incorrect event category.", command, operation.getType());
Assert.assertEquals("Incorrect executed class-name.", memcachedClient.getClass().getName(), operation.getClassName());
Assert.assertEquals("Incorrect executed method-name.", method, operation.getMethodName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,23 @@ private UserClassEntity setUserClassEntity(AbstractOperation operation, Security

for (int i = 0; i < operation.getStackTrace().length; i++) {
StackTraceElement stackTraceElement = operation.getStackTrace()[i];

// Section for user class identification using API handlers
if( !securityMetaData.getMetaData().isFoundAnnotedUserLevelServiceMethod() && URLMappingsHelper.getHandlersHash().contains(stackTraceElement.getClassName().hashCode())){
//Found -> assign user class and return
userClassEntity.setUserClassElement(stackTraceElement);
securityMetaData.getMetaData().setUserLevelServiceMethodEncountered(true);
userClassEntity.setCalledByUserCode(true);
return userClassEntity;
}

//Fallback to old mechanism
if(userStackTraceElement != null){
if(StringUtils.equals(stackTraceElement.getClassName(), userStackTraceElement.getClassName())
&& StringUtils.equals(stackTraceElement.getMethodName(), userStackTraceElement.getMethodName())){
userClassEntity.setUserClassElement(stackTraceElement);
userClassEntity.setCalledByUserCode(securityMetaData.getMetaData().isUserLevelServiceMethodEncountered());
return userClassEntity;
userStackTraceElement = stackTraceElement;
}
}
// TODO: the `if` should be `else if` please check crypto case BenchmarkTest01978. service trace is being registered from doSomething()
Expand Down Expand Up @@ -711,4 +722,5 @@ public String decryptAndVerify(String encryptedData, String hashVerifier) {
return null;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,18 @@ public static Set<ApplicationURLMapping> getApplicationURLMappings() {
return mappings;
}

private static Set<Integer> handlers = ConcurrentHashMap.newKeySet();

public static Set<Integer> getHandlersHash() {
return handlers;
}

public static void addApplicationURLMapping(ApplicationURLMapping mapping) {
if (mapping.getHandler() == null || (mapping.getHandler() != null && !defaultHandlers.contains(mapping.getHandler()))) {
mappings.add(mapping);
}
if (mapping.getHandler() != null){
handlers.add(mapping.getHandler().hashCode());
}
}
}
Loading