From 082d425c21a77185f1ce303998ab343e3ce475c4 Mon Sep 17 00:00:00 2001 From: Dev Agarwal Date: Thu, 31 Aug 2023 22:29:22 +0530 Subject: [PATCH 1/4] Added sampler based on Blanket Probabilistic Sampling rate and Override for on demand (#9522) Signed-off-by: Dev Agarwal --- CHANGELOG.md | 1 + .../IntegrationTestOTelTelemetryPlugin.java | 6 +- .../telemetry/OTelTelemetryPlugin.java | 8 +- .../tracing/OTelResourceProvider.java | 8 +- .../tracing/sampler/ProbabilisticSampler.java | 82 +++++++++++++++++ .../tracing/sampler/RequestSampler.java | 67 ++++++++++++++ .../tracing/sampler/package-info.java | 12 +++ .../telemetry/OTelTelemetryPluginTests.java | 6 +- .../sampler/ProbabilisticSamplerTests.java | 64 +++++++++++++ .../tracing/sampler/RequestSamplerTests.java | 92 +++++++++++++++++++ .../common/settings/ClusterSettings.java | 2 +- .../opensearch/plugins/TelemetryPlugin.java | 2 +- .../telemetry/TelemetrySettings.java | 29 ++++++ 13 files changed, 367 insertions(+), 12 deletions(-) create mode 100644 plugins/telemetry-otel/src/main/java/org/opensearch/telemetry/tracing/sampler/ProbabilisticSampler.java create mode 100644 plugins/telemetry-otel/src/main/java/org/opensearch/telemetry/tracing/sampler/RequestSampler.java create mode 100644 plugins/telemetry-otel/src/main/java/org/opensearch/telemetry/tracing/sampler/package-info.java create mode 100644 plugins/telemetry-otel/src/test/java/org/opensearch/telemetry/tracing/sampler/ProbabilisticSamplerTests.java create mode 100644 plugins/telemetry-otel/src/test/java/org/opensearch/telemetry/tracing/sampler/RequestSamplerTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 5df33dc6e3ace..0f3f4be1d295b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -169,6 +169,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - [Remote Store] Rate limiter integration for remote store uploads and downloads([#9448](https://github.com/opensearch-project/OpenSearch/pull/9448/)) - [Remote Store] Implicitly use replication type SEGMENT for remote store clusters ([#9264](https://github.com/opensearch-project/OpenSearch/pull/9264)) - Use non-concurrent path for sort request on timeseries index and field([#9562](https://github.com/opensearch-project/OpenSearch/pull/9562)) +- Added sampler based on `Blanket Probabilistic Sampling rate` and `Override for on demand` ([#9621](https://github.com/opensearch-project/OpenSearch/issues/9621)) ### Deprecated diff --git a/plugins/telemetry-otel/src/internalClusterTest/java/org/opensearch/telemetry/tracing/IntegrationTestOTelTelemetryPlugin.java b/plugins/telemetry-otel/src/internalClusterTest/java/org/opensearch/telemetry/tracing/IntegrationTestOTelTelemetryPlugin.java index 57dbf4e001be4..ed4d13f3abb7d 100644 --- a/plugins/telemetry-otel/src/internalClusterTest/java/org/opensearch/telemetry/tracing/IntegrationTestOTelTelemetryPlugin.java +++ b/plugins/telemetry-otel/src/internalClusterTest/java/org/opensearch/telemetry/tracing/IntegrationTestOTelTelemetryPlugin.java @@ -32,10 +32,10 @@ public IntegrationTestOTelTelemetryPlugin(Settings settings) { /** * This method overrides getTelemetry() method in OTel plugin class, so we create only one instance of global OpenTelemetry * resetForTest() will set OpenTelemetry to null again. - * @param settings cluster settings + * @param telemetrySettings telemetry settings */ - public Optional getTelemetry(TelemetrySettings settings) { + public Optional getTelemetry(TelemetrySettings telemetrySettings) { GlobalOpenTelemetry.resetForTest(); - return super.getTelemetry(settings); + return super.getTelemetry(telemetrySettings); } } diff --git a/plugins/telemetry-otel/src/main/java/org/opensearch/telemetry/OTelTelemetryPlugin.java b/plugins/telemetry-otel/src/main/java/org/opensearch/telemetry/OTelTelemetryPlugin.java index a1ca3adf4d2a2..1af88196e3727 100644 --- a/plugins/telemetry-otel/src/main/java/org/opensearch/telemetry/OTelTelemetryPlugin.java +++ b/plugins/telemetry-otel/src/main/java/org/opensearch/telemetry/OTelTelemetryPlugin.java @@ -49,8 +49,8 @@ public List> getSettings() { } @Override - public Optional getTelemetry(TelemetrySettings settings) { - return Optional.of(telemetry()); + public Optional getTelemetry(TelemetrySettings telemetrySettings) { + return Optional.of(telemetry(telemetrySettings)); } @Override @@ -58,8 +58,8 @@ public String getName() { return OTEL_TRACER_NAME; } - private Telemetry telemetry() { - return new OTelTelemetry(new OTelTracingTelemetry(OTelResourceProvider.get(settings)), new MetricsTelemetry() { + private Telemetry telemetry(TelemetrySettings telemetrySettings) { + return new OTelTelemetry(new OTelTracingTelemetry(OTelResourceProvider.get(telemetrySettings, settings)), new MetricsTelemetry() { }); } diff --git a/plugins/telemetry-otel/src/main/java/org/opensearch/telemetry/tracing/OTelResourceProvider.java b/plugins/telemetry-otel/src/main/java/org/opensearch/telemetry/tracing/OTelResourceProvider.java index 1ec4818b8b73e..b395a335a4d83 100644 --- a/plugins/telemetry-otel/src/main/java/org/opensearch/telemetry/tracing/OTelResourceProvider.java +++ b/plugins/telemetry-otel/src/main/java/org/opensearch/telemetry/tracing/OTelResourceProvider.java @@ -9,7 +9,10 @@ package org.opensearch.telemetry.tracing; import org.opensearch.common.settings.Settings; +import org.opensearch.telemetry.TelemetrySettings; import org.opensearch.telemetry.tracing.exporter.OTelSpanExporterFactory; +import org.opensearch.telemetry.tracing.sampler.ProbabilisticSampler; +import org.opensearch.telemetry.tracing.sampler.RequestSampler; import java.util.concurrent.TimeUnit; @@ -37,15 +40,16 @@ private OTelResourceProvider() {} /** * Creates OpenTelemetry instance with default configuration + * @param telemetrySettings telemetry settings * @param settings cluster settings * @return OpenTelemetry instance */ - public static OpenTelemetry get(Settings settings) { + public static OpenTelemetry get(TelemetrySettings telemetrySettings, Settings settings) { return get( settings, OTelSpanExporterFactory.create(settings), ContextPropagators.create(W3CTraceContextPropagator.getInstance()), - Sampler.alwaysOn() + Sampler.parentBased(new RequestSampler(new ProbabilisticSampler(telemetrySettings))) ); } diff --git a/plugins/telemetry-otel/src/main/java/org/opensearch/telemetry/tracing/sampler/ProbabilisticSampler.java b/plugins/telemetry-otel/src/main/java/org/opensearch/telemetry/tracing/sampler/ProbabilisticSampler.java new file mode 100644 index 0000000000000..cab7b1a4af2e6 --- /dev/null +++ b/plugins/telemetry-otel/src/main/java/org/opensearch/telemetry/tracing/sampler/ProbabilisticSampler.java @@ -0,0 +1,82 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.telemetry.tracing.sampler; + +import org.opensearch.telemetry.TelemetrySettings; + +import java.util.List; +import java.util.Objects; + +import io.opentelemetry.api.common.Attributes; +import io.opentelemetry.api.trace.SpanKind; +import io.opentelemetry.context.Context; +import io.opentelemetry.sdk.trace.data.LinkData; +import io.opentelemetry.sdk.trace.samplers.Sampler; +import io.opentelemetry.sdk.trace.samplers.SamplingResult; + +/** + * ProbabilisticSampler implements a head-based sampling strategy based on provided settings. + */ +public class ProbabilisticSampler implements Sampler { + private Sampler defaultSampler; + private final TelemetrySettings telemetrySettings; + private double samplingRatio; + + /** + * Constructor + * + * @param telemetrySettings Telemetry settings. + */ + public ProbabilisticSampler(TelemetrySettings telemetrySettings) { + this.telemetrySettings = Objects.requireNonNull(telemetrySettings); + this.samplingRatio = telemetrySettings.getTracerHeadSamplerSamplingRatio(); + this.defaultSampler = Sampler.traceIdRatioBased(samplingRatio); + } + + Sampler getSampler() { + double newSamplingRatio = telemetrySettings.getTracerHeadSamplerSamplingRatio(); + if (isSamplingRatioChanged(newSamplingRatio)) { + synchronized (this) { + this.samplingRatio = newSamplingRatio; + defaultSampler = Sampler.traceIdRatioBased(samplingRatio); + } + } + return defaultSampler; + } + + private boolean isSamplingRatioChanged(double newSamplingRatio) { + return Double.compare(this.samplingRatio, newSamplingRatio) != 0; + } + + double getSamplingRatio() { + return samplingRatio; + } + + @Override + public SamplingResult shouldSample( + Context parentContext, + String traceId, + String name, + SpanKind spanKind, + Attributes attributes, + List parentLinks + ) { + return getSampler().shouldSample(parentContext, traceId, name, spanKind, attributes, parentLinks); + } + + @Override + public String getDescription() { + return "Probabilistic Sampler"; + } + + @Override + public String toString() { + return getDescription(); + } +} diff --git a/plugins/telemetry-otel/src/main/java/org/opensearch/telemetry/tracing/sampler/RequestSampler.java b/plugins/telemetry-otel/src/main/java/org/opensearch/telemetry/tracing/sampler/RequestSampler.java new file mode 100644 index 0000000000000..9ea681370a3ec --- /dev/null +++ b/plugins/telemetry-otel/src/main/java/org/opensearch/telemetry/tracing/sampler/RequestSampler.java @@ -0,0 +1,67 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.telemetry.tracing.sampler; + +import java.util.List; + +import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.api.common.Attributes; +import io.opentelemetry.api.trace.SpanKind; +import io.opentelemetry.context.Context; +import io.opentelemetry.sdk.trace.data.LinkData; +import io.opentelemetry.sdk.trace.samplers.Sampler; +import io.opentelemetry.sdk.trace.samplers.SamplingResult; + +/** + * HeadBased sampler + */ +public class RequestSampler implements Sampler { + private final Sampler defaultSampler; + + // TODO: Pick value of TRACE from PR #9415. + private static final String TRACE = "trace"; + + /** + * Creates Head based sampler + * @param defaultSampler defaultSampler + */ + public RequestSampler(Sampler defaultSampler) { + this.defaultSampler = defaultSampler; + } + + @Override + public SamplingResult shouldSample( + Context parentContext, + String traceId, + String name, + SpanKind spanKind, + Attributes attributes, + List parentLinks + ) { + + final String trace = attributes.get(AttributeKey.stringKey(TRACE)); + + if (trace != null) { + return (Boolean.parseBoolean(trace) == true) ? SamplingResult.recordAndSample() : SamplingResult.drop(); + } else { + return defaultSampler.shouldSample(parentContext, traceId, name, spanKind, attributes, parentLinks); + } + + } + + @Override + public String getDescription() { + return "Request Sampler"; + } + + @Override + public String toString() { + return getDescription(); + } +} diff --git a/plugins/telemetry-otel/src/main/java/org/opensearch/telemetry/tracing/sampler/package-info.java b/plugins/telemetry-otel/src/main/java/org/opensearch/telemetry/tracing/sampler/package-info.java new file mode 100644 index 0000000000000..6534b33f6177c --- /dev/null +++ b/plugins/telemetry-otel/src/main/java/org/opensearch/telemetry/tracing/sampler/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * This package contains classes needed for sampler. + */ +package org.opensearch.telemetry.tracing.sampler; diff --git a/plugins/telemetry-otel/src/test/java/org/opensearch/telemetry/OTelTelemetryPluginTests.java b/plugins/telemetry-otel/src/test/java/org/opensearch/telemetry/OTelTelemetryPluginTests.java index 611656942860f..8c2b5d14733e2 100644 --- a/plugins/telemetry-otel/src/test/java/org/opensearch/telemetry/OTelTelemetryPluginTests.java +++ b/plugins/telemetry-otel/src/test/java/org/opensearch/telemetry/OTelTelemetryPluginTests.java @@ -29,6 +29,8 @@ import static org.opensearch.telemetry.OTelTelemetrySettings.TRACER_EXPORTER_BATCH_SIZE_SETTING; import static org.opensearch.telemetry.OTelTelemetrySettings.TRACER_EXPORTER_DELAY_SETTING; import static org.opensearch.telemetry.OTelTelemetrySettings.TRACER_EXPORTER_MAX_QUEUE_SIZE_SETTING; +import static org.opensearch.telemetry.TelemetrySettings.TRACER_ENABLED_SETTING; +import static org.opensearch.telemetry.TelemetrySettings.TRACER_SAMPLER_PROBABILITY; public class OTelTelemetryPluginTests extends OpenSearchTestCase { @@ -42,7 +44,9 @@ public void setup() { // io.opentelemetry.sdk.OpenTelemetrySdk.close waits only for 10 seconds for shutdown to complete. Settings settings = Settings.builder().put(TRACER_EXPORTER_DELAY_SETTING.getKey(), "1s").build(); oTelTracerModulePlugin = new OTelTelemetryPlugin(settings); - telemetry = oTelTracerModulePlugin.getTelemetry(null); + telemetry = oTelTracerModulePlugin.getTelemetry( + new TelemetrySettings(Settings.EMPTY, new ClusterSettings(settings, Set.of(TRACER_ENABLED_SETTING, TRACER_SAMPLER_PROBABILITY))) + ); tracingTelemetry = telemetry.get().getTracingTelemetry(); } diff --git a/plugins/telemetry-otel/src/test/java/org/opensearch/telemetry/tracing/sampler/ProbabilisticSamplerTests.java b/plugins/telemetry-otel/src/test/java/org/opensearch/telemetry/tracing/sampler/ProbabilisticSamplerTests.java new file mode 100644 index 0000000000000..639dc341ef0db --- /dev/null +++ b/plugins/telemetry-otel/src/test/java/org/opensearch/telemetry/tracing/sampler/ProbabilisticSamplerTests.java @@ -0,0 +1,64 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.telemetry.tracing.sampler; + +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.telemetry.TelemetrySettings; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Set; + +import io.opentelemetry.sdk.trace.samplers.Sampler; + +import static org.opensearch.telemetry.OTelTelemetrySettings.TRACER_EXPORTER_DELAY_SETTING; +import static org.opensearch.telemetry.TelemetrySettings.TRACER_ENABLED_SETTING; +import static org.opensearch.telemetry.TelemetrySettings.TRACER_SAMPLER_PROBABILITY; + +public class ProbabilisticSamplerTests extends OpenSearchTestCase { + + // When ProbabilisticSampler is created with OTelTelemetrySettings as null + public void testProbabilisticSamplerWithNullSettings() { + // Verify that the constructor throws IllegalArgumentException when given null settings + assertThrows(NullPointerException.class, () -> { new ProbabilisticSampler(null); }); + } + + public void testDefaultGetSampler() { + Settings settings = Settings.builder().put(TRACER_EXPORTER_DELAY_SETTING.getKey(), "1s").build(); + TelemetrySettings telemetrySettings = new TelemetrySettings( + Settings.EMPTY, + new ClusterSettings(settings, Set.of(TRACER_SAMPLER_PROBABILITY, TRACER_ENABLED_SETTING)) + ); + + // Probabilistic Sampler + ProbabilisticSampler probabilisticSampler = new ProbabilisticSampler(telemetrySettings); + + assertNotNull(probabilisticSampler.getSampler()); + assertEquals(0.01, probabilisticSampler.getSamplingRatio(), 0.0d); + } + + public void testGetSamplerWithUpdatedSamplingRatio() { + Settings settings = Settings.builder().put(TRACER_EXPORTER_DELAY_SETTING.getKey(), "1s").build(); + TelemetrySettings telemetrySettings = new TelemetrySettings( + Settings.EMPTY, + new ClusterSettings(settings, Set.of(TRACER_SAMPLER_PROBABILITY, TRACER_ENABLED_SETTING)) + ); + + // Probabilistic Sampler + ProbabilisticSampler probabilisticSampler = new ProbabilisticSampler(telemetrySettings); + assertEquals(0.01d, probabilisticSampler.getSamplingRatio(), 0.0d); + + telemetrySettings.setSamplingProbability(0.02); + + // Need to call getSampler() to update the value of tracerHeadSamplerSamplingRatio + Sampler updatedProbabilisticSampler = probabilisticSampler.getSampler(); + assertEquals(0.02, probabilisticSampler.getSamplingRatio(), 0.0d); + } + +} diff --git a/plugins/telemetry-otel/src/test/java/org/opensearch/telemetry/tracing/sampler/RequestSamplerTests.java b/plugins/telemetry-otel/src/test/java/org/opensearch/telemetry/tracing/sampler/RequestSamplerTests.java new file mode 100644 index 0000000000000..facf04623ec46 --- /dev/null +++ b/plugins/telemetry-otel/src/test/java/org/opensearch/telemetry/tracing/sampler/RequestSamplerTests.java @@ -0,0 +1,92 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.telemetry.tracing.sampler; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; + +import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.api.common.Attributes; +import io.opentelemetry.api.trace.SpanKind; +import io.opentelemetry.context.Context; +import io.opentelemetry.sdk.trace.samplers.Sampler; +import io.opentelemetry.sdk.trace.samplers.SamplingResult; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class RequestSamplerTests extends OpenSearchTestCase { + + public void testShouldSampleWithTraceAttributeAsTrue() { + + // Create a mock default sampler + Sampler defaultSampler = mock(Sampler.class); + when(defaultSampler.shouldSample(any(), anyString(), anyString(), any(), any(), any())).thenReturn(SamplingResult.drop()); + + // Create an instance of HeadSampler with the mock default sampler + RequestSampler requestSampler = new RequestSampler(defaultSampler); + + // Create a mock Context and Attributes + Context parentContext = mock(Context.class); + Attributes attributes = Attributes.of(AttributeKey.stringKey("trace"), "true"); + + // Call shouldSample on HeadSampler + SamplingResult result = requestSampler.shouldSample( + parentContext, + "traceId", + "spanName", + SpanKind.INTERNAL, + attributes, + Collections.emptyList() + ); + + assertEquals(SamplingResult.recordAndSample(), result); + + // Verify that the default sampler's shouldSample method was not called + verify(defaultSampler, never()).shouldSample(any(), anyString(), anyString(), any(), any(), any()); + } + + public void testShouldSampleWithoutTraceAttribute() { + + // Create a mock default sampler + Sampler defaultSampler = mock(Sampler.class); + when(defaultSampler.shouldSample(any(), anyString(), anyString(), any(), any(), any())).thenReturn( + SamplingResult.recordAndSample() + ); + + // Create an instance of HeadSampler with the mock default sampler + RequestSampler requestSampler = new RequestSampler(defaultSampler); + + // Create a mock Context and Attributes + Context parentContext = mock(Context.class); + Attributes attributes = Attributes.empty(); + + // Call shouldSample on HeadSampler + SamplingResult result = requestSampler.shouldSample( + parentContext, + "traceId", + "spanName", + SpanKind.INTERNAL, + attributes, + Collections.emptyList() + ); + + // Verify that HeadSampler returned SamplingResult.recordAndSample() + assertEquals(SamplingResult.recordAndSample(), result); + + // Verify that the default sampler's shouldSample method was called + verify(defaultSampler).shouldSample(any(), anyString(), anyString(), any(), any(), any()); + } + +} diff --git a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java index e00e7e3bf4ea7..8093eab696779 100644 --- a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java @@ -694,6 +694,6 @@ public void apply(Settings value, Settings current, Settings previous) { SearchService.CONCURRENT_SEGMENT_SEARCH_TARGET_MAX_SLICE_COUNT_SETTING ), List.of(FeatureFlags.TELEMETRY), - List.of(TelemetrySettings.TRACER_ENABLED_SETTING) + List.of(TelemetrySettings.TRACER_ENABLED_SETTING, TelemetrySettings.TRACER_SAMPLER_PROBABILITY) ); } diff --git a/server/src/main/java/org/opensearch/plugins/TelemetryPlugin.java b/server/src/main/java/org/opensearch/plugins/TelemetryPlugin.java index 33dc9b7a0c843..66033df394d9f 100644 --- a/server/src/main/java/org/opensearch/plugins/TelemetryPlugin.java +++ b/server/src/main/java/org/opensearch/plugins/TelemetryPlugin.java @@ -18,7 +18,7 @@ */ public interface TelemetryPlugin { - Optional getTelemetry(TelemetrySettings settings); + Optional getTelemetry(TelemetrySettings telemetrySettings); String getName(); diff --git a/server/src/main/java/org/opensearch/telemetry/TelemetrySettings.java b/server/src/main/java/org/opensearch/telemetry/TelemetrySettings.java index 7c9e0d5ac8097..aa11a2879e4d7 100644 --- a/server/src/main/java/org/opensearch/telemetry/TelemetrySettings.java +++ b/server/src/main/java/org/opensearch/telemetry/TelemetrySettings.java @@ -23,12 +23,27 @@ public class TelemetrySettings { Setting.Property.Dynamic ); + /** + * Probability of sampler + */ + public static final Setting TRACER_SAMPLER_PROBABILITY = Setting.doubleSetting( + "telemetry.tracer.sampler.probability", + 0.01d, + 0.00d, + 1.00d, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + private volatile boolean tracingEnabled; + private volatile double samplingProbability; public TelemetrySettings(Settings settings, ClusterSettings clusterSettings) { this.tracingEnabled = TRACER_ENABLED_SETTING.get(settings); + this.samplingProbability = TRACER_SAMPLER_PROBABILITY.get(settings); clusterSettings.addSettingsUpdateConsumer(TRACER_ENABLED_SETTING, this::setTracingEnabled); + clusterSettings.addSettingsUpdateConsumer(TRACER_SAMPLER_PROBABILITY, this::setSamplingProbability); } public void setTracingEnabled(boolean tracingEnabled) { @@ -39,4 +54,18 @@ public boolean isTracingEnabled() { return tracingEnabled; } + /** + * Set sampling ratio + * @param samplingProbability double + */ + public void setSamplingProbability(double samplingProbability) { + this.samplingProbability = samplingProbability; + } + + /** + * Get sampling ratio + */ + public double getTracerHeadSamplerSamplingRatio() { + return samplingProbability; + } } From 1126d2f7caac7ec660975e63527404b9d9f6e087 Mon Sep 17 00:00:00 2001 From: Russ Cam Date: Fri, 1 Sep 2023 04:44:07 +1000 Subject: [PATCH 2/4] Expose DelimitedTermFrequencyTokenFilter (#9479) * Expose DelimitedTermFrequencyTokenFilter Relates: #9413 This commit exposes Lucene's delimited term frequency token filter to be able to provide term frequencies along with terms. Signed-off-by: Russ Cam * fix format violations Signed-off-by: Russ Cam * fix test and add to changelog Signed-off-by: Russ Cam * Address PR feedback - Add unit tests for DelimitedTermFrequencyTokenFilterFactory - Remove IllegalArgumentException as caught exception - Add skip to yaml rest tests to skip for version < 2.10 Signed-off-by: Russ Cam * formatting Signed-off-by: Russ Cam * Rename filter Signed-off-by: Russ Cam * update naming in REST tests Signed-off-by: Russ Cam --------- Signed-off-by: Russ Cam --- CHANGELOG.md | 1 + .../common/CommonAnalysisModulePlugin.java | 9 ++ ...imitedTermFrequencyTokenFilterFactory.java | 45 ++++++++++ .../common/CommonAnalysisFactoryTests.java | 2 + ...dTermFrequencyTokenFilterFactoryTests.java | 89 +++++++++++++++++++ .../test/analysis-common/40_token_filters.yml | 40 +++++++++ .../analysis/AnalysisFactoryTestCase.java | 4 +- 7 files changed, 187 insertions(+), 3 deletions(-) create mode 100644 modules/analysis-common/src/main/java/org/opensearch/analysis/common/DelimitedTermFrequencyTokenFilterFactory.java create mode 100644 modules/analysis-common/src/test/java/org/opensearch/analysis/common/DelimitedTermFrequencyTokenFilterFactoryTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f3f4be1d295b..d4a060cc16504 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -96,6 +96,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - [BWC and API enforcement] Decorate the existing APIs with proper annotations (part 1) ([#9520](https://github.com/opensearch-project/OpenSearch/pull/9520)) - Add concurrent segment search related metrics to node and index stats ([#9622](https://github.com/opensearch-project/OpenSearch/issues/9622)) - Decouple replication lag from logic to fail stale replicas ([#9507](https://github.com/opensearch-project/OpenSearch/pull/9507)) +- Expose DelimitedTermFrequencyTokenFilter to allow providing term frequencies along with terms ([#9479](https://github.com/opensearch-project/OpenSearch/pull/9479)) ### Dependencies - Bump `org.apache.logging.log4j:log4j-core` from 2.17.1 to 2.20.0 ([#8307](https://github.com/opensearch-project/OpenSearch/pull/8307)) diff --git a/modules/analysis-common/src/main/java/org/opensearch/analysis/common/CommonAnalysisModulePlugin.java b/modules/analysis-common/src/main/java/org/opensearch/analysis/common/CommonAnalysisModulePlugin.java index 46220f5369d16..b0d9c1765190a 100644 --- a/modules/analysis-common/src/main/java/org/opensearch/analysis/common/CommonAnalysisModulePlugin.java +++ b/modules/analysis-common/src/main/java/org/opensearch/analysis/common/CommonAnalysisModulePlugin.java @@ -89,6 +89,7 @@ import org.apache.lucene.analysis.lt.LithuanianAnalyzer; import org.apache.lucene.analysis.lv.LatvianAnalyzer; import org.apache.lucene.analysis.miscellaneous.ASCIIFoldingFilter; +import org.apache.lucene.analysis.miscellaneous.DelimitedTermFrequencyTokenFilter; import org.apache.lucene.analysis.miscellaneous.DisableGraphAttribute; import org.apache.lucene.analysis.miscellaneous.KeywordRepeatFilter; import org.apache.lucene.analysis.miscellaneous.LengthFilter; @@ -265,6 +266,7 @@ public Map> getTokenFilters() { ); filters.put("decimal_digit", DecimalDigitFilterFactory::new); filters.put("delimited_payload", DelimitedPayloadTokenFilterFactory::new); + filters.put("delimited_term_freq", DelimitedTermFrequencyTokenFilterFactory::new); filters.put("dictionary_decompounder", requiresAnalysisSettings(DictionaryCompoundWordTokenFilterFactory::new)); filters.put("dutch_stem", DutchStemTokenFilterFactory::new); filters.put("edge_ngram", EdgeNGramTokenFilterFactory::new); @@ -500,6 +502,13 @@ public List getPreConfiguredTokenFilters() { ) ) ); + filters.add( + PreConfiguredTokenFilter.singleton( + "delimited_term_freq", + false, + input -> new DelimitedTermFrequencyTokenFilter(input, DelimitedTermFrequencyTokenFilterFactory.DEFAULT_DELIMITER) + ) + ); filters.add(PreConfiguredTokenFilter.singleton("dutch_stem", false, input -> new SnowballFilter(input, new DutchStemmer()))); filters.add(PreConfiguredTokenFilter.singleton("edge_ngram", false, false, input -> new EdgeNGramTokenFilter(input, 1))); filters.add(PreConfiguredTokenFilter.openSearchVersion("edgeNGram", false, false, (reader, version) -> { diff --git a/modules/analysis-common/src/main/java/org/opensearch/analysis/common/DelimitedTermFrequencyTokenFilterFactory.java b/modules/analysis-common/src/main/java/org/opensearch/analysis/common/DelimitedTermFrequencyTokenFilterFactory.java new file mode 100644 index 0000000000000..8929a7c54ef4c --- /dev/null +++ b/modules/analysis-common/src/main/java/org/opensearch/analysis/common/DelimitedTermFrequencyTokenFilterFactory.java @@ -0,0 +1,45 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.analysis.common; + +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.miscellaneous.DelimitedTermFrequencyTokenFilter; +import org.opensearch.common.settings.Settings; +import org.opensearch.env.Environment; +import org.opensearch.index.IndexSettings; +import org.opensearch.index.analysis.AbstractTokenFilterFactory; + +public class DelimitedTermFrequencyTokenFilterFactory extends AbstractTokenFilterFactory { + public static final char DEFAULT_DELIMITER = '|'; + private static final String DELIMITER = "delimiter"; + private final char delimiter; + + DelimitedTermFrequencyTokenFilterFactory(IndexSettings indexSettings, Environment environment, String name, Settings settings) { + super(indexSettings, name, settings); + delimiter = parseDelimiter(settings); + } + + @Override + public TokenStream create(TokenStream tokenStream) { + return new DelimitedTermFrequencyTokenFilter(tokenStream, delimiter); + } + + private static char parseDelimiter(Settings settings) { + String delimiter = settings.get(DELIMITER); + if (delimiter == null) { + return DEFAULT_DELIMITER; + } else if (delimiter.length() == 1) { + return delimiter.charAt(0); + } + + throw new IllegalArgumentException( + "Setting [" + DELIMITER + "] must be a single, non-null character. [" + delimiter + "] was provided." + ); + } +} diff --git a/modules/analysis-common/src/test/java/org/opensearch/analysis/common/CommonAnalysisFactoryTests.java b/modules/analysis-common/src/test/java/org/opensearch/analysis/common/CommonAnalysisFactoryTests.java index 1c4db089565ff..11713f52f5b18 100644 --- a/modules/analysis-common/src/test/java/org/opensearch/analysis/common/CommonAnalysisFactoryTests.java +++ b/modules/analysis-common/src/test/java/org/opensearch/analysis/common/CommonAnalysisFactoryTests.java @@ -145,6 +145,7 @@ protected Map> getTokenFilters() { filters.put("cjkwidth", CJKWidthFilterFactory.class); filters.put("cjkbigram", CJKBigramFilterFactory.class); filters.put("delimitedpayload", DelimitedPayloadTokenFilterFactory.class); + filters.put("delimitedtermfrequency", DelimitedTermFrequencyTokenFilterFactory.class); filters.put("keepword", KeepWordFilterFactory.class); filters.put("type", KeepTypesFilterFactory.class); filters.put("classic", ClassicFilterFactory.class); @@ -202,6 +203,7 @@ protected Map> getPreConfiguredTokenFilters() { filters.put("decimal_digit", null); filters.put("delimited_payload_filter", org.apache.lucene.analysis.payloads.DelimitedPayloadTokenFilterFactory.class); filters.put("delimited_payload", org.apache.lucene.analysis.payloads.DelimitedPayloadTokenFilterFactory.class); + filters.put("delimited_term_freq", org.apache.lucene.analysis.miscellaneous.DelimitedTermFrequencyTokenFilterFactory.class); filters.put("dutch_stem", SnowballPorterFilterFactory.class); filters.put("edge_ngram", null); filters.put("edgeNGram", null); diff --git a/modules/analysis-common/src/test/java/org/opensearch/analysis/common/DelimitedTermFrequencyTokenFilterFactoryTests.java b/modules/analysis-common/src/test/java/org/opensearch/analysis/common/DelimitedTermFrequencyTokenFilterFactoryTests.java new file mode 100644 index 0000000000000..fab83a75387de --- /dev/null +++ b/modules/analysis-common/src/test/java/org/opensearch/analysis/common/DelimitedTermFrequencyTokenFilterFactoryTests.java @@ -0,0 +1,89 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.analysis.common; + +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.Tokenizer; +import org.apache.lucene.analysis.core.WhitespaceTokenizer; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute; +import org.opensearch.common.settings.Settings; +import org.opensearch.env.Environment; +import org.opensearch.index.analysis.AnalysisTestsHelper; +import org.opensearch.index.analysis.TokenFilterFactory; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.OpenSearchTokenStreamTestCase; + +import java.io.StringReader; + +public class DelimitedTermFrequencyTokenFilterFactoryTests extends OpenSearchTokenStreamTestCase { + + public void testDefault() throws Exception { + OpenSearchTestCase.TestAnalysis analysis = AnalysisTestsHelper.createTestAnalysisFromSettings( + Settings.builder() + .put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()) + .put("index.analysis.filter.my_delimited_term_freq.type", "delimited_term_freq") + .build(), + new CommonAnalysisModulePlugin() + ); + doTest(analysis, "cat|4 dog|5"); + } + + public void testDelimiter() throws Exception { + OpenSearchTestCase.TestAnalysis analysis = AnalysisTestsHelper.createTestAnalysisFromSettings( + Settings.builder() + .put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()) + .put("index.analysis.filter.my_delimited_term_freq.type", "delimited_term_freq") + .put("index.analysis.filter.my_delimited_term_freq.delimiter", ":") + .build(), + new CommonAnalysisModulePlugin() + ); + doTest(analysis, "cat:4 dog:5"); + } + + public void testDelimiterLongerThanOneCharThrows() { + IllegalArgumentException ex = expectThrows( + IllegalArgumentException.class, + () -> AnalysisTestsHelper.createTestAnalysisFromSettings( + Settings.builder() + .put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()) + .put("index.analysis.filter.my_delimited_term_freq.type", "delimited_term_freq") + .put("index.analysis.filter.my_delimited_term_freq.delimiter", "^^") + .build(), + new CommonAnalysisModulePlugin() + ) + ); + + assertEquals("Setting [delimiter] must be a single, non-null character. [^^] was provided.", ex.getMessage()); + } + + private void doTest(OpenSearchTestCase.TestAnalysis analysis, String source) throws Exception { + TokenFilterFactory tokenFilter = analysis.tokenFilter.get("my_delimited_term_freq"); + Tokenizer tokenizer = new WhitespaceTokenizer(); + tokenizer.setReader(new StringReader(source)); + + TokenStream stream = tokenFilter.create(tokenizer); + + CharTermAttribute termAtt = stream.getAttribute(CharTermAttribute.class); + TermFrequencyAttribute tfAtt = stream.getAttribute(TermFrequencyAttribute.class); + stream.reset(); + assertTermEquals("cat", stream, termAtt, tfAtt, 4); + assertTermEquals("dog", stream, termAtt, tfAtt, 5); + assertFalse(stream.incrementToken()); + stream.end(); + stream.close(); + } + + void assertTermEquals(String expected, TokenStream stream, CharTermAttribute termAtt, TermFrequencyAttribute tfAtt, int expectedTf) + throws Exception { + assertTrue(stream.incrementToken()); + assertEquals(expected, termAtt.toString()); + assertEquals(expectedTf, tfAtt.getTermFrequency()); + } +} diff --git a/modules/analysis-common/src/yamlRestTest/resources/rest-api-spec/test/analysis-common/40_token_filters.yml b/modules/analysis-common/src/yamlRestTest/resources/rest-api-spec/test/analysis-common/40_token_filters.yml index 40c82ff185661..e92cc0c4838c7 100644 --- a/modules/analysis-common/src/yamlRestTest/resources/rest-api-spec/test/analysis-common/40_token_filters.yml +++ b/modules/analysis-common/src/yamlRestTest/resources/rest-api-spec/test/analysis-common/40_token_filters.yml @@ -1198,6 +1198,46 @@ - match: { tokens.0.token: foo } --- +"delimited_term_freq": + - skip: + version: " - 2.9.99" + reason: "delimited_term_freq token filter was added in v2.10.0" + - do: + indices.create: + index: test + body: + settings: + analysis: + filter: + my_delimited_term_freq: + type: delimited_term_freq + delimiter: ^ + - do: + indices.analyze: + index: test + body: + text: foo^3 + tokenizer: keyword + filter: [my_delimited_term_freq] + attributes: termFrequency + explain: true + - length: { detail.tokenfilters: 1 } + - match: { detail.tokenfilters.0.tokens.0.token: foo } + - match: { detail.tokenfilters.0.tokens.0.termFrequency: 3 } + + # Test pre-configured token filter too: + - do: + indices.analyze: + body: + text: foo|100 + tokenizer: keyword + filter: [delimited_term_freq] + attributes: termFrequency + explain: true + - length: { detail.tokenfilters: 1 } + - match: { detail.tokenfilters.0.tokens.0.token: foo } + - match: { detail.tokenfilters.0.tokens.0.termFrequency: 100 } +--- "keep_filter": - do: indices.create: diff --git a/test/framework/src/main/java/org/opensearch/indices/analysis/AnalysisFactoryTestCase.java b/test/framework/src/main/java/org/opensearch/indices/analysis/AnalysisFactoryTestCase.java index b93cb64e32cfe..c412ae8317f24 100644 --- a/test/framework/src/main/java/org/opensearch/indices/analysis/AnalysisFactoryTestCase.java +++ b/test/framework/src/main/java/org/opensearch/indices/analysis/AnalysisFactoryTestCase.java @@ -98,6 +98,7 @@ public abstract class AnalysisFactoryTestCase extends OpenSearchTestCase { .put("czechstem", MovedToAnalysisCommon.class) .put("decimaldigit", MovedToAnalysisCommon.class) .put("delimitedpayload", MovedToAnalysisCommon.class) + .put("delimitedtermfrequency", MovedToAnalysisCommon.class) .put("dictionarycompoundword", MovedToAnalysisCommon.class) .put("edgengram", MovedToAnalysisCommon.class) .put("elision", MovedToAnalysisCommon.class) @@ -201,9 +202,6 @@ public abstract class AnalysisFactoryTestCase extends OpenSearchTestCase { .put("daterecognizer", Void.class) // for token filters that generate bad offsets, which are now rejected since Lucene 7 .put("fixbrokenoffsets", Void.class) - // should we expose it, or maybe think about higher level integration of the - // fake term frequency feature (LUCENE-7854) - .put("delimitedtermfrequency", Void.class) // LUCENE-8273: ProtectedTermFilterFactory allows analysis chains to skip // particular token filters based on the attributes of the current token. .put("protectedterm", Void.class) From 6765b1654658f3d73911ef3d26b179959d873a55 Mon Sep 17 00:00:00 2001 From: Kunal Kotwani Date: Thu, 31 Aug 2023 17:39:07 -0700 Subject: [PATCH 3/4] Add async blob read and download support using multiple streams (#9592) Signed-off-by: Kunal Kotwani --- CHANGELOG.md | 1 + .../repositories/s3/S3BlobContainer.java | 6 + .../s3/S3BlobStoreContainerTests.java | 12 ++ .../mocks/MockFsVerifyingBlobContainer.java | 24 +++ .../VerifyingMultiStreamBlobContainer.java | 26 +++ .../blobstore/stream/read/ReadContext.java | 46 +++++ .../read/listener/FileCompletionListener.java | 47 +++++ .../stream/read/listener/FilePartWriter.java | 90 ++++++++++ .../read/listener/ReadContextListener.java | 65 +++++++ .../stream/read/listener/package-info.java | 14 ++ .../blobstore/stream/read/package-info.java | 13 ++ .../listener/FileCompletionListenerTests.java | 58 +++++++ .../read/listener/FilePartWriterTests.java | 163 ++++++++++++++++++ .../read/listener/ListenerTestUtils.java | 56 ++++++ .../listener/ReadContextListenerTests.java | 124 +++++++++++++ 15 files changed, 745 insertions(+) create mode 100644 server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java create mode 100644 server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListener.java create mode 100644 server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java create mode 100644 server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java create mode 100644 server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/package-info.java create mode 100644 server/src/main/java/org/opensearch/common/blobstore/stream/read/package-info.java create mode 100644 server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java create mode 100644 server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java create mode 100644 server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java create mode 100644 server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index d4a060cc16504..80ab0fb87e609 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -97,6 +97,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add concurrent segment search related metrics to node and index stats ([#9622](https://github.com/opensearch-project/OpenSearch/issues/9622)) - Decouple replication lag from logic to fail stale replicas ([#9507](https://github.com/opensearch-project/OpenSearch/pull/9507)) - Expose DelimitedTermFrequencyTokenFilter to allow providing term frequencies along with terms ([#9479](https://github.com/opensearch-project/OpenSearch/pull/9479)) +- APIs for performing async blob reads and async downloads from the repository using multiple streams ([#9592](https://github.com/opensearch-project/OpenSearch/issues/9592)) ### Dependencies - Bump `org.apache.logging.log4j:log4j-core` from 2.17.1 to 2.20.0 ([#8307](https://github.com/opensearch-project/OpenSearch/pull/8307)) diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java index a97a509adce47..183b5f8fe7ac1 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java @@ -69,6 +69,7 @@ import org.opensearch.common.blobstore.BlobStoreException; import org.opensearch.common.blobstore.DeleteResult; import org.opensearch.common.blobstore.VerifyingMultiStreamBlobContainer; +import org.opensearch.common.blobstore.stream.read.ReadContext; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.common.blobstore.stream.write.WritePriority; import org.opensearch.common.blobstore.support.AbstractBlobContainer; @@ -211,6 +212,11 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener comp } } + @Override + public void readBlobAsync(String blobName, ActionListener listener) { + throw new UnsupportedOperationException(); + } + // package private for testing long getLargeBlobThresholdInBytes() { return blobStore.bufferSizeInBytes(); diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java index 2438acaf7c1f2..1c4936cae7eba 100644 --- a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java @@ -61,6 +61,7 @@ import software.amazon.awssdk.services.s3.model.UploadPartResponse; import software.amazon.awssdk.services.s3.paginators.ListObjectsV2Iterable; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.common.blobstore.BlobContainer; import org.opensearch.common.blobstore.BlobMetadata; import org.opensearch.common.blobstore.BlobPath; @@ -881,6 +882,17 @@ public void onFailure(Exception e) {} } } + public void testAsyncBlobDownload() { + final S3BlobStore blobStore = mock(S3BlobStore.class); + final BlobPath blobPath = mock(BlobPath.class); + final String blobName = "test-blob"; + + final UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, () -> { + final S3BlobContainer blobContainer = new S3BlobContainer(blobPath, blobStore); + blobContainer.readBlobAsync(blobName, new PlainActionFuture<>()); + }); + } + public void testListBlobsByPrefixInLexicographicOrderWithNegativeLimit() throws IOException { testListBlobsByPrefixInLexicographicOrder(-5, 0, BlobContainer.BlobNameSortOrder.LEXICOGRAPHIC); } diff --git a/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsVerifyingBlobContainer.java b/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsVerifyingBlobContainer.java index d882220c9f4d7..887a4cc6ba9a8 100644 --- a/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsVerifyingBlobContainer.java +++ b/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsVerifyingBlobContainer.java @@ -14,6 +14,7 @@ import org.opensearch.common.blobstore.VerifyingMultiStreamBlobContainer; import org.opensearch.common.blobstore.fs.FsBlobContainer; import org.opensearch.common.blobstore.fs.FsBlobStore; +import org.opensearch.common.blobstore.stream.read.ReadContext; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.common.io.InputStreamContainer; import org.opensearch.core.action.ActionListener; @@ -24,6 +25,8 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; @@ -114,6 +117,27 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener comp } + @Override + public void readBlobAsync(String blobName, ActionListener listener) { + new Thread(() -> { + try { + long contentLength = listBlobs().get(blobName).length(); + long partSize = contentLength / 10; + int numberOfParts = (int) ((contentLength % partSize) == 0 ? contentLength / partSize : (contentLength / partSize) + 1); + List blobPartStreams = new ArrayList<>(); + for (int partNumber = 0; partNumber < numberOfParts; partNumber++) { + long offset = partNumber * partSize; + InputStreamContainer blobPartStream = new InputStreamContainer(readBlob(blobName, offset, partSize), partSize, offset); + blobPartStreams.add(blobPartStream); + } + ReadContext blobReadContext = new ReadContext(contentLength, blobPartStreams, null); + listener.onResponse(blobReadContext); + } catch (Exception e) { + listener.onFailure(e); + } + }).start(); + } + private boolean isSegmentFile(String filename) { return !filename.endsWith(".tlog") && !filename.endsWith(".ckp"); } diff --git a/server/src/main/java/org/opensearch/common/blobstore/VerifyingMultiStreamBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/VerifyingMultiStreamBlobContainer.java index d10445ba14d76..1764c9e634781 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/VerifyingMultiStreamBlobContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/VerifyingMultiStreamBlobContainer.java @@ -8,10 +8,15 @@ package org.opensearch.common.blobstore; +import org.opensearch.common.annotation.ExperimentalApi; +import org.opensearch.common.blobstore.stream.read.ReadContext; +import org.opensearch.common.blobstore.stream.read.listener.ReadContextListener; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.threadpool.ThreadPool; import java.io.IOException; +import java.nio.file.Path; /** * An extension of {@link BlobContainer} that adds {@link VerifyingMultiStreamBlobContainer#asyncBlobUpload} to allow @@ -31,4 +36,25 @@ public interface VerifyingMultiStreamBlobContainer extends BlobContainer { * @throws IOException if any of the input streams could not be read, or the target blob could not be written to */ void asyncBlobUpload(WriteContext writeContext, ActionListener completionListener) throws IOException; + + /** + * Creates an async callback of a {@link ReadContext} containing the multipart streams for a specified blob within the container. + * @param blobName The name of the blob for which the {@link ReadContext} needs to be fetched. + * @param listener Async listener for {@link ReadContext} object which serves the input streams and other metadata for the blob + */ + @ExperimentalApi + void readBlobAsync(String blobName, ActionListener listener); + + /** + * Asynchronously downloads the blob to the specified location using an executor from the thread pool. + * @param blobName The name of the blob for which needs to be downloaded. + * @param fileLocation The path on local disk where the blob needs to be downloaded. + * @param threadPool The threadpool instance which will provide the executor for performing a multipart download. + * @param completionListener Listener which will be notified when the download is complete. + */ + @ExperimentalApi + default void asyncBlobDownload(String blobName, Path fileLocation, ThreadPool threadPool, ActionListener completionListener) { + ReadContextListener readContextListener = new ReadContextListener(blobName, fileLocation, threadPool, completionListener); + readBlobAsync(blobName, readContextListener); + } } diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java new file mode 100644 index 0000000000000..4ba17959f8040 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java @@ -0,0 +1,46 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read; + +import org.opensearch.common.annotation.ExperimentalApi; +import org.opensearch.common.io.InputStreamContainer; + +import java.util.List; + +/** + * ReadContext is used to encapsulate all data needed by BlobContainer#readBlobAsync + */ +@ExperimentalApi +public class ReadContext { + private final long blobSize; + private final List partStreams; + private final String blobChecksum; + + public ReadContext(long blobSize, List partStreams, String blobChecksum) { + this.blobSize = blobSize; + this.partStreams = partStreams; + this.blobChecksum = blobChecksum; + } + + public String getBlobChecksum() { + return blobChecksum; + } + + public int getNumberOfParts() { + return partStreams.size(); + } + + public long getBlobSize() { + return blobSize; + } + + public List getPartStreams() { + return partStreams; + } +} diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListener.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListener.java new file mode 100644 index 0000000000000..aadd6e2ab304e --- /dev/null +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListener.java @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read.listener; + +import org.opensearch.common.annotation.InternalApi; +import org.opensearch.core.action.ActionListener; + +import java.util.concurrent.atomic.AtomicInteger; + +/** + * FileCompletionListener listens for completion of fetch on all the streams for a file, where + * individual streams are handled using {@link FilePartWriter}. The {@link FilePartWriter}(s) + * hold a reference to the file completion listener to be notified. + */ +@InternalApi +class FileCompletionListener implements ActionListener { + + private final int numberOfParts; + private final String fileName; + private final AtomicInteger completedPartsCount; + private final ActionListener completionListener; + + public FileCompletionListener(int numberOfParts, String fileName, ActionListener completionListener) { + this.completedPartsCount = new AtomicInteger(); + this.numberOfParts = numberOfParts; + this.fileName = fileName; + this.completionListener = completionListener; + } + + @Override + public void onResponse(Integer unused) { + if (completedPartsCount.incrementAndGet() == numberOfParts) { + completionListener.onResponse(fileName); + } + } + + @Override + public void onFailure(Exception e) { + completionListener.onFailure(e); + } +} diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java new file mode 100644 index 0000000000000..84fd7ed9ffebf --- /dev/null +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java @@ -0,0 +1,90 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read.listener; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.annotation.InternalApi; +import org.opensearch.common.io.Channels; +import org.opensearch.common.io.InputStreamContainer; +import org.opensearch.core.action.ActionListener; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * FilePartWriter transfers the provided stream into the specified file path using a {@link FileChannel} + * instance. It performs offset based writes to the file and notifies the {@link FileCompletionListener} on completion. + */ +@InternalApi +class FilePartWriter implements Runnable { + + private final int partNumber; + private final InputStreamContainer blobPartStreamContainer; + private final Path fileLocation; + private final AtomicBoolean anyPartStreamFailed; + private final ActionListener fileCompletionListener; + private static final Logger logger = LogManager.getLogger(FilePartWriter.class); + + // 8 MB buffer for transfer + private static final int BUFFER_SIZE = 8 * 1024 * 2024; + + public FilePartWriter( + int partNumber, + InputStreamContainer blobPartStreamContainer, + Path fileLocation, + AtomicBoolean anyPartStreamFailed, + ActionListener fileCompletionListener + ) { + this.partNumber = partNumber; + this.blobPartStreamContainer = blobPartStreamContainer; + this.fileLocation = fileLocation; + this.anyPartStreamFailed = anyPartStreamFailed; + this.fileCompletionListener = fileCompletionListener; + } + + @Override + public void run() { + // Ensures no writes to the file if any stream fails. + if (anyPartStreamFailed.get() == false) { + try (FileChannel outputFileChannel = FileChannel.open(fileLocation, StandardOpenOption.WRITE, StandardOpenOption.CREATE)) { + try (InputStream inputStream = blobPartStreamContainer.getInputStream()) { + long streamOffset = blobPartStreamContainer.getOffset(); + final byte[] buffer = new byte[BUFFER_SIZE]; + int bytesRead; + while ((bytesRead = inputStream.read(buffer)) != -1) { + Channels.writeToChannel(buffer, 0, bytesRead, outputFileChannel, streamOffset); + streamOffset += bytesRead; + } + } + } catch (IOException e) { + processFailure(e); + return; + } + fileCompletionListener.onResponse(partNumber); + } + } + + void processFailure(Exception e) { + try { + Files.deleteIfExists(fileLocation); + } catch (IOException ex) { + // Die silently + logger.info("Failed to delete file {} on stream failure: {}", fileLocation, ex); + } + if (anyPartStreamFailed.getAndSet(true) == false) { + fileCompletionListener.onFailure(e); + } + } +} diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java new file mode 100644 index 0000000000000..4338bddb3fbe7 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java @@ -0,0 +1,65 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read.listener; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.annotation.InternalApi; +import org.opensearch.common.blobstore.stream.read.ReadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.threadpool.ThreadPool; + +import java.nio.file.Path; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * ReadContextListener orchestrates the async file fetch from the {@link org.opensearch.common.blobstore.BlobContainer} + * using a {@link ReadContext} callback. On response, it spawns off the download using multiple streams which are + * spread across a {@link ThreadPool} executor. + */ +@InternalApi +public class ReadContextListener implements ActionListener { + + private final String fileName; + private final Path fileLocation; + private final ThreadPool threadPool; + private final ActionListener completionListener; + private static final Logger logger = LogManager.getLogger(ReadContextListener.class); + + public ReadContextListener(String fileName, Path fileLocation, ThreadPool threadPool, ActionListener completionListener) { + this.fileName = fileName; + this.fileLocation = fileLocation; + this.threadPool = threadPool; + this.completionListener = completionListener; + } + + @Override + public void onResponse(ReadContext readContext) { + logger.trace("Streams received for blob {}", fileName); + final int numParts = readContext.getNumberOfParts(); + final AtomicBoolean anyPartStreamFailed = new AtomicBoolean(); + FileCompletionListener fileCompletionListener = new FileCompletionListener(numParts, fileName, completionListener); + + for (int partNumber = 0; partNumber < numParts; partNumber++) { + FilePartWriter filePartWriter = new FilePartWriter( + partNumber, + readContext.getPartStreams().get(partNumber), + fileLocation, + anyPartStreamFailed, + fileCompletionListener + ); + threadPool.executor(ThreadPool.Names.GENERIC).submit(filePartWriter); + } + } + + @Override + public void onFailure(Exception e) { + completionListener.onFailure(e); + } +} diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/package-info.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/package-info.java new file mode 100644 index 0000000000000..fe670fe3eb25c --- /dev/null +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/package-info.java @@ -0,0 +1,14 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * Abstractions for stream based file reads from the blob store. + * Provides listeners for performing the necessary async read operations to perform + * multi stream reads for blobs from the container. + * */ +package org.opensearch.common.blobstore.stream.read.listener; diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/package-info.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/package-info.java new file mode 100644 index 0000000000000..a9e2ca35c1fa6 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/package-info.java @@ -0,0 +1,13 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * Abstractions for stream based file reads from the blob store. + * Provides support for async reads from the blob container. + * */ +package org.opensearch.common.blobstore.stream.read; diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java new file mode 100644 index 0000000000000..fa13d90f42fa6 --- /dev/null +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java @@ -0,0 +1,58 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read.listener; + +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +import static org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils.CountingCompletionListener; + +public class FileCompletionListenerTests extends OpenSearchTestCase { + + public void testFileCompletionListener() { + int numStreams = 10; + String fileName = "test_segment_file"; + CountingCompletionListener completionListener = new CountingCompletionListener(); + FileCompletionListener fileCompletionListener = new FileCompletionListener(numStreams, fileName, completionListener); + + for (int stream = 0; stream < numStreams; stream++) { + // Ensure completion listener called only when all streams are completed + assertEquals(0, completionListener.getResponseCount()); + fileCompletionListener.onResponse(null); + } + + assertEquals(1, completionListener.getResponseCount()); + assertEquals(fileName, completionListener.getResponse()); + } + + public void testFileCompletionListenerFailure() { + int numStreams = 10; + String fileName = "test_segment_file"; + CountingCompletionListener completionListener = new CountingCompletionListener(); + FileCompletionListener fileCompletionListener = new FileCompletionListener(numStreams, fileName, completionListener); + + // Fail the listener initially + IOException exception = new IOException(); + fileCompletionListener.onFailure(exception); + + for (int stream = 0; stream < numStreams - 1; stream++) { + assertEquals(0, completionListener.getResponseCount()); + fileCompletionListener.onResponse(null); + } + + assertEquals(1, completionListener.getFailureCount()); + assertEquals(exception, completionListener.getException()); + assertEquals(0, completionListener.getResponseCount()); + + fileCompletionListener.onFailure(exception); + assertEquals(2, completionListener.getFailureCount()); + assertEquals(exception, completionListener.getException()); + } +} diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java new file mode 100644 index 0000000000000..811566eb5767b --- /dev/null +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java @@ -0,0 +1,163 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read.listener; + +import org.opensearch.common.io.InputStreamContainer; +import org.opensearch.test.OpenSearchTestCase; +import org.junit.Before; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils.CountingCompletionListener; + +public class FilePartWriterTests extends OpenSearchTestCase { + + private Path path; + + @Before + public void init() throws Exception { + path = createTempDir("FilePartWriterTests"); + } + + public void testFilePartWriter() throws Exception { + Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); + int contentLength = 100; + int partNumber = 1; + InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); + InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, inputStream.available(), 0); + AtomicBoolean anyStreamFailed = new AtomicBoolean(); + CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); + + FilePartWriter filePartWriter = new FilePartWriter( + partNumber, + inputStreamContainer, + segmentFilePath, + anyStreamFailed, + fileCompletionListener + ); + filePartWriter.run(); + + assertTrue(Files.exists(segmentFilePath)); + assertEquals(contentLength, Files.size(segmentFilePath)); + assertEquals(1, fileCompletionListener.getResponseCount()); + assertEquals(Integer.valueOf(partNumber), fileCompletionListener.getResponse()); + } + + public void testFilePartWriterWithOffset() throws Exception { + Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); + int contentLength = 100; + int offset = 10; + int partNumber = 1; + InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); + InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, inputStream.available(), offset); + AtomicBoolean anyStreamFailed = new AtomicBoolean(); + CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); + + FilePartWriter filePartWriter = new FilePartWriter( + partNumber, + inputStreamContainer, + segmentFilePath, + anyStreamFailed, + fileCompletionListener + ); + filePartWriter.run(); + + assertTrue(Files.exists(segmentFilePath)); + assertEquals(contentLength + offset, Files.size(segmentFilePath)); + assertEquals(1, fileCompletionListener.getResponseCount()); + assertEquals(Integer.valueOf(partNumber), fileCompletionListener.getResponse()); + } + + public void testFilePartWriterLargeInput() throws Exception { + Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); + int contentLength = 20 * 1024 * 1024; + int partNumber = 1; + InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); + InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, contentLength, 0); + AtomicBoolean anyStreamFailed = new AtomicBoolean(); + CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); + + FilePartWriter filePartWriter = new FilePartWriter( + partNumber, + inputStreamContainer, + segmentFilePath, + anyStreamFailed, + fileCompletionListener + ); + filePartWriter.run(); + + assertTrue(Files.exists(segmentFilePath)); + assertEquals(contentLength, Files.size(segmentFilePath)); + + assertEquals(1, fileCompletionListener.getResponseCount()); + assertEquals(Integer.valueOf(partNumber), fileCompletionListener.getResponse()); + } + + public void testFilePartWriterException() throws Exception { + Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); + int contentLength = 100; + int partNumber = 1; + InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); + InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, contentLength, 0); + AtomicBoolean anyStreamFailed = new AtomicBoolean(); + CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); + + IOException ioException = new IOException(); + FilePartWriter filePartWriter = new FilePartWriter( + partNumber, + inputStreamContainer, + segmentFilePath, + anyStreamFailed, + fileCompletionListener + ); + assertFalse(anyStreamFailed.get()); + filePartWriter.processFailure(ioException); + + assertTrue(anyStreamFailed.get()); + assertFalse(Files.exists(segmentFilePath)); + + // Fail stream again to simulate another stream failure for same file + filePartWriter.processFailure(ioException); + + assertTrue(anyStreamFailed.get()); + assertFalse(Files.exists(segmentFilePath)); + + assertEquals(0, fileCompletionListener.getResponseCount()); + assertEquals(1, fileCompletionListener.getFailureCount()); + assertEquals(ioException, fileCompletionListener.getException()); + } + + public void testFilePartWriterStreamFailed() throws Exception { + Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); + int contentLength = 100; + int partNumber = 1; + InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); + InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, inputStream.available(), 0); + AtomicBoolean anyStreamFailed = new AtomicBoolean(true); + CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); + + FilePartWriter filePartWriter = new FilePartWriter( + partNumber, + inputStreamContainer, + segmentFilePath, + anyStreamFailed, + fileCompletionListener + ); + filePartWriter.run(); + + assertFalse(Files.exists(segmentFilePath)); + assertEquals(0, fileCompletionListener.getResponseCount()); + } +} diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java new file mode 100644 index 0000000000000..1e9450c83e3ab --- /dev/null +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java @@ -0,0 +1,56 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read.listener; + +import org.opensearch.core.action.ActionListener; + +/** + * Utility class containing common functionality for read listener based tests + */ +public class ListenerTestUtils { + + /** + * CountingCompletionListener acts as a verification instance for wrapping listener based calls. + * Keeps track of the last response, failure and count of response and failure invocations. + */ + static class CountingCompletionListener implements ActionListener { + private int responseCount; + private int failureCount; + private T response; + private Exception exception; + + @Override + public void onResponse(T response) { + this.response = response; + responseCount++; + } + + @Override + public void onFailure(Exception e) { + exception = e; + failureCount++; + } + + public int getResponseCount() { + return responseCount; + } + + public int getFailureCount() { + return failureCount; + } + + public T getResponse() { + return response; + } + + public Exception getException() { + return exception; + } + } +} diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java new file mode 100644 index 0000000000000..f785b5f1191b4 --- /dev/null +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java @@ -0,0 +1,124 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read.listener; + +import org.opensearch.action.LatchedActionListener; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.common.blobstore.stream.read.ReadContext; +import org.opensearch.common.io.InputStreamContainer; +import org.opensearch.core.action.ActionListener; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; + +import static org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils.CountingCompletionListener; + +public class ReadContextListenerTests extends OpenSearchTestCase { + + private Path path; + private static ThreadPool threadPool; + private static final int NUMBER_OF_PARTS = 5; + private static final int PART_SIZE = 10; + private static final String TEST_SEGMENT_FILE = "test_segment_file"; + + @BeforeClass + public static void setup() { + threadPool = new TestThreadPool(ReadContextListenerTests.class.getName()); + } + + @AfterClass + public static void cleanup() { + threadPool.shutdown(); + } + + @Before + public void init() throws Exception { + path = createTempDir("ReadContextListenerTests"); + } + + public void testReadContextListener() throws InterruptedException, IOException { + Path fileLocation = path.resolve(UUID.randomUUID().toString()); + List blobPartStreams = initializeBlobPartStreams(); + CountDownLatch countDownLatch = new CountDownLatch(1); + ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); + ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, threadPool, completionListener); + ReadContext readContext = new ReadContext((long) PART_SIZE * NUMBER_OF_PARTS, blobPartStreams, null); + readContextListener.onResponse(readContext); + + countDownLatch.await(); + + assertTrue(Files.exists(fileLocation)); + assertEquals(NUMBER_OF_PARTS * PART_SIZE, Files.size(fileLocation)); + } + + public void testReadContextListenerFailure() throws InterruptedException { + Path fileLocation = path.resolve(UUID.randomUUID().toString()); + List blobPartStreams = initializeBlobPartStreams(); + CountDownLatch countDownLatch = new CountDownLatch(1); + ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); + ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, threadPool, completionListener); + InputStream badInputStream = new InputStream() { + + @Override + public int read(byte[] b, int off, int len) throws IOException { + return read(); + } + + @Override + public int read() throws IOException { + throw new IOException(); + } + + @Override + public int available() { + return PART_SIZE; + } + }; + + blobPartStreams.add(NUMBER_OF_PARTS, new InputStreamContainer(badInputStream, PART_SIZE, PART_SIZE * NUMBER_OF_PARTS)); + ReadContext readContext = new ReadContext((long) (PART_SIZE + 1) * NUMBER_OF_PARTS, blobPartStreams, null); + readContextListener.onResponse(readContext); + + countDownLatch.await(); + + assertFalse(Files.exists(fileLocation)); + } + + public void testReadContextListenerException() { + Path fileLocation = path.resolve(UUID.randomUUID().toString()); + CountingCompletionListener listener = new CountingCompletionListener(); + ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, threadPool, listener); + IOException exception = new IOException(); + readContextListener.onFailure(exception); + assertEquals(1, listener.getFailureCount()); + assertEquals(exception, listener.getException()); + } + + private List initializeBlobPartStreams() { + List blobPartStreams = new ArrayList<>(); + for (int partNumber = 0; partNumber < NUMBER_OF_PARTS; partNumber++) { + InputStream testStream = new ByteArrayInputStream(randomByteArrayOfLength(PART_SIZE)); + blobPartStreams.add(new InputStreamContainer(testStream, PART_SIZE, (long) partNumber * PART_SIZE)); + } + return blobPartStreams; + } +} From 04c90c7fed76e8f107bc5944eb973fa4c2034b67 Mon Sep 17 00:00:00 2001 From: Suraj Singh Date: Thu, 31 Aug 2023 19:56:42 -0700 Subject: [PATCH 4/4] Mute RemoteIndexShardTests primary promotion flaky tests (#9679) --- .../org/opensearch/index/shard/RemoteIndexShardTests.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/server/src/test/java/org/opensearch/index/shard/RemoteIndexShardTests.java b/server/src/test/java/org/opensearch/index/shard/RemoteIndexShardTests.java index 9dcecbe1059b6..33159b85ec640 100644 --- a/server/src/test/java/org/opensearch/index/shard/RemoteIndexShardTests.java +++ b/server/src/test/java/org/opensearch/index/shard/RemoteIndexShardTests.java @@ -64,18 +64,22 @@ protected ReplicationGroup getReplicationGroup(int numberOfReplicas) throws IOEx return createGroup(numberOfReplicas, settings, indexMapping, new NRTReplicationEngineFactory(), createTempDir()); } + @AwaitsFix(bugUrl = "https://github.com/opensearch-project/OpenSearch/issues/9624") public void testNRTReplicaWithRemoteStorePromotedAsPrimaryRefreshRefresh() throws Exception { testNRTReplicaWithRemoteStorePromotedAsPrimary(false, false); } + @AwaitsFix(bugUrl = "https://github.com/opensearch-project/OpenSearch/issues/9624") public void testNRTReplicaWithRemoteStorePromotedAsPrimaryRefreshCommit() throws Exception { testNRTReplicaWithRemoteStorePromotedAsPrimary(false, true); } + @AwaitsFix(bugUrl = "https://github.com/opensearch-project/OpenSearch/issues/9624") public void testNRTReplicaWithRemoteStorePromotedAsPrimaryCommitRefresh() throws Exception { testNRTReplicaWithRemoteStorePromotedAsPrimary(true, false); } + @AwaitsFix(bugUrl = "https://github.com/opensearch-project/OpenSearch/issues/9624") public void testNRTReplicaWithRemoteStorePromotedAsPrimaryCommitCommit() throws Exception { testNRTReplicaWithRemoteStorePromotedAsPrimary(true, true); }