Skip to content

Commit

Permalink
Handle v2 errors (v1 fails)
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGould committed Oct 24, 2023
1 parent a2b1f46 commit 38f7dc1
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 87 deletions.
16 changes: 10 additions & 6 deletions payjoin-cli/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,14 @@ impl App {
&self,
client: &reqwest::blocking::Client,
enroll_context: &mut EnrollContext,
) -> Result<UncheckedProposal, reqwest::Error> {
) -> Result<UncheckedProposal> {
loop {
let (enroll_body, context) = enroll_context.enroll_body();
let (enroll_body, context) = enroll_context.enroll_body()?;
let ohttp_response = client.post(&self.config.ohttp_proxy).body(enroll_body).send()?;
let ohttp_response = ohttp_response.bytes()?;
let proposal =
enroll_context.parse_relay_response(ohttp_response.as_ref(), context).unwrap();
let proposal = enroll_context
.parse_relay_response(ohttp_response.as_ref(), context)
.map_err(|e| anyhow!("parse error {}", e))?;
match proposal {
Some(proposal) => return Ok(proposal),
None => std::thread::sleep(std::time::Duration::from_secs(5)),
Expand Down Expand Up @@ -235,8 +236,11 @@ impl App {
.map_err(|e| anyhow!("Failed to parse into UncheckedProposal {}", e))?;

let receive_endpoint = format!("{}/{}", self.config.pj_endpoint, context.receive_subdir());
let (body, ohttp_ctx) =
payjoin_proposal.extract_v2_req(&self.config.ohttp_config, &receive_endpoint);
let ohttp_config =
bitcoin::base64::decode_config(&self.config.ohttp_config, base64::URL_SAFE)?;
let (body, ohttp_ctx) = payjoin_proposal
.extract_v2_req(&ohttp_config, &receive_endpoint)
.map_err(|e| anyhow!("v2 req extraction failed {}", e))?;
let res = client
.post(&self.config.ohttp_proxy)
.body(body)
Expand Down
52 changes: 31 additions & 21 deletions payjoin-relay/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ fn init_ohttp() -> Result<ohttp::Server> {
&[SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305)];

// create or read from file
let server_config = ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap();
let encoded_config = server_config.encode().unwrap();
let server_config = ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC))?;
let encoded_config = server_config.encode()?;
let b64_config = payjoin::bitcoin::base64::encode_config(
&encoded_config,
payjoin::bitcoin::base64::Config::new(
Expand All @@ -83,43 +83,53 @@ fn init_ohttp() -> Result<ohttp::Server> {

async fn handle_ohttp(
enc_request: Bytes,
mut target: Router,
target: Router,
ohttp: Arc<ohttp::Server>,
) -> (StatusCode, Vec<u8>) {
match handle_ohttp_inner(enc_request, target, ohttp).await {
Ok(res) => res,
Err(e) => {
tracing::error!("ohttp error: {:?}", e);
(StatusCode::INTERNAL_SERVER_ERROR, vec![])
}
}
}

async fn handle_ohttp_inner(
enc_request: Bytes,
mut target: Router,
ohttp: Arc<ohttp::Server>,
) -> Result<(StatusCode, Vec<u8>)> {
use axum::body::Body;
use http::Uri;
use tower_service::Service;

// decapsulate
let (bhttp_req, res_ctx) = ohttp.decapsulate(&enc_request).unwrap();
let (bhttp_req, res_ctx) = ohttp.decapsulate(&enc_request)?;
let mut cursor = std::io::Cursor::new(bhttp_req);
let req = bhttp::Message::read_bhttp(&mut cursor).unwrap();
// let parsed_request: httparse::Request = httparse::Request::new(&mut vec![]).parse(cursor).unwrap();
// // handle request
// Request::new
let req = bhttp::Message::read_bhttp(&mut cursor)?;
let uri = Uri::builder()
.scheme(req.control().scheme().unwrap())
.authority(req.control().authority().unwrap())
.path_and_query(req.control().path().unwrap())
.build()
.unwrap();
.scheme(req.control().scheme().unwrap_or_default())
.authority(req.control().authority().unwrap_or_default())
.path_and_query(req.control().path().unwrap_or_default())
.build()?;
let body = req.content().to_vec();
let mut request = Request::builder().uri(uri).method(req.control().method().unwrap());
let mut request =
Request::builder().uri(uri).method(req.control().method().unwrap_or_default());
for header in req.header().fields() {
request = request.header(header.name(), header.value())
}
let request = request.body(Body::from(body)).unwrap();
let request = request.body(Body::from(body))?;

let response = target.call(request).await.unwrap();
let response = target.call(request).await?;

let (parts, body) = response.into_parts();
let mut bhttp_res = bhttp::Message::response(parts.status.as_u16());
let full_body = hyper::body::to_bytes(body).await.unwrap();
let full_body = hyper::body::to_bytes(body).await?;
bhttp_res.write_content(&full_body);
let mut bhttp_bytes = Vec::new();
bhttp_res.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes).unwrap();
let ohttp_res = res_ctx.encapsulate(&bhttp_bytes).unwrap();
(StatusCode::OK, ohttp_res)
bhttp_res.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes)?;
let ohttp_res = res_ctx.encapsulate(&bhttp_bytes)?;
Ok((StatusCode::OK, ohttp_res))
}

fn ohttp_config(server: &ohttp::Server) -> Result<String> {
Expand Down
16 changes: 16 additions & 0 deletions payjoin/src/receive/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@ pub enum Error {
BadRequest(RequestError),
// To be returned as HTTP 500
Server(Box<dyn error::Error>),
// V2 d/encapsulation failed
#[cfg(feature = "v2")]
V2(crate::v2::Error),
}

impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self {
Self::BadRequest(e) => e.fmt(f),
Self::Server(e) => write!(f, "Internal Server Error: {}", e),
#[cfg(feature = "v2")]
Self::V2(e) => e.fmt(f),
}
}
}
Expand All @@ -23,6 +28,8 @@ impl error::Error for Error {
match &self {
Self::BadRequest(_) => None,
Self::Server(e) => Some(e.as_ref()),
#[cfg(feature = "v2")]
Self::V2(e) => Some(e),
}
}
}
Expand All @@ -31,6 +38,15 @@ impl From<RequestError> for Error {
fn from(e: RequestError) -> Self { Error::BadRequest(e) }
}

impl From<InternalRequestError> for Error {
fn from(e: InternalRequestError) -> Self { Error::BadRequest(e.into()) }
}

impl From<crate::v2::Error> for Error {
#[cfg(feature = "v2")]
fn from(e: crate::v2::Error) -> Self { Error::V2(e) }
}

/// Error that may occur when the request from sender is malformed.
///
/// This is currently opaque type because we aren't sure which variants will stay.
Expand Down
36 changes: 20 additions & 16 deletions payjoin/src/receive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ use url::Url;
use crate::input_type::InputType;
use crate::optional_parameters::Params;
use crate::psbt::PsbtExt;
use crate::v2;

pub trait Headers {
fn get_header(&self, key: &str) -> Option<&str>;
Expand Down Expand Up @@ -328,32 +329,30 @@ impl EnrollContext {
format!("{}/{}", self.subdirectory(), crate::v2::RECEIVE)
}

pub fn enroll_body(&mut self) -> (Vec<u8>, ohttp::ClientResponse) {
pub fn enroll_body(&mut self) -> Result<(Vec<u8>, ohttp::ClientResponse), crate::v2::Error> {
let receive_endpoint = self.receive_subdir();
log::debug!("{}{}", self.relay_url.as_str(), receive_endpoint);
let (ohttp_req, ctx) = crate::v2::ohttp_encapsulate(
crate::v2::ohttp_encapsulate(
&self.ohttp_config,
"GET",
format!("{}{}", self.relay_url.as_str(), receive_endpoint).as_str(),
None,
);

(ohttp_req, ctx)
)
}

pub fn parse_relay_response(
&self,
mut body: impl std::io::Read,
context: ohttp::ClientResponse,
) -> Result<Option<UncheckedProposal>, RequestError> {
) -> Result<Option<UncheckedProposal>, Error> {
let mut buf = Vec::new();
let _ = body.read_to_end(&mut buf);
let response = crate::v2::ohttp_decapsulate(context, &buf);
let response = crate::v2::ohttp_decapsulate(context, &buf)?;
if response.is_empty() {
log::debug!("response is empty");
return Ok(None);
}
let (proposal, e) = crate::v2::decrypt_message_a(&response, self.s.secret_key());
let (proposal, e) = crate::v2::decrypt_message_a(&response, self.s.secret_key())?;
let mut proposal = serde_json::from_slice::<UncheckedProposal>(&proposal)
.map_err(InternalRequestError::Json)?;
proposal.psbt = proposal.psbt.validate().map_err(InternalRequestError::InconsistentPsbt)?;
Expand Down Expand Up @@ -935,21 +934,26 @@ impl PayjoinProposal {
#[cfg(feature = "v2")]
pub fn extract_v2_req(
&self,
ohttp_config: &str,
ohttp_config: &Vec<u8>,
receive_endpoint: &str,
) -> (Vec<u8>, ohttp::ClientResponse) {
) -> Result<(Vec<u8>, ohttp::ClientResponse), Error> {
let e = self.v2_context.unwrap(); // TODO make v2 only
let mut payjoin_bytes = self.payjoin_psbt.serialize();
let body = crate::v2::encrypt_message_b(&mut payjoin_bytes, e);
let ohttp_config = bitcoin::base64::decode_config(ohttp_config, base64::URL_SAFE).unwrap();
let body = crate::v2::encrypt_message_b(&mut payjoin_bytes, e)?;
dbg!(receive_endpoint);
crate::v2::ohttp_encapsulate(&ohttp_config, "POST", receive_endpoint, Some(&body))
let (req, ctx) =
crate::v2::ohttp_encapsulate(&ohttp_config, "POST", receive_endpoint, Some(&body))?;
Ok((req, ctx))
}

#[cfg(feature = "v2")]
pub fn deserialize_res(&self, res: Vec<u8>, ohttp_context: ohttp::ClientResponse) -> Vec<u8> {
pub fn deserialize_res(
&self,
res: Vec<u8>,
ohttp_context: ohttp::ClientResponse,
) -> Result<Vec<u8>, Error> {
// display success or failure
crate::v2::ohttp_decapsulate(ohttp_context, &res)
let res = crate::v2::ohttp_decapsulate(ohttp_context, &res)?;
Ok(res)
}
}

Expand Down
44 changes: 37 additions & 7 deletions payjoin/src/send/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,22 @@ pub struct ValidationError {

#[derive(Debug)]
pub(crate) enum InternalValidationError {
Psbt(bitcoin::psbt::PsbtParseError),
PsbtParse(bitcoin::psbt::PsbtParseError),
Io(std::io::Error),
InvalidInputType(InputTypeError),
InvalidProposedInput(crate::psbt::PrevTxOutError),
VersionsDontMatch { proposed: i32, original: i32 },
LockTimesDontMatch { proposed: LockTime, original: LockTime },
SenderTxinSequenceChanged { proposed: Sequence, original: Sequence },
VersionsDontMatch {
proposed: i32,
original: i32,
},
LockTimesDontMatch {
proposed: LockTime,
original: LockTime,
},
SenderTxinSequenceChanged {
proposed: Sequence,
original: Sequence,
},
SenderTxinContainsNonWitnessUtxo,
SenderTxinContainsWitnessUtxo,
SenderTxinContainsFinalScriptSig,
Expand All @@ -32,7 +41,10 @@ pub(crate) enum InternalValidationError {
ReceiverTxinNotFinalized,
ReceiverTxinMissingUtxoInfo,
MixedSequence,
MixedInputTypes { proposed: InputType, original: InputType },
MixedInputTypes {
proposed: InputType,
original: InputType,
},
MissingOrShuffledInputs,
TxOutContainsKeyPaths,
FeeContributionExceedsMaximum,
Expand All @@ -44,6 +56,10 @@ pub(crate) enum InternalValidationError {
PayeeTookContributedFee,
FeeContributionPaysOutputSizeIncrease,
FeeRateBelowMinimum,
#[cfg(feature = "v2")]
V2(crate::v2::Error),
#[cfg(feature = "v2")]
Psbt(bitcoin::psbt::Error),
}

impl From<InternalValidationError> for ValidationError {
Expand All @@ -58,7 +74,7 @@ impl fmt::Display for ValidationError {
use InternalValidationError::*;

match &self.internal {
Psbt(e) => write!(f, "couldn't decode PSBT: {}", e),
PsbtParse(e) => write!(f, "couldn't decode PSBT: {}", e),
Io(e) => write!(f, "couldn't read PSBT: {}", e),
InvalidInputType(e) => write!(f, "invalid transaction input type: {}", e),
InvalidProposedInput(e) => write!(f, "invalid proposed transaction input: {}", e),
Expand Down Expand Up @@ -86,6 +102,10 @@ impl fmt::Display for ValidationError {
PayeeTookContributedFee => write!(f, "payee tried to take fee contribution for himself"),
FeeContributionPaysOutputSizeIncrease => write!(f, "fee contribution pays for additional outputs"),
FeeRateBelowMinimum => write!(f, "the fee rate of proposed transaction is below minimum"),
#[cfg(feature = "v2")]
V2(e) => write!(f, "v2 error: {}", e),
#[cfg(feature = "v2")]
Psbt(e) => write!(f, "psbt error: {}", e),
}
}
}
Expand All @@ -95,7 +115,7 @@ impl std::error::Error for ValidationError {
use InternalValidationError::*;

match &self.internal {
Psbt(error) => Some(error),
PsbtParse(error) => Some(error),
Io(error) => Some(error),
InvalidInputType(error) => Some(error),
InvalidProposedInput(error) => Some(error),
Expand Down Expand Up @@ -123,6 +143,10 @@ impl std::error::Error for ValidationError {
PayeeTookContributedFee => None,
FeeContributionPaysOutputSizeIncrease => None,
FeeRateBelowMinimum => None,
#[cfg(feature = "v2")]
V2(error) => Some(error),
#[cfg(feature = "v2")]
Psbt(error) => Some(error),
}
}
}
Expand Down Expand Up @@ -152,6 +176,8 @@ pub(crate) enum InternalCreateRequestError {
UriDoesNotSupportPayjoin,
PrevTxOut(crate::psbt::PrevTxOutError),
InputType(crate::input_type::InputTypeError),
#[cfg(feature = "v2")]
V2(crate::v2::Error),
}

impl fmt::Display for CreateRequestError {
Expand All @@ -174,6 +200,8 @@ impl fmt::Display for CreateRequestError {
UriDoesNotSupportPayjoin => write!(f, "the URI does not support payjoin"),
PrevTxOut(e) => write!(f, "invalid previous transaction output: {}", e),
InputType(e) => write!(f, "invalid input type: {}", e),
#[cfg(feature = "v2")]
V2(e) => write!(f, "v2 error: {}", e),
}
}
}
Expand All @@ -198,6 +226,8 @@ impl std::error::Error for CreateRequestError {
UriDoesNotSupportPayjoin => None,
PrevTxOut(error) => Some(error),
InputType(error) => Some(error),
#[cfg(feature = "v2")]
V2(error) => Some(error),
}
}
}
Expand Down
Loading

0 comments on commit 38f7dc1

Please sign in to comment.