diff --git a/core/alias.go b/core/alias.go new file mode 100644 index 0000000000..1a19343867 --- /dev/null +++ b/core/alias.go @@ -0,0 +1,50 @@ +// Package core provides convenient access to foundational, central go-libp2p primitives via type aliases. +package core + +import ( + "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/protocol" + "github.com/multiformats/go-multiaddr" +) + +// Multiaddr aliases the Multiaddr type from github.com/multiformats/go-multiaddr. +// +// Refer to the docs on that type for more info. +type Multiaddr = multiaddr.Multiaddr + +// PeerID aliases peer.ID. +// +// Refer to the docs on that type for more info. +type PeerID = peer.ID + +// PeerID aliases protocol.ID. +// +// Refer to the docs on that type for more info. +type ProtocolID = protocol.ID + +// PeerAddrInfo aliases peer.AddrInfo. +// +// Refer to the docs on that type for more info. +type PeerAddrInfo = peer.AddrInfo + +// Host aliases host.Host. +// +// Refer to the docs on that type for more info. +type Host = host.Host + +// Network aliases network.Network. +// +// Refer to the docs on that type for more info. +type Network = network.Network + +// Conn aliases network.Conn. +// +// Refer to the docs on that type for more info. +type Conn = network.Conn + +// Stream aliases network.Stream. +// +// Refer to the docs on that type for more info. +type Stream = network.Stream diff --git a/core/connmgr/connmgr.go b/core/connmgr/connmgr.go new file mode 100644 index 0000000000..98baf8b9a2 --- /dev/null +++ b/core/connmgr/connmgr.go @@ -0,0 +1,78 @@ +// Package connmgr provides connection tracking and management interfaces for libp2p. +// +// The ConnManager interface exported from this package allows libp2p to enforce an +// upper bound on the total number of open connections. To avoid service disruptions, +// connections can be tagged with metadata and optionally "protected" to ensure that +// essential connections are not arbitrarily cut. +package connmgr + +import ( + "context" + "time" + + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" +) + +// ConnManager tracks connections to peers, and allows consumers to associate metadata +// with each peer. +// +// It enables connections to be trimmed based on implementation-defined heuristics. +// The ConnManager allows libp2p to enforce an upper bound on the total number of +// open connections. +type ConnManager interface { + + // TagPeer tags a peer with a string, associating a weight with the tag. + TagPeer(peer.ID, string, int) + + // Untag removes the tagged value from the peer. + UntagPeer(p peer.ID, tag string) + + // UpsertTag updates an existing tag or inserts a new one. + // + // The connection manager calls the upsert function supplying the current + // value of the tag (or zero if inexistent). The return value is used as + // the new value of the tag. + UpsertTag(p peer.ID, tag string, upsert func(int) int) + + // GetTagInfo returns the metadata associated with the peer, + // or nil if no metadata has been recorded for the peer. + GetTagInfo(p peer.ID) *TagInfo + + // TrimOpenConns terminates open connections based on an implementation-defined + // heuristic. + TrimOpenConns(ctx context.Context) + + // Notifee returns an implementation that can be called back to inform of + // opened and closed connections. + Notifee() network.Notifiee + + // Protect protects a peer from having its connection(s) pruned. + // + // Tagging allows different parts of the system to manage protections without interfering with one another. + // + // Calls to Protect() with the same tag are idempotent. They are not refcounted, so after multiple calls + // to Protect() with the same tag, a single Unprotect() call bearing the same tag will revoke the protection. + Protect(id peer.ID, tag string) + + // Unprotect removes a protection that may have been placed on a peer, under the specified tag. + // + // The return value indicates whether the peer continues to be protected after this call, by way of a different tag. + // See notes on Protect() for more info. + Unprotect(id peer.ID, tag string) (protected bool) + + // Close closes the connection manager and stops background processes + Close() error +} + +// TagInfo stores metadata associated with a peer. +type TagInfo struct { + FirstSeen time.Time + Value int + + // Tags maps tag ids to the numerical values. + Tags map[string]int + + // Conns maps connection ids (such as remote multiaddr) to their creation time. + Conns map[string]time.Time +} diff --git a/core/connmgr/null.go b/core/connmgr/null.go new file mode 100644 index 0000000000..473af32895 --- /dev/null +++ b/core/connmgr/null.go @@ -0,0 +1,23 @@ +package connmgr + +import ( + "context" + + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" +) + +// NullConnMgr is a ConnMgr that provides no functionality. +type NullConnMgr struct{} + +var _ ConnManager = (*NullConnMgr)(nil) + +func (_ NullConnMgr) TagPeer(peer.ID, string, int) {} +func (_ NullConnMgr) UntagPeer(peer.ID, string) {} +func (_ NullConnMgr) UpsertTag(peer.ID, string, func(int) int) {} +func (_ NullConnMgr) GetTagInfo(peer.ID) *TagInfo { return &TagInfo{} } +func (_ NullConnMgr) TrimOpenConns(ctx context.Context) {} +func (_ NullConnMgr) Notifee() network.Notifiee { return network.GlobalNoopNotifiee } +func (_ NullConnMgr) Protect(peer.ID, string) {} +func (_ NullConnMgr) Unprotect(peer.ID, string) bool { return false } +func (_ NullConnMgr) Close() error { return nil } diff --git a/core/crypto/bench_test.go b/core/crypto/bench_test.go new file mode 100644 index 0000000000..fffc3bb7de --- /dev/null +++ b/core/crypto/bench_test.go @@ -0,0 +1,84 @@ +package crypto + +import "testing" + +func BenchmarkSignRSA1B(b *testing.B) { RunBenchmarkSignRSA(b, 1) } +func BenchmarkSignRSA10B(b *testing.B) { RunBenchmarkSignRSA(b, 10) } +func BenchmarkSignRSA100B(b *testing.B) { RunBenchmarkSignRSA(b, 100) } +func BenchmarkSignRSA1000B(b *testing.B) { RunBenchmarkSignRSA(b, 1000) } +func BenchmarkSignRSA10000B(b *testing.B) { RunBenchmarkSignRSA(b, 10000) } +func BenchmarkSignRSA100000B(b *testing.B) { RunBenchmarkSignRSA(b, 100000) } + +func BenchmarkVerifyRSA1B(b *testing.B) { RunBenchmarkVerifyRSA(b, 1) } +func BenchmarkVerifyRSA10B(b *testing.B) { RunBenchmarkVerifyRSA(b, 10) } +func BenchmarkVerifyRSA100B(b *testing.B) { RunBenchmarkVerifyRSA(b, 100) } +func BenchmarkVerifyRSA1000B(b *testing.B) { RunBenchmarkVerifyRSA(b, 1000) } +func BenchmarkVerifyRSA10000B(b *testing.B) { RunBenchmarkVerifyRSA(b, 10000) } +func BenchmarkVerifyRSA100000B(b *testing.B) { RunBenchmarkVerifyRSA(b, 100000) } + +func BenchmarkSignEd255191B(b *testing.B) { RunBenchmarkSignEd25519(b, 1) } +func BenchmarkSignEd2551910B(b *testing.B) { RunBenchmarkSignEd25519(b, 10) } +func BenchmarkSignEd25519100B(b *testing.B) { RunBenchmarkSignEd25519(b, 100) } +func BenchmarkSignEd255191000B(b *testing.B) { RunBenchmarkSignEd25519(b, 1000) } +func BenchmarkSignEd2551910000B(b *testing.B) { RunBenchmarkSignEd25519(b, 10000) } +func BenchmarkSignEd25519100000B(b *testing.B) { RunBenchmarkSignEd25519(b, 100000) } + +func BenchmarkVerifyEd255191B(b *testing.B) { RunBenchmarkVerifyEd25519(b, 1) } +func BenchmarkVerifyEd2551910B(b *testing.B) { RunBenchmarkVerifyEd25519(b, 10) } +func BenchmarkVerifyEd25519100B(b *testing.B) { RunBenchmarkVerifyEd25519(b, 100) } +func BenchmarkVerifyEd255191000B(b *testing.B) { RunBenchmarkVerifyEd25519(b, 1000) } +func BenchmarkVerifyEd2551910000B(b *testing.B) { RunBenchmarkVerifyEd25519(b, 10000) } +func BenchmarkVerifyEd25519100000B(b *testing.B) { RunBenchmarkVerifyEd25519(b, 100000) } + +func RunBenchmarkSignRSA(b *testing.B, numBytes int) { + runBenchmarkSign(b, numBytes, RSA) +} + +func RunBenchmarkSignEd25519(b *testing.B, numBytes int) { + runBenchmarkSign(b, numBytes, Ed25519) +} + +func runBenchmarkSign(b *testing.B, numBytes int, t int) { + secret, _, err := GenerateKeyPair(t, 1024) + if err != nil { + b.Fatal(err) + } + someData := make([]byte, numBytes) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := secret.Sign(someData) + if err != nil { + b.Fatal(err) + } + } +} + +func RunBenchmarkVerifyRSA(b *testing.B, numBytes int) { + runBenchmarkSign(b, numBytes, RSA) +} + +func RunBenchmarkVerifyEd25519(b *testing.B, numBytes int) { + runBenchmarkSign(b, numBytes, Ed25519) +} + +func runBenchmarkVerify(b *testing.B, numBytes int, t int) { + secret, public, err := GenerateKeyPair(t, 1024) + if err != nil { + b.Fatal(err) + } + someData := make([]byte, numBytes) + signature, err := secret.Sign(someData) + if err != nil { + b.Fatal(err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + valid, err := public.Verify(someData, signature) + if err != nil { + b.Fatal(err) + } + if !valid { + b.Fatal("signature should be valid") + } + } +} diff --git a/core/crypto/ecdsa.go b/core/crypto/ecdsa.go new file mode 100644 index 0000000000..58e5d5f5ab --- /dev/null +++ b/core/crypto/ecdsa.go @@ -0,0 +1,186 @@ +package crypto + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "encoding/asn1" + "errors" + "io" + "math/big" + + pb "github.com/libp2p/go-libp2p-core/crypto/pb" + + sha256 "github.com/minio/sha256-simd" +) + +// ECDSAPrivateKey is an implementation of an ECDSA private key +type ECDSAPrivateKey struct { + priv *ecdsa.PrivateKey +} + +// ECDSAPublicKey is an implementation of an ECDSA public key +type ECDSAPublicKey struct { + pub *ecdsa.PublicKey +} + +// ECDSASig holds the r and s values of an ECDSA signature +type ECDSASig struct { + R, S *big.Int +} + +var ( + // ErrNotECDSAPubKey is returned when the public key passed is not an ecdsa public key + ErrNotECDSAPubKey = errors.New("not an ecdsa public key") + // ErrNilSig is returned when the signature is nil + ErrNilSig = errors.New("sig is nil") + // ErrNilPrivateKey is returned when a nil private key is provided + ErrNilPrivateKey = errors.New("private key is nil") + // ECDSACurve is the default ecdsa curve used + ECDSACurve = elliptic.P256() +) + +// GenerateECDSAKeyPair generates a new ecdsa private and public key +func GenerateECDSAKeyPair(src io.Reader) (PrivKey, PubKey, error) { + return GenerateECDSAKeyPairWithCurve(ECDSACurve, src) +} + +// GenerateECDSAKeyPairWithCurve generates a new ecdsa private and public key with a speicified curve +func GenerateECDSAKeyPairWithCurve(curve elliptic.Curve, src io.Reader) (PrivKey, PubKey, error) { + priv, err := ecdsa.GenerateKey(curve, src) + if err != nil { + return nil, nil, err + } + + return &ECDSAPrivateKey{priv}, &ECDSAPublicKey{&priv.PublicKey}, nil +} + +// ECDSAKeyPairFromKey generates a new ecdsa private and public key from an input private key +func ECDSAKeyPairFromKey(priv *ecdsa.PrivateKey) (PrivKey, PubKey, error) { + if priv == nil { + return nil, nil, ErrNilPrivateKey + } + + return &ECDSAPrivateKey{priv}, &ECDSAPublicKey{&priv.PublicKey}, nil +} + +// MarshalECDSAPrivateKey returns x509 bytes from a private key +func MarshalECDSAPrivateKey(ePriv ECDSAPrivateKey) ([]byte, error) { + return x509.MarshalECPrivateKey(ePriv.priv) +} + +// MarshalECDSAPublicKey returns x509 bytes from a public key +func MarshalECDSAPublicKey(ePub ECDSAPublicKey) ([]byte, error) { + return x509.MarshalPKIXPublicKey(ePub.pub) +} + +// UnmarshalECDSAPrivateKey returns a private key from x509 bytes +func UnmarshalECDSAPrivateKey(data []byte) (PrivKey, error) { + priv, err := x509.ParseECPrivateKey(data) + if err != nil { + return nil, err + } + + return &ECDSAPrivateKey{priv}, nil +} + +// UnmarshalECDSAPublicKey returns the public key from x509 bytes +func UnmarshalECDSAPublicKey(data []byte) (PubKey, error) { + pubIfc, err := x509.ParsePKIXPublicKey(data) + if err != nil { + return nil, err + } + + pub, ok := pubIfc.(*ecdsa.PublicKey) + if !ok { + return nil, ErrNotECDSAPubKey + } + + return &ECDSAPublicKey{pub}, nil +} + +// Bytes returns the private key as protobuf bytes +func (ePriv *ECDSAPrivateKey) Bytes() ([]byte, error) { + return MarshalPrivateKey(ePriv) +} + +// Type returns the key type +func (ePriv *ECDSAPrivateKey) Type() pb.KeyType { + return pb.KeyType_ECDSA +} + +// Raw returns x509 bytes from a private key +func (ePriv *ECDSAPrivateKey) Raw() ([]byte, error) { + return x509.MarshalECPrivateKey(ePriv.priv) +} + +// Equals compares to private keys +func (ePriv *ECDSAPrivateKey) Equals(o Key) bool { + oPriv, ok := o.(*ECDSAPrivateKey) + if !ok { + return false + } + + return ePriv.priv.D.Cmp(oPriv.priv.D) == 0 +} + +// Sign returns the signature of the input data +func (ePriv *ECDSAPrivateKey) Sign(data []byte) ([]byte, error) { + hash := sha256.Sum256(data) + r, s, err := ecdsa.Sign(rand.Reader, ePriv.priv, hash[:]) + if err != nil { + return nil, err + } + + return asn1.Marshal(ECDSASig{ + R: r, + S: s, + }) +} + +// GetPublic returns a public key +func (ePriv *ECDSAPrivateKey) GetPublic() PubKey { + return &ECDSAPublicKey{&ePriv.priv.PublicKey} +} + +// Bytes returns the public key as protobuf bytes +func (ePub *ECDSAPublicKey) Bytes() ([]byte, error) { + return MarshalPublicKey(ePub) +} + +// Type returns the key type +func (ePub *ECDSAPublicKey) Type() pb.KeyType { + return pb.KeyType_ECDSA +} + +// Raw returns x509 bytes from a public key +func (ePub ECDSAPublicKey) Raw() ([]byte, error) { + return x509.MarshalPKIXPublicKey(ePub.pub) +} + +// Equals compares to public keys +func (ePub *ECDSAPublicKey) Equals(o Key) bool { + oPub, ok := o.(*ECDSAPublicKey) + if !ok { + return false + } + + return ePub.pub.X != nil && ePub.pub.Y != nil && oPub.pub.X != nil && oPub.pub.Y != nil && + 0 == ePub.pub.X.Cmp(oPub.pub.X) && 0 == ePub.pub.Y.Cmp(oPub.pub.Y) +} + +// Verify compares data to a signature +func (ePub *ECDSAPublicKey) Verify(data, sigBytes []byte) (bool, error) { + sig := new(ECDSASig) + if _, err := asn1.Unmarshal(sigBytes, sig); err != nil { + return false, err + } + if sig == nil { + return false, ErrNilSig + } + + hash := sha256.Sum256(data) + + return ecdsa.Verify(ePub.pub, hash[:], sig.R, sig.S), nil +} diff --git a/core/crypto/ecdsa_test.go b/core/crypto/ecdsa_test.go new file mode 100644 index 0000000000..328b9a7c6c --- /dev/null +++ b/core/crypto/ecdsa_test.go @@ -0,0 +1,96 @@ +package crypto + +import ( + "crypto/rand" + "testing" +) + +func TestECDSABasicSignAndVerify(t *testing.T) { + priv, pub, err := GenerateECDSAKeyPair(rand.Reader) + if err != nil { + t.Fatal(err) + } + + data := []byte("hello! and welcome to some awesome crypto primitives") + + sig, err := priv.Sign(data) + if err != nil { + t.Fatal(err) + } + + ok, err := pub.Verify(data, sig) + if err != nil { + t.Fatal(err) + } + + if !ok { + t.Fatal("signature didnt match") + } + + // change data + data[0] = ^data[0] + ok, err = pub.Verify(data, sig) + if err != nil { + t.Fatal(err) + } + + if ok { + t.Fatal("signature matched and shouldn't") + } +} + +func TestECDSASignZero(t *testing.T) { + priv, pub, err := GenerateECDSAKeyPair(rand.Reader) + if err != nil { + t.Fatal(err) + } + + data := make([]byte, 0) + sig, err := priv.Sign(data) + if err != nil { + t.Fatal(err) + } + + ok, err := pub.Verify(data, sig) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("signature didn't match") + } +} + +func TestECDSAMarshalLoop(t *testing.T) { + priv, pub, err := GenerateECDSAKeyPair(rand.Reader) + if err != nil { + t.Fatal(err) + } + + privB, err := priv.Bytes() + if err != nil { + t.Fatal(err) + } + + privNew, err := UnmarshalPrivateKey(privB) + if err != nil { + t.Fatal(err) + } + + if !priv.Equals(privNew) || !privNew.Equals(priv) { + t.Fatal("keys are not equal") + } + + pubB, err := pub.Bytes() + if err != nil { + t.Fatal(err) + } + pubNew, err := UnmarshalPublicKey(pubB) + if err != nil { + t.Fatal(err) + } + + if !pub.Equals(pubNew) || !pubNew.Equals(pub) { + t.Fatal("keys are not equal") + } + +} diff --git a/core/crypto/ed25519.go b/core/crypto/ed25519.go new file mode 100644 index 0000000000..b6e553147d --- /dev/null +++ b/core/crypto/ed25519.go @@ -0,0 +1,155 @@ +package crypto + +import ( + "bytes" + "errors" + "fmt" + "io" + + pb "github.com/libp2p/go-libp2p-core/crypto/pb" + + "golang.org/x/crypto/ed25519" +) + +// Ed25519PrivateKey is an ed25519 private key. +type Ed25519PrivateKey struct { + k ed25519.PrivateKey +} + +// Ed25519PublicKey is an ed25519 public key. +type Ed25519PublicKey struct { + k ed25519.PublicKey +} + +// GenerateEd25519Key generates a new ed25519 private and public key pair. +func GenerateEd25519Key(src io.Reader) (PrivKey, PubKey, error) { + pub, priv, err := ed25519.GenerateKey(src) + if err != nil { + return nil, nil, err + } + + return &Ed25519PrivateKey{ + k: priv, + }, + &Ed25519PublicKey{ + k: pub, + }, + nil +} + +// Type of the private key (Ed25519). +func (k *Ed25519PrivateKey) Type() pb.KeyType { + return pb.KeyType_Ed25519 +} + +// Bytes marshals an ed25519 private key to protobuf bytes. +func (k *Ed25519PrivateKey) Bytes() ([]byte, error) { + return MarshalPrivateKey(k) +} + +// Raw private key bytes. +func (k *Ed25519PrivateKey) Raw() ([]byte, error) { + // The Ed25519 private key contains two 32-bytes curve points, the private + // key and the public key. + // It makes it more efficient to get the public key without re-computing an + // elliptic curve multiplication. + buf := make([]byte, len(k.k)) + copy(buf, k.k) + + return buf, nil +} + +func (k *Ed25519PrivateKey) pubKeyBytes() []byte { + return k.k[ed25519.PrivateKeySize-ed25519.PublicKeySize:] +} + +// Equals compares two ed25519 private keys. +func (k *Ed25519PrivateKey) Equals(o Key) bool { + edk, ok := o.(*Ed25519PrivateKey) + if !ok { + return false + } + + return bytes.Equal(k.k, edk.k) +} + +// GetPublic returns an ed25519 public key from a private key. +func (k *Ed25519PrivateKey) GetPublic() PubKey { + return &Ed25519PublicKey{k: k.pubKeyBytes()} +} + +// Sign returns a signature from an input message. +func (k *Ed25519PrivateKey) Sign(msg []byte) ([]byte, error) { + return ed25519.Sign(k.k, msg), nil +} + +// Type of the public key (Ed25519). +func (k *Ed25519PublicKey) Type() pb.KeyType { + return pb.KeyType_Ed25519 +} + +// Bytes returns a ed25519 public key as protobuf bytes. +func (k *Ed25519PublicKey) Bytes() ([]byte, error) { + return MarshalPublicKey(k) +} + +// Raw public key bytes. +func (k *Ed25519PublicKey) Raw() ([]byte, error) { + return k.k, nil +} + +// Equals compares two ed25519 public keys. +func (k *Ed25519PublicKey) Equals(o Key) bool { + edk, ok := o.(*Ed25519PublicKey) + if !ok { + return false + } + + return bytes.Equal(k.k, edk.k) +} + +// Verify checks a signature agains the input data. +func (k *Ed25519PublicKey) Verify(data []byte, sig []byte) (bool, error) { + return ed25519.Verify(k.k, data, sig), nil +} + +// UnmarshalEd25519PublicKey returns a public key from input bytes. +func UnmarshalEd25519PublicKey(data []byte) (PubKey, error) { + if len(data) != 32 { + return nil, errors.New("expect ed25519 public key data size to be 32") + } + + return &Ed25519PublicKey{ + k: ed25519.PublicKey(data), + }, nil +} + +// UnmarshalEd25519PrivateKey returns a private key from input bytes. +func UnmarshalEd25519PrivateKey(data []byte) (PrivKey, error) { + switch len(data) { + case ed25519.PrivateKeySize + ed25519.PublicKeySize: + // Remove the redundant public key. See issue #36. + redundantPk := data[ed25519.PrivateKeySize:] + pk := data[ed25519.PrivateKeySize-ed25519.PublicKeySize : ed25519.PrivateKeySize] + if !bytes.Equal(pk, redundantPk) { + return nil, errors.New("expected redundant ed25519 public key to be redundant") + } + + // No point in storing the extra data. + newKey := make([]byte, ed25519.PrivateKeySize) + copy(newKey, data[:ed25519.PrivateKeySize]) + data = newKey + case ed25519.PrivateKeySize: + default: + return nil, fmt.Errorf( + "expected ed25519 data size to be %d or %d, got %d", + ed25519.PrivateKeySize, + ed25519.PrivateKeySize+ed25519.PublicKeySize, + len(data), + ) + } + + return &Ed25519PrivateKey{ + k: ed25519.PrivateKey(data), + }, nil +} diff --git a/core/crypto/ed25519_test.go b/core/crypto/ed25519_test.go new file mode 100644 index 0000000000..222b5154b6 --- /dev/null +++ b/core/crypto/ed25519_test.go @@ -0,0 +1,220 @@ +package crypto + +import ( + "crypto/rand" + "testing" + + pb "github.com/libp2p/go-libp2p-core/crypto/pb" + + "golang.org/x/crypto/ed25519" +) + +func TestBasicSignAndVerify(t *testing.T) { + priv, pub, err := GenerateEd25519Key(rand.Reader) + if err != nil { + t.Fatal(err) + } + + data := []byte("hello! and welcome to some awesome crypto primitives") + + sig, err := priv.Sign(data) + if err != nil { + t.Fatal(err) + } + + ok, err := pub.Verify(data, sig) + if err != nil { + t.Fatal(err) + } + + if !ok { + t.Fatal("signature didn't match") + } + + // change data + data[0] = ^data[0] + ok, err = pub.Verify(data, sig) + if err != nil { + t.Fatal(err) + } + + if ok { + t.Fatal("signature matched and shouldn't") + } +} + +func TestSignZero(t *testing.T) { + priv, pub, err := GenerateEd25519Key(rand.Reader) + if err != nil { + t.Fatal(err) + } + + data := make([]byte, 0) + sig, err := priv.Sign(data) + if err != nil { + t.Fatal(err) + } + + ok, err := pub.Verify(data, sig) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("signature didn't match") + } +} + +func TestMarshalLoop(t *testing.T) { + priv, pub, err := GenerateEd25519Key(rand.Reader) + if err != nil { + t.Fatal(err) + } + + t.Run("PrivateKey", func(t *testing.T) { + for name, f := range map[string]func() ([]byte, error){ + "Bytes": priv.Bytes, + "Marshal": func() ([]byte, error) { + return MarshalPrivateKey(priv) + }, + "Redundant": func() ([]byte, error) { + // See issue #36. + // Ed25519 private keys used to contain the public key twice. + // For backwards-compatibility, we need to continue supporting + // that scenario. + pbmes := new(pb.PrivateKey) + pbmes.Type = priv.Type() + data, err := priv.Raw() + if err != nil { + t.Fatal(err) + } + + pbmes.Data = append(data, data[len(data)-ed25519.PublicKeySize:]...) + return pbmes.Marshal() + }, + } { + t.Run(name, func(t *testing.T) { + bts, err := f() + if err != nil { + t.Fatal(err) + } + + privNew, err := UnmarshalPrivateKey(bts) + if err != nil { + t.Fatal(err) + } + + if !priv.Equals(privNew) || !privNew.Equals(priv) { + t.Fatal("keys are not equal") + } + + msg := []byte("My child, my sister,\nThink of the rapture\nOf living together there!") + signed, err := privNew.Sign(msg) + if err != nil { + t.Fatal(err) + } + + ok, err := privNew.GetPublic().Verify(msg, signed) + if err != nil { + t.Fatal(err) + } + + if !ok { + t.Fatal("signature didn't match") + } + }) + } + }) + + t.Run("PublicKey", func(t *testing.T) { + for name, f := range map[string]func() ([]byte, error){ + "Bytes": pub.Bytes, + "Marshal": func() ([]byte, error) { + return MarshalPublicKey(pub) + }, + } { + t.Run(name, func(t *testing.T) { + bts, err := f() + if err != nil { + t.Fatal(err) + } + pubNew, err := UnmarshalPublicKey(bts) + if err != nil { + t.Fatal(err) + } + + if !pub.Equals(pubNew) || !pubNew.Equals(pub) { + t.Fatal("keys are not equal") + } + }) + } + }) +} + +func TestUnmarshalErrors(t *testing.T) { + t.Run("PublicKey", func(t *testing.T) { + t.Run("Invalid data length", func(t *testing.T) { + pbmes := &pb.PublicKey{ + Type: pb.KeyType_Ed25519, + Data: []byte{42}, + } + + data, err := pbmes.Marshal() + if err != nil { + t.Fatal(err) + } + + _, err = UnmarshalPublicKey(data) + if err == nil { + t.Fatal("expected an error") + } + }) + }) + + t.Run("PrivateKey", func(t *testing.T) { + t.Run("Redundant public key mismatch", func(t *testing.T) { + priv, _, err := GenerateEd25519Key(rand.Reader) + if err != nil { + t.Fatal(err) + } + + pbmes := new(pb.PrivateKey) + pbmes.Type = priv.Type() + data, err := priv.Raw() + if err != nil { + t.Fatal(err) + } + + // Append the private key instead of the public key. + pbmes.Data = append(data, data[:ed25519.PublicKeySize]...) + b, err := pbmes.Marshal() + if err != nil { + t.Fatal(err) + } + + _, err = UnmarshalPrivateKey(b) + if err == nil { + t.Fatal("expected an error") + } + if err.Error() != "expected redundant ed25519 public key to be redundant" { + t.Fatalf("invalid error received: %s", err.Error()) + } + }) + + t.Run("Invalid data length", func(t *testing.T) { + pbmes := &pb.PrivateKey{ + Type: pb.KeyType_Ed25519, + Data: []byte{42}, + } + + data, err := pbmes.Marshal() + if err != nil { + t.Fatal(err) + } + + _, err = UnmarshalPrivateKey(data) + if err == nil { + t.Fatal("expected an error") + } + }) + }) +} diff --git a/core/crypto/fixture_test.go b/core/crypto/fixture_test.go new file mode 100644 index 0000000000..471e85472f --- /dev/null +++ b/core/crypto/fixture_test.go @@ -0,0 +1,132 @@ +package crypto_test + +import ( + "bytes" + "crypto/rand" + "fmt" + "io" + "io/ioutil" + "testing" + + crypto "github.com/libp2p/go-libp2p-core/crypto" + crypto_pb "github.com/libp2p/go-libp2p-core/crypto/pb" +) + +var message = []byte("Libp2p is the _best_!") + +type testCase struct { + keyType crypto_pb.KeyType + gen func(i io.Reader) (crypto.PrivKey, crypto.PubKey, error) + sigDeterministic bool +} + +var keyTypes = []testCase{ + { + keyType: crypto_pb.KeyType_ECDSA, + gen: crypto.GenerateECDSAKeyPair, + }, + { + keyType: crypto_pb.KeyType_Secp256k1, + sigDeterministic: true, + gen: crypto.GenerateSecp256k1Key, + }, + { + keyType: crypto_pb.KeyType_RSA, + sigDeterministic: true, + gen: func(i io.Reader) (crypto.PrivKey, crypto.PubKey, error) { + return crypto.GenerateRSAKeyPair(2048, i) + }, + }, +} + +func fname(kt crypto_pb.KeyType, ext string) string { + return fmt.Sprintf("test_data/%d.%s", kt, ext) +} + +func TestFixtures(t *testing.T) { + for _, tc := range keyTypes { + t.Run(tc.keyType.String(), func(t *testing.T) { + pubBytes, err := ioutil.ReadFile(fname(tc.keyType, "pub")) + if err != nil { + t.Fatal(err) + } + privBytes, err := ioutil.ReadFile(fname(tc.keyType, "priv")) + if err != nil { + t.Fatal(err) + } + sigBytes, err := ioutil.ReadFile(fname(tc.keyType, "sig")) + if err != nil { + t.Fatal(err) + } + pub, err := crypto.UnmarshalPublicKey(pubBytes) + if err != nil { + t.Fatal(err) + } + pubBytes2, err := pub.Bytes() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(pubBytes2, pubBytes) { + t.Fatal("encoding round-trip failed") + } + priv, err := crypto.UnmarshalPrivateKey(privBytes) + if err != nil { + t.Fatal(err) + } + privBytes2, err := priv.Bytes() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(privBytes2, privBytes) { + t.Fatal("encoding round-trip failed") + } + ok, err := pub.Verify(message, sigBytes) + if !ok || err != nil { + t.Fatal("failed to validate signature with public key") + } + + if tc.sigDeterministic { + sigBytes2, err := priv.Sign(message) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(sigBytes2, sigBytes) { + t.Fatal("signature not deterministic") + } + } + }) + } +} + +func init() { + // set to true to re-generate test data + if false { + generate() + panic("generated") + } +} + +// generate re-generates test data +func generate() { + for _, tc := range keyTypes { + priv, pub, err := tc.gen(rand.Reader) + if err != nil { + panic(err) + } + pubb, err := pub.Bytes() + if err != nil { + panic(err) + } + privb, err := priv.Bytes() + if err != nil { + panic(err) + } + sig, err := priv.Sign(message) + if err != nil { + panic(err) + } + ioutil.WriteFile(fname(tc.keyType, "pub"), pubb, 0666) + ioutil.WriteFile(fname(tc.keyType, "priv"), privb, 0666) + ioutil.WriteFile(fname(tc.keyType, "sig"), sig, 0666) + } +} diff --git a/core/crypto/key.go b/core/crypto/key.go new file mode 100644 index 0000000000..50167dbe02 --- /dev/null +++ b/core/crypto/key.go @@ -0,0 +1,351 @@ +// Package crypto implements various cryptographic utilities used by libp2p. +// This includes a Public and Private key interface and key implementations +// for supported key algorithms. +package crypto + +import ( + "bytes" + "crypto/elliptic" + "crypto/hmac" + "crypto/rand" + "crypto/sha1" + "crypto/sha512" + "encoding/base64" + "errors" + "fmt" + "hash" + "io" + + pb "github.com/libp2p/go-libp2p-core/crypto/pb" + + "github.com/gogo/protobuf/proto" + sha256 "github.com/minio/sha256-simd" +) + +const ( + // RSA is an enum for the supported RSA key type + RSA = iota + // Ed25519 is an enum for the supported Ed25519 key type + Ed25519 + // Secp256k1 is an enum for the supported Secp256k1 key type + Secp256k1 + // ECDSA is an enum for the supported ECDSA key type + ECDSA +) + +var ( + // ErrBadKeyType is returned when a key is not supported + ErrBadKeyType = errors.New("invalid or unsupported key type") + // KeyTypes is a list of supported keys + KeyTypes = []int{ + RSA, + Ed25519, + Secp256k1, + ECDSA, + } +) + +// PubKeyUnmarshaller is a func that creates a PubKey from a given slice of bytes +type PubKeyUnmarshaller func(data []byte) (PubKey, error) + +// PrivKeyUnmarshaller is a func that creates a PrivKey from a given slice of bytes +type PrivKeyUnmarshaller func(data []byte) (PrivKey, error) + +// PubKeyUnmarshallers is a map of unmarshallers by key type +var PubKeyUnmarshallers = map[pb.KeyType]PubKeyUnmarshaller{ + pb.KeyType_RSA: UnmarshalRsaPublicKey, + pb.KeyType_Ed25519: UnmarshalEd25519PublicKey, + pb.KeyType_Secp256k1: UnmarshalSecp256k1PublicKey, + pb.KeyType_ECDSA: UnmarshalECDSAPublicKey, +} + +// PrivKeyUnmarshallers is a map of unmarshallers by key type +var PrivKeyUnmarshallers = map[pb.KeyType]PrivKeyUnmarshaller{ + pb.KeyType_RSA: UnmarshalRsaPrivateKey, + pb.KeyType_Ed25519: UnmarshalEd25519PrivateKey, + pb.KeyType_Secp256k1: UnmarshalSecp256k1PrivateKey, + pb.KeyType_ECDSA: UnmarshalECDSAPrivateKey, +} + +// Key represents a crypto key that can be compared to another key +type Key interface { + // Bytes returns a serialized, storeable representation of this key + // DEPRECATED in favor of Marshal / Unmarshal + Bytes() ([]byte, error) + + // Equals checks whether two PubKeys are the same + Equals(Key) bool + + // Raw returns the raw bytes of the key (not wrapped in the + // libp2p-crypto protobuf). + // + // This function is the inverse of {Priv,Pub}KeyUnmarshaler. + Raw() ([]byte, error) + + // Type returns the protobof key type. + Type() pb.KeyType +} + +// PrivKey represents a private key that can be used to generate a public key and sign data +type PrivKey interface { + Key + + // Cryptographically sign the given bytes + Sign([]byte) ([]byte, error) + + // Return a public key paired with this private key + GetPublic() PubKey +} + +// PubKey is a public key that can be used to verifiy data signed with the corresponding private key +type PubKey interface { + Key + + // Verify that 'sig' is the signed hash of 'data' + Verify(data []byte, sig []byte) (bool, error) +} + +// GenSharedKey generates the shared key from a given private key +type GenSharedKey func([]byte) ([]byte, error) + +// GenerateKeyPair generates a private and public key +func GenerateKeyPair(typ, bits int) (PrivKey, PubKey, error) { + return GenerateKeyPairWithReader(typ, bits, rand.Reader) +} + +// GenerateKeyPairWithReader returns a keypair of the given type and bitsize +func GenerateKeyPairWithReader(typ, bits int, src io.Reader) (PrivKey, PubKey, error) { + switch typ { + case RSA: + return GenerateRSAKeyPair(bits, src) + case Ed25519: + return GenerateEd25519Key(src) + case Secp256k1: + return GenerateSecp256k1Key(src) + case ECDSA: + return GenerateECDSAKeyPair(src) + default: + return nil, nil, ErrBadKeyType + } +} + +// GenerateEKeyPair returns an ephemeral public key and returns a function that will compute +// the shared secret key. Used in the identify module. +// +// Focuses only on ECDH now, but can be made more general in the future. +func GenerateEKeyPair(curveName string) ([]byte, GenSharedKey, error) { + var curve elliptic.Curve + + switch curveName { + case "P-256": + curve = elliptic.P256() + case "P-384": + curve = elliptic.P384() + case "P-521": + curve = elliptic.P521() + } + + priv, x, y, err := elliptic.GenerateKey(curve, rand.Reader) + if err != nil { + return nil, nil, err + } + + pubKey := elliptic.Marshal(curve, x, y) + + done := func(theirPub []byte) ([]byte, error) { + // Verify and unpack node's public key. + x, y := elliptic.Unmarshal(curve, theirPub) + if x == nil { + return nil, fmt.Errorf("malformed public key: %d %v", len(theirPub), theirPub) + } + + if !curve.IsOnCurve(x, y) { + return nil, errors.New("invalid public key") + } + + // Generate shared secret. + secret, _ := curve.ScalarMult(x, y, priv) + + return secret.Bytes(), nil + } + + return pubKey, done, nil +} + +// StretchedKeys ... +type StretchedKeys struct { + IV []byte + MacKey []byte + CipherKey []byte +} + +// KeyStretcher returns a set of keys for each party by stretching the shared key. +// (myIV, theirIV, myCipherKey, theirCipherKey, myMACKey, theirMACKey) +func KeyStretcher(cipherType string, hashType string, secret []byte) (StretchedKeys, StretchedKeys) { + var cipherKeySize int + var ivSize int + switch cipherType { + case "AES-128": + ivSize = 16 + cipherKeySize = 16 + case "AES-256": + ivSize = 16 + cipherKeySize = 32 + case "Blowfish": + ivSize = 8 + // Note: cypherKeySize arbitrarily selected, needs more thought + cipherKeySize = 32 + } + + hmacKeySize := 20 + + seed := []byte("key expansion") + + result := make([]byte, 2*(ivSize+cipherKeySize+hmacKeySize)) + + var h func() hash.Hash + + switch hashType { + case "SHA1": + h = sha1.New + case "SHA256": + h = sha256.New + case "SHA512": + h = sha512.New + default: + panic("Unrecognized hash function, programmer error?") + } + + m := hmac.New(h, secret) + // note: guaranteed to never return an error + m.Write(seed) + + a := m.Sum(nil) + + j := 0 + for j < len(result) { + m.Reset() + + // note: guaranteed to never return an error. + m.Write(a) + m.Write(seed) + + b := m.Sum(nil) + + todo := len(b) + + if j+todo > len(result) { + todo = len(result) - j + } + + copy(result[j:j+todo], b) + + j += todo + + m.Reset() + + // note: guaranteed to never return an error. + m.Write(a) + + a = m.Sum(nil) + } + + half := len(result) / 2 + r1 := result[:half] + r2 := result[half:] + + var k1 StretchedKeys + var k2 StretchedKeys + + k1.IV = r1[0:ivSize] + k1.CipherKey = r1[ivSize : ivSize+cipherKeySize] + k1.MacKey = r1[ivSize+cipherKeySize:] + + k2.IV = r2[0:ivSize] + k2.CipherKey = r2[ivSize : ivSize+cipherKeySize] + k2.MacKey = r2[ivSize+cipherKeySize:] + + return k1, k2 +} + +// UnmarshalPublicKey converts a protobuf serialized public key into its +// representative object +func UnmarshalPublicKey(data []byte) (PubKey, error) { + pmes := new(pb.PublicKey) + err := proto.Unmarshal(data, pmes) + if err != nil { + return nil, err + } + + um, ok := PubKeyUnmarshallers[pmes.GetType()] + if !ok { + return nil, ErrBadKeyType + } + + return um(pmes.GetData()) +} + +// MarshalPublicKey converts a public key object into a protobuf serialized +// public key +func MarshalPublicKey(k PubKey) ([]byte, error) { + pbmes := new(pb.PublicKey) + pbmes.Type = k.Type() + data, err := k.Raw() + if err != nil { + return nil, err + } + pbmes.Data = data + + return proto.Marshal(pbmes) +} + +// UnmarshalPrivateKey converts a protobuf serialized private key into its +// representative object +func UnmarshalPrivateKey(data []byte) (PrivKey, error) { + pmes := new(pb.PrivateKey) + err := proto.Unmarshal(data, pmes) + if err != nil { + return nil, err + } + + um, ok := PrivKeyUnmarshallers[pmes.GetType()] + if !ok { + return nil, ErrBadKeyType + } + + return um(pmes.GetData()) +} + +// MarshalPrivateKey converts a key object into its protobuf serialized form. +func MarshalPrivateKey(k PrivKey) ([]byte, error) { + pbmes := new(pb.PrivateKey) + pbmes.Type = k.Type() + data, err := k.Raw() + if err != nil { + return nil, err + } + + pbmes.Data = data + return proto.Marshal(pbmes) +} + +// ConfigDecodeKey decodes from b64 (for config file), and unmarshals. +func ConfigDecodeKey(b string) ([]byte, error) { + return base64.StdEncoding.DecodeString(b) +} + +// ConfigEncodeKey encodes to b64 (for config file), and marshals. +func ConfigEncodeKey(b []byte) string { + return base64.StdEncoding.EncodeToString(b) +} + +// KeyEqual checks whether two Keys are equivalent (have identical byte representations). +func KeyEqual(k1, k2 Key) bool { + if k1 == k2 { + return true + } + + b1, err1 := k1.Bytes() + b2, err2 := k2.Bytes() + return bytes.Equal(b1, b2) && err1 == err2 +} diff --git a/core/crypto/key_test.go b/core/crypto/key_test.go new file mode 100644 index 0000000000..a93fbf0857 --- /dev/null +++ b/core/crypto/key_test.go @@ -0,0 +1,147 @@ +package crypto_test + +import ( + "bytes" + "crypto/rand" + "testing" + + . "github.com/libp2p/go-libp2p-core/crypto" + pb "github.com/libp2p/go-libp2p-core/crypto/pb" + "github.com/libp2p/go-libp2p-testing/crypto" +) + +func TestKeys(t *testing.T) { + for _, typ := range KeyTypes { + testKeyType(typ, t) + } +} + +func testKeyType(typ int, t *testing.T) { + sk, pk, err := tcrypto.RandTestKeyPair(typ, 512) + if err != nil { + t.Fatal(err) + } + + testKeySignature(t, sk) + testKeyEncoding(t, sk) + testKeyEquals(t, sk) + testKeyEquals(t, pk) +} + +func testKeySignature(t *testing.T, sk PrivKey) { + pk := sk.GetPublic() + + text := make([]byte, 16) + if _, err := rand.Read(text); err != nil { + t.Fatal(err) + } + + sig, err := sk.Sign(text) + if err != nil { + t.Fatal(err) + } + + valid, err := pk.Verify(text, sig) + if err != nil { + t.Fatal(err) + } + + if !valid { + t.Fatal("Invalid signature.") + } +} + +func testKeyEncoding(t *testing.T, sk PrivKey) { + skbm, err := MarshalPrivateKey(sk) + if err != nil { + t.Fatal(err) + } + + sk2, err := UnmarshalPrivateKey(skbm) + if err != nil { + t.Fatal(err) + } + + if !sk.Equals(sk2) { + t.Error("Unmarshaled private key didn't match original.\n") + } + + skbm2, err := MarshalPrivateKey(sk2) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(skbm, skbm2) { + t.Error("skb -> marshal -> unmarshal -> skb failed.\n", skbm, "\n", skbm2) + } + + pk := sk.GetPublic() + pkbm, err := MarshalPublicKey(pk) + if err != nil { + t.Fatal(err) + } + + pk2, err := UnmarshalPublicKey(pkbm) + if err != nil { + t.Fatal(err) + } + + if !pk.Equals(pk2) { + t.Error("Unmarshaled public key didn't match original.\n") + } + + pkbm2, err := MarshalPublicKey(pk) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(pkbm, pkbm2) { + t.Error("skb -> marshal -> unmarshal -> skb failed.\n", pkbm, "\n", pkbm2) + } +} + +func testKeyEquals(t *testing.T, k Key) { + kb, err := k.Bytes() + if err != nil { + t.Fatal(err) + } + + if !KeyEqual(k, k) { + t.Fatal("Key not equal to itself.") + } + + if !KeyEqual(k, testkey(kb)) { + t.Fatal("Key not equal to key with same bytes.") + } + + sk, pk, err := tcrypto.RandTestKeyPair(RSA, 512) + if err != nil { + t.Fatal(err) + } + + if KeyEqual(k, sk) { + t.Fatal("Keys should not equal.") + } + + if KeyEqual(k, pk) { + t.Fatal("Keys should not equal.") + } +} + +type testkey []byte + +func (pk testkey) Bytes() ([]byte, error) { + return pk, nil +} + +func (pk testkey) Type() pb.KeyType { + return pb.KeyType_RSA +} + +func (pk testkey) Raw() ([]byte, error) { + return pk, nil +} + +func (pk testkey) Equals(k Key) bool { + return KeyEqual(pk, k) +} diff --git a/core/crypto/openssl_common.go b/core/crypto/openssl_common.go new file mode 100644 index 0000000000..f235e75f33 --- /dev/null +++ b/core/crypto/openssl_common.go @@ -0,0 +1,98 @@ +// +build openssl + +package crypto + +import ( + pb "github.com/libp2p/go-libp2p-core/crypto/pb" + + openssl "github.com/spacemonkeygo/openssl" +) + +// define these as separate types so we can add more key types later and reuse +// code. + +type opensslPublicKey struct { + key openssl.PublicKey +} + +type opensslPrivateKey struct { + key openssl.PrivateKey +} + +func unmarshalOpensslPrivateKey(b []byte) (opensslPrivateKey, error) { + sk, err := openssl.LoadPrivateKeyFromDER(b) + if err != nil { + return opensslPrivateKey{}, err + } + return opensslPrivateKey{sk}, nil +} + +func unmarshalOpensslPublicKey(b []byte) (opensslPublicKey, error) { + sk, err := openssl.LoadPublicKeyFromDER(b) + if err != nil { + return opensslPublicKey{}, err + } + return opensslPublicKey{sk}, nil +} + +// Verify compares a signature against input data +func (pk *opensslPublicKey) Verify(data, sig []byte) (bool, error) { + err := pk.key.VerifyPKCS1v15(openssl.SHA256_Method, data, sig) + return err == nil, err +} + +func (pk *opensslPublicKey) Type() pb.KeyType { + switch pk.key.KeyType() { + case openssl.KeyTypeRSA: + return pb.KeyType_RSA + default: + return -1 + } +} + +// Bytes returns protobuf bytes of a public key +func (pk *opensslPublicKey) Bytes() ([]byte, error) { + return MarshalPublicKey(pk) +} + +func (pk *opensslPublicKey) Raw() ([]byte, error) { + return pk.key.MarshalPKIXPublicKeyDER() +} + +// Equals checks whether this key is equal to another +func (pk *opensslPublicKey) Equals(k Key) bool { + return KeyEqual(pk, k) +} + +// Sign returns a signature of the input data +func (sk *opensslPrivateKey) Sign(message []byte) ([]byte, error) { + return sk.key.SignPKCS1v15(openssl.SHA256_Method, message) +} + +// GetPublic returns a public key +func (sk *opensslPrivateKey) GetPublic() PubKey { + return &opensslPublicKey{sk.key} +} + +func (sk *opensslPrivateKey) Type() pb.KeyType { + switch sk.key.KeyType() { + case openssl.KeyTypeRSA: + return pb.KeyType_RSA + default: + return -1 + } +} + +// Bytes returns protobuf bytes from a private key +func (sk *opensslPrivateKey) Bytes() ([]byte, error) { + return MarshalPrivateKey(sk) +} + +func (sk *opensslPrivateKey) Raw() ([]byte, error) { + return sk.key.MarshalPKCS1PrivateKeyDER() +} + +// Equals checks whether this key is equal to another +func (sk *opensslPrivateKey) Equals(k Key) bool { + return KeyEqual(sk, k) +} diff --git a/core/crypto/pb/Makefile b/core/crypto/pb/Makefile new file mode 100644 index 0000000000..df34e54b01 --- /dev/null +++ b/core/crypto/pb/Makefile @@ -0,0 +1,11 @@ +PB = $(wildcard *.proto) +GO = $(PB:.proto=.pb.go) + +all: $(GO) + +%.pb.go: %.proto + protoc --proto_path=$(GOPATH)/src:. --gogofaster_out=. $< + +clean: + rm -f *.pb.go + rm -f *.go diff --git a/core/crypto/pb/crypto.pb.go b/core/crypto/pb/crypto.pb.go new file mode 100644 index 0000000000..5fa7aec73f --- /dev/null +++ b/core/crypto/pb/crypto.pb.go @@ -0,0 +1,643 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: crypto.proto + +package crypto_pb + +import ( + fmt "fmt" + github_com_gogo_protobuf_proto "github.com/gogo/protobuf/proto" + proto "github.com/gogo/protobuf/proto" + io "io" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package + +type KeyType int32 + +const ( + KeyType_RSA KeyType = 0 + KeyType_Ed25519 KeyType = 1 + KeyType_Secp256k1 KeyType = 2 + KeyType_ECDSA KeyType = 3 +) + +var KeyType_name = map[int32]string{ + 0: "RSA", + 1: "Ed25519", + 2: "Secp256k1", + 3: "ECDSA", +} + +var KeyType_value = map[string]int32{ + "RSA": 0, + "Ed25519": 1, + "Secp256k1": 2, + "ECDSA": 3, +} + +func (x KeyType) Enum() *KeyType { + p := new(KeyType) + *p = x + return p +} + +func (x KeyType) String() string { + return proto.EnumName(KeyType_name, int32(x)) +} + +func (x *KeyType) UnmarshalJSON(data []byte) error { + value, err := proto.UnmarshalJSONEnum(KeyType_value, data, "KeyType") + if err != nil { + return err + } + *x = KeyType(value) + return nil +} + +func (KeyType) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_527278fb02d03321, []int{0} +} + +type PublicKey struct { + Type KeyType `protobuf:"varint,1,req,name=Type,enum=crypto.pb.KeyType" json:"Type"` + Data []byte `protobuf:"bytes,2,req,name=Data" json:"Data"` +} + +func (m *PublicKey) Reset() { *m = PublicKey{} } +func (m *PublicKey) String() string { return proto.CompactTextString(m) } +func (*PublicKey) ProtoMessage() {} +func (*PublicKey) Descriptor() ([]byte, []int) { + return fileDescriptor_527278fb02d03321, []int{0} +} +func (m *PublicKey) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *PublicKey) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_PublicKey.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalTo(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *PublicKey) XXX_Merge(src proto.Message) { + xxx_messageInfo_PublicKey.Merge(m, src) +} +func (m *PublicKey) XXX_Size() int { + return m.Size() +} +func (m *PublicKey) XXX_DiscardUnknown() { + xxx_messageInfo_PublicKey.DiscardUnknown(m) +} + +var xxx_messageInfo_PublicKey proto.InternalMessageInfo + +func (m *PublicKey) GetType() KeyType { + if m != nil { + return m.Type + } + return KeyType_RSA +} + +func (m *PublicKey) GetData() []byte { + if m != nil { + return m.Data + } + return nil +} + +type PrivateKey struct { + Type KeyType `protobuf:"varint,1,req,name=Type,enum=crypto.pb.KeyType" json:"Type"` + Data []byte `protobuf:"bytes,2,req,name=Data" json:"Data"` +} + +func (m *PrivateKey) Reset() { *m = PrivateKey{} } +func (m *PrivateKey) String() string { return proto.CompactTextString(m) } +func (*PrivateKey) ProtoMessage() {} +func (*PrivateKey) Descriptor() ([]byte, []int) { + return fileDescriptor_527278fb02d03321, []int{1} +} +func (m *PrivateKey) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *PrivateKey) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_PrivateKey.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalTo(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *PrivateKey) XXX_Merge(src proto.Message) { + xxx_messageInfo_PrivateKey.Merge(m, src) +} +func (m *PrivateKey) XXX_Size() int { + return m.Size() +} +func (m *PrivateKey) XXX_DiscardUnknown() { + xxx_messageInfo_PrivateKey.DiscardUnknown(m) +} + +var xxx_messageInfo_PrivateKey proto.InternalMessageInfo + +func (m *PrivateKey) GetType() KeyType { + if m != nil { + return m.Type + } + return KeyType_RSA +} + +func (m *PrivateKey) GetData() []byte { + if m != nil { + return m.Data + } + return nil +} + +func init() { + proto.RegisterEnum("crypto.pb.KeyType", KeyType_name, KeyType_value) + proto.RegisterType((*PublicKey)(nil), "crypto.pb.PublicKey") + proto.RegisterType((*PrivateKey)(nil), "crypto.pb.PrivateKey") +} + +func init() { proto.RegisterFile("crypto.proto", fileDescriptor_527278fb02d03321) } + +var fileDescriptor_527278fb02d03321 = []byte{ + // 203 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x49, 0x2e, 0xaa, 0x2c, + 0x28, 0xc9, 0xd7, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x84, 0xf1, 0x92, 0x94, 0x82, 0xb9, + 0x38, 0x03, 0x4a, 0x93, 0x72, 0x32, 0x93, 0xbd, 0x53, 0x2b, 0x85, 0x74, 0xb8, 0x58, 0x42, 0x2a, + 0x0b, 0x52, 0x25, 0x18, 0x15, 0x98, 0x34, 0xf8, 0x8c, 0x84, 0xf4, 0xe0, 0xca, 0xf4, 0xbc, 0x53, + 0x2b, 0x41, 0x32, 0x4e, 0x2c, 0x27, 0xee, 0xc9, 0x33, 0x04, 0x81, 0x55, 0x09, 0x49, 0x70, 0xb1, + 0xb8, 0x24, 0x96, 0x24, 0x4a, 0x30, 0x29, 0x30, 0x69, 0xf0, 0xc0, 0x64, 0x40, 0x22, 0x4a, 0x21, + 0x5c, 0x5c, 0x01, 0x45, 0x99, 0x65, 0x89, 0x25, 0xa9, 0x54, 0x34, 0x55, 0xcb, 0x92, 0x8b, 0x1d, + 0xaa, 0x41, 0x88, 0x9d, 0x8b, 0x39, 0x28, 0xd8, 0x51, 0x80, 0x41, 0x88, 0x9b, 0x8b, 0xdd, 0x35, + 0xc5, 0xc8, 0xd4, 0xd4, 0xd0, 0x52, 0x80, 0x51, 0x88, 0x97, 0x8b, 0x33, 0x38, 0x35, 0xb9, 0xc0, + 0xc8, 0xd4, 0x2c, 0xdb, 0x50, 0x80, 0x49, 0x88, 0x93, 0x8b, 0xd5, 0xd5, 0xd9, 0x25, 0xd8, 0x51, + 0x80, 0xd9, 0x49, 0xe2, 0xc4, 0x23, 0x39, 0xc6, 0x0b, 0x8f, 0xe4, 0x18, 0x1f, 0x3c, 0x92, 0x63, + 0x9c, 0xf0, 0x58, 0x8e, 0xe1, 0xc2, 0x63, 0x39, 0x86, 0x1b, 0x8f, 0xe5, 0x18, 0x00, 0x01, 0x00, + 0x00, 0xff, 0xff, 0x13, 0xbe, 0xd4, 0xff, 0x19, 0x01, 0x00, 0x00, +} + +func (m *PublicKey) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *PublicKey) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + dAtA[i] = 0x8 + i++ + i = encodeVarintCrypto(dAtA, i, uint64(m.Type)) + if m.Data != nil { + dAtA[i] = 0x12 + i++ + i = encodeVarintCrypto(dAtA, i, uint64(len(m.Data))) + i += copy(dAtA[i:], m.Data) + } + return i, nil +} + +func (m *PrivateKey) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *PrivateKey) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + dAtA[i] = 0x8 + i++ + i = encodeVarintCrypto(dAtA, i, uint64(m.Type)) + if m.Data != nil { + dAtA[i] = 0x12 + i++ + i = encodeVarintCrypto(dAtA, i, uint64(len(m.Data))) + i += copy(dAtA[i:], m.Data) + } + return i, nil +} + +func encodeVarintCrypto(dAtA []byte, offset int, v uint64) int { + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return offset + 1 +} +func (m *PublicKey) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + n += 1 + sovCrypto(uint64(m.Type)) + if m.Data != nil { + l = len(m.Data) + n += 1 + l + sovCrypto(uint64(l)) + } + return n +} + +func (m *PrivateKey) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + n += 1 + sovCrypto(uint64(m.Type)) + if m.Data != nil { + l = len(m.Data) + n += 1 + l + sovCrypto(uint64(l)) + } + return n +} + +func sovCrypto(x uint64) (n int) { + for { + n++ + x >>= 7 + if x == 0 { + break + } + } + return n +} +func sozCrypto(x uint64) (n int) { + return sovCrypto(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *PublicKey) Unmarshal(dAtA []byte) error { + var hasFields [1]uint64 + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowCrypto + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: PublicKey: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: PublicKey: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Type", wireType) + } + m.Type = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowCrypto + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Type |= KeyType(b&0x7F) << shift + if b < 0x80 { + break + } + } + hasFields[0] |= uint64(0x00000001) + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Data", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowCrypto + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthCrypto + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthCrypto + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Data = append(m.Data[:0], dAtA[iNdEx:postIndex]...) + if m.Data == nil { + m.Data = []byte{} + } + iNdEx = postIndex + hasFields[0] |= uint64(0x00000002) + default: + iNdEx = preIndex + skippy, err := skipCrypto(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthCrypto + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthCrypto + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + if hasFields[0]&uint64(0x00000001) == 0 { + return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Type") + } + if hasFields[0]&uint64(0x00000002) == 0 { + return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Data") + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *PrivateKey) Unmarshal(dAtA []byte) error { + var hasFields [1]uint64 + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowCrypto + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: PrivateKey: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: PrivateKey: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Type", wireType) + } + m.Type = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowCrypto + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Type |= KeyType(b&0x7F) << shift + if b < 0x80 { + break + } + } + hasFields[0] |= uint64(0x00000001) + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Data", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowCrypto + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthCrypto + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthCrypto + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Data = append(m.Data[:0], dAtA[iNdEx:postIndex]...) + if m.Data == nil { + m.Data = []byte{} + } + iNdEx = postIndex + hasFields[0] |= uint64(0x00000002) + default: + iNdEx = preIndex + skippy, err := skipCrypto(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthCrypto + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthCrypto + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + if hasFields[0]&uint64(0x00000001) == 0 { + return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Type") + } + if hasFields[0]&uint64(0x00000002) == 0 { + return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Data") + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipCrypto(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowCrypto + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowCrypto + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowCrypto + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLengthCrypto + } + iNdEx += length + if iNdEx < 0 { + return 0, ErrInvalidLengthCrypto + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start int = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowCrypto + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skipCrypto(dAtA[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + if iNdEx < 0 { + return 0, ErrInvalidLengthCrypto + } + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") +} + +var ( + ErrInvalidLengthCrypto = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowCrypto = fmt.Errorf("proto: integer overflow") +) diff --git a/core/crypto/pb/crypto.proto b/core/crypto/pb/crypto.proto new file mode 100644 index 0000000000..cb5cee8a22 --- /dev/null +++ b/core/crypto/pb/crypto.proto @@ -0,0 +1,20 @@ +syntax = "proto2"; + +package crypto.pb; + +enum KeyType { + RSA = 0; + Ed25519 = 1; + Secp256k1 = 2; + ECDSA = 3; +} + +message PublicKey { + required KeyType Type = 1; + required bytes Data = 2; +} + +message PrivateKey { + required KeyType Type = 1; + required bytes Data = 2; +} diff --git a/core/crypto/rsa_common.go b/core/crypto/rsa_common.go new file mode 100644 index 0000000000..d50651f218 --- /dev/null +++ b/core/crypto/rsa_common.go @@ -0,0 +1,10 @@ +package crypto + +import ( + "errors" +) + +// ErrRsaKeyTooSmall is returned when trying to generate or parse an RSA key +// that's smaller than 512 bits. Keys need to be larger enough to sign a 256bit +// hash so this is a reasonable absolute minimum. +var ErrRsaKeyTooSmall = errors.New("rsa keys must be >= 512 bits to be useful") diff --git a/core/crypto/rsa_go.go b/core/crypto/rsa_go.go new file mode 100644 index 0000000000..e9813779dc --- /dev/null +++ b/core/crypto/rsa_go.go @@ -0,0 +1,125 @@ +// +build !openssl + +package crypto + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "errors" + "io" + + pb "github.com/libp2p/go-libp2p-core/crypto/pb" + + "github.com/minio/sha256-simd" +) + +// RsaPrivateKey is an rsa private key +type RsaPrivateKey struct { + sk rsa.PrivateKey +} + +// RsaPublicKey is an rsa public key +type RsaPublicKey struct { + k rsa.PublicKey +} + +// GenerateRSAKeyPair generates a new rsa private and public key +func GenerateRSAKeyPair(bits int, src io.Reader) (PrivKey, PubKey, error) { + if bits < 512 { + return nil, nil, ErrRsaKeyTooSmall + } + priv, err := rsa.GenerateKey(src, bits) + if err != nil { + return nil, nil, err + } + pk := priv.PublicKey + return &RsaPrivateKey{sk: *priv}, &RsaPublicKey{pk}, nil +} + +// Verify compares a signature against input data +func (pk *RsaPublicKey) Verify(data, sig []byte) (bool, error) { + hashed := sha256.Sum256(data) + err := rsa.VerifyPKCS1v15(&pk.k, crypto.SHA256, hashed[:], sig) + if err != nil { + return false, err + } + return true, nil +} + +func (pk *RsaPublicKey) Type() pb.KeyType { + return pb.KeyType_RSA +} + +// Bytes returns protobuf bytes of a public key +func (pk *RsaPublicKey) Bytes() ([]byte, error) { + return MarshalPublicKey(pk) +} + +func (pk *RsaPublicKey) Raw() ([]byte, error) { + return x509.MarshalPKIXPublicKey(&pk.k) +} + +// Equals checks whether this key is equal to another +func (pk *RsaPublicKey) Equals(k Key) bool { + return KeyEqual(pk, k) +} + +// Sign returns a signature of the input data +func (sk *RsaPrivateKey) Sign(message []byte) ([]byte, error) { + hashed := sha256.Sum256(message) + return rsa.SignPKCS1v15(rand.Reader, &sk.sk, crypto.SHA256, hashed[:]) +} + +// GetPublic returns a public key +func (sk *RsaPrivateKey) GetPublic() PubKey { + return &RsaPublicKey{sk.sk.PublicKey} +} + +func (sk *RsaPrivateKey) Type() pb.KeyType { + return pb.KeyType_RSA +} + +// Bytes returns protobuf bytes from a private key +func (sk *RsaPrivateKey) Bytes() ([]byte, error) { + return MarshalPrivateKey(sk) +} + +func (sk *RsaPrivateKey) Raw() ([]byte, error) { + b := x509.MarshalPKCS1PrivateKey(&sk.sk) + return b, nil +} + +// Equals checks whether this key is equal to another +func (sk *RsaPrivateKey) Equals(k Key) bool { + return KeyEqual(sk, k) +} + +// UnmarshalRsaPrivateKey returns a private key from the input x509 bytes +func UnmarshalRsaPrivateKey(b []byte) (PrivKey, error) { + sk, err := x509.ParsePKCS1PrivateKey(b) + if err != nil { + return nil, err + } + if sk.N.BitLen() < 512 { + return nil, ErrRsaKeyTooSmall + } + return &RsaPrivateKey{sk: *sk}, nil +} + +// UnmarshalRsaPublicKey returns a public key from the input x509 bytes +func UnmarshalRsaPublicKey(b []byte) (PubKey, error) { + pub, err := x509.ParsePKIXPublicKey(b) + if err != nil { + return nil, err + } + pk, ok := pub.(*rsa.PublicKey) + if !ok { + return nil, errors.New("not actually an rsa public key") + } + if pk.N.BitLen() < 512 { + return nil, ErrRsaKeyTooSmall + } + return &RsaPublicKey{*pk}, nil +} diff --git a/core/crypto/rsa_openssl.go b/core/crypto/rsa_openssl.go new file mode 100644 index 0000000000..96c5588615 --- /dev/null +++ b/core/crypto/rsa_openssl.go @@ -0,0 +1,62 @@ +// +build openssl + +package crypto + +import ( + "errors" + "io" + + openssl "github.com/spacemonkeygo/openssl" +) + +// RsaPrivateKey is an rsa private key +type RsaPrivateKey struct { + opensslPrivateKey +} + +// RsaPublicKey is an rsa public key +type RsaPublicKey struct { + opensslPublicKey +} + +// GenerateRSAKeyPair generates a new rsa private and public key +func GenerateRSAKeyPair(bits int, _ io.Reader) (PrivKey, PubKey, error) { + if bits < 512 { + return nil, nil, ErrRsaKeyTooSmall + } + + key, err := openssl.GenerateRSAKey(bits) + if err != nil { + return nil, nil, err + } + return &RsaPrivateKey{opensslPrivateKey{key}}, &RsaPublicKey{opensslPublicKey{key}}, nil +} + +// GetPublic returns a public key +func (sk *RsaPrivateKey) GetPublic() PubKey { + return &RsaPublicKey{opensslPublicKey{sk.opensslPrivateKey.key}} +} + +// UnmarshalRsaPrivateKey returns a private key from the input x509 bytes +func UnmarshalRsaPrivateKey(b []byte) (PrivKey, error) { + key, err := unmarshalOpensslPrivateKey(b) + if err != nil { + return nil, err + } + if key.Type() != RSA { + return nil, errors.New("not actually an rsa public key") + } + return &RsaPrivateKey{key}, nil +} + +// UnmarshalRsaPublicKey returns a public key from the input x509 bytes +func UnmarshalRsaPublicKey(b []byte) (PubKey, error) { + key, err := unmarshalOpensslPublicKey(b) + if err != nil { + return nil, err + } + if key.Type() != RSA { + return nil, errors.New("not actually an rsa public key") + } + return &RsaPublicKey{key}, nil +} diff --git a/core/crypto/rsa_test.go b/core/crypto/rsa_test.go new file mode 100644 index 0000000000..7ee520acb7 --- /dev/null +++ b/core/crypto/rsa_test.go @@ -0,0 +1,102 @@ +package crypto + +import ( + "crypto/rand" + "testing" +) + +func TestRSABasicSignAndVerify(t *testing.T) { + priv, pub, err := GenerateRSAKeyPair(512, rand.Reader) + if err != nil { + t.Fatal(err) + } + + data := []byte("hello! and welcome to some awesome crypto primitives") + + sig, err := priv.Sign(data) + if err != nil { + t.Fatal(err) + } + + ok, err := pub.Verify(data, sig) + if err != nil { + t.Fatal(err) + } + + if !ok { + t.Fatal("signature didnt match") + } + + // change data + data[0] = ^data[0] + ok, err = pub.Verify(data, sig) + if err == nil { + t.Fatal("should have produced a verification error") + } + + if ok { + t.Fatal("signature matched and shouldn't") + } +} + +func TestRSASmallKey(t *testing.T) { + _, _, err := GenerateRSAKeyPair(384, rand.Reader) + if err != ErrRsaKeyTooSmall { + t.Fatal("should have refused to create small RSA key") + } +} + +func TestRSASignZero(t *testing.T) { + priv, pub, err := GenerateRSAKeyPair(512, rand.Reader) + if err != nil { + t.Fatal(err) + } + + data := make([]byte, 0) + sig, err := priv.Sign(data) + if err != nil { + t.Fatal(err) + } + + ok, err := pub.Verify(data, sig) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("signature didn't match") + } +} + +func TestRSAMarshalLoop(t *testing.T) { + priv, pub, err := GenerateRSAKeyPair(512, rand.Reader) + if err != nil { + t.Fatal(err) + } + + privB, err := priv.Bytes() + if err != nil { + t.Fatal(err) + } + + privNew, err := UnmarshalPrivateKey(privB) + if err != nil { + t.Fatal(err) + } + + if !priv.Equals(privNew) || !privNew.Equals(priv) { + t.Fatal("keys are not equal") + } + + pubB, err := pub.Bytes() + if err != nil { + t.Fatal(err) + } + pubNew, err := UnmarshalPublicKey(pubB) + if err != nil { + t.Fatal(err) + } + + if !pub.Equals(pubNew) || !pubNew.Equals(pub) { + t.Fatal("keys are not equal") + } +} diff --git a/core/crypto/secp256k1.go b/core/crypto/secp256k1.go new file mode 100644 index 0000000000..d2ac74b996 --- /dev/null +++ b/core/crypto/secp256k1.go @@ -0,0 +1,125 @@ +package crypto + +import ( + "fmt" + "io" + + pb "github.com/libp2p/go-libp2p-core/crypto/pb" + + btcec "github.com/btcsuite/btcd/btcec" + sha256 "github.com/minio/sha256-simd" +) + +// Secp256k1PrivateKey is an Secp256k1 private key +type Secp256k1PrivateKey btcec.PrivateKey + +// Secp256k1PublicKey is an Secp256k1 public key +type Secp256k1PublicKey btcec.PublicKey + +// GenerateSecp256k1Key generates a new Secp256k1 private and public key pair +func GenerateSecp256k1Key(src io.Reader) (PrivKey, PubKey, error) { + privk, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + return nil, nil, err + } + + k := (*Secp256k1PrivateKey)(privk) + return k, k.GetPublic(), nil +} + +// UnmarshalSecp256k1PrivateKey returns a private key from bytes +func UnmarshalSecp256k1PrivateKey(data []byte) (PrivKey, error) { + if len(data) != btcec.PrivKeyBytesLen { + return nil, fmt.Errorf("expected secp256k1 data size to be %d", btcec.PrivKeyBytesLen) + } + + privk, _ := btcec.PrivKeyFromBytes(btcec.S256(), data) + return (*Secp256k1PrivateKey)(privk), nil +} + +// UnmarshalSecp256k1PublicKey returns a public key from bytes +func UnmarshalSecp256k1PublicKey(data []byte) (PubKey, error) { + k, err := btcec.ParsePubKey(data, btcec.S256()) + if err != nil { + return nil, err + } + + return (*Secp256k1PublicKey)(k), nil +} + +// Bytes returns protobuf bytes from a private key +func (k *Secp256k1PrivateKey) Bytes() ([]byte, error) { + return MarshalPrivateKey(k) +} + +// Type returns the private key type +func (k *Secp256k1PrivateKey) Type() pb.KeyType { + return pb.KeyType_Secp256k1 +} + +// Raw returns the bytes of the key +func (k *Secp256k1PrivateKey) Raw() ([]byte, error) { + return (*btcec.PrivateKey)(k).Serialize(), nil +} + +// Equals compares two private keys +func (k *Secp256k1PrivateKey) Equals(o Key) bool { + sk, ok := o.(*Secp256k1PrivateKey) + if !ok { + return false + } + + return k.D.Cmp(sk.D) == 0 +} + +// Sign returns a signature from input data +func (k *Secp256k1PrivateKey) Sign(data []byte) ([]byte, error) { + hash := sha256.Sum256(data) + sig, err := (*btcec.PrivateKey)(k).Sign(hash[:]) + if err != nil { + return nil, err + } + + return sig.Serialize(), nil +} + +// GetPublic returns a public key +func (k *Secp256k1PrivateKey) GetPublic() PubKey { + return (*Secp256k1PublicKey)((*btcec.PrivateKey)(k).PubKey()) +} + +// Bytes returns protobuf bytes from a public key +func (k *Secp256k1PublicKey) Bytes() ([]byte, error) { + return MarshalPublicKey(k) +} + +// Type returns the public key type +func (k *Secp256k1PublicKey) Type() pb.KeyType { + return pb.KeyType_Secp256k1 +} + +// Raw returns the bytes of the key +func (k *Secp256k1PublicKey) Raw() ([]byte, error) { + return (*btcec.PublicKey)(k).SerializeCompressed(), nil +} + +// Equals compares two public keys +func (k *Secp256k1PublicKey) Equals(o Key) bool { + sk, ok := o.(*Secp256k1PublicKey) + if !ok { + return false + } + + return (*btcec.PublicKey)(k).IsEqual((*btcec.PublicKey)(sk)) +} + +// Verify compares a signature against the input data +func (k *Secp256k1PublicKey) Verify(data []byte, sigStr []byte) (bool, error) { + sig, err := btcec.ParseDERSignature(sigStr, btcec.S256()) + if err != nil { + return false, err + } + + hash := sha256.Sum256(data) + return sig.Verify(hash[:], (*btcec.PublicKey)(k)), nil +} diff --git a/core/crypto/secp256k1_test.go b/core/crypto/secp256k1_test.go new file mode 100644 index 0000000000..aff2bd9773 --- /dev/null +++ b/core/crypto/secp256k1_test.go @@ -0,0 +1,96 @@ +package crypto + +import ( + "crypto/rand" + "testing" +) + +func TestSecp256k1BasicSignAndVerify(t *testing.T) { + priv, pub, err := GenerateSecp256k1Key(rand.Reader) + if err != nil { + t.Fatal(err) + } + + data := []byte("hello! and welcome to some awesome crypto primitives") + + sig, err := priv.Sign(data) + if err != nil { + t.Fatal(err) + } + + ok, err := pub.Verify(data, sig) + if err != nil { + t.Fatal(err) + } + + if !ok { + t.Fatal("signature didnt match") + } + + // change data + data[0] = ^data[0] + ok, err = pub.Verify(data, sig) + if err != nil { + t.Fatal(err) + } + + if ok { + t.Fatal("signature matched and shouldn't") + } +} + +func TestSecp256k1SignZero(t *testing.T) { + priv, pub, err := GenerateSecp256k1Key(rand.Reader) + if err != nil { + t.Fatal(err) + } + + data := make([]byte, 0) + sig, err := priv.Sign(data) + if err != nil { + t.Fatal(err) + } + + ok, err := pub.Verify(data, sig) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("signature didn't match") + } +} + +func TestSecp256k1MarshalLoop(t *testing.T) { + priv, pub, err := GenerateSecp256k1Key(rand.Reader) + if err != nil { + t.Fatal(err) + } + + privB, err := priv.Bytes() + if err != nil { + t.Fatal(err) + } + + privNew, err := UnmarshalPrivateKey(privB) + if err != nil { + t.Fatal(err) + } + + if !priv.Equals(privNew) || !privNew.Equals(priv) { + t.Fatal("keys are not equal") + } + + pubB, err := pub.Bytes() + if err != nil { + t.Fatal(err) + } + pubNew, err := UnmarshalPublicKey(pubB) + if err != nil { + t.Fatal(err) + } + + if !pub.Equals(pubNew) || !pubNew.Equals(pub) { + t.Fatal("keys are not equal") + } + +} diff --git a/core/crypto/test_data/0.priv b/core/crypto/test_data/0.priv new file mode 100644 index 0000000000..9047d5d95d Binary files /dev/null and b/core/crypto/test_data/0.priv differ diff --git a/core/crypto/test_data/0.pub b/core/crypto/test_data/0.pub new file mode 100644 index 0000000000..d4295e8893 Binary files /dev/null and b/core/crypto/test_data/0.pub differ diff --git a/core/crypto/test_data/0.sig b/core/crypto/test_data/0.sig new file mode 100644 index 0000000000..2f16825274 Binary files /dev/null and b/core/crypto/test_data/0.sig differ diff --git a/core/crypto/test_data/2.priv b/core/crypto/test_data/2.priv new file mode 100644 index 0000000000..1004ab9153 --- /dev/null +++ b/core/crypto/test_data/2.priv @@ -0,0 +1 @@ + 1A�`jPLD�4���N�[��-��X����F�X \ No newline at end of file diff --git a/core/crypto/test_data/2.pub b/core/crypto/test_data/2.pub new file mode 100644 index 0000000000..dd984bdcc8 --- /dev/null +++ b/core/crypto/test_data/2.pub @@ -0,0 +1 @@ +!5@��*��5Q�M���U��&Pk��S�����֢ \ No newline at end of file diff --git a/core/crypto/test_data/2.sig b/core/crypto/test_data/2.sig new file mode 100644 index 0000000000..b96001bc8b --- /dev/null +++ b/core/crypto/test_data/2.sig @@ -0,0 +1 @@ +0D 1��3���Z�Cu��ܛ�@��������L�� I�!�E���Gu�Cꏲ�pCG�5I<@;��Y��� \ No newline at end of file diff --git a/core/crypto/test_data/3.priv b/core/crypto/test_data/3.priv new file mode 100644 index 0000000000..7a05f359f8 Binary files /dev/null and b/core/crypto/test_data/3.priv differ diff --git a/core/crypto/test_data/3.pub b/core/crypto/test_data/3.pub new file mode 100644 index 0000000000..f4551f8811 Binary files /dev/null and b/core/crypto/test_data/3.pub differ diff --git a/core/crypto/test_data/3.sig b/core/crypto/test_data/3.sig new file mode 100644 index 0000000000..253c09f6d7 Binary files /dev/null and b/core/crypto/test_data/3.sig differ diff --git a/core/discovery/discovery.go b/core/discovery/discovery.go new file mode 100644 index 0000000000..f463e0e852 --- /dev/null +++ b/core/discovery/discovery.go @@ -0,0 +1,27 @@ +// Package discovery provides service advertisement and peer discovery interfaces for libp2p. +package discovery + +import ( + "context" + "time" + + "github.com/libp2p/go-libp2p-core/peer" +) + +// Advertiser is an interface for advertising services +type Advertiser interface { + // Advertise advertises a service + Advertise(ctx context.Context, ns string, opts ...Option) (time.Duration, error) +} + +// Discoverer is an interface for peer discovery +type Discoverer interface { + // FindPeers discovers peers providing a service + FindPeers(ctx context.Context, ns string, opts ...Option) (<-chan peer.AddrInfo, error) +} + +// Discovery is an interface that combines service advertisement and peer discovery +type Discovery interface { + Advertiser + Discoverer +} diff --git a/core/discovery/options.go b/core/discovery/options.go new file mode 100644 index 0000000000..7b28305268 --- /dev/null +++ b/core/discovery/options.go @@ -0,0 +1,41 @@ +package discovery + +import "time" + +// DiscoveryOpt is a single discovery option. +type Option func(opts *Options) error + +// DiscoveryOpts is a set of discovery options. +type Options struct { + Ttl time.Duration + Limit int + + // Other (implementation-specific) options + Other map[interface{}]interface{} +} + +// Apply applies the given options to this DiscoveryOpts +func (opts *Options) Apply(options ...Option) error { + for _, o := range options { + if err := o(opts); err != nil { + return err + } + } + return nil +} + +// TTL is an option that provides a hint for the duration of an advertisement +func TTL(ttl time.Duration) Option { + return func(opts *Options) error { + opts.Ttl = ttl + return nil + } +} + +// Limit is an option that provides an upper bound on the peer count for discovery +func Limit(limit int) Option { + return func(opts *Options) error { + opts.Limit = limit + return nil + } +} diff --git a/core/helpers/match.go b/core/helpers/match.go new file mode 100644 index 0000000000..2e24667897 --- /dev/null +++ b/core/helpers/match.go @@ -0,0 +1,42 @@ +package helpers + +import ( + "strings" + + "github.com/coreos/go-semver/semver" + "github.com/libp2p/go-libp2p-core/protocol" +) + +// MultistreamSemverMatcher returns a matcher function for a given base protocol. +// The matcher function will return a boolean indicating whether a protocol ID +// matches the base protocol. A given protocol ID matches the base protocol if +// the IDs are the same and if the semantic version of the base protocol is the +// same or higher than that of the protocol ID provided. +// TODO +func MultistreamSemverMatcher(base protocol.ID) (func(string) bool, error) { + parts := strings.Split(string(base), "/") + vers, err := semver.NewVersion(parts[len(parts)-1]) + if err != nil { + return nil, err + } + + return func(check string) bool { + chparts := strings.Split(check, "/") + if len(chparts) != len(parts) { + return false + } + + for i, v := range chparts[:len(chparts)-1] { + if parts[i] != v { + return false + } + } + + chvers, err := semver.NewVersion(chparts[len(chparts)-1]) + if err != nil { + return false + } + + return vers.Major == chvers.Major && vers.Minor >= chvers.Minor + }, nil +} diff --git a/core/helpers/match_test.go b/core/helpers/match_test.go new file mode 100644 index 0000000000..8a6e4b678d --- /dev/null +++ b/core/helpers/match_test.go @@ -0,0 +1,33 @@ +package helpers + +import ( + "testing" +) + +func TestSemverMatching(t *testing.T) { + m, err := MultistreamSemverMatcher("/testing/4.3.5") + if err != nil { + t.Fatal(err) + } + + cases := map[string]bool{ + "/testing/4.3.0": true, + "/testing/4.3.7": true, + "/testing/4.3.5": true, + "/testing/4.2.7": true, + "/testing/4.0.0": true, + "/testing/5.0.0": false, + "/cars/dogs/4.3.5": false, + "/foo/1.0.0": false, + "": false, + "dogs": false, + "/foo": false, + "/foo/1.1.1.1": false, + } + + for p, ok := range cases { + if m(p) != ok { + t.Fatalf("expected %s to be %t", p, ok) + } + } +} diff --git a/core/helpers/stream.go b/core/helpers/stream.go new file mode 100644 index 0000000000..0b4c1f3e00 --- /dev/null +++ b/core/helpers/stream.go @@ -0,0 +1,56 @@ +package helpers + +import ( + "errors" + "io" + "time" + + "github.com/libp2p/go-libp2p-core/network" +) + +// EOFTimeout is the maximum amount of time to wait to successfully observe an +// EOF on the stream. Defaults to 60 seconds. +var EOFTimeout = time.Second * 60 + +// ErrExpectedEOF is returned when we read data while expecting an EOF. +var ErrExpectedEOF = errors.New("read data when expecting EOF") + +// FullClose closes the stream and waits to read an EOF from the other side. +// +// * If it reads any data *before* the EOF, it resets the stream. +// * If it doesn't read an EOF within EOFTimeout, it resets the stream. +// +// You'll likely want to invoke this as `go FullClose(stream)` to close the +// stream in the background. +func FullClose(s network.Stream) error { + if err := s.Close(); err != nil { + s.Reset() + return err + } + return AwaitEOF(s) +} + +// AwaitEOF waits for an EOF on the given stream, returning an error if that +// fails. It waits at most EOFTimeout (defaults to 1 minute) after which it +// resets the stream. +func AwaitEOF(s network.Stream) error { + // So we don't wait forever + s.SetDeadline(time.Now().Add(EOFTimeout)) + + // We *have* to observe the EOF. Otherwise, we leak the stream. + // Now, technically, we should do this *before* + // returning from SendMessage as the message + // hasn't really been sent yet until we see the + // EOF but we don't actually *know* what + // protocol the other side is speaking. + n, err := s.Read([]byte{0}) + if n > 0 || err == nil { + s.Reset() + return ErrExpectedEOF + } + if err != io.EOF { + s.Reset() + return err + } + return nil +} diff --git a/core/helpers/stream_test.go b/core/helpers/stream_test.go new file mode 100644 index 0000000000..d267f1ba3d --- /dev/null +++ b/core/helpers/stream_test.go @@ -0,0 +1,151 @@ +package helpers_test + +import ( + "errors" + "io" + "testing" + "time" + + "github.com/libp2p/go-libp2p-core/helpers" + network "github.com/libp2p/go-libp2p-core/network" +) + +var errCloseFailed = errors.New("close failed") +var errWriteFailed = errors.New("write failed") +var errReadFailed = errors.New("read failed") + +type stream struct { + network.Stream + + data []byte + + failRead, failWrite, failClose bool + + reset bool +} + +func (s *stream) Reset() error { + s.reset = true + return nil +} + +func (s *stream) Close() error { + if s.failClose { + return errCloseFailed + } + return nil +} + +func (s *stream) SetDeadline(t time.Time) error { + s.SetReadDeadline(t) + s.SetWriteDeadline(t) + return nil +} + +func (s *stream) SetReadDeadline(t time.Time) error { + return nil +} + +func (s *stream) SetWriteDeadline(t time.Time) error { + return nil +} + +func (s *stream) Write(b []byte) (int, error) { + if s.failWrite { + return 0, errWriteFailed + } + return len(b), nil +} + +func (s *stream) Read(b []byte) (int, error) { + var err error + if s.failRead { + err = errReadFailed + } + if len(s.data) == 0 { + if err == nil { + err = io.EOF + } + return 0, err + } + n := copy(b, s.data) + s.data = s.data[n:] + return n, err +} + +func TestNormal(t *testing.T) { + var s stream + if err := helpers.FullClose(&s); err != nil { + t.Fatal(err) + } + if s.reset { + t.Fatal("stream should not have been reset") + } +} + +func TestFailRead(t *testing.T) { + var s stream + s.failRead = true + if helpers.FullClose(&s) != errReadFailed { + t.Fatal("expected read to fail with:", errReadFailed) + } + if !s.reset { + t.Fatal("expected stream to be reset") + } +} + +func TestFailClose(t *testing.T) { + var s stream + s.failClose = true + if helpers.FullClose(&s) != errCloseFailed { + t.Fatal("expected close to fail with:", errCloseFailed) + } + if !s.reset { + t.Fatal("expected stream to be reset") + } +} + +func TestFailWrite(t *testing.T) { + var s stream + s.failWrite = true + if err := helpers.FullClose(&s); err != nil { + t.Fatal(err) + } + if s.reset { + t.Fatal("stream should not have been reset") + } +} + +func TestReadDataOne(t *testing.T) { + var s stream + s.data = []byte{0} + if err := helpers.FullClose(&s); err != helpers.ErrExpectedEOF { + t.Fatal("expected:", helpers.ErrExpectedEOF) + } + if !s.reset { + t.Fatal("stream have been reset") + } +} + +func TestReadDataMany(t *testing.T) { + var s stream + s.data = []byte{0, 1, 2, 3} + if err := helpers.FullClose(&s); err != helpers.ErrExpectedEOF { + t.Fatal("expected:", helpers.ErrExpectedEOF) + } + if !s.reset { + t.Fatal("stream have been reset") + } +} + +func TestReadDataError(t *testing.T) { + var s stream + s.data = []byte{0, 1, 2, 3} + s.failRead = true + if err := helpers.FullClose(&s); err != helpers.ErrExpectedEOF { + t.Fatal("expected:", helpers.ErrExpectedEOF) + } + if !s.reset { + t.Fatal("stream have been reset") + } +} diff --git a/core/host/helpers.go b/core/host/helpers.go new file mode 100644 index 0000000000..a24beb1bbe --- /dev/null +++ b/core/host/helpers.go @@ -0,0 +1,11 @@ +package host + +import "github.com/libp2p/go-libp2p-core/peer" + +// InfoFromHost returns a peer.AddrInfo struct with the Host's ID and all of its Addrs. +func InfoFromHost(h Host) *peer.AddrInfo { + return &peer.AddrInfo{ + ID: h.ID(), + Addrs: h.Addrs(), + } +} diff --git a/core/host/host.go b/core/host/host.go new file mode 100644 index 0000000000..d4b3c03fdc --- /dev/null +++ b/core/host/host.go @@ -0,0 +1,71 @@ +// Package host provides the core Host interface for libp2p. +// +// Host represents a single libp2p node in a peer-to-peer network. +package host + +import ( + "context" + + "github.com/libp2p/go-libp2p-core/connmgr" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/peerstore" + "github.com/libp2p/go-libp2p-core/protocol" + + ma "github.com/multiformats/go-multiaddr" +) + +// Host is an object participating in a p2p network, which +// implements protocols or provides services. It handles +// requests like a Server, and issues requests like a Client. +// It is called Host because it is both Server and Client (and Peer +// may be confusing). +type Host interface { + // ID returns the (local) peer.ID associated with this Host + ID() peer.ID + + // Peerstore returns the Host's repository of Peer Addresses and Keys. + Peerstore() peerstore.Peerstore + + // Returns the listen addresses of the Host + Addrs() []ma.Multiaddr + + // Networks returns the Network interface of the Host + Network() network.Network + + // Mux returns the Mux multiplexing incoming streams to protocol handlers + Mux() protocol.Switch + + // Connect ensures there is a connection between this host and the peer with + // given peer.ID. Connect will absorb the addresses in pi into its internal + // peerstore. If there is not an active connection, Connect will issue a + // h.Network.Dial, and block until a connection is open, or an error is + // returned. // TODO: Relay + NAT. + Connect(ctx context.Context, pi peer.AddrInfo) error + + // SetStreamHandler sets the protocol handler on the Host's Mux. + // This is equivalent to: + // host.Mux().SetHandler(proto, handler) + // (Threadsafe) + SetStreamHandler(pid protocol.ID, handler network.StreamHandler) + + // SetStreamHandlerMatch sets the protocol handler on the Host's Mux + // using a matching function for protocol selection. + SetStreamHandlerMatch(protocol.ID, func(string) bool, network.StreamHandler) + + // RemoveStreamHandler removes a handler on the mux that was set by + // SetStreamHandler + RemoveStreamHandler(pid protocol.ID) + + // NewStream opens a new stream to given peer p, and writes a p2p/protocol + // header with given ProtocolID. If there is no connection to p, attempts + // to create one. If ProtocolID is "", writes no header. + // (Threadsafe) + NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) + + // Close shuts down the host, its Network, and services. + Close() error + + // ConnManager returns this hosts connection manager + ConnManager() connmgr.ConnManager +} diff --git a/core/metrics/bandwidth.go b/core/metrics/bandwidth.go new file mode 100644 index 0000000000..e2c8acf30d --- /dev/null +++ b/core/metrics/bandwidth.go @@ -0,0 +1,153 @@ +// Package metrics provides metrics collection and reporting interfaces for libp2p. +package metrics + +import ( + "github.com/libp2p/go-flow-metrics" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/protocol" +) + +// BandwidthCounter tracks incoming and outgoing data transferred by the local peer. +// Metrics are available for total bandwidth across all peers / protocols, as well +// as segmented by remote peer ID and protocol ID. +type BandwidthCounter struct { + totalIn flow.Meter + totalOut flow.Meter + + protocolIn flow.MeterRegistry + protocolOut flow.MeterRegistry + + peerIn flow.MeterRegistry + peerOut flow.MeterRegistry +} + +// NewBandwidthCounter creates a new BandwidthCounter. +func NewBandwidthCounter() *BandwidthCounter { + return new(BandwidthCounter) +} + +// LogSentMessage records the size of an outgoing message +// without associating the bandwidth to a specific peer or protocol. +func (bwc *BandwidthCounter) LogSentMessage(size int64) { + bwc.totalOut.Mark(uint64(size)) +} + +// LogRecvMessage records the size of an incoming message +// without associating the bandwith to a specific peer or protocol. +func (bwc *BandwidthCounter) LogRecvMessage(size int64) { + bwc.totalIn.Mark(uint64(size)) +} + +// LogSentMessageStream records the size of an outgoing message over a single logical stream. +// Bandwidth is associated with the given protocol.ID and peer.ID. +func (bwc *BandwidthCounter) LogSentMessageStream(size int64, proto protocol.ID, p peer.ID) { + bwc.protocolOut.Get(string(proto)).Mark(uint64(size)) + bwc.peerOut.Get(string(p)).Mark(uint64(size)) +} + +// LogRecvMessageStream records the size of an incoming message over a single logical stream. +// Bandwidth is associated with the given protocol.ID and peer.ID. +func (bwc *BandwidthCounter) LogRecvMessageStream(size int64, proto protocol.ID, p peer.ID) { + bwc.protocolIn.Get(string(proto)).Mark(uint64(size)) + bwc.peerIn.Get(string(p)).Mark(uint64(size)) +} + +// GetBandwidthForPeer returns a Stats struct with bandwidth metrics associated with the given peer.ID. +// The metrics returned include all traffic sent / received for the peer, regardless of protocol. +func (bwc *BandwidthCounter) GetBandwidthForPeer(p peer.ID) (out Stats) { + inSnap := bwc.peerIn.Get(string(p)).Snapshot() + outSnap := bwc.peerOut.Get(string(p)).Snapshot() + + return Stats{ + TotalIn: int64(inSnap.Total), + TotalOut: int64(outSnap.Total), + RateIn: inSnap.Rate, + RateOut: outSnap.Rate, + } +} + +// GetBandwidthForProtocol returns a Stats struct with bandwidth metrics associated with the given protocol.ID. +// The metrics returned include all traffic sent / recieved for the protocol, regardless of which peers were +// involved. +func (bwc *BandwidthCounter) GetBandwidthForProtocol(proto protocol.ID) (out Stats) { + inSnap := bwc.protocolIn.Get(string(proto)).Snapshot() + outSnap := bwc.protocolOut.Get(string(proto)).Snapshot() + + return Stats{ + TotalIn: int64(inSnap.Total), + TotalOut: int64(outSnap.Total), + RateIn: inSnap.Rate, + RateOut: outSnap.Rate, + } +} + +// GetBandwidthTotals returns a Stats struct with bandwidth metrics for all data sent / recieved by the +// local peer, regardless of protocol or remote peer IDs. +func (bwc *BandwidthCounter) GetBandwidthTotals() (out Stats) { + inSnap := bwc.totalIn.Snapshot() + outSnap := bwc.totalOut.Snapshot() + + return Stats{ + TotalIn: int64(inSnap.Total), + TotalOut: int64(outSnap.Total), + RateIn: inSnap.Rate, + RateOut: outSnap.Rate, + } +} + +// GetBandwidthByPeer returns a map of all remembered peers and the bandwidth +// metrics with respect to each. This method may be very expensive. +func (bwc *BandwidthCounter) GetBandwidthByPeer() map[peer.ID]Stats { + peers := make(map[peer.ID]Stats) + + bwc.peerIn.ForEach(func(p string, meter *flow.Meter) { + id := peer.ID(p) + snap := meter.Snapshot() + + stat := peers[id] + stat.TotalIn = int64(snap.Total) + stat.RateIn = snap.Rate + peers[id] = stat + }) + + bwc.peerOut.ForEach(func(p string, meter *flow.Meter) { + id := peer.ID(p) + snap := meter.Snapshot() + + stat := peers[id] + stat.TotalOut = int64(snap.Total) + stat.RateOut = snap.Rate + peers[id] = stat + }) + + return peers +} + +// GetBandwidthByProtocol returns a map of all remembered protocols and +// the bandwidth metrics with respect to each. This method may be moderately +// expensive. +func (bwc *BandwidthCounter) GetBandwidthByProtocol() map[protocol.ID]Stats { + protocols := make(map[protocol.ID]Stats) + + bwc.protocolIn.ForEach(func(p string, meter *flow.Meter) { + id := protocol.ID(p) + snap := meter.Snapshot() + + stat := protocols[id] + stat.TotalIn = int64(snap.Total) + stat.RateIn = snap.Rate + protocols[id] = stat + }) + + bwc.protocolOut.ForEach(func(p string, meter *flow.Meter) { + id := protocol.ID(p) + snap := meter.Snapshot() + + stat := protocols[id] + stat.TotalOut = int64(snap.Total) + stat.RateOut = snap.Rate + protocols[id] = stat + }) + + return protocols +} diff --git a/core/metrics/bandwidth_test.go b/core/metrics/bandwidth_test.go new file mode 100644 index 0000000000..de7b0fdcfd --- /dev/null +++ b/core/metrics/bandwidth_test.go @@ -0,0 +1,158 @@ +package metrics + +import ( + "fmt" + "math" + "sync" + "testing" + "time" + + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/protocol" +) + +func BenchmarkBandwidthCounter(b *testing.B) { + b.StopTimer() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + bwc := NewBandwidthCounter() + round(bwc, b) + } +} + +func round(bwc *BandwidthCounter, b *testing.B) { + start := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(10000) + for i := 0; i < 1000; i++ { + p := peer.ID(fmt.Sprintf("peer-%d", i)) + for j := 0; j < 10; j++ { + proto := protocol.ID(fmt.Sprintf("bitswap-%d", j)) + go func() { + defer wg.Done() + <-start + + for i := 0; i < 1000; i++ { + bwc.LogSentMessage(100) + bwc.LogSentMessageStream(100, proto, p) + time.Sleep(1 * time.Millisecond) + } + }() + } + } + + b.StartTimer() + close(start) + wg.Wait() + b.StopTimer() +} + +// Allow 7% errors for bw calculations. +const acceptableError = 0.07 + +func TestBandwidthCounter(t *testing.T) { + bwc := NewBandwidthCounter() + start := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(200) + for i := 0; i < 100; i++ { + p := peer.ID(fmt.Sprintf("peer-%d", i)) + for j := 0; j < 2; j++ { + proto := protocol.ID(fmt.Sprintf("proto-%d", j)) + go func() { + defer wg.Done() + <-start + + t := time.NewTicker(100 * time.Millisecond) + defer t.Stop() + + for i := 0; i < 40; i++ { + bwc.LogSentMessage(100) + bwc.LogRecvMessage(50) + bwc.LogSentMessageStream(100, proto, p) + bwc.LogRecvMessageStream(50, proto, p) + <-t.C + } + }() + } + } + + assertProtocols := func(check func(Stats)) { + byProtocol := bwc.GetBandwidthByProtocol() + if len(byProtocol) != 2 { + t.Errorf("expected 2 protocols, got %d", len(byProtocol)) + } + for i := 0; i < 2; i++ { + p := protocol.ID(fmt.Sprintf("proto-%d", i)) + for _, stats := range [...]Stats{bwc.GetBandwidthForProtocol(p), byProtocol[p]} { + check(stats) + } + } + } + + assertPeers := func(check func(Stats)) { + byPeer := bwc.GetBandwidthByPeer() + if len(byPeer) != 100 { + t.Errorf("expected 100 peers, got %d", len(byPeer)) + } + for i := 0; i < 100; i++ { + p := peer.ID(fmt.Sprintf("peer-%d", i)) + for _, stats := range [...]Stats{bwc.GetBandwidthForPeer(p), byPeer[p]} { + check(stats) + } + } + } + + close(start) + time.Sleep(2*time.Second + 100*time.Millisecond) + + assertPeers(func(stats Stats) { + assertApproxEq(t, 2000, stats.RateOut) + assertApproxEq(t, 1000, stats.RateIn) + }) + + assertProtocols(func(stats Stats) { + assertApproxEq(t, 100000, stats.RateOut) + assertApproxEq(t, 50000, stats.RateIn) + }) + + { + stats := bwc.GetBandwidthTotals() + assertApproxEq(t, 200000, stats.RateOut) + assertApproxEq(t, 100000, stats.RateIn) + } + + wg.Wait() + time.Sleep(1 * time.Second) + + assertPeers(func(stats Stats) { + assertEq(t, 8000, stats.TotalOut) + assertEq(t, 4000, stats.TotalIn) + }) + + assertProtocols(func(stats Stats) { + assertEq(t, 400000, stats.TotalOut) + assertEq(t, 200000, stats.TotalIn) + }) + + { + stats := bwc.GetBandwidthTotals() + assertEq(t, 800000, stats.TotalOut) + assertEq(t, 400000, stats.TotalIn) + } +} + +func assertEq(t *testing.T, expected, actual int64) { + if expected != actual { + t.Errorf("expected %d, got %d", expected, actual) + } +} + +func assertApproxEq(t *testing.T, expected, actual float64) { + t.Helper() + margin := expected * acceptableError + if !(math.Abs(expected-actual) <= margin) { + t.Errorf("expected %f (±%f), got %f", expected, margin, actual) + } +} diff --git a/core/metrics/reporter.go b/core/metrics/reporter.go new file mode 100644 index 0000000000..860345d639 --- /dev/null +++ b/core/metrics/reporter.go @@ -0,0 +1,31 @@ +// Package metrics provides metrics collection and reporting interfaces for libp2p. +package metrics + +import ( + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/protocol" +) + +// Stats represents a point-in-time snapshot of bandwidth metrics. +// +// The TotalIn and TotalOut fields record cumulative bytes sent / received. +// The RateIn and RateOut fields record bytes sent / received per second. +type Stats struct { + TotalIn int64 + TotalOut int64 + RateIn float64 + RateOut float64 +} + +// Reporter provides methods for logging and retrieving metrics. +type Reporter interface { + LogSentMessage(int64) + LogRecvMessage(int64) + LogSentMessageStream(int64, protocol.ID, peer.ID) + LogRecvMessageStream(int64, protocol.ID, peer.ID) + GetBandwidthForPeer(peer.ID) Stats + GetBandwidthForProtocol(protocol.ID) Stats + GetBandwidthTotals() Stats + GetBandwidthByPeer() map[peer.ID]Stats + GetBandwidthByProtocol() map[protocol.ID]Stats +} diff --git a/core/mux/mux.go b/core/mux/mux.go new file mode 100644 index 0000000000..39f2e51c1b --- /dev/null +++ b/core/mux/mux.go @@ -0,0 +1,70 @@ +// Package mux provides stream multiplexing interfaces for libp2p. +// +// For a conceptual overview of stream multiplexing in libp2p, see +// https://docs.libp2p.io/concepts/stream-multiplexing/ +package mux + +import ( + "errors" + "io" + "net" + "time" +) + +// ErrReset is returned when reading or writing on a reset stream. +var ErrReset = errors.New("stream reset") + +// Stream is a bidirectional io pipe within a connection. +type MuxedStream interface { + io.Reader + io.Writer + + // Close closes the stream for writing. Reading will still work (that + // is, the remote side can still write). + io.Closer + + // Reset closes both ends of the stream. Use this to tell the remote + // side to hang up and go away. + Reset() error + + SetDeadline(time.Time) error + SetReadDeadline(time.Time) error + SetWriteDeadline(time.Time) error +} + +// NoopHandler do nothing. Resets streams as soon as they are opened. +var NoopHandler = func(s MuxedStream) { s.Reset() } + +// MuxedConn represents a connection to a remote peer that has been +// extended to support stream multiplexing. +// +// A MuxedConn allows a single net.Conn connection to carry many logically +// independent bidirectional streams of binary data. +// +// Together with network.ConnSecurity, MuxedConn is a component of the +// transport.CapableConn interface, which represents a "raw" network +// connection that has been "upgraded" to support the libp2p capabilities +// of secure communication and stream multiplexing. +type MuxedConn interface { + // Close closes the stream muxer and the the underlying net.Conn. + io.Closer + + // IsClosed returns whether a connection is fully closed, so it can + // be garbage collected. + IsClosed() bool + + // OpenStream creates a new stream. + OpenStream() (MuxedStream, error) + + // AcceptStream accepts a stream opened by the other side. + AcceptStream() (MuxedStream, error) +} + +// Multiplexer wraps a net.Conn with a stream multiplexing +// implementation and returns a MuxedConn that supports opening +// multiple streams over the underlying net.Conn +type Multiplexer interface { + + // NewConn constructs a new connection + NewConn(c net.Conn, isServer bool) (MuxedConn, error) +} diff --git a/core/network/conn.go b/core/network/conn.go new file mode 100644 index 0000000000..e7844a0e80 --- /dev/null +++ b/core/network/conn.go @@ -0,0 +1,58 @@ +package network + +import ( + "io" + + ic "github.com/libp2p/go-libp2p-core/crypto" + "github.com/libp2p/go-libp2p-core/peer" + + ma "github.com/multiformats/go-multiaddr" +) + +// Conn is a connection to a remote peer. It multiplexes streams. +// Usually there is no need to use a Conn directly, but it may +// be useful to get information about the peer on the other side: +// stream.Conn().RemotePeer() +type Conn interface { + io.Closer + + ConnSecurity + ConnMultiaddrs + + // NewStream constructs a new Stream over this conn. + NewStream() (Stream, error) + + // GetStreams returns all open streams over this conn. + GetStreams() []Stream + + // Stat stores metadata pertaining to this conn. + Stat() Stat +} + +// ConnSecurity is the interface that one can mix into a connection interface to +// give it the security methods. +type ConnSecurity interface { + // LocalPeer returns our peer ID + LocalPeer() peer.ID + + // LocalPrivateKey returns our private key + LocalPrivateKey() ic.PrivKey + + // RemotePeer returns the peer ID of the remote peer. + RemotePeer() peer.ID + + // RemotePublicKey returns the public key of the remote peer. + RemotePublicKey() ic.PubKey +} + +// ConnMultiaddrs is an interface mixin for connection types that provide multiaddr +// addresses for the endpoints. +type ConnMultiaddrs interface { + // LocalMultiaddr returns the local Multiaddr associated + // with this connection + LocalMultiaddr() ma.Multiaddr + + // RemoteMultiaddr returns the remote Multiaddr associated + // with this connection + RemoteMultiaddr() ma.Multiaddr +} diff --git a/core/network/context.go b/core/network/context.go new file mode 100644 index 0000000000..9025f83a22 --- /dev/null +++ b/core/network/context.go @@ -0,0 +1,48 @@ +package network + +import ( + "context" + "time" +) + +// DialPeerTimeout is the default timeout for a single call to `DialPeer`. When +// there are multiple concurrent calls to `DialPeer`, this timeout will apply to +// each independently. +var DialPeerTimeout = 60 * time.Second + +type noDialCtxKey struct{} +type dialPeerTimeoutCtxKey struct{} + +var noDial = noDialCtxKey{} + +// WithNoDial constructs a new context with an option that instructs the network +// to not attempt a new dial when opening a stream. +func WithNoDial(ctx context.Context, reason string) context.Context { + return context.WithValue(ctx, noDial, reason) +} + +// GetNoDial returns true if the no dial option is set in the context. +func GetNoDial(ctx context.Context) (nodial bool, reason string) { + v := ctx.Value(noDial) + if v != nil { + return true, v.(string) + } + + return false, "" +} + +// GetDialPeerTimeout returns the current DialPeer timeout (or the default). +func GetDialPeerTimeout(ctx context.Context) time.Duration { + if to, ok := ctx.Value(dialPeerTimeoutCtxKey{}).(time.Duration); ok { + return to + } + return DialPeerTimeout +} + +// WithDialPeerTimeout returns a new context with the DialPeer timeout applied. +// +// This timeout overrides the default DialPeerTimeout and applies per-dial +// independently. +func WithDialPeerTimeout(ctx context.Context, timeout time.Duration) context.Context { + return context.WithValue(ctx, dialPeerTimeoutCtxKey{}, timeout) +} diff --git a/core/network/context_test.go b/core/network/context_test.go new file mode 100644 index 0000000000..09125516ad --- /dev/null +++ b/core/network/context_test.go @@ -0,0 +1,40 @@ +package network + +import ( + "context" + "testing" + "time" +) + +func TestDefaultTimeout(t *testing.T) { + ctx := context.Background() + dur := GetDialPeerTimeout(ctx) + if dur != DialPeerTimeout { + t.Fatal("expected default peer timeout") + } +} + +func TestNonDefaultTimeout(t *testing.T) { + customTimeout := time.Duration(1) + ctx := context.WithValue( + context.Background(), + dialPeerTimeoutCtxKey{}, + customTimeout, + ) + dur := GetDialPeerTimeout(ctx) + if dur != customTimeout { + t.Fatal("peer timeout doesn't match set timeout") + } +} + +func TestSettingTimeout(t *testing.T) { + customTimeout := time.Duration(1) + ctx := WithDialPeerTimeout( + context.Background(), + customTimeout, + ) + dur := GetDialPeerTimeout(ctx) + if dur != customTimeout { + t.Fatal("peer timeout doesn't match set timeout") + } +} diff --git a/core/network/errors.go b/core/network/errors.go new file mode 100644 index 0000000000..f0cf7291e8 --- /dev/null +++ b/core/network/errors.go @@ -0,0 +1,10 @@ +package network + +import "errors" + +// ErrNoRemoteAddrs is returned when there are no addresses associated with a peer during a dial. +var ErrNoRemoteAddrs = errors.New("no remote addresses") + +// ErrNoConn is returned when attempting to open a stream to a peer with the NoDial +// option and no usable connection is available. +var ErrNoConn = errors.New("no usable connection to peer") diff --git a/core/network/network.go b/core/network/network.go new file mode 100644 index 0000000000..467109c48c --- /dev/null +++ b/core/network/network.go @@ -0,0 +1,138 @@ +// Package network provides core networking abstractions for libp2p. +// +// The network package provides the high-level Network interface for interacting +// with other libp2p peers, which is the primary public API for initiating and +// accepting connections to remote peers. +package network + +import ( + "context" + "io" + + "github.com/jbenet/goprocess" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/peerstore" + + ma "github.com/multiformats/go-multiaddr" +) + +// MessageSizeMax is a soft (recommended) maximum for network messages. +// One can write more, as the interface is a stream. But it is useful +// to bunch it up into multiple read/writes when the whole message is +// a single, large serialized object. +const MessageSizeMax = 1 << 22 // 4 MB + +// Direction represents which peer in a stream initiated a connection. +type Direction int + +const ( + // DirUnknown is the default direction. + DirUnknown Direction = iota + // DirInbound is for when the remote peer initiated a connection. + DirInbound + // DirOutbound is for when the local peer initiated a connection. + DirOutbound +) + +// Connectedness signals the capacity for a connection with a given node. +// It is used to signal to services and other peers whether a node is reachable. +type Connectedness int + +const ( + // NotConnected means no connection to peer, and no extra information (default) + NotConnected Connectedness = iota + + // Connected means has an open, live connection to peer + Connected + + // CanConnect means recently connected to peer, terminated gracefully + CanConnect + + // CannotConnect means recently attempted connecting but failed to connect. + // (should signal "made effort, failed") + CannotConnect +) + +// Stat stores metadata pertaining to a given Stream/Conn. +type Stat struct { + Direction Direction + Extra map[interface{}]interface{} +} + +// StreamHandler is the type of function used to listen for +// streams opened by the remote side. +type StreamHandler func(Stream) + +// ConnHandler is the type of function used to listen for +// connections opened by the remote side. +type ConnHandler func(Conn) + +// Network is the interface used to connect to the outside world. +// It dials and listens for connections. it uses a Swarm to pool +// connections (see swarm pkg, and peerstream.Swarm). Connections +// are encrypted with a TLS-like protocol. +type Network interface { + Dialer + io.Closer + + // SetStreamHandler sets the handler for new streams opened by the + // remote side. This operation is threadsafe. + SetStreamHandler(StreamHandler) + + // SetConnHandler sets the handler for new connections opened by the + // remote side. This operation is threadsafe. + SetConnHandler(ConnHandler) + + // NewStream returns a new stream to given peer p. + // If there is no connection to p, attempts to create one. + NewStream(context.Context, peer.ID) (Stream, error) + + // Listen tells the network to start listening on given multiaddrs. + Listen(...ma.Multiaddr) error + + // ListenAddresses returns a list of addresses at which this network listens. + ListenAddresses() []ma.Multiaddr + + // InterfaceListenAddresses returns a list of addresses at which this network + // listens. It expands "any interface" addresses (/ip4/0.0.0.0, /ip6/::) to + // use the known local interfaces. + InterfaceListenAddresses() ([]ma.Multiaddr, error) + + // Process returns the network's Process + Process() goprocess.Process +} + +// Dialer represents a service that can dial out to peers +// (this is usually just a Network, but other services may not need the whole +// stack, and thus it becomes easier to mock) +type Dialer interface { + // Peerstore returns the internal peerstore + // This is useful to tell the dialer about a new address for a peer. + // Or use one of the public keys found out over the network. + Peerstore() peerstore.Peerstore + + // LocalPeer returns the local peer associated with this network + LocalPeer() peer.ID + + // DialPeer establishes a connection to a given peer + DialPeer(context.Context, peer.ID) (Conn, error) + + // ClosePeer closes the connection to a given peer + ClosePeer(peer.ID) error + + // Connectedness returns a state signaling connection capabilities + Connectedness(peer.ID) Connectedness + + // Peers returns the peers connected + Peers() []peer.ID + + // Conns returns the connections in this Netowrk + Conns() []Conn + + // ConnsToPeer returns the connections in this Netowrk for given peer. + ConnsToPeer(p peer.ID) []Conn + + // Notify/StopNotify register and unregister a notifiee for signals + Notify(Notifiee) + StopNotify(Notifiee) +} diff --git a/core/network/notifee.go b/core/network/notifee.go new file mode 100644 index 0000000000..10ef72f1e1 --- /dev/null +++ b/core/network/notifee.go @@ -0,0 +1,92 @@ +package network + +import ( + ma "github.com/multiformats/go-multiaddr" +) + +// Notifiee is an interface for an object wishing to receive +// notifications from a Network. +type Notifiee interface { + Listen(Network, ma.Multiaddr) // called when network starts listening on an addr + ListenClose(Network, ma.Multiaddr) // called when network stops listening on an addr + Connected(Network, Conn) // called when a connection opened + Disconnected(Network, Conn) // called when a connection closed + OpenedStream(Network, Stream) // called when a stream opened + ClosedStream(Network, Stream) // called when a stream closed + + // TODO + // PeerConnected(Network, peer.ID) // called when a peer connected + // PeerDisconnected(Network, peer.ID) // called when a peer disconnected +} + +// NotifyBundle implements Notifiee by calling any of the functions set on it, +// and nop'ing if they are unset. This is the easy way to register for +// notifications. +type NotifyBundle struct { + ListenF func(Network, ma.Multiaddr) + ListenCloseF func(Network, ma.Multiaddr) + + ConnectedF func(Network, Conn) + DisconnectedF func(Network, Conn) + + OpenedStreamF func(Network, Stream) + ClosedStreamF func(Network, Stream) +} + +var _ Notifiee = (*NotifyBundle)(nil) + +// Listen calls ListenF if it is not null. +func (nb *NotifyBundle) Listen(n Network, a ma.Multiaddr) { + if nb.ListenF != nil { + nb.ListenF(n, a) + } +} + +// ListenClose calls ListenCloseF if it is not null. +func (nb *NotifyBundle) ListenClose(n Network, a ma.Multiaddr) { + if nb.ListenCloseF != nil { + nb.ListenCloseF(n, a) + } +} + +// Connected calls ConnectedF if it is not null. +func (nb *NotifyBundle) Connected(n Network, c Conn) { + if nb.ConnectedF != nil { + nb.ConnectedF(n, c) + } +} + +// Disconnected calls DisconnectedF if it is not null. +func (nb *NotifyBundle) Disconnected(n Network, c Conn) { + if nb.DisconnectedF != nil { + nb.DisconnectedF(n, c) + } +} + +// OpenedStream calls OpenedStreamF if it is not null. +func (nb *NotifyBundle) OpenedStream(n Network, s Stream) { + if nb.OpenedStreamF != nil { + nb.OpenedStreamF(n, s) + } +} + +// ClosedStream calls ClosedStreamF if it is not null. +func (nb *NotifyBundle) ClosedStream(n Network, s Stream) { + if nb.ClosedStreamF != nil { + nb.ClosedStreamF(n, s) + } +} + +// Global noop notifiee. Do not change. +var GlobalNoopNotifiee = &NoopNotifiee{} + +type NoopNotifiee struct{} + +var _ Notifiee = (*NoopNotifiee)(nil) + +func (nn *NoopNotifiee) Connected(n Network, c Conn) {} +func (nn *NoopNotifiee) Disconnected(n Network, c Conn) {} +func (nn *NoopNotifiee) Listen(n Network, addr ma.Multiaddr) {} +func (nn *NoopNotifiee) ListenClose(n Network, addr ma.Multiaddr) {} +func (nn *NoopNotifiee) OpenedStream(Network, Stream) {} +func (nn *NoopNotifiee) ClosedStream(Network, Stream) {} diff --git a/core/network/notifee_test.go b/core/network/notifee_test.go new file mode 100644 index 0000000000..1948074a39 --- /dev/null +++ b/core/network/notifee_test.go @@ -0,0 +1,123 @@ +package network + +import ( + "testing" + + ma "github.com/multiformats/go-multiaddr" +) + +func TestListen(T *testing.T) { + var notifee NotifyBundle + addr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/1234") + if err != nil { + T.Fatal("unexpected multiaddr error") + } + notifee.Listen(nil, addr) + + called := false + notifee.ListenF = func(Network, ma.Multiaddr) { + called = true + } + if called { + T.Fatal("called should be false") + } + + notifee.Listen(nil, addr) + if !called { + T.Fatal("Listen should have been called") + } +} + +func TestListenClose(T *testing.T) { + var notifee NotifyBundle + addr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/1234") + if err != nil { + T.Fatal("unexpected multiaddr error") + } + notifee.ListenClose(nil, addr) + + called := false + notifee.ListenCloseF = func(Network, ma.Multiaddr) { + called = true + } + if called { + T.Fatal("called should be false") + } + + notifee.ListenClose(nil, addr) + if !called { + T.Fatal("ListenClose should have been called") + } +} + +func TestConnected(T *testing.T) { + var notifee NotifyBundle + notifee.Connected(nil, nil) + + called := false + notifee.ConnectedF = func(Network, Conn) { + called = true + } + if called { + T.Fatal("called should be false") + } + + notifee.Connected(nil, nil) + if !called { + T.Fatal("Connected should have been called") + } +} + +func TestDisconnected(T *testing.T) { + var notifee NotifyBundle + notifee.Disconnected(nil, nil) + + called := false + notifee.DisconnectedF = func(Network, Conn) { + called = true + } + if called { + T.Fatal("called should be false") + } + + notifee.Disconnected(nil, nil) + if !called { + T.Fatal("Disconnected should have been called") + } +} + +func TestOpenedStream(T *testing.T) { + var notifee NotifyBundle + notifee.OpenedStream(nil, nil) + + called := false + notifee.OpenedStreamF = func(Network, Stream) { + called = true + } + if called { + T.Fatal("called should be false") + } + + notifee.OpenedStream(nil, nil) + if !called { + T.Fatal("OpenedStream should have been called") + } +} + +func TestClosedStream(T *testing.T) { + var notifee NotifyBundle + notifee.ClosedStream(nil, nil) + + called := false + notifee.ClosedStreamF = func(Network, Stream) { + called = true + } + if called { + T.Fatal("called should be false") + } + + notifee.ClosedStream(nil, nil) + if !called { + T.Fatal("ClosedStream should have been called") + } +} diff --git a/core/network/stream.go b/core/network/stream.go new file mode 100644 index 0000000000..d13dc30519 --- /dev/null +++ b/core/network/stream.go @@ -0,0 +1,24 @@ +package network + +import ( + "github.com/libp2p/go-libp2p-core/mux" + "github.com/libp2p/go-libp2p-core/protocol" +) + +// Stream represents a bidirectional channel between two agents in +// a libp2p network. "agent" is as granular as desired, potentially +// being a "request -> reply" pair, or whole protocols. +// +// Streams are backed by a multiplexer underneath the hood. +type Stream interface { + mux.MuxedStream + + Protocol() protocol.ID + SetProtocol(id protocol.ID) + + // Stat returns metadata pertaining to this stream. + Stat() Stat + + // Conn returns the connection this stream is part of. + Conn() Conn +} diff --git a/core/peer/addrinfo.go b/core/peer/addrinfo.go new file mode 100644 index 0000000000..c7324e4f3e --- /dev/null +++ b/core/peer/addrinfo.go @@ -0,0 +1,80 @@ +package peer + +import ( + "fmt" + "strings" + + ma "github.com/multiformats/go-multiaddr" +) + +// AddrInfo is a small struct used to pass around a peer with +// a set of addresses (and later, keys?). +type AddrInfo struct { + ID ID + Addrs []ma.Multiaddr +} + +var _ fmt.Stringer = AddrInfo{} + +func (pi AddrInfo) String() string { + return fmt.Sprintf("{%v: %v}", pi.ID, pi.Addrs) +} + +var ErrInvalidAddr = fmt.Errorf("invalid p2p multiaddr") + +func AddrInfoFromP2pAddr(m ma.Multiaddr) (*AddrInfo, error) { + if m == nil { + return nil, ErrInvalidAddr + } + + // make sure it's a P2P addr + parts := ma.Split(m) + if len(parts) < 1 { + return nil, ErrInvalidAddr + } + + // TODO(lgierth): we shouldn't assume /p2p is the last part + p2ppart := parts[len(parts)-1] + if p2ppart.Protocols()[0].Code != ma.P_P2P { + return nil, ErrInvalidAddr + } + + // make sure the /p2p value parses as a peer.ID + peerIdParts := strings.Split(p2ppart.String(), "/") + peerIdStr := peerIdParts[len(peerIdParts)-1] + id, err := IDB58Decode(peerIdStr) + if err != nil { + return nil, err + } + + // we might have received just an /p2p part, which means there's no addr. + var addrs []ma.Multiaddr + if len(parts) > 1 { + addrs = append(addrs, ma.Join(parts[:len(parts)-1]...)) + } + + return &AddrInfo{ + ID: id, + Addrs: addrs, + }, nil +} + +func AddrInfoToP2pAddrs(pi *AddrInfo) ([]ma.Multiaddr, error) { + var addrs []ma.Multiaddr + tpl := "/" + ma.ProtocolWithCode(ma.P_P2P).Name + "/" + for _, addr := range pi.Addrs { + p2paddr, err := ma.NewMultiaddr(tpl + IDB58Encode(pi.ID)) + if err != nil { + return nil, err + } + addrs = append(addrs, addr.Encapsulate(p2paddr)) + } + return addrs, nil +} + +func (pi *AddrInfo) Loggable() map[string]interface{} { + return map[string]interface{}{ + "peerID": pi.ID.Pretty(), + "addrs": pi.Addrs, + } +} diff --git a/core/peer/addrinfo_serde.go b/core/peer/addrinfo_serde.go new file mode 100644 index 0000000000..1df24e2b73 --- /dev/null +++ b/core/peer/addrinfo_serde.go @@ -0,0 +1,38 @@ +package peer + +import ( + "encoding/json" + + ma "github.com/multiformats/go-multiaddr" +) + +func (pi AddrInfo) MarshalJSON() ([]byte, error) { + out := make(map[string]interface{}) + out["ID"] = pi.ID.Pretty() + var addrs []string + for _, a := range pi.Addrs { + addrs = append(addrs, a.String()) + } + out["Addrs"] = addrs + return json.Marshal(out) +} + +func (pi *AddrInfo) UnmarshalJSON(b []byte) error { + var data map[string]interface{} + err := json.Unmarshal(b, &data) + if err != nil { + return err + } + pid, err := IDB58Decode(data["ID"].(string)) + if err != nil { + return err + } + pi.ID = pid + addrs, ok := data["Addrs"].([]interface{}) + if ok { + for _, a := range addrs { + pi.Addrs = append(pi.Addrs, ma.StringCast(a.(string))) + } + } + return nil +} diff --git a/core/peer/peer.go b/core/peer/peer.go new file mode 100644 index 0000000000..db307ee1e3 --- /dev/null +++ b/core/peer/peer.go @@ -0,0 +1,186 @@ +// Package peer implements an object used to represent peers in the libp2p network. +package peer + +import ( + "encoding/hex" + "errors" + "fmt" + + ic "github.com/libp2p/go-libp2p-core/crypto" + b58 "github.com/mr-tron/base58/base58" + mh "github.com/multiformats/go-multihash" +) + +var ( + // ErrEmptyPeerID is an error for empty peer ID. + ErrEmptyPeerID = errors.New("empty peer ID") + // ErrNoPublicKey is an error for peer IDs that don't embed public keys + ErrNoPublicKey = errors.New("public key is not embedded in peer ID") +) + +// AdvancedEnableInlining enables automatically inlining keys shorter than +// 42 bytes into the peer ID (using the "identity" multihash function). +// +// WARNING: This flag will likely be set to false in the future and eventually +// be removed in favor of using a hash function specified by the key itself. +// See: https://github.com/libp2p/specs/issues/138 +// +// DO NOT change this flag unless you know what you're doing. +// +// This currently defaults to true for backwards compatibility but will likely +// be set to false by default when an upgrade path is determined. +var AdvancedEnableInlining = true + +const maxInlineKeyLength = 42 + +// ID is a libp2p peer identity. +// +// Peer IDs are derived by hashing a peer's public key and encoding the +// hash output as a multihash. See IDFromPublicKey for details. +type ID string + +// Pretty returns a base58-encoded string representation of the ID. +func (id ID) Pretty() string { + return IDB58Encode(id) +} + +// Loggable returns a pretty peer ID string in loggable JSON format. +func (id ID) Loggable() map[string]interface{} { + return map[string]interface{}{ + "peerID": id.Pretty(), + } +} + +func (id ID) String() string { + return id.Pretty() +} + +// String prints out the peer ID. +// +// TODO(brian): ensure correctness at ID generation and +// enforce this by only exposing functions that generate +// IDs safely. Then any peer.ID type found in the +// codebase is known to be correct. +func (id ID) ShortString() string { + pid := id.Pretty() + if len(pid) <= 10 { + return fmt.Sprintf("", pid) + } + return fmt.Sprintf("", pid[:2], pid[len(pid)-6:]) +} + +// MatchesPrivateKey tests whether this ID was derived from the secret key sk. +func (id ID) MatchesPrivateKey(sk ic.PrivKey) bool { + return id.MatchesPublicKey(sk.GetPublic()) +} + +// MatchesPublicKey tests whether this ID was derived from the public key pk. +func (id ID) MatchesPublicKey(pk ic.PubKey) bool { + oid, err := IDFromPublicKey(pk) + if err != nil { + return false + } + return oid == id +} + +// ExtractPublicKey attempts to extract the public key from an ID +// +// This method returns ErrNoPublicKey if the peer ID looks valid but it can't extract +// the public key. +func (id ID) ExtractPublicKey() (ic.PubKey, error) { + decoded, err := mh.Decode([]byte(id)) + if err != nil { + return nil, err + } + if decoded.Code != mh.ID { + return nil, ErrNoPublicKey + } + pk, err := ic.UnmarshalPublicKey(decoded.Digest) + if err != nil { + return nil, err + } + return pk, nil +} + +// Validate checks if ID is empty or not. +func (id ID) Validate() error { + if id == ID("") { + return ErrEmptyPeerID + } + + return nil +} + +// IDFromString casts a string to the ID type, and validates +// the value to make sure it is a multihash. +func IDFromString(s string) (ID, error) { + if _, err := mh.Cast([]byte(s)); err != nil { + return ID(""), err + } + return ID(s), nil +} + +// IDFromBytes casts a byte slice to the ID type, and validates +// the value to make sure it is a multihash. +func IDFromBytes(b []byte) (ID, error) { + if _, err := mh.Cast(b); err != nil { + return ID(""), err + } + return ID(b), nil +} + +// IDB58Decode accepts a base58-encoded multihash representing a peer ID +// and returns the decoded ID if the input is valid. +func IDB58Decode(s string) (ID, error) { + m, err := mh.FromB58String(s) + if err != nil { + return "", err + } + return ID(m), err +} + +// IDB58Encode returns the base58-encoded multihash representation of the ID. +func IDB58Encode(id ID) string { + return b58.Encode([]byte(id)) +} + +// IDHexDecode accepts a hex-encoded multihash representing a peer ID +// and returns the decoded ID if the input is valid. +func IDHexDecode(s string) (ID, error) { + m, err := mh.FromHexString(s) + if err != nil { + return "", err + } + return ID(m), err +} + +// IDHexEncode returns the hex-encoded multihash representation of the ID. +func IDHexEncode(id ID) string { + return hex.EncodeToString([]byte(id)) +} + +// IDFromPublicKey returns the Peer ID corresponding to the public key pk. +func IDFromPublicKey(pk ic.PubKey) (ID, error) { + b, err := pk.Bytes() + if err != nil { + return "", err + } + var alg uint64 = mh.SHA2_256 + if AdvancedEnableInlining && len(b) <= maxInlineKeyLength { + alg = mh.ID + } + hash, _ := mh.Sum(b, alg, -1) + return ID(hash), nil +} + +// IDFromPrivateKey returns the Peer ID corresponding to the secret key sk. +func IDFromPrivateKey(sk ic.PrivKey) (ID, error) { + return IDFromPublicKey(sk.GetPublic()) +} + +// IDSlice for sorting peers +type IDSlice []ID + +func (es IDSlice) Len() int { return len(es) } +func (es IDSlice) Swap(i, j int) { es[i], es[j] = es[j], es[i] } +func (es IDSlice) Less(i, j int) bool { return string(es[i]) < string(es[j]) } diff --git a/core/peer/peer_serde.go b/core/peer/peer_serde.go new file mode 100644 index 0000000000..7f1b3e6a65 --- /dev/null +++ b/core/peer/peer_serde.go @@ -0,0 +1,75 @@ +// This file contains Protobuf and JSON serialization/deserialization methods for peer IDs. +package peer + +import ( + "encoding" + "encoding/json" +) + +// Interface assertions commented out to avoid introducing hard dependencies to protobuf. +// var _ proto.Marshaler = (*ID)(nil) +// var _ proto.Unmarshaler = (*ID)(nil) +var _ json.Marshaler = (*ID)(nil) +var _ json.Unmarshaler = (*ID)(nil) + +var _ encoding.BinaryMarshaler = (*ID)(nil) +var _ encoding.BinaryUnmarshaler = (*ID)(nil) +var _ encoding.TextMarshaler = (*ID)(nil) +var _ encoding.TextUnmarshaler = (*ID)(nil) + +func (id ID) Marshal() ([]byte, error) { + return []byte(id), nil +} + +// BinaryMarshal returns the byte representation of the peer ID. +func (id ID) MarshalBinary() ([]byte, error) { + return id.Marshal() +} + +func (id ID) MarshalTo(data []byte) (n int, err error) { + return copy(data, []byte(id)), nil +} + +func (id *ID) Unmarshal(data []byte) (err error) { + *id, err = IDFromBytes(data) + return err +} + +// BinaryUnmarshal sets the ID from its binary representation. +func (id *ID) UnmarshalBinary(data []byte) error { + return id.Unmarshal(data) +} + +// Implements Gogo's proto.Sizer, but we omit the compile-time assertion to avoid introducing a hard +// dependency on gogo. +func (id ID) Size() int { + return len([]byte(id)) +} + +func (id ID) MarshalJSON() ([]byte, error) { + return json.Marshal(IDB58Encode(id)) +} + +func (id *ID) UnmarshalJSON(data []byte) (err error) { + var v string + if err = json.Unmarshal(data, &v); err != nil { + return err + } + *id, err = IDB58Decode(v) + return err +} + +// TextMarshal returns the text encoding of the ID. +func (id ID) MarshalText() ([]byte, error) { + return []byte(IDB58Encode(id)), nil +} + +// TextUnmarshal restores the ID from its text encoding. +func (id *ID) UnmarshalText(data []byte) error { + pid, err := IDB58Decode(string(data)) + if err != nil { + return err + } + *id = pid + return nil +} diff --git a/core/peer/peer_serde_test.go b/core/peer/peer_serde_test.go new file mode 100644 index 0000000000..62ecb0f725 --- /dev/null +++ b/core/peer/peer_serde_test.go @@ -0,0 +1,83 @@ +package peer_test + +import ( + "testing" + + "github.com/libp2p/go-libp2p-core/peer" + . "github.com/libp2p/go-libp2p-testing/peer" +) + +func TestPeerSerdePB(t *testing.T) { + id, err := RandPeerID() + if err != nil { + t.Fatal(err) + } + b, err := id.Marshal() + if err != nil { + t.Fatal(err) + } + + var id2 peer.ID + if err = id2.Unmarshal(b); err != nil { + t.Fatal(err) + } + if id != id2 { + t.Error("expected equal ids in circular serde test") + } +} + +func TestPeerSerdeJSON(t *testing.T) { + id, err := RandPeerID() + if err != nil { + t.Fatal(err) + } + b, err := id.MarshalJSON() + if err != nil { + t.Fatal(err) + } + var id2 peer.ID + if err = id2.UnmarshalJSON(b); err != nil { + t.Fatal(err) + } + if id != id2 { + t.Error("expected equal ids in circular serde test") + } +} + +func TestBinaryMarshaler(t *testing.T) { + id, err := RandPeerID() + if err != nil { + t.Fatal(err) + } + b, err := id.MarshalBinary() + if err != nil { + t.Fatal(err) + } + + var id2 peer.ID + if err = id2.UnmarshalBinary(b); err != nil { + t.Fatal(err) + } + if id != id2 { + t.Error("expected equal ids in circular serde test") + } +} + +func TestTextMarshaler(t *testing.T) { + id, err := RandPeerID() + if err != nil { + t.Fatal(err) + } + b, err := id.MarshalText() + if err != nil { + t.Fatal(err) + } + + var id2 peer.ID + if err = id2.UnmarshalText(b); err != nil { + t.Fatal(err) + } + if id != id2 { + t.Error("expected equal ids in circular serde test") + } +} diff --git a/core/peer/peer_test.go b/core/peer/peer_test.go new file mode 100644 index 0000000000..a134532140 --- /dev/null +++ b/core/peer/peer_test.go @@ -0,0 +1,244 @@ +package peer_test + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "strings" + "testing" + + ic "github.com/libp2p/go-libp2p-core/crypto" + "github.com/libp2p/go-libp2p-testing/crypto" + + . "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-testing/peer" + + b58 "github.com/mr-tron/base58/base58" + mh "github.com/multiformats/go-multihash" +) + +var gen1 keyset // generated +var gen2 keyset // generated +var man keyset // manual + +func hash(b []byte) []byte { + h, _ := mh.Sum(b, mh.SHA2_256, -1) + return []byte(h) +} + +func init() { + if err := gen1.generate(); err != nil { + panic(err) + } + if err := gen2.generate(); err != nil { + panic(err) + } + + skManBytes = strings.Replace(skManBytes, "\n", "", -1) + if err := man.load(hpkpMan, skManBytes); err != nil { + panic(err) + } +} + +type keyset struct { + sk ic.PrivKey + pk ic.PubKey + hpk string + hpkp string +} + +func (ks *keyset) generate() error { + var err error + ks.sk, ks.pk, err = tcrypto.RandTestKeyPair(ic.RSA, 512) + if err != nil { + return err + } + + bpk, err := ks.pk.Bytes() + if err != nil { + return err + } + + ks.hpk = string(hash(bpk)) + ks.hpkp = b58.Encode([]byte(ks.hpk)) + return nil +} + +func (ks *keyset) load(hpkp, skBytesStr string) error { + skBytes, err := base64.StdEncoding.DecodeString(skBytesStr) + if err != nil { + return err + } + + ks.sk, err = ic.UnmarshalPrivateKey(skBytes) + if err != nil { + return err + } + + ks.pk = ks.sk.GetPublic() + bpk, err := ks.pk.Bytes() + if err != nil { + return err + } + + ks.hpk = string(hash(bpk)) + ks.hpkp = b58.Encode([]byte(ks.hpk)) + if ks.hpkp != hpkp { + return fmt.Errorf("hpkp doesn't match key. %s", hpkp) + } + return nil +} + +func TestIDMatchesPublicKey(t *testing.T) { + + test := func(ks keyset) { + p1, err := IDB58Decode(ks.hpkp) + if err != nil { + t.Fatal(err) + } + + if ks.hpk != string(p1) { + t.Error("p1 and hpk differ") + } + + if !p1.MatchesPublicKey(ks.pk) { + t.Fatal("p1 does not match pk") + } + + p2, err := IDFromPublicKey(ks.pk) + if err != nil { + t.Fatal(err) + } + + if p1 != p2 { + t.Error("p1 and p2 differ", p1.Pretty(), p2.Pretty()) + } + + if p2.Pretty() != ks.hpkp { + t.Error("hpkp and p2.Pretty differ", ks.hpkp, p2.Pretty()) + } + } + + test(gen1) + test(gen2) + test(man) +} + +func TestIDMatchesPrivateKey(t *testing.T) { + + test := func(ks keyset) { + p1, err := IDB58Decode(ks.hpkp) + if err != nil { + t.Fatal(err) + } + + if ks.hpk != string(p1) { + t.Error("p1 and hpk differ") + } + + if !p1.MatchesPrivateKey(ks.sk) { + t.Fatal("p1 does not match sk") + } + + p2, err := IDFromPrivateKey(ks.sk) + if err != nil { + t.Fatal(err) + } + + if p1 != p2 { + t.Error("p1 and p2 differ", p1.Pretty(), p2.Pretty()) + } + } + + test(gen1) + test(gen2) + test(man) +} + +func TestPublicKeyExtraction(t *testing.T) { + t.Skip("disabled until libp2p/go-libp2p-crypto#51 is fixed") + // Happy path + _, originalPub, err := ic.GenerateEd25519Key(rand.Reader) + if err != nil { + t.Fatal(err) + } + + id, err := IDFromPublicKey(originalPub) + if err != nil { + t.Fatal(err) + } + + extractedPub, err := id.ExtractPublicKey() + if err != nil { + t.Fatal(err) + } + if extractedPub == nil { + t.Fatal("failed to extract public key") + } + if !originalPub.Equals(extractedPub) { + t.Fatal("extracted public key doesn't match") + } + + // Test invalid multihash (invariant of the type of public key) + pk, err := ID("").ExtractPublicKey() + if err == nil { + t.Fatal("expected an error") + } + if pk != nil { + t.Fatal("expected a nil public key") + } + + // Shouldn't work for, e.g. RSA keys (too large) + + _, rsaPub, err := ic.GenerateKeyPair(ic.RSA, 2048) + if err != nil { + t.Fatal(err) + } + rsaId, err := IDFromPublicKey(rsaPub) + if err != nil { + t.Fatal(err) + } + extractedRsaPub, err := rsaId.ExtractPublicKey() + if err != ErrNoPublicKey { + t.Fatal(err) + } + if extractedRsaPub != nil { + t.Fatal("expected to fail to extract public key from rsa ID") + } +} + +func TestValidate(t *testing.T) { + // Empty peer ID invalidates + err := ID("").Validate() + if err == nil { + t.Error("expected error") + } else if err != ErrEmptyPeerID { + t.Error("expected error message: " + ErrEmptyPeerID.Error()) + } + + // Non-empty peer ID validates + p, err := tpeer.RandPeerID() + if err != nil { + t.Fatal(err) + } + + err = p.Validate() + if err != nil { + t.Error("expected nil, but found " + err.Error()) + } +} + +var hpkpMan = `QmRK3JgmVEGiewxWbhpXLJyjWuGuLeSTMTndA1coMHEy5o` +var skManBytes = ` +CAAS4AQwggJcAgEAAoGBAL7w+Wc4VhZhCdM/+Hccg5Nrf4q9NXWwJylbSrXz/unFS24wyk6pEk0zi3W +7li+vSNVO+NtJQw9qGNAMtQKjVTP+3Vt/jfQRnQM3s6awojtjueEWuLYVt62z7mofOhCtj+VwIdZNBo +/EkLZ0ETfcvN5LVtLYa8JkXybnOPsLvK+PAgMBAAECgYBdk09HDM7zzL657uHfzfOVrdslrTCj6p5mo +DzvCxLkkjIzYGnlPuqfNyGjozkpSWgSUc+X+EGLLl3WqEOVdWJtbM61fewEHlRTM5JzScvwrJ39t7o6 +CCAjKA0cBWBd6UWgbN/t53RoWvh9HrA2AW5YrT0ZiAgKe9y7EMUaENVJ8QJBAPhpdmb4ZL4Fkm4OKia +NEcjzn6mGTlZtef7K/0oRC9+2JkQnCuf6HBpaRhJoCJYg7DW8ZY+AV6xClKrgjBOfERMCQQDExhnzu2 +dsQ9k8QChBlpHO0TRbZBiQfC70oU31kM1AeLseZRmrxv9Yxzdl8D693NNWS2JbKOXl0kMHHcuGQLMVA +kBZ7WvkmPV3aPL6jnwp2pXepntdVnaTiSxJ1dkXShZ/VSSDNZMYKY306EtHrIu3NZHtXhdyHKcggDXr +qkBrdgErAkAlpGPojUwemOggr4FD8sLX1ot2hDJyyV7OK2FXfajWEYJyMRL1Gm9Uk1+Un53RAkJneqp +JGAzKpyttXBTIDO51AkEA98KTiROMnnU8Y6Mgcvr68/SMIsvCYMt9/mtwSBGgl80VaTQ5Hpaktl6Xbh +VUt5Wv0tRxlXZiViCGCD1EtrrwTw== +` diff --git a/core/peer/set.go b/core/peer/set.go new file mode 100644 index 0000000000..ea82ea807e --- /dev/null +++ b/core/peer/set.go @@ -0,0 +1,71 @@ +package peer + +import ( + "sync" +) + +// PeerSet is a threadsafe set of peers +type Set struct { + lk sync.RWMutex + ps map[ID]struct{} + + size int +} + +func NewSet() *Set { + ps := new(Set) + ps.ps = make(map[ID]struct{}) + ps.size = -1 + return ps +} + +func NewLimitedSet(size int) *Set { + ps := new(Set) + ps.ps = make(map[ID]struct{}) + ps.size = size + return ps +} + +func (ps *Set) Add(p ID) { + ps.lk.Lock() + ps.ps[p] = struct{}{} + ps.lk.Unlock() +} + +func (ps *Set) Contains(p ID) bool { + ps.lk.RLock() + _, ok := ps.ps[p] + ps.lk.RUnlock() + return ok +} + +func (ps *Set) Size() int { + ps.lk.RLock() + defer ps.lk.RUnlock() + return len(ps.ps) +} + +// TryAdd Attempts to add the given peer into the set. +// This operation can fail for one of two reasons: +// 1) The given peer is already in the set +// 2) The number of peers in the set is equal to size +func (ps *Set) TryAdd(p ID) bool { + var success bool + ps.lk.Lock() + if _, ok := ps.ps[p]; !ok && (len(ps.ps) < ps.size || ps.size == -1) { + success = true + ps.ps[p] = struct{}{} + } + ps.lk.Unlock() + return success +} + +func (ps *Set) Peers() []ID { + ps.lk.Lock() + out := make([]ID, 0, len(ps.ps)) + for p, _ := range ps.ps { + out = append(out, p) + } + ps.lk.Unlock() + return out +} diff --git a/core/peerstore/peerstore.go b/core/peerstore/peerstore.go new file mode 100644 index 0000000000..2e1b644989 --- /dev/null +++ b/core/peerstore/peerstore.go @@ -0,0 +1,162 @@ +// Package peerstore provides types and interfaces for local storage of address information, +// metadata, and public key material about libp2p peers. +package peerstore + +import ( + "context" + "errors" + "io" + "math" + "time" + + "github.com/libp2p/go-libp2p-core/peer" + + ic "github.com/libp2p/go-libp2p-core/crypto" + + ma "github.com/multiformats/go-multiaddr" +) + +var ErrNotFound = errors.New("item not found") + +var ( + // AddressTTL is the expiration time of addresses. + AddressTTL = time.Hour + + // TempAddrTTL is the ttl used for a short lived address + TempAddrTTL = time.Minute * 2 + + // ProviderAddrTTL is the TTL of an address we've received from a provider. + // This is also a temporary address, but lasts longer. After this expires, + // the records we return will require an extra lookup. + ProviderAddrTTL = time.Minute * 10 + + // RecentlyConnectedAddrTTL is used when we recently connected to a peer. + // It means that we are reasonably certain of the peer's address. + RecentlyConnectedAddrTTL = time.Minute * 10 + + // OwnObservedAddrTTL is used for our own external addresses observed by peers. + OwnObservedAddrTTL = time.Minute * 10 +) + +// Permanent TTLs (distinct so we can distinguish between them, constant as they +// are, in fact, permanent) +const ( + // PermanentAddrTTL is the ttl for a "permanent address" (e.g. bootstrap nodes). + PermanentAddrTTL = math.MaxInt64 - iota + + // ConnectedAddrTTL is the ttl used for the addresses of a peer to whom + // we're connected directly. This is basically permanent, as we will + // clear them + re-add under a TempAddrTTL after disconnecting. + ConnectedAddrTTL +) + +// Peerstore provides a threadsafe store of Peer related +// information. +type Peerstore interface { + io.Closer + + AddrBook + KeyBook + PeerMetadata + Metrics + ProtoBook + + // PeerInfo returns a peer.PeerInfo struct for given peer.ID. + // This is a small slice of the information Peerstore has on + // that peer, useful to other services. + PeerInfo(peer.ID) peer.AddrInfo + + // Peers returns all of the peer IDs stored across all inner stores. + Peers() peer.IDSlice +} + +// PeerMetadata can handle values of any type. Serializing values is +// up to the implementation. Dynamic type introspection may not be +// supported, in which case explicitly enlisting types in the +// serializer may be required. +// +// Refer to the docs of the underlying implementation for more +// information. +type PeerMetadata interface { + // Get/Put is a simple registry for other peer-related key/value pairs. + // if we find something we use often, it should become its own set of + // methods. this is a last resort. + Get(p peer.ID, key string) (interface{}, error) + Put(p peer.ID, key string, val interface{}) error +} + +// AddrBook holds the multiaddrs of peers. +type AddrBook interface { + + // AddAddr calls AddAddrs(p, []ma.Multiaddr{addr}, ttl) + AddAddr(p peer.ID, addr ma.Multiaddr, ttl time.Duration) + + // AddAddrs gives this AddrBook addresses to use, with a given ttl + // (time-to-live), after which the address is no longer valid. + // If the manager has a longer TTL, the operation is a no-op for that address + AddAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Duration) + + // SetAddr calls mgr.SetAddrs(p, addr, ttl) + SetAddr(p peer.ID, addr ma.Multiaddr, ttl time.Duration) + + // SetAddrs sets the ttl on addresses. This clears any TTL there previously. + // This is used when we receive the best estimate of the validity of an address. + SetAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Duration) + + // UpdateAddrs updates the addresses associated with the given peer that have + // the given oldTTL to have the given newTTL. + UpdateAddrs(p peer.ID, oldTTL time.Duration, newTTL time.Duration) + + // Addresses returns all known (and valid) addresses for a given peer + Addrs(p peer.ID) []ma.Multiaddr + + // AddrStream returns a channel that gets all addresses for a given + // peer sent on it. If new addresses are added after the call is made + // they will be sent along through the channel as well. + AddrStream(context.Context, peer.ID) <-chan ma.Multiaddr + + // ClearAddresses removes all previously stored addresses + ClearAddrs(p peer.ID) + + // PeersWithAddrs returns all of the peer IDs stored in the AddrBook + PeersWithAddrs() peer.IDSlice +} + +// KeyBook tracks the keys of Peers. +type KeyBook interface { + // PubKey stores the public key of a peer. + PubKey(peer.ID) ic.PubKey + + // AddPubKey stores the public key of a peer. + AddPubKey(peer.ID, ic.PubKey) error + + // PrivKey returns the private key of a peer, if known. Generally this might only be our own + // private key, see + // https://discuss.libp2p.io/t/what-is-the-purpose-of-having-map-peer-id-privatekey-in-peerstore/74. + PrivKey(peer.ID) ic.PrivKey + + // AddPrivKey stores the private key of a peer. + AddPrivKey(peer.ID, ic.PrivKey) error + + // PeersWithKeys returns all the peer IDs stored in the KeyBook + PeersWithKeys() peer.IDSlice +} + +// Metrics is just an object that tracks metrics +// across a set of peers. +type Metrics interface { + // RecordLatency records a new latency measurement + RecordLatency(peer.ID, time.Duration) + + // LatencyEWMA returns an exponentially-weighted moving avg. + // of all measurements of a peer's latency. + LatencyEWMA(peer.ID) time.Duration +} + +// ProtoBook tracks the protocols supported by peers +type ProtoBook interface { + GetProtocols(peer.ID) ([]string, error) + AddProtocols(peer.ID, ...string) error + SetProtocols(peer.ID, ...string) error + SupportsProtocols(peer.ID, ...string) ([]string, error) +} diff --git a/core/pnet/env.go b/core/pnet/env.go new file mode 100644 index 0000000000..c8db5e3cbd --- /dev/null +++ b/core/pnet/env.go @@ -0,0 +1,19 @@ +package pnet + +import "os" + +// EnvKey defines environment variable name for forcing usage of PNet in libp2p +// When environment variable of this name is set to "1" the ForcePrivateNetwork +// variable will be set to true. +const EnvKey = "LIBP2P_FORCE_PNET" + +// ForcePrivateNetwork is boolean variable that forces usage of PNet in libp2p +// Setting this variable to true or setting LIBP2P_FORCE_PNET environment variable +// to true will make libp2p to require private network protector. +// If no network protector is provided and this variable is set to true libp2p will +// refuse to connect. +var ForcePrivateNetwork = false + +func init() { + ForcePrivateNetwork = os.Getenv(EnvKey) == "1" +} diff --git a/core/pnet/error.go b/core/pnet/error.go new file mode 100644 index 0000000000..184b71d6ac --- /dev/null +++ b/core/pnet/error.go @@ -0,0 +1,34 @@ +package pnet + +// ErrNotInPrivateNetwork is an error that should be returned by libp2p when it +// tries to dial with ForcePrivateNetwork set and no PNet Protector +var ErrNotInPrivateNetwork = NewError("private network was not configured but" + + " is enforced by the environment") + +// Error is error type for ease of detecting PNet errors +type Error interface { + IsPNetError() bool +} + +// NewError creates new Error +func NewError(err string) error { + return pnetErr("privnet: " + err) +} + +// IsPNetError checks if given error is PNet Error +func IsPNetError(err error) bool { + v, ok := err.(Error) + return ok && v.IsPNetError() +} + +type pnetErr string + +var _ Error = (*pnetErr)(nil) + +func (p pnetErr) Error() string { + return string(p) +} + +func (pnetErr) IsPNetError() bool { + return true +} diff --git a/core/pnet/error_test.go b/core/pnet/error_test.go new file mode 100644 index 0000000000..e1fe462b2a --- /dev/null +++ b/core/pnet/error_test.go @@ -0,0 +1,20 @@ +package pnet + +import ( + "errors" + "testing" +) + +func TestIsPnetErr(t *testing.T) { + err := NewError("test") + + if err.Error() != "privnet: test" { + t.Fatalf("expected 'privnet: test' got '%s'", err.Error()) + } + if !IsPNetError(err) { + t.Fatal("expected the pnetErr to be detected by IsPnetError") + } + if IsPNetError(errors.New("not pnet error")) { + t.Fatal("expected generic error not to be pnetError") + } +} diff --git a/core/pnet/protector.go b/core/pnet/protector.go new file mode 100644 index 0000000000..c443ee8fda --- /dev/null +++ b/core/pnet/protector.go @@ -0,0 +1,15 @@ +// Package pnet provides interfaces for private networking in libp2p. +package pnet + +import "net" + +// Protector interface is a way for private network implementation to be transparent in +// libp2p. It is created by implementation and use by libp2p-conn to secure connections +// so they can be only established with selected number of peers. +type Protector interface { + // Wraps passed connection to protect it + Protect(net.Conn) (net.Conn, error) + + // Returns key fingerprint that is safe to expose + Fingerprint() []byte +} diff --git a/core/protocol/id.go b/core/protocol/id.go new file mode 100644 index 0000000000..f7e4a32baf --- /dev/null +++ b/core/protocol/id.go @@ -0,0 +1,9 @@ +package protocol + +// ID is an identifier used to write protocol headers in streams. +type ID string + +// These are reserved protocol.IDs. +const ( + TestingID ID = "/p2p/_testing" +) diff --git a/core/protocol/switch.go b/core/protocol/switch.go new file mode 100644 index 0000000000..f3ac369bec --- /dev/null +++ b/core/protocol/switch.go @@ -0,0 +1,81 @@ +// Package protocol provides core interfaces for protocol routing and negotiation in libp2p. +package protocol + +import ( + "io" +) + +// HandlerFunc is a user-provided function used by the Router to +// handle a protocol/stream. +// +// Will be invoked with the protocol ID string as the first argument, +// which may differ from the ID used for registration if the handler +// was registered using a match function. +type HandlerFunc = func(protocol string, rwc io.ReadWriteCloser) error + +// Router is an interface that allows users to add and remove protocol handlers, +// which will be invoked when incoming stream requests for registered protocols +// are accepted. +// +// Upon receiving an incoming stream request, the Router will check all registered +// protocol handlers to determine which (if any) is capable of handling the stream. +// The handlers are checked in order of registration; if multiple handlers are +// eligible, only the first to be registered will be invoked. +type Router interface { + + // AddHandler registers the given handler to be invoked for + // an exact literal match of the given protocol ID string. + AddHandler(protocol string, handler HandlerFunc) + + // AddHandlerWithFunc registers the given handler to be invoked + // when the provided match function returns true. + // + // The match function will be invoked with an incoming protocol + // ID string, and should return true if the handler supports + // the protocol. Note that the protocol ID argument is not + // used for matching; if you want to match the protocol ID + // string exactly, you must check for it in your match function. + AddHandlerWithFunc(protocol string, match func(string) bool, handler HandlerFunc) + + // RemoveHandler removes the registered handler (if any) for the + // given protocol ID string. + RemoveHandler(protocol string) + + // Protocols returns a list of all registered protocol ID strings. + // Note that the Router may be able to handle protocol IDs not + // included in this list if handlers were added with match functions + // using AddHandlerWithFunc. + Protocols() []string +} + +// Negotiator is a component capable of reaching agreement over what protocols +// to use for inbound streams of communication. +type Negotiator interface { + + // NegotiateLazy will return the registered protocol handler to use + // for a given inbound stream, returning as soon as the protocol has been + // determined. Returns an error if negotiation fails. + // + // NegotiateLazy may return before all protocol negotiation responses have been + // written to the stream. This is in contrast to Negotiate, which will block until + // the Negotiator is finished with the stream. + NegotiateLazy(rwc io.ReadWriteCloser) (io.ReadWriteCloser, string, HandlerFunc, error) + + // Negotiate will return the registered protocol handler to use for a given + // inbound stream, returning after the protocol has been determined and the + // Negotiator has finished using the stream for negotiation. Returns an + // error if negotiation fails. + Negotiate(rwc io.ReadWriteCloser) (string, HandlerFunc, error) + + // Handle calls Negotiate to determine which protocol handler to use for an + // inbound stream, then invokes the protocol handler function, passing it + // the protocol ID and the stream. Returns an error if negotiation fails. + Handle(rwc io.ReadWriteCloser) error +} + +// Switch is the component responsible for "dispatching" incoming stream requests to +// their corresponding stream handlers. It is both a Negotiator and a Router. +type Switch interface { + Router + Negotiator +} diff --git a/core/routing/options.go b/core/routing/options.go new file mode 100644 index 0000000000..4b235cbfc0 --- /dev/null +++ b/core/routing/options.go @@ -0,0 +1,50 @@ +package routing + +// Option is a single routing option. +type Option func(opts *Options) error + +// Options is a set of routing options +type Options struct { + // Allow expired values. + Expired bool + Offline bool + // Other (ValueStore implementation specific) options. + Other map[interface{}]interface{} +} + +// Apply applies the given options to this Options +func (opts *Options) Apply(options ...Option) error { + for _, o := range options { + if err := o(opts); err != nil { + return err + } + } + return nil +} + +// ToOption converts this Options to a single Option. +func (opts *Options) ToOption() Option { + return func(nopts *Options) error { + *nopts = *opts + if opts.Other != nil { + nopts.Other = make(map[interface{}]interface{}, len(opts.Other)) + for k, v := range opts.Other { + nopts.Other[k] = v + } + } + return nil + } +} + +// Expired is an option that tells the routing system to return expired records +// when no newer records are known. +var Expired Option = func(opts *Options) error { + opts.Expired = true + return nil +} + +// Offline is an option that tells the routing system to operate offline (i.e., rely on cached/local data only). +var Offline Option = func(opts *Options) error { + opts.Offline = true + return nil +} diff --git a/core/routing/query.go b/core/routing/query.go new file mode 100644 index 0000000000..3769d1c265 --- /dev/null +++ b/core/routing/query.go @@ -0,0 +1,86 @@ +package routing + +import ( + "context" + "sync" + + "github.com/libp2p/go-libp2p-core/peer" +) + +type QueryEventType int + +// Number of events to buffer. +var QueryEventBufferSize = 16 + +const ( + SendingQuery QueryEventType = iota + PeerResponse + FinalPeer + QueryError + Provider + Value + AddingPeer + DialingPeer +) + +type QueryEvent struct { + ID peer.ID + Type QueryEventType + Responses []*peer.AddrInfo + Extra string +} + +type routingQueryKey struct{} +type eventChannel struct { + mu sync.Mutex + ctx context.Context + ch chan<- *QueryEvent +} + +// waitThenClose is spawned in a goroutine when the channel is registered. This +// safely cleans up the channel when the context has been canceled. +func (e *eventChannel) waitThenClose() { + <-e.ctx.Done() + e.mu.Lock() + close(e.ch) + // 1. Signals that we're done. + // 2. Frees memory (in case we end up hanging on to this for a while). + e.ch = nil + e.mu.Unlock() +} + +// send sends an event on the event channel, aborting if either the passed or +// the internal context expire. +func (e *eventChannel) send(ctx context.Context, ev *QueryEvent) { + e.mu.Lock() + // Closed. + if e.ch == nil { + e.mu.Unlock() + return + } + // in case the passed context is unrelated, wait on both. + select { + case e.ch <- ev: + case <-e.ctx.Done(): + case <-ctx.Done(): + } + e.mu.Unlock() +} + +func RegisterForQueryEvents(ctx context.Context) (context.Context, <-chan *QueryEvent) { + ch := make(chan *QueryEvent, QueryEventBufferSize) + ech := &eventChannel{ch: ch, ctx: ctx} + go ech.waitThenClose() + return context.WithValue(ctx, routingQueryKey{}, ech), ch +} + +func PublishQueryEvent(ctx context.Context, ev *QueryEvent) { + ich := ctx.Value(routingQueryKey{}) + if ich == nil { + return + } + + // We *want* to panic here. + ech := ich.(*eventChannel) + ech.send(ctx, ev) +} diff --git a/core/routing/query_serde.go b/core/routing/query_serde.go new file mode 100644 index 0000000000..c7423ffe4c --- /dev/null +++ b/core/routing/query_serde.go @@ -0,0 +1,40 @@ +package routing + +import ( + "encoding/json" + + "github.com/libp2p/go-libp2p-core/peer" +) + +func (qe *QueryEvent) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]interface{}{ + "ID": peer.IDB58Encode(qe.ID), + "Type": int(qe.Type), + "Responses": qe.Responses, + "Extra": qe.Extra, + }) +} + +func (qe *QueryEvent) UnmarshalJSON(b []byte) error { + temp := struct { + ID string + Type int + Responses []*peer.AddrInfo + Extra string + }{} + err := json.Unmarshal(b, &temp) + if err != nil { + return err + } + if len(temp.ID) > 0 { + pid, err := peer.IDB58Decode(temp.ID) + if err != nil { + return err + } + qe.ID = pid + } + qe.Type = QueryEventType(temp.Type) + qe.Responses = temp.Responses + qe.Extra = temp.Extra + return nil +} diff --git a/core/routing/query_test.go b/core/routing/query_test.go new file mode 100644 index 0000000000..15b4846dbb --- /dev/null +++ b/core/routing/query_test.go @@ -0,0 +1,44 @@ +package routing + +import ( + "context" + "fmt" + "sync" + "testing" +) + +func TestEventsCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + ctx, events := RegisterForQueryEvents(ctx) + goch := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + for i := 0; i < 100; i++ { + PublishQueryEvent(ctx, &QueryEvent{Extra: fmt.Sprint(i)}) + } + close(goch) + for i := 100; i < 1000; i++ { + PublishQueryEvent(ctx, &QueryEvent{Extra: fmt.Sprint(i)}) + } + }() + go func() { + defer wg.Done() + i := 0 + for e := range events { + if i < 100 { + if e.Extra != fmt.Sprint(i) { + t.Errorf("expected %d, got %s", i, e.Extra) + } + } + i++ + } + if i < 100 { + t.Errorf("expected at least 100 events, got %d", i) + } + }() + <-goch + cancel() + wg.Wait() +} diff --git a/core/routing/routing.go b/core/routing/routing.go new file mode 100644 index 0000000000..c7078054a6 --- /dev/null +++ b/core/routing/routing.go @@ -0,0 +1,125 @@ +// Package routing provides interfaces for peer routing and content routing in libp2p. +package routing + +import ( + "context" + "errors" + + ci "github.com/libp2p/go-libp2p-core/crypto" + "github.com/libp2p/go-libp2p-core/peer" + + cid "github.com/ipfs/go-cid" +) + +// ErrNotFound is returned when the router fails to find the requested record. +var ErrNotFound = errors.New("routing: not found") + +// ErrNotSupported is returned when the router doesn't support the given record +// type/operation. +var ErrNotSupported = errors.New("routing: operation or key not supported") + +// ContentRouting is a value provider layer of indirection. It is used to find +// information about who has what content. +// +// Content is identified by CID (content identifier), which encodes a hash +// of the identified content in a future-proof manner. +type ContentRouting interface { + // Provide adds the given cid to the content routing system. If 'true' is + // passed, it also announces it, otherwise it is just kept in the local + // accounting of which objects are being provided. + Provide(context.Context, cid.Cid, bool) error + + // Search for peers who are able to provide a given key + FindProvidersAsync(context.Context, cid.Cid, int) <-chan peer.AddrInfo +} + +// PeerRouting is a way to find address information about certain peers. +// This can be implemented by a simple lookup table, a tracking server, +// or even a DHT. +type PeerRouting interface { + // Find specific Peer + // FindPeer searches for a peer with given ID, returns a peer.AddrInfo + // with relevant addresses. + FindPeer(context.Context, peer.ID) (peer.AddrInfo, error) +} + +// ValueStore is a basic Put/Get interface. +type ValueStore interface { + + // PutValue adds value corresponding to given Key. + PutValue(context.Context, string, []byte, ...Option) error + + // GetValue searches for the value corresponding to given Key. + GetValue(context.Context, string, ...Option) ([]byte, error) + + // SearchValue searches for better and better values from this value + // store corresponding to the given Key. By default implementations must + // stop the search after a good value is found. A 'good' value is a value + // that would be returned from GetValue. + // + // Useful when you want a result *now* but still want to hear about + // better/newer results. + // + // Implementations of this methods won't return ErrNotFound. When a value + // couldn't be found, the channel will get closed without passing any results + SearchValue(context.Context, string, ...Option) (<-chan []byte, error) +} + +// Routing is the combination of different routing types supported by libp2p. +// It can be satisfied by a single item (such as a DHT) or multiple different +// pieces that are more optimized to each task. +type Routing interface { + ContentRouting + PeerRouting + ValueStore + + // Bootstrap allows callers to hint to the routing system to get into a + // Boostrapped state and remain there. It is not a synchronous call. + Bootstrap(context.Context) error + + // TODO expose io.Closer or plain-old Close error +} + +// PubKeyFetcher is an interfaces that should be implemented by value stores +// that can optimize retrieval of public keys. +// +// TODO(steb): Consider removing, see https://github.com/libp2p/go-libp2p-routing/issues/22. +type PubKeyFetcher interface { + // GetPublicKey returns the public key for the given peer. + GetPublicKey(context.Context, peer.ID) (ci.PubKey, error) +} + +// KeyForPublicKey returns the key used to retrieve public keys +// from a value store. +func KeyForPublicKey(id peer.ID) string { + return "/pk/" + string(id) +} + +// GetPublicKey retrieves the public key associated with the given peer ID from +// the value store. +// +// If the ValueStore is also a PubKeyFetcher, this method will call GetPublicKey +// (which may be better optimized) instead of GetValue. +func GetPublicKey(r ValueStore, ctx context.Context, p peer.ID) (ci.PubKey, error) { + switch k, err := p.ExtractPublicKey(); err { + case peer.ErrNoPublicKey: + // check the datastore + case nil: + return k, nil + default: + return nil, err + } + + if dht, ok := r.(PubKeyFetcher); ok { + // If we have a DHT as our routing system, use optimized fetcher + return dht.GetPublicKey(ctx, p) + } + key := KeyForPublicKey(p) + pkval, err := r.GetValue(ctx, key) + if err != nil { + return nil, err + } + + // get PublicKey from node.Data + return ci.UnmarshalPublicKey(pkval) +} diff --git a/core/sec/insecure/insecure.go b/core/sec/insecure/insecure.go new file mode 100644 index 0000000000..801569cb0b --- /dev/null +++ b/core/sec/insecure/insecure.go @@ -0,0 +1,90 @@ +// Package insecure provides an insecure, unencrypted implementation of the the SecureConn and SecureTransport interfaces. +// +// Recommended only for testing and other non-production usage. +package insecure + +import ( + "context" + "net" + + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/sec" + + ci "github.com/libp2p/go-libp2p-core/crypto" +) + +// ID is the multistream-select protocol ID that should be used when identifying +// this security transport. +const ID = "/plaintext/1.0.0" + +// Transport is a no-op stream security transport. It provides no +// security and simply mocks the security and identity methods to +// return peer IDs known ahead of time. +type Transport struct { + id peer.ID +} + +// New constructs a new insecure transport. +func New(id peer.ID) *Transport { + return &Transport{ + id: id, + } +} + +// LocalPeer returns the transports local peer ID. +func (t *Transport) LocalPeer() peer.ID { + return t.id +} + +// LocalPrivateKey returns nil. This transport is not secure. +func (t *Transport) LocalPrivateKey() ci.PrivKey { + return nil +} + +// SecureInbound *pretends to secure* an outbound connection to the given peer. +func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.SecureConn, error) { + return &Conn{ + Conn: insecure, + local: t.id, + }, nil +} + +// SecureOutbound *pretends to secure* an outbound connection to the given peer. +func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { + return &Conn{ + Conn: insecure, + local: t.id, + remote: p, + }, nil +} + +// Conn is the connection type returned by the insecure transport. +type Conn struct { + net.Conn + local peer.ID + remote peer.ID +} + +// LocalPeer returns the local peer ID. +func (ic *Conn) LocalPeer() peer.ID { + return ic.local +} + +// RemotePeer returns the remote peer ID if we initiated the dial. Otherwise, it +// returns "" (because this connection isn't actually secure). +func (ic *Conn) RemotePeer() peer.ID { + return ic.remote +} + +// RemotePublicKey returns nil. This connection is not secure +func (ic *Conn) RemotePublicKey() ci.PubKey { + return nil +} + +// LocalPrivateKey returns nil. This connection is not secure. +func (ic *Conn) LocalPrivateKey() ci.PrivKey { + return nil +} + +var _ sec.SecureTransport = (*Transport)(nil) +var _ sec.SecureConn = (*Conn)(nil) diff --git a/core/sec/security.go b/core/sec/security.go new file mode 100644 index 0000000000..95c8e6a651 --- /dev/null +++ b/core/sec/security.go @@ -0,0 +1,26 @@ +// Package sec provides secure connection and transport interfaces for libp2p. +package sec + +import ( + "context" + "net" + + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" +) + +// SecureConn is an authenticated, encrypted connection. +type SecureConn interface { + net.Conn + network.ConnSecurity +} + +// A SecureTransport turns inbound and outbound unauthenticated, +// plain-text, native connections into authenticated, encrypted connections. +type SecureTransport interface { + // SecureInbound secures an inbound connection. + SecureInbound(ctx context.Context, insecure net.Conn) (SecureConn, error) + + // SecureOutbound secures an outbound connection. + SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, error) +} diff --git a/core/transport/transport.go b/core/transport/transport.go new file mode 100644 index 0000000000..774b8d824d --- /dev/null +++ b/core/transport/transport.go @@ -0,0 +1,113 @@ +// Package transport provides the Transport interface, which represents +// the devices and network protocols used to send and recieve data. +package transport + +import ( + "context" + "net" + "time" + + "github.com/libp2p/go-libp2p-core/mux" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + + ma "github.com/multiformats/go-multiaddr" +) + +// DialTimeout is the maximum duration a Dial is allowed to take. +// This includes the time between dialing the raw network connection, +// protocol selection as well the handshake, if applicable. +var DialTimeout = 60 * time.Second + +// AcceptTimeout is the maximum duration an Accept is allowed to take. +// This includes the time between accepting the raw network connection, +// protocol selection as well as the handshake, if applicable. +var AcceptTimeout = 60 * time.Second + +// A CapableConn represents a connection that has offers the basic +// capabilities required by libp2p: stream multiplexing, encryption and +// peer authentication. +// +// These capabilities may be natively provided by the transport, or they +// may be shimmed via the "connection upgrade" process, which converts a +// "raw" network connection into one that supports such capabilities by +// layering an encryption channel and a stream multiplexer. +// +// CapableConn provides accessors for the local and remote multiaddrs used to +// establish the connection and an accessor for the underlying Transport. +type CapableConn interface { + mux.MuxedConn + network.ConnSecurity + network.ConnMultiaddrs + + // Transport returns the transport to which this connection belongs. + Transport() Transport +} + +// Transport represents any device by which you can connect to and accept +// connections from other peers. +// +// The Transport interface allows you to open connections to other peers +// by dialing them, and also lets you listen for incoming connections. +// +// Connections returned by Dial and passed into Listeners are of type +// CapableConn, which means that they have been upgraded to support +// stream multiplexing and connection security (encryption and authentication). +// +// For a conceptual overview, see https://docs.libp2p.io/concepts/transport/ +type Transport interface { + // Dial dials a remote peer. It should try to reuse local listener + // addresses if possible but it may choose not to. + Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (CapableConn, error) + + // CanDial returns true if this transport knows how to dial the given + // multiaddr. + // + // Returning true does not guarantee that dialing this multiaddr will + // succeed. This function should *only* be used to preemptively filter + // out addresses that we can't dial. + CanDial(addr ma.Multiaddr) bool + + // Listen listens on the passed multiaddr. + Listen(laddr ma.Multiaddr) (Listener, error) + + // Protocol returns the set of protocols handled by this transport. + // + // See the Network interface for an explanation of how this is used. + Protocols() []int + + // Proxy returns true if this is a proxy transport. + // + // See the Network interface for an explanation of how this is used. + // TODO: Make this a part of the go-multiaddr protocol instead? + Proxy() bool +} + +// Listener is an interface closely resembling the net.Listener interface. The +// only real difference is that Accept() returns Conn's of the type in this +// package, and also exposes a Multiaddr method as opposed to a regular Addr +// method +type Listener interface { + Accept() (CapableConn, error) + Close() error + Addr() net.Addr + Multiaddr() ma.Multiaddr +} + +// Network is an inet.Network with methods for managing transports. +type TransportNetwork interface { + network.Network + + // AddTransport adds a transport to this Network. + // + // When dialing, this Network will iterate over the protocols in the + // remote multiaddr and pick the first protocol registered with a proxy + // transport, if any. Otherwise, it'll pick the transport registered to + // handle the last protocol in the multiaddr. + // + // When listening, this Network will iterate over the protocols in the + // local multiaddr and pick the *last* protocol registered with a proxy + // transport, if any. Otherwise, it'll pick the transport registered to + // handle the last protocol in the multiaddr. + AddTransport(t Transport) error +}