diff --git a/pkg/ccl/logictestccl/testdata/logic_test/triggers b/pkg/ccl/logictestccl/testdata/logic_test/triggers index c7cd6a8ed8e9..7e5ef35e580b 100644 --- a/pkg/ccl/logictestccl/testdata/logic_test/triggers +++ b/pkg/ccl/logictestccl/testdata/logic_test/triggers @@ -3642,6 +3642,335 @@ statement ok DROP TABLE t; DROP FUNCTION g; +# ============================================================================== +# Test restrictions on multiple mutations to the same table. +# ============================================================================== + +subtest multiple_mutations + +statement ok +CREATE TABLE t1 (a INT, b INT); +CREATE TABLE t2 (a INT, b INT); +CREATE TABLE parent (k INT PRIMARY KEY); +CREATE TABLE child (k INT PRIMARY KEY, v INT REFERENCES parent(k) ON DELETE CASCADE); + +statement ok +CREATE FUNCTION g() RETURNS TRIGGER LANGUAGE PLpgSQL AS $$ + BEGIN + INSERT INTO t2 VALUES (100, 200); + RETURN COALESCE(NEW, OLD); + END +$$; + +statement ok +CREATE FUNCTION h() RETURNS TRIGGER LANGUAGE PLpgSQL AS $$ + BEGIN + UPDATE t2 SET b = b + 100 WHERE a = 100; + RETURN COALESCE(NEW, OLD); + END +$$; + +statement ok +CREATE FUNCTION insert_t1() RETURNS INT LANGUAGE PLpgSQL AS $$ BEGIN INSERT INTO t1 VALUES (1, 1); RETURN 0; END $$; +CREATE FUNCTION delete_parent() RETURNS INT LANGUAGE PLpgSQL AS $$ BEGIN DELETE FROM parent WHERE k = 1; RETURN 0; END $$; + +# ------------------------------------------------------------------------------ +# Test a BEFORE trigger with an INSERT statement. +# ------------------------------------------------------------------------------ + +statement ok +CREATE TRIGGER foo BEFORE INSERT OR UPDATE ON t1 FOR EACH ROW EXECUTE FUNCTION g(); + +statement ok +CREATE TRIGGER bar BEFORE DELETE ON child FOR EACH ROW EXECUTE FUNCTION g(); + +# Multiple mutations of the same table are allowed if they all use INSERT +# without ON CONFLICT. +statement ok +WITH foo AS (INSERT INTO t2 VALUES (1, 1) RETURNING a) INSERT INTO t1 VALUES (1, 1); + +statement ok +WITH foo AS (INSERT INTO t2 VALUES (1, 1) RETURNING a) SELECT insert_t1(); + +statement ok +UPSERT INTO parent VALUES (1); +UPSERT INTO child VALUES (1, 1); +WITH foo AS (INSERT INTO t2 VALUES (1, 1) RETURNING a) DELETE FROM parent WHERE k = 1; + +# Wrapping the DELETE on parent in a UDF causes the FK cascade (and therefore +# the BEFORE trigger) to execute as part of the main query. However, INSERT +# statements do not conflict with one another. +statement ok +UPSERT INTO parent VALUES (1); +UPSERT INTO child VALUES (1, 1); +WITH foo AS (INSERT INTO t2 VALUES (1, 1) RETURNING a) SELECT delete_parent(); + +# The triggered INSERT conflicts with the outer UPDATE. +statement error pgcode 0A000 pq: multiple mutations of the same table "t2" are not supported unless they all use INSERT without ON CONFLICT +WITH foo AS (UPDATE t2 SET b = 1 WHERE a = 1 RETURNING a) INSERT INTO t1 VALUES (1, 1); + +statement error pgcode 0A000 pq: multiple mutations of the same table "t2" are not supported unless they all use INSERT without ON CONFLICT +WITH foo AS (UPDATE t2 SET b = 1 WHERE a = 1 RETURNING a) SELECT insert_t1(); + +# The triggered INSERT does not conflict with the outer UPDATE because it is run +# as part of the FK cascade after the main query. +statement ok +UPSERT INTO parent VALUES (1); +UPSERT INTO child VALUES (1, 1); +WITH foo AS (UPDATE t2 SET b = 1 WHERE a = 1 RETURNING a) DELETE FROM parent WHERE k = 1; + +# Wrapping the DELETE on parent in a UDF causes the FK cascade (and therefore +# the BEFORE trigger) to execute as part of the main query. As a result, the +# triggered INSERT conflicts with the outer UPDATE. +statement error pgcode 0A000 pq: while building cascade expression: multiple mutations of the same table "t2" are not supported unless they all use INSERT without ON CONFLICT +UPSERT INTO parent VALUES (1); +UPSERT INTO child VALUES (1, 1); +WITH foo AS (UPDATE t2 SET b = 1 WHERE a = 1 RETURNING a) SELECT delete_parent(); + +statement ok +DROP TRIGGER foo ON t1; + +statement ok +DROP TRIGGER bar ON child; + +# ------------------------------------------------------------------------------ +# Test a BEFORE trigger with an UPDATE statement. +# ------------------------------------------------------------------------------ + +statement ok +CREATE TRIGGER foo BEFORE INSERT OR UPDATE ON t1 FOR EACH ROW EXECUTE FUNCTION h(); + +statement ok +CREATE TRIGGER bar BEFORE DELETE ON child FOR EACH ROW EXECUTE FUNCTION h(); + +# The triggered UPDATE conflicts with the outer INSERT. +statement error pgcode 0A000 pq: multiple mutations of the same table "t2" are not supported unless they all use INSERT without ON CONFLICT +WITH foo AS (INSERT INTO t2 VALUES (1, 1) RETURNING a) INSERT INTO t1 VALUES (1, 2); + +statement error pgcode 0A000 pq: multiple mutations of the same table "t2" are not supported unless they all use INSERT without ON CONFLICT +WITH foo AS (INSERT INTO t2 VALUES (1, 1) RETURNING a) SELECT insert_t1(); + +# The triggered UPDATE does not conflict with the outer INSERT because it is run +# as part of the FK cascade after the main query. +statement ok +UPSERT INTO parent VALUES (1); +UPSERT INTO child VALUES (1, 1); +WITH foo AS (INSERT INTO t2 VALUES (1, 1) RETURNING a) DELETE FROM parent WHERE k = 1; + +# Wrapping the DELETE on parent in a UDF causes the FK cascade (and therefore +# the BEFORE trigger) to execute as part of the main query. As a result, the +# triggered UPDATE conflicts with the outer INSERT. +statement error pgcode 0A000 pq: while building cascade expression: multiple mutations of the same table "t2" are not supported unless they all use INSERT without ON CONFLICT +UPSERT INTO parent VALUES (1); +UPSERT INTO child VALUES (1, 1); +WITH foo AS (INSERT INTO t2 VALUES (1, 1) RETURNING a) SELECT delete_parent(); + +# Even though the triggered UPDATE executes twice, the mutations are allowed +# because they are "siblings". +statement ok +WITH foo AS (INSERT INTO t1 VALUES (1, 2) RETURNING a) INSERT INTO t1 VALUES (1, 1); + +statement ok +WITH foo AS (INSERT INTO t1 VALUES (1, 2) RETURNING a) SELECT insert_t1(); + +statement ok +WITH foo AS (SELECT insert_t1()) SELECT insert_t1(); + +statement ok +DROP TRIGGER foo ON t1; + +statement ok +DROP TRIGGER bar ON child; + +# ------------------------------------------------------------------------------ +# Test an AFTER trigger with an INSERT statement. +# ------------------------------------------------------------------------------ + +statement ok +CREATE TRIGGER foo AFTER INSERT OR UPDATE ON t1 FOR EACH ROW EXECUTE FUNCTION g(); + +statement ok +CREATE TRIGGER bar AFTER DELETE ON child FOR EACH ROW EXECUTE FUNCTION g(); + +# INSERT without ON CONFLICT is always allowed. +statement ok +WITH foo AS (INSERT INTO t2 VALUES (1, 1) RETURNING a) INSERT INTO t1 VALUES (1, 1); + +statement ok +WITH foo AS (INSERT INTO t2 VALUES (1, 1) RETURNING a) SELECT insert_t1(); + +statement ok +UPSERT INTO parent VALUES (1); +UPSERT INTO child VALUES (1, 1); +WITH foo AS (INSERT INTO t2 VALUES (1, 1) RETURNING a) DELETE FROM parent WHERE k = 1; + +# Wrapping the DELETE on parent in a UDF causes the FK cascade (and therefore +# the AFTER trigger) to execute as part of the main query. However, INSERT +# statements do not conflict with one another. +statement ok +UPSERT INTO parent VALUES (1); +UPSERT INTO child VALUES (1, 1); +WITH foo AS (INSERT INTO t2 VALUES (1, 1) RETURNING a) SELECT delete_parent(); + +# The triggered INSERT does not conflict with the outer UPDATE on t2 because the +# trigger is run as a post-query. +statement ok +WITH foo AS (UPDATE t2 SET b = 1 WHERE a = 1 RETURNING a) INSERT INTO t1 VALUES (1, 1); + +# When the INSERT into t1 is wrapped in a UDF, the AFTER trigger is run within +# the UDF, and so the triggered INSERT on t2 conflicts with the outer UPDATE. +statement error pgcode 0A000 pq: while building trigger expression: multiple mutations of the same table "t2" are not supported unless they all use INSERT without ON CONFLICT +WITH foo AS (UPDATE t2 SET b = 1 WHERE a = 1 RETURNING a) SELECT insert_t1(); + +# The triggered INSERT does not conflict with the outer UPDATE because it is run +# after the FK cascade, which runs after the main query. +statement ok +UPSERT INTO parent VALUES (1); +UPSERT INTO child VALUES (1, 1); +WITH foo AS (UPDATE t2 SET b = 1 WHERE a = 1 RETURNING a) DELETE FROM parent WHERE k = 1; + +# Wrapping the DELETE on parent in a UDF causes the FK cascade (and therefore +# the AFTER trigger) to execute as part of the main query. As a result, the +# triggered INSERT conflicts with the outer UPDATE. +statement error pgcode 0A000 pq: while building trigger expression: multiple mutations of the same table "t2" are not supported unless they all use INSERT without ON CONFLICT +UPSERT INTO parent VALUES (1); +UPSERT INTO child VALUES (1, 1); +WITH foo AS (UPDATE t2 SET b = 1 WHERE a = 1 RETURNING a) SELECT delete_parent(); + +statement ok +DROP TRIGGER foo ON t1; + +statement ok +DROP TRIGGER bar ON child; + +# ------------------------------------------------------------------------------ +# Test an AFTER trigger with an UPDATE statement. +# ------------------------------------------------------------------------------ + +statement ok +CREATE TRIGGER foo AFTER INSERT OR UPDATE ON t1 FOR EACH ROW EXECUTE FUNCTION h(); + +statement ok +CREATE TRIGGER bar AFTER DELETE ON child FOR EACH ROW EXECUTE FUNCTION h(); + +# The triggered UPDATE does not conflict with the outer INSERT on t2 because the +# trigger is run as a post-query. +statement ok +WITH foo AS (INSERT INTO t2 VALUES (1, 1) RETURNING a) INSERT INTO t1 VALUES (1, 1); + +# When the INSERT into t1 is wrapped in a UDF, the AFTER trigger is run within +# the UDF, and so the triggered UPDATE on t2 conflicts with the outer INSERT. +statement error pgcode 0A000 pq: while building trigger expression: multiple mutations of the same table "t2" are not supported unless they all use INSERT without ON CONFLICT +WITH foo AS (INSERT INTO t2 VALUES (1, 1) RETURNING a) SELECT insert_t1(); + +# The triggered UPDATE does not conflict with the outer INSERT because it is run +# after the FK cascade, which runs after the main query. +statement ok +UPSERT INTO parent VALUES (1); +UPSERT INTO child VALUES (1, 1); +WITH foo AS (INSERT INTO t2 VALUES (1, 1) RETURNING a) DELETE FROM parent WHERE k = 1; + +# Wrapping the DELETE on parent in a UDF causes the FK cascade (and therefore +# the AFTER trigger) to execute as part of the main query. As a result, the +# triggered UPDATE conflicts with the outer INSERT. +statement error pgcode 0A000 pq: while building trigger expression: multiple mutations of the same table "t2" are not supported unless they all use INSERT without ON CONFLICT +UPSERT INTO parent VALUES (1); +UPSERT INTO child VALUES (1, 1); +WITH foo AS (INSERT INTO t2 VALUES (1, 1) RETURNING a) SELECT delete_parent(); + +# The triggered UPDATE does not conflict with the outer UPDATE on t2 because the +# trigger is run as a post-query. +statement ok +WITH foo AS (UPDATE t2 SET b = 1 WHERE a = 1 RETURNING a) INSERT INTO t1 VALUES (1, 1); + +# Wrapping the INSERT on t1 in a UDF means the trigger is run within the scope +# of the UDF, so the triggered UPDATE on t2 conflicts with the outer UPDATE. +statement error pgcode 0A000 pq: while building trigger expression: multiple mutations of the same table "t2" are not supported unless they all use INSERT without ON CONFLICT +WITH foo AS (UPDATE t2 SET b = 1 WHERE a = 1 RETURNING a) SELECT insert_t1(); + +# The triggered UPDATE does not conflict with the outer UPDATE because it is run +# after the FK cascade, which runs after the main query. +statement ok +UPSERT INTO parent VALUES (1); +UPSERT INTO child VALUES (1, 1); +WITH foo AS (UPDATE t2 SET b = 1 WHERE a = 1 RETURNING a) DELETE FROM parent WHERE k = 1; + +# Wrapping the DELETE on parent in a UDF causes the FK cascade (and therefore +# the AFTER trigger) to execute as part of the main query. As a result, the +# triggered UPDATE conflicts with the outer UPDATE. +statement error pgcode 0A000 pq: while building trigger expression: multiple mutations of the same table "t2" are not supported unless they all use INSERT without ON CONFLICT +UPSERT INTO parent VALUES (1); +UPSERT INTO child VALUES (1, 1); +WITH foo AS (UPDATE t2 SET b = 1 WHERE a = 1 RETURNING a) SELECT delete_parent(); + +# Even though the trigger executes an UPDATE on t2 twice, the mutations are +# allowed because they are executed as post-queries. +statement ok +WITH foo AS (INSERT INTO t1 VALUES (1, 2) RETURNING a) INSERT INTO t1 VALUES (1, 1); + +# Even though the trigger executes an UPDATE on t2 twice, the mutations are +# allowed because one is executed as a post-query. +statement ok +WITH foo AS (INSERT INTO t1 VALUES (1, 2) RETURNING a) SELECT insert_t1(); + +# Even though the trigger executes an UPDATE on t2 twice, the mutations are +# allowed because they are "siblings". +statement ok +WITH foo AS (SELECT insert_t1()) SELECT insert_t1(); + +statement ok +DROP TRIGGER foo ON t1; + +statement ok +DROP TRIGGER bar ON child; + +statement ok +DROP FUNCTION insert_t1; +DROP TABLE t1; +DROP TABLE t2; +DROP FUNCTION g; +DROP FUNCTION h; + +# ------------------------------------------------------------------------------ +# Test a trigger conflicting with a FK cascade mutation. +# ------------------------------------------------------------------------------ + +statement ok +CREATE FUNCTION g() RETURNS TRIGGER LANGUAGE PLpgSQL AS $$ + BEGIN + UPDATE child SET k = k + 1 WHERE True; + RETURN OLD; + END +$$; + +statement ok +CREATE TRIGGER foo BEFORE DELETE ON child FOR EACH ROW EXECUTE FUNCTION g(); + +# The triggered UPDATE conflicts with the cascaded DELETE on child. +statement error pgcode 0A000 pq: while building cascade expression: multiple mutations of the same table "child" are not supported unless they all use INSERT without ON CONFLICT +UPSERT INTO parent VALUES (1); +UPSERT INTO child VALUES (1, 1); +DELETE FROM parent WHERE k = 1; + +statement ok +DROP TRIGGER foo ON child; + +statement ok +CREATE TRIGGER foo AFTER DELETE ON child FOR EACH ROW EXECUTE FUNCTION g(); + +# The triggered UPDATE does not conflict with the cascaded DELETE on child +# because the trigger is run as a post-query. +statement ok +UPSERT INTO parent VALUES (1); +UPSERT INTO child VALUES (1, 1); +DELETE FROM parent WHERE k = 1; + +statement ok +DROP FUNCTION delete_parent; +DROP TABLE child; +DROP TABLE parent; +DROP FUNCTION g; + # ============================================================================== # Test unsupported syntax. # ============================================================================== diff --git a/pkg/sql/opt/optbuilder/fk_cascade.go b/pkg/sql/opt/optbuilder/fk_cascade.go index 8cdfa8c3c067..e9229f354ca2 100644 --- a/pkg/sql/opt/optbuilder/fk_cascade.go +++ b/pkg/sql/opt/optbuilder/fk_cascade.go @@ -61,18 +61,23 @@ type onDeleteCascadeBuilder struct { // Note that the columns must be remapped to the new memo when the cascade is // built. oldValues opt.ColList + + // stmtTreeInitFn returns a statementTree that tracks the mutations in + // ancestor statements. It may be unset if there are no ancestor statements. + stmtTreeInitFn func() statementTree } var _ memo.PostQueryBuilder = &onDeleteCascadeBuilder{} -func newOnDeleteCascadeBuilder( - mutatedTable cat.Table, fkInboundOrdinal int, childTable cat.Table, oldValues opt.ColList, +func (mb *mutationBuilder) newOnDeleteCascadeBuilder( + fkInboundOrdinal int, childTable cat.Table, oldValues opt.ColList, ) *onDeleteCascadeBuilder { return &onDeleteCascadeBuilder{ - mutatedTable: mutatedTable, + mutatedTable: mb.tab, fkInboundOrdinal: fkInboundOrdinal, childTable: childTable, oldValues: oldValues, + stmtTreeInitFn: mb.b.stmtTree.GetInitFnForPostQuery(), } } @@ -87,38 +92,42 @@ func (cb *onDeleteCascadeBuilder) Build( bindingProps *props.Relational, colMap opt.ColMap, ) (_ memo.RelExpr, err error) { - return buildTriggerCascadeHelper(ctx, semaCtx, evalCtx, catalog, factoryI, func(b *Builder) memo.RelExpr { - opt.MaybeInjectOptimizerTestingPanic(ctx, evalCtx) - - fk := cb.mutatedTable.InboundForeignKey(cb.fkInboundOrdinal) - - dep := opt.DepByID(fk.OriginTableID()) - b.checkPrivilege(dep, cb.childTable, privilege.DELETE) - b.checkPrivilege(dep, cb.childTable, privilege.SELECT) - - var mb mutationBuilder - mb.init(b, "delete", cb.childTable, tree.MakeUnqualifiedTableName(cb.childTable.Name())) - - // Build a semi join of the table with the mutation input. - // - // The scope returned by buildDeleteCascadeMutationInput has one column - // for each public table column, making it appropriate to set it as - // mb.fetchScope. - oldValues := cb.oldValues.RemapColumns(colMap) - mb.fetchScope = b.buildDeleteCascadeMutationInput( - cb.childTable, &mb.alias, fk, binding, bindingProps, oldValues, - ) - mb.outScope = mb.fetchScope + return buildTriggerCascadeHelper(ctx, semaCtx, evalCtx, catalog, factoryI, cb.stmtTreeInitFn, + func(b *Builder) memo.RelExpr { + opt.MaybeInjectOptimizerTestingPanic(ctx, evalCtx) - // Set list of columns that will be fetched by the input expression. - mb.setFetchColIDs(mb.outScope.cols) + fk := cb.mutatedTable.InboundForeignKey(cb.fkInboundOrdinal) - // Cascades can fire triggers on the child table. - mb.buildRowLevelBeforeTriggers(tree.TriggerEventDelete, true /* cascade */) + dep := opt.DepByID(fk.OriginTableID()) + b.checkPrivilege(dep, cb.childTable, privilege.DELETE) + b.checkPrivilege(dep, cb.childTable, privilege.SELECT) - mb.buildDelete(nil /* returning */) - return mb.outScope.expr - }) + var mb mutationBuilder + mb.init(b, "delete", cb.childTable, tree.MakeUnqualifiedTableName(cb.childTable.Name())) + + // Build a semi join of the table with the mutation input. + // + // The scope returned by buildDeleteCascadeMutationInput has one column + // for each public table column, making it appropriate to set it as + // mb.fetchScope. + oldValues := cb.oldValues.RemapColumns(colMap) + mb.fetchScope = b.buildDeleteCascadeMutationInput( + cb.childTable, &mb.alias, fk, binding, bindingProps, oldValues, + ) + mb.outScope = mb.fetchScope + + // Set list of columns that will be fetched by the input expression. + mb.setFetchColIDs(mb.outScope.cols) + + // Register the mutation with the statementTree + b.checkMultipleMutations(mb.tab, generalMutation) + + // Cascades can fire triggers on the child table. + mb.buildRowLevelBeforeTriggers(tree.TriggerEventDelete, true /* cascade */) + + mb.buildDelete(nil /* returning */) + return mb.outScope.expr + }) } // onDeleteFastCascadeBuilder is a memo.PostQueryBuilder implementation for @@ -155,6 +164,10 @@ type onDeleteFastCascadeBuilder struct { origFilters memo.FiltersExpr origFKCols opt.ColList + + // stmtTreeInitFn returns a statementTree that tracks the mutations in + // ancestor statements. It may be unset if there are no ancestor statements. + stmtTreeInitFn func() statementTree } var _ memo.PostQueryBuilder = &onDeleteFastCascadeBuilder{} @@ -162,15 +175,11 @@ var _ memo.PostQueryBuilder = &onDeleteFastCascadeBuilder{} // tryNewOnDeleteFastCascadeBuilder checks if the fast path cascade is // applicable to the given mutation, and if yes it returns an instance of // onDeleteFastCascadeBuilder. -func tryNewOnDeleteFastCascadeBuilder( - ctx context.Context, - md *opt.Metadata, - catalog cat.Catalog, - fk cat.ForeignKeyConstraint, - fkInboundOrdinal int, - parentTab, childTab cat.Table, - mutationInputScope *scope, +func (mb *mutationBuilder) tryNewOnDeleteFastCascadeBuilder( + fk cat.ForeignKeyConstraint, fkInboundOrdinal int, childTab cat.Table, ) (_ *onDeleteFastCascadeBuilder, ok bool) { + parentTab := mb.tab + mutationInputScope := mb.outScope fkCols := make(opt.ColList, fk.ColumnCount()) for i := range fkCols { tabOrd := fk.ReferencedColumnOrdinal(parentTab, i) @@ -212,7 +221,7 @@ func tryNewOnDeleteFastCascadeBuilder( } // Check that the scan retrieves all table data (currently, this should always // be the case in a normalized expression). - if !scan.IsUnfiltered(md) { + if !scan.IsUnfiltered(mb.md) { return nil, false } @@ -245,7 +254,7 @@ func tryNewOnDeleteFastCascadeBuilder( case childTabID: tab = childTab default: - tab = resolveTable(ctx, catalog, tabID) + tab = resolveTable(mb.b.ctx, mb.b.catalog, tabID) } // If the table could not be resolved, then we cannot confirm that FK // references form a simple tree, so we return false. This is possible @@ -271,6 +280,7 @@ func tryNewOnDeleteFastCascadeBuilder( childTable: childTab, origFilters: filters, origFKCols: fkCols, + stmtTreeInitFn: mb.b.stmtTree.GetInitFnForPostQuery(), }, true } @@ -285,86 +295,90 @@ func (cb *onDeleteFastCascadeBuilder) Build( _ *props.Relational, _ opt.ColMap, ) (_ memo.RelExpr, err error) { - return buildTriggerCascadeHelper(ctx, semaCtx, evalCtx, catalog, factoryI, func(b *Builder) memo.RelExpr { - opt.MaybeInjectOptimizerTestingPanic(ctx, evalCtx) - - fk := cb.mutatedTable.InboundForeignKey(cb.fkInboundOrdinal) - - dep := opt.DepByID(fk.OriginTableID()) - b.checkPrivilege(dep, cb.childTable, privilege.DELETE) - b.checkPrivilege(dep, cb.childTable, privilege.SELECT) - - var mb mutationBuilder - mb.init(b, "delete", cb.childTable, tree.MakeUnqualifiedTableName(cb.childTable.Name())) - - // Build the input to the delete mutation, which is simply a Scan with a - // Select on top. - mb.fetchScope = b.buildScan( - b.addTable(cb.childTable, &mb.alias), - tableOrdinals(cb.childTable, columnKinds{ - includeMutations: false, - includeSystem: false, - includeInverted: false, - }), - nil, /* indexFlags */ - noRowLocking, - b.allocScope(), - true, /* disableNotVisibleIndex */ - ) - mb.outScope = mb.fetchScope - - var filters memo.FiltersExpr - - // Build the filters by copying the original filters and replacing all - // variable references. - if len(cb.origFilters) > 0 { - var replaceFn norm.ReplaceFunc - replaceFn = func(e opt.Expr) opt.Expr { - if v, ok := e.(*memo.VariableExpr); ok { - idx, found := cb.origFKCols.Find(v.Col) - if !found { - panic(errors.AssertionFailedf("non-FK variable in filter")) + return buildTriggerCascadeHelper(ctx, semaCtx, evalCtx, catalog, factoryI, cb.stmtTreeInitFn, + func(b *Builder) memo.RelExpr { + opt.MaybeInjectOptimizerTestingPanic(ctx, evalCtx) + + fk := cb.mutatedTable.InboundForeignKey(cb.fkInboundOrdinal) + + dep := opt.DepByID(fk.OriginTableID()) + b.checkPrivilege(dep, cb.childTable, privilege.DELETE) + b.checkPrivilege(dep, cb.childTable, privilege.SELECT) + + var mb mutationBuilder + mb.init(b, "delete", cb.childTable, tree.MakeUnqualifiedTableName(cb.childTable.Name())) + + // Build the input to the delete mutation, which is simply a Scan with a + // Select on top. + mb.fetchScope = b.buildScan( + b.addTable(cb.childTable, &mb.alias), + tableOrdinals(cb.childTable, columnKinds{ + includeMutations: false, + includeSystem: false, + includeInverted: false, + }), + nil, /* indexFlags */ + noRowLocking, + b.allocScope(), + true, /* disableNotVisibleIndex */ + ) + mb.outScope = mb.fetchScope + + var filters memo.FiltersExpr + + // Build the filters by copying the original filters and replacing all + // variable references. + if len(cb.origFilters) > 0 { + var replaceFn norm.ReplaceFunc + replaceFn = func(e opt.Expr) opt.Expr { + if v, ok := e.(*memo.VariableExpr); ok { + idx, found := cb.origFKCols.Find(v.Col) + if !found { + panic(errors.AssertionFailedf("non-FK variable in filter")) + } + tabOrd := fk.OriginColumnOrdinal(cb.childTable, idx) + col := mb.outScope.getColumnForTableOrdinal(tabOrd) + return b.factory.ConstructVariable(col.id) } - tabOrd := fk.OriginColumnOrdinal(cb.childTable, idx) - col := mb.outScope.getColumnForTableOrdinal(tabOrd) - return b.factory.ConstructVariable(col.id) + return b.factory.CopyAndReplaceDefault(e, replaceFn) } - return b.factory.CopyAndReplaceDefault(e, replaceFn) + filters = *replaceFn(&cb.origFilters).(*memo.FiltersExpr) } - filters = *replaceFn(&cb.origFilters).(*memo.FiltersExpr) - } - // We have to filter out rows that have NULL values; add an IS NOT NULL - // filter for each FK column, unless the column is not-nullable (this is - // a minor optimization, as normalization rules would have removed the - // filter anyway). - notNullCols := mb.outScope.expr.Relational().NotNullCols - for i := range cb.origFKCols { - tabOrd := fk.OriginColumnOrdinal(cb.childTable, i) - col := mb.outScope.getColumnForTableOrdinal(tabOrd) - if !notNullCols.Contains(col.id) { - filters = append(filters, b.factory.ConstructFiltersItem( - b.factory.ConstructIsNot( - b.factory.ConstructVariable(col.id), - b.factory.ConstructNull(col.typ), - ), - )) + // We have to filter out rows that have NULL values; add an IS NOT NULL + // filter for each FK column, unless the column is not-nullable (this is + // a minor optimization, as normalization rules would have removed the + // filter anyway). + notNullCols := mb.outScope.expr.Relational().NotNullCols + for i := range cb.origFKCols { + tabOrd := fk.OriginColumnOrdinal(cb.childTable, i) + col := mb.outScope.getColumnForTableOrdinal(tabOrd) + if !notNullCols.Contains(col.id) { + filters = append(filters, b.factory.ConstructFiltersItem( + b.factory.ConstructIsNot( + b.factory.ConstructVariable(col.id), + b.factory.ConstructNull(col.typ), + ), + )) + } } - } - if len(filters) > 0 { - mb.outScope.expr = b.factory.ConstructSelect(mb.outScope.expr, filters) - } + if len(filters) > 0 { + mb.outScope.expr = b.factory.ConstructSelect(mb.outScope.expr, filters) + } - // Set list of columns that will be fetched by the input expression. - mb.setFetchColIDs(mb.outScope.cols) + // Set list of columns that will be fetched by the input expression. + mb.setFetchColIDs(mb.outScope.cols) - // Cascades can fire triggers on the child table. - mb.buildRowLevelBeforeTriggers(tree.TriggerEventDelete, true /* cascade */) + // Register the mutation with the statementTree + b.checkMultipleMutations(mb.tab, generalMutation) - mb.buildDelete(nil /* returning */) - return mb.outScope.expr - }) + // Cascades can fire triggers on the child table. + mb.buildRowLevelBeforeTriggers(tree.TriggerEventDelete, true /* cascade */) + + mb.buildDelete(nil /* returning */) + return mb.outScope.expr + }) } // onDeleteSetBuilder is a memo.PostQueryBuilder implementation for @@ -420,23 +434,24 @@ type onDeleteSetBuilder struct { // Note that the columns must be remapped to the new memo when the cascade is // built. oldValues opt.ColList + + // stmtTreeInitFn returns a statementTree that tracks the mutations in + // ancestor statements. It may be unset if there are no ancestor statements. + stmtTreeInitFn func() statementTree } var _ memo.PostQueryBuilder = &onDeleteSetBuilder{} -func newOnDeleteSetBuilder( - mutatedTable cat.Table, - fkInboundOrdinal int, - childTable cat.Table, - action tree.ReferenceAction, - oldValues opt.ColList, +func (mb *mutationBuilder) newOnDeleteSetBuilder( + fkInboundOrdinal int, childTable cat.Table, action tree.ReferenceAction, oldValues opt.ColList, ) *onDeleteSetBuilder { return &onDeleteSetBuilder{ - mutatedTable: mutatedTable, + mutatedTable: mb.tab, fkInboundOrdinal: fkInboundOrdinal, childTable: childTable, action: action, oldValues: oldValues, + stmtTreeInitFn: mb.b.stmtTree.GetInitFnForPostQuery(), } } @@ -451,60 +466,64 @@ func (cb *onDeleteSetBuilder) Build( bindingProps *props.Relational, colMap opt.ColMap, ) (_ memo.RelExpr, err error) { - return buildTriggerCascadeHelper(ctx, semaCtx, evalCtx, catalog, factoryI, func(b *Builder) memo.RelExpr { - opt.MaybeInjectOptimizerTestingPanic(ctx, evalCtx) - - fk := cb.mutatedTable.InboundForeignKey(cb.fkInboundOrdinal) - - dep := opt.DepByID(fk.OriginTableID()) - b.checkPrivilege(dep, cb.childTable, privilege.UPDATE) - b.checkPrivilege(dep, cb.childTable, privilege.SELECT) - - var mb mutationBuilder - mb.init(b, "update", cb.childTable, tree.MakeUnqualifiedTableName(cb.childTable.Name())) - - // Build a semi join of the table with the mutation input. - // - // The scope returned by buildDeleteCascadeMutationInput has one column - // for each public table column, making it appropriate to set it as - // mb.fetchScope. - oldValues := cb.oldValues.RemapColumns(colMap) - mb.fetchScope = b.buildDeleteCascadeMutationInput( - cb.childTable, &mb.alias, fk, binding, bindingProps, oldValues, - ) - mb.outScope = mb.fetchScope - - // Set list of columns that will be fetched by the input expression. - mb.setFetchColIDs(mb.outScope.cols) - // Add target columns. - numFKCols := fk.ColumnCount() - for i := 0; i < numFKCols; i++ { - tabOrd := fk.OriginColumnOrdinal(cb.childTable, i) - mb.addTargetCol(tabOrd) - } + return buildTriggerCascadeHelper(ctx, semaCtx, evalCtx, catalog, factoryI, cb.stmtTreeInitFn, + func(b *Builder) memo.RelExpr { + opt.MaybeInjectOptimizerTestingPanic(ctx, evalCtx) + + fk := cb.mutatedTable.InboundForeignKey(cb.fkInboundOrdinal) + + dep := opt.DepByID(fk.OriginTableID()) + b.checkPrivilege(dep, cb.childTable, privilege.UPDATE) + b.checkPrivilege(dep, cb.childTable, privilege.SELECT) + + var mb mutationBuilder + mb.init(b, "update", cb.childTable, tree.MakeUnqualifiedTableName(cb.childTable.Name())) + + // Build a semi join of the table with the mutation input. + // + // The scope returned by buildDeleteCascadeMutationInput has one column + // for each public table column, making it appropriate to set it as + // mb.fetchScope. + oldValues := cb.oldValues.RemapColumns(colMap) + mb.fetchScope = b.buildDeleteCascadeMutationInput( + cb.childTable, &mb.alias, fk, binding, bindingProps, oldValues, + ) + mb.outScope = mb.fetchScope + + // Set list of columns that will be fetched by the input expression. + mb.setFetchColIDs(mb.outScope.cols) + // Add target columns. + numFKCols := fk.ColumnCount() + for i := 0; i < numFKCols; i++ { + tabOrd := fk.OriginColumnOrdinal(cb.childTable, i) + mb.addTargetCol(tabOrd) + } - // Add the SET expressions. - updateExprs := make(tree.UpdateExprs, numFKCols) - for i := range updateExprs { - updateExprs[i] = &tree.UpdateExpr{} - if cb.action == tree.SetNull { - updateExprs[i].Expr = tree.DNull - } else { - updateExprs[i].Expr = tree.DefaultVal{} + // Add the SET expressions. + updateExprs := make(tree.UpdateExprs, numFKCols) + for i := range updateExprs { + updateExprs[i] = &tree.UpdateExpr{} + if cb.action == tree.SetNull { + updateExprs[i].Expr = tree.DNull + } else { + updateExprs[i].Expr = tree.DefaultVal{} + } } - } - mb.addUpdateCols(updateExprs) + mb.addUpdateCols(updateExprs) - // Cascades can fire triggers on the child table. - mb.buildRowLevelBeforeTriggers(tree.TriggerEventUpdate, true /* cascade */) + // Register the mutation with the statementTree + b.checkMultipleMutations(mb.tab, generalMutation) - // TODO(radu): consider plumbing a flag to prevent building the FK check - // against the parent we are cascading from. Need to investigate in which - // cases this is safe (e.g. other cascades could have messed with the parent - // table in the meantime). - mb.buildUpdate(nil /* returning */) - return mb.outScope.expr - }) + // Cascades can fire triggers on the child table. + mb.buildRowLevelBeforeTriggers(tree.TriggerEventUpdate, true /* cascade */) + + // TODO(radu): consider plumbing a flag to prevent building the FK check + // against the parent we are cascading from. Need to investigate in which + // cases this is safe (e.g. other cascades could have messed with the parent + // table in the meantime). + mb.buildUpdate(nil /* returning */) + return mb.outScope.expr + }) } // buildDeleteCascadeMutationInput constructs a semi-join between the child @@ -656,24 +675,28 @@ type onUpdateCascadeBuilder struct { // table. Note that the columns must be remapped to the new memo when the // cascade is built. newValues opt.ColList + + // stmtTreeInitFn returns a statementTree that tracks the mutations in + // ancestor statements. It may be unset if there are no ancestor statements. + stmtTreeInitFn func() statementTree } var _ memo.PostQueryBuilder = &onUpdateCascadeBuilder{} -func newOnUpdateCascadeBuilder( - mutatedTable cat.Table, +func (mb *mutationBuilder) newOnUpdateCascadeBuilder( fkInboundOrdinal int, childTable cat.Table, action tree.ReferenceAction, oldValues, newValues opt.ColList, ) *onUpdateCascadeBuilder { return &onUpdateCascadeBuilder{ - mutatedTable: mutatedTable, + mutatedTable: mb.tab, fkInboundOrdinal: fkInboundOrdinal, childTable: childTable, action: action, oldValues: oldValues, newValues: newValues, + stmtTreeInitFn: mb.b.stmtTree.GetInitFnForPostQuery(), } } @@ -688,64 +711,68 @@ func (cb *onUpdateCascadeBuilder) Build( bindingProps *props.Relational, colMap opt.ColMap, ) (_ memo.RelExpr, err error) { - return buildTriggerCascadeHelper(ctx, semaCtx, evalCtx, catalog, factoryI, func(b *Builder) memo.RelExpr { - opt.MaybeInjectOptimizerTestingPanic(ctx, evalCtx) - - fk := cb.mutatedTable.InboundForeignKey(cb.fkInboundOrdinal) - - dep := opt.DepByID(fk.OriginTableID()) - b.checkPrivilege(dep, cb.childTable, privilege.UPDATE) - b.checkPrivilege(dep, cb.childTable, privilege.SELECT) - - var mb mutationBuilder - mb.init(b, "update", cb.childTable, tree.MakeUnqualifiedTableName(cb.childTable.Name())) - - // Build a join of the table with the mutation input. - oldValues := cb.oldValues.RemapColumns(colMap) - newValues := cb.newValues.RemapColumns(colMap) - mb.outScope = b.buildUpdateCascadeMutationInput( - cb.childTable, &mb.alias, fk, binding, bindingProps, oldValues, newValues, - ) - - // The scope created by b.buildUpdateCascadeMutationInput has the table - // columns, followed by the old FK values, followed by the new FK values. - numFKCols := fk.ColumnCount() - tableScopeCols := mb.outScope.cols[:len(mb.outScope.cols)-2*numFKCols] - newValScopeCols := mb.outScope.cols[len(mb.outScope.cols)-numFKCols:] - mb.fetchScope = b.allocScope() - mb.fetchScope.appendColumns(tableScopeCols) - - // Set list of columns that will be fetched by the input expression. - mb.setFetchColIDs(tableScopeCols) - // Add target columns. - for i := 0; i < numFKCols; i++ { - tabOrd := fk.OriginColumnOrdinal(cb.childTable, i) - mb.addTargetCol(tabOrd) - } + return buildTriggerCascadeHelper(ctx, semaCtx, evalCtx, catalog, factoryI, cb.stmtTreeInitFn, + func(b *Builder) memo.RelExpr { + opt.MaybeInjectOptimizerTestingPanic(ctx, evalCtx) + + fk := cb.mutatedTable.InboundForeignKey(cb.fkInboundOrdinal) + + dep := opt.DepByID(fk.OriginTableID()) + b.checkPrivilege(dep, cb.childTable, privilege.UPDATE) + b.checkPrivilege(dep, cb.childTable, privilege.SELECT) + + var mb mutationBuilder + mb.init(b, "update", cb.childTable, tree.MakeUnqualifiedTableName(cb.childTable.Name())) + + // Build a join of the table with the mutation input. + oldValues := cb.oldValues.RemapColumns(colMap) + newValues := cb.newValues.RemapColumns(colMap) + mb.outScope = b.buildUpdateCascadeMutationInput( + cb.childTable, &mb.alias, fk, binding, bindingProps, oldValues, newValues, + ) + + // The scope created by b.buildUpdateCascadeMutationInput has the table + // columns, followed by the old FK values, followed by the new FK values. + numFKCols := fk.ColumnCount() + tableScopeCols := mb.outScope.cols[:len(mb.outScope.cols)-2*numFKCols] + newValScopeCols := mb.outScope.cols[len(mb.outScope.cols)-numFKCols:] + mb.fetchScope = b.allocScope() + mb.fetchScope.appendColumns(tableScopeCols) + + // Set list of columns that will be fetched by the input expression. + mb.setFetchColIDs(tableScopeCols) + // Add target columns. + for i := 0; i < numFKCols; i++ { + tabOrd := fk.OriginColumnOrdinal(cb.childTable, i) + mb.addTargetCol(tabOrd) + } - // Add the SET expressions. - updateExprs := make(tree.UpdateExprs, numFKCols) - for i := range updateExprs { - updateExprs[i] = &tree.UpdateExpr{} - switch cb.action { - case tree.Cascade: - updateExprs[i].Expr = &newValScopeCols[i] - case tree.SetNull: - updateExprs[i].Expr = tree.DNull - case tree.SetDefault: - updateExprs[i].Expr = tree.DefaultVal{} - default: - panic(errors.AssertionFailedf("unsupported action")) + // Add the SET expressions. + updateExprs := make(tree.UpdateExprs, numFKCols) + for i := range updateExprs { + updateExprs[i] = &tree.UpdateExpr{} + switch cb.action { + case tree.Cascade: + updateExprs[i].Expr = &newValScopeCols[i] + case tree.SetNull: + updateExprs[i].Expr = tree.DNull + case tree.SetDefault: + updateExprs[i].Expr = tree.DefaultVal{} + default: + panic(errors.AssertionFailedf("unsupported action")) + } } - } - mb.addUpdateCols(updateExprs) + mb.addUpdateCols(updateExprs) - // Cascades can fire triggers on the child table. - mb.buildRowLevelBeforeTriggers(tree.TriggerEventUpdate, true /* cascade */) + // Register the mutation with the statementTree + b.checkMultipleMutations(mb.tab, generalMutation) - mb.buildUpdate(nil /* returning */) - return mb.outScope.expr - }) + // Cascades can fire triggers on the child table. + mb.buildRowLevelBeforeTriggers(tree.TriggerEventUpdate, true /* cascade */) + + mb.buildUpdate(nil /* returning */) + return mb.outScope.expr + }) } // buildUpdateCascadeMutationInput constructs an inner-join between the child @@ -939,16 +966,28 @@ func (b *Builder) buildUpdateCascadeMutationInput( // buildTriggerCascadeHelper contains boilerplate for PostQueryBuilder.Build // implementations. It creates a Builder, sets up panic-to-error conversion, // and executes the given function. +// +// - stmtTreeInitFn is a closure that returns a statement tree describing +// mutations which might conflict with AFTER triggers. It may be nil if there +// is no "outer" mutation that might conflict. Cascades must propagate this +// information as well, since they can themselves fire triggers. func buildTriggerCascadeHelper( ctx context.Context, semaCtx *tree.SemaContext, evalCtx *eval.Context, catalog cat.Catalog, factoryI interface{}, + stmtTreeInitFn func() statementTree, fn func(b *Builder) memo.RelExpr, ) (_ memo.RelExpr, err error) { factory := factoryI.(*norm.Factory) b := New(ctx, semaCtx, evalCtx, catalog, factory, nil /* stmt */) + if stmtTreeInitFn != nil { + b.stmtTree = stmtTreeInitFn() + } + // Push a new statement onto the statement tree. + b.stmtTree.Push() + defer b.stmtTree.Pop() // Enact panic handling similar to Builder.Build(). defer func() { diff --git a/pkg/sql/opt/optbuilder/mutation_builder_fk.go b/pkg/sql/opt/optbuilder/mutation_builder_fk.go index 27b56b68ee18..f26bd51baf1d 100644 --- a/pkg/sql/opt/optbuilder/mutation_builder_fk.go +++ b/pkg/sql/opt/optbuilder/mutation_builder_fk.go @@ -149,17 +149,15 @@ func (mb *mutationBuilder) buildFKChecksAndCascadesForDelete() { case tree.Cascade: // Try the fast builder first; if it cannot be used, use the regular builder. var ok bool - builder, ok = tryNewOnDeleteFastCascadeBuilder( - mb.b.ctx, mb.md, mb.b.catalog, h.fk, i, mb.tab, h.otherTab, mb.outScope, - ) + builder, ok = mb.tryNewOnDeleteFastCascadeBuilder(h.fk, i, h.otherTab) if !ok { mb.ensureWithID() - builder = newOnDeleteCascadeBuilder(mb.tab, i, h.otherTab, cols) + builder = mb.newOnDeleteCascadeBuilder(i, h.otherTab, cols) } triggerEventType = tree.TriggerEventDelete case tree.SetNull, tree.SetDefault: mb.ensureWithID() - builder = newOnDeleteSetBuilder(mb.tab, i, h.otherTab, a, cols) + builder = mb.newOnDeleteSetBuilder(i, h.otherTab, a, cols) triggerEventType = tree.TriggerEventUpdate default: panic(errors.AssertionFailedf("unhandled action type %s", a)) @@ -313,7 +311,7 @@ func (mb *mutationBuilder) buildFKChecksForUpdate() { hasBeforeTriggers := cat.HasRowLevelTriggers( h.otherTab, tree.TriggerActionTimeBefore, tree.TriggerEventUpdate, ) - builder := newOnUpdateCascadeBuilder(mb.tab, i, h.otherTab, a, oldCols, newCols) + builder := mb.newOnUpdateCascadeBuilder(i, h.otherTab, a, oldCols, newCols) mb.cascades = append(mb.cascades, memo.FKCascade{ FKConstraint: h.fk, HasBeforeTriggers: hasBeforeTriggers, @@ -435,7 +433,7 @@ func (mb *mutationBuilder) buildFKChecksForUpsert() { hasBeforeTriggers := cat.HasRowLevelTriggers( h.otherTab, tree.TriggerActionTimeBefore, tree.TriggerEventUpdate, ) - builder := newOnUpdateCascadeBuilder(mb.tab, i, h.otherTab, a, oldCols, newCols) + builder := mb.newOnUpdateCascadeBuilder(i, h.otherTab, a, oldCols, newCols) mb.cascades = append(mb.cascades, memo.FKCascade{ FKConstraint: h.fk, HasBeforeTriggers: hasBeforeTriggers, diff --git a/pkg/sql/opt/optbuilder/statement_tree.go b/pkg/sql/opt/optbuilder/statement_tree.go index 74ef08b8aad8..b30b61a792da 100644 --- a/pkg/sql/opt/optbuilder/statement_tree.go +++ b/pkg/sql/opt/optbuilder/statement_tree.go @@ -42,8 +42,30 @@ import ( // This set of children mutations is checked for conflicts during // CanMutateTable, which is equivalent to maintaining and traversing the entire // sub-tree of the popped statement. +// +// +----------------+ +// | After Triggers | +// +----------------+ +// After triggers are an exceptional case because they are planned lazily during +// execution, and unlike cascades, can include arbitrary logic. This means that +// we can't know ahead of time which tables will be mutated by AFTER triggers. +// Since AFTER triggers are post-queries, they do not conflict with the current +// statement. However, they can conflict with ancestor statements, e.g., if they +// are fired within a routine. +// +// We handle this by deferring the check for AFTER triggers until they are +// actually built. We maintain references to all ancestor statementTreeNodes, +// and use those to initialize the statement tree when building the AFTER +// triggers. Since AFTER triggers are built after the outer query has finished +// planning, all potentially conflicting mutations are known and reflected in +// the initialized statement tree. Note that some care must be taken to ensure +// that the statementTreeNode references remain valid and up-to-date (see the +// stmts comment below). type statementTree struct { - stmts []statementTreeNode + // stmts is a stack of statement nodes, as described in the struct comment. + // It is a slice of pointers to ensure that slice appends don't invalidate + // references to existing nodes. + stmts []*statementTreeNode } // mutationType represents a set of mutation types that can be applied to a @@ -86,17 +108,17 @@ func (n *statementTreeNode) childrenConflictWithMutation( // Push pushes a new statement onto the tree as a descendent of the current // statement. The newly pushed statement becomes the current statement. func (st *statementTree) Push() { - st.stmts = append(st.stmts, statementTreeNode{}) + st.stmts = append(st.stmts, &statementTreeNode{}) } // Pop sets the parent of the current statement as the new current statement. func (st *statementTree) Pop() { - popped := &st.stmts[len(st.stmts)-1] + popped := st.stmts[len(st.stmts)-1] st.stmts = st.stmts[:len(st.stmts)-1] if len(st.stmts) > 0 { // Combine the popped statement's mutations and child mutations into the // child statements of its parent (the new current statement). - curr := &st.stmts[len(st.stmts)-1] + curr := st.stmts[len(st.stmts)-1] curr.childrenSimpleInsertTables.UnionWith(popped.simpleInsertTables) curr.childrenSimpleInsertTables.UnionWith(popped.childrenSimpleInsertTables) curr.childrenGeneralMutationTables.UnionWith(popped.generalMutationTables) @@ -152,7 +174,7 @@ func (st *statementTree) CanMutateTable( // visible. offset = 2 } - curr := &st.stmts[len(st.stmts)-offset] + curr := st.stmts[len(st.stmts)-offset] // Check the children of the current statement for a conflict. if !isPostStmt && curr.childrenConflictWithMutation(tabID, typ) { return false @@ -164,7 +186,7 @@ func (st *statementTree) CanMutateTable( // mutation to the parent statement. break } - n := &st.stmts[i] + n := st.stmts[i] if n.conflictsWithMutation(tabID, typ) { return false } @@ -186,3 +208,41 @@ func (st *statementTree) CanMutateTable( } return true } + +// GetInitFnForPostQuery returns a function that can be used to initialize the +// statement tree for a post-query that is a child of the current statement. +// This is necessary because the post-query is not built until after the main +// statement has finished executing. The returned function may be nil if the +// statement tree does not need to be initialized for the post-query. +// +// NOTE: cascades are checked up-front by Builder.checkMultipleMutationsCascade, +// but the statement tree still must be propagated to them via this function in +// case the cascade causes a trigger to fire. +func (st *statementTree) GetInitFnForPostQuery() func() statementTree { + if len(st.stmts) <= 1 { + return nil + } + // Save references to the ancestor statementTreeNodes. Modifications to them + // after this point should be reflected in the references, ensuring that all + // ancestor mutations are visible by the time the post-query plan is built. + // + // This is necessary because the full set of mutations in the current + // statement may not be known at the time the statement tree is saved. For + // example, a CTE in which the first branch triggers a post-query, and the + // second is a mutation. + ancestorStatements := make([]*statementTreeNode, len(st.stmts)-1) + copy(ancestorStatements, st.stmts[:len(st.stmts)-1]) + return func() statementTree { + // Combine the non-child mutated tables for all ancestor nodes into a single + // ancestor node. This provides all the information needed to check for + // conflicts in a trigger run as a post-query. We can omit the child + // mutation tables because they can only conflict with the ancestor nodes, + // and that case has already been checked. + var node statementTreeNode + for i := range ancestorStatements { + node.simpleInsertTables.UnionWith(ancestorStatements[i].simpleInsertTables) + node.generalMutationTables.UnionWith(ancestorStatements[i].generalMutationTables) + } + return statementTree{stmts: []*statementTreeNode{&node}} + } +} diff --git a/pkg/sql/opt/optbuilder/statement_tree_test.go b/pkg/sql/opt/optbuilder/statement_tree_test.go index 6a8ae7feed15..b92331e6e055 100644 --- a/pkg/sql/opt/optbuilder/statement_tree_test.go +++ b/pkg/sql/opt/optbuilder/statement_tree_test.go @@ -12,13 +12,15 @@ import ( ) func TestStatementTree(t *testing.T) { - type cmd uint8 + type cmd uint16 const ( push cmd = 1 << iota pop mut simple post + getInit + init t1 t2 fail @@ -419,10 +421,177 @@ func TestStatementTree(t *testing.T) { pop, }, }, + // 24. + // Original: + // Push + // CanMutateTable(t1, default) + // GetInitFnForPostQuery() + // Pop + // + // Post-Query: + // initFn() + // Push + // CanMutateTable(t1, default) + // Pop + { + cmds: []cmd{ + push, + mut | t1, + getInit, + pop, + init, + push, + mut | t1, + pop, + }, + }, + // 25. + // Original: + // Push + // Push + // CanMutateTable(t1, default) + // Pop + // GetInitFnForPostQuery() + // Pop + // + // Post-Query: + // initFn() + // Push + // CanMutateTable(t1, default) + // Pop + { + cmds: []cmd{ + push, + push, + mut | t1, + pop, + getInit, + pop, + init, + push, + mut | t1, + pop, + }, + }, + // 25. + // Original: + // Push + // Push + // CanMutateTable(t1, default) + // Pop + // Push + // GetInitFnForPostQuery() + // Pop + // Pop + // + // Post-Query: + // initFn() + // Push + // CanMutateTable(t1, default) + // Pop + { + cmds: []cmd{ + push, + push, + mut | t1, + pop, + push, + getInit, + pop, + pop, + init, + push, + mut | t1, + pop, + }, + }, + // 26. + // Original: + // Push + // CanMutateTable(t1, default) + // Push + // GetInitFnForPostQuery() + // Pop + // Pop + // + // Post-Query: + // initFn() + // Push + // CanMutateTable(t1, default) FAIL + { + cmds: []cmd{ + push, + mut | t1, + push, + getInit, + pop, + pop, + init, + push, + mut | t1 | fail, + }, + }, + // 27. + // Original: + // Push + // Push + // GetInitFnForPostQuery() + // Pop + // CanMutateTable(t1, default) + // Pop + // + // Post-Query: + // initFn() + // Push + // CanMutateTable(t1, default) FAIL + { + cmds: []cmd{ + push, + push, + getInit, + pop, + mut | t1, + pop, + init, + push, + mut | t1 | fail, + }, + }, + // 28. + // Original: + // Push + // CanMutateTable(t1, default) + // Push + // Push + // GetInitFnForPostQuery() + // Pop + // Pop + // Pop + // + // Post-Query: + // initFn() + // Push + // CanMutateTable(t1, default) FAIL + { + cmds: []cmd{ + push, + mut | t1, + push, + push, + getInit, + pop, + pop, + pop, + init, + push, + mut | t1 | fail, + }, + }, } for i, tc := range testCases { var mt statementTree + var pqTreeFn func() statementTree for j, c := range tc.cmds { switch { case c&push == push: @@ -431,6 +600,19 @@ func TestStatementTree(t *testing.T) { case c&pop == pop: mt.Pop() + case c&getInit == getInit: + if pqTreeFn != nil { + t.Fatalf("test case %d: GetInitFnForPostQuery called twice", i) + } + pqTreeFn = mt.GetInitFnForPostQuery() + + case c&init == init: + if pqTreeFn == nil { + mt = statementTree{} + } else { + mt = pqTreeFn() + } + case c&mut == mut: var tabID cat.StableID switch { diff --git a/pkg/sql/opt/optbuilder/trigger.go b/pkg/sql/opt/optbuilder/trigger.go index 6e0c78af1cfc..173bbc32a221 100644 --- a/pkg/sql/opt/optbuilder/trigger.go +++ b/pkg/sql/opt/optbuilder/trigger.go @@ -447,8 +447,8 @@ func (mb *mutationBuilder) buildRowLevelAfterTriggers(mutation opt.Operator) { } mb.afterTriggers = &memo.AfterTriggers{ Triggers: triggers, - Builder: newRowLevelAfterTriggerBuilder( - mutation, mb.tab, triggers, fetchCols, updateCols, insertCols, mb.canaryColID, + Builder: mb.newRowLevelAfterTriggerBuilder( + mutation, triggers, fetchCols, updateCols, insertCols, ), WithID: mb.withID, } @@ -490,6 +490,10 @@ type rowLevelAfterTriggerBuilder struct { mutatedTable cat.Table triggers []cat.Trigger + // stmtTreeInitFn returns a statementTree that tracks the mutations in + // ancestor statements. It may be unset if there are no ancestor statements. + stmtTreeInitFn func() statementTree + // The following fields contain the columns from the mutation input needed to // build the triggers. The columns must be remapped to the new memo when the // triggers are built. If fetchCols, updateCols, or insertCols is set, then @@ -511,21 +515,18 @@ type rowLevelAfterTriggerBuilder struct { var _ memo.PostQueryBuilder = &rowLevelAfterTriggerBuilder{} -func newRowLevelAfterTriggerBuilder( - mutation opt.Operator, - mutatedTable cat.Table, - triggers []cat.Trigger, - fetchCols, updateCols, insertCols opt.ColList, - canaryCol opt.ColumnID, +func (mb *mutationBuilder) newRowLevelAfterTriggerBuilder( + mutation opt.Operator, triggers []cat.Trigger, fetchCols, updateCols, insertCols opt.ColList, ) *rowLevelAfterTriggerBuilder { return &rowLevelAfterTriggerBuilder{ - mutation: mutation, - mutatedTable: mutatedTable, - triggers: triggers, - fetchCols: fetchCols, - updateCols: updateCols, - insertCols: insertCols, - canaryCol: canaryCol, + mutation: mutation, + mutatedTable: mb.tab, + triggers: triggers, + stmtTreeInitFn: mb.b.stmtTree.GetInitFnForPostQuery(), + fetchCols: fetchCols, + updateCols: updateCols, + insertCols: insertCols, + canaryCol: mb.canaryColID, } } @@ -540,221 +541,222 @@ func (tb *rowLevelAfterTriggerBuilder) Build( bindingProps *props.Relational, colMap opt.ColMap, ) (_ memo.RelExpr, err error) { - return buildTriggerCascadeHelper(ctx, semaCtx, evalCtx, catalog, factoryI, func(b *Builder) memo.RelExpr { - f := b.factory - md := f.Metadata() - - typeID := typedesc.TableIDToImplicitTypeOID(descpb.ID(tb.mutatedTable.ID())) - tableTyp, err := semaCtx.TypeResolver.ResolveTypeByOID(ctx, typeID) - if err != nil { - panic(err) - } + return buildTriggerCascadeHelper(ctx, semaCtx, evalCtx, catalog, factoryI, tb.stmtTreeInitFn, + func(b *Builder) memo.RelExpr { + f := b.factory + md := f.Metadata() + + typeID := typedesc.TableIDToImplicitTypeOID(descpb.ID(tb.mutatedTable.ID())) + tableTyp, err := semaCtx.TypeResolver.ResolveTypeByOID(ctx, typeID) + if err != nil { + panic(err) + } - // Map the columns from the original memo to the new one using colMap. - inFetchCols := tb.fetchCols.RemapColumns(colMap) - inUpdateCols := tb.updateCols.RemapColumns(colMap) - inInsertCols := tb.insertCols.RemapColumns(colMap) - colCount := len(inFetchCols) + len(inUpdateCols) + len(inInsertCols) - if tb.canaryCol != 0 { - // Make space for the canary column. - colCount++ - } - inCols := make(opt.ColList, 0, colCount) - outCols := make(opt.ColList, 0, colCount) - - // Allocate a new scope to build the expression that will call the trigger - // functions for each row scanned from the buffer. - triggerScope := b.allocScope() - var inCanaryCol, outCanaryCol opt.ColumnID - if tb.canaryCol != 0 { - inCanaryColID, ok := colMap.Get(int(tb.canaryCol)) - if !ok { - panic(errors.AssertionFailedf("column %d not in mapping %s\n", - tb.canaryCol, colMap.String())) + // Map the columns from the original memo to the new one using colMap. + inFetchCols := tb.fetchCols.RemapColumns(colMap) + inUpdateCols := tb.updateCols.RemapColumns(colMap) + inInsertCols := tb.insertCols.RemapColumns(colMap) + colCount := len(inFetchCols) + len(inUpdateCols) + len(inInsertCols) + if tb.canaryCol != 0 { + // Make space for the canary column. + colCount++ } - inCanaryCol = opt.ColumnID(inCanaryColID) - colType := md.ColumnMeta(inCanaryCol).Type - colName := scopeColName("").WithMetadataName("canary") - col := b.synthesizeColumn(triggerScope, colName, colType, nil /* expr */, nil /* scalar */) - outCanaryCol = col.id - inCols = append(inCols, inCanaryCol) - outCols = append(outCols, outCanaryCol) - } - addCols := func(cols opt.ColList, suffix string) opt.ColList { - startIdx := len(outCols) - for _, col := range cols { - colMeta := md.ColumnMeta(col) - name := scopeColName("").WithMetadataName(fmt.Sprintf("%s_%s", colMeta.Alias, suffix)) - outCol := b.synthesizeColumn( - triggerScope, name, colMeta.Type, nil /* expr */, nil, /* scalar */ - ) - inCols = append(inCols, col) - outCols = append(outCols, outCol.id) + inCols := make(opt.ColList, 0, colCount) + outCols := make(opt.ColList, 0, colCount) + + // Allocate a new scope to build the expression that will call the trigger + // functions for each row scanned from the buffer. + triggerScope := b.allocScope() + var inCanaryCol, outCanaryCol opt.ColumnID + if tb.canaryCol != 0 { + inCanaryColID, ok := colMap.Get(int(tb.canaryCol)) + if !ok { + panic(errors.AssertionFailedf("column %d not in mapping %s\n", + tb.canaryCol, colMap.String())) + } + inCanaryCol = opt.ColumnID(inCanaryColID) + colType := md.ColumnMeta(inCanaryCol).Type + colName := scopeColName("").WithMetadataName("canary") + col := b.synthesizeColumn(triggerScope, colName, colType, nil /* expr */, nil /* scalar */) + outCanaryCol = col.id + inCols = append(inCols, inCanaryCol) + outCols = append(outCols, outCanaryCol) } - return outCols[startIdx:len(outCols):len(outCols)] - } - outFetchCols := addCols(inFetchCols, "old") - outUpdateCols := addCols(inUpdateCols, "new") - outInsertCols := addCols(inInsertCols, "new") - md.AddWithBinding(binding, b.factory.ConstructFakeRel(&memo.FakeRelPrivate{ - Props: bindingProps, - })) - triggerScope.expr = f.ConstructWithScan(&memo.WithScanPrivate{ - With: binding, - InCols: inCols, - OutCols: outCols, - ID: md.NextUniqueID(), - }) - - // Project the old and new values into tuples. These will become the OLD and - // NEW arguments to the trigger functions. - makeTuple := func(cols opt.ColList) opt.ScalarExpr { - elems := make([]opt.ScalarExpr, len(cols)) - for i, col := range cols { - elems[i] = f.ConstructVariable(col) + addCols := func(cols opt.ColList, suffix string) opt.ColList { + startIdx := len(outCols) + for _, col := range cols { + colMeta := md.ColumnMeta(col) + name := scopeColName("").WithMetadataName(fmt.Sprintf("%s_%s", colMeta.Alias, suffix)) + outCol := b.synthesizeColumn( + triggerScope, name, colMeta.Type, nil /* expr */, nil, /* scalar */ + ) + inCols = append(inCols, col) + outCols = append(outCols, outCol.id) + } + return outCols[startIdx:len(outCols):len(outCols)] + } + outFetchCols := addCols(inFetchCols, "old") + outUpdateCols := addCols(inUpdateCols, "new") + outInsertCols := addCols(inInsertCols, "new") + md.AddWithBinding(binding, b.factory.ConstructFakeRel(&memo.FakeRelPrivate{ + Props: bindingProps, + })) + triggerScope.expr = f.ConstructWithScan(&memo.WithScanPrivate{ + With: binding, + InCols: inCols, + OutCols: outCols, + ID: md.NextUniqueID(), + }) + + // Project the old and new values into tuples. These will become the OLD and + // NEW arguments to the trigger functions. + makeTuple := func(cols opt.ColList) opt.ScalarExpr { + elems := make([]opt.ScalarExpr, len(cols)) + for i, col := range cols { + elems[i] = f.ConstructVariable(col) + } + return f.ConstructTuple(elems, tableTyp) + } + var canaryCheck opt.ScalarExpr + if tb.canaryCol != 0 { + canaryCheck = f.ConstructIs(f.ConstructVariable(outCanaryCol), memo.NullSingleton) } - return f.ConstructTuple(elems, tableTyp) - } - var canaryCheck opt.ScalarExpr - if tb.canaryCol != 0 { - canaryCheck = f.ConstructIs(f.ConstructVariable(outCanaryCol), memo.NullSingleton) - } - // Build an expression for the old values of each row. - oldScalar := opt.ScalarExpr(memo.NullSingleton) - if len(outFetchCols) > 0 { - oldScalar = makeTuple(outFetchCols) - if outCanaryCol != 0 { - // For an UPSERT/ON CONFLICT, the OLD column is non-null only for the - // conflicting rows, which are identified by the canary column. - oldScalar = f.ConstructCase( - memo.TrueSingleton, - memo.ScalarListExpr{f.ConstructWhen(canaryCheck, f.ConstructNull(tableTyp))}, - oldScalar, - ) + // Build an expression for the old values of each row. + oldScalar := opt.ScalarExpr(memo.NullSingleton) + if len(outFetchCols) > 0 { + oldScalar = makeTuple(outFetchCols) + if outCanaryCol != 0 { + // For an UPSERT/ON CONFLICT, the OLD column is non-null only for the + // conflicting rows, which are identified by the canary column. + oldScalar = f.ConstructCase( + memo.TrueSingleton, + memo.ScalarListExpr{f.ConstructWhen(canaryCheck, f.ConstructNull(tableTyp))}, + oldScalar, + ) + } } - } - // Build an expression for the new values of each row. - newScalar := opt.ScalarExpr(memo.NullSingleton) - if outCanaryCol != 0 { - // For an UPSERT/ON CONFLICT, the NEW column contains either inserted or - // updated values, depending on the canary column. - newScalar = f.ConstructCase( - memo.TrueSingleton, - memo.ScalarListExpr{f.ConstructWhen(canaryCheck, makeTuple(outInsertCols))}, - makeTuple(outUpdateCols), - ) - } else if len(outUpdateCols) > 0 { - newScalar = makeTuple(outUpdateCols) - } else if len(outInsertCols) > 0 { - newScalar = makeTuple(outInsertCols) - } - oldColID := b.projectColWithMetadataName(triggerScope, triggerColOld, tableTyp, oldScalar) - newColID := b.projectColWithMetadataName(triggerScope, triggerColNew, tableTyp, newScalar) - tgWhen := tree.NewDString("AFTER") - tgLevel := tree.NewDString("ROW") - tgRelID := tree.NewDOid(oid.Oid(tb.mutatedTable.ID())) - tgTableName := tree.NewDString(string(tb.mutatedTable.Name())) - fqName, err := b.catalog.FullyQualifiedName(ctx, tb.mutatedTable) - if err != nil { - panic(err) - } - tgTableSchema := tree.NewDString(fqName.Schema()) - var tgOp opt.ScalarExpr - switch tb.mutation { - case opt.InsertOp: - tgOp = f.ConstructConstVal(tree.NewDString("INSERT"), types.String) + // Build an expression for the new values of each row. + newScalar := opt.ScalarExpr(memo.NullSingleton) if outCanaryCol != 0 { - tgOp = f.ConstructCase( + // For an UPSERT/ON CONFLICT, the NEW column contains either inserted or + // updated values, depending on the canary column. + newScalar = f.ConstructCase( memo.TrueSingleton, - memo.ScalarListExpr{f.ConstructWhen(canaryCheck, tgOp)}, - f.ConstructConstVal(tree.NewDString("UPDATE"), types.String), + memo.ScalarListExpr{f.ConstructWhen(canaryCheck, makeTuple(outInsertCols))}, + makeTuple(outUpdateCols), ) + } else if len(outUpdateCols) > 0 { + newScalar = makeTuple(outUpdateCols) + } else if len(outInsertCols) > 0 { + newScalar = makeTuple(outInsertCols) } - case opt.UpdateOp: - tgOp = f.ConstructConstVal(tree.NewDString("UPDATE"), types.String) - case opt.DeleteOp: - tgOp = f.ConstructConstVal(tree.NewDString("DELETE"), types.String) - default: - panic(errors.AssertionFailedf("unexpected mutation type: %v", tb.mutation)) - } - - for i, trigger := range tb.triggers { - if i > 0 { - // No need to place a barrier below the first trigger. - triggerScope.expr = f.ConstructBarrier(triggerScope.expr) + oldColID := b.projectColWithMetadataName(triggerScope, triggerColOld, tableTyp, oldScalar) + newColID := b.projectColWithMetadataName(triggerScope, triggerColNew, tableTyp, newScalar) + tgWhen := tree.NewDString("AFTER") + tgLevel := tree.NewDString("ROW") + tgRelID := tree.NewDOid(oid.Oid(tb.mutatedTable.ID())) + tgTableName := tree.NewDString(string(tb.mutatedTable.Name())) + fqName, err := b.catalog.FullyQualifiedName(ctx, tb.mutatedTable) + if err != nil { + panic(err) } - - tgName := tree.NewDName(string(trigger.Name())) - tgNumArgs := tree.NewDInt(tree.DInt(len(trigger.FuncArgs()))) - tgArgV := tree.NewDArray(types.String) - for _, arg := range trigger.FuncArgs() { - err = tgArgV.Append(arg) - if err != nil { - panic(err) + tgTableSchema := tree.NewDString(fqName.Schema()) + var tgOp opt.ScalarExpr + switch tb.mutation { + case opt.InsertOp: + tgOp = f.ConstructConstVal(tree.NewDString("INSERT"), types.String) + if outCanaryCol != 0 { + tgOp = f.ConstructCase( + memo.TrueSingleton, + memo.ScalarListExpr{f.ConstructWhen(canaryCheck, tgOp)}, + f.ConstructConstVal(tree.NewDString("UPDATE"), types.String), + ) } - } - args := memo.ScalarListExpr{ - f.ConstructVariable(newColID), // NEW - f.ConstructVariable(oldColID), // OLD - f.ConstructConstVal(tgName, types.Name), // TG_NAME - f.ConstructConstVal(tgWhen, types.String), // TG_WHEN - f.ConstructConstVal(tgLevel, types.String), // TG_LEVEL - tgOp, // TG_OP - f.ConstructConstVal(tgRelID, types.Oid), // TG_RELIID - f.ConstructConstVal(tgTableName, types.String), // TG_RELNAME - f.ConstructConstVal(tgTableName, types.String), // TG_TABLE_NAME - f.ConstructConstVal(tgTableSchema, types.String), // TG_TABLE_SCHEMA - f.ConstructConstVal(tgNumArgs, types.Int), // TG_NARGS - f.ConstructConstVal(tgArgV, types.StringArray), // TG_ARGV + case opt.UpdateOp: + tgOp = f.ConstructConstVal(tree.NewDString("UPDATE"), types.String) + case opt.DeleteOp: + tgOp = f.ConstructConstVal(tree.NewDString("DELETE"), types.String) + default: + panic(errors.AssertionFailedf("unexpected mutation type: %v", tb.mutation)) } - // Resolve the trigger function and build the invocation. - triggerFn, def := b.buildTriggerFunction(trigger, tb.mutatedTable.ID(), tableTyp, args) - - // If there is a WHEN condition, wrap the trigger function invocation in a - // CASE WHEN statement that checks the WHEN condition. - if trigger.WhenExpr() != "" { - triggerFn = b.buildTriggerWhen( - trigger, triggerScope, oldColID, newColID, triggerFn, f.ConstructNull(tableTyp), - ) - } + for i, trigger := range tb.triggers { + if i > 0 { + // No need to place a barrier below the first trigger. + triggerScope.expr = f.ConstructBarrier(triggerScope.expr) + } - // For UPSERT and INSERT ON CONFLICT, UPDATE triggers should only fire for - // the conflicting rows, which are identified by the canary column. INSERT - // triggers should only fire for non-conflicting rows. A trigger that - // matches both operations can fire unconditionally. - if outCanaryCol != 0 { - var hasInsert, hasUpdate bool - for j := 0; j < trigger.EventCount(); j++ { - if trigger.Event(j).EventType == tree.TriggerEventInsert { - hasInsert = true - } else if trigger.Event(j).EventType == tree.TriggerEventUpdate { - hasUpdate = true + tgName := tree.NewDName(string(trigger.Name())) + tgNumArgs := tree.NewDInt(tree.DInt(len(trigger.FuncArgs()))) + tgArgV := tree.NewDArray(types.String) + for _, arg := range trigger.FuncArgs() { + err = tgArgV.Append(arg) + if err != nil { + panic(err) } } - if hasInsert && !hasUpdate { - triggerFn = f.ConstructCase( - memo.TrueSingleton, - memo.ScalarListExpr{f.ConstructWhen(canaryCheck, triggerFn)}, - f.ConstructNull(tableTyp), - ) - } else if hasUpdate && !hasInsert { - triggerFn = f.ConstructCase( - memo.TrueSingleton, - memo.ScalarListExpr{f.ConstructWhen(canaryCheck, f.ConstructNull(tableTyp))}, - triggerFn, + args := memo.ScalarListExpr{ + f.ConstructVariable(newColID), // NEW + f.ConstructVariable(oldColID), // OLD + f.ConstructConstVal(tgName, types.Name), // TG_NAME + f.ConstructConstVal(tgWhen, types.String), // TG_WHEN + f.ConstructConstVal(tgLevel, types.String), // TG_LEVEL + tgOp, // TG_OP + f.ConstructConstVal(tgRelID, types.Oid), // TG_RELIID + f.ConstructConstVal(tgTableName, types.String), // TG_RELNAME + f.ConstructConstVal(tgTableName, types.String), // TG_TABLE_NAME + f.ConstructConstVal(tgTableSchema, types.String), // TG_TABLE_SCHEMA + f.ConstructConstVal(tgNumArgs, types.Int), // TG_NARGS + f.ConstructConstVal(tgArgV, types.StringArray), // TG_ARGV + } + + // Resolve the trigger function and build the invocation. + triggerFn, def := b.buildTriggerFunction(trigger, tb.mutatedTable.ID(), tableTyp, args) + + // If there is a WHEN condition, wrap the trigger function invocation in a + // CASE WHEN statement that checks the WHEN condition. + if trigger.WhenExpr() != "" { + triggerFn = b.buildTriggerWhen( + trigger, triggerScope, oldColID, newColID, triggerFn, f.ConstructNull(tableTyp), ) } - } - // Finally, project a column that invokes the trigger function. - b.projectColWithMetadataName(triggerScope, def.Name, tableTyp, triggerFn) - } - // Always wrap the expression in a barrier, or else the projections will be - // pruned and the triggers will not be executed. - return f.ConstructBarrier(triggerScope.expr) - }) + // For UPSERT and INSERT ON CONFLICT, UPDATE triggers should only fire for + // the conflicting rows, which are identified by the canary column. INSERT + // triggers should only fire for non-conflicting rows. A trigger that + // matches both operations can fire unconditionally. + if outCanaryCol != 0 { + var hasInsert, hasUpdate bool + for j := 0; j < trigger.EventCount(); j++ { + if trigger.Event(j).EventType == tree.TriggerEventInsert { + hasInsert = true + } else if trigger.Event(j).EventType == tree.TriggerEventUpdate { + hasUpdate = true + } + } + if hasInsert && !hasUpdate { + triggerFn = f.ConstructCase( + memo.TrueSingleton, + memo.ScalarListExpr{f.ConstructWhen(canaryCheck, triggerFn)}, + f.ConstructNull(tableTyp), + ) + } else if hasUpdate && !hasInsert { + triggerFn = f.ConstructCase( + memo.TrueSingleton, + memo.ScalarListExpr{f.ConstructWhen(canaryCheck, f.ConstructNull(tableTyp))}, + triggerFn, + ) + } + } + + // Finally, project a column that invokes the trigger function. + b.projectColWithMetadataName(triggerScope, def.Name, tableTyp, triggerFn) + } + // Always wrap the expression in a barrier, or else the projections will be + // pruned and the triggers will not be executed. + return f.ConstructBarrier(triggerScope.expr) + }) } // ============================================================================