Skip to content

Commit

Permalink
style: reuse-code from gallery
Browse files Browse the repository at this point in the history
  • Loading branch information
mudler committed Dec 17, 2023
1 parent 18e581f commit 831566f
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 121 deletions.
31 changes: 5 additions & 26 deletions api/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"sync"

"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v3"
)

Expand Down Expand Up @@ -265,43 +266,22 @@ func (cm *ConfigLoader) ListConfigs() []string {
return res
}

func convertURL(s string) string {
switch {
case strings.HasPrefix(s, "huggingface://"):
repository := strings.Replace(s, "huggingface://", "", 1)
// convert repository to a full URL.
// e.g. TheBloke/Mixtral-8x7B-v0.1-GGUF/mixtral-8x7b-v0.1.Q2_K.gguf@main -> https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q2_K.gguf
owner := strings.Split(repository, "/")[0]
repo := strings.Split(repository, "/")[1]
branch := "main"
if strings.Contains(repo, "@") {
branch = strings.Split(repository, "@")[1]
}
filepath := strings.Split(repository, "/")[2]
if strings.Contains(filepath, "@") {
filepath = strings.Split(filepath, "@")[0]
}

return fmt.Sprintf("https://huggingface.co/%s/%s/resolve/%s/%s", owner, repo, branch, filepath)
}

return s
}

func (cm *ConfigLoader) Preload(modelPath string) error {
cm.Lock()
defer cm.Unlock()

for i, config := range cm.configs {
modelURL := config.PredictionOptions.Model
modelURL = convertURL(modelURL)
modelURL = utils.ConvertURL(modelURL)
if strings.HasPrefix(modelURL, "http://") || strings.HasPrefix(modelURL, "https://") {
// md5 of model name
md5Name := utils.MD5(modelURL)

// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); err == os.ErrNotExist {
err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name))
err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", func(fileName, current, total string, percent float64) {
log.Info().Msgf("Downloading %s: %s/%s (%.2f%%)", fileName, current, total, percent)
})
if err != nil {
return err
}
Expand All @@ -312,7 +292,6 @@ func (cm *ConfigLoader) Preload(modelPath string) error {
c.PredictionOptions.Model = md5Name
cm.configs[i] = *c
}

}
return nil
}
Expand Down
86 changes: 2 additions & 84 deletions pkg/gallery/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"hash"
"io"
"net/http"
"os"
"path/filepath"
"strconv"
Expand Down Expand Up @@ -115,89 +114,8 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides
// Create file path
filePath := filepath.Join(basePath, file.Filename)

// Check if the file already exists
_, err := os.Stat(filePath)
if err == nil {
// File exists, check SHA
if file.SHA256 != "" {
// Verify SHA
calculatedSHA, err := calculateSHA(filePath)
if err != nil {
return fmt.Errorf("failed to calculate SHA for file %q: %v", file.Filename, err)
}
if calculatedSHA == file.SHA256 {
// SHA matches, skip downloading
log.Debug().Msgf("File %q already exists and matches the SHA. Skipping download", file.Filename)
continue
}
// SHA doesn't match, delete the file and download again
err = os.Remove(filePath)
if err != nil {
return fmt.Errorf("failed to remove existing file %q: %v", file.Filename, err)
}
log.Debug().Msgf("Removed %q (SHA doesn't match)", filePath)

} else {
// SHA is missing, skip downloading
log.Debug().Msgf("File %q already exists. Skipping download", file.Filename)
continue
}
} else if !os.IsNotExist(err) {
// Error occurred while checking file existence
return fmt.Errorf("failed to check file %q existence: %v", file.Filename, err)
}

log.Debug().Msgf("Downloading %q", file.URI)

// Download file
resp, err := http.Get(file.URI)
if err != nil {
return fmt.Errorf("failed to download file %q: %v", file.Filename, err)
}
defer resp.Body.Close()

// Create parent directory
err = os.MkdirAll(filepath.Dir(filePath), 0755)
if err != nil {
return fmt.Errorf("failed to create parent directory for file %q: %v", file.Filename, err)
}

// Create and write file content
outFile, err := os.Create(filePath)
if err != nil {
return fmt.Errorf("failed to create file %q: %v", file.Filename, err)
}
defer outFile.Close()

progress := &progressWriter{
fileName: file.Filename,
total: resp.ContentLength,
hash: sha256.New(),
downloadStatus: downloadStatus,
}
_, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body)
if err != nil {
return fmt.Errorf("failed to write file %q: %v", file.Filename, err)
}

if file.SHA256 != "" {
// Verify SHA
calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil))
if calculatedSHA != file.SHA256 {
log.Debug().Msgf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
}
} else {
log.Debug().Msgf("SHA missing for %q. Skipping validation", file.Filename)
}

log.Debug().Msgf("File %q downloaded and verified", file.Filename)
if utils.IsArchive(filePath) {
log.Debug().Msgf("File %q is an archive, uncompressing to %s", file.Filename, basePath)
if err := utils.ExtractArchive(filePath, basePath); err != nil {
log.Debug().Msgf("Failed decompressing %q: %s", file.Filename, err.Error())
return err
}
if err := utils.DownloadFile(file.URI, filePath, file.SHA256, downloadStatus); err != nil {
return err
}
}

Expand Down
171 changes: 160 additions & 11 deletions pkg/utils/uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@ package utils

import (
"crypto/md5"
"crypto/sha256"
"fmt"
"hash"
"io"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"

"github.com/rs/zerolog/log"
)

const (
Expand Down Expand Up @@ -66,28 +71,172 @@ func GetURI(url string, f func(url string, i []byte) error) error {
return f(url, body)
}

func DownloadFile(url string, filepath string) error {
// Create the file
out, err := os.Create(filepath)
if err != nil {
return err
func ConvertURL(s string) string {
switch {
case strings.HasPrefix(s, "huggingface://"):
repository := strings.Replace(s, "huggingface://", "", 1)
// convert repository to a full URL.
// e.g. TheBloke/Mixtral-8x7B-v0.1-GGUF/mixtral-8x7b-v0.1.Q2_K.gguf@main -> https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q2_K.gguf
owner := strings.Split(repository, "/")[0]
repo := strings.Split(repository, "/")[1]
branch := "main"
if strings.Contains(repo, "@") {
branch = strings.Split(repository, "@")[1]
}
filepath := strings.Split(repository, "/")[2]
if strings.Contains(filepath, "@") {
filepath = strings.Split(filepath, "@")[0]
}

return fmt.Sprintf("https://huggingface.co/%s/%s/resolve/%s/%s", owner, repo, branch, filepath)
}

return s
}

func DownloadFile(url string, filePath, sha string, downloadStatus func(string, string, string, float64)) error {
url = ConvertURL(url)
// Check if the file already exists
_, err := os.Stat(filePath)
if err == nil {
// File exists, check SHA
if sha != "" {
// Verify SHA
calculatedSHA, err := calculateSHA(filePath)
if err != nil {
return fmt.Errorf("failed to calculate SHA for file %q: %v", filePath, err)
}
if calculatedSHA == sha {
// SHA matches, skip downloading
log.Debug().Msgf("File %q already exists and matches the SHA. Skipping download", filePath)
return nil
}
// SHA doesn't match, delete the file and download again
err = os.Remove(filePath)
if err != nil {
return fmt.Errorf("failed to remove existing file %q: %v", filePath, err)
}
log.Debug().Msgf("Removed %q (SHA doesn't match)", filePath)

} else {
// SHA is missing, skip downloading
log.Debug().Msgf("File %q already exists. Skipping download", filePath)
return nil
}
} else if !os.IsNotExist(err) {
// Error occurred while checking file existence
return fmt.Errorf("failed to check file %q existence: %v", filePath, err)
}
defer out.Close()

// Get the data
log.Info().Msgf("Downloading %q", url)

// Download file
resp, err := http.Get(url)
if err != nil {
return err
return fmt.Errorf("failed to download file %q: %v", filePath, err)
}
defer resp.Body.Close()

// Write the body to file
_, err = io.Copy(out, resp.Body)
// Create parent directory
err = os.MkdirAll(filepath.Dir(filePath), 0755)
if err != nil {
return fmt.Errorf("failed to create parent directory for file %q: %v", filePath, err)
}

// Create and write file content
outFile, err := os.Create(filePath)
if err != nil {
return fmt.Errorf("failed to create file %q: %v", filePath, err)
}
defer outFile.Close()

progress := &progressWriter{
fileName: filePath,
total: resp.ContentLength,
hash: sha256.New(),
downloadStatus: downloadStatus,
}
_, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body)
if err != nil {
return fmt.Errorf("failed to write file %q: %v", filePath, err)
}

if sha != "" {
// Verify SHA
calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil))
if calculatedSHA != sha {
log.Debug().Msgf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", filePath, calculatedSHA, sha)
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", filePath, calculatedSHA, sha)
}
} else {
log.Debug().Msgf("SHA missing for %q. Skipping validation", filePath)
}

log.Info().Msgf("File %q downloaded and verified", filePath)
if IsArchive(filePath) {
basePath := filepath.Dir(filePath)
log.Info().Msgf("File %q is an archive, uncompressing to %s", filePath, basePath)
if err := ExtractArchive(filePath, basePath); err != nil {
log.Debug().Msgf("Failed decompressing %q: %s", filePath, err.Error())
return err
}
}

return nil
}

type progressWriter struct {
fileName string
total int64
written int64
downloadStatus func(string, string, string, float64)
hash hash.Hash
}

func (pw *progressWriter) Write(p []byte) (n int, err error) {
n, err = pw.hash.Write(p)
pw.written += int64(n)

if pw.total > 0 {
percentage := float64(pw.written) / float64(pw.total) * 100
//log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
} else {
pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0)
}

return err
return
}

// MD5 of a string
func MD5(s string) string {
return fmt.Sprintf("%x", md5.Sum([]byte(s)))
}

func formatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return strconv.FormatInt(bytes, 10) + " B"
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
}

func calculateSHA(filePath string) (string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", err
}
defer file.Close()

hash := sha256.New()
if _, err := io.Copy(hash, file); err != nil {
return "", err
}

return fmt.Sprintf("%x", hash.Sum(nil)), nil
}

0 comments on commit 831566f

Please sign in to comment.