diff --git a/cmd/README.md b/cmd/README.md new file mode 100644 index 0000000..d0efa12 --- /dev/null +++ b/cmd/README.md @@ -0,0 +1,6 @@ +# TLS handshake example + +Run +```bash +go run cmd/tlsdiag.go server +``` diff --git a/cmd/tlsdiag.go b/cmd/tlsdiag.go new file mode 100644 index 0000000..73aef5f --- /dev/null +++ b/cmd/tlsdiag.go @@ -0,0 +1,33 @@ +package main + +import ( + "fmt" + "os" + + "github.com/libp2p/go-libp2p-tls/cmd/cmdimpl" +) + +func main() { + if len(os.Args) <= 1 { + fmt.Println("missing argument: client / server") + return + } + + role := os.Args[1] + // remove the role argument from os.Args + os.Args = append([]string{os.Args[0]}, os.Args[2:]...) + + var err error + switch role { + case "client": + err = cmdimpl.StartClient() + case "server": + err = cmdimpl.StartServer() + default: + fmt.Println("invalid argument. Expected client / server") + return + } + if err != nil { + panic(err) + } +} diff --git a/cmd/tlsdiag/client.go b/cmd/tlsdiag/client.go new file mode 100644 index 0000000..9de3347 --- /dev/null +++ b/cmd/tlsdiag/client.go @@ -0,0 +1,62 @@ +package cmdimpl + +import ( + "context" + "flag" + "fmt" + "io/ioutil" + "net" + "time" + + peer "github.com/libp2p/go-libp2p-peer" + libp2ptls "github.com/libp2p/go-libp2p-tls" +) + +func StartClient() error { + port := flag.Int("p", 5533, "port") + peerIDString := flag.String("id", "", "peer ID") + keyType := flag.String("key", "ecdsa", "rsa, ecdsa, ed25519 or secp256k1") + flag.Parse() + + priv, err := generateKey(*keyType) + if err != nil { + return err + } + + peerID, err := peer.IDB58Decode(*peerIDString) + if err != nil { + return err + } + + id, err := peer.IDFromPrivateKey(priv) + if err != nil { + return err + } + fmt.Printf(" Peer ID: %s\n", id.Pretty()) + tp, err := libp2ptls.New(priv) + if err != nil { + return err + } + + remoteAddr := fmt.Sprintf("localhost:%d", *port) + fmt.Printf("Dialing %s\n", remoteAddr) + conn, err := net.Dial("tcp", remoteAddr) + if err != nil { + return err + } + fmt.Printf("Dialed raw connection to %s\n", conn.RemoteAddr()) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + sconn, err := tp.SecureOutbound(ctx, conn, peerID) + if err != nil { + return err + } + fmt.Printf("Authenticated server: %s\n", sconn.RemotePeer().Pretty()) + data, err := ioutil.ReadAll(sconn) + if err != nil { + return err + } + fmt.Printf("Received message from server: %s\n", string(data)) + return nil +} diff --git a/cmd/tlsdiag/key.go b/cmd/tlsdiag/key.go new file mode 100644 index 0000000..fc2ad8a --- /dev/null +++ b/cmd/tlsdiag/key.go @@ -0,0 +1,28 @@ +package cmdimpl + +import ( + "crypto/rand" + "fmt" + + ic "github.com/libp2p/go-libp2p-crypto" +) + +func generateKey(keyType string) (priv ic.PrivKey, err error) { + switch keyType { + case "rsa": + fmt.Printf("Generated new peer with an RSA key.") + priv, _, err = ic.GenerateRSAKeyPair(2048, rand.Reader) + case "ecdsa": + fmt.Printf("Generated new peer with an ECDSA key.") + priv, _, err = ic.GenerateECDSAKeyPair(rand.Reader) + case "ed25519": + fmt.Printf("Generated new peer with an Ed25519 key.") + priv, _, err = ic.GenerateEd25519Key(rand.Reader) + case "secp256k1": + fmt.Printf("Generated new peer with an Secp256k1 key.") + priv, _, err = ic.GenerateSecp256k1Key(rand.Reader) + default: + return nil, fmt.Errorf("unknown key type: %s", keyType) + } + return +} diff --git a/cmd/tlsdiag/server.go b/cmd/tlsdiag/server.go new file mode 100644 index 0000000..b192c44 --- /dev/null +++ b/cmd/tlsdiag/server.go @@ -0,0 +1,67 @@ +package cmdimpl + +import ( + "context" + "flag" + "fmt" + "net" + "time" + + peer "github.com/libp2p/go-libp2p-peer" + libp2ptls "github.com/libp2p/go-libp2p-tls" +) + +func StartServer() error { + port := flag.Int("p", 5533, "port") + keyType := flag.String("key", "ecdsa", "rsa, ecdsa, ed25519 or secp256k1") + flag.Parse() + + priv, err := generateKey(*keyType) + if err != nil { + return err + } + + id, err := peer.IDFromPrivateKey(priv) + if err != nil { + return err + } + fmt.Printf(" Peer ID: %s\n", id.Pretty()) + tp, err := libp2ptls.New(priv) + if err != nil { + return err + } + + ln, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", *port)) + if err != nil { + return err + } + fmt.Printf("Listening for new connections on %s\n", ln.Addr()) + fmt.Printf("Now run the following command in a separate terminal:\n") + fmt.Printf("\tgo run cmd/tlsdiag.go client -p %d -id %s\n", *port, id.Pretty()) + + for { + conn, err := ln.Accept() + if err != nil { + return err + } + fmt.Printf("Accepted raw connection from %s\n", conn.RemoteAddr()) + go func() { + if err := handleConn(tp, conn); err != nil { + fmt.Printf("Error handling connection from %s: %s\n", conn.RemoteAddr(), err) + } + }() + } +} + +func handleConn(tp *libp2ptls.Transport, conn net.Conn) error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + sconn, err := tp.SecureInbound(ctx, conn) + if err != nil { + return err + } + fmt.Printf("Authenticated client: %s\n", sconn.RemotePeer().Pretty()) + fmt.Fprintf(sconn, "Hello client!") + fmt.Printf("Closing connection to %s\n", conn.RemoteAddr()) + return sconn.Close() +}