Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring gated and ref-counted interfaces and their implementations #2396

Merged
merged 6 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@

package org.opensearch.common.util.concurrent;

import org.opensearch.common.concurrent.OneWayGate;
import org.opensearch.test.OpenSearchTestCase;
import org.hamcrest.Matchers;

import java.io.IOException;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
Expand Down Expand Up @@ -138,19 +138,19 @@ public void run() {

private final class MyRefCounted extends AbstractRefCounted {

private final AtomicBoolean closed = new AtomicBoolean(false);
private final OneWayGate gate = new OneWayGate();

MyRefCounted() {
super("test");
}

@Override
protected void closeInternal() {
this.closed.set(true);
gate.close();
}

public void ensureOpen() {
if (closed.get()) {
if (gate.isClosed()) {
assert this.refCount() == 0;
throw new IllegalStateException("closed");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,42 +33,15 @@
package org.opensearch.discovery.ec2;

import com.amazonaws.services.ec2.AmazonEC2;

import org.opensearch.common.lease.Releasable;
import org.opensearch.common.util.concurrent.AbstractRefCounted;
import org.opensearch.common.concurrent.RefCountedReleasable;

/**
* Handles the shutdown of the wrapped {@link AmazonEC2} using reference
* counting.
*/
public class AmazonEc2Reference extends AbstractRefCounted implements Releasable {

private final AmazonEC2 client;
public class AmazonEc2Reference extends RefCountedReleasable<AmazonEC2> {

AmazonEc2Reference(AmazonEC2 client) {
super("AWS_EC2_CLIENT");
this.client = client;
super("AWS_EC2_CLIENT", client, client::shutdown);
}

/**
* Call when the client is not needed anymore.
*/
@Override
public void close() {
decRef();
}

/**
* Returns the underlying `AmazonEC2` client. All method calls are permitted BUT
* NOT shutdown. Shutdown is called when reference count reaches 0.
*/
public AmazonEC2 client() {
return client;
}

@Override
protected void closeInternal() {
client.shutdown();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ protected List<TransportAddress> fetchDynamicNodes() {
// NOTE: we don't filter by security group during the describe instances request for two reasons:
// 1. differences in VPCs require different parameters during query (ID vs Name)
// 2. We want to use two different strategies: (all security groups vs. any security groups)
descInstances = SocketAccess.doPrivileged(() -> clientReference.client().describeInstances(buildDescribeInstancesRequest()));
descInstances = SocketAccess.doPrivileged(() -> clientReference.get().describeInstances(buildDescribeInstancesRequest()));
} catch (final AmazonClientException e) {
logger.info("Exception while retrieving instance list from AWS API: {}", e.getMessage());
logger.debug("Full exception:", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,15 @@ public void testNodeAttributesErrorLenient() throws Exception {

public void testDefaultEndpoint() throws IOException {
try (Ec2DiscoveryPluginMock plugin = new Ec2DiscoveryPluginMock(Settings.EMPTY)) {
final String endpoint = ((AmazonEC2Mock) plugin.ec2Service.client().client()).endpoint;
final String endpoint = ((AmazonEC2Mock) plugin.ec2Service.client().get()).endpoint;
assertThat(endpoint, is(""));
}
}

public void testSpecificEndpoint() throws IOException {
final Settings settings = Settings.builder().put(Ec2ClientSettings.ENDPOINT_SETTING.getKey(), "ec2.endpoint").build();
try (Ec2DiscoveryPluginMock plugin = new Ec2DiscoveryPluginMock(settings)) {
final String endpoint = ((AmazonEC2Mock) plugin.ec2Service.client().client()).endpoint;
final String endpoint = ((AmazonEC2Mock) plugin.ec2Service.client().get()).endpoint;
assertThat(endpoint, is("ec2.endpoint"));
}
}
Expand Down Expand Up @@ -150,7 +150,7 @@ public void testClientSettingsReInit() throws IOException {
try (Ec2DiscoveryPluginMock plugin = new Ec2DiscoveryPluginMock(settings1)) {
try (AmazonEc2Reference clientReference = plugin.ec2Service.client()) {
{
final AWSCredentials credentials = ((AmazonEC2Mock) clientReference.client()).credentials.getCredentials();
final AWSCredentials credentials = ((AmazonEC2Mock) clientReference.get()).credentials.getCredentials();
assertThat(credentials.getAWSAccessKeyId(), is("ec2_access_1"));
assertThat(credentials.getAWSSecretKey(), is("ec2_secret_1"));
if (mockSecure1HasSessionToken) {
Expand All @@ -159,32 +159,32 @@ public void testClientSettingsReInit() throws IOException {
} else {
assertThat(credentials, instanceOf(BasicAWSCredentials.class));
}
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyUsername(), is("proxy_username_1"));
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyPassword(), is("proxy_password_1"));
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyHost(), is("proxy_host_1"));
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyPort(), is(881));
assertThat(((AmazonEC2Mock) clientReference.client()).endpoint, is("ec2_endpoint_1"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyUsername(), is("proxy_username_1"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyPassword(), is("proxy_password_1"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyHost(), is("proxy_host_1"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyPort(), is(881));
assertThat(((AmazonEC2Mock) clientReference.get()).endpoint, is("ec2_endpoint_1"));
}
// reload secure settings2
plugin.reload(settings2);
// client is not released, it is still using the old settings
{
final AWSCredentials credentials = ((AmazonEC2Mock) clientReference.client()).credentials.getCredentials();
final AWSCredentials credentials = ((AmazonEC2Mock) clientReference.get()).credentials.getCredentials();
if (mockSecure1HasSessionToken) {
assertThat(credentials, instanceOf(BasicSessionCredentials.class));
assertThat(((BasicSessionCredentials) credentials).getSessionToken(), is("ec2_session_token_1"));
} else {
assertThat(credentials, instanceOf(BasicAWSCredentials.class));
}
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyUsername(), is("proxy_username_1"));
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyPassword(), is("proxy_password_1"));
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyHost(), is("proxy_host_1"));
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyPort(), is(881));
assertThat(((AmazonEC2Mock) clientReference.client()).endpoint, is("ec2_endpoint_1"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyUsername(), is("proxy_username_1"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyPassword(), is("proxy_password_1"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyHost(), is("proxy_host_1"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyPort(), is(881));
assertThat(((AmazonEC2Mock) clientReference.get()).endpoint, is("ec2_endpoint_1"));
}
}
try (AmazonEc2Reference clientReference = plugin.ec2Service.client()) {
final AWSCredentials credentials = ((AmazonEC2Mock) clientReference.client()).credentials.getCredentials();
final AWSCredentials credentials = ((AmazonEC2Mock) clientReference.get()).credentials.getCredentials();
assertThat(credentials.getAWSAccessKeyId(), is("ec2_access_2"));
assertThat(credentials.getAWSSecretKey(), is("ec2_secret_2"));
if (mockSecure2HasSessionToken) {
Expand All @@ -193,11 +193,11 @@ public void testClientSettingsReInit() throws IOException {
} else {
assertThat(credentials, instanceOf(BasicAWSCredentials.class));
}
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyUsername(), is("proxy_username_2"));
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyPassword(), is("proxy_password_2"));
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyHost(), is("proxy_host_2"));
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyPort(), is(882));
assertThat(((AmazonEC2Mock) clientReference.client()).endpoint, is("ec2_endpoint_2"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyUsername(), is("proxy_username_2"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyPassword(), is("proxy_password_2"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyHost(), is("proxy_host_2"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyPort(), is(882));
assertThat(((AmazonEC2Mock) clientReference.get()).endpoint, is("ec2_endpoint_2"));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,45 +32,17 @@

package org.opensearch.repositories.s3;

import org.opensearch.common.util.concurrent.AbstractRefCounted;

import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.AmazonS3Client;

import org.opensearch.common.lease.Releasable;
import org.opensearch.common.concurrent.RefCountedReleasable;

/**
* Handles the shutdown of the wrapped {@link AmazonS3Client} using reference
* counting.
*/
public class AmazonS3Reference extends AbstractRefCounted implements Releasable {

private final AmazonS3 client;
public class AmazonS3Reference extends RefCountedReleasable<AmazonS3> {

AmazonS3Reference(AmazonS3 client) {
super("AWS_S3_CLIENT");
this.client = client;
}

/**
* Call when the client is not needed anymore.
*/
@Override
public void close() {
decRef();
super("AWS_S3_CLIENT", client, client::shutdown);
}

/**
* Returns the underlying `AmazonS3` client. All method calls are permitted BUT
* NOT shutdown. Shutdown is called when reference count reaches 0.
*/
public AmazonS3 client() {
return client;
}

@Override
protected void closeInternal() {
client.shutdown();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class S3BlobContainer extends AbstractBlobContainer {
@Override
public boolean blobExists(String blobName) {
try (AmazonS3Reference clientReference = blobStore.clientReference()) {
return SocketAccess.doPrivileged(() -> clientReference.client().doesObjectExist(blobStore.bucket(), buildKey(blobName)));
return SocketAccess.doPrivileged(() -> clientReference.get().doesObjectExist(blobStore.bucket(), buildKey(blobName)));
} catch (final Exception e) {
throw new BlobStoreException("Failed to check if blob [" + blobName + "] exists", e);
}
Expand Down Expand Up @@ -169,13 +169,13 @@ public DeleteResult delete() throws IOException {
ObjectListing list;
if (prevListing != null) {
final ObjectListing finalPrevListing = prevListing;
list = SocketAccess.doPrivileged(() -> clientReference.client().listNextBatchOfObjects(finalPrevListing));
list = SocketAccess.doPrivileged(() -> clientReference.get().listNextBatchOfObjects(finalPrevListing));
} else {
final ListObjectsRequest listObjectsRequest = new ListObjectsRequest();
listObjectsRequest.setBucketName(blobStore.bucket());
listObjectsRequest.setPrefix(keyPath);
listObjectsRequest.setRequestMetricCollector(blobStore.listMetricCollector);
list = SocketAccess.doPrivileged(() -> clientReference.client().listObjects(listObjectsRequest));
list = SocketAccess.doPrivileged(() -> clientReference.get().listObjects(listObjectsRequest));
}
final List<String> blobsToDelete = new ArrayList<>();
list.getObjectSummaries().forEach(s3ObjectSummary -> {
Expand Down Expand Up @@ -236,7 +236,7 @@ private void doDeleteBlobs(List<String> blobNames, boolean relative) throws IOEx
.map(DeleteObjectsRequest.KeyVersion::getKey)
.collect(Collectors.toList());
try {
clientReference.client().deleteObjects(deleteRequest);
clientReference.get().deleteObjects(deleteRequest);
outstanding.removeAll(keysInRequest);
} catch (MultiObjectDeleteException e) {
// We are sending quiet mode requests so we can't use the deleted keys entry on the exception and instead
Expand Down Expand Up @@ -324,9 +324,9 @@ private static List<ObjectListing> executeListing(AmazonS3Reference clientRefere
ObjectListing list;
if (prevListing != null) {
final ObjectListing finalPrevListing = prevListing;
list = SocketAccess.doPrivileged(() -> clientReference.client().listNextBatchOfObjects(finalPrevListing));
list = SocketAccess.doPrivileged(() -> clientReference.get().listNextBatchOfObjects(finalPrevListing));
} else {
list = SocketAccess.doPrivileged(() -> clientReference.client().listObjects(listObjectsRequest));
list = SocketAccess.doPrivileged(() -> clientReference.get().listObjects(listObjectsRequest));
}
results.add(list);
if (list.isTruncated()) {
Expand Down Expand Up @@ -374,7 +374,7 @@ void executeSingleUpload(final S3BlobStore blobStore, final String blobName, fin
putRequest.setRequestMetricCollector(blobStore.putMetricCollector);

try (AmazonS3Reference clientReference = blobStore.clientReference()) {
SocketAccess.doPrivilegedVoid(() -> { clientReference.client().putObject(putRequest); });
SocketAccess.doPrivilegedVoid(() -> { clientReference.get().putObject(putRequest); });
} catch (final AmazonClientException e) {
throw new IOException("Unable to upload object [" + blobName + "] using a single upload", e);
}
Expand Down Expand Up @@ -413,7 +413,7 @@ void executeMultipartUpload(final S3BlobStore blobStore, final String blobName,
}
try (AmazonS3Reference clientReference = blobStore.clientReference()) {

uploadId.set(SocketAccess.doPrivileged(() -> clientReference.client().initiateMultipartUpload(initRequest).getUploadId()));
uploadId.set(SocketAccess.doPrivileged(() -> clientReference.get().initiateMultipartUpload(initRequest).getUploadId()));
if (Strings.isEmpty(uploadId.get())) {
throw new IOException("Failed to initialize multipart upload " + blobName);
}
Expand All @@ -439,7 +439,7 @@ void executeMultipartUpload(final S3BlobStore blobStore, final String blobName,
}
bytesCount += uploadRequest.getPartSize();

final UploadPartResult uploadResponse = SocketAccess.doPrivileged(() -> clientReference.client().uploadPart(uploadRequest));
final UploadPartResult uploadResponse = SocketAccess.doPrivileged(() -> clientReference.get().uploadPart(uploadRequest));
parts.add(uploadResponse.getPartETag());
}

Expand All @@ -456,7 +456,7 @@ void executeMultipartUpload(final S3BlobStore blobStore, final String blobName,
parts
);
complRequest.setRequestMetricCollector(blobStore.multiPartUploadMetricCollector);
SocketAccess.doPrivilegedVoid(() -> clientReference.client().completeMultipartUpload(complRequest));
SocketAccess.doPrivilegedVoid(() -> clientReference.get().completeMultipartUpload(complRequest));
success = true;

} catch (final AmazonClientException e) {
Expand All @@ -465,7 +465,7 @@ void executeMultipartUpload(final S3BlobStore blobStore, final String blobName,
if ((success == false) && Strings.hasLength(uploadId.get())) {
final AbortMultipartUploadRequest abortRequest = new AbortMultipartUploadRequest(bucketName, blobName, uploadId.get());
try (AmazonS3Reference clientReference = blobStore.clientReference()) {
SocketAccess.doPrivilegedVoid(() -> clientReference.client().abortMultipartUpload(abortRequest));
SocketAccess.doPrivilegedVoid(() -> clientReference.get().abortMultipartUpload(abortRequest));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ private void openStream() throws IOException {
+ end;
getObjectRequest.setRange(Math.addExact(start, currentOffset), end);
}
final S3Object s3Object = SocketAccess.doPrivileged(() -> clientReference.client().getObject(getObjectRequest));
final S3Object s3Object = SocketAccess.doPrivileged(() -> clientReference.get().getObject(getObjectRequest));
this.currentStreamLastOffset = Math.addExact(Math.addExact(start, currentOffset), getStreamLength(s3Object));
this.currentStream = s3Object.getObjectContent();
} catch (final AmazonClientException e) {
Expand Down
Loading