Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to configure ArtifactService gRPC Channel #25151

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package org.apache.beam.runners.core.construction;

import java.util.Iterator;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi;

public interface ArtifactServiceClient extends AutoCloseable {
ArtifactApi.ResolveArtifactsResponse resolveArtifacts(ArtifactApi.ResolveArtifactsRequest request);
Iterator<ArtifactApi.GetArtifactResponse> getArtifact(ArtifactApi.GetArtifactRequest request);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,53 @@
*/
package org.apache.beam.runners.core.construction;

import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import org.apache.beam.model.expansion.v1.ExpansionApi;
import org.apache.beam.model.expansion.v1.ExpansionServiceGrpc;
import org.apache.beam.model.pipeline.v1.Endpoints;
import org.apache.beam.vendor.grpc.v1p48p1.io.grpc.ManagedChannel;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi;
import org.apache.beam.model.jobmanagement.v1.ArtifactRetrievalServiceGrpc;

/** Default factory for ExpansionServiceClient used by External transform. */
public class DefaultExpansionServiceClientFactory implements ExpansionServiceClientFactory {
private Map<Endpoints.ApiServiceDescriptor, ExpansionServiceClient> expansionServiceMap;
private Function<Endpoints.ApiServiceDescriptor, ManagedChannel> channelFactory;
private Map<Endpoints.ApiServiceDescriptor, ArtifactServiceClient> artifactServiceMap;
private Function<Endpoints.ApiServiceDescriptor, ManagedChannel> expansionChannelFactory;
private Function<Endpoints.ApiServiceDescriptor, ManagedChannel> artifactChannelFactory;


private DefaultExpansionServiceClientFactory(
Function<Endpoints.ApiServiceDescriptor, ManagedChannel> channelFactory) {
Function<Endpoints.ApiServiceDescriptor, ManagedChannel> expansionChannelFactory,
Function<Endpoints.ApiServiceDescriptor, ManagedChannel> artifactChannelFactory) {
this.expansionServiceMap = new ConcurrentHashMap<>();
this.channelFactory = channelFactory;
this.artifactServiceMap = new ConcurrentHashMap<>();
this.expansionChannelFactory = expansionChannelFactory;
this.artifactChannelFactory = artifactChannelFactory;
}

public static DefaultExpansionServiceClientFactory create(
Function<Endpoints.ApiServiceDescriptor, ManagedChannel> channelFactory) {
return new DefaultExpansionServiceClientFactory(channelFactory);
return new DefaultExpansionServiceClientFactory(channelFactory, channelFactory);
}

public static DefaultExpansionServiceClientFactory create(
Function<Endpoints.ApiServiceDescriptor, ManagedChannel> expansionChannelFactory,
Function<Endpoints.ApiServiceDescriptor, ManagedChannel> artifactChannelFactory) {
return new DefaultExpansionServiceClientFactory(expansionChannelFactory, artifactChannelFactory);
}

@Override
public void close() throws Exception {
for (ExpansionServiceClient client : expansionServiceMap.values()) {
try (AutoCloseable closer = client) {}
}
for (ArtifactServiceClient client : artifactServiceMap.values()) {
try (AutoCloseable closer = client) {}
}
}

@Override
Expand All @@ -54,7 +72,7 @@ public ExpansionServiceClient getExpansionServiceClient(Endpoints.ApiServiceDesc
endpoint,
e ->
new ExpansionServiceClient() {
private final ManagedChannel channel = channelFactory.apply(endpoint);
private final ManagedChannel channel = expansionChannelFactory.apply(endpoint);
private final ExpansionServiceGrpc.ExpansionServiceBlockingStub service =
ExpansionServiceGrpc.newBlockingStub(channel);

Expand All @@ -69,4 +87,31 @@ public void close() throws Exception {
}
});
}

@Override
public ArtifactServiceClient getArtifactServiceClient(Endpoints.ApiServiceDescriptor endpoint) {
return artifactServiceMap.computeIfAbsent(
endpoint,
e ->
new ArtifactServiceClient() {
private final ManagedChannel channel = artifactChannelFactory.apply(endpoint);
private final ArtifactRetrievalServiceGrpc.ArtifactRetrievalServiceBlockingStub service =
ArtifactRetrievalServiceGrpc.newBlockingStub(channel);

@Override
public ArtifactApi.ResolveArtifactsResponse resolveArtifacts(ArtifactApi.ResolveArtifactsRequest request) {
return service.resolveArtifacts(request);
}

@Override
public Iterator<ArtifactApi.GetArtifactResponse> getArtifact(ArtifactApi.GetArtifactRequest request) {
return service.getArtifact(request);
}

@Override
public void close() throws Exception {
channel.shutdown();
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@
*/
public interface ExpansionServiceClientFactory extends AutoCloseable {
ExpansionServiceClient getExpansionServiceClient(Endpoints.ApiServiceDescriptor endpoint);

ArtifactServiceClient getArtifactServiceClient(Endpoints.ApiServiceDescriptor endpoint);
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public class External {

private static final ExpansionServiceClientFactory DEFAULT =
DefaultExpansionServiceClientFactory.create(
endPoint -> ManagedChannelBuilder.forTarget(endPoint.getUrl()).usePlaintext().build());
endPoint -> ManagedChannelBuilder.forTarget(endPoint.getUrl()).usePlaintext().maxInboundMessageSize(Integer.MAX_VALUE).build());

private static int getFreshNamespaceIndex() {
return namespaceCounter.getAndIncrement();
Expand Down Expand Up @@ -346,48 +346,38 @@ private Map<String, RunnerApi.Environment> resolveArtifacts(
if (environments.size() == 0) {
return environments;
}
ManagedChannel channel =
ManagedChannelBuilder.forTarget(endpoint.getUrl())
.usePlaintext()
.maxInboundMessageSize(Integer.MAX_VALUE)
.build();
try {
ArtifactRetrievalServiceGrpc.ArtifactRetrievalServiceBlockingStub retrievalStub =
ArtifactRetrievalServiceGrpc.newBlockingStub(channel);
return environments.entrySet().stream()
.collect(
Collectors.toMap(
Map.Entry::getKey,
kv -> {
try {
return resolveArtifacts(retrievalStub, kv.getValue());
} catch (IOException e) {
throw new RuntimeException(e);
}
}));
} finally {
channel.shutdown();
}

return environments.entrySet().stream()
.collect(
Collectors.toMap(
Map.Entry::getKey,
kv -> {
try {
return resolveArtifacts(clientFactory.getArtifactServiceClient(endpoint), kv.getValue());
} catch (IOException e) {
throw new RuntimeException(e);
}
}));
}

private RunnerApi.Environment resolveArtifacts(
ArtifactRetrievalServiceGrpc.ArtifactRetrievalServiceBlockingStub retrievalStub,
ArtifactServiceClient artifactServiceClient,
RunnerApi.Environment environment)
throws IOException {
return environment
.toBuilder()
.clearDependencies()
.addAllDependencies(resolveArtifacts(retrievalStub, environment.getDependenciesList()))
.addAllDependencies(resolveArtifacts(artifactServiceClient, environment.getDependenciesList()))
.build();
}

private List<RunnerApi.ArtifactInformation> resolveArtifacts(
ArtifactRetrievalServiceGrpc.ArtifactRetrievalServiceBlockingStub retrievalStub,
ArtifactServiceClient artifactServiceClient,
List<RunnerApi.ArtifactInformation> artifacts)
throws IOException {
List<RunnerApi.ArtifactInformation> resolved = new ArrayList<>();
for (RunnerApi.ArtifactInformation artifact :
retrievalStub
artifactServiceClient
.resolveArtifacts(
ArtifactApi.ResolveArtifactsRequest.newBuilder()
.addAllArtifacts(artifacts)
Expand All @@ -396,7 +386,7 @@ private List<RunnerApi.ArtifactInformation> resolveArtifacts(
Path path = Files.createTempFile("beam-artifact", "");
try (FileOutputStream fout = new FileOutputStream(path.toFile())) {
for (Iterator<ArtifactApi.GetArtifactResponse> it =
retrievalStub.getArtifact(
artifactServiceClient.getArtifact(
ArtifactApi.GetArtifactRequest.newBuilder().setArtifact(artifact).build());
it.hasNext(); ) {
it.next().getData().writeTo(fout);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,21 @@
import static org.hamcrest.Matchers.hasItems;

import org.apache.beam.model.expansion.v1.ExpansionApi;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi;
import org.apache.beam.model.pipeline.v1.Endpoints;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

import java.util.Iterator;

/** Tests for {@link org.apache.beam.runners.core.construction.ExternalTranslation}. */
@RunWith(JUnit4.class)
public class ExternalTranslationTest {
Expand All @@ -56,6 +60,7 @@ public void testTranslation() {
.getTransformsMap()
.keySet()
.toArray(new String[0])));

}

static class TestExpansionServiceClientFactory implements ExpansionServiceClientFactory {
Expand Down Expand Up @@ -97,6 +102,26 @@ public void close() throws Exception {
};
}

@Override
public ArtifactServiceClient getArtifactServiceClient(Endpoints.ApiServiceDescriptor endpoint) {
return new ArtifactServiceClient() {
@Override
public ArtifactApi.ResolveArtifactsResponse resolveArtifacts(ArtifactApi.ResolveArtifactsRequest request) {
return ArtifactApi.ResolveArtifactsResponse.getDefaultInstance();
}

@Override
public Iterator<ArtifactApi.GetArtifactResponse> getArtifact(ArtifactApi.GetArtifactRequest request) {
return ImmutableList.of(ArtifactApi.GetArtifactResponse.newBuilder().setData(ByteString.EMPTY).build()).iterator();
}

@Override
public void close() throws Exception {

}
};
}

@Override
public void close() throws Exception {
// do nothing
Expand Down