diff --git a/cmd/admin/handlers-get.go b/cmd/admin/handlers-get.go index 4e7ca6ad..3f90ce1b 100644 --- a/cmd/admin/handlers-get.go +++ b/cmd/admin/handlers-get.go @@ -6,6 +6,7 @@ import ( "log" "net/http" "os" + "strconv" "strings" "github.com/jmpsec/osctrl/pkg/carves" @@ -1129,7 +1130,7 @@ func carvesDownloadHandler(w http.ResponseWriter, r *http.Request) { return } // Prepare file to download - f, err := carvesmgr.Archive(carveSession, carvedFilesFolder) + result, err := carvesmgr.Archive(carveSession, carvedFilesFolder) if err != nil { incMetric(metricAdminErr) log.Printf("error downloading carve - %v", err) @@ -1142,15 +1143,15 @@ func carvesDownloadHandler(w http.ResponseWriter, r *http.Request) { // Send response w.Header().Set("Content-Description", "File Carve Download") w.Header().Set("Content-Type", "application/octet-stream") - w.Header().Set("Content-Disposition", "attachment; filename="+f["file"]) + w.Header().Set("Content-Disposition", "attachment; filename="+result.File) w.Header().Set("Content-Transfer-Encoding", "binary") w.Header().Set("Connection", "Keep-Alive") w.Header().Set("Expires", "0") w.Header().Set("Cache-Control", "must-revalidate, post-check=0, pre-check=0") w.Header().Set("Pragma", "public") - w.Header().Set("Content-Length", f["size"]) + w.Header().Set("Content-Length", strconv.FormatInt(result.Size, 10)) w.WriteHeader(http.StatusOK) var fileReader io.Reader - fileReader, _ = os.Open(f["file"]) + fileReader, _ = os.Open(result.File) _, _ = io.Copy(w, fileReader) } diff --git a/pkg/carves/carves.go b/pkg/carves/carves.go index 6f73b0f2..1aaa94e9 100644 --- a/pkg/carves/carves.go +++ b/pkg/carves/carves.go @@ -6,7 +6,6 @@ import ( "fmt" "log" "os" - "strconv" "strings" "time" @@ -70,6 +69,12 @@ type CarvedBlock struct { Size int } +// CarveResult holds metadata related to a carve +type CarveResult struct { + Size int64 + File string +} + // Carves to handle file carves from nodes type Carves struct { DB *gorm.DB @@ -93,13 +98,9 @@ func CreateFileCarves(backend *gorm.DB) *Carves { // CreateCarve to create a new carved file for a node func (c *Carves) CreateCarve(carve CarvedFile) error { if c.DB.NewRecord(carve) { - if err := c.DB.Create(&carve).Error; err != nil { - return err - } - } else { - return fmt.Errorf("db.NewRecord did not return true") + return c.DB.Create(&carve).Error // can be nil or err } - return nil + return fmt.Errorf("db.NewRecord did not return true") } // CheckCarve to verify a session belong to a carve @@ -114,13 +115,9 @@ func (c *Carves) CheckCarve(sessionid, requestid string) bool { // CreateBlock to create a new block for a carve func (c *Carves) CreateBlock(block CarvedBlock) error { if c.DB.NewRecord(block) { - if err := c.DB.Create(&block).Error; err != nil { - return err - } - } else { - return fmt.Errorf("db.NewRecord did not return true") + return c.DB.Create(&block).Error // can be nil or err } - return nil + return fmt.Errorf("db.NewRecord did not return true") } // Delete to delete a carve by id @@ -255,25 +252,24 @@ func (c *Carves) Completed(sessionid string) bool { } // Archive to convert finalize a completed carve and create a file ready to download -func (c *Carves) Archive(sessionid, path string) (map[string]string, error) { - res := make(map[string]string) +func (c *Carves) Archive(sessionid, path string) (*CarveResult, error) { + res := &CarveResult{ + File: path, + } // Make sure last character is a slash - finalFile := path if path[len(path)-1:] != "/" { - finalFile += "/" + res.File += "/" } - finalFile += sessionid + ".tar" + res.File += sessionid + ".tar" // If file already exists, no need to re-generate it from blocks - _f, err := os.Stat(finalFile) + _f, err := os.Stat(res.File) if err == nil { - res["file"] = finalFile - res["size"] = strconv.FormatInt(_f.Size(), 10) + res.Size = _f.Size() return res, nil } - _f, err = os.Stat(finalFile + ".zst") + _f, err = os.Stat(res.File + ".zst") if err == nil { - res["file"] = finalFile - res["size"] = strconv.FormatInt(_f.Size(), 10) + res.Size = _f.Size() return res, nil } // Get all blocks @@ -286,14 +282,13 @@ func (c *Carves) Archive(sessionid, path string) (map[string]string, error) { return res, fmt.Errorf("Compression check - %v", err) } if zstd { - finalFile += ".zst" + res.File += ".zst" } - f, err := os.OpenFile(finalFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644) + f, err := os.OpenFile(res.File, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644) if err != nil { return res, fmt.Errorf("File creation - %v", err) } defer f.Close() - bytesWritten := 0 // Iterate through blocks and write decoded content to file for _, b := range blocks { toFile, err := base64.StdEncoding.DecodeString(b.Data) @@ -303,9 +298,7 @@ func (c *Carves) Archive(sessionid, path string) (map[string]string, error) { if _, err := f.Write(toFile); err != nil { return res, fmt.Errorf("Writing to file - %v", err) } - bytesWritten += len(toFile) + res.Size += int64(len(toFile)) } - res["file"] = finalFile - res["size"] = strconv.Itoa(bytesWritten) return res, nil }