Skip to content

Commit

Permalink
chore: added feature flag for loadByFolderPath
Browse files Browse the repository at this point in the history
  • Loading branch information
achettyiitr committed Jan 27, 2025
1 parent d8af4c8 commit c10b877
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 54 deletions.
182 changes: 159 additions & 23 deletions warehouse/integrations/redshift/redshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,22 @@ package redshift
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/url"
"os"
"path/filepath"
"regexp"
"slices"
"sort"
"strings"
"time"

"github.com/samber/lo"
"github.com/tidwall/gjson"

"github.com/rudderlabs/rudder-go-kit/filemanager"
"github.com/rudderlabs/sqlconnect-go/sqlconnect"
sqlconnectconfig "github.com/rudderlabs/sqlconnect-go/sqlconnect/config"

Expand Down Expand Up @@ -153,9 +160,24 @@ type Redshift struct {
skipDedupDestinationIDs []string
skipComputingUserLatestTraits bool
enableDeleteByJobs bool
loadByFolderPath bool
}
}

type s3ManifestEntryMetadata struct {
ContentLength int64 `json:"content_length"`
}

type s3ManifestEntry struct {
Url string `json:"url"`
Mandatory bool `json:"mandatory"`
Metadata s3ManifestEntryMetadata `json:"meta"`
}

type s3Manifest struct {
Entries []s3ManifestEntry `json:"entries"`
}

func New(conf *config.Config, log logger.Logger, stat stats.Stats) *Redshift {
rs := &Redshift{}

Expand All @@ -170,6 +192,7 @@ func New(conf *config.Config, log logger.Logger, stat stats.Stats) *Redshift {
rs.config.skipComputingUserLatestTraits = conf.GetBool("Warehouse.redshift.skipComputingUserLatestTraits", false)
rs.config.enableDeleteByJobs = conf.GetBool("Warehouse.redshift.enableDeleteByJobs", false)
rs.config.slowQueryThreshold = conf.GetDuration("Warehouse.redshift.slowQueryThreshold", 5, time.Minute)
rs.config.loadByFolderPath = conf.GetBool("Warehouse.redshift.loadByFolderPath", false)

return rs
}
Expand Down Expand Up @@ -319,6 +342,92 @@ func (rs *Redshift) createSchema(ctx context.Context) (err error) {
return
}

func (rs *Redshift) generateManifest(ctx context.Context, tableName string) (string, error) {
metadata, err := rs.Uploader.GetLoadFilesMetadata(
ctx,
warehouseutils.GetLoadFilesOptions{
Table: tableName,
},
)
if err != nil {
return "", err
}
loadFiles := warehouseutils.GetS3Locations(metadata)

entries := lo.Map(loadFiles, func(loadFile warehouseutils.LoadFile, index int) s3ManifestEntry {
manifestEntry := s3ManifestEntry{
Url: loadFile.Location,
Mandatory: true,
}

// add contentLength to manifest entry if it exists
contentLength := gjson.Get(string(loadFile.Metadata), "content_length")
if contentLength.Exists() {
manifestEntry.Metadata.ContentLength = contentLength.Int()
}

return manifestEntry
})

manifestJSON, err := json.Marshal(&s3Manifest{
Entries: entries,
})
if err != nil {
return "", fmt.Errorf("marshalling manifest: %v", err)
}

tmpDirPath, err := misc.CreateTMPDIR()
if err != nil {
panic(err)
}

localManifestPath := tmpDirPath + "/" + misc.RudderRedshiftManifests + "/" + misc.FastUUID().String()
err = os.MkdirAll(filepath.Dir(localManifestPath), os.ModePerm)
if err != nil {
return "", fmt.Errorf("creating manifest directory: %v", err)
}

defer func() {
misc.RemoveFilePaths(localManifestPath)
}()

err = os.WriteFile(localManifestPath, manifestJSON, 0o644)
if err != nil {
return "", fmt.Errorf("writing manifest to file: %v", err)
}

file, err := os.Open(localManifestPath)
if err != nil {
return "", fmt.Errorf("opening manifest file: %v", err)
}
defer func() { _ = file.Close() }()

uploader, err := filemanager.New(&filemanager.Settings{
Provider: warehouseutils.S3,
Config: misc.GetObjectStorageConfig(misc.ObjectStorageOptsT{
Provider: warehouseutils.S3,
Config: rs.Warehouse.Destination.Config,
UseRudderStorage: rs.Uploader.UseRudderStorage(),
WorkspaceID: rs.Warehouse.Destination.WorkspaceID,
}),
})
if err != nil {
return "", fmt.Errorf("creating uploader: %w", err)
}

uploadOutput, err := uploader.Upload(
ctx, file, misc.RudderRedshiftManifests,
rs.Warehouse.Source.ID, rs.Warehouse.Destination.ID,
time.Now().Format("01-02-2006"), tableName,
misc.FastUUID().String(),
)
if err != nil {
return "", fmt.Errorf("uploading manifest file: %w", err)
}

return uploadOutput.Location, nil
}

func (rs *Redshift) dropStagingTables(ctx context.Context, stagingTableNames []string) {
for _, stagingTableName := range stagingTableNames {
rs.logger.Infof("WH: dropping table %+v\n", stagingTableName)
Expand Down Expand Up @@ -348,11 +457,6 @@ func (rs *Redshift) loadTable(
)
log.Infow("started loading")

objectLocation, err := rs.Uploader.GetSampleLoadFileLocation(ctx, tableName)
if err != nil {
return nil, "", fmt.Errorf("getting sample load file location: %w", err)
}

stagingTableName := warehouseutils.StagingTableName(
provider,
tableName,
Expand All @@ -365,7 +469,7 @@ func (rs *Redshift) loadTable(
stagingTableName,
tableName,
)
if _, err = rs.DB.ExecContext(ctx, createStagingTableStmt); err != nil {
if _, err := rs.DB.ExecContext(ctx, createStagingTableStmt); err != nil {
return nil, "", fmt.Errorf("creating staging table: %w", err)
}

Expand All @@ -390,8 +494,8 @@ func (rs *Redshift) loadTable(

log.Infow("loading data into staging table")
err = rs.copyIntoLoadTable(
ctx, txn, stagingTableName,
objectLocation, strKeys,
ctx, txn, tableName, stagingTableName,
strKeys,
)
if err != nil {
return nil, "", fmt.Errorf("loading data into staging table: %w", err)
Expand Down Expand Up @@ -450,20 +554,42 @@ func (rs *Redshift) loadTable(
func (rs *Redshift) copyIntoLoadTable(
ctx context.Context,
txn *sqlmiddleware.Tx,
tableName string,
stagingTableName string,
objectLocation string,
strKeys []string,
) error {
tempAccessKeyId, tempSecretAccessKey, token, err := warehouseutils.GetTemporaryS3Cred(&rs.Warehouse.Destination)
if err != nil {
return fmt.Errorf("getting temporary s3 credentials: %w", err)
}

s3Location, region := warehouseutils.GetS3Location(objectLocation)
if region == "" {
region = "us-east-1"
var manifestSQL string
if !rs.config.loadByFolderPath {
manifestSQL = "MANIFEST"
}

var s3Location, region string
if rs.config.loadByFolderPath {
objectLocation, err := rs.Uploader.GetSampleLoadFileLocation(ctx, tableName)
if err != nil {
return fmt.Errorf("getting sample load file location: %w", err)
}

Check warning on line 576 in warehouse/integrations/redshift/redshift.go

View check run for this annotation

Codecov / codecov/patch

warehouse/integrations/redshift/redshift.go#L575-L576

Added lines #L575 - L576 were not covered by tests

s3Location, region = warehouseutils.GetS3Location(objectLocation)
if region == "" {
region = "us-east-1"
}
s3Location = warehouseutils.GetLocationFolder(s3Location)
} else {
manifestLocation, err := rs.generateManifest(ctx, tableName)
if err != nil {
return fmt.Errorf("generating manifest: %w", err)
}

Check warning on line 587 in warehouse/integrations/redshift/redshift.go

View check run for this annotation

Codecov / codecov/patch

warehouse/integrations/redshift/redshift.go#L586-L587

Added lines #L586 - L587 were not covered by tests
s3Location, region = warehouseutils.GetS3Location(manifestLocation)
if region == "" {
region = "us-east-1"
}
}
s3LocationFolder := warehouseutils.GetLocationFolder(s3Location)

sortedColumnNames := warehouseutils.JoinWithFormatting(strKeys, func(_ int, name string) string {
return fmt.Sprintf(`%q`, name)
Expand All @@ -477,12 +603,13 @@ func (rs *Redshift) copyIntoLoadTable(
ACCESS_KEY_ID '%s'
SECRET_ACCESS_KEY '%s'
SESSION_TOKEN '%s'
FORMAT PARQUET;`,
%s FORMAT PARQUET;`,
fmt.Sprintf(`%q.%q`, rs.Namespace, stagingTableName),
s3LocationFolder,
s3Location,
tempAccessKeyId,
tempSecretAccessKey,
token,
manifestSQL,
)
} else {
copyStmt = fmt.Sprintf(
Expand All @@ -495,16 +622,17 @@ func (rs *Redshift) copyIntoLoadTable(
REGION '%s'
DATEFORMAT 'auto'
TIMEFORMAT 'auto'
TRUNCATECOLUMNS EMPTYASNULL BLANKSASNULL FILLRECORD ACCEPTANYDATE TRIMBLANKS ACCEPTINVCHARS
%s TRUNCATECOLUMNS EMPTYASNULL BLANKSASNULL FILLRECORD ACCEPTANYDATE TRIMBLANKS ACCEPTINVCHARS
COMPUPDATE OFF
STATUPDATE OFF;`,
fmt.Sprintf(`%q.%q`, rs.Namespace, stagingTableName),
sortedColumnNames,
s3LocationFolder,
s3Location,
tempAccessKeyId,
tempSecretAccessKey,
token,
region,
manifestSQL,
)
}

Expand Down Expand Up @@ -1240,18 +1368,26 @@ func (rs *Redshift) LoadTestTable(ctx context.Context, location, tableName strin
return
}

s3Location, region := warehouseutils.GetS3Location(location)
if region == "" {
region = "us-east-1"
var s3Location, region string
if rs.config.loadByFolderPath {
s3Location, region = warehouseutils.GetS3Location(location)
if region == "" {
region = "us-east-1"
}
s3Location = warehouseutils.GetLocationFolder(s3Location)

Check warning on line 1377 in warehouse/integrations/redshift/redshift.go

View check run for this annotation

Codecov / codecov/patch

warehouse/integrations/redshift/redshift.go#L1373-L1377

Added lines #L1373 - L1377 were not covered by tests
} else {
s3Location, region = warehouseutils.GetS3Location(location)
if region == "" {
region = "us-east-1"
}
}
s3LocationFolder := warehouseutils.GetLocationFolder(s3Location)

var sqlStatement string
if format == warehouseutils.LoadFileTypeParquet {
// copy statement for parquet load files
sqlStatement = fmt.Sprintf(`COPY %v FROM '%s' ACCESS_KEY_ID '%s' SECRET_ACCESS_KEY '%s' SESSION_TOKEN '%s' FORMAT PARQUET`,
fmt.Sprintf(`%q.%q`, rs.Namespace, tableName),
s3LocationFolder,
s3Location,

Check warning on line 1390 in warehouse/integrations/redshift/redshift.go

View check run for this annotation

Codecov / codecov/patch

warehouse/integrations/redshift/redshift.go#L1390

Added line #L1390 was not covered by tests
tempAccessKeyId,
tempSecretAccessKey,
token,
Expand All @@ -1261,7 +1397,7 @@ func (rs *Redshift) LoadTestTable(ctx context.Context, location, tableName strin
sqlStatement = fmt.Sprintf(`COPY %v(%v) FROM '%v' CSV GZIP ACCESS_KEY_ID '%s' SECRET_ACCESS_KEY '%s' SESSION_TOKEN '%s' REGION '%s' DATEFORMAT 'auto' TIMEFORMAT 'auto' TRUNCATECOLUMNS EMPTYASNULL BLANKSASNULL FILLRECORD ACCEPTANYDATE TRIMBLANKS ACCEPTINVCHARS COMPUPDATE OFF STATUPDATE OFF`,
fmt.Sprintf(`%q.%q`, rs.Namespace, tableName),
fmt.Sprintf(`%q, %q`, "id", "val"),
s3LocationFolder,
s3Location,
tempAccessKeyId,
tempSecretAccessKey,
token,
Expand Down
Loading

0 comments on commit c10b877

Please sign in to comment.