Skip to content

Commit

Permalink
fix(spanner/spannertest): support queries in ExecuteSql (#3640)
Browse files Browse the repository at this point in the history
Normal queries from the Spanner client use the ExecuteStreamingSql method,
while DML statements use ExecuteSql. This distinction was also built into
spannertest where ExecuteSql would only support DML statements and required
a transaction to be specified. The session pool however uses ExecuteSql to
execute a simple `SELECT 1` query without specifying any transaction. This
would cause a nil pointer dereference.

This PR introduces support for queries in the ExecuteSql method. The current
logic assumes that the statement is a query if the transaction is a single-
use read-only transaction.

Fixes #3639
  • Loading branch information
olavloite authored Feb 4, 2021
1 parent b7c3ca6 commit 8eede84
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 26 deletions.
93 changes: 72 additions & 21 deletions spanner/spannertest/inmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -485,9 +485,19 @@ func (s *server) readTx(ctx context.Context, session string, tsel *spannerpb.Tra
}

func (s *server) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) {
// Assume this is probably a DML statement. Queries tend to use ExecuteStreamingSql.
// Assume this is probably a DML statement or a ping from the session pool.
// Queries normally use ExecuteStreamingSql.
// TODO: Expand this to support more things.

// If it is a single-use transaction we assume it is a query.
if req.Transaction.GetSelector() == nil || req.Transaction.GetSingleUse().GetReadOnly() != nil {
ri, err := s.executeQuery(req)
if err != nil {
return nil, err
}
return s.resultSet(ri)
}

obj, ok := req.Transaction.Selector.(*spannerpb.TransactionSelector_Id)
if !ok {
return nil, fmt.Errorf("unsupported transaction type %T", req.Transaction.Selector)
Expand Down Expand Up @@ -527,27 +537,31 @@ func (s *server) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream sp
}
defer cleanup()

ri, err := s.executeQuery(req)
if err != nil {
return err
}
return s.readStream(stream.Context(), tx, stream.Send, ri)
}

func (s *server) executeQuery(req *spannerpb.ExecuteSqlRequest) (ri rowIter, err error) {
q, err := spansql.ParseQuery(req.Sql)
if err != nil {
// TODO: check what code the real Spanner returns here.
return status.Errorf(codes.InvalidArgument, "bad query: %v", err)
return nil, status.Errorf(codes.InvalidArgument, "bad query: %v", err)
}

params, err := parseQueryParams(req.GetParams(), req.ParamTypes)
if err != nil {
return err
return nil, err
}

s.logf("Querying: %s", q.SQL())
if len(params) > 0 {
s.logf(" ▹ %v", params)
}

ri, err := s.db.Query(q, params)
if err != nil {
return err
}
return s.readStream(stream.Context(), tx, stream.Send, ri)
return s.db.Query(q, params)
}

// TODO: Read
Expand Down Expand Up @@ -591,21 +605,39 @@ func (s *server) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Span
return s.readStream(stream.Context(), tx, stream.Send, ri)
}

func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spannerpb.PartialResultSet) error, ri rowIter) error {
// Build the result set metadata.
rsm := &spannerpb.ResultSetMetadata{
RowType: &spannerpb.StructType{},
// TODO: transaction info?
func (s *server) resultSet(ri rowIter) (*spannerpb.ResultSet, error) {
rsm, err := s.buildResultSetMetadata(ri)
if err != nil {
return nil, err
}
for _, ci := range ri.Cols() {
st, err := spannerTypeFromType(ci.Type)
if err != nil {
return err
rs := &spannerpb.ResultSet{
Metadata: rsm,
}
for {
row, err := ri.Next()
if err == io.EOF {
break
} else if err != nil {
return nil, err
}
rsm.RowType.Fields = append(rsm.RowType.Fields, &spannerpb.StructType_Field{
Name: string(ci.Name),
Type: st,
})

values := make([]*structpb.Value, len(row))
for i, x := range row {
v, err := spannerValueFromValue(x)
if err != nil {
return nil, err
}
values[i] = v
}
rs.Rows = append(rs.Rows, &structpb.ListValue{Values: values})
}
return rs, nil
}

func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spannerpb.PartialResultSet) error, ri rowIter) error {
rsm, err := s.buildResultSetMetadata(ri)
if err != nil {
return err
}

for {
Expand Down Expand Up @@ -640,6 +672,25 @@ func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spa
return nil
}

func (s *server) buildResultSetMetadata(ri rowIter) (*spannerpb.ResultSetMetadata, error) {
// Build the result set metadata.
rsm := &spannerpb.ResultSetMetadata{
RowType: &spannerpb.StructType{},
// TODO: transaction info?
}
for _, ci := range ri.Cols() {
st, err := spannerTypeFromType(ci.Type)
if err != nil {
return nil, err
}
rsm.RowType.Fields = append(rsm.RowType.Fields, &spannerpb.StructType_Field{
Name: string(ci.Name),
Type: st,
})
}
return rsm, nil
}

func (s *server) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) {
//s.logf("BeginTransaction(%v)", req)

Expand Down
52 changes: 47 additions & 5 deletions spanner/spannertest/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"cloud.google.com/go/civil"
"cloud.google.com/go/spanner"
dbadmin "cloud.google.com/go/spanner/admin/database/apiv1"
v1 "cloud.google.com/go/spanner/apiv1"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
"google.golang.org/grpc"
Expand All @@ -56,7 +57,7 @@ func dbName() string {
return "projects/fake-proj/instances/fake-instance/databases/fake-db"
}

func makeClient(t *testing.T) (*spanner.Client, *dbadmin.DatabaseAdminClient, func()) {
func makeClient(t *testing.T) (*spanner.Client, *dbadmin.DatabaseAdminClient, *v1.Client, func()) {
// Despite the docs, this context is also used for auth,
// so it needs to be long-lived.
ctx := context.Background()
Expand All @@ -73,7 +74,13 @@ func makeClient(t *testing.T) (*spanner.Client, *dbadmin.DatabaseAdminClient, fu
client.Close()
t.Fatalf("Connecting DB admin client: %v", err)
}
return client, adminClient, func() { client.Close(); adminClient.Close() }
gapicClient, err := v1.NewClient(ctx, dialOpt)
if err != nil {
client.Close()
adminClient.Close()
t.Fatalf("Connecting Spanner generated client: %v", err)
}
return client, adminClient, gapicClient, func() { client.Close(); adminClient.Close(); gapicClient.Close() }
}

// Don't use SPANNER_EMULATOR_HOST because we need the raw connection for
Expand Down Expand Up @@ -102,16 +109,23 @@ func makeClient(t *testing.T) (*spanner.Client, *dbadmin.DatabaseAdminClient, fu
srv.Close()
t.Fatalf("Connecting to in-memory fake DB admin: %v", err)
}
return client, adminClient, func() {
gapicClient, err := v1.NewClient(ctx, option.WithGRPCConn(conn))
if err != nil {
srv.Close()
t.Fatalf("Connecting to in-memory fake generated Spanner client: %v", err)
}

return client, adminClient, gapicClient, func() {
client.Close()
adminClient.Close()
gapicClient.Close()
conn.Close()
srv.Close()
}
}

func TestIntegration_SpannerBasics(t *testing.T) {
client, adminClient, cleanup := makeClient(t)
client, adminClient, generatedClient, cleanup := makeClient(t)
defer cleanup()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
Expand All @@ -137,6 +151,34 @@ func TestIntegration_SpannerBasics(t *testing.T) {
}
it.Stop()

// Try to execute the equivalent of a session pool ping.
// This used to cause a panic as ExecuteSql did not expect any requests
// that would execute a query without a transaction selector.
// https://github.com/googleapis/google-cloud-go/issues/3639
s, err := generatedClient.CreateSession(ctx, &spannerpb.CreateSessionRequest{Database: dbName()})
if err != nil {
t.Fatalf("Creating session: %v", err)
}
rs, err := generatedClient.ExecuteSql(ctx, &spannerpb.ExecuteSqlRequest{
Session: s.Name,
Sql: "SELECT 1",
})
if err != nil {
t.Fatalf("Executing ping: %v", err)
}
if len(rs.Rows) != 1 {
t.Fatalf("Ping gave %v rows, want 1", len(rs.Rows))
}
if len(rs.Rows[0].Values) != 1 {
t.Fatalf("Ping gave %v cols, want 1", len(rs.Rows[0].Values))
}
if rs.Rows[0].Values[0].GetStringValue() != "1" {
t.Fatalf("Ping gave value %v, want '1'", rs.Rows[0].Values[0].GetStringValue())
}
if err = generatedClient.DeleteSession(ctx, &spannerpb.DeleteSessionRequest{Name: s.Name}); err != nil {
t.Fatalf("Deleting session: %v", err)
}

// Drop any previous test table/index, and make a fresh one in a few stages.
const tableName = "Characters"
err = updateDDL(t, adminClient, "DROP INDEX AgeIndex")
Expand Down Expand Up @@ -400,7 +442,7 @@ func TestIntegration_SpannerBasics(t *testing.T) {
}

func TestIntegration_ReadsAndQueries(t *testing.T) {
client, adminClient, cleanup := makeClient(t)
client, adminClient, _, cleanup := makeClient(t)
defer cleanup()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
Expand Down

0 comments on commit 8eede84

Please sign in to comment.