diff --git a/executor/insert_common.go b/executor/insert_common.go index 87938ff9af27b..a0e54ba244218 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -1198,6 +1198,10 @@ func (e *InsertValues) batchCheckAndInsert(ctx context.Context, rows [][]types.D } } else { e.ctx.GetSessionVars().StmtCtx.AppendWarning(r.handleKey.dupErr) + if txnCtx := e.ctx.GetSessionVars().TxnCtx; txnCtx.IsPessimistic { + // lock duplicated row key on insert-ignore + txnCtx.AddUnchangedRowKey(r.handleKey.newKey) + } continue } } else if !kv.IsErrNotFound(err) { @@ -1207,12 +1211,45 @@ func (e *InsertValues) batchCheckAndInsert(ctx context.Context, rows [][]types.D for _, uk := range r.uniqueKeys { _, err := txn.Get(ctx, uk.newKey) if err == nil { +<<<<<<< HEAD // If duplicate keys were found in BatchGet, mark row = nil. e.ctx.GetSessionVars().StmtCtx.AppendWarning(uk.dupErr) skip = true break } if !kv.IsErrNotFound(err) { +======= + if replace { + _, handle, err := tables.FetchDuplicatedHandle( + ctx, + uk.newKey, + true, + txn, + e.Table.Meta().ID, + uk.commonHandle, + ) + if err != nil { + return err + } + if handle == nil { + continue + } + _, err = e.removeRow(ctx, txn, handle, r, true) + if err != nil { + return err + } + } else { + // If duplicate keys were found in BatchGet, mark row = nil. + e.ctx.GetSessionVars().StmtCtx.AppendWarning(uk.dupErr) + if txnCtx := e.ctx.GetSessionVars().TxnCtx; txnCtx.IsPessimistic { + // lock duplicated unique key on insert-ignore + txnCtx.AddUnchangedRowKey(uk.newKey) + } + skip = true + break + } + } else if !kv.IsErrNotFound(err) { +>>>>>>> 24cd54c7462 (executor: lock duplicated keys on insert-ignore & replace-nothing (#42210)) return err } } @@ -1257,7 +1294,18 @@ func (e *InsertValues) removeRow(ctx context.Context, txn kv.Transaction, r toBe return err } if identical { +<<<<<<< HEAD return nil +======= + if inReplace { + e.ctx.GetSessionVars().StmtCtx.AddAffectedRows(1) + } + _, err := appendUnchangedRowForLock(e.ctx, r.t, handle, oldRow) + if err != nil { + return false, err + } + return true, nil +>>>>>>> 24cd54c7462 (executor: lock duplicated keys on insert-ignore & replace-nothing (#42210)) } err = r.t.RemoveRecord(e.ctx, handle, oldRow) diff --git a/executor/insert_test.go b/executor/insert_test.go index acd8d6736b219..d6e4b08f93252 100644 --- a/executor/insert_test.go +++ b/executor/insert_test.go @@ -1509,3 +1509,69 @@ func TestIssue32213(t *testing.T) { tk.MustQuery("select cast(test.t1.c1 as decimal(5, 3)) from test.t1").Check(testkit.Rows("99.999")) tk.MustQuery("select cast(test.t1.c1 as decimal(6, 3)) from test.t1").Check(testkit.Rows("100.000")) } + +func TestInsertLock(t *testing.T) { + store := testkit.CreateMockStore(t) + tk1 := testkit.NewTestKit(t, store) + tk2 := testkit.NewTestKit(t, store) + tk1.MustExec("use test") + tk2.MustExec("use test") + + for _, tt := range []struct { + name string + ddl string + dml string + }{ + { + "replace-pk", + "create table t (c int primary key clustered)", + "replace into t values (1)", + }, + { + "replace-uk", + "create table t (c int unique key)", + "replace into t values (1)", + }, + { + "insert-ingore-pk", + "create table t (c int primary key clustered)", + "insert ignore into t values (1)", + }, + { + "insert-ingore-uk", + "create table t (c int unique key)", + "insert ignore into t values (1)", + }, + { + "insert-update-pk", + "create table t (c int primary key clustered)", + "insert into t values (1) on duplicate key update c = values(c)", + }, + { + "insert-update-uk", + "create table t (c int unique key)", + "insert into t values (1) on duplicate key update c = values(c)", + }, + } { + t.Run(tt.name, func(t *testing.T) { + tk1.MustExec("drop table if exists t") + tk1.MustExec(tt.ddl) + tk1.MustExec("insert into t values (1)") + tk1.MustExec("begin") + tk1.MustExec(tt.dml) + done := make(chan struct{}) + go func() { + tk2.MustExec("delete from t") + done <- struct{}{} + }() + select { + case <-done: + require.Failf(t, "txn2 is not blocked by %q", tt.dml) + case <-time.After(100 * time.Millisecond): + } + tk1.MustExec("commit") + <-done + tk1.MustQuery("select * from t").Check([][]interface{}{}) + }) + } +} diff --git a/executor/write.go b/executor/write.go index 0a3fbfda4795e..d333f5564a594 100644 --- a/executor/write.go +++ b/executor/write.go @@ -137,22 +137,8 @@ func updateRecord(ctx context.Context, sctx sessionctx.Context, h kv.Handle, old if sctx.GetSessionVars().ClientCapability&mysql.ClientFoundRows > 0 { sc.AddAffectedRows(1) } - - physicalID := t.Meta().ID - if pt, ok := t.(table.PartitionedTable); ok { - p, err := pt.GetPartitionByRow(sctx, oldData) - if err != nil { - return false, err - } - physicalID = p.GetPhysicalID() - } - - unchangedRowKey := tablecodec.EncodeRowKeyWithHandle(physicalID, h) - txnCtx := sctx.GetSessionVars().TxnCtx - if txnCtx.IsPessimistic { - txnCtx.AddUnchangedRowKey(unchangedRowKey) - } - return false, nil + _, err := appendUnchangedRowForLock(sctx, t, h, oldData) + return false, err } // Fill values into on-update-now fields, only if they are really changed. @@ -229,6 +215,24 @@ func updateRecord(ctx context.Context, sctx sessionctx.Context, h kv.Handle, old return true, nil } +func appendUnchangedRowForLock(sctx sessionctx.Context, t table.Table, h kv.Handle, row []types.Datum) (bool, error) { + txnCtx := sctx.GetSessionVars().TxnCtx + if !txnCtx.IsPessimistic { + return false, nil + } + physicalID := t.Meta().ID + if pt, ok := t.(table.PartitionedTable); ok { + p, err := pt.GetPartitionByRow(sctx, row) + if err != nil { + return false, err + } + physicalID = p.GetPhysicalID() + } + unchangedRowKey := tablecodec.EncodeRowKeyWithHandle(physicalID, h) + txnCtx.AddUnchangedRowKey(unchangedRowKey) + return true, nil +} + func rebaseAutoRandomValue(ctx context.Context, sctx sessionctx.Context, t table.Table, newData *types.Datum, col *table.Column) error { tableInfo := t.Meta() if !tableInfo.ContainsAutoRandomBits() { diff --git a/tests/realtikvtest/pessimistictest/pessimistic_test.go b/tests/realtikvtest/pessimistictest/pessimistic_test.go index a73a36de109cf..e4ff3d7c5a473 100644 --- a/tests/realtikvtest/pessimistictest/pessimistic_test.go +++ b/tests/realtikvtest/pessimistictest/pessimistic_test.go @@ -539,7 +539,7 @@ func TestOptimisticConflicts(t *testing.T) { tk.MustExec("begin pessimistic") // This SQL use BatchGet and cache data in the txn snapshot. // It can be changed to other SQLs that use BatchGet. - tk.MustExec("insert ignore into conflict values (1, 2)") + tk.MustExec("select * from conflict where id in (1, 2, 3)") tk2.MustExec("update conflict set c = c - 1")