diff --git a/README.md b/README.md index 5106716..9c00139 100644 --- a/README.md +++ b/README.md @@ -1,47 +1,55 @@ # optool -A tool to execute commands on multiple remote hosts +A tool to execute commands,transfer files on multiple remote hosts ------------ ### Usage: ```bash Usage: - -V print sample configure + -V print sample configure -config string - set config file path (default "/optool.yml") + set config file path (default "/optool.yml") -encrypt - encrypt a password/phrase + encrypt a password/phrase -g string - set default group name for hosts + set default group name for hosts + -get string + get a file from remote host -gz - enable gzip for transfer./usr/bin/gzip must be executable at remote host + enable gzip for transfer./usr/bin/gzip must be executable at remote host -host string - set run host + set run host -key string - set private key + set private key -nh int - (1)1<<0=no header,(2)1<<1=no server ip,3=none + (1)1<<0=no header,(2)1<<1=no server ip,3=none -o string - set output file (default "-") + set output file (default "-") + -override + Override remote file if exists + -path string + set path.if get is set this is local path,if put is set this is remote path -port int - set default ssh port + set default ssh port + -put string + put a file to remote host -s string - read commands from script + read commands from script -t string - set tagged command + set tagged command -ta string - append tagged command parameters, overflow params will be dropped, separated by comma(,). - to replace in tags use string: _REPLACE_ + append tagged command parameters, overflow params will be dropped, separated by comma(,). + to replace in tags use string: _REPLACE_ -tl - list all tags + list all tags -tp - print tag line + print tag line -u string - set ssh auth user - -v verbose all configs + set ssh auth user + -v verbose all configs -version - print version and exit + print version and exit -x string - execute command directly + execute command directly ``` ### Sample configure: @@ -67,4 +75,5 @@ tags: ps: "/bin/ps" netstat: "/bin/netstat -lntpu" err: "/bin/grep ERROR /var/log/nginx/error.log_REPLACE_" +# transfer_max_size: 1099511627776 #100MB ``` diff --git a/common/command.go b/common/command.go index cbc90a5..447fb3a 100644 --- a/common/command.go +++ b/common/command.go @@ -1,12 +1,12 @@ package common import ( + "bytes" "compress/gzip" "fmt" "io" "io/ioutil" "log" - "os" "strconv" "strings" "sync" @@ -54,48 +54,16 @@ func NewRemoteCommand(hosts []string, cmd string) *RemoteCommand { } // Start run remote command -func (rc *RemoteCommand) Start() error { +func (rc *RemoteCommand) Start() (err error) { cfg := &ssh.ClientConfig{ HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: time.Second * 10, } - password := C.Auth.Password - if !C.Auth.PlainPassword { - password = string(Decrypt(C.Auth.Password)) - } if C.Auth.User != "" { cfg.User = C.Auth.User - if C.Auth.PrivateKey != "" { - if _, err := os.Stat(C.Auth.PrivateKey); err != nil { - return err - } - key, err := ioutil.ReadFile(C.Auth.PrivateKey) - if err != nil { - return err - } - var signer ssh.Signer - if C.Auth.PrivateKeyPhrase == "" { - signer, err = ssh.ParsePrivateKey(key) - } else { - passphrase := []byte(C.Auth.PrivateKeyPhrase) - if !C.Auth.PlainPassword { - passphrase = Decrypt(C.Auth.PrivateKeyPhrase) - } - signer, err = ssh.ParsePrivateKeyWithPassphrase(key, passphrase) - } - if err != nil { - return err - } - cfg.Auth = []ssh.AuthMethod{ - ssh.PublicKeys(signer), - } - if password != "" { - cfg.Auth = append(cfg.Auth, ssh.Password(password)) - } - } else { - cfg.Auth = []ssh.AuthMethod{ - ssh.Password(password), - } + cfg.Auth, err = GetAuth() + if err != nil { + return err } } for _, host := range rc.Hosts { @@ -172,7 +140,12 @@ func (rc *RemoteCommand) PrettyPrint(wo io.Writer, we io.Writer, noHeader bool, we.Write([]byte("================================= ERROR =================================\n")) } for h, e := range rc.Error { - fmt.Fprintln(we, h, ":\n", e) + e = strings.TrimRight(e, "\n") + if strings.Contains(e, "\n") { + fmt.Fprintln(we, h, ":\n", e) + } else { + fmt.Fprintln(we, h, ":", e) + } } } if len(rc.Output) > 0 { @@ -191,15 +164,23 @@ func (rc *RemoteCommand) PrettyPrint(wo io.Writer, we io.Writer, noHeader bool, if err != nil { log.Println(err) } + data = bytes.TrimRight(data, "\n") if !noHost { - wo.Write([]byte(h + ": \n")) + fmt.Fprintf(wo, "%15s: ", h) + if bytes.Contains(data, []byte("\n")) { + wo.Write([]byte("\n")) + } } wo.Write(data) wo.Write([]byte("\n")) continue } + o = strings.TrimRight(o, "\n") if !noHost { - wo.Write([]byte(h + ": \n")) + fmt.Fprintf(wo, "%15s: ", h) + if strings.Contains(o, "\n") { + wo.Write([]byte("\n")) + } } wo.Write([]byte(o)) wo.Write([]byte("\n")) diff --git a/common/config.go b/common/config.go index f752257..3829936 100644 --- a/common/config.go +++ b/common/config.go @@ -2,6 +2,9 @@ package common import ( "io/ioutil" + "os" + + "golang.org/x/crypto/ssh" "github.com/go-yaml/yaml" ) @@ -24,7 +27,10 @@ type Configure struct { Tags map[string]string `yaml:"tags"` // shortcut for frequently used commands Gzip bool `yaml:"-"` // enable gzip transfer //DefaultGroup string `yaml:"default_group"` // set default host group + TransferMaxSize int64 `yaml:"transfer_max_size"` } + +// Server server groups and default port/group config type Server struct { DefaultGroup string `yaml:"default_group"` DefaultPort int `yaml:"default_port"` @@ -50,3 +56,44 @@ func ParseConfig(f string) error { } return nil } + +// GetAuth get auth method list from configs +func GetAuth() (auth []ssh.AuthMethod, err error) { + password := C.Auth.Password + if !C.Auth.PlainPassword { + password = string(Decrypt(C.Auth.Password)) + } + if C.Auth.PrivateKey != "" { + if _, err := os.Stat(C.Auth.PrivateKey); err != nil { + return nil, err + } + key, err := ioutil.ReadFile(C.Auth.PrivateKey) + if err != nil { + return nil, err + } + var signer ssh.Signer + if C.Auth.PrivateKeyPhrase == "" { + signer, err = ssh.ParsePrivateKey(key) + } else { + passphrase := []byte(C.Auth.PrivateKeyPhrase) + if !C.Auth.PlainPassword { + passphrase = Decrypt(C.Auth.PrivateKeyPhrase) + } + signer, err = ssh.ParsePrivateKeyWithPassphrase(key, passphrase) + } + if err != nil { + return nil, err + } + auth = []ssh.AuthMethod{ + ssh.PublicKeys(signer), + } + if password != "" { + auth = append(auth, ssh.Password(password)) + } + } else { + auth = []ssh.AuthMethod{ + ssh.Password(password), + } + } + return +} diff --git a/common/transfer.go b/common/transfer.go new file mode 100644 index 0000000..cd96663 --- /dev/null +++ b/common/transfer.go @@ -0,0 +1,277 @@ +package common + +import ( + "errors" + "fmt" + "log" + "os" + "path" + "strconv" + "strings" + "sync" + "time" + + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +const ( + // TransferGet get file from remote servers + TransferGet = "GET" + // TransferPut put file to remote servers + TransferPut = "PUT" + // TransferDefaultMaxSize default max size to transfer + TransferDefaultMaxSize = 1099511627776 // 100MB +) + +// Transfer transfer files via ssh +type Transfer struct { + Inited bool + Method string // GET,PUT + LocalPath string + RemotePath string + Recursive bool + Hosts []string + Clients map[string]*ssh.Client + SftpClient map[string]*sftp.Client + Override bool // override remote existed file? + TransferResult map[string]FileTransfer // result of transfering + Lock sync.Mutex +} + +// FileTransfer transfer file info +type FileTransfer struct { + Source string + Target string + Size int64 + Elapse time.Duration +} + +// NewTransfer get file transfer instance +func NewTransfer(method, localPath, remotePath string, hosts []string) *Transfer { + return &Transfer{ + Inited: true, + Method: method, + LocalPath: localPath, + RemotePath: remotePath, + Recursive: false, + Clients: make(map[string]*ssh.Client), + SftpClient: make(map[string]*sftp.Client), + Hosts: hosts, + Override: false, + TransferResult: make(map[string]FileTransfer), + Lock: sync.Mutex{}, + } +} + +// Start start file transfer +func (t *Transfer) Start() (err error) { + if err = t.initClient(); err != nil { + return + } + // close connections + defer func() { + for _, sc := range t.SftpClient { + sc.Close() + } + for _, c := range t.Clients { + c.Close() + } + }() + if t.Method == TransferGet { + return t.batchGet() + } + if t.Method == TransferPut { + return t.batchPut() + } + return nil +} + +func (t *Transfer) batchGet() (err error) { + fi, err := os.Stat(t.LocalPath) + if err != nil { + err = os.MkdirAll(t.LocalPath, 0755) + if err != nil { + return + } + } else { + if !fi.IsDir() { + log.Fatalln("Local path cannot be a file") + } + } + wg := sync.WaitGroup{} + for h, sc := range t.SftpClient { + c := t.Clients[h] + wg.Add(1) + go func(sc *sftp.Client, c *ssh.Client) { + defer wg.Done() + err := t.get(sc, c, t.RemotePath, t.LocalPath) + if err != nil { + fmt.Println(c.Conn.RemoteAddr().String(), err) + } + }(sc, c) + } + wg.Wait() + return +} + +func (t *Transfer) batchPut() (err error) { + fi, err := os.Stat(t.LocalPath) + if err != nil { + return + } + if fi.IsDir() { + return errors.New("Local is dir,recursive transfer not supported now") + } + wg := sync.WaitGroup{} + for h, sc := range t.SftpClient { + c := t.Clients[h] + wg.Add(1) + go func(sc *sftp.Client, c *ssh.Client) { + defer wg.Done() + err := t.put(sc, c, t.LocalPath, t.RemotePath) + if err != nil { + fmt.Println(err) + } + }(sc, c) + } + wg.Wait() + return +} + +func (t *Transfer) get(sc *sftp.Client, c *ssh.Client, remotePath, localPath string) (err error) { + fi, err := sc.Stat(remotePath) + if err != nil { + return + } + if fi.IsDir() { + return errors.New("Remote dir get is not supported") + } + if fi.Size() > C.TransferMaxSize { + return fmt.Errorf("Max transfer size is set to %d", C.TransferMaxSize) + } + basename := path.Base(fi.Name()) + srcFile, err := sc.Open(remotePath) + if err != nil { + return + } + defer srcFile.Close() + addr := c.Conn.RemoteAddr().String() + xaddr := strings.Split(addr, ":") + exp := strings.Split(basename, ".") + var ext, prefName string + lenth := len(exp) + if lenth > 1 { + ext = exp[lenth-1] + prefName = strings.Join(exp[0:lenth-1], ".") + } else { + prefName = basename + } + dstFile, err := os.OpenFile(path.Join(localPath, prefName+"-"+strings.Replace(xaddr[0], ".", "-", -1)+"."+ext), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0755) + if err != nil { + return + } + defer dstFile.Close() + ft := FileTransfer{ + Source: srcFile.Name(), + Target: dstFile.Name(), + } + ts := time.Now() + buf := make([]byte, 1024) + var size int64 + for { + n, _ := srcFile.Read(buf) + if n < 1 { + break + } + size = size + int64(n) + dstFile.Write(buf[0:n]) + } + ft.Size = size + ft.Elapse = time.Now().Sub(ts) + t.Lock.Lock() + t.TransferResult[addr] = ft + t.Lock.Unlock() + return +} +func (t *Transfer) put(sc *sftp.Client, c *ssh.Client, localPath, remotePath string) (err error) { + // remote path is dir + if strings.HasSuffix(remotePath, "/") { + basename := path.Base(localPath) + remotePath = path.Join(remotePath, basename) + } + _, e := sc.Stat(remotePath) + if e == nil { + if !t.Override { + fmt.Println("Remote file exists") + return errors.New("Remote file exists") + } + } + srcFile, err := os.OpenFile(localPath, os.O_RDONLY, 0755) + if err != nil { + return + } + defer srcFile.Close() + dstFile, err := sc.OpenFile(remotePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC) + if err != nil { + return + } + defer dstFile.Close() + ft := FileTransfer{ + Source: srcFile.Name(), + Target: dstFile.Name(), + } + ts := time.Now() + var size int64 + buf := make([]byte, 1024) + for { + n, _ := srcFile.Read(buf) + if n < 1 { + break + } + size = size + int64(n) + dstFile.Write(buf[0:n]) + } + ft.Size = size + ft.Elapse = time.Now().Sub(ts) + addr := c.Conn.RemoteAddr().String() + t.Lock.Lock() + t.TransferResult[addr] = ft + t.Lock.Unlock() + return +} + +func (t *Transfer) initClient() error { + auth, err := GetAuth() + if err != nil { + log.Fatalln(err) + } + clientConfig := &ssh.ClientConfig{ + User: C.Auth.User, + Auth: auth, + Timeout: 30 * time.Second, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + for _, h := range t.Hosts { + if strings.Index(h, ":") < 0 { + h = h + ":" + strconv.Itoa(C.Server.DefaultPort) + } + client, err := ssh.Dial("tcp", h, clientConfig) + if err != nil { + return err + } + t.Clients[h] = client + t.SftpClient[h], err = sftp.NewClient(client, sftp.MaxPacket(33788)) + if err != nil { + return err + } + } + return nil +} + +// PrettyPrint print transfer result +func (t *Transfer) PrettyPrint() { + for h, ft := range t.TransferResult { + fmt.Printf("%21s: %s => %s %dByte %.2f seconds\n", h, ft.Source, ft.Target, ft.Size, ft.Elapse.Seconds()) + } +} diff --git a/main.go b/main.go index 60cad28..e215a0c 100644 --- a/main.go +++ b/main.go @@ -47,9 +47,15 @@ var ( pSampleConfig = flag.Bool("V", false, "print sample configure") pVersion = flag.Bool("version", false, "print version and exit") pEncrypt = flag.Bool("encrypt", false, "encrypt a password/phrase") + //@todo + pGet = flag.String("get", "", "get a file from remote host") + pPut = flag.String("put", "", "put a file to remote host") + pPath = flag.String("path", "", "set path.if get is set this is local path,if put is set this is remote path") + pOverride = flag.Bool("override", false, "Override remote file if exists") ) func main() { + log.SetFlags(log.LstdFlags | log.Llongfile) flag.Parse() if *pVersion { fmt.Println("Opstool", OptoolVersion) @@ -129,6 +135,31 @@ func main() { common.C.Auth.PrivateKey = *pPrivateKey common.C.Auth.PrivateKeyPhrase = "" } + // Get/Put files + if *pGet != "" && *pPut != "" { + log.Fatalln("Get or put cannot be set at once") + } + transfer := &common.Transfer{ + Inited: false, + } + if *pGet != "" { + transfer = common.NewTransfer(common.TransferGet, *pPath, *pGet, hosts) + } else if *pPut != "" { + transfer = common.NewTransfer(common.TransferPut, *pPut, *pPath, hosts) + } + if transfer.Inited { + if common.C.TransferMaxSize < 1 { + common.C.TransferMaxSize = common.TransferDefaultMaxSize + } + if *pOverride { + transfer.Override = true + } + if err = transfer.Start(); err != nil { + log.Fatalln(err) + } + transfer.PrettyPrint() + os.Exit(0) + } // command var cmd string if *pTag != "" { @@ -162,6 +193,10 @@ func main() { log.Fatalln("Parameter is not enough. Required is", toReplaceCount) } for i := 0; i < toReplaceCount; i++ { + // - stands for skip this args + if tagArgs[i] == "-" { + tagArgs[i] = "" + } cmd = strings.Replace(cmd, REPLACEMENT, tagArgs[i], 1) } if *pVerbose { @@ -202,6 +237,7 @@ tags: ps: "/bin/ps" netstat: "/bin/netstat -lntpu" err: "/bin/grep ERROR /var/log/nginx/error.log_REPLACE_" +# transfer_max_size: 1099511627776 #100MB `) }