From 90855ce5a79d9a93c9b557b1c41dff2a566c3e86 Mon Sep 17 00:00:00 2001 From: David Li Date: Sun, 26 Jan 2025 23:27:17 -0500 Subject: [PATCH] feat(go/adbc/driver/snowflake): add query tag option This lets you identify particular queries in the query history. Fixes #1934. --- go/adbc/driver/snowflake/driver_test.go | 41 +++++++++++++++++++ go/adbc/driver/snowflake/statement.go | 23 +++++++++++ .../adbc_driver_snowflake/__init__.py | 3 ++ 3 files changed, 67 insertions(+) diff --git a/go/adbc/driver/snowflake/driver_test.go b/go/adbc/driver/snowflake/driver_test.go index d014ef0ef1..151292be99 100644 --- a/go/adbc/driver/snowflake/driver_test.go +++ b/go/adbc/driver/snowflake/driver_test.go @@ -2253,3 +2253,44 @@ func TestJSONUnmarshal(t *testing.T) { } } } + +func (suite *SnowflakeTests) TestQueryTag() { + u, err := uuid.NewV7() + suite.Require().NoError(err) + tag := u.String() + suite.Require().NoError(suite.stmt.SetOption(driver.OptionStatementQueryTag, tag)) + + val, err := suite.stmt.(adbc.GetSetOptions).GetOption(driver.OptionStatementQueryTag) + suite.Require().NoError(err) + suite.Require().Equal(tag, val) + + suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT 1")) + rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) + suite.Require().NoError(err) + defer rdr.Release() + + suite.EqualValues(1, n) + suite.True(rdr.Next()) + suite.False(rdr.Next()) + suite.Require().NoError(rdr.Err()) + + // Unset tag + suite.Require().NoError(suite.stmt.SetOption(driver.OptionStatementQueryTag, "")) + + suite.Require().NoError(suite.stmt.SetSqlQuery(fmt.Sprintf(` +SELECT query_text +FROM table(information_schema.query_history()) +WHERE query_tag = '%s' +ORDER BY start_time; +`, tag))) + rdr, n, err = suite.stmt.ExecuteQuery(suite.ctx) + suite.Require().NoError(err) + defer rdr.Release() + + suite.EqualValues(1, n) + suite.True(rdr.Next()) + result := rdr.Record() + suite.Require().Equal("SELECT 1", result.Column(0).(*array.String).Value(0)) + suite.False(rdr.Next()) + suite.Require().NoError(rdr.Err()) +} diff --git a/go/adbc/driver/snowflake/statement.go b/go/adbc/driver/snowflake/statement.go index 1665fa7782..72b095677c 100644 --- a/go/adbc/driver/snowflake/statement.go +++ b/go/adbc/driver/snowflake/statement.go @@ -33,6 +33,7 @@ import ( ) const ( + OptionStatementQueryTag = "adbc.snowflake.statement.query_tag" OptionStatementQueueSize = "adbc.rpc.result_queue_size" OptionStatementPrefetchConcurrency = "adbc.snowflake.rpc.prefetch_concurrency" OptionStatementIngestWriterConcurrency = "adbc.snowflake.statement.ingest_writer_concurrency" @@ -54,11 +55,20 @@ type statement struct { targetTable string ingestMode string ingestOptions *ingestOptions + queryTag string bound arrow.Record streamBind array.RecordReader } +// setQueryContext applies the query tag if present. +func (st *statement) setQueryContext(ctx context.Context) context.Context { + if st.queryTag != "" { + ctx = gosnowflake.WithQueryTag(ctx, st.queryTag) + } + return ctx +} + // Close releases any relevant resources associated with this statement // and closes it (particularly if it is a prepared statement). // @@ -82,6 +92,10 @@ func (st *statement) Close() error { } func (st *statement) GetOption(key string) (string, error) { + switch key { + case OptionStatementQueryTag: + return st.queryTag, nil + } return "", adbc.Error{ Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), Code: adbc.StatusNotFound, @@ -186,6 +200,9 @@ func (st *statement) SetOption(key string, val string) error { } } return st.SetOptionInt(key, int64(size)) + case OptionStatementQueryTag: + st.queryTag = val + return nil case OptionUseHighPrecision: switch val { case adbc.OptionValueEnabled: @@ -449,6 +466,8 @@ func (st *statement) executeIngest(ctx context.Context) (int64, error) { // // This invalidates any prior result sets on this statement. func (st *statement) ExecuteQuery(ctx context.Context) (array.RecordReader, int64, error) { + ctx = st.setQueryContext(ctx) + if st.targetTable != "" { n, err := st.executeIngest(ctx) return nil, n, err @@ -500,6 +519,8 @@ func (st *statement) ExecuteQuery(ctx context.Context) (array.RecordReader, int6 // ExecuteUpdate executes a statement that does not generate a result // set. It returns the number of rows affected if known, otherwise -1. func (st *statement) ExecuteUpdate(ctx context.Context) (int64, error) { + ctx = st.setQueryContext(ctx) + if st.targetTable != "" { return st.executeIngest(ctx) } @@ -558,6 +579,8 @@ func (st *statement) ExecuteUpdate(ctx context.Context) (int64, error) { // ExecuteSchema gets the schema of the result set of a query without executing it. func (st *statement) ExecuteSchema(ctx context.Context) (*arrow.Schema, error) { + ctx = st.setQueryContext(ctx) + if st.targetTable != "" { return nil, adbc.Error{ Msg: "cannot execute schema for ingestion", diff --git a/python/adbc_driver_snowflake/adbc_driver_snowflake/__init__.py b/python/adbc_driver_snowflake/adbc_driver_snowflake/__init__.py index 9c2676cead..1e6420ced5 100644 --- a/python/adbc_driver_snowflake/adbc_driver_snowflake/__init__.py +++ b/python/adbc_driver_snowflake/adbc_driver_snowflake/__init__.py @@ -112,6 +112,9 @@ class StatementOptions(enum.Enum): #: Number of concurrent streams being prefetched for a result set. #: Defaults to 10. PREFETCH_CONCURRENCY = "adbc.snowflake.rpc.prefetch_concurrency" + #: An identifier for a query/queries that can be used to find the query in + #: the query history. Use a blank string to unset the tag. + QUERY_TAG = "adbc.snowflake.statement.query_tag" #: Number of parquet files to write in parallel for bulk ingestion #: Defaults to NumCPU INGEST_WRITER_CONCURRENCY = "adbc.snowflake.statement.ingest_writer_concurrency"