Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip session recording reservation files (filessesion) #13826

Merged
merged 9 commits into from
Jun 28, 2022
57 changes: 54 additions & 3 deletions lib/events/filesessions/filestream.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ limitations under the License.
package filesessions

import (
"bytes"
"context"
"encoding/gob"
"fmt"
"io"
"os"
Expand Down Expand Up @@ -165,6 +167,13 @@ func (h *Handler) CompleteUpload(ctx context.Context, upload events.StreamUpload
}
}()

// If the part file is just a reservation file, skip it to avoid
// generating broken upload files.
if isReservationFile(file) {
h.Debugf("Skipping reservation file %q", path)
return nil
}

_, err = io.Copy(f, file)
return err
}
Expand Down Expand Up @@ -302,10 +311,12 @@ func (h *Handler) ReserveUploadPart(ctx context.Context, upload events.StreamUpl
return trace.ConvertSystemError(err)
}

// Create a buffer with the max size that a part file can have.
buf := make([]byte, minUploadBytes+events.MaxProtoMessageSizeBytes)
contents, err := generateReservationFileContents(partNumber)
if err != nil {
return trace.Wrap(err)
}

_, err = file.Write(buf)
_, err = file.Write(contents)
if err = trace.NewAggregate(err, file.Close()); err != nil {
if rmErr := os.Remove(partPath); rmErr != nil {
h.WithError(rmErr).Warningf("Failed to remove file %q.", partPath)
Expand Down Expand Up @@ -384,6 +395,46 @@ func checkUploadID(uploadID string) error {
return nil
}

// generateReservationFileContent generates the content placed on the
// reservation files.
func generateReservationFileContents(partNumber int64) ([]byte, error) {
// Create a buffer with the max size that a part file can have.
buf := make([]byte, minUploadBytes+events.MaxProtoMessageSizeBytes)

// Encode reservation content.
encoded := &bytes.Buffer{}
encoder := gob.NewEncoder(encoded)
gabrielcorado marked this conversation as resolved.
Show resolved Hide resolved
err := encoder.Encode(events.StreamPart{Number: partNumber})
if err != nil {
return nil, trace.Wrap(err)
}

// Copy into the final contents.
copy(buf[0:], encoded.Bytes())

return buf, nil
}

// isReservationFile verifies if the provided file is a reservation file
// generate by `ReservePartUpload`.
func isReservationFile(f *os.File) bool {
// Reset the file pointer to the begining.
defer f.Seek(0, 0)

streamPart := &events.StreamPart{}
decoder := gob.NewDecoder(f)
err := decoder.Decode(streamPart)
if err != nil {
return false
}

if streamPart.Number > 0 {
return true
}

return false
}

const (
// uploadsDir is a directory with multipart uploads
uploadsDir = "multi"
Expand Down
75 changes: 75 additions & 0 deletions lib/events/filesessions/filestream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"os"
"testing"

"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/session"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -75,3 +76,77 @@ func TestUploadPart(t *testing.T) {
require.NoError(t, err)
require.True(t, bytes.Equal(expectedContent, partFileContent))
}

func TestCompleteUpload(t *testing.T) {
ctx := context.Background()

// Create some upload parts using reserve + write.
createPart := func(t *testing.T, handler *Handler, upload *events.StreamUpload, partNumber int64, content []byte) events.StreamPart {
err := handler.ReserveUploadPart(ctx, *upload, partNumber)
require.NoError(t, err)

if len(content) > 0 {
part, err := handler.UploadPart(ctx, *upload, partNumber, bytes.NewReader(content))
require.NoError(t, err)
return *part
}

return events.StreamPart{Number: partNumber}
}

for _, test := range []struct {
desc string
expectedContent []byte
partsFunc func(t *testing.T, handler *Handler, upload *events.StreamUpload) []events.StreamPart
}{
{
desc: "PartsWithContent",
expectedContent: []byte("helloworld"),
partsFunc: func(t *testing.T, handler *Handler, upload *events.StreamUpload) []events.StreamPart {
return []events.StreamPart{
createPart(t, handler, upload, int64(1), []byte("hello")),
createPart(t, handler, upload, int64(2), []byte("world")),
}
},
},
{
desc: "ReservationParts",
expectedContent: []byte("helloworld"),
partsFunc: func(t *testing.T, handler *Handler, upload *events.StreamUpload) []events.StreamPart {
return []events.StreamPart{
createPart(t, handler, upload, int64(1), []byte{}),
createPart(t, handler, upload, int64(2), []byte("hello")),
createPart(t, handler, upload, int64(3), []byte("world")),
createPart(t, handler, upload, int64(4), []byte{}),
}
},
},
} {
t.Run(test.desc, func(t *testing.T) {
handler, err := NewHandler(Config{
Directory: t.TempDir(),
})
require.NoError(t, err)

upload, err := handler.CreateUpload(ctx, session.NewID())
require.NoError(t, err)

err = handler.CompleteUpload(ctx, *upload, test.partsFunc(t, handler, upload))
require.NoError(t, err)

// Check upload contents
uploadPath := handler.path(upload.SessionID)
f, err := os.Open(uploadPath)
require.NoError(t, err)

contents, err := io.ReadAll(f)
require.NoError(t, err)
require.Equal(t, test.expectedContent, contents)

// Part files directory should no longer exists.
_, err = os.ReadDir(handler.uploadRootPath(*upload))
require.Error(t, err)
require.True(t, os.IsNotExist(err))
})
}
}