diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala index 88dda7225a..d44eee5fca 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala @@ -101,7 +101,7 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A transport ! TransportHandler.Listener(self) context watch transport val localInit = nodeParams.overrideFeatures.get(remoteNodeId) match { - case Some(f) => wire.Init(f) + case Some(f) => wire.Init(f, TlvStream(InitTlv.Networks(nodeParams.chainHash :: Nil))) case None => // Eclair-mobile thinks feature bit 15 (payment_secret) is gossip_queries_ex which creates issues, so we mask // off basic_mpp and payment_secret. As long as they're provided in the invoice it's not an issue. @@ -116,7 +116,7 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A // ... and leave the others untouched case (value, _) => value }).reverse.bytes.dropWhile(_ == 0) - wire.Init(tweakedFeatures) + wire.Init(tweakedFeatures, TlvStream(InitTlv.Networks(nodeParams.chainHash :: Nil))) } log.info(s"using features=${localInit.features.toBin}") transport ! localInit @@ -148,9 +148,19 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A case Event(remoteInit: wire.Init, d: InitializingData) => d.transport ! TransportHandler.ReadAck(remoteInit) - log.info(s"peer is using features=${remoteInit.features.toBin}") + log.info(s"peer is using features=${remoteInit.features.toBin}, network=${remoteInit.networks}") - if (Features.areSupported(remoteInit.features)) { + if (remoteInit.networks.nonEmpty && !remoteInit.networks.contains(nodeParams.chainHash)) { + log.warning(s"incompatible networks (${remoteInit.networks}), disconnecting") + d.origin_opt.foreach(origin => origin ! Status.Failure(new RuntimeException("incompatible networks"))) + d.transport ! PoisonPill + stay + } else if (!Features.areSupported(remoteInit.features)) { + log.warning("incompatible features, disconnecting") + d.origin_opt.foreach(origin => origin ! Status.Failure(new RuntimeException("incompatible features"))) + d.transport ! PoisonPill + stay + } else { d.origin_opt.foreach(origin => origin ! "connected") def localHasFeature(f: Feature): Boolean = Features.hasFeature(d.localInit.features, f) @@ -181,11 +191,6 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A val rebroadcastDelay = Random.nextInt(nodeParams.routerConf.routerBroadcastInterval.toSeconds.toInt).seconds log.info(s"rebroadcast will be delayed by $rebroadcastDelay") goto(CONNECTED) using ConnectedData(d.address_opt, d.transport, d.localInit, remoteInit, d.channels.map { case (k: ChannelId, v) => (k, v) }, rebroadcastDelay) forMax (30 seconds) // forMax will trigger a StateTimeout - } else { - log.warning(s"incompatible features, disconnecting") - d.origin_opt.foreach(origin => origin ! Status.Failure(new RuntimeException("incompatible features"))) - d.transport ! PoisonPill - stay } case Event(Authenticator.Authenticated(connection, _, _, _, _, origin_opt), _) => diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/InitTlv.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/InitTlv.scala index 7f1b6d48b0..275243860d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/InitTlv.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/InitTlv.scala @@ -40,10 +40,6 @@ object InitTlvCodecs { import InitTlv._ - // TODO: - // * Send the chainHash from nodeParams when creating Init - // * Add logic to Peer.scala to fail connections to others that don't offer my chainHash - private val networks: Codec[Networks] = variableSizeBytesLong(varintoverflow, list(bytes32)).as[Networks] val initTlvCodec = TlvCodecs.tlvStream(discriminated[InitTlv].by(varint) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala index 8994595a56..83dbf71b60 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala @@ -21,6 +21,7 @@ import java.net.{Inet4Address, InetAddress, InetSocketAddress, ServerSocket} import akka.actor.FSM.{CurrentState, SubscribeTransitionCallBack, Transition} import akka.actor.{ActorRef, PoisonPill} import akka.testkit.{TestFSMRef, TestProbe} +import fr.acinq.bitcoin.Block import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.TestConstants._ import fr.acinq.eclair._ @@ -31,7 +32,7 @@ import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.io.Peer._ import fr.acinq.eclair.router.RoutingSyncSpec.makeFakeRoutingInfo import fr.acinq.eclair.router.{Rebroadcast, RoutingSyncSpec, SendChannelQuery} -import fr.acinq.eclair.wire.{ChannelCodecsSpec, Color, EncodedShortChannelIds, EncodingType, Error, IPv4, LightningMessageCodecs, NodeAddress, NodeAnnouncement, Ping, Pong, QueryShortChannelIds, TlvStream} +import fr.acinq.eclair.wire.{ChannelCodecsSpec, Color, EncodedShortChannelIds, EncodingType, Error, IPv4, InitTlv, LightningMessageCodecs, NodeAddress, NodeAnnouncement, Ping, Pong, QueryShortChannelIds, TlvStream} import org.scalatest.{Outcome, Tag} import scodec.bits.{ByteVector, _} @@ -81,7 +82,8 @@ class PeerSpec extends TestkitBaseClass with StateTestsHelperMethods { probe.send(peer, Peer.Init(None, channels)) authenticator.send(peer, Authenticator.Authenticated(connection.ref, transport.ref, remoteNodeId, fakeIPAddress.socketAddress, outgoing = true, None)) transport.expectMsgType[TransportHandler.Listener] - transport.expectMsgType[wire.Init] + val localInit = transport.expectMsgType[wire.Init] + assert(localInit.networks === List(Block.RegtestGenesisBlock.hash)) transport.send(peer, remoteInit) transport.expectMsgType[TransportHandler.ReadAck] if (expectSync) { @@ -255,6 +257,19 @@ class PeerSpec extends TestkitBaseClass with StateTestsHelperMethods { assert(init.features === sentFeatures.bytes) } } + + test("disconnect if incompatible networks") { f => + import f._ + val probe = TestProbe() + probe.watch(transport.ref) + probe.send(peer, Peer.Init(None, Set.empty)) + authenticator.send(peer, Authenticator.Authenticated(connection.ref, transport.ref, remoteNodeId, new InetSocketAddress("1.2.3.4", 42000), outgoing = true, None)) + transport.expectMsgType[TransportHandler.Listener] + transport.expectMsgType[wire.Init] + transport.send(peer, wire.Init(Bob.nodeParams.features, TlvStream(InitTlv.Networks(Block.LivenetGenesisBlock.hash :: Block.SegnetGenesisBlock.hash :: Nil)))) + transport.expectMsgType[TransportHandler.ReadAck] + probe.expectTerminated(transport.ref) + } test("handle disconnect in status INITIALIZING") { f => import f._