Skip to content

Commit

Permalink
Merge pull request #2 from NASA-AMMOS/1-memory-leak-and-sqs-throttling
Browse files Browse the repository at this point in the history
initial cut at thread throttling / memory management
  • Loading branch information
ztaylor54 authored Dec 16, 2020
2 parents 4ecbda5 + b827313 commit 1fd6f92
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,32 @@
/**
* @author ghollins, jwood, ztaylor
*/
public class S3DataManager {
private static final Logger log = LoggerFactory.getLogger(S3DataManager.class);
public class S3DataManager implements AutoCloseable {
private static final Logger log = LoggerFactory.getLogger(S3DataManager.class);

public static final int OLD_SECONDS_BACK_THRESHOLD = 120;
public static final int OLD_SECONDS_BACK_THRESHOLD = 120;

private Region regionUsed;
private Region regionUsed;

private S3Client s3;
private S3Client s3;

public S3DataManager(String region) {
init(region); // sets defaults
}
public S3DataManager(String region) {
init(region); // sets defaults
}

public void init(String region) {
s3 = S3Client.builder().region(Region.of(region)).build();
regionUsed = Region.of(region);
}
public void init(String region) {
s3 = S3Client.builder().region(Region.of(region)).build();
regionUsed = Region.of(region);
}

@Override
public void close() throws Exception {
if (s3 != null) {
s3.close();
}
}

public S3Client getClient() {
public S3Client getClient() {
return s3;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,49 +403,48 @@ private void scheduleIfNotAlready(String initiationTrigger, Map<String,String> p
* The ETag reflects changes only to the contents of an object,
* not its metadata
*/
protected synchronized boolean skipScheduling(Map<String,String> partners) {
S3DataManager s3 = new S3DataManager(aws_default_region);

// Get sorted list of s3ObjKey/etags for all partners
List<String> eTagList = new ArrayList<>();
for (String s3ObjKey : partners.values()) {
HeadObjectResponse metaData = s3.getObjectMetadata(s3BucketName, s3ObjKey);
if (metaData != null) {
eTagList.add(s3ObjKey + metaData.eTag());
}
else {
log.warn("Skipping scheduling process for inputs: " + partners +
", since they don't all exist for this initiator (" + initiatorId + ")");
return true; // not all objects exist, so skip scheduling
protected synchronized boolean skipScheduling(Map<String,String> partners) throws Exception {

try ( S3DataManager s3 = new S3DataManager(aws_default_region) ) {
// Get sorted list of s3ObjKey/etags for all partners
List<String> eTagList = new ArrayList<>();
for (String s3ObjKey : partners.values()) {
HeadObjectResponse metaData = s3.getObjectMetadata(s3BucketName, s3ObjKey);
if (metaData != null) {
eTagList.add(s3ObjKey + metaData.eTag());
} else {
log.warn("Skipping scheduling process for inputs: " + partners +
", since they don't all exist for this initiator (" + initiatorId + ")");
return true; // not all objects exist, so skip scheduling
}
}
}

// hashCode is different, depending on order, so sort
Collections.sort(eTagList);
// hashCode is different, depending on order, so sort
Collections.sort(eTagList);

// get hashcode
int hashCode = ( initiatorId + eTagList.toString() ).hashCode();
if (recentlyProcessedInputs.get(hashCode) != null) {
log.info("Skipping scheduling process for inputs: " + partners +
", since they have been recently scheduled (" + hashCode + ") for this initiator (" + initiatorId + ") " +
"within the last " + DUPLICATE_PREVENTION_PERIOD + " seconds.");
return true; // already processed this set of inputs
}
else {
recentlyProcessedInputs.put(hashCode, hashCode);
// also add in, for each partner, a hashcode into another (new) TTL map
// Then check this map in other parts of code .
// If none are found in mem in other part of code, then schedule immediately.
// this avoids the false positive of "old".
// Also, cleanup models , like XYZ to produce new RDR versions...
log.debug("added hash code: " + hashCode + ", to recentlyProcessedInputs. " +
recentlyProcessedInputs.size() + " (initiatorId = " + initiatorId + ")");

for (String partner : partners.values()) {
hashCode = (initiatorId + partner).hashCode();
individualProcessedInputs.put(hashCode, hashCode);
// get hashcode
int hashCode = (initiatorId + eTagList.toString()).hashCode();
if (recentlyProcessedInputs.get(hashCode) != null) {
log.info("Skipping scheduling process for inputs: " + partners +
", since they have been recently scheduled (" + hashCode + ") for this initiator (" + initiatorId + ") " +
"within the last " + DUPLICATE_PREVENTION_PERIOD + " seconds.");
return true; // already processed this set of inputs
} else {
recentlyProcessedInputs.put(hashCode, hashCode);
// also add in, for each partner, a hashcode into another (new) TTL map
// Then check this map in other parts of code .
// If none are found in mem in other part of code, then schedule immediately.
// this avoids the false positive of "old".
// Also, cleanup models , like XYZ to produce new RDR versions...
log.debug("added hash code: " + hashCode + ", to recentlyProcessedInputs. " +
recentlyProcessedInputs.size() + " (initiatorId = " + initiatorId + ")");

for (String partner : partners.values()) {
hashCode = (initiatorId + partner).hashCode();
individualProcessedInputs.put(hashCode, hashCode);
}
return false;
}
return false;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sqs.SqsClient;
import software.amazon.awssdk.services.sqs.model.DeleteMessageRequest;
Expand All @@ -21,6 +22,7 @@
import java.util.Map.Entry;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;

/**
* This thread subscribes to an AWS SQS URL, and when a new message arrives,
Expand Down Expand Up @@ -56,7 +58,7 @@ public class SQSDispatcherThread extends Thread implements InitializingBean {

private SqsClient sqs;
private long lastClientRefreshTime;
private static final int TOKEN_REFRESH_FREQUENCY = 60 * 30 * 1000; // 30 minutes in milliseconds
private static final int TOKEN_REFRESH_FREQUENCY = 60 * 10 * 1000; // 10 minutes in milliseconds
private static final Integer SQS_CLIENT_WAIT_TIME_SECONDS = 20;
private Gson gson;

Expand All @@ -67,9 +69,20 @@ public class SQSDispatcherThread extends Thread implements InitializingBean {
private Map<String,HashSet<String>> dispatcherMap;
static final Object dispatcherMapLock = new Object();

// Maximum number of simultaneous threads that may be dispatched before throttling occurs.
// If this value is met, then the SQS request rate will be throttled by the given amount until the
// system can catch up.
//
@Value("${aws.sqs.dispatcher.maxThreads}") private Integer maxThreads;

// number of threads in messageHandlerThreadExecutor running at a given moment
private AtomicInteger numberThreads = new AtomicInteger(0);

private ExecutorService messageDeleterThreadExecutor = Executors.newFixedThreadPool(10);
private ExecutorService messageHandlerThreadExecutor = Executors.newFixedThreadPool(20);

private long avgMsgHandleTimeMillis = 100;

public SQSDispatcherThread() {
log.debug("SQSDispatcherThread ctor...........................................");
}
Expand All @@ -95,6 +108,9 @@ public void run() {
log.debug("SQSDispatcherThread STARTING...");
gson = new Gson();

// See: https://docs.aws.amazon.com/sdk-for-java/v1/developer-guide/java-dg-jvm-ttl.html
java.security.Security.setProperty("networkaddress.cache.ttl" , "60");

refreshAwsClient(true);

ReceiveMessageRequest receiveMessageRequest = ReceiveMessageRequest.builder()
Expand All @@ -121,14 +137,23 @@ public void run() {
}

try {
// Will throttle looping if max number of threads has been exceeded
if (numberThreads.get() > maxThreads) {
long actualThrottleMillis = (long)(1.1 * avgMsgHandleTimeMillis) * (numberThreads.get() - maxThreads);
log.warn("Throttling by {} ms ({}/{}) avgMsgHandleTime={}", actualThrottleMillis, numberThreads, maxThreads, avgMsgHandleTimeMillis);
Thread.sleep(actualThrottleMillis);
avgMsgHandleTimeMillis += 10; // backoff
continue;
}

log.trace("about to receive message...");
long t0 = System.currentTimeMillis();
//
// FIXME: This creates a new thread that doesn't get cleaned up!!
//
refreshAwsClient(false);
List<Message> messages = sqs.receiveMessage(receiveMessageRequest).messages();
long t1 = System.currentTimeMillis();
log.debug("bufferedSqs.receiveMessage (in " + (t1 - t0) + "ms) [" + messages.size() + " messages]");
log.debug("bufferedSqs.receiveMessage (in " + (t1 - t0) + "ms) [" +
messages.size() + " messages, " +
numberThreads.get() + " handlerThreads]");

if (messages.isEmpty()) {
log.trace("GOT " + messages.size() + " MESSAGE(S)");
Expand All @@ -140,6 +165,7 @@ public void run() {
// For each received message
//
for (Message msg : messages) {
numberThreads.incrementAndGet();
handleMessageOnSeparateThread(msg);
}
}
Expand Down Expand Up @@ -219,6 +245,7 @@ public void run() {
} catch (Exception e) {
log.error("Unable to parse message as JSON. Deleting this message from queue, and moving on to next message...", e);
deleteMessageFromQueueOnSeparateThread(msg);
numberThreads.decrementAndGet();
return;
}

Expand Down Expand Up @@ -262,16 +289,23 @@ public void run() {
}
} catch (Exception e) {
log.error("error while processing message", e);
numberThreads.decrementAndGet();
return;
}
finally {
deleteMessageFromQueueOnSeparateThread(msg);
}


if ((System.currentTimeMillis() - d0) > 100) {
log.debug("Handled message (in " + (System.currentTimeMillis() - d0) + " ms)");
int curThreads = numberThreads.decrementAndGet();
long handleDuration = (System.currentTimeMillis() - d0);
if (handleDuration > 100) {
log.debug("Handled message (in " + (System.currentTimeMillis() - d0) + " ms) " +
curThreads + " threads now active)");
}

// keep track of avg message handling duration...
if (avgMsgHandleTimeMillis > handleDuration) { avgMsgHandleTimeMillis--; } else { avgMsgHandleTimeMillis++; }
if (avgMsgHandleTimeMillis < 30) { avgMsgHandleTimeMillis = 30; } // floor
}
});

Expand Down Expand Up @@ -320,20 +354,16 @@ private void refreshAwsClient(boolean forceRefresh) {
lastClientRefreshTime == 0 ||
((System.currentTimeMillis() - lastClientRefreshTime) > TOKEN_REFRESH_FREQUENCY)) {

log.debug("About to refresh AWS SQS client...");
log.debug("About to refresh AWS SQS client...");

if (sqs != null) {
sqs.close();
}

sqs = SqsClient.builder()
.region(Region.of(aws_default_region))
.build();

// See: https://docs.aws.amazon.com/sdk-for-java/v1/developer-guide/java-dg-jvm-ttl.html
log.debug("networkaddress.cache.ttl = " + java.security.Security.getProperty("networkaddress.cache.ttl"));
java.security.Security.setProperty("networkaddress.cache.ttl" , "60");
log.debug("networkaddress.cache.ttl = " + java.security.Security.getProperty("networkaddress.cache.ttl"));

// Create the buffered SQS client
//bufferedSqs = new AmazonSQSBufferedAsyncClient(sqsAsync);

lastClientRefreshTime = System.currentTimeMillis(); // update timestamp
log.debug("AWS credentials / client refreshed.");
}
Expand Down
3 changes: 2 additions & 1 deletion install/cws-engine/cws-engine.properties
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ cws.db.username=__CWS_DB_USERNAME__
cws.db.password=__CWS_DB_PASSWORD__
aws.default.region=__AWS_DEFAULT_REGION__
aws.sqs.dispatcher.sqsUrl=__AWS_SQS_DISPATCHER_SQS_URL__
aws.sqs.dispatcher.msgFetchLimit=__AWS_SQS_DISPATCHER_MSG_FETCH_LIMIT__
aws.sqs.dispatcher.msgFetchLimit=__AWS_SQS_DISPATCHER_MSG_FETCH_LIMIT__
aws.sqs.dispatcher.maxThreads=100
1 change: 1 addition & 0 deletions install/cws-ui/cws-ui.properties
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@ cws.metrics.publishing.interval=__CWS_METRICS_PUBLISHING_INTERVAL__
aws.default.region=__AWS_DEFAULT_REGION__
aws.sqs.dispatcher.sqsUrl=__AWS_SQS_DISPATCHER_SQS_URL__
aws.sqs.dispatcher.msgFetchLimit=__AWS_SQS_DISPATCHER_MSG_FETCH_LIMIT__
aws.sqs.dispatcher.maxThreads=100

0 comments on commit 1fd6f92

Please sign in to comment.