diff --git a/go.mod b/go.mod index 12c1cc2..d0ccd97 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module github.com/titan-data/ssh-remote-go require ( github.com/stretchr/testify v1.4.0 - github.com/titan-data/remote-sdk-go v0.1.0 + github.com/titan-data/remote-sdk-go v0.2.1 golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876 ) diff --git a/go.sum b/go.sum index 3171489..43867fd 100644 --- a/go.sum +++ b/go.sum @@ -21,6 +21,8 @@ github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5a github.com/hashicorp/go-hclog v0.0.0-20180709165350-ff2cf002a8dd/go.mod h1:9bjs9uLqI8l75knNv3lV1kA55veR+WUPSiKIWcQHudI= github.com/hashicorp/go-hclog v0.10.1 h1:uyt/l0dWjJ879yiAu+T7FG3/6QX+zwm4bQ8P7XsYt3o= github.com/hashicorp/go-hclog v0.10.1/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ= +github.com/hashicorp/go-hclog v0.11.0 h1:zf3QG3ap4KOMHzDLxBvq9ZtEFVSxQzVdH1ccl5NK2tU= +github.com/hashicorp/go-hclog v0.11.0/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ= github.com/hashicorp/go-plugin v1.0.1 h1:4OtAfUGbnKC6yS48p0CtMX2oFYtzFZVv6rok3cRWgnE= github.com/hashicorp/go-plugin v1.0.1/go.mod h1:++UyYGoz3o5w9ZzAdZxtQKrWWP+iqPBn3cQptSMzBuY= github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb h1:b5rjCoWHc7eqmAS4/qyk21ZsHyb6Mxv/jykxvNTkU4M= @@ -37,6 +39,7 @@ github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= @@ -48,6 +51,8 @@ github.com/titan-data/remote-sdk-go v0.0.3 h1:kGLc7JP7znTcXyl3gRML2jEST99ebdhz0C github.com/titan-data/remote-sdk-go v0.0.3/go.mod h1:b4McaOFiLYWv2/wCQW/sE2BcCSNz/Ae6iKKEtqw703w= github.com/titan-data/remote-sdk-go v0.1.0 h1:Xab4sduqyGdo04eJ5mQ2lIfAYZDjHJ4AzlK4oJ5EqqE= github.com/titan-data/remote-sdk-go v0.1.0/go.mod h1:u7w3Olu3EtjoFcHzxUOzelxedey4q2nveuKLvgRO7tg= +github.com/titan-data/remote-sdk-go v0.2.1 h1:Aa1CWqSbPIIvuacy27nMcbmLF2eymLIJs8m+yW8ki8E= +github.com/titan-data/remote-sdk-go v0.2.1/go.mod h1:IYCrMWL1hFGoYusrTHgseG/bLVuY72LpQSSdGOrcHb4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876 h1:sKJQZMuxjOAR/Uo2LBfU90onWEf1dF4C+0hPJCc9Mpc= golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= diff --git a/ssh/mock_test.go b/ssh/mock_test.go new file mode 100644 index 0000000..322954e --- /dev/null +++ b/ssh/mock_test.go @@ -0,0 +1,64 @@ +/* + * Copyright The Titan Project Contributors. + */ +package ssh + +import ( + "github.com/stretchr/testify/mock" + "golang.org/x/crypto/ssh" + "net" +) + +type MockConn struct { + mock.Mock +} + +func (m *MockConn) User() string { + args := m.Called() + return args.String(0) +} + +func (m *MockConn) SessionID() []byte { + args := m.Called() + return args.Get(0).([]byte) +} + +func (m *MockConn) ClientVersion() []byte { + args := m.Called() + return args.Get(0).([]byte) +} + +func (m *MockConn) ServerVersion() []byte { + args := m.Called() + return args.Get(0).([]byte) +} + +func (m *MockConn) RemoteAddr() net.Addr { + args := m.Called() + return args.Get(0).(net.Addr) +} + +func (m *MockConn) LocalAddr() net.Addr { + args := m.Called() + return args.Get(0).(net.Addr) +} + +func (m *MockConn) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { + args := m.Called(name, wantReply, payload) + return args.Bool(0), args.Get(1).([]byte), args.Error(2) +} + +func (m MockConn) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) { + args := m.Called(name, data) + return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2) +} + +func (m MockConn) Close() error { + args := m.Called() + return args.Error(0) +} + +func (m MockConn) Wait() error { + args := m.Called() + return args.Error(0) +} diff --git a/ssh/ssh.go b/ssh/ssh.go index d87061d..99270e6 100644 --- a/ssh/ssh.go +++ b/ssh/ssh.go @@ -4,9 +4,12 @@ package ssh import ( + "bufio" + "encoding/json" "errors" "fmt" "github.com/titan-data/remote-sdk-go/remote" + "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/terminal" "io/ioutil" "net/url" @@ -83,22 +86,35 @@ func (s sshRemote) FromURL(rawUrl string, additionalProperties map[string]string return result, nil } +func getPort(port interface{}) (int, error) { + portval := 0 + if p, ok := port.(int); ok { + portval = p + } + if p, ok := port.(float32); ok { + portval = int(p) + } + if p, ok := port.(float64); ok { + portval = int(p) + } + if portval <= 0 || portval > 65535 { + return 0, errors.New("invalid port") + } + return portval, nil +} + func (s sshRemote) ToURL(properties map[string]interface{}) (string, map[string]string, error) { u := fmt.Sprintf("ssh://%s", properties["username"]) if properties["password"] != nil { u += ":*****" } u += fmt.Sprintf("@%s", properties["address"]) - if properties["port"] != nil { - var port = 0 - if flt, ok := properties["port"].(float32); ok { - port = int(flt) - } else if flt, ok := properties["port"].(float64); ok { - port = int(flt) - } else { - port = properties["port"].(int) + if port, ok := properties["port"]; ok { + portval, err := getPort(port) + if err != nil { + return "", nil, err } - u += fmt.Sprintf(":%d", port) + u += fmt.Sprintf(":%d", portval) } if properties["path"].(string)[0:1] != "/" { u += "/~/" @@ -139,6 +155,139 @@ func (s sshRemote) GetParameters(remoteProperties map[string]interface{}) (map[s return result, nil } +func (s sshRemote) ValidateRemote(properties map[string]interface{}) error { + err := remote.ValidateFields(properties, []string{"username", "address", "path"}, []string{"password", "port", "keyFile"}) + if err != nil { + return err + } + if port, ok := properties["port"]; ok { + _, err := getPort(port) + return err + } + return nil +} + +func (s sshRemote) ValidateParameters(parameters map[string]interface{}) error { + return remote.ValidateFields(parameters, []string{}, []string{"password", "key"}) +} + +/* + * This method will parse the remote configuration and parameters to determine if we should use password + * authentication or key-based authentication. It returns a pair where exactly one element must be set, either + * the first (password) or second (key). + */ +func getAuth(properties map[string]interface{}, parameters map[string]interface{}) (string, string, error) { + paramsPassword, paramsPasswordOk := parameters["password"] + paramsKey, paramsKeyOk := parameters["key"] + remotePassword, remotePasswordOk := properties["password"] + if paramsPasswordOk && paramsKeyOk { + return "", "", errors.New("only one of password or key can be specified") + } + if paramsKeyOk { + return "", paramsKey.(string), nil + } + if paramsPasswordOk { + return paramsPassword.(string), "", nil + } + if remotePasswordOk { + return remotePassword.(string), "", nil + } + return "", "", errors.New("one of password or key must be specified") +} + +var dial = ssh.Dial + +func getConnection(properties map[string]interface{}, parameters map[string]interface{}) (*ssh.Client, error) { + password, key, err := getAuth(properties, parameters) + if err != nil { + return nil, err + } + config := &ssh.ClientConfig{ + User: properties["username"].(string), + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + if key != "" { + parsed, err := ssh.ParsePrivateKey([]byte(key)) + if err != nil { + return nil, err + } + config.Auth = []ssh.AuthMethod{ssh.PublicKeys(parsed)} + } else { + config.Auth = []ssh.AuthMethod{ssh.Password(password)} + } + + return dial("tcp", properties["address"].(string), config) +} + +func runCommand(conn *ssh.Client, command string) ([]byte, error) { + sess, err := conn.NewSession() + if err != nil { + return nil, err + } + defer sess.Close() + + output, err := sess.CombinedOutput(command) + if err != nil { + return nil, fmt.Errorf("failed to execute '%s': %w\n%s", command, err, string(output)) + } + return output, nil +} + +var run = runCommand + +func readCommit(conn *ssh.Client, properties map[string]interface{}, commitId string) (*remote.Commit, error) { + output, err := run(conn, fmt.Sprintf("cat \"%s/%s/metadata.json\"", properties["path"], commitId)) + if err != nil { + return nil, err + } + + commit := map[string]interface{}{} + err = json.Unmarshal(output, &commit) + if err != nil { + return nil, err + } + + return &remote.Commit{Id: commitId, Properties: commit}, nil +} + +func (s sshRemote) ListCommits(properties map[string]interface{}, parameters map[string]interface{}, tags []remote.Tag) ([]remote.Commit, error) { + conn, err := getConnection(properties, parameters) + if err != nil { + return nil, err + } + defer conn.Close() + + output, err := run(conn, fmt.Sprintf("ls -1 \"%s\"", properties["path"])) + if err != nil { + return nil, err + } + + var ret []remote.Commit + scanner := bufio.NewScanner(strings.NewReader(string(output))) + for scanner.Scan() { + commitId := strings.TrimSpace(scanner.Text()) + commit, err := readCommit(conn, properties, commitId) + if err == nil && remote.MatchTags(commit.Properties, tags) { + ret = append(ret, remote.Commit{Id: commit.Id, Properties: commit.Properties}) + } + } + + remote.SortCommits(ret) + + return ret, nil +} + +func (s sshRemote) GetCommit(properties map[string]interface{}, parameters map[string]interface{}, commitId string) (*remote.Commit, error) { + conn, err := getConnection(properties, parameters) + if err != nil { + return nil, err + } + defer conn.Close() + + return readCommit(conn, properties, commitId) +} + func init() { remote.Register(sshRemote{}) } diff --git a/ssh/ssh_test.go b/ssh/ssh_test.go index d0774f4..4dd8c5c 100644 --- a/ssh/ssh_test.go +++ b/ssh/ssh_test.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/stretchr/testify/assert" "github.com/titan-data/remote-sdk-go/remote" + "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/terminal" "io/ioutil" "os" @@ -17,201 +18,247 @@ import ( func TestRegistered(t *testing.T) { r := remote.Get("ssh") - ret, _ := r.Type() - assert.Equal(t, "ssh", ret) + ret, err := r.Type() + if assert.NoError(t, err) { + assert.Equal(t, "ssh", ret) + } } func TestFromURL(t *testing.T) { r := remote.Get("ssh") - props, _ := r.FromURL("ssh://user:pass@host:8022/path", map[string]string{}) - assert.Equal(t, "user", props["username"]) - assert.Equal(t, "pass", props["password"]) - assert.Equal(t, "host", props["address"]) - assert.Equal(t, 8022, props["port"]) - assert.Equal(t, "/path", props["path"]) - assert.Nil(t, props["keyFile"]) + props, err := r.FromURL("ssh://user:pass@host:8022/path", map[string]string{}) + if assert.NoError(t, err) { + assert.Equal(t, "user", props["username"]) + assert.Equal(t, "pass", props["password"]) + assert.Equal(t, "host", props["address"]) + assert.Equal(t, 8022, props["port"]) + assert.Equal(t, "/path", props["path"]) + assert.Nil(t, props["keyFile"]) + } } func TestSimple(t *testing.T) { r := remote.Get("ssh") - props, _ := r.FromURL("ssh://user@host/path", map[string]string{}) - assert.Equal(t, "user", props["username"]) - assert.Nil(t, props["password"]) - assert.Equal(t, "host", props["address"]) - assert.Nil(t, props["port"]) - assert.Equal(t, "/path", props["path"]) - assert.Nil(t, props["keyFile"]) + props, err := r.FromURL("ssh://user@host/path", map[string]string{}) + if assert.NoError(t, err) { + assert.Equal(t, "user", props["username"]) + assert.Nil(t, props["password"]) + assert.Equal(t, "host", props["address"]) + assert.Nil(t, props["port"]) + assert.Equal(t, "/path", props["path"]) + assert.Nil(t, props["keyFile"]) + } } func TestKeyFile(t *testing.T) { r := remote.Get("ssh") - props, _ := r.FromURL("ssh://user@host/path", map[string]string{"keyFile": "~/.ssh/id_dsa"}) - assert.Equal(t, "~/.ssh/id_dsa", props["keyFile"]) + props, err := r.FromURL("ssh://user@host/path", map[string]string{"keyFile": "~/.ssh/id_dsa"}) + if assert.NoError(t, err) { + assert.Equal(t, "~/.ssh/id_dsa", props["keyFile"]) + } } func TestRelativePath(t *testing.T) { r := remote.Get("ssh") - props, _ := r.FromURL("ssh://user@host/~/relative/path", map[string]string{}) - assert.Equal(t, "relative/path", props["path"]) + props, err := r.FromURL("ssh://user@host/~/relative/path", map[string]string{}) + if assert.NoError(t, err) { + assert.Equal(t, "relative/path", props["path"]) + } +} + +func TestBadUrl(t *testing.T) { + r := remote.Get("ssh") + _, err := r.FromURL("ssh://host\nname", map[string]string{}) + assert.Error(t, err) } func TestBadScheme(t *testing.T) { r := remote.Get("ssh") _, err := r.FromURL("foo://user:pass@host:8022/path", map[string]string{}) - assert.NotNil(t, err) + assert.Error(t, err) } func TestBadPasswordAndKeyFile(t *testing.T) { r := remote.Get("ssh") _, err := r.FromURL("ssh://user:password@host/path", map[string]string{"keyFile": "~/.ssh/id_dsa"}) - assert.NotNil(t, err) + assert.Error(t, err) } func TestBadProperty(t *testing.T) { r := remote.Get("ssh") _, err := r.FromURL("ssh://user@host/path", map[string]string{"foo": "bar"}) - assert.NotNil(t, err) + assert.Error(t, err) } func TestBadMissingHost(t *testing.T) { r := remote.Get("ssh") _, err := r.FromURL("ssh:///path", map[string]string{}) - assert.NotNil(t, err) + assert.Error(t, err) } func TestBadSchemeOnly(t *testing.T) { r := remote.Get("ssh") _, err := r.FromURL("ssh", map[string]string{}) - assert.NotNil(t, err) + assert.Error(t, err) } func TestBadMissingUsername(t *testing.T) { r := remote.Get("ssh") _, err := r.FromURL("ssh://host/path", map[string]string{}) - assert.NotNil(t, err) + assert.Error(t, err) } func TestBadPort(t *testing.T) { r := remote.Get("ssh") _, err := r.FromURL("ssh://user@host:29348529384572398457932847539/path", map[string]string{}) - assert.NotNil(t, err) + assert.Error(t, err) } func TestBadMissingPath(t *testing.T) { r := remote.Get("ssh") _, err := r.FromURL("ssh://user@host", map[string]string{}) - assert.NotNil(t, err) + assert.Error(t, err) } func TestBadMissingHostWithUser(t *testing.T) { r := remote.Get("ssh") _, err := r.FromURL("ssh://user@/path", map[string]string{}) - assert.NotNil(t, err) + assert.Error(t, err) } func TestToURL(t *testing.T) { r := remote.Get("ssh") - u, props, _ := r.ToURL(map[string]interface{}{"username": "username", "address": "host", + u, props, err := r.ToURL(map[string]interface{}{"username": "username", "address": "host", "path": "/path"}) - assert.Equal(t, "ssh://username@host/path", u) - assert.Empty(t, props) + if assert.NoError(t, err) { + assert.Equal(t, "ssh://username@host/path", u) + assert.Empty(t, props) + } } func TestToPassword(t *testing.T) { r := remote.Get("ssh") - u, props, _ := r.ToURL(map[string]interface{}{"username": "username", "address": "host", + u, props, err := r.ToURL(map[string]interface{}{"username": "username", "address": "host", "path": "/path", "password": "pass"}) - assert.Equal(t, "ssh://username:*****@host/path", u) - assert.Empty(t, props) + if assert.NoError(t, err) { + assert.Equal(t, "ssh://username:*****@host/path", u) + assert.Empty(t, props) + } } func TestToPort(t *testing.T) { r := remote.Get("ssh") - u, props, _ := r.ToURL(map[string]interface{}{"username": "username", "address": "host", + u, props, err := r.ToURL(map[string]interface{}{"username": "username", "address": "host", "path": "/path", "port": 812}) - assert.Equal(t, "ssh://username@host:812/path", u) - assert.Empty(t, props) + if assert.NoError(t, err) { + assert.Equal(t, "ssh://username@host:812/path", u) + assert.Empty(t, props) + } +} + +func TestToBadPort(t *testing.T) { + r := remote.Get("ssh") + _, _, err := r.ToURL(map[string]interface{}{"username": "username", "address": "host", + "path": "/path", "port": "812"}) + assert.Error(t, err) } func TestToRelativePath(t *testing.T) { r := remote.Get("ssh") - u, props, _ := r.ToURL(map[string]interface{}{"username": "username", "address": "host", + u, props, err := r.ToURL(map[string]interface{}{"username": "username", "address": "host", "path": "path"}) - assert.Equal(t, "ssh://username@host/~/path", u) - assert.Empty(t, props) + if assert.NoError(t, err) { + assert.Equal(t, "ssh://username@host/~/path", u) + assert.Empty(t, props) + } } func TestToKeyFile(t *testing.T) { r := remote.Get("ssh") - u, props, _ := r.ToURL(map[string]interface{}{"username": "username", "address": "host", + u, props, err := r.ToURL(map[string]interface{}{"username": "username", "address": "host", "path": "/path", "keyFile": "keyfile"}) - assert.Equal(t, "ssh://username@host/path", u) - assert.Len(t, props, 1) - assert.Equal(t, "keyfile", props["keyFile"]) + if assert.NoError(t, err) { + assert.Equal(t, "ssh://username@host/path", u) + assert.Len(t, props, 1) + assert.Equal(t, "keyfile", props["keyFile"]) + } } func TestToPortFloat(t *testing.T) { p := float32(812) r := remote.Get("ssh") - u, props, _ := r.ToURL(map[string]interface{}{"username": "username", "address": "host", + u, props, err := r.ToURL(map[string]interface{}{"username": "username", "address": "host", "path": "/path", "port": p}) - assert.Equal(t, "ssh://username@host:812/path", u) - assert.Empty(t, props) + if assert.NoError(t, err) { + assert.Equal(t, "ssh://username@host:812/path", u) + assert.Empty(t, props) + } } func TestToPortDouble(t *testing.T) { r := remote.Get("ssh") - u, props, _ := r.ToURL(map[string]interface{}{"username": "username", "address": "host", + u, props, err := r.ToURL(map[string]interface{}{"username": "username", "address": "host", "path": "/path", "port": 812.0}) - assert.Equal(t, "ssh://username@host:812/path", u) - assert.Empty(t, props) + if assert.NoError(t, err) { + assert.Equal(t, "ssh://username@host:812/path", u) + assert.Empty(t, props) + } } func TestGetParameters(t *testing.T) { r := remote.Get("ssh") - props, _ := r.GetParameters(map[string]interface{}{"username": "username", "address": "host", + props, err := r.GetParameters(map[string]interface{}{"username": "username", "address": "host", "path": "/path", "password": "pass"}) - assert.Empty(t, props) + if assert.NoError(t, err) { + assert.Empty(t, props) + } } func TestKeyFileParameters(t *testing.T) { r := remote.Get("ssh") file, err := ioutil.TempFile("", "ssh.test") - if err != nil { - t.Fatal(err) + if !assert.NoError(t, err) { + return } defer os.Remove(file.Name()) path, err := filepath.Abs(file.Name()) - if err != nil { - t.Fatal(err) + if !assert.NoError(t, err) { + return } err = ioutil.WriteFile(path, []byte("KEY"), 0600) - if err != nil { - t.Fatal(err) + if !assert.NoError(t, err) { + return } - props, _ := r.GetParameters(map[string]interface{}{"username": "username", "address": "host", + props, err := r.GetParameters(map[string]interface{}{"username": "username", "address": "host", "path": "/path", "keyFile": path}) - assert.Nil(t, props["password"]) - assert.Equal(t, "KEY", props["key"]) + if assert.NoError(t, err) { + assert.Nil(t, props["password"]) + assert.Equal(t, "KEY", props["key"]) + } } func TestBadKeyFileParameters(t *testing.T) { r := remote.Get("ssh") file, err := ioutil.TempFile("", "ssh.test") - if err != nil { - t.Fatal(err) + if !assert.NoError(t, err) { + return } path, err := filepath.Abs(file.Name()) - if err != nil { - t.Fatal(err) + if !assert.NoError(t, err) { + return + } + err = file.Close() + if !assert.NoError(t, err) { + return + } + err = os.Remove(path) + if assert.NoError(t, err) { + _, err = r.GetParameters(map[string]interface{}{"username": "username", "address": "host", + "path": "/path", "keyFile": path}) + assert.Error(t, err) } - os.Remove(path) - - _, err = r.GetParameters(map[string]interface{}{"username": "username", "address": "host", - "path": "/path", "keyFile": path}) - assert.NotNil(t, err) } func TestPasswordPrompt(t *testing.T) { @@ -222,13 +269,15 @@ func TestPasswordPrompt(t *testing.T) { fmtPrintf = func(format string, a ...interface{}) (n int, err error) { return 0, nil } - props, _ := r.GetParameters(map[string]interface{}{"username": "username", "address": "host", + props, err := r.GetParameters(map[string]interface{}{"username": "username", "address": "host", "path": "/path"}) - readPassword = terminal.ReadPassword - fmtPrintf = fmt.Printf + if assert.NoError(t, err) { + readPassword = terminal.ReadPassword + fmtPrintf = fmt.Printf - assert.Nil(t, props["key"]) - assert.Equal(t, "pass", props["password"]) + assert.Nil(t, props["key"]) + assert.Equal(t, "pass", props["password"]) + } } func TestBadPasswordPrompt(t *testing.T) { @@ -244,5 +293,353 @@ func TestBadPasswordPrompt(t *testing.T) { readPassword = terminal.ReadPassword fmtPrintf = fmt.Printf - assert.NotNil(t, err) + assert.Error(t, err) +} + +func TestValidateRemoteRequiredOnly(t *testing.T) { + r := remote.Get("ssh") + err := r.ValidateRemote(map[string]interface{}{"username": "username", "address": "host", "path": "/path"}) + assert.NoError(t, err) +} + +func TestValidateRemoteAllOptional(t *testing.T) { + r := remote.Get("ssh") + err := r.ValidateRemote(map[string]interface{}{"username": "username", "address": "host", "path": "/path", + "keyFile": "/keyfile", "password": "password", "port": 8022}) + assert.NoError(t, err) +} + +func TestValidateRemoteBadPort(t *testing.T) { + r := remote.Get("ssh") + err := r.ValidateRemote(map[string]interface{}{"username": "username", "address": "host", "path": "/path", + "keyFile": "/keyfile", "password": "password", "port": "foo"}) + assert.Error(t, err) +} + +func TestValidateRemoteBadPortNegative(t *testing.T) { + r := remote.Get("ssh") + err := r.ValidateRemote(map[string]interface{}{"username": "username", "address": "host", "path": "/path", + "keyFile": "/keyfile", "password": "password", "port": -1}) + assert.Error(t, err) +} + +func TestValidateRemotePortFloat(t *testing.T) { + r := remote.Get("ssh") + err := r.ValidateRemote(map[string]interface{}{"username": "username", "address": "host", "path": "/path", + "keyFile": "/keyfile", "password": "password", "port": 22.0}) + assert.NoError(t, err) +} + +func TestValidateRemotePortFloat32(t *testing.T) { + r := remote.Get("ssh") + var p float32 = 22.0 + err := r.ValidateRemote(map[string]interface{}{"username": "username", "address": "host", "path": "/path", + "keyFile": "/keyfile", "password": "password", "port": p}) + assert.NoError(t, err) +} + +func TestValidateRemoteMissingRequired(t *testing.T) { + r := remote.Get("ssh") + err := r.ValidateRemote(map[string]interface{}{"username": "username", "address": "host"}) + assert.Error(t, err) +} + +func TestValidateRemoteExtraProperty(t *testing.T) { + r := remote.Get("ssh") + err := r.ValidateRemote(map[string]interface{}{"username": "username", "address": "host", "path": "/path", + "foo": "bar"}) + assert.Error(t, err) +} + +func TestValidateParametersEmpty(t *testing.T) { + r := remote.Get("ssh") + err := r.ValidateParameters(map[string]interface{}{}) + assert.NoError(t, err) +} + +func TestValidateParametersAllOptional(t *testing.T) { + r := remote.Get("ssh") + err := r.ValidateParameters(map[string]interface{}{"key": "key", "password": "password"}) + assert.NoError(t, err) +} + +func TestValidateParametersUnknown(t *testing.T) { + r := remote.Get("ssh") + err := r.ValidateParameters(map[string]interface{}{"foo": "bar"}) + assert.Error(t, err) +} + +func TestGetAuthBoth(t *testing.T) { + _, _, err := getAuth(map[string]interface{}{"password": "password"}, map[string]interface{}{"password": "password", + "key": "key"}) + assert.Error(t, err) +} + +func TestGetAuthKey(t *testing.T) { + pass, key, err := getAuth(map[string]interface{}{"password": "password"}, map[string]interface{}{"key": "key"}) + assert.NoError(t, err) + assert.Empty(t, pass) + assert.NotEmpty(t, key) +} + +func TestGetAuthParamPassword(t *testing.T) { + pass, key, err := getAuth(map[string]interface{}{"password": "one"}, map[string]interface{}{"password": "two"}) + assert.NoError(t, err) + assert.Equal(t, "two", pass) + assert.Empty(t, key) +} + +func TestGetAuthRemotePassword(t *testing.T) { + pass, key, err := getAuth(map[string]interface{}{"password": "one"}, map[string]interface{}{}) + assert.NoError(t, err) + assert.Equal(t, "one", pass) + assert.Empty(t, key) +} + +func TestGetAuthMissing(t *testing.T) { + _, _, err := getAuth(map[string]interface{}{}, map[string]interface{}{}) + assert.Error(t, err) +} + +func TestGetConnBadAuth(t *testing.T) { + dial = func(network string, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { + return nil, nil + } + _, err := getConnection(map[string]interface{}{}, map[string]interface{}{}) + dial = ssh.Dial + assert.Error(t, err) +} + +func TestGetConnPassword(t *testing.T) { + host := "" + var config *ssh.ClientConfig = nil + dial = func(network string, addr string, cfg *ssh.ClientConfig) (*ssh.Client, error) { + host = addr + config = cfg + return nil, nil + } + _, err := getConnection(map[string]interface{}{"username": "username", "address": "address"}, + map[string]interface{}{"password": "password"}) + if assert.NoError(t, err) { + assert.Equal(t, "address", host) + assert.Equal(t, "username", config.User) + } + dial = ssh.Dial +} + +func TestGetConnKey(t *testing.T) { + key := ` +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAsXU8SiL4eLBupbLEF9XAy+60Dr5+TPSUm8c27WCUfYOF5Yly +DWZcTS86coEGjgfqDFM6o3wXgadugt/XYi7M2k0QVsmX1577088/SixrNnX8HQyX +f3S4tLGDLX/d48A2Xi6FmJUpHqyPzKzVU1THQPOKoxZUV4qZmbRrR0FO8WmZMQTl +KNNopq4fvEPZw0oNPONS8e28zvCu0qqka06+mB5pIc5+OhXoQK4xPgPr/gW5Cruv +R5IgBt4gdLMSpBp2JB3hFj6U0c+7wmGaZYt5R92/b8tetn/jMhIt7720mJJPfq1d +1W1UpERZjUTMvFzNBLdCtgT59qxqL+Tv4QA6AQIDAQABAoIBAE7EvQgjUaswlUyT +dxslVDixMddBkwpRng0vdiATuJWl5a8nPSrZfqr8BbOBtgkhVjA2WVbr4/s2+IS7 +Gv2HzIIxpsj/HpklBp7T5UHlSYmZAVlbl3uJsdry2Ek/8pv/W6Kef8pkmyX0brfp +F5+vh+o6sBUH+lQJP3jMbrnoMURSX9jFSPg/+J1zb1Nf/SulBro4+Pb4t+i97FUk +mWqMI1jvCkAnQJ0oYQ9CeYBJjvXeENyN7HQ+RM6OEdsHi64EMfJCwrZGMAHIo2Ty +87AQhgoHEKfNC+XotnkPaKmS5qaP2ggPe2Ol63k3FbR6VHlqJny0VR48pbzyQyr2 +feENXWECgYEA2naSDRVCwiAdZAIvMa0cDjpwOYIfLJelc0hljCOaLQeye2oT+hAQ +pCVO7+maD4VbZ1Xmc70LGFSWktVlByJU9UOBY5rq7DTgXoEMoOtf2uUcJupnLLix +we7Fn9TFaM3RWKbbg3G0OjucepB7yVZ2qSVDGPQ8Bl/IKq2hfKsqy3UCgYEAz/L9 +PU0gzxmZlF1rea1d3clNoounW/J1qHXl2nT11RaIzPhct1fKde/wIt/D7gqwI3ba +wBJhNv/a4kDvnJwV3iyEKFs7qqeqaZ1KsLCkaQ0erdhl8LzfE25MWKlTthqjY9yT +f8ohD0r57y8NVInwfXhBKIUZr3qXBA+d0krfft0CgYACgbnLTKMndxbfPucrusDH +qQQApO2WpWbQm9QOd5odSilSITV5eRW3zHXLavLJms4hsWqjiVfHP7E6nhg6rLos +1kl1yyFG9JRegTyT3B+Nc3OPPsFQUg44G3VJEDfzq+jrC38ZUwSuZmC1R1MkTEmw +Ry0t7B+EMzUoyDVCKPSkwQKBgQC+roMWdiYSodfZWzyVK6r6F4AP/90sDA1ltw5Z +HozZo7s3sLpcCK2HLchWQjfIjJZtPqxiGbh5FW3hsEfHpLzMqKda1iXFW8+A3xHB +KYjpJ3WtVdRMRvSLPcXWOxae0phmlrnOIUvlWQwMDmo7zezvMJkXDc26wj++Io/G +aI++JQKBgQDYBW6xXOYHFbCazz7euPRXaV0BX9Pt+ylrQvqDWwa6fk9FDGOrhRW8 +1ywiam3Z+Nup2JNE8PjwP0qQisLbzAbG60HMg2Yx0C6yclIZLUDEwmrjmBVCiP81 +qXdXtd+SfLRrfCd1KJRp8NFIPFsk0T3iy8hxZJZSHtM6/nwM3p2rHw== +-----END RSA PRIVATE KEY-----` + host := "" + var config *ssh.ClientConfig = nil + dial = func(network string, addr string, cfg *ssh.ClientConfig) (*ssh.Client, error) { + host = addr + config = cfg + return nil, nil + } + _, err := getConnection(map[string]interface{}{"username": "username", "address": "address"}, + map[string]interface{}{"key": key}) + if assert.NoError(t, err) { + assert.Equal(t, "address", host) + assert.Equal(t, "username", config.User) + } + dial = ssh.Dial +} + +func TestGetConnBadKey(t *testing.T) { + key := "notakey" + dial = func(network string, addr string, cfg *ssh.ClientConfig) (*ssh.Client, error) { + return nil, nil + } + _, err := getConnection(map[string]interface{}{"username": "username", "address": "address"}, + map[string]interface{}{"key": key}) + assert.Error(t, err) + dial = ssh.Dial +} + +func TestGetCommit(t *testing.T) { + remoteCommand := "" + conn := new(MockConn) + conn.On("Close").Return(nil) + dial = func(network string, addr string, cfg *ssh.ClientConfig) (*ssh.Client, error) { + return &ssh.Client{Conn: conn}, nil + } + run = func(conn *ssh.Client, command string) (bytes []byte, err error) { + remoteCommand = command + return []byte("{\"a\": \"b\", \"c\": {\"d\": \"e\"}}"), nil + } + r := remote.Get("ssh") + commit, err := r.GetCommit(map[string]interface{}{"username": "username", "address": "address", "path": "/path"}, + map[string]interface{}{"password": "password"}, "id") + if assert.NoError(t, err) { + assert.Equal(t, "cat \"/path/id/metadata.json\"", remoteCommand) + assert.Equal(t, "id", commit.Id) + assert.Equal(t, "b", commit.Properties["a"]) + props := commit.Properties["c"].(map[string]interface{}) + assert.Equal(t, "e", props["d"]) + } + + run = runCommand + dial = ssh.Dial +} + +func TestGetCommitBadJson(t *testing.T) { + conn := new(MockConn) + conn.On("Close").Return(nil) + dial = func(network string, addr string, cfg *ssh.ClientConfig) (*ssh.Client, error) { + return &ssh.Client{Conn: conn}, nil + } + run = func(conn *ssh.Client, command string) (bytes []byte, err error) { + return []byte("foo"), nil + } + r := remote.Get("ssh") + _, err := r.GetCommit(map[string]interface{}{"username": "username", "address": "address", "path": "/path"}, + map[string]interface{}{"password": "password"}, "id") + assert.Error(t, err) + + run = runCommand + dial = ssh.Dial +} + +func TestGetCommitRunFail(t *testing.T) { + conn := new(MockConn) + conn.On("Close").Return(nil) + dial = func(network string, addr string, cfg *ssh.ClientConfig) (*ssh.Client, error) { + return &ssh.Client{Conn: conn}, nil + } + run = func(conn *ssh.Client, command string) (bytes []byte, err error) { + return nil, errors.New("error") + } + r := remote.Get("ssh") + _, err := r.GetCommit(map[string]interface{}{"username": "username", "address": "address", "path": "/path"}, + map[string]interface{}{"password": "password"}, "id") + assert.Error(t, err) + + run = runCommand + dial = ssh.Dial +} + +func TestGetCommitBadConn(t *testing.T) { + dial = func(network string, addr string, cfg *ssh.ClientConfig) (*ssh.Client, error) { + return nil, errors.New("error") + } + r := remote.Get("ssh") + _, err := r.GetCommit(map[string]interface{}{"username": "username", "address": "address", "path": "/path"}, + map[string]interface{}{"password": "password"}, "id") + assert.Error(t, err) + dial = ssh.Dial +} + +func TestListCommitsBadConn(t *testing.T) { + dial = func(network string, addr string, cfg *ssh.ClientConfig) (*ssh.Client, error) { + return nil, errors.New("error") + } + r := remote.Get("ssh") + _, err := r.ListCommits(map[string]interface{}{"username": "username", "address": "address", "path": "/path"}, + map[string]interface{}{"password": "password"}, []remote.Tag{}) + assert.Error(t, err) + dial = ssh.Dial +} + +func TestListCommitsRunFail(t *testing.T) { + conn := new(MockConn) + conn.On("Close").Return(nil) + dial = func(network string, addr string, cfg *ssh.ClientConfig) (*ssh.Client, error) { + return &ssh.Client{Conn: conn}, nil + } + run = func(conn *ssh.Client, command string) (bytes []byte, err error) { + return nil, errors.New("error") + } + r := remote.Get("ssh") + _, err := r.ListCommits(map[string]interface{}{"username": "username", "address": "address", "path": "/path"}, + map[string]interface{}{"password": "password"}, []remote.Tag{}) + assert.Error(t, err) + + run = runCommand + dial = ssh.Dial +} + +func TestListCommits(t *testing.T) { + conn := new(MockConn) + conn.On("Close").Return(nil) + dial = func(network string, addr string, cfg *ssh.ClientConfig) (*ssh.Client, error) { + return &ssh.Client{Conn: conn}, nil + } + run = func(conn *ssh.Client, command string) (bytes []byte, err error) { + if command == "ls -1 \"/path\"" { + return []byte("one\ntwo\n"), nil + } + if command == "cat \"/path/one/metadata.json\"" { + return []byte("{\"timestamp\": \"2019-09-20T13:45:36Z\"}"), nil + } + if command == "cat \"/path/two/metadata.json\"" { + return []byte("{\"timestamp\": \"2019-09-20T13:45:37Z\"}"), nil + } + return nil, errors.New("error") + } + r := remote.Get("ssh") + commits, err := r.ListCommits(map[string]interface{}{"username": "username", "address": "address", "path": "/path"}, + map[string]interface{}{"password": "password"}, []remote.Tag{}) + if assert.NoError(t, err) { + assert.Len(t, commits, 2) + assert.Equal(t, "two", commits[0].Id) + assert.Equal(t, "one", commits[1].Id) + } + run = runCommand + dial = ssh.Dial +} + +func TestListCommitsTags(t *testing.T) { + conn := new(MockConn) + conn.On("Close").Return(nil) + dial = func(network string, addr string, cfg *ssh.ClientConfig) (*ssh.Client, error) { + return &ssh.Client{Conn: conn}, nil + } + run = func(conn *ssh.Client, command string) (bytes []byte, err error) { + if command == "ls -1 \"/path\"" { + return []byte("one\ntwo\n"), nil + } + if command == "cat \"/path/one/metadata.json\"" { + return []byte("{\"timestamp\": \"2019-09-20T13:45:36Z\", \"tags\": {\"a\": \"b\"}}"), nil + } + if command == "cat \"/path/two/metadata.json\"" { + return []byte("{\"timestamp\": \"2019-09-20T13:45:37Z\", \"tags\": {\"c\": \"d\"}}"), nil + } + return nil, errors.New("error") + } + r := remote.Get("ssh") + commits, err := r.ListCommits(map[string]interface{}{"username": "username", "address": "address", "path": "/path"}, + map[string]interface{}{"password": "password"}, []remote.Tag{{Key: "a"}}) + if assert.NoError(t, err) { + assert.Len(t, commits, 1) + assert.Equal(t, "one", commits[0].Id) + } + run = runCommand + dial = ssh.Dial }