diff --git a/executor/insert_common.go b/executor/insert_common.go index dbd4a5ae264cd..751f83c071eda 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -1166,6 +1166,10 @@ func (e *InsertValues) batchCheckAndInsert(ctx context.Context, rows [][]types.D } } else { e.ctx.GetSessionVars().StmtCtx.AppendWarning(r.handleKey.dupErr) + if txnCtx := e.ctx.GetSessionVars().TxnCtx; txnCtx.IsPessimistic { + // lock duplicated row key on insert-ignore + txnCtx.AddUnchangedRowKey(r.handleKey.newKey) + } continue } } else if !kv.IsErrNotFound(err) { @@ -1177,6 +1181,10 @@ func (e *InsertValues) batchCheckAndInsert(ctx context.Context, rows [][]types.D if err == nil { // If duplicate keys were found in BatchGet, mark row = nil. e.ctx.GetSessionVars().StmtCtx.AppendWarning(uk.dupErr) + if txnCtx := e.ctx.GetSessionVars().TxnCtx; txnCtx.IsPessimistic { + // lock duplicated unique key on insert-ignore + txnCtx.AddUnchangedRowKey(uk.newKey) + } skip = true break } @@ -1225,6 +1233,10 @@ func (e *InsertValues) removeRow(ctx context.Context, txn kv.Transaction, r toBe return err } if identical { + _, err := appendUnchangedRowForLock(e.ctx, r.t, handle, oldRow) + if err != nil { + return err + } return nil } diff --git a/executor/insert_test.go b/executor/insert_test.go index b55c3a63765e3..293b888629846 100644 --- a/executor/insert_test.go +++ b/executor/insert_test.go @@ -1481,3 +1481,69 @@ func TestIssue32213(t *testing.T) { tk.MustQuery("select cast(test.t1.c1 as decimal(5, 3)) from test.t1").Check(testkit.Rows("99.999")) tk.MustQuery("select cast(test.t1.c1 as decimal(6, 3)) from test.t1").Check(testkit.Rows("100.000")) } + +func TestInsertLock(t *testing.T) { + store := testkit.CreateMockStore(t) + tk1 := testkit.NewTestKit(t, store) + tk2 := testkit.NewTestKit(t, store) + tk1.MustExec("use test") + tk2.MustExec("use test") + + for _, tt := range []struct { + name string + ddl string + dml string + }{ + { + "replace-pk", + "create table t (c int primary key clustered)", + "replace into t values (1)", + }, + { + "replace-uk", + "create table t (c int unique key)", + "replace into t values (1)", + }, + { + "insert-ingore-pk", + "create table t (c int primary key clustered)", + "insert ignore into t values (1)", + }, + { + "insert-ingore-uk", + "create table t (c int unique key)", + "insert ignore into t values (1)", + }, + { + "insert-update-pk", + "create table t (c int primary key clustered)", + "insert into t values (1) on duplicate key update c = values(c)", + }, + { + "insert-update-uk", + "create table t (c int unique key)", + "insert into t values (1) on duplicate key update c = values(c)", + }, + } { + t.Run(tt.name, func(t *testing.T) { + tk1.MustExec("drop table if exists t") + tk1.MustExec(tt.ddl) + tk1.MustExec("insert into t values (1)") + tk1.MustExec("begin") + tk1.MustExec(tt.dml) + done := make(chan struct{}) + go func() { + tk2.MustExec("delete from t") + done <- struct{}{} + }() + select { + case <-done: + require.Failf(t, "txn2 is not blocked by %q", tt.dml) + case <-time.After(100 * time.Millisecond): + } + tk1.MustExec("commit") + <-done + tk1.MustQuery("select * from t").Check([][]interface{}{}) + }) + } +} diff --git a/executor/replace.go b/executor/replace.go index 4093b0773db8d..c81b36d6e0abc 100644 --- a/executor/replace.go +++ b/executor/replace.go @@ -86,6 +86,10 @@ func (e *ReplaceExec) removeRow(ctx context.Context, txn kv.Transaction, handle } if rowUnchanged { e.ctx.GetSessionVars().StmtCtx.AddAffectedRows(1) + _, err := appendUnchangedRowForLock(e.ctx, r.t, handle, oldRow) + if err != nil { + return false, err + } return true, nil } diff --git a/executor/write.go b/executor/write.go index e32ff1770a77f..6277cd1d09941 100644 --- a/executor/write.go +++ b/executor/write.go @@ -139,22 +139,8 @@ func updateRecord(ctx context.Context, sctx sessionctx.Context, h kv.Handle, old if sctx.GetSessionVars().ClientCapability&mysql.ClientFoundRows > 0 { sc.AddAffectedRows(1) } - - physicalID := t.Meta().ID - if pt, ok := t.(table.PartitionedTable); ok { - p, err := pt.GetPartitionByRow(sctx, oldData) - if err != nil { - return false, err - } - physicalID = p.GetPhysicalID() - } - - unchangedRowKey := tablecodec.EncodeRowKeyWithHandle(physicalID, h) - txnCtx := sctx.GetSessionVars().TxnCtx - if txnCtx.IsPessimistic { - txnCtx.AddUnchangedRowKey(unchangedRowKey) - } - return false, nil + _, err := appendUnchangedRowForLock(sctx, t, h, oldData) + return false, err } // Fill values into on-update-now fields, only if they are really changed. @@ -231,6 +217,24 @@ func updateRecord(ctx context.Context, sctx sessionctx.Context, h kv.Handle, old return true, nil } +func appendUnchangedRowForLock(sctx sessionctx.Context, t table.Table, h kv.Handle, row []types.Datum) (bool, error) { + txnCtx := sctx.GetSessionVars().TxnCtx + if !txnCtx.IsPessimistic { + return false, nil + } + physicalID := t.Meta().ID + if pt, ok := t.(table.PartitionedTable); ok { + p, err := pt.GetPartitionByRow(sctx, row) + if err != nil { + return false, err + } + physicalID = p.GetPhysicalID() + } + unchangedRowKey := tablecodec.EncodeRowKeyWithHandle(physicalID, h) + txnCtx.AddUnchangedRowKey(unchangedRowKey) + return true, nil +} + func rebaseAutoRandomValue(ctx context.Context, sctx sessionctx.Context, t table.Table, newData *types.Datum, col *table.Column) error { tableInfo := t.Meta() if !tableInfo.ContainsAutoRandomBits() { diff --git a/tests/realtikvtest/pessimistictest/pessimistic_test.go b/tests/realtikvtest/pessimistictest/pessimistic_test.go index ae7545e0e91f6..2b04c9a391d0d 100644 --- a/tests/realtikvtest/pessimistictest/pessimistic_test.go +++ b/tests/realtikvtest/pessimistictest/pessimistic_test.go @@ -534,7 +534,7 @@ func TestOptimisticConflicts(t *testing.T) { tk.MustExec("begin pessimistic") // This SQL use BatchGet and cache data in the txn snapshot. // It can be changed to other SQLs that use BatchGet. - tk.MustExec("insert ignore into conflict values (1, 2)") + tk.MustExec("select * from conflict where id in (1, 2, 3)") tk2.MustExec("update conflict set c = c - 1")