Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shorten subdirectory IDs #386

Merged
merged 3 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions payjoin-directory/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@ use tracing::debug;
const DEFAULT_COLUMN: &str = "";
const PJ_V1_COLUMN: &str = "pjv1";

// TODO move to payjoin crate as pub?
// TODO impl From<HpkePublicKey> for ShortId
// TODO impl Display for ShortId (Base64)
// TODO impl TryFrom<&str> for ShortId (Base64)
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct ShortId(pub [u8; 8]);

impl ShortId {
pub fn column_key(&self, column: &str) -> Vec<u8> {
self.0.iter().chain(column.as_bytes()).copied().collect()
}
}

#[derive(Debug, Clone)]
pub(crate) struct DbPool {
client: Client,
Expand All @@ -19,23 +32,28 @@ impl DbPool {
Ok(Self { client, timeout })
}

pub async fn push_default(&self, pubkey_id: &str, data: Vec<u8>) -> RedisResult<()> {
pub async fn push_default(&self, pubkey_id: &ShortId, data: Vec<u8>) -> RedisResult<()> {
self.push(pubkey_id, DEFAULT_COLUMN, data).await
}

pub async fn peek_default(&self, pubkey_id: &str) -> Option<RedisResult<Vec<u8>>> {
pub async fn peek_default(&self, pubkey_id: &ShortId) -> Option<RedisResult<Vec<u8>>> {
self.peek_with_timeout(pubkey_id, DEFAULT_COLUMN).await
}

pub async fn push_v1(&self, pubkey_id: &str, data: Vec<u8>) -> RedisResult<()> {
pub async fn push_v1(&self, pubkey_id: &ShortId, data: Vec<u8>) -> RedisResult<()> {
self.push(pubkey_id, PJ_V1_COLUMN, data).await
}

pub async fn peek_v1(&self, pubkey_id: &str) -> Option<RedisResult<Vec<u8>>> {
pub async fn peek_v1(&self, pubkey_id: &ShortId) -> Option<RedisResult<Vec<u8>>> {
self.peek_with_timeout(pubkey_id, PJ_V1_COLUMN).await
}

async fn push(&self, pubkey_id: &str, channel_type: &str, data: Vec<u8>) -> RedisResult<()> {
async fn push(
&self,
pubkey_id: &ShortId,
channel_type: &str,
data: Vec<u8>,
) -> RedisResult<()> {
let mut conn = self.client.get_async_connection().await?;
let key = channel_name(pubkey_id, channel_type);
() = conn.set(&key, data.clone()).await?;
Expand All @@ -45,13 +63,13 @@ impl DbPool {

async fn peek_with_timeout(
&self,
pubkey_id: &str,
pubkey_id: &ShortId,
channel_type: &str,
) -> Option<RedisResult<Vec<u8>>> {
tokio::time::timeout(self.timeout, self.peek(pubkey_id, channel_type)).await.ok()
}

async fn peek(&self, pubkey_id: &str, channel_type: &str) -> RedisResult<Vec<u8>> {
async fn peek(&self, pubkey_id: &ShortId, channel_type: &str) -> RedisResult<Vec<u8>> {
let mut conn = self.client.get_async_connection().await?;
let key = channel_name(pubkey_id, channel_type);

Expand Down Expand Up @@ -99,6 +117,6 @@ impl DbPool {
}
}

fn channel_name(pubkey_id: &str, channel_type: &str) -> String {
format!("{}:{}", pubkey_id, channel_type)
fn channel_name(pubkey_id: &ShortId, channel_type: &str) -> Vec<u8> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This helper is no longer that helpful IMO and could be replaced by direct calls to column_key() in later commits.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed... do you think the TODO comments for ShortId are missing anything or could be simplified?

pubkey_id.column_key(channel_type)
}
22 changes: 17 additions & 5 deletions payjoin-directory/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::sync::Arc;
use std::time::Duration;

use anyhow::Result;
use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD;
use bitcoin::base64::Engine;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty, Full};
use hyper::body::{Body, Bytes, Incoming};
Expand All @@ -15,6 +17,8 @@ use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tracing::{debug, error, info, trace};

use crate::db::ShortId;

pub const DEFAULT_DIR_PORT: u16 = 8080;
pub const DEFAULT_DB_HOST: &str = "localhost:6379";
pub const DEFAULT_TIMEOUT_SECS: u64 = 30;
Expand Down Expand Up @@ -295,7 +299,7 @@ async fn post_fallback_v1(
};

let v2_compat_body = format!("{}\n{}", body_str, query);
let id = shorten_string(id);
let id = decode_short_id(id)?;
pool.push_default(&id, v2_compat_body.into())
.await
.map_err(|e| HandlerError::BadRequest(e.into()))?;
Expand All @@ -316,7 +320,7 @@ async fn put_payjoin_v1(
trace!("Put_payjoin_v1");
let ok_response = Response::builder().status(StatusCode::OK).body(empty())?;

let id = shorten_string(id);
let id = decode_short_id(id)?;
let req =
body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes();
if req.len() > MAX_BUFFER_SIZE {
Expand All @@ -337,7 +341,7 @@ async fn post_subdir(
let none_response = Response::builder().status(StatusCode::OK).body(empty())?;
trace!("post_subdir");

let id = shorten_string(id);
let id = decode_short_id(id)?;
let req =
body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes();
if req.len() > MAX_BUFFER_SIZE {
Expand All @@ -355,7 +359,7 @@ async fn get_subdir(
pool: DbPool,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
trace!("get_subdir");
let id = shorten_string(id);
let id = decode_short_id(id)?;
match pool.peek_default(&id).await {
Some(result) => match result {
Ok(buffered_req) => Ok(Response::new(full(buffered_req))),
Expand Down Expand Up @@ -385,7 +389,15 @@ async fn get_ohttp_keys(
Ok(res)
}

fn shorten_string(input: &str) -> String { input.chars().take(8).collect() }
fn decode_short_id(input: &str) -> Result<ShortId, HandlerError> {
let decoded =
BASE64_URL_SAFE_NO_PAD.decode(input).map_err(|e| HandlerError::BadRequest(e.into()))?;

decoded[..8]
.try_into()
.map_err(|_| HandlerError::BadRequest(anyhow::anyhow!("Invalid subdirectory ID")))
.map(ShortId)
}

fn empty() -> BoxBody<Bytes, hyper::Error> {
Empty::<Bytes>::new().map_err(|never| match never {}).boxed()
Expand Down
20 changes: 13 additions & 7 deletions payjoin/src/receive/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::time::{Duration, SystemTime};

use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD;
use bitcoin::base64::Engine;
use bitcoin::hashes::{sha256, Hash};
use bitcoin::psbt::Psbt;
use bitcoin::{Address, FeeRate, OutPoint, Script, TxOut};
use serde::de::Deserializer;
Expand Down Expand Up @@ -48,7 +49,8 @@ where
}

fn subdir_path_from_pubkey(pubkey: &HpkePublicKey) -> String {
BASE64_URL_SAFE_NO_PAD.encode(pubkey.to_compressed_bytes())
let hash = sha256::Hash::hash(&pubkey.to_compressed_bytes());
BASE64_URL_SAFE_NO_PAD.encode(&hash.as_byte_array()[..8])
}

/// A payjoin V2 receiver, allowing for polled requests to the
Expand Down Expand Up @@ -188,22 +190,26 @@ impl Receiver {
)
}

// The contents of the `&pj=` query parameter including the base64url-encoded public key receiver subdirectory.
// The contents of the `&pj=` query parameter.
// This identifies a session at the payjoin directory server.
pub fn pj_url(&self) -> Url {
let pubkey = &self.id();
let pubkey_base64 = BASE64_URL_SAFE_NO_PAD.encode(pubkey);
let id_base64 = BASE64_URL_SAFE_NO_PAD.encode(self.id());
let mut url = self.context.directory.clone();
{
let mut path_segments =
url.path_segments_mut().expect("Payjoin Directory URL cannot be a base");
path_segments.push(&pubkey_base64);
path_segments.push(&id_base64);
}
url
}

/// The per-session public key to use as an identifier
pub fn id(&self) -> [u8; 33] { self.context.s.public_key().to_compressed_bytes() }
/// The per-session identifier
pub fn id(&self) -> [u8; 8] {
let hash = sha256::Hash::hash(&self.context.s.public_key().to_compressed_bytes());
hash.as_byte_array()[..8]
.try_into()
.expect("truncating SHA256 to 8 bytes should always succeed")
}
}

/// The sender's original PSBT and optional parameters
Expand Down
8 changes: 6 additions & 2 deletions payjoin/src/send/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ use std::str::FromStr;

#[cfg(feature = "v2")]
use bitcoin::base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine};
#[cfg(feature = "v2")]
use bitcoin::hashes::{sha256, Hash};
use bitcoin::psbt::Psbt;
use bitcoin::{Amount, FeeRate, Script, ScriptBuf, TxOut, Weight};
pub use error::{CreateRequestError, ResponseError, ValidationError};
Expand Down Expand Up @@ -394,8 +396,10 @@ impl V2GetContext {
) -> Result<(Request, ohttp::ClientResponse), CreateRequestError> {
use crate::uri::UrlExt;
let mut url = self.endpoint.clone();
let subdir = BASE64_URL_SAFE_NO_PAD
.encode(self.hpke_ctx.reply_pair.public_key().to_compressed_bytes());

// TODO unify with receiver's fn subdir_path_from_pubkey
let hash = sha256::Hash::hash(&self.hpke_ctx.reply_pair.public_key().to_compressed_bytes());
let subdir = BASE64_URL_SAFE_NO_PAD.encode(&hash.as_byte_array()[..8]);
url.set_path(&subdir);
let body = encrypt_message_a(
Vec::new(),
Expand Down