diff --git a/common/sshkeys/sshkeys.go b/common/sshkeys/sshkeys.go index 5936689f2..eb0bc605f 100644 --- a/common/sshkeys/sshkeys.go +++ b/common/sshkeys/sshkeys.go @@ -34,7 +34,7 @@ import ( ) const ( - userProvidedKey = "user_provided_key" + shoudlAskUserForSSHKey = "____m2k_ask_user_to_provide_the_ssh_private_key" ) var ( @@ -107,7 +107,7 @@ func loadSSHKeysOfCurrentUser() { // Ask whether to load private keys or provide own key options := []string{ - "Load private ssh keys from " + privateKeyDir, + fmt.Sprintf("Load the private SSH keys from the directory '%s'" + privateKeyDir), "Provide your own key", "No, I will add them later if necessary.", } @@ -115,57 +115,48 @@ func loadSSHKeysOfCurrentUser() { If any of the repos require ssh keys you will need to provide them. Select an option:` selectedOption := qaengine.FetchSelectAnswer(common.ConfigRepoLoadPrivKey, message, nil, "", options, nil) - switch selectedOption { case options[0]: - if err := loadKeysFromDirectory(privateKeyDir); err != nil { - logrus.Warn("Can't load keys from directory. Error:", err) + selectedKeyFilenames, err := loadKeysFromDirectory(privateKeyDir) + if err != nil { + logrus.Warnf("Failed to load the keys from the SSH directory '%s'. Error: %q", privateKeyDir, err) return } - + // Save the filenames for now. We will decrypt them if and when we need them. + privateKeysToConsider = selectedKeyFilenames case options[1]: - privateKeysToConsider = []string{userProvidedKey} - + privateKeysToConsider = []string{shoudlAskUserForSSHKey} default: logrus.Debug("Don't read private keys. They will be added later if necessary.") return } - } -func loadKeysFromDirectory(directory string) error { +func loadKeysFromDirectory(directory string) ([]string, error) { finfos, err := os.ReadDir(directory) if err != nil { - return fmt.Errorf("failed to read the SSH directory at path %q: %w", directory, err) + return nil, fmt.Errorf("failed to read the directory '%s'. Error: %w", directory, err) } - if len(finfos) == 0 { - logrus.Warn("No key files were found in", directory) - return nil + return nil, fmt.Errorf("no key files were found in the directory '%s'", directory) } - filenames := []string{} for _, finfo := range finfos { filenames = append(filenames, finfo.Name()) } - selectedFilenames := qaengine.FetchMultiSelectAnswer( common.ConfigRepoKeyPathsKey, - fmt.Sprintf("These are the files found in %q. Select the keys to consider:", directory), - []string{"Select all the keys that give access to git repos."}, + fmt.Sprintf("These are the files we found in the SSH directory '%s'. Select the keys to consider:", directory), + []string{"Select all the keys that give access to the git repos."}, filenames, filenames, nil, ) - if len(selectedFilenames) == 0 { logrus.Info("All key files ignored.") - return nil + return nil, nil } - - // Save the filenames for now. We will decrypt them if and when we need them. - privateKeysToConsider = selectedFilenames - return nil + return selectedFilenames, nil } func marshalRSAIntoPEM(key *rsa.PrivateKey) string { @@ -186,29 +177,22 @@ func marshalECDSAIntoPEM(key *ecdsa.PrivateKey) string { return string(PEMBytes) } -func loadSSHKey(filename string) (string, error) { - path := filepath.Join(privateKeyDir, filename) - fileBytes, err := os.ReadFile(path) - if err != nil { - logrus.Errorf("Failed to read the private key file at path %q Error: %q", path, err) - return "", err - } - key, err := ssh.ParseRawPrivateKey(fileBytes) +// loadSSHPrivateKeyFromBytes tries to parse the bytes as an SSH private key. +// The keyName is optional (used to ask the user for the password if necessary). +func loadSSHPrivateKeyFromBytes(keyBytes []byte, keyName string) (string, error) { + key, err := ssh.ParseRawPrivateKey(keyBytes) if err != nil { // Could be an encrypted private key. if _, ok := err.(*ssh.PassphraseMissingError); !ok { - logrus.Errorf("Failed to parse the private key file at path %q Error %q", path, err) - return "", err + return "", fmt.Errorf("failed to parse as a SSH private key. Error %w", err) } - - qaKey := common.JoinQASubKeys(common.ConfigRepoPrivKey, `"`+filename+`"`, "password") - desc := fmt.Sprintf("Enter the password to decrypt the private key %q : ", filename) + qaKey := common.JoinQASubKeys(common.ConfigRepoPrivKey, `"`+keyName+`"`, "password") + desc := fmt.Sprintf("Enter the password to decrypt the SSH private key '%s' : ", keyName) hints := []string{"Password:"} password := qaengine.FetchPasswordAnswer(qaKey, desc, hints, nil) - key, err = ssh.ParseRawPrivateKeyWithPassphrase(fileBytes, []byte(password)) + key, err = ssh.ParseRawPrivateKeyWithPassphrase(keyBytes, []byte(password)) if err != nil { - logrus.Errorf("Failed to parse the encrypted private key file at path %q Error %q", path, err) - return "", err + return "", fmt.Errorf("failed to decrypt and parse the encrypted private SSH key '%s' . Error %w", keyName, err) } } // *ecdsa.PrivateKey @@ -218,61 +202,70 @@ func loadSSHKey(filename string) (string, error) { case *ecdsa.PrivateKey: return marshalECDSAIntoPEM(actualKey), nil default: - logrus.Errorf("Unknown key type [%T]", key) return "", fmt.Errorf("unknown key type [%T]", key) } } +func loadSSHPrivateKey(filename string) (string, error) { + path := filepath.Join(privateKeyDir, filename) + fileBytes, err := os.ReadFile(path) + if err != nil { + return "", fmt.Errorf("failed to read the SSH private key file '%s' . Error: %w", path, err) + } + return loadSSHPrivateKeyFromBytes(fileBytes, filename) +} + // GetSSHKey returns the private key for the given domain. func GetSSHKey(domain string) (string, bool) { loadSSHKeysOfCurrentUser() if len(privateKeysToConsider) == 0 { return "", false } - if privateKeysToConsider[0] == userProvidedKey { - key := qaengine.FetchStringAnswer(common.ConfigRepoPrivKey, "Provide your own PEM-formatted private key:", []string{"Should not be empty"}, "", nil) + + if len(privateKeysToConsider) == 1 && privateKeysToConsider[0] == shoudlAskUserForSSHKey { + qaKey := common.JoinQASubKeys(common.ConfigRepoKeysKey, `"`+domain+`"`, "keyData") + validatedKey := "" + key := qaengine.FetchStringAnswer( + qaKey, + fmt.Sprintf("Provide a PEM-formatted SSH private key for the domain '%s':", domain), + []string{"To skip this question, just leave the answer empty"}, + "", + func(ansI interface{}) error { + ans := ansI.(string) + if ans == "" { + return nil + } + t1, err := loadSSHPrivateKeyFromBytes([]byte(ans), domain) + if err == nil { + validatedKey = t1 + } + return err + }, + ) if key == "" { - logrus.Error("User-provided private key is empty.") - return "", false - } - if err := validatePEMPrivateKey(key); err != nil { - logrus.Error("Can't validate the PEM-formatted private key. Error:", err) + logrus.Debugf("No key was provided for the domain '%s'", domain) return "", false } - return key, true + return validatedKey, true } filenames := privateKeysToConsider noAnswer := "none of the above" filenames = append(filenames, noAnswer) qaKey := common.JoinQASubKeys(common.ConfigRepoKeysKey, `"`+domain+`"`, "key") - desc := fmt.Sprintf("Select the key to use for the git domain %s :", domain) - hints := []string{fmt.Sprintf("If none of the keys are correct, select %s", noAnswer)} + desc := fmt.Sprintf("Select the key to use for the git domain '%s' :", domain) + hints := []string{fmt.Sprintf("If none of the keys are correct, select '%s'", noAnswer)} filename := qaengine.FetchSelectAnswer(qaKey, desc, hints, noAnswer, filenames, nil) if filename == noAnswer { - logrus.Debugf("No key selected for domain %s", domain) + logrus.Debugf("No key was selected for domain '%s'", domain) return "", false } logrus.Debug("Loading the key", filename) - key, err := loadSSHKey(filename) + key, err := loadSSHPrivateKey(filename) if err != nil { - logrus.Warnf("Failed to load the key %q Error %q", filename, err) + logrus.Warnf("Failed to load the SSH private key file '%s' . Error %q", filename, err) return "", false } return key, true } - -func validatePEMPrivateKey(key string) error { - block, _ := pem.Decode([]byte(key)) - if block == nil || block.Type != "RSA PRIVATE KEY" { - return fmt.Errorf("invalid PEM private key format") - } - - _, err := x509.ParsePKCS1PrivateKey(block.Bytes) - if err != nil { - return err - } - - return nil -}