From 831566fa3ec445a69446c90181ad932128d25192 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 17 Dec 2023 09:45:29 +0000 Subject: [PATCH] style: reuse-code from gallery --- api/config/config.go | 31 ++------ pkg/gallery/models.go | 86 +-------------------- pkg/utils/uri.go | 171 +++++++++++++++++++++++++++++++++++++++--- 3 files changed, 167 insertions(+), 121 deletions(-) diff --git a/api/config/config.go b/api/config/config.go index d4f1e5820d14..7ed7061af917 100644 --- a/api/config/config.go +++ b/api/config/config.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/rs/zerolog/log" "gopkg.in/yaml.v3" ) @@ -265,43 +266,22 @@ func (cm *ConfigLoader) ListConfigs() []string { return res } -func convertURL(s string) string { - switch { - case strings.HasPrefix(s, "huggingface://"): - repository := strings.Replace(s, "huggingface://", "", 1) - // convert repository to a full URL. - // e.g. TheBloke/Mixtral-8x7B-v0.1-GGUF/mixtral-8x7b-v0.1.Q2_K.gguf@main -> https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q2_K.gguf - owner := strings.Split(repository, "/")[0] - repo := strings.Split(repository, "/")[1] - branch := "main" - if strings.Contains(repo, "@") { - branch = strings.Split(repository, "@")[1] - } - filepath := strings.Split(repository, "/")[2] - if strings.Contains(filepath, "@") { - filepath = strings.Split(filepath, "@")[0] - } - - return fmt.Sprintf("https://huggingface.co/%s/%s/resolve/%s/%s", owner, repo, branch, filepath) - } - - return s -} - func (cm *ConfigLoader) Preload(modelPath string) error { cm.Lock() defer cm.Unlock() for i, config := range cm.configs { modelURL := config.PredictionOptions.Model - modelURL = convertURL(modelURL) + modelURL = utils.ConvertURL(modelURL) if strings.HasPrefix(modelURL, "http://") || strings.HasPrefix(modelURL, "https://") { // md5 of model name md5Name := utils.MD5(modelURL) // check if file exists if _, err := os.Stat(filepath.Join(modelPath, md5Name)); err == os.ErrNotExist { - err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name)) + err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", func(fileName, current, total string, percent float64) { + log.Info().Msgf("Downloading %s: %s/%s (%.2f%%)", fileName, current, total, percent) + }) if err != nil { return err } @@ -312,7 +292,6 @@ func (cm *ConfigLoader) Preload(modelPath string) error { c.PredictionOptions.Model = md5Name cm.configs[i] = *c } - } return nil } diff --git a/pkg/gallery/models.go b/pkg/gallery/models.go index db68f525a29f..9a1697981614 100644 --- a/pkg/gallery/models.go +++ b/pkg/gallery/models.go @@ -5,7 +5,6 @@ import ( "fmt" "hash" "io" - "net/http" "os" "path/filepath" "strconv" @@ -115,89 +114,8 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides // Create file path filePath := filepath.Join(basePath, file.Filename) - // Check if the file already exists - _, err := os.Stat(filePath) - if err == nil { - // File exists, check SHA - if file.SHA256 != "" { - // Verify SHA - calculatedSHA, err := calculateSHA(filePath) - if err != nil { - return fmt.Errorf("failed to calculate SHA for file %q: %v", file.Filename, err) - } - if calculatedSHA == file.SHA256 { - // SHA matches, skip downloading - log.Debug().Msgf("File %q already exists and matches the SHA. Skipping download", file.Filename) - continue - } - // SHA doesn't match, delete the file and download again - err = os.Remove(filePath) - if err != nil { - return fmt.Errorf("failed to remove existing file %q: %v", file.Filename, err) - } - log.Debug().Msgf("Removed %q (SHA doesn't match)", filePath) - - } else { - // SHA is missing, skip downloading - log.Debug().Msgf("File %q already exists. Skipping download", file.Filename) - continue - } - } else if !os.IsNotExist(err) { - // Error occurred while checking file existence - return fmt.Errorf("failed to check file %q existence: %v", file.Filename, err) - } - - log.Debug().Msgf("Downloading %q", file.URI) - - // Download file - resp, err := http.Get(file.URI) - if err != nil { - return fmt.Errorf("failed to download file %q: %v", file.Filename, err) - } - defer resp.Body.Close() - - // Create parent directory - err = os.MkdirAll(filepath.Dir(filePath), 0755) - if err != nil { - return fmt.Errorf("failed to create parent directory for file %q: %v", file.Filename, err) - } - - // Create and write file content - outFile, err := os.Create(filePath) - if err != nil { - return fmt.Errorf("failed to create file %q: %v", file.Filename, err) - } - defer outFile.Close() - - progress := &progressWriter{ - fileName: file.Filename, - total: resp.ContentLength, - hash: sha256.New(), - downloadStatus: downloadStatus, - } - _, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body) - if err != nil { - return fmt.Errorf("failed to write file %q: %v", file.Filename, err) - } - - if file.SHA256 != "" { - // Verify SHA - calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil)) - if calculatedSHA != file.SHA256 { - log.Debug().Msgf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256) - return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256) - } - } else { - log.Debug().Msgf("SHA missing for %q. Skipping validation", file.Filename) - } - - log.Debug().Msgf("File %q downloaded and verified", file.Filename) - if utils.IsArchive(filePath) { - log.Debug().Msgf("File %q is an archive, uncompressing to %s", file.Filename, basePath) - if err := utils.ExtractArchive(filePath, basePath); err != nil { - log.Debug().Msgf("Failed decompressing %q: %s", file.Filename, err.Error()) - return err - } + if err := utils.DownloadFile(file.URI, filePath, file.SHA256, downloadStatus); err != nil { + return err } } diff --git a/pkg/utils/uri.go b/pkg/utils/uri.go index da2925665f37..8046b89fffc9 100644 --- a/pkg/utils/uri.go +++ b/pkg/utils/uri.go @@ -2,12 +2,17 @@ package utils import ( "crypto/md5" + "crypto/sha256" "fmt" + "hash" "io" "net/http" "os" "path/filepath" + "strconv" "strings" + + "github.com/rs/zerolog/log" ) const ( @@ -66,28 +71,172 @@ func GetURI(url string, f func(url string, i []byte) error) error { return f(url, body) } -func DownloadFile(url string, filepath string) error { - // Create the file - out, err := os.Create(filepath) - if err != nil { - return err +func ConvertURL(s string) string { + switch { + case strings.HasPrefix(s, "huggingface://"): + repository := strings.Replace(s, "huggingface://", "", 1) + // convert repository to a full URL. + // e.g. TheBloke/Mixtral-8x7B-v0.1-GGUF/mixtral-8x7b-v0.1.Q2_K.gguf@main -> https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q2_K.gguf + owner := strings.Split(repository, "/")[0] + repo := strings.Split(repository, "/")[1] + branch := "main" + if strings.Contains(repo, "@") { + branch = strings.Split(repository, "@")[1] + } + filepath := strings.Split(repository, "/")[2] + if strings.Contains(filepath, "@") { + filepath = strings.Split(filepath, "@")[0] + } + + return fmt.Sprintf("https://huggingface.co/%s/%s/resolve/%s/%s", owner, repo, branch, filepath) + } + + return s +} + +func DownloadFile(url string, filePath, sha string, downloadStatus func(string, string, string, float64)) error { + url = ConvertURL(url) + // Check if the file already exists + _, err := os.Stat(filePath) + if err == nil { + // File exists, check SHA + if sha != "" { + // Verify SHA + calculatedSHA, err := calculateSHA(filePath) + if err != nil { + return fmt.Errorf("failed to calculate SHA for file %q: %v", filePath, err) + } + if calculatedSHA == sha { + // SHA matches, skip downloading + log.Debug().Msgf("File %q already exists and matches the SHA. Skipping download", filePath) + return nil + } + // SHA doesn't match, delete the file and download again + err = os.Remove(filePath) + if err != nil { + return fmt.Errorf("failed to remove existing file %q: %v", filePath, err) + } + log.Debug().Msgf("Removed %q (SHA doesn't match)", filePath) + + } else { + // SHA is missing, skip downloading + log.Debug().Msgf("File %q already exists. Skipping download", filePath) + return nil + } + } else if !os.IsNotExist(err) { + // Error occurred while checking file existence + return fmt.Errorf("failed to check file %q existence: %v", filePath, err) } - defer out.Close() - // Get the data + log.Info().Msgf("Downloading %q", url) + + // Download file resp, err := http.Get(url) if err != nil { - return err + return fmt.Errorf("failed to download file %q: %v", filePath, err) } defer resp.Body.Close() - // Write the body to file - _, err = io.Copy(out, resp.Body) + // Create parent directory + err = os.MkdirAll(filepath.Dir(filePath), 0755) + if err != nil { + return fmt.Errorf("failed to create parent directory for file %q: %v", filePath, err) + } + + // Create and write file content + outFile, err := os.Create(filePath) + if err != nil { + return fmt.Errorf("failed to create file %q: %v", filePath, err) + } + defer outFile.Close() + + progress := &progressWriter{ + fileName: filePath, + total: resp.ContentLength, + hash: sha256.New(), + downloadStatus: downloadStatus, + } + _, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body) + if err != nil { + return fmt.Errorf("failed to write file %q: %v", filePath, err) + } + + if sha != "" { + // Verify SHA + calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil)) + if calculatedSHA != sha { + log.Debug().Msgf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", filePath, calculatedSHA, sha) + return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", filePath, calculatedSHA, sha) + } + } else { + log.Debug().Msgf("SHA missing for %q. Skipping validation", filePath) + } + + log.Info().Msgf("File %q downloaded and verified", filePath) + if IsArchive(filePath) { + basePath := filepath.Dir(filePath) + log.Info().Msgf("File %q is an archive, uncompressing to %s", filePath, basePath) + if err := ExtractArchive(filePath, basePath); err != nil { + log.Debug().Msgf("Failed decompressing %q: %s", filePath, err.Error()) + return err + } + } + + return nil +} + +type progressWriter struct { + fileName string + total int64 + written int64 + downloadStatus func(string, string, string, float64) + hash hash.Hash +} + +func (pw *progressWriter) Write(p []byte) (n int, err error) { + n, err = pw.hash.Write(p) + pw.written += int64(n) + + if pw.total > 0 { + percentage := float64(pw.written) / float64(pw.total) * 100 + //log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage) + pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage) + } else { + pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0) + } - return err + return } // MD5 of a string func MD5(s string) string { return fmt.Sprintf("%x", md5.Sum([]byte(s))) } + +func formatBytes(bytes int64) string { + const unit = 1024 + if bytes < unit { + return strconv.FormatInt(bytes, 10) + " B" + } + div, exp := int64(unit), 0 + for n := bytes / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp]) +} + +func calculateSHA(filePath string) (string, error) { + file, err := os.Open(filePath) + if err != nil { + return "", err + } + defer file.Close() + + hash := sha256.New() + if _, err := io.Copy(hash, file); err != nil { + return "", err + } + + return fmt.Sprintf("%x", hash.Sum(nil)), nil +}