Skip to content

Commit

Permalink
btcec/schnorr/musig2: update to musig 1.0.0
Browse files Browse the repository at this point in the history
The major change in musig 1.0.0 is that plain public keys are used as
input to key aggregation.
  • Loading branch information
Roasbeef committed Oct 21, 2022
1 parent a34e777 commit 1567f20
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 55 deletions.
8 changes: 2 additions & 6 deletions btcec/schnorr/musig2/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,7 @@ func NewContext(signingKey *btcec.PrivateKey, shouldSort bool,
option(opts)
}

pubKey, err := schnorr.ParsePubKey(
schnorr.SerializePubKey(signingKey.PubKey()),
)
if err != nil {
return nil, err
}
pubKey := signingKey.PubKey()

ctx := &Context{
signingKey: signingKey,
Expand Down Expand Up @@ -243,6 +238,7 @@ func NewContext(signingKey *btcec.PrivateKey, shouldSort bool,
// the nonce now to pass in to the session once all the callers
// are known.
if opts.earlyNonce {
var err error
ctx.sessionNonce, err = GenNonces()
if err != nil {
return nil, err
Expand Down
31 changes: 14 additions & 17 deletions btcec/schnorr/musig2/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ type sortableKeys []*btcec.PublicKey
// with index j.
func (s sortableKeys) Less(i, j int) bool {
// TODO(roasbeef): more efficient way to compare...
keyIBytes := schnorr.SerializePubKey(s[i])
keyJBytes := schnorr.SerializePubKey(s[j])
keyIBytes := s[i].SerializeCompressed()
keyJBytes := s[j].SerializeCompressed()

return bytes.Compare(keyIBytes, keyJBytes) == -1
}
Expand All @@ -56,9 +56,9 @@ func (s sortableKeys) Len() int {
return len(s)
}

// sortKeys takes a set of schnorr public keys and returns a new slice that is
// a copy of the keys sorted in lexicographical order bytes on the x-only
// pubkey serialization.
// sortKeys takes a set of public keys and returns a new slice that is a copy
// of the keys sorted in lexicographical order bytes on the x-only pubkey
// serialization.
func sortKeys(keys []*btcec.PublicKey) []*btcec.PublicKey {
keySet := sortableKeys(keys)
if sort.IsSorted(keySet) {
Expand All @@ -72,36 +72,33 @@ func sortKeys(keys []*btcec.PublicKey) []*btcec.PublicKey {
// keyHashFingerprint computes the tagged hash of the series of (sorted) public
// keys passed as input. This is used to compute the aggregation coefficient
// for each key. The final computation is:
// * H(tag=KeyAgg list, pk1 || pk2..)
// - H(tag=KeyAgg list, pk1 || pk2..)
func keyHashFingerprint(keys []*btcec.PublicKey, sort bool) []byte {
if sort {
keys = sortKeys(keys)
}

// We'll create a single buffer and slice into that so the bytes buffer
// doesn't continually need to grow the underlying buffer.
keyAggBuf := make([]byte, 32*len(keys))
keyAggBuf := make([]byte, 33*len(keys))
keyBytes := bytes.NewBuffer(keyAggBuf[0:0])
for _, key := range keys {
keyBytes.Write(schnorr.SerializePubKey(key))
keyBytes.Write(key.SerializeCompressed())
}

h := chainhash.TaggedHash(KeyAggTagList, keyBytes.Bytes())
return h[:]
}

// keyBytesEqual returns true if two keys are the same from the PoV of BIP
// 340's 32-byte x-only public keys.
// keyBytesEqual returns true if two keys are the same based on the compressed
// serialization of each key.
func keyBytesEqual(a, b *btcec.PublicKey) bool {
return bytes.Equal(
schnorr.SerializePubKey(a),
schnorr.SerializePubKey(b),
)
return bytes.Equal(a.SerializeCompressed(), b.SerializeCompressed())
}

// aggregationCoefficient computes the key aggregation coefficient for the
// specified target key. The coefficient is computed as:
// * H(tag=KeyAgg coefficient, keyHashFingerprint(pks) || pk)
// - H(tag=KeyAgg coefficient, keyHashFingerprint(pks) || pk)
func aggregationCoefficient(keySet []*btcec.PublicKey,
targetKey *btcec.PublicKey, keysHash []byte,
secondKeyIdx int) *btcec.ModNScalar {
Expand All @@ -116,9 +113,9 @@ func aggregationCoefficient(keySet []*btcec.PublicKey,
// Otherwise, we'll compute the full finger print hash for this given
// key and then use that to compute the coefficient tagged hash:
// * H(tag=KeyAgg coefficient, keyHashFingerprint(pks, pk) || pk)
var coefficientBytes [64]byte
var coefficientBytes [65]byte
copy(coefficientBytes[:], keysHash[:])
copy(coefficientBytes[32:], schnorr.SerializePubKey(targetKey))
copy(coefficientBytes[32:], targetKey.SerializeCompressed())

muHash := chainhash.TaggedHash(KeyAggTagCoeff, coefficientBytes[:])

Expand Down
7 changes: 1 addition & 6 deletions btcec/schnorr/musig2/musig2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -993,12 +993,7 @@ func testMultiPartySign(t *testing.T, taprootTweak []byte,
t.Fatalf("unable to gen priv key: %v", err)
}

pubKey, err := schnorr.ParsePubKey(
schnorr.SerializePubKey(privKey.PubKey()),
)
if err != nil {
t.Fatalf("unable to gen key: %v", err)
}
pubKey := privKey.PubKey()

signerKeys[i] = privKey
signSet[i] = pubKey
Expand Down
2 changes: 1 addition & 1 deletion btcec/schnorr/musig2/nonces.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ func withCustomOptions(customOpts nonceGenOpts) NonceGenOption {
o.randReader = customOpts.randReader
o.secretKey = customOpts.secretKey
o.combinedKey = customOpts.combinedKey
o.msg = customOpts.msg
o.auxInput = customOpts.auxInput
o.msg = customOpts.msg
}
}

Expand Down
36 changes: 11 additions & 25 deletions btcec/schnorr/musig2/sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,31 +299,22 @@ func Sign(secNonce [SecNonceSize]byte, privKey *btcec.PrivateKey,
}

pubKey := privKey.PubKey()
pubKeyYIsOdd := func() bool {
pubKeyBytes := pubKey.SerializeCompressed()
return pubKeyBytes[0] == secp.PubKeyFormatCompressedOdd
}()
combinedKeyYIsOdd := func() bool {
combinedKeyBytes := combinedKey.FinalKey.SerializeCompressed()
return combinedKeyBytes[0] == secp.PubKeyFormatCompressedOdd
}()

// Next we'll compute our two parity factors for Q the combined public
// key, and P, the public key we're signing with. If the keys are odd,
// then we'll negate them.
// Next we'll compute the two parity factors for Q, the combined key.
// If the key is odd, then we'll negate it.
parityCombinedKey := new(btcec.ModNScalar).SetInt(1)
paritySignKey := new(btcec.ModNScalar).SetInt(1)
if combinedKeyYIsOdd {
parityCombinedKey.Negate()
}
if pubKeyYIsOdd {
paritySignKey.Negate()
}

// Before we sign below, we'll multiply by our various parity factors
// to ensure that the signing key is properly negated (if necessary):
// * d = gv⋅gaccv⋅gp⋅d'
privKeyScalar.Mul(parityCombinedKey).Mul(paritySignKey).Mul(parityAcc)
// * d = g⋅gacc⋅d'
privKeyScalar.Mul(parityCombinedKey).Mul(parityAcc)

// Next we'll create the challenge hash that commits to the combined
// nonce, combined public key and also the message:
Expand Down Expand Up @@ -372,7 +363,7 @@ func (p *PartialSignature) Verify(pubNonce [PubNonceSize]byte,
combinedNonce [PubNonceSize]byte, keySet []*btcec.PublicKey,
signingKey *btcec.PublicKey, msg [32]byte, signOpts ...SignOption) bool {

pubKey := schnorr.SerializePubKey(signingKey)
pubKey := signingKey.SerializeCompressed()

return verifyPartialSig(
p, pubNonce, combinedNonce, keySet, pubKey, msg, signOpts...,
Expand All @@ -398,7 +389,6 @@ func verifyPartialSig(partialSig *PartialSignature, pubNonce [PubNonceSize]byte,
// Next we'll parse out the two public nonces into something we can
// use.
//

// Compute the hash of all the keys here as we'll need it do aggregate
// the keys and also at the final step of verification.
keysHash := keyHashFingerprint(keySet, opts.sortKeys)
Expand Down Expand Up @@ -458,7 +448,6 @@ func verifyPartialSig(partialSig *PartialSignature, pubNonce [PubNonceSize]byte,
// With our nonce blinding value, we'll now combine both the public
// nonces, using the blinding factor to tweak the second nonce:
// * R = R_1 + b*R_2

var nonce btcec.JacobianPoint
btcec.ScalarMultNonConst(&nonceBlinder, &r2J, &r2J)
btcec.AddNonConst(&r1J, &r2J, &nonce)
Expand Down Expand Up @@ -516,7 +505,7 @@ func verifyPartialSig(partialSig *PartialSignature, pubNonce [PubNonceSize]byte,
var e btcec.ModNScalar
e.SetByteSlice(challengeBytes[:])

signingKey, err := schnorr.ParsePubKey(pubKey)
signingKey, err := btcec.ParsePubKey(pubKey)
if err != nil {
return err
}
Expand All @@ -527,27 +516,24 @@ func verifyPartialSig(partialSig *PartialSignature, pubNonce [PubNonceSize]byte,

// If the combined key has an odd y coordinate, then we'll negate
// parity factor for the signing key.
paritySignKey := new(btcec.ModNScalar).SetInt(1)
parityCombinedKey := new(btcec.ModNScalar).SetInt(1)
combinedKeyBytes := combinedKey.FinalKey.SerializeCompressed()
if combinedKeyBytes[0] == secp.PubKeyFormatCompressedOdd {
paritySignKey.Negate()
parityCombinedKey.Negate()
}

// Next, we'll construct the final parity factor by multiplying the
// sign key parity factor with the accumulated parity factor for all
// the keys.
finalParityFactor := paritySignKey.Mul(parityAcc)
finalParityFactor := parityCombinedKey.Mul(parityAcc)

// Now we'll multiply the parity factor by our signing key, which'll
// take care of the amount of negation needed.
var signKeyJ btcec.JacobianPoint
signingKey.AsJacobian(&signKeyJ)
btcec.ScalarMultNonConst(finalParityFactor, &signKeyJ, &signKeyJ)

// In the final set, we'll check that: s*G == R' + e*a*P.
// In the final set, we'll check that: s*G == R' + e*a*g*P.
var sG, rP btcec.JacobianPoint
btcec.ScalarBaseMultNonConst(s, &sG)
btcec.ScalarMultNonConst(e.Mul(a), &signKeyJ, &rP)
btcec.ScalarMultNonConst(e.Mul(a).Mul(finalParityFactor), &signKeyJ, &rP)
btcec.AddNonConst(&rP, &pubNonceJ, &rP)

sG.ToAffine()
Expand Down

0 comments on commit 1567f20

Please sign in to comment.