Skip to content

Commit

Permalink
Support TVF (#209)
Browse files Browse the repository at this point in the history
* Support TVF

* Update testdata

* Do go test --update

* Fix indentation

* Rename to update_with_safe_ml_predict.sql

* Fix doc comment

* Simplify implementation
  • Loading branch information
apstndb authored Dec 18, 2024
1 parent 438cf9d commit e0d66c0
Show file tree
Hide file tree
Showing 19 changed files with 2,031 additions and 0 deletions.
56 changes: 56 additions & 0 deletions ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ func (PathTableExpr) isTableExpr() {}
func (SubQueryTableExpr) isTableExpr() {}
func (ParenTableExpr) isTableExpr() {}
func (Join) isTableExpr() {}
func (TVFCallExpr) isTableExpr() {}

// JoinCondition represents condition part of JOIN expression.
type JoinCondition interface {
Expand Down Expand Up @@ -221,6 +222,15 @@ func (IntervalArg) isArg() {}
func (SequenceArg) isArg() {}
func (LambdaArg) isArg() {}

type TVFArg interface {
Node
isTVFArg()
}

func (ExprArg) isTVFArg() {}
func (ModelArg) isTVFArg() {}
func (TableArg) isTVFArg() {}

// NullHandlingModifier represents IGNORE/RESPECT NULLS of aggregate function calls
type NullHandlingModifier interface {
Node
Expand Down Expand Up @@ -1162,6 +1172,28 @@ type CallExpr struct {
Hint *Hint // optional
}

// TVFCallExpr is table-valued function call expression node.
//
// {{.Name | sql}}(
// {{.Args | sqlJoin ", "}}
// {{if len(.Args) > 0 && len(.NamedArgs) > 0}}, {{end}}
// {{.NamedArgs | sqlJoin ", "}}
// )
// {{.Hint | sqlOpt}}
// {{.Sample | sqlOpt}}
type TVFCallExpr struct {
// pos = Name.pos
// end = (Sample ?? Hint).end || Rparen + 1

Rparen token.Pos // position of ")"

Name *Path
Args []TVFArg
NamedArgs []*NamedArg
Hint *Hint // optional
Sample *TableSample // optional
}

// ExprArg is argument of the generic function call.
//
// {{.Expr | sql}}
Expand Down Expand Up @@ -1212,6 +1244,30 @@ type LambdaArg struct {
Expr Expr
}

// ModelArg is argument of model function call.
//
// MODEL {{.Name | sql}}
type ModelArg struct {
// pos = Model
// end = Name.end

Model token.Pos // position of "MODEL" keyword

Name *Path
}

// TableArg is TABLE table_name argument of table valued function call.
//
// TABLE {{.Name | sql}}
type TableArg struct {
// pos = Table
// end = Name.end

Table token.Pos // position of "TABLE" keyword

Name *Path
}

// NamedArg represents a name and value pair in named arguments
//
// {{.Name | sql}} => {{.Value | sql}}
Expand Down
24 changes: 24 additions & 0 deletions ast/pos.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions ast/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,15 @@ func (l *LambdaArg) SQL() string {
l.Expr.SQL()
}

func (c *TVFCallExpr) SQL() string {
return c.Name.SQL() + "(" +
sqlJoin(c.Args, ", ") +
strOpt(len(c.Args) > 0 && len(c.NamedArgs) > 0, ", ") +
sqlJoin(c.NamedArgs, ", ") +
")" +
sqlOpt(" ", c.Hint, "")
}

func (n *NamedArg) SQL() string { return n.Name.SQL() + " => " + n.Value.SQL() }

func (i *IgnoreNulls) SQL() string { return "IGNORE NULLS" }
Expand All @@ -542,6 +551,14 @@ func (s *SequenceArg) SQL() string {
return "SEQUENCE " + s.Expr.SQL()
}

func (s *ModelArg) SQL() string {
return "MODEL " + s.Name.SQL()
}

func (s *TableArg) SQL() string {
return "TABLE " + s.Name.SQL()
}

func (*CountStarExpr) SQL() string {
return "COUNT(*)"
}
Expand Down
60 changes: 60 additions & 0 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,9 @@ func (p *Parser) parseSimpleTableExpr() ast.TableExpr {

if p.Token.Kind == token.TokenIdent {
ids := p.parseIdentOrPath()
if p.Token.Kind == "(" {
return p.parseTVFCallExpr(ids)
}
if len(ids) == 1 {
return p.parseTableNameSuffix(ids[0])
}
Expand All @@ -826,6 +829,63 @@ func (p *Parser) parseSimpleTableExpr() ast.TableExpr {
panic(p.errorfAtToken(&p.Token, "expected token: (, UNNEST, <ident>, but: %s", p.Token.Kind))
}

func (p *Parser) parseTVFCallExpr(ids []*ast.Ident) *ast.TVFCallExpr {
p.expect("(")

var args []ast.TVFArg
if p.Token.Kind != ")" {
for !p.lookaheadNamedArg() {
args = append(args, p.parseTVFArg())
if p.Token.Kind != "," {
break
}
p.nextToken()
}
}

var namedArgs []*ast.NamedArg
if p.lookaheadNamedArg() {
namedArgs = parseCommaSeparatedList(p, p.parseNamedArg)
}

rparen := p.expect(")").Pos
hint := p.tryParseHint()
sample := p.tryParseTableSample()

return &ast.TVFCallExpr{
Rparen: rparen,
Name: &ast.Path{Idents: ids},
Args: args,
NamedArgs: namedArgs,
Hint: hint,
Sample: sample,
}
}

func (p *Parser) parseTVFArg() ast.TVFArg {
pos := p.Token.Pos
switch {
case p.Token.IsKeywordLike("TABLE"):
p.nextToken()
path := p.parsePath()

return &ast.TableArg{
Table: pos,
Name: path,
}
case p.Token.IsKeywordLike("MODEL"):
p.nextToken()
path := p.parsePath()

return &ast.ModelArg{
Model: pos,
Name: path,
}
default:
return p.parseExprArg()
}
}

func (p *Parser) parseIdentOrPath() []*ast.Ident {
ids := []*ast.Ident{p.parseIdent()}
for p.Token.Kind == "." {
Expand Down
12 changes: 12 additions & 0 deletions testdata/input/dml/update_with_safe_ml_predict.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
-- https://cloud.google.com/spanner/docs/backfill-embeddings?hl=en#backfill
UPDATE products
SET
products.desc_embed = (
SELECT embeddings.values
FROM SAFE.ML.PREDICT(
MODEL gecko_model,
(SELECT products.description AS content)
) @{remote_udf_max_rows_per_rpc=200}
),
products.desc_embed_model_version = 3
WHERE products.desc_embed IS NULL
6 changes: 6 additions & 0 deletions testdata/input/query/select_from_change_stream.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SELECT ChangeRecord FROM READ_SingersNameStream (
start_timestamp => "2022-05-01T09:00:00Z",
end_timestamp => NULL,
partition_token => NULL,
heartbeat_milliseconds => 10000
)
7 changes: 7 additions & 0 deletions testdata/input/query/select_from_ml_predict_hint.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
-- https://cloud.google.com/spanner/docs/ml-tutorial-generative-ai?hl=en#register_a_generative_ai_model_in_a_schema
SELECT content
FROM ML.PREDICT(
MODEL TextBison,
(SELECT "Is 13 prime?" AS prompt),
STRUCT(256 AS maxOutputTokens, 0.2 AS temperature, 40 as topK, 0.95 AS topP)
) @{remote_udf_max_rows_per_rpc=1}
2 changes: 2 additions & 0 deletions testdata/input/query/select_from_ml_predict_simple.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT id, color, value
FROM ML.PREDICT(MODEL DiamondAppraise, TABLE Diamonds)
15 changes: 15 additions & 0 deletions testdata/input/query/select_from_ml_predict_textbison.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
SELECT product_id, product_name, content
FROM ML.PREDICT(
MODEL TextBison,
(SELECT
product.id as product_id,
product.name as product_name,
CONCAT("Is this product safe for infants?", "\n",
"Product Name: ", product.name, "\n",
"Category Name: ", category.name, "\n",
"Product Description:", product.description) AS prompt
FROM
Products AS product JOIN Categories AS category
ON product.category_id = category.id),
STRUCT(100 AS maxOutputTokens)
) @{remote_udf_max_rows_per_rpc=1}
Loading

0 comments on commit e0d66c0

Please sign in to comment.