Skip to content

Commit

Permalink
Merge pull request #321 from newrelic/feature/iast-scan-rate-limit
Browse files Browse the repository at this point in the history
NR-304574: Rate limit the IAST replay requests
  • Loading branch information
IshikaDawda authored Sep 12, 2024
2 parents 120ff86 + 99f3406 commit bf9b27c
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 11 deletions.
4 changes: 2 additions & 2 deletions gradle.properties
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# The agent version.
agentVersion=1.4.1
jsonVersion=1.2.6
agentVersion=1.4.2
jsonVersion=1.2.7
# Updated exposed NR APM API version.
nrAPIVersion=8.12.0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ public class AgentConfig {

private Map<String, String> noticeErrorCustomParams = new HashMap<>();

private String iastTestIdentifier;

private AgentConfig(){
}

Expand All @@ -88,6 +90,8 @@ public long instantiate() throws RestrictionModeException {
// Set required LogLevel
logLevel = applyRequiredLogLevel();

iastTestIdentifier = NewRelic.getAgent().getConfig().getValue(IUtilConstants.IAST_TEST_IDENTIFIER);

instantiateAgentMode(groupName);

return triggerIAST();
Expand Down Expand Up @@ -414,6 +418,10 @@ public String getSecurityHome() {
return NR_CSEC_HOME;
}

public String getIastTestIdentifier() {
return iastTestIdentifier;
}

public AgentMode getAgentMode() {
return agentMode;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.newrelic.agent.security.AgentInfo;
import com.newrelic.agent.security.instrumentator.utils.INRSettingsKey;
import com.newrelic.agent.security.intcodeagent.filelogging.FileLoggerThreadPool;
import com.newrelic.agent.security.util.IUtilConstants;
import com.newrelic.api.agent.security.utils.logging.LogLevel;
import com.newrelic.agent.security.intcodeagent.models.IASTDataTransferRequest;
import com.newrelic.agent.security.intcodeagent.websocket.JsonConverter;
Expand All @@ -28,8 +29,6 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

import static com.newrelic.agent.security.instrumentator.utils.INRSettingsKey.SECURITY_POLICY_VULNERABILITY_SCAN_IAST_SCAN_PROBING_THRESHOLD;

public class IASTDataTransferRequestProcessor {
private static final FileLoggerThreadPool logger = FileLoggerThreadPool.getInstance();
public static final String UNABLE_TO_SEND_IAST_DATA_REQUEST_DUE_TO_ERROR_S_S = "Unable to send IAST data request due to error: %s : %s";
Expand All @@ -47,6 +46,10 @@ public class IASTDataTransferRequestProcessor {

private final AtomicLong lastFuzzCCTimestamp = new AtomicLong();

private int currentFetchThresholdPerMin = 3600;

private long controlCommandRequestedAtEpochMilli = 0;

private void task() {
IASTDataTransferRequest request = null;
try {
Expand All @@ -65,6 +68,11 @@ private void task() {
}
}
long currentTimestamp = Instant.now().toEpochMilli();
if(controlCommandRequestedAtEpochMilli <= 0){
AgentInfo.getInstance().getJaHealthCheck().setControlCommandRequestedTime(currentTimestamp);
controlCommandRequestedAtEpochMilli = currentTimestamp;
AgentInfo.getInstance().getJaHealthCheck().setScanActive(true);
}
// Sleep if under cooldown
long cooldownSleepTime = cooldownTillTimestamp.get() - currentTimestamp;
if(cooldownSleepTime > 0) {
Expand All @@ -75,8 +83,12 @@ private void task() {
return;
}

int currentFetchThreshold = NewRelic.getAgent().getConfig()
.getValue(SECURITY_POLICY_VULNERABILITY_SCAN_IAST_SCAN_PROBING_THRESHOLD, 300);
int currentFetchThreshold = Math.round((float) currentFetchThresholdPerMin/12);
if (currentFetchThreshold <= 0){
return;
}

int fetchRatio = 300/currentFetchThreshold;

int remainingRecordCapacityRest = RestRequestThreadPool.getInstance().getQueue().remainingCapacity();
int currentRecordBacklogRest = RestRequestThreadPool.getInstance().getQueue().size();
Expand All @@ -91,7 +103,7 @@ private void task() {
batchSize /= 2;
}

if (batchSize > 100 && remainingRecordCapacity > batchSize) {
if (batchSize > 100/fetchRatio && remainingRecordCapacity > batchSize) {
request = new IASTDataTransferRequest(NewRelicSecurity.getAgent().getAgentUUID());
if (AgentConfig.getInstance().getConfig().getCustomerInfo() != null) {
request.setAppAccountId(AgentConfig.getInstance().getConfig().getCustomerInfo().getAccountId());
Expand Down Expand Up @@ -163,6 +175,12 @@ public void startDataRequestSchedule(long delay, TimeUnit timeUnit){
if(initialDelay < 0){
initialDelay = 0;
}
// IAST Scan Rate per minute with range [12, 3600]; default 3600 replay requests will be replayed per minute
try {
currentFetchThresholdPerMin = Math.min(Math.max(NewRelic.getAgent().getConfig().getValue(IUtilConstants.SCAN_REQUEST_RATE_LIMIT, 3600), 12), 3600);
} catch (Exception e) {
logger.log(LogLevel.WARNING, String.format("Error while reading Configuration security.scan_request_rate_limit : %s, Using default value %s replay request per min.", e.getMessage(), currentFetchThresholdPerMin), e, this.getClass().getName());
}
logger.log(LogLevel.INFO, String.format("IAST data pull request is scheduled at %s, after delay of %s seconds", AgentConfig.getInstance().getAgentMode().getScanSchedule().getDataCollectionTime(), initialDelay), IASTDataTransferRequestProcessor.class.getName());
future = executorService.scheduleWithFixedDelay(this::task, initialDelay, delay, timeUnit);
} catch (Throwable ignored){}
Expand All @@ -185,4 +203,8 @@ public void setCooldownTillTimestamp(long timestamp) {
public void setLastFuzzCCTimestamp(long timestamp) {
lastFuzzCCTimestamp.set(timestamp);
}

public long getControlCommandRequestedAtEpochMilli() {
return controlCommandRequestedAtEpochMilli;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ public void run() {
iastReplayRequestMsgReceiveTime = Instant.now();
IASTDataTransferRequestProcessor.getInstance().setLastFuzzCCTimestamp(Instant.now().toEpochMilli());
RestRequestProcessor.processControlCommand(controlCommand);
if(ControlCommandProcessorThreadPool.getInstance().getScanStartTime() <= 0) {
ControlCommandProcessorThreadPool.getInstance().setScanStartTime(Instant.now().toEpochMilli());
AgentInfo.getInstance().getJaHealthCheck().setScanStartTime(ControlCommandProcessorThreadPool.getInstance().getScanStartTime());
}
break;

case IntCodeControlCommand.STARTUP_WELCOME_MSG:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ public class ControlCommandProcessorThreadPool {
private final boolean allowCoreThreadTimeOut = false;
private static Object mutex = new Object();

private long scanStartTime = 0;

public ThreadPoolExecutor getExecutor() {
return executor;
}
Expand Down Expand Up @@ -157,4 +159,11 @@ public void shutDownThreadPoolExecutor() {
}
}

public long getScanStartTime() {
return scanStartTime;
}

public void setScanStartTime(long scanStartTime) {
this.scanStartTime = scanStartTime;
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package com.newrelic.agent.security.intcodeagent.models.javaagent;

import com.newrelic.agent.security.AgentConfig;
import com.newrelic.agent.security.AgentInfo;
import com.newrelic.agent.security.intcodeagent.filelogging.FileLoggerThreadPool;
import com.newrelic.api.agent.security.utils.logging.LogLevel;
import com.newrelic.agent.security.intcodeagent.websocket.JsonConverter;

import java.lang.management.ManagementFactory;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
Expand All @@ -14,9 +17,19 @@ public class JAHealthCheck extends AgentBasicInfo {
private static final FileLoggerThreadPool logger = FileLoggerThreadPool.getInstance();
private static final String HC_CREATED = "Created Health Check: %s";

// private String protectedServer;
private long procStartTime;

// private Set protectedDB;
private long controlCommandRequestedTime;

private long scanStartTime;

private long trafficStartedTime;

private final long csecActivationTime;

private final long iastDataRequestTime;

private Boolean scanActive = false;

private AtomicInteger invokedHookCount;

Expand Down Expand Up @@ -45,6 +58,17 @@ public JAHealthCheck(String applicationUUID) {
this.serviceStatus = new HashMap<>();
this.eventStats = new EventStats();
this.setKind(AgentInfo.getInstance().getApplicationInfo().getIdentifier().getKind());
this.procStartTime = ManagementFactory.getRuntimeMXBean().getStartTime();
if(AgentConfig.getInstance().getAgentMode().getScanSchedule().getNextScanTime() != null) {
this.csecActivationTime = AgentConfig.getInstance().getAgentMode().getScanSchedule().getNextScanTime().getTime();
} else {
this.csecActivationTime = Instant.now().toEpochMilli();
}
if(AgentConfig.getInstance().getAgentMode().getScanSchedule().getDataCollectionTime() != null) {
this.iastDataRequestTime = AgentConfig.getInstance().getAgentMode().getScanSchedule().getDataCollectionTime().getTime();
} else {
this.iastDataRequestTime = Instant.now().toEpochMilli();
}
logger.log(LogLevel.INFO, String.format(HC_CREATED, JsonConverter.toJSON(this)), JAHealthCheck.class.getName());
}

Expand All @@ -59,6 +83,13 @@ public JAHealthCheck(JAHealthCheck jaHealthCheck) {
this.schedulerRuns = new SchedulerRuns(jaHealthCheck.schedulerRuns);
this.invokedHookCount = new AtomicInteger(jaHealthCheck.invokedHookCount.get());
this.webSocketConnectionStats = new WebSocketConnectionStats(jaHealthCheck.webSocketConnectionStats);
this.procStartTime = jaHealthCheck.getProcStartTime();
this.controlCommandRequestedTime = jaHealthCheck.getControlCommandRequestedTime();
this.scanStartTime = jaHealthCheck.getScanStartTime();
this.trafficStartedTime = jaHealthCheck.getTrafficStartedTime();
this.csecActivationTime = jaHealthCheck.getCsecActivationTime();
this.iastDataRequestTime = jaHealthCheck.getIastDataRequestTime();
this.scanActive = jaHealthCheck.getScanActive();
logger.log(LogLevel.INFO, String.format(HC_CREATED, JsonConverter.toJSON(this)), JAHealthCheck.class.getName());
}

Expand Down Expand Up @@ -133,6 +164,54 @@ public void setSchedulerRuns(SchedulerRuns schedulerRuns) {
this.schedulerRuns = schedulerRuns;
}

public long getProcStartTime() {
return procStartTime;
}

public void setProcStartTime(long procStartTime) {
this.procStartTime = procStartTime;
}

public long getControlCommandRequestedTime() {
return controlCommandRequestedTime;
}

public void setControlCommandRequestedTime(long controlCommandRequestedTime) {
this.controlCommandRequestedTime = controlCommandRequestedTime;
}

public long getScanStartTime() {
return scanStartTime;
}

public void setScanStartTime(long scanStartTime) {
this.scanStartTime = scanStartTime;
}

public long getTrafficStartedTime() {
return trafficStartedTime;
}

public void setTrafficStartedTime(long trafficStartedTime) {
this.trafficStartedTime = trafficStartedTime;
}

public long getCsecActivationTime() {
return csecActivationTime;
}

public Boolean getScanActive() {
return scanActive;
}

public void setScanActive(Boolean scanActive) {
this.scanActive = scanActive;
}

public long getIastDataRequestTime() {
return iastDataRequestTime;
}

public void reset(){
this.setInvokedHookCount(0);
this.stats.clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.management.ManagementFactory;
import java.net.*;
import java.nio.file.Files;
import java.nio.file.Paths;
Expand Down Expand Up @@ -153,6 +154,10 @@ private WSClient() throws URISyntaxException {
this.addHeader("NR-ACCOUNT-ID", AgentConfig.getInstance().getConfig().getCustomerInfo().getAccountId());
this.addHeader("NR-CSEC-IAST-DATA-TRANSFER-MODE", "PULL");
this.addHeader("NR-CSEC-IGNORED-VUL-CATEGORIES", AgentConfig.getInstance().getAgentMode().getSkipScan().getIastDetectionCategory().getDisabledCategoriesCSV());
this.addHeader("NR-CSEC-PROCESS-START-TIME", String.valueOf(ManagementFactory.getRuntimeMXBean().getStartTime()));
if (AgentConfig.getInstance().getIastTestIdentifier() != null) {
this.addHeader("NR-CSEC-IAST-TEST-IDENTIFIER", AgentConfig.getInstance().getIastTestIdentifier());
}
Proxy proxy = proxyManager();
if(proxy != null) {
this.setProxy(proxy);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ public interface IUtilConstants {
String SCAN_TIME_SCHEDULE = "security.scan_schedule.schedule";
String SCAN_TIME_DURATION = "security.scan_schedule.duration";
String SCAN_TIME_COLLECT_SAMPLES = "security.scan_schedule.always_sample_traces";
String SCAN_REQUEST_RATE_LIMIT = "security.scan_controllers.iast_scan_request_rate_limit";

String SKIP_IAST_SCAN = "security.exclude_from_iast_scan";
String SKIP_IAST_SCAN_API = SKIP_IAST_SCAN + ".api";
Expand Down Expand Up @@ -61,6 +62,7 @@ public interface IUtilConstants {
String NR_SECURITY_ENABLED = "security.enabled";

String NR_SECURITY_HOME_APP = "security.is_home_app";
String IAST_TEST_IDENTIFIER = "security.iast_test_identifier";

String NR_SECURITY_CA_BUNDLE_PATH = "ca_bundle_path";
String NR_CSEC_DEBUG_LOGFILE_SIZE = "NR_CSEC_DEBUG_LOGFILE_SIZE";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import com.newrelic.agent.security.intcodeagent.exceptions.RestrictionModeException;
import com.newrelic.agent.security.intcodeagent.filelogging.FileLoggerThreadPool;
import com.newrelic.agent.security.intcodeagent.filelogging.LogFileHelper;
import com.newrelic.agent.security.intcodeagent.models.collectorconfig.AgentMode;
import com.newrelic.agent.security.intcodeagent.models.javaagent.*;
import com.newrelic.agent.security.intcodeagent.utils.*;
import com.newrelic.api.agent.security.instrumentation.helpers.*;
Expand Down Expand Up @@ -62,6 +61,7 @@ public class Agent implements SecurityAgent {

public static final String DROPPING_EVENT_AS_IT_WAS_GENERATED_BY_K_2_INTERNAL_API_CALL = "Dropping event as it was generated by agent internal API call : ";
private static final AtomicBoolean firstEventProcessed = new AtomicBoolean(false);
private long trafficStartedAt = 0;
public static final String ERROR_WHILE_GENERATING_TRACE_ID_FOR_CATEGORY_S = "Error while generating trace id for category : %s";
public static final String SKIPPING_THE_API_S_AS_IT_IS_PART_OF_THE_SKIP_SCAN_LIST = "Skipping the API %s as it is part of the skip scan list";
public static final String INVALID_CRON_EXPRESSION_PROVIDED_FOR_IAST_RESTRICTED_MODE = "Invalid cron expression provided for IAST Mode";
Expand Down Expand Up @@ -261,6 +261,7 @@ private void startSecurityServices() {
);
WSReconnectionST.getInstance().submitNewTaskSchedule(0);
EventSendPool.getInstance();
ControlCommandProcessorThreadPool.getInstance();
logger.logInit(
LogLevel.INFO,
String.format(STARTED_MODULE_LOG, AgentServices.EventWritePool.name()),
Expand All @@ -279,7 +280,8 @@ private void startSecurityServices() {
} else {
IASTDataTransferRequestProcessor.getInstance().stopDataRequestSchedule(true);
}

AgentInfo.getInstance().getJaHealthCheck().setControlCommandRequestedTime(IASTDataTransferRequestProcessor.getInstance().getControlCommandRequestedAtEpochMilli());
AgentInfo.getInstance().getJaHealthCheck().setScanStartTime(ControlCommandProcessorThreadPool.getInstance().getScanStartTime());
}

@Override
Expand Down Expand Up @@ -344,6 +346,7 @@ private void deactivateSecurityServices(){
*/
// InstrumentationUtils.shutdownLogic();
IASTDataTransferRequestProcessor.getInstance().stopDataRequestSchedule(true);
info.getJaHealthCheck().setScanActive(false);
if(!config.getAgentMode().getScanSchedule().isCollectSamples()) {
AgentInfo.getInstance().setAgentActive(false);
HealthCheckScheduleThread.getInstance().cancelTask(true);
Expand Down Expand Up @@ -468,6 +471,8 @@ public void registerOperation(AbstractOperation operation) {
String.format(EVENT_ZERO_PROCESSED, securityMetaData.getRequest()),
this.getClass().getName());
firstEventProcessed.set(true);
trafficStartedAt = Instant.now().toEpochMilli();
AgentInfo.getInstance().getJaHealthCheck().setTrafficStartedTime(trafficStartedAt);
}
}
}
Expand Down

0 comments on commit bf9b27c

Please sign in to comment.