Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved threading capabilities of S3+parquet #5451

Merged
merged 11 commits into from
May 8, 2024
13 changes: 11 additions & 2 deletions Util/src/main/java/io/deephaven/util/thread/ThreadHelpers.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,17 @@
import io.deephaven.configuration.Configuration;

public class ThreadHelpers {
public static int getNumThreadsFromConfig(final String configKey) {
final int numThreads = Configuration.getInstance().getIntegerWithDefault(configKey, -1);
/**
* Get the number of threads to use for a given configuration key, defaulting to the number of available processors
* if the configuration key is set to a non-positive value, or the configuration key is not set and the provided
* default is non-positive.
*
* @param configKey The configuration key to look up
* @param defaultValue The default value to use if the configuration key is not set
* @return The number of threads to use
*/
public static int getOrComputeThreadCountProperty(final String configKey, final int defaultValue) {
final int numThreads = Configuration.getInstance().getIntegerWithDefault(configKey, defaultValue);
if (numThreads <= 0) {
return Runtime.getRuntime().availableProcessors();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

import static io.deephaven.util.thread.ThreadHelpers.getNumThreadsFromConfig;
import static io.deephaven.util.thread.ThreadHelpers.getOrComputeThreadCountProperty;

/**
* Implementation of OperationInitializer that delegates to a pool of threads.
Expand All @@ -26,8 +26,9 @@ public class OperationInitializationThreadPool implements OperationInitializer {
/**
* The number of threads that will be used for parallel initialization in this process
*/
private static final int NUM_THREADS = getNumThreadsFromConfig("OperationInitializationThreadPool.threads");
private static final ThreadLocal<Boolean> isInitializationThread = ThreadLocal.withInitial(() -> false);
private static final int NUM_THREADS =
getOrComputeThreadCountProperty("OperationInitializationThreadPool.threads", -1);
private final ThreadLocal<Boolean> isInitializationThread = ThreadLocal.withInitial(() -> false);

private final ThreadPoolExecutor executorService;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
//
package io.deephaven.extensions.s3;

import io.deephaven.configuration.Configuration;
import io.deephaven.internal.log.LoggerFactory;
import io.deephaven.io.logger.Logger;
import org.jetbrains.annotations.NotNull;
Expand All @@ -24,13 +23,14 @@
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;

import static io.deephaven.util.thread.ThreadHelpers.getNumThreadsFromConfig;
import static io.deephaven.util.thread.ThreadHelpers.getOrComputeThreadCountProperty;

class S3AsyncClientFactory {

private static final int NUM_FUTURE_COMPLETION_THREADS = getNumThreadsFromConfig("S3.numFutureCompletionThreads");
private static final int NUM_FUTURE_COMPLETION_THREADS =
getOrComputeThreadCountProperty("S3.numFutureCompletionThreads", -1);
private static final int NUM_SCHEDULED_EXECUTOR_THREADS =
Configuration.getInstance().getIntegerWithDefault("S3.numScheduledExecutorThreads", 5);
getOrComputeThreadCountProperty("S3.numScheduledExecutorThreads", 5);

private static final Logger log = LoggerFactory.getLogger(S3AsyncClientFactory.class);
private static final Map<HttpClientConfig, SdkAsyncHttpClient> httpClientCache = new ConcurrentHashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@
import java.nio.channels.SeekableByteChannel;
import java.nio.file.Path;
import java.util.Iterator;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Spliterator;
import java.util.Spliterators;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

Expand Down Expand Up @@ -62,56 +63,26 @@ final class S3SeekableChannelProvider implements SeekableChannelsProvider {
private final S3AsyncClient s3AsyncClient;
private final S3Instructions s3Instructions;

private SoftReference<KeyedObjectHashMap<URI, FileSizeInfo>> uriToFileSizeSoftRef;
@SuppressWarnings("rawtypes")
private static final AtomicReferenceFieldUpdater<S3SeekableChannelProvider, SoftReference> FILE_SIZE_CACHE_REF_UPDATER =
AtomicReferenceFieldUpdater.newUpdater(S3SeekableChannelProvider.class, SoftReference.class,
"fileSizeCacheRef");

private static final class FileSizeInfo {
private final URI uri;
private final long size;

FileSizeInfo(@NotNull final URI uri, final long size) {
this.uri = Require.neqNull(uri, "uri");
this.size = size;
}

@Override
public int hashCode() {
return Objects.hash(uri, size);
}

@Override
public boolean equals(final Object other) {
if (this == other) {
return true;
}
if (other == null || getClass() != other.getClass()) {
return false;
}
final FileSizeInfo that = (FileSizeInfo) other;
return size == that.size && uri.equals(that.uri);
}
}

private static final KeyedObjectKey<URI, FileSizeInfo> URI_MATCH_KEY = new KeyedObjectKey.Basic<>() {
@Override
public URI getKey(@NotNull final FileSizeInfo value) {
return value.uri;
}
};
private volatile SoftReference<Map<URI, FileSizeInfo>> fileSizeCacheRef;

S3SeekableChannelProvider(@NotNull final S3Instructions s3Instructions) {
this.s3AsyncClient = S3AsyncClientFactory.getAsyncClient(s3Instructions);
this.s3Instructions = s3Instructions;
this.uriToFileSizeSoftRef = new SoftReference<>(new KeyedObjectHashMap<>(URI_MATCH_KEY));
}

@Override
public SeekableByteChannel getReadChannel(@NotNull final SeekableChannelContext channelContext,
@NotNull final URI uri) {
final S3Uri s3Uri = s3AsyncClient.utilities().parseUri(uri);
// context is unused here, will be set before reading from the channel
final KeyedObjectHashMap<URI, FileSizeInfo> uriToFileSize = uriToFileSizeSoftRef.get();
if (uriToFileSize != null && uriToFileSize.containsKey(uri)) {
return new S3SeekableByteChannel(s3Uri, uriToFileSize.get(uri).size);
final Map<URI, FileSizeInfo> fileSizeCache = fileSizeCacheRef.get();
malhotrashivam marked this conversation as resolved.
Show resolved Hide resolved
if (fileSizeCache != null && fileSizeCache.containsKey(uri)) {
return new S3SeekableByteChannel(s3Uri, fileSizeCache.get(uri).size);
}
return new S3SeekableByteChannel(s3Uri);
}
Expand Down Expand Up @@ -236,7 +207,7 @@ private void fetchNextBatch() throws IOException {
+ s3Object.key() + " and bucket " + bucketName + " inside directory "
+ directory, e);
}
updateFileSizeCache(uri, s3Object.size());
updateFileSizeCache(getFileSizeCache(), uri, s3Object.size());
return uri;
}).iterator();
// The following token is null when the last batch is fetched.
Expand All @@ -248,25 +219,53 @@ private void fetchNextBatch() throws IOException {
}

/**
* Update the file size cache with the given URI and size.
* Get a strong reference to the file size cache, creating it if necessary.
*/
private void updateFileSizeCache(@NotNull final URI uri, final long size) {
KeyedObjectHashMap<URI, FileSizeInfo> uriToFileSize = uriToFileSizeSoftRef.get();
if (uriToFileSize != null) {
uriToFileSize.compute(uri, (key, existingInfo) -> {
if (existingInfo == null) {
return new FileSizeInfo(uri, size);
} else if (existingInfo.size != size) {
throw new IllegalStateException("Existing size " + existingInfo.size + " does not match "
+ " the new size " + size + " for key " + key);
}
return existingInfo;
});
} else {
uriToFileSize = new KeyedObjectHashMap<>(URI_MATCH_KEY);
uriToFileSize.put(uri, new FileSizeInfo(uri, size));
uriToFileSizeSoftRef = new SoftReference<>(uriToFileSize);
private Map<URI, FileSizeInfo> getFileSizeCache() {
SoftReference<Map<URI, FileSizeInfo>> cacheRef;
Map<URI, FileSizeInfo> cache;
while ((cacheRef = fileSizeCacheRef) == null || (cache = cacheRef.get()) == null) {
malhotrashivam marked this conversation as resolved.
Show resolved Hide resolved
if (FILE_SIZE_CACHE_REF_UPDATER.compareAndSet(this, cacheRef,
new SoftReference<>(cache = new KeyedObjectHashMap<>(FileSizeInfo.URI_MATCH_KEY)))) {
return cache;
}
}
return cache;
}

/**
* Update the given file size cache with the given URI and size.
*/
private static void updateFileSizeCache(
@NotNull final Map<URI, FileSizeInfo> fileSizeCache,
@NotNull final URI uri,
final long size) {
fileSizeCache.compute(uri, (key, existingInfo) -> {
if (existingInfo == null) {
return new FileSizeInfo(uri, size);
} else if (existingInfo.size != size) {
throw new IllegalStateException("Existing size " + existingInfo.size + " does not match "
+ " the new size " + size + " for key " + key);
}
return existingInfo;
});
}

private static final class FileSizeInfo {
private final URI uri;
private final long size;

FileSizeInfo(@NotNull final URI uri, final long size) {
this.uri = Require.neqNull(uri, "uri");
this.size = size;
}

private static final KeyedObjectKey<URI, FileSizeInfo> URI_MATCH_KEY = new KeyedObjectKey.Basic<>() {
@Override
public URI getKey(@NotNull final FileSizeInfo value) {
return value.uri;
}
};
}

@Override
Expand Down
Loading