Skip to content

Commit

Permalink
[CELEBORN-1700] Flink supports fallback to vanilla Flink built-in shu…
Browse files Browse the repository at this point in the history
…ffle implementation
  • Loading branch information
SteNicholas committed Nov 27, 2024
1 parent 903c13a commit 6f68761
Show file tree
Hide file tree
Showing 21 changed files with 770 additions and 511 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,20 @@
import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
import org.apache.flink.runtime.io.network.NettyShuffleEnvironment;
import org.apache.flink.runtime.io.network.NettyShuffleServiceFactory;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
import org.apache.flink.runtime.io.network.metrics.InputChannelMetrics;
import org.apache.flink.runtime.io.network.partition.PartitionProducerStateProvider;
import org.apache.flink.runtime.io.network.partition.ResultPartitionFactory;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateFactory;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.shuffle.ShuffleEnvironment;
import org.apache.flink.runtime.shuffle.ShuffleEnvironmentContext;
import org.apache.flink.runtime.shuffle.ShuffleIOOwnerContext;

import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.reflect.DynFields;
import org.apache.celeborn.plugin.flink.netty.NettyShuffleEnvironmentWrapper;

/**
* The implementation of {@link ShuffleEnvironment} based on the remote shuffle service, providing
Expand All @@ -57,56 +53,32 @@ public class RemoteShuffleEnvironment extends AbstractRemoteShuffleEnvironment

private final RemoteShuffleInputGateFactory inputGateFactory;

private final NettyShuffleServiceFactory nettyShuffleServiceFactory;

private final ShuffleEnvironmentContext shuffleEnvironmentContext;
private final NettyShuffleEnvironmentWrapper shuffleEnvironmentWrapper;

private final ConcurrentHashMap.KeySetView<IntermediateDataSetID, Boolean> nettyResultIds =
ConcurrentHashMap.newKeySet();

private final ConcurrentHashMap.KeySetView<IntermediateResultPartitionID, Boolean>
nettyResultPartitionIds = ConcurrentHashMap.newKeySet();

private volatile NettyShuffleEnvironment nettyShuffleEnvironment;

private volatile ResultPartitionFactory nettyResultPartitionFactory;

private volatile SingleInputGateFactory nettyInputGateFactory;

private static final DynFields.UnboundField<ResultPartitionFactory>
RESULT_PARTITION_FACTORY_FIELD =
DynFields.builder()
.hiddenImpl(NettyShuffleEnvironment.class, "resultPartitionFactory")
.defaultAlwaysNull()
.build();

private static final DynFields.UnboundField<SingleInputGateFactory> INPUT_GATE_FACTORY_FIELD =
DynFields.builder()
.hiddenImpl(NettyShuffleEnvironment.class, "singleInputGateFactory")
.defaultAlwaysNull()
.build();

/**
* @param networkBufferPool Network buffer pool for shuffle read and shuffle write.
* @param resultPartitionManager A trivial {@link ResultPartitionManager}.
* @param resultPartitionFactory Factory class to create {@link RemoteShuffleResultPartition}.
* @param inputGateFactory Factory class to create {@link RemoteShuffleInputGate}.
* @param nettyShuffleServiceFactory Factory class to create {@link NettyShuffleEnvironment}.
* @param shuffleEnvironmentContext Environment context of shuffle.
* @param shuffleEnvironmentWrapper Wrapper class to create {@link NettyShuffleEnvironment}.
*/
public RemoteShuffleEnvironment(
NetworkBufferPool networkBufferPool,
ResultPartitionManager resultPartitionManager,
RemoteShuffleResultPartitionFactory resultPartitionFactory,
RemoteShuffleInputGateFactory inputGateFactory,
CelebornConf conf,
NettyShuffleServiceFactory nettyShuffleServiceFactory,
ShuffleEnvironmentContext shuffleEnvironmentContext) {
NettyShuffleEnvironmentWrapper shuffleEnvironmentWrapper) {
super(networkBufferPool, resultPartitionManager, conf);
this.resultPartitionFactory = resultPartitionFactory;
this.inputGateFactory = inputGateFactory;
this.nettyShuffleServiceFactory = nettyShuffleServiceFactory;
this.shuffleEnvironmentContext = shuffleEnvironmentContext;
this.shuffleEnvironmentWrapper = shuffleEnvironmentWrapper;
}

@Override
Expand All @@ -122,7 +94,8 @@ public ResultPartitionWriter createResultPartitionWriterInternal(
} else {
nettyResultIds.add(resultPartitionDeploymentDescriptor.getResultId());
nettyResultPartitionIds.add(resultPartitionDeploymentDescriptor.getPartitionId());
return nettyResultPartitionFactory()
return shuffleEnvironmentWrapper
.nettyResultPartitionFactory()
.create(ownerContext.getOwnerName(), index, resultPartitionDeploymentDescriptor);
}
}
Expand All @@ -134,7 +107,8 @@ IndexedInputGate createInputGateInternal(
int gateIndex,
InputGateDeploymentDescriptor igdd) {
return nettyResultIds.contains(igdd.getConsumedResultId())
? nettyInputGateFactory()
? shuffleEnvironmentWrapper
.nettyInputGateFactory()
.create(
ownerContext.getOwnerName(),
gateIndex,
Expand All @@ -151,47 +125,14 @@ public void releasePartitionsLocally(Collection<ResultPartitionID> partitionIds)
.filter(partitionId -> nettyResultPartitionIds.contains(partitionId.getPartitionId()))
.collect(Collectors.toList());
if (!resultPartitionIds.isEmpty()) {
nettyShuffleEnvironment().releasePartitionsLocally(resultPartitionIds);
shuffleEnvironmentWrapper
.nettyShuffleEnvironment()
.releasePartitionsLocally(resultPartitionIds);
}
}

@VisibleForTesting
RemoteShuffleResultPartitionFactory getResultPartitionFactory() {
return resultPartitionFactory;
}

private NettyShuffleEnvironment nettyShuffleEnvironment() {
if (nettyShuffleEnvironment == null) {
synchronized (this) {
if (nettyShuffleEnvironment == null) {
nettyShuffleEnvironment =
nettyShuffleServiceFactory.createShuffleEnvironment(shuffleEnvironmentContext);
}
}
}
return nettyShuffleEnvironment;
}

private ResultPartitionFactory nettyResultPartitionFactory() {
if (nettyResultPartitionFactory == null) {
synchronized (this) {
if (nettyResultPartitionFactory == null) {
nettyResultPartitionFactory =
RESULT_PARTITION_FACTORY_FIELD.bind(nettyShuffleEnvironment()).get();
}
}
}
return nettyResultPartitionFactory;
}

private SingleInputGateFactory nettyInputGateFactory() {
if (nettyInputGateFactory == null) {
synchronized (this) {
if (nettyInputGateFactory == null) {
nettyInputGateFactory = INPUT_GATE_FACTORY_FIELD.bind(nettyShuffleEnvironment()).get();
}
}
}
return nettyInputGateFactory;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate;
import org.apache.flink.runtime.shuffle.*;

import org.apache.celeborn.plugin.flink.netty.NettyShuffleEnvironmentWrapper;

public class RemoteShuffleServiceFactory extends AbstractRemoteShuffleServiceFactory
implements ShuffleServiceFactory<ShuffleDescriptor, ResultPartitionWriter, IndexedInputGate> {

Expand Down Expand Up @@ -56,7 +58,6 @@ public ShuffleEnvironment<ResultPartitionWriter, IndexedInputGate> createShuffle
resultPartitionFactory,
inputGateFactory,
parameters.celebornConf,
nettyShuffleServiceFactory,
shuffleEnvironmentContext);
new NettyShuffleEnvironmentWrapper(nettyShuffleServiceFactory, shuffleEnvironmentContext));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.celeborn.plugin.flink.netty;

import org.apache.flink.runtime.io.network.NettyShuffleEnvironment;
import org.apache.flink.runtime.io.network.NettyShuffleServiceFactory;
import org.apache.flink.runtime.io.network.partition.ResultPartitionFactory;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateFactory;
import org.apache.flink.runtime.shuffle.ShuffleEnvironmentContext;

import org.apache.celeborn.reflect.DynFields;

/**
* The wrapper of {@link NettyShuffleEnvironment} to generate {@link ResultPartitionFactory} and
* {@link SingleInputGateFactory}.
*/
public class NettyShuffleEnvironmentWrapper {

private final NettyShuffleServiceFactory nettyShuffleServiceFactory;
private final ShuffleEnvironmentContext shuffleEnvironmentContext;

private volatile NettyShuffleEnvironment nettyShuffleEnvironment;
private volatile ResultPartitionFactory nettyResultPartitionFactory;
private volatile SingleInputGateFactory nettyInputGateFactory;

private static final DynFields.UnboundField<ResultPartitionFactory>
RESULT_PARTITION_FACTORY_FIELD =
DynFields.builder()
.hiddenImpl(NettyShuffleEnvironment.class, "resultPartitionFactory")
.defaultAlwaysNull()
.build();

private static final DynFields.UnboundField<SingleInputGateFactory> INPUT_GATE_FACTORY_FIELD =
DynFields.builder()
.hiddenImpl(NettyShuffleEnvironment.class, "singleInputGateFactory")
.defaultAlwaysNull()
.build();

public NettyShuffleEnvironmentWrapper(
NettyShuffleServiceFactory nettyShuffleServiceFactory,
ShuffleEnvironmentContext shuffleEnvironmentContext) {
this.nettyShuffleServiceFactory = nettyShuffleServiceFactory;
this.shuffleEnvironmentContext = shuffleEnvironmentContext;
}

public NettyShuffleEnvironment nettyShuffleEnvironment() {
if (nettyShuffleEnvironment == null) {
synchronized (this) {
if (nettyShuffleEnvironment == null) {
nettyShuffleEnvironment =
nettyShuffleServiceFactory.createShuffleEnvironment(shuffleEnvironmentContext);
}
}
}
return nettyShuffleEnvironment;
}

public ResultPartitionFactory nettyResultPartitionFactory() {
if (nettyResultPartitionFactory == null) {
synchronized (this) {
if (nettyResultPartitionFactory == null) {
nettyResultPartitionFactory =
RESULT_PARTITION_FACTORY_FIELD.bind(nettyShuffleEnvironment()).get();
}
}
}
return nettyResultPartitionFactory;
}

public SingleInputGateFactory nettyInputGateFactory() {
if (nettyInputGateFactory == null) {
synchronized (this) {
if (nettyInputGateFactory == null) {
nettyInputGateFactory = INPUT_GATE_FACTORY_FIELD.bind(nettyShuffleEnvironment()).get();
}
}
}
return nettyInputGateFactory;
}
}
Loading

0 comments on commit 6f68761

Please sign in to comment.