Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Several updates to Noise PR #48

Merged
merged 1 commit into from
Oct 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/main/kotlin/io/libp2p/security/noise/NoiseXXCodec.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package io.libp2p.security.noise

import com.southernstorm.noise.protocol.CipherState
import io.libp2p.etc.types.toByteArray
import io.netty.buffer.ByteBuf
import io.netty.buffer.Unpooled
import io.netty.channel.ChannelHandlerContext
import io.netty.handler.codec.MessageToMessageCodec
import org.apache.logging.log4j.LogManager

private val logger = LogManager.getLogger(NoiseXXSecureChannel::class.java.name)

class NoiseXXCodec(val aliceCipher: CipherState, val bobCipher: CipherState) : MessageToMessageCodec<ByteBuf, ByteBuf>() {

override fun encode(ctx: ChannelHandlerContext, msg: ByteBuf, out: MutableList<Any>) {
val plainLength = msg.readableBytes()
val buf = ByteArray(plainLength + aliceCipher.macLength)
msg.readBytes(buf, 0, plainLength)
val length = aliceCipher.encryptWithAd(null, buf, 0, buf, 0, plainLength)
logger.debug("encrypt length: $length")
out += Unpooled.wrappedBuffer(
Unpooled.buffer().writeShort(length),
Unpooled.wrappedBuffer(buf, 0, length))
logger.trace("channel outbound handler write: $msg")
}

override fun decode(ctx: ChannelHandlerContext, msg: ByteBuf, out: MutableList<Any>) {
val length = msg.readShort().toInt()
val buf = msg.toByteArray()
logger.debug("decrypt length: $length")
val decryptLen = bobCipher.decryptWithAd(null, buf, 0, buf, 0, length)
out += Unpooled.wrappedBuffer(buf, 0, decryptLen)
}

override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
logger.error(cause.message)
}
}
39 changes: 5 additions & 34 deletions src/main/kotlin/io/libp2p/security/noise/NoiseXXSecureChannel.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,15 @@ import io.libp2p.core.security.SecureChannel
import io.libp2p.etc.SECURE_SESSION
import io.libp2p.etc.events.SecureChannelFailed
import io.libp2p.etc.events.SecureChannelInitialized
import io.libp2p.etc.types.toByteArray
import io.libp2p.etc.types.toByteBuf
import io.netty.buffer.ByteBuf
import io.netty.channel.ChannelHandlerContext
import io.netty.channel.ChannelInboundHandlerAdapter
import io.netty.channel.ChannelOutboundHandlerAdapter
import io.netty.channel.ChannelPromise
import io.netty.channel.SimpleChannelInboundHandler
import org.apache.logging.log4j.Level
import org.apache.logging.log4j.LogManager
import org.apache.logging.log4j.core.config.Configurator
import spipe.pb.Spipe
import java.nio.charset.StandardCharsets
import java.util.concurrent.CompletableFuture
import java.util.concurrent.atomic.AtomicInteger

Expand Down Expand Up @@ -71,40 +67,15 @@ open class NoiseXXSecureChannel(private val localKey: PrivKey, private val priva
override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) {
when (evt) {
is SecureChannelInitialized -> {

ctx.channel().attr(SECURE_SESSION).set(evt.session)
val session = evt.session as NoiseSecureChannelSession
ctx.channel().attr(SECURE_SESSION).set(session)

ctx.pipeline().remove(handshakeHandlerName)
ctx.pipeline().remove(this)

ctx.pipeline().addFirst(object : SimpleChannelInboundHandler<ByteBuf>() {
override fun channelRead0(ctx: ChannelHandlerContext, msg: ByteBuf) {
val get: NoiseSecureChannelSession = ctx.channel().attr(SECURE_SESSION).get() as NoiseSecureChannelSession
val additionalData = ByteArray(65535)
val plainText = ByteArray(65535)
val cipherText = msg.toByteArray()
val length = msg.getShort(0).toInt()
logger.debug("decrypt length:$length")
val l = get.bobCipher.decryptWithAd(additionalData, cipherText, 2, plainText, 0, length)
val rec2 = plainText.copyOf(l).toString(StandardCharsets.UTF_8)
ctx.pipeline().fireChannelRead(rec2)
}
})
ctx.pipeline().addFirst(object : ChannelOutboundHandlerAdapter() {
override fun write(ctx: ChannelHandlerContext, msg: Any, promise: ChannelPromise) {
msg as ByteBuf
val get: NoiseSecureChannelSession = ctx.channel().attr(SECURE_SESSION).get() as NoiseSecureChannelSession
val additionalData = ByteArray(65535)
val cipherText = ByteArray(65535)
val plaintext = msg.toByteArray()
val length = get.aliceCipher.encryptWithAd(additionalData, plaintext, 0, cipherText, 2, plaintext.size)
logger.debug("encrypt length:$length")
ctx.write(cipherText.copyOf(length + 2).toByteBuf().setShort(0, length))
logger.debug("channel outbound handler write: $msg")
}
})

ret.complete(evt.session)
ctx.pipeline().addLast(NoiseXXCodec(session.aliceCipher, session.bobCipher))

ret.complete(session)

logger.debug("Reporting secure channel initialized")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ import io.libp2p.core.crypto.KEY_TYPE
import io.libp2p.core.crypto.generateKeyPair
import io.libp2p.core.multistream.Mode
import io.libp2p.core.multistream.ProtocolMatcher
import io.libp2p.etc.types.toByteArray
import io.libp2p.multistream.Negotiator
import io.libp2p.multistream.ProtocolSelect
import io.libp2p.tools.TestChannel.Companion.interConnect
import io.libp2p.tools.TestHandler
import io.netty.buffer.ByteBuf
import io.netty.buffer.Unpooled
import io.netty.channel.ChannelHandlerContext
import io.netty.handler.logging.LogLevel
Expand Down Expand Up @@ -193,7 +195,8 @@ class NoiseSecureChannelTest {
// Setup alice's pipeline
eCh1.pipeline().addLast(object : TestHandler("1") {
override fun channelRead(ctx: ChannelHandlerContext, msg: Any?) {
rec1 = msg as String
msg as ByteBuf
rec1 = String(msg.toByteArray())
logger.debug("==$name== read: $msg")
latch.countDown()
}
Expand All @@ -202,7 +205,8 @@ class NoiseSecureChannelTest {
// Setup bob's pipeline
eCh2.pipeline().addLast(object : TestHandler("2") {
override fun channelRead(ctx: ChannelHandlerContext, msg: Any?) {
rec2 = msg as String
msg as ByteBuf
rec2 = String(msg.toByteArray())
logger.debug("==$name== read: $msg")
latch.countDown()
}
Expand Down