Skip to content

Commit

Permalink
lnprototest: refactroing abstract class and fixed the #14
Browse files Browse the repository at this point in the history
kill all the process when the class was removed from the scope.

Signed-off-by: Vincenzo Palazzo <[email protected]>
  • Loading branch information
vincenzopalazzo committed Mar 3, 2022
1 parent 8530319 commit e0cce0a
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 105 deletions.
1 change: 1 addition & 0 deletions lnprototest/backend/bitcoind.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def start(self) -> None:

def stop(self) -> None:
self.proc.kill()
shutil.rmtree(self.bitcoin_dir)

def restart(self) -> None:
# Only restart if we have to.
Expand Down
102 changes: 53 additions & 49 deletions lnprototest/clightning/clightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import lnprototest
import bitcoin.core
import struct
import shutil

from concurrent import futures
from ephemeral_port_reserve import reserve
Expand Down Expand Up @@ -53,13 +54,13 @@ def __init__(self, connprivkey: str, port: int):
class Runner(lnprototest.Runner):
def __init__(self, config: Any):
super().__init__(config)
self.running = False
self.cleanup_callbacks: List[Callable[[], None]] = []
self.fundchannel_future: Optional[Any] = None
self.is_fundchannel_kill = False

directory = tempfile.mkdtemp(prefix="lnpt-cl-")
self.bitcoind = Bitcoind(directory)
self.bitcoind.start()
self.executor = futures.ThreadPoolExecutor(max_workers=20)

self.lightning_dir = os.path.join(directory, "lightningd")
Expand All @@ -80,8 +81,8 @@ def __init__(self, config: Any):
stdout=subprocess.PIPE,
check=True,
)
.stdout.decode("utf-8")
.splitlines()
.stdout.decode("utf-8")
.splitlines()
)
self.options: Dict[str, str] = {}
for o in opts:
Expand All @@ -106,7 +107,12 @@ def get_node_privkey(self) -> str:
def get_node_bitcoinkey(self) -> str:
return "0000000000000000000000000000000000000000000000000000000000000010"

def is_running(self) -> bool:
return self.running

def start(self) -> None:
if self.running:
return
self.proc = subprocess.Popen(
[
"{}/lightningd/lightningd".format(LIGHTNING_SRC),
Expand All @@ -128,10 +134,11 @@ def start(self) -> None:
]
+ self.startup_flags
)
self.running = True
self.bitcoind.start()
self.rpc = pyln.client.LightningRpc(
os.path.join(self.lightning_dir, "regtest", "lightning-rpc")
)

def node_ready(rpc: pyln.client.LightningRpc) -> bool:
try:
rpc.getinfo()
Expand Down Expand Up @@ -160,38 +167,25 @@ def shutdown(self) -> None:
cb()

def stop(self) -> None:
for cb in self.cleanup_callbacks:
cb()
if self.running is False:
return
self.shutdown()
self.rpc.stop()
self.bitcoind.stop()
self.running = False
for c in self.conns.values():
cast(CLightningConn, c).connection.connection.close()

def connect(self, event: Event, connprivkey: str) -> None:
self.add_conn(CLightningConn(connprivkey, self.lightning_port))

def __enter__(self) -> "Runner":
self.start()
return self

def __exit__(self, type: Any, value: Any, tb: Any) -> None:
self.stop()

def restart(self) -> None:
super().restart()
if self.config.getoption("verbose"):
print("[RESTART]")
for cb in self.cleanup_callbacks:
cb()
self.rpc.stop()
self.bitcoind.restart()
for c in self.conns.values():
cast(CLightningConn, c).connection.connection.close()

self.stop()
# Make a clean start
os.remove(os.path.join(self.lightning_dir, "regtest", "gossip_store"))
os.remove(os.path.join(self.lightning_dir, "regtest", "lightningd.sqlite3"))
os.remove(os.path.join(self.lightning_dir, "regtest", "log"))
super().restart()
shutil.rmtree(self.lightning_dir)
self.start()

def getblockheight(self) -> int:
Expand Down Expand Up @@ -229,12 +223,12 @@ def recv(self, event: Event, conn: Conn, outbuf: bytes) -> None:
raise EventError(event, "Connection closed")

def fundchannel(
self,
event: Event,
conn: Conn,
amount: int,
feerate: int = 253,
expect_fail: bool = False,
self,
event: Event,
conn: Conn,
amount: int,
feerate: int = 253,
expect_fail: bool = False,
) -> None:
"""
event - the event which cause this, for error logging
Expand All @@ -254,11 +248,11 @@ def fundchannel(
self.fundchannel_future = None

def _fundchannel(
runner: Runner,
conn: Conn,
amount: int,
feerate: int,
expect_fail: bool = False,
runner: Runner,
conn: Conn,
amount: int,
feerate: int,
expect_fail: bool = False,
) -> str:
peer_id = conn.pubkey.format().hex()
# Need to supply feerate here, since regtest cannot estimate fees
Expand All @@ -282,14 +276,14 @@ def _done(fut: Any) -> None:
self.cleanup_callbacks.append(self.kill_fundchannel)

def init_rbf(
self,
event: Event,
conn: Conn,
channel_id: str,
amount: int,
utxo_txid: str,
utxo_outnum: int,
feerate: int,
self,
event: Event,
conn: Conn,
channel_id: str,
amount: int,
utxo_txid: str,
utxo_outnum: int,
feerate: int,
) -> None:

if self.fundchannel_future:
Expand Down Expand Up @@ -362,7 +356,7 @@ def addhtlc(self, event: Event, conn: Conn, amount: int, preimage: str) -> None:
self.rpc.sendpay([routestep], payhash)

def get_output_message(
self, conn: Conn, event: Event, timeout: int = TIMEOUT
self, conn: Conn, event: Event, timeout: int = TIMEOUT
) -> Optional[bytes]:
fut = self.executor.submit(cast(CLightningConn, conn).connection.read_message)
try:
Expand All @@ -379,11 +373,11 @@ def check_error(self, event: Event, conn: Conn) -> Optional[str]:
return msg.hex()

def check_final_error(
self,
event: Event,
conn: Conn,
expected: bool,
must_not_events: List[MustNotMsg],
self,
event: Event,
conn: Conn,
expected: bool,
must_not_events: List[MustNotMsg],
) -> None:
if not expected:
# Inject raw packet to ensure it hangs up *after* processing all previous ones.
Expand Down Expand Up @@ -435,3 +429,13 @@ def add_startup_flag(self, flag: str) -> None:
if self.config.getoption("verbose"):
print("[ADD STARTUP FLAG '{}']".format(flag))
self.startup_flags.append("--{}".format(flag))

def close_channel(self, channel_id: str) -> bool:
if self.config.getoption("verbose"):
print("[CLOSE CHANNEL '{}']".format(channel_id))
try:
self.rpc.close(peer_id=channel_id)
except Exception as ex:
print(ex)
return False
return True
73 changes: 41 additions & 32 deletions lnprototest/dummyrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


class DummyRunner(Runner):

def __init__(self, config: Any):
super().__init__(config)

Expand Down Expand Up @@ -86,12 +87,12 @@ def recv(self, event: Event, conn: Conn, outbuf: bytes) -> None:
print("[RECV {} {}]".format(event, outbuf.hex()))

def fundchannel(
self,
event: Event,
conn: Conn,
amount: int,
feerate: int = 253,
expect_fail: bool = False,
self,
event: Event,
conn: Conn,
amount: int,
feerate: int = 253,
expect_fail: bool = False,
) -> None:
if self.config.getoption("verbose"):
print(
Expand All @@ -101,14 +102,14 @@ def fundchannel(
)

def init_rbf(
self,
event: Event,
conn: Conn,
channel_id: str,
amount: int,
utxo_txid: str,
utxo_outnum: int,
feerate: int,
self,
event: Event,
conn: Conn,
channel_id: str,
amount: int,
utxo_txid: str,
utxo_outnum: int,
feerate: int,
) -> None:
if self.config.getoption("verbose"):
print(
Expand Down Expand Up @@ -143,21 +144,21 @@ def fake_field(ftype: FieldType) -> str:
if ftype.elemtype.name == "byte":
return "00" * ftype.arraysize
return (
"["
+ ",".join([DummyRunner.fake_field(ftype.elemtype)] * ftype.arraysize)
+ "]"
"["
+ ",".join([DummyRunner.fake_field(ftype.elemtype)] * ftype.arraysize)
+ "]"
)
elif ftype.name in (
"byte",
"u8",
"u16",
"u32",
"u64",
"tu16",
"tu32",
"tu64",
"bigsize",
"varint",
"byte",
"u8",
"u16",
"u32",
"u64",
"tu16",
"tu32",
"tu64",
"bigsize",
"varint",
):
return "0"
elif ftype.name in ("chain_hash", "channel_id", "sha256"):
Expand Down Expand Up @@ -200,10 +201,18 @@ def check_error(self, event: Event, conn: Conn) -> Optional[str]:
return "Dummy error"

def check_final_error(
self,
event: Event,
conn: Conn,
expected: bool,
must_not_events: List[MustNotMsg],
self,
event: Event,
conn: Conn,
expected: bool,
must_not_events: List[MustNotMsg],
) -> None:
pass

def close_channel(self, channel_id: str) -> bool:
if self.config.getoption("verbose"):
print("[CLOSE-CHANNEL {}]".format(channel_id))
return True

def is_running(self) -> bool:
return True
Loading

0 comments on commit e0cce0a

Please sign in to comment.