diff --git a/pkg/git/lfs.go b/pkg/git/lfs.go index 4b0065f3e..acf41468c 100644 --- a/pkg/git/lfs.go +++ b/pkg/git/lfs.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "io/fs" "path" "path/filepath" "strconv" @@ -60,10 +59,7 @@ func LFSTransfer(ctx context.Context, cmd ServiceCommand) error { } // Advertise capabilities. - for _, cap := range []string{ - "version=1", - "locking", - } { + for _, cap := range transfer.Capabilities { if err := handler.WritePacketText(cap); err != nil { logger.Errorf("error sending capability: %s: %v", cap, err) return err @@ -114,34 +110,32 @@ func (t *lfsTransfer) Batch(_ string, pointers []transfer.BatchItem, _ transfer. } // Download implements transfer.Backend. -func (t *lfsTransfer) Download(oid string, _ transfer.Args) (fs.File, error) { +func (t *lfsTransfer) Download(oid string, _ transfer.Args) (io.ReadCloser, int64, error) { cfg := config.FromContext(t.ctx) repoID := strconv.FormatInt(t.repo.ID(), 10) strg := storage.NewLocalStorage(filepath.Join(cfg.DataPath, "lfs", repoID)) pointer := transfer.Pointer{Oid: oid} - return strg.Open(path.Join("objects", pointer.RelativePath())) -} - -type uploadObject struct { - oid string - size int64 - object storage.Object -} - -func (u *uploadObject) Close() error { - return u.object.Close() + obj, err := strg.Open(path.Join("objects", pointer.RelativePath())) + if err != nil { + return nil, 0, err + } + stat, err := obj.Stat() + if err != nil { + return nil, 0, err + } + return obj, stat.Size(), nil } -// StartUpload implements transfer.Backend. -func (t *lfsTransfer) StartUpload(oid string, r io.Reader, _ transfer.Args) (io.Closer, error) { +// Upload implements transfer.Backend. +func (t *lfsTransfer) Upload(oid string, size int64, r io.Reader, _ transfer.Args) error { if r == nil { - return nil, fmt.Errorf("no reader: %w", transfer.ErrMissingData) + return fmt.Errorf("no reader: %w", transfer.ErrMissingData) } tempDir := "incomplete" randBytes := make([]byte, 12) if _, err := rand.Read(randBytes); err != nil { - return nil, err + return err } tempName := fmt.Sprintf("%s%x", oid, randBytes) @@ -150,37 +144,22 @@ func (t *lfsTransfer) StartUpload(oid string, r io.Reader, _ transfer.Args) (io. written, err := t.storage.Put(tempName, r) if err != nil { t.logger.Errorf("error putting object: %v", err) - return nil, err + return err } obj, err := t.storage.Open(tempName) if err != nil { t.logger.Errorf("error opening object: %v", err) - return nil, err - } - - return &uploadObject{ - oid: oid, - size: written, - object: obj, - }, nil -} - -// FinishUpload implements transfer.Backend. -func (t *lfsTransfer) FinishUpload(state io.Closer, args transfer.Args) error { - upl, ok := state.(*uploadObject) - if !ok { - return errors.New("invalid state") + return err } - size, _ := transfer.SizeFromArgs(args) pointer := transfer.Pointer{ - Oid: upl.oid, + Oid: oid, } if size > 0 { pointer.Size = size } else { - pointer.Size = upl.size + pointer.Size = written } if err := t.store.CreateLFSObject(t.ctx, t.dbx, t.repo.ID(), pointer.Oid, pointer.Size); err != nil { @@ -188,7 +167,7 @@ func (t *lfsTransfer) FinishUpload(state io.Closer, args transfer.Args) error { } expectedPath := path.Join("objects", pointer.RelativePath()) - if err := t.storage.Rename(upl.object.Name(), expectedPath); err != nil { + if err := t.storage.Rename(obj.Name(), expectedPath); err != nil { t.logger.Errorf("error renaming object: %v", err) _ = t.store.DeleteLFSObjectByOid(t.ctx, t.dbx, t.repo.ID(), pointer.Oid) return err @@ -198,12 +177,7 @@ func (t *lfsTransfer) FinishUpload(state io.Closer, args transfer.Args) error { } // Verify implements transfer.Backend. -func (t *lfsTransfer) Verify(oid string, args transfer.Args) (transfer.Status, error) { - expectedSize, err := transfer.SizeFromArgs(args) - if err != nil { - return transfer.NewStatus(transfer.StatusBadRequest, "missing size"), nil // nolint: nilerr - } - +func (t *lfsTransfer) Verify(oid string, size int64, args transfer.Args) (transfer.Status, error) { obj, err := t.store.GetLFSObjectByOid(t.ctx, t.dbx, t.repo.ID(), oid) if err != nil { if errors.Is(err, db.ErrRecordNotFound) { @@ -213,8 +187,8 @@ func (t *lfsTransfer) Verify(oid string, args transfer.Args) (transfer.Status, e return nil, err } - if obj.Size != expectedSize { - t.logger.Errorf("size mismatch: %d != %d", obj.Size, expectedSize) + if obj.Size != size { + t.logger.Errorf("size mismatch: %d != %d", obj.Size, size) return transfer.NewStatus(transfer.StatusConflict, "size mismatch"), nil }