Skip to content

Commit

Permalink
Merge pull request #151 from aspectron/pskb-sign-fix
Browse files Browse the repository at this point in the history
Pskb sign fix
  • Loading branch information
aspect authored Jan 25, 2025
2 parents b9db244 + 54a5e16 commit cb7dcf7
Showing 1 changed file with 42 additions and 29 deletions.
71 changes: 42 additions & 29 deletions wallet/core/src/account/pskb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ pub async fn bundle_from_pskt_generator(generator: PSKTGenerator) -> Result<Bund

Ok(bundle)
}

pub async fn pskb_signer_for_address(
bundle: &Bundle,
signer: Arc<PSKBSigner>,
Expand All @@ -163,64 +162,78 @@ pub async fn pskb_signer_for_address(
key_fingerprint: KeyFingerprint,
) -> Result<Bundle, Error> {
let mut signed_bundle = Bundle::new();
let reused_values = SigHashReusedValuesUnsync::new();

// If set, sign-for address is used for signing.
// Else, all addresses from inputs are.
let addresses: Vec<Address> = match sign_for_address {
Some(signer) => vec![signer.clone()],
None => bundle
// If sign_for_address is provided, we'll use it for all signatures
// Otherwise, collect addresses per PSKT
let addresses_per_pskt: Vec<Vec<Address>> = if sign_for_address.is_some() {
// Create a vec of single-address vecs
bundle.iter().map(|_| vec![sign_for_address.unwrap().clone()]).collect()
} else {
// Collect addresses for each PSKT separately
bundle
.iter()
.flat_map(|inner| {
inner.inputs
.map(|inner| {
inner
.inputs
.iter()
.filter_map(|input| input.utxo_entry.as_ref()) // Filter out None and get a reference to UtxoEntry if it exists
.filter_map(|input| input.utxo_entry.as_ref())
.filter_map(|utxo_entry| {
extract_script_pub_key_address(&utxo_entry.script_public_key.clone(), network_id.into()).ok()
})
.collect::<Vec<Address>>()
.collect()
})
.collect(),
.collect()
};

// Prepare the signer.
signer.ingest(addresses.as_ref())?;
// Prepare the signer with all unique addresses
let all_addresses: Vec<Address> = addresses_per_pskt.iter().flat_map(|addresses| addresses.iter().cloned()).collect();
signer.ingest(all_addresses.as_slice())?;

for pskt_inner in bundle.iter().cloned() {
// Process each PSKT in the bundle
for (pskt_idx, pskt_inner) in bundle.iter().cloned().enumerate() {
let pskt: PSKT<Signer> = PSKT::from(pskt_inner);
let current_addresses = &addresses_per_pskt[pskt_idx];

// Create new reused values for each PSKT
let reused_values = SigHashReusedValuesUnsync::new();

let sign = |signer_pskt: PSKT<Signer>| {
let sign = |signer_pskt: PSKT<Signer>| -> Result<PSKT<Signer>, Error> {
signer_pskt
.pass_signature_sync(|tx, sighash| -> Result<Vec<SignInputOk>, String> {
tx.tx
.inputs
.iter()
.enumerate()
.map(|(idx, _input)| {
let hash = calc_schnorr_signature_hash(&tx.as_verifiable(), idx, sighash[idx], &reused_values);
let msg = secp256k1::Message::from_digest_slice(hash.as_bytes().as_slice()).unwrap();

// When address represents a locked UTXO, no private key is available.
// Instead, use the account receive address' private key.
let address: &Address = match sign_for_address {
Some(address) => address,
None => addresses.get(idx).expect("Input indexed address"),
.map(|(input_idx, _input)| {
let hash = calc_schnorr_signature_hash(&tx.as_verifiable(), input_idx, sighash[input_idx], &reused_values);
let msg = secp256k1::Message::from_digest_slice(hash.as_bytes().as_slice()).map_err(|e| e.to_string())?;

// Get the appropriate address for this input
let address = if let Some(sign_addr) = sign_for_address {
sign_addr
} else {
current_addresses.get(input_idx).ok_or_else(|| format!("No address found for input {}", input_idx))?
};

let public_key = signer.public_key(address).expect("Public key for input indexed address");
let public_key = signer.public_key(address).map_err(|e| format!("Failed to get public key: {}", e))?;

let signature = signer.sign_schnorr(address, msg).map_err(|e| format!("Failed to sign: {}", e))?;

Ok(SignInputOk {
signature: Signature::Schnorr(signer.sign_schnorr(address, msg).unwrap()),
signature: Signature::Schnorr(signature),
pub_key: public_key,
key_source: Some(KeySource { key_fingerprint, derivation_path: derivation_path.clone() }),
})
})
.collect()
})
.unwrap()
.map_err(Error::from)
};
signed_bundle.add_pskt(sign(pskt.clone()));

let signed_pskt = sign(pskt)?;
signed_bundle.add_pskt(signed_pskt);
}

Ok(signed_bundle)
}

Expand Down

0 comments on commit cb7dcf7

Please sign in to comment.