diff --git a/README.md b/README.md index f4b3847..656b621 100644 --- a/README.md +++ b/README.md @@ -10,29 +10,27 @@ An SSH key pair generator with password protected keys support. Supports generat ## Example ```go -filepath := filepath.Join(".ssh", "my_awesome_key") -passphrase := []byte("awesome_secret") -k, err := NewWithWrite(filepath, passphrase, key.Ed25519) +kp, err := keygen.New("awesome", keygen.WithPassphrase("awesome_secret"), + keygen.WithKeyType(keygen.Ed25519)) if err != nil { - fmt.Printf("error creating SSH key pair: %v", err) - os.Exit(1) + log.Fatalf("error creating SSH key pair: %v", err) } -fmt.Printf("Your authorized key: %s\n", string(k.PublicKey())) +fmt.Printf("Your authorized key: %s\n", kp.AuthorizedKey()) ``` ## Feedback We’d love to hear your thoughts on this project. Feel free to drop us a note! -* [Twitter](https://twitter.com/charmcli) -* [The Fediverse](https://mastodon.social/@charmcli) -* [Discord](https://charm.sh/chat) +- [Twitter](https://twitter.com/charmcli) +- [The Fediverse](https://mastodon.social/@charmcli) +- [Discord](https://charm.sh/chat) ## License [MIT](https://github.com/charmbracelet/keygen/raw/master/LICENSE) -*** +--- Part of [Charm](https://charm.sh). diff --git a/go.mod b/go.mod index 3d37736..73864a8 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.17 require ( github.com/caarlos0/sshmarshal v0.1.0 - github.com/mitchellh/go-homedir v1.1.0 golang.org/x/crypto v0.6.0 ) diff --git a/go.sum b/go.sum index e18d517..5eb889e 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ github.com/caarlos0/sshmarshal v0.1.0 h1:zTCZrDORFfWh526Tsb7vCm3+Yg/SfW/Ub8aQDeosk0I= github.com/caarlos0/sshmarshal v0.1.0/go.mod h1:7Pd/0mmq9x/JCzKauogNjSQEhivBclCQHfr9dlpDIyA= -github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= -github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= diff --git a/keygen.go b/keygen.go index f35f590..5603205 100644 --- a/keygen.go +++ b/keygen.go @@ -2,7 +2,6 @@ package keygen import ( - "bytes" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -16,9 +15,9 @@ import ( "os" "os/user" "path/filepath" + "strings" "github.com/caarlos0/sshmarshal" - "github.com/mitchellh/go-homedir" "golang.org/x/crypto/ssh" ) @@ -32,6 +31,11 @@ const ( ECDSA KeyType = "ecdsa" ) +// String implements the Stringer interface for KeyType. +func (k KeyType) String() string { + return string(k) +} + const rsaDefaultBits = 4096 // ErrMissingSSHKeys indicates we're missing some keys that we expected to @@ -43,7 +47,7 @@ type ErrUnsupportedKeyType struct { keyType string } -// Error implements the error interface for ErrUnsupportedKeyType +// Error implements the error interface for ErrUnsupportedKeyType. func (e ErrUnsupportedKeyType) Error() string { err := "unsupported key type" if e.keyType != "" { @@ -79,81 +83,132 @@ type SSHKeysAlreadyExistErr struct { // SSHKeyPair holds a pair of SSH keys and associated methods. type SSHKeyPair struct { path string // private key filename path; public key will have .pub appended + writeKeys bool passphrase []byte + rsaBitSize int keyType KeyType privateKey crypto.PrivateKey } func (s SSHKeyPair) privateKeyPath() string { - p := fmt.Sprintf("%s_%s", s.path, s.keyType) - return p + return s.path } func (s SSHKeyPair) publicKeyPath() string { return s.privateKeyPath() + ".pub" } +// Option is a functional option for SSHKeyPair. +type Option func(*SSHKeyPair) + +// WithPassphrase sets the passphrase for the private key. +func WithPassphrase(passphrase string) Option { + return func(s *SSHKeyPair) { + s.passphrase = []byte(passphrase) + } +} + +// WithKeyType sets the key type for the key pair. +// Available key types are RSA, Ed25519, and ECDSA. +func WithKeyType(keyType KeyType) Option { + return func(s *SSHKeyPair) { + s.keyType = keyType + } +} + +// WithBitSize sets the key size for the RSA key pair. +// This option is ignored for other key types. +func WithBitSize(bits int) Option { + return func(s *SSHKeyPair) { + s.rsaBitSize = bits + } +} + +// WithWrite writes the key pair to disk if it doesn't exist. +func WithWrite() Option { + return func(s *SSHKeyPair) { + s.writeKeys = true + } +} + // New generates an SSHKeyPair, which contains a pair of SSH keys. -func New(path string, passphrase []byte, keyType KeyType) (*SSHKeyPair, error) { +// +// If the key pair already exists, it will be loaded from disk, otherwise, a +// new SSH key pair is generated. +// If no key type is specified, Ed25519 will be used. +func New(path string, opts ...Option) (*SSHKeyPair, error) { var err error + for _, kt := range []KeyType{RSA, Ed25519, ECDSA} { + path = strings.TrimSuffix(path, "_"+kt.String()) + } + s := &SSHKeyPair{ path: path, - keyType: keyType, - passphrase: passphrase, + rsaBitSize: rsaDefaultBits, + keyType: Ed25519, + } + + for _, opt := range opts { + opt(s) } + if s.KeyPairExists() { privData, err := ioutil.ReadFile(s.privateKeyPath()) if err != nil { return nil, err } + var k interface{} - if len(passphrase) > 0 { - k, err = ssh.ParseRawPrivateKeyWithPassphrase(privData, passphrase) + if len(s.passphrase) > 0 { + k, err = ssh.ParseRawPrivateKeyWithPassphrase(privData, s.passphrase) } else { k, err = ssh.ParseRawPrivateKey(privData) } + if err != nil { return nil, err } + switch k := k.(type) { - case *rsa.PrivateKey, *ecdsa.PrivateKey, *ed25519.PrivateKey: + case *rsa.PrivateKey: + s.keyType = RSA + s.privateKey = k + case *ecdsa.PrivateKey: + s.keyType = ECDSA + s.privateKey = k + case *ed25519.PrivateKey: + s.keyType = Ed25519 s.privateKey = k default: return nil, ErrUnsupportedKeyType{fmt.Sprintf("%T", k)} } + return s, nil } - switch keyType { + + switch s.keyType { case Ed25519: err = s.generateEd25519Keys() case RSA: - err = s.generateRSAKeys(rsaDefaultBits) + err = s.generateRSAKeys(s.rsaBitSize) case ECDSA: err = s.generateECDSAKeys(elliptic.P384()) default: - return nil, ErrUnsupportedKeyType{string(keyType)} - } - if err != nil { - return nil, err + return nil, ErrUnsupportedKeyType{string(s.keyType)} } - return s, nil -} -// NewWithWrite generates an SSHKeyPair and writes it to disk if not exist. -func NewWithWrite(path string, passphrase []byte, keyType KeyType) (*SSHKeyPair, error) { - s, err := New(path, passphrase, keyType) if err != nil { return nil, err } - if !s.KeyPairExists() { - if err = s.WriteKeys(); err != nil { - return nil, err - } + + if s.writeKeys { + return s, s.WriteKeys() } + return s, nil } -// PrivateKey returns the unencrypted private key. +// PrivateKey returns the unencrypted crypto.PrivateKey. func (s *SSHKeyPair) PrivateKey() crypto.PrivateKey { switch s.keyType { case RSA, Ed25519, ECDSA: @@ -163,49 +218,103 @@ func (s *SSHKeyPair) PrivateKey() crypto.PrivateKey { } } -// PrivateKeyPEM returns the unencrypted private key in OPENSSH PEM format. -func (s *SSHKeyPair) PrivateKeyPEM() []byte { - block, err := s.pemBlock(nil) +// PublicKey returns the ssh.PublicKey for the key pair. +func (s *SSHKeyPair) PublicKey() ssh.PublicKey { + p, err := ssh.NewPublicKey(s.cryptoPublicKey()) if err != nil { return nil } - return pem.EncodeToMemory(block) + return p } -// PublicKey returns the SSH public key (RFC 4253). Ready to be used in an -// OpenSSH authorized_keys file. -func (s *SSHKeyPair) PublicKey() []byte { - var pk crypto.PublicKey - // Prepare public key +func (s *SSHKeyPair) cryptoPublicKey() crypto.PublicKey { switch s.keyType { case RSA: key, ok := s.privateKey.(*rsa.PrivateKey) if !ok { return nil } - pk = key.Public() + return key.Public() case Ed25519: key, ok := s.privateKey.(*ed25519.PrivateKey) if !ok { return nil } - pk = key.Public() + return key.Public() case ECDSA: key, ok := s.privateKey.(*ecdsa.PrivateKey) if !ok { return nil } - pk = key.Public() + return key.Public() default: return nil } - p, err := ssh.NewPublicKey(pk) +} + +// CryptoPublicKey returns the crypto.PublicKey of the SSH key pair. +func (s *SSHKeyPair) CryptoPublicKey() crypto.PublicKey { + return s.cryptoPublicKey() +} + +// RawAuthorizedKey returns the underlying SSH public key (RFC 4253) in OpenSSH +// authorized_keys format. +func (s *SSHKeyPair) RawAuthorizedKey() []byte { + bts, err := os.ReadFile(s.publicKeyPath()) + if err != nil { + return []byte(s.AuthorizedKey()) + } + + _, c, opts, _, err := ssh.ParseAuthorizedKey(bts) + if err != nil { + return []byte(s.AuthorizedKey()) + } + + ak := s.authorizedKey(s.PublicKey()) + if len(opts) > 0 { + ak = fmt.Sprintf("%s %s", strings.Join(opts, ","), ak) + } + + if c != "" { + ak = fmt.Sprintf("%s %s", ak, c) + } + + return []byte(ak) +} + +func (s *SSHKeyPair) authorizedKey(pk ssh.PublicKey) string { + if pk == nil { + return "" + } + + // serialize authorized key + return strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pk))) +} + +// AuthorizedKey returns the SSH public key (RFC 4253) in OpenSSH authorized_keys +// format. The returned string is trimmed of sshd options and comments. +func (s *SSHKeyPair) AuthorizedKey() string { + return s.authorizedKey(s.PublicKey()) +} + +// RawPrivateKey returns the raw unencrypted private key bytes in PEM format. +func (s *SSHKeyPair) RawPrivateKey() []byte { + return s.rawPrivateKey(nil) +} + +// RawProtectedPrivateKey returns the raw password protected private key bytes +// in PEM format. +func (s *SSHKeyPair) RawProtectedPrivateKey() []byte { + return s.rawPrivateKey(s.passphrase) +} + +func (s *SSHKeyPair) rawPrivateKey(pass []byte) []byte { + block, err := s.pemBlock(pass) if err != nil { return nil } - // serialize public key - ak := ssh.MarshalAuthorizedKey(p) - return pubKeyWithMemo(ak) + + return pem.EncodeToMemory(block) } func (s *SSHKeyPair) pemBlock(passphrase []byte) (*pem.Block, error) { @@ -273,7 +382,7 @@ func (s *SSHKeyPair) prepFilesystem() error { keyDir := filepath.Dir(s.path) if keyDir != "" { - keyDir, err = homedir.Expand(keyDir) + keyDir, err = filepath.Abs(keyDir) if err != nil { return err } @@ -314,20 +423,11 @@ func (s *SSHKeyPair) prepFilesystem() error { // WriteKeys writes the SSH key pair to disk. func (s *SSHKeyPair) WriteKeys() error { var err error - priv := s.PrivateKeyPEM() - pub := s.PublicKey() - if priv == nil || pub == nil { + priv := s.RawProtectedPrivateKey() + if priv == nil { return ErrMissingSSHKeys } - // Encrypt private key with passphrase - if len(s.passphrase) > 0 { - block, err := s.pemBlock(s.passphrase) - if err != nil { - return err - } - priv = pem.EncodeToMemory(block) - } if err = s.prepFilesystem(); err != nil { return err } @@ -335,7 +435,12 @@ func (s *SSHKeyPair) WriteKeys() error { if err := writeKeyToFile(priv, s.privateKeyPath()); err != nil { return err } - if err := writeKeyToFile(pub, s.publicKeyPath()); err != nil { + + ak := s.AuthorizedKey() + if memo := pubKeyMemo(); memo != "" { + ak = fmt.Sprintf("%s %s", ak, memo) + } + if err := writeKeyToFile([]byte(ak), s.publicKeyPath()); err != nil { return err } @@ -367,17 +472,17 @@ func fileExists(path string) bool { // attaches a user@host suffix to a serialized public key. returns the original // pubkey if we can't get the username or host. -func pubKeyWithMemo(pubKey []byte) []byte { +func pubKeyMemo() string { u, err := user.Current() if err != nil { - return pubKey + return "" } hostname, err := os.Hostname() if err != nil { - return pubKey + return "" } - return append(bytes.TrimRight(pubKey, "\n"), []byte(fmt.Sprintf(" %s@%s\n", u.Username, hostname))...) + return fmt.Sprintf("%s@%s\n", u.Username, hostname) } // Error returns the a human-readable error message for SSHKeysAlreadyExistErr. diff --git a/keygen_test.go b/keygen_test.go index c5c5b35..5762cfe 100644 --- a/keygen_test.go +++ b/keygen_test.go @@ -1,8 +1,8 @@ package keygen import ( + "bytes" "crypto/elliptic" - "fmt" "io/ioutil" "os" "path/filepath" @@ -10,11 +10,13 @@ import ( ) func TestNewSSHKeyPair(t *testing.T) { - p := filepath.Join(t.TempDir(), "test") - _, err := NewWithWrite(p, []byte(""), RSA) + kp, err := New("") if err != nil { t.Errorf("error creating SSH key pair: %v", err) } + if kp.keyType != Ed25519 { + t.Errorf("expected default key type to be Ed25519, got %s", kp.keyType) + } } func TestGenerateEd25519Keys(t *testing.T) { @@ -35,11 +37,11 @@ func TestGenerateEd25519Keys(t *testing.T) { // TODO: is there a good way to validate these? Lengths seem to vary a bit, // so far now we're just asserting that the keys indeed exist. - if len(k.PrivateKeyPEM()) == 0 { + if len(k.RawPrivateKey()) == 0 { t.Error("error creating SSH private key PEM; key is 0 bytes") } - if len(k.PublicKey()) == 0 { - t.Error("error creating SSH public key; key is 0 bytes") + if len(k.AuthorizedKey()) == 0 { + t.Error("error creating SSH authorized key; key is 0 bytes") } }) @@ -101,10 +103,10 @@ func TestGenerateECDSAKeys(t *testing.T) { // TODO: is there a good way to validate these? Lengths seem to vary a bit, // so far now we're just asserting that the keys indeed exist. - if len(k.PrivateKeyPEM()) == 0 { + if len(k.RawPrivateKey()) == 0 { t.Error("error creating SSH private key PEM; key is 0 bytes") } - if len(k.PublicKey()) == 0 { + if len(k.AuthorizedKey()) == 0 { t.Error("error creating SSH public key; key is 0 bytes") } }) @@ -174,11 +176,11 @@ func createEmptyFile(t *testing.T, path string) (ok bool) { func TestGeneratePublicKeyWithEmptyDir(t *testing.T) { for _, keyType := range []KeyType{RSA, ECDSA, Ed25519} { func(t *testing.T) { - k, err := NewWithWrite("testkey", nil, keyType) + fn := "testkey" + k, err := New(fn, WithKeyType(keyType), WithWrite()) if err != nil { t.Fatalf("error creating SSH key pair: %v", err) } - fn := fmt.Sprintf("testkey_%s", keyType) f, err := os.Open(fn + ".pub") if err != nil { t.Fatalf("error opening SSH key file: %v", err) @@ -190,7 +192,7 @@ func TestGeneratePublicKeyWithEmptyDir(t *testing.T) { } defer os.Remove(fn) defer os.Remove(fn + ".pub") - if string(k.PublicKey()) != string(fc) { + if bytes.Equal(k.RawAuthorizedKey(), fc) { t.Errorf("error key mismatch\nprivate key:\n%s\n\nactual file:\n%s", k.PrivateKey(), string(fc)) } }(t) @@ -201,11 +203,11 @@ func TestGenerateKeyWithPassphrase(t *testing.T) { for _, keyType := range []KeyType{RSA, ECDSA, Ed25519} { ph := "testpass" func(t *testing.T) { - _, err := NewWithWrite("testph", []byte(ph), keyType) + _, err := New("testph", WithKeyType(keyType), WithPassphrase(ph), WithWrite()) if err != nil { t.Fatalf("error creating SSH key pair: %v", err) } - fn := fmt.Sprintf("testph_%s", keyType) + fn := "testph" f, err := os.Open(fn) if err != nil { t.Fatalf("error opening SSH key file: %v", err) @@ -217,11 +219,11 @@ func TestGenerateKeyWithPassphrase(t *testing.T) { } defer os.Remove(fn) defer os.Remove(fn + ".pub") - k, err := New("testph", []byte(ph), keyType) + k, err := New("testph", WithKeyType(keyType), WithPassphrase(ph)) if err != nil { t.Fatalf("error reading SSH key pair: %v", err) } - if string(k.PrivateKeyPEM()) == string(fc) { + if bytes.Equal(k.RawPrivateKey(), fc) { t.Errorf("encrypted private key matches file contents") } }(t) @@ -231,7 +233,7 @@ func TestGenerateKeyWithPassphrase(t *testing.T) { func TestReadingKeyWithPassphrase(t *testing.T) { for _, keyType := range []KeyType{RSA, ECDSA, Ed25519} { kp := filepath.Join("testdata", "test") - _, err := New(kp, []byte("test"), keyType) + _, err := New(kp, WithKeyType(keyType), WithPassphrase("test")) if err != nil { t.Fatalf("error reading SSH key pair: %v", err) }