From 9d76e6a491f287968b7971a90eb70acb0f4a4670 Mon Sep 17 00:00:00 2001 From: Davoud Eshtehari Date: Fri, 13 Oct 2023 17:27:24 +0000 Subject: [PATCH] Merged PR 4056: [4.0.4] | Fix AE enclave retry logic not working for async queries (#1988) Ports [#1988](https://github.com/dotnet/SqlClient/pull/1988) --- .../SqlColumnEncryptionEnclaveProvider.xml | 13 +- ...umnEncryptionEnclaveProvider.NetCoreApp.cs | 4 +- .../SqlColumnEncryptionEnclaveProvider.cs | 2 +- .../Microsoft/Data/SqlClient/SqlCommand.cs | 133 ++++++++++++------ .../SqlColumnEncryptionEnclaveProvider.cs | 2 +- .../Microsoft/Data/SqlClient/SqlCommand.cs | 104 +++++++++----- .../AzureAttestationBasedEnclaveProvider.cs | 4 +- .../Data/SqlClient/EnclaveDelegate.Crypto.cs | 17 ++- .../SqlClient/EnclaveDelegate.NotSupported.cs | 9 +- .../Data/SqlClient/EnclaveProviderBase.cs | 6 +- .../VirtualSecureModeEnclaveProviderBase.cs | 4 +- .../ManualTests/AlwaysEncrypted/ApiShould.cs | 86 +++++++++++ .../SystemDataInternals/CommandHelper.cs | 17 +++ 13 files changed, 292 insertions(+), 109 deletions(-) diff --git a/doc/snippets/Microsoft.Data.SqlClient/SqlColumnEncryptionEnclaveProvider.xml b/doc/snippets/Microsoft.Data.SqlClient/SqlColumnEncryptionEnclaveProvider.xml index 261f5e57ed..dfbdbe8782 100644 --- a/doc/snippets/Microsoft.Data.SqlClient/SqlColumnEncryptionEnclaveProvider.xml +++ b/doc/snippets/Microsoft.Data.SqlClient/SqlColumnEncryptionEnclaveProvider.xml @@ -20,8 +20,8 @@ the enclave attestation protocol as well as the logic for creating and caching e The information the provider uses to attest the enclave and generate a symmetric key for the session. The format of this information is specific to the enclave attestation protocol. A Diffie-Hellman algorithm object that encapsulates a client-side key pair. The set of parameters required for an enclave session. - The set of extra data needed for attestating the enclave. - The length of the extra data needed for attestating the enclave. + The set of extra data needed for attesting the enclave. + The length of the extra data needed for attesting the enclave. The requested enclave session or if the provider doesn't implement session caching. A counter that the enclave provider is expected to increment each time SqlClient retrieves the session from the cache. The purpose of this field is to prevent replay attacks. When overridden in a derived class, performs enclave attestation, generates a symmetric key for the session, creates a an enclave session and stores the session information in the cache. @@ -29,8 +29,8 @@ the enclave attestation protocol as well as the logic for creating and caching e The endpoint of an attestation service for attesting the enclave. - A set of extra data needed for attestating the enclave. - The length of the extra data needed for attestating the enclave. + A set of extra data needed for attesting the enclave. + The length of the extra data needed for attesting the enclave. Gets the information that SqlClient subsequently uses to initiate the process of attesting the enclave and to establish a secure session with the enclave. The information SqlClient subsequently uses to initiate the process of attesting the enclave and to establish a secure session with the enclave. To be added. @@ -38,10 +38,11 @@ the enclave attestation protocol as well as the logic for creating and caching e The set of parameters required for enclave session. to indicate that a set of extra data needs to be generated for attestation; otherwise, . + Indicates if this is a retry from a failed call. When this method returns, the requested enclave session or if the provider doesn't implement session caching. This parameter is treated as uninitialized. A counter that the enclave provider is expected to increment each time SqlClient retrieves the session from the cache. The purpose of this field is to prevent replay attacks. - A set of extra data needed for attestating the enclave. - The length of the extra data needed for attestating the enclave. + A set of extra data needed for attesting the enclave. + The length of the extra data needed for attesting the enclave. When overridden in a derived class, looks up an existing enclave session information in the enclave session cache. If the enclave provider doesn't implement enclave session caching, this method is expected to return in the parameter. To be added. diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlColumnEncryptionEnclaveProvider.NetCoreApp.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlColumnEncryptionEnclaveProvider.NetCoreApp.cs index cca04fc323..fd81db557d 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlColumnEncryptionEnclaveProvider.NetCoreApp.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlColumnEncryptionEnclaveProvider.NetCoreApp.cs @@ -15,8 +15,8 @@ internal abstract partial class SqlColumnEncryptionEnclaveProvider /// The information the provider uses to attest the enclave and generate a symmetric key for the session. The format of this information is specific to the enclave attestation protocol. /// A Diffie-Hellman algorithm object encapsulating a client-side key pair. /// The set of parameters required for enclave session. - /// The set of extra data needed for attestating the enclave. - /// The length of the extra data needed for attestating the enclave. + /// The set of extra data needed for attesting the enclave. + /// The length of the extra data needed for attesting the enclave. /// The requested enclave session or null if the provider does not implement session caching. /// A counter that the enclave provider is expected to increment each time SqlClient retrieves the session from the cache. The purpose of this field is to prevent replay attacks. internal abstract void CreateEnclaveSession(byte[] enclaveAttestationInfo, ECDiffieHellman clientDiffieHellmanKey, EnclaveSessionParameters enclaveSessionParameters, byte[] customData, int customDataLength, diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlColumnEncryptionEnclaveProvider.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlColumnEncryptionEnclaveProvider.cs index e374d43665..9223702509 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlColumnEncryptionEnclaveProvider.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlColumnEncryptionEnclaveProvider.cs @@ -8,7 +8,7 @@ namespace Microsoft.Data.SqlClient internal abstract partial class SqlColumnEncryptionEnclaveProvider { /// - internal abstract void GetEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength); + internal abstract void GetEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, bool isRetry, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength); /// internal abstract SqlEnclaveAttestationParameters GetAttestationParameters(string attestationUrl, byte[] customData, int customDataLength); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs index 8fc2d5735b..1080b0c7c2 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -112,6 +112,11 @@ protected override void AfterCleared(SqlCommand owner) /// Internal flag for testing purposes that forces all queries to internally end async calls. /// private static bool _forceInternalEndQuery = false; + + /// + /// Internal flag for testing purposes that forces one RetryableEnclaveQueryExecutionException during GenerateEnclavePackage + /// + private static bool _forceRetryableEnclaveQueryExecutionExceptionDuringGenerateEnclavePackage = false; #endif private static readonly SqlDiagnosticListener _diagnosticListener = new SqlDiagnosticListener(SqlClientDiagnosticListenerExtensions.DiagnosticListenerName); @@ -2198,7 +2203,7 @@ private IAsyncResult BeginExecuteReaderInternal(CommandBehavior behavior, AsyncC // back into pool when we should not. } - bool usedCache; + bool usedCache = false; Task writeTask = null; try { @@ -2215,7 +2220,10 @@ private IAsyncResult BeginExecuteReaderInternal(CommandBehavior behavior, AsyncC // For async, RunExecuteReader will never put the stateObj back into the pool, so do so now. ReliablePutStateObject(); - throw; + if (inRetry || e is not EnclaveDelegate.RetryableEnclaveQueryExecutionException) + { + throw; + } } if (writeTask != null) @@ -2416,12 +2424,7 @@ long firstAttemptStart // Remove the entry from the cache since it was inconsistent. SqlQueryMetadataCache.GetInstance().InvalidateCacheEntry(this); - if (ShouldUseEnclaveBasedWorkflow && this.enclavePackage != null) - { - EnclaveSessionParameters enclaveSessionParameters = new EnclaveSessionParameters(this._activeConnection.DataSource, this._activeConnection.EnclaveAttestationUrl, this._activeConnection.Database); - EnclaveDelegate.Instance.InvalidateEnclaveSession(this._activeConnection.AttestationProtocol, this._activeConnection.Parser.EnclaveType, - enclaveSessionParameters, this.enclavePackage.EnclaveSession); - } + InvalidateEnclaveSession(); try { @@ -2457,6 +2460,26 @@ long firstAttemptStart }, TaskScheduler.Default); } + private void InvalidateEnclaveSession() + { + if (ShouldUseEnclaveBasedWorkflow && this.enclavePackage != null) + { + EnclaveDelegate.Instance.InvalidateEnclaveSession( + this._activeConnection.AttestationProtocol, + this._activeConnection.Parser.EnclaveType, + GetEnclaveSessionParameters(), + this.enclavePackage.EnclaveSession); + } + } + + private EnclaveSessionParameters GetEnclaveSessionParameters() + { + return new EnclaveSessionParameters( + this._activeConnection.DataSource, + this._activeConnection.EnclaveAttestationUrl, + this._activeConnection.Database); + } + private void BeginExecuteReaderInternalReadStage(TaskCompletionSource completion) { Debug.Assert(completion != null, "CompletionSource should not be null"); @@ -3610,7 +3633,13 @@ private void PrepareForTransparentEncryption(CommandBehavior cmdBehavior, bool r try { // Fetch the encryption information that applies to any of the input parameters. - describeParameterEncryptionDataReader = TryFetchInputParameterEncryptionInfo(timeout, isAsync, asyncWrite, out describeParameterEncryptionNeeded, out fetchInputParameterEncryptionInfoTask, out describeParameterEncryptionRpcOriginalRpcMap); + describeParameterEncryptionDataReader = TryFetchInputParameterEncryptionInfo(timeout, + isAsync, + asyncWrite, + out describeParameterEncryptionNeeded, + out fetchInputParameterEncryptionInfoTask, + out describeParameterEncryptionRpcOriginalRpcMap, + inRetry); Debug.Assert(describeParameterEncryptionNeeded || describeParameterEncryptionDataReader == null, "describeParameterEncryptionDataReader should be null if we don't need to request describe parameter encryption request."); @@ -3645,7 +3674,13 @@ private void PrepareForTransparentEncryption(CommandBehavior cmdBehavior, bool r // Mark that we should not process the finally block since we have async execution pending. // Note that this should be done outside the task's continuation delegate. processFinallyBlock = false; - describeParameterEncryptionDataReader = GetParameterEncryptionDataReader(out returnTask, fetchInputParameterEncryptionInfoTask, describeParameterEncryptionDataReader, describeParameterEncryptionRpcOriginalRpcMap, describeParameterEncryptionNeeded); + describeParameterEncryptionDataReader = GetParameterEncryptionDataReader( + out returnTask, + fetchInputParameterEncryptionInfoTask, + describeParameterEncryptionDataReader, + describeParameterEncryptionRpcOriginalRpcMap, + describeParameterEncryptionNeeded, + inRetry); decrementAsyncCountInFinallyBlock = false; } @@ -3657,14 +3692,22 @@ private void PrepareForTransparentEncryption(CommandBehavior cmdBehavior, bool r // Mark that we should not process the finally block since we have async execution pending. // Note that this should be done outside the task's continuation delegate. processFinallyBlock = false; - describeParameterEncryptionDataReader = GetParameterEncryptionDataReaderAsync(out returnTask, describeParameterEncryptionDataReader, describeParameterEncryptionRpcOriginalRpcMap, describeParameterEncryptionNeeded); + describeParameterEncryptionDataReader = GetParameterEncryptionDataReaderAsync( + out returnTask, + describeParameterEncryptionDataReader, + describeParameterEncryptionRpcOriginalRpcMap, + describeParameterEncryptionNeeded, + inRetry); decrementAsyncCountInFinallyBlock = false; } else { // For synchronous execution, read the results of describe parameter encryption here. - ReadDescribeEncryptionParameterResults(describeParameterEncryptionDataReader, describeParameterEncryptionRpcOriginalRpcMap); + ReadDescribeEncryptionParameterResults( + describeParameterEncryptionDataReader, + describeParameterEncryptionRpcOriginalRpcMap, + inRetry); } #if DEBUG @@ -3711,7 +3754,7 @@ private void PrepareForTransparentEncryption(CommandBehavior cmdBehavior, bool r private SqlDataReader GetParameterEncryptionDataReader(out Task returnTask, Task fetchInputParameterEncryptionInfoTask, SqlDataReader describeParameterEncryptionDataReader, - ReadOnlyDictionary<_SqlRPC, _SqlRPC> describeParameterEncryptionRpcOriginalRpcMap, bool describeParameterEncryptionNeeded) + ReadOnlyDictionary<_SqlRPC, _SqlRPC> describeParameterEncryptionRpcOriginalRpcMap, bool describeParameterEncryptionNeeded, bool inRetry) { returnTask = AsyncHelper.CreateContinuationTaskWithState(fetchInputParameterEncryptionInfoTask, this, (object state) => @@ -3740,7 +3783,7 @@ private SqlDataReader GetParameterEncryptionDataReader(out Task returnTask, Task Debug.Assert(null == command._stateObj, "non-null state object in PrepareForTransparentEncryption."); // Read the results of describe parameter encryption. - command.ReadDescribeEncryptionParameterResults(describeParameterEncryptionDataReader, describeParameterEncryptionRpcOriginalRpcMap); + command.ReadDescribeEncryptionParameterResults(describeParameterEncryptionDataReader, describeParameterEncryptionRpcOriginalRpcMap, inRetry); #if DEBUG // Failpoint to force the thread to halt to simulate cancellation of SqlCommand. @@ -3785,7 +3828,7 @@ private SqlDataReader GetParameterEncryptionDataReader(out Task returnTask, Task private SqlDataReader GetParameterEncryptionDataReaderAsync(out Task returnTask, SqlDataReader describeParameterEncryptionDataReader, - ReadOnlyDictionary<_SqlRPC, _SqlRPC> describeParameterEncryptionRpcOriginalRpcMap, bool describeParameterEncryptionNeeded) + ReadOnlyDictionary<_SqlRPC, _SqlRPC> describeParameterEncryptionRpcOriginalRpcMap, bool describeParameterEncryptionNeeded, bool inRetry) { returnTask = Task.Run(() => { @@ -3813,7 +3856,7 @@ private SqlDataReader GetParameterEncryptionDataReaderAsync(out Task returnTask, // Read the results of describe parameter encryption. ReadDescribeEncryptionParameterResults(describeParameterEncryptionDataReader, - describeParameterEncryptionRpcOriginalRpcMap); + describeParameterEncryptionRpcOriginalRpcMap, inRetry); #if DEBUG // Failpoint to force the thread to halt to simulate cancellation of SqlCommand. if (_sleepAfterReadDescribeEncryptionParameterResults) @@ -3850,13 +3893,15 @@ private SqlDataReader GetParameterEncryptionDataReaderAsync(out Task returnTask, /// /// /// + /// Indicates if this is a retry from a failed call. /// private SqlDataReader TryFetchInputParameterEncryptionInfo(int timeout, bool isAsync, bool asyncWrite, out bool inputParameterEncryptionNeeded, out Task task, - out ReadOnlyDictionary<_SqlRPC, _SqlRPC> describeParameterEncryptionRpcOriginalRpcMap) + out ReadOnlyDictionary<_SqlRPC, _SqlRPC> describeParameterEncryptionRpcOriginalRpcMap, + bool isRetry) { inputParameterEncryptionNeeded = false; task = null; @@ -3868,10 +3913,10 @@ private SqlDataReader TryFetchInputParameterEncryptionInfo(int timeout, SqlConnectionAttestationProtocol attestationProtocol = this._activeConnection.AttestationProtocol; string enclaveType = this._activeConnection.Parser.EnclaveType; - EnclaveSessionParameters enclaveSessionParameters = new EnclaveSessionParameters(this._activeConnection.DataSource, this._activeConnection.EnclaveAttestationUrl, this._activeConnection.Database); + EnclaveSessionParameters enclaveSessionParameters = GetEnclaveSessionParameters(); SqlEnclaveSession sqlEnclaveSession = null; - EnclaveDelegate.Instance.GetEnclaveSession(attestationProtocol, enclaveType, enclaveSessionParameters, true, out sqlEnclaveSession, out customData, out customDataLength); + EnclaveDelegate.Instance.GetEnclaveSession(attestationProtocol, enclaveType, enclaveSessionParameters, true, isRetry, out sqlEnclaveSession, out customData, out customDataLength); if (sqlEnclaveSession == null) { enclaveAttestationParameters = EnclaveDelegate.Instance.GetAttestationParameters(attestationProtocol, enclaveType, enclaveSessionParameters.AttestationUrl, customData, customDataLength); @@ -4113,7 +4158,8 @@ private void PrepareDescribeParameterEncryptionRequest(_SqlRPC originalRpcReques /// /// Resultset from calling to sp_describe_parameter_encryption /// Readonly dictionary with the map of parameter encryption rpc requests with the corresponding original rpc requests. - private void ReadDescribeEncryptionParameterResults(SqlDataReader ds, ReadOnlyDictionary<_SqlRPC, _SqlRPC> describeParameterEncryptionRpcOriginalRpcMap) + /// Indicates if this is a retry from a failed call. + private void ReadDescribeEncryptionParameterResults(SqlDataReader ds, ReadOnlyDictionary<_SqlRPC, _SqlRPC> describeParameterEncryptionRpcOriginalRpcMap, bool isRetry) { _SqlRPC rpc = null; int currentOrdinal = -1; @@ -4392,9 +4438,16 @@ private void ReadDescribeEncryptionParameterResults(SqlDataReader ds, ReadOnlyDi SqlConnectionAttestationProtocol attestationProtocol = this._activeConnection.AttestationProtocol; string enclaveType = this._activeConnection.Parser.EnclaveType; - EnclaveSessionParameters enclaveSessionParameters = new EnclaveSessionParameters(this._activeConnection.DataSource, this._activeConnection.EnclaveAttestationUrl, this._activeConnection.Database); - EnclaveDelegate.Instance.CreateEnclaveSession(attestationProtocol, enclaveType, enclaveSessionParameters, attestationInfo, enclaveAttestationParameters, customData, customDataLength); + EnclaveDelegate.Instance.CreateEnclaveSession( + attestationProtocol, + enclaveType, + GetEnclaveSessionParameters(), + attestationInfo, + enclaveAttestationParameters, + customData, + customDataLength, + isRetry); enclaveAttestationParameters = null; attestationInfoRead = true; } @@ -4495,7 +4548,7 @@ internal SqlDataReader RunExecuteReader(CommandBehavior cmdBehavior, RunBehavior catch (EnclaveDelegate.RetryableEnclaveQueryExecutionException) { - if (inRetry || isAsync) + if (inRetry) { throw; } @@ -4504,22 +4557,16 @@ internal SqlDataReader RunExecuteReader(CommandBehavior cmdBehavior, RunBehavior // First invalidate the entry from the cache, so that we refresh our encryption MD. SqlQueryMetadataCache.GetInstance().InvalidateCacheEntry(this); - if (ShouldUseEnclaveBasedWorkflow && this.enclavePackage != null) - { - EnclaveSessionParameters enclaveSessionParameters = new EnclaveSessionParameters(this._activeConnection.DataSource, this._activeConnection.EnclaveAttestationUrl, this._activeConnection.Database); - EnclaveDelegate.Instance.InvalidateEnclaveSession(this._activeConnection.AttestationProtocol, this._activeConnection.Parser.EnclaveType, - enclaveSessionParameters, this.enclavePackage.EnclaveSession); - } + InvalidateEnclaveSession(); - return RunExecuteReader(cmdBehavior, runBehavior, returnStream, null, TdsParserStaticMethods.GetRemainingTimeout(timeout, firstAttemptStart), out task, out usedCache, isAsync, inRetry: true, method: method); + return RunExecuteReader(cmdBehavior, runBehavior, returnStream, completion, TdsParserStaticMethods.GetRemainingTimeout(timeout, firstAttemptStart), out task, out usedCache, isAsync, inRetry: true, method: method); } catch (SqlException ex) { // We only want to retry once, so don't retry if we are already in retry. // If we didn't use the cache, we don't want to retry. - // The async retried are handled separately, handle only sync calls here. - if (inRetry || isAsync || (!usedCache && !ShouldUseEnclaveBasedWorkflow)) + if (inRetry || (!usedCache && !ShouldUseEnclaveBasedWorkflow)) { throw; } @@ -4548,14 +4595,9 @@ internal SqlDataReader RunExecuteReader(CommandBehavior cmdBehavior, RunBehavior // First invalidate the entry from the cache, so that we refresh our encryption MD. SqlQueryMetadataCache.GetInstance().InvalidateCacheEntry(this); - if (ShouldUseEnclaveBasedWorkflow && this.enclavePackage != null) - { - EnclaveSessionParameters enclaveSessionParameters = new EnclaveSessionParameters(this._activeConnection.DataSource, this._activeConnection.EnclaveAttestationUrl, this._activeConnection.Database); - EnclaveDelegate.Instance.InvalidateEnclaveSession(this._activeConnection.AttestationProtocol, this._activeConnection.Parser.EnclaveType, - enclaveSessionParameters, this.enclavePackage.EnclaveSession); - } + InvalidateEnclaveSession(); - return RunExecuteReader(cmdBehavior, runBehavior, returnStream, null, TdsParserStaticMethods.GetRemainingTimeout(timeout, firstAttemptStart), out task, out usedCache, isAsync, inRetry: true, method: method); + return RunExecuteReader(cmdBehavior, runBehavior, returnStream, completion, TdsParserStaticMethods.GetRemainingTimeout(timeout, firstAttemptStart), out task, out usedCache, isAsync, inRetry: true, method: method); } } } @@ -4652,9 +4694,15 @@ private void GenerateEnclavePackage() try { - EnclaveSessionParameters enclaveSessionParameters = new EnclaveSessionParameters(this._activeConnection.DataSource, this._activeConnection.EnclaveAttestationUrl, this._activeConnection.Database); +#if DEBUG + if (_forceRetryableEnclaveQueryExecutionExceptionDuringGenerateEnclavePackage) + { + _forceRetryableEnclaveQueryExecutionExceptionDuringGenerateEnclavePackage = false; + throw new EnclaveDelegate.RetryableEnclaveQueryExecutionException("testing", null); + } +#endif this.enclavePackage = EnclaveDelegate.Instance.GenerateEnclavePackage(attestationProtocol, keysToBeSentToEnclave, - this.CommandText, enclaveType, enclaveSessionParameters, _activeConnection, this); + this.CommandText, enclaveType, GetEnclaveSessionParameters(), _activeConnection, this); } catch (EnclaveDelegate.RetryableEnclaveQueryExecutionException) { @@ -4712,8 +4760,7 @@ private SqlDataReader RunExecuteReaderTds(CommandBehavior cmdBehavior, RunBehavi bool processFinallyBlock = true; bool decrementAsyncCountOnFailure = false; - // If we are in retry, don't increment the Async count. This should have already been set. - if (isAsync && !inRetry) + if (isAsync) { _activeConnection.GetOpenTdsConnection().IncrementAsyncCount(); decrementAsyncCountOnFailure = true; diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlColumnEncryptionEnclaveProvider.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlColumnEncryptionEnclaveProvider.cs index 9017b84717..de26b79232 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlColumnEncryptionEnclaveProvider.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlColumnEncryptionEnclaveProvider.cs @@ -11,7 +11,7 @@ internal abstract class SqlColumnEncryptionEnclaveProvider { /// - internal abstract void GetEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength); + internal abstract void GetEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, bool isRetry, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength); /// internal abstract SqlEnclaveAttestationParameters GetAttestationParameters(string attestationUrl, byte[] customData, int customDataLength); diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs index 878bdc2c7b..bfb51ac053 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -83,6 +83,11 @@ public sealed class SqlCommand : DbCommand, ICloneable /// Internal flag for testing purposes that forces all queries to internally end async calls. /// private static bool _forceInternalEndQuery = false; + + /// + /// Internal flag for testing purposes that forces one RetryableEnclaveQueryExecutionException during GenerateEnclavePackage + /// + private static bool _forceRetryableEnclaveQueryExecutionExceptionDuringGenerateEnclavePackage = false; #endif internal static readonly Action s_cancelIgnoreFailure = CancelIgnoreFailureCallback; @@ -2650,7 +2655,7 @@ private IAsyncResult BeginExecuteReaderInternal(CommandBehavior behavior, AsyncC // back into pool when we should not. } - bool usedCache; + bool usedCache = false; Task writeTask = null; try { @@ -2783,11 +2788,7 @@ private bool TriggerInternalEndAndRetryIfNecessary(CommandBehavior behavior, obj // Remove the enrty from the cache since it was inconsistent. SqlQueryMetadataCache.GetInstance().InvalidateCacheEntry(this); - if (ShouldUseEnclaveBasedWorkflow && this.enclavePackage != null) - { - EnclaveSessionParameters enclaveSessionParameters = new EnclaveSessionParameters(this._activeConnection.DataSource, this._activeConnection.EnclaveAttestationUrl, this._activeConnection.Database); - EnclaveDelegate.Instance.InvalidateEnclaveSession(this._activeConnection.AttestationProtocol, this._activeConnection.Parser.EnclaveType, enclaveSessionParameters, this.enclavePackage.EnclaveSession); - } + InvalidateEnclaveSession(); try { @@ -2828,6 +2829,26 @@ private bool TriggerInternalEndAndRetryIfNecessary(CommandBehavior behavior, obj } } + private void InvalidateEnclaveSession() + { + if (ShouldUseEnclaveBasedWorkflow && this.enclavePackage != null) + { + EnclaveDelegate.Instance.InvalidateEnclaveSession( + this._activeConnection.AttestationProtocol, + this._activeConnection.Parser.EnclaveType, + GetEnclaveSessionParameters(), + this.enclavePackage.EnclaveSession); + } + } + + private EnclaveSessionParameters GetEnclaveSessionParameters() + { + return new EnclaveSessionParameters( + this._activeConnection.DataSource, + this._activeConnection.EnclaveAttestationUrl, + this._activeConnection.Database); + } + private void BeginExecuteReaderInternalReadStage(TaskCompletionSource completion) { Debug.Assert(completion != null, "CompletionSource should not be null"); @@ -4144,7 +4165,8 @@ private void PrepareForTransparentEncryption(CommandBehavior cmdBehavior, bool r asyncWrite, out describeParameterEncryptionNeeded, out fetchInputParameterEncryptionInfoTask, - out describeParameterEncryptionRpcOriginalRpcMap); + out describeParameterEncryptionRpcOriginalRpcMap, + inRetry); Debug.Assert(describeParameterEncryptionNeeded || describeParameterEncryptionDataReader == null, "describeParameterEncryptionDataReader should be null if we don't need to request describe parameter encryption request."); @@ -4211,7 +4233,10 @@ private void PrepareForTransparentEncryption(CommandBehavior cmdBehavior, bool r Debug.Assert(null == _stateObj, "non-null state object in PrepareForTransparentEncryption."); // Read the results of describe parameter encryption. - ReadDescribeEncryptionParameterResults(describeParameterEncryptionDataReader, describeParameterEncryptionRpcOriginalRpcMap); + ReadDescribeEncryptionParameterResults( + describeParameterEncryptionDataReader, + describeParameterEncryptionRpcOriginalRpcMap, + inRetry); #if DEBUG // Failpoint to force the thread to halt to simulate cancellation of SqlCommand. @@ -4296,7 +4321,7 @@ private void PrepareForTransparentEncryption(CommandBehavior cmdBehavior, bool r Debug.Assert(null == _stateObj, "non-null state object in PrepareForTransparentEncryption."); // Read the results of describe parameter encryption. - ReadDescribeEncryptionParameterResults(describeParameterEncryptionDataReader, describeParameterEncryptionRpcOriginalRpcMap); + ReadDescribeEncryptionParameterResults(describeParameterEncryptionDataReader, describeParameterEncryptionRpcOriginalRpcMap, inRetry); #if DEBUG // Failpoint to force the thread to halt to simulate cancellation of SqlCommand. if (_sleepAfterReadDescribeEncryptionParameterResults) @@ -4333,7 +4358,7 @@ private void PrepareForTransparentEncryption(CommandBehavior cmdBehavior, bool r else { // For synchronous execution, read the results of describe parameter encryption here. - ReadDescribeEncryptionParameterResults(describeParameterEncryptionDataReader, describeParameterEncryptionRpcOriginalRpcMap); + ReadDescribeEncryptionParameterResults(describeParameterEncryptionDataReader, describeParameterEncryptionRpcOriginalRpcMap, inRetry); } #if DEBUG @@ -4411,13 +4436,15 @@ private void PrepareForTransparentEncryption(CommandBehavior cmdBehavior, bool r /// /// /// + /// Indicates if this is a retry from a failed call. /// private SqlDataReader TryFetchInputParameterEncryptionInfo(int timeout, bool async, bool asyncWrite, out bool inputParameterEncryptionNeeded, out Task task, - out ReadOnlyDictionary<_SqlRPC, _SqlRPC> describeParameterEncryptionRpcOriginalRpcMap) + out ReadOnlyDictionary<_SqlRPC, _SqlRPC> describeParameterEncryptionRpcOriginalRpcMap, + bool inRetry) { inputParameterEncryptionNeeded = false; task = null; @@ -4429,10 +4456,10 @@ private SqlDataReader TryFetchInputParameterEncryptionInfo(int timeout, SqlConnectionAttestationProtocol attestationProtocol = this._activeConnection.AttestationProtocol; string enclaveType = this._activeConnection.Parser.EnclaveType; - EnclaveSessionParameters enclaveSessionParameters = new EnclaveSessionParameters(this._activeConnection.DataSource, this._activeConnection.EnclaveAttestationUrl, this._activeConnection.Database); + EnclaveSessionParameters enclaveSessionParameters = GetEnclaveSessionParameters(); SqlEnclaveSession sqlEnclaveSession = null; - EnclaveDelegate.Instance.GetEnclaveSession(attestationProtocol, enclaveType, enclaveSessionParameters, true, out sqlEnclaveSession, out customData, out customDataLength); + EnclaveDelegate.Instance.GetEnclaveSession(attestationProtocol, enclaveType, enclaveSessionParameters, true, inRetry, out sqlEnclaveSession, out customData, out customDataLength); if (sqlEnclaveSession == null) { enclaveAttestationParameters = EnclaveDelegate.Instance.GetAttestationParameters(attestationProtocol, enclaveType, enclaveSessionParameters.AttestationUrl, customData, customDataLength); @@ -4682,7 +4709,8 @@ private void PrepareDescribeParameterEncryptionRequest(_SqlRPC originalRpcReques /// /// Resultset from calling to sp_describe_parameter_encryption /// Readonly dictionary with the map of parameter encryption rpc requests with the corresponding original rpc requests. - private void ReadDescribeEncryptionParameterResults(SqlDataReader ds, ReadOnlyDictionary<_SqlRPC, _SqlRPC> describeParameterEncryptionRpcOriginalRpcMap) + /// Indicates if this is a retry from a failed call. + private void ReadDescribeEncryptionParameterResults(SqlDataReader ds, ReadOnlyDictionary<_SqlRPC, _SqlRPC> describeParameterEncryptionRpcOriginalRpcMap, bool inRetry) { _SqlRPC rpc = null; int currentOrdinal = -1; @@ -4956,9 +4984,16 @@ private void ReadDescribeEncryptionParameterResults(SqlDataReader ds, ReadOnlyDi SqlConnectionAttestationProtocol attestationProtocol = this._activeConnection.AttestationProtocol; string enclaveType = this._activeConnection.Parser.EnclaveType; - EnclaveSessionParameters enclaveSessionParameters = new EnclaveSessionParameters(this._activeConnection.DataSource, this._activeConnection.EnclaveAttestationUrl, this._activeConnection.Database); - EnclaveDelegate.Instance.CreateEnclaveSession(attestationProtocol, enclaveType, enclaveSessionParameters, attestationInfo, enclaveAttestationParameters, customData, customDataLength); + EnclaveDelegate.Instance.CreateEnclaveSession( + attestationProtocol, + enclaveType, + GetEnclaveSessionParameters(), + attestationInfo, + enclaveAttestationParameters, + customData, + customDataLength, + inRetry); enclaveAttestationParameters = null; attestationInfoRead = true; } @@ -5080,7 +5115,7 @@ internal SqlDataReader RunExecuteReader(CommandBehavior cmdBehavior, RunBehavior catch (EnclaveDelegate.RetryableEnclaveQueryExecutionException) { - if (inRetry || async) + if (inRetry) { throw; } @@ -5089,22 +5124,16 @@ internal SqlDataReader RunExecuteReader(CommandBehavior cmdBehavior, RunBehavior // First invalidate the entry from the cache, so that we refresh our encryption MD. SqlQueryMetadataCache.GetInstance().InvalidateCacheEntry(this); - if (ShouldUseEnclaveBasedWorkflow && this.enclavePackage != null) - { - EnclaveSessionParameters enclaveSessionParameters = new EnclaveSessionParameters(this._activeConnection.DataSource, this._activeConnection.EnclaveAttestationUrl, this._activeConnection.Database); - EnclaveDelegate.Instance.InvalidateEnclaveSession(this._activeConnection.AttestationProtocol, this._activeConnection.Parser.EnclaveType, - enclaveSessionParameters, this.enclavePackage.EnclaveSession); - } + InvalidateEnclaveSession(); - return RunExecuteReader(cmdBehavior, runBehavior, returnStream, method, null, TdsParserStaticMethods.GetRemainingTimeout(timeout, firstAttemptStart), out task, out usedCache, async, inRetry: true); + return RunExecuteReader(cmdBehavior, runBehavior, returnStream, method, completion, TdsParserStaticMethods.GetRemainingTimeout(timeout, firstAttemptStart), out task, out usedCache, async, inRetry: true); } catch (SqlException ex) { // We only want to retry once, so don't retry if we are already in retry. // If we didn't use the cache, we don't want to retry. - // The async retried are handled separately, handle only sync calls here. - if (inRetry || async || (!usedCache && !ShouldUseEnclaveBasedWorkflow)) + if (inRetry || (!usedCache && !ShouldUseEnclaveBasedWorkflow)) { throw; } @@ -5133,14 +5162,9 @@ internal SqlDataReader RunExecuteReader(CommandBehavior cmdBehavior, RunBehavior // First invalidate the entry from the cache, so that we refresh our encryption MD. SqlQueryMetadataCache.GetInstance().InvalidateCacheEntry(this); - if (ShouldUseEnclaveBasedWorkflow && this.enclavePackage != null) - { - EnclaveSessionParameters enclaveSessionParameters = new EnclaveSessionParameters(this._activeConnection.DataSource, this._activeConnection.EnclaveAttestationUrl, this._activeConnection.Database); - EnclaveDelegate.Instance.InvalidateEnclaveSession(this._activeConnection.AttestationProtocol, this._activeConnection.Parser.EnclaveType, - enclaveSessionParameters, this.enclavePackage.EnclaveSession); - } + InvalidateEnclaveSession(); - return RunExecuteReader(cmdBehavior, runBehavior, returnStream, method, null, TdsParserStaticMethods.GetRemainingTimeout(timeout, firstAttemptStart), out task, out usedCache, async, inRetry: true); + return RunExecuteReader(cmdBehavior, runBehavior, returnStream, method, completion, TdsParserStaticMethods.GetRemainingTimeout(timeout, firstAttemptStart), out task, out usedCache, async, inRetry: true); } } } @@ -5148,7 +5172,6 @@ internal SqlDataReader RunExecuteReader(CommandBehavior cmdBehavior, RunBehavior { return RunExecuteReaderTds(cmdBehavior, runBehavior, returnStream, async, timeout, out task, asyncWrite && async, inRetry: inRetry); } - } #if DEBUG finally @@ -5260,9 +5283,15 @@ private void GenerateEnclavePackage() try { - EnclaveSessionParameters enclaveSessionParameters = new EnclaveSessionParameters(this._activeConnection.DataSource, this._activeConnection.EnclaveAttestationUrl, this._activeConnection.Database); +#if DEBUG + if (_forceRetryableEnclaveQueryExecutionExceptionDuringGenerateEnclavePackage) + { + _forceRetryableEnclaveQueryExecutionExceptionDuringGenerateEnclavePackage = false; + throw new EnclaveDelegate.RetryableEnclaveQueryExecutionException("testing", null); + } +#endif this.enclavePackage = EnclaveDelegate.Instance.GenerateEnclavePackage(attestationProtocol, keysToBeSentToEnclave, - this.CommandText, enclaveType, enclaveSessionParameters, _activeConnection, this); + this.CommandText, enclaveType, GetEnclaveSessionParameters(), _activeConnection, this); } catch (EnclaveDelegate.RetryableEnclaveQueryExecutionException) { @@ -5343,8 +5372,7 @@ private SqlDataReader RunExecuteReaderTds(CommandBehavior cmdBehavior, RunBehavi bool processFinallyBlock = true; bool decrementAsyncCountOnFailure = false; - // If we are in retry, don't increment the Async count. This should have already been set. - if (async && !inRetry) + if (async) { _activeConnection.GetOpenTdsConnection().IncrementAsyncCount(); decrementAsyncCountOnFailure = true; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/AzureAttestationBasedEnclaveProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/AzureAttestationBasedEnclaveProvider.cs index d08db25036..987d12aaa2 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/AzureAttestationBasedEnclaveProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/AzureAttestationBasedEnclaveProvider.cs @@ -65,9 +65,9 @@ internal class AzureAttestationEnclaveProvider : EnclaveProviderBase #region Internal methods // When overridden in a derived class, looks up an existing enclave session information in the enclave session cache. // If the enclave provider doesn't implement enclave session caching, this method is expected to return null in the sqlEnclaveSession parameter. - internal override void GetEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength) + internal override void GetEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, bool isRetry, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength) { - GetEnclaveSessionHelper(enclaveSessionParameters, generateCustomData, out sqlEnclaveSession, out counter, out customData, out customDataLength); + GetEnclaveSessionHelper(enclaveSessionParameters, generateCustomData, isRetry, out sqlEnclaveSession, out counter, out customData, out customDataLength); } // Gets the information that SqlClient subsequently uses to initiate the process of attesting the enclave and to establish a secure session with the enclave. diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/EnclaveDelegate.Crypto.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/EnclaveDelegate.Crypto.cs index bf8786ef7d..3b58689220 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/EnclaveDelegate.Crypto.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/EnclaveDelegate.Crypto.cs @@ -20,10 +20,11 @@ internal sealed partial class EnclaveDelegate /// The set of parameters required for enclave session. /// attestation info from SQL Server /// attestation parameters - /// A set of extra data needed for attestating the enclave. - /// The length of the extra data needed for attestating the enclave. + /// A set of extra data needed for attesting the enclave. + /// The length of the extra data needed for attesting the enclave. + /// Indicates if this is a retry from a failed call. internal void CreateEnclaveSession(SqlConnectionAttestationProtocol attestationProtocol, string enclaveType, EnclaveSessionParameters enclaveSessionParameters, - byte[] attestationInfo, SqlEnclaveAttestationParameters attestationParameters, byte[] customData, int customDataLength) + byte[] attestationInfo, SqlEnclaveAttestationParameters attestationParameters, byte[] customData, int customDataLength, bool isRetry) { lock (_lock) { @@ -32,6 +33,7 @@ internal void CreateEnclaveSession(SqlConnectionAttestationProtocol attestationP sqlColumnEncryptionEnclaveProvider.GetEnclaveSession( enclaveSessionParameters, generateCustomData: false, + isRetry: isRetry, sqlEnclaveSession: out SqlEnclaveSession sqlEnclaveSession, counter: out _, customData: out _, @@ -60,15 +62,15 @@ internal void CreateEnclaveSession(SqlConnectionAttestationProtocol attestationP } } - internal void GetEnclaveSession(SqlConnectionAttestationProtocol attestationProtocol, string enclaveType, EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out byte[] customData, out int customDataLength) + internal void GetEnclaveSession(SqlConnectionAttestationProtocol attestationProtocol, string enclaveType, EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, bool isRetry, out SqlEnclaveSession sqlEnclaveSession, out byte[] customData, out int customDataLength) { - GetEnclaveSession(attestationProtocol, enclaveType, enclaveSessionParameters, generateCustomData, out sqlEnclaveSession, out _, out customData, out customDataLength, throwIfNull: false); + GetEnclaveSession(attestationProtocol, enclaveType, enclaveSessionParameters, generateCustomData, isRetry, out sqlEnclaveSession, out _, out customData, out customDataLength, throwIfNull: false); } - private void GetEnclaveSession(SqlConnectionAttestationProtocol attestationProtocol, string enclaveType, EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength, bool throwIfNull) + private void GetEnclaveSession(SqlConnectionAttestationProtocol attestationProtocol, string enclaveType, EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, bool isRetry, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength, bool throwIfNull) { SqlColumnEncryptionEnclaveProvider sqlColumnEncryptionEnclaveProvider = GetEnclaveProvider(attestationProtocol, enclaveType); - sqlColumnEncryptionEnclaveProvider.GetEnclaveSession(enclaveSessionParameters, generateCustomData, out sqlEnclaveSession, out counter, out customData, out customDataLength); + sqlColumnEncryptionEnclaveProvider.GetEnclaveSession(enclaveSessionParameters, generateCustomData, isRetry, out sqlEnclaveSession, out counter, out customData, out customDataLength); if (throwIfNull && sqlEnclaveSession == null) { @@ -145,6 +147,7 @@ internal EnclavePackage GenerateEnclavePackage(SqlConnectionAttestationProtocol enclaveType, enclaveSessionParameters, generateCustomData: false, + isRetry: false, sqlEnclaveSession: out sqlEnclaveSession, counter: out counter, customData: out _, diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/EnclaveDelegate.NotSupported.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/EnclaveDelegate.NotSupported.cs index 7b43c8cb66..461bfdf6d7 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/EnclaveDelegate.NotSupported.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/EnclaveDelegate.NotSupported.cs @@ -24,15 +24,16 @@ internal byte[] GetSerializedAttestationParameters( /// The set of parameters required for enclave session. /// attestation info from SQL Server /// attestation parameters - /// A set of extra data needed for attestating the enclave. - /// The length of the extra data needed for attestating the enclave. + /// A set of extra data needed for attesting the enclave. + /// The length of the extra data needed for attesting the enclave. + /// Indicates if this is a retry from a failed call. internal void CreateEnclaveSession(SqlConnectionAttestationProtocol attestationProtocol, string enclaveType, EnclaveSessionParameters enclaveSessionParameters, - byte[] attestationInfo, SqlEnclaveAttestationParameters attestationParameters, byte[] customData, int customDataLength) + byte[] attestationInfo, SqlEnclaveAttestationParameters attestationParameters, byte[] customData, int customDataLength, bool isRetry) { throw new PlatformNotSupportedException(); } - internal void GetEnclaveSession(SqlConnectionAttestationProtocol attestationProtocol, string enclaveType, EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out byte[] customData, out int customDataLength) + internal void GetEnclaveSession(SqlConnectionAttestationProtocol attestationProtocol, string enclaveType, EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, bool isRetry, out SqlEnclaveSession sqlEnclaveSession, out byte[] customData, out int customDataLength) { throw new PlatformNotSupportedException(); } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/EnclaveProviderBase.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/EnclaveProviderBase.cs index 0c72beac9d..b8a52b9e4b 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/EnclaveProviderBase.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/EnclaveProviderBase.cs @@ -89,7 +89,7 @@ internal abstract class EnclaveProviderBase : SqlColumnEncryptionEnclaveProvider #region protected methods // Helper method to get the enclave session from the cache if present - protected void GetEnclaveSessionHelper(EnclaveSessionParameters enclaveSessionParameters, bool shouldGenerateNonce, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength) + protected void GetEnclaveSessionHelper(EnclaveSessionParameters enclaveSessionParameters, bool shouldGenerateNonce, bool isRetry, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength) { customData = null; customDataLength = 0; @@ -107,7 +107,7 @@ protected void GetEnclaveSessionHelper(EnclaveSessionParameters enclaveSessionPa { sameThreadRetry = true; } - else + else if (!isRetry) { // We are explicitly not signalling the event here, as we want to hold the event till driver calls CreateEnclaveSession // If we signal the event now, then multiple thread end up calling GetAttestationParameters which triggers the attestation workflow. @@ -124,7 +124,7 @@ protected void GetEnclaveSessionHelper(EnclaveSessionParameters enclaveSessionPa // In case of multi-threaded application, first thread will set the event and all the subsequent threads will wait here either until the enclave // session is created or timeout happens. - if (sessionCacheLockTaken || sameThreadRetry) + if (sessionCacheLockTaken || sameThreadRetry || isRetry) { // While the current thread is waiting for event to be signaled and in the meanwhile we already completed the attestation on different thread // then we need to signal the event here diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/VirtualSecureModeEnclaveProviderBase.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/VirtualSecureModeEnclaveProviderBase.cs index f71047d965..0ea0fdc91d 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/VirtualSecureModeEnclaveProviderBase.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/VirtualSecureModeEnclaveProviderBase.cs @@ -86,9 +86,9 @@ internal abstract class VirtualizationBasedSecurityEnclaveProviderBase : Enclave // When overridden in a derived class, looks up an existing enclave session information in the enclave session cache. // If the enclave provider doesn't implement enclave session caching, this method is expected to return null in the sqlEnclaveSession parameter. - internal override void GetEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength) + internal override void GetEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, bool isRetry, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength) { - GetEnclaveSessionHelper(enclaveSessionParameters, false, out sqlEnclaveSession, out counter, out customData, out customDataLength); + GetEnclaveSessionHelper(enclaveSessionParameters, false, isRetry, out sqlEnclaveSession, out counter, out customData, out customDataLength); } // Gets the information that SqlClient subsequently uses to initiate the process of attesting the enclave and to establish a secure session with the enclave. diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ApiShould.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ApiShould.cs index 855cf590b0..e1b1094f7d 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ApiShould.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ApiShould.cs @@ -2372,6 +2372,92 @@ public void TestRetryWhenAEParameterMetadataCacheIsStale(string connectionString cmd.ExecuteNonQuery(); } + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringSetupForAE), nameof(DataTestUtility.EnclaveEnabled))] + [ClassData(typeof(AEConnectionStringProvider))] + public void TestRetryWhenAEEnclaveCacheIsStale(string connectionString) + { + CleanUpTable(connectionString, _tableName); + + const int customerId = 50; + IList values = GetValues(dataHint: customerId); + InsertRows(tableName: _tableName, numberofRows: 1, values: values, connection: connectionString); + + ApiTestTable table = _fixture.ApiTestTable as ApiTestTable; + string enclaveSelectQuery = $@"SELECT CustomerId, FirstName, LastName FROM [{_tableName}] WHERE CustomerId > @CustomerId"; + string alterCekQueryFormatString = "ALTER TABLE [{0}] " + + "ALTER COLUMN [CustomerId] [int] " + + "ENCRYPTED WITH (COLUMN_ENCRYPTION_KEY = [{1}], " + + "ENCRYPTION_TYPE = Randomized, " + + "ALGORITHM = 'AEAD_AES_256_CBC_HMAC_SHA_256'); " + + "ALTER DATABASE SCOPED CONFIGURATION CLEAR PROCEDURE_CACHE;"; + + using SqlConnection sqlConnection = new(connectionString); + sqlConnection.Open(); + + // change the CEK and encryption type to randomized for the CustomerId column to ensure enclaves are used + using SqlCommand cmd = new SqlCommand( + string.Format(alterCekQueryFormatString, _tableName, table.columnEncryptionKey2.Name), + sqlConnection, + null, + SqlCommandColumnEncryptionSetting.Enabled); + cmd.ExecuteNonQuery(); + + // execute the select query to create the cache entry + cmd.CommandText = enclaveSelectQuery; + cmd.Parameters.AddWithValue("@CustomerId", 0); + using (SqlDataReader reader = cmd.ExecuteReader()) + { + while (reader.Read()) + { + Assert.Equal(customerId, (int)reader[0]); + } + reader.Close(); + } + + CommandHelper.InvalidateEnclaveSession(cmd); + + // Execute again to exercise the session retry logic + using (SqlDataReader reader = cmd.ExecuteReader()) + { + while (reader.Read()) + { + Assert.Equal(customerId, (int)reader[0]); + } + reader.Close(); + } + + CommandHelper.InvalidateEnclaveSession(cmd); + + // Execute again to exercise the async session retry logic + Task readAsyncTask = ReadAsync(cmd, values, CommandBehavior.Default); + readAsyncTask.GetAwaiter().GetResult(); + +#if DEBUG + CommandHelper.ForceThrowDuringGenerateEnclavePackage(cmd); + + // Execute again to exercise the session retry logic + using (SqlDataReader reader = cmd.ExecuteReader()) + { + while (reader.Read()) + { + Assert.Equal(customerId, (int)reader[0]); + } + reader.Close(); + } + + CommandHelper.ForceThrowDuringGenerateEnclavePackage(cmd); + + // Execute again to exercise the async session retry logic + Task readAsyncTask2 = ReadAsync(cmd, values, CommandBehavior.Default); + readAsyncTask2.GetAwaiter().GetResult(); +#endif + + // revert the CEK change to the CustomerId column + cmd.Parameters.Clear(); + cmd.CommandText = string.Format(alterCekQueryFormatString, _tableName, table.columnEncryptionKey1.Name); + cmd.ExecuteNonQuery(); + } + private void ExecuteQueryThatRequiresCustomKeyStoreProvider(SqlConnection connection) { using (SqlCommand command = CreateCommandThatRequiresCustomKeyStoreProvider(connection)) diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/CommandHelper.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/CommandHelper.cs index 6d56c30e49..ba3d578673 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/CommandHelper.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/CommandHelper.cs @@ -12,6 +12,11 @@ internal static class CommandHelper private static Type s_sqlCommand = typeof(SqlCommand); private static MethodInfo s_completePendingReadWithSuccess = s_sqlCommand.GetMethod("CompletePendingReadWithSuccess", BindingFlags.NonPublic | BindingFlags.Instance); private static MethodInfo s_completePendingReadWithFailure = s_sqlCommand.GetMethod("CompletePendingReadWithFailure", BindingFlags.NonPublic | BindingFlags.Instance); + private static MethodInfo s_invalidateEnclaveSession = s_sqlCommand.GetMethod("InvalidateEnclaveSession", BindingFlags.NonPublic | BindingFlags.Instance); +#if DEBUG + private static FieldInfo s_forceRetryableEnclaveQueryExecutionExceptionDuringGenerateEnclavePackage = + s_sqlCommand.GetField(@"_forceRetryableEnclaveQueryExecutionExceptionDuringGenerateEnclavePackage", BindingFlags.NonPublic | BindingFlags.Static); +#endif public static PropertyInfo s_debugForceAsyncWriteDelay = s_sqlCommand.GetProperty("DebugForceAsyncWriteDelay", BindingFlags.NonPublic | BindingFlags.Static); public static FieldInfo s_sleepDuringTryFetchInputParameterEncryptionInfo = s_sqlCommand.GetField(@"_sleepDuringTryFetchInputParameterEncryptionInfo", BindingFlags.Static | BindingFlags.NonPublic); public static PropertyInfo s_isDescribeParameterEncryptionRPCCurrentlyInProgress = s_sqlCommand.GetProperty(@"IsDescribeParameterEncryptionRPCCurrentlyInProgress", BindingFlags.Instance | BindingFlags.NonPublic); @@ -31,6 +36,18 @@ internal static void CompletePendingReadWithFailure(SqlCommand command, int erro s_completePendingReadWithFailure.Invoke(command, new object[] { errorCode, resetForcePendingReadsToWait }); } + internal static void InvalidateEnclaveSession(SqlCommand command) + { + s_invalidateEnclaveSession.Invoke(command, null); + } + +#if DEBUG + internal static void ForceThrowDuringGenerateEnclavePackage(SqlCommand command) + { + s_forceRetryableEnclaveQueryExecutionExceptionDuringGenerateEnclavePackage.SetValue(command, true); + } +#endif + internal static int ForceAsyncWriteDelay { get { return (int)s_debugForceAsyncWriteDelay.GetValue(null); }