Skip to content

Commit

Permalink
[CELEBORN-1348] Update infrastructure for SSL communication
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Update infrastructure for SSL support.
Please see #2416 for the consolidated PR with all the changes for reference.

### Why are the changes needed?

At a high level, the changes are:
* `ManagedBuffer.convertToNettyForSsl`, to support SSL encryption.
* Add `EncryptedMessageWithHeader`, which is used to wrap the message and body, for use with SSL.
* `SslMessageEncoder`  is an encoder for SSL

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?

The overall PR #2416 (and this PR as well) passes all tests, and this PR includes relevant subset of tests.

Closes #2427 from mridulm/update-infra-for-ssl.

Authored-by: Mridul Muralidharan <mridulatgmail.com>
Signed-off-by: SteNicholas <[email protected]>
  • Loading branch information
Mridul Muralidharan authored and SteNicholas committed Apr 1, 2024
1 parent df2cb1b commit 3ff8812
Show file tree
Hide file tree
Showing 10 changed files with 456 additions and 0 deletions.
3 changes: 3 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,14 @@ Apache License 2.0
Apache Spark
./client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
./client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
./common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java
./common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java
./common/src/main/java/org/apache/celeborn/common/network/ssl/ReloadingX509TrustManager.java
./common/src/main/java/org/apache/celeborn/common/network/util/NettyLogger.java
./common/src/main/java/org/apache/celeborn/common/unsafe/Platform.java
./common/src/main/java/org/apache/celeborn/common/util/JavaUtils.java
./common/src/main/scala/org/apache/celeborn/common/util/SignalUtils.scala
./common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java
./common/src/test/java/org/apache/celeborn/common/network/ssl/ReloadingX509TrustManagerSuiteJ.java
./common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/DB.java
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.celeborn.plugin.flink.buffer;

import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;

Expand Down Expand Up @@ -64,4 +65,9 @@ public ManagedBuffer release() {
public Object convertToNetty() {
return buf.duplicate().retain();
}

@Override
public Object convertToNettyForSsl() throws IOException {
return buf.duplicate().retain();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import com.google.common.io.ByteStreams;
import io.netty.channel.DefaultFileRegion;
import io.netty.handler.stream.ChunkedStream;
import org.apache.commons.lang3.builder.ToStringBuilder;
import org.apache.commons.lang3.builder.ToStringStyle;

Expand Down Expand Up @@ -132,6 +133,12 @@ public Object convertToNetty() throws IOException {
}
}

@Override
public Object convertToNettyForSsl() throws IOException {
// Cannot use zero-copy with SSL
return new ChunkedStream(createInputStream(), conf.maxSslEncryptedBlockSize());
}

public File getFile() {
return file;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,15 @@ public abstract class ManagedBuffer {
* the caller will be responsible for releasing this new reference.
*/
public abstract Object convertToNetty() throws IOException;

/**
* Convert the buffer into a Netty object, used to write the data out with SSL encryption, which
* cannot use {@link io.netty.channel.FileRegion}. The return value is either a {@link
* io.netty.buffer.ByteBuf}, a {@link io.netty.handler.stream.ChunkedStream}, or a {@link
* java.io.InputStream}.
*
* <p>If this method returns a ByteBuf, then that buffer's reference count will be incremented and
* the caller will be responsible for releasing this new reference.
*/
public abstract Object convertToNettyForSsl() throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ public Object convertToNetty() throws IOException {
return buf.duplicate().retain();
}

@Override
public Object convertToNettyForSsl() throws IOException {
return buf.duplicate().retain();
}

@Override
public String toString() {
return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ public Object convertToNetty() throws IOException {
return Unpooled.wrappedBuffer(buf);
}

@Override
public Object convertToNettyForSsl() throws IOException {
return Unpooled.wrappedBuffer(buf);
}

@Override
public String toString() {
return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.celeborn.common.network.protocol;

import java.io.EOFException;
import java.io.InputStream;

import javax.annotation.Nullable;

import com.google.common.base.Preconditions;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.stream.ChunkedInput;
import io.netty.handler.stream.ChunkedStream;

import org.apache.celeborn.common.network.buffer.ManagedBuffer;

/**
* A wrapper message that holds two separate pieces (a header and a body).
*
* <p>The header must be a ByteBuf, while the body can be any InputStream or ChunkedStream Based on
* common/network-common/org.apache.spark.network.protocol.EncryptedMessageWithHeader
*/
public class EncryptedMessageWithHeader implements ChunkedInput<ByteBuf> {

@Nullable private final ManagedBuffer managedBuffer;
private final ByteBuf header;
private final int headerLength;
private final Object body;
private final long bodyLength;
private long totalBytesTransferred;

/**
* Construct a new EncryptedMessageWithHeader.
*
* @param managedBuffer the {@link ManagedBuffer} that the message body came from. This needs to
* be passed in so that the buffer can be freed when this message is deallocated. Ownership of
* the caller's reference to this buffer is transferred to this class, so if the caller wants
* to continue to use the ManagedBuffer in other messages then they will need to call retain()
* on it before passing it to this constructor.
* @param header the message header.
* @param body the message body.
* @param bodyLength the length of the message body, in bytes.
*/
public EncryptedMessageWithHeader(
@Nullable ManagedBuffer managedBuffer, ByteBuf header, Object body, long bodyLength) {
Preconditions.checkArgument(
body instanceof InputStream || body instanceof ChunkedStream,
"Body must be an InputStream or a ChunkedStream.");
this.managedBuffer = managedBuffer;
this.header = header;
this.headerLength = header.readableBytes();
this.body = body;
this.bodyLength = bodyLength;
this.totalBytesTransferred = 0;
}

@Override
public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception {
return readChunk(ctx.alloc());
}

@Override
public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception {
if (isEndOfInput()) {
return null;
}

if (totalBytesTransferred < headerLength) {
totalBytesTransferred += headerLength;
return header.retain();
} else if (body instanceof InputStream) {
InputStream stream = (InputStream) body;
int available = stream.available();
if (available <= 0) {
available = (int) (length() - totalBytesTransferred);
} else {
available = (int) Math.min(available, length() - totalBytesTransferred);
}
ByteBuf buffer = allocator.buffer(available);
int toRead = Math.min(available, buffer.writableBytes());
int read = buffer.writeBytes(stream, toRead);
if (read >= 0) {
totalBytesTransferred += read;
return buffer;
} else {
throw new EOFException("Unable to read bytes from InputStream");
}
} else if (body instanceof ChunkedStream) {
ChunkedStream stream = (ChunkedStream) body;
long old = stream.transferredBytes();
ByteBuf buffer = stream.readChunk(allocator);
long read = stream.transferredBytes() - old;
if (read >= 0) {
totalBytesTransferred += read;
assert (totalBytesTransferred <= length());
return buffer;
} else {
throw new EOFException("Unable to read bytes from ChunkedStream");
}
} else {
return null;
}
}

@Override
public long length() {
return headerLength + bodyLength;
}

@Override
public long progress() {
return totalBytesTransferred;
}

@Override
public boolean isEndOfInput() throws Exception {
return (headerLength + bodyLength) == totalBytesTransferred;
}

@Override
public void close() throws Exception {
header.release();
if (managedBuffer != null) {
managedBuffer.release();
}
if (body instanceof InputStream) {
((InputStream) body).close();
} else if (body instanceof ChunkedStream) {
((ChunkedStream) body).close();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.celeborn.common.network.protocol;

import java.io.InputStream;
import java.util.List;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToMessageEncoder;
import io.netty.handler.stream.ChunkedStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Encoder used by the server side to encode secure (SSL) server-to-client responses. This encoder
* is stateless so it is safe to be shared by multiple threads. Based on
* common/network-common/org.apache.spark.network.protocol.SslMessageEncoder
*/
@ChannelHandler.Sharable
public final class SslMessageEncoder extends MessageToMessageEncoder<Message> {

private static final Logger logger = LoggerFactory.getLogger(SslMessageEncoder.class);
public static final SslMessageEncoder INSTANCE = new SslMessageEncoder();

private SslMessageEncoder() {}

/**
* Encodes a Message by invoking its encode() method. For non-data messages, we will add one
* ByteBuf to 'out' containing the total frame length, the message type, and the message itself.
* In the case of a ChunkFetchSuccess, we will also add the ManagedBuffer corresponding to the
* data to 'out'.
*/
@Override
public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) throws Exception {
Object body = null;
int bodyLength = 0;

// If the message has a body, take it out...
// For SSL, zero-copy transfer will not work, so we will check if
// the body is an InputStream, and if so, use an EncryptedMessageWithHeader
// to wrap the header+body appropriately (for thread safety).
if (in.body() != null) {
try {
bodyLength = (int) in.body().size();
body = in.body().convertToNettyForSsl();
} catch (Exception e) {
in.body().release();
if (in instanceof ResponseMessage) {
ResponseMessage resp = (ResponseMessage) in;
// Re-encode this message as a failure response.
String error = e.getMessage() != null ? e.getMessage() : "null";
logger.error(
String.format("Error processing %s for client %s", in, ctx.channel().remoteAddress()),
e);
encode(ctx, resp.createFailureResponse(error), out);
} else {
throw e;
}
return;
}
}

Message.Type msgType = in.type();
// message size, message type size, body size, message encoded length
int headerLength = 4 + msgType.encodedLength() + 4 + in.encodedLength();
ByteBuf header = ctx.alloc().heapBuffer(headerLength);
header.writeInt(in.encodedLength());
msgType.encode(header);
header.writeInt(bodyLength);
in.encode(header);
assert header.writableBytes() == 0;

if (body != null && bodyLength > 0) {
if (body instanceof ByteBuf) {
out.add(Unpooled.wrappedBuffer(header, (ByteBuf) body));
} else if (body instanceof InputStream || body instanceof ChunkedStream) {
// For now, assume the InputStream is doing proper chunking.
out.add(new EncryptedMessageWithHeader(in.body(), header, body, bodyLength));
} else {
throw new IllegalArgumentException(
"Body must be a ByteBuf, ChunkedStream or an InputStream");
}
} else {
out.add(header);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ public Object convertToNetty() throws IOException {
return underlying.convertToNetty();
}

@Override
public Object convertToNettyForSsl() throws IOException {
return underlying.convertToNettyForSsl();
}

@Override
public int hashCode() {
return underlying.hashCode();
Expand Down
Loading

0 comments on commit 3ff8812

Please sign in to comment.