Skip to content

Commit

Permalink
Merge branch 'main' into fix-log-message-format-bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
joegallo committed Dec 10, 2024
2 parents 6a21cc9 + 85f37ac commit 5684bec
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 14 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/118177.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 118177
summary: Fixing bedrock event executor terminated cache issue
area: Machine Learning
type: bug
issues:
- 117916
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

package org.elasticsearch.nativeaccess.jdk;

import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.nativeaccess.VectorSimilarityFunctions;
import org.elasticsearch.nativeaccess.lib.LoaderHelper;
import org.elasticsearch.nativeaccess.lib.VectorLibrary;
Expand All @@ -25,6 +27,8 @@

public final class JdkVectorLibrary implements VectorLibrary {

static final Logger logger = LogManager.getLogger(JdkVectorLibrary.class);

static final MethodHandle dot7u$mh;
static final MethodHandle sqr7u$mh;

Expand All @@ -36,6 +40,7 @@ public final class JdkVectorLibrary implements VectorLibrary {

try {
int caps = (int) vecCaps$mh.invokeExact();
logger.info("vec_caps=" + caps);
if (caps != 0) {
if (caps == 2) {
dot7u$mh = downcallHandle(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public class DocumentMapper {
private final MapperMetrics mapperMetrics;
private final IndexVersion indexVersion;
private final Logger logger;
private final String indexName;

/**
* Create a new {@link DocumentMapper} that holds empty mappings.
Expand Down Expand Up @@ -67,16 +68,17 @@ public static DocumentMapper createEmpty(MapperService mapperService) {
this.mapperMetrics = mapperMetrics;
this.indexVersion = version;
this.logger = Loggers.getLogger(getClass(), indexName);
this.indexName = indexName;

assert mapping.toCompressedXContent().equals(source) || isSyntheticSourceMalformed(source, version)
: "provided source [" + source + "] differs from mapping [" + mapping.toCompressedXContent() + "]";
}

private void maybeLog(Exception ex) {
if (logger.isDebugEnabled()) {
logger.debug("Error while parsing document: " + ex.getMessage(), ex);
logger.debug("Error while parsing document for index [" + indexName + "]: " + ex.getMessage(), ex);
} else if (IntervalThrottler.DOCUMENT_PARSING_FAILURE.accept()) {
logger.info("Error while parsing document: " + ex.getMessage(), ex);
logger.info("Error while parsing document for index [" + indexName + "]: " + ex.getMessage(), ex);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel;

import java.time.Clock;
import java.time.Instant;
import java.util.Objects;

public abstract class AmazonBedrockBaseClient implements AmazonBedrockClient {
protected final Integer modelKeysAndRegionHashcode;
protected Clock clock = Clock.systemUTC();
protected volatile Instant expiryTimestamp;

protected AmazonBedrockBaseClient(AmazonBedrockModel model, @Nullable TimeValue timeout) {
Objects.requireNonNull(model);
Expand All @@ -33,5 +35,10 @@ public final void setClock(Clock clock) {
this.clock = clock;
}

// used for testing
Instant getExpiryTimestamp() {
return this.expiryTimestamp;
}

abstract void close();
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ public class AmazonBedrockInferenceClient extends AmazonBedrockBaseClient {

private final BedrockRuntimeAsyncClient internalClient;
private final ThreadPool threadPool;
private volatile Instant expiryTimestamp;

public static AmazonBedrockBaseClient create(AmazonBedrockModel model, @Nullable TimeValue timeout, ThreadPool threadPool) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,24 @@ public AmazonBedrockInferenceClientCache(BiFunction<AmazonBedrockModel, TimeValu
}

public AmazonBedrockBaseClient getOrCreateClient(AmazonBedrockModel model, @Nullable TimeValue timeout) {
var returnClient = internalGetOrCreateClient(model, timeout);
flushExpiredClients();
return returnClient;
return internalGetOrCreateClient(model, timeout);
}

private AmazonBedrockBaseClient internalGetOrCreateClient(AmazonBedrockModel model, @Nullable TimeValue timeout) {
final Integer modelHash = AmazonBedrockInferenceClient.getModelKeysAndRegionHashcode(model, timeout);
cacheLock.readLock().lock();
try {
return clientsCache.computeIfAbsent(modelHash, hashKey -> {
final AmazonBedrockBaseClient builtClient = creator.apply(model, timeout);
builtClient.setClock(clock);
builtClient.resetExpiration();
return builtClient;
return clientsCache.compute(modelHash, (hashKey, client) -> {
AmazonBedrockBaseClient clientToUse = client;
if (clientToUse == null) {
clientToUse = creator.apply(model, timeout);
}

// for testing - would be nice to refactor client factory in the future to take clock as parameter
clientToUse.setClock(clock);
clientToUse.resetExpiration();
return clientToUse;
});
} finally {
cacheLock.readLock().unlock();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,36 @@ public void testCache_ReturnsSameObject() throws IOException {
assertThat(cacheInstance.clientCount(), is(0));
}

public void testCache_ItUpdatesExpirationForExistingClients() throws IOException {
var clock = Clock.fixed(Instant.now(), ZoneId.systemDefault());
AmazonBedrockInferenceClientCache cacheInstance;
try (var cache = new AmazonBedrockInferenceClientCache(AmazonBedrockMockInferenceClient::create, clock)) {
cacheInstance = cache;

var model = AmazonBedrockEmbeddingsModelTests.createModel(
"inferenceId",
"testregion",
"model",
AmazonBedrockProvider.AMAZONTITAN,
"access_key",
"secret_key"
);

var client = cache.getOrCreateClient(model, null);
var expiryTimestamp = client.getExpiryTimestamp();
assertThat(cache.clientCount(), is(1));

// set clock to clock + 1 minutes so cache hasn't expired
cache.setClock(Clock.fixed(clock.instant().plus(Duration.ofMinutes(1)), ZoneId.systemDefault()));

var regetClient = cache.getOrCreateClient(model, null);

assertThat(client, sameInstance(regetClient));
assertNotEquals(expiryTimestamp, regetClient.getExpiryTimestamp());
}
assertThat(cacheInstance.clientCount(), is(0));
}

public void testCache_ItEvictsExpiredClients() throws IOException {
var clock = Clock.fixed(Instant.now(), ZoneId.systemDefault());
AmazonBedrockInferenceClientCache cacheInstance;
Expand All @@ -76,6 +106,10 @@ public void testCache_ItEvictsExpiredClients() throws IOException {
);

var client = cache.getOrCreateClient(model, null);
assertThat(cache.clientCount(), is(1));

// set clock to clock + 2 minutes
cache.setClock(Clock.fixed(clock.instant().plus(Duration.ofMinutes(2)), ZoneId.systemDefault()));

var secondModel = AmazonBedrockEmbeddingsModelTests.createModel(
"inferenceId_two",
Expand All @@ -86,22 +120,25 @@ public void testCache_ItEvictsExpiredClients() throws IOException {
"other_secret_key"
);

assertThat(cache.clientCount(), is(1));

var secondClient = cache.getOrCreateClient(secondModel, null);
assertThat(client, not(sameInstance(secondClient)));

assertThat(cache.clientCount(), is(2));

// set clock to after expiry
// set clock to after expiry of first client but not after expiry of second client
cache.setClock(Clock.fixed(clock.instant().plus(Duration.ofMinutes(CLIENT_CACHE_EXPIRY_MINUTES + 1)), ZoneId.systemDefault()));

// get another client, this will ensure flushExpiredClients is called
// retrieve the second client, this will ensure flushExpiredClients is called
var regetSecondClient = cache.getOrCreateClient(secondModel, null);
assertThat(secondClient, sameInstance(regetSecondClient));

// expired first client should have been flushed
assertThat(cache.clientCount(), is(1));

var regetFirstClient = cache.getOrCreateClient(model, null);
assertThat(client, not(sameInstance(regetFirstClient)));

assertThat(cache.clientCount(), is(2));
}
assertThat(cacheInstance.clientCount(), is(0));
}
Expand Down

0 comments on commit 5684bec

Please sign in to comment.