Skip to content

Commit

Permalink
Handle v2 errors (v1 fails)
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGould committed Nov 6, 2023
1 parent 7ddabb1 commit 9eb5cfc
Show file tree
Hide file tree
Showing 10 changed files with 286 additions and 122 deletions.
1 change: 1 addition & 0 deletions payjoin-cli/seen_inputs.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
["c93eb8f0c617f1150bdf311f594774c7c50a9518e954b83b5424753426d91a5e:1"][["c93eb8f0c617f1150bdf311f594774c7c50a9518e954b83b5424753426d91a5e:1"]
30 changes: 18 additions & 12 deletions payjoin-cli/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,20 @@ impl App {
&self,
client: &reqwest::blocking::Client,
enroll_context: &mut EnrollContext,
) -> Result<UncheckedProposal, reqwest::Error> {
) -> Result<UncheckedProposal> {
loop {
let (payjoin_get_body, context) = enroll_context.payjoin_get_body();
let (payjoin_get_body, context) = enroll_context
.payjoin_get_body()
.map_err(|e| anyhow!("Failed to create payjoin GET body: {}", e))?;
let ohttp_response =
client.post(&self.config.ohttp_proxy).body(payjoin_get_body).send()?;
let ohttp_response = ohttp_response.bytes()?;
let proposal =
enroll_context.parse_relay_response(ohttp_response.as_ref(), context).unwrap();
let proposal = enroll_context
.parse_relay_response(ohttp_response.as_ref(), context)
.map_err(|e| anyhow!("parse error {}", e))?;
log::debug!("got response");
match proposal {
Some(proposal) => return Ok(proposal),
Some(proposal) => break Ok(proposal),
None => std::thread::sleep(std::time::Duration::from_secs(5)),
}
}
Expand Down Expand Up @@ -229,17 +233,19 @@ impl App {
.build()
.with_context(|| "Failed to build reqwest http client")?;
log::debug!("Awaiting request");
let _enroll = client.post(&self.config.pj_endpoint).body(context.enroll_body()).send()?;
let (body, _) = context.enroll_body().unwrap();
let _enroll = client.post(&self.config.pj_endpoint).body(body).send()?;

log::debug!("Awaiting proposal");
let res = self.long_poll_get(&client, &mut context)?;
log::debug!("Received request");
let payjoin_proposal = self
.process_proposal(proposal)
.map_err(|e| anyhow!("Failed to process UncheckedProposal {}", e))?;
let payjoin_endpoint = format!("{}/{}/receive", self.config.pj_endpoint, pubkey_base64);
let (body, ohttp_ctx) =
payjoin_proposal.extract_v2_req(&self.config.ohttp_config, &payjoin_endpoint);
let payjoin_proposal =
self.process_proposal(res).map_err(|e| anyhow!("Failed to process proposal {}", e))?;
log::debug!("Posting payjoin back");
let receive_endpoint = format!("{}/{}", self.config.pj_endpoint, context.payjoin_subdir());
let (body, ohttp_ctx) = payjoin_proposal
.extract_v2_req(&self.config.ohttp_config, &receive_endpoint)
.map_err(|e| anyhow!("v2 req extraction failed {}", e))?;
let res = client
.post(&self.config.ohttp_proxy)
.body(body)
Expand Down
2 changes: 1 addition & 1 deletion payjoin-relay/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ edition = "2021"
[dependencies]
hyper = { version = "0.14", features = ["full"] }
anyhow = "1.0.71"
payjoin = { path = "../payjoin", features = ["base64"] }
payjoin = { path = "../payjoin", features = ["base64", "v2"] }
# ohttp = "0.4.0"
ohttp = { path = "../../ohttp/ohttp" }
bhttp = { version = "0.4.0", features = ["http"] }
Expand Down
84 changes: 52 additions & 32 deletions payjoin-relay/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use std::env;
use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::Arc;

use anyhow::Result;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, StatusCode, Uri};
use payjoin::{base64, bitcoin};
use tracing::{debug, info, trace};
use tracing::{debug, error, info, trace};
use tracing_subscriber::filter::LevelFilter;
use tracing_subscriber::EnvFilter;

Expand Down Expand Up @@ -72,7 +71,7 @@ fn init_ohttp() -> Result<ohttp::Server> {
let server_config = ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC))?;
let encoded_config = server_config.encode()?;
let b64_config = base64::encode_config(
encoded_config,
&encoded_config,
base64::Config::new(base64::CharacterSet::UrlSafe, false),
);
info!("ohttp server config base64 UrlSafe: {:?}", b64_config);
Expand Down Expand Up @@ -112,33 +111,41 @@ async fn handle_ohttp(
) -> Result<Response<Body>, HandlerError> {
// decapsulate
let ohttp_body =
hyper::body::to_bytes(body).await.map_err(|_| HandlerError::InternalServerError)?;
hyper::body::to_bytes(body).await.map_err(|e| HandlerError::BadRequest(e.into()))?;

let (bhttp_req, res_ctx) = ohttp.decapsulate(&ohttp_body).unwrap();
let (bhttp_req, res_ctx) =
ohttp.decapsulate(&ohttp_body).map_err(|e| HandlerError::BadRequest(e.into()))?;
let mut cursor = std::io::Cursor::new(bhttp_req);
let req = bhttp::Message::read_bhttp(&mut cursor).unwrap();
let req =
bhttp::Message::read_bhttp(&mut cursor).map_err(|e| HandlerError::BadRequest(e.into()))?;
let uri = Uri::builder()
.scheme(req.control().scheme().unwrap())
.authority(req.control().authority().unwrap())
.path_and_query(req.control().path().unwrap())
.build()
.unwrap();
.scheme(req.control().scheme().unwrap_or_default())
.authority(req.control().authority().unwrap_or_default())
.path_and_query(req.control().path().unwrap_or_default())
.build()?;
let body = req.content().to_vec();
let mut http_req = Request::builder().uri(uri).method(req.control().method().unwrap());
let mut http_req =
Request::builder().uri(uri).method(req.control().method().unwrap_or_default());
for header in req.header().fields() {
http_req = http_req.header(header.name(), header.value())
}
let request = http_req.body(Body::from(body)).unwrap();
let request = http_req.body(Body::from(body))?;

let response = handle_v2(pool, request).await?;

let (parts, body) = response.into_parts();
let mut bhttp_res = bhttp::Message::response(parts.status.as_u16());
let full_body = hyper::body::to_bytes(body).await.unwrap();
let full_body = hyper::body::to_bytes(body)
.await
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
bhttp_res.write_content(&full_body);
let mut bhttp_bytes = Vec::new();
bhttp_res.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes).unwrap();
let ohttp_res = res_ctx.encapsulate(&bhttp_bytes).unwrap();
bhttp_res
.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes)
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
let ohttp_res = res_ctx
.encapsulate(&bhttp_bytes)
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
Ok(Response::new(Body::from(ohttp_res)))
}

Expand All @@ -159,16 +166,22 @@ async fn handle_v2(pool: DbPool, req: Request<Body>) -> Result<Response<Body>, H

enum HandlerError {
PayloadTooLarge,
InternalServerError,
BadRequest,
InternalServerError(Box<dyn std::error::Error>),
BadRequest(Box<dyn std::error::Error>),
}

impl HandlerError {
fn to_response(&self) -> Response<Body> {
let status = match self {
HandlerError::PayloadTooLarge => StatusCode::PAYLOAD_TOO_LARGE,
HandlerError::BadRequest => StatusCode::BAD_REQUEST,
_ => StatusCode::INTERNAL_SERVER_ERROR,
Self::InternalServerError(e) => {
error!("Internal server error: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
}
Self::BadRequest(e) => {
error!("Bad request: {}", e);
StatusCode::BAD_REQUEST
}
};

let mut res = Response::new(Body::empty());
Expand All @@ -178,17 +191,19 @@ impl HandlerError {
}

impl From<hyper::http::Error> for HandlerError {
fn from(_: hyper::http::Error) -> Self { HandlerError::InternalServerError }
fn from(e: hyper::http::Error) -> Self { HandlerError::InternalServerError(e.into()) }
}

async fn post_enroll(body: Body) -> Result<Response<Body>, HandlerError> {
let b64_config = base64::Config::new(base64::CharacterSet::UrlSafe, false);
let bytes = hyper::body::to_bytes(body).await.map_err(|_| HandlerError::BadRequest)?;
let base64_id = String::from_utf8(bytes.to_vec()).map_err(|_| HandlerError::BadRequest)?;
let pubkey_bytes: Vec<u8> =
base64::decode_config(base64_id, b64_config).map_err(|_| HandlerError::BadRequest)?;
let bytes =
hyper::body::to_bytes(body).await.map_err(|e| HandlerError::BadRequest(e.into()))?;
let base64_id =
String::from_utf8(bytes.to_vec()).map_err(|e| HandlerError::BadRequest(e.into()))?;
let pubkey_bytes: Vec<u8> = base64::decode_config(base64_id, b64_config)
.map_err(|e| HandlerError::BadRequest(e.into()))?;
let pubkey = bitcoin::secp256k1::PublicKey::from_slice(&pubkey_bytes)
.map_err(|_| HandlerError::BadRequest)?;
.map_err(|e| HandlerError::BadRequest(e.into()))?;
tracing::info!("Enrolled valid pubkey: {:?}", pubkey);
Ok(Response::builder().status(StatusCode::NO_CONTENT).body(Body::empty())?)
}
Expand Down Expand Up @@ -223,20 +238,23 @@ async fn post_fallback(
) -> Result<Response<Body>, HandlerError> {
tracing::trace!("Post fallback");
let id = shorten_string(id);
let req = hyper::body::to_bytes(body).await.map_err(|_| HandlerError::InternalServerError)?;
let req = hyper::body::to_bytes(body)
.await
.map_err(|e| HandlerError::InternalServerError(e.into()))?;

if req.len() > MAX_BUFFER_SIZE {
return Err(HandlerError::PayloadTooLarge);
}

match pool.push_req(&id, req.into()).await {
Ok(_) => (),
Err(_) => return Err(HandlerError::BadRequest),
Err(e) => return Err(HandlerError::BadRequest(e.into())),
};

match pool.peek_res(&id).await {
Some(result) => match result {
Ok(buffered_res) => Ok(Response::new(Body::from(buffered_res))),
Err(_) => Err(HandlerError::BadRequest),
Err(e) => Err(HandlerError::BadRequest(e.into())),
},
None => Ok(none_response),
}
Expand All @@ -247,19 +265,21 @@ async fn get_fallback(id: &str, pool: DbPool) -> Result<Response<Body>, HandlerE
match pool.peek_req(&id).await {
Some(result) => match result {
Ok(buffered_req) => Ok(Response::new(Body::from(buffered_req))),
Err(_) => Err(HandlerError::BadRequest),
Err(e) => Err(HandlerError::BadRequest(e.into())),
},
None => Ok(Response::builder().status(StatusCode::ACCEPTED).body(Body::empty())?),
}
}

async fn post_payjoin(id: &str, body: Body, pool: DbPool) -> Result<Response<Body>, HandlerError> {
let id = shorten_string(id);
let res = hyper::body::to_bytes(body).await.map_err(|_| HandlerError::InternalServerError)?;
let res = hyper::body::to_bytes(body)
.await
.map_err(|e| HandlerError::InternalServerError(e.into()))?;

match pool.push_res(&id, res.into()).await {
Ok(_) => Ok(Response::builder().status(StatusCode::NO_CONTENT).body(Body::empty())?),
Err(_) => Err(HandlerError::BadRequest),
Err(e) => Err(HandlerError::BadRequest(e.into())),
}
}

Expand Down
31 changes: 18 additions & 13 deletions payjoin-relay/tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ mod integration {
// Enroll with relay
let mut enroll_ctx =
EnrollContext::from_relay_config(&RELAY_URL, &ohttp_config_base64, &RELAY_URL);
let (enroll_body, _ohttp_req_ctx) = enroll_ctx.enroll_body();
let (enroll_body, _ohttp_req_ctx) = enroll_ctx.enroll_body().expect("Failed to enroll");
let _ohttp_response =
http.post(RELAY_URL).body(enroll_body).send().await.expect("Failed to send request");
log::debug!("Enrolled receiver");
Expand Down Expand Up @@ -150,7 +150,8 @@ mod integration {
// **********************
// Inside the Receiver:
// GET fallback_psbt
let (payjoin_get_body, ohttp_req_ctx) = enroll_ctx.payjoin_get_body();
let (payjoin_get_body, ohttp_req_ctx) =
enroll_ctx.payjoin_get_body().expect("Failed to get fallback");
let ohttp_response = http
.post(RELAY_URL)
.body(payjoin_get_body)
Expand All @@ -162,18 +163,20 @@ mod integration {
);
let proposal = enroll_ctx.parse_relay_response(reader, ohttp_req_ctx).unwrap().unwrap();
let payjoin_proposal = handle_proposal(proposal, receiver);

let (body, _ohttp_ctx) = payjoin_proposal.extract_v2_req(
&ohttp_config_base64,
&format!("{}/{}", &RELAY_URL, &enroll_ctx.payjoin_subdir()),
);
let (body, _ohttp_ctx) = payjoin_proposal
.extract_v2_req(
&ohttp_config_base64,
&format!("{}/{}", &RELAY_URL, &enroll_ctx.payjoin_subdir()),
)
.expect("Failed to extract v2 req");
let _ohttp_response =
http.post(RELAY_URL).body(body).send().await.expect("Failed to post payjoin_psbt");

// **********************
// Inside the Sender:
// Sender checks, signs, finalizes, extracts, and broadcasts
log::info!("replay POST fallback psbt for payjoin_psbt response");
log::info!("Req body {:#?}", &req.body);
let response = http
.post(req.url.as_str())
.body(req.body.clone())
Expand Down Expand Up @@ -256,7 +259,7 @@ mod integration {
// Enroll with relay
let mut enroll_ctx =
EnrollContext::from_relay_config(&RELAY_URL, &ohttp_config_base64, &RELAY_URL);
let (enroll_body, _ohttp_req_ctx) = enroll_ctx.enroll_body();
let (enroll_body, _ohttp_req_ctx) = enroll_ctx.enroll_body().unwrap();
let enroll =
http.post(RELAY_URL).body(enroll_body).send().await.expect("Failed to send request");

Expand Down Expand Up @@ -331,7 +334,7 @@ mod integration {
.expect("Failed to build reqwest http client");

let proposal = loop {
let (payjoin_get_body, ohttp_req_ctx) = enroll_ctx.payjoin_get_body();
let (payjoin_get_body, ohttp_req_ctx) = enroll_ctx.payjoin_get_body().unwrap();
let enc_response = http
.post(RELAY_URL)
.body(payjoin_get_body)
Expand All @@ -355,10 +358,12 @@ mod integration {
debug!("handle relay response");
let response = handle_proposal(proposal, receiver);
debug!("Post payjoin_psbt to relay");
let (body, _ohttp_ctx) = response.extract_v2_req(
&ohttp_config_base64,
&format!("{}/{}", &RELAY_URL, &enroll_ctx.payjoin_subdir()),
);
let (body, _ohttp_ctx) = response
.extract_v2_req(
&ohttp_config_base64,
&format!("{}/{}", &RELAY_URL, &enroll_ctx.payjoin_subdir()),
)
.unwrap();
// Respond with payjoin psbt within the time window the sender is willing to wait
let response = http.post(RELAY_URL).body(body).send().await;
debug!("POSTed with payjoin_psbt");
Expand Down
16 changes: 16 additions & 0 deletions payjoin/src/receive/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@ pub enum Error {
BadRequest(RequestError),
// To be returned as HTTP 500
Server(Box<dyn error::Error>),
// V2 d/encapsulation failed
#[cfg(feature = "v2")]
V2(crate::v2::Error),
}

impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self {
Self::BadRequest(e) => e.fmt(f),
Self::Server(e) => write!(f, "Internal Server Error: {}", e),
#[cfg(feature = "v2")]
Self::V2(e) => e.fmt(f),
}
}
}
Expand All @@ -23,6 +28,8 @@ impl error::Error for Error {
match &self {
Self::BadRequest(_) => None,
Self::Server(e) => Some(e.as_ref()),
#[cfg(feature = "v2")]
Self::V2(e) => Some(e),
}
}
}
Expand All @@ -31,6 +38,15 @@ impl From<RequestError> for Error {
fn from(e: RequestError) -> Self { Error::BadRequest(e) }
}

impl From<InternalRequestError> for Error {
fn from(e: InternalRequestError) -> Self { Error::BadRequest(e.into()) }
}

impl From<crate::v2::Error> for Error {
#[cfg(feature = "v2")]
fn from(e: crate::v2::Error) -> Self { Error::V2(e) }
}

/// Error that may occur when the request from sender is malformed.
///
/// This is currently opaque type because we aren't sure which variants will stay.
Expand Down
Loading

0 comments on commit 9eb5cfc

Please sign in to comment.