Skip to content

Commit

Permalink
SNOW-1016278 Fix panic on empty arrow batches
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus committed Jan 25, 2024
1 parent 611fe9d commit 1b0b843
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 2 deletions.
12 changes: 12 additions & 0 deletions assert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ func assertBetweenE(t *testing.T, value float64, min float64, max float64, descr
errorOnNonEmpty(t, validateValueBetween(value, min, max, descriptions...))
}

func assertEmptyE[T any](t *testing.T, actual []T, descriptions ...string) {
errorOnNonEmpty(t, validateEmpty(actual, descriptions...))
}

func fatalOnNonEmpty(t *testing.T, errMsg string) {
if errMsg != "" {
t.Fatal(formatErrorMessage(errMsg))
Expand Down Expand Up @@ -122,6 +126,14 @@ func validateValueBetween(value float64, min float64, max float64, descriptions
return fmt.Sprintf("expected \"%f\" should be between \"%f\" and \"%f\" but did not. %s", value, min, max, desc)
}

func validateEmpty[T any](value []T, descriptions ...string) string {
if len(value) == 0 {
return ""
}
desc := joinDescriptions(descriptions...)
return fmt.Sprintf("expected \"%v\" to be empty. %s", value, desc)
}

func joinDescriptions(descriptions ...string) string {
return strings.Join(descriptions, " ")
}
Expand Down
2 changes: 1 addition & 1 deletion chunk_downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func (scd *snowflakeChunkDownloader) getRowType() []execResponseRowType {
}

func (scd *snowflakeChunkDownloader) getArrowBatches() []*ArrowBatch {
if scd.FirstBatch.rec == nil {
if scd.FirstBatch == nil || scd.FirstBatch.rec == nil {
return scd.ArrowBatches
}
return append([]*ArrowBatch{scd.FirstBatch}, scd.ArrowBatches...)
Expand Down
26 changes: 25 additions & 1 deletion chunk_downloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,38 @@ func TestWithArrowBatchesWhenQueryReturnsNoRowsWhenUsingNativeGoSQLInterface(t *
})
}

func TestWithArrowBatchesWhenQueryReturnsNoRows(t *testing.T) {
func TestWithArrowBatchesWhenQueryReturnsRowsAndReadingRows(t *testing.T) {
runDBTest(t, func(dbt *DBTest) {
rows := dbt.mustQueryContext(WithArrowBatches(context.Background()), "SELECT 1")
defer rows.Close()
assertFalseF(t, rows.Next())
})
}

func TestWithArrowBatchesWhenQueryReturnsNoRowsAndReadingRows(t *testing.T) {
runDBTest(t, func(dbt *DBTest) {
rows := dbt.mustQueryContext(WithArrowBatches(context.Background()), "SELECT 1 WHERE 1 = 0")
defer rows.Close()
assertFalseF(t, rows.Next())
})
}

func TestWithArrowBatchesWhenQueryReturnsNoRowsAndReadingArrowBatches(t *testing.T) {
runDBTest(t, func(dbt *DBTest) {
var rows driver.Rows
var err error
err = dbt.conn.Raw(func(x any) error {
rows, err = x.(driver.QueryerContext).QueryContext(WithArrowBatches(context.Background()), "SELECT 1 WHERE 1 = 0", nil)
return err
})
assertNilF(t, err)
defer rows.Close()
batches, err := rows.(SnowflakeRows).GetArrowBatches()
assertNilF(t, err)
assertEmptyE(t, batches)
})
}

func TestWithArrowBatchesWhenQueryReturnsSomeRowsInGivenFormatUsingNativeGoSQLInterface(t *testing.T) {
for _, tc := range []struct {
useJSON bool
Expand Down

0 comments on commit 1b0b843

Please sign in to comment.