Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StateWatcher watches and reports changed Pipeline State #32040

Merged
merged 2 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions runners/prism/java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ description = "Apache Beam :: Runners :: Prism :: Java"
ext.summary = "Support for executing a pipeline on Prism."

dependencies {
implementation project(path: ":model:job-management", configuration: "shadow")
implementation project(path: ":sdks:java:core", configuration: "shadow")
implementation project(":runners:portability:java")

implementation library.java.joda_time
implementation library.java.slf4j_api
implementation library.java.vendored_grpc_1_60_1
implementation library.java.vendored_guava_32_1_2_jre

testImplementation library.java.junit
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.prism;

import org.apache.beam.sdk.PipelineResult;

/** Listens for {@link PipelineResult.State} changes reported by the {@link StateWatcher}. */
interface StateListener {

/** Callback invoked when {@link StateWatcher} discovers a {@link PipelineResult.State} change. */
void onStateChanged(PipelineResult.State state);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.prism;

import com.google.auto.value.AutoValue;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import org.apache.beam.model.jobmanagement.v1.JobApi;
import org.apache.beam.model.jobmanagement.v1.JobServiceGrpc;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ChannelCredentials;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.InsecureChannelCredentials;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.netty.NettyChannelBuilder;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;

/**
* {@link StateWatcher} {@link #watch}es for and reports {@link PipelineResult.State} changes to
* {@link StateListener}s.
*/
@AutoValue
abstract class StateWatcher implements AutoCloseable {

private Optional<PipelineResult.State> latestState = Optional.empty();

/**
* Instantiates a {@link StateWatcher} with {@link InsecureChannelCredentials}. {@link
* StateWatcher} will report to each {@link StateListener} of {@param listeners} of any changed
* {@link PipelineResult.State}.
*/
static StateWatcher insecure(String endpoint, StateListener... listeners) {
return StateWatcher.builder()
.setEndpoint(HostAndPort.fromString(endpoint))
.setCredentials(InsecureChannelCredentials.create())
.setListeners(Arrays.asList(listeners))
.build();
}

/**
* Watch for a Job's {@link PipelineResult.State} change. A {@link
* org.apache.beam.model.jobmanagement.v1.JobApi.GetJobStateRequest} identifies a Job to watch via
* its {@link JobApi.GetJobStateRequest#getJobId()}. The method is blocking until the {@link
* JobApi.JobStateEvent} {@link StreamObserver#onCompleted()}.
*/
void watch(String jobId) {
JobApi.GetJobStateRequest request =
JobApi.GetJobStateRequest.newBuilder().setJobId(jobId).build();
Iterator<JobApi.JobStateEvent> iterator = getJobServiceBlockingStub().getStateStream(request);
while (iterator.hasNext()) {
JobApi.JobStateEvent event = iterator.next();
PipelineResult.State state = PipelineResult.State.valueOf(event.getState().name());
publish(state);
}
}

private void publish(PipelineResult.State state) {
if (latestState.isPresent() && latestState.get().equals(state)) {
return;
}
latestState = Optional.of(state);
for (StateListener listener : getListeners()) {
listener.onStateChanged(state);
}
}

static Builder builder() {
return new AutoValue_StateWatcher.Builder();
}

abstract HostAndPort getEndpoint();

abstract ChannelCredentials getCredentials();

abstract List<StateListener> getListeners();

abstract ManagedChannel getManagedChannel();

abstract JobServiceGrpc.JobServiceBlockingStub getJobServiceBlockingStub();

@Override
public void close() {
getManagedChannel().shutdown();
try {
getManagedChannel().awaitTermination(3000L, TimeUnit.MILLISECONDS);
} catch (InterruptedException ignored) {
}
}

@AutoValue.Builder
abstract static class Builder {

abstract Builder setEndpoint(HostAndPort endpoint);

abstract Optional<HostAndPort> getEndpoint();

abstract Builder setCredentials(ChannelCredentials credentials);

abstract Optional<ChannelCredentials> getCredentials();

abstract Builder setListeners(List<StateListener> listeners);

abstract Builder setManagedChannel(ManagedChannel managedChannel);

abstract Builder setJobServiceBlockingStub(
JobServiceGrpc.JobServiceBlockingStub jobServiceBlockingStub);

abstract StateWatcher autoBuild();

final StateWatcher build() {
if (!getEndpoint().isPresent()) {
throw new IllegalStateException("missing endpoint");
}
if (!getCredentials().isPresent()) {
throw new IllegalStateException("missing credentials");
}
HostAndPort endpoint = getEndpoint().get();
ManagedChannel channel =
NettyChannelBuilder.forAddress(
endpoint.getHost(), endpoint.getPort(), getCredentials().get())
.build();
setManagedChannel(channel);
setJobServiceBlockingStub(JobServiceGrpc.newBlockingStub(channel));

return autoBuild();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.prism;

import static com.google.common.truth.Truth.assertThat;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.beam.model.jobmanagement.v1.JobApi;
import org.apache.beam.model.jobmanagement.v1.JobServiceGrpc;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Grpc;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.InsecureServerCredentials;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class StateWatcherTest {

@Test
public void givenSingleListener_watches() {
Server server = serverOf(PipelineResult.State.RUNNING, PipelineResult.State.DONE);
TestStateListener listener = new TestStateListener();
try (StateWatcher underTest = StateWatcher.insecure("0.0.0.0:" + server.getPort(), listener)) {
underTest.watch("job-001");
assertThat(listener.states)
.containsExactly(PipelineResult.State.RUNNING, PipelineResult.State.DONE);
shutdown(server);
}
}

@Test
public void givenMultipleListeners_watches() {
Server server = serverOf(PipelineResult.State.RUNNING, PipelineResult.State.DONE);
TestStateListener listenerA = new TestStateListener();
TestStateListener listenerB = new TestStateListener();
try (StateWatcher underTest =
StateWatcher.insecure("0.0.0.0:" + server.getPort(), listenerA, listenerB)) {
underTest.watch("job-001");
assertThat(listenerA.states)
.containsExactly(PipelineResult.State.RUNNING, PipelineResult.State.DONE);
assertThat(listenerB.states)
.containsExactly(PipelineResult.State.RUNNING, PipelineResult.State.DONE);
shutdown(server);
}
}

@Test
public void publishesOnlyChangedState() {
Server server =
serverOf(
PipelineResult.State.RUNNING,
PipelineResult.State.RUNNING,
PipelineResult.State.RUNNING,
PipelineResult.State.RUNNING,
PipelineResult.State.RUNNING,
PipelineResult.State.RUNNING,
PipelineResult.State.RUNNING,
PipelineResult.State.DONE);
TestStateListener listener = new TestStateListener();
try (StateWatcher underTest = StateWatcher.insecure("0.0.0.0:" + server.getPort(), listener)) {
underTest.watch("job-001");
assertThat(listener.states)
.containsExactly(PipelineResult.State.RUNNING, PipelineResult.State.DONE);
shutdown(server);
}
}

private static class TestStateListener implements StateListener {
private final List<PipelineResult.State> states = new ArrayList<>();

@Override
public void onStateChanged(PipelineResult.State state) {
states.add(state);
}
}

private static class TestJobServiceStateStream extends JobServiceGrpc.JobServiceImplBase {
private final List<PipelineResult.State> states;

TestJobServiceStateStream(PipelineResult.State... states) {
this.states = Arrays.asList(states);
}

@Override
public void getStateStream(
JobApi.GetJobStateRequest request, StreamObserver<JobApi.JobStateEvent> responseObserver) {
for (PipelineResult.State state : states) {
responseObserver.onNext(
JobApi.JobStateEvent.newBuilder()
.setState(JobApi.JobState.Enum.valueOf(state.name()))
.build());
}
responseObserver.onCompleted();
}
}

private static Server serverOf(PipelineResult.State... states) {
try {
return Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create())
.addService(new TestJobServiceStateStream(states))
.build()
.start();
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private static void shutdown(Server server) {
server.shutdownNow();
try {
server.awaitTermination();
} catch (InterruptedException ignored) {
}
}
}
Loading