Skip to content

Commit

Permalink
[v10] Skip session recording reservation files (filessesion) (#13947)
Browse files Browse the repository at this point in the history
Skip session recording reservation files (filessesion) (#13826)
  • Loading branch information
gabrielcorado authored Jun 29, 2022
1 parent aa4b3ef commit c2ffcaf
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 11 deletions.
36 changes: 26 additions & 10 deletions lib/events/filesessions/filestream.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,25 @@ func (h *Handler) UploadPart(ctx context.Context, upload events.StreamUpload, pa
return nil, trace.Wrap(err)
}

file, partPath, err := h.openUploadPart(upload, partNumber)
file, reservationPath, err := h.openReservationPart(upload, partNumber)
if err != nil {
return nil, trace.ConvertSystemError(err)
}

_, err = io.Copy(file, partBody)
if err = trace.NewAggregate(err, file.Close()); err != nil {
if rmErr := os.Remove(partPath); rmErr != nil {
h.WithError(rmErr).Warningf("Failed to remove file %q.", partPath)
size, err := io.Copy(file, partBody)
if err = trace.NewAggregate(err, file.Truncate(size), file.Close()); err != nil {
if rmErr := os.Remove(reservationPath); rmErr != nil {
h.WithError(rmErr).Warningf("Failed to remove file %q.", reservationPath)
}
return nil, trace.Wrap(err)
}

// Rename reservation to part file.
err = os.Rename(reservationPath, h.partPath(upload, partNumber))
if err != nil {
return nil, trace.ConvertSystemError(err)
}

return &events.StreamPart{Number: partNumber}, nil
}

Expand Down Expand Up @@ -297,7 +303,7 @@ func (h *Handler) GetUploadMetadata(s session.ID) events.UploadMetadata {

// ReserveUploadPart reserves an upload part.
func (h *Handler) ReserveUploadPart(ctx context.Context, upload events.StreamUpload, partNumber int64) error {
file, partPath, err := h.openUploadPart(upload, partNumber)
file, partPath, err := h.openReservationPart(upload, partNumber)
if err != nil {
return trace.ConvertSystemError(err)
}
Expand All @@ -317,10 +323,10 @@ func (h *Handler) ReserveUploadPart(ctx context.Context, upload events.StreamUpl
return nil
}

// openUploadPart opens a upload file part.
func (h *Handler) openUploadPart(upload events.StreamUpload, partNumber int64) (*os.File, string, error) {
partPath := h.partPath(upload, partNumber)
file, err := GetOpenFileFunc()(partPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
// openReservationPart opens a reservation upload part file.
func (h *Handler) openReservationPart(upload events.StreamUpload, partNumber int64) (*os.File, string, error) {
partPath := h.reservationPath(upload, partNumber)
file, err := GetOpenFileFunc()(partPath, os.O_RDWR|os.O_CREATE, 0600)
if err != nil {
return nil, partPath, trace.ConvertSystemError(err)
}
Expand All @@ -344,10 +350,18 @@ func (h *Handler) partPath(upload events.StreamUpload, partNumber int64) string
return filepath.Join(h.uploadPath(upload), partFileName(partNumber))
}

func (h *Handler) reservationPath(upload events.StreamUpload, partNumber int64) string {
return filepath.Join(h.uploadPath(upload), reservationFileName(partNumber))
}

func partFileName(partNumber int64) string {
return fmt.Sprintf("%v%v", partNumber, partExt)
}

func reservationFileName(partNumber int64) string {
return fmt.Sprintf("%v%v", partNumber, reservationExt)
}

func partFromFileName(fileName string) (int64, error) {
base := filepath.Base(fileName)
if filepath.Ext(base) != partExt {
Expand Down Expand Up @@ -395,4 +409,6 @@ const (
checkpointExt = ".checkpoint"
// errorExt is a suffix for files storing session errors
errorExt = ".error"
// reservationExt is part reservation extension.
reservationExt = ".reservation"
)
92 changes: 91 additions & 1 deletion lib/events/filesessions/filestream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"os"
"testing"

"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/session"

"github.com/stretchr/testify/require"
Expand All @@ -42,7 +43,7 @@ func TestReserveUploadPart(t *testing.T) {
err = handler.ReserveUploadPart(ctx, *upload, partNumber)
require.NoError(t, err)

fi, err := os.Stat(handler.partPath(*upload, partNumber))
fi, err := os.Stat(handler.reservationPath(*upload, partNumber))
require.NoError(t, err)
require.GreaterOrEqual(t, fi.Size(), int64(minUploadBytes))
}
Expand Down Expand Up @@ -71,7 +72,96 @@ func TestUploadPart(t *testing.T) {
require.NoError(t, err)
defer partFile.Close()

fd, err := partFile.Stat()
require.NoError(t, err)
require.Equal(t, int64(len(expectedContent)), fd.Size())

partFileContent, err := io.ReadAll(partFile)
require.NoError(t, err)
require.True(t, bytes.Equal(expectedContent, partFileContent))
}

func TestCompleteUpload(t *testing.T) {
ctx := context.Background()

// Create some upload parts using reserve + write.
createPart := func(t *testing.T, handler *Handler, upload *events.StreamUpload, partNumber int64, content []byte) events.StreamPart {
err := handler.ReserveUploadPart(ctx, *upload, partNumber)
require.NoError(t, err)

if len(content) > 0 {
part, err := handler.UploadPart(ctx, *upload, partNumber, bytes.NewReader(content))
require.NoError(t, err)
return *part
}

return events.StreamPart{Number: partNumber}
}

for _, test := range []struct {
desc string
expectedContent []byte
partsFunc func(t *testing.T, handler *Handler, upload *events.StreamUpload)
}{
{
desc: "PartsWithContent",
expectedContent: []byte("helloworld"),
partsFunc: func(t *testing.T, handler *Handler, upload *events.StreamUpload) {
createPart(t, handler, upload, int64(1), []byte("hello"))
createPart(t, handler, upload, int64(2), []byte("world"))
},
},
{
desc: "ReservationParts",
expectedContent: []byte("helloworldwithreservation"),
partsFunc: func(t *testing.T, handler *Handler, upload *events.StreamUpload) {
createPart(t, handler, upload, int64(1), []byte{})
createPart(t, handler, upload, int64(2), []byte("hello"))
createPart(t, handler, upload, int64(3), []byte("world"))
createPart(t, handler, upload, int64(4), []byte{})
createPart(t, handler, upload, int64(5), []byte("withreservation"))
},
},
{
desc: "OnlyReservation",
expectedContent: []byte{},
partsFunc: func(t *testing.T, handler *Handler, upload *events.StreamUpload) {
createPart(t, handler, upload, int64(1), []byte{})
createPart(t, handler, upload, int64(2), []byte{})
},
},
} {
t.Run(test.desc, func(t *testing.T) {
handler, err := NewHandler(Config{
Directory: t.TempDir(),
})
require.NoError(t, err)

upload, err := handler.CreateUpload(ctx, session.NewID())
require.NoError(t, err)

// Create upload parts.
test.partsFunc(t, handler, upload)

parts, err := handler.ListParts(ctx, *upload)
require.NoError(t, err)

err = handler.CompleteUpload(ctx, *upload, parts)
require.NoError(t, err)

// Check upload contents
uploadPath := handler.path(upload.SessionID)
f, err := os.Open(uploadPath)
require.NoError(t, err)

contents, err := io.ReadAll(f)
require.NoError(t, err)
require.Equal(t, test.expectedContent, contents)

// Part files directory should no longer exists.
_, err = os.ReadDir(handler.uploadRootPath(*upload))
require.Error(t, err)
require.True(t, os.IsNotExist(err))
})
}
}

0 comments on commit c2ffcaf

Please sign in to comment.