From 90855ce5a79d9a93c9b557b1c41dff2a566c3e86 Mon Sep 17 00:00:00 2001
From: David Li
Date: Sun, 26 Jan 2025 23:27:17 -0500
Subject: [PATCH] feat(go/adbc/driver/snowflake): add query tag option
This lets you identify particular queries in the query history.
Fixes #1934.
---
go/adbc/driver/snowflake/driver_test.go | 41 +++++++++++++++++++
go/adbc/driver/snowflake/statement.go | 23 +++++++++++
.../adbc_driver_snowflake/__init__.py | 3 ++
3 files changed, 67 insertions(+)
diff --git a/go/adbc/driver/snowflake/driver_test.go b/go/adbc/driver/snowflake/driver_test.go
index d014ef0ef1..151292be99 100644
--- a/go/adbc/driver/snowflake/driver_test.go
+++ b/go/adbc/driver/snowflake/driver_test.go
@@ -2253,3 +2253,44 @@ func TestJSONUnmarshal(t *testing.T) {
}
}
}
+
+func (suite *SnowflakeTests) TestQueryTag() {
+ u, err := uuid.NewV7()
+ suite.Require().NoError(err)
+ tag := u.String()
+ suite.Require().NoError(suite.stmt.SetOption(driver.OptionStatementQueryTag, tag))
+
+ val, err := suite.stmt.(adbc.GetSetOptions).GetOption(driver.OptionStatementQueryTag)
+ suite.Require().NoError(err)
+ suite.Require().Equal(tag, val)
+
+ suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT 1"))
+ rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx)
+ suite.Require().NoError(err)
+ defer rdr.Release()
+
+ suite.EqualValues(1, n)
+ suite.True(rdr.Next())
+ suite.False(rdr.Next())
+ suite.Require().NoError(rdr.Err())
+
+ // Unset tag
+ suite.Require().NoError(suite.stmt.SetOption(driver.OptionStatementQueryTag, ""))
+
+ suite.Require().NoError(suite.stmt.SetSqlQuery(fmt.Sprintf(`
+SELECT query_text
+FROM table(information_schema.query_history())
+WHERE query_tag = '%s'
+ORDER BY start_time;
+`, tag)))
+ rdr, n, err = suite.stmt.ExecuteQuery(suite.ctx)
+ suite.Require().NoError(err)
+ defer rdr.Release()
+
+ suite.EqualValues(1, n)
+ suite.True(rdr.Next())
+ result := rdr.Record()
+ suite.Require().Equal("SELECT 1", result.Column(0).(*array.String).Value(0))
+ suite.False(rdr.Next())
+ suite.Require().NoError(rdr.Err())
+}
diff --git a/go/adbc/driver/snowflake/statement.go b/go/adbc/driver/snowflake/statement.go
index 1665fa7782..72b095677c 100644
--- a/go/adbc/driver/snowflake/statement.go
+++ b/go/adbc/driver/snowflake/statement.go
@@ -33,6 +33,7 @@ import (
)
const (
+ OptionStatementQueryTag = "adbc.snowflake.statement.query_tag"
OptionStatementQueueSize = "adbc.rpc.result_queue_size"
OptionStatementPrefetchConcurrency = "adbc.snowflake.rpc.prefetch_concurrency"
OptionStatementIngestWriterConcurrency = "adbc.snowflake.statement.ingest_writer_concurrency"
@@ -54,11 +55,20 @@ type statement struct {
targetTable string
ingestMode string
ingestOptions *ingestOptions
+ queryTag string
bound arrow.Record
streamBind array.RecordReader
}
+// setQueryContext applies the query tag if present.
+func (st *statement) setQueryContext(ctx context.Context) context.Context {
+ if st.queryTag != "" {
+ ctx = gosnowflake.WithQueryTag(ctx, st.queryTag)
+ }
+ return ctx
+}
+
// Close releases any relevant resources associated with this statement
// and closes it (particularly if it is a prepared statement).
//
@@ -82,6 +92,10 @@ func (st *statement) Close() error {
}
func (st *statement) GetOption(key string) (string, error) {
+ switch key {
+ case OptionStatementQueryTag:
+ return st.queryTag, nil
+ }
return "", adbc.Error{
Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key),
Code: adbc.StatusNotFound,
@@ -186,6 +200,9 @@ func (st *statement) SetOption(key string, val string) error {
}
}
return st.SetOptionInt(key, int64(size))
+ case OptionStatementQueryTag:
+ st.queryTag = val
+ return nil
case OptionUseHighPrecision:
switch val {
case adbc.OptionValueEnabled:
@@ -449,6 +466,8 @@ func (st *statement) executeIngest(ctx context.Context) (int64, error) {
//
// This invalidates any prior result sets on this statement.
func (st *statement) ExecuteQuery(ctx context.Context) (array.RecordReader, int64, error) {
+ ctx = st.setQueryContext(ctx)
+
if st.targetTable != "" {
n, err := st.executeIngest(ctx)
return nil, n, err
@@ -500,6 +519,8 @@ func (st *statement) ExecuteQuery(ctx context.Context) (array.RecordReader, int6
// ExecuteUpdate executes a statement that does not generate a result
// set. It returns the number of rows affected if known, otherwise -1.
func (st *statement) ExecuteUpdate(ctx context.Context) (int64, error) {
+ ctx = st.setQueryContext(ctx)
+
if st.targetTable != "" {
return st.executeIngest(ctx)
}
@@ -558,6 +579,8 @@ func (st *statement) ExecuteUpdate(ctx context.Context) (int64, error) {
// ExecuteSchema gets the schema of the result set of a query without executing it.
func (st *statement) ExecuteSchema(ctx context.Context) (*arrow.Schema, error) {
+ ctx = st.setQueryContext(ctx)
+
if st.targetTable != "" {
return nil, adbc.Error{
Msg: "cannot execute schema for ingestion",
diff --git a/python/adbc_driver_snowflake/adbc_driver_snowflake/__init__.py b/python/adbc_driver_snowflake/adbc_driver_snowflake/__init__.py
index 9c2676cead..1e6420ced5 100644
--- a/python/adbc_driver_snowflake/adbc_driver_snowflake/__init__.py
+++ b/python/adbc_driver_snowflake/adbc_driver_snowflake/__init__.py
@@ -112,6 +112,9 @@ class StatementOptions(enum.Enum):
#: Number of concurrent streams being prefetched for a result set.
#: Defaults to 10.
PREFETCH_CONCURRENCY = "adbc.snowflake.rpc.prefetch_concurrency"
+ #: An identifier for a query/queries that can be used to find the query in
+ #: the query history. Use a blank string to unset the tag.
+ QUERY_TAG = "adbc.snowflake.statement.query_tag"
#: Number of parquet files to write in parallel for bulk ingestion
#: Defaults to NumCPU
INGEST_WRITER_CONCURRENCY = "adbc.snowflake.statement.ingest_writer_concurrency"