diff --git a/README.md b/README.md index 8caa19f..a03f9db 100644 --- a/README.md +++ b/README.md @@ -29,11 +29,8 @@ negotiation in libp2p can be found in the [connection establishment spec][conn-s ## Status -This implementation is being updated to track some recent changes to the [spec][noise-libp2p-spec]: - -- [ ] [use of channel binding token to prevent replay attacks](https://github.com/libp2p/specs/pull/234) - -We recommend waiting until those changes are in place before adopting go-libp2p-noise for production use. +This implementation is currently considered "feature complete," but it has not yet +been widely tested in a production environment. ## Install @@ -45,7 +42,7 @@ As `go-libp2p-noise` is still in development, it is not included as a default de go get github.com/libp2p/go-libp2p-noise ``` -This repo is [gomod](https://github.com/golang/go/wiki/Modules)-compatible, and users of +This repo is [gomod](https://github.com/golang/go/wiki/Modules) compatible, and users of go 1.12 and later with modules enabled will automatically pull the latest tagged release by referencing this package. Upgrades to future releases can be managed using `go get`, or by editing your `go.mod` file as [described by the gomod documentation](https://github.com/golang/go/wiki/Modules#how-to-upgrade-and-downgrade-dependencies). @@ -55,8 +52,27 @@ or by editing your `go.mod` file as [described by the gomod documentation](https `go-libp2p-noise` is not currently enabled by default when constructing a new libp2p [Host][godoc-host], so you will need to explicitly enable it in order to use it. -The API for constructing a new Noise transport is currently in flux. This README will -be updated once it stabilizes. +To do so, you can pass `noise.New` as an argument to a `libp2p.Security` `Option` when +constructing a libp2p `Host` with `libp2p.New`: + +```go +import ( + libp2p "github.com/libp2p/go-libp2p" + noise "github.com/libp2p/go-libp2p-noise" +) + +// wherever you create your libp2p instance: +host := libp2p.New( + libp2p.Security(noise.ID, noise.New) +) +``` + +Note that the above snippet will _replace_ the default security protocols. To add Noise +as an additional protocol, chain it to the default options instead: + +```go +libp2p.ChainOptions(libp2p.DefaultSecurity, libp2p.Security(noise.ID, noise.New)) +``` ## Contribute diff --git a/crypto.go b/crypto.go index 9d53d4e..e837c7a 100644 --- a/crypto.go +++ b/crypto.go @@ -2,56 +2,39 @@ package noise import ( "errors" - ik "github.com/libp2p/go-libp2p-noise/ik" xx "github.com/libp2p/go-libp2p-noise/xx" ) func (s *secureSession) Encrypt(plaintext []byte) (ciphertext []byte, err error) { - if s.xx_complete { - if s.initiator { - cs := s.xx_ns.CS1() - _, ciphertext = xx.EncryptWithAd(cs, nil, plaintext) - } else { - cs := s.xx_ns.CS2() - _, ciphertext = xx.EncryptWithAd(cs, nil, plaintext) - } - } else if s.ik_complete { - if s.initiator { - cs := s.ik_ns.CS1() - _, ciphertext = ik.EncryptWithAd(cs, nil, plaintext) - } else { - cs := s.ik_ns.CS2() - _, ciphertext = ik.EncryptWithAd(cs, nil, plaintext) - } - } else { + if !s.handshakeComplete { return nil, errors.New("encrypt err: haven't completed handshake") } + if s.initiator { + cs := s.ns.CS1() + _, ciphertext = xx.EncryptWithAd(cs, nil, plaintext) + } else { + cs := s.ns.CS2() + _, ciphertext = xx.EncryptWithAd(cs, nil, plaintext) + } + return ciphertext, nil } func (s *secureSession) Decrypt(ciphertext []byte) (plaintext []byte, err error) { var ok bool - if s.xx_complete { - if s.initiator { - cs := s.xx_ns.CS2() - _, plaintext, ok = xx.DecryptWithAd(cs, nil, ciphertext) - } else { - cs := s.xx_ns.CS1() - _, plaintext, ok = xx.DecryptWithAd(cs, nil, ciphertext) - } - } else if s.ik_complete { - if s.initiator { - cs := s.ik_ns.CS2() - _, plaintext, ok = ik.DecryptWithAd(cs, nil, ciphertext) - } else { - cs := s.ik_ns.CS1() - _, plaintext, ok = ik.DecryptWithAd(cs, nil, ciphertext) - } - } else { + if !s.handshakeComplete { return nil, errors.New("decrypt err: haven't completed handshake") } + if s.initiator { + cs := s.ns.CS2() + _, plaintext, ok = xx.DecryptWithAd(cs, nil, ciphertext) + } else { + cs := s.ns.CS1() + _, plaintext, ok = xx.DecryptWithAd(cs, nil, ciphertext) + } + if !ok { return nil, errors.New("decrypt err: could not decrypt") } diff --git a/crypto_test.go b/crypto_test.go index dff7ab7..efc1a8f 100644 --- a/crypto_test.go +++ b/crypto_test.go @@ -2,6 +2,8 @@ package noise import ( "bytes" + "context" + "net" "testing" "github.com/libp2p/go-libp2p-core/crypto" @@ -63,3 +65,41 @@ func TestEncryptAndDecrypt_RespToInit(t *testing.T) { t.Fatal(err) } } + +func TestCryptoFailsIfCiphertextIsAltered(t *testing.T) { + initTransport := newTestTransport(t, crypto.Ed25519, 2048) + respTransport := newTestTransport(t, crypto.Ed25519, 2048) + + initConn, respConn := connect(t, initTransport, respTransport) + defer initConn.Close() + defer respConn.Close() + + plaintext := []byte("helloworld") + ciphertext, err := respConn.Encrypt(plaintext) + if err != nil { + t.Fatal(err) + } + + ciphertext[0] = ^ciphertext[0] + + _, err = initConn.Decrypt(ciphertext) + if err == nil { + t.Fatal("expected decryption to fail when ciphertext altered") + } +} + +func TestCryptoFailsIfHandshakeIncomplete(t *testing.T) { + initTransport := newTestTransport(t, crypto.Ed25519, 2048) + init, resp := net.Pipe() + _ = resp.Close() + + session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", true) + _, err := session.Encrypt([]byte("hi")) + if err == nil { + t.Error("expected encryption error when handshake incomplete") + } + _, err = session.Decrypt([]byte("it's a secret")) + if err == nil { + t.Error("expected decryption error when handshake incomplete") + } +} diff --git a/handshake.go b/handshake.go new file mode 100644 index 0000000..7fcd857 --- /dev/null +++ b/handshake.go @@ -0,0 +1,240 @@ +package noise + +import ( + "context" + "fmt" + "github.com/gogo/protobuf/proto" + "github.com/libp2p/go-libp2p-core/crypto" + "github.com/libp2p/go-libp2p-core/peer" + + "github.com/libp2p/go-libp2p-noise/pb" + "github.com/libp2p/go-libp2p-noise/xx" +) + +// payloadSigPrefix is prepended to our Noise static key before signing with +// our libp2p identity key. +const payloadSigPrefix = "noise-libp2p-static-key:" + +func (s *secureSession) setRemotePeerInfo(key []byte) (err error) { + s.remote.libp2pKey, err = crypto.UnmarshalPublicKey(key) + return err +} + +func (s *secureSession) setRemotePeerID(key crypto.PubKey) (err error) { + s.remotePeer, err = peer.IDFromPublicKey(key) + return err +} + +func (s *secureSession) verifyPayload(payload *pb.NoiseHandshakePayload, noiseKey [32]byte) (err error) { + sig := payload.GetIdentitySig() + msg := append([]byte(payloadSigPrefix), noiseKey[:]...) + + ok, err := s.RemotePublicKey().Verify(msg, sig) + if err != nil { + return err + } else if !ok { + return fmt.Errorf("did not verify payload") + } + + return nil +} + +func (s *secureSession) sendHandshakeMessage(payload []byte, initialStage bool) error { + var msgbuf xx.MessageBuffer + s.ns, msgbuf = xx.SendMessage(s.ns, payload) + var encMsgBuf []byte + if initialStage { + encMsgBuf = msgbuf.Encode0() + } else { + encMsgBuf = msgbuf.Encode1() + } + + err := s.writeMsgInsecure(encMsgBuf) + if err != nil { + return fmt.Errorf("sendHandshakeMessage write to conn err=%s", err) + } + + return nil +} + +func (s *secureSession) recvHandshakeMessage(initialStage bool) (buf []byte, plaintext []byte, valid bool, err error) { + buf, err = s.readMsgInsecure() + if err != nil { + return nil, nil, false, fmt.Errorf("recvHandshakeMessage read length err=%s", err) + } + + var msgbuf *xx.MessageBuffer + if initialStage { + msgbuf, err = xx.Decode0(buf) + } else { + msgbuf, err = xx.Decode1(buf) + } + + if err != nil { + return buf, nil, false, fmt.Errorf("recvHandshakeMessage decode msg err=%s", err) + } + + s.ns, plaintext, valid = xx.RecvMessage(s.ns, msgbuf) + if !valid { + return buf, nil, false, fmt.Errorf("recvHandshakeMessage validation fail") + } + + return buf, plaintext, valid, nil +} + +// Runs the XX handshake +// XX: +// -> e +// <- e, ee, s, es +// -> s, se +func (s *secureSession) runHandshake(ctx context.Context) (err error) { + + // setup libp2p keys + localKeyRaw, err := s.LocalPublicKey().Bytes() + if err != nil { + return fmt.Errorf("runHandshake err getting raw pubkey: %s", err) + } + + // sign noise data for payload + noise_pub := s.noiseKeypair.publicKey + signedPayload, err := s.localKey.Sign(append([]byte(payloadSigPrefix), noise_pub[:]...)) + if err != nil { + return fmt.Errorf("runHandshake signing payload err=%s", err) + } + + // create payload + payload := new(pb.NoiseHandshakePayload) + payload.IdentityKey = localKeyRaw + payload.IdentitySig = signedPayload + payloadEnc, err := proto.Marshal(payload) + if err != nil { + return fmt.Errorf("runHandshake proto marshal payload err=%s", err) + } + + kp := xx.NewKeypair(s.noiseKeypair.publicKey, s.noiseKeypair.privateKey) + + // new XX noise session + s.ns = xx.InitSession(s.initiator, s.prologue, kp, [32]byte{}) + + if s.initiator { + // stage 0 // + + err = s.sendHandshakeMessage(nil, true) + if err != nil { + return fmt.Errorf("runHandshake stage 0 initiator fail: %s", err) + } + + // stage 1 // + + var plaintext []byte + var valid bool + // read reply + _, plaintext, valid, err = s.recvHandshakeMessage(false) + if err != nil { + return fmt.Errorf("runHandshake initiator stage 1 fail: %s", err) + } + if !valid { + return fmt.Errorf("runHandshake stage 1 initiator validation fail") + } + + // stage 2 // + + err = s.sendHandshakeMessage(payloadEnc, false) + if err != nil { + return fmt.Errorf("runHandshake stage=2 initiator=true err=%s", err) + } + + // unmarshal payload + nhp := new(pb.NoiseHandshakePayload) + err = proto.Unmarshal(plaintext, nhp) + if err != nil { + return fmt.Errorf("runHandshake stage=2 initiator=true err=cannot unmarshal payload") + } + + // set remote libp2p public key + err = s.setRemotePeerInfo(nhp.GetIdentityKey()) + if err != nil { + return fmt.Errorf("runHandshake stage=2 initiator=true read remote libp2p key fail") + } + + // assert that remote peer ID matches libp2p public key + pid, err := peer.IDFromPublicKey(s.RemotePublicKey()) + if pid != s.remotePeer { + return fmt.Errorf("runHandshake stage=2 initiator=true check remote peer id err: expected %x got %x", s.remotePeer, pid) + } else if err != nil { + return fmt.Errorf("runHandshake stage 2 initiator check remote peer id err %s", err) + } + + // verify payload is signed by libp2p key + err = s.verifyPayload(nhp, s.ns.RemoteKey()) + if err != nil { + return fmt.Errorf("runHandshake stage=2 initiator=true verify payload err=%s", err) + } + + } else { + + // stage 0 // + + var plaintext []byte + var valid bool + nhp := new(pb.NoiseHandshakePayload) + + // read message + _, plaintext, valid, err = s.recvHandshakeMessage(true) + if err != nil { + return fmt.Errorf("runHandshake stage=0 initiator=false err=%s", err) + } + + if !valid { + return fmt.Errorf("runHandshake stage=0 initiator=false err=validation fail") + } + + // stage 1 // + + err = s.sendHandshakeMessage(payloadEnc, false) + if err != nil { + return fmt.Errorf("runHandshake stage=1 initiator=false err=%s", err) + } + + // stage 2 // + + // read message + _, plaintext, valid, err = s.recvHandshakeMessage(false) + if err != nil { + return fmt.Errorf("runHandshake stage=2 initiator=false err=%s", err) + } + + if !valid { + return fmt.Errorf("runHandshake stage=2 initiator=false err=validation fail") + } + + // unmarshal payload + err = proto.Unmarshal(plaintext, nhp) + if err != nil { + return fmt.Errorf("runHandshake stage=2 initiator=false err=cannot unmarshal payload") + } + + // set remote libp2p public key + err = s.setRemotePeerInfo(nhp.GetIdentityKey()) + if err != nil { + return fmt.Errorf("runHandshake stage=2 initiator=false read remote libp2p key fail") + } + + // set remote libp2p public key from payload + err = s.setRemotePeerID(s.RemotePublicKey()) + if err != nil { + return fmt.Errorf("runHandshake stage=2 initiator=false set remote peer id err=%s", err) + } + + s.remote.noiseKey = s.ns.RemoteKey() + + // verify payload is signed by libp2p key + err = s.verifyPayload(nhp, s.remote.noiseKey) + if err != nil { + return fmt.Errorf("runHandshake stage=2 initiator=false err=%s", err) + } + } + + s.handshakeComplete = true + return nil +} diff --git a/ik/IK.noise.go b/ik/IK.noise.go deleted file mode 100644 index 724a98e..0000000 --- a/ik/IK.noise.go +++ /dev/null @@ -1,533 +0,0 @@ -/* -IK: - <- s - ... - -> e, es, s, ss - <- e, ee, se -*/ - -// Implementation Version: 1.0.0 - -/* ---------------------------------------------------------------- * - * PARAMETERS * - * ---------------------------------------------------------------- */ - -package ik - -import ( - "crypto/rand" - "crypto/subtle" - "encoding/binary" - "errors" - "golang.org/x/crypto/blake2s" - "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/hkdf" - "hash" - "io" - "math" -) - -/* ---------------------------------------------------------------- * - * TYPES * - * ---------------------------------------------------------------- */ - -type Keypair struct { - public_key [32]byte - private_key [32]byte -} - -type MessageBuffer struct { - ne [32]byte - ns []byte - ciphertext []byte -} - -type cipherstate struct { - k [32]byte - n uint32 -} - -type symmetricstate struct { - cs cipherstate - ck [32]byte - h [32]byte -} - -type handshakestate struct { - ss symmetricstate - s Keypair - e Keypair - rs [32]byte - re [32]byte - psk [32]byte -} - -type NoiseSession struct { - hs handshakestate - h [32]byte - cs1 cipherstate - cs2 cipherstate - mc uint64 - i bool -} - -func (ns *NoiseSession) CS1() *cipherstate { - return &ns.cs1 -} - -func (ns *NoiseSession) CS2() *cipherstate { - return &ns.cs2 -} - -func (ns *NoiseSession) Ephemeral() *Keypair { - return &ns.hs.e -} - -func NewKeypair(pub [32]byte, priv [32]byte) Keypair { - return Keypair{ - public_key: pub, - private_key: priv, - } -} - -func (kp Keypair) PubKey() [32]byte { - return kp.public_key -} - -func (kp Keypair) PrivKey() [32]byte { - return kp.private_key -} - -func (ns *NoiseSession) RemoteKey() [32]byte { - return ns.hs.rs -} - -func NewMessageBuffer(ne [32]byte, ns []byte, ciphertext []byte) MessageBuffer { - return MessageBuffer{ - ne: ne, - ns: ns, - ciphertext: ciphertext, - } -} - -/* ---------------------------------------------------------------- * - * CONSTANTS * - * ---------------------------------------------------------------- */ - -var emptyKey = [32]byte{ - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, -} - -var minNonce = uint32(0) - -/* ---------------------------------------------------------------- * - * UTILITY FUNCTIONS * - * ---------------------------------------------------------------- */ - -func getPublicKey(kp *Keypair) [32]byte { - return kp.public_key -} - -func isEmptyKey(k [32]byte) bool { - return subtle.ConstantTimeCompare(k[:], emptyKey[:]) == 1 -} - -func (mb *MessageBuffer) NE() [32]byte { - return mb.ne -} - -func (mb *MessageBuffer) NS() []byte { - return mb.ns -} - -func (mb *MessageBuffer) Ciphertext() []byte { - return mb.ciphertext -} - -// Encodes a MessageBuffer from stage 0 -func (mb *MessageBuffer) Encode0() []byte { - enc := []byte{} - - enc = append(enc, mb.ne[:]...) - enc = append(enc, mb.ns...) - enc = append(enc, mb.ciphertext...) - - return enc -} - -// Encodes a MessageBuffer from stage 1 -func (mb *MessageBuffer) Encode1() []byte { - enc := []byte{} - - enc = append(enc, mb.ne[:]...) - enc = append(enc, mb.ciphertext...) - - return enc -} - -// Decodes initial message (stage 0) into MessageBuffer -func Decode0(in []byte) (*MessageBuffer, error) { - if len(in) < 80 { - return nil, errors.New("cannot decode stage 0 MessageBuffer: length less than 80 bytes") - } - - mb := new(MessageBuffer) - copy(mb.ne[:], in[:32]) - mb.ns = in[32:80] - mb.ciphertext = in[80:] - - return mb, nil -} - -// Decodes messages at stage 1 into MessageBuffer -func Decode1(in []byte) (*MessageBuffer, error) { - if len(in) < 32 { - return nil, errors.New("cannot decode stage 1 MessageBuffer: length less than 32 bytes") - } - - mb := new(MessageBuffer) - copy(mb.ne[:], in[:32]) - mb.ciphertext = in[32:] - - return mb, nil -} - -func validatePublicKey(k []byte) bool { - forbiddenCurveValues := [12][]byte{ - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - {224, 235, 122, 124, 59, 65, 184, 174, 22, 86, 227, 250, 241, 159, 196, 106, 218, 9, 141, 235, 156, 50, 177, 253, 134, 98, 5, 22, 95, 73, 184, 0}, - {95, 156, 149, 188, 163, 80, 140, 36, 177, 208, 177, 85, 156, 131, 239, 91, 4, 68, 92, 196, 88, 28, 142, 134, 216, 34, 78, 221, 208, 159, 17, 87}, - {236, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127}, - {237, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127}, - {238, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127}, - {205, 235, 122, 124, 59, 65, 184, 174, 22, 86, 227, 250, 241, 159, 196, 106, 218, 9, 141, 235, 156, 50, 177, 253, 134, 98, 5, 22, 95, 73, 184, 128}, - {76, 156, 149, 188, 163, 80, 140, 36, 177, 208, 177, 85, 156, 131, 239, 91, 4, 68, 92, 196, 88, 28, 142, 134, 216, 34, 78, 221, 208, 159, 17, 215}, - {217, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}, - {218, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}, - {219, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 25}, - } - - for _, testValue := range forbiddenCurveValues { - if subtle.ConstantTimeCompare(k[:], testValue[:]) == 1 { - panic("Invalid public key") - } - } - return true -} - -/* ---------------------------------------------------------------- * - * PRIMITIVES * - * ---------------------------------------------------------------- */ - -func incrementNonce(n uint32) uint32 { - return n + 1 -} - -func dh(private_key [32]byte, public_key [32]byte) [32]byte { - var ss [32]byte - curve25519.ScalarMult(&ss, &private_key, &public_key) - return ss -} - -func GenerateKeypair() Keypair { - var public_key [32]byte - var private_key [32]byte - _, _ = rand.Read(private_key[:]) - curve25519.ScalarBaseMult(&public_key, &private_key) - if validatePublicKey(public_key[:]) { - return Keypair{public_key, private_key} - } - return GenerateKeypair() -} - -func GeneratePublicKey(private_key [32]byte) [32]byte { - var public_key [32]byte - curve25519.ScalarBaseMult(&public_key, &private_key) - return public_key -} - -func encrypt(k [32]byte, n uint32, ad []byte, plaintext []byte) []byte { - var nonce [12]byte - var ciphertext []byte - enc, _ := chacha20poly1305.New(k[:]) - binary.LittleEndian.PutUint32(nonce[4:], n) - ciphertext = enc.Seal(nil, nonce[:], plaintext, ad) - return ciphertext -} - -func decrypt(k [32]byte, n uint32, ad []byte, ciphertext []byte) (bool, []byte, []byte) { - var nonce [12]byte - var plaintext []byte - enc, err := chacha20poly1305.New(k[:]) - binary.LittleEndian.PutUint32(nonce[4:], n) - plaintext, err = enc.Open(nil, nonce[:], ciphertext, ad) - return (err == nil), ad, plaintext -} - -func getHash(a []byte, b []byte) [32]byte { - return blake2s.Sum256(append(a, b...)) -} - -func hashProtocolName(protocolName []byte) [32]byte { - var h [32]byte - if len(protocolName) <= 32 { - copy(h[:], protocolName) - } else { - h = getHash(protocolName, []byte{}) - } - return h -} - -func blake2HkdfInterface() hash.Hash { - h, _ := blake2s.New256([]byte{}) - return h -} - -func getHkdf(ck [32]byte, ikm []byte) ([32]byte, [32]byte, [32]byte) { - var k1 [32]byte - var k2 [32]byte - var k3 [32]byte - output := hkdf.New(blake2HkdfInterface, ikm[:], ck[:], []byte{}) - io.ReadFull(output, k1[:]) - io.ReadFull(output, k2[:]) - io.ReadFull(output, k3[:]) - return k1, k2, k3 -} - -/* ---------------------------------------------------------------- * - * STATE MANAGEMENT * - * ---------------------------------------------------------------- */ - -/* CipherState */ -func initializeKey(k [32]byte) cipherstate { - return cipherstate{k, minNonce} -} - -func hasKey(cs *cipherstate) bool { - return !isEmptyKey(cs.k) -} - -func setNonce(cs *cipherstate, newNonce uint32) *cipherstate { - cs.n = newNonce - return cs -} - -func EncryptWithAd(cs *cipherstate, ad []byte, plaintext []byte) (*cipherstate, []byte) { - e := encrypt(cs.k, cs.n, ad, plaintext) - cs = setNonce(cs, incrementNonce(cs.n)) - return cs, e -} - -func DecryptWithAd(cs *cipherstate, ad []byte, ciphertext []byte) (*cipherstate, []byte, bool) { - valid, ad, plaintext := decrypt(cs.k, cs.n, ad, ciphertext) - cs = setNonce(cs, incrementNonce(cs.n)) - return cs, plaintext, valid -} - -func reKey(cs *cipherstate) *cipherstate { - e := encrypt(cs.k, math.MaxUint32, []byte{}, emptyKey[:]) - copy(cs.k[:], e) - return cs -} - -/* SymmetricState */ - -func initializeSymmetric(protocolName []byte) symmetricstate { - h := hashProtocolName(protocolName) - ck := h - cs := initializeKey(emptyKey) - return symmetricstate{cs, ck, h} -} - -func mixKey(ss *symmetricstate, ikm [32]byte) *symmetricstate { - ck, tempK, _ := getHkdf(ss.ck, ikm[:]) - ss.cs = initializeKey(tempK) - ss.ck = ck - return ss -} - -func mixHash(ss *symmetricstate, data []byte) *symmetricstate { - ss.h = getHash(ss.h[:], data) - return ss -} - -func mixKeyAndHash(ss *symmetricstate, ikm [32]byte) *symmetricstate { - var tempH [32]byte - var tempK [32]byte - ss.ck, tempH, tempK = getHkdf(ss.ck, ikm[:]) - ss = mixHash(ss, tempH[:]) - ss.cs = initializeKey(tempK) - return ss -} - -func getHandshakeHash(ss *symmetricstate) [32]byte { - return ss.h -} - -func encryptAndHash(ss *symmetricstate, plaintext []byte) (*symmetricstate, []byte) { - var ciphertext []byte - if hasKey(&ss.cs) { - _, ciphertext = EncryptWithAd(&ss.cs, ss.h[:], plaintext) - } else { - ciphertext = plaintext - } - ss = mixHash(ss, ciphertext) - return ss, ciphertext -} - -func decryptAndHash(ss *symmetricstate, ciphertext []byte) (*symmetricstate, []byte, bool) { - var plaintext []byte - var valid bool - if hasKey(&ss.cs) { - _, plaintext, valid = DecryptWithAd(&ss.cs, ss.h[:], ciphertext) - } else { - plaintext, valid = ciphertext, true - } - ss = mixHash(ss, ciphertext) - return ss, plaintext, valid -} - -func split(ss *symmetricstate) (cipherstate, cipherstate) { - tempK1, tempK2, _ := getHkdf(ss.ck, []byte{}) - cs1 := initializeKey(tempK1) - cs2 := initializeKey(tempK2) - return cs1, cs2 -} - -/* HandshakeState */ - -func initializeInitiator(prologue []byte, s Keypair, rs [32]byte, psk [32]byte) handshakestate { - var ss symmetricstate - var e Keypair - var re [32]byte - name := []byte("Noise_IK_25519_ChaChaPoly_SHA256") - ss = initializeSymmetric(name) - mixHash(&ss, prologue) - mixHash(&ss, rs[:]) - return handshakestate{ss, s, e, rs, re, psk} -} - -func initializeResponder(prologue []byte, s Keypair, rs [32]byte, psk [32]byte) handshakestate { - var ss symmetricstate - var e Keypair - var re [32]byte - name := []byte("Noise_IK_25519_ChaChaPoly_SHA256") - ss = initializeSymmetric(name) - mixHash(&ss, prologue) - mixHash(&ss, s.public_key[:]) - return handshakestate{ss, s, e, rs, re, psk} -} - -func writeMessageA(hs *handshakestate, payload []byte) (*handshakestate, MessageBuffer) { - ne, ns, ciphertext := emptyKey, []byte{}, []byte{} - hs.e = GenerateKeypair() - ne = hs.e.public_key - mixHash(&hs.ss, ne[:]) - /* No PSK, so skipping mixKey */ - mixKey(&hs.ss, dh(hs.e.private_key, hs.rs)) - spk := make([]byte, len(hs.s.public_key)) - copy(spk[:], hs.s.public_key[:]) - _, ns = encryptAndHash(&hs.ss, spk) - mixKey(&hs.ss, dh(hs.s.private_key, hs.rs)) - _, ciphertext = encryptAndHash(&hs.ss, payload) - messageBuffer := MessageBuffer{ne, ns, ciphertext} - return hs, messageBuffer -} - -func writeMessageB(hs *handshakestate, payload []byte) ([32]byte, MessageBuffer, cipherstate, cipherstate) { - ne, ns, ciphertext := emptyKey, []byte{}, []byte{} - hs.e = GenerateKeypair() - ne = hs.e.public_key - mixHash(&hs.ss, ne[:]) - /* No PSK, so skipping mixKey */ - mixKey(&hs.ss, dh(hs.e.private_key, hs.re)) - mixKey(&hs.ss, dh(hs.e.private_key, hs.rs)) - _, ciphertext = encryptAndHash(&hs.ss, payload) - messageBuffer := MessageBuffer{ne, ns, ciphertext} - cs1, cs2 := split(&hs.ss) - return hs.ss.h, messageBuffer, cs1, cs2 -} - -func readMessageA(hs *handshakestate, message *MessageBuffer) (*handshakestate, []byte, bool) { - valid1 := true - if validatePublicKey(message.ne[:]) { - hs.re = message.ne - } - mixHash(&hs.ss, hs.re[:]) - /* No PSK, so skipping mixKey */ - mixKey(&hs.ss, dh(hs.s.private_key, hs.re)) - _, ns, valid1 := decryptAndHash(&hs.ss, message.ns) - if valid1 && len(ns) == 32 && validatePublicKey(message.ns[:]) { - copy(hs.rs[:], ns) - } - mixKey(&hs.ss, dh(hs.s.private_key, hs.rs)) - _, plaintext, valid2 := decryptAndHash(&hs.ss, message.ciphertext) - return hs, plaintext, (valid1 && valid2) -} - -func readMessageB(hs *handshakestate, message *MessageBuffer) ([32]byte, []byte, bool, cipherstate, cipherstate) { - valid1 := true - if validatePublicKey(message.ne[:]) { - hs.re = message.ne - } - mixHash(&hs.ss, hs.re[:]) - /* No PSK, so skipping mixKey */ - mixKey(&hs.ss, dh(hs.e.private_key, hs.re)) - mixKey(&hs.ss, dh(hs.s.private_key, hs.re)) - _, plaintext, valid2 := decryptAndHash(&hs.ss, message.ciphertext) - cs1, cs2 := split(&hs.ss) - return hs.ss.h, plaintext, (valid1 && valid2), cs1, cs2 -} - -/* ---------------------------------------------------------------- * - * PROCESSES * - * ---------------------------------------------------------------- */ - -func InitSession(initiator bool, prologue []byte, s Keypair, rs [32]byte) *NoiseSession { - var session NoiseSession - psk := emptyKey - if initiator { - session.hs = initializeInitiator(prologue, s, rs, psk) - } else { - session.hs = initializeResponder(prologue, s, rs, psk) - } - session.i = initiator - session.mc = 0 - return &session -} - -func SendMessage(session *NoiseSession, message []byte) (*NoiseSession, MessageBuffer) { - var messageBuffer MessageBuffer - if session.mc == 0 { - _, messageBuffer = writeMessageA(&session.hs, message) - } - if session.mc == 1 { - session.h, messageBuffer, session.cs1, session.cs2 = writeMessageB(&session.hs, message) - } - session.mc = session.mc + 1 - return session, messageBuffer -} - -func RecvMessage(session *NoiseSession, message *MessageBuffer) (*NoiseSession, []byte, bool) { - var plaintext []byte - var valid bool - if session.mc == 0 { - _, plaintext, valid = readMessageA(&session.hs, message) - } - if session.mc == 1 { - session.h, plaintext, valid, session.cs1, session.cs2 = readMessageB(&session.hs, message) - } - session.mc = session.mc + 1 - return session, plaintext, valid -} - -func main() {} diff --git a/ik/IK_test.go b/ik/IK_test.go deleted file mode 100644 index f657769..0000000 --- a/ik/IK_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package ik - -import ( - "crypto/rand" - proto "github.com/gogo/protobuf/proto" - "github.com/libp2p/go-libp2p-core/crypto" - pb "github.com/libp2p/go-libp2p-noise/pb" - "testing" -) - -func TestHandshake(t *testing.T) { - // generate local static noise key - kp_init := GenerateKeypair() - kp_resp := GenerateKeypair() - - payload_string := []byte("noise-libp2p-static-key:") - prologue := []byte("/noise/0.0.0") - - // initiator setup - init_pub := kp_init.PubKey() - libp2p_priv_init, libp2p_pub_init, err := crypto.GenerateEd25519Key(rand.Reader) - if err != nil { - t.Fatal(err) - } - libp2p_pub_init_raw, err := libp2p_pub_init.Raw() - if err != nil { - t.Fatal(err) - } - libp2p_init_signed_payload, err := libp2p_priv_init.Sign(append(payload_string, init_pub[:]...)) - if err != nil { - t.Fatal(err) - } - - // respoonder setup - resp_pub := kp_resp.PubKey() - libp2p_priv_resp, libp2p_pub_resp, err := crypto.GenerateEd25519Key(rand.Reader) - if err != nil { - t.Fatal(err) - } - libp2p_pub_resp_raw, err := libp2p_pub_resp.Raw() - if err != nil { - t.Fatal(err) - } - libp2p_resp_signed_payload, err := libp2p_priv_resp.Sign(append(payload_string, resp_pub[:]...)) - if err != nil { - t.Fatal(err) - } - - // initiator: new IK noise session - ns_init := InitSession(true, prologue, kp_init, kp_resp.PubKey()) - - // responder: new IK noise session - ns_resp := InitSession(false, prologue, kp_resp, [32]byte{}) - - // stage 0: initiator - // create payload - payload_init := new(pb.NoiseHandshakePayload) - payload_init.IdentityKey = libp2p_pub_init_raw - payload_init.IdentitySig = libp2p_init_signed_payload - payload_init_enc, err := proto.Marshal(payload_init) - if err != nil { - t.Fatalf("proto marshal payload fail: %s", err) - } - - // send message - var msgbuf MessageBuffer - msg := []byte{} - msg = append(msg, payload_init_enc[:]...) - ns_init, msgbuf = SendMessage(ns_init, msg) - - t.Logf("stage 0 msgbuf: %v", msgbuf) - t.Logf("stage 0 msgbuf ne len: %d", len(msgbuf.NE())) - - // stage 0: responder - var plaintext []byte - var valid bool - ns_resp, plaintext, valid = RecvMessage(ns_resp, &msgbuf) - if !valid { - t.Fatalf("stage 0 receive not valid") - } - - t.Logf("stage 0 resp payload: %x", plaintext) - - // stage 1: responder - // create payload - payload_resp := new(pb.NoiseHandshakePayload) - payload_resp.IdentityKey = libp2p_pub_resp_raw - payload_resp.IdentitySig = libp2p_resp_signed_payload - payload_resp_enc, err := proto.Marshal(payload_resp) - if err != nil { - t.Fatalf("proto marshal payload fail: %s", err) - } - msg = append(msg, payload_resp_enc[:]...) - ns_resp, msgbuf = SendMessage(ns_resp, msg) - - t.Logf("stage 1 msgbuf: %v", msgbuf) - t.Logf("stage 1 msgbuf ne len: %d", len(msgbuf.NE())) - t.Logf("stage 1 msgbuf ns len: %d", len(msgbuf.NS())) - - // stage 1: initiator - ns_init, plaintext, valid = RecvMessage(ns_init, &msgbuf) - if !valid { - t.Fatalf("stage 1 receive not valid") - } - - t.Logf("stage 1 resp payload: %x", plaintext) - -} diff --git a/ik_handshake.go b/ik_handshake.go deleted file mode 100644 index 2ee6797..0000000 --- a/ik_handshake.go +++ /dev/null @@ -1,182 +0,0 @@ -package noise - -import ( - "context" - "fmt" - "github.com/libp2p/go-libp2p-core/peer" - "io" - - proto "github.com/gogo/protobuf/proto" - ik "github.com/libp2p/go-libp2p-noise/ik" - pb "github.com/libp2p/go-libp2p-noise/pb" -) - -func (s *secureSession) ik_sendHandshakeMessage(payload []byte, initial_stage bool) error { - var msgbuf ik.MessageBuffer - s.ik_ns, msgbuf = ik.SendMessage(s.ik_ns, payload) - var encMsgBuf []byte - if initial_stage { - encMsgBuf = msgbuf.Encode0() - } else { - encMsgBuf = msgbuf.Encode1() - } - - err := s.writeLength(len(encMsgBuf)) - if err != nil { - return fmt.Errorf("ik_sendHandshakeMessage write length err=%s", err) - } - - // send message - _, err = s.insecure.Write(encMsgBuf) - if err != nil { - return fmt.Errorf("ik_sendHandshakeMessage write to conn err=%s", err) - } - - return nil -} - -func (s *secureSession) ik_recvHandshakeMessage(initial_stage bool) (buf []byte, plaintext []byte, valid bool, err error) { - l, err := s.readLength() - if err != nil { - return nil, nil, false, fmt.Errorf("ik_recvHandshakeMessage read length err=%s", err) - } - - buf = make([]byte, l) - - _, err = io.ReadFull(s.insecure, buf) - if err != nil { - return buf, nil, false, fmt.Errorf("ik_recvHandshakeMessage read from conn err=%s", err) - } - - var msgbuf *ik.MessageBuffer - if initial_stage { - msgbuf, err = ik.Decode0(buf) - } else { - msgbuf, err = ik.Decode1(buf) - } - - if err != nil { - return buf, nil, false, fmt.Errorf("ik_recvHandshakeMessage decode msg fail: %s", err) - } - - s.ik_ns, plaintext, valid = ik.RecvMessage(s.ik_ns, msgbuf) - if !valid { - return buf, nil, false, fmt.Errorf("ik_recvHandshakeMessage validation fail") - } - - return buf, plaintext, valid, nil -} - -// IK: -// <- s -// ... -// -> e, es, s, ss -// <- e, ee, se -// returns last successful message upon error -func (s *secureSession) runHandshake_ik(ctx context.Context, payload []byte) ([]byte, error) { - kp := ik.NewKeypair(s.noiseKeypair.publicKey, s.noiseKeypair.privateKey) - - remoteNoiseKey := s.noiseStaticKeyCache.Load(s.remotePeer) - - // new IK noise session - s.ik_ns = ik.InitSession(s.initiator, s.prologue, kp, remoteNoiseKey) - - if s.initiator { - // bail out early if we don't know the remote Noise key - if remoteNoiseKey == [32]byte{} { - return nil, fmt.Errorf("runHandshake_ik aborting - unknown static key for peer %s", s.remotePeer.Pretty()) - } - - // stage 0 // - err := s.ik_sendHandshakeMessage(payload, true) - if err != nil { - return nil, fmt.Errorf("runHandshake_ik stage=0 initiator=true err=%s", err) - } - - // stage 1 // - - // read message - buf, plaintext, valid, err := s.ik_recvHandshakeMessage(false) - if err != nil { - return buf, fmt.Errorf("runHandshake_ik stage=1 initiator=true err=%s", err) - } - - if !valid { - return buf, fmt.Errorf("runHandshake_ik stage=1 initiator=true err=validation fail") - } - - // unmarshal payload - nhp := new(pb.NoiseHandshakePayload) - err = proto.Unmarshal(plaintext, nhp) - if err != nil { - return buf, fmt.Errorf("runHandshake_ik stage=1 initiator=true err=validation fail: cannot unmarshal payload") - } - - // set remote libp2p public key - err = s.setRemotePeerInfo(nhp.GetIdentityKey()) - if err != nil { - return buf, fmt.Errorf("runHandshake_ik stage=1 initiator=true err=read remote libp2p key fail") - } - - // assert that remote peer ID matches libp2p public key - pid, err := peer.IDFromPublicKey(s.RemotePublicKey()) - if pid != s.remotePeer { - return buf, fmt.Errorf("runHandshake_ik stage=1 initiator=true check remote peer id err: expected %x got %x", s.remotePeer, pid) - } else if err != nil { - return buf, fmt.Errorf("runHandshake_ik stage=1 initiator=true check remote peer id err %s", err) - } - - // verify payload is signed by libp2p key - err = s.verifyPayload(nhp, remoteNoiseKey) - if err != nil { - return buf, fmt.Errorf("runHandshake_ik stage=1 initiator=true verify payload err=%s", err) - } - - } else { - // stage 0 // - - // read message - buf, plaintext, valid, err := s.ik_recvHandshakeMessage(true) - if err != nil { - return buf, fmt.Errorf("runHandshake_ik stage=0 initiator=false err=%s", err) - } - - if !valid { - return buf, fmt.Errorf("runHandshake_ik stage=0 initiator=false err: validation fail") - } - - // unmarshal payload - nhp := new(pb.NoiseHandshakePayload) - err = proto.Unmarshal(plaintext, nhp) - if err != nil { - return buf, fmt.Errorf("runHandshake_ik stage=0 initiator=false err=validation fail: cannot unmarshal payload") - } - - // set remote libp2p public key - err = s.setRemotePeerInfo(nhp.GetIdentityKey()) - if err != nil { - return buf, fmt.Errorf("runHandshake_ik stage=0 initiator=false err=read remote libp2p key fail") - } - - // assert that remote peer ID matches libp2p key - err = s.setRemotePeerID(s.RemotePublicKey()) - if err != nil { - return buf, fmt.Errorf("runHandshake_ik stage=0 initiator=false set remote peer id err=%s:", err) - } - - // verify payload is signed by libp2p key - err = s.verifyPayload(nhp, s.ik_ns.RemoteKey()) - if err != nil { - return buf, fmt.Errorf("runHandshake_ik stage=0 initiator=false verify payload err=%s", err) - } - - // stage 1 // - - err = s.ik_sendHandshakeMessage(payload, false) - if err != nil { - return nil, fmt.Errorf("runHandshake_ik stage=1 initiator=false send err=%s", err) - } - - } - return nil, nil -} diff --git a/integration_test.go b/integration_test.go index 0cb2a93..bd793b3 100644 --- a/integration_test.go +++ b/integration_test.go @@ -35,13 +35,13 @@ func generateKey(seed int64) (crypto.PrivKey, error) { return priv, nil } -func makeNode(t *testing.T, seed int64, port int, kp *Keypair) (host.Host, error) { +func makeNode(t *testing.T, seed int64, port int) (host.Host, error) { priv, err := generateKey(seed) if err != nil { t.Fatal(err) } - tpt, err := New(priv, NoiseKeyPair(kp)) + tpt, err := New(priv) if err != nil { t.Fatal(err) } @@ -62,49 +62,17 @@ func makeNode(t *testing.T, seed int64, port int, kp *Keypair) (host.Host, error return libp2p.New(ctx, options...) } -func makeNodePipes(t *testing.T, seed int64, port int, rpid peer.ID, rpubkey [32]byte, kp *Keypair) (host.Host, error) { - priv, err := generateKey(seed) - if err != nil { - t.Fatal(err) - } - - tpt, err := New(priv, UseNoisePipes, NoiseKeyPair(kp)) - if err != nil { - t.Fatal(err) - } - - tpt.noiseStaticKeyCache = NewKeyCache() - tpt.noiseStaticKeyCache.Store(rpid, rpubkey) - - ip := "0.0.0.0" - addr, err := ma.NewMultiaddr(fmt.Sprintf("/ip4/%s/tcp/%d", ip, port)) - if err != nil { - t.Fatal(err) - } - - options := []libp2p.Option{ - libp2p.Identity(priv), - libp2p.Security(ID, tpt), - libp2p.ListenAddrs(addr), - } - - ctx := context.Background() - - h, err := libp2p.New(ctx, options...) - return h, err -} - -func TestLibp2pIntegration_NoPipes(t *testing.T) { +func TestLibp2pIntegration(t *testing.T) { ctx := context.Background() - ha, err := makeNode(t, 1, 33333, nil) + ha, err := makeNode(t, 1, 33333) if err != nil { t.Fatal(err) } defer ha.Close() - hb, err := makeNode(t, 2, 34343, nil) + hb, err := makeNode(t, 2, 34343) if err != nil { t.Fatal(err) } @@ -146,137 +114,6 @@ func TestLibp2pIntegration_NoPipes(t *testing.T) { time.Sleep(time.Second) } -func TestLibp2pIntegration_WithPipes(t *testing.T) { - ctx := context.Background() - - kpa, err := GenerateKeypair() - if err != nil { - t.Fatal(err) - } - - ha, err := makeNodePipes(t, 1, 33333, "", [32]byte{}, kpa) - if err != nil { - t.Fatal(err) - } - - defer ha.Close() - - hb, err := makeNodePipes(t, 2, 34343, ha.ID(), kpa.publicKey, nil) - if err != nil { - t.Fatal(err) - } - - defer hb.Close() - - ha.SetStreamHandler(testProtocolID, streamHandler(t)) - hb.SetStreamHandler(testProtocolID, streamHandler(t)) - - addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", ha.Addrs()[0].String(), ha.ID())) - if err != nil { - t.Fatal(err) - } - - fmt.Printf("ha: %s\n", addr) - - addrInfo, err := peer.AddrInfoFromP2pAddr(addr) - if err != nil { - t.Fatal(err) - } - - err = hb.Connect(ctx, *addrInfo) - if err != nil { - t.Fatal(err) - } - - stream, err := hb.NewStream(ctx, ha.ID(), testProtocolID) - if err != nil { - t.Fatal(err) - } - - err = writeRandomPayloadAndClose(t, stream) - if err != nil { - t.Fatal(err) - } - - fmt.Println("fin") - - time.Sleep(time.Second) -} - -func TestLibp2pIntegration_XXFallback(t *testing.T) { - ctx := context.Background() - - kpa, err := GenerateKeypair() - if err != nil { - t.Fatal(err) - } - - ha, err := makeNode(t, 1, 33333, kpa) - if err != nil { - t.Fatal(err) - } - - defer ha.Close() - - hb, err := makeNodePipes(t, 2, 34343, ha.ID(), kpa.publicKey, nil) - if err != nil { - t.Fatal(err) - } - - defer hb.Close() - - ha.SetStreamHandler(testProtocolID, streamHandler(t)) - hb.SetStreamHandler(testProtocolID, streamHandler(t)) - - addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", hb.Addrs()[0].String(), hb.ID())) - if err != nil { - t.Fatal(err) - } - - fmt.Printf("ha: %s\n", addr) - - addrInfo, err := peer.AddrInfoFromP2pAddr(addr) - if err != nil { - t.Fatal(err) - } - - err = ha.Connect(ctx, *addrInfo) - if err != nil { - t.Fatal(err) - } - - stream, err := hb.NewStream(ctx, ha.ID(), testProtocolID) - if err != nil { - t.Fatal(err) - } - - err = writeRandomPayloadAndClose(t, stream) - if err != nil { - t.Fatal(err) - } - - fmt.Println("fin") - - time.Sleep(time.Second) -} - -func TestConstrucingWithMaker(t *testing.T) { - kp, err := GenerateKeypair() - if err != nil { - t.Fatal(err) - } - - ctx := context.Background() - h, err := libp2p.New(ctx, - libp2p.Security(ID, - Maker(NoiseKeyPair(kp), UseNoisePipes))) - - if err != nil { - t.Fatalf("unable to create libp2p host with Maker: %v", err) - } - _ = h.Close() -} - func writeRandomPayloadAndClose(t *testing.T, stream net.Stream) error { t.Helper() size := 1 << 24 diff --git a/keycache.go b/keycache.go deleted file mode 100644 index f43ce78..0000000 --- a/keycache.go +++ /dev/null @@ -1,30 +0,0 @@ -package noise - -import ( - "github.com/libp2p/go-libp2p-core/peer" - "sync" -) - -type KeyCache struct { - lock sync.RWMutex - m map[peer.ID][32]byte -} - -func NewKeyCache() *KeyCache { - return &KeyCache{ - lock: sync.RWMutex{}, - m: make(map[peer.ID][32]byte), - } -} - -func (kc *KeyCache) Store(p peer.ID, key [32]byte) { - kc.lock.Lock() - kc.m[p] = key - kc.lock.Unlock() -} - -func (kc *KeyCache) Load(p peer.ID) [32]byte { - kc.lock.RLock() - defer kc.lock.RUnlock() - return kc.m[p] -} diff --git a/options.go b/options.go deleted file mode 100644 index 8ea2051..0000000 --- a/options.go +++ /dev/null @@ -1,44 +0,0 @@ -package noise - -// UseNoisePipes configures the Noise transport to use the Noise Pipes pattern. -// Noise Pipes attempts to use the more efficient IK handshake pattern when -// dialing a remote peer, if that peer's static Noise key is known. If this -// is unsuccessful, the transport will fallback to using the default XX pattern. -// -// Note that the fallback does not add any additional round-trips vs. simply -// using XX in the first place, however there is a slight processing overhead -// due to the initial decryption attempt of the IK message. -func UseNoisePipes(cfg *config) { - cfg.noisePipesSupport = true -} - -// NoiseKeyPair configures the Noise transport to use the given Noise static -// keypair. This is distinct from the libp2p Host's identity keypair and is -// used only for Noise. If this option is not provided, a new Noise static -// keypair will be generated when the transport is initialized. -// -// This option is most useful when Noise Pipes is enabled, as longer static -// key lifetimes may lead to more successful IK handshake attempts. -// -// If you do use this option with a key that's been saved to disk, you must -// take care to store the key securely! -func NoiseKeyPair(kp *Keypair) Option { - return func(cfg *config) { - cfg.noiseKeypair = kp - } -} - -type config struct { - noiseKeypair *Keypair - noisePipesSupport bool -} - -type Option func(cfg *config) - -func (cfg *config) applyOptions(opts ...Option) { - for _, opt := range opts { - if opt != nil { - opt(cfg) - } - } -} diff --git a/protocol.go b/protocol.go deleted file mode 100644 index 4f5b778..0000000 --- a/protocol.go +++ /dev/null @@ -1,351 +0,0 @@ -package noise - -import ( - "context" - "encoding/binary" - "errors" - "fmt" - "io" - "net" - "sync" - "time" - - "github.com/gogo/protobuf/proto" - logging "github.com/ipfs/go-log" - "github.com/libp2p/go-libp2p-core/crypto" - "github.com/libp2p/go-libp2p-core/peer" - - "github.com/libp2p/go-libp2p-noise/ik" - "github.com/libp2p/go-libp2p-noise/pb" - "github.com/libp2p/go-libp2p-noise/xx" -) - -const payload_string = "noise-libp2p-static-key:" - -// Each encrypted transport message must be <= 65,535 bytes, including 16 -// bytes of authentication data. To write larger plaintexts, we split them -// into fragments of maxPlaintextLength before encrypting. -const maxPlaintextLength = 65519 - -var log = logging.Logger("noise") - -var errNoKeypair = errors.New("cannot initiate secureSession - transport has no noise keypair") - -type secureSession struct { - insecure net.Conn - - initiator bool - prologue []byte - - localKey crypto.PrivKey - localPeer peer.ID - remotePeer peer.ID - - local peerInfo - remote peerInfo - - xx_ns *xx.NoiseSession - ik_ns *ik.NoiseSession - - xx_complete bool - ik_complete bool - - noisePipesSupport bool - noiseStaticKeyCache *KeyCache - - noiseKeypair *Keypair - - msgBuffer []byte - readLock sync.Mutex - writeLock sync.Mutex -} - -type peerInfo struct { - noiseKey [32]byte // static noise public key - libp2pKey crypto.PubKey -} - -// newSecureSession creates a noise session over the given insecure Conn, using the static -// Noise keypair and libp2p identity keypair from the given Transport. -// -// If tpt.noisePipesSupport == true, the Noise Pipes handshake protocol will be used, -// which consists of the IK and XXfallback handshake patterns. With Noise Pipes on, we first try IK, -// if that fails, move to XXfallback. With Noise Pipes off, we always do XX. -func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, initiator bool) (*secureSession, error) { - if tpt.noiseKeypair == nil { - return nil, errNoKeypair - } - - // if the transport doesn't have a key cache, we make a new one just for - // this session. it's a bit of a waste, but saves us having to check if - // it's nil later - keyCache := tpt.noiseStaticKeyCache - if keyCache == nil { - keyCache = NewKeyCache() - } - - localPeerInfo := peerInfo{ - noiseKey: tpt.noiseKeypair.publicKey, - libp2pKey: tpt.privateKey.GetPublic(), - } - - s := &secureSession{ - insecure: insecure, - initiator: initiator, - prologue: []byte{}, - localKey: tpt.privateKey, - localPeer: tpt.localID, - remotePeer: remote, - local: localPeerInfo, - noisePipesSupport: tpt.noisePipesSupport, - noiseStaticKeyCache: keyCache, - msgBuffer: []byte{}, - noiseKeypair: tpt.noiseKeypair, - } - - err := s.runHandshake(ctx) - - return s, err -} - -func (s *secureSession) NoiseStaticKeyCache() *KeyCache { - return s.noiseStaticKeyCache -} - -func (s *secureSession) NoisePublicKey() [32]byte { - return s.noiseKeypair.publicKey -} - -func (s *secureSession) NoisePrivateKey() [32]byte { - return s.noiseKeypair.privateKey -} - -func (s *secureSession) readLength() (int, error) { - buf := make([]byte, 2) - _, err := io.ReadFull(s.insecure, buf) - return int(binary.BigEndian.Uint16(buf)), err -} - -func (s *secureSession) writeLength(length int) error { - buf := make([]byte, 2) - binary.BigEndian.PutUint16(buf, uint16(length)) - _, err := s.insecure.Write(buf) - return err -} - -func (s *secureSession) setRemotePeerInfo(key []byte) (err error) { - s.remote.libp2pKey, err = crypto.UnmarshalPublicKey(key) - return err -} - -func (s *secureSession) setRemotePeerID(key crypto.PubKey) (err error) { - s.remotePeer, err = peer.IDFromPublicKey(key) - return err -} - -func (s *secureSession) verifyPayload(payload *pb.NoiseHandshakePayload, noiseKey [32]byte) (err error) { - sig := payload.GetIdentitySig() - msg := append([]byte(payload_string), noiseKey[:]...) - - ok, err := s.RemotePublicKey().Verify(msg, sig) - if err != nil { - return err - } else if !ok { - return fmt.Errorf("did not verify payload") - } - - return nil -} - -func (s *secureSession) runHandshake(ctx context.Context) error { - // setup libp2p keys - localKeyRaw, err := s.LocalPublicKey().Bytes() - if err != nil { - return fmt.Errorf("runHandshake err getting raw pubkey: %s", err) - } - - // sign noise data for payload - noise_pub := s.noiseKeypair.publicKey - signedPayload, err := s.localKey.Sign(append([]byte(payload_string), noise_pub[:]...)) - if err != nil { - return fmt.Errorf("runHandshake signing payload err=%s", err) - } - - // create payload - payload := new(pb.NoiseHandshakePayload) - payload.IdentityKey = localKeyRaw - payload.IdentitySig = signedPayload - payloadEnc, err := proto.Marshal(payload) - if err != nil { - return fmt.Errorf("runHandshake proto marshal payload err=%s", err) - } - - // If we support Noise pipes, we try IK first, falling back to XX if IK fails. - // The exception is when we're the initiator and don't know the other party's - // static Noise key. Then IK will always fail, so we go straight to XX. - tryIK := s.noisePipesSupport - if s.initiator && s.noiseStaticKeyCache.Load(s.remotePeer) == [32]byte{} { - tryIK = false - } - if tryIK { - // we're either a responder or an initiator with a known static key for the remote peer, try IK - buf, err := s.runHandshake_ik(ctx, payloadEnc) - if err != nil { - // IK failed, pipe to XXfallback - err = s.runHandshake_xx(ctx, true, payloadEnc, buf) - if err != nil { - return fmt.Errorf("runHandshake xx err=%s", err) - } - - s.xx_complete = true - } else { - s.ik_complete = true - } - } else { - // unknown static key for peer, try XX - err := s.runHandshake_xx(ctx, false, payloadEnc, nil) - if err != nil { - return err - } - - s.xx_complete = true - } - - return nil -} - -func (s *secureSession) LocalAddr() net.Addr { - return s.insecure.LocalAddr() -} - -func (s *secureSession) LocalPeer() peer.ID { - return s.localPeer -} - -func (s *secureSession) LocalPrivateKey() crypto.PrivKey { - return s.localKey -} - -func (s *secureSession) LocalPublicKey() crypto.PubKey { - return s.localKey.GetPublic() -} - -func (s *secureSession) Read(buf []byte) (int, error) { - s.readLock.Lock() - defer s.readLock.Unlock() - - l := len(buf) - - // if we have previously unread bytes, and they fit into the buf, copy them over and return - if l <= len(s.msgBuffer) { - copy(buf, s.msgBuffer) - s.msgBuffer = s.msgBuffer[l:] - return l, nil - } - - readChunk := func(buf []byte) (int, error) { - // read length of encrypted message - l, err := s.readLength() - if err != nil { - return 0, err - } - - // read and decrypt ciphertext - ciphertext := make([]byte, l) - _, err = io.ReadFull(s.insecure, ciphertext) - if err != nil { - return 0, err - } - - plaintext, err := s.Decrypt(ciphertext) - if err != nil { - return 0, err - } - - // append plaintext to message buffer, copy over what can fit in the buf - // then advance message buffer to remove what was copied - s.msgBuffer = append(s.msgBuffer, plaintext...) - c := copy(buf, s.msgBuffer) - s.msgBuffer = s.msgBuffer[c:] - return c, nil - } - - total := 0 - for i := 0; i < len(buf); i += maxPlaintextLength { - end := i + maxPlaintextLength - if end > len(buf) { - end = len(buf) - } - - c, err := readChunk(buf[i:end]) - total += c - if err != nil { - return total, err - } - } - - return total, nil -} - -func (s *secureSession) RemoteAddr() net.Addr { - return s.insecure.RemoteAddr() -} - -func (s *secureSession) RemotePeer() peer.ID { - return s.remotePeer -} - -func (s *secureSession) RemotePublicKey() crypto.PubKey { - return s.remote.libp2pKey -} - -func (s *secureSession) SetDeadline(t time.Time) error { - return s.insecure.SetDeadline(t) -} - -func (s *secureSession) SetReadDeadline(t time.Time) error { - return s.insecure.SetReadDeadline(t) -} - -func (s *secureSession) SetWriteDeadline(t time.Time) error { - return s.insecure.SetWriteDeadline(t) -} - -func (s *secureSession) Write(in []byte) (int, error) { - s.writeLock.Lock() - defer s.writeLock.Unlock() - - writeChunk := func(in []byte) (int, error) { - ciphertext, err := s.Encrypt(in) - if err != nil { - return 0, err - } - - err = s.writeLength(len(ciphertext)) - if err != nil { - return 0, err - } - - _, err = s.insecure.Write(ciphertext) - return len(in), err - } - - written := 0 - for i := 0; i < len(in); i += maxPlaintextLength { - end := i + maxPlaintextLength - if end > len(in) { - end = len(in) - } - - l, err := writeChunk(in[i:end]) - written += l - if err != nil { - return written, err - } - } - return written, nil -} - -func (s *secureSession) Close() error { - return s.insecure.Close() -} diff --git a/rw.go b/rw.go new file mode 100644 index 0000000..fa39cff --- /dev/null +++ b/rw.go @@ -0,0 +1,127 @@ +package noise + +import ( + "encoding/binary" + "fmt" + "io" +) + +// Each encrypted transport message must be <= 65,535 bytes, including 16 +// bytes of authentication data. To write larger plaintexts, we split them +// into fragments of maxPlaintextLength before encrypting. +const maxPlaintextLength = 65519 + +// Read reads from the secure connection, filling `buf` with plaintext data. +// May read less than len(buf) if data is available immediately. +func (s *secureSession) Read(buf []byte) (int, error) { + s.readLock.Lock() + defer s.readLock.Unlock() + + l := len(buf) + + // if we have previously unread bytes, and they fit into the buf, copy them over and return + if l <= len(s.msgBuffer) { + copy(buf, s.msgBuffer) + s.msgBuffer = s.msgBuffer[l:] + return l, nil + } + + readChunk := func(buf []byte) (int, error) { + // read length of encrypted message + ciphertext, err := s.readMsgInsecure() + if err != nil { + return 0, err + } + + plaintext, err := s.Decrypt(ciphertext) + if err != nil { + return 0, err + } + + // append plaintext to message buffer, copy over what can fit in the buf + // then advance message buffer to remove what was copied + s.msgBuffer = append(s.msgBuffer, plaintext...) + c := copy(buf, s.msgBuffer) + s.msgBuffer = s.msgBuffer[c:] + return c, nil + } + + total := 0 + for i := 0; i < len(buf); i += maxPlaintextLength { + end := i + maxPlaintextLength + if end > len(buf) { + end = len(buf) + } + + c, err := readChunk(buf[i:end]) + total += c + if err != nil { + return total, err + } + } + + return total, nil +} + +// Write encrypts the plaintext `in` data and sends it on the +// secure connection. +func (s *secureSession) Write(in []byte) (int, error) { + s.writeLock.Lock() + defer s.writeLock.Unlock() + + writeChunk := func(in []byte) (int, error) { + ciphertext, err := s.Encrypt(in) + if err != nil { + return 0, err + } + + err = s.writeMsgInsecure(ciphertext) + if err != nil { + return 0, err + } + return len(in), err + } + + written := 0 + for i := 0; i < len(in); i += maxPlaintextLength { + end := i + maxPlaintextLength + if end > len(in) { + end = len(in) + } + + l, err := writeChunk(in[i:end]) + written += l + if err != nil { + return written, err + } + } + return written, nil +} + +// readMsgInsecure reads a message from the insecure channel. +// it first reads the message length, then consumes that many bytes +// from the insecure conn. +func (s *secureSession) readMsgInsecure() ([]byte, error) { + buf := make([]byte, 2) + _, err := io.ReadFull(s.insecure, buf) + if err != nil { + return nil, fmt.Errorf("error reading length prefix: %s", err) + } + size := int(binary.BigEndian.Uint16(buf)) + buf = make([]byte, size) + _, err = io.ReadFull(s.insecure, buf) + return buf, err +} + +// writeMsgInsecure writes to the insecure conn. +// data will be prefixed with its length in bytes, written as a 16-bit uint in network order. +func (s *secureSession) writeMsgInsecure(data []byte) error { + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, uint16(len(data))) + _, err := s.insecure.Write(buf) + if err != nil { + return fmt.Errorf("error writing length prefix: %s", err) + } + _, err = s.insecure.Write(data) + return err +} diff --git a/session.go b/session.go new file mode 100644 index 0000000..0afb27f --- /dev/null +++ b/session.go @@ -0,0 +1,127 @@ +package noise + +import ( + "context" + "errors" + "net" + "sync" + "time" + + logging "github.com/ipfs/go-log" + "github.com/libp2p/go-libp2p-core/crypto" + "github.com/libp2p/go-libp2p-core/peer" + + "github.com/libp2p/go-libp2p-noise/xx" +) + +var log = logging.Logger("noise") + +var errNoKeypair = errors.New("cannot initiate secureSession - transport has no noise keypair") + +type secureSession struct { + insecure net.Conn + + initiator bool + prologue []byte + + localKey crypto.PrivKey + localPeer peer.ID + remotePeer peer.ID + + local peerInfo + remote peerInfo + + ns *xx.NoiseSession + + handshakeComplete bool + noiseKeypair *Keypair + + msgBuffer []byte + readLock sync.Mutex + writeLock sync.Mutex +} + +type peerInfo struct { + noiseKey [32]byte // static noise public key + libp2pKey crypto.PubKey +} + +// newSecureSession creates a noise session over the given insecure Conn, using the static +// Noise keypair and libp2p identity keypair from the given Transport. +func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, initiator bool) (*secureSession, error) { + if tpt.noiseKeypair == nil { + return nil, errNoKeypair + } + + localPeerInfo := peerInfo{ + noiseKey: tpt.noiseKeypair.publicKey, + libp2pKey: tpt.privateKey.GetPublic(), + } + + s := &secureSession{ + insecure: insecure, + initiator: initiator, + prologue: []byte{}, + localKey: tpt.privateKey, + localPeer: tpt.localID, + remotePeer: remote, + local: localPeerInfo, + msgBuffer: []byte{}, + noiseKeypair: tpt.noiseKeypair, + } + + err := s.runHandshake(ctx) + return s, err +} + +func (s *secureSession) NoisePublicKey() [32]byte { + return s.noiseKeypair.publicKey +} + +func (s *secureSession) NoisePrivateKey() [32]byte { + return s.noiseKeypair.privateKey +} + +func (s *secureSession) LocalAddr() net.Addr { + return s.insecure.LocalAddr() +} + +func (s *secureSession) LocalPeer() peer.ID { + return s.localPeer +} + +func (s *secureSession) LocalPrivateKey() crypto.PrivKey { + return s.localKey +} + +func (s *secureSession) LocalPublicKey() crypto.PubKey { + return s.localKey.GetPublic() +} + +func (s *secureSession) RemoteAddr() net.Addr { + return s.insecure.RemoteAddr() +} + +func (s *secureSession) RemotePeer() peer.ID { + return s.remotePeer +} + +func (s *secureSession) RemotePublicKey() crypto.PubKey { + return s.remote.libp2pKey +} + +func (s *secureSession) SetDeadline(t time.Time) error { + return s.insecure.SetDeadline(t) +} + +func (s *secureSession) SetReadDeadline(t time.Time) error { + return s.insecure.SetReadDeadline(t) +} + +func (s *secureSession) SetWriteDeadline(t time.Time) error { + return s.insecure.SetWriteDeadline(t) +} + +func (s *secureSession) Close() error { + return s.insecure.Close() +} diff --git a/transport.go b/transport.go index d6c33d7..9ad3f37 100644 --- a/transport.go +++ b/transport.go @@ -17,91 +17,28 @@ var _ sec.SecureTransport = &Transport{} // Transport implements the interface sec.SecureTransport // https://godoc.org/github.com/libp2p/go-libp2p-core/sec#SecureConn type Transport struct { - localID peer.ID - privateKey crypto.PrivKey - noisePipesSupport bool - noiseStaticKeyCache *KeyCache - noiseKeypair *Keypair -} - -type transportConstructor func(crypto.PrivKey) (*Transport, error) - -// Maker returns a function that will construct a new Noise transport -// using the given Options. The returned function may be provided as a libp2p.Security -// option when configuring a libp2p Host using libp2p.New, and is compatible with the -// "reflection magic" that libp2p.New uses to inject the private identity key: -// -// host := libp2p.New( -// libp2p.Security(noise.ID, noise.Maker())) -// -// The transport can be configured by passing in Options. -// -// To enable the Noise Pipes pattern (which can be more efficient when reconnecting -// to a known peer), pass in the UseNoisePipes Option: -// -// Maker(UseNoisePipes) -// -// To use a specific Noise keypair, pass in the NoiseKeyPair(kp) option, where -// kp is a noise.Keypair struct. This is most useful when using Noise Pipes, whose -// efficiency gains rely on the static Noise key being known in advance. Persisting -// the Noise keypair across process restarts makes it more likely that other peers -// will be able to use the more efficient IK handshake pattern. -// -// Maker(UseNoisePipes, NoiseKeypair(keypairLoadedFromDisk)) -func Maker(options ...Option) transportConstructor { - return func(privKey crypto.PrivKey) (*Transport, error) { - return New(privKey, options...) - } + localID peer.ID + privateKey crypto.PrivKey + noiseKeypair *Keypair } // New creates a new Noise transport using the given private key as its -// libp2p identity key. This function may be used when you want a transport -// instance and know the libp2p Host's identity key before the Host is initialized. -// When configuring a go-libp2p Host using libp2p.New, it's simpler to use -// Maker instead, which will receive the identity key when the Host -// is initialized. -// -// New supports all the same Options as noise.Maker. -// -// To configure a go-libp2p Host to use the newly created transport, pass it into -// libp2p.New wrapped in a libp2p.Security Option. You will also need to -// make sure to set the libp2p.Identity option so that the Host uses the same -// identity key: -// -// privkey := loadPrivateKeyFromSomewhere() -// noiseTpt := noise.New(privkey) -// host := libp2p.New( -// libp2p.Identity(privkey), -// libp2p.Security(noise.ID, noiseTpt)) -func New(privkey crypto.PrivKey, options ...Option) (*Transport, error) { +// libp2p identity key. +func New(privkey crypto.PrivKey) (*Transport, error) { localID, err := peer.IDFromPrivateKey(privkey) if err != nil { return nil, err } - cfg := config{} - cfg.applyOptions(options...) - - kp := cfg.noiseKeypair - if kp == nil { - kp, err = GenerateKeypair() - if err != nil { - return nil, err - } - } - - // the static key cache is only useful if Noise Pipes is enabled - var keyCache *KeyCache - if cfg.noisePipesSupport { - keyCache = NewKeyCache() + kp, err := GenerateKeypair() + if err != nil { + return nil, err } return &Transport{ - localID: localID, - privateKey: privkey, - noisePipesSupport: cfg.noisePipesSupport, - noiseKeypair: kp, - noiseStaticKeyCache: keyCache, + localID: localID, + privateKey: privkey, + noiseKeypair: kp, }, nil } diff --git a/transport_test.go b/transport_test.go index a8194fc..4c4b050 100644 --- a/transport_test.go +++ b/transport_test.go @@ -7,7 +7,6 @@ import ( "net" "testing" - //ik "github.com/libp2p/go-libp2p-noise/ik" crypto "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/sec" @@ -33,12 +32,6 @@ func newTestTransport(t *testing.T, typ, bits int) *Transport { } } -func newTestTransportPipes(t *testing.T, typ, bits int) *Transport { - tpt := newTestTransport(t, typ, bits) - tpt.noisePipesSupport = true - return tpt -} - // Create a new pair of connected TCP sockets. func newConnPair(t *testing.T) (net.Conn, net.Conn) { lstnr, err := net.Listen("tcp", "localhost:0") @@ -227,73 +220,3 @@ func TestHandshakeXX(t *testing.T) { t.Errorf("Message mismatch. %v != %v", before, after) } } - -// Test IK handshake -func TestHandshakeIK(t *testing.T) { - initTransport := newTestTransportPipes(t, crypto.Ed25519, 2048) - respTransport := newTestTransportPipes(t, crypto.Ed25519, 2048) - - // add responder's static key to initiator's key cache - keycache := NewKeyCache() - keycache.Store(respTransport.localID, respTransport.noiseKeypair.publicKey) - initTransport.noiseStaticKeyCache = keycache - - // do IK handshake - initConn, respConn := connect(t, initTransport, respTransport) - defer initConn.Close() - defer respConn.Close() - - before := []byte("hello world") - _, err := initConn.Write(before) - if err != nil { - t.Fatal(err) - } - - after := make([]byte, len(before)) - _, err = respConn.Read(after) - if err != nil { - t.Fatal(err) - } - - if !bytes.Equal(before, after) { - t.Errorf("Message mismatch. %v != %v", before, after) - } - - // make sure IK was actually used - if !(initConn.ik_complete && respConn.ik_complete) { - t.Error("Expected IK handshake to be used") - } -} - -// Test noise pipes -func TestHandshakeXXfallback(t *testing.T) { - initTransport := newTestTransportPipes(t, crypto.Ed25519, 2048) - respTransport := newTestTransportPipes(t, crypto.Ed25519, 2048) - - // turning on pipes causes it to default to IK, but since we haven't already - // done a handshake, it'll fallback to XX - initConn, respConn := connect(t, initTransport, respTransport) - defer initConn.Close() - defer respConn.Close() - - before := []byte("hello world") - _, err := initConn.Write(before) - if err != nil { - t.Fatal(err) - } - - after := make([]byte, len(before)) - _, err = respConn.Read(after) - if err != nil { - t.Fatal(err) - } - - if !bytes.Equal(before, after) { - t.Errorf("Message mismatch. %v != %v", before, after) - } - - // make sure XX was actually used - if !(initConn.xx_complete && respConn.xx_complete) { - t.Error("Expected XXfallback handshake to be used") - } -} diff --git a/xx/XX.noise.go b/xx/XX.noise.go index 9bf37b6..f046eae 100644 --- a/xx/XX.noise.go +++ b/xx/XX.noise.go @@ -427,15 +427,10 @@ func initializeResponder(prologue []byte, s Keypair, rs [32]byte, psk [32]byte) return handshakestate{ss, s, e, rs, re, psk} } -func writeMessageA(hs *handshakestate, payload []byte, e *Keypair) (*handshakestate, MessageBuffer) { +func writeMessageA(hs *handshakestate, payload []byte) (*handshakestate, MessageBuffer) { ne, ns, ciphertext := emptyKey, []byte{}, []byte{} - if e == nil { - hs.e = GenerateKeypair() - } else { - hs.e = *e - } - + hs.e = GenerateKeypair() ne = hs.e.public_key mixHash(&hs.ss, ne[:]) /* No PSK, so skipping mixKey */ @@ -529,10 +524,10 @@ func InitSession(initiator bool, prologue []byte, s Keypair, rs [32]byte) *Noise return &session } -func SendMessage(session *NoiseSession, message []byte, ephemeral *Keypair) (*NoiseSession, MessageBuffer) { +func SendMessage(session *NoiseSession, message []byte) (*NoiseSession, MessageBuffer) { var messageBuffer MessageBuffer if session.mc == 0 { - _, messageBuffer = writeMessageA(&session.hs, message, ephemeral) + _, messageBuffer = writeMessageA(&session.hs, message) } if session.mc == 1 { _, messageBuffer = writeMessageB(&session.hs, message) diff --git a/xx/XX_test.go b/xx/XX_test.go index 4949978..63f6d71 100644 --- a/xx/XX_test.go +++ b/xx/XX_test.go @@ -86,7 +86,7 @@ func doHandshake(t *testing.T) (*NoiseSession, *NoiseSession) { var msgbuf MessageBuffer msg := []byte{} msg = append(msg, payload_init_enc[:]...) - ns_init, msgbuf = SendMessage(ns_init, msg, nil) + ns_init, msgbuf = SendMessage(ns_init, msg) t.Logf("stage 0 msgbuf: %v", msgbuf) t.Logf("stage 0 msgbuf ne len: %d", len(msgbuf.NE())) @@ -111,7 +111,7 @@ func doHandshake(t *testing.T) (*NoiseSession, *NoiseSession) { t.Fatalf("proto marshal payload fail: %s", err) } msg = append(msg, payload_resp_enc[:]...) - ns_resp, msgbuf = SendMessage(ns_resp, msg, nil) + ns_resp, msgbuf = SendMessage(ns_resp, msg) t.Logf("stage 1 msgbuf: %v", msgbuf) t.Logf("stage 1 msgbuf ne len: %d", len(msgbuf.NE())) @@ -128,7 +128,7 @@ func doHandshake(t *testing.T) (*NoiseSession, *NoiseSession) { // stage 2: initiator // send message //msg = append(msg, payload_init_enc[:]...) - ns_init, msgbuf = SendMessage(ns_init, nil, nil) + ns_init, msgbuf = SendMessage(ns_init, nil) t.Logf("stage 2 msgbuf: %v", msgbuf) t.Logf("stage 2 msgbuf ne len: %d", len(msgbuf.NE())) diff --git a/xx_handshake.go b/xx_handshake.go deleted file mode 100644 index 8bc16d0..0000000 --- a/xx_handshake.go +++ /dev/null @@ -1,250 +0,0 @@ -package noise - -import ( - "context" - "fmt" - "io" - - proto "github.com/gogo/protobuf/proto" - "github.com/libp2p/go-libp2p-core/peer" - - pb "github.com/libp2p/go-libp2p-noise/pb" - xx "github.com/libp2p/go-libp2p-noise/xx" -) - -func (s *secureSession) xx_sendHandshakeMessage(payload []byte, initial_stage bool) error { - var msgbuf xx.MessageBuffer - s.xx_ns, msgbuf = xx.SendMessage(s.xx_ns, payload, nil) - var encMsgBuf []byte - if initial_stage { - encMsgBuf = msgbuf.Encode0() - } else { - encMsgBuf = msgbuf.Encode1() - } - - err := s.writeLength(len(encMsgBuf)) - if err != nil { - return fmt.Errorf("xx_sendHandshakeMessage write length err=%s", err) - } - - _, err = s.insecure.Write(encMsgBuf) - if err != nil { - return fmt.Errorf("xx_sendHandshakeMessage write to conn err=%s", err) - } - - return nil -} - -func (s *secureSession) xx_recvHandshakeMessage(initial_stage bool) (buf []byte, plaintext []byte, valid bool, err error) { - l, err := s.readLength() - if err != nil { - return nil, nil, false, fmt.Errorf("xx_recvHandshakeMessage read length err=%s", err) - } - - buf = make([]byte, l) - - _, err = io.ReadFull(s.insecure, buf) - if err != nil { - return buf, nil, false, fmt.Errorf("xx_recvHandshakeMessage read from conn err=%s", err) - } - - var msgbuf *xx.MessageBuffer - if initial_stage { - msgbuf, err = xx.Decode0(buf) - } else { - msgbuf, err = xx.Decode1(buf) - } - - if err != nil { - return buf, nil, false, fmt.Errorf("xx_recvHandshakeMessage decode msg err=%s", err) - } - - s.xx_ns, plaintext, valid = xx.RecvMessage(s.xx_ns, msgbuf) - if !valid { - return buf, nil, false, fmt.Errorf("xx_recvHandshakeMessage validation fail") - } - - return buf, plaintext, valid, nil -} - -// Runs the XX handshake -// XX: -// -> e -// <- e, ee, s, es -// -> s, se -// if fallback = true, initialMsg is used as the message in stage 1 of the initiator and stage 0 -// of the responder -func (s *secureSession) runHandshake_xx(ctx context.Context, fallback bool, payload []byte, initialMsg []byte) (err error) { - kp := xx.NewKeypair(s.noiseKeypair.publicKey, s.noiseKeypair.privateKey) - - // new XX noise session - s.xx_ns = xx.InitSession(s.initiator, s.prologue, kp, [32]byte{}) - - if s.initiator { - // stage 0 // - - if !fallback { - err = s.xx_sendHandshakeMessage(nil, true) - if err != nil { - return fmt.Errorf("runHandshake_xx stage 0 initiator fail: %s", err) - } - } else { - e_ik := s.ik_ns.Ephemeral() - e_xx := xx.NewKeypair(e_ik.PubKey(), e_ik.PrivKey()) - - // initialize state as if we sent the first message - s.xx_ns, _ = xx.SendMessage(s.xx_ns, nil, &e_xx) - } - - // stage 1 // - - var plaintext []byte - var valid bool - if !fallback { - // read reply - _, plaintext, valid, err = s.xx_recvHandshakeMessage(false) - if err != nil { - return fmt.Errorf("runHandshake_xx initiator stage 1 fail: %s", err) - } - - if !valid { - return fmt.Errorf("runHandshake_xx stage 1 initiator validation fail") - } - } else { - var msgbuf *xx.MessageBuffer - msgbuf, err = xx.Decode1(initialMsg) - - if err != nil { - return fmt.Errorf("runHandshake_xx decode msg fail: %s", err) - } - - s.xx_ns, plaintext, valid = xx.RecvMessage(s.xx_ns, msgbuf) - if !valid { - return fmt.Errorf("runHandshake_xx validation fail") - } - - } - - // stage 2 // - - err = s.xx_sendHandshakeMessage(payload, false) - if err != nil { - return fmt.Errorf("runHandshake_xx stage=2 initiator=true err=%s", err) - } - - // unmarshal payload - nhp := new(pb.NoiseHandshakePayload) - err = proto.Unmarshal(plaintext, nhp) - if err != nil { - return fmt.Errorf("runHandshake_xx stage=2 initiator=true err=cannot unmarshal payload") - } - - // set remote libp2p public key - err = s.setRemotePeerInfo(nhp.GetIdentityKey()) - if err != nil { - return fmt.Errorf("runHandshake_xx stage=2 initiator=true read remote libp2p key fail") - } - - // assert that remote peer ID matches libp2p public key - pid, err := peer.IDFromPublicKey(s.RemotePublicKey()) - if pid != s.remotePeer { - return fmt.Errorf("runHandshake_xx stage=2 initiator=true check remote peer id err: expected %x got %x", s.remotePeer, pid) - } else if err != nil { - return fmt.Errorf("runHandshake_xx stage 2 initiator check remote peer id err %s", err) - } - - // verify payload is signed by libp2p key - err = s.verifyPayload(nhp, s.xx_ns.RemoteKey()) - if err != nil { - return fmt.Errorf("runHandshake_xx stage=2 initiator=true verify payload err=%s", err) - } - - if s.noisePipesSupport { - s.noiseStaticKeyCache.Store(s.remotePeer, s.xx_ns.RemoteKey()) - } - - } else { - - // stage 0 // - - var plaintext []byte - var valid bool - nhp := new(pb.NoiseHandshakePayload) - - if !fallback { - // read message - _, plaintext, valid, err = s.xx_recvHandshakeMessage(true) - if err != nil { - return fmt.Errorf("runHandshake_xx stage=0 initiator=false err=%s", err) - } - - if !valid { - return fmt.Errorf("runHandshake_xx stage=0 initiator=false err=validation fail") - } - - } else { - var msgbuf *xx.MessageBuffer - msgbuf, err = xx.Decode0(initialMsg) - if err != nil { - return err - } - - xx_msgbuf := xx.NewMessageBuffer(msgbuf.NE(), nil, nil) - s.xx_ns, plaintext, valid = xx.RecvMessage(s.xx_ns, &xx_msgbuf) - if !valid { - return fmt.Errorf("runHandshake_xx validation fail") - } - } - - // stage 1 // - - err = s.xx_sendHandshakeMessage(payload, false) - if err != nil { - return fmt.Errorf("runHandshake_xx stage=1 initiator=false err=%s", err) - } - - // stage 2 // - - // read message - _, plaintext, valid, err = s.xx_recvHandshakeMessage(false) - if err != nil { - return fmt.Errorf("runHandshake_xx stage=2 initiator=false err=%s", err) - } - - if !valid { - return fmt.Errorf("runHandshake_xx stage=2 initiator=false err=validation fail") - } - - // unmarshal payload - err = proto.Unmarshal(plaintext, nhp) - if err != nil { - return fmt.Errorf("runHandshake_xx stage=2 initiator=false err=cannot unmarshal payload") - } - - // set remote libp2p public key - err = s.setRemotePeerInfo(nhp.GetIdentityKey()) - if err != nil { - return fmt.Errorf("runHandshake_xx stage=2 initiator=false read remote libp2p key fail") - } - - // set remote libp2p public key from payload - err = s.setRemotePeerID(s.RemotePublicKey()) - if err != nil { - return fmt.Errorf("runHandshake_xx stage=2 initiator=false set remote peer id err=%s", err) - } - - s.remote.noiseKey = s.xx_ns.RemoteKey() - - // verify payload is signed by libp2p key - err = s.verifyPayload(nhp, s.remote.noiseKey) - if err != nil { - return fmt.Errorf("runHandshake_xx stage=2 initiator=false err=%s", err) - } - - if s.noisePipesSupport { - s.noiseStaticKeyCache.Store(s.remotePeer, s.remote.noiseKey) - } - } - - return nil -}