Skip to content

Commit

Permalink
move wrapTriggerRollback logic (#2720)
Browse files Browse the repository at this point in the history
  • Loading branch information
jycor authored Oct 28, 2024
1 parent a6973b5 commit 2c6e3dd
Show file tree
Hide file tree
Showing 14 changed files with 2,781 additions and 2,944 deletions.
7 changes: 4 additions & 3 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ func (e *Engine) QueryWithBindings(ctx *sql.Context, query string, parsed sqlpar

// PrepQueryPlanForExecution prepares a query plan for execution and returns the result schema with a row iterator to
// begin spooling results
func (e *Engine) PrepQueryPlanForExecution(ctx *sql.Context, _ string, plan sql.Node) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
func (e *Engine) PrepQueryPlanForExecution(ctx *sql.Context, _ string, plan sql.Node, qFlags *sql.QueryFlags) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
// Give the integrator a chance to reject the session before proceeding
// TODO: this check doesn't belong here
err := ctx.Session.ValidateSession(ctx)
Expand Down Expand Up @@ -482,9 +482,9 @@ func (e *Engine) PrepQueryPlanForExecution(ctx *sql.Context, _ string, plan sql.
return nil, nil, nil, err
}

iter = finalizeIters(ctx, plan, nil, iter)
iter = finalizeIters(ctx, plan, qFlags, iter)

return plan.Schema(), iter, nil, nil
return plan.Schema(), iter, qFlags, nil
}

// BoundQueryPlan returns query plan for the given statement with the given bindings applied
Expand Down Expand Up @@ -866,6 +866,7 @@ func findCreateEventNode(planTree sql.Node) (*plan.CreateEvent, error) {

// finalizeIters applies the final transformations on sql.RowIter before execution.
func finalizeIters(ctx *sql.Context, analyzed sql.Node, qFlags *sql.QueryFlags, iter sql.RowIter) sql.RowIter {
iter = rowexec.AddTriggerRollbackIter(ctx, qFlags, iter)
iter = rowexec.AddTransactionCommittingIter(qFlags, iter)
iter = plan.AddTrackedRowIter(ctx, analyzed, iter)
iter = rowexec.AddExpressionCloser(analyzed, iter)
Expand Down
5,471 changes: 2,728 additions & 2,743 deletions enginetest/queries/integration_plans.go

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -1082,9 +1082,9 @@ func (h *Handler) executeBoundPlan(
_ sqlparser.Statement,
plan sql.Node,
_ map[string]*querypb.BindVariable,
_ *sql.QueryFlags,
qFlags *sql.QueryFlags,
) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
return h.e.PrepQueryPlanForExecution(ctx, query, plan)
return h.e.PrepQueryPlanForExecution(ctx, query, plan, qFlags)
}

func bindingsToExprs(bindings map[string]*querypb.BindVariable) (map[string]sqlparser.Expr, error) {
Expand Down
1 change: 0 additions & 1 deletion sql/analyzer/rule_ids.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ const (
assignRoutinesId // assignRoutines
modifyUpdateExprsForJoinId // modifyUpdateExprsForJoin
applyUpdateAccumulatorsId // applyUpdateAccumulators
wrapWithRollbackId // wrapWithRollback
applyForeignKeysId // applyForeignKeys

// validate
Expand Down
35 changes: 17 additions & 18 deletions sql/analyzer/ruleid_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion sql/analyzer/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ func init() {
{applyTriggersId, applyTriggers},
{applyProceduresId, applyProcedures},
{applyUpdateAccumulatorsId, applyUpdateAccumulators},
{wrapWithRollbackId, wrapWithRollback},
{inlineSubqueryAliasRefsId, inlineSubqueryAliasRefs},
{cacheSubqueryAliasesInJoinsId, cacheSubqueryAliasesInJoins},
{BacktickDefaulColumnValueNamesId, backtickDefaultColumnValueNames},
Expand Down
41 changes: 3 additions & 38 deletions sql/analyzer/triggers.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope

switch n := c.Node.(type) {
case *plan.InsertInto:
qFlags.Set(sql.QFlagTrigger)
if trigger.TriggerTime == sqlparser.BeforeStr {
triggerExecutor := plan.NewTriggerExecutor(n.Source, triggerLogic, plan.InsertTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
Name: trigger.TriggerName,
Expand All @@ -359,6 +360,7 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope
}), transform.NewTree, nil
}
case *plan.Update:
qFlags.Set(sql.QFlagTrigger)
if trigger.TriggerTime == sqlparser.BeforeStr {
triggerExecutor := plan.NewTriggerExecutor(n.Child, triggerLogic, plan.UpdateTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
Name: trigger.TriggerName,
Expand Down Expand Up @@ -387,6 +389,7 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope
"does not support triggers; retry with single table deletes")
}

qFlags.Set(sql.QFlagTrigger)
if trigger.TriggerTime == sqlparser.BeforeStr {
triggerExecutor := plan.NewTriggerExecutor(n.Child, triggerLogic, plan.DeleteTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
Name: trigger.TriggerName,
Expand Down Expand Up @@ -517,41 +520,3 @@ func orderTriggersAndReverseAfter(triggers []*plan.CreateTrigger) []*plan.Create
func triggerEventsMatch(event plan.TriggerEvent, event2 string) bool {
return strings.ToLower((string)(event)) == strings.ToLower(event2)
}

// wrapWithRollback wraps the entire tree iff it contains a trigger, allowing rollback when a trigger errors
func wrapWithRollback(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
// Check if tree contains a TriggerExecutor
containsTrigger := false
transform.Inspect(n, func(n sql.Node) bool {
// After Triggers wrap nodes
if _, ok := n.(*plan.TriggerExecutor); ok {
containsTrigger = true
return false // done, don't bother to recurse
}

// Before Triggers on Inserts are inside Source
if n, ok := n.(*plan.InsertInto); ok {
if _, ok := n.Source.(*plan.TriggerExecutor); ok {
containsTrigger = true
return false
}
}

// Before Triggers on Delete and Update should be in children
return true
})

// No TriggerExecutor, so return same tree
if !containsTrigger {
return n, transform.SameTree, nil
}

// If we don't have a transaction session we can't do rollbacks
_, ok := ctx.Session.(sql.TransactionSession)
if !ok {
return plan.NewNoopTriggerRollback(n), transform.NewTree, nil
}

// Wrap tree with new node
return plan.NewTriggerRollback(n), transform.NewTree, nil
}
100 changes: 0 additions & 100 deletions sql/plan/trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,103 +103,3 @@ func (t *TriggerExecutor) CheckPrivileges(ctx *sql.Context, opChecker sql.Privil
func (t *TriggerExecutor) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.GetCoercibility(ctx, t.left)
}

// TriggerRollback is a node that wraps the entire tree iff it contains a trigger, creates a savepoint, and performs a
// rollback if something went wrong during execution
type TriggerRollback struct {
UnaryNode
}

var _ sql.Node = (*TriggerRollback)(nil)
var _ sql.CollationCoercible = (*TriggerRollback)(nil)

func NewTriggerRollback(child sql.Node) *TriggerRollback {
return &TriggerRollback{
UnaryNode: UnaryNode{Child: child},
}
}

func (t *TriggerRollback) WithChildren(children ...sql.Node) (sql.Node, error) {
if len(children) != 1 {
return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1)
}

return NewTriggerRollback(children[0]), nil
}

// CheckPrivileges implements the interface sql.Node.
func (t *TriggerRollback) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
return t.Child.CheckPrivileges(ctx, opChecker)
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (t *TriggerRollback) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.GetCoercibility(ctx, t.Child)
}

func (t *TriggerRollback) IsReadOnly() bool {
return t.Child.IsReadOnly()
}

func (t *TriggerRollback) String() string {
pr := sql.NewTreePrinter()
_ = pr.WriteNode("TriggerRollback()")
_ = pr.WriteChildren(t.Child.String())
return pr.String()
}

func (t *TriggerRollback) DebugString() string {
pr := sql.NewTreePrinter()
_ = pr.WriteNode("TriggerRollback")
_ = pr.WriteChildren(sql.DebugString(t.Child))
return pr.String()
}

type NoopTriggerRollback struct {
UnaryNode
}

var _ sql.Node = (*NoopTriggerRollback)(nil)
var _ sql.CollationCoercible = (*NoopTriggerRollback)(nil)

func NewNoopTriggerRollback(child sql.Node) *NoopTriggerRollback {
return &NoopTriggerRollback{
UnaryNode: UnaryNode{Child: child},
}
}

func (t *NoopTriggerRollback) WithChildren(children ...sql.Node) (sql.Node, error) {
if len(children) != 1 {
return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1)
}

return NewNoopTriggerRollback(children[0]), nil
}

// CheckPrivileges implements the interface sql.Node.
func (t *NoopTriggerRollback) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
return t.Child.CheckPrivileges(ctx, opChecker)
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (t *NoopTriggerRollback) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.GetCoercibility(ctx, t.Child)
}

func (t *NoopTriggerRollback) IsReadOnly() bool {
return true
}

func (t *NoopTriggerRollback) String() string {
pr := sql.NewTreePrinter()
_ = pr.WriteNode("TriggerRollback()")
_ = pr.WriteChildren(t.Child.String())
return pr.String()
}

func (t *NoopTriggerRollback) DebugString() string {
pr := sql.NewTreePrinter()
_ = pr.WriteNode("TriggerRollback")
_ = pr.WriteChildren(sql.DebugString(t.Child))
return pr.String()
}
4 changes: 4 additions & 0 deletions sql/query_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ const (
QFlagDeferProjections
// QFlagUndeferrableExprs indicates that the query has expressions that cannot be deferred
QFlagUndeferrableExprs
QFlagTrigger
)

type QueryFlags struct {
Expand All @@ -69,6 +70,9 @@ func (qp *QueryFlags) Unset(flag int) {
}

func (qp *QueryFlags) IsSet(flag int) bool {
if qp == nil {
return false
}
return qp.Flags.Contains(flag)
}

Expand Down
3 changes: 1 addition & 2 deletions sql/rowexec/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ type ExecBuilderFunc func(ctx *sql.Context, n sql.Node, r sql.Row) (sql.RowIter,
// sql.ExecSourceRel are also built into the tree.
type BaseBuilder struct {
// if override is provided, we try to build executor with this first
override sql.NodeExecBuilder
triggerSavePointCounter int // tracks the number of save points that have been created by triggers
override sql.NodeExecBuilder
}

func (b *BaseBuilder) Build(ctx *sql.Context, n sql.Node, r sql.Row) (sql.RowIter, error) {
Expand Down
26 changes: 0 additions & 26 deletions sql/rowexec/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,32 +256,6 @@ func (b *BaseBuilder) buildDropTable(ctx *sql.Context, n *plan.DropTable, row sq
return rowIterWithOkResultWithZeroRowsAffected(), nil
}

func (b *BaseBuilder) buildTriggerRollback(ctx *sql.Context, n *plan.TriggerRollback, row sql.Row) (sql.RowIter, error) {
childIter, err := b.buildNodeExec(ctx, n.Child, row)
if err != nil {
return nil, err
}

savePointCounter := b.triggerSavePointCounter + 1
savePointName := fmt.Sprintf("%s%v", TriggerSavePointPrefix, savePointCounter)
ctx.GetLogger().Tracef("TriggerRollback creating savepoint: %s", savePointName)

ts, ok := ctx.Session.(sql.TransactionSession)
if !ok {
return nil, fmt.Errorf("expected a sql.TransactionSession, but got %T", ctx.Session)
}

if err := ts.CreateSavepoint(ctx, ctx.GetTransaction(), savePointName); err != nil {
ctx.GetLogger().WithError(err).Errorf("CreateSavepoint failed")
}
b.triggerSavePointCounter = savePointCounter

return &triggerRollbackIter{
child: childIter,
savePointName: savePointName,
}, nil
}

func (b *BaseBuilder) buildAlterIndex(ctx *sql.Context, n *plan.AlterIndex, row sql.Row) (sql.RowIter, error) {
err := b.executeAlterIndex(ctx, n)
if err != nil {
Expand Down
23 changes: 22 additions & 1 deletion sql/rowexec/dml_iters.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2023 Dolthub, Inc.
// Copyright 2023-2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,6 +33,27 @@ type triggerRollbackIter struct {
savePointName string
}

func AddTriggerRollbackIter(ctx *sql.Context, qFlags *sql.QueryFlags, iter sql.RowIter) sql.RowIter {
if !qFlags.IsSet(sql.QFlagTrigger) {
return iter
}

transSess, isTransSess := ctx.Session.(sql.TransactionSession)
if !isTransSess {
return iter
}

ctx.GetLogger().Tracef("TriggerRollback creating savepoint: %s", TriggerSavePointPrefix)
if err := transSess.CreateSavepoint(ctx, ctx.GetTransaction(), TriggerSavePointPrefix); err != nil {
ctx.GetLogger().WithError(err).Errorf("CreateSavepoint failed")
}

return &triggerRollbackIter{
child: iter,
savePointName: TriggerSavePointPrefix,
}
}

func (t *triggerRollbackIter) Next(ctx *sql.Context) (row sql.Row, returnErr error) {
childRow, err := t.child.Next(ctx)

Expand Down
4 changes: 0 additions & 4 deletions sql/rowexec/node_builder.gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,6 @@ func (b *BaseBuilder) buildNodeExecNoAnalyze(ctx *sql.Context, n sql.Node, row s
return b.buildHaving(ctx, n, row)
case *plan.Signal:
return b.buildSignal(ctx, n, row)
case *plan.TriggerRollback:
return b.buildTriggerRollback(ctx, n, row)
case *plan.ExternalProcedure:
return b.buildExternalProcedure(ctx, n, row)
case *plan.Into:
Expand Down Expand Up @@ -246,8 +244,6 @@ func (b *BaseBuilder) buildNodeExecNoAnalyze(ctx *sql.Context, n sql.Node, row s
return b.buildCreateIndex(ctx, n, row)
case *plan.Procedure:
return b.buildProcedure(ctx, n, row)
case *plan.NoopTriggerRollback:
return b.buildNoopTriggerRollback(ctx, n, row)
case *plan.With:
return b.buildWith(ctx, n, row)
case *plan.Project:
Expand Down
5 changes: 0 additions & 5 deletions sql/rowexec/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,6 @@ func (b *BaseBuilder) buildCommit(ctx *sql.Context, n *plan.Commit, row sql.Row)
return sql.RowsToRowIter(), nil
}

func (b *BaseBuilder) buildNoopTriggerRollback(ctx *sql.Context, n *plan.NoopTriggerRollback, row sql.Row) (sql.RowIter, error) {
return b.buildNodeExec(ctx, n.Child, row)

}

func (b *BaseBuilder) buildKill(ctx *sql.Context, n *plan.Kill, row sql.Row) (sql.RowIter, error) {
return &lazyRowIter{
func(ctx *sql.Context) (sql.Row, error) {
Expand Down

0 comments on commit 2c6e3dd

Please sign in to comment.