Skip to content

Commit

Permalink
chore: Python format rule adjustments
Browse files Browse the repository at this point in the history
  • Loading branch information
michael1011 committed Oct 29, 2023
1 parent e3b387a commit e2bb8ae
Show file tree
Hide file tree
Showing 20 changed files with 117 additions and 229 deletions.
4 changes: 1 addition & 3 deletions docker/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,7 @@ def build_images(
)
else:
extra_tag = "" if no_latest else f"-t {name}:latest"
command = (
f"docker build -t {name}:{tag} {extra_tag} -f {dockerfile} {args} ."
)
command = f"docker build -t {name}:{tag} {extra_tag} -f {dockerfile} {args} ."

if no_cache:
command = command + " --no-cache"
Expand Down
9 changes: 6 additions & 3 deletions lib/swap/SwapNursery.ts
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,8 @@ class SwapNursery extends EventEmitter implements ISwapNursery {
if (
typeof error !== 'object' ||
((error as any).details !== 'unable to locate invoice' &&
(error as any).details !== 'there are no existing invoices')
(error as any).details !== 'there are no existing invoices' &&
(error as any).message !== 'hold invoice not found')
) {
this.logger.error(
`Could not cancel invoice${plural} of Reverse Swap ${
Expand All @@ -446,10 +447,12 @@ class SwapNursery extends EventEmitter implements ISwapNursery {
);
return;
} else {
this.logger.warn(
this.logger.silly(
`Cancelling invoice${plural} of Reverse Swap ${
reverseSwap.id
} failed although they could be found: ${formatError(error)}`,
} failed because they could not be found: ${formatError(
error,
)}`,
);
}
}
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"python:install": "cd tools && poetry install",
"python:proto": "cd tools && poetry run python -m grpc_tools.protoc -I hold/protos --python_out=hold/protos --pyi_out=hold/protos --grpc_python_out=hold/protos hold/protos/hold.proto",
"python:lint": "cd tools && poetry run ruff *.py hold/ backup/*.py ../docker/*.py",
"python:format": "cd tools && poetry run ruff format . ../docker/",
"python:format": "cd tools && poetry run ruff format **/* ../docker/ ",
"python:test": "cd tools && poetry run pytest"
},
"bin": {
Expand Down
15 changes: 5 additions & 10 deletions tools/hold/certs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def get_path(file_name: str) -> str:

if all(Path.exists(Path(path)) for path in [cert_path, key_path]):
return Certificate(
key=Path.read_bytes(Path(key_path)), cert=Path.read_bytes(Path(cert_path))
key=Path.read_bytes(Path(key_path)),
cert=Path.read_bytes(Path(cert_path)),
)

key = ec.generate_private_key(curve=SECP256R1())
Expand All @@ -49,26 +50,20 @@ def get_path(file_name: str) -> str:

cert = (
x509.CertificateBuilder()
.subject_name(
issuer_name if is_ca else create_cert_name(f"{subject_prefix} {name}")
)
.subject_name(issuer_name if is_ca else create_cert_name(f"{subject_prefix} {name}"))
.issuer_name(issuer_name)
.public_key(key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(time_now())
.not_valid_after(time_now() + datetime.timedelta(weeks=52 * 10))
.add_extension(
x509.SubjectAlternativeName(
[x509.DNSName(name) for name in ["hold", "localhost"]]
),
x509.SubjectAlternativeName([x509.DNSName(name) for name in ["hold", "localhost"]]),
critical=False,
)
)

if is_ca:
cert = cert.add_extension(
x509.BasicConstraints(ca=True, path_length=None), critical=True
)
cert = cert.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)

cert = cert.sign(
serialization.load_pem_private_key(ca_key, password=None) if not is_ca else key,
Expand Down
5 changes: 1 addition & 4 deletions tools/hold/hold.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ def invoice(
min_final_cltv_expiry: int,
route_hints: list[RouteHint] | None = None,
) -> str:
if (
len(self._plugin.rpc.listinvoices(payment_hash=payment_hash)["invoices"])
> 0
):
if len(self._plugin.rpc.listinvoices(payment_hash=payment_hash)["invoices"]) > 0:
raise InvoiceExistsError

bolt11 = self._encoder.encode(
Expand Down
20 changes: 5 additions & 15 deletions tools/hold/htlc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ class HtlcHandler:
_ds: DataStore
_settler: Settler

def __init__(
self, plugin: Plugin, ds: DataStore, settler: Settler, tracker: Tracker
) -> None:
def __init__(self, plugin: Plugin, ds: DataStore, settler: Settler, tracker: Tracker) -> None:
self._plugin = plugin
self._ds = ds
self._settler = settler
Expand Down Expand Up @@ -84,10 +82,7 @@ def handle_htlc(
self._fail_and_save_htlc(request, invoice, htlc)
return

if (
"payment_secret" not in onion
or onion["payment_secret"] != dec["payment_secret"]
):
if "payment_secret" not in onion or onion["payment_secret"] != dec["payment_secret"]:
self._log_htlc_rejected(
invoice,
htlc,
Expand All @@ -97,9 +92,7 @@ def handle_htlc(
return

if invoice.state != InvoiceState.Unpaid:
self._log_htlc_rejected(
invoice, htlc, f"invoice is in state {invoice.state}"
)
self._log_htlc_rejected(invoice, htlc, f"invoice is in state {invoice.state}")
self._fail_and_save_htlc(request, invoice, htlc)
return

Expand All @@ -122,13 +115,10 @@ def handle_htlc(
invoice.set_state(self._tracker, InvoiceState.Accepted)
self._ds.save_invoice(invoice, mode="must-replace")
self._plugin.log(
f"Accepted hold invoice {invoice.payment_hash} "
f"with {len(invoice.htlcs)} HTLCs",
f"Accepted hold invoice {invoice.payment_hash} " f"with {len(invoice.htlcs)} HTLCs",
)

def handle_known_htlc(
self, invoice: HoldInvoice, htlc: Htlc, request: Request
) -> None:
def handle_known_htlc(self, invoice: HoldInvoice, htlc: Htlc, request: Request) -> None:
if htlc.state == HtlcState.Accepted:
# Pass the request to the settler to handle in the future
self._settler.add_htlc(invoice.payment_hash, request, htlc)
Expand Down
23 changes: 8 additions & 15 deletions tools/hold/invoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ def from_dict(cls: type[HtlcType], htlc_dict: dict[str, Any]) -> HtlcType:

@classmethod
def from_json_dict(cls: type[HtlcType], json_dict: dict[str, Any]) -> HtlcType:
json_dict["created_at"] = datetime.fromtimestamp(
json_dict["created_at"], tz=timezone.utc
)
json_dict["created_at"] = datetime.fromtimestamp(json_dict["created_at"], tz=timezone.utc)

return cls(**json_dict)

Expand Down Expand Up @@ -79,14 +77,15 @@ def find_htlc(self, short_channel_id: str, channel_id: int) -> Htlc | None:
(
htlc
for htlc in self.htlcs
if htlc.short_channel_id == short_channel_id
and htlc.channel_id == channel_id
if htlc.short_channel_id == short_channel_id and htlc.channel_id == channel_id
),
None,
)

def cancel_expired(
self, expiry: int, fail_callback: Callable[[Htlc, HtlcFailureMessage], None]
self,
expiry: int,
fail_callback: Callable[[Htlc, HtlcFailureMessage], None],
) -> None:
for htlc in self.htlcs:
if not (time_now() - htlc.created_at).total_seconds() > expiry:
Expand Down Expand Up @@ -156,9 +155,7 @@ def is_fully_paid(self) -> bool:

def sum_paid(self) -> int:
return sum(
h.msat
for h in self.htlcs.htlcs
if h.state in [HtlcState.Paid, HtlcState.Accepted]
h.msat for h in self.htlcs.htlcs if h.state in [HtlcState.Paid, HtlcState.Accepted]
)

def to_json(self) -> str:
Expand Down Expand Up @@ -186,15 +183,11 @@ def from_json(cls: type[HoldInvoiceType], json_str: str) -> HoldInvoiceType:
json_str = json_str.removesuffix("\\}") + "}"

json_dict = json.loads(json_str)
json_dict["created_at"] = datetime.fromtimestamp(
json_dict["created_at"], tz=timezone.utc
)
json_dict["created_at"] = datetime.fromtimestamp(json_dict["created_at"], tz=timezone.utc)

if "amount_msat" not in json_dict:
json_dict["amount_msat"] = bolt11.decode(json_dict["bolt11"]).amount_msat

json_dict["htlcs"] = (
Htlcs.from_json_arr(json_dict["htlcs"]) if "htlcs" in json_dict else []
)
json_dict["htlcs"] = Htlcs.from_json_arr(json_dict["htlcs"]) if "htlcs" in json_dict else []

return cls(**json_dict)
8 changes: 2 additions & 6 deletions tools/hold/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def hold_invoice(
method_name="listholdinvoices",
category=PLUGIN_NAME,
)
def list_hold_invoices(
plugin: Plugin, payment_hash: str = "", invoice: str = ""
) -> dict[str, Any]:
def list_hold_invoices(plugin: Plugin, payment_hash: str = "", invoice: str = "") -> dict[str, Any]:
"""List one or more hold invoices."""
if payment_hash in empty_value and invoice not in empty_value:
payment_hash = bolt11.decode(invoice).payment_hash
Expand All @@ -100,9 +98,7 @@ def list_hold_invoices(
payment_hash = bolt11.decode(payment_hash).payment_hash

return {
"holdinvoices": [
invoice.to_dict() for invoice in hold.list_invoices(payment_hash)
],
"holdinvoices": [invoice.to_dict() for invoice in hold.list_invoices(payment_hash)],
}


Expand Down
9 changes: 2 additions & 7 deletions tools/hold/protos/hold_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

48 changes: 9 additions & 39 deletions tools/hold/protos/hold_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,7 @@ class RoutingHintsRequest(_message.Message):
def __init__(self, node: _Optional[str] = ...) -> None: ...

class Hop(_message.Message):
__slots__ = [
"public_key",
"short_channel_id",
"base_fee",
"ppm_fee",
"cltv_expiry_delta",
]
__slots__ = ["public_key", "short_channel_id", "base_fee", "ppm_fee", "cltv_expiry_delta"]
PUBLIC_KEY_FIELD_NUMBER: _ClassVar[int]
SHORT_CHANNEL_ID_FIELD_NUMBER: _ClassVar[int]
BASE_FEE_FIELD_NUMBER: _ClassVar[int]
Expand All @@ -117,9 +111,7 @@ class RoutingHint(_message.Message):
__slots__ = ["hops"]
HOPS_FIELD_NUMBER: _ClassVar[int]
hops: _containers.RepeatedCompositeFieldContainer[Hop]
def __init__(
self, hops: _Optional[_Iterable[_Union[Hop, _Mapping]]] = ...
) -> None: ...
def __init__(self, hops: _Optional[_Iterable[_Union[Hop, _Mapping]]] = ...) -> None: ...

class RoutingHintsResponse(_message.Message):
__slots__ = ["hints"]
Expand Down Expand Up @@ -195,9 +187,7 @@ class ListResponse(_message.Message):
__slots__ = ["invoices"]
INVOICES_FIELD_NUMBER: _ClassVar[int]
invoices: _containers.RepeatedCompositeFieldContainer[Invoice]
def __init__(
self, invoices: _Optional[_Iterable[_Union[Invoice, _Mapping]]] = ...
) -> None: ...
def __init__(self, invoices: _Optional[_Iterable[_Union[Invoice, _Mapping]]] = ...) -> None: ...

class SettleRequest(_message.Message):
__slots__ = ["payment_preimage"]
Expand Down Expand Up @@ -258,10 +248,8 @@ class PayStatusRequest(_message.Message):

class PayStatusResponse(_message.Message):
__slots__ = ["status"]

class PayStatus(_message.Message):
__slots__ = ["bolt11", "amount_msat", "destination", "attempts"]

class Attempt(_message.Message):
__slots__ = [
"strategy",
Expand All @@ -272,33 +260,23 @@ class PayStatusResponse(_message.Message):
"success",
"failure",
]

class AttemptState(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = []
ATTEMPT_PENDING: _ClassVar[
PayStatusResponse.PayStatus.Attempt.AttemptState
]
ATTEMPT_COMPLETED: _ClassVar[
PayStatusResponse.PayStatus.Attempt.AttemptState
]
ATTEMPT_PENDING: _ClassVar[PayStatusResponse.PayStatus.Attempt.AttemptState]
ATTEMPT_COMPLETED: _ClassVar[PayStatusResponse.PayStatus.Attempt.AttemptState]
ATTEMPT_PENDING: PayStatusResponse.PayStatus.Attempt.AttemptState
ATTEMPT_COMPLETED: PayStatusResponse.PayStatus.Attempt.AttemptState

class Success(_message.Message):
__slots__ = ["id", "payment_preimage"]
ID_FIELD_NUMBER: _ClassVar[int]
PAYMENT_PREIMAGE_FIELD_NUMBER: _ClassVar[int]
id: int
payment_preimage: str
def __init__(
self,
id: _Optional[int] = ...,
payment_preimage: _Optional[str] = ...,
self, id: _Optional[int] = ..., payment_preimage: _Optional[str] = ...
) -> None: ...

class Failure(_message.Message):
__slots__ = ["message", "code", "data"]

class Data(_message.Message):
__slots__ = [
"id",
Expand Down Expand Up @@ -340,9 +318,7 @@ class PayStatusResponse(_message.Message):
message: _Optional[str] = ...,
code: _Optional[int] = ...,
data: _Optional[
_Union[
PayStatusResponse.PayStatus.Attempt.Failure.Data, _Mapping
]
_Union[PayStatusResponse.PayStatus.Attempt.Failure.Data, _Mapping]
] = ...,
) -> None: ...
STRATEGY_FIELD_NUMBER: _ClassVar[int]
Expand Down Expand Up @@ -382,9 +358,7 @@ class PayStatusResponse(_message.Message):
bolt11: str
amount_msat: int
destination: str
attempts: _containers.RepeatedCompositeFieldContainer[
PayStatusResponse.PayStatus.Attempt
]
attempts: _containers.RepeatedCompositeFieldContainer[PayStatusResponse.PayStatus.Attempt]
def __init__(
self,
bolt11: _Optional[str] = ...,
Expand All @@ -397,10 +371,7 @@ class PayStatusResponse(_message.Message):
STATUS_FIELD_NUMBER: _ClassVar[int]
status: _containers.RepeatedCompositeFieldContainer[PayStatusResponse.PayStatus]
def __init__(
self,
status: _Optional[
_Iterable[_Union[PayStatusResponse.PayStatus, _Mapping]]
] = ...,
self, status: _Optional[_Iterable[_Union[PayStatusResponse.PayStatus, _Mapping]]] = ...
) -> None: ...

class GetRouteRequest(_message.Message):
Expand Down Expand Up @@ -436,7 +407,6 @@ class GetRouteRequest(_message.Message):

class GetRouteResponse(_message.Message):
__slots__ = ["hops", "fees_msat"]

class Hop(_message.Message):
__slots__ = ["id", "channel", "direction", "amount_msat", "delay", "style"]
ID_FIELD_NUMBER: _ClassVar[int]
Expand Down
4 changes: 1 addition & 3 deletions tools/hold/protos/hold_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,7 @@ def add_HoldServicer_to_server(servicer, server):
response_serializer=hold__pb2.GetRouteResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
"hold.Hold", rpc_method_handlers
)
generic_handler = grpc.method_handlers_generic_handler("hold.Hold", rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))


Expand Down
3 changes: 2 additions & 1 deletion tools/hold/route_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def get_private_channels(self, node: str) -> list[RouteHint]:
]
)
for chan in filter(
lambda chan: not chan["public"] and chan["source"] == node, chans
lambda chan: not chan["public"] and chan["source"] == node,
chans,
)
]
Loading

0 comments on commit e2bb8ae

Please sign in to comment.