From 249b87437d9c9db475ba7d9b4c3b6d6a585f5d7f Mon Sep 17 00:00:00 2001 From: Anton Nashatyrev Date: Wed, 2 Oct 2019 14:56:28 +0300 Subject: [PATCH] Decompose Noise Cipher from Handshake classes - avoid allocating max buffers for each message - plain text should be binary --- .../io/libp2p/security/noise/NoiseXXCodec.kt | 38 ++++++++++++++++++ .../security/noise/NoiseXXSecureChannel.kt | 39 +++---------------- .../security/noise/NoiseSecureChannelTest.kt | 8 +++- 3 files changed, 49 insertions(+), 36 deletions(-) create mode 100644 src/main/kotlin/io/libp2p/security/noise/NoiseXXCodec.kt diff --git a/src/main/kotlin/io/libp2p/security/noise/NoiseXXCodec.kt b/src/main/kotlin/io/libp2p/security/noise/NoiseXXCodec.kt new file mode 100644 index 000000000..8eae181de --- /dev/null +++ b/src/main/kotlin/io/libp2p/security/noise/NoiseXXCodec.kt @@ -0,0 +1,38 @@ +package io.libp2p.security.noise + +import com.southernstorm.noise.protocol.CipherState +import io.libp2p.etc.types.toByteArray +import io.netty.buffer.ByteBuf +import io.netty.buffer.Unpooled +import io.netty.channel.ChannelHandlerContext +import io.netty.handler.codec.MessageToMessageCodec +import org.apache.logging.log4j.LogManager + +private val logger = LogManager.getLogger(NoiseXXSecureChannel::class.java.name) + +class NoiseXXCodec(val aliceCipher: CipherState, val bobCipher: CipherState) : MessageToMessageCodec() { + + override fun encode(ctx: ChannelHandlerContext, msg: ByteBuf, out: MutableList) { + val plainLength = msg.readableBytes() + val buf = ByteArray(plainLength + aliceCipher.macLength) + msg.readBytes(buf, 0, plainLength) + val length = aliceCipher.encryptWithAd(null, buf, 0, buf, 0, plainLength) + logger.debug("encrypt length: $length") + out += Unpooled.wrappedBuffer( + Unpooled.buffer().writeShort(length), + Unpooled.wrappedBuffer(buf, 0, length)) + logger.trace("channel outbound handler write: $msg") + } + + override fun decode(ctx: ChannelHandlerContext, msg: ByteBuf, out: MutableList) { + val length = msg.readShort().toInt() + val buf = msg.toByteArray() + logger.debug("decrypt length: $length") + val decryptLen = bobCipher.decryptWithAd(null, buf, 0, buf, 0, length) + out += Unpooled.wrappedBuffer(buf, 0, decryptLen) + } + + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { + logger.error(cause.message) + } +} \ No newline at end of file diff --git a/src/main/kotlin/io/libp2p/security/noise/NoiseXXSecureChannel.kt b/src/main/kotlin/io/libp2p/security/noise/NoiseXXSecureChannel.kt index b75eb4a4d..1998b6f18 100644 --- a/src/main/kotlin/io/libp2p/security/noise/NoiseXXSecureChannel.kt +++ b/src/main/kotlin/io/libp2p/security/noise/NoiseXXSecureChannel.kt @@ -16,19 +16,15 @@ import io.libp2p.core.security.SecureChannel import io.libp2p.etc.SECURE_SESSION import io.libp2p.etc.events.SecureChannelFailed import io.libp2p.etc.events.SecureChannelInitialized -import io.libp2p.etc.types.toByteArray import io.libp2p.etc.types.toByteBuf import io.netty.buffer.ByteBuf import io.netty.channel.ChannelHandlerContext import io.netty.channel.ChannelInboundHandlerAdapter -import io.netty.channel.ChannelOutboundHandlerAdapter -import io.netty.channel.ChannelPromise import io.netty.channel.SimpleChannelInboundHandler import org.apache.logging.log4j.Level import org.apache.logging.log4j.LogManager import org.apache.logging.log4j.core.config.Configurator import spipe.pb.Spipe -import java.nio.charset.StandardCharsets import java.util.concurrent.CompletableFuture import java.util.concurrent.atomic.AtomicInteger @@ -71,40 +67,15 @@ open class NoiseXXSecureChannel(private val localKey: PrivKey, private val priva override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) { when (evt) { is SecureChannelInitialized -> { - - ctx.channel().attr(SECURE_SESSION).set(evt.session) + val session = evt.session as NoiseSecureChannelSession + ctx.channel().attr(SECURE_SESSION).set(session) ctx.pipeline().remove(handshakeHandlerName) ctx.pipeline().remove(this) - ctx.pipeline().addFirst(object : SimpleChannelInboundHandler() { - override fun channelRead0(ctx: ChannelHandlerContext, msg: ByteBuf) { - val get: NoiseSecureChannelSession = ctx.channel().attr(SECURE_SESSION).get() as NoiseSecureChannelSession - val additionalData = ByteArray(65535) - val plainText = ByteArray(65535) - val cipherText = msg.toByteArray() - val length = msg.getShort(0).toInt() - logger.debug("decrypt length:$length") - val l = get.bobCipher.decryptWithAd(additionalData, cipherText, 2, plainText, 0, length) - val rec2 = plainText.copyOf(l).toString(StandardCharsets.UTF_8) - ctx.pipeline().fireChannelRead(rec2) - } - }) - ctx.pipeline().addFirst(object : ChannelOutboundHandlerAdapter() { - override fun write(ctx: ChannelHandlerContext, msg: Any, promise: ChannelPromise) { - msg as ByteBuf - val get: NoiseSecureChannelSession = ctx.channel().attr(SECURE_SESSION).get() as NoiseSecureChannelSession - val additionalData = ByteArray(65535) - val cipherText = ByteArray(65535) - val plaintext = msg.toByteArray() - val length = get.aliceCipher.encryptWithAd(additionalData, plaintext, 0, cipherText, 2, plaintext.size) - logger.debug("encrypt length:$length") - ctx.write(cipherText.copyOf(length + 2).toByteBuf().setShort(0, length)) - logger.debug("channel outbound handler write: $msg") - } - }) - - ret.complete(evt.session) + ctx.pipeline().addLast(NoiseXXCodec(session.aliceCipher, session.bobCipher)) + + ret.complete(session) logger.debug("Reporting secure channel initialized") } diff --git a/src/test/kotlin/io/libp2p/security/noise/NoiseSecureChannelTest.kt b/src/test/kotlin/io/libp2p/security/noise/NoiseSecureChannelTest.kt index 7d3e7c62a..21cdc9ae0 100644 --- a/src/test/kotlin/io/libp2p/security/noise/NoiseSecureChannelTest.kt +++ b/src/test/kotlin/io/libp2p/security/noise/NoiseSecureChannelTest.kt @@ -7,10 +7,12 @@ import io.libp2p.core.crypto.KEY_TYPE import io.libp2p.core.crypto.generateKeyPair import io.libp2p.core.multistream.Mode import io.libp2p.core.multistream.ProtocolMatcher +import io.libp2p.etc.types.toByteArray import io.libp2p.multistream.Negotiator import io.libp2p.multistream.ProtocolSelect import io.libp2p.tools.TestChannel.Companion.interConnect import io.libp2p.tools.TestHandler +import io.netty.buffer.ByteBuf import io.netty.buffer.Unpooled import io.netty.channel.ChannelHandlerContext import io.netty.handler.logging.LogLevel @@ -193,7 +195,8 @@ class NoiseSecureChannelTest { // Setup alice's pipeline eCh1.pipeline().addLast(object : TestHandler("1") { override fun channelRead(ctx: ChannelHandlerContext, msg: Any?) { - rec1 = msg as String + msg as ByteBuf + rec1 = String(msg.toByteArray()) logger.debug("==$name== read: $msg") latch.countDown() } @@ -202,7 +205,8 @@ class NoiseSecureChannelTest { // Setup bob's pipeline eCh2.pipeline().addLast(object : TestHandler("2") { override fun channelRead(ctx: ChannelHandlerContext, msg: Any?) { - rec2 = msg as String + msg as ByteBuf + rec2 = String(msg.toByteArray()) logger.debug("==$name== read: $msg") latch.countDown() }