diff --git a/assert_test.go b/assert_test.go index b49213691..91a69c3d5 100644 --- a/assert_test.go +++ b/assert_test.go @@ -58,6 +58,10 @@ func assertBetweenE(t *testing.T, value float64, min float64, max float64, descr errorOnNonEmpty(t, validateValueBetween(value, min, max, descriptions...)) } +func assertEmptyE[T any](t *testing.T, actual []T, descriptions ...string) { + errorOnNonEmpty(t, validateEmpty(actual, descriptions...)) +} + func fatalOnNonEmpty(t *testing.T, errMsg string) { if errMsg != "" { t.Fatal(formatErrorMessage(errMsg)) @@ -122,6 +126,14 @@ func validateValueBetween(value float64, min float64, max float64, descriptions return fmt.Sprintf("expected \"%f\" should be between \"%f\" and \"%f\" but did not. %s", value, min, max, desc) } +func validateEmpty[T any](value []T, descriptions ...string) string { + if len(value) == 0 { + return "" + } + desc := joinDescriptions(descriptions...) + return fmt.Sprintf("expected \"%v\" to be empty. %s", value, desc) +} + func joinDescriptions(descriptions ...string) string { return strings.Join(descriptions, " ") } diff --git a/chunk_downloader.go b/chunk_downloader.go index b68f9ece0..a265f2fa6 100644 --- a/chunk_downloader.go +++ b/chunk_downloader.go @@ -249,7 +249,7 @@ func (scd *snowflakeChunkDownloader) getRowType() []execResponseRowType { } func (scd *snowflakeChunkDownloader) getArrowBatches() []*ArrowBatch { - if scd.FirstBatch.rec == nil { + if scd.FirstBatch == nil || scd.FirstBatch.rec == nil { return scd.ArrowBatches } return append([]*ArrowBatch{scd.FirstBatch}, scd.ArrowBatches...) diff --git a/chunk_downloader_test.go b/chunk_downloader_test.go index b63ef5736..ea8192d77 100644 --- a/chunk_downloader_test.go +++ b/chunk_downloader_test.go @@ -41,7 +41,7 @@ func TestWithArrowBatchesWhenQueryReturnsNoRowsWhenUsingNativeGoSQLInterface(t * }) } -func TestWithArrowBatchesWhenQueryReturnsNoRows(t *testing.T) { +func TestWithArrowBatchesWhenQueryReturnsRowsAndReadingRows(t *testing.T) { runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryContext(WithArrowBatches(context.Background()), "SELECT 1") defer rows.Close() @@ -49,6 +49,30 @@ func TestWithArrowBatchesWhenQueryReturnsNoRows(t *testing.T) { }) } +func TestWithArrowBatchesWhenQueryReturnsNoRowsAndReadingRows(t *testing.T) { + runDBTest(t, func(dbt *DBTest) { + rows := dbt.mustQueryContext(WithArrowBatches(context.Background()), "SELECT 1 WHERE 1 = 0") + defer rows.Close() + assertFalseF(t, rows.Next()) + }) +} + +func TestWithArrowBatchesWhenQueryReturnsNoRowsAndReadingArrowBatches(t *testing.T) { + runDBTest(t, func(dbt *DBTest) { + var rows driver.Rows + var err error + err = dbt.conn.Raw(func(x any) error { + rows, err = x.(driver.QueryerContext).QueryContext(WithArrowBatches(context.Background()), "SELECT 1 WHERE 1 = 0", nil) + return err + }) + assertNilF(t, err) + defer rows.Close() + batches, err := rows.(SnowflakeRows).GetArrowBatches() + assertNilF(t, err) + assertEmptyE(t, batches) + }) +} + func TestWithArrowBatchesWhenQueryReturnsSomeRowsInGivenFormatUsingNativeGoSQLInterface(t *testing.T) { for _, tc := range []struct { useJSON bool