diff --git a/github.com/kr/fs/LICENSE b/github.com/kr/fs/LICENSE new file mode 100644 index 0000000000..7448756763 --- /dev/null +++ b/github.com/kr/fs/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2012 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/github.com/kr/fs/Readme b/github.com/kr/fs/Readme new file mode 100644 index 0000000000..c95e13fc84 --- /dev/null +++ b/github.com/kr/fs/Readme @@ -0,0 +1,3 @@ +Filesystem Package + +http://godoc.org/github.com/kr/fs diff --git a/github.com/kr/fs/example_test.go b/github.com/kr/fs/example_test.go new file mode 100644 index 0000000000..77e0db9fee --- /dev/null +++ b/github.com/kr/fs/example_test.go @@ -0,0 +1,19 @@ +package fs_test + +import ( + "fmt" + "os" + + "github.com/kr/fs" +) + +func ExampleWalker() { + walker := fs.Walk("/usr/lib") + for walker.Step() { + if err := walker.Err(); err != nil { + fmt.Fprintln(os.Stderr, err) + continue + } + fmt.Println(walker.Path()) + } +} diff --git a/github.com/kr/fs/filesystem.go b/github.com/kr/fs/filesystem.go new file mode 100644 index 0000000000..f1c4805fbd --- /dev/null +++ b/github.com/kr/fs/filesystem.go @@ -0,0 +1,36 @@ +package fs + +import ( + "io/ioutil" + "os" + "path/filepath" +) + +// FileSystem defines the methods of an abstract filesystem. +type FileSystem interface { + + // ReadDir reads the directory named by dirname and returns a + // list of directory entries. + ReadDir(dirname string) ([]os.FileInfo, error) + + // Lstat returns a FileInfo describing the named file. If the file is a + // symbolic link, the returned FileInfo describes the symbolic link. Lstat + // makes no attempt to follow the link. + Lstat(name string) (os.FileInfo, error) + + // Join joins any number of path elements into a single path, adding a + // separator if necessary. The result is Cleaned; in particular, all + // empty strings are ignored. + // + // The separator is FileSystem specific. + Join(elem ...string) string +} + +// fs represents a FileSystem provided by the os package. +type fs struct{} + +func (f *fs) ReadDir(dirname string) ([]os.FileInfo, error) { return ioutil.ReadDir(dirname) } + +func (f *fs) Lstat(name string) (os.FileInfo, error) { return os.Lstat(name) } + +func (f *fs) Join(elem ...string) string { return filepath.Join(elem...) } diff --git a/github.com/kr/fs/walk.go b/github.com/kr/fs/walk.go new file mode 100644 index 0000000000..6ffa1e0b24 --- /dev/null +++ b/github.com/kr/fs/walk.go @@ -0,0 +1,95 @@ +// Package fs provides filesystem-related functions. +package fs + +import ( + "os" +) + +// Walker provides a convenient interface for iterating over the +// descendants of a filesystem path. +// Successive calls to the Step method will step through each +// file or directory in the tree, including the root. The files +// are walked in lexical order, which makes the output deterministic +// but means that for very large directories Walker can be inefficient. +// Walker does not follow symbolic links. +type Walker struct { + fs FileSystem + cur item + stack []item + descend bool +} + +type item struct { + path string + info os.FileInfo + err error +} + +// Walk returns a new Walker rooted at root. +func Walk(root string) *Walker { + return WalkFS(root, new(fs)) +} + +// WalkFS returns a new Walker rooted at root on the FileSystem fs. +func WalkFS(root string, fs FileSystem) *Walker { + info, err := fs.Lstat(root) + return &Walker{ + fs: fs, + stack: []item{{root, info, err}}, + } +} + +// Step advances the Walker to the next file or directory, +// which will then be available through the Path, Stat, +// and Err methods. +// It returns false when the walk stops at the end of the tree. +func (w *Walker) Step() bool { + if w.descend && w.cur.err == nil && w.cur.info.IsDir() { + list, err := w.fs.ReadDir(w.cur.path) + if err != nil { + w.cur.err = err + w.stack = append(w.stack, w.cur) + } else { + for i := len(list) - 1; i >= 0; i-- { + path := w.fs.Join(w.cur.path, list[i].Name()) + w.stack = append(w.stack, item{path, list[i], nil}) + } + } + } + + if len(w.stack) == 0 { + return false + } + i := len(w.stack) - 1 + w.cur = w.stack[i] + w.stack = w.stack[:i] + w.descend = true + return true +} + +// Path returns the path to the most recent file or directory +// visited by a call to Step. It contains the argument to Walk +// as a prefix; that is, if Walk is called with "dir", which is +// a directory containing the file "a", Path will return "dir/a". +func (w *Walker) Path() string { + return w.cur.path +} + +// Stat returns info for the most recent file or directory +// visited by a call to Step. +func (w *Walker) Stat() os.FileInfo { + return w.cur.info +} + +// Err returns the error, if any, for the most recent attempt +// by Step to visit a file or directory. If a directory has +// an error, w will not descend into that directory. +func (w *Walker) Err() error { + return w.cur.err +} + +// SkipDir causes the currently visited directory to be skipped. +// If w is not on a directory, SkipDir has no effect. +func (w *Walker) SkipDir() { + w.descend = false +} diff --git a/github.com/kr/fs/walk_test.go b/github.com/kr/fs/walk_test.go new file mode 100644 index 0000000000..6f5ad2ad30 --- /dev/null +++ b/github.com/kr/fs/walk_test.go @@ -0,0 +1,209 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fs_test + +import ( + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/kr/fs" +) + +type PathTest struct { + path, result string +} + +type Node struct { + name string + entries []*Node // nil if the entry is a file + mark int +} + +var tree = &Node{ + "testdata", + []*Node{ + {"a", nil, 0}, + {"b", []*Node{}, 0}, + {"c", nil, 0}, + { + "d", + []*Node{ + {"x", nil, 0}, + {"y", []*Node{}, 0}, + { + "z", + []*Node{ + {"u", nil, 0}, + {"v", nil, 0}, + }, + 0, + }, + }, + 0, + }, + }, + 0, +} + +func walkTree(n *Node, path string, f func(path string, n *Node)) { + f(path, n) + for _, e := range n.entries { + walkTree(e, filepath.Join(path, e.name), f) + } +} + +func makeTree(t *testing.T) { + walkTree(tree, tree.name, func(path string, n *Node) { + if n.entries == nil { + fd, err := os.Create(path) + if err != nil { + t.Errorf("makeTree: %v", err) + return + } + fd.Close() + } else { + os.Mkdir(path, 0770) + } + }) +} + +func markTree(n *Node) { walkTree(n, "", func(path string, n *Node) { n.mark++ }) } + +func checkMarks(t *testing.T, report bool) { + walkTree(tree, tree.name, func(path string, n *Node) { + if n.mark != 1 && report { + t.Errorf("node %s mark = %d; expected 1", path, n.mark) + } + n.mark = 0 + }) +} + +// Assumes that each node name is unique. Good enough for a test. +// If clear is true, any incoming error is cleared before return. The errors +// are always accumulated, though. +func mark(path string, info os.FileInfo, err error, errors *[]error, clear bool) error { + if err != nil { + *errors = append(*errors, err) + if clear { + return nil + } + return err + } + name := info.Name() + walkTree(tree, tree.name, func(path string, n *Node) { + if n.name == name { + n.mark++ + } + }) + return nil +} + +func TestWalk(t *testing.T) { + makeTree(t) + errors := make([]error, 0, 10) + clear := true + markFn := func(walker *fs.Walker) (err error) { + for walker.Step() { + err = mark(walker.Path(), walker.Stat(), walker.Err(), &errors, clear) + if err != nil { + break + } + } + return err + } + // Expect no errors. + err := markFn(fs.Walk(tree.name)) + if err != nil { + t.Fatalf("no error expected, found: %s", err) + } + if len(errors) != 0 { + t.Fatalf("unexpected errors: %s", errors) + } + checkMarks(t, true) + errors = errors[0:0] + + // Test permission errors. Only possible if we're not root + // and only on some file systems (AFS, FAT). To avoid errors during + // all.bash on those file systems, skip during go test -short. + if os.Getuid() > 0 && !testing.Short() { + // introduce 2 errors: chmod top-level directories to 0 + os.Chmod(filepath.Join(tree.name, tree.entries[1].name), 0) + os.Chmod(filepath.Join(tree.name, tree.entries[3].name), 0) + + // 3) capture errors, expect two. + // mark respective subtrees manually + markTree(tree.entries[1]) + markTree(tree.entries[3]) + // correct double-marking of directory itself + tree.entries[1].mark-- + tree.entries[3].mark-- + err := markFn(fs.Walk(tree.name)) + if err != nil { + t.Fatalf("expected no error return from Walk, got %s", err) + } + if len(errors) != 2 { + t.Errorf("expected 2 errors, got %d: %s", len(errors), errors) + } + // the inaccessible subtrees were marked manually + checkMarks(t, true) + errors = errors[0:0] + + // 4) capture errors, stop after first error. + // mark respective subtrees manually + markTree(tree.entries[1]) + markTree(tree.entries[3]) + // correct double-marking of directory itself + tree.entries[1].mark-- + tree.entries[3].mark-- + clear = false // error will stop processing + err = markFn(fs.Walk(tree.name)) + if err == nil { + t.Fatalf("expected error return from Walk") + } + if len(errors) != 1 { + t.Errorf("expected 1 error, got %d: %s", len(errors), errors) + } + // the inaccessible subtrees were marked manually + checkMarks(t, false) + errors = errors[0:0] + + // restore permissions + os.Chmod(filepath.Join(tree.name, tree.entries[1].name), 0770) + os.Chmod(filepath.Join(tree.name, tree.entries[3].name), 0770) + } + + // cleanup + if err := os.RemoveAll(tree.name); err != nil { + t.Errorf("removeTree: %v", err) + } +} + +func TestBug3486(t *testing.T) { // http://code.google.com/p/go/issues/detail?id=3486 + root, err := filepath.EvalSymlinks(runtime.GOROOT()) + if err != nil { + t.Fatal(err) + } + lib := filepath.Join(root, "lib") + src := filepath.Join(root, "src") + seenSrc := false + walker := fs.Walk(root) + for walker.Step() { + if walker.Err() != nil { + t.Fatal(walker.Err()) + } + + switch walker.Path() { + case lib: + walker.SkipDir() + case src: + seenSrc = true + } + } + if !seenSrc { + t.Fatalf("%q not seen", src) + } +} diff --git a/github.com/pkg/sftp/.gitignore b/github.com/pkg/sftp/.gitignore new file mode 100644 index 0000000000..e1ec837c3f --- /dev/null +++ b/github.com/pkg/sftp/.gitignore @@ -0,0 +1,7 @@ +.*.swo +.*.swp + +server_standalone/server_standalone + +examples/*/id_rsa +examples/*/id_rsa.pub diff --git a/github.com/pkg/sftp/.travis.yml b/github.com/pkg/sftp/.travis.yml new file mode 100644 index 0000000000..1cb8382e15 --- /dev/null +++ b/github.com/pkg/sftp/.travis.yml @@ -0,0 +1,37 @@ +language: go +go_import_path: github.com/pkg/sftp + +# current and previous stable releases, and tip +go: + - 1.7.x + - 1.8.x + - tip + +os: + - linux + - osx + +matrix: + exclude: + - os: osx + go: 1.7.x + - os: osx + go: tip + +sudo: false + +addons: + ssh_known_hosts: + - bitbucket.org + +install: + - go get -t -v ./... + - ssh-keygen -t rsa -q -P "" -f $HOME/.ssh/id_rsa + +script: + - go test -integration -v ./... + - go test -testserver -v ./... + - go test -integration -testserver -v ./... + - go test -race -integration -v ./... + - go test -race -testserver -v ./... + - go test -race -integration -testserver -v ./... diff --git a/github.com/pkg/sftp/CONTRIBUTORS b/github.com/pkg/sftp/CONTRIBUTORS new file mode 100644 index 0000000000..5c7196ae6a --- /dev/null +++ b/github.com/pkg/sftp/CONTRIBUTORS @@ -0,0 +1,3 @@ +Dave Cheney +Saulius Gurklys +John Eikenberry diff --git a/github.com/pkg/sftp/LICENSE b/github.com/pkg/sftp/LICENSE new file mode 100644 index 0000000000..b7b53921e9 --- /dev/null +++ b/github.com/pkg/sftp/LICENSE @@ -0,0 +1,9 @@ +Copyright (c) 2013, Dave Cheney +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/github.com/pkg/sftp/README.md b/github.com/pkg/sftp/README.md new file mode 100644 index 0000000000..1fb700c41a --- /dev/null +++ b/github.com/pkg/sftp/README.md @@ -0,0 +1,44 @@ +sftp +---- + +The `sftp` package provides support for file system operations on remote ssh +servers using the SFTP subsystem. It also implements an SFTP server for serving +files from the filesystem. + +[![UNIX Build Status](https://travis-ci.org/pkg/sftp.svg?branch=master)](https://travis-ci.org/pkg/sftp) [![GoDoc](http://godoc.org/github.com/pkg/sftp?status.svg)](http://godoc.org/github.com/pkg/sftp) + +usage and examples +------------------ + +See [godoc.org/github.com/pkg/sftp](http://godoc.org/github.com/pkg/sftp) for +examples and usage. + +The basic operation of the package mirrors the facilities of the +[os](http://golang.org/pkg/os) package. + +The Walker interface for directory traversal is heavily inspired by Keith +Rarick's [fs](http://godoc.org/github.com/kr/fs) package. + +roadmap +------- + + * There is way too much duplication in the Client methods. If there was an + unmarshal(interface{}) method this would reduce a heap of the duplication. + +contributing +------------ + +We welcome pull requests, bug fixes and issue reports. + +Before proposing a large change, first please discuss your change by raising an +issue. + +For API/code bugs, please include a small, self contained code example to +reproduce the issue. For pull requests, remember test coverage. + +We try to handle issues and pull requests with a 0 open philosophy. That means +we will try to address the submission as soon as possible and will work toward +a resolution. If progress can no longer be made (eg. unreproducible bug) or +stops (eg. unresponsive submitter), we will close the bug. + +Thanks. diff --git a/github.com/pkg/sftp/attrs.go b/github.com/pkg/sftp/attrs.go new file mode 100644 index 0000000000..3e4c2912db --- /dev/null +++ b/github.com/pkg/sftp/attrs.go @@ -0,0 +1,237 @@ +package sftp + +// ssh_FXP_ATTRS support +// see http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02#section-5 + +import ( + "os" + "syscall" + "time" +) + +const ( + ssh_FILEXFER_ATTR_SIZE = 0x00000001 + ssh_FILEXFER_ATTR_UIDGID = 0x00000002 + ssh_FILEXFER_ATTR_PERMISSIONS = 0x00000004 + ssh_FILEXFER_ATTR_ACMODTIME = 0x00000008 + ssh_FILEXFER_ATTR_EXTENDED = 0x80000000 +) + +// fileInfo is an artificial type designed to satisfy os.FileInfo. +type fileInfo struct { + name string + size int64 + mode os.FileMode + mtime time.Time + sys interface{} +} + +// Name returns the base name of the file. +func (fi *fileInfo) Name() string { return fi.name } + +// Size returns the length in bytes for regular files; system-dependent for others. +func (fi *fileInfo) Size() int64 { return fi.size } + +// Mode returns file mode bits. +func (fi *fileInfo) Mode() os.FileMode { return fi.mode } + +// ModTime returns the last modification time of the file. +func (fi *fileInfo) ModTime() time.Time { return fi.mtime } + +// IsDir returns true if the file is a directory. +func (fi *fileInfo) IsDir() bool { return fi.Mode().IsDir() } + +func (fi *fileInfo) Sys() interface{} { return fi.sys } + +// FileStat holds the original unmarshalled values from a call to READDIR or *STAT. +// It is exported for the purposes of accessing the raw values via os.FileInfo.Sys() +type FileStat struct { + Size uint64 + Mode uint32 + Mtime uint32 + Atime uint32 + UID uint32 + GID uint32 + Extended []StatExtended +} + +// StatExtended contains additional, extended information for a FileStat. +type StatExtended struct { + ExtType string + ExtData string +} + +func fileInfoFromStat(st *FileStat, name string) os.FileInfo { + fs := &fileInfo{ + name: name, + size: int64(st.Size), + mode: toFileMode(st.Mode), + mtime: time.Unix(int64(st.Mtime), 0), + sys: st, + } + return fs +} + +func fileStatFromInfo(fi os.FileInfo) (uint32, FileStat) { + mtime := fi.ModTime().Unix() + atime := mtime + var flags uint32 = ssh_FILEXFER_ATTR_SIZE | + ssh_FILEXFER_ATTR_PERMISSIONS | + ssh_FILEXFER_ATTR_ACMODTIME + + fileStat := FileStat{ + Size: uint64(fi.Size()), + Mode: fromFileMode(fi.Mode()), + Mtime: uint32(mtime), + Atime: uint32(atime), + } + + // os specific file stat decoding + fileStatFromInfoOs(fi, &flags, &fileStat) + + return flags, fileStat +} + +func unmarshalAttrs(b []byte) (*FileStat, []byte) { + flags, b := unmarshalUint32(b) + var fs FileStat + if flags&ssh_FILEXFER_ATTR_SIZE == ssh_FILEXFER_ATTR_SIZE { + fs.Size, b = unmarshalUint64(b) + } + if flags&ssh_FILEXFER_ATTR_UIDGID == ssh_FILEXFER_ATTR_UIDGID { + fs.UID, b = unmarshalUint32(b) + } + if flags&ssh_FILEXFER_ATTR_UIDGID == ssh_FILEXFER_ATTR_UIDGID { + fs.GID, b = unmarshalUint32(b) + } + if flags&ssh_FILEXFER_ATTR_PERMISSIONS == ssh_FILEXFER_ATTR_PERMISSIONS { + fs.Mode, b = unmarshalUint32(b) + } + if flags&ssh_FILEXFER_ATTR_ACMODTIME == ssh_FILEXFER_ATTR_ACMODTIME { + fs.Atime, b = unmarshalUint32(b) + fs.Mtime, b = unmarshalUint32(b) + } + if flags&ssh_FILEXFER_ATTR_EXTENDED == ssh_FILEXFER_ATTR_EXTENDED { + var count uint32 + count, b = unmarshalUint32(b) + ext := make([]StatExtended, count, count) + for i := uint32(0); i < count; i++ { + var typ string + var data string + typ, b = unmarshalString(b) + data, b = unmarshalString(b) + ext[i] = StatExtended{typ, data} + } + fs.Extended = ext + } + return &fs, b +} + +func marshalFileInfo(b []byte, fi os.FileInfo) []byte { + // attributes variable struct, and also variable per protocol version + // spec version 3 attributes: + // uint32 flags + // uint64 size present only if flag SSH_FILEXFER_ATTR_SIZE + // uint32 uid present only if flag SSH_FILEXFER_ATTR_UIDGID + // uint32 gid present only if flag SSH_FILEXFER_ATTR_UIDGID + // uint32 permissions present only if flag SSH_FILEXFER_ATTR_PERMISSIONS + // uint32 atime present only if flag SSH_FILEXFER_ACMODTIME + // uint32 mtime present only if flag SSH_FILEXFER_ACMODTIME + // uint32 extended_count present only if flag SSH_FILEXFER_ATTR_EXTENDED + // string extended_type + // string extended_data + // ... more extended data (extended_type - extended_data pairs), + // so that number of pairs equals extended_count + + flags, fileStat := fileStatFromInfo(fi) + + b = marshalUint32(b, flags) + if flags&ssh_FILEXFER_ATTR_SIZE != 0 { + b = marshalUint64(b, fileStat.Size) + } + if flags&ssh_FILEXFER_ATTR_UIDGID != 0 { + b = marshalUint32(b, fileStat.UID) + b = marshalUint32(b, fileStat.GID) + } + if flags&ssh_FILEXFER_ATTR_PERMISSIONS != 0 { + b = marshalUint32(b, fileStat.Mode) + } + if flags&ssh_FILEXFER_ATTR_ACMODTIME != 0 { + b = marshalUint32(b, fileStat.Atime) + b = marshalUint32(b, fileStat.Mtime) + } + + return b +} + +// toFileMode converts sftp filemode bits to the os.FileMode specification +func toFileMode(mode uint32) os.FileMode { + var fm = os.FileMode(mode & 0777) + switch mode & syscall.S_IFMT { + case syscall.S_IFBLK: + fm |= os.ModeDevice + case syscall.S_IFCHR: + fm |= os.ModeDevice | os.ModeCharDevice + case syscall.S_IFDIR: + fm |= os.ModeDir + case syscall.S_IFIFO: + fm |= os.ModeNamedPipe + case syscall.S_IFLNK: + fm |= os.ModeSymlink + case syscall.S_IFREG: + // nothing to do + case syscall.S_IFSOCK: + fm |= os.ModeSocket + } + if mode&syscall.S_ISGID != 0 { + fm |= os.ModeSetgid + } + if mode&syscall.S_ISUID != 0 { + fm |= os.ModeSetuid + } + if mode&syscall.S_ISVTX != 0 { + fm |= os.ModeSticky + } + return fm +} + +// fromFileMode converts from the os.FileMode specification to sftp filemode bits +func fromFileMode(mode os.FileMode) uint32 { + ret := uint32(0) + + if mode&os.ModeDevice != 0 { + if mode&os.ModeCharDevice != 0 { + ret |= syscall.S_IFCHR + } else { + ret |= syscall.S_IFBLK + } + } + if mode&os.ModeDir != 0 { + ret |= syscall.S_IFDIR + } + if mode&os.ModeSymlink != 0 { + ret |= syscall.S_IFLNK + } + if mode&os.ModeNamedPipe != 0 { + ret |= syscall.S_IFIFO + } + if mode&os.ModeSetgid != 0 { + ret |= syscall.S_ISGID + } + if mode&os.ModeSetuid != 0 { + ret |= syscall.S_ISUID + } + if mode&os.ModeSticky != 0 { + ret |= syscall.S_ISVTX + } + if mode&os.ModeSocket != 0 { + ret |= syscall.S_IFSOCK + } + + if mode&os.ModeType == 0 { + ret |= syscall.S_IFREG + } + ret |= uint32(mode & os.ModePerm) + + return ret +} diff --git a/github.com/pkg/sftp/attrs_stubs.go b/github.com/pkg/sftp/attrs_stubs.go new file mode 100644 index 0000000000..81cf3eac2b --- /dev/null +++ b/github.com/pkg/sftp/attrs_stubs.go @@ -0,0 +1,11 @@ +// +build !cgo,!plan9 windows android + +package sftp + +import ( + "os" +) + +func fileStatFromInfoOs(fi os.FileInfo, flags *uint32, fileStat *FileStat) { + // todo +} diff --git a/github.com/pkg/sftp/attrs_test.go b/github.com/pkg/sftp/attrs_test.go new file mode 100644 index 0000000000..e234649cd9 --- /dev/null +++ b/github.com/pkg/sftp/attrs_test.go @@ -0,0 +1,45 @@ +package sftp + +import ( + "bytes" + "os" + "reflect" + "testing" + "time" +) + +// ensure that attrs implemenst os.FileInfo +var _ os.FileInfo = new(fileInfo) + +var unmarshalAttrsTests = []struct { + b []byte + want *fileInfo + rest []byte +}{ + {marshal(nil, struct{ Flags uint32 }{}), &fileInfo{mtime: time.Unix(int64(0), 0)}, nil}, + {marshal(nil, struct { + Flags uint32 + Size uint64 + }{ssh_FILEXFER_ATTR_SIZE, 20}), &fileInfo{size: 20, mtime: time.Unix(int64(0), 0)}, nil}, + {marshal(nil, struct { + Flags uint32 + Size uint64 + Permissions uint32 + }{ssh_FILEXFER_ATTR_SIZE | ssh_FILEXFER_ATTR_PERMISSIONS, 20, 0644}), &fileInfo{size: 20, mode: os.FileMode(0644), mtime: time.Unix(int64(0), 0)}, nil}, + {marshal(nil, struct { + Flags uint32 + Size uint64 + UID, GID, Permissions uint32 + }{ssh_FILEXFER_ATTR_SIZE | ssh_FILEXFER_ATTR_UIDGID | ssh_FILEXFER_ATTR_UIDGID | ssh_FILEXFER_ATTR_PERMISSIONS, 20, 1000, 1000, 0644}), &fileInfo{size: 20, mode: os.FileMode(0644), mtime: time.Unix(int64(0), 0)}, nil}, +} + +func TestUnmarshalAttrs(t *testing.T) { + for _, tt := range unmarshalAttrsTests { + stat, rest := unmarshalAttrs(tt.b) + got := fileInfoFromStat(stat, "") + tt.want.sys = got.Sys() + if !reflect.DeepEqual(got, tt.want) || !bytes.Equal(tt.rest, rest) { + t.Errorf("unmarshalAttrs(%#v): want %#v, %#v, got: %#v, %#v", tt.b, tt.want, tt.rest, got, rest) + } + } +} diff --git a/github.com/pkg/sftp/attrs_unix.go b/github.com/pkg/sftp/attrs_unix.go new file mode 100644 index 0000000000..ab6ecdea9f --- /dev/null +++ b/github.com/pkg/sftp/attrs_unix.go @@ -0,0 +1,17 @@ +// +build darwin dragonfly freebsd !android,linux netbsd openbsd solaris +// +build cgo + +package sftp + +import ( + "os" + "syscall" +) + +func fileStatFromInfoOs(fi os.FileInfo, flags *uint32, fileStat *FileStat) { + if statt, ok := fi.Sys().(*syscall.Stat_t); ok { + *flags |= ssh_FILEXFER_ATTR_UIDGID + fileStat.UID = statt.Uid + fileStat.GID = statt.Gid + } +} diff --git a/github.com/pkg/sftp/client.go b/github.com/pkg/sftp/client.go new file mode 100644 index 0000000000..c3012c0751 --- /dev/null +++ b/github.com/pkg/sftp/client.go @@ -0,0 +1,1151 @@ +package sftp + +import ( + "bytes" + "encoding/binary" + "io" + "os" + "path" + "sync/atomic" + "time" + + "github.com/kr/fs" + "github.com/pkg/errors" + "golang.org/x/crypto/ssh" +) + +// InternalInconsistency indicates the packets sent and the data queued to be +// written to the file don't match up. It is an unusual error and if you get it +// you should file a ticket. +var InternalInconsistency = errors.New("internal inconsistency") + +// A ClientOption is a function which applies configuration to a Client. +type ClientOption func(*Client) error + +// This is based on Openssh's max accepted size of 1<<18 - overhead +const maxMaxPacket = (1 << 18) - 1024 + +// MaxPacket sets the maximum size of the payload. The size param must be +// between 32768 (1<<15) and 261120 ((1 << 18) - 1024). The minimum size is +// given by the RFC, while the maximum size is a de-facto standard based on +// Openssh's SFTP server which won't accept packets much larger than that. +// +// Note if you aren't using Openssh's sftp server and get the error "failed to +// send packet header: EOF" when copying a large file try lowering this number. +func MaxPacket(size int) ClientOption { + return func(c *Client) error { + if size < 1<<15 { + return errors.Errorf("size must be greater or equal to 32k") + } + if size > maxMaxPacket { + return errors.Errorf("max packet size is too large (see docs)") + } + c.maxPacket = size + return nil + } +} + +// NewClient creates a new SFTP client on conn, using zero or more option +// functions. +func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) { + s, err := conn.NewSession() + if err != nil { + return nil, err + } + if err := s.RequestSubsystem("sftp"); err != nil { + return nil, err + } + pw, err := s.StdinPipe() + if err != nil { + return nil, err + } + pr, err := s.StdoutPipe() + if err != nil { + return nil, err + } + + return NewClientPipe(pr, pw, opts...) +} + +// NewClientPipe creates a new SFTP client given a Reader and a WriteCloser. +// This can be used for connecting to an SFTP server over TCP/TLS or by using +// the system's ssh client program (e.g. via exec.Command). +func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) { + sftp := &Client{ + clientConn: clientConn{ + conn: conn{ + Reader: rd, + WriteCloser: wr, + }, + inflight: make(map[uint32]chan<- result), + }, + maxPacket: 1 << 15, + } + if err := sftp.applyOptions(opts...); err != nil { + wr.Close() + return nil, err + } + if err := sftp.sendInit(); err != nil { + wr.Close() + return nil, err + } + if err := sftp.recvVersion(); err != nil { + wr.Close() + return nil, err + } + sftp.clientConn.wg.Add(1) + go sftp.loop() + return sftp, nil +} + +// Client represents an SFTP session on a *ssh.ClientConn SSH connection. +// Multiple Clients can be active on a single SSH connection, and a Client +// may be called concurrently from multiple Goroutines. +// +// Client implements the github.com/kr/fs.FileSystem interface. +type Client struct { + clientConn + + maxPacket int // max packet size read or written. + nextid uint32 +} + +// Create creates the named file mode 0666 (before umask), truncating it if +// it already exists. If successful, methods on the returned File can be +// used for I/O; the associated file descriptor has mode O_RDWR. +func (c *Client) Create(path string) (*File, error) { + return c.open(path, flags(os.O_RDWR|os.O_CREATE|os.O_TRUNC)) +} + +const sftpProtocolVersion = 3 // http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02 + +func (c *Client) sendInit() error { + return c.clientConn.conn.sendPacket(sshFxInitPacket{ + Version: sftpProtocolVersion, // http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02 + }) +} + +// returns the next value of c.nextid +func (c *Client) nextID() uint32 { + return atomic.AddUint32(&c.nextid, 1) +} + +func (c *Client) recvVersion() error { + typ, data, err := c.recvPacket() + if err != nil { + return err + } + if typ != ssh_FXP_VERSION { + return &unexpectedPacketErr{ssh_FXP_VERSION, typ} + } + + version, _ := unmarshalUint32(data) + if version != sftpProtocolVersion { + return &unexpectedVersionErr{sftpProtocolVersion, version} + } + + return nil +} + +// Walk returns a new Walker rooted at root. +func (c *Client) Walk(root string) *fs.Walker { + return fs.WalkFS(root, c) +} + +// ReadDir reads the directory named by dirname and returns a list of +// directory entries. +func (c *Client) ReadDir(p string) ([]os.FileInfo, error) { + handle, err := c.opendir(p) + if err != nil { + return nil, err + } + defer c.close(handle) // this has to defer earlier than the lock below + var attrs []os.FileInfo + var done = false + for !done { + id := c.nextID() + typ, data, err1 := c.sendPacket(sshFxpReaddirPacket{ + ID: id, + Handle: handle, + }) + if err1 != nil { + err = err1 + done = true + break + } + switch typ { + case ssh_FXP_NAME: + sid, data := unmarshalUint32(data) + if sid != id { + return nil, &unexpectedIDErr{id, sid} + } + count, data := unmarshalUint32(data) + for i := uint32(0); i < count; i++ { + var filename string + filename, data = unmarshalString(data) + _, data = unmarshalString(data) // discard longname + var attr *FileStat + attr, data = unmarshalAttrs(data) + if filename == "." || filename == ".." { + continue + } + attrs = append(attrs, fileInfoFromStat(attr, path.Base(filename))) + } + case ssh_FXP_STATUS: + // TODO(dfc) scope warning! + err = normaliseError(unmarshalStatus(id, data)) + done = true + default: + return nil, unimplementedPacketErr(typ) + } + } + if err == io.EOF { + err = nil + } + return attrs, err +} + +func (c *Client) opendir(path string) (string, error) { + id := c.nextID() + typ, data, err := c.sendPacket(sshFxpOpendirPacket{ + ID: id, + Path: path, + }) + if err != nil { + return "", err + } + switch typ { + case ssh_FXP_HANDLE: + sid, data := unmarshalUint32(data) + if sid != id { + return "", &unexpectedIDErr{id, sid} + } + handle, _ := unmarshalString(data) + return handle, nil + case ssh_FXP_STATUS: + return "", normaliseError(unmarshalStatus(id, data)) + default: + return "", unimplementedPacketErr(typ) + } +} + +// Stat returns a FileInfo structure describing the file specified by path 'p'. +// If 'p' is a symbolic link, the returned FileInfo structure describes the referent file. +func (c *Client) Stat(p string) (os.FileInfo, error) { + id := c.nextID() + typ, data, err := c.sendPacket(sshFxpStatPacket{ + ID: id, + Path: p, + }) + if err != nil { + return nil, err + } + switch typ { + case ssh_FXP_ATTRS: + sid, data := unmarshalUint32(data) + if sid != id { + return nil, &unexpectedIDErr{id, sid} + } + attr, _ := unmarshalAttrs(data) + return fileInfoFromStat(attr, path.Base(p)), nil + case ssh_FXP_STATUS: + return nil, normaliseError(unmarshalStatus(id, data)) + default: + return nil, unimplementedPacketErr(typ) + } +} + +// Lstat returns a FileInfo structure describing the file specified by path 'p'. +// If 'p' is a symbolic link, the returned FileInfo structure describes the symbolic link. +func (c *Client) Lstat(p string) (os.FileInfo, error) { + id := c.nextID() + typ, data, err := c.sendPacket(sshFxpLstatPacket{ + ID: id, + Path: p, + }) + if err != nil { + return nil, err + } + switch typ { + case ssh_FXP_ATTRS: + sid, data := unmarshalUint32(data) + if sid != id { + return nil, &unexpectedIDErr{id, sid} + } + attr, _ := unmarshalAttrs(data) + return fileInfoFromStat(attr, path.Base(p)), nil + case ssh_FXP_STATUS: + return nil, normaliseError(unmarshalStatus(id, data)) + default: + return nil, unimplementedPacketErr(typ) + } +} + +// ReadLink reads the target of a symbolic link. +func (c *Client) ReadLink(p string) (string, error) { + id := c.nextID() + typ, data, err := c.sendPacket(sshFxpReadlinkPacket{ + ID: id, + Path: p, + }) + if err != nil { + return "", err + } + switch typ { + case ssh_FXP_NAME: + sid, data := unmarshalUint32(data) + if sid != id { + return "", &unexpectedIDErr{id, sid} + } + count, data := unmarshalUint32(data) + if count != 1 { + return "", unexpectedCount(1, count) + } + filename, _ := unmarshalString(data) // ignore dummy attributes + return filename, nil + case ssh_FXP_STATUS: + return "", normaliseError(unmarshalStatus(id, data)) + default: + return "", unimplementedPacketErr(typ) + } +} + +// Symlink creates a symbolic link at 'newname', pointing at target 'oldname' +func (c *Client) Symlink(oldname, newname string) error { + id := c.nextID() + typ, data, err := c.sendPacket(sshFxpSymlinkPacket{ + ID: id, + Linkpath: newname, + Targetpath: oldname, + }) + if err != nil { + return err + } + switch typ { + case ssh_FXP_STATUS: + return normaliseError(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +// setstat is a convience wrapper to allow for changing of various parts of the file descriptor. +func (c *Client) setstat(path string, flags uint32, attrs interface{}) error { + id := c.nextID() + typ, data, err := c.sendPacket(sshFxpSetstatPacket{ + ID: id, + Path: path, + Flags: flags, + Attrs: attrs, + }) + if err != nil { + return err + } + switch typ { + case ssh_FXP_STATUS: + return normaliseError(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +// Chtimes changes the access and modification times of the named file. +func (c *Client) Chtimes(path string, atime time.Time, mtime time.Time) error { + type times struct { + Atime uint32 + Mtime uint32 + } + attrs := times{uint32(atime.Unix()), uint32(mtime.Unix())} + return c.setstat(path, ssh_FILEXFER_ATTR_ACMODTIME, attrs) +} + +// Chown changes the user and group owners of the named file. +func (c *Client) Chown(path string, uid, gid int) error { + type owner struct { + UID uint32 + GID uint32 + } + attrs := owner{uint32(uid), uint32(gid)} + return c.setstat(path, ssh_FILEXFER_ATTR_UIDGID, attrs) +} + +// Chmod changes the permissions of the named file. +func (c *Client) Chmod(path string, mode os.FileMode) error { + return c.setstat(path, ssh_FILEXFER_ATTR_PERMISSIONS, uint32(mode)) +} + +// Truncate sets the size of the named file. Although it may be safely assumed +// that if the size is less than its current size it will be truncated to fit, +// the SFTP protocol does not specify what behavior the server should do when setting +// size greater than the current size. +func (c *Client) Truncate(path string, size int64) error { + return c.setstat(path, ssh_FILEXFER_ATTR_SIZE, uint64(size)) +} + +// Open opens the named file for reading. If successful, methods on the +// returned file can be used for reading; the associated file descriptor +// has mode O_RDONLY. +func (c *Client) Open(path string) (*File, error) { + return c.open(path, flags(os.O_RDONLY)) +} + +// OpenFile is the generalized open call; most users will use Open or +// Create instead. It opens the named file with specified flag (O_RDONLY +// etc.). If successful, methods on the returned File can be used for I/O. +func (c *Client) OpenFile(path string, f int) (*File, error) { + return c.open(path, flags(f)) +} + +func (c *Client) open(path string, pflags uint32) (*File, error) { + id := c.nextID() + typ, data, err := c.sendPacket(sshFxpOpenPacket{ + ID: id, + Path: path, + Pflags: pflags, + }) + if err != nil { + return nil, err + } + switch typ { + case ssh_FXP_HANDLE: + sid, data := unmarshalUint32(data) + if sid != id { + return nil, &unexpectedIDErr{id, sid} + } + handle, _ := unmarshalString(data) + return &File{c: c, path: path, handle: handle}, nil + case ssh_FXP_STATUS: + return nil, normaliseError(unmarshalStatus(id, data)) + default: + return nil, unimplementedPacketErr(typ) + } +} + +// close closes a handle handle previously returned in the response +// to SSH_FXP_OPEN or SSH_FXP_OPENDIR. The handle becomes invalid +// immediately after this request has been sent. +func (c *Client) close(handle string) error { + id := c.nextID() + typ, data, err := c.sendPacket(sshFxpClosePacket{ + ID: id, + Handle: handle, + }) + if err != nil { + return err + } + switch typ { + case ssh_FXP_STATUS: + return normaliseError(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +func (c *Client) fstat(handle string) (*FileStat, error) { + id := c.nextID() + typ, data, err := c.sendPacket(sshFxpFstatPacket{ + ID: id, + Handle: handle, + }) + if err != nil { + return nil, err + } + switch typ { + case ssh_FXP_ATTRS: + sid, data := unmarshalUint32(data) + if sid != id { + return nil, &unexpectedIDErr{id, sid} + } + attr, _ := unmarshalAttrs(data) + return attr, nil + case ssh_FXP_STATUS: + return nil, normaliseError(unmarshalStatus(id, data)) + default: + return nil, unimplementedPacketErr(typ) + } +} + +// StatVFS retrieves VFS statistics from a remote host. +// +// It implements the statvfs@openssh.com SSH_FXP_EXTENDED feature +// from http://www.opensource.apple.com/source/OpenSSH/OpenSSH-175/openssh/PROTOCOL?txt. +func (c *Client) StatVFS(path string) (*StatVFS, error) { + // send the StatVFS packet to the server + id := c.nextID() + typ, data, err := c.sendPacket(sshFxpStatvfsPacket{ + ID: id, + Path: path, + }) + if err != nil { + return nil, err + } + + switch typ { + // server responded with valid data + case ssh_FXP_EXTENDED_REPLY: + var response StatVFS + err = binary.Read(bytes.NewReader(data), binary.BigEndian, &response) + if err != nil { + return nil, errors.New("can not parse reply") + } + + return &response, nil + + // the resquest failed + case ssh_FXP_STATUS: + return nil, errors.New(fxp(ssh_FXP_STATUS).String()) + + default: + return nil, unimplementedPacketErr(typ) + } +} + +// Join joins any number of path elements into a single path, adding a +// separating slash if necessary. The result is Cleaned; in particular, all +// empty strings are ignored. +func (c *Client) Join(elem ...string) string { return path.Join(elem...) } + +// Remove removes the specified file or directory. An error will be returned if no +// file or directory with the specified path exists, or if the specified directory +// is not empty. +func (c *Client) Remove(path string) error { + err := c.removeFile(path) + if err, ok := err.(*StatusError); ok { + switch err.Code { + // some servers, *cough* osx *cough*, return EPERM, not ENODIR. + // serv-u returns ssh_FX_FILE_IS_A_DIRECTORY + case ssh_FX_PERMISSION_DENIED, ssh_FX_FAILURE, ssh_FX_FILE_IS_A_DIRECTORY: + return c.RemoveDirectory(path) + } + } + return err +} + +func (c *Client) removeFile(path string) error { + id := c.nextID() + typ, data, err := c.sendPacket(sshFxpRemovePacket{ + ID: id, + Filename: path, + }) + if err != nil { + return err + } + switch typ { + case ssh_FXP_STATUS: + return normaliseError(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +// RemoveDirectory removes a directory path. +func (c *Client) RemoveDirectory(path string) error { + id := c.nextID() + typ, data, err := c.sendPacket(sshFxpRmdirPacket{ + ID: id, + Path: path, + }) + if err != nil { + return err + } + switch typ { + case ssh_FXP_STATUS: + return normaliseError(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +// Rename renames a file. +func (c *Client) Rename(oldname, newname string) error { + id := c.nextID() + typ, data, err := c.sendPacket(sshFxpRenamePacket{ + ID: id, + Oldpath: oldname, + Newpath: newname, + }) + if err != nil { + return err + } + switch typ { + case ssh_FXP_STATUS: + return normaliseError(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +func (c *Client) realpath(path string) (string, error) { + id := c.nextID() + typ, data, err := c.sendPacket(sshFxpRealpathPacket{ + ID: id, + Path: path, + }) + if err != nil { + return "", err + } + switch typ { + case ssh_FXP_NAME: + sid, data := unmarshalUint32(data) + if sid != id { + return "", &unexpectedIDErr{id, sid} + } + count, data := unmarshalUint32(data) + if count != 1 { + return "", unexpectedCount(1, count) + } + filename, _ := unmarshalString(data) // ignore attributes + return filename, nil + case ssh_FXP_STATUS: + return "", normaliseError(unmarshalStatus(id, data)) + default: + return "", unimplementedPacketErr(typ) + } +} + +// Getwd returns the current working directory of the server. Operations +// involving relative paths will be based at this location. +func (c *Client) Getwd() (string, error) { + return c.realpath(".") +} + +// Mkdir creates the specified directory. An error will be returned if a file or +// directory with the specified path already exists, or if the directory's +// parent folder does not exist (the method cannot create complete paths). +func (c *Client) Mkdir(path string) error { + id := c.nextID() + typ, data, err := c.sendPacket(sshFxpMkdirPacket{ + ID: id, + Path: path, + }) + if err != nil { + return err + } + switch typ { + case ssh_FXP_STATUS: + return normaliseError(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +// applyOptions applies options functions to the Client. +// If an error is encountered, option processing ceases. +func (c *Client) applyOptions(opts ...ClientOption) error { + for _, f := range opts { + if err := f(c); err != nil { + return err + } + } + return nil +} + +// File represents a remote file. +type File struct { + c *Client + path string + handle string + offset uint64 // current offset within remote file +} + +// Close closes the File, rendering it unusable for I/O. It returns an +// error, if any. +func (f *File) Close() error { + return f.c.close(f.handle) +} + +// Name returns the name of the file as presented to Open or Create. +func (f *File) Name() string { + return f.path +} + +const maxConcurrentRequests = 64 + +// Read reads up to len(b) bytes from the File. It returns the number of bytes +// read and an error, if any. Read follows io.Reader semantics, so when Read +// encounters an error or EOF condition after successfully reading n > 0 bytes, +// it returns the number of bytes read. +func (f *File) Read(b []byte) (int, error) { + // Split the read into multiple maxPacket sized concurrent reads + // bounded by maxConcurrentRequests. This allows reads with a suitably + // large buffer to transfer data at a much faster rate due to + // overlapping round trip times. + inFlight := 0 + desiredInFlight := 1 + offset := f.offset + // maxConcurrentRequests buffer to deal with broadcastErr() floods + // also must have a buffer of max value of (desiredInFlight - inFlight) + ch := make(chan result, maxConcurrentRequests) + type inflightRead struct { + b []byte + offset uint64 + } + reqs := map[uint32]inflightRead{} + type offsetErr struct { + offset uint64 + err error + } + var firstErr offsetErr + + sendReq := func(b []byte, offset uint64) { + reqID := f.c.nextID() + f.c.dispatchRequest(ch, sshFxpReadPacket{ + ID: reqID, + Handle: f.handle, + Offset: offset, + Len: uint32(len(b)), + }) + inFlight++ + reqs[reqID] = inflightRead{b: b, offset: offset} + } + + var read int + for len(b) > 0 || inFlight > 0 { + for inFlight < desiredInFlight && len(b) > 0 && firstErr.err == nil { + l := min(len(b), f.c.maxPacket) + rb := b[:l] + sendReq(rb, offset) + offset += uint64(l) + b = b[l:] + } + + if inFlight == 0 { + break + } + res := <-ch + inFlight-- + if res.err != nil { + firstErr = offsetErr{offset: 0, err: res.err} + continue + } + reqID, data := unmarshalUint32(res.data) + req, ok := reqs[reqID] + if !ok { + firstErr = offsetErr{offset: 0, err: errors.Errorf("sid: %v not found", reqID)} + continue + } + delete(reqs, reqID) + switch res.typ { + case ssh_FXP_STATUS: + if firstErr.err == nil || req.offset < firstErr.offset { + firstErr = offsetErr{ + offset: req.offset, + err: normaliseError(unmarshalStatus(reqID, res.data)), + } + } + case ssh_FXP_DATA: + l, data := unmarshalUint32(data) + n := copy(req.b, data[:l]) + read += n + if n < len(req.b) { + sendReq(req.b[l:], req.offset+uint64(l)) + } + if desiredInFlight < maxConcurrentRequests { + desiredInFlight++ + } + default: + firstErr = offsetErr{offset: 0, err: unimplementedPacketErr(res.typ)} + } + } + // If the error is anything other than EOF, then there + // may be gaps in the data copied to the buffer so it's + // best to return 0 so the caller can't make any + // incorrect assumptions about the state of the buffer. + if firstErr.err != nil && firstErr.err != io.EOF { + read = 0 + } + f.offset += uint64(read) + return read, firstErr.err +} + +// WriteTo writes the file to w. The return value is the number of bytes +// written. Any error encountered during the write is also returned. +func (f *File) WriteTo(w io.Writer) (int64, error) { + fi, err := f.Stat() + if err != nil { + return 0, err + } + inFlight := 0 + desiredInFlight := 1 + offset := f.offset + writeOffset := offset + fileSize := uint64(fi.Size()) + // see comment on same line in Read() above + ch := make(chan result, maxConcurrentRequests) + type inflightRead struct { + b []byte + offset uint64 + } + reqs := map[uint32]inflightRead{} + pendingWrites := map[uint64][]byte{} + type offsetErr struct { + offset uint64 + err error + } + var firstErr offsetErr + + sendReq := func(b []byte, offset uint64) { + reqID := f.c.nextID() + f.c.dispatchRequest(ch, sshFxpReadPacket{ + ID: reqID, + Handle: f.handle, + Offset: offset, + Len: uint32(len(b)), + }) + inFlight++ + reqs[reqID] = inflightRead{b: b, offset: offset} + } + + var copied int64 + for firstErr.err == nil || inFlight > 0 { + if firstErr.err == nil { + for inFlight+len(pendingWrites) < desiredInFlight { + b := make([]byte, f.c.maxPacket) + sendReq(b, offset) + offset += uint64(f.c.maxPacket) + if offset > fileSize { + desiredInFlight = 1 + } + } + } + + if inFlight == 0 { + if firstErr.err == nil && len(pendingWrites) > 0 { + return copied, InternalInconsistency + } + break + } + res := <-ch + inFlight-- + if res.err != nil { + firstErr = offsetErr{offset: 0, err: res.err} + continue + } + reqID, data := unmarshalUint32(res.data) + req, ok := reqs[reqID] + if !ok { + firstErr = offsetErr{offset: 0, err: errors.Errorf("sid: %v not found", reqID)} + continue + } + delete(reqs, reqID) + switch res.typ { + case ssh_FXP_STATUS: + if firstErr.err == nil || req.offset < firstErr.offset { + firstErr = offsetErr{offset: req.offset, err: normaliseError(unmarshalStatus(reqID, res.data))} + } + case ssh_FXP_DATA: + l, data := unmarshalUint32(data) + if req.offset == writeOffset { + nbytes, err := w.Write(data) + copied += int64(nbytes) + if err != nil { + // We will never receive another DATA with offset==writeOffset, so + // the loop will drain inFlight and then exit. + firstErr = offsetErr{offset: req.offset + uint64(nbytes), err: err} + break + } + if nbytes < int(l) { + firstErr = offsetErr{offset: req.offset + uint64(nbytes), err: io.ErrShortWrite} + break + } + switch { + case offset > fileSize: + desiredInFlight = 1 + case desiredInFlight < maxConcurrentRequests: + desiredInFlight++ + } + writeOffset += uint64(nbytes) + for { + pendingData, ok := pendingWrites[writeOffset] + if !ok { + break + } + // Give go a chance to free the memory. + delete(pendingWrites, writeOffset) + nbytes, err := w.Write(pendingData) + // Do not move writeOffset on error so subsequent iterations won't trigger + // any writes. + if err != nil { + firstErr = offsetErr{offset: writeOffset + uint64(nbytes), err: err} + break + } + if nbytes < len(pendingData) { + firstErr = offsetErr{offset: writeOffset + uint64(nbytes), err: io.ErrShortWrite} + break + } + writeOffset += uint64(nbytes) + } + } else { + // Don't write the data yet because + // this response came in out of order + // and we need to wait for responses + // for earlier segments of the file. + pendingWrites[req.offset] = data + } + default: + firstErr = offsetErr{offset: 0, err: unimplementedPacketErr(res.typ)} + } + } + if firstErr.err != io.EOF { + return copied, firstErr.err + } + return copied, nil +} + +// Stat returns the FileInfo structure describing file. If there is an +// error. +func (f *File) Stat() (os.FileInfo, error) { + fs, err := f.c.fstat(f.handle) + if err != nil { + return nil, err + } + return fileInfoFromStat(fs, path.Base(f.path)), nil +} + +// Write writes len(b) bytes to the File. It returns the number of bytes +// written and an error, if any. Write returns a non-nil error when n != +// len(b). +func (f *File) Write(b []byte) (int, error) { + // Split the write into multiple maxPacket sized concurrent writes + // bounded by maxConcurrentRequests. This allows writes with a suitably + // large buffer to transfer data at a much faster rate due to + // overlapping round trip times. + inFlight := 0 + desiredInFlight := 1 + offset := f.offset + // see comment on same line in Read() above + ch := make(chan result, maxConcurrentRequests) + var firstErr error + written := len(b) + for len(b) > 0 || inFlight > 0 { + for inFlight < desiredInFlight && len(b) > 0 && firstErr == nil { + l := min(len(b), f.c.maxPacket) + rb := b[:l] + f.c.dispatchRequest(ch, sshFxpWritePacket{ + ID: f.c.nextID(), + Handle: f.handle, + Offset: offset, + Length: uint32(len(rb)), + Data: rb, + }) + inFlight++ + offset += uint64(l) + b = b[l:] + } + + if inFlight == 0 { + break + } + res := <-ch + inFlight-- + if res.err != nil { + firstErr = res.err + continue + } + switch res.typ { + case ssh_FXP_STATUS: + id, _ := unmarshalUint32(res.data) + err := normaliseError(unmarshalStatus(id, res.data)) + if err != nil && firstErr == nil { + firstErr = err + break + } + if desiredInFlight < maxConcurrentRequests { + desiredInFlight++ + } + default: + firstErr = unimplementedPacketErr(res.typ) + } + } + // If error is non-nil, then there may be gaps in the data written to + // the file so it's best to return 0 so the caller can't make any + // incorrect assumptions about the state of the file. + if firstErr != nil { + written = 0 + } + f.offset += uint64(written) + return written, firstErr +} + +// ReadFrom reads data from r until EOF and writes it to the file. The return +// value is the number of bytes read. Any error except io.EOF encountered +// during the read is also returned. +func (f *File) ReadFrom(r io.Reader) (int64, error) { + inFlight := 0 + desiredInFlight := 1 + offset := f.offset + // see comment on same line in Read() above + ch := make(chan result, maxConcurrentRequests) + var firstErr error + read := int64(0) + b := make([]byte, f.c.maxPacket) + for inFlight > 0 || firstErr == nil { + for inFlight < desiredInFlight && firstErr == nil { + n, err := r.Read(b) + if err != nil { + firstErr = err + } + f.c.dispatchRequest(ch, sshFxpWritePacket{ + ID: f.c.nextID(), + Handle: f.handle, + Offset: offset, + Length: uint32(n), + Data: b[:n], + }) + inFlight++ + offset += uint64(n) + read += int64(n) + } + + if inFlight == 0 { + break + } + res := <-ch + inFlight-- + if res.err != nil { + firstErr = res.err + continue + } + switch res.typ { + case ssh_FXP_STATUS: + id, _ := unmarshalUint32(res.data) + err := normaliseError(unmarshalStatus(id, res.data)) + if err != nil && firstErr == nil { + firstErr = err + break + } + if desiredInFlight < maxConcurrentRequests { + desiredInFlight++ + } + default: + firstErr = unimplementedPacketErr(res.typ) + } + } + if firstErr == io.EOF { + firstErr = nil + } + // If error is non-nil, then there may be gaps in the data written to + // the file so it's best to return 0 so the caller can't make any + // incorrect assumptions about the state of the file. + if firstErr != nil { + read = 0 + } + f.offset += uint64(read) + return read, firstErr +} + +// Seek implements io.Seeker by setting the client offset for the next Read or +// Write. It returns the next offset read. Seeking before or after the end of +// the file is undefined. Seeking relative to the end calls Stat. +func (f *File) Seek(offset int64, whence int) (int64, error) { + switch whence { + case os.SEEK_SET: + f.offset = uint64(offset) + case os.SEEK_CUR: + f.offset = uint64(int64(f.offset) + offset) + case os.SEEK_END: + fi, err := f.Stat() + if err != nil { + return int64(f.offset), err + } + f.offset = uint64(fi.Size() + offset) + default: + return int64(f.offset), unimplementedSeekWhence(whence) + } + return int64(f.offset), nil +} + +// Chown changes the uid/gid of the current file. +func (f *File) Chown(uid, gid int) error { + return f.c.Chown(f.path, uid, gid) +} + +// Chmod changes the permissions of the current file. +func (f *File) Chmod(mode os.FileMode) error { + return f.c.Chmod(f.path, mode) +} + +// Truncate sets the size of the current file. Although it may be safely assumed +// that if the size is less than its current size it will be truncated to fit, +// the SFTP protocol does not specify what behavior the server should do when setting +// size greater than the current size. +func (f *File) Truncate(size int64) error { + return f.c.Truncate(f.path, size) +} + +func min(a, b int) int { + if a > b { + return b + } + return a +} + +// normaliseError normalises an error into a more standard form that can be +// checked against stdlib errors like io.EOF or os.ErrNotExist. +func normaliseError(err error) error { + switch err := err.(type) { + case *StatusError: + switch err.Code { + case ssh_FX_EOF: + return io.EOF + case ssh_FX_NO_SUCH_FILE: + return os.ErrNotExist + case ssh_FX_OK: + return nil + default: + return err + } + default: + return err + } +} + +func unmarshalStatus(id uint32, data []byte) error { + sid, data := unmarshalUint32(data) + if sid != id { + return &unexpectedIDErr{id, sid} + } + code, data := unmarshalUint32(data) + msg, data, _ := unmarshalStringSafe(data) + lang, _, _ := unmarshalStringSafe(data) + return &StatusError{ + Code: code, + msg: msg, + lang: lang, + } +} + +func marshalStatus(b []byte, err StatusError) []byte { + b = marshalUint32(b, err.Code) + b = marshalString(b, err.msg) + b = marshalString(b, err.lang) + return b +} + +// flags converts the flags passed to OpenFile into ssh flags. +// Unsupported flags are ignored. +func flags(f int) uint32 { + var out uint32 + switch f & os.O_WRONLY { + case os.O_WRONLY: + out |= ssh_FXF_WRITE + case os.O_RDONLY: + out |= ssh_FXF_READ + } + if f&os.O_RDWR == os.O_RDWR { + out |= ssh_FXF_READ | ssh_FXF_WRITE + } + if f&os.O_APPEND == os.O_APPEND { + out |= ssh_FXF_APPEND + } + if f&os.O_CREATE == os.O_CREATE { + out |= ssh_FXF_CREAT + } + if f&os.O_TRUNC == os.O_TRUNC { + out |= ssh_FXF_TRUNC + } + if f&os.O_EXCL == os.O_EXCL { + out |= ssh_FXF_EXCL + } + return out +} diff --git a/github.com/pkg/sftp/client_integration_darwin_test.go b/github.com/pkg/sftp/client_integration_darwin_test.go new file mode 100644 index 0000000000..6c72536f31 --- /dev/null +++ b/github.com/pkg/sftp/client_integration_darwin_test.go @@ -0,0 +1,42 @@ +package sftp + +import ( + "syscall" + "testing" +) + +const sftpServer = "/usr/libexec/sftp-server" + +func TestClientStatVFS(t *testing.T) { + if *testServerImpl { + t.Skipf("go server does not support FXP_EXTENDED") + } + sftp, cmd := testClient(t, READWRITE, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + vfs, err := sftp.StatVFS("/") + if err != nil { + t.Fatal(err) + } + + // get system stats + s := syscall.Statfs_t{} + err = syscall.Statfs("/", &s) + if err != nil { + t.Fatal(err) + } + + // check some stats + if vfs.Files != uint64(s.Files) { + t.Fatal("fr_size does not match") + } + + if vfs.Bfree != uint64(s.Bfree) { + t.Fatal("f_bsize does not match") + } + + if vfs.Favail != uint64(s.Ffree) { + t.Fatal("f_namemax does not match") + } +} diff --git a/github.com/pkg/sftp/client_integration_linux_test.go b/github.com/pkg/sftp/client_integration_linux_test.go new file mode 100644 index 0000000000..1517998e2d --- /dev/null +++ b/github.com/pkg/sftp/client_integration_linux_test.go @@ -0,0 +1,42 @@ +package sftp + +import ( + "syscall" + "testing" +) + +const sftpServer = "/usr/lib/openssh/sftp-server" + +func TestClientStatVFS(t *testing.T) { + if *testServerImpl { + t.Skipf("go server does not support FXP_EXTENDED") + } + sftp, cmd := testClient(t, READWRITE, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + vfs, err := sftp.StatVFS("/") + if err != nil { + t.Fatal(err) + } + + // get system stats + s := syscall.Statfs_t{} + err = syscall.Statfs("/", &s) + if err != nil { + t.Fatal(err) + } + + // check some stats + if vfs.Frsize != uint64(s.Frsize) { + t.Fatalf("fr_size does not match, expected: %v, got: %v", s.Frsize, vfs.Frsize) + } + + if vfs.Bsize != uint64(s.Bsize) { + t.Fatalf("f_bsize does not match, expected: %v, got: %v", s.Bsize, vfs.Bsize) + } + + if vfs.Namemax != uint64(s.Namelen) { + t.Fatalf("f_namemax does not match, expected: %v, got: %v", s.Namelen, vfs.Namemax) + } +} diff --git a/github.com/pkg/sftp/client_integration_test.go b/github.com/pkg/sftp/client_integration_test.go new file mode 100644 index 0000000000..ef73a08c1d --- /dev/null +++ b/github.com/pkg/sftp/client_integration_test.go @@ -0,0 +1,2167 @@ +package sftp + +// sftp integration tests +// enable with -integration + +import ( + "bytes" + "crypto/sha1" + "encoding" + "errors" + "flag" + "io" + "io/ioutil" + "math/rand" + "net" + "os" + "os/exec" + "os/user" + "path" + "path/filepath" + "reflect" + "regexp" + "strconv" + "testing" + "testing/quick" + "time" + + "sort" + + "github.com/kr/fs" +) + +const ( + READONLY = true + READWRITE = false + NO_DELAY time.Duration = 0 + + debuglevel = "ERROR" // set to "DEBUG" for debugging +) + +var testServerImpl = flag.Bool("testserver", false, "perform integration tests against sftp package server instance") +var testIntegration = flag.Bool("integration", false, "perform integration tests against sftp server process") +var testSftp = flag.String("sftp", sftpServer, "location of the sftp server binary") + +type delayedWrite struct { + t time.Time + b []byte +} + +// delayedWriter wraps a writer and artificially delays the write. This is +// meant to mimic connections with various latencies. Error's returned from the +// underlying writer will panic so this should only be used over reliable +// connections. +type delayedWriter struct { + w io.WriteCloser + ch chan delayedWrite + closed chan struct{} +} + +func newDelayedWriter(w io.WriteCloser, delay time.Duration) io.WriteCloser { + ch := make(chan delayedWrite, 128) + closed := make(chan struct{}) + go func() { + for writeMsg := range ch { + time.Sleep(writeMsg.t.Add(delay).Sub(time.Now())) + n, err := w.Write(writeMsg.b) + if err != nil { + panic("write error") + } + if n < len(writeMsg.b) { + panic("showrt write") + } + } + w.Close() + close(closed) + }() + return delayedWriter{w: w, ch: ch, closed: closed} +} + +func (w delayedWriter) Write(b []byte) (int, error) { + bcopy := make([]byte, len(b)) + copy(bcopy, b) + w.ch <- delayedWrite{t: time.Now(), b: bcopy} + return len(b), nil +} + +func (w delayedWriter) Close() error { + close(w.ch) + <-w.closed + return nil +} + +// netPipe provides a pair of io.ReadWriteClosers connected to each other. +// The functions is identical to os.Pipe with the exception that netPipe +// provides the Read/Close guarantees that os.File derrived pipes do not. +func netPipe(t testing.TB) (io.ReadWriteCloser, io.ReadWriteCloser) { + type result struct { + net.Conn + error + } + + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + ch := make(chan result, 1) + go func() { + conn, err := l.Accept() + ch <- result{conn, err} + err = l.Close() + if err != nil { + t.Error(err) + } + }() + c1, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + l.Close() // might cause another in the listening goroutine, but too bad + t.Fatal(err) + } + r := <-ch + if r.error != nil { + t.Fatal(err) + } + return c1, r.Conn +} + +func testClientGoSvr(t testing.TB, readonly bool, delay time.Duration) (*Client, *exec.Cmd) { + c1, c2 := netPipe(t) + + options := []ServerOption{WithDebug(os.Stderr)} + if readonly { + options = append(options, ReadOnly()) + } + + server, err := NewServer(c1, options...) + if err != nil { + t.Fatal(err) + } + go server.Serve() + + var ctx io.WriteCloser = c2 + if delay > NO_DELAY { + ctx = newDelayedWriter(ctx, delay) + } + + client, err := NewClientPipe(c2, ctx) + if err != nil { + t.Fatal(err) + } + + // dummy command... + return client, exec.Command("true") +} + +// testClient returns a *Client connected to a localy running sftp-server +// the *exec.Cmd returned must be defer Wait'd. +func testClient(t testing.TB, readonly bool, delay time.Duration) (*Client, *exec.Cmd) { + if !*testIntegration { + t.Skip("skipping intergration test") + } + + if *testServerImpl { + return testClientGoSvr(t, readonly, delay) + } + + cmd := exec.Command(*testSftp, "-e", "-R", "-l", debuglevel) // log to stderr, read only + if !readonly { + cmd = exec.Command(*testSftp, "-e", "-l", debuglevel) // log to stderr + } + cmd.Stderr = os.Stdout + pw, err := cmd.StdinPipe() + if err != nil { + t.Fatal(err) + } + if delay > NO_DELAY { + pw = newDelayedWriter(pw, delay) + } + pr, err := cmd.StdoutPipe() + if err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Skipf("could not start sftp-server process: %v", err) + } + + sftp, err := NewClientPipe(pr, pw) + if err != nil { + t.Fatal(err) + } + + return sftp, cmd +} + +func TestNewClient(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + + if err := sftp.Close(); err != nil { + t.Fatal(err) + } +} + +func TestClientLstat(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + + want, err := os.Lstat(f.Name()) + if err != nil { + t.Fatal(err) + } + + got, err := sftp.Lstat(f.Name()) + if err != nil { + t.Fatal(err) + } + + if !sameFile(want, got) { + t.Fatalf("Lstat(%q): want %#v, got %#v", f.Name(), want, got) + } +} + +func TestClientLstatIsNotExist(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + os.Remove(f.Name()) + + if _, err := sftp.Lstat(f.Name()); !os.IsNotExist(err) { + t.Errorf("os.IsNotExist(%v) = false, want true", err) + } +} + +func TestClientMkdir(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + dir, err := ioutil.TempDir("", "sftptest") + if err != nil { + t.Fatal(err) + } + sub := path.Join(dir, "mkdir1") + if err := sftp.Mkdir(sub); err != nil { + t.Fatal(err) + } + if _, err := os.Lstat(sub); err != nil { + t.Fatal(err) + } +} + +func TestClientOpen(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + + got, err := sftp.Open(f.Name()) + if err != nil { + t.Fatal(err) + } + if err := got.Close(); err != nil { + t.Fatal(err) + } +} + +func TestClientOpenIsNotExist(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + if _, err := sftp.Open("/doesnt/exist/"); !os.IsNotExist(err) { + t.Errorf("os.IsNotExist(%v) = false, want true", err) + } +} + +func TestClientStatIsNotExist(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + if _, err := sftp.Stat("/doesnt/exist/"); !os.IsNotExist(err) { + t.Errorf("os.IsNotExist(%v) = false, want true", err) + } +} + +const seekBytes = 128 * 1024 + +type seek struct { + offset int64 +} + +func (s seek) Generate(r *rand.Rand, _ int) reflect.Value { + s.offset = int64(r.Int31n(seekBytes)) + return reflect.ValueOf(s) +} + +func (s seek) set(t *testing.T, r io.ReadSeeker) { + if _, err := r.Seek(s.offset, os.SEEK_SET); err != nil { + t.Fatalf("error while seeking with %+v: %v", s, err) + } +} + +func (s seek) current(t *testing.T, r io.ReadSeeker) { + const mid = seekBytes / 2 + + skip := s.offset / 2 + if s.offset > mid { + skip = -skip + } + + if _, err := r.Seek(mid, os.SEEK_SET); err != nil { + t.Fatalf("error seeking to midpoint with %+v: %v", s, err) + } + if _, err := r.Seek(skip, os.SEEK_CUR); err != nil { + t.Fatalf("error seeking from %d with %+v: %v", mid, s, err) + } +} + +func (s seek) end(t *testing.T, r io.ReadSeeker) { + if _, err := r.Seek(-s.offset, os.SEEK_END); err != nil { + t.Fatalf("error seeking from end with %+v: %v", s, err) + } +} + +func TestClientSeek(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + fOS, err := ioutil.TempFile("", "seek-test") + if err != nil { + t.Fatal(err) + } + defer fOS.Close() + + fSFTP, err := sftp.Open(fOS.Name()) + if err != nil { + t.Fatal(err) + } + defer fSFTP.Close() + + writeN(t, fOS, seekBytes) + + if err := quick.CheckEqual( + func(s seek) (string, int64) { s.set(t, fOS); return readHash(t, fOS) }, + func(s seek) (string, int64) { s.set(t, fSFTP); return readHash(t, fSFTP) }, + nil, + ); err != nil { + t.Errorf("Seek: expected equal absolute seeks: %v", err) + } + + if err := quick.CheckEqual( + func(s seek) (string, int64) { s.current(t, fOS); return readHash(t, fOS) }, + func(s seek) (string, int64) { s.current(t, fSFTP); return readHash(t, fSFTP) }, + nil, + ); err != nil { + t.Errorf("Seek: expected equal seeks from middle: %v", err) + } + + if err := quick.CheckEqual( + func(s seek) (string, int64) { s.end(t, fOS); return readHash(t, fOS) }, + func(s seek) (string, int64) { s.end(t, fSFTP); return readHash(t, fSFTP) }, + nil, + ); err != nil { + t.Errorf("Seek: expected equal seeks from end: %v", err) + } +} + +func TestClientCreate(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer f.Close() + defer os.Remove(f.Name()) + + f2, err := sftp.Create(f.Name()) + if err != nil { + t.Fatal(err) + } + defer f2.Close() +} + +func TestClientAppend(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer f.Close() + defer os.Remove(f.Name()) + + f2, err := sftp.OpenFile(f.Name(), os.O_RDWR|os.O_APPEND) + if err != nil { + t.Fatal(err) + } + defer f2.Close() +} + +func TestClientCreateFailed(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer f.Close() + defer os.Remove(f.Name()) + + f2, err := sftp.Create(f.Name()) + if err1, ok := err.(*StatusError); !ok || err1.Code != ssh_FX_PERMISSION_DENIED { + t.Fatalf("Create: want: %v, got %#v", ssh_FX_PERMISSION_DENIED, err) + } + if err == nil { + f2.Close() + } +} + +func TestClientFileName(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + + f2, err := sftp.Open(f.Name()) + if err != nil { + t.Fatal(err) + } + + if got, want := f2.Name(), f.Name(); got != want { + t.Fatalf("Name: got %q want %q", want, got) + } +} + +func TestClientFileStat(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + + want, err := os.Lstat(f.Name()) + if err != nil { + t.Fatal(err) + } + + f2, err := sftp.Open(f.Name()) + if err != nil { + t.Fatal(err) + } + + got, err := f2.Stat() + if err != nil { + t.Fatal(err) + } + + if !sameFile(want, got) { + t.Fatalf("Lstat(%q): want %#v, got %#v", f.Name(), want, got) + } +} + +func TestClientStatLink(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + + realName := f.Name() + linkName := f.Name() + ".softlink" + + // create a symlink that points at sftptest + if err := os.Symlink(realName, linkName); err != nil { + t.Fatal(err) + } + defer os.Remove(linkName) + + // compare Lstat of links + wantLstat, err := os.Lstat(linkName) + if err != nil { + t.Fatal(err) + } + wantStat, err := os.Stat(linkName) + if err != nil { + t.Fatal(err) + } + + gotLstat, err := sftp.Lstat(linkName) + if err != nil { + t.Fatal(err) + } + gotStat, err := sftp.Stat(linkName) + if err != nil { + t.Fatal(err) + } + + // check that stat is not lstat from os package + if sameFile(wantLstat, wantStat) { + t.Fatalf("Lstat / Stat(%q): both %#v %#v", f.Name(), wantLstat, wantStat) + } + + // compare Lstat of links + if !sameFile(wantLstat, gotLstat) { + t.Fatalf("Lstat(%q): want %#v, got %#v", f.Name(), wantLstat, gotLstat) + } + + // compare Stat of links + if !sameFile(wantStat, gotStat) { + t.Fatalf("Stat(%q): want %#v, got %#v", f.Name(), wantStat, gotStat) + } + + // check that stat is not lstat + if sameFile(gotLstat, gotStat) { + t.Fatalf("Lstat / Stat(%q): both %#v %#v", f.Name(), gotLstat, gotStat) + } +} + +func TestClientRemove(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + if err := sftp.Remove(f.Name()); err != nil { + t.Fatal(err) + } + if _, err := os.Lstat(f.Name()); !os.IsNotExist(err) { + t.Fatal(err) + } +} + +func TestClientRemoveDir(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + dir, err := ioutil.TempDir("", "sftptest") + if err != nil { + t.Fatal(err) + } + if err := sftp.Remove(dir); err != nil { + t.Fatal(err) + } + if _, err := os.Lstat(dir); !os.IsNotExist(err) { + t.Fatal(err) + } +} + +func TestClientRemoveFailed(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + if err := sftp.Remove(f.Name()); err == nil { + t.Fatalf("Remove(%v): want: permission denied, got %v", f.Name(), err) + } + if _, err := os.Lstat(f.Name()); err != nil { + t.Fatal(err) + } +} + +func TestClientRename(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + f2 := f.Name() + ".new" + if err := sftp.Rename(f.Name(), f2); err != nil { + t.Fatal(err) + } + if _, err := os.Lstat(f.Name()); !os.IsNotExist(err) { + t.Fatal(err) + } + if _, err := os.Lstat(f2); err != nil { + t.Fatal(err) + } +} + +func TestClientGetwd(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + lwd, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + rwd, err := sftp.Getwd() + if err != nil { + t.Fatal(err) + } + if !path.IsAbs(rwd) { + t.Fatalf("Getwd: wanted absolute path, got %q", rwd) + } + if lwd != rwd { + t.Fatalf("Getwd: want %q, got %q", lwd, rwd) + } +} + +func TestClientReadLink(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + f2 := f.Name() + ".sym" + if err := os.Symlink(f.Name(), f2); err != nil { + t.Fatal(err) + } + if rl, err := sftp.ReadLink(f2); err != nil { + t.Fatal(err) + } else if rl != f.Name() { + t.Fatalf("unexpected link target: %v, not %v", rl, f.Name()) + } +} + +func TestClientSymlink(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + f2 := f.Name() + ".sym" + if err := sftp.Symlink(f.Name(), f2); err != nil { + t.Fatal(err) + } + if rl, err := sftp.ReadLink(f2); err != nil { + t.Fatal(err) + } else if rl != f.Name() { + t.Fatalf("unexpected link target: %v, not %v", rl, f.Name()) + } +} + +func TestClientChmod(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + if err := sftp.Chmod(f.Name(), 0531); err != nil { + t.Fatal(err) + } + if stat, err := os.Stat(f.Name()); err != nil { + t.Fatal(err) + } else if stat.Mode()&os.ModePerm != 0531 { + t.Fatalf("invalid perm %o\n", stat.Mode()) + } +} + +func TestClientChmodReadonly(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + if err := sftp.Chmod(f.Name(), 0531); err == nil { + t.Fatal("expected error") + } +} + +func TestClientChown(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + usr, err := user.Current() + if err != nil { + t.Fatal(err) + } + chownto, err := user.Lookup("daemon") // seems common-ish... + if err != nil { + t.Fatal(err) + } + + if usr.Uid != "0" { + t.Log("must be root to run chown tests") + t.Skip() + } + toUID, err := strconv.Atoi(chownto.Uid) + if err != nil { + t.Fatal(err) + } + toGID, err := strconv.Atoi(chownto.Gid) + if err != nil { + t.Fatal(err) + } + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + before, err := exec.Command("ls", "-nl", f.Name()).Output() + if err != nil { + t.Fatal(err) + } + if err := sftp.Chown(f.Name(), toUID, toGID); err != nil { + t.Fatal(err) + } + after, err := exec.Command("ls", "-nl", f.Name()).Output() + if err != nil { + t.Fatal(err) + } + + spaceRegex := regexp.MustCompile(`\s+`) + + beforeWords := spaceRegex.Split(string(before), -1) + if beforeWords[2] != "0" { + t.Fatalf("bad previous user? should be root") + } + afterWords := spaceRegex.Split(string(after), -1) + if afterWords[2] != chownto.Uid || afterWords[3] != chownto.Gid { + t.Fatalf("bad chown: %#v", afterWords) + } + t.Logf("before: %v", string(before)) + t.Logf(" after: %v", string(after)) +} + +func TestClientChownReadonly(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + usr, err := user.Current() + if err != nil { + t.Fatal(err) + } + chownto, err := user.Lookup("daemon") // seems common-ish... + if err != nil { + t.Fatal(err) + } + + if usr.Uid != "0" { + t.Log("must be root to run chown tests") + t.Skip() + } + toUID, err := strconv.Atoi(chownto.Uid) + if err != nil { + t.Fatal(err) + } + toGID, err := strconv.Atoi(chownto.Gid) + if err != nil { + t.Fatal(err) + } + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + if err := sftp.Chown(f.Name(), toUID, toGID); err == nil { + t.Fatal("expected error") + } +} + +func TestClientChtimes(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + + atime := time.Date(2013, 2, 23, 13, 24, 35, 0, time.UTC) + mtime := time.Date(1985, 6, 12, 6, 6, 6, 0, time.UTC) + if err := sftp.Chtimes(f.Name(), atime, mtime); err != nil { + t.Fatal(err) + } + if stat, err := os.Stat(f.Name()); err != nil { + t.Fatal(err) + } else if stat.ModTime().Sub(mtime) != 0 { + t.Fatalf("incorrect mtime: %v vs %v", stat.ModTime(), mtime) + } +} + +func TestClientChtimesReadonly(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + + atime := time.Date(2013, 2, 23, 13, 24, 35, 0, time.UTC) + mtime := time.Date(1985, 6, 12, 6, 6, 6, 0, time.UTC) + if err := sftp.Chtimes(f.Name(), atime, mtime); err == nil { + t.Fatal("expected error") + } +} + +func TestClientTruncate(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + fname := f.Name() + + if n, err := f.Write([]byte("hello world")); n != 11 || err != nil { + t.Fatal(err) + } + f.Close() + + if err := sftp.Truncate(fname, 5); err != nil { + t.Fatal(err) + } + if stat, err := os.Stat(fname); err != nil { + t.Fatal(err) + } else if stat.Size() != 5 { + t.Fatalf("unexpected size: %d", stat.Size()) + } +} + +func TestClientTruncateReadonly(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + fname := f.Name() + + if n, err := f.Write([]byte("hello world")); n != 11 || err != nil { + t.Fatal(err) + } + f.Close() + + if err := sftp.Truncate(fname, 5); err == nil { + t.Fatal("expected error") + } + if stat, err := os.Stat(fname); err != nil { + t.Fatal(err) + } else if stat.Size() != 11 { + t.Fatalf("unexpected size: %d", stat.Size()) + } +} + +func sameFile(want, got os.FileInfo) bool { + return want.Name() == got.Name() && + want.Size() == got.Size() +} + +func TestClientReadSimple(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + d, err := ioutil.TempDir("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(d) + + f, err := ioutil.TempFile(d, "read-test") + if err != nil { + t.Fatal(err) + } + fname := f.Name() + f.Write([]byte("hello")) + f.Close() + + f2, err := sftp.Open(fname) + if err != nil { + t.Fatal(err) + } + defer f2.Close() + stuff := make([]byte, 32) + n, err := f2.Read(stuff) + if err != nil && err != io.EOF { + t.Fatalf("err: %v", err) + } + if n != 5 { + t.Fatalf("n should be 5, is %v", n) + } + if string(stuff[0:5]) != "hello" { + t.Fatalf("invalid contents") + } +} + +func TestClientReadDir(t *testing.T) { + sftp1, cmd1 := testClient(t, READONLY, NO_DELAY) + sftp2, cmd2 := testClientGoSvr(t, READONLY, NO_DELAY) + defer cmd1.Wait() + defer cmd2.Wait() + defer sftp1.Close() + defer sftp2.Close() + + dir := "/dev/" + + d, err := os.Open(dir) + if err != nil { + t.Fatal(err) + } + defer d.Close() + osfiles, err := d.Readdir(4096) + if err != nil { + t.Fatal(err) + } + + sftp1Files, err := sftp1.ReadDir(dir) + if err != nil { + t.Fatal(err) + } + sftp2Files, err := sftp2.ReadDir(dir) + if err != nil { + t.Fatal(err) + } + + osFilesByName := map[string]os.FileInfo{} + for _, f := range osfiles { + osFilesByName[f.Name()] = f + } + sftp1FilesByName := map[string]os.FileInfo{} + for _, f := range sftp1Files { + sftp1FilesByName[f.Name()] = f + } + sftp2FilesByName := map[string]os.FileInfo{} + for _, f := range sftp2Files { + sftp2FilesByName[f.Name()] = f + } + + if len(osFilesByName) != len(sftp1FilesByName) || len(sftp1FilesByName) != len(sftp2FilesByName) { + t.Fatalf("os gives %v, sftp1 gives %v, sftp2 gives %v", len(osFilesByName), len(sftp1FilesByName), len(sftp2FilesByName)) + } + + for name, osF := range osFilesByName { + sftp1F, ok := sftp1FilesByName[name] + if !ok { + t.Fatalf("%v present in os but not sftp1", name) + } + sftp2F, ok := sftp2FilesByName[name] + if !ok { + t.Fatalf("%v present in os but not sftp2", name) + } + + //t.Logf("%v: %v %v %v", name, osF, sftp1F, sftp2F) + if osF.Size() != sftp1F.Size() || sftp1F.Size() != sftp2F.Size() { + t.Fatalf("size %v %v %v", osF.Size(), sftp1F.Size(), sftp2F.Size()) + } + if osF.IsDir() != sftp1F.IsDir() || sftp1F.IsDir() != sftp2F.IsDir() { + t.Fatalf("isdir %v %v %v", osF.IsDir(), sftp1F.IsDir(), sftp2F.IsDir()) + } + if osF.ModTime().Sub(sftp1F.ModTime()) > time.Second || sftp1F.ModTime() != sftp2F.ModTime() { + t.Fatalf("modtime %v %v %v", osF.ModTime(), sftp1F.ModTime(), sftp2F.ModTime()) + } + if osF.Mode() != sftp1F.Mode() || sftp1F.Mode() != sftp2F.Mode() { + t.Fatalf("mode %x %x %x", osF.Mode(), sftp1F.Mode(), sftp2F.Mode()) + } + } +} + +var clientReadTests = []struct { + n int64 +}{ + {0}, + {1}, + {1000}, + {1024}, + {1025}, + {2048}, + {4096}, + {1 << 12}, + {1 << 13}, + {1 << 14}, + {1 << 15}, + {1 << 16}, + {1 << 17}, + {1 << 18}, + {1 << 19}, + {1 << 20}, +} + +func TestClientRead(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + d, err := ioutil.TempDir("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(d) + + for _, tt := range clientReadTests { + f, err := ioutil.TempFile(d, "read-test") + if err != nil { + t.Fatal(err) + } + defer f.Close() + hash := writeN(t, f, tt.n) + f2, err := sftp.Open(f.Name()) + if err != nil { + t.Fatal(err) + } + defer f2.Close() + hash2, n := readHash(t, f2) + if hash != hash2 || tt.n != n { + t.Errorf("Read: hash: want: %q, got %q, read: want: %v, got %v", hash, hash2, tt.n, n) + } + } +} + +// readHash reads r until EOF returning the number of bytes read +// and the hash of the contents. +func readHash(t *testing.T, r io.Reader) (string, int64) { + h := sha1.New() + tr := io.TeeReader(r, h) + read, err := io.Copy(ioutil.Discard, tr) + if err != nil { + t.Fatal(err) + } + return string(h.Sum(nil)), read +} + +// writeN writes n bytes of random data to w and returns the +// hash of that data. +func writeN(t *testing.T, w io.Writer, n int64) string { + rand, err := os.Open("/dev/urandom") + if err != nil { + t.Fatal(err) + } + defer rand.Close() + + h := sha1.New() + + mw := io.MultiWriter(w, h) + + written, err := io.CopyN(mw, rand, n) + if err != nil { + t.Fatal(err) + } + if written != n { + t.Fatalf("CopyN(%v): wrote: %v", n, written) + } + return string(h.Sum(nil)) +} + +var clientWriteTests = []struct { + n int + total int64 // cumulative file size +}{ + {0, 0}, + {1, 1}, + {0, 1}, + {999, 1000}, + {24, 1024}, + {1023, 2047}, + {2048, 4095}, + {1 << 12, 8191}, + {1 << 13, 16383}, + {1 << 14, 32767}, + {1 << 15, 65535}, + {1 << 16, 131071}, + {1 << 17, 262143}, + {1 << 18, 524287}, + {1 << 19, 1048575}, + {1 << 20, 2097151}, + {1 << 21, 4194303}, +} + +func TestClientWrite(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + d, err := ioutil.TempDir("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(d) + + f := path.Join(d, "writeTest") + w, err := sftp.Create(f) + if err != nil { + t.Fatal(err) + } + defer w.Close() + + for _, tt := range clientWriteTests { + got, err := w.Write(make([]byte, tt.n)) + if err != nil { + t.Fatal(err) + } + if got != tt.n { + t.Errorf("Write(%v): wrote: want: %v, got %v", tt.n, tt.n, got) + } + fi, err := os.Stat(f) + if err != nil { + t.Fatal(err) + } + if total := fi.Size(); total != tt.total { + t.Errorf("Write(%v): size: want: %v, got %v", tt.n, tt.total, total) + } + } +} + +// ReadFrom is basically Write with io.Reader as the arg +func TestClientReadFrom(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + d, err := ioutil.TempDir("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(d) + + f := path.Join(d, "writeTest") + w, err := sftp.Create(f) + if err != nil { + t.Fatal(err) + } + defer w.Close() + + for _, tt := range clientWriteTests { + got, err := w.ReadFrom(bytes.NewReader(make([]byte, tt.n))) + if err != nil { + t.Fatal(err) + } + if got != int64(tt.n) { + t.Errorf("Write(%v): wrote: want: %v, got %v", tt.n, tt.n, got) + } + fi, err := os.Stat(f) + if err != nil { + t.Fatal(err) + } + if total := fi.Size(); total != tt.total { + t.Errorf("Write(%v): size: want: %v, got %v", tt.n, tt.total, total) + } + } +} + +// Issue #145 in github +// Deadlock in ReadFrom when network drops after 1 good packet. +// Deadlock would occur anytime desiredInFlight-inFlight==2 and 2 errors +// occured in a row. The channel to report the errors only had a buffer +// of 1 and 2 would be sent. +var fakeNetErr = errors.New("Fake network issue") + +func TestClientReadFromDeadlock(t *testing.T) { + clientWriteDeadlock(t, 1, func(f *File) { + b := make([]byte, 32768*4) + content := bytes.NewReader(b) + n, err := f.ReadFrom(content) + if n != 0 { + t.Fatal("Write should return 0", n) + } + if err != fakeNetErr { + t.Fatal("Didn't recieve correct error", err) + } + }) +} + +// Write has exact same problem +func TestClientWriteDeadlock(t *testing.T) { + clientWriteDeadlock(t, 1, func(f *File) { + b := make([]byte, 32768*4) + n, err := f.Write(b) + if n != 0 { + t.Fatal("Write should return 0", n) + } + if err != fakeNetErr { + t.Fatal("Didn't recieve correct error", err) + } + }) +} + +// shared body for both previous tests +func clientWriteDeadlock(t *testing.T, N int, badfunc func(*File)) { + if !*testServerImpl { + t.Skipf("skipping without -testserver") + } + sftp, cmd := testClient(t, READWRITE, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + d, err := ioutil.TempDir("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(d) + + f := path.Join(d, "writeTest") + w, err := sftp.Create(f) + if err != nil { + t.Fatal(err) + } + defer w.Close() + + // Override sendPacket with failing version + // Replicates network error/drop part way through (after 1 good packet) + count := 0 + sendPacketTest := func(w io.Writer, m encoding.BinaryMarshaler) error { + count++ + if count > N { + return fakeNetErr + } + return sendPacket(w, m) + } + sftp.clientConn.conn.sendPacketTest = sendPacketTest + defer func() { + sftp.clientConn.conn.sendPacketTest = nil + }() + + // this locked (before the fix) + badfunc(w) +} + +// Read/WriteTo has this issue as well +func TestClientReadDeadlock(t *testing.T) { + clientReadDeadlock(t, 1, func(f *File) { + b := make([]byte, 32768*4) + n, err := f.Read(b) + if n != 0 { + t.Fatal("Write should return 0", n) + } + if err != fakeNetErr { + t.Fatal("Didn't recieve correct error", err) + } + }) +} + +func TestClientWriteToDeadlock(t *testing.T) { + clientReadDeadlock(t, 2, func(f *File) { + b := make([]byte, 32768*4) + buf := bytes.NewBuffer(b) + n, err := f.WriteTo(buf) + if n != 32768 { + t.Fatal("Write should return 0", n) + } + if err != fakeNetErr { + t.Fatal("Didn't recieve correct error", err) + } + }) +} + +func clientReadDeadlock(t *testing.T, N int, badfunc func(*File)) { + if !*testServerImpl { + t.Skipf("skipping without -testserver") + } + sftp, cmd := testClient(t, READWRITE, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + d, err := ioutil.TempDir("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(d) + + f := path.Join(d, "writeTest") + w, err := sftp.Create(f) + if err != nil { + t.Fatal(err) + } + // write the data for the read tests + b := make([]byte, 32768*4) + w.Write(b) + defer w.Close() + + // open new copy of file for read tests + r, err := sftp.Open(f) + if err != nil { + t.Fatal(err) + } + defer r.Close() + + // Override sendPacket with failing version + // Replicates network error/drop part way through (after 1 good packet) + count := 0 + sendPacketTest := func(w io.Writer, m encoding.BinaryMarshaler) error { + count++ + if count > N { + return fakeNetErr + } + return sendPacket(w, m) + } + sftp.clientConn.conn.sendPacketTest = sendPacketTest + defer func() { + sftp.clientConn.conn.sendPacketTest = nil + }() + + // this locked (before the fix) + badfunc(r) +} + +// taken from github.com/kr/fs/walk_test.go + +type Node struct { + name string + entries []*Node // nil if the entry is a file + mark int +} + +var tree = &Node{ + "testdata", + []*Node{ + {"a", nil, 0}, + {"b", []*Node{}, 0}, + {"c", nil, 0}, + { + "d", + []*Node{ + {"x", nil, 0}, + {"y", []*Node{}, 0}, + { + "z", + []*Node{ + {"u", nil, 0}, + {"v", nil, 0}, + }, + 0, + }, + }, + 0, + }, + }, + 0, +} + +func walkTree(n *Node, path string, f func(path string, n *Node)) { + f(path, n) + for _, e := range n.entries { + walkTree(e, filepath.Join(path, e.name), f) + } +} + +func makeTree(t *testing.T) { + walkTree(tree, tree.name, func(path string, n *Node) { + if n.entries == nil { + fd, err := os.Create(path) + if err != nil { + t.Errorf("makeTree: %v", err) + return + } + fd.Close() + } else { + os.Mkdir(path, 0770) + } + }) +} + +func markTree(n *Node) { walkTree(n, "", func(path string, n *Node) { n.mark++ }) } + +func checkMarks(t *testing.T, report bool) { + walkTree(tree, tree.name, func(path string, n *Node) { + if n.mark != 1 && report { + t.Errorf("node %s mark = %d; expected 1", path, n.mark) + } + n.mark = 0 + }) +} + +// Assumes that each node name is unique. Good enough for a test. +// If clear is true, any incoming error is cleared before return. The errors +// are always accumulated, though. +func mark(path string, info os.FileInfo, err error, errors *[]error, clear bool) error { + if err != nil { + *errors = append(*errors, err) + if clear { + return nil + } + return err + } + name := info.Name() + walkTree(tree, tree.name, func(path string, n *Node) { + if n.name == name { + n.mark++ + } + }) + return nil +} + +func TestClientWalk(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + makeTree(t) + errors := make([]error, 0, 10) + clear := true + markFn := func(walker *fs.Walker) error { + for walker.Step() { + err := mark(walker.Path(), walker.Stat(), walker.Err(), &errors, clear) + if err != nil { + return err + } + } + return nil + } + // Expect no errors. + err := markFn(sftp.Walk(tree.name)) + if err != nil { + t.Fatalf("no error expected, found: %s", err) + } + if len(errors) != 0 { + t.Fatalf("unexpected errors: %s", errors) + } + checkMarks(t, true) + errors = errors[0:0] + + // Test permission errors. Only possible if we're not root + // and only on some file systems (AFS, FAT). To avoid errors during + // all.bash on those file systems, skip during go test -short. + if os.Getuid() > 0 && !testing.Short() { + // introduce 2 errors: chmod top-level directories to 0 + os.Chmod(filepath.Join(tree.name, tree.entries[1].name), 0) + os.Chmod(filepath.Join(tree.name, tree.entries[3].name), 0) + + // 3) capture errors, expect two. + // mark respective subtrees manually + markTree(tree.entries[1]) + markTree(tree.entries[3]) + // correct double-marking of directory itself + tree.entries[1].mark-- + tree.entries[3].mark-- + err := markFn(sftp.Walk(tree.name)) + if err != nil { + t.Fatalf("expected no error return from Walk, got %s", err) + } + if len(errors) != 2 { + t.Errorf("expected 2 errors, got %d: %s", len(errors), errors) + } + // the inaccessible subtrees were marked manually + checkMarks(t, true) + errors = errors[0:0] + + // 4) capture errors, stop after first error. + // mark respective subtrees manually + markTree(tree.entries[1]) + markTree(tree.entries[3]) + // correct double-marking of directory itself + tree.entries[1].mark-- + tree.entries[3].mark-- + clear = false // error will stop processing + err = markFn(sftp.Walk(tree.name)) + if err == nil { + t.Fatalf("expected error return from Walk") + } + if len(errors) != 1 { + t.Errorf("expected 1 error, got %d: %s", len(errors), errors) + } + // the inaccessible subtrees were marked manually + checkMarks(t, false) + errors = errors[0:0] + + // restore permissions + os.Chmod(filepath.Join(tree.name, tree.entries[1].name), 0770) + os.Chmod(filepath.Join(tree.name, tree.entries[3].name), 0770) + } + + // cleanup + if err := os.RemoveAll(tree.name); err != nil { + t.Errorf("removeTree: %v", err) + } +} + +type MatchTest struct { + pattern, s string + match bool + err error +} + +var matchTests = []MatchTest{ + {"abc", "abc", true, nil}, + {"*", "abc", true, nil}, + {"*c", "abc", true, nil}, + {"a*", "a", true, nil}, + {"a*", "abc", true, nil}, + {"a*", "ab/c", false, nil}, + {"a*/b", "abc/b", true, nil}, + {"a*/b", "a/c/b", false, nil}, + {"a*b*c*d*e*/f", "axbxcxdxe/f", true, nil}, + {"a*b*c*d*e*/f", "axbxcxdxexxx/f", true, nil}, + {"a*b*c*d*e*/f", "axbxcxdxe/xxx/f", false, nil}, + {"a*b*c*d*e*/f", "axbxcxdxexxx/fff", false, nil}, + {"a*b?c*x", "abxbbxdbxebxczzx", true, nil}, + {"a*b?c*x", "abxbbxdbxebxczzy", false, nil}, + {"ab[c]", "abc", true, nil}, + {"ab[b-d]", "abc", true, nil}, + {"ab[e-g]", "abc", false, nil}, + {"ab[^c]", "abc", false, nil}, + {"ab[^b-d]", "abc", false, nil}, + {"ab[^e-g]", "abc", true, nil}, + {"a\\*b", "a*b", true, nil}, + {"a\\*b", "ab", false, nil}, + {"a?b", "a☺b", true, nil}, + {"a[^a]b", "a☺b", true, nil}, + {"a???b", "a☺b", false, nil}, + {"a[^a][^a][^a]b", "a☺b", false, nil}, + {"[a-ζ]*", "α", true, nil}, + {"*[a-ζ]", "A", false, nil}, + {"a?b", "a/b", false, nil}, + {"a*b", "a/b", false, nil}, + {"[\\]a]", "]", true, nil}, + {"[\\-]", "-", true, nil}, + {"[x\\-]", "x", true, nil}, + {"[x\\-]", "-", true, nil}, + {"[x\\-]", "z", false, nil}, + {"[\\-x]", "x", true, nil}, + {"[\\-x]", "-", true, nil}, + {"[\\-x]", "a", false, nil}, + {"[]a]", "]", false, ErrBadPattern}, + {"[-]", "-", false, ErrBadPattern}, + {"[x-]", "x", false, ErrBadPattern}, + {"[x-]", "-", false, ErrBadPattern}, + {"[x-]", "z", false, ErrBadPattern}, + {"[-x]", "x", false, ErrBadPattern}, + {"[-x]", "-", false, ErrBadPattern}, + {"[-x]", "a", false, ErrBadPattern}, + {"\\", "a", false, ErrBadPattern}, + {"[a-b-c]", "a", false, ErrBadPattern}, + {"[", "a", false, ErrBadPattern}, + {"[^", "a", false, ErrBadPattern}, + {"[^bc", "a", false, ErrBadPattern}, + {"a[", "a", false, nil}, + {"a[", "ab", false, ErrBadPattern}, + {"*x", "xxx", true, nil}, +} + +func errp(e error) string { + if e == nil { + return "" + } + return e.Error() +} + +// contains returns true if vector contains the string s. +func contains(vector []string, s string) bool { + for _, elem := range vector { + if elem == s { + return true + } + } + return false +} + +var globTests = []struct { + pattern, result string +}{ + {"match.go", "match.go"}, + {"mat?h.go", "match.go"}, + {"ma*ch.go", "match.go"}, + {"../*/match.go", "../sftp/match.go"}, +} + +type globTest struct { + pattern string + matches []string +} + +func (test *globTest) buildWant(root string) []string { + var want []string + for _, m := range test.matches { + want = append(want, root+filepath.FromSlash(m)) + } + sort.Strings(want) + return want +} + +func TestMatch(t *testing.T) { + for _, tt := range matchTests { + pattern := tt.pattern + s := tt.s + ok, err := Match(pattern, s) + if ok != tt.match || err != tt.err { + t.Errorf("Match(%#q, %#q) = %v, %q want %v, %q", pattern, s, ok, errp(err), tt.match, errp(tt.err)) + } + } +} + +func TestGlob(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + for _, tt := range globTests { + pattern := tt.pattern + result := tt.result + matches, err := sftp.Glob(pattern) + if err != nil { + t.Errorf("Glob error for %q: %s", pattern, err) + continue + } + if !contains(matches, result) { + t.Errorf("Glob(%#q) = %#v want %v", pattern, matches, result) + } + } + for _, pattern := range []string{"no_match", "../*/no_match"} { + matches, err := sftp.Glob(pattern) + if err != nil { + t.Errorf("Glob error for %q: %s", pattern, err) + continue + } + if len(matches) != 0 { + t.Errorf("Glob(%#q) = %#v want []", pattern, matches) + } + } +} + +func TestGlobError(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + _, err := sftp.Glob("[7]") + if err != nil { + t.Error("expected error for bad pattern; got none") + } +} + +func TestGlobUNC(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + // Just make sure this runs without crashing for now. + // See issue 15879. + sftp.Glob(`\\?\C:\*`) +} + +// sftp/issue/42, abrupt server hangup would result in client hangs. +func TestServerRoughDisconnect(t *testing.T) { + if *testServerImpl { + t.Skipf("skipping with -testserver") + } + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := sftp.Open("/dev/zero") + if err != nil { + t.Fatal(err) + } + defer f.Close() + go func() { + time.Sleep(100 * time.Millisecond) + cmd.Process.Kill() + }() + + io.Copy(ioutil.Discard, f) +} + +// sftp/issue/181, abrupt server hangup would result in client hangs. +// due to broadcastErr filling up the request channel +// this reproduces it about 50% of the time +func TestServerRoughDisconnect2(t *testing.T) { + if *testServerImpl { + t.Skipf("skipping with -testserver") + } + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := sftp.Open("/dev/zero") + if err != nil { + t.Fatal(err) + } + defer f.Close() + b := make([]byte, 32768*100) + go func() { + time.Sleep(1 * time.Millisecond) + cmd.Process.Kill() + }() + for { + _, err = f.Read(b) + if err != nil { + break + } + } +} + +// sftp/issue/26 writing to a read only file caused client to loop. +func TestClientWriteToROFile(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := sftp.Open("/dev/zero") + if err != nil { + t.Fatal(err) + } + defer f.Close() + _, err = f.Write([]byte("hello")) + if err == nil { + t.Fatal("expected error, got", err) + } +} + +func benchmarkRead(b *testing.B, bufsize int, delay time.Duration) { + size := 10*1024*1024 + 123 // ~10MiB + + // open sftp client + sftp, cmd := testClient(b, READONLY, delay) + defer cmd.Wait() + // defer sftp.Close() + + buf := make([]byte, bufsize) + + b.ResetTimer() + b.SetBytes(int64(size)) + + for i := 0; i < b.N; i++ { + offset := 0 + + f2, err := sftp.Open("/dev/zero") + if err != nil { + b.Fatal(err) + } + defer f2.Close() + + for offset < size { + n, err := io.ReadFull(f2, buf) + offset += n + if err == io.ErrUnexpectedEOF && offset != size { + b.Fatalf("read too few bytes! want: %d, got: %d", size, n) + } + + if err != nil { + b.Fatal(err) + } + + offset += n + } + } +} + +func BenchmarkRead1k(b *testing.B) { + benchmarkRead(b, 1*1024, NO_DELAY) +} + +func BenchmarkRead16k(b *testing.B) { + benchmarkRead(b, 16*1024, NO_DELAY) +} + +func BenchmarkRead32k(b *testing.B) { + benchmarkRead(b, 32*1024, NO_DELAY) +} + +func BenchmarkRead128k(b *testing.B) { + benchmarkRead(b, 128*1024, NO_DELAY) +} + +func BenchmarkRead512k(b *testing.B) { + benchmarkRead(b, 512*1024, NO_DELAY) +} + +func BenchmarkRead1MiB(b *testing.B) { + benchmarkRead(b, 1024*1024, NO_DELAY) +} + +func BenchmarkRead4MiB(b *testing.B) { + benchmarkRead(b, 4*1024*1024, NO_DELAY) +} + +func BenchmarkRead4MiBDelay10Msec(b *testing.B) { + benchmarkRead(b, 4*1024*1024, 10*time.Millisecond) +} + +func BenchmarkRead4MiBDelay50Msec(b *testing.B) { + benchmarkRead(b, 4*1024*1024, 50*time.Millisecond) +} + +func BenchmarkRead4MiBDelay150Msec(b *testing.B) { + benchmarkRead(b, 4*1024*1024, 150*time.Millisecond) +} + +func benchmarkWrite(b *testing.B, bufsize int, delay time.Duration) { + size := 10*1024*1024 + 123 // ~10MiB + + // open sftp client + sftp, cmd := testClient(b, false, delay) + defer cmd.Wait() + // defer sftp.Close() + + data := make([]byte, size) + + b.ResetTimer() + b.SetBytes(int64(size)) + + for i := 0; i < b.N; i++ { + offset := 0 + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + b.Fatal(err) + } + defer os.Remove(f.Name()) + + f2, err := sftp.Create(f.Name()) + if err != nil { + b.Fatal(err) + } + defer f2.Close() + + for offset < size { + n, err := f2.Write(data[offset:min(len(data), offset+bufsize)]) + if err != nil { + b.Fatal(err) + } + + if offset+n < size && n != bufsize { + b.Fatalf("wrote too few bytes! want: %d, got: %d", size, n) + } + + offset += n + } + + f2.Close() + + fi, err := os.Stat(f.Name()) + if err != nil { + b.Fatal(err) + } + + if fi.Size() != int64(size) { + b.Fatalf("wrong file size: want %d, got %d", size, fi.Size()) + } + + os.Remove(f.Name()) + } +} + +func BenchmarkWrite1k(b *testing.B) { + benchmarkWrite(b, 1*1024, NO_DELAY) +} + +func BenchmarkWrite16k(b *testing.B) { + benchmarkWrite(b, 16*1024, NO_DELAY) +} + +func BenchmarkWrite32k(b *testing.B) { + benchmarkWrite(b, 32*1024, NO_DELAY) +} + +func BenchmarkWrite128k(b *testing.B) { + benchmarkWrite(b, 128*1024, NO_DELAY) +} + +func BenchmarkWrite512k(b *testing.B) { + benchmarkWrite(b, 512*1024, NO_DELAY) +} + +func BenchmarkWrite1MiB(b *testing.B) { + benchmarkWrite(b, 1024*1024, NO_DELAY) +} + +func BenchmarkWrite4MiB(b *testing.B) { + benchmarkWrite(b, 4*1024*1024, NO_DELAY) +} + +func BenchmarkWrite4MiBDelay10Msec(b *testing.B) { + benchmarkWrite(b, 4*1024*1024, 10*time.Millisecond) +} + +func BenchmarkWrite4MiBDelay50Msec(b *testing.B) { + benchmarkWrite(b, 4*1024*1024, 50*time.Millisecond) +} + +func BenchmarkWrite4MiBDelay150Msec(b *testing.B) { + benchmarkWrite(b, 4*1024*1024, 150*time.Millisecond) +} + +func benchmarkReadFrom(b *testing.B, bufsize int, delay time.Duration) { + size := 10*1024*1024 + 123 // ~10MiB + + // open sftp client + sftp, cmd := testClient(b, false, delay) + defer cmd.Wait() + // defer sftp.Close() + + data := make([]byte, size) + + b.ResetTimer() + b.SetBytes(int64(size)) + + for i := 0; i < b.N; i++ { + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + b.Fatal(err) + } + defer os.Remove(f.Name()) + + f2, err := sftp.Create(f.Name()) + if err != nil { + b.Fatal(err) + } + defer f2.Close() + + f2.ReadFrom(bytes.NewReader(data)) + f2.Close() + + fi, err := os.Stat(f.Name()) + if err != nil { + b.Fatal(err) + } + + if fi.Size() != int64(size) { + b.Fatalf("wrong file size: want %d, got %d", size, fi.Size()) + } + + os.Remove(f.Name()) + } +} + +func BenchmarkReadFrom1k(b *testing.B) { + benchmarkReadFrom(b, 1*1024, NO_DELAY) +} + +func BenchmarkReadFrom16k(b *testing.B) { + benchmarkReadFrom(b, 16*1024, NO_DELAY) +} + +func BenchmarkReadFrom32k(b *testing.B) { + benchmarkReadFrom(b, 32*1024, NO_DELAY) +} + +func BenchmarkReadFrom128k(b *testing.B) { + benchmarkReadFrom(b, 128*1024, NO_DELAY) +} + +func BenchmarkReadFrom512k(b *testing.B) { + benchmarkReadFrom(b, 512*1024, NO_DELAY) +} + +func BenchmarkReadFrom1MiB(b *testing.B) { + benchmarkReadFrom(b, 1024*1024, NO_DELAY) +} + +func BenchmarkReadFrom4MiB(b *testing.B) { + benchmarkReadFrom(b, 4*1024*1024, NO_DELAY) +} + +func BenchmarkReadFrom4MiBDelay10Msec(b *testing.B) { + benchmarkReadFrom(b, 4*1024*1024, 10*time.Millisecond) +} + +func BenchmarkReadFrom4MiBDelay50Msec(b *testing.B) { + benchmarkReadFrom(b, 4*1024*1024, 50*time.Millisecond) +} + +func BenchmarkReadFrom4MiBDelay150Msec(b *testing.B) { + benchmarkReadFrom(b, 4*1024*1024, 150*time.Millisecond) +} + +func benchmarkCopyDown(b *testing.B, fileSize int64, delay time.Duration) { + // Create a temp file and fill it with zero's. + src, err := ioutil.TempFile("", "sftptest") + if err != nil { + b.Fatal(err) + } + defer src.Close() + srcFilename := src.Name() + defer os.Remove(srcFilename) + zero, err := os.Open("/dev/zero") + if err != nil { + b.Fatal(err) + } + n, err := io.Copy(src, io.LimitReader(zero, fileSize)) + if err != nil { + b.Fatal(err) + } + if n < fileSize { + b.Fatal("short copy") + } + zero.Close() + src.Close() + + sftp, cmd := testClient(b, READONLY, delay) + defer cmd.Wait() + // defer sftp.Close() + b.ResetTimer() + b.SetBytes(fileSize) + + for i := 0; i < b.N; i++ { + dst, err := ioutil.TempFile("", "sftptest") + if err != nil { + b.Fatal(err) + } + defer os.Remove(dst.Name()) + + src, err := sftp.Open(srcFilename) + if err != nil { + b.Fatal(err) + } + defer src.Close() + n, err := io.Copy(dst, src) + if err != nil { + b.Fatalf("copy error: %v", err) + } + if n < fileSize { + b.Fatal("unable to copy all bytes") + } + dst.Close() + fi, err := os.Stat(dst.Name()) + if err != nil { + b.Fatal(err) + } + + if fi.Size() != fileSize { + b.Fatalf("wrong file size: want %d, got %d", fileSize, fi.Size()) + } + os.Remove(dst.Name()) + } +} + +func BenchmarkCopyDown10MiBDelay10Msec(b *testing.B) { + benchmarkCopyDown(b, 10*1024*1024, 10*time.Millisecond) +} + +func BenchmarkCopyDown10MiBDelay50Msec(b *testing.B) { + benchmarkCopyDown(b, 10*1024*1024, 50*time.Millisecond) +} + +func BenchmarkCopyDown10MiBDelay150Msec(b *testing.B) { + benchmarkCopyDown(b, 10*1024*1024, 150*time.Millisecond) +} + +func benchmarkCopyUp(b *testing.B, fileSize int64, delay time.Duration) { + // Create a temp file and fill it with zero's. + src, err := ioutil.TempFile("", "sftptest") + if err != nil { + b.Fatal(err) + } + defer src.Close() + srcFilename := src.Name() + defer os.Remove(srcFilename) + zero, err := os.Open("/dev/zero") + if err != nil { + b.Fatal(err) + } + n, err := io.Copy(src, io.LimitReader(zero, fileSize)) + if err != nil { + b.Fatal(err) + } + if n < fileSize { + b.Fatal("short copy") + } + zero.Close() + src.Close() + + sftp, cmd := testClient(b, false, delay) + defer cmd.Wait() + // defer sftp.Close() + + b.ResetTimer() + b.SetBytes(fileSize) + + for i := 0; i < b.N; i++ { + tmp, err := ioutil.TempFile("", "sftptest") + if err != nil { + b.Fatal(err) + } + tmp.Close() + defer os.Remove(tmp.Name()) + + dst, err := sftp.Create(tmp.Name()) + if err != nil { + b.Fatal(err) + } + defer dst.Close() + src, err := os.Open(srcFilename) + if err != nil { + b.Fatal(err) + } + defer src.Close() + n, err := io.Copy(dst, src) + if err != nil { + b.Fatalf("copy error: %v", err) + } + if n < fileSize { + b.Fatal("unable to copy all bytes") + } + + fi, err := os.Stat(tmp.Name()) + if err != nil { + b.Fatal(err) + } + + if fi.Size() != fileSize { + b.Fatalf("wrong file size: want %d, got %d", fileSize, fi.Size()) + } + os.Remove(tmp.Name()) + } +} + +func BenchmarkCopyUp10MiBDelay10Msec(b *testing.B) { + benchmarkCopyUp(b, 10*1024*1024, 10*time.Millisecond) +} + +func BenchmarkCopyUp10MiBDelay50Msec(b *testing.B) { + benchmarkCopyUp(b, 10*1024*1024, 50*time.Millisecond) +} + +func BenchmarkCopyUp10MiBDelay150Msec(b *testing.B) { + benchmarkCopyUp(b, 10*1024*1024, 150*time.Millisecond) +} diff --git a/github.com/pkg/sftp/client_test.go b/github.com/pkg/sftp/client_test.go new file mode 100644 index 0000000000..0c9994ce30 --- /dev/null +++ b/github.com/pkg/sftp/client_test.go @@ -0,0 +1,147 @@ +package sftp + +import ( + "errors" + "io" + "os" + "reflect" + "testing" + + "github.com/kr/fs" +) + +// assert that *Client implements fs.FileSystem +var _ fs.FileSystem = new(Client) + +// assert that *File implements io.ReadWriteCloser +var _ io.ReadWriteCloser = new(File) + +func TestNormaliseError(t *testing.T) { + var ( + ok = &StatusError{Code: ssh_FX_OK} + eof = &StatusError{Code: ssh_FX_EOF} + fail = &StatusError{Code: ssh_FX_FAILURE} + noSuchFile = &StatusError{Code: ssh_FX_NO_SUCH_FILE} + foo = errors.New("foo") + ) + + var tests = []struct { + desc string + err error + want error + }{ + { + desc: "nil error", + }, + { + desc: "not *StatusError", + err: foo, + want: foo, + }, + { + desc: "*StatusError with ssh_FX_EOF", + err: eof, + want: io.EOF, + }, + { + desc: "*StatusError with ssh_FX_NO_SUCH_FILE", + err: noSuchFile, + want: os.ErrNotExist, + }, + { + desc: "*StatusError with ssh_FX_OK", + err: ok, + }, + { + desc: "*StatusError with ssh_FX_FAILURE", + err: fail, + want: fail, + }, + } + + for _, tt := range tests { + got := normaliseError(tt.err) + if got != tt.want { + t.Errorf("normaliseError(%#v), test %q\n- want: %#v\n- got: %#v", + tt.err, tt.desc, tt.want, got) + } + } +} + +var flagsTests = []struct { + flags int + want uint32 +}{ + {os.O_RDONLY, ssh_FXF_READ}, + {os.O_WRONLY, ssh_FXF_WRITE}, + {os.O_RDWR, ssh_FXF_READ | ssh_FXF_WRITE}, + {os.O_RDWR | os.O_CREATE | os.O_TRUNC, ssh_FXF_READ | ssh_FXF_WRITE | ssh_FXF_CREAT | ssh_FXF_TRUNC}, + {os.O_WRONLY | os.O_APPEND, ssh_FXF_WRITE | ssh_FXF_APPEND}, +} + +func TestFlags(t *testing.T) { + for i, tt := range flagsTests { + got := flags(tt.flags) + if got != tt.want { + t.Errorf("test %v: flags(%x): want: %x, got: %x", i, tt.flags, tt.want, got) + } + } +} + +func TestUnmarshalStatus(t *testing.T) { + requestID := uint32(1) + + id := marshalUint32([]byte{}, requestID) + idCode := marshalUint32(id, ssh_FX_FAILURE) + idCodeMsg := marshalString(idCode, "err msg") + idCodeMsgLang := marshalString(idCodeMsg, "lang tag") + + var tests = []struct { + desc string + reqID uint32 + status []byte + want error + }{ + { + desc: "well-formed status", + reqID: 1, + status: idCodeMsgLang, + want: &StatusError{ + Code: ssh_FX_FAILURE, + msg: "err msg", + lang: "lang tag", + }, + }, + { + desc: "missing error message and language tag", + reqID: 1, + status: idCode, + want: &StatusError{ + Code: ssh_FX_FAILURE, + }, + }, + { + desc: "missing language tag", + reqID: 1, + status: idCodeMsg, + want: &StatusError{ + Code: ssh_FX_FAILURE, + msg: "err msg", + }, + }, + { + desc: "request identifier mismatch", + reqID: 2, + status: idCodeMsgLang, + want: &unexpectedIDErr{2, requestID}, + }, + } + + for _, tt := range tests { + got := unmarshalStatus(tt.reqID, tt.status) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("unmarshalStatus(%v, %v), test %q\n- want: %#v\n- got: %#v", + requestID, tt.status, tt.desc, tt.want, got) + } + } +} diff --git a/github.com/pkg/sftp/conn.go b/github.com/pkg/sftp/conn.go new file mode 100644 index 0000000000..d9e10952d6 --- /dev/null +++ b/github.com/pkg/sftp/conn.go @@ -0,0 +1,133 @@ +package sftp + +import ( + "encoding" + "io" + "sync" + + "github.com/pkg/errors" +) + +// conn implements a bidirectional channel on which client and server +// connections are multiplexed. +type conn struct { + io.Reader + io.WriteCloser + sync.Mutex // used to serialise writes to sendPacket + // sendPacketTest is needed to replicate packet issues in testing + sendPacketTest func(w io.Writer, m encoding.BinaryMarshaler) error +} + +func (c *conn) recvPacket() (uint8, []byte, error) { + return recvPacket(c) +} + +func (c *conn) sendPacket(m encoding.BinaryMarshaler) error { + c.Lock() + defer c.Unlock() + if c.sendPacketTest != nil { + return c.sendPacketTest(c, m) + } + return sendPacket(c, m) +} + +type clientConn struct { + conn + wg sync.WaitGroup + sync.Mutex // protects inflight + inflight map[uint32]chan<- result // outstanding requests +} + +// Close closes the SFTP session. +func (c *clientConn) Close() error { + defer c.wg.Wait() + return c.conn.Close() +} + +func (c *clientConn) loop() { + defer c.wg.Done() + err := c.recv() + if err != nil { + c.broadcastErr(err) + } +} + +// recv continuously reads from the server and forwards responses to the +// appropriate channel. +func (c *clientConn) recv() error { + defer func() { + c.conn.Lock() + c.conn.Close() + c.conn.Unlock() + }() + for { + typ, data, err := c.recvPacket() + if err != nil { + return err + } + sid, _ := unmarshalUint32(data) + c.Lock() + ch, ok := c.inflight[sid] + delete(c.inflight, sid) + c.Unlock() + if !ok { + // This is an unexpected occurrence. Send the error + // back to all listeners so that they terminate + // gracefully. + return errors.Errorf("sid: %v not fond", sid) + } + ch <- result{typ: typ, data: data} + } +} + +// result captures the result of receiving the a packet from the server +type result struct { + typ byte + data []byte + err error +} + +type idmarshaler interface { + id() uint32 + encoding.BinaryMarshaler +} + +func (c *clientConn) sendPacket(p idmarshaler) (byte, []byte, error) { + ch := make(chan result, 1) + c.dispatchRequest(ch, p) + s := <-ch + return s.typ, s.data, s.err +} + +func (c *clientConn) dispatchRequest(ch chan<- result, p idmarshaler) { + c.Lock() + c.inflight[p.id()] = ch + c.Unlock() + if err := c.conn.sendPacket(p); err != nil { + c.Lock() + delete(c.inflight, p.id()) + c.Unlock() + ch <- result{err: err} + } +} + +// broadcastErr sends an error to all goroutines waiting for a response. +func (c *clientConn) broadcastErr(err error) { + c.Lock() + listeners := make([]chan<- result, 0, len(c.inflight)) + for _, ch := range c.inflight { + listeners = append(listeners, ch) + } + c.Unlock() + for _, ch := range listeners { + ch <- result{err: err} + } +} + +type serverConn struct { + conn +} + +func (s *serverConn) sendError(p ider, err error) error { + return s.sendPacket(statusFromError(p, err)) +} diff --git a/github.com/pkg/sftp/debug.go b/github.com/pkg/sftp/debug.go new file mode 100644 index 0000000000..3e264abe30 --- /dev/null +++ b/github.com/pkg/sftp/debug.go @@ -0,0 +1,9 @@ +// +build debug + +package sftp + +import "log" + +func debug(fmt string, args ...interface{}) { + log.Printf(fmt, args...) +} diff --git a/github.com/pkg/sftp/example_test.go b/github.com/pkg/sftp/example_test.go new file mode 100644 index 0000000000..c6c9009a68 --- /dev/null +++ b/github.com/pkg/sftp/example_test.go @@ -0,0 +1,135 @@ +package sftp_test + +import ( + "fmt" + "log" + "os" + "os/exec" + "path" + "strings" + + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +func Example() { + var conn *ssh.Client + + // open an SFTP session over an existing ssh connection. + sftp, err := sftp.NewClient(conn) + if err != nil { + log.Fatal(err) + } + defer sftp.Close() + + // walk a directory + w := sftp.Walk("/home/user") + for w.Step() { + if w.Err() != nil { + continue + } + log.Println(w.Path()) + } + + // leave your mark + f, err := sftp.Create("hello.txt") + if err != nil { + log.Fatal(err) + } + if _, err := f.Write([]byte("Hello world!")); err != nil { + log.Fatal(err) + } + + // check it's there + fi, err := sftp.Lstat("hello.txt") + if err != nil { + log.Fatal(err) + } + log.Println(fi) +} + +func ExampleNewClientPipe() { + // Connect to a remote host and request the sftp subsystem via the 'ssh' + // command. This assumes that passwordless login is correctly configured. + cmd := exec.Command("ssh", "example.com", "-s", "sftp") + + // send errors from ssh to stderr + cmd.Stderr = os.Stderr + + // get stdin and stdout + wr, err := cmd.StdinPipe() + if err != nil { + log.Fatal(err) + } + rd, err := cmd.StdoutPipe() + if err != nil { + log.Fatal(err) + } + + // start the process + if err := cmd.Start(); err != nil { + log.Fatal(err) + } + defer cmd.Wait() + + // open the SFTP session + client, err := sftp.NewClientPipe(rd, wr) + if err != nil { + log.Fatal(err) + } + + // read a directory + list, err := client.ReadDir("/") + if err != nil { + log.Fatal(err) + } + + // print contents + for _, item := range list { + fmt.Println(item.Name()) + } + + // close the connection + client.Close() +} + +func ExampleClient_Mkdir_parents() { + // Example of mimicing 'mkdir --parents'; I.E. recursively create + // directoryies and don't error if any directories already exists. + var conn *ssh.Client + + client, err := sftp.NewClient(conn) + if err != nil { + log.Fatal(err) + } + defer client.Close() + + ssh_fx_failure := uint32(4) + mkdirParents := func(client *sftp.Client, dir string) (err error) { + var parents string + for _, name := range strings.Split(dir, "/") { + parents = path.Join(parents, name) + err = client.Mkdir(parents) + if status, ok := err.(*sftp.StatusError); ok { + if status.Code == ssh_fx_failure { + var fi os.FileInfo + fi, err = client.Stat(parents) + if err == nil { + if !fi.IsDir() { + return fmt.Errorf("File exists: %s", parents) + } + } + } + } + if err != nil { + break + } + } + return err + } + + err = mkdirParents(client, "/tmp/foo/bar") + if err != nil { + log.Fatal(err) + } +} diff --git a/github.com/pkg/sftp/examples/buffered-read-benchmark/main.go b/github.com/pkg/sftp/examples/buffered-read-benchmark/main.go new file mode 100644 index 0000000000..36ac6d726c --- /dev/null +++ b/github.com/pkg/sftp/examples/buffered-read-benchmark/main.go @@ -0,0 +1,78 @@ +// buffered-read-benchmark benchmarks the peformance of reading +// from /dev/zero on the server to a []byte on the client via io.Copy. +package main + +import ( + "flag" + "fmt" + "io" + "log" + "net" + "os" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + + "github.com/pkg/sftp" +) + +var ( + USER = flag.String("user", os.Getenv("USER"), "ssh username") + HOST = flag.String("host", "localhost", "ssh server hostname") + PORT = flag.Int("port", 22, "ssh server port") + PASS = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password") + SIZE = flag.Int("s", 1<<15, "set max packet size") +) + +func init() { + flag.Parse() +} + +func main() { + var auths []ssh.AuthMethod + if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { + auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers)) + + } + if *PASS != "" { + auths = append(auths, ssh.Password(*PASS)) + } + + config := ssh.ClientConfig{ + User: *USER, + Auth: auths, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + addr := fmt.Sprintf("%s:%d", *HOST, *PORT) + conn, err := ssh.Dial("tcp", addr, &config) + if err != nil { + log.Fatalf("unable to connect to [%s]: %v", addr, err) + } + defer conn.Close() + + c, err := sftp.NewClient(conn, sftp.MaxPacket(*SIZE)) + if err != nil { + log.Fatalf("unable to start sftp subsytem: %v", err) + } + defer c.Close() + + r, err := c.Open("/dev/zero") + if err != nil { + log.Fatal(err) + } + defer r.Close() + + const size = 1e9 + + log.Printf("reading %v bytes", size) + t1 := time.Now() + n, err := io.ReadFull(r, make([]byte, size)) + if err != nil { + log.Fatal(err) + } + if n != size { + log.Fatalf("copy: expected %v bytes, got %d", size, n) + } + log.Printf("read %v bytes in %s", size, time.Since(t1)) +} diff --git a/github.com/pkg/sftp/examples/buffered-write-benchmark/main.go b/github.com/pkg/sftp/examples/buffered-write-benchmark/main.go new file mode 100644 index 0000000000..d1babedb2d --- /dev/null +++ b/github.com/pkg/sftp/examples/buffered-write-benchmark/main.go @@ -0,0 +1,84 @@ +// buffered-write-benchmark benchmarks the peformance of writing +// a single large []byte on the client to /dev/null on the server via io.Copy. +package main + +import ( + "flag" + "fmt" + "log" + "net" + "os" + "syscall" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + + "github.com/pkg/sftp" +) + +var ( + USER = flag.String("user", os.Getenv("USER"), "ssh username") + HOST = flag.String("host", "localhost", "ssh server hostname") + PORT = flag.Int("port", 22, "ssh server port") + PASS = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password") + SIZE = flag.Int("s", 1<<15, "set max packet size") +) + +func init() { + flag.Parse() +} + +func main() { + var auths []ssh.AuthMethod + if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { + auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers)) + + } + if *PASS != "" { + auths = append(auths, ssh.Password(*PASS)) + } + + config := ssh.ClientConfig{ + User: *USER, + Auth: auths, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + addr := fmt.Sprintf("%s:%d", *HOST, *PORT) + conn, err := ssh.Dial("tcp", addr, &config) + if err != nil { + log.Fatalf("unable to connect to [%s]: %v", addr, err) + } + defer conn.Close() + + c, err := sftp.NewClient(conn, sftp.MaxPacket(*SIZE)) + if err != nil { + log.Fatalf("unable to start sftp subsytem: %v", err) + } + defer c.Close() + + w, err := c.OpenFile("/dev/null", syscall.O_WRONLY) + if err != nil { + log.Fatal(err) + } + defer w.Close() + + f, err := os.Open("/dev/zero") + if err != nil { + log.Fatal(err) + } + defer f.Close() + + const size = 1e9 + + log.Printf("writing %v bytes", size) + t1 := time.Now() + n, err := w.Write(make([]byte, size)) + if err != nil { + log.Fatal(err) + } + if n != size { + log.Fatalf("copy: expected %v bytes, got %d", size, n) + } + log.Printf("wrote %v bytes in %s", size, time.Since(t1)) +} diff --git a/github.com/pkg/sftp/examples/request-server/main.go b/github.com/pkg/sftp/examples/request-server/main.go new file mode 100644 index 0000000000..fd21b43eaf --- /dev/null +++ b/github.com/pkg/sftp/examples/request-server/main.go @@ -0,0 +1,131 @@ +// An example SFTP server implementation using the golang SSH package. +// Serves the whole filesystem visible to the user, and has a hard-coded username and password, +// so not for real use! +package main + +import ( + "flag" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "os" + + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +// Based on example server code from golang.org/x/crypto/ssh and server_standalone +func main() { + + var ( + readOnly bool + debugStderr bool + ) + + flag.BoolVar(&readOnly, "R", false, "read-only server") + flag.BoolVar(&debugStderr, "e", false, "debug to stderr") + flag.Parse() + + debugStream := ioutil.Discard + if debugStderr { + debugStream = os.Stderr + } + + // An SSH server is represented by a ServerConfig, which holds + // certificate details and handles authentication of ServerConns. + config := &ssh.ServerConfig{ + PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + // Should use constant-time compare (or better, salt+hash) in + // a production setting. + fmt.Fprintf(debugStream, "Login: %s\n", c.User()) + if c.User() == "testuser" && string(pass) == "tiger" { + return nil, nil + } + return nil, fmt.Errorf("password rejected for %q", c.User()) + }, + } + + privateBytes, err := ioutil.ReadFile("id_rsa") + if err != nil { + log.Fatal("Failed to load private key", err) + } + + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + log.Fatal("Failed to parse private key", err) + } + + config.AddHostKey(private) + + // Once a ServerConfig has been configured, connections can be + // accepted. + listener, err := net.Listen("tcp", "0.0.0.0:2022") + if err != nil { + log.Fatal("failed to listen for connection", err) + } + fmt.Printf("Listening on %v\n", listener.Addr()) + + nConn, err := listener.Accept() + if err != nil { + log.Fatal("failed to accept incoming connection", err) + } + + // Before use, a handshake must be performed on the incoming net.Conn. + sconn, chans, reqs, err := ssh.NewServerConn(nConn, config) + if err != nil { + log.Fatal("failed to handshake", err) + } + log.Println("login detected:", sconn.User()) + fmt.Fprintf(debugStream, "SSH server established\n") + + // The incoming Request channel must be serviced. + go ssh.DiscardRequests(reqs) + + // Service the incoming Channel channel. + for newChannel := range chans { + // Channels have a type, depending on the application level + // protocol intended. In the case of an SFTP session, this is "subsystem" + // with a payload string of "sftp" + fmt.Fprintf(debugStream, "Incoming channel: %s\n", newChannel.ChannelType()) + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + fmt.Fprintf(debugStream, "Unknown channel type: %s\n", newChannel.ChannelType()) + continue + } + channel, requests, err := newChannel.Accept() + if err != nil { + log.Fatal("could not accept channel.", err) + } + fmt.Fprintf(debugStream, "Channel accepted\n") + + // Sessions have out-of-band requests such as "shell", + // "pty-req" and "env". Here we handle only the + // "subsystem" request. + go func(in <-chan *ssh.Request) { + for req := range in { + fmt.Fprintf(debugStream, "Request: %v\n", req.Type) + ok := false + switch req.Type { + case "subsystem": + fmt.Fprintf(debugStream, "Subsystem: %s\n", req.Payload[4:]) + if string(req.Payload[4:]) == "sftp" { + ok = true + } + } + fmt.Fprintf(debugStream, " - accepted: %v\n", ok) + req.Reply(ok, nil) + } + }(requests) + + root := sftp.InMemHandler() + server := sftp.NewRequestServer(channel, root) + if err := server.Serve(); err == io.EOF { + server.Close() + log.Print("sftp client exited session.") + } else if err != nil { + log.Fatal("sftp server completed with error:", err) + } + } +} diff --git a/github.com/pkg/sftp/examples/sftp-server/README.md b/github.com/pkg/sftp/examples/sftp-server/README.md new file mode 100644 index 0000000000..bd96f2d8ab --- /dev/null +++ b/github.com/pkg/sftp/examples/sftp-server/README.md @@ -0,0 +1,12 @@ +Example SFTP server implementation +=== + +In order to use this example you will need an RSA key. + +On linux-like systems with openssh installed, you can use the command: + +``` +ssh-keygen -t rsa -f id_rsa +``` + +Then you will be able to run the sftp-server command in the current directory. diff --git a/github.com/pkg/sftp/examples/sftp-server/main.go b/github.com/pkg/sftp/examples/sftp-server/main.go new file mode 100644 index 0000000000..48e0e8684f --- /dev/null +++ b/github.com/pkg/sftp/examples/sftp-server/main.go @@ -0,0 +1,147 @@ +// An example SFTP server implementation using the golang SSH package. +// Serves the whole filesystem visible to the user, and has a hard-coded username and password, +// so not for real use! +package main + +import ( + "flag" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "os" + + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +// Based on example server code from golang.org/x/crypto/ssh and server_standalone +func main() { + + var ( + readOnly bool + debugStderr bool + ) + + flag.BoolVar(&readOnly, "R", false, "read-only server") + flag.BoolVar(&debugStderr, "e", false, "debug to stderr") + flag.Parse() + + debugStream := ioutil.Discard + if debugStderr { + debugStream = os.Stderr + } + + // An SSH server is represented by a ServerConfig, which holds + // certificate details and handles authentication of ServerConns. + config := &ssh.ServerConfig{ + PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + // Should use constant-time compare (or better, salt+hash) in + // a production setting. + fmt.Fprintf(debugStream, "Login: %s\n", c.User()) + if c.User() == "testuser" && string(pass) == "tiger" { + return nil, nil + } + return nil, fmt.Errorf("password rejected for %q", c.User()) + }, + } + + privateBytes, err := ioutil.ReadFile("id_rsa") + if err != nil { + log.Fatal("Failed to load private key", err) + } + + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + log.Fatal("Failed to parse private key", err) + } + + config.AddHostKey(private) + + // Once a ServerConfig has been configured, connections can be + // accepted. + listener, err := net.Listen("tcp", "0.0.0.0:2022") + if err != nil { + log.Fatal("failed to listen for connection", err) + } + fmt.Printf("Listening on %v\n", listener.Addr()) + + nConn, err := listener.Accept() + if err != nil { + log.Fatal("failed to accept incoming connection", err) + } + + // Before use, a handshake must be performed on the incoming + // net.Conn. + _, chans, reqs, err := ssh.NewServerConn(nConn, config) + if err != nil { + log.Fatal("failed to handshake", err) + } + fmt.Fprintf(debugStream, "SSH server established\n") + + // The incoming Request channel must be serviced. + go ssh.DiscardRequests(reqs) + + // Service the incoming Channel channel. + for newChannel := range chans { + // Channels have a type, depending on the application level + // protocol intended. In the case of an SFTP session, this is "subsystem" + // with a payload string of "sftp" + fmt.Fprintf(debugStream, "Incoming channel: %s\n", newChannel.ChannelType()) + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + fmt.Fprintf(debugStream, "Unknown channel type: %s\n", newChannel.ChannelType()) + continue + } + channel, requests, err := newChannel.Accept() + if err != nil { + log.Fatal("could not accept channel.", err) + } + fmt.Fprintf(debugStream, "Channel accepted\n") + + // Sessions have out-of-band requests such as "shell", + // "pty-req" and "env". Here we handle only the + // "subsystem" request. + go func(in <-chan *ssh.Request) { + for req := range in { + fmt.Fprintf(debugStream, "Request: %v\n", req.Type) + ok := false + switch req.Type { + case "subsystem": + fmt.Fprintf(debugStream, "Subsystem: %s\n", req.Payload[4:]) + if string(req.Payload[4:]) == "sftp" { + ok = true + } + } + fmt.Fprintf(debugStream, " - accepted: %v\n", ok) + req.Reply(ok, nil) + } + }(requests) + + serverOptions := []sftp.ServerOption{ + sftp.WithDebug(debugStream), + } + + if readOnly { + serverOptions = append(serverOptions, sftp.ReadOnly()) + fmt.Fprintf(debugStream, "Read-only server\n") + } else { + fmt.Fprintf(debugStream, "Read write server\n") + } + + server, err := sftp.NewServer( + channel, + serverOptions..., + ) + if err != nil { + log.Fatal(err) + } + if err := server.Serve(); err == io.EOF { + server.Close() + log.Print("sftp client exited session.") + } else if err != nil { + log.Fatal("sftp server completed with error:", err) + } + } +} diff --git a/github.com/pkg/sftp/examples/streaming-read-benchmark/main.go b/github.com/pkg/sftp/examples/streaming-read-benchmark/main.go new file mode 100644 index 0000000000..87afc5a324 --- /dev/null +++ b/github.com/pkg/sftp/examples/streaming-read-benchmark/main.go @@ -0,0 +1,85 @@ +// streaming-read-benchmark benchmarks the peformance of reading +// from /dev/zero on the server to /dev/null on the client via io.Copy. +package main + +import ( + "flag" + "fmt" + "io" + "log" + "net" + "os" + "syscall" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + + "github.com/pkg/sftp" +) + +var ( + USER = flag.String("user", os.Getenv("USER"), "ssh username") + HOST = flag.String("host", "localhost", "ssh server hostname") + PORT = flag.Int("port", 22, "ssh server port") + PASS = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password") + SIZE = flag.Int("s", 1<<15, "set max packet size") +) + +func init() { + flag.Parse() +} + +func main() { + var auths []ssh.AuthMethod + if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { + auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers)) + + } + if *PASS != "" { + auths = append(auths, ssh.Password(*PASS)) + } + + config := ssh.ClientConfig{ + User: *USER, + Auth: auths, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + addr := fmt.Sprintf("%s:%d", *HOST, *PORT) + conn, err := ssh.Dial("tcp", addr, &config) + if err != nil { + log.Fatalf("unable to connect to [%s]: %v", addr, err) + } + defer conn.Close() + + c, err := sftp.NewClient(conn, sftp.MaxPacket(*SIZE)) + if err != nil { + log.Fatalf("unable to start sftp subsytem: %v", err) + } + defer c.Close() + + r, err := c.Open("/dev/zero") + if err != nil { + log.Fatal(err) + } + defer r.Close() + + w, err := os.OpenFile("/dev/null", syscall.O_WRONLY, 0600) + if err != nil { + log.Fatal(err) + } + defer w.Close() + + const size int64 = 1e9 + + log.Printf("reading %v bytes", size) + t1 := time.Now() + n, err := io.Copy(w, io.LimitReader(r, size)) + if err != nil { + log.Fatal(err) + } + if n != size { + log.Fatalf("copy: expected %v bytes, got %d", size, n) + } + log.Printf("read %v bytes in %s", size, time.Since(t1)) +} diff --git a/github.com/pkg/sftp/examples/streaming-write-benchmark/main.go b/github.com/pkg/sftp/examples/streaming-write-benchmark/main.go new file mode 100644 index 0000000000..8f432d3914 --- /dev/null +++ b/github.com/pkg/sftp/examples/streaming-write-benchmark/main.go @@ -0,0 +1,85 @@ +// streaming-write-benchmark benchmarks the peformance of writing +// from /dev/zero on the client to /dev/null on the server via io.Copy. +package main + +import ( + "flag" + "fmt" + "io" + "log" + "net" + "os" + "syscall" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + + "github.com/pkg/sftp" +) + +var ( + USER = flag.String("user", os.Getenv("USER"), "ssh username") + HOST = flag.String("host", "localhost", "ssh server hostname") + PORT = flag.Int("port", 22, "ssh server port") + PASS = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password") + SIZE = flag.Int("s", 1<<15, "set max packet size") +) + +func init() { + flag.Parse() +} + +func main() { + var auths []ssh.AuthMethod + if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { + auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers)) + + } + if *PASS != "" { + auths = append(auths, ssh.Password(*PASS)) + } + + config := ssh.ClientConfig{ + User: *USER, + Auth: auths, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + addr := fmt.Sprintf("%s:%d", *HOST, *PORT) + conn, err := ssh.Dial("tcp", addr, &config) + if err != nil { + log.Fatalf("unable to connect to [%s]: %v", addr, err) + } + defer conn.Close() + + c, err := sftp.NewClient(conn, sftp.MaxPacket(*SIZE)) + if err != nil { + log.Fatalf("unable to start sftp subsytem: %v", err) + } + defer c.Close() + + w, err := c.OpenFile("/dev/null", syscall.O_WRONLY) + if err != nil { + log.Fatal(err) + } + defer w.Close() + + f, err := os.Open("/dev/zero") + if err != nil { + log.Fatal(err) + } + defer f.Close() + + const size int64 = 1e9 + + log.Printf("writing %v bytes", size) + t1 := time.Now() + n, err := io.Copy(w, io.LimitReader(f, size)) + if err != nil { + log.Fatal(err) + } + if n != size { + log.Fatalf("copy: expected %v bytes, got %d", size, n) + } + log.Printf("wrote %v bytes in %s", size, time.Since(t1)) +} diff --git a/github.com/pkg/sftp/match.go b/github.com/pkg/sftp/match.go new file mode 100644 index 0000000000..e2f2ba409e --- /dev/null +++ b/github.com/pkg/sftp/match.go @@ -0,0 +1,295 @@ +package sftp + +import ( + "path" + "strings" + "unicode/utf8" +) + +// ErrBadPattern indicates a globbing pattern was malformed. +var ErrBadPattern = path.ErrBadPattern + +// Unix separator +const separator = "/" + +// Match reports whether name matches the shell file name pattern. +// The pattern syntax is: +// +// pattern: +// { term } +// term: +// '*' matches any sequence of non-Separator characters +// '?' matches any single non-Separator character +// '[' [ '^' ] { character-range } ']' +// character class (must be non-empty) +// c matches character c (c != '*', '?', '\\', '[') +// '\\' c matches character c +// +// character-range: +// c matches character c (c != '\\', '-', ']') +// '\\' c matches character c +// lo '-' hi matches character c for lo <= c <= hi +// +// Match requires pattern to match all of name, not just a substring. +// The only possible returned error is ErrBadPattern, when pattern +// is malformed. +// +// +func Match(pattern, name string) (matched bool, err error) { + return path.Match(pattern, name) +} + +// detect if byte(char) is path separator +func isPathSeparator(c byte) bool { + return string(c) == "/" +} + +// scanChunk gets the next segment of pattern, which is a non-star string +// possibly preceded by a star. +func scanChunk(pattern string) (star bool, chunk, rest string) { + for len(pattern) > 0 && pattern[0] == '*' { + pattern = pattern[1:] + star = true + } + inrange := false + var i int +Scan: + for i = 0; i < len(pattern); i++ { + switch pattern[i] { + case '\\': + + // error check handled in matchChunk: bad pattern. + if i+1 < len(pattern) { + i++ + } + case '[': + inrange = true + case ']': + inrange = false + case '*': + if !inrange { + break Scan + } + } + } + return star, pattern[0:i], pattern[i:] +} + +// matchChunk checks whether chunk matches the beginning of s. +// If so, it returns the remainder of s (after the match). +// Chunk is all single-character operators: literals, char classes, and ?. +func matchChunk(chunk, s string) (rest string, ok bool, err error) { + for len(chunk) > 0 { + if len(s) == 0 { + return + } + switch chunk[0] { + case '[': + // character class + r, n := utf8.DecodeRuneInString(s) + s = s[n:] + chunk = chunk[1:] + // We can't end right after '[', we're expecting at least + // a closing bracket and possibly a caret. + if len(chunk) == 0 { + err = ErrBadPattern + return + } + // possibly negated + negated := chunk[0] == '^' + if negated { + chunk = chunk[1:] + } + // parse all ranges + match := false + nrange := 0 + for { + if len(chunk) > 0 && chunk[0] == ']' && nrange > 0 { + chunk = chunk[1:] + break + } + var lo, hi rune + if lo, chunk, err = getEsc(chunk); err != nil { + return + } + hi = lo + if chunk[0] == '-' { + if hi, chunk, err = getEsc(chunk[1:]); err != nil { + return + } + } + if lo <= r && r <= hi { + match = true + } + nrange++ + } + if match == negated { + return + } + + case '?': + if isPathSeparator(s[0]) { + return + } + _, n := utf8.DecodeRuneInString(s) + s = s[n:] + chunk = chunk[1:] + + case '\\': + chunk = chunk[1:] + if len(chunk) == 0 { + err = ErrBadPattern + return + } + fallthrough + + default: + if chunk[0] != s[0] { + return + } + s = s[1:] + chunk = chunk[1:] + } + } + return s, true, nil +} + +// getEsc gets a possibly-escaped character from chunk, for a character class. +func getEsc(chunk string) (r rune, nchunk string, err error) { + if len(chunk) == 0 || chunk[0] == '-' || chunk[0] == ']' { + err = ErrBadPattern + return + } + if chunk[0] == '\\' { + chunk = chunk[1:] + if len(chunk) == 0 { + err = ErrBadPattern + return + } + } + r, n := utf8.DecodeRuneInString(chunk) + if r == utf8.RuneError && n == 1 { + err = ErrBadPattern + } + nchunk = chunk[n:] + if len(nchunk) == 0 { + err = ErrBadPattern + } + return +} + +// Split splits path immediately following the final Separator, +// separating it into a directory and file name component. +// If there is no Separator in path, Split returns an empty dir +// and file set to path. +// The returned values have the property that path = dir+file. +func Split(path string) (dir, file string) { + i := len(path) - 1 + for i >= 0 && !isPathSeparator(path[i]) { + i-- + } + return path[:i+1], path[i+1:] +} + +// Glob returns the names of all files matching pattern or nil +// if there is no matching file. The syntax of patterns is the same +// as in Match. The pattern may describe hierarchical names such as +// /usr/*/bin/ed (assuming the Separator is '/'). +// +// Glob ignores file system errors such as I/O errors reading directories. +// The only possible returned error is ErrBadPattern, when pattern +// is malformed. +func (c *Client) Glob(pattern string) (matches []string, err error) { + if !hasMeta(pattern) { + file, err := c.Lstat(pattern) + if err != nil { + return nil, nil + } + dir, _ := Split(pattern) + dir = cleanGlobPath(dir) + return []string{Join(dir, file.Name())}, nil + } + + dir, file := Split(pattern) + dir = cleanGlobPath(dir) + + if !hasMeta(dir) { + return c.glob(dir, file, nil) + } + + // Prevent infinite recursion. See issue 15879. + if dir == pattern { + return nil, ErrBadPattern + } + + var m []string + m, err = c.Glob(dir) + if err != nil { + return + } + for _, d := range m { + matches, err = c.glob(d, file, matches) + if err != nil { + return + } + } + return +} + +// cleanGlobPath prepares path for glob matching. +func cleanGlobPath(path string) string { + switch path { + case "": + return "." + case string(separator): + // do nothing to the path + return path + default: + return path[0 : len(path)-1] // chop off trailing separator + } +} + +// glob searches for files matching pattern in the directory dir +// and appends them to matches. If the directory cannot be +// opened, it returns the existing matches. New matches are +// added in lexicographical order. +func (c *Client) glob(dir, pattern string, matches []string) (m []string, e error) { + m = matches + fi, err := c.Stat(dir) + if err != nil { + return + } + if !fi.IsDir() { + return + } + names, err := c.ReadDir(dir) + if err != nil { + return + } + //sort.Strings(names) + + for _, n := range names { + matched, err := Match(pattern, n.Name()) + if err != nil { + return m, err + } + if matched { + m = append(m, Join(dir, n.Name())) + } + } + return +} + +// Join joins any number of path elements into a single path, adding +// a Separator if necessary. +// all empty strings are ignored. +func Join(elem ...string) string { + return path.Join(elem...) +} + +// hasMeta reports whether path contains any of the magic characters +// recognized by Match. +func hasMeta(path string) bool { + // TODO(niemeyer): Should other magic characters be added here? + return strings.ContainsAny(path, "*?[") +} diff --git a/github.com/pkg/sftp/other_test.go b/github.com/pkg/sftp/other_test.go new file mode 100644 index 0000000000..1b84ccfa84 --- /dev/null +++ b/github.com/pkg/sftp/other_test.go @@ -0,0 +1,5 @@ +// +build !linux,!darwin + +package sftp + +const sftpServer = "/usr/bin/false" // unsupported diff --git a/github.com/pkg/sftp/packet-manager.go b/github.com/pkg/sftp/packet-manager.go new file mode 100644 index 0000000000..6d1a8e5d6c --- /dev/null +++ b/github.com/pkg/sftp/packet-manager.go @@ -0,0 +1,156 @@ +package sftp + +import ( + "encoding" + "sync" +) + +// The goal of the packetManager is to keep the outgoing packets in the same +// order as the incoming. This is due to some sftp clients requiring this +// behavior (eg. winscp). + +type packetSender interface { + sendPacket(encoding.BinaryMarshaler) error +} + +type packetManager struct { + requests chan requestPacket + responses chan responsePacket + fini chan struct{} + incoming requestPacketIDs + outgoing responsePackets + sender packetSender // connection object + working *sync.WaitGroup +} + +func newPktMgr(sender packetSender) packetManager { + s := packetManager{ + requests: make(chan requestPacket, sftpServerWorkerCount), + responses: make(chan responsePacket, sftpServerWorkerCount), + fini: make(chan struct{}), + incoming: make([]uint32, 0, sftpServerWorkerCount), + outgoing: make([]responsePacket, 0, sftpServerWorkerCount), + sender: sender, + working: &sync.WaitGroup{}, + } + go s.controller() + return s +} + +// register incoming packets to be handled +// send id of 0 for packets without id +func (s packetManager) incomingPacket(pkt requestPacket) { + s.working.Add(1) + s.requests <- pkt // buffer == sftpServerWorkerCount +} + +// register outgoing packets as being ready +func (s packetManager) readyPacket(pkt responsePacket) { + s.responses <- pkt + s.working.Done() +} + +// shut down packetManager controller +func (s packetManager) close() { + // pause until current packets are processed + s.working.Wait() + close(s.fini) +} + +// Passed a worker function, returns a channel for incoming packets. +// The goal is to process packets in the order they are received as is +// requires by section 7 of the RFC, while maximizing throughput of file +// transfers. +func (s *packetManager) workerChan(runWorker func(requestChan)) requestChan { + + rwChan := make(chan requestPacket, sftpServerWorkerCount) + for i := 0; i < sftpServerWorkerCount; i++ { + runWorker(rwChan) + } + + cmdChan := make(chan requestPacket) + runWorker(cmdChan) + + pktChan := make(chan requestPacket, sftpServerWorkerCount) + go func() { + // start with cmdChan + curChan := cmdChan + for pkt := range pktChan { + // on file open packet, switch to rwChan + switch pkt.(type) { + case *sshFxpOpenPacket: + curChan = rwChan + // on file close packet, switch back to cmdChan + // after waiting for any reads/writes to finish + case *sshFxpClosePacket: + // wait for rwChan to finish + s.working.Wait() + // stop using rwChan + curChan = cmdChan + } + s.incomingPacket(pkt) + curChan <- pkt + } + close(rwChan) + close(cmdChan) + s.close() + }() + + return pktChan +} + +// process packets +func (s *packetManager) controller() { + for { + select { + case pkt := <-s.requests: + debug("incoming id: %v", pkt.id()) + s.incoming = append(s.incoming, pkt.id()) + if len(s.incoming) > 1 { + s.incoming.Sort() + } + case pkt := <-s.responses: + debug("outgoing pkt: %v", pkt.id()) + s.outgoing = append(s.outgoing, pkt) + if len(s.outgoing) > 1 { + s.outgoing.Sort() + } + case <-s.fini: + return + } + s.maybeSendPackets() + } +} + +// send as many packets as are ready +func (s *packetManager) maybeSendPackets() { + for { + if len(s.outgoing) == 0 || len(s.incoming) == 0 { + debug("break! -- outgoing: %v; incoming: %v", + len(s.outgoing), len(s.incoming)) + break + } + out := s.outgoing[0] + in := s.incoming[0] + // debug("incoming: %v", s.incoming) + // debug("outgoing: %v", outfilter(s.outgoing)) + if in == out.id() { + s.sender.sendPacket(out) + // pop off heads + copy(s.incoming, s.incoming[1:]) // shift left + s.incoming = s.incoming[:len(s.incoming)-1] // remove last + copy(s.outgoing, s.outgoing[1:]) // shift left + s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last + } else { + break + } + } +} + +func outfilter(o []responsePacket) []uint32 { + res := make([]uint32, 0, len(o)) + for _, v := range o { + res = append(res, v.id()) + } + return res +} diff --git a/github.com/pkg/sftp/packet-manager_go1.8.go b/github.com/pkg/sftp/packet-manager_go1.8.go new file mode 100644 index 0000000000..ccae8c1d35 --- /dev/null +++ b/github.com/pkg/sftp/packet-manager_go1.8.go @@ -0,0 +1,21 @@ +// +build go1.8 + +package sftp + +import "sort" + +type responsePackets []responsePacket + +func (r responsePackets) Sort() { + sort.Slice(r, func(i, j int) bool { + return r[i].id() < r[j].id() + }) +} + +type requestPacketIDs []uint32 + +func (r requestPacketIDs) Sort() { + sort.Slice(r, func(i, j int) bool { + return r[i] < r[j] + }) +} diff --git a/github.com/pkg/sftp/packet-manager_legacy.go b/github.com/pkg/sftp/packet-manager_legacy.go new file mode 100644 index 0000000000..97f0ff096d --- /dev/null +++ b/github.com/pkg/sftp/packet-manager_legacy.go @@ -0,0 +1,21 @@ +// +build !go1.8 + +package sftp + +import "sort" + +// for sorting/ordering outgoing +type responsePackets []responsePacket + +func (r responsePackets) Len() int { return len(r) } +func (r responsePackets) Swap(i, j int) { r[i], r[j] = r[j], r[i] } +func (r responsePackets) Less(i, j int) bool { return r[i].id() < r[j].id() } +func (r responsePackets) Sort() { sort.Sort(r) } + +// for sorting/ordering incoming +type requestPacketIDs []uint32 + +func (r requestPacketIDs) Len() int { return len(r) } +func (r requestPacketIDs) Swap(i, j int) { r[i], r[j] = r[j], r[i] } +func (r requestPacketIDs) Less(i, j int) bool { return r[i] < r[j] } +func (r requestPacketIDs) Sort() { sort.Sort(r) } diff --git a/github.com/pkg/sftp/packet-manager_test.go b/github.com/pkg/sftp/packet-manager_test.go new file mode 100644 index 0000000000..7e187d94bf --- /dev/null +++ b/github.com/pkg/sftp/packet-manager_test.go @@ -0,0 +1,154 @@ +package sftp + +import ( + "encoding" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type _testSender struct { + sent chan encoding.BinaryMarshaler +} + +func newTestSender() *_testSender { + return &_testSender{make(chan encoding.BinaryMarshaler)} +} + +func (s _testSender) sendPacket(p encoding.BinaryMarshaler) error { + s.sent <- p + return nil +} + +type fakepacket uint32 + +func (fakepacket) MarshalBinary() ([]byte, error) { + return []byte{}, nil +} + +func (fakepacket) UnmarshalBinary([]byte) error { + return nil +} + +func (f fakepacket) id() uint32 { + return uint32(f) +} + +type pair struct { + in fakepacket + out fakepacket +} + +// basic test +var ttable1 = []pair{ + pair{fakepacket(0), fakepacket(0)}, + pair{fakepacket(1), fakepacket(1)}, + pair{fakepacket(2), fakepacket(2)}, + pair{fakepacket(3), fakepacket(3)}, +} + +// outgoing packets out of order +var ttable2 = []pair{ + pair{fakepacket(0), fakepacket(0)}, + pair{fakepacket(1), fakepacket(4)}, + pair{fakepacket(2), fakepacket(1)}, + pair{fakepacket(3), fakepacket(3)}, + pair{fakepacket(4), fakepacket(2)}, +} + +// incoming packets out of order +var ttable3 = []pair{ + pair{fakepacket(2), fakepacket(0)}, + pair{fakepacket(1), fakepacket(1)}, + pair{fakepacket(3), fakepacket(2)}, + pair{fakepacket(0), fakepacket(3)}, +} + +var tables = [][]pair{ttable1, ttable2, ttable3} + +func TestPacketManager(t *testing.T) { + sender := newTestSender() + s := newPktMgr(sender) + + for i := range tables { + table := tables[i] + for _, p := range table { + s.incomingPacket(p.in) + } + for _, p := range table { + s.readyPacket(p.out) + } + for i := 0; i < len(table); i++ { + pkt := <-sender.sent + id := pkt.(fakepacket).id() + assert.Equal(t, id, uint32(i)) + } + } + s.close() +} + +func (p sshFxpRemovePacket) String() string { + return fmt.Sprintf("RmPct:%d", p.ID) +} +func (p sshFxpOpenPacket) String() string { + return fmt.Sprintf("OpPct:%d", p.ID) +} +func (p sshFxpWritePacket) String() string { + return fmt.Sprintf("WrPct:%d", p.ID) +} +func (p sshFxpClosePacket) String() string { + return fmt.Sprintf("ClPct:%d", p.ID) +} + +// Test what happens when the pool processes a close packet on a file that it +// is still reading from. +func TestCloseOutOfOrder(t *testing.T) { + packets := []requestPacket{ + &sshFxpRemovePacket{ID: 0, Filename: "foo"}, + &sshFxpOpenPacket{ID: 1}, + &sshFxpWritePacket{ID: 2, Handle: "foo"}, + &sshFxpWritePacket{ID: 3, Handle: "foo"}, + &sshFxpWritePacket{ID: 4, Handle: "foo"}, + &sshFxpWritePacket{ID: 5, Handle: "foo"}, + &sshFxpClosePacket{ID: 6, Handle: "foo"}, + &sshFxpRemovePacket{ID: 7, Filename: "foo"}, + } + + recvChan := make(chan requestPacket, len(packets)+1) + sender := newTestSender() + pktMgr := newPktMgr(sender) + wg := sync.WaitGroup{} + wg.Add(len(packets)) + runWorker := func(ch requestChan) { + go func() { + for pkt := range ch { + if _, ok := pkt.(*sshFxpWritePacket); ok { + // sleep to cause writes to come after close/remove + time.Sleep(time.Millisecond) + } + pktMgr.working.Done() + recvChan <- pkt + wg.Done() + } + }() + } + pktChan := pktMgr.workerChan(runWorker) + for _, p := range packets { + pktChan <- p + } + wg.Wait() + close(recvChan) + received := []requestPacket{} + for p := range recvChan { + received = append(received, p) + } + if received[len(received)-2].id() != packets[len(packets)-2].id() { + t.Fatal("Packets processed out of order1:", received, packets) + } + if received[len(received)-1].id() != packets[len(packets)-1].id() { + t.Fatal("Packets processed out of order2:", received, packets) + } +} diff --git a/github.com/pkg/sftp/packet-typing.go b/github.com/pkg/sftp/packet-typing.go new file mode 100644 index 0000000000..920851ddbd --- /dev/null +++ b/github.com/pkg/sftp/packet-typing.go @@ -0,0 +1,141 @@ +package sftp + +import ( + "encoding" + + "github.com/pkg/errors" +) + +// all incoming packets +type requestPacket interface { + encoding.BinaryUnmarshaler + id() uint32 +} + +type requestChan chan requestPacket + +type responsePacket interface { + encoding.BinaryMarshaler + id() uint32 +} + +// interfaces to group types +type hasPath interface { + requestPacket + getPath() string +} + +type hasHandle interface { + requestPacket + getHandle() string +} + +type isOpener interface { + hasPath + isOpener() +} + +type notReadOnly interface { + notReadOnly() +} + +//// define types by adding methods +// hasPath +func (p sshFxpLstatPacket) getPath() string { return p.Path } +func (p sshFxpStatPacket) getPath() string { return p.Path } +func (p sshFxpRmdirPacket) getPath() string { return p.Path } +func (p sshFxpReadlinkPacket) getPath() string { return p.Path } +func (p sshFxpRealpathPacket) getPath() string { return p.Path } +func (p sshFxpMkdirPacket) getPath() string { return p.Path } +func (p sshFxpSetstatPacket) getPath() string { return p.Path } +func (p sshFxpStatvfsPacket) getPath() string { return p.Path } +func (p sshFxpRemovePacket) getPath() string { return p.Filename } +func (p sshFxpRenamePacket) getPath() string { return p.Oldpath } +func (p sshFxpSymlinkPacket) getPath() string { return p.Targetpath } + +// Openers implement hasPath and isOpener +func (p sshFxpOpendirPacket) getPath() string { return p.Path } +func (p sshFxpOpendirPacket) isOpener() {} +func (p sshFxpOpenPacket) getPath() string { return p.Path } +func (p sshFxpOpenPacket) isOpener() {} + +// hasHandle +func (p sshFxpFstatPacket) getHandle() string { return p.Handle } +func (p sshFxpFsetstatPacket) getHandle() string { return p.Handle } +func (p sshFxpReadPacket) getHandle() string { return p.Handle } +func (p sshFxpWritePacket) getHandle() string { return p.Handle } +func (p sshFxpReaddirPacket) getHandle() string { return p.Handle } + +// notReadOnly +func (p sshFxpWritePacket) notReadOnly() {} +func (p sshFxpSetstatPacket) notReadOnly() {} +func (p sshFxpFsetstatPacket) notReadOnly() {} +func (p sshFxpRemovePacket) notReadOnly() {} +func (p sshFxpMkdirPacket) notReadOnly() {} +func (p sshFxpRmdirPacket) notReadOnly() {} +func (p sshFxpRenamePacket) notReadOnly() {} +func (p sshFxpSymlinkPacket) notReadOnly() {} + +// this has a handle, but is only used for close +func (p sshFxpClosePacket) getHandle() string { return p.Handle } + +// some packets with ID are missing id() +func (p sshFxpDataPacket) id() uint32 { return p.ID } +func (p sshFxpStatusPacket) id() uint32 { return p.ID } +func (p sshFxpStatResponse) id() uint32 { return p.ID } +func (p sshFxpNamePacket) id() uint32 { return p.ID } +func (p sshFxpHandlePacket) id() uint32 { return p.ID } +func (p sshFxVersionPacket) id() uint32 { return 0 } + +// take raw incoming packet data and build packet objects +func makePacket(p rxPacket) (requestPacket, error) { + var pkt requestPacket + switch p.pktType { + case ssh_FXP_INIT: + pkt = &sshFxInitPacket{} + case ssh_FXP_LSTAT: + pkt = &sshFxpLstatPacket{} + case ssh_FXP_OPEN: + pkt = &sshFxpOpenPacket{} + case ssh_FXP_CLOSE: + pkt = &sshFxpClosePacket{} + case ssh_FXP_READ: + pkt = &sshFxpReadPacket{} + case ssh_FXP_WRITE: + pkt = &sshFxpWritePacket{} + case ssh_FXP_FSTAT: + pkt = &sshFxpFstatPacket{} + case ssh_FXP_SETSTAT: + pkt = &sshFxpSetstatPacket{} + case ssh_FXP_FSETSTAT: + pkt = &sshFxpFsetstatPacket{} + case ssh_FXP_OPENDIR: + pkt = &sshFxpOpendirPacket{} + case ssh_FXP_READDIR: + pkt = &sshFxpReaddirPacket{} + case ssh_FXP_REMOVE: + pkt = &sshFxpRemovePacket{} + case ssh_FXP_MKDIR: + pkt = &sshFxpMkdirPacket{} + case ssh_FXP_RMDIR: + pkt = &sshFxpRmdirPacket{} + case ssh_FXP_REALPATH: + pkt = &sshFxpRealpathPacket{} + case ssh_FXP_STAT: + pkt = &sshFxpStatPacket{} + case ssh_FXP_RENAME: + pkt = &sshFxpRenamePacket{} + case ssh_FXP_READLINK: + pkt = &sshFxpReadlinkPacket{} + case ssh_FXP_SYMLINK: + pkt = &sshFxpSymlinkPacket{} + case ssh_FXP_EXTENDED: + pkt = &sshFxpExtendedPacket{} + default: + return nil, errors.Errorf("unhandled packet type: %s", p.pktType) + } + if err := pkt.UnmarshalBinary(p.pktBytes); err != nil { + return nil, err + } + return pkt, nil +} diff --git a/github.com/pkg/sftp/packet.go b/github.com/pkg/sftp/packet.go new file mode 100644 index 0000000000..db4fbb3d41 --- /dev/null +++ b/github.com/pkg/sftp/packet.go @@ -0,0 +1,898 @@ +package sftp + +import ( + "bytes" + "encoding" + "encoding/binary" + "fmt" + "io" + "os" + "reflect" + + "github.com/pkg/errors" +) + +var ( + errShortPacket = errors.New("packet too short") + errUnknownExtendedPacket = errors.New("unknown extended packet") +) + +const ( + debugDumpTxPacket = false + debugDumpRxPacket = false + debugDumpTxPacketBytes = false + debugDumpRxPacketBytes = false +) + +func marshalUint32(b []byte, v uint32) []byte { + return append(b, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) +} + +func marshalUint64(b []byte, v uint64) []byte { + return marshalUint32(marshalUint32(b, uint32(v>>32)), uint32(v)) +} + +func marshalString(b []byte, v string) []byte { + return append(marshalUint32(b, uint32(len(v))), v...) +} + +func marshal(b []byte, v interface{}) []byte { + if v == nil { + return b + } + switch v := v.(type) { + case uint8: + return append(b, v) + case uint32: + return marshalUint32(b, v) + case uint64: + return marshalUint64(b, v) + case string: + return marshalString(b, v) + case os.FileInfo: + return marshalFileInfo(b, v) + default: + switch d := reflect.ValueOf(v); d.Kind() { + case reflect.Struct: + for i, n := 0, d.NumField(); i < n; i++ { + b = append(marshal(b, d.Field(i).Interface())) + } + return b + case reflect.Slice: + for i, n := 0, d.Len(); i < n; i++ { + b = append(marshal(b, d.Index(i).Interface())) + } + return b + default: + panic(fmt.Sprintf("marshal(%#v): cannot handle type %T", v, v)) + } + } +} + +func unmarshalUint32(b []byte) (uint32, []byte) { + v := uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24 + return v, b[4:] +} + +func unmarshalUint32Safe(b []byte) (uint32, []byte, error) { + var v uint32 + if len(b) < 4 { + return 0, nil, errShortPacket + } + v, b = unmarshalUint32(b) + return v, b, nil +} + +func unmarshalUint64(b []byte) (uint64, []byte) { + h, b := unmarshalUint32(b) + l, b := unmarshalUint32(b) + return uint64(h)<<32 | uint64(l), b +} + +func unmarshalUint64Safe(b []byte) (uint64, []byte, error) { + var v uint64 + if len(b) < 8 { + return 0, nil, errShortPacket + } + v, b = unmarshalUint64(b) + return v, b, nil +} + +func unmarshalString(b []byte) (string, []byte) { + n, b := unmarshalUint32(b) + return string(b[:n]), b[n:] +} + +func unmarshalStringSafe(b []byte) (string, []byte, error) { + n, b, err := unmarshalUint32Safe(b) + if err != nil { + return "", nil, err + } + if int64(n) > int64(len(b)) { + return "", nil, errShortPacket + } + return string(b[:n]), b[n:], nil +} + +// sendPacket marshals p according to RFC 4234. +func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error { + bb, err := m.MarshalBinary() + if err != nil { + return errors.Errorf("binary marshaller failed: %v", err) + } + if debugDumpTxPacketBytes { + debug("send packet: %s %d bytes %x", fxp(bb[0]), len(bb), bb[1:]) + } else if debugDumpTxPacket { + debug("send packet: %s %d bytes", fxp(bb[0]), len(bb)) + } + l := uint32(len(bb)) + hdr := []byte{byte(l >> 24), byte(l >> 16), byte(l >> 8), byte(l)} + _, err = w.Write(hdr) + if err != nil { + return errors.Errorf("failed to send packet header: %v", err) + } + _, err = w.Write(bb) + if err != nil { + return errors.Errorf("failed to send packet body: %v", err) + } + return nil +} + +func recvPacket(r io.Reader) (uint8, []byte, error) { + var b = []byte{0, 0, 0, 0} + if _, err := io.ReadFull(r, b); err != nil { + return 0, nil, err + } + l, _ := unmarshalUint32(b) + b = make([]byte, l) + if _, err := io.ReadFull(r, b); err != nil { + debug("recv packet %d bytes: err %v", l, err) + return 0, nil, err + } + if debugDumpRxPacketBytes { + debug("recv packet: %s %d bytes %x", fxp(b[0]), l, b[1:]) + } else if debugDumpRxPacket { + debug("recv packet: %s %d bytes", fxp(b[0]), l) + } + return b[0], b[1:], nil +} + +type extensionPair struct { + Name string + Data string +} + +func unmarshalExtensionPair(b []byte) (extensionPair, []byte, error) { + var ep extensionPair + var err error + ep.Name, b, err = unmarshalStringSafe(b) + if err != nil { + return ep, b, err + } + ep.Data, b, err = unmarshalStringSafe(b) + return ep, b, err +} + +// Here starts the definition of packets along with their MarshalBinary +// implementations. +// Manually writing the marshalling logic wins us a lot of time and +// allocation. + +type sshFxInitPacket struct { + Version uint32 + Extensions []extensionPair +} + +func (p sshFxInitPacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 // byte + uint32 + for _, e := range p.Extensions { + l += 4 + len(e.Name) + 4 + len(e.Data) + } + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_INIT) + b = marshalUint32(b, p.Version) + for _, e := range p.Extensions { + b = marshalString(b, e.Name) + b = marshalString(b, e.Data) + } + return b, nil +} + +func (p *sshFxInitPacket) UnmarshalBinary(b []byte) error { + var err error + if p.Version, b, err = unmarshalUint32Safe(b); err != nil { + return err + } + for len(b) > 0 { + var ep extensionPair + ep, b, err = unmarshalExtensionPair(b) + if err != nil { + return err + } + p.Extensions = append(p.Extensions, ep) + } + return nil +} + +type sshFxVersionPacket struct { + Version uint32 + Extensions []struct { + Name, Data string + } +} + +func (p sshFxVersionPacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 // byte + uint32 + for _, e := range p.Extensions { + l += 4 + len(e.Name) + 4 + len(e.Data) + } + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_VERSION) + b = marshalUint32(b, p.Version) + for _, e := range p.Extensions { + b = marshalString(b, e.Name) + b = marshalString(b, e.Data) + } + return b, nil +} + +func marshalIDString(packetType byte, id uint32, str string) ([]byte, error) { + l := 1 + 4 + // type(byte) + uint32 + 4 + len(str) + + b := make([]byte, 0, l) + b = append(b, packetType) + b = marshalUint32(b, id) + b = marshalString(b, str) + return b, nil +} + +func unmarshalIDString(b []byte, id *uint32, str *string) error { + var err error + *id, b, err = unmarshalUint32Safe(b) + if err != nil { + return err + } + *str, b, err = unmarshalStringSafe(b) + return err +} + +type sshFxpReaddirPacket struct { + ID uint32 + Handle string +} + +func (p sshFxpReaddirPacket) id() uint32 { return p.ID } + +func (p sshFxpReaddirPacket) MarshalBinary() ([]byte, error) { + return marshalIDString(ssh_FXP_READDIR, p.ID, p.Handle) +} + +func (p *sshFxpReaddirPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Handle) +} + +type sshFxpOpendirPacket struct { + ID uint32 + Path string +} + +func (p sshFxpOpendirPacket) id() uint32 { return p.ID } + +func (p sshFxpOpendirPacket) MarshalBinary() ([]byte, error) { + return marshalIDString(ssh_FXP_OPENDIR, p.ID, p.Path) +} + +func (p *sshFxpOpendirPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Path) +} + +type sshFxpLstatPacket struct { + ID uint32 + Path string +} + +func (p sshFxpLstatPacket) id() uint32 { return p.ID } + +func (p sshFxpLstatPacket) MarshalBinary() ([]byte, error) { + return marshalIDString(ssh_FXP_LSTAT, p.ID, p.Path) +} + +func (p *sshFxpLstatPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Path) +} + +type sshFxpStatPacket struct { + ID uint32 + Path string +} + +func (p sshFxpStatPacket) id() uint32 { return p.ID } + +func (p sshFxpStatPacket) MarshalBinary() ([]byte, error) { + return marshalIDString(ssh_FXP_STAT, p.ID, p.Path) +} + +func (p *sshFxpStatPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Path) +} + +type sshFxpFstatPacket struct { + ID uint32 + Handle string +} + +func (p sshFxpFstatPacket) id() uint32 { return p.ID } + +func (p sshFxpFstatPacket) MarshalBinary() ([]byte, error) { + return marshalIDString(ssh_FXP_FSTAT, p.ID, p.Handle) +} + +func (p *sshFxpFstatPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Handle) +} + +type sshFxpClosePacket struct { + ID uint32 + Handle string +} + +func (p sshFxpClosePacket) id() uint32 { return p.ID } + +func (p sshFxpClosePacket) MarshalBinary() ([]byte, error) { + return marshalIDString(ssh_FXP_CLOSE, p.ID, p.Handle) +} + +func (p *sshFxpClosePacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Handle) +} + +type sshFxpRemovePacket struct { + ID uint32 + Filename string +} + +func (p sshFxpRemovePacket) id() uint32 { return p.ID } + +func (p sshFxpRemovePacket) MarshalBinary() ([]byte, error) { + return marshalIDString(ssh_FXP_REMOVE, p.ID, p.Filename) +} + +func (p *sshFxpRemovePacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Filename) +} + +type sshFxpRmdirPacket struct { + ID uint32 + Path string +} + +func (p sshFxpRmdirPacket) id() uint32 { return p.ID } + +func (p sshFxpRmdirPacket) MarshalBinary() ([]byte, error) { + return marshalIDString(ssh_FXP_RMDIR, p.ID, p.Path) +} + +func (p *sshFxpRmdirPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Path) +} + +type sshFxpSymlinkPacket struct { + ID uint32 + Targetpath string + Linkpath string +} + +func (p sshFxpSymlinkPacket) id() uint32 { return p.ID } + +func (p sshFxpSymlinkPacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 + // type(byte) + uint32 + 4 + len(p.Targetpath) + + 4 + len(p.Linkpath) + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_SYMLINK) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Targetpath) + b = marshalString(b, p.Linkpath) + return b, nil +} + +func (p *sshFxpSymlinkPacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Targetpath, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Linkpath, b, err = unmarshalStringSafe(b); err != nil { + return err + } + return nil +} + +type sshFxpReadlinkPacket struct { + ID uint32 + Path string +} + +func (p sshFxpReadlinkPacket) id() uint32 { return p.ID } + +func (p sshFxpReadlinkPacket) MarshalBinary() ([]byte, error) { + return marshalIDString(ssh_FXP_READLINK, p.ID, p.Path) +} + +func (p *sshFxpReadlinkPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Path) +} + +type sshFxpRealpathPacket struct { + ID uint32 + Path string +} + +func (p sshFxpRealpathPacket) id() uint32 { return p.ID } + +func (p sshFxpRealpathPacket) MarshalBinary() ([]byte, error) { + return marshalIDString(ssh_FXP_REALPATH, p.ID, p.Path) +} + +func (p *sshFxpRealpathPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Path) +} + +type sshFxpNameAttr struct { + Name string + LongName string + Attrs []interface{} +} + +func (p sshFxpNameAttr) MarshalBinary() ([]byte, error) { + b := []byte{} + b = marshalString(b, p.Name) + b = marshalString(b, p.LongName) + for _, attr := range p.Attrs { + b = marshal(b, attr) + } + return b, nil +} + +type sshFxpNamePacket struct { + ID uint32 + NameAttrs []sshFxpNameAttr +} + +func (p sshFxpNamePacket) MarshalBinary() ([]byte, error) { + b := []byte{} + b = append(b, ssh_FXP_NAME) + b = marshalUint32(b, p.ID) + b = marshalUint32(b, uint32(len(p.NameAttrs))) + for _, na := range p.NameAttrs { + ab, err := na.MarshalBinary() + if err != nil { + return nil, err + } + + b = append(b, ab...) + } + return b, nil +} + +type sshFxpOpenPacket struct { + ID uint32 + Path string + Pflags uint32 + Flags uint32 // ignored +} + +func (p sshFxpOpenPacket) id() uint32 { return p.ID } + +func (p sshFxpOpenPacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 + + 4 + len(p.Path) + + 4 + 4 + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_OPEN) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Path) + b = marshalUint32(b, p.Pflags) + b = marshalUint32(b, p.Flags) + return b, nil +} + +func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Path, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Pflags, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil { + return err + } + return nil +} + +type sshFxpReadPacket struct { + ID uint32 + Handle string + Offset uint64 + Len uint32 +} + +func (p sshFxpReadPacket) id() uint32 { return p.ID } + +func (p sshFxpReadPacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 + // type(byte) + uint32 + 4 + len(p.Handle) + + 8 + 4 // uint64 + uint32 + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_READ) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Handle) + b = marshalUint64(b, p.Offset) + b = marshalUint32(b, p.Len) + return b, nil +} + +func (p *sshFxpReadPacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Handle, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Offset, b, err = unmarshalUint64Safe(b); err != nil { + return err + } else if p.Len, b, err = unmarshalUint32Safe(b); err != nil { + return err + } + return nil +} + +type sshFxpRenamePacket struct { + ID uint32 + Oldpath string + Newpath string +} + +func (p sshFxpRenamePacket) id() uint32 { return p.ID } + +func (p sshFxpRenamePacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 + // type(byte) + uint32 + 4 + len(p.Oldpath) + + 4 + len(p.Newpath) + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_RENAME) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Oldpath) + b = marshalString(b, p.Newpath) + return b, nil +} + +func (p *sshFxpRenamePacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Oldpath, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Newpath, b, err = unmarshalStringSafe(b); err != nil { + return err + } + return nil +} + +type sshFxpWritePacket struct { + ID uint32 + Handle string + Offset uint64 + Length uint32 + Data []byte +} + +func (p sshFxpWritePacket) id() uint32 { return p.ID } + +func (p sshFxpWritePacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 + // type(byte) + uint32 + 4 + len(p.Handle) + + 8 + 4 + // uint64 + uint32 + len(p.Data) + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_WRITE) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Handle) + b = marshalUint64(b, p.Offset) + b = marshalUint32(b, p.Length) + b = append(b, p.Data...) + return b, nil +} + +func (p *sshFxpWritePacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Handle, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Offset, b, err = unmarshalUint64Safe(b); err != nil { + return err + } else if p.Length, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if uint32(len(b)) < p.Length { + return errShortPacket + } + + p.Data = append([]byte{}, b[:p.Length]...) + return nil +} + +type sshFxpMkdirPacket struct { + ID uint32 + Path string + Flags uint32 // ignored +} + +func (p sshFxpMkdirPacket) id() uint32 { return p.ID } + +func (p sshFxpMkdirPacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 + // type(byte) + uint32 + 4 + len(p.Path) + + 4 // uint32 + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_MKDIR) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Path) + b = marshalUint32(b, p.Flags) + return b, nil +} + +func (p *sshFxpMkdirPacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Path, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil { + return err + } + return nil +} + +type sshFxpSetstatPacket struct { + ID uint32 + Path string + Flags uint32 + Attrs interface{} +} + +type sshFxpFsetstatPacket struct { + ID uint32 + Handle string + Flags uint32 + Attrs interface{} +} + +func (p sshFxpSetstatPacket) id() uint32 { return p.ID } +func (p sshFxpFsetstatPacket) id() uint32 { return p.ID } + +func (p sshFxpSetstatPacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 + // type(byte) + uint32 + 4 + len(p.Path) + + 4 // uint32 + uint64 + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_SETSTAT) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Path) + b = marshalUint32(b, p.Flags) + b = marshal(b, p.Attrs) + return b, nil +} + +func (p sshFxpFsetstatPacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 + // type(byte) + uint32 + 4 + len(p.Handle) + + 4 // uint32 + uint64 + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_FSETSTAT) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Handle) + b = marshalUint32(b, p.Flags) + b = marshal(b, p.Attrs) + return b, nil +} + +func (p *sshFxpSetstatPacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Path, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil { + return err + } + p.Attrs = b + return nil +} + +func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Handle, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil { + return err + } + p.Attrs = b + return nil +} + +type sshFxpHandlePacket struct { + ID uint32 + Handle string +} + +func (p sshFxpHandlePacket) MarshalBinary() ([]byte, error) { + b := []byte{ssh_FXP_HANDLE} + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Handle) + return b, nil +} + +type sshFxpStatusPacket struct { + ID uint32 + StatusError +} + +func (p sshFxpStatusPacket) MarshalBinary() ([]byte, error) { + b := []byte{ssh_FXP_STATUS} + b = marshalUint32(b, p.ID) + b = marshalStatus(b, p.StatusError) + return b, nil +} + +type sshFxpDataPacket struct { + ID uint32 + Length uint32 + Data []byte +} + +func (p sshFxpDataPacket) MarshalBinary() ([]byte, error) { + b := []byte{ssh_FXP_DATA} + b = marshalUint32(b, p.ID) + b = marshalUint32(b, p.Length) + b = append(b, p.Data[:p.Length]...) + return b, nil +} + +func (p *sshFxpDataPacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Length, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if uint32(len(b)) < p.Length { + return errors.New("truncated packet") + } + + p.Data = make([]byte, p.Length) + copy(p.Data, b) + return nil +} + +type sshFxpStatvfsPacket struct { + ID uint32 + Path string +} + +func (p sshFxpStatvfsPacket) id() uint32 { return p.ID } + +func (p sshFxpStatvfsPacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 + // type(byte) + uint32 + len(p.Path) + + len("statvfs@openssh.com") + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_EXTENDED) + b = marshalUint32(b, p.ID) + b = marshalString(b, "statvfs@openssh.com") + b = marshalString(b, p.Path) + return b, nil +} + +// A StatVFS contains statistics about a filesystem. +type StatVFS struct { + ID uint32 + Bsize uint64 /* file system block size */ + Frsize uint64 /* fundamental fs block size */ + Blocks uint64 /* number of blocks (unit f_frsize) */ + Bfree uint64 /* free blocks in file system */ + Bavail uint64 /* free blocks for non-root */ + Files uint64 /* total file inodes */ + Ffree uint64 /* free file inodes */ + Favail uint64 /* free file inodes for to non-root */ + Fsid uint64 /* file system id */ + Flag uint64 /* bit mask of f_flag values */ + Namemax uint64 /* maximum filename length */ +} + +// TotalSpace calculates the amount of total space in a filesystem. +func (p *StatVFS) TotalSpace() uint64 { + return p.Frsize * p.Blocks +} + +// FreeSpace calculates the amount of free space in a filesystem. +func (p *StatVFS) FreeSpace() uint64 { + return p.Frsize * p.Bfree +} + +// Convert to ssh_FXP_EXTENDED_REPLY packet binary format +func (p *StatVFS) MarshalBinary() ([]byte, error) { + var buf bytes.Buffer + buf.Write([]byte{ssh_FXP_EXTENDED_REPLY}) + err := binary.Write(&buf, binary.BigEndian, p) + return buf.Bytes(), err +} + +type sshFxpExtendedPacket struct { + ID uint32 + ExtendedRequest string + SpecificPacket interface { + serverRespondablePacket + readonly() bool + } +} + +func (p sshFxpExtendedPacket) id() uint32 { return p.ID } +func (p sshFxpExtendedPacket) readonly() bool { return p.SpecificPacket.readonly() } + +func (p sshFxpExtendedPacket) respond(svr *Server) error { + return p.SpecificPacket.respond(svr) +} + +func (p *sshFxpExtendedPacket) UnmarshalBinary(b []byte) error { + var err error + bOrig := b + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.ExtendedRequest, b, err = unmarshalStringSafe(b); err != nil { + return err + } + + // specific unmarshalling + switch p.ExtendedRequest { + case "statvfs@openssh.com": + p.SpecificPacket = &sshFxpExtendedPacketStatVFS{} + default: + return errUnknownExtendedPacket + } + + return p.SpecificPacket.UnmarshalBinary(bOrig) +} + +type sshFxpExtendedPacketStatVFS struct { + ID uint32 + ExtendedRequest string + Path string +} + +func (p sshFxpExtendedPacketStatVFS) id() uint32 { return p.ID } +func (p sshFxpExtendedPacketStatVFS) readonly() bool { return true } +func (p *sshFxpExtendedPacketStatVFS) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.ExtendedRequest, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Path, b, err = unmarshalStringSafe(b); err != nil { + return err + } + return nil +} diff --git a/github.com/pkg/sftp/packet_test.go b/github.com/pkg/sftp/packet_test.go new file mode 100644 index 0000000000..2a948dbd84 --- /dev/null +++ b/github.com/pkg/sftp/packet_test.go @@ -0,0 +1,345 @@ +package sftp + +import ( + "bytes" + "encoding" + "os" + "testing" +) + +var marshalUint32Tests = []struct { + v uint32 + want []byte +}{ + {1, []byte{0, 0, 0, 1}}, + {256, []byte{0, 0, 1, 0}}, + {^uint32(0), []byte{255, 255, 255, 255}}, +} + +func TestMarshalUint32(t *testing.T) { + for _, tt := range marshalUint32Tests { + got := marshalUint32(nil, tt.v) + if !bytes.Equal(tt.want, got) { + t.Errorf("marshalUint32(%d): want %v, got %v", tt.v, tt.want, got) + } + } +} + +var marshalUint64Tests = []struct { + v uint64 + want []byte +}{ + {1, []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}}, + {256, []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0}}, + {^uint64(0), []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + {1 << 32, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0}}, +} + +func TestMarshalUint64(t *testing.T) { + for _, tt := range marshalUint64Tests { + got := marshalUint64(nil, tt.v) + if !bytes.Equal(tt.want, got) { + t.Errorf("marshalUint64(%d): want %#v, got %#v", tt.v, tt.want, got) + } + } +} + +var marshalStringTests = []struct { + v string + want []byte +}{ + {"", []byte{0, 0, 0, 0}}, + {"/foo", []byte{0x0, 0x0, 0x0, 0x4, 0x2f, 0x66, 0x6f, 0x6f}}, +} + +func TestMarshalString(t *testing.T) { + for _, tt := range marshalStringTests { + got := marshalString(nil, tt.v) + if !bytes.Equal(tt.want, got) { + t.Errorf("marshalString(%q): want %#v, got %#v", tt.v, tt.want, got) + } + } +} + +var marshalTests = []struct { + v interface{} + want []byte +}{ + {uint8(1), []byte{1}}, + {byte(1), []byte{1}}, + {uint32(1), []byte{0, 0, 0, 1}}, + {uint64(1), []byte{0, 0, 0, 0, 0, 0, 0, 1}}, + {"foo", []byte{0x0, 0x0, 0x0, 0x3, 0x66, 0x6f, 0x6f}}, + {[]uint32{1, 2, 3, 4}, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x4}}, +} + +func TestMarshal(t *testing.T) { + for _, tt := range marshalTests { + got := marshal(nil, tt.v) + if !bytes.Equal(tt.want, got) { + t.Errorf("marshal(%v): want %#v, got %#v", tt.v, tt.want, got) + } + } +} + +var unmarshalUint32Tests = []struct { + b []byte + want uint32 + rest []byte +}{ + {[]byte{0, 0, 0, 0}, 0, nil}, + {[]byte{0, 0, 1, 0}, 256, nil}, + {[]byte{255, 0, 0, 255}, 4278190335, nil}, +} + +func TestUnmarshalUint32(t *testing.T) { + for _, tt := range unmarshalUint32Tests { + got, rest := unmarshalUint32(tt.b) + if got != tt.want || !bytes.Equal(rest, tt.rest) { + t.Errorf("unmarshalUint32(%v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest) + } + } +} + +var unmarshalUint64Tests = []struct { + b []byte + want uint64 + rest []byte +}{ + {[]byte{0, 0, 0, 0, 0, 0, 0, 0}, 0, nil}, + {[]byte{0, 0, 0, 0, 0, 0, 1, 0}, 256, nil}, + {[]byte{255, 0, 0, 0, 0, 0, 0, 255}, 18374686479671623935, nil}, +} + +func TestUnmarshalUint64(t *testing.T) { + for _, tt := range unmarshalUint64Tests { + got, rest := unmarshalUint64(tt.b) + if got != tt.want || !bytes.Equal(rest, tt.rest) { + t.Errorf("unmarshalUint64(%v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest) + } + } +} + +var unmarshalStringTests = []struct { + b []byte + want string + rest []byte +}{ + {marshalString(nil, ""), "", nil}, + {marshalString(nil, "blah"), "blah", nil}, +} + +func TestUnmarshalString(t *testing.T) { + for _, tt := range unmarshalStringTests { + got, rest := unmarshalString(tt.b) + if got != tt.want || !bytes.Equal(rest, tt.rest) { + t.Errorf("unmarshalUint64(%v): want %q, %#v, got %q, %#v", tt.b, tt.want, tt.rest, got, rest) + } + } +} + +var sendPacketTests = []struct { + p encoding.BinaryMarshaler + want []byte +}{ + {sshFxInitPacket{ + Version: 3, + Extensions: []extensionPair{ + {"posix-rename@openssh.com", "1"}, + }, + }, []byte{0x0, 0x0, 0x0, 0x26, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x18, 0x70, 0x6f, 0x73, 0x69, 0x78, 0x2d, 0x72, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x40, 0x6f, 0x70, 0x65, 0x6e, 0x73, 0x73, 0x68, 0x2e, 0x63, 0x6f, 0x6d, 0x0, 0x0, 0x0, 0x1, 0x31}}, + + {sshFxpOpenPacket{ + ID: 1, + Path: "/foo", + Pflags: flags(os.O_RDONLY), + }, []byte{0x0, 0x0, 0x0, 0x15, 0x3, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x4, 0x2f, 0x66, 0x6f, 0x6f, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0}}, + + {sshFxpWritePacket{ + ID: 124, + Handle: "foo", + Offset: 13, + Length: uint32(len([]byte("bar"))), + Data: []byte("bar"), + }, []byte{0x0, 0x0, 0x0, 0x1b, 0x6, 0x0, 0x0, 0x0, 0x7c, 0x0, 0x0, 0x0, 0x3, 0x66, 0x6f, 0x6f, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xd, 0x0, 0x0, 0x0, 0x3, 0x62, 0x61, 0x72}}, + + {sshFxpSetstatPacket{ + ID: 31, + Path: "/bar", + Flags: flags(os.O_WRONLY), + Attrs: struct { + UID uint32 + GID uint32 + }{1000, 100}, + }, []byte{0x0, 0x0, 0x0, 0x19, 0x9, 0x0, 0x0, 0x0, 0x1f, 0x0, 0x0, 0x0, 0x4, 0x2f, 0x62, 0x61, 0x72, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x3, 0xe8, 0x0, 0x0, 0x0, 0x64}}, +} + +func TestSendPacket(t *testing.T) { + for _, tt := range sendPacketTests { + var w bytes.Buffer + sendPacket(&w, tt.p) + if got := w.Bytes(); !bytes.Equal(tt.want, got) { + t.Errorf("sendPacket(%v): want %#v, got %#v", tt.p, tt.want, got) + } + } +} + +func sp(p encoding.BinaryMarshaler) []byte { + var w bytes.Buffer + sendPacket(&w, p) + return w.Bytes() +} + +var recvPacketTests = []struct { + b []byte + want uint8 + rest []byte +}{ + {sp(sshFxInitPacket{ + Version: 3, + Extensions: []extensionPair{ + {"posix-rename@openssh.com", "1"}, + }, + }), ssh_FXP_INIT, []byte{0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x18, 0x70, 0x6f, 0x73, 0x69, 0x78, 0x2d, 0x72, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x40, 0x6f, 0x70, 0x65, 0x6e, 0x73, 0x73, 0x68, 0x2e, 0x63, 0x6f, 0x6d, 0x0, 0x0, 0x0, 0x1, 0x31}}, +} + +func TestRecvPacket(t *testing.T) { + for _, tt := range recvPacketTests { + r := bytes.NewReader(tt.b) + got, rest, _ := recvPacket(r) + if got != tt.want || !bytes.Equal(rest, tt.rest) { + t.Errorf("recvPacket(%#v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest) + } + } +} + +func TestSSHFxpOpenPacketreadonly(t *testing.T) { + var tests = []struct { + pflags uint32 + ok bool + }{ + { + pflags: ssh_FXF_READ, + ok: true, + }, + { + pflags: ssh_FXF_WRITE, + ok: false, + }, + { + pflags: ssh_FXF_READ | ssh_FXF_WRITE, + ok: false, + }, + } + + for _, tt := range tests { + p := &sshFxpOpenPacket{ + Pflags: tt.pflags, + } + + if want, got := tt.ok, p.readonly(); want != got { + t.Errorf("unexpected value for p.readonly(): want: %v, got: %v", + want, got) + } + } +} + +func TestSSHFxpOpenPackethasPflags(t *testing.T) { + var tests = []struct { + desc string + haveFlags uint32 + testFlags []uint32 + ok bool + }{ + { + desc: "have read, test against write", + haveFlags: ssh_FXF_READ, + testFlags: []uint32{ssh_FXF_WRITE}, + ok: false, + }, + { + desc: "have write, test against read", + haveFlags: ssh_FXF_WRITE, + testFlags: []uint32{ssh_FXF_READ}, + ok: false, + }, + { + desc: "have read+write, test against read", + haveFlags: ssh_FXF_READ | ssh_FXF_WRITE, + testFlags: []uint32{ssh_FXF_READ}, + ok: true, + }, + { + desc: "have read+write, test against write", + haveFlags: ssh_FXF_READ | ssh_FXF_WRITE, + testFlags: []uint32{ssh_FXF_WRITE}, + ok: true, + }, + { + desc: "have read+write, test against read+write", + haveFlags: ssh_FXF_READ | ssh_FXF_WRITE, + testFlags: []uint32{ssh_FXF_READ, ssh_FXF_WRITE}, + ok: true, + }, + } + + for _, tt := range tests { + t.Log(tt.desc) + + p := &sshFxpOpenPacket{ + Pflags: tt.haveFlags, + } + + if want, got := tt.ok, p.hasPflags(tt.testFlags...); want != got { + t.Errorf("unexpected value for p.hasPflags(%#v): want: %v, got: %v", + tt.testFlags, want, got) + } + } +} + +func BenchmarkMarshalInit(b *testing.B) { + for i := 0; i < b.N; i++ { + sp(sshFxInitPacket{ + Version: 3, + Extensions: []extensionPair{ + {"posix-rename@openssh.com", "1"}, + }, + }) + } +} + +func BenchmarkMarshalOpen(b *testing.B) { + for i := 0; i < b.N; i++ { + sp(sshFxpOpenPacket{ + ID: 1, + Path: "/home/test/some/random/path", + Pflags: flags(os.O_RDONLY), + }) + } +} + +func BenchmarkMarshalWriteWorstCase(b *testing.B) { + data := make([]byte, 32*1024) + for i := 0; i < b.N; i++ { + sp(sshFxpWritePacket{ + ID: 1, + Handle: "someopaquehandle", + Offset: 0, + Length: uint32(len(data)), + Data: data, + }) + } +} + +func BenchmarkMarshalWrite1k(b *testing.B) { + data := make([]byte, 1024) + for i := 0; i < b.N; i++ { + sp(sshFxpWritePacket{ + ID: 1, + Handle: "someopaquehandle", + Offset: 0, + Length: uint32(len(data)), + Data: data, + }) + } +} diff --git a/github.com/pkg/sftp/release.go b/github.com/pkg/sftp/release.go new file mode 100644 index 0000000000..b695528fde --- /dev/null +++ b/github.com/pkg/sftp/release.go @@ -0,0 +1,5 @@ +// +build !debug + +package sftp + +func debug(fmt string, args ...interface{}) {} diff --git a/github.com/pkg/sftp/request-example.go b/github.com/pkg/sftp/request-example.go new file mode 100644 index 0000000000..3333a8d6fe --- /dev/null +++ b/github.com/pkg/sftp/request-example.go @@ -0,0 +1,244 @@ +package sftp + +// This serves as an example of how to implement the request server handler as +// well as a dummy backend for testing. It implements an in-memory backend that +// works as a very simple filesystem with simple flat key-value lookup system. + +import ( + "bytes" + "fmt" + "io" + "os" + "path/filepath" + "sort" + "strconv" + "sync" + "time" +) + +// InMemHandler returns a Hanlders object with the test handlers +func InMemHandler() Handlers { + root := &root{ + files: make(map[string]*memFile), + } + root.memFile = newMemFile("/", true) + return Handlers{root, root, root, root} +} + +// Handlers +func (fs *root) Fileread(r Request) (io.ReaderAt, error) { + fs.filesLock.Lock() + defer fs.filesLock.Unlock() + file, err := fs.fetch(r.Filepath) + if err != nil { + return nil, err + } + if file.symlink != "" { + file, err = fs.fetch(file.symlink) + if err != nil { + return nil, err + } + } + return file.ReaderAt() +} + +func (fs *root) Filewrite(r Request) (io.WriterAt, error) { + fs.filesLock.Lock() + defer fs.filesLock.Unlock() + file, err := fs.fetch(r.Filepath) + if err == os.ErrNotExist { + dir, err := fs.fetch(filepath.Dir(r.Filepath)) + if err != nil { + return nil, err + } + if !dir.isdir { + return nil, os.ErrInvalid + } + file = newMemFile(r.Filepath, false) + fs.files[r.Filepath] = file + } + return file.WriterAt() +} + +func (fs *root) Filecmd(r Request) error { + fs.filesLock.Lock() + defer fs.filesLock.Unlock() + switch r.Method { + case "Setstat": + return nil + case "Rename": + file, err := fs.fetch(r.Filepath) + if err != nil { + return err + } + if _, ok := fs.files[r.Target]; ok { + return &os.LinkError{Op: "rename", Old: r.Filepath, New: r.Target, + Err: fmt.Errorf("dest file exists")} + } + fs.files[r.Target] = file + delete(fs.files, r.Filepath) + case "Rmdir", "Remove": + _, err := fs.fetch(filepath.Dir(r.Filepath)) + if err != nil { + return err + } + delete(fs.files, r.Filepath) + case "Mkdir": + _, err := fs.fetch(filepath.Dir(r.Filepath)) + if err != nil { + return err + } + fs.files[r.Filepath] = newMemFile(r.Filepath, true) + case "Symlink": + _, err := fs.fetch(r.Filepath) + if err != nil { + return err + } + link := newMemFile(r.Target, false) + link.symlink = r.Filepath + fs.files[r.Target] = link + } + return nil +} + +func (fs *root) Fileinfo(r Request) ([]os.FileInfo, error) { + fs.filesLock.Lock() + defer fs.filesLock.Unlock() + switch r.Method { + case "List": + var err error + batch_size := 10 + current_offset := 0 + if token := r.LsNext(); token != "" { + current_offset, err = strconv.Atoi(token) + if err != nil { + return nil, os.ErrInvalid + } + } + ordered_names := []string{} + for fn, _ := range fs.files { + if filepath.Dir(fn) == r.Filepath { + ordered_names = append(ordered_names, fn) + } + } + sort.Sort(sort.StringSlice(ordered_names)) + list := make([]os.FileInfo, len(ordered_names)) + for i, fn := range ordered_names { + list[i] = fs.files[fn] + } + if len(list) < current_offset { + return nil, io.EOF + } + new_offset := current_offset + batch_size + if new_offset > len(list) { + new_offset = len(list) + } + r.LsSave(strconv.Itoa(new_offset)) + return list[current_offset:new_offset], nil + case "Stat": + file, err := fs.fetch(r.Filepath) + if err != nil { + return nil, err + } + return []os.FileInfo{file}, nil + case "Readlink": + file, err := fs.fetch(r.Filepath) + if err != nil { + return nil, err + } + if file.symlink != "" { + file, err = fs.fetch(file.symlink) + if err != nil { + return nil, err + } + } + return []os.FileInfo{file}, nil + } + return nil, nil +} + +// In memory file-system-y thing that the Hanlders live on +type root struct { + *memFile + files map[string]*memFile + filesLock sync.Mutex +} + +func (fs *root) fetch(path string) (*memFile, error) { + if path == "/" { + return fs.memFile, nil + } + if file, ok := fs.files[path]; ok { + return file, nil + } + return nil, os.ErrNotExist +} + +// Implements os.FileInfo, Reader and Writer interfaces. +// These are the 3 interfaces necessary for the Handlers. +type memFile struct { + name string + modtime time.Time + symlink string + isdir bool + content []byte + contentLock sync.RWMutex +} + +// factory to make sure modtime is set +func newMemFile(name string, isdir bool) *memFile { + return &memFile{ + name: name, + modtime: time.Now(), + isdir: isdir, + } +} + +// Have memFile fulfill os.FileInfo interface +func (f *memFile) Name() string { return filepath.Base(f.name) } +func (f *memFile) Size() int64 { return int64(len(f.content)) } +func (f *memFile) Mode() os.FileMode { + ret := os.FileMode(0644) + if f.isdir { + ret = os.FileMode(0755) | os.ModeDir + } + if f.symlink != "" { + ret = os.FileMode(0777) | os.ModeSymlink + } + return ret +} +func (f *memFile) ModTime() time.Time { return f.modtime } +func (f *memFile) IsDir() bool { return f.isdir } +func (f *memFile) Sys() interface{} { + return fakeFileInfoSys() +} + +// Read/Write +func (f *memFile) ReaderAt() (io.ReaderAt, error) { + if f.isdir { + return nil, os.ErrInvalid + } + return bytes.NewReader(f.content), nil +} + +func (f *memFile) WriterAt() (io.WriterAt, error) { + if f.isdir { + return nil, os.ErrInvalid + } + return f, nil +} +func (f *memFile) WriteAt(p []byte, off int64) (int, error) { + // fmt.Println(string(p), off) + // mimic write delays, should be optional + time.Sleep(time.Microsecond * time.Duration(len(p))) + f.contentLock.Lock() + defer f.contentLock.Unlock() + plen := len(p) + int(off) + if plen >= len(f.content) { + nc := make([]byte, plen) + copy(nc, f.content) + f.content = nc + } + copy(f.content[off:], p) + return len(p), nil +} diff --git a/github.com/pkg/sftp/request-interfaces.go b/github.com/pkg/sftp/request-interfaces.go new file mode 100644 index 0000000000..46db449c4f --- /dev/null +++ b/github.com/pkg/sftp/request-interfaces.go @@ -0,0 +1,30 @@ +package sftp + +import ( + "io" + "os" +) + +// Interfaces are differentiated based on required returned values. +// All input arguments are to be pulled from Request (the only arg). + +// FileReader should return an io.Reader for the filepath +type FileReader interface { + Fileread(Request) (io.ReaderAt, error) +} + +// FileWriter should return an io.Writer for the filepath +type FileWriter interface { + Filewrite(Request) (io.WriterAt, error) +} + +// FileCmder should return an error (rename, remove, setstate, etc.) +type FileCmder interface { + Filecmd(Request) error +} + +// FileInfoer should return file listing info and errors (readdir, stat) +// note stat requests would return a list of 1 +type FileInfoer interface { + Fileinfo(Request) ([]os.FileInfo, error) +} diff --git a/github.com/pkg/sftp/request-readme.md b/github.com/pkg/sftp/request-readme.md new file mode 100644 index 0000000000..7a54ecf774 --- /dev/null +++ b/github.com/pkg/sftp/request-readme.md @@ -0,0 +1,48 @@ +# Request Based SFTP API + +The request based API allows for custom backends in a way similar to the http +package. In order to create a backend you need to implement 4 handler +interfaces; one for reading, one for writing, one for misc commands and one for +listing files. Each has 1 required method and in each case those methods take +the Request as the only parameter and they each return something different. +These 4 interfaces are enough to handle all the SFTP traffic in a simplified +manner. + +The Request structure has 5 public fields which you will deal with. + +- Method (string) - string name of incoming call +- Filepath (string) - path of file to act on +- Attrs ([]byte) - byte string of file attribute data +- Target (string) - target path for renames and sym-links + +Below are the methods and a brief description of what they need to do. + +### Fileread(*Request) (io.Reader, error) + +Handler for "Get" method and returns an io.Reader for the file which the server +then sends to the client. + +### Filewrite(*Request) (io.Writer, error) + +Handler for "Put" method and returns an io.Writer for the file which the server +then writes the uploaded file to. + +### Filecmd(*Request) error + +Handles "SetStat", "Rename", "Rmdir", "Mkdir" and "Symlink" methods. Makes the +appropriate changes and returns nil for success or an filesystem like error +(eg. os.ErrNotExist). + +### Fileinfo(*Request) ([]os.FileInfo, error) + +Handles "List", "Stat", "Readlink" methods. Gathers/creates FileInfo structs +with the data on the files and returns in a list (list of 1 for Stat and +Readlink). + + +## TODO + +- Add support for API users to see trace/debugging info of what is going on +inside SFTP server. +- Consider adding support for SFTP file append only mode. + diff --git a/github.com/pkg/sftp/request-server.go b/github.com/pkg/sftp/request-server.go new file mode 100644 index 0000000000..b51a10b350 --- /dev/null +++ b/github.com/pkg/sftp/request-server.go @@ -0,0 +1,231 @@ +package sftp + +import ( + "encoding" + "io" + "os" + "path/filepath" + "strconv" + "sync" + "syscall" + + "github.com/pkg/errors" +) + +var maxTxPacket uint32 = 1 << 15 + +type handleHandler func(string) string + +// Handlers contains the 4 SFTP server request handlers. +type Handlers struct { + FileGet FileReader + FilePut FileWriter + FileCmd FileCmder + FileInfo FileInfoer +} + +// RequestServer abstracts the sftp protocol with an http request-like protocol +type RequestServer struct { + *serverConn + Handlers Handlers + pktMgr packetManager + openRequests map[string]Request + openRequestLock sync.RWMutex + handleCount int +} + +// NewRequestServer creates/allocates/returns new RequestServer. +// Normally there there will be one server per user-session. +func NewRequestServer(rwc io.ReadWriteCloser, h Handlers) *RequestServer { + svrConn := &serverConn{ + conn: conn{ + Reader: rwc, + WriteCloser: rwc, + }, + } + return &RequestServer{ + serverConn: svrConn, + Handlers: h, + pktMgr: newPktMgr(svrConn), + openRequests: make(map[string]Request), + } +} + +func (rs *RequestServer) nextRequest(r Request) string { + rs.openRequestLock.Lock() + defer rs.openRequestLock.Unlock() + rs.handleCount++ + handle := strconv.Itoa(rs.handleCount) + rs.openRequests[handle] = r + return handle +} + +func (rs *RequestServer) getRequest(handle string) (Request, bool) { + rs.openRequestLock.RLock() + defer rs.openRequestLock.RUnlock() + r, ok := rs.openRequests[handle] + return r, ok +} + +func (rs *RequestServer) closeRequest(handle string) { + rs.openRequestLock.Lock() + defer rs.openRequestLock.Unlock() + if r, ok := rs.openRequests[handle]; ok { + r.close() + delete(rs.openRequests, handle) + } +} + +// Close the read/write/closer to trigger exiting the main server loop +func (rs *RequestServer) Close() error { return rs.conn.Close() } + +// Serve requests for user session +func (rs *RequestServer) Serve() error { + var wg sync.WaitGroup + runWorker := func(ch requestChan) { + wg.Add(1) + go func() { + defer wg.Done() + if err := rs.packetWorker(ch); err != nil { + rs.conn.Close() // shuts down recvPacket + } + }() + } + pktChan := rs.pktMgr.workerChan(runWorker) + + var err error + var pkt requestPacket + var pktType uint8 + var pktBytes []byte + for { + pktType, pktBytes, err = rs.recvPacket() + if err != nil { + break + } + + pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes}) + if err != nil { + debug("makePacket err: %v", err) + rs.conn.Close() // shuts down recvPacket + break + } + + pktChan <- pkt + } + + close(pktChan) // shuts down sftpServerWorkers + wg.Wait() // wait for all workers to exit + + return err +} + +func (rs *RequestServer) packetWorker(pktChan chan requestPacket) error { + for pkt := range pktChan { + var rpkt responsePacket + switch pkt := pkt.(type) { + case *sshFxInitPacket: + rpkt = sshFxVersionPacket{sftpProtocolVersion, nil} + case *sshFxpClosePacket: + handle := pkt.getHandle() + rs.closeRequest(handle) + rpkt = statusFromError(pkt, nil) + case *sshFxpRealpathPacket: + rpkt = cleanPath(pkt) + case isOpener: + handle := rs.nextRequest(requestFromPacket(pkt)) + rpkt = sshFxpHandlePacket{pkt.id(), handle} + case *sshFxpFstatPacket: + handle := pkt.getHandle() + request, ok := rs.getRequest(handle) + if !ok { + rpkt = statusFromError(pkt, syscall.EBADF) + } else { + request = requestFromPacket( + &sshFxpStatPacket{ID: pkt.id(), Path: request.Filepath}) + rpkt = rs.handle(request, pkt) + } + case *sshFxpFsetstatPacket: + handle := pkt.getHandle() + request, ok := rs.getRequest(handle) + if !ok { + rpkt = statusFromError(pkt, syscall.EBADF) + } else { + request = requestFromPacket( + &sshFxpSetstatPacket{ID: pkt.id(), Path: request.Filepath, + Flags: pkt.Flags, Attrs: pkt.Attrs, + }) + rpkt = rs.handle(request, pkt) + } + case hasHandle: + handle := pkt.getHandle() + request, ok := rs.getRequest(handle) + request.update(pkt) + if !ok { + rpkt = statusFromError(pkt, syscall.EBADF) + } else { + rpkt = rs.handle(request, pkt) + } + case hasPath: + request := requestFromPacket(pkt) + rpkt = rs.handle(request, pkt) + default: + return errors.Errorf("unexpected packet type %T", pkt) + } + + err := rs.sendPacket(rpkt) + if err != nil { + return err + } + } + return nil +} + +func cleanPath(pkt *sshFxpRealpathPacket) responsePacket { + path := pkt.getPath() + if !filepath.IsAbs(path) { + path = "/" + path + } // all paths are absolute + + cleaned_path := filepath.Clean(path) + return &sshFxpNamePacket{ + ID: pkt.id(), + NameAttrs: []sshFxpNameAttr{{ + Name: cleaned_path, + LongName: cleaned_path, + Attrs: emptyFileStat, + }}, + } +} + +func (rs *RequestServer) handle(request Request, pkt requestPacket) responsePacket { + // fmt.Println("Request Method: ", request.Method) + rpkt, err := request.handle(rs.Handlers) + if err != nil { + err = errorAdapter(err) + rpkt = statusFromError(pkt, err) + } + return rpkt +} + +// Wrap underlying connection methods to use packetManager +func (rs *RequestServer) sendPacket(m encoding.BinaryMarshaler) error { + if pkt, ok := m.(responsePacket); ok { + rs.pktMgr.readyPacket(pkt) + } else { + return errors.Errorf("unexpected packet type %T", m) + } + return nil +} + +func (rs *RequestServer) sendError(p ider, err error) error { + return rs.sendPacket(statusFromError(p, err)) +} + +// os.ErrNotExist should convert to ssh_FX_NO_SUCH_FILE, but is not recognized +// by statusFromError. So we convert to syscall.ENOENT which it does. +func errorAdapter(err error) error { + if err == os.ErrNotExist { + return syscall.ENOENT + } + return err +} diff --git a/github.com/pkg/sftp/request-server_test.go b/github.com/pkg/sftp/request-server_test.go new file mode 100644 index 0000000000..ee9621be7d --- /dev/null +++ b/github.com/pkg/sftp/request-server_test.go @@ -0,0 +1,329 @@ +package sftp + +import ( + "fmt" + "io" + "net" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +var _ = fmt.Print + +type csPair struct { + cli *Client + svr *RequestServer +} + +// these must be closed in order, else client.Close will hang +func (cs csPair) Close() { + cs.svr.Close() + cs.cli.Close() + os.Remove(sock) +} + +func (cs csPair) testHandler() *root { + return cs.svr.Handlers.FileGet.(*root) +} + +const sock = "/tmp/rstest.sock" + +func clientRequestServerPair(t *testing.T) *csPair { + ready := make(chan bool) + os.Remove(sock) // either this or signal handling + var server *RequestServer + go func() { + l, err := net.Listen("unix", sock) + if err != nil { + // neither assert nor t.Fatal reliably exit before Accept errors + panic(err) + } + ready <- true + fd, err := l.Accept() + assert.Nil(t, err) + handlers := InMemHandler() + server = NewRequestServer(fd, handlers) + server.Serve() + }() + <-ready + defer os.Remove(sock) + c, err := net.Dial("unix", sock) + assert.Nil(t, err) + client, err := NewClientPipe(c, c) + if err != nil { + t.Fatalf("%+v\n", err) + } + return &csPair{client, server} +} + +// after adding logging, maybe check log to make sure packet handling +// was split over more than one worker +func TestRequestSplitWrite(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + w, err := p.cli.Create("/foo") + assert.Nil(t, err) + p.cli.maxPacket = 3 // force it to send in small chunks + contents := "one two three four five six seven eight nine ten" + w.Write([]byte(contents)) + w.Close() + r := p.testHandler() + f, _ := r.fetch("/foo") + assert.Equal(t, contents, string(f.content)) +} + +func TestRequestCache(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + foo := NewRequest("", "foo") + bar := NewRequest("", "bar") + fh := p.svr.nextRequest(foo) + bh := p.svr.nextRequest(bar) + assert.Len(t, p.svr.openRequests, 2) + _foo, ok := p.svr.getRequest(fh) + assert.Equal(t, foo, _foo) + assert.True(t, ok) + _, ok = p.svr.getRequest("zed") + assert.False(t, ok) + p.svr.closeRequest(fh) + p.svr.closeRequest(bh) + assert.Len(t, p.svr.openRequests, 0) +} + +func TestRequestCacheState(t *testing.T) { + // test operation that uses open/close + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + assert.Nil(t, err) + assert.Len(t, p.svr.openRequests, 0) + // test operation that doesn't open/close + err = p.cli.Remove("/foo") + assert.Nil(t, err) + assert.Len(t, p.svr.openRequests, 0) +} + +func putTestFile(cli *Client, path, content string) (int, error) { + w, err := cli.Create(path) + if err == nil { + defer w.Close() + return w.Write([]byte(content)) + } + return 0, err +} + +func TestRequestWrite(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + n, err := putTestFile(p.cli, "/foo", "hello") + assert.Nil(t, err) + assert.Equal(t, 5, n) + r := p.testHandler() + f, err := r.fetch("/foo") + assert.Nil(t, err) + assert.False(t, f.isdir) + assert.Equal(t, f.content, []byte("hello")) +} + +// needs fail check +func TestRequestFilename(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + assert.Nil(t, err) + r := p.testHandler() + f, err := r.fetch("/foo") + assert.Nil(t, err) + assert.Equal(t, f.Name(), "foo") +} + +func TestRequestRead(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + assert.Nil(t, err) + rf, err := p.cli.Open("/foo") + assert.Nil(t, err) + defer rf.Close() + contents := make([]byte, 5) + n, err := rf.Read(contents) + if err != nil && err != io.EOF { + t.Fatalf("err: %v", err) + } + assert.Equal(t, 5, n) + assert.Equal(t, "hello", string(contents[0:5])) +} + +func TestRequestReadFail(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + rf, err := p.cli.Open("/foo") + assert.Nil(t, err) + contents := make([]byte, 5) + n, err := rf.Read(contents) + assert.Equal(t, n, 0) + assert.Exactly(t, os.ErrNotExist, err) +} + +func TestRequestOpen(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + fh, err := p.cli.Open("foo") + assert.Nil(t, err) + err = fh.Close() + assert.Nil(t, err) +} + +func TestRequestMkdir(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + err := p.cli.Mkdir("/foo") + assert.Nil(t, err) + r := p.testHandler() + f, err := r.fetch("/foo") + assert.Nil(t, err) + assert.True(t, f.isdir) +} + +func TestRequestRemove(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + assert.Nil(t, err) + r := p.testHandler() + _, err = r.fetch("/foo") + assert.Nil(t, err) + err = p.cli.Remove("/foo") + assert.Nil(t, err) + _, err = r.fetch("/foo") + assert.Equal(t, err, os.ErrNotExist) +} + +func TestRequestRename(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + assert.Nil(t, err) + r := p.testHandler() + _, err = r.fetch("/foo") + assert.Nil(t, err) + err = p.cli.Rename("/foo", "/bar") + assert.Nil(t, err) + _, err = r.fetch("/bar") + assert.Nil(t, err) + _, err = r.fetch("/foo") + assert.Equal(t, err, os.ErrNotExist) +} + +func TestRequestRenameFail(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + assert.Nil(t, err) + _, err = putTestFile(p.cli, "/bar", "goodbye") + assert.Nil(t, err) + err = p.cli.Rename("/foo", "/bar") + assert.IsType(t, &StatusError{}, err) +} + +func TestRequestStat(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + assert.Nil(t, err) + fi, err := p.cli.Stat("/foo") + assert.Equal(t, fi.Name(), "foo") + assert.Equal(t, fi.Size(), int64(5)) + assert.Equal(t, fi.Mode(), os.FileMode(0644)) + assert.NoError(t, testOsSys(fi.Sys())) +} + +// NOTE: Setstat is a noop in the request server tests, but we want to test +// that is does nothing without crapping out. +func TestRequestSetstat(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + assert.Nil(t, err) + mode := os.FileMode(0644) + err = p.cli.Chmod("/foo", mode) + assert.Nil(t, err) + fi, err := p.cli.Stat("/foo") + assert.Nil(t, err) + assert.Equal(t, fi.Name(), "foo") + assert.Equal(t, fi.Size(), int64(5)) + assert.Equal(t, fi.Mode(), os.FileMode(0644)) + assert.NoError(t, testOsSys(fi.Sys())) +} + +func TestRequestFstat(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + assert.Nil(t, err) + fp, err := p.cli.Open("/foo") + assert.Nil(t, err) + fi, err := fp.Stat() + assert.Nil(t, err) + assert.Equal(t, fi.Name(), "foo") + assert.Equal(t, fi.Size(), int64(5)) + assert.Equal(t, fi.Mode(), os.FileMode(0644)) + assert.NoError(t, testOsSys(fi.Sys())) +} + +func TestRequestStatFail(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + fi, err := p.cli.Stat("/foo") + assert.Nil(t, fi) + assert.True(t, os.IsNotExist(err)) +} + +func TestRequestSymlink(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + assert.Nil(t, err) + err = p.cli.Symlink("/foo", "/bar") + assert.Nil(t, err) + r := p.testHandler() + fi, err := r.fetch("/bar") + assert.Nil(t, err) + assert.True(t, fi.Mode()&os.ModeSymlink == os.ModeSymlink) +} + +func TestRequestSymlinkFail(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + err := p.cli.Symlink("/foo", "/bar") + assert.True(t, os.IsNotExist(err)) +} + +func TestRequestReadlink(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + assert.Nil(t, err) + err = p.cli.Symlink("/foo", "/bar") + assert.Nil(t, err) + rl, err := p.cli.ReadLink("/bar") + assert.Nil(t, err) + assert.Equal(t, "foo", rl) +} + +func TestRequestReaddir(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + for i := 0; i < 100; i++ { + fname := fmt.Sprintf("/foo_%02d", i) + _, err := putTestFile(p.cli, fname, fname) + assert.Nil(t, err) + } + di, err := p.cli.ReadDir("/") + assert.Nil(t, err) + assert.Len(t, di, 100) + names := []string{di[18].Name(), di[81].Name()} + assert.Equal(t, []string{"foo_18", "foo_81"}, names) +} diff --git a/github.com/pkg/sftp/request-unix.go b/github.com/pkg/sftp/request-unix.go new file mode 100644 index 0000000000..a71a8980ae --- /dev/null +++ b/github.com/pkg/sftp/request-unix.go @@ -0,0 +1,23 @@ +// +build !windows + +package sftp + +import ( + "errors" + "syscall" +) + +func fakeFileInfoSys() interface{} { + return &syscall.Stat_t{Uid: 65534, Gid: 65534} +} + +func testOsSys(sys interface{}) error { + fstat := sys.(*FileStat) + if fstat.UID != uint32(65534) { + return errors.New("Uid failed to match.") + } + if fstat.GID != uint32(65534) { + return errors.New("Gid failed to match:") + } + return nil +} diff --git a/github.com/pkg/sftp/request.go b/github.com/pkg/sftp/request.go new file mode 100644 index 0000000000..6d7d6ffdf9 --- /dev/null +++ b/github.com/pkg/sftp/request.go @@ -0,0 +1,334 @@ +package sftp + +import ( + "io" + "os" + "path" + "path/filepath" + "sync" + "syscall" + + "github.com/pkg/errors" +) + +// Request contains the data and state for the incoming service request. +type Request struct { + // Get, Put, Setstat, Stat, Rename, Remove + // Rmdir, Mkdir, List, Readlink, Symlink + Method string + Filepath string + Flags uint32 + Attrs []byte // convert to sub-struct + Target string // for renames and sym-links + // packet data + pkt_id uint32 + packets chan packet_data + // reader/writer/readdir from handlers + stateLock *sync.RWMutex + state *state +} + +type state struct { + writerAt io.WriterAt + readerAt io.ReaderAt + endofdir bool // in case handler doesn't use EOF on file list + readdirToken string +} + +type packet_data struct { + id uint32 + data []byte + length uint32 + offset int64 +} + +// New Request initialized based on packet data +func requestFromPacket(pkt hasPath) Request { + method := requestMethod(pkt) + request := NewRequest(method, pkt.getPath()) + request.pkt_id = pkt.id() + switch p := pkt.(type) { + case *sshFxpSetstatPacket: + request.Flags = p.Flags + request.Attrs = p.Attrs.([]byte) + case *sshFxpRenamePacket: + request.Target = filepath.Clean(p.Newpath) + case *sshFxpSymlinkPacket: + request.Target = filepath.Clean(p.Linkpath) + } + return request +} + +// NewRequest creates a new Request object. +func NewRequest(method, path string) Request { + request := Request{Method: method, Filepath: filepath.Clean(path)} + request.packets = make(chan packet_data, sftpServerWorkerCount) + request.state = &state{} + request.stateLock = &sync.RWMutex{} + return request +} + +// LsSave takes a token to keep track of file list batches. Openssh uses a +// batch size of 100, so I suggest sticking close to that. +func (r Request) LsSave(token string) { + r.stateLock.RLock() + defer r.stateLock.RUnlock() + r.state.readdirToken = token +} + +// LsNext should return the token from the previous call to know which batch +// to return next. +func (r Request) LsNext() string { + r.stateLock.RLock() + defer r.stateLock.RUnlock() + return r.state.readdirToken +} + +// manage file read/write state +func (r Request) setFileState(s interface{}) { + r.stateLock.Lock() + defer r.stateLock.Unlock() + switch s := s.(type) { + case io.WriterAt: + r.state.writerAt = s + case io.ReaderAt: + r.state.readerAt = s + + } +} + +func (r Request) getWriter() io.WriterAt { + r.stateLock.RLock() + defer r.stateLock.RUnlock() + return r.state.writerAt +} + +func (r Request) getReader() io.ReaderAt { + r.stateLock.RLock() + defer r.stateLock.RUnlock() + return r.state.readerAt +} + +// For backwards compatibility. The Handler didn't have batch handling at +// first, and just always assumed 1 batch. This preserves that behavior. +func (r Request) setEOD(eod bool) { + r.stateLock.RLock() + defer r.stateLock.RUnlock() + r.state.endofdir = eod +} + +func (r Request) getEOD() bool { + r.stateLock.RLock() + defer r.stateLock.RUnlock() + return r.state.endofdir +} + +// Close reader/writer if possible +func (r Request) close() { + rd := r.getReader() + if c, ok := rd.(io.Closer); ok { + c.Close() + } + wt := r.getWriter() + if c, ok := wt.(io.Closer); ok { + c.Close() + } +} + +// push packet_data into fifo +func (r Request) pushPacket(pd packet_data) { + r.packets <- pd +} + +// pop packet_data into fifo +func (r *Request) popPacket() packet_data { + return <-r.packets +} + +// called from worker to handle packet/request +func (r Request) handle(handlers Handlers) (responsePacket, error) { + var err error + var rpkt responsePacket + switch r.Method { + case "Get": + rpkt, err = fileget(handlers.FileGet, r) + case "Put": // add "Append" to this to handle append only file writes + rpkt, err = fileput(handlers.FilePut, r) + case "Setstat", "Rename", "Rmdir", "Mkdir", "Symlink", "Remove": + rpkt, err = filecmd(handlers.FileCmd, r) + case "List", "Stat", "Readlink": + rpkt, err = fileinfo(handlers.FileInfo, r) + default: + return rpkt, errors.Errorf("unexpected method: %s", r.Method) + } + return rpkt, err +} + +// wrap FileReader handler +func fileget(h FileReader, r Request) (responsePacket, error) { + var err error + reader := r.getReader() + if reader == nil { + reader, err = h.Fileread(r) + if err != nil { + return nil, err + } + r.setFileState(reader) + } + + pd := r.popPacket() + data := make([]byte, clamp(pd.length, maxTxPacket)) + n, err := reader.ReadAt(data, pd.offset) + if err != nil && (err != io.EOF || n == 0) { + return nil, err + } + return &sshFxpDataPacket{ + ID: pd.id, + Length: uint32(n), + Data: data[:n], + }, nil +} + +// wrap FileWriter handler +func fileput(h FileWriter, r Request) (responsePacket, error) { + var err error + writer := r.getWriter() + if writer == nil { + writer, err = h.Filewrite(r) + if err != nil { + return nil, err + } + r.setFileState(writer) + } + + pd := r.popPacket() + _, err = writer.WriteAt(pd.data, pd.offset) + if err != nil { + return nil, err + } + return &sshFxpStatusPacket{ + ID: pd.id, + StatusError: StatusError{ + Code: ssh_FX_OK, + }}, nil +} + +// wrap FileCmder handler +func filecmd(h FileCmder, r Request) (responsePacket, error) { + err := h.Filecmd(r) + if err != nil { + return nil, err + } + return &sshFxpStatusPacket{ + ID: r.pkt_id, + StatusError: StatusError{ + Code: ssh_FX_OK, + }}, nil +} + +// wrap FileInfoer handler +func fileinfo(h FileInfoer, r Request) (responsePacket, error) { + if r.getEOD() { + return nil, io.EOF + } + finfo, err := h.Fileinfo(r) + if err != nil { + return nil, err + } + + switch r.Method { + case "List": + pd := r.popPacket() + dirname := path.Base(r.Filepath) + ret := &sshFxpNamePacket{ID: pd.id} + for _, fi := range finfo { + ret.NameAttrs = append(ret.NameAttrs, sshFxpNameAttr{ + Name: fi.Name(), + LongName: runLs(dirname, fi), + Attrs: []interface{}{fi}, + }) + } + // No entries means we should return EOF as the Handler didn't. + if len(finfo) == 0 { + return nil, io.EOF + } + // If files are returned but no token is set, return EOF next call. + if r.LsNext() == "" { + r.setEOD(true) + } + return ret, nil + case "Stat": + if len(finfo) == 0 { + err = &os.PathError{Op: "stat", Path: r.Filepath, + Err: syscall.ENOENT} + return nil, err + } + return &sshFxpStatResponse{ + ID: r.pkt_id, + info: finfo[0], + }, nil + case "Readlink": + if len(finfo) == 0 { + err = &os.PathError{Op: "readlink", Path: r.Filepath, + Err: syscall.ENOENT} + return nil, err + } + filename := finfo[0].Name() + return &sshFxpNamePacket{ + ID: r.pkt_id, + NameAttrs: []sshFxpNameAttr{{ + Name: filename, + LongName: filename, + Attrs: emptyFileStat, + }}, + }, nil + } + return nil, err +} + +// file data for additional read/write packets +func (r *Request) update(p hasHandle) error { + pd := packet_data{id: p.id()} + switch p := p.(type) { + case *sshFxpReadPacket: + r.Method = "Get" + pd.length = p.Len + pd.offset = int64(p.Offset) + case *sshFxpWritePacket: + r.Method = "Put" + pd.data = p.Data + pd.length = p.Length + pd.offset = int64(p.Offset) + case *sshFxpReaddirPacket: + r.Method = "List" + default: + return errors.Errorf("unexpected packet type %T", p) + } + r.pushPacket(pd) + return nil +} + +// init attributes of request object from packet data +func requestMethod(p hasPath) (method string) { + switch p.(type) { + case *sshFxpOpenPacket, *sshFxpOpendirPacket: + method = "Open" + case *sshFxpSetstatPacket: + method = "Setstat" + case *sshFxpRenamePacket: + method = "Rename" + case *sshFxpSymlinkPacket: + method = "Symlink" + case *sshFxpRemovePacket: + method = "Remove" + case *sshFxpStatPacket, *sshFxpLstatPacket: + method = "Stat" + case *sshFxpRmdirPacket: + method = "Rmdir" + case *sshFxpReadlinkPacket: + method = "Readlink" + case *sshFxpMkdirPacket: + method = "Mkdir" + } + return method +} diff --git a/github.com/pkg/sftp/request_test.go b/github.com/pkg/sftp/request_test.go new file mode 100644 index 0000000000..e537fc16e7 --- /dev/null +++ b/github.com/pkg/sftp/request_test.go @@ -0,0 +1,182 @@ +package sftp + +import ( + "sync" + + "github.com/stretchr/testify/assert" + + "bytes" + "errors" + "io" + "os" + "testing" +) + +type testHandler struct { + filecontents []byte // dummy contents + output io.WriterAt // dummy file out + err error // dummy error, should be file related +} + +func (t *testHandler) Fileread(r Request) (io.ReaderAt, error) { + if t.err != nil { + return nil, t.err + } + return bytes.NewReader(t.filecontents), nil +} + +func (t *testHandler) Filewrite(r Request) (io.WriterAt, error) { + if t.err != nil { + return nil, t.err + } + return io.WriterAt(t.output), nil +} + +func (t *testHandler) Filecmd(r Request) error { + if t.err != nil { + return t.err + } + return nil +} + +func (t *testHandler) Fileinfo(r Request) ([]os.FileInfo, error) { + if t.err != nil { + return nil, t.err + } + f, err := os.Open(r.Filepath) + if err != nil { + return nil, err + } + fi, err := f.Stat() + if err != nil { + return nil, err + } + return []os.FileInfo{fi}, nil +} + +// make sure len(fakefile) == len(filecontents) +type fakefile [10]byte + +var filecontents = []byte("file-data.") + +func testRequest(method string) Request { + request := Request{ + Filepath: "./request_test.go", + Method: method, + Attrs: []byte("foo"), + Target: "foo", + packets: make(chan packet_data, sftpServerWorkerCount), + state: &state{}, + stateLock: &sync.RWMutex{}, + } + for _, p := range []packet_data{ + packet_data{id: 1, data: filecontents[:5], length: 5}, + packet_data{id: 2, data: filecontents[5:], length: 5, offset: 5}} { + request.packets <- p + } + return request +} + +func (ff *fakefile) WriteAt(p []byte, off int64) (int, error) { + n := copy(ff[off:], p) + return n, nil +} + +func (ff fakefile) string() string { + b := make([]byte, len(ff)) + copy(b, ff[:]) + return string(b) +} + +func newTestHandlers() Handlers { + handler := &testHandler{ + filecontents: filecontents, + output: &fakefile{}, + err: nil, + } + return Handlers{ + FileGet: handler, + FilePut: handler, + FileCmd: handler, + FileInfo: handler, + } +} + +func (h Handlers) getOutString() string { + handler := h.FilePut.(*testHandler) + return handler.output.(*fakefile).string() +} + +var errTest = errors.New("test error") + +func (h *Handlers) returnError() { + handler := h.FilePut.(*testHandler) + handler.err = errTest +} + +func statusOk(t *testing.T, p interface{}) { + if pkt, ok := p.(*sshFxpStatusPacket); ok { + assert.Equal(t, pkt.StatusError.Code, uint32(ssh_FX_OK)) + } +} + +func TestRequestGet(t *testing.T) { + handlers := newTestHandlers() + request := testRequest("Get") + // req.length is 5, so we test reads in 5 byte chunks + for i, txt := range []string{"file-", "data."} { + pkt, err := request.handle(handlers) + assert.Nil(t, err) + dpkt := pkt.(*sshFxpDataPacket) + assert.Equal(t, dpkt.id(), uint32(i+1)) + assert.Equal(t, string(dpkt.Data), txt) + } +} + +func TestRequestPut(t *testing.T) { + handlers := newTestHandlers() + request := testRequest("Put") + pkt, err := request.handle(handlers) + assert.Nil(t, err) + statusOk(t, pkt) + pkt, err = request.handle(handlers) + assert.Nil(t, err) + statusOk(t, pkt) + assert.Equal(t, "file-data.", handlers.getOutString()) +} + +func TestRequestCmdr(t *testing.T) { + handlers := newTestHandlers() + request := testRequest("Mkdir") + pkt, err := request.handle(handlers) + assert.Nil(t, err) + statusOk(t, pkt) + + handlers.returnError() + pkt, err = request.handle(handlers) + assert.Nil(t, pkt) + assert.Equal(t, err, errTest) +} + +func TestRequestInfoList(t *testing.T) { testInfoMethod(t, "List") } +func TestRequestInfoReadlink(t *testing.T) { testInfoMethod(t, "Readlink") } +func TestRequestInfoStat(t *testing.T) { + handlers := newTestHandlers() + request := testRequest("Stat") + pkt, err := request.handle(handlers) + assert.Nil(t, err) + spkt, ok := pkt.(*sshFxpStatResponse) + assert.True(t, ok) + assert.Equal(t, spkt.info.Name(), "request_test.go") +} + +func testInfoMethod(t *testing.T, method string) { + handlers := newTestHandlers() + request := testRequest(method) + pkt, err := request.handle(handlers) + assert.Nil(t, err) + npkt, ok := pkt.(*sshFxpNamePacket) + assert.True(t, ok) + assert.IsType(t, sshFxpNameAttr{}, npkt.NameAttrs[0]) + assert.Equal(t, npkt.NameAttrs[0].Name, "request_test.go") +} diff --git a/github.com/pkg/sftp/request_windows.go b/github.com/pkg/sftp/request_windows.go new file mode 100644 index 0000000000..94d306b6e9 --- /dev/null +++ b/github.com/pkg/sftp/request_windows.go @@ -0,0 +1,11 @@ +package sftp + +import "syscall" + +func fakeFileInfoSys() interface{} { + return syscall.Win32FileAttributeData{} +} + +func testOsSys(sys interface{}) error { + return nil +} diff --git a/github.com/pkg/sftp/server.go b/github.com/pkg/sftp/server.go new file mode 100644 index 0000000000..42afc9cbf1 --- /dev/null +++ b/github.com/pkg/sftp/server.go @@ -0,0 +1,575 @@ +package sftp + +// sftp server counterpart + +import ( + "encoding" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "strconv" + "sync" + "syscall" + "time" + + "github.com/pkg/errors" +) + +const ( + sftpServerWorkerCount = 8 +) + +// Server is an SSH File Transfer Protocol (sftp) server. +// This is intended to provide the sftp subsystem to an ssh server daemon. +// This implementation currently supports most of sftp server protocol version 3, +// as specified at http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02 +type Server struct { + *serverConn + debugStream io.Writer + readOnly bool + pktMgr packetManager + openFiles map[string]*os.File + openFilesLock sync.RWMutex + handleCount int + maxTxPacket uint32 +} + +func (svr *Server) nextHandle(f *os.File) string { + svr.openFilesLock.Lock() + defer svr.openFilesLock.Unlock() + svr.handleCount++ + handle := strconv.Itoa(svr.handleCount) + svr.openFiles[handle] = f + return handle +} + +func (svr *Server) closeHandle(handle string) error { + svr.openFilesLock.Lock() + defer svr.openFilesLock.Unlock() + if f, ok := svr.openFiles[handle]; ok { + delete(svr.openFiles, handle) + return f.Close() + } + + return syscall.EBADF +} + +func (svr *Server) getHandle(handle string) (*os.File, bool) { + svr.openFilesLock.RLock() + defer svr.openFilesLock.RUnlock() + f, ok := svr.openFiles[handle] + return f, ok +} + +type serverRespondablePacket interface { + encoding.BinaryUnmarshaler + id() uint32 + respond(svr *Server) error +} + +// NewServer creates a new Server instance around the provided streams, serving +// content from the root of the filesystem. Optionally, ServerOption +// functions may be specified to further configure the Server. +// +// A subsequent call to Serve() is required to begin serving files over SFTP. +func NewServer(rwc io.ReadWriteCloser, options ...ServerOption) (*Server, error) { + svrConn := &serverConn{ + conn: conn{ + Reader: rwc, + WriteCloser: rwc, + }, + } + s := &Server{ + serverConn: svrConn, + debugStream: ioutil.Discard, + pktMgr: newPktMgr(svrConn), + openFiles: make(map[string]*os.File), + maxTxPacket: 1 << 15, + } + + for _, o := range options { + if err := o(s); err != nil { + return nil, err + } + } + + return s, nil +} + +// A ServerOption is a function which applies configuration to a Server. +type ServerOption func(*Server) error + +// WithDebug enables Server debugging output to the supplied io.Writer. +func WithDebug(w io.Writer) ServerOption { + return func(s *Server) error { + s.debugStream = w + return nil + } +} + +// ReadOnly configures a Server to serve files in read-only mode. +func ReadOnly() ServerOption { + return func(s *Server) error { + s.readOnly = true + return nil + } +} + +type rxPacket struct { + pktType fxp + pktBytes []byte +} + +// Up to N parallel servers +func (svr *Server) sftpServerWorker(pktChan chan requestPacket) error { + for pkt := range pktChan { + + // readonly checks + readonly := true + switch pkt := pkt.(type) { + case notReadOnly: + readonly = false + case *sshFxpOpenPacket: + readonly = pkt.readonly() + case *sshFxpExtendedPacket: + readonly = pkt.SpecificPacket.readonly() + } + + // If server is operating read-only and a write operation is requested, + // return permission denied + if !readonly && svr.readOnly { + if err := svr.sendError(pkt, syscall.EPERM); err != nil { + return errors.Wrap(err, "failed to send read only packet response") + } + continue + } + + if err := handlePacket(svr, pkt); err != nil { + return err + } + } + return nil +} + +func handlePacket(s *Server, p interface{}) error { + switch p := p.(type) { + case *sshFxInitPacket: + return s.sendPacket(sshFxVersionPacket{sftpProtocolVersion, nil}) + case *sshFxpStatPacket: + // stat the requested file + info, err := os.Stat(p.Path) + if err != nil { + return s.sendError(p, err) + } + return s.sendPacket(sshFxpStatResponse{ + ID: p.ID, + info: info, + }) + case *sshFxpLstatPacket: + // stat the requested file + info, err := os.Lstat(p.Path) + if err != nil { + return s.sendError(p, err) + } + return s.sendPacket(sshFxpStatResponse{ + ID: p.ID, + info: info, + }) + case *sshFxpFstatPacket: + f, ok := s.getHandle(p.Handle) + if !ok { + return s.sendError(p, syscall.EBADF) + } + + info, err := f.Stat() + if err != nil { + return s.sendError(p, err) + } + + return s.sendPacket(sshFxpStatResponse{ + ID: p.ID, + info: info, + }) + case *sshFxpMkdirPacket: + // TODO FIXME: ignore flags field + err := os.Mkdir(p.Path, 0755) + return s.sendError(p, err) + case *sshFxpRmdirPacket: + err := os.Remove(p.Path) + return s.sendError(p, err) + case *sshFxpRemovePacket: + err := os.Remove(p.Filename) + return s.sendError(p, err) + case *sshFxpRenamePacket: + err := os.Rename(p.Oldpath, p.Newpath) + return s.sendError(p, err) + case *sshFxpSymlinkPacket: + err := os.Symlink(p.Targetpath, p.Linkpath) + return s.sendError(p, err) + case *sshFxpClosePacket: + return s.sendError(p, s.closeHandle(p.Handle)) + case *sshFxpReadlinkPacket: + f, err := os.Readlink(p.Path) + if err != nil { + return s.sendError(p, err) + } + + return s.sendPacket(sshFxpNamePacket{ + ID: p.ID, + NameAttrs: []sshFxpNameAttr{{ + Name: f, + LongName: f, + Attrs: emptyFileStat, + }}, + }) + + case *sshFxpRealpathPacket: + f, err := filepath.Abs(p.Path) + if err != nil { + return s.sendError(p, err) + } + f = filepath.Clean(f) + f = filepath.ToSlash(f) // make path more Unix like on windows servers + return s.sendPacket(sshFxpNamePacket{ + ID: p.ID, + NameAttrs: []sshFxpNameAttr{{ + Name: f, + LongName: f, + Attrs: emptyFileStat, + }}, + }) + case *sshFxpOpendirPacket: + return sshFxpOpenPacket{ + ID: p.ID, + Path: p.Path, + Pflags: ssh_FXF_READ, + }.respond(s) + case *sshFxpReadPacket: + f, ok := s.getHandle(p.Handle) + if !ok { + return s.sendError(p, syscall.EBADF) + } + + data := make([]byte, clamp(p.Len, s.maxTxPacket)) + n, err := f.ReadAt(data, int64(p.Offset)) + if err != nil && (err != io.EOF || n == 0) { + return s.sendError(p, err) + } + return s.sendPacket(sshFxpDataPacket{ + ID: p.ID, + Length: uint32(n), + Data: data[:n], + }) + case *sshFxpWritePacket: + f, ok := s.getHandle(p.Handle) + if !ok { + return s.sendError(p, syscall.EBADF) + } + + _, err := f.WriteAt(p.Data, int64(p.Offset)) + return s.sendError(p, err) + case serverRespondablePacket: + err := p.respond(s) + return errors.Wrap(err, "pkt.respond failed") + default: + return errors.Errorf("unexpected packet type %T", p) + } +} + +// Serve serves SFTP connections until the streams stop or the SFTP subsystem +// is stopped. +func (svr *Server) Serve() error { + var wg sync.WaitGroup + runWorker := func(ch requestChan) { + wg.Add(1) + go func() { + defer wg.Done() + if err := svr.sftpServerWorker(ch); err != nil { + svr.conn.Close() // shuts down recvPacket + } + }() + } + pktChan := svr.pktMgr.workerChan(runWorker) + + var err error + var pkt requestPacket + var pktType uint8 + var pktBytes []byte + for { + pktType, pktBytes, err = svr.recvPacket() + if err != nil { + break + } + + pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes}) + if err != nil { + debug("makePacket err: %v", err) + svr.conn.Close() // shuts down recvPacket + break + } + + pktChan <- pkt + } + + close(pktChan) // shuts down sftpServerWorkers + wg.Wait() // wait for all workers to exit + + // close any still-open files + for handle, file := range svr.openFiles { + fmt.Fprintf(svr.debugStream, "sftp server file with handle %q left open: %v\n", handle, file.Name()) + file.Close() + } + return err // error from recvPacket +} + +// Wrap underlying connection methods to use packetManager +func (svr *Server) sendPacket(m encoding.BinaryMarshaler) error { + if pkt, ok := m.(responsePacket); ok { + svr.pktMgr.readyPacket(pkt) + } else { + return errors.Errorf("unexpected packet type %T", m) + } + return nil +} + +func (svr *Server) sendError(p ider, err error) error { + return svr.sendPacket(statusFromError(p, err)) +} + +type ider interface { + id() uint32 +} + +// The init packet has no ID, so we just return a zero-value ID +func (p sshFxInitPacket) id() uint32 { return 0 } + +type sshFxpStatResponse struct { + ID uint32 + info os.FileInfo +} + +func (p sshFxpStatResponse) MarshalBinary() ([]byte, error) { + b := []byte{ssh_FXP_ATTRS} + b = marshalUint32(b, p.ID) + b = marshalFileInfo(b, p.info) + return b, nil +} + +var emptyFileStat = []interface{}{uint32(0)} + +func (p sshFxpOpenPacket) readonly() bool { + return !p.hasPflags(ssh_FXF_WRITE) +} + +func (p sshFxpOpenPacket) hasPflags(flags ...uint32) bool { + for _, f := range flags { + if p.Pflags&f == 0 { + return false + } + } + return true +} + +func (p sshFxpOpenPacket) respond(svr *Server) error { + var osFlags int + if p.hasPflags(ssh_FXF_READ, ssh_FXF_WRITE) { + osFlags |= os.O_RDWR + } else if p.hasPflags(ssh_FXF_WRITE) { + osFlags |= os.O_WRONLY + } else if p.hasPflags(ssh_FXF_READ) { + osFlags |= os.O_RDONLY + } else { + // how are they opening? + return svr.sendError(p, syscall.EINVAL) + } + + if p.hasPflags(ssh_FXF_APPEND) { + osFlags |= os.O_APPEND + } + if p.hasPflags(ssh_FXF_CREAT) { + osFlags |= os.O_CREATE + } + if p.hasPflags(ssh_FXF_TRUNC) { + osFlags |= os.O_TRUNC + } + if p.hasPflags(ssh_FXF_EXCL) { + osFlags |= os.O_EXCL + } + + f, err := os.OpenFile(p.Path, osFlags, 0644) + if err != nil { + return svr.sendError(p, err) + } + + handle := svr.nextHandle(f) + return svr.sendPacket(sshFxpHandlePacket{p.ID, handle}) +} + +func (p sshFxpReaddirPacket) respond(svr *Server) error { + f, ok := svr.getHandle(p.Handle) + if !ok { + return svr.sendError(p, syscall.EBADF) + } + + dirname := f.Name() + dirents, err := f.Readdir(128) + if err != nil { + return svr.sendError(p, err) + } + + ret := sshFxpNamePacket{ID: p.ID} + for _, dirent := range dirents { + ret.NameAttrs = append(ret.NameAttrs, sshFxpNameAttr{ + Name: dirent.Name(), + LongName: runLs(dirname, dirent), + Attrs: []interface{}{dirent}, + }) + } + return svr.sendPacket(ret) +} + +func (p sshFxpSetstatPacket) respond(svr *Server) error { + // additional unmarshalling is required for each possibility here + b := p.Attrs.([]byte) + var err error + + debug("setstat name \"%s\"", p.Path) + if (p.Flags & ssh_FILEXFER_ATTR_SIZE) != 0 { + var size uint64 + if size, b, err = unmarshalUint64Safe(b); err == nil { + err = os.Truncate(p.Path, int64(size)) + } + } + if (p.Flags & ssh_FILEXFER_ATTR_PERMISSIONS) != 0 { + var mode uint32 + if mode, b, err = unmarshalUint32Safe(b); err == nil { + err = os.Chmod(p.Path, os.FileMode(mode)) + } + } + if (p.Flags & ssh_FILEXFER_ATTR_ACMODTIME) != 0 { + var atime uint32 + var mtime uint32 + if atime, b, err = unmarshalUint32Safe(b); err != nil { + } else if mtime, b, err = unmarshalUint32Safe(b); err != nil { + } else { + atimeT := time.Unix(int64(atime), 0) + mtimeT := time.Unix(int64(mtime), 0) + err = os.Chtimes(p.Path, atimeT, mtimeT) + } + } + if (p.Flags & ssh_FILEXFER_ATTR_UIDGID) != 0 { + var uid uint32 + var gid uint32 + if uid, b, err = unmarshalUint32Safe(b); err != nil { + } else if gid, b, err = unmarshalUint32Safe(b); err != nil { + } else { + err = os.Chown(p.Path, int(uid), int(gid)) + } + } + + return svr.sendError(p, err) +} + +func (p sshFxpFsetstatPacket) respond(svr *Server) error { + f, ok := svr.getHandle(p.Handle) + if !ok { + return svr.sendError(p, syscall.EBADF) + } + + // additional unmarshalling is required for each possibility here + b := p.Attrs.([]byte) + var err error + + debug("fsetstat name \"%s\"", f.Name()) + if (p.Flags & ssh_FILEXFER_ATTR_SIZE) != 0 { + var size uint64 + if size, b, err = unmarshalUint64Safe(b); err == nil { + err = f.Truncate(int64(size)) + } + } + if (p.Flags & ssh_FILEXFER_ATTR_PERMISSIONS) != 0 { + var mode uint32 + if mode, b, err = unmarshalUint32Safe(b); err == nil { + err = f.Chmod(os.FileMode(mode)) + } + } + if (p.Flags & ssh_FILEXFER_ATTR_ACMODTIME) != 0 { + var atime uint32 + var mtime uint32 + if atime, b, err = unmarshalUint32Safe(b); err != nil { + } else if mtime, b, err = unmarshalUint32Safe(b); err != nil { + } else { + atimeT := time.Unix(int64(atime), 0) + mtimeT := time.Unix(int64(mtime), 0) + err = os.Chtimes(f.Name(), atimeT, mtimeT) + } + } + if (p.Flags & ssh_FILEXFER_ATTR_UIDGID) != 0 { + var uid uint32 + var gid uint32 + if uid, b, err = unmarshalUint32Safe(b); err != nil { + } else if gid, b, err = unmarshalUint32Safe(b); err != nil { + } else { + err = f.Chown(int(uid), int(gid)) + } + } + + return svr.sendError(p, err) +} + +// translateErrno translates a syscall error number to a SFTP error code. +func translateErrno(errno syscall.Errno) uint32 { + switch errno { + case 0: + return ssh_FX_OK + case syscall.ENOENT: + return ssh_FX_NO_SUCH_FILE + case syscall.EPERM: + return ssh_FX_PERMISSION_DENIED + } + + return ssh_FX_FAILURE +} + +func statusFromError(p ider, err error) sshFxpStatusPacket { + ret := sshFxpStatusPacket{ + ID: p.id(), + StatusError: StatusError{ + // ssh_FX_OK = 0 + // ssh_FX_EOF = 1 + // ssh_FX_NO_SUCH_FILE = 2 ENOENT + // ssh_FX_PERMISSION_DENIED = 3 + // ssh_FX_FAILURE = 4 + // ssh_FX_BAD_MESSAGE = 5 + // ssh_FX_NO_CONNECTION = 6 + // ssh_FX_CONNECTION_LOST = 7 + // ssh_FX_OP_UNSUPPORTED = 8 + Code: ssh_FX_OK, + }, + } + if err != nil { + debug("statusFromError: error is %T %#v", err, err) + ret.StatusError.Code = ssh_FX_FAILURE + ret.StatusError.msg = err.Error() + if err == io.EOF { + ret.StatusError.Code = ssh_FX_EOF + } else if errno, ok := err.(syscall.Errno); ok { + ret.StatusError.Code = translateErrno(errno) + } else if pathError, ok := err.(*os.PathError); ok { + debug("statusFromError: error is %T %#v", pathError.Err, pathError.Err) + if errno, ok := pathError.Err.(syscall.Errno); ok { + ret.StatusError.Code = translateErrno(errno) + } + } + } + return ret +} + +func clamp(v, max uint32) uint32 { + if v > max { + return max + } + return v +} diff --git a/github.com/pkg/sftp/server_integration_test.go b/github.com/pkg/sftp/server_integration_test.go new file mode 100644 index 0000000000..c5786074ab --- /dev/null +++ b/github.com/pkg/sftp/server_integration_test.go @@ -0,0 +1,671 @@ +package sftp + +// sftp server integration tests +// enable with -integration +// example invokation (darwin): gofmt -w `find . -name \*.go` && (cd server_standalone/ ; go build -tags debug) && go test -tags debug github.com/pkg/sftp -integration -v -sftp /usr/libexec/sftp-server -run ServerCompareSubsystems + +import ( + "bytes" + "encoding/hex" + "flag" + "fmt" + "io/ioutil" + "math/rand" + "net" + "os" + "os/exec" + "path" + "path/filepath" + "regexp" + "strconv" + "strings" + "testing" + "time" + + "github.com/kr/fs" + "golang.org/x/crypto/ssh" +) + +var testSftpClientBin = flag.String("sftp_client", "/usr/bin/sftp", "location of the sftp client binary") +var sshServerDebugStream = ioutil.Discard +var sftpServerDebugStream = ioutil.Discard +var sftpClientDebugStream = ioutil.Discard + +const ( + GOLANG_SFTP = true + OPENSSH_SFTP = false +) + +var ( + hostPrivateKeySigner ssh.Signer + privKey = []byte(` +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEArhp7SqFnXVZAgWREL9Ogs+miy4IU/m0vmdkoK6M97G9NX/Pj +wf8I/3/ynxmcArbt8Rc4JgkjT2uxx/NqR0yN42N1PjO5Czu0dms1PSqcKIJdeUBV +7gdrKSm9Co4d2vwfQp5mg47eG4w63pz7Drk9+VIyi9YiYH4bve7WnGDswn4ycvYZ +slV5kKnjlfCdPig+g5P7yQYud0cDWVwyA0+kxvL6H3Ip+Fu8rLDZn4/P1WlFAIuc +PAf4uEKDGGmC2URowi5eesYR7f6GN/HnBs2776laNlAVXZUmYTUfOGagwLsEkx8x +XdNqntfbs2MOOoK+myJrNtcB9pCrM0H6um19uQIDAQABAoIBABkWr9WdVKvalgkP +TdQmhu3mKRNyd1wCl+1voZ5IM9Ayac/98UAvZDiNU4Uhx52MhtVLJ0gz4Oa8+i16 +IkKMAZZW6ro/8dZwkBzQbieWUFJ2Fso2PyvB3etcnGU8/Yhk9IxBDzy+BbuqhYE2 +1ebVQtz+v1HvVZzaD11bYYm/Xd7Y28QREVfFen30Q/v3dv7dOteDE/RgDS8Czz7w +jMW32Q8JL5grz7zPkMK39BLXsTcSYcaasT2ParROhGJZDmbgd3l33zKCVc1zcj9B +SA47QljGd09Tys958WWHgtj2o7bp9v1Ufs4LnyKgzrB80WX1ovaSQKvd5THTLchO +kLIhUAECgYEA2doGXy9wMBmTn/hjiVvggR1aKiBwUpnB87Hn5xCMgoECVhFZlT6l +WmZe7R2klbtG1aYlw+y+uzHhoVDAJW9AUSV8qoDUwbRXvBVlp+In5wIqJ+VjfivK +zgIfzomL5NvDz37cvPmzqIeySTowEfbQyq7CUQSoDtE9H97E2wWZhDkCgYEAzJdJ +k+NSFoTkHhfD3L0xCDHpRV3gvaOeew8524fVtVUq53X8m91ng4AX1r74dCUYwwiF +gqTtSSJfx2iH1xKnNq28M9uKg7wOrCKrRqNPnYUO3LehZEC7rwUr26z4iJDHjjoB +uBcS7nw0LJ+0Zeg1IF+aIdZGV3MrAKnrzWPixYECgYBsffX6ZWebrMEmQ89eUtFF +u9ZxcGI/4K8ErC7vlgBD5ffB4TYZ627xzFWuBLs4jmHCeNIJ9tct5rOVYN+wRO1k +/CRPzYUnSqb+1jEgILL6istvvv+DkE+ZtNkeRMXUndWwel94BWsBnUKe0UmrSJ3G +sq23J3iCmJW2T3z+DpXbkQKBgQCK+LUVDNPE0i42NsRnm+fDfkvLP7Kafpr3Umdl +tMY474o+QYn+wg0/aPJIf9463rwMNyyhirBX/k57IIktUdFdtfPicd2MEGETElWv +nN1GzYxD50Rs2f/jKisZhEwqT9YNyV9DkgDdGGdEbJNYqbv0qpwDIg8T9foe8E1p +bdErgQKBgAt290I3L316cdxIQTkJh1DlScN/unFffITwu127WMr28Jt3mq3cZpuM +Aecey/eEKCj+Rlas5NDYKsB18QIuAw+qqWyq0LAKLiAvP1965Rkc4PLScl3MgJtO +QYa37FK0p8NcDeUuF86zXBVutwS5nJLchHhKfd590ks57OROtm29 +-----END RSA PRIVATE KEY----- +`) +) + +func init() { + var err error + hostPrivateKeySigner, err = ssh.ParsePrivateKey(privKey) + if err != nil { + panic(err) + } +} + +func keyAuth(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + permissions := &ssh.Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + } + return permissions, nil +} + +func pwAuth(conn ssh.ConnMetadata, pw []byte) (*ssh.Permissions, error) { + permissions := &ssh.Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + } + return permissions, nil +} + +func basicServerConfig() *ssh.ServerConfig { + config := ssh.ServerConfig{ + Config: ssh.Config{ + MACs: []string{"hmac-sha1"}, + }, + PasswordCallback: pwAuth, + PublicKeyCallback: keyAuth, + } + config.AddHostKey(hostPrivateKeySigner) + return &config +} + +type sshServer struct { + useSubsystem bool + conn net.Conn + config *ssh.ServerConfig + sshConn *ssh.ServerConn + newChans <-chan ssh.NewChannel + newReqs <-chan *ssh.Request +} + +func sshServerFromConn(conn net.Conn, useSubsystem bool, config *ssh.ServerConfig) (*sshServer, error) { + // From a standard TCP connection to an encrypted SSH connection + sshConn, newChans, newReqs, err := ssh.NewServerConn(conn, config) + if err != nil { + return nil, err + } + + svr := &sshServer{useSubsystem, conn, config, sshConn, newChans, newReqs} + svr.listenChannels() + return svr, nil +} + +func (svr *sshServer) Wait() error { + return svr.sshConn.Wait() +} + +func (svr *sshServer) Close() error { + return svr.sshConn.Close() +} + +func (svr *sshServer) listenChannels() { + go func() { + for chanReq := range svr.newChans { + go svr.handleChanReq(chanReq) + } + }() + go func() { + for req := range svr.newReqs { + go svr.handleReq(req) + } + }() +} + +func (svr *sshServer) handleReq(req *ssh.Request) { + switch req.Type { + default: + rejectRequest(req) + } +} + +type sshChannelServer struct { + svr *sshServer + chanReq ssh.NewChannel + ch ssh.Channel + newReqs <-chan *ssh.Request +} + +type sshSessionChannelServer struct { + *sshChannelServer + env []string +} + +func (svr *sshServer) handleChanReq(chanReq ssh.NewChannel) { + fmt.Fprintf(sshServerDebugStream, "channel request: %v, extra: '%v'\n", chanReq.ChannelType(), hex.EncodeToString(chanReq.ExtraData())) + switch chanReq.ChannelType() { + case "session": + if ch, reqs, err := chanReq.Accept(); err != nil { + fmt.Fprintf(sshServerDebugStream, "fail to accept channel request: %v\n", err) + chanReq.Reject(ssh.ResourceShortage, "channel accept failure") + } else { + chsvr := &sshSessionChannelServer{ + sshChannelServer: &sshChannelServer{svr, chanReq, ch, reqs}, + env: append([]string{}, os.Environ()...), + } + chsvr.handle() + } + default: + chanReq.Reject(ssh.UnknownChannelType, "channel type is not a session") + } +} + +func (chsvr *sshSessionChannelServer) handle() { + // should maybe do something here... + go chsvr.handleReqs() +} + +func (chsvr *sshSessionChannelServer) handleReqs() { + for req := range chsvr.newReqs { + chsvr.handleReq(req) + } + fmt.Fprintf(sshServerDebugStream, "ssh server session channel complete\n") +} + +func (chsvr *sshSessionChannelServer) handleReq(req *ssh.Request) { + switch req.Type { + case "env": + chsvr.handleEnv(req) + case "subsystem": + chsvr.handleSubsystem(req) + default: + rejectRequest(req) + } +} + +func rejectRequest(req *ssh.Request) error { + fmt.Fprintf(sshServerDebugStream, "ssh rejecting request, type: %s\n", req.Type) + err := req.Reply(false, []byte{}) + if err != nil { + fmt.Fprintf(sshServerDebugStream, "ssh request reply had error: %v\n", err) + } + return err +} + +func rejectRequestUnmarshalError(req *ssh.Request, s interface{}, err error) error { + fmt.Fprintf(sshServerDebugStream, "ssh request unmarshaling error, type '%T': %v\n", s, err) + rejectRequest(req) + return err +} + +// env request form: +type sshEnvRequest struct { + Envvar string + Value string +} + +func (chsvr *sshSessionChannelServer) handleEnv(req *ssh.Request) error { + envReq := &sshEnvRequest{} + if err := ssh.Unmarshal(req.Payload, envReq); err != nil { + return rejectRequestUnmarshalError(req, envReq, err) + } + req.Reply(true, nil) + + found := false + for i, envstr := range chsvr.env { + if strings.HasPrefix(envstr, envReq.Envvar+"=") { + found = true + chsvr.env[i] = envReq.Envvar + "=" + envReq.Value + } + } + if !found { + chsvr.env = append(chsvr.env, envReq.Envvar+"="+envReq.Value) + } + + return nil +} + +// Payload: int: command size, string: command +type sshSubsystemRequest struct { + Name string +} + +type sshSubsystemExitStatus struct { + Status uint32 +} + +func (chsvr *sshSessionChannelServer) handleSubsystem(req *ssh.Request) error { + defer func() { + err1 := chsvr.ch.CloseWrite() + err2 := chsvr.ch.Close() + fmt.Fprintf(sshServerDebugStream, "ssh server subsystem request complete, err: %v %v\n", err1, err2) + }() + + subsystemReq := &sshSubsystemRequest{} + if err := ssh.Unmarshal(req.Payload, subsystemReq); err != nil { + return rejectRequestUnmarshalError(req, subsystemReq, err) + } + + // reply to the ssh client + + // no idea if this is actually correct spec-wise. + // just enough for an sftp server to start. + if subsystemReq.Name != "sftp" { + return req.Reply(false, nil) + } + + req.Reply(true, nil) + + if !chsvr.svr.useSubsystem { + // use the openssh sftp server backend; this is to test the ssh code, not the sftp code, + // or is used for comparison between our sftp subsystem and the openssh sftp subsystem + cmd := exec.Command(*testSftp, "-e", "-l", "DEBUG") // log to stderr + cmd.Stdin = chsvr.ch + cmd.Stdout = chsvr.ch + cmd.Stderr = sftpServerDebugStream + if err := cmd.Start(); err != nil { + return err + } + return cmd.Wait() + } + + sftpServer, err := NewServer( + chsvr.ch, + WithDebug(sftpServerDebugStream), + ) + if err != nil { + return err + } + + // wait for the session to close + runErr := sftpServer.Serve() + exitStatus := uint32(1) + if runErr == nil { + exitStatus = uint32(0) + } + + _, exitStatusErr := chsvr.ch.SendRequest("exit-status", false, ssh.Marshal(sshSubsystemExitStatus{exitStatus})) + return exitStatusErr +} + +// starts an ssh server to test. returns: host string and port +func testServer(t *testing.T, useSubsystem bool, readonly bool) (net.Listener, string, int) { + if !*testIntegration { + t.Skip("skipping intergration test") + } + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + host, portStr, err := net.SplitHostPort(listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + port, err := strconv.Atoi(portStr) + if err != nil { + t.Fatal(err) + } + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + fmt.Fprintf(sshServerDebugStream, "ssh server socket closed: %v\n", err) + break + } + + go func() { + defer conn.Close() + sshSvr, err := sshServerFromConn(conn, useSubsystem, basicServerConfig()) + if err != nil { + t.Error(err) + return + } + err = sshSvr.Wait() + fmt.Fprintf(sshServerDebugStream, "ssh server finished, err: %v\n", err) + }() + } + }() + + return listener, host, port +} + +func runSftpClient(t *testing.T, script string, path string, host string, port int) (string, error) { + // if sftp client binary is unavailable, skip test + if _, err := os.Stat(*testSftpClientBin); err != nil { + t.Skip("sftp client binary unavailable") + } + args := []string{ + // "-vvvv", + "-b", "-", + "-o", "StrictHostKeyChecking=no", + "-o", "LogLevel=ERROR", + "-o", "UserKnownHostsFile /dev/null", + "-P", fmt.Sprintf("%d", port), fmt.Sprintf("%s:%s", host, path), + } + cmd := exec.Command(*testSftpClientBin, args...) + var stdout bytes.Buffer + cmd.Stdin = bytes.NewBufferString(script) + cmd.Stdout = &stdout + cmd.Stderr = sftpClientDebugStream + if err := cmd.Start(); err != nil { + return "", err + } + err := cmd.Wait() + return string(stdout.Bytes()), err +} + +func TestServerCompareSubsystems(t *testing.T) { + listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY) + listenerOp, hostOp, portOp := testServer(t, OPENSSH_SFTP, READONLY) + defer listenerGo.Close() + defer listenerOp.Close() + + script := ` +ls / +ls -l / +ls /dev/ +ls -l /dev/ +ls -l /etc/ +ls -l /bin/ +ls -l /usr/bin/ +` + outputGo, err := runSftpClient(t, script, "/", hostGo, portGo) + if err != nil { + t.Fatal(err) + } + + outputOp, err := runSftpClient(t, script, "/", hostOp, portOp) + if err != nil { + t.Fatal(err) + } + + newlineRegex := regexp.MustCompile(`\r*\n`) + spaceRegex := regexp.MustCompile(`\s+`) + outputGoLines := newlineRegex.Split(outputGo, -1) + outputOpLines := newlineRegex.Split(outputOp, -1) + + for i, goLine := range outputGoLines { + if i > len(outputOpLines) { + t.Fatalf("output line count differs") + } + opLine := outputOpLines[i] + bad := false + if goLine != opLine { + goWords := spaceRegex.Split(goLine, -1) + opWords := spaceRegex.Split(opLine, -1) + // allow words[2] and [3] to be different as these are users & groups + // also allow words[1] to differ as the link count for directories like + // proc is unstable during testing as processes are created/destroyed. + for j, goWord := range goWords { + if j > len(opWords) { + bad = true + } + opWord := opWords[j] + if goWord != opWord && j != 1 && j != 2 && j != 3 { + bad = true + } + } + } + + if bad { + t.Errorf("outputs differ, go:\n%v\nopenssh:\n%v\n", goLine, opLine) + } + } +} + +var rng = rand.New(rand.NewSource(time.Now().Unix())) + +func randData(length int) []byte { + data := make([]byte, length) + for i := 0; i < length; i++ { + data[i] = byte(rng.Uint32()) + } + return data +} + +func randName() string { + return "sftp." + hex.EncodeToString(randData(16)) +} + +func TestServerMkdirRmdir(t *testing.T) { + listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY) + defer listenerGo.Close() + + tmpDir := "/tmp/" + randName() + defer os.RemoveAll(tmpDir) + + // mkdir remote + if _, err := runSftpClient(t, "mkdir "+tmpDir, "/", hostGo, portGo); err != nil { + t.Fatal(err) + } + + // directory should now exist + if _, err := os.Stat(tmpDir); err != nil { + t.Fatal(err) + } + + // now remove the directory + if _, err := runSftpClient(t, "rmdir "+tmpDir, "/", hostGo, portGo); err != nil { + t.Fatal(err) + } + + if _, err := os.Stat(tmpDir); err == nil { + t.Fatal("should have error after deleting the directory") + } +} + +func TestServerSymlink(t *testing.T) { + listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY) + defer listenerGo.Close() + + link := "/tmp/" + randName() + defer os.RemoveAll(link) + + // now create a symbolic link within the new directory + if output, err := runSftpClient(t, "symlink /bin/sh "+link, "/", hostGo, portGo); err != nil { + t.Fatalf("failed: %v %v", err, string(output)) + } + + // symlink should now exist + if stat, err := os.Lstat(link); err != nil { + t.Fatal(err) + } else if (stat.Mode() & os.ModeSymlink) != os.ModeSymlink { + t.Fatalf("is not a symlink: %v", stat.Mode()) + } +} + +func TestServerPut(t *testing.T) { + listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY) + defer listenerGo.Close() + + tmpFileLocal := "/tmp/" + randName() + tmpFileRemote := "/tmp/" + randName() + defer os.RemoveAll(tmpFileLocal) + defer os.RemoveAll(tmpFileRemote) + + t.Logf("put: local %v remote %v", tmpFileLocal, tmpFileRemote) + + // create a file with random contents. This will be the local file pushed to the server + tmpFileLocalData := randData(10 * 1024 * 1024) + if err := ioutil.WriteFile(tmpFileLocal, tmpFileLocalData, 0644); err != nil { + t.Fatal(err) + } + + // sftp the file to the server + if output, err := runSftpClient(t, "put "+tmpFileLocal+" "+tmpFileRemote, "/", hostGo, portGo); err != nil { + t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output) + } + + // tmpFile2 should now exist, with the same contents + if tmpFileRemoteData, err := ioutil.ReadFile(tmpFileRemote); err != nil { + t.Fatal(err) + } else if string(tmpFileLocalData) != string(tmpFileRemoteData) { + t.Fatal("contents of file incorrect after put") + } +} + +func TestServerGet(t *testing.T) { + listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY) + defer listenerGo.Close() + + tmpFileLocal := "/tmp/" + randName() + tmpFileRemote := "/tmp/" + randName() + defer os.RemoveAll(tmpFileLocal) + defer os.RemoveAll(tmpFileRemote) + + t.Logf("get: local %v remote %v", tmpFileLocal, tmpFileRemote) + + // create a file with random contents. This will be the remote file pulled from the server + tmpFileRemoteData := randData(10 * 1024 * 1024) + if err := ioutil.WriteFile(tmpFileRemote, tmpFileRemoteData, 0644); err != nil { + t.Fatal(err) + } + + // sftp the file to the server + if output, err := runSftpClient(t, "get "+tmpFileRemote+" "+tmpFileLocal, "/", hostGo, portGo); err != nil { + t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output) + } + + // tmpFile2 should now exist, with the same contents + if tmpFileLocalData, err := ioutil.ReadFile(tmpFileLocal); err != nil { + t.Fatal(err) + } else if string(tmpFileLocalData) != string(tmpFileRemoteData) { + t.Fatal("contents of file incorrect after put") + } +} + +func compareDirectoriesRecursive(t *testing.T, aroot, broot string) { + walker := fs.Walk(aroot) + for walker.Step() { + if err := walker.Err(); err != nil { + t.Fatal(err) + } + // find paths + aPath := walker.Path() + aRel, err := filepath.Rel(aroot, aPath) + if err != nil { + t.Fatalf("could not find relative path for %v: %v", aPath, err) + } + bPath := path.Join(broot, aRel) + + if aRel == "." { + continue + } + + //t.Logf("comparing: %v a: %v b %v", aRel, aPath, bPath) + + // if a is a link, the sftp recursive copy won't have copied it. ignore + aLink, err := os.Lstat(aPath) + if err != nil { + t.Fatalf("could not lstat %v: %v", aPath, err) + } + if aLink.Mode()&os.ModeSymlink != 0 { + continue + } + + // stat the files + aFile, err := os.Stat(aPath) + if err != nil { + t.Fatalf("could not stat %v: %v", aPath, err) + } + bFile, err := os.Stat(bPath) + if err != nil { + t.Fatalf("could not stat %v: %v", bPath, err) + } + + // compare stats, with some leniency for the timestamp + if aFile.Mode() != bFile.Mode() { + t.Fatalf("modes different for %v: %v vs %v", aRel, aFile.Mode(), bFile.Mode()) + } + if !aFile.IsDir() { + if aFile.Size() != bFile.Size() { + t.Fatalf("sizes different for %v: %v vs %v", aRel, aFile.Size(), bFile.Size()) + } + } + timeDiff := aFile.ModTime().Sub(bFile.ModTime()) + if timeDiff > time.Second || timeDiff < -time.Second { + t.Fatalf("mtimes different for %v: %v vs %v", aRel, aFile.ModTime(), bFile.ModTime()) + } + + // compare contents + if !aFile.IsDir() { + if aContents, err := ioutil.ReadFile(aPath); err != nil { + t.Fatal(err) + } else if bContents, err := ioutil.ReadFile(bPath); err != nil { + t.Fatal(err) + } else if string(aContents) != string(bContents) { + t.Fatalf("contents different for %v", aRel) + } + } + } +} + +func TestServerPutRecursive(t *testing.T) { + listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY) + defer listenerGo.Close() + + dirLocal, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + tmpDirRemote := "/tmp/" + randName() + defer os.RemoveAll(tmpDirRemote) + + t.Logf("put recursive: local %v remote %v", dirLocal, tmpDirRemote) + + // push this directory (source code etc) recursively to the server + if output, err := runSftpClient(t, "mkdir "+tmpDirRemote+"\r\nput -r -P "+dirLocal+"/ "+tmpDirRemote+"/", "/", hostGo, portGo); err != nil { + t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output) + } + + compareDirectoriesRecursive(t, dirLocal, path.Join(tmpDirRemote, path.Base(dirLocal))) +} + +func TestServerGetRecursive(t *testing.T) { + listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY) + defer listenerGo.Close() + + dirRemote, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + tmpDirLocal := "/tmp/" + randName() + defer os.RemoveAll(tmpDirLocal) + + t.Logf("get recursive: local %v remote %v", tmpDirLocal, dirRemote) + + // pull this directory (source code etc) recursively from the server + if output, err := runSftpClient(t, "lmkdir "+tmpDirLocal+"\r\nget -r -P "+dirRemote+"/ "+tmpDirLocal+"/", "/", hostGo, portGo); err != nil { + t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output) + } + + compareDirectoriesRecursive(t, dirRemote, path.Join(tmpDirLocal, path.Base(dirRemote))) +} diff --git a/github.com/pkg/sftp/server_standalone/main.go b/github.com/pkg/sftp/server_standalone/main.go new file mode 100644 index 0000000000..0b8e102a03 --- /dev/null +++ b/github.com/pkg/sftp/server_standalone/main.go @@ -0,0 +1,52 @@ +package main + +// small wrapper around sftp server that allows it to be used as a separate process subsystem call by the ssh server. +// in practice this will statically link; however this allows unit testing from the sftp client. + +import ( + "flag" + "fmt" + "io" + "io/ioutil" + "os" + + "github.com/pkg/sftp" +) + +func main() { + var ( + readOnly bool + debugStderr bool + debugLevel string + options []sftp.ServerOption + ) + + flag.BoolVar(&readOnly, "R", false, "read-only server") + flag.BoolVar(&debugStderr, "e", false, "debug to stderr") + flag.StringVar(&debugLevel, "l", "none", "debug level (ignored)") + flag.Parse() + + debugStream := ioutil.Discard + if debugStderr { + debugStream = os.Stderr + } + options = append(options, sftp.WithDebug(debugStream)) + + if readOnly { + options = append(options, sftp.ReadOnly()) + } + + svr, _ := sftp.NewServer( + struct { + io.Reader + io.WriteCloser + }{os.Stdin, + os.Stdout, + }, + options..., + ) + if err := svr.Serve(); err != nil { + fmt.Fprintf(debugStream, "sftp server completed with error: %v", err) + os.Exit(1) + } +} diff --git a/github.com/pkg/sftp/server_statvfs_darwin.go b/github.com/pkg/sftp/server_statvfs_darwin.go new file mode 100644 index 0000000000..8c01dac52d --- /dev/null +++ b/github.com/pkg/sftp/server_statvfs_darwin.go @@ -0,0 +1,21 @@ +package sftp + +import ( + "syscall" +) + +func statvfsFromStatfst(stat *syscall.Statfs_t) (*StatVFS, error) { + return &StatVFS{ + Bsize: uint64(stat.Bsize), + Frsize: uint64(stat.Bsize), // fragment size is a linux thing; use block size here + Blocks: stat.Blocks, + Bfree: stat.Bfree, + Bavail: stat.Bavail, + Files: stat.Files, + Ffree: stat.Ffree, + Favail: stat.Ffree, // not sure how to calculate Favail + Fsid: uint64(uint64(stat.Fsid.Val[1])<<32 | uint64(stat.Fsid.Val[0])), // endianness? + Flag: uint64(stat.Flags), // assuming POSIX? + Namemax: 1024, // man 2 statfs shows: #define MAXPATHLEN 1024 + }, nil +} diff --git a/github.com/pkg/sftp/server_statvfs_impl.go b/github.com/pkg/sftp/server_statvfs_impl.go new file mode 100644 index 0000000000..c26870ebe7 --- /dev/null +++ b/github.com/pkg/sftp/server_statvfs_impl.go @@ -0,0 +1,25 @@ +// +build darwin linux,!gccgo + +// fill in statvfs structure with OS specific values +// Statfs_t is different per-kernel, and only exists on some unixes (not Solaris for instance) + +package sftp + +import ( + "syscall" +) + +func (p sshFxpExtendedPacketStatVFS) respond(svr *Server) error { + stat := &syscall.Statfs_t{} + if err := syscall.Statfs(p.Path, stat); err != nil { + return svr.sendPacket(statusFromError(p, err)) + } + + retPkt, err := statvfsFromStatfst(stat) + if err != nil { + return svr.sendPacket(statusFromError(p, err)) + } + retPkt.ID = p.ID + + return svr.sendPacket(retPkt) +} diff --git a/github.com/pkg/sftp/server_statvfs_linux.go b/github.com/pkg/sftp/server_statvfs_linux.go new file mode 100644 index 0000000000..43478e890c --- /dev/null +++ b/github.com/pkg/sftp/server_statvfs_linux.go @@ -0,0 +1,22 @@ +// +build !gccgo,linux + +package sftp + +import ( + "syscall" +) + +func statvfsFromStatfst(stat *syscall.Statfs_t) (*StatVFS, error) { + return &StatVFS{ + Bsize: uint64(stat.Bsize), + Frsize: uint64(stat.Frsize), + Blocks: stat.Blocks, + Bfree: stat.Bfree, + Bavail: stat.Bavail, + Files: stat.Files, + Ffree: stat.Ffree, + Favail: stat.Ffree, // not sure how to calculate Favail + Flag: uint64(stat.Flags), // assuming POSIX? + Namemax: uint64(stat.Namelen), + }, nil +} diff --git a/github.com/pkg/sftp/server_statvfs_stubs.go b/github.com/pkg/sftp/server_statvfs_stubs.go new file mode 100644 index 0000000000..1512a132bb --- /dev/null +++ b/github.com/pkg/sftp/server_statvfs_stubs.go @@ -0,0 +1,11 @@ +// +build !darwin,!linux gccgo + +package sftp + +import ( + "syscall" +) + +func (p sshFxpExtendedPacketStatVFS) respond(svr *Server) error { + return syscall.ENOTSUP +} diff --git a/github.com/pkg/sftp/server_stubs.go b/github.com/pkg/sftp/server_stubs.go new file mode 100644 index 0000000000..3b1ddbdbbc --- /dev/null +++ b/github.com/pkg/sftp/server_stubs.go @@ -0,0 +1,12 @@ +// +build !cgo,!plan9 windows android + +package sftp + +import ( + "os" + "path" +) + +func runLs(dirname string, dirent os.FileInfo) string { + return path.Join(dirname, dirent.Name()) +} diff --git a/github.com/pkg/sftp/server_test.go b/github.com/pkg/sftp/server_test.go new file mode 100644 index 0000000000..721acc6778 --- /dev/null +++ b/github.com/pkg/sftp/server_test.go @@ -0,0 +1,95 @@ +package sftp + +import ( + "io" + "sync" + "testing" +) + +func clientServerPair(t *testing.T) (*Client, *Server) { + cr, sw := io.Pipe() + sr, cw := io.Pipe() + server, err := NewServer(struct { + io.Reader + io.WriteCloser + }{sr, sw}) + if err != nil { + t.Fatal(err) + } + go server.Serve() + client, err := NewClientPipe(cr, cw) + if err != nil { + t.Fatalf("%+v\n", err) + } + return client, server +} + +type sshFxpTestBadExtendedPacket struct { + ID uint32 + Extension string + Data string +} + +func (p sshFxpTestBadExtendedPacket) id() uint32 { return p.ID } + +func (p sshFxpTestBadExtendedPacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 + 4 + // type(byte) + uint32 + uint32 + len(p.Extension) + + len(p.Data) + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_EXTENDED) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Extension) + b = marshalString(b, p.Data) + return b, nil +} + +// test that errors are sent back when we request an invalid extended packet operation +func TestInvalidExtendedPacket(t *testing.T) { + client, server := clientServerPair(t) + defer client.Close() + defer server.Close() + + badPacket := sshFxpTestBadExtendedPacket{client.nextID(), "thisDoesn'tExist", "foobar"} + _, _, err := client.clientConn.sendPacket(badPacket) + if err == nil { + t.Fatal("expected error from bad packet") + } + + // try to stat a file; the client should have shut down. + filePath := "/etc/passwd" + _, err = client.Stat(filePath) + if err == nil { + t.Fatal("expected error from closed connection") + } + +} + +// test that server handles concurrent requests correctly +func TestConcurrentRequests(t *testing.T) { + client, server := clientServerPair(t) + defer client.Close() + defer server.Close() + + concurrency := 2 + var wg sync.WaitGroup + wg.Add(concurrency) + + for i := 0; i < concurrency; i++ { + go func() { + defer wg.Done() + + for j := 0; j < 1024; j++ { + f, err := client.Open("/etc/passwd") + if err != nil { + t.Errorf("failed to open file: %v", err) + } + if err := f.Close(); err != nil { + t.Errorf("failed t close file: %v", err) + } + } + }() + } + wg.Wait() +} diff --git a/github.com/pkg/sftp/server_unix.go b/github.com/pkg/sftp/server_unix.go new file mode 100644 index 0000000000..8c3f0b44ea --- /dev/null +++ b/github.com/pkg/sftp/server_unix.go @@ -0,0 +1,143 @@ +// +build darwin dragonfly freebsd !android,linux netbsd openbsd solaris +// +build cgo + +package sftp + +import ( + "fmt" + "os" + "path" + "syscall" + "time" +) + +func runLsTypeWord(dirent os.FileInfo) string { + // find first character, the type char + // b Block special file. + // c Character special file. + // d Directory. + // l Symbolic link. + // s Socket link. + // p FIFO. + // - Regular file. + tc := '-' + mode := dirent.Mode() + if (mode & os.ModeDir) != 0 { + tc = 'd' + } else if (mode & os.ModeDevice) != 0 { + tc = 'b' + if (mode & os.ModeCharDevice) != 0 { + tc = 'c' + } + } else if (mode & os.ModeSymlink) != 0 { + tc = 'l' + } else if (mode & os.ModeSocket) != 0 { + tc = 's' + } else if (mode & os.ModeNamedPipe) != 0 { + tc = 'p' + } + + // owner + orc := '-' + if (mode & 0400) != 0 { + orc = 'r' + } + owc := '-' + if (mode & 0200) != 0 { + owc = 'w' + } + oxc := '-' + ox := (mode & 0100) != 0 + setuid := (mode & os.ModeSetuid) != 0 + if ox && setuid { + oxc = 's' + } else if setuid { + oxc = 'S' + } else if ox { + oxc = 'x' + } + + // group + grc := '-' + if (mode & 040) != 0 { + grc = 'r' + } + gwc := '-' + if (mode & 020) != 0 { + gwc = 'w' + } + gxc := '-' + gx := (mode & 010) != 0 + setgid := (mode & os.ModeSetgid) != 0 + if gx && setgid { + gxc = 's' + } else if setgid { + gxc = 'S' + } else if gx { + gxc = 'x' + } + + // all / others + arc := '-' + if (mode & 04) != 0 { + arc = 'r' + } + awc := '-' + if (mode & 02) != 0 { + awc = 'w' + } + axc := '-' + ax := (mode & 01) != 0 + sticky := (mode & os.ModeSticky) != 0 + if ax && sticky { + axc = 't' + } else if sticky { + axc = 'T' + } else if ax { + axc = 'x' + } + + return fmt.Sprintf("%c%c%c%c%c%c%c%c%c%c", tc, orc, owc, oxc, grc, gwc, gxc, arc, awc, axc) +} + +func runLsStatt(dirname string, dirent os.FileInfo, statt *syscall.Stat_t) string { + // example from openssh sftp server: + // crw-rw-rw- 1 root wheel 0 Jul 31 20:52 ttyvd + // format: + // {directory / char device / etc}{rwxrwxrwx} {number of links} owner group size month day [time (this year) | year (otherwise)] name + + typeword := runLsTypeWord(dirent) + numLinks := statt.Nlink + uid := statt.Uid + gid := statt.Gid + username := fmt.Sprintf("%d", uid) + groupname := fmt.Sprintf("%d", gid) + // TODO FIXME: uid -> username, gid -> groupname lookup for ls -l format output + + mtime := dirent.ModTime() + monthStr := mtime.Month().String()[0:3] + day := mtime.Day() + year := mtime.Year() + now := time.Now() + isOld := mtime.Before(now.Add(-time.Hour * 24 * 365 / 2)) + + yearOrTime := fmt.Sprintf("%02d:%02d", mtime.Hour(), mtime.Minute()) + if isOld { + yearOrTime = fmt.Sprintf("%d", year) + } + + return fmt.Sprintf("%s %4d %-8s %-8s %8d %s %2d %5s %s", typeword, numLinks, username, groupname, dirent.Size(), monthStr, day, yearOrTime, dirent.Name()) +} + +// ls -l style output for a file, which is in the 'long output' section of a readdir response packet +// this is a very simple (lazy) implementation, just enough to look almost like openssh in a few basic cases +func runLs(dirname string, dirent os.FileInfo) string { + dsys := dirent.Sys() + if dsys == nil { + } else if statt, ok := dsys.(*syscall.Stat_t); !ok { + } else { + return runLsStatt(dirname, dirent, statt) + } + + return path.Join(dirname, dirent.Name()) +} diff --git a/github.com/pkg/sftp/sftp.go b/github.com/pkg/sftp/sftp.go new file mode 100644 index 0000000000..22184afe0c --- /dev/null +++ b/github.com/pkg/sftp/sftp.go @@ -0,0 +1,217 @@ +// Package sftp implements the SSH File Transfer Protocol as described in +// https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt +package sftp + +import ( + "fmt" + + "github.com/pkg/errors" +) + +const ( + ssh_FXP_INIT = 1 + ssh_FXP_VERSION = 2 + ssh_FXP_OPEN = 3 + ssh_FXP_CLOSE = 4 + ssh_FXP_READ = 5 + ssh_FXP_WRITE = 6 + ssh_FXP_LSTAT = 7 + ssh_FXP_FSTAT = 8 + ssh_FXP_SETSTAT = 9 + ssh_FXP_FSETSTAT = 10 + ssh_FXP_OPENDIR = 11 + ssh_FXP_READDIR = 12 + ssh_FXP_REMOVE = 13 + ssh_FXP_MKDIR = 14 + ssh_FXP_RMDIR = 15 + ssh_FXP_REALPATH = 16 + ssh_FXP_STAT = 17 + ssh_FXP_RENAME = 18 + ssh_FXP_READLINK = 19 + ssh_FXP_SYMLINK = 20 + ssh_FXP_STATUS = 101 + ssh_FXP_HANDLE = 102 + ssh_FXP_DATA = 103 + ssh_FXP_NAME = 104 + ssh_FXP_ATTRS = 105 + ssh_FXP_EXTENDED = 200 + ssh_FXP_EXTENDED_REPLY = 201 +) + +const ( + ssh_FX_OK = 0 + ssh_FX_EOF = 1 + ssh_FX_NO_SUCH_FILE = 2 + ssh_FX_PERMISSION_DENIED = 3 + ssh_FX_FAILURE = 4 + ssh_FX_BAD_MESSAGE = 5 + ssh_FX_NO_CONNECTION = 6 + ssh_FX_CONNECTION_LOST = 7 + ssh_FX_OP_UNSUPPORTED = 8 + + // see draft-ietf-secsh-filexfer-13 + // https://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1 + ssh_FX_INVALID_HANDLE = 9 + ssh_FX_NO_SUCH_PATH = 10 + ssh_FX_FILE_ALREADY_EXISTS = 11 + ssh_FX_WRITE_PROTECT = 12 + ssh_FX_NO_MEDIA = 13 + ssh_FX_NO_SPACE_ON_FILESYSTEM = 14 + ssh_FX_QUOTA_EXCEEDED = 15 + ssh_FX_UNKNOWN_PRINCIPAL = 16 + ssh_FX_LOCK_CONFLICT = 17 + ssh_FX_DIR_NOT_EMPTY = 18 + ssh_FX_NOT_A_DIRECTORY = 19 + ssh_FX_INVALID_FILENAME = 20 + ssh_FX_LINK_LOOP = 21 + ssh_FX_CANNOT_DELETE = 22 + ssh_FX_INVALID_PARAMETER = 23 + ssh_FX_FILE_IS_A_DIRECTORY = 24 + ssh_FX_BYTE_RANGE_LOCK_CONFLICT = 25 + ssh_FX_BYTE_RANGE_LOCK_REFUSED = 26 + ssh_FX_DELETE_PENDING = 27 + ssh_FX_FILE_CORRUPT = 28 + ssh_FX_OWNER_INVALID = 29 + ssh_FX_GROUP_INVALID = 30 + ssh_FX_NO_MATCHING_BYTE_RANGE_LOCK = 31 +) + +const ( + ssh_FXF_READ = 0x00000001 + ssh_FXF_WRITE = 0x00000002 + ssh_FXF_APPEND = 0x00000004 + ssh_FXF_CREAT = 0x00000008 + ssh_FXF_TRUNC = 0x00000010 + ssh_FXF_EXCL = 0x00000020 +) + +type fxp uint8 + +func (f fxp) String() string { + switch f { + case ssh_FXP_INIT: + return "SSH_FXP_INIT" + case ssh_FXP_VERSION: + return "SSH_FXP_VERSION" + case ssh_FXP_OPEN: + return "SSH_FXP_OPEN" + case ssh_FXP_CLOSE: + return "SSH_FXP_CLOSE" + case ssh_FXP_READ: + return "SSH_FXP_READ" + case ssh_FXP_WRITE: + return "SSH_FXP_WRITE" + case ssh_FXP_LSTAT: + return "SSH_FXP_LSTAT" + case ssh_FXP_FSTAT: + return "SSH_FXP_FSTAT" + case ssh_FXP_SETSTAT: + return "SSH_FXP_SETSTAT" + case ssh_FXP_FSETSTAT: + return "SSH_FXP_FSETSTAT" + case ssh_FXP_OPENDIR: + return "SSH_FXP_OPENDIR" + case ssh_FXP_READDIR: + return "SSH_FXP_READDIR" + case ssh_FXP_REMOVE: + return "SSH_FXP_REMOVE" + case ssh_FXP_MKDIR: + return "SSH_FXP_MKDIR" + case ssh_FXP_RMDIR: + return "SSH_FXP_RMDIR" + case ssh_FXP_REALPATH: + return "SSH_FXP_REALPATH" + case ssh_FXP_STAT: + return "SSH_FXP_STAT" + case ssh_FXP_RENAME: + return "SSH_FXP_RENAME" + case ssh_FXP_READLINK: + return "SSH_FXP_READLINK" + case ssh_FXP_SYMLINK: + return "SSH_FXP_SYMLINK" + case ssh_FXP_STATUS: + return "SSH_FXP_STATUS" + case ssh_FXP_HANDLE: + return "SSH_FXP_HANDLE" + case ssh_FXP_DATA: + return "SSH_FXP_DATA" + case ssh_FXP_NAME: + return "SSH_FXP_NAME" + case ssh_FXP_ATTRS: + return "SSH_FXP_ATTRS" + case ssh_FXP_EXTENDED: + return "SSH_FXP_EXTENDED" + case ssh_FXP_EXTENDED_REPLY: + return "SSH_FXP_EXTENDED_REPLY" + default: + return "unknown" + } +} + +type fx uint8 + +func (f fx) String() string { + switch f { + case ssh_FX_OK: + return "SSH_FX_OK" + case ssh_FX_EOF: + return "SSH_FX_EOF" + case ssh_FX_NO_SUCH_FILE: + return "SSH_FX_NO_SUCH_FILE" + case ssh_FX_PERMISSION_DENIED: + return "SSH_FX_PERMISSION_DENIED" + case ssh_FX_FAILURE: + return "SSH_FX_FAILURE" + case ssh_FX_BAD_MESSAGE: + return "SSH_FX_BAD_MESSAGE" + case ssh_FX_NO_CONNECTION: + return "SSH_FX_NO_CONNECTION" + case ssh_FX_CONNECTION_LOST: + return "SSH_FX_CONNECTION_LOST" + case ssh_FX_OP_UNSUPPORTED: + return "SSH_FX_OP_UNSUPPORTED" + default: + return "unknown" + } +} + +type unexpectedPacketErr struct { + want, got uint8 +} + +func (u *unexpectedPacketErr) Error() string { + return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", fxp(u.want), fxp(u.got)) +} + +func unimplementedPacketErr(u uint8) error { + return errors.Errorf("sftp: unimplemented packet type: got %v", fxp(u)) +} + +type unexpectedIDErr struct{ want, got uint32 } + +func (u *unexpectedIDErr) Error() string { + return fmt.Sprintf("sftp: unexpected id: want %v, got %v", u.want, u.got) +} + +func unimplementedSeekWhence(whence int) error { + return errors.Errorf("sftp: unimplemented seek whence %v", whence) +} + +func unexpectedCount(want, got uint32) error { + return errors.Errorf("sftp: unexpected count: want %v, got %v", want, got) +} + +type unexpectedVersionErr struct{ want, got uint32 } + +func (u *unexpectedVersionErr) Error() string { + return fmt.Sprintf("sftp: unexpected server version: want %v, got %v", u.want, u.got) +} + +// A StatusError is returned when an SFTP operation fails, and provides +// additional information about the failure. +type StatusError struct { + Code uint32 + msg, lang string +} + +func (s *StatusError) Error() string { return fmt.Sprintf("sftp: %q (%v)", s.msg, fx(s.Code)) }