Skip to content

Commit

Permalink
feat: Allow master configuration for ssh key type (#10072)
Browse files Browse the repository at this point in the history
  • Loading branch information
gt2345 authored and thiagodallacqua-hpe committed Oct 28, 2024
1 parent 6e0c7e7 commit 79ebe68
Show file tree
Hide file tree
Showing 16 changed files with 157 additions and 56 deletions.
5 changes: 5 additions & 0 deletions docs/reference/deploy/master-config-reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,11 @@ Specifies configuration settings for SSH.

Number of bits to use when generating RSA keys for SSH for tasks. Maximum size is 16384.

``key_type``
============

Specifies the crypto system for SSH. Currently accepts ``RSA``, ``ECDSA`` or ``ED25519``.

``authz``
=========

Expand Down
8 changes: 8 additions & 0 deletions docs/release-notes/ssh-crypto-system.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
:orphan:

**Improvements**

- Master Configuration: Add support for crypto system configuration for ssh connection.
``security.key_type`` now accepts ``RSA``, ``ECDSA`` or ``ED25519``. Default key type is changed
from ``1024-bit RSA`` to ``ED25519``, since ``ED25519`` keys are faster and more secure than the
old default, and ``ED25519`` is also the default key type for ``ssh-keygen``.
11 changes: 0 additions & 11 deletions harness/determined/cli/shell.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse
import contextlib
import functools
import getpass
import os
import pathlib
import platform
Expand All @@ -22,9 +21,6 @@

def start_shell(args: argparse.Namespace) -> None:
sess = cli.setup_session(args)
data = {}
if args.passphrase:
data["passphrase"] = getpass.getpass("Enter new passphrase: ")
config = ntsc.parse_config(args.config_file, None, args.config, args.volume)
workspace_id = cli.workspace.get_workspace_id_from_args(args)

Expand All @@ -35,7 +31,6 @@ def start_shell(args: argparse.Namespace) -> None:
args.template,
context_path=args.context,
includes=args.include,
data=data,
workspace_id=workspace_id,
)
shell = bindings.v1LaunchShellResponse.from_json(resp).shell
Expand Down Expand Up @@ -280,12 +275,6 @@ def _open_shell(
help=ntsc.INCLUDE_DESC,
),
cli.Arg("--config", action="append", default=[], help=ntsc.CONFIG_DESC),
cli.Arg(
"-p",
"--passphrase",
action="store_true",
help="passphrase to encrypt the shell private key",
),
cli.Arg(
"--template",
type=str,
Expand Down
16 changes: 1 addition & 15 deletions master/internal/api_shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package internal
import (
"archive/tar"
"context"
"encoding/json"
"fmt"
"strconv"

Expand Down Expand Up @@ -253,20 +252,7 @@ func (a *apiServer) LaunchShell(
}
maps.Copy(launchReq.Spec.Base.ExtraEnvVars, oidcPachydermEnvVars)

var passphrase *string
if len(req.Data) > 0 {
var data map[string]interface{}
if err = json.Unmarshal(req.Data, &data); err != nil {
return nil, status.Errorf(codes.Internal, "failed to parse data %s: %s", req.Data, err)
}
if pwd, ok := data["passphrase"]; ok {
if typed, typedOK := pwd.(string); typedOK {
passphrase = &typed
}
}
}

keys, err := ssh.GenerateKey(launchReq.Spec.Base.SSHRsaSize, passphrase)
keys, err := ssh.GenerateKey(launchReq.Spec.Base.SSHConfig)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
Expand Down
2 changes: 1 addition & 1 deletion master/internal/api_user_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func setupAPITest(t *testing.T, pgdb *db.PgDB,
TaskContainerDefaults: model.TaskContainerDefaultsConfig{},
ResourceConfig: *config.DefaultResourceConfig(),
},
taskSpec: &tasks.TaskSpec{SSHRsaSize: 1024},
taskSpec: &tasks.TaskSpec{SSHConfig: config.SSHConfig{KeyType: "ED25519"}},
allRms: map[string]rm.ResourceManager{config.DefaultClusterName: mockRM},
},
}
Expand Down
27 changes: 21 additions & 6 deletions master/internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ const (
preemptionScheduler = "preemption"
)

const (
// KeyTypeRSA uses RSA.
KeyTypeRSA = "RSA"
// KeyTypeECDSA uses ECDSA.
KeyTypeECDSA = "ECDSA"
// KeyTypeED25519 uses ED25519.
KeyTypeED25519 = "ED25519"
)

type (
// ExperimentConfigPatch is the updatedble fields for patching an experiment.
ExperimentConfigPatch struct {
Expand Down Expand Up @@ -108,7 +117,7 @@ func DefaultConfig() *Config {
Group: "root",
},
SSH: SSHConfig{
RsaKeySize: 1024,
KeyType: KeyTypeED25519,
},
AuthZ: *DefaultAuthZConfig(),
},
Expand Down Expand Up @@ -452,7 +461,8 @@ type SecurityConfig struct {

// SSHConfig is the configuration setting for SSH.
type SSHConfig struct {
RsaKeySize int `json:"rsa_key_size"`
RsaKeySize int `json:"rsa_key_size"`
KeyType string `json:"key_type"`
}

// TLSConfig is the configuration for setting up serving over TLS.
Expand All @@ -475,10 +485,15 @@ func (t *TLSConfig) Validate() []error {
// Validate implements the check.Validatable interface.
func (t *SSHConfig) Validate() []error {
var errs []error
if t.RsaKeySize < 1 {
errs = append(errs, errors.New("RSA Key size must be greater than 0"))
} else if t.RsaKeySize > 16384 {
errs = append(errs, errors.New("RSA Key size must be less than 16,384"))
if t.KeyType != KeyTypeRSA && t.KeyType != KeyTypeECDSA && t.KeyType != KeyTypeED25519 {
errs = append(errs, errors.New("Crypto system must be one of 'RSA', 'ECDSA' or 'ED25519'"))
}
if t.KeyType == KeyTypeRSA {
if t.RsaKeySize < 1 {
errs = append(errs, errors.New("RSA Key size must be greater than 0"))
} else if t.RsaKeySize > 16384 {
errs = append(errs, errors.New("RSA Key size must be less than 16,384"))
}
}
return errs
}
Expand Down
2 changes: 1 addition & 1 deletion master/internal/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -1242,7 +1242,7 @@ func (m *Master) Run(ctx context.Context, gRPCLogInitDone chan struct{}) error {
HarnessPath: filepath.Join(m.config.Root, "wheels"),
TaskContainerDefaults: m.config.TaskContainerDefaults,
MasterCert: config.GetCertPEM(cert),
SSHRsaSize: m.config.Security.SSH.RsaKeySize,
SSHConfig: m.config.Security.SSH,
SegmentEnabled: m.config.Telemetry.Enabled && m.config.Telemetry.SegmentMasterKey != "",
SegmentAPIKey: m.config.Telemetry.SegmentMasterKey,
LogRetentionDays: m.config.RetentionPolicy.LogRetentionDays,
Expand Down
2 changes: 1 addition & 1 deletion master/internal/core_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func TestRun(t *testing.T) {
DefaultLoggingConfig: &model.DefaultLoggingConfig{},
},
},
taskSpec: &tasks.TaskSpec{SSHRsaSize: 1024},
taskSpec: &tasks.TaskSpec{SSHConfig: config.SSHConfig{KeyType: "ED25519"}},
}
require.NoError(t, m.config.Resolve())
m.config.DB = config.DBConfig{
Expand Down
2 changes: 1 addition & 1 deletion master/internal/experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func newExperiment(

taskSpec.AgentUserGroup = agentUserGroup

generatedKeys, err := ssh.GenerateKey(taskSpec.SSHRsaSize, nil)
generatedKeys, err := ssh.GenerateKey(taskSpec.SSHConfig)
if err != nil {
return nil, nil, errors.Wrap(err, "generating ssh keys for trials")
}
Expand Down
1 change: 0 additions & 1 deletion master/internal/trial_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ func setup(t *testing.T) (
&model.Checkpoint{},
&tasks.TaskSpec{
AgentUserGroup: &model.AgentUserGroup{},
SSHRsaSize: 1024,
Workspace: model.DefaultWorkspaceName,
},
ssh.PrivateAndPublicKeys{},
Expand Down
89 changes: 74 additions & 15 deletions master/pkg/ssh/ssh.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
package ssh

import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"

"github.com/pkg/errors"
sshlib "golang.org/x/crypto/ssh"

"github.com/determined-ai/determined/master/internal/config"
)

const (
trialPEMBlockType = "RSA PRIVATE KEY"
rsaPEMBlockType = "RSA PRIVATE KEY"
ecdsaPEMBlockType = "EC PRIVATE KEY"
)

// PrivateAndPublicKeys contains a private and public key.
Expand All @@ -21,40 +27,93 @@ type PrivateAndPublicKeys struct {
}

// GenerateKey returns a private and public SSH key.
func GenerateKey(rsaKeySize int, passphrase *string) (PrivateAndPublicKeys, error) {
func GenerateKey(conf config.SSHConfig) (PrivateAndPublicKeys, error) {
var generatedKeys PrivateAndPublicKeys
switch conf.KeyType {
case config.KeyTypeRSA:
return generateRSAKey(conf.RsaKeySize)
case config.KeyTypeECDSA:
return generateECDSAKey()
case config.KeyTypeED25519:
return generateED25519Key()
default:
return generatedKeys, errors.New("Invalid crypto system")
}
}

func generateRSAKey(rsaKeySize int) (PrivateAndPublicKeys, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, rsaKeySize)
if err != nil {
return generatedKeys, errors.Wrap(err, "unable to generate private key")
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate RSA private key")
}

if err = privateKey.Validate(); err != nil {
return generatedKeys, err
return PrivateAndPublicKeys{}, err
}

block := &pem.Block{
Type: trialPEMBlockType,
Type: rsaPEMBlockType,
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
}

if passphrase != nil {
// TODO: Replace usage of deprecated x509.EncryptPEMBlock.
block, err = x509.EncryptPEMBlock( //nolint: staticcheck
rand.Reader, block.Type, block.Bytes, []byte(*passphrase), x509.PEMCipherAES256)
if err != nil {
return generatedKeys, errors.Wrap(err, "unable to encrypt private key")
}
publicKey, err := sshlib.NewPublicKey(&privateKey.PublicKey)
if err != nil {
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate RSA public key")
}

return PrivateAndPublicKeys{
PrivateKey: pem.EncodeToMemory(block),
PublicKey: sshlib.MarshalAuthorizedKey(publicKey),
}, nil
}

func generateECDSAKey() (PrivateAndPublicKeys, error) {
// Curve size currently not configurable, using the NIST recommendation.
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate ECDSA private key")
}

privateKeyBytes, err := x509.MarshalECPrivateKey(privateKey)
if err != nil {
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to marshal ECDSA private key")
}

block := &pem.Block{
Type: ecdsaPEMBlockType,
Bytes: privateKeyBytes,
}

publicKey, err := sshlib.NewPublicKey(&privateKey.PublicKey)
if err != nil {
return generatedKeys, errors.Wrap(err, "unable to generate public key")
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate ECDSA public key")
}

generatedKeys = PrivateAndPublicKeys{
return PrivateAndPublicKeys{
PrivateKey: pem.EncodeToMemory(block),
PublicKey: sshlib.MarshalAuthorizedKey(publicKey),
}, nil
}

func generateED25519Key() (PrivateAndPublicKeys, error) {
ed25519PublicKey, privateKey, err := ed25519.GenerateKey(nil)
if err != nil {
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate ED25519 private key")
}

return generatedKeys, nil
// Before OpenSSH 9.6, for ED25519 keys, only the OpenSSH private key format was supported.
block, err := sshlib.MarshalPrivateKey(privateKey, "")
if err != nil {
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to marshal ED25519 private key")
}

publicKey, err := sshlib.NewPublicKey(ed25519PublicKey)
if err != nil {
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate ED25519 public key")
}

return PrivateAndPublicKeys{
PrivateKey: pem.EncodeToMemory(block),
PublicKey: sshlib.MarshalAuthorizedKey(publicKey),
}, nil
}
39 changes: 39 additions & 0 deletions master/pkg/ssh/ssh_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package ssh

import (
"testing"

"golang.org/x/crypto/ssh"
"gotest.tools/assert"

"github.com/determined-ai/determined/master/internal/config"
)

func verifyKeys(t *testing.T, keys PrivateAndPublicKeys) {
privateKey, err := ssh.ParsePrivateKey(keys.PrivateKey)
assert.NilError(t, err)

publickKey, _, _, _, err := ssh.ParseAuthorizedKey(keys.PublicKey) //nolint:dogsled
assert.NilError(t, err)
assert.Equal(t, string(publickKey.Marshal()), string(privateKey.PublicKey().Marshal()))
}

func TestSSHKeyGenerate(t *testing.T) {
t.Run("generate RSA key", func(t *testing.T) {
keys, err := GenerateKey(config.SSHConfig{KeyType: config.KeyTypeRSA, RsaKeySize: 512})
assert.NilError(t, err)
verifyKeys(t, keys)
})

t.Run("generate ECDSA key", func(t *testing.T) {
keys, err := GenerateKey(config.SSHConfig{KeyType: config.KeyTypeECDSA})
assert.NilError(t, err)
verifyKeys(t, keys)
})

t.Run("generate ED25519 key", func(t *testing.T) {
keys, err := GenerateKey(config.SSHConfig{KeyType: config.KeyTypeED25519})
assert.NilError(t, err)
verifyKeys(t, keys)
})
}
3 changes: 2 additions & 1 deletion master/pkg/tasks/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/docker/docker/api/types/mount"
"github.com/jinzhu/copier"

"github.com/determined-ai/determined/master/internal/config"
"github.com/determined-ai/determined/master/pkg/archive"
"github.com/determined-ai/determined/master/pkg/cproto"
"github.com/determined-ai/determined/master/pkg/device"
Expand Down Expand Up @@ -70,7 +71,7 @@ type TaskSpec struct {
ClusterID string
HarnessPath string
MasterCert []byte
SSHRsaSize int
SSHConfig config.SSHConfig

SegmentEnabled bool
SegmentAPIKey string
Expand Down
2 changes: 1 addition & 1 deletion proto/pkg/apiv1/shell.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion proto/src/determined/api/v1/shell.proto
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ message LaunchShellRequest {
string template_name = 2;
// The files to run with the command.
repeated determined.util.v1.File files = 3;
// Additional data.
// Deprecated: Do not use.
bytes data = 4;
// Workspace ID. Defaults to 'Uncategorized' workspace if not specified.
int32 workspace_id = 5;
Expand Down
Loading

0 comments on commit 79ebe68

Please sign in to comment.