diff --git a/runners/prism/java/build.gradle b/runners/prism/java/build.gradle index 93d151f3e058..23f4a024569b 100644 --- a/runners/prism/java/build.gradle +++ b/runners/prism/java/build.gradle @@ -26,11 +26,13 @@ description = "Apache Beam :: Runners :: Prism :: Java" ext.summary = "Support for executing a pipeline on Prism." dependencies { + implementation project(path: ":model:job-management", configuration: "shadow") implementation project(path: ":sdks:java:core", configuration: "shadow") implementation project(":runners:portability:java") implementation library.java.joda_time implementation library.java.slf4j_api + implementation library.java.vendored_grpc_1_60_1 implementation library.java.vendored_guava_32_1_2_jre testImplementation library.java.junit diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/StateListener.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/StateListener.java new file mode 100644 index 000000000000..89f537e4f812 --- /dev/null +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/StateListener.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.prism; + +import org.apache.beam.sdk.PipelineResult; + +/** Listens for {@link PipelineResult.State} changes reported by the {@link StateWatcher}. */ +interface StateListener { + + /** Callback invoked when {@link StateWatcher} discovers a {@link PipelineResult.State} change. */ + void onStateChanged(PipelineResult.State state); +} diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/StateWatcher.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/StateWatcher.java new file mode 100644 index 000000000000..fe9eb84a72b5 --- /dev/null +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/StateWatcher.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.prism; + +import com.google.auto.value.AutoValue; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import org.apache.beam.model.jobmanagement.v1.JobApi; +import org.apache.beam.model.jobmanagement.v1.JobServiceGrpc; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ChannelCredentials; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.InsecureChannelCredentials; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.netty.NettyChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; + +/** + * {@link StateWatcher} {@link #watch}es for and reports {@link PipelineResult.State} changes to + * {@link StateListener}s. + */ +@AutoValue +abstract class StateWatcher implements AutoCloseable { + + private Optional latestState = Optional.empty(); + + /** + * Instantiates a {@link StateWatcher} with {@link InsecureChannelCredentials}. {@link + * StateWatcher} will report to each {@link StateListener} of {@param listeners} of any changed + * {@link PipelineResult.State}. + */ + static StateWatcher insecure(String endpoint, StateListener... listeners) { + return StateWatcher.builder() + .setEndpoint(HostAndPort.fromString(endpoint)) + .setCredentials(InsecureChannelCredentials.create()) + .setListeners(Arrays.asList(listeners)) + .build(); + } + + /** + * Watch for a Job's {@link PipelineResult.State} change. A {@link + * org.apache.beam.model.jobmanagement.v1.JobApi.GetJobStateRequest} identifies a Job to watch via + * its {@link JobApi.GetJobStateRequest#getJobId()}. The method is blocking until the {@link + * JobApi.JobStateEvent} {@link StreamObserver#onCompleted()}. + */ + void watch(String jobId) { + JobApi.GetJobStateRequest request = + JobApi.GetJobStateRequest.newBuilder().setJobId(jobId).build(); + Iterator iterator = getJobServiceBlockingStub().getStateStream(request); + while (iterator.hasNext()) { + JobApi.JobStateEvent event = iterator.next(); + PipelineResult.State state = PipelineResult.State.valueOf(event.getState().name()); + publish(state); + } + } + + private void publish(PipelineResult.State state) { + if (latestState.isPresent() && latestState.get().equals(state)) { + return; + } + latestState = Optional.of(state); + for (StateListener listener : getListeners()) { + listener.onStateChanged(state); + } + } + + static Builder builder() { + return new AutoValue_StateWatcher.Builder(); + } + + abstract HostAndPort getEndpoint(); + + abstract ChannelCredentials getCredentials(); + + abstract List getListeners(); + + abstract ManagedChannel getManagedChannel(); + + abstract JobServiceGrpc.JobServiceBlockingStub getJobServiceBlockingStub(); + + @Override + public void close() { + getManagedChannel().shutdown(); + try { + getManagedChannel().awaitTermination(3000L, TimeUnit.MILLISECONDS); + } catch (InterruptedException ignored) { + } + } + + @AutoValue.Builder + abstract static class Builder { + + abstract Builder setEndpoint(HostAndPort endpoint); + + abstract Optional getEndpoint(); + + abstract Builder setCredentials(ChannelCredentials credentials); + + abstract Optional getCredentials(); + + abstract Builder setListeners(List listeners); + + abstract Builder setManagedChannel(ManagedChannel managedChannel); + + abstract Builder setJobServiceBlockingStub( + JobServiceGrpc.JobServiceBlockingStub jobServiceBlockingStub); + + abstract StateWatcher autoBuild(); + + final StateWatcher build() { + if (!getEndpoint().isPresent()) { + throw new IllegalStateException("missing endpoint"); + } + if (!getCredentials().isPresent()) { + throw new IllegalStateException("missing credentials"); + } + HostAndPort endpoint = getEndpoint().get(); + ManagedChannel channel = + NettyChannelBuilder.forAddress( + endpoint.getHost(), endpoint.getPort(), getCredentials().get()) + .build(); + setManagedChannel(channel); + setJobServiceBlockingStub(JobServiceGrpc.newBlockingStub(channel)); + + return autoBuild(); + } + } +} diff --git a/runners/prism/java/src/test/java/org/apache/beam/runners/prism/StateWatcherTest.java b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/StateWatcherTest.java new file mode 100644 index 000000000000..cfc420046206 --- /dev/null +++ b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/StateWatcherTest.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.prism; + +import static com.google.common.truth.Truth.assertThat; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.apache.beam.model.jobmanagement.v1.JobApi; +import org.apache.beam.model.jobmanagement.v1.JobServiceGrpc; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Grpc; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.InsecureServerCredentials; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class StateWatcherTest { + + @Test + public void givenSingleListener_watches() { + Server server = serverOf(PipelineResult.State.RUNNING, PipelineResult.State.DONE); + TestStateListener listener = new TestStateListener(); + try (StateWatcher underTest = StateWatcher.insecure("0.0.0.0:" + server.getPort(), listener)) { + underTest.watch("job-001"); + assertThat(listener.states) + .containsExactly(PipelineResult.State.RUNNING, PipelineResult.State.DONE); + shutdown(server); + } + } + + @Test + public void givenMultipleListeners_watches() { + Server server = serverOf(PipelineResult.State.RUNNING, PipelineResult.State.DONE); + TestStateListener listenerA = new TestStateListener(); + TestStateListener listenerB = new TestStateListener(); + try (StateWatcher underTest = + StateWatcher.insecure("0.0.0.0:" + server.getPort(), listenerA, listenerB)) { + underTest.watch("job-001"); + assertThat(listenerA.states) + .containsExactly(PipelineResult.State.RUNNING, PipelineResult.State.DONE); + assertThat(listenerB.states) + .containsExactly(PipelineResult.State.RUNNING, PipelineResult.State.DONE); + shutdown(server); + } + } + + @Test + public void publishesOnlyChangedState() { + Server server = + serverOf( + PipelineResult.State.RUNNING, + PipelineResult.State.RUNNING, + PipelineResult.State.RUNNING, + PipelineResult.State.RUNNING, + PipelineResult.State.RUNNING, + PipelineResult.State.RUNNING, + PipelineResult.State.RUNNING, + PipelineResult.State.DONE); + TestStateListener listener = new TestStateListener(); + try (StateWatcher underTest = StateWatcher.insecure("0.0.0.0:" + server.getPort(), listener)) { + underTest.watch("job-001"); + assertThat(listener.states) + .containsExactly(PipelineResult.State.RUNNING, PipelineResult.State.DONE); + shutdown(server); + } + } + + private static class TestStateListener implements StateListener { + private final List states = new ArrayList<>(); + + @Override + public void onStateChanged(PipelineResult.State state) { + states.add(state); + } + } + + private static class TestJobServiceStateStream extends JobServiceGrpc.JobServiceImplBase { + private final List states; + + TestJobServiceStateStream(PipelineResult.State... states) { + this.states = Arrays.asList(states); + } + + @Override + public void getStateStream( + JobApi.GetJobStateRequest request, StreamObserver responseObserver) { + for (PipelineResult.State state : states) { + responseObserver.onNext( + JobApi.JobStateEvent.newBuilder() + .setState(JobApi.JobState.Enum.valueOf(state.name())) + .build()); + } + responseObserver.onCompleted(); + } + } + + private static Server serverOf(PipelineResult.State... states) { + try { + return Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create()) + .addService(new TestJobServiceStateStream(states)) + .build() + .start(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static void shutdown(Server server) { + server.shutdownNow(); + try { + server.awaitTermination(); + } catch (InterruptedException ignored) { + } + } +}