Skip to content

Commit

Permalink
Add networks to init message (#1254)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-bast authored Jan 21, 2020
1 parent 01a30ed commit ca713ba
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 49 deletions.
28 changes: 17 additions & 11 deletions eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A
log.debug(s"got authenticated connection to $remoteNodeId@${address.getHostString}:${address.getPort}")
transport ! TransportHandler.Listener(self)
context watch transport
val localInit = nodeParams.overrideFeatures.get(remoteNodeId) match {
case Some(f) => wire.Init(f)
val localFeatures = nodeParams.overrideFeatures.get(remoteNodeId) match {
case Some(f) => f
case None =>
// Eclair-mobile thinks feature bit 15 (payment_secret) is gossip_queries_ex which creates issues, so we mask
// off basic_mpp and payment_secret. As long as they're provided in the invoice it's not an issue.
Expand All @@ -116,9 +116,10 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A
// ... and leave the others untouched
case (value, _) => value
}).reverse.bytes.dropWhile(_ == 0)
wire.Init(tweakedFeatures)
tweakedFeatures
}
log.info(s"using features=${localInit.features.toBin}")
log.info(s"using features=${localFeatures.toBin}")
val localInit = wire.Init(localFeatures, TlvStream(InitTlv.Networks(nodeParams.chainHash :: Nil)))
transport ! localInit

val address_opt = if (outgoing) {
Expand Down Expand Up @@ -148,9 +149,19 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A
case Event(remoteInit: wire.Init, d: InitializingData) =>
d.transport ! TransportHandler.ReadAck(remoteInit)

log.info(s"peer is using features=${remoteInit.features.toBin}")
log.info(s"peer is using features=${remoteInit.features.toBin}, networks=${remoteInit.networks.mkString(",")}")

if (Features.areSupported(remoteInit.features)) {
if (remoteInit.networks.nonEmpty && !remoteInit.networks.contains(nodeParams.chainHash)) {
log.warning(s"incompatible networks (${remoteInit.networks}), disconnecting")
d.origin_opt.foreach(origin => origin ! Status.Failure(new RuntimeException("incompatible networks")))
d.transport ! PoisonPill
stay
} else if (!Features.areSupported(remoteInit.features)) {
log.warning("incompatible features, disconnecting")
d.origin_opt.foreach(origin => origin ! Status.Failure(new RuntimeException("incompatible features")))
d.transport ! PoisonPill
stay
} else {
d.origin_opt.foreach(origin => origin ! "connected")

def localHasFeature(f: Feature): Boolean = Features.hasFeature(d.localInit.features, f)
Expand Down Expand Up @@ -181,11 +192,6 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A
val rebroadcastDelay = Random.nextInt(nodeParams.routerConf.routerBroadcastInterval.toSeconds.toInt).seconds
log.info(s"rebroadcast will be delayed by $rebroadcastDelay")
goto(CONNECTED) using ConnectedData(d.address_opt, d.transport, d.localInit, remoteInit, d.channels.map { case (k: ChannelId, v) => (k, v) }, rebroadcastDelay) forMax (30 seconds) // forMax will trigger a StateTimeout
} else {
log.warning(s"incompatible features, disconnecting")
d.origin_opt.foreach(origin => origin ! Status.Failure(new RuntimeException("incompatible features")))
d.transport ! PoisonPill
stay
}

case Event(Authenticator.Authenticated(connection, _, _, _, _, origin_opt), _) =>
Expand Down
49 changes: 49 additions & 0 deletions eclair-core/src/main/scala/fr/acinq/eclair/wire/InitTlv.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright 2019 ACINQ SAS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package fr.acinq.eclair.wire

import fr.acinq.bitcoin.ByteVector32
import fr.acinq.eclair.UInt64
import fr.acinq.eclair.wire.CommonCodecs._
import scodec.Codec
import scodec.codecs.{discriminated, list, variableSizeBytesLong}

/**
* Created by t-bast on 13/12/2019.
*/

/** Tlv types used inside Init messages. */
sealed trait InitTlv extends Tlv

object InitTlv {

/** The chains the node is interested in. */
case class Networks(chainHashes: List[ByteVector32]) extends InitTlv

}

object InitTlvCodecs {

import InitTlv._

private val networks: Codec[Networks] = variableSizeBytesLong(varintoverflow, list(bytes32)).as[Networks]

val initTlvCodec = TlvCodecs.tlvStream(discriminated[InitTlv].by(varint)
.typecase(UInt64(1), networks)
)

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ object LightningMessageCodecs {
},
{ features => (ByteVector.empty, features) })

val initCodec: Codec[Init] = combinedFeaturesCodec.as[Init]
val initCodec: Codec[Init] = (("features" | combinedFeaturesCodec) :: ("tlvStream" | InitTlvCodecs.initTlvCodec)).as[Init]

val errorCodec: Codec[Error] = (
("channelId" | bytes32) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ sealed trait HasChainHash extends LightningMessage { def chainHash: ByteVector32
sealed trait UpdateMessage extends HtlcMessage // <- not in the spec
// @formatter:on

case class Init(features: ByteVector) extends SetupMessage
case class Init(features: ByteVector, tlvs: TlvStream[InitTlv] = TlvStream.empty) extends SetupMessage {
val networks = tlvs.get[InitTlv.Networks].map(_.chainHashes).getOrElse(Nil)
}

case class Error(channelId: ByteVector32, data: ByteVector) extends SetupMessage with HasChannelId {
def toAscii: String = if (fr.acinq.eclair.isAsciiPrintable(data)) new String(data.toArray, StandardCharsets.US_ASCII) else "n/a"
Expand Down
43 changes: 21 additions & 22 deletions eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,41 +22,40 @@ import scodec.bits.ByteVector
import scala.reflect.ClassTag

/**
* Created by t-bast on 20/06/2019.
*/
* Created by t-bast on 20/06/2019.
*/

trait Tlv

/**
* Generic tlv type we fallback to if we don't understand the incoming tlv.
*
* @param tag tlv tag.
* @param value tlv value (length is implicit, and encoded as a varint).
*/
* Generic tlv type we fallback to if we don't understand the incoming tlv.
*
* @param tag tlv tag.
* @param value tlv value (length is implicit, and encoded as a varint).
*/
case class GenericTlv(tag: UInt64, value: ByteVector) extends Tlv

/**
* A tlv stream is a collection of tlv records.
* A tlv stream is constrained to a specific tlv namespace that dictates how to parse the tlv records.
* That namespace is provided by a trait extending the top-level tlv trait.
*
* @param records known tlv records.
* @param unknown unknown tlv records.
* @tparam T the stream namespace is a trait extending the top-level tlv trait.
*/
* A tlv stream is a collection of tlv records.
* A tlv stream is constrained to a specific tlv namespace that dictates how to parse the tlv records.
* That namespace is provided by a trait extending the top-level tlv trait.
*
* @param records known tlv records.
* @param unknown unknown tlv records.
* @tparam T the stream namespace is a trait extending the top-level tlv trait.
*/
case class TlvStream[T <: Tlv](records: Traversable[T], unknown: Traversable[GenericTlv] = Nil) {
/**
*
* @tparam R input type parameter, must be a subtype of the main TLV type
* @return the TLV record of type that matches the input type parameter if any (there can be at most one, since BOLTs specify
* that TLV records are supposed to be unique)
*/
*
* @tparam R input type parameter, must be a subtype of the main TLV type
* @return the TLV record of type that matches the input type parameter if any (there can be at most one, since BOLTs specify
* that TLV records are supposed to be unique)
*/
def get[R <: T : ClassTag]: Option[R] = records.collectFirst { case r: R => r }
}

object TlvStream {
def empty[T <: Tlv] = TlvStream[T](Nil, Nil)
def empty[T <: Tlv]: TlvStream[T] = TlvStream[T](Nil, Nil)

def apply[T <: Tlv](records: T*): TlvStream[T] = TlvStream(records, Nil)

}
19 changes: 17 additions & 2 deletions eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.net.{Inet4Address, InetAddress, InetSocketAddress, ServerSocket}
import akka.actor.FSM.{CurrentState, SubscribeTransitionCallBack, Transition}
import akka.actor.{ActorRef, PoisonPill}
import akka.testkit.{TestFSMRef, TestProbe}
import fr.acinq.bitcoin.Block
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.TestConstants._
import fr.acinq.eclair._
Expand All @@ -31,7 +32,7 @@ import fr.acinq.eclair.crypto.TransportHandler
import fr.acinq.eclair.io.Peer._
import fr.acinq.eclair.router.RoutingSyncSpec.makeFakeRoutingInfo
import fr.acinq.eclair.router.{Rebroadcast, RoutingSyncSpec, SendChannelQuery}
import fr.acinq.eclair.wire.{ChannelCodecsSpec, Color, EncodedShortChannelIds, EncodingType, Error, IPv4, LightningMessageCodecs, NodeAddress, NodeAnnouncement, Ping, Pong, QueryShortChannelIds, TlvStream}
import fr.acinq.eclair.wire.{ChannelCodecsSpec, Color, EncodedShortChannelIds, EncodingType, Error, IPv4, InitTlv, LightningMessageCodecs, NodeAddress, NodeAnnouncement, Ping, Pong, QueryShortChannelIds, TlvStream}
import org.scalatest.{Outcome, Tag}
import scodec.bits.{ByteVector, _}

Expand Down Expand Up @@ -81,7 +82,8 @@ class PeerSpec extends TestkitBaseClass with StateTestsHelperMethods {
probe.send(peer, Peer.Init(None, channels))
authenticator.send(peer, Authenticator.Authenticated(connection.ref, transport.ref, remoteNodeId, fakeIPAddress.socketAddress, outgoing = true, None))
transport.expectMsgType[TransportHandler.Listener]
transport.expectMsgType[wire.Init]
val localInit = transport.expectMsgType[wire.Init]
assert(localInit.networks === List(Block.RegtestGenesisBlock.hash))
transport.send(peer, remoteInit)
transport.expectMsgType[TransportHandler.ReadAck]
if (expectSync) {
Expand Down Expand Up @@ -255,6 +257,19 @@ class PeerSpec extends TestkitBaseClass with StateTestsHelperMethods {
assert(init.features === sentFeatures.bytes)
}
}

test("disconnect if incompatible networks") { f =>
import f._
val probe = TestProbe()
probe.watch(transport.ref)
probe.send(peer, Peer.Init(None, Set.empty))
authenticator.send(peer, Authenticator.Authenticated(connection.ref, transport.ref, remoteNodeId, new InetSocketAddress("1.2.3.4", 42000), outgoing = true, None))
transport.expectMsgType[TransportHandler.Listener]
transport.expectMsgType[wire.Init]
transport.send(peer, wire.Init(Bob.nodeParams.features, TlvStream(InitTlv.Networks(Block.LivenetGenesisBlock.hash :: Block.SegnetGenesisBlock.hash :: Nil))))
transport.expectMsgType[TransportHandler.ReadAck]
probe.expectTerminated(transport.ref)
}

test("handle disconnect in status INITIALIZING") { f =>
import f._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,37 @@ class LightningMessageCodecsSpec extends FunSuite {
def publicKey(fill: Byte) = PrivateKey(ByteVector.fill(32)(fill)).publicKey

test("encode/decode init message") {
case class TestCase(encoded: ByteVector, features: ByteVector, networks: List[ByteVector32], valid: Boolean, reEncoded: Option[ByteVector] = None)
val chainHash1 = ByteVector32(hex"0101010101010101010101010101010101010101010101010101010101010101")
val chainHash2 = ByteVector32(hex"0202020202020202020202020202020202020202020202020202020202020202")
val testCases = Seq(
(hex"0000 0000", hex"", hex"0000 0000"), // no features
(hex"0000 0002088a", hex"088a", hex"0000 0002088a"), // no global features
(hex"00020200 0000", hex"0200", hex"0000 00020200"), // no local features
(hex"00020200 0002088a", hex"0a8a", hex"0000 00020a8a"), // local and global - no conflict - same size
(hex"00020200 0003020002", hex"020202", hex"0000 0003020202"), // local and global - no conflict - different sizes
(hex"00020a02 0002088a", hex"0a8a", hex"0000 00020a8a"), // local and global - conflict - same size
(hex"00022200 000302aaa2", hex"02aaa2", hex"0000 000302aaa2") // local and global - conflict - different sizes
TestCase(hex"0000 0000", hex"", Nil, valid = true), // no features
TestCase(hex"0000 0002088a", hex"088a", Nil, valid = true), // no global features
TestCase(hex"00020200 0000", hex"0200", Nil, valid = true, Some(hex"0000 00020200")), // no local features
TestCase(hex"00020200 0002088a", hex"0a8a", Nil, valid = true, Some(hex"0000 00020a8a")), // local and global - no conflict - same size
TestCase(hex"00020200 0003020002", hex"020202", Nil, valid = true, Some(hex"0000 0003020202")), // local and global - no conflict - different sizes
TestCase(hex"00020a02 0002088a", hex"0a8a", Nil, valid = true, Some(hex"0000 00020a8a")), // local and global - conflict - same size
TestCase(hex"00022200 000302aaa2", hex"02aaa2", Nil, valid = true, Some(hex"0000 000302aaa2")), // local and global - conflict - different sizes
TestCase(hex"0000 0002088a 03012a05022aa2", hex"088a", Nil, valid = true), // unknown odd records
TestCase(hex"0000 0002088a 03012a04022aa2", hex"088a", Nil, valid = false), // unknown even records
TestCase(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101", hex"088a", Nil, valid = false), // invalid tlv stream
TestCase(hex"0000 0002088a 01200101010101010101010101010101010101010101010101010101010101010101", hex"088a", List(chainHash1), valid = true), // single network
TestCase(hex"0000 0002088a 014001010101010101010101010101010101010101010101010101010101010101010202020202020202020202020202020202020202020202020202020202020202", hex"088a", List(chainHash1, chainHash2), valid = true), // multiple networks
TestCase(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101010103012a", hex"088a", List(chainHash1), valid = true), // network and unknown odd records
TestCase(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101010102012a", hex"088a", Nil, valid = false) // network and unknown even records
)

for ((bin, features, encoded) <- testCases) {
val init = initCodec.decode(bin.bits).require.value
assert(init.features === features)
assert(initCodec.encode(init).require.bytes === encoded)
assert(initCodec.decode(encoded.bits).require.value === init)
for (testCase <- testCases) {
if (testCase.valid) {
val init = initCodec.decode(testCase.encoded.bits).require.value
assert(init.features === testCase.features)
assert(init.networks === testCase.networks)
val encoded = initCodec.encode(init).require
assert(encoded.bytes === testCase.reEncoded.getOrElse(testCase.encoded))
assert(initCodec.decode(encoded).require.value === init)
} else {
assert(initCodec.decode(testCase.encoded.bits).isFailure)
}
}
}

Expand Down

0 comments on commit ca713ba

Please sign in to comment.