Skip to content

Commit

Permalink
Use BlockHeight everywhere (#2129)
Browse files Browse the repository at this point in the history
We now have this better type to remove ambiguity.
We should use it wherever it makes sense.
There shouldn't be any business logic change in this commit.
  • Loading branch information
t-bast authored Jan 19, 2022
1 parent 40f7ff4 commit 58f9ebc
Show file tree
Hide file tree
Showing 85 changed files with 894 additions and 880 deletions.
13 changes: 10 additions & 3 deletions eclair-core/src/main/scala/fr/acinq/eclair/BlockHeight.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,22 @@ package fr.acinq.eclair
case class BlockHeight(private val underlying: Long) extends Ordered[BlockHeight] {
// @formatter:off
override def compare(other: BlockHeight): Int = underlying.compareTo(other.underlying)
def +(other: BlockHeight) = BlockHeight(underlying + other.underlying)
def +(i: Int) = BlockHeight(underlying + i)
def +(l: Long) = BlockHeight(underlying + l)
def -(other: BlockHeight) = BlockHeight(underlying - other.underlying)
def -(i: Int) = BlockHeight(underlying - i)
def -(l: Long) = BlockHeight(underlying - l)
def -(other: BlockHeight): Long = underlying - other.underlying
def unary_- = BlockHeight(-underlying)

def toLong: Long = underlying
def max(other: BlockHeight): BlockHeight = if (this > other) this else other
def min(other: BlockHeight): BlockHeight = if (this < other) this else other

def toInt: Int = underlying.toInt
def toLong: Long = underlying
def toDouble: Double = underlying.toDouble
// @formatter:on
}

object BlockHeight {
def apply(underlying: Int): BlockHeight = BlockHeight(underlying.toLong)
}
18 changes: 13 additions & 5 deletions eclair-core/src/main/scala/fr/acinq/eclair/CltvExpiry.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,26 @@ package fr.acinq.eclair
*/

/**
* Bitcoin scripts (in particular HTLCs) need an absolute block expiry (greater than the current block count) to work
* Bitcoin scripts (in particular HTLCs) need an absolute block expiry (greater than the current block height) to work
* with OP_CLTV.
*
* @param underlying the absolute cltv expiry value (current block count + some delta).
* @param underlying the absolute cltv expiry value (current block height + some delta).
*/
case class CltvExpiry(private val underlying: Long) extends Ordered[CltvExpiry] {
case class CltvExpiry(private val underlying: BlockHeight) extends Ordered[CltvExpiry] {
// @formatter:off
def +(d: CltvExpiryDelta): CltvExpiry = CltvExpiry(underlying + d.toInt)
def -(d: CltvExpiryDelta): CltvExpiry = CltvExpiry(underlying - d.toInt)
def -(other: CltvExpiry): CltvExpiryDelta = CltvExpiryDelta((underlying - other.underlying).toInt)
override def compare(other: CltvExpiry): Int = underlying.compareTo(other.underlying)
def toLong: Long = underlying
def blockHeight: BlockHeight = underlying
def toLong: Long = underlying.toLong
// @formatter:on
}

object CltvExpiry {
// @formatter:off
def apply(underlying: Int): CltvExpiry = CltvExpiry(BlockHeight(underlying))
def apply(underlying: Long): CltvExpiry = CltvExpiry(BlockHeight(underlying))
// @formatter:on
}

Expand All @@ -49,7 +57,7 @@ case class CltvExpiryDelta(private val underlying: Int) extends Ordered[CltvExpi
/**
* Adds the current block height to the given delta to obtain an absolute expiry.
*/
def toCltvExpiry(blockHeight: Long) = CltvExpiry(blockHeight + underlying)
def toCltvExpiry(currentBlockHeight: BlockHeight) = CltvExpiry(currentBlockHeight + underlying)

// @formatter:off
def +(other: Int): CltvExpiryDelta = CltvExpiryDelta(underlying + other)
Expand Down
8 changes: 4 additions & 4 deletions eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import scala.jdk.CollectionConverters._
case class NodeParams(nodeKeyManager: NodeKeyManager,
channelKeyManager: ChannelKeyManager,
instanceId: UUID, // a unique instance ID regenerated after each restart
private val blockCount: AtomicLong,
private val blockHeight: AtomicLong,
alias: String,
color: Color,
publicAddresses: List[NodeAddress],
Expand Down Expand Up @@ -108,7 +108,7 @@ case class NodeParams(nodeKeyManager: NodeKeyManager,

val pluginMessageTags: Set[Int] = pluginParams.collect { case p: CustomFeaturePlugin => p.messageTags }.toSet.flatten

def currentBlockHeight: Long = blockCount.get
def currentBlockHeight: BlockHeight = BlockHeight(blockHeight.get)

/** Returns the features that should be used in our init message with the given peer. */
def initFeaturesFor(nodeId: PublicKey): Features = overrideFeatures.getOrElse(nodeId, features).initFeatures()
Expand Down Expand Up @@ -186,7 +186,7 @@ object NodeParams extends Logging {
def chainFromHash(chainHash: ByteVector32): String = chain2Hash.map(_.swap).getOrElse(chainHash, throw new RuntimeException(s"invalid chainHash '$chainHash'"))

def makeNodeParams(config: Config, instanceId: UUID, nodeKeyManager: NodeKeyManager, channelKeyManager: ChannelKeyManager,
torAddress_opt: Option[NodeAddress], database: Databases, blockCount: AtomicLong, feeEstimator: FeeEstimator,
torAddress_opt: Option[NodeAddress], database: Databases, blockHeight: AtomicLong, feeEstimator: FeeEstimator,
pluginParams: Seq[PluginParams] = Nil): NodeParams = {
// check configuration for keys that have been renamed
val deprecatedKeyPaths = Map(
Expand Down Expand Up @@ -384,7 +384,7 @@ object NodeParams extends Logging {
nodeKeyManager = nodeKeyManager,
channelKeyManager = channelKeyManager,
instanceId = instanceId,
blockCount = blockCount,
blockHeight = blockHeight,
alias = nodeAlias,
color = Color(color(0), color(1), color(2)),
publicAddresses = addresses,
Expand Down
8 changes: 4 additions & 4 deletions eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class Setup(val datadir: File,
* It is mainly used to calculate htlc expiries.
* The value is read by all actors, hence it needs to be thread-safe.
*/
val blockCount = new AtomicLong(0)
val blockHeight = new AtomicLong(0)

/**
* This holds the current feerates, in satoshi-per-kilobytes.
Expand All @@ -132,7 +132,7 @@ class Setup(val datadir: File,
// @formatter:on
}

val nodeParams = NodeParams.makeNodeParams(config, instanceId, nodeKeyManager, channelKeyManager, initTor(), databases, blockCount, feeEstimator, pluginParams)
val nodeParams = NodeParams.makeNodeParams(config, instanceId, nodeKeyManager, channelKeyManager, initTor(), databases, blockHeight, feeEstimator, pluginParams)
pluginParams.foreach(param => logger.info(s"using plugin=${param.name}"))

val serverBindingAddress = new InetSocketAddress(config.getString("server.binding-ip"), config.getInt("server.port"))
Expand Down Expand Up @@ -194,7 +194,7 @@ class Setup(val datadir: File,
assert(progress > 0.999, s"bitcoind should be synchronized (progress=$progress)")
assert(headers - blocks <= 1, s"bitcoind should be synchronized (headers=$headers blocks=$blocks)")
logger.info(s"current blockchain height=$blocks")
blockCount.set(blocks)
blockHeight.set(blocks)
bitcoinClient
}

Expand Down Expand Up @@ -255,7 +255,7 @@ class Setup(val datadir: File,
watcher = {
system.actorOf(SimpleSupervisor.props(Props(new ZMQActor(config.getString("bitcoind.zmqblock"), ZMQActor.Topics.HashBlock, Some(zmqBlockConnected))), "zmqblock", SupervisorStrategy.Restart))
system.actorOf(SimpleSupervisor.props(Props(new ZMQActor(config.getString("bitcoind.zmqtx"), ZMQActor.Topics.RawTx, Some(zmqTxConnected))), "zmqtx", SupervisorStrategy.Restart))
system.spawn(Behaviors.supervise(ZmqWatcher(nodeParams, blockCount, bitcoinClient)).onFailure(typed.SupervisorStrategy.resume), "watcher")
system.spawn(Behaviors.supervise(ZmqWatcher(nodeParams, blockHeight, bitcoinClient)).onFailure(typed.SupervisorStrategy.resume), "watcher")
}

router = system.actorOf(SimpleSupervisor.props(Router.props(nodeParams, watcher, Some(routerInitialized)), "router", SupervisorStrategy.Resume))
Expand Down
20 changes: 9 additions & 11 deletions eclair-core/src/main/scala/fr/acinq/eclair/ShortChannelId.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
package fr.acinq.eclair

/**
* A short channel id uniquely identifies a channel by the coordinates of its funding tx output in the blockchain.
*
* See BOLT 7: https://github.com/lightningnetwork/lightning-rfc/blob/master/07-routing-gossip.md#requirements
*
*/
* A short channel id uniquely identifies a channel by the coordinates of its funding tx output in the blockchain.
* See BOLT 7: https://github.com/lightningnetwork/lightning-rfc/blob/master/07-routing-gossip.md#requirements
*/
case class ShortChannelId(private val id: Long) extends Ordered[ShortChannelId] {

def toLong: Long = id
Expand All @@ -30,7 +28,7 @@ case class ShortChannelId(private val id: Long) extends Ordered[ShortChannelId]

override def toString: String = {
val TxCoordinates(blockHeight, txIndex, outputIndex) = ShortChannelId.coordinates(this)
s"${blockHeight}x${txIndex}x${outputIndex}"
s"${blockHeight.toLong}x${txIndex}x$outputIndex"
}

// we use an unsigned long comparison here
Expand All @@ -44,20 +42,20 @@ object ShortChannelId {
case _ => throw new IllegalArgumentException(s"Invalid short channel id: $s")
}

def apply(blockHeight: Int, txIndex: Int, outputIndex: Int): ShortChannelId = ShortChannelId(toShortId(blockHeight, txIndex, outputIndex))
def apply(blockHeight: BlockHeight, txIndex: Int, outputIndex: Int): ShortChannelId = ShortChannelId(toShortId(blockHeight.toInt, txIndex, outputIndex))

def toShortId(blockHeight: Int, txIndex: Int, outputIndex: Int): Long = ((blockHeight & 0xFFFFFFL) << 40) | ((txIndex & 0xFFFFFFL) << 16) | (outputIndex & 0xFFFFL)

@inline
def blockHeight(shortChannelId: ShortChannelId) = ((shortChannelId.id >> 40) & 0xFFFFFF).toInt
def blockHeight(shortChannelId: ShortChannelId): BlockHeight = BlockHeight((shortChannelId.id >> 40) & 0xFFFFFF)

@inline
def txIndex(shortChannelId: ShortChannelId) = ((shortChannelId.id >> 16) & 0xFFFFFF).toInt
def txIndex(shortChannelId: ShortChannelId): Int = ((shortChannelId.id >> 16) & 0xFFFFFF).toInt

@inline
def outputIndex(shortChannelId: ShortChannelId) = (shortChannelId.id & 0xFFFF).toInt
def outputIndex(shortChannelId: ShortChannelId): Int = (shortChannelId.id & 0xFFFF).toInt

def coordinates(shortChannelId: ShortChannelId): TxCoordinates = TxCoordinates(blockHeight(shortChannelId), txIndex(shortChannelId), outputIndex(shortChannelId))
}

case class TxCoordinates(blockHeight: Int, txIndex: Int, outputIndex: Int)
case class TxCoordinates(blockHeight: BlockHeight, txIndex: Int, outputIndex: Int)
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package fr.acinq.eclair.blockchain

import fr.acinq.bitcoin.{ByteVector32, Transaction}
import fr.acinq.eclair.BlockHeight
import fr.acinq.eclair.blockchain.fee.FeeratesPerKw

/**
Expand All @@ -29,6 +30,6 @@ case class NewBlock(blockHash: ByteVector32) extends BlockchainEvent

case class NewTransaction(tx: Transaction) extends BlockchainEvent

case class CurrentBlockCount(blockCount: Long) extends BlockchainEvent
case class CurrentBlockHeight(blockHeight: BlockHeight) extends BlockchainEvent

case class CurrentFeerates(feeratesPerKw: FeeratesPerKw) extends BlockchainEvent
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import fr.acinq.eclair.blockchain._
import fr.acinq.eclair.blockchain.bitcoind.rpc.BitcoinCoreClient
import fr.acinq.eclair.blockchain.watchdogs.BlockchainWatchdog
import fr.acinq.eclair.wire.protocol.ChannelAnnouncement
import fr.acinq.eclair.{KamonExt, NodeParams, ShortChannelId, TimestampSecond}
import fr.acinq.eclair.{BlockHeight, KamonExt, NodeParams, ShortChannelId, TimestampSecond}

import java.util.concurrent.atomic.AtomicLong
import scala.concurrent.duration._
Expand Down Expand Up @@ -59,8 +59,8 @@ object ZmqWatcher {
private case object TickNewBlock extends Command
private case object TickBlockTimeout extends Command
private case class GetBlockCountFailed(t: Throwable) extends Command
private case class CheckBlockCount(count: Long) extends Command
private case class PublishBlockCount(count: Long) extends Command
private case class CheckBlockHeight(current: BlockHeight) extends Command
private case class PublishBlockHeight(current: BlockHeight) extends Command
private case class ProcessNewBlock(blockHash: ByteVector32) extends Command
private case class ProcessNewTransaction(tx: Transaction) extends Command

Expand Down Expand Up @@ -120,7 +120,7 @@ object ZmqWatcher {
/** This event is sent when a [[WatchConfirmed]] condition is met. */
sealed trait WatchConfirmedTriggered extends WatchTriggered {
/** Block in which the transaction was confirmed. */
def blockHeight: Int
def blockHeight: BlockHeight
/** Index of the transaction in that block. */
def txIndex: Int
/** Transaction that has been confirmed. */
Expand All @@ -146,16 +146,16 @@ object ZmqWatcher {
case class WatchOutputSpentTriggered(spendingTx: Transaction) extends WatchSpentTriggered

case class WatchFundingConfirmed(replyTo: ActorRef[WatchFundingConfirmedTriggered], txId: ByteVector32, minDepth: Long) extends WatchConfirmed[WatchFundingConfirmedTriggered]
case class WatchFundingConfirmedTriggered(blockHeight: Int, txIndex: Int, tx: Transaction) extends WatchConfirmedTriggered
case class WatchFundingConfirmedTriggered(blockHeight: BlockHeight, txIndex: Int, tx: Transaction) extends WatchConfirmedTriggered

case class WatchFundingDeeplyBuried(replyTo: ActorRef[WatchFundingDeeplyBuriedTriggered], txId: ByteVector32, minDepth: Long) extends WatchConfirmed[WatchFundingDeeplyBuriedTriggered]
case class WatchFundingDeeplyBuriedTriggered(blockHeight: Int, txIndex: Int, tx: Transaction) extends WatchConfirmedTriggered
case class WatchFundingDeeplyBuriedTriggered(blockHeight: BlockHeight, txIndex: Int, tx: Transaction) extends WatchConfirmedTriggered

case class WatchTxConfirmed(replyTo: ActorRef[WatchTxConfirmedTriggered], txId: ByteVector32, minDepth: Long) extends WatchConfirmed[WatchTxConfirmedTriggered]
case class WatchTxConfirmedTriggered(blockHeight: Int, txIndex: Int, tx: Transaction) extends WatchConfirmedTriggered
case class WatchTxConfirmedTriggered(blockHeight: BlockHeight, txIndex: Int, tx: Transaction) extends WatchConfirmedTriggered

case class WatchParentTxConfirmed(replyTo: ActorRef[WatchParentTxConfirmedTriggered], txId: ByteVector32, minDepth: Long) extends WatchConfirmed[WatchParentTxConfirmedTriggered]
case class WatchParentTxConfirmedTriggered(blockHeight: Int, txIndex: Int, tx: Transaction) extends WatchConfirmedTriggered
case class WatchParentTxConfirmedTriggered(blockHeight: BlockHeight, txIndex: Int, tx: Transaction) extends WatchConfirmedTriggered

// TODO: not implemented yet: notify me if confirmation number gets below minDepth?
case class WatchFundingLost(replyTo: ActorRef[WatchFundingLostTriggered], txId: ByteVector32, minDepth: Long) extends Watch[WatchFundingLostTriggered]
Expand Down Expand Up @@ -213,7 +213,7 @@ object ZmqWatcher {

}

private class ZmqWatcher(nodeParams: NodeParams, blockCount: AtomicLong, client: BitcoinCoreClient, context: ActorContext[ZmqWatcher.Command], timers: TimerScheduler[ZmqWatcher.Command])(implicit ec: ExecutionContext = ExecutionContext.global) {
private class ZmqWatcher(nodeParams: NodeParams, blockHeight: AtomicLong, client: BitcoinCoreClient, context: ActorContext[ZmqWatcher.Command], timers: TimerScheduler[ZmqWatcher.Command])(implicit ec: ExecutionContext = ExecutionContext.global) {

import ZmqWatcher._

Expand Down Expand Up @@ -250,39 +250,39 @@ private class ZmqWatcher(nodeParams: NodeParams, blockCount: AtomicLong, client:
case TickBlockTimeout =>
// we haven't received a block in a while, we check whether we're behind and restart the timer.
timers.startSingleTimer(TickBlockTimeout, blockTimeout)
context.pipeToSelf(client.getBlockCount) {
context.pipeToSelf(client.getBlockHeight()) {
case Failure(t) => GetBlockCountFailed(t)
case Success(count) => CheckBlockCount(count)
case Success(currentHeight) => CheckBlockHeight(currentHeight)
}
Behaviors.same

case GetBlockCountFailed(t) =>
log.error("could not get block count from bitcoind", t)
Behaviors.same

case CheckBlockCount(count) =>
val current = blockCount.get()
if (count > current) {
log.warn("block {} wasn't received via ZMQ, you should verify that your bitcoind node is running", count)
case CheckBlockHeight(height) =>
val current = blockHeight.get()
if (height.toLong > current) {
log.warn("block {} wasn't received via ZMQ, you should verify that your bitcoind node is running", height.toLong)
context.self ! TickNewBlock
}
Behaviors.same

case TickNewBlock =>
context.pipeToSelf(client.getBlockCount) {
context.pipeToSelf(client.getBlockHeight()) {
case Failure(t) => GetBlockCountFailed(t)
case Success(count) => PublishBlockCount(count)
case Success(currentHeight) => PublishBlockHeight(currentHeight)
}
// TODO: beware of the herd effect
KamonExt.timeFuture(Metrics.NewBlockCheckConfirmedDuration.withoutTags()) {
Future.sequence(watches.collect { case w: WatchConfirmed[_] => checkConfirmed(w) })
}
Behaviors.same

case PublishBlockCount(count) =>
log.debug("setting blockCount={}", count)
blockCount.set(count)
context.system.eventStream ! EventStream.Publish(CurrentBlockCount(count))
case PublishBlockHeight(currentHeight) =>
log.debug("setting blockHeight={}", currentHeight)
blockHeight.set(currentHeight.toLong)
context.system.eventStream ! EventStream.Publish(CurrentBlockHeight(currentHeight))
Behaviors.same

case TriggerEvent(replyTo, watch, event) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ package fr.acinq.eclair.blockchain.bitcoind.rpc
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.bitcoin._
import fr.acinq.eclair.ShortChannelId.coordinates
import fr.acinq.eclair.{TimestampSecond, TxCoordinates}
import fr.acinq.eclair.blockchain.OnChainWallet
import fr.acinq.eclair.blockchain.OnChainWallet.{MakeFundingTxResponse, OnChainBalance}
import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher.{GetTxWithMetaResponse, UtxoStatus, ValidateResult}
import fr.acinq.eclair.blockchain.fee.{FeeratePerKB, FeeratePerKw}
import fr.acinq.eclair.transactions.Transactions
import fr.acinq.eclair.wire.protocol.ChannelAnnouncement
import fr.acinq.eclair.{BlockHeight, TimestampSecond, TxCoordinates}
import grizzled.slf4j.Logging
import org.json4s.Formats
import org.json4s.JsonAST._
Expand Down Expand Up @@ -84,14 +84,14 @@ class BitcoinCoreClient(val rpcClient: BitcoinJsonRPCClient) extends OnChainWall
* @return a Future[height, index] where height is the height of the block where this transaction was published, and
* index is the index of the transaction in that block.
*/
def getTransactionShortId(txid: ByteVector32)(implicit ec: ExecutionContext): Future[(Int, Int)] =
def getTransactionShortId(txid: ByteVector32)(implicit ec: ExecutionContext): Future[(BlockHeight, Int)] =
for {
Some(blockHash) <- getTxBlockHash(txid)
json <- rpcClient.invoke("getblock", blockHash)
JInt(height) = json \ "height"
JArray(txs) = json \ "tx"
index = txs.indexOf(JString(txid.toHex))
} yield (height.toInt, index)
} yield (BlockHeight(height.toInt), index)

def isTransactionOutputSpendable(txid: ByteVector32, outputIndex: Int, includeMempool: Boolean)(implicit ec: ExecutionContext): Future[Boolean] =
for {
Expand Down Expand Up @@ -381,15 +381,15 @@ class BitcoinCoreClient(val rpcClient: BitcoinJsonRPCClient) extends OnChainWall

//------------------------- BLOCKCHAIN -------------------------//

def getBlockCount(implicit ec: ExecutionContext): Future[Long] =
def getBlockHeight()(implicit ec: ExecutionContext): Future[BlockHeight] =
rpcClient.invoke("getblockcount").collect {
case JInt(count) => count.toLong
case JInt(count) => BlockHeight(count.toLong)
}

def validate(c: ChannelAnnouncement)(implicit ec: ExecutionContext): Future[ValidateResult] = {
val TxCoordinates(blockHeight, txIndex, outputIndex) = coordinates(c.shortChannelId)
for {
blockHash <- rpcClient.invoke("getblockhash", blockHeight).map(_.extractOpt[String].map(ByteVector32.fromValidHex).getOrElse(ByteVector32.Zeroes))
blockHash <- rpcClient.invoke("getblockhash", blockHeight.toInt).map(_.extractOpt[String].map(ByteVector32.fromValidHex).getOrElse(ByteVector32.Zeroes))
txid: ByteVector32 <- rpcClient.invoke("getblock", blockHash).map(json => Try {
val JArray(txs) = json \ "tx"
ByteVector32.fromValidHex(txs(txIndex).extract[String])
Expand Down
Loading

0 comments on commit 58f9ebc

Please sign in to comment.