diff --git a/src/Microsoft.Data.Sqlite.Core/SqliteTransaction.cs b/src/Microsoft.Data.Sqlite.Core/SqliteTransaction.cs index f1db5bdf2cc..6a4b4faa76e 100644 --- a/src/Microsoft.Data.Sqlite.Core/SqliteTransaction.cs +++ b/src/Microsoft.Data.Sqlite.Core/SqliteTransaction.cs @@ -83,9 +83,15 @@ public override void Commit() throw new InvalidOperationException(Resources.TransactionCompleted); } - sqlite3_rollback_hook(_connection.Handle, null, null); - _connection.ExecuteNonQuery("COMMIT;"); - Complete(); + try + { + sqlite3_rollback_hook(_connection.Handle, null, null); + _connection.ExecuteNonQuery("COMMIT;"); + } + finally + { + Complete(); + } } /// @@ -213,7 +219,14 @@ private void Complete() { if (IsolationLevel == IsolationLevel.ReadUncommitted) { - _connection!.ExecuteNonQuery("PRAGMA read_uncommitted = 0;"); + try + { + _connection!.ExecuteNonQuery("PRAGMA read_uncommitted = 0;"); + } + catch + { + // Ignore failure attempting to clean up. + } } _connection!.Transaction = null; @@ -223,13 +236,19 @@ private void Complete() private void RollbackInternal() { - if (!ExternalRollback) + try + { + if (!ExternalRollback) + { + sqlite3_rollback_hook(_connection!.Handle, null, null); + _connection.ExecuteNonQuery("ROLLBACK;"); + } + } + finally { - sqlite3_rollback_hook(_connection!.Handle, null, null); - _connection.ExecuteNonQuery("ROLLBACK;"); + Complete(); } - Complete(); } private void RollbackExternal(object userData) diff --git a/test/Microsoft.Data.Sqlite.Tests/SqliteTransactionTest.cs b/test/Microsoft.Data.Sqlite.Tests/SqliteTransactionTest.cs index 8236ff5d5f6..7289a44b93f 100644 --- a/test/Microsoft.Data.Sqlite.Tests/SqliteTransactionTest.cs +++ b/test/Microsoft.Data.Sqlite.Tests/SqliteTransactionTest.cs @@ -3,6 +3,8 @@ using System; using System.Data; +using System.Diagnostics.CodeAnalysis; +using System.Threading.Tasks; using Microsoft.Data.Sqlite.Properties; using Xunit; using static SQLitePCL.raw; @@ -11,6 +13,144 @@ namespace Microsoft.Data.Sqlite; public class SqliteTransactionTest { + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task SqliteTransaction_Dispose_does_not_leave_orphaned_transaction(bool async) // Issue #25119 + { + using var connection = new FakeConnection("Data Source=:memory:"); + + if (async) + { + await connection.OpenAsync(); + } + else + { + connection.Open(); + } + +#if NET5_0_OR_GREATER + using var transaction = async ? await connection.BeginTransactionAsync() : connection.BeginTransaction(); +#else + using var transaction = connection.BeginTransaction(); +#endif + + await AddNewTable("Table1"); + + connection.SimulateFailureOnRollback = true; + + try + { +#if NET5_0_OR_GREATER + if (async) + { + await transaction.DisposeAsync(); + } + else + { + transaction.Dispose(); + } +#else + transaction.Dispose(); +#endif + + Assert.Fail(); + } + catch (Exception) + { + // Expected to throw. + } + + Assert.Null(connection.Transaction); + + connection.SimulateFailureOnRollback = false; + +#if NET5_0_OR_GREATER + using var transaction2 = async ? await connection.BeginTransactionAsync() : connection.BeginTransaction(); +#else + using var transaction2 = connection.BeginTransaction(); +#endif + + await AddNewTable("Table2"); + +#if NET5_0_OR_GREATER + if (async) + { + await transaction2.DisposeAsync(); + } + else + { + transaction2.Dispose(); + } +#else + transaction2.Dispose(); +#endif + + Assert.Null(connection.Transaction); + + async Task AddNewTable(string tableName) + { + using var command = connection.CreateCommand(); + command.CommandText = $"CREATE TABLE {tableName} (ID INT PRIMARY KEY NOT NULL);"; + _ = async ? await command.ExecuteNonQueryAsync() : command.ExecuteNonQuery(); + } + } + + private class FakeCommand : SqliteCommand + { + private readonly FakeConnection _connection; + private readonly SqliteCommand _realCommand; + + public FakeCommand(FakeConnection connection, SqliteCommand realCommand) + { + _connection = connection; + _realCommand = realCommand; + } + + public override int ExecuteNonQuery() + { + var result = _realCommand.ExecuteNonQuery(); + + if (_connection.SimulateFailureOnRollback && CommandText.Contains("ROLLBACK")) + { + throw new SqliteException("Simulated failure", 1); + } + + return result; + } + + [AllowNull] + public override string CommandText { get => _realCommand.CommandText; set => _realCommand.CommandText = value; } + public override int CommandTimeout { get => _realCommand.CommandTimeout; set => _realCommand.CommandTimeout = value; } + public override CommandType CommandType { get => _realCommand.CommandType; set => _realCommand.CommandType = value; } + public override bool DesignTimeVisible { get => _realCommand.DesignTimeVisible; set => _realCommand.DesignTimeVisible = value; } + + public override UpdateRowSource UpdatedRowSource + { + get => _realCommand.UpdatedRowSource; + set => _realCommand.UpdatedRowSource = value; + } + + public override void Cancel() + => _realCommand.Cancel(); + + public override object? ExecuteScalar() + => _realCommand.ExecuteScalar(); + + public override void Prepare() + => _realCommand.Prepare(); + } + + private class FakeConnection(string connectionString) : SqliteConnection(connectionString) + { + public bool SimulateFailureOnRollback { get; set; } + + public override SqliteCommand CreateCommand() + => new FakeCommand(this, base.CreateCommand()); + + public new SqliteTransaction? Transaction => base.Transaction; + } + [Fact] public void Ctor_sets_read_uncommitted() {